├── hotfix ├── __init__.py ├── vision.py ├── transforms.py └── folder.py ├── ModelLoader ├── geffnet │ ├── version.py │ ├── __init__.py │ ├── model_factory.py │ ├── activations │ │ ├── activations_jit.py │ │ ├── activations.py │ │ ├── __init__.py │ │ └── activations_me.py │ ├── helpers.py │ └── config.py └── __init__.py ├── misc └── ZenNet_speed.png ├── .gitignore ├── ZenNet ├── zennet_imagenet1k_flops600M_res224.txt ├── zennet_imagenet1k_flops400M_res224.txt ├── zennet_cifar_model_size05M_res32.txt ├── zennet_imagenet1k_flops900M_res224.txt ├── zennet_imagenet1k_latency01ms_res224.txt ├── zennet_cifar_model_size1M_res32.txt ├── zennet_imagenet1k_latency02ms_res224.txt ├── zennet_cifar_model_size2M_res32.txt ├── zennet_imagenet1k_latency03ms_res224.txt ├── zennet_imagenet1k_latency05ms_res224.txt ├── zennet_imagenet1k_latency08ms_res224.txt ├── zennet_imagenet1k_latency12ms_res224.txt ├── masternet.py └── __init__.py ├── analyze_model.py ├── scripts ├── Zen_NAS_ImageNet_latency0.1ms.sh ├── Zen_NAS_ImageNet_latency0.2ms.sh ├── Zen_NAS_ImageNet_latency0.3ms.sh ├── Zen_NAS_ImageNet_latency0.5ms.sh ├── Zen_NAS_ImageNet_latency0.8ms.sh ├── Zen_NAS_ImageNet_latency1.2ms.sh ├── Zen_NAS_ImageNet_flops400M.sh ├── Zen_NAS_ImageNet_flops600M.sh ├── Zen_NAS_ImageNet_flops800M.sh ├── Zen_NAS_cifar_params2M.sh ├── Flops_NAS_cifar_params1M.sh ├── Params_NAS_cifar_params1M.sh ├── Random_NAS_cifar_params1M.sh ├── TE_NAS_cifar_params1M.sh ├── Zen_NAS_cifar_params1M.sh ├── NASWOT_NAS_cifar_params1M.sh ├── Syncflow_NAS_cifar_params1M.sh └── GradNorm_NAS_cifar_params1M.sh ├── ZeroShotProxy ├── compute_gradnorm_score.py ├── compute_zen_score.py ├── compute_NASWOT_score.py ├── compute_syncflow_score.py └── compute_te_nas_score.py ├── val.py ├── val_cifar.py ├── SearchSpace ├── search_space_XXBL.py └── search_space_IDW_fixfc.py ├── benchmark_network_latency.py ├── Masternet.py ├── PlainNet ├── SuperResKXKX.py ├── SuperResK1KXK1.py ├── super_blocks.py └── __init__.py └── DataLoader └── autoaugment.py /hotfix/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ModelLoader/geffnet/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.0.0' 2 | -------------------------------------------------------------------------------- /misc/ZenNet_speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idstcv/ZenNAS/HEAD/misc/ZenNet_speed.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pyc.vscode/ 3 | .idea/ 4 | .vscode/ 5 | __pycache__/ 6 | upload_code_to_oss.sh 7 | download_code_from_oss.sh 8 | -------------------------------------------------------------------------------- /ZenNet/zennet_imagenet1k_flops600M_res224.txt: -------------------------------------------------------------------------------- 1 | SuperConvK3BNRELU(3,24,2,1)SuperResIDWE1K7(24,48,2,48,1)SuperResIDWE2K7(48,72,2,72,1)SuperResIDWE6K7(72,96,2,88,5)SuperResIDWE4K7(96,192,2,168,5)SuperConvK1BNRELU(192,2048,1,1) -------------------------------------------------------------------------------- /ZenNet/zennet_imagenet1k_flops400M_res224.txt: -------------------------------------------------------------------------------- 1 | SuperConvK3BNRELU(3,16,2,1)SuperResIDWE1K7(16,40,2,40,1)SuperResIDWE1K7(40,64,2,64,1)SuperResIDWE4K7(64,96,2,96,5)SuperResIDWE2K7(96,224,2,224,5)SuperConvK1BNRELU(224,2048,1,1) 2 | -------------------------------------------------------------------------------- /ZenNet/zennet_cifar_model_size05M_res32.txt: -------------------------------------------------------------------------------- 1 | SuperConvK3BNRELU(3,64,1,1)SuperResK1K5K1(64,168,1,16,3)SuperResK1K3K1(168,80,2,32,4)SuperResK1K5K1(80,112,2,16,3)SuperResK1K5K1(112,144,1,24,3)SuperResK1K3K1(144,32,2,40,1)SuperConvK1BNRELU(32,512,1,1) 2 | -------------------------------------------------------------------------------- /ZenNet/zennet_imagenet1k_flops900M_res224.txt: -------------------------------------------------------------------------------- 1 | SuperConvK3BNRELU(3,16,2,1)SuperResIDWE1K7(16,48,2,72,1)SuperResIDWE2K7(48,72,2,64,3)SuperResIDWE2K7(72,152,2,144,3)SuperResIDWE2K7(152,360,2,352,4)SuperResIDWE4K7(360,288,1,264,3)SuperConvK1BNRELU(288,2048,1,1) 2 | -------------------------------------------------------------------------------- /ZenNet/zennet_imagenet1k_latency01ms_res224.txt: -------------------------------------------------------------------------------- 1 | SuperConvK3BNRELU(3,24,2,1)SuperResK3K3(24,32,2,64,1)SuperResK5K5(32,64,2,32,1)SuperResK5K5(64,168,2,96,1)SuperResK1K5K1(168,320,1,120,1)SuperResK1K5K1(320,640,2,304,3)SuperResK1K5K1(640,512,1,384,1)SuperConvK1BNRELU(512,2384,1,1) -------------------------------------------------------------------------------- /ZenNet/zennet_cifar_model_size1M_res32.txt: -------------------------------------------------------------------------------- 1 | SuperConvK3BNRELU(3,88,1,1)SuperResK1K7K1(88,120,1,16,1)SuperResK1K7K1(120,192,2,16,3)SuperResK1K5K1(192,224,1,24,4)SuperResK1K5K1(224,96,2,24,2)SuperResK1K3K1(96,168,2,40,3)SuperResK1K3K1(168,112,1,48,3)SuperConvK1BNRELU(112,512,1,1) 2 | -------------------------------------------------------------------------------- /ZenNet/zennet_imagenet1k_latency02ms_res224.txt: -------------------------------------------------------------------------------- 1 | SuperConvK3BNRELU(3,24,2,1)SuperResK1K5K1(24,32,2,32,1)SuperResK1K7K1(32,104,2,64,1)SuperResK1K5K1(104,512,2,160,1)SuperResK1K5K1(512,344,1,192,1)SuperResK1K5K1(344,688,2,320,4)SuperResK1K5K1(688,680,1,304,3)SuperConvK1BNRELU(680,2552,1,1) -------------------------------------------------------------------------------- /ZenNet/zennet_cifar_model_size2M_res32.txt: -------------------------------------------------------------------------------- 1 | SuperConvK3BNRELU(3,32,1,1)SuperResK1K5K1(32,120,1,40,1)SuperResK1K5K1(120,176,2,32,3)SuperResK1K7K1(176,272,1,24,3)SuperResK1K3K1(272,176,1,56,3)SuperResK1K3K1(176,176,1,64,4)SuperResK1K5K1(176,216,2,40,2)SuperResK1K3K1(216,72,2,56,2)SuperConvK1BNRELU(72,512,1,1) 2 | -------------------------------------------------------------------------------- /ZenNet/zennet_imagenet1k_latency03ms_res224.txt: -------------------------------------------------------------------------------- 1 | SuperConvK3BNRELU(3,24,2,1)SuperResK1K5K1(24,64,2,32,1)SuperResK1K3K1(64,128,2,128,1)SuperResK1K7K1(128,432,2,128,1)SuperResK1K5K1(432,272,1,160,1)SuperResK1K5K1(272,848,2,384,4)SuperResK1K5K1(848,848,1,320,3)SuperResK1K5K1(848,456,1,320,3)SuperConvK1BNRELU(456,6704,1,1) -------------------------------------------------------------------------------- /ZenNet/zennet_imagenet1k_latency05ms_res224.txt: -------------------------------------------------------------------------------- 1 | SuperConvK3BNRELU(3,8,2,1)SuperResK1K7K1(8,64,2,32,1)SuperResK1K3K1(64,152,2,128,1)SuperResK1K5K1(152,640,2,192,4)SuperResK1K5K1(640,640,1,192,2)SuperResK1K5K1(640,1536,2,384,4)SuperResK1K5K1(1536,816,1,384,3)SuperResK1K5K1(816,816,1,384,3)SuperConvK1BNRELU(816,5304,1,1) -------------------------------------------------------------------------------- /ZenNet/zennet_imagenet1k_latency08ms_res224.txt: -------------------------------------------------------------------------------- 1 | SuperConvK3BNRELU(3,16,2,1)SuperResK1K5K1(16,64,2,64,1)SuperResK1K3K1(64,240,2,128,2)SuperResK1K7K1(240,640,2,160,3)SuperResK1K7K1(640,768,1,192,4)SuperResK1K5K1(768,1536,2,384,5)SuperResK1K5K1(1536,1536,1,384,3)SuperResK1K5K1(1536,2304,1,384,5)SuperConvK1BNRELU(2304,4912,1,1) -------------------------------------------------------------------------------- /ZenNet/zennet_imagenet1k_latency12ms_res224.txt: -------------------------------------------------------------------------------- 1 | SuperConvK3BNRELU(3,32,2,1)SuperResK1K5K1(32,80,2,32,1)SuperResK1K7K1(80,432,2,128,5)SuperResK1K7K1(432,640,2,192,3)SuperResK1K7K1(640,1008,1,160,5)SuperResK1K7K1(1008,976,1,160,4)SuperResK1K5K1(976,2304,2,384,5)SuperResK1K5K1(2304,2496,1,384,5)SuperConvK1BNRELU(2496,3072,1,1) -------------------------------------------------------------------------------- /ModelLoader/geffnet/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The geffnet module is modified from: 3 | https://github.com/rwightman/gen-efficientnet-pytorch 4 | ''' 5 | 6 | import os, sys 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 8 | 9 | from .gen_efficientnet import * 10 | from .mobilenetv3 import * 11 | from .model_factory import create_model 12 | from .config import is_exportable, is_scriptable, set_exportable, set_scriptable 13 | from .activations import * -------------------------------------------------------------------------------- /analyze_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | 6 | import os, sys 7 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 8 | 9 | import logging 10 | import ModelLoader 11 | import global_utils 12 | from ptflops import get_model_complexity_info 13 | 14 | 15 | def main(opt, argv): 16 | model = ModelLoader.get_model(opt, argv) 17 | flops, params = get_model_complexity_info(model, (3, opt.input_image_size, opt.input_image_size), 18 | as_strings=False, 19 | print_per_layer_stat=True) 20 | print('Flops: {:4g}'.format(flops)) 21 | print('Params: {:4g}'.format(params)) 22 | 23 | if __name__ == "__main__": 24 | opt = global_utils.parse_cmd_options(sys.argv) 25 | 26 | main(opt, sys.argv) -------------------------------------------------------------------------------- /ModelLoader/geffnet/model_factory.py: -------------------------------------------------------------------------------- 1 | from .config import set_layer_config 2 | from .helpers import load_checkpoint 3 | 4 | from .gen_efficientnet import * 5 | from .mobilenetv3 import * 6 | 7 | 8 | def create_model( 9 | model_name='mnasnet_100', 10 | pretrained=None, 11 | num_classes=1000, 12 | in_chans=3, 13 | checkpoint_path='', 14 | **kwargs): 15 | 16 | model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs) 17 | 18 | if model_name in globals(): 19 | create_fn = globals()[model_name] 20 | model = create_fn(**model_kwargs) 21 | else: 22 | raise RuntimeError('Unknown model (%s)' % model_name) 23 | 24 | if checkpoint_path and not pretrained: 25 | load_checkpoint(model, checkpoint_path) 26 | 27 | return model 28 | -------------------------------------------------------------------------------- /scripts/Zen_NAS_ImageNet_latency0.1ms.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | save_dir=../save_dir/Zen_NAS_ImageNet_latency0.1ms 8 | mkdir -p ${save_dir} 9 | 10 | 11 | resolution=224 12 | budget_latency=1e-4 13 | max_layers=10 14 | population_size=512 15 | epochs=480 16 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for sufficient searching 17 | 18 | echo "SuperConvK3BNRELU(3,32,2,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,128,2,64,1)\ 19 | SuperResK3K3(128,256,2,128,1)SuperResK3K3(256,512,2,256,1)\ 20 | SuperConvK1BNRELU(256,512,1,1)" > ${save_dir}/init_plainnet.txt 21 | 22 | python evolution_search.py --gpu 0 \ 23 | --zero_shot_score Zen \ 24 | --search_space SearchSpace/search_space_XXBL.py \ 25 | --budget_latency ${budget_latency} \ 26 | --max_layers ${max_layers} \ 27 | --batch_size 64 \ 28 | --input_image_size ${resolution} \ 29 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 30 | --num_classes 1000 \ 31 | --evolution_max_iter ${evolution_max_iter} \ 32 | --population_size ${population_size} \ 33 | --save_dir ${save_dir} 34 | 35 | 36 | python analyze_model.py \ 37 | --input_image_size 224 \ 38 | --num_classes 1000 \ 39 | --arch Masternet.py:MasterNet \ 40 | --plainnet_struct_txt ${save_dir}/best_structure.txt 41 | 42 | 43 | python ts_train_image_classification.py --dataset imagenet --num_classes 1000 \ 44 | --dist_mode single --workers_per_gpu 6 \ 45 | --input_image_size ${resolution} --epochs ${epochs} --warmup 5 \ 46 | --optimizer sgd --bn_momentum 0.01 --wd 4e-5 --nesterov --weight_init custom \ 47 | --label_smoothing --random_erase --mixup --auto_augment \ 48 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 49 | --arch Masternet.py:MasterNet \ 50 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 51 | --teacher_arch geffnet_tf_efficientnet_b3_ns \ 52 | --teacher_pretrained \ 53 | --teacher_input_image_size 320 \ 54 | --teacher_feature_weight 1.0 \ 55 | --teacher_logit_weight 1.0 \ 56 | --ts_proj_no_relu \ 57 | --ts_proj_no_bn \ 58 | --target_downsample_ratio 16 \ 59 | --batch_size_per_gpu 64 --save_dir ${save_dir}/ts_effnet_b3ns_epochs${epochs} -------------------------------------------------------------------------------- /scripts/Zen_NAS_ImageNet_latency0.2ms.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | save_dir=../save_dir/Zen_NAS_ImageNet_latency0.2ms 8 | mkdir -p ${save_dir} 9 | 10 | 11 | resolution=224 12 | budget_latency=2e-4 13 | max_layers=13 14 | population_size=512 15 | epochs=480 16 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for sufficient searching 17 | 18 | echo "SuperConvK3BNRELU(3,32,2,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,128,2,64,1)\ 19 | SuperResK3K3(128,256,2,128,1)SuperResK3K3(256,512,2,256,1)\ 20 | SuperConvK1BNRELU(256,512,1,1)" > ${save_dir}/init_plainnet.txt 21 | 22 | python evolution_search.py --gpu 0 \ 23 | --zero_shot_score Zen \ 24 | --search_space SearchSpace/search_space_XXBL.py \ 25 | --budget_latency ${budget_latency} \ 26 | --max_layers ${max_layers} \ 27 | --batch_size 64 \ 28 | --input_image_size ${resolution} \ 29 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 30 | --num_classes 1000 \ 31 | --evolution_max_iter ${evolution_max_iter} \ 32 | --population_size ${population_size} \ 33 | --save_dir ${save_dir} 34 | 35 | 36 | python analyze_model.py \ 37 | --input_image_size 224 \ 38 | --num_classes 1000 \ 39 | --arch Masternet.py:MasterNet \ 40 | --plainnet_struct_txt ${save_dir}/best_structure.txt 41 | 42 | 43 | python ts_train_image_classification.py --dataset imagenet --num_classes 1000 \ 44 | --dist_mode single --workers_per_gpu 6 \ 45 | --input_image_size ${resolution} --epochs ${epochs} --warmup 5 \ 46 | --optimizer sgd --bn_momentum 0.01 --wd 4e-5 --nesterov --weight_init custom \ 47 | --label_smoothing --random_erase --mixup --auto_augment \ 48 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 49 | --arch Masternet.py:MasterNet \ 50 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 51 | --teacher_arch geffnet_tf_efficientnet_b3_ns \ 52 | --teacher_pretrained \ 53 | --teacher_input_image_size 320 \ 54 | --teacher_feature_weight 1.0 \ 55 | --teacher_logit_weight 1.0 \ 56 | --ts_proj_no_relu \ 57 | --ts_proj_no_bn \ 58 | --target_downsample_ratio 16 \ 59 | --batch_size_per_gpu 64 --save_dir ${save_dir}/ts_effnet_b3ns_epochs${epochs} -------------------------------------------------------------------------------- /scripts/Zen_NAS_ImageNet_latency0.3ms.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | save_dir=../save_dir/Zen_NAS_ImageNet_latency0.3ms 8 | mkdir -p ${save_dir} 9 | 10 | 11 | resolution=224 12 | budget_latency=3e-4 13 | max_layers=16 14 | population_size=512 15 | epochs=480 16 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for sufficient searching 17 | 18 | echo "SuperConvK3BNRELU(3,32,2,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,128,2,64,1)\ 19 | SuperResK3K3(128,256,2,128,1)SuperResK3K3(256,512,2,256,1)\ 20 | SuperConvK1BNRELU(256,512,1,1)" > ${save_dir}/init_plainnet.txt 21 | 22 | python evolution_search.py --gpu 0 \ 23 | --zero_shot_score Zen \ 24 | --search_space SearchSpace/search_space_XXBL.py \ 25 | --budget_latency ${budget_latency} \ 26 | --max_layers ${max_layers} \ 27 | --batch_size 64 \ 28 | --input_image_size ${resolution} \ 29 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 30 | --num_classes 1000 \ 31 | --evolution_max_iter ${evolution_max_iter} \ 32 | --population_size ${population_size} \ 33 | --save_dir ${save_dir} 34 | 35 | 36 | python analyze_model.py \ 37 | --input_image_size 224 \ 38 | --num_classes 1000 \ 39 | --arch Masternet.py:MasterNet \ 40 | --plainnet_struct_txt ${save_dir}/best_structure.txt 41 | 42 | 43 | python ts_train_image_classification.py --dataset imagenet --num_classes 1000 \ 44 | --dist_mode single --workers_per_gpu 6 \ 45 | --input_image_size ${resolution} --epochs ${epochs} --warmup 5 \ 46 | --optimizer sgd --bn_momentum 0.01 --wd 4e-5 --nesterov --weight_init custom \ 47 | --label_smoothing --random_erase --mixup --auto_augment \ 48 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 49 | --arch Masternet.py:MasterNet \ 50 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 51 | --teacher_arch geffnet_tf_efficientnet_b3_ns \ 52 | --teacher_pretrained \ 53 | --teacher_input_image_size 320 \ 54 | --teacher_feature_weight 1.0 \ 55 | --teacher_logit_weight 1.0 \ 56 | --ts_proj_no_relu \ 57 | --ts_proj_no_bn \ 58 | --target_downsample_ratio 16 \ 59 | --batch_size_per_gpu 64 --save_dir ${save_dir}/ts_effnet_b3ns_epochs${epochs} -------------------------------------------------------------------------------- /scripts/Zen_NAS_ImageNet_latency0.5ms.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | save_dir=../save_dir/Zen_NAS_ImageNet_latency0.5ms 8 | mkdir -p ${save_dir} 9 | 10 | 11 | resolution=224 12 | budget_latency=5e-4 13 | max_layers=20 14 | population_size=512 15 | epochs=480 16 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for sufficient searching 17 | 18 | echo "SuperConvK3BNRELU(3,32,2,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,128,2,64,1)\ 19 | SuperResK3K3(128,256,2,128,1)SuperResK3K3(256,512,2,256,1)\ 20 | SuperConvK1BNRELU(256,512,1,1)" > ${save_dir}/init_plainnet.txt 21 | 22 | python evolution_search.py --gpu 0 \ 23 | --zero_shot_score Zen \ 24 | --search_space SearchSpace/search_space_XXBL.py \ 25 | --budget_latency ${budget_latency} \ 26 | --max_layers ${max_layers} \ 27 | --batch_size 64 \ 28 | --input_image_size ${resolution} \ 29 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 30 | --num_classes 1000 \ 31 | --evolution_max_iter ${evolution_max_iter} \ 32 | --population_size ${population_size} \ 33 | --save_dir ${save_dir} 34 | 35 | 36 | python analyze_model.py \ 37 | --input_image_size 224 \ 38 | --num_classes 1000 \ 39 | --arch Masternet.py:MasterNet \ 40 | --plainnet_struct_txt ${save_dir}/best_structure.txt 41 | 42 | 43 | python ts_train_image_classification.py --dataset imagenet --num_classes 1000 \ 44 | --dist_mode single --workers_per_gpu 6 \ 45 | --input_image_size ${resolution} --epochs ${epochs} --warmup 5 \ 46 | --optimizer sgd --bn_momentum 0.01 --wd 4e-5 --nesterov --weight_init custom \ 47 | --label_smoothing --random_erase --mixup --auto_augment \ 48 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 49 | --arch Masternet.py:MasterNet \ 50 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 51 | --teacher_arch geffnet_tf_efficientnet_b3_ns \ 52 | --teacher_pretrained \ 53 | --teacher_input_image_size 320 \ 54 | --teacher_feature_weight 1.0 \ 55 | --teacher_logit_weight 1.0 \ 56 | --ts_proj_no_relu \ 57 | --ts_proj_no_bn \ 58 | --target_downsample_ratio 16 \ 59 | --batch_size_per_gpu 64 --save_dir ${save_dir}/ts_effnet_b3ns_epochs${epochs} -------------------------------------------------------------------------------- /scripts/Zen_NAS_ImageNet_latency0.8ms.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | save_dir=../save_dir/Zen_NAS_ImageNet_latency0.8ms 8 | mkdir -p ${save_dir} 9 | 10 | 11 | resolution=224 12 | budget_latency=8e-4 13 | max_layers=25 14 | population_size=512 15 | epochs=480 16 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for sufficient searching 17 | 18 | echo "SuperConvK3BNRELU(3,32,2,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,128,2,64,1)\ 19 | SuperResK3K3(128,256,2,128,1)SuperResK3K3(256,512,2,256,1)\ 20 | SuperConvK1BNRELU(256,512,1,1)" > ${save_dir}/init_plainnet.txt 21 | 22 | python evolution_search.py --gpu 0 \ 23 | --zero_shot_score Zen \ 24 | --search_space SearchSpace/search_space_XXBL.py \ 25 | --budget_latency ${budget_latency} \ 26 | --max_layers ${max_layers} \ 27 | --batch_size 64 \ 28 | --input_image_size ${resolution} \ 29 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 30 | --num_classes 1000 \ 31 | --evolution_max_iter ${evolution_max_iter} \ 32 | --population_size ${population_size} \ 33 | --save_dir ${save_dir} 34 | 35 | 36 | python analyze_model.py \ 37 | --input_image_size 224 \ 38 | --num_classes 1000 \ 39 | --arch Masternet.py:MasterNet \ 40 | --plainnet_struct_txt ${save_dir}/best_structure.txt 41 | 42 | 43 | python ts_train_image_classification.py --dataset imagenet --num_classes 1000 \ 44 | --dist_mode single --workers_per_gpu 6 \ 45 | --input_image_size ${resolution} --epochs ${epochs} --warmup 5 \ 46 | --optimizer sgd --bn_momentum 0.01 --wd 4e-5 --nesterov --weight_init custom \ 47 | --label_smoothing --random_erase --mixup --auto_augment \ 48 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 49 | --arch Masternet.py:MasterNet \ 50 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 51 | --teacher_arch geffnet_tf_efficientnet_b3_ns \ 52 | --teacher_pretrained \ 53 | --teacher_input_image_size 320 \ 54 | --teacher_feature_weight 1.0 \ 55 | --teacher_logit_weight 1.0 \ 56 | --ts_proj_no_relu \ 57 | --ts_proj_no_bn \ 58 | --target_downsample_ratio 16 \ 59 | --batch_size_per_gpu 64 --save_dir ${save_dir}/ts_effnet_b3ns_epochs${epochs} -------------------------------------------------------------------------------- /scripts/Zen_NAS_ImageNet_latency1.2ms.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | save_dir=../save_dir/Zen_NAS_ImageNet_latency1.2ms 8 | mkdir -p ${save_dir} 9 | 10 | 11 | resolution=224 12 | budget_latency=12e-4 13 | max_layers=30 14 | population_size=512 15 | epochs=480 16 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for sufficient searching 17 | 18 | echo "SuperConvK3BNRELU(3,32,2,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,128,2,64,1)\ 19 | SuperResK3K3(128,256,2,128,1)SuperResK3K3(256,512,2,256,1)\ 20 | SuperConvK1BNRELU(256,512,1,1)" > ${save_dir}/init_plainnet.txt 21 | 22 | python evolution_search.py --gpu 0 \ 23 | --zero_shot_score Zen \ 24 | --search_space SearchSpace/search_space_XXBL.py \ 25 | --budget_latency ${budget_latency} \ 26 | --max_layers ${max_layers} \ 27 | --batch_size 64 \ 28 | --input_image_size ${resolution} \ 29 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 30 | --num_classes 1000 \ 31 | --evolution_max_iter ${evolution_max_iter} \ 32 | --population_size ${population_size} \ 33 | --save_dir ${save_dir} 34 | 35 | 36 | python analyze_model.py \ 37 | --input_image_size 224 \ 38 | --num_classes 1000 \ 39 | --arch Masternet.py:MasterNet \ 40 | --plainnet_struct_txt ${save_dir}/best_structure.txt 41 | 42 | 43 | python ts_train_image_classification.py --dataset imagenet --num_classes 1000 \ 44 | --dist_mode single --workers_per_gpu 6 \ 45 | --input_image_size ${resolution} --epochs ${epochs} --warmup 5 \ 46 | --optimizer sgd --bn_momentum 0.01 --wd 4e-5 --nesterov --weight_init custom \ 47 | --label_smoothing --random_erase --mixup --auto_augment \ 48 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 49 | --arch Masternet.py:MasterNet \ 50 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 51 | --teacher_arch geffnet_tf_efficientnet_b3_ns \ 52 | --teacher_pretrained \ 53 | --teacher_input_image_size 320 \ 54 | --teacher_feature_weight 1.0 \ 55 | --teacher_logit_weight 1.0 \ 56 | --ts_proj_no_relu \ 57 | --ts_proj_no_bn \ 58 | --target_downsample_ratio 16 \ 59 | --batch_size_per_gpu 64 --save_dir ${save_dir}/ts_effnet_b3ns_epochs${epochs} -------------------------------------------------------------------------------- /scripts/Zen_NAS_ImageNet_flops400M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | save_dir=../save_dir/Zen_NAS_ImageNet_flops400M 8 | mkdir -p ${save_dir} 9 | 10 | 11 | resolution=224 12 | budget_flops=400e6 13 | max_layers=14 14 | population_size=512 15 | epochs=480 16 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for sufficient searching 17 | 18 | echo "SuperConvK3BNRELU(3,8,2,1)SuperResIDWE6K3(8,32,2,8,1)SuperResIDWE6K3(32,48,2,32,1)\ 19 | SuperResIDWE6K3(48,96,2,48,1)SuperResIDWE6K3(96,128,2,96,1)\ 20 | SuperConvK1BNRELU(128,2048,1,1)" > ${save_dir}/init_plainnet.txt 21 | 22 | python evolution_search.py --gpu 0 \ 23 | --zero_shot_score Zen \ 24 | --search_space SearchSpace/search_space_IDW_fixfc.py \ 25 | --budget_flops ${budget_flops} \ 26 | --max_layers ${max_layers} \ 27 | --batch_size 64 \ 28 | --input_image_size ${resolution} \ 29 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 30 | --num_classes 1000 \ 31 | --evolution_max_iter ${evolution_max_iter} \ 32 | --population_size ${population_size} \ 33 | --save_dir ${save_dir} 34 | 35 | 36 | python analyze_model.py \ 37 | --input_image_size 224 \ 38 | --num_classes 1000 \ 39 | --arch Masternet.py:MasterNet \ 40 | --plainnet_struct_txt ${save_dir}/best_structure.txt 41 | 42 | 43 | python ts_train_image_classification.py --dataset imagenet --num_classes 1000 \ 44 | --dist_mode single --workers_per_gpu 6 \ 45 | --input_image_size ${resolution} --epochs ${epochs} --warmup 5 \ 46 | --optimizer sgd --bn_momentum 0.01 --wd 4e-5 --nesterov --weight_init custom \ 47 | --label_smoothing --random_erase --mixup --auto_augment \ 48 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 49 | --arch Masternet.py:MasterNet \ 50 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 51 | --teacher_arch geffnet_tf_efficientnet_b3_ns \ 52 | --teacher_pretrained \ 53 | --teacher_input_image_size 320 \ 54 | --teacher_feature_weight 1.0 \ 55 | --teacher_logit_weight 1.0 \ 56 | --ts_proj_no_relu \ 57 | --ts_proj_no_bn \ 58 | --use_se \ 59 | --target_downsample_ratio 16 \ 60 | --batch_size_per_gpu 64 --save_dir ${save_dir}/ts_effnet_b3ns_epochs${epochs} -------------------------------------------------------------------------------- /scripts/Zen_NAS_ImageNet_flops600M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | save_dir=../save_dir/Zen_NAS_ImageNet_flops600M 8 | mkdir -p ${save_dir} 9 | 10 | 11 | resolution=224 12 | budget_flops=600e6 13 | max_layers=14 14 | population_size=512 15 | epochs=480 16 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for sufficient searching 17 | 18 | echo "SuperConvK3BNRELU(3,8,2,1)SuperResIDWE6K3(8,32,2,8,1)SuperResIDWE6K3(32,48,2,32,1)\ 19 | SuperResIDWE6K3(48,96,2,48,1)SuperResIDWE6K3(96,128,2,96,1)\ 20 | SuperConvK1BNRELU(128,2048,1,1)" > ${save_dir}/init_plainnet.txt 21 | 22 | python evolution_search.py --gpu 0 \ 23 | --zero_shot_score Zen \ 24 | --search_space SearchSpace/search_space_IDW_fixfc.py \ 25 | --budget_flops ${budget_flops} \ 26 | --max_layers ${max_layers} \ 27 | --batch_size 64 \ 28 | --input_image_size ${resolution} \ 29 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 30 | --num_classes 1000 \ 31 | --evolution_max_iter ${evolution_max_iter} \ 32 | --population_size ${population_size} \ 33 | --save_dir ${save_dir} 34 | 35 | 36 | python analyze_model.py \ 37 | --input_image_size 224 \ 38 | --num_classes 1000 \ 39 | --arch Masternet.py:MasterNet \ 40 | --plainnet_struct_txt ${save_dir}/best_structure.txt 41 | 42 | 43 | python ts_train_image_classification.py --dataset imagenet --num_classes 1000 \ 44 | --dist_mode single --workers_per_gpu 6 \ 45 | --input_image_size ${resolution} --epochs ${epochs} --warmup 5 \ 46 | --optimizer sgd --bn_momentum 0.01 --wd 4e-5 --nesterov --weight_init custom \ 47 | --label_smoothing --random_erase --mixup --auto_augment \ 48 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 49 | --arch Masternet.py:MasterNet \ 50 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 51 | --teacher_arch geffnet_tf_efficientnet_b3_ns \ 52 | --teacher_pretrained \ 53 | --teacher_input_image_size 320 \ 54 | --teacher_feature_weight 1.0 \ 55 | --teacher_logit_weight 1.0 \ 56 | --ts_proj_no_relu \ 57 | --ts_proj_no_bn \ 58 | --use_se \ 59 | --target_downsample_ratio 16 \ 60 | --batch_size_per_gpu 64 --save_dir ${save_dir}/ts_effnet_b3ns_epochs${epochs} -------------------------------------------------------------------------------- /scripts/Zen_NAS_ImageNet_flops800M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | save_dir=../save_dir/Zen_NAS_ImageNet_flops800M 8 | mkdir -p ${save_dir} 9 | 10 | 11 | resolution=224 12 | budget_flops=800e6 13 | max_layers=16 14 | population_size=512 15 | epochs=480 16 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for sufficient searching 17 | 18 | echo "SuperConvK3BNRELU(3,8,2,1)SuperResIDWE6K3(8,32,2,8,1)SuperResIDWE6K3(32,48,2,32,1)\ 19 | SuperResIDWE6K3(48,96,2,48,1)SuperResIDWE6K3(96,128,2,96,1)\ 20 | SuperConvK1BNRELU(128,2048,1,1)" > ${save_dir}/init_plainnet.txt 21 | 22 | python evolution_search.py --gpu 0 \ 23 | --zero_shot_score Zen \ 24 | --search_space SearchSpace/search_space_IDW_fixfc.py \ 25 | --budget_flops ${budget_flops} \ 26 | --max_layers ${max_layers} \ 27 | --batch_size 64 \ 28 | --input_image_size ${resolution} \ 29 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 30 | --num_classes 1000 \ 31 | --evolution_max_iter ${evolution_max_iter} \ 32 | --population_size ${population_size} \ 33 | --save_dir ${save_dir} 34 | 35 | 36 | python analyze_model.py \ 37 | --input_image_size 224 \ 38 | --num_classes 1000 \ 39 | --arch Masternet.py:MasterNet \ 40 | --plainnet_struct_txt ${save_dir}/best_structure.txt 41 | 42 | 43 | python ts_train_image_classification.py --dataset imagenet --num_classes 1000 \ 44 | --dist_mode single --workers_per_gpu 6 \ 45 | --input_image_size ${resolution} --epochs ${epochs} --warmup 5 \ 46 | --optimizer sgd --bn_momentum 0.01 --wd 4e-5 --nesterov --weight_init custom \ 47 | --label_smoothing --random_erase --mixup --auto_augment \ 48 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 49 | --arch Masternet.py:MasterNet \ 50 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 51 | --teacher_arch geffnet_tf_efficientnet_b3_ns \ 52 | --teacher_pretrained \ 53 | --teacher_input_image_size 320 \ 54 | --teacher_feature_weight 1.0 \ 55 | --teacher_logit_weight 1.0 \ 56 | --ts_proj_no_relu \ 57 | --ts_proj_no_bn \ 58 | --use_se \ 59 | --target_downsample_ratio 16 \ 60 | --batch_size_per_gpu 64 --save_dir ${save_dir}/ts_effnet_b3ns_epochs${epochs} -------------------------------------------------------------------------------- /scripts/Zen_NAS_cifar_params2M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | budget_model_size=2e6 8 | max_layers=20 9 | population_size=512 10 | evolution_max_iter=480000 11 | 12 | save_dir=../save_dir/Zen_NAS_cifar_params2M 13 | mkdir -p ${save_dir} 14 | 15 | echo "SuperConvK3BNRELU(3,8,1,1)SuperResK3K3(8,16,1,8,1)SuperResK3K3(16,32,2,16,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,64,2,32,1)SuperConvK1BNRELU(64,128,1,1)" \ 16 | > ${save_dir}/init_plainnet.txt 17 | 18 | python evolution_search.py --gpu 0 \ 19 | --zero_shot_score Zen \ 20 | --search_space SearchSpace/search_space_XXBL.py \ 21 | --budget_model_size ${budget_model_size} \ 22 | --max_layers ${max_layers} \ 23 | --batch_size 64 \ 24 | --input_image_size 32 \ 25 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 26 | --num_classes 100 \ 27 | --evolution_max_iter ${evolution_max_iter} \ 28 | --population_size ${population_size} \ 29 | --save_dir ${save_dir} 30 | 31 | 32 | python analyze_model.py \ 33 | --input_image_size 32 \ 34 | --num_classes 100 \ 35 | --arch Masternet.py:MasterNet \ 36 | --plainnet_struct_txt ${save_dir}/best_structure.txt 37 | 38 | python train_image_classification.py --dataset cifar10 --num_classes 10 \ 39 | --dist_mode single --workers_per_gpu 6 \ 40 | --input_image_size 32 --epochs 1440 --warmup 5 \ 41 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 42 | --label_smoothing --random_erase --mixup --auto_augment \ 43 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 44 | --arch Masternet.py:MasterNet \ 45 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 46 | --batch_size_per_gpu 64 \ 47 | --save_dir ${save_dir}/cifar10_1440epochs 48 | 49 | python train_image_classification.py --dataset cifar100 --num_classes 100 \ 50 | --dist_mode single --workers_per_gpu 6 \ 51 | --input_image_size 32 --epochs 1440 --warmup 5 \ 52 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 53 | --label_smoothing --random_erase --mixup --auto_augment \ 54 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 55 | --arch Masternet.py:MasterNet \ 56 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 57 | --batch_size_per_gpu 64 \ 58 | --save_dir ${save_dir}/cifar100_1440epochs 59 | -------------------------------------------------------------------------------- /scripts/Flops_NAS_cifar_params1M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | budget_model_size=1e6 8 | max_layers=18 9 | population_size=512 10 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for 11 | 12 | save_dir=../save_dir/Flops_NAS_cifar_params1M 13 | mkdir -p ${save_dir} 14 | 15 | echo "SuperConvK3BNRELU(3,8,1,1)SuperResK3K3(8,16,1,8,1)SuperResK3K3(16,32,2,16,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,64,2,32,1)SuperConvK1BNRELU(64,128,1,1)" \ 16 | > ${save_dir}/init_plainnet.txt 17 | 18 | python evolution_search.py --gpu 0 \ 19 | --zero_shot_score Flops \ 20 | --search_space SearchSpace/search_space_XXBL.py \ 21 | --budget_model_size ${budget_model_size} \ 22 | --max_layers ${max_layers} \ 23 | --batch_size 64 \ 24 | --input_image_size 32 \ 25 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 26 | --num_classes 100 \ 27 | --evolution_max_iter ${evolution_max_iter} \ 28 | --population_size ${population_size} \ 29 | --save_dir ${save_dir} 30 | 31 | 32 | python analyze_model.py \ 33 | --input_image_size 32 \ 34 | --num_classes 100 \ 35 | --arch Masternet.py:MasterNet \ 36 | --plainnet_struct_txt ${save_dir}/best_structure.txt 37 | 38 | python train_image_classification.py --dataset cifar10 --num_classes 10 \ 39 | --dist_mode single --workers_per_gpu 6 \ 40 | --input_image_size 32 --epochs 1440 --warmup 5 \ 41 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 42 | --label_smoothing --random_erase --mixup --auto_augment \ 43 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 44 | --arch Masternet.py:MasterNet \ 45 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 46 | --batch_size_per_gpu 64 \ 47 | --save_dir ${save_dir}/cifar10_1440epochs 48 | 49 | python train_image_classification.py --dataset cifar100 --num_classes 100 \ 50 | --dist_mode single --workers_per_gpu 6 \ 51 | --input_image_size 32 --epochs 1440 --warmup 5 \ 52 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 53 | --label_smoothing --random_erase --mixup --auto_augment \ 54 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 55 | --arch Masternet.py:MasterNet \ 56 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 57 | --batch_size_per_gpu 64 \ 58 | --save_dir ${save_dir}/cifar100_1440epochs 59 | -------------------------------------------------------------------------------- /scripts/Params_NAS_cifar_params1M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | budget_model_size=1e6 8 | max_layers=18 9 | population_size=512 10 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for 11 | 12 | 13 | save_dir=../save_dir/Params_NAS_cifar_params1M 14 | mkdir -p ${save_dir} 15 | 16 | echo "SuperConvK3BNRELU(3,8,1,1)SuperResK3K3(8,16,1,8,1)SuperResK3K3(16,32,2,16,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,64,2,32,1)SuperConvK1BNRELU(64,128,1,1)" \ 17 | > ${save_dir}/init_plainnet.txt 18 | 19 | python evolution_search.py --gpu 0 \ 20 | --zero_shot_score Params \ 21 | --search_space SearchSpace/search_space_XXBL.py \ 22 | --budget_model_size ${budget_model_size} \ 23 | --max_layers ${max_layers} \ 24 | --batch_size 64 \ 25 | --input_image_size 32 \ 26 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 27 | --num_classes 100 \ 28 | --evolution_max_iter ${evolution_max_iter} \ 29 | --population_size ${population_size} \ 30 | --save_dir ${save_dir} 31 | 32 | 33 | python analyze_model.py \ 34 | --input_image_size 32 \ 35 | --num_classes 100 \ 36 | --arch Masternet.py:MasterNet \ 37 | --plainnet_struct_txt ${save_dir}/best_structure.txt 38 | 39 | python train_image_classification.py --dataset cifar10 --num_classes 10 \ 40 | --dist_mode single --workers_per_gpu 6 \ 41 | --input_image_size 32 --epochs 1440 --warmup 5 \ 42 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 43 | --label_smoothing --random_erase --mixup --auto_augment \ 44 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 45 | --arch Masternet.py:MasterNet \ 46 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 47 | --batch_size_per_gpu 64 \ 48 | --save_dir ${save_dir}/cifar10_1440epochs 49 | 50 | python train_image_classification.py --dataset cifar100 --num_classes 100 \ 51 | --dist_mode single --workers_per_gpu 6 \ 52 | --input_image_size 32 --epochs 1440 --warmup 5 \ 53 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 54 | --label_smoothing --random_erase --mixup --auto_augment \ 55 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 56 | --arch Masternet.py:MasterNet \ 57 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 58 | --batch_size_per_gpu 64 \ 59 | --save_dir ${save_dir}/cifar100_1440epochs 60 | -------------------------------------------------------------------------------- /scripts/Random_NAS_cifar_params1M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | budget_model_size=1e6 8 | max_layers=18 9 | population_size=512 10 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for 11 | 12 | 13 | save_dir=../save_dir/Random_NAS_cifar_params1M 14 | mkdir -p ${save_dir} 15 | 16 | echo "SuperConvK3BNRELU(3,8,1,1)SuperResK3K3(8,16,1,8,1)SuperResK3K3(16,32,2,16,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,64,2,32,1)SuperConvK1BNRELU(64,128,1,1)" \ 17 | > ${save_dir}/init_plainnet.txt 18 | 19 | python evolution_search.py --gpu 0 \ 20 | --zero_shot_score Random \ 21 | --search_space SearchSpace/search_space_XXBL.py \ 22 | --budget_model_size ${budget_model_size} \ 23 | --max_layers ${max_layers} \ 24 | --batch_size 64 \ 25 | --input_image_size 32 \ 26 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 27 | --num_classes 100 \ 28 | --evolution_max_iter ${evolution_max_iter} \ 29 | --population_size ${population_size} \ 30 | --save_dir ${save_dir} 31 | 32 | 33 | python analyze_model.py \ 34 | --input_image_size 32 \ 35 | --num_classes 100 \ 36 | --arch Masternet.py:MasterNet \ 37 | --plainnet_struct_txt ${save_dir}/best_structure.txt 38 | 39 | python train_image_classification.py --dataset cifar10 --num_classes 10 \ 40 | --dist_mode single --workers_per_gpu 6 \ 41 | --input_image_size 32 --epochs 1440 --warmup 5 \ 42 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 43 | --label_smoothing --random_erase --mixup --auto_augment \ 44 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 45 | --arch Masternet.py:MasterNet \ 46 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 47 | --batch_size_per_gpu 64 \ 48 | --save_dir ${save_dir}/cifar10_1440epochs 49 | 50 | python train_image_classification.py --dataset cifar100 --num_classes 100 \ 51 | --dist_mode single --workers_per_gpu 6 \ 52 | --input_image_size 32 --epochs 1440 --warmup 5 \ 53 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 54 | --label_smoothing --random_erase --mixup --auto_augment \ 55 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 56 | --arch Masternet.py:MasterNet \ 57 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 58 | --batch_size_per_gpu 64 \ 59 | --save_dir ${save_dir}/cifar100_1440epochs 60 | -------------------------------------------------------------------------------- /scripts/TE_NAS_cifar_params1M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | budget_model_size=1e6 8 | max_layers=18 9 | population_size=512 10 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for 11 | 12 | 13 | save_dir=../save_dir/TE_NAS_cifar_params1M 14 | mkdir -p ${save_dir} 15 | 16 | echo "SuperConvK3BNRELU(3,8,1,1)SuperResK3K3(8,16,1,8,1)SuperResK3K3(16,32,2,16,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,64,2,32,1)SuperConvK1BNRELU(64,128,1,1)" \ 17 | > ${save_dir}/init_plainnet.txt 18 | 19 | python evolution_search.py --gpu 0 \ 20 | --zero_shot_score TE-NAS \ 21 | --search_space SearchSpace/search_space_XXBL.py \ 22 | --budget_model_size ${budget_model_size} \ 23 | --max_layers ${max_layers} \ 24 | --batch_size 64 \ 25 | --input_image_size 32 \ 26 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 27 | --num_classes 100 \ 28 | --evolution_max_iter ${evolution_max_iter} \ 29 | --population_size ${population_size} \ 30 | --save_dir ${save_dir} 31 | 32 | 33 | python analyze_model.py \ 34 | --input_image_size 32 \ 35 | --num_classes 100 \ 36 | --arch Masternet.py:MasterNet \ 37 | --plainnet_struct_txt ${save_dir}/best_structure.txt 38 | 39 | python train_image_classification.py --dataset cifar10 --num_classes 10 \ 40 | --dist_mode single --workers_per_gpu 6 \ 41 | --input_image_size 32 --epochs 1440 --warmup 5 \ 42 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 43 | --label_smoothing --random_erase --mixup --auto_augment \ 44 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 45 | --arch Masternet.py:MasterNet \ 46 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 47 | --batch_size_per_gpu 64 \ 48 | --save_dir ${save_dir}/cifar10_1440epochs 49 | 50 | 51 | python train_image_classification.py --dataset cifar100 --num_classes 100 \ 52 | --dist_mode single --workers_per_gpu 6 \ 53 | --input_image_size 32 --epochs 1440 --warmup 5 \ 54 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 55 | --label_smoothing --random_erase --mixup --auto_augment \ 56 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 57 | --arch Masternet.py:MasterNet \ 58 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 59 | --batch_size_per_gpu 64 \ 60 | --save_dir ${save_dir}/cifar100_1440epochs 61 | -------------------------------------------------------------------------------- /scripts/Zen_NAS_cifar_params1M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | budget_model_size=1e6 8 | max_layers=18 9 | population_size=512 10 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for 11 | 12 | 13 | save_dir=../save_dir/Zen_NAS_cifar_params1M 14 | mkdir -p ${save_dir} 15 | 16 | echo "SuperConvK3BNRELU(3,8,1,1)SuperResK3K3(8,16,1,8,1)SuperResK3K3(16,32,2,16,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,64,2,32,1)SuperConvK1BNRELU(64,128,1,1)" \ 17 | > ${save_dir}/init_plainnet.txt 18 | 19 | python evolution_search.py --gpu 0 \ 20 | --zero_shot_score Zen \ 21 | --search_space SearchSpace/search_space_XXBL.py \ 22 | --budget_model_size ${budget_model_size} \ 23 | --max_layers ${max_layers} \ 24 | --batch_size 64 \ 25 | --input_image_size 32 \ 26 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 27 | --num_classes 100 \ 28 | --evolution_max_iter ${evolution_max_iter} \ 29 | --population_size ${population_size} \ 30 | --save_dir ${save_dir} 31 | 32 | 33 | python analyze_model.py \ 34 | --input_image_size 32 \ 35 | --num_classes 100 \ 36 | --arch Masternet.py:MasterNet \ 37 | --plainnet_struct_txt ${save_dir}/best_structure.txt 38 | 39 | python train_image_classification.py --dataset cifar10 --num_classes 10 \ 40 | --dist_mode single --workers_per_gpu 6 \ 41 | --input_image_size 32 --epochs 1440 --warmup 5 \ 42 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 43 | --label_smoothing --random_erase --mixup --auto_augment \ 44 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 45 | --arch Masternet.py:MasterNet \ 46 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 47 | --batch_size_per_gpu 64 \ 48 | --save_dir ${save_dir}/cifar10_1440epochs 49 | 50 | 51 | python train_image_classification.py --dataset cifar100 --num_classes 100 \ 52 | --dist_mode single --workers_per_gpu 6 \ 53 | --input_image_size 32 --epochs 1440 --warmup 5 \ 54 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 55 | --label_smoothing --random_erase --mixup --auto_augment \ 56 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 57 | --arch Masternet.py:MasterNet \ 58 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 59 | --batch_size_per_gpu 64 \ 60 | --save_dir ${save_dir}/cifar100_1440epochs 61 | -------------------------------------------------------------------------------- /scripts/NASWOT_NAS_cifar_params1M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | budget_model_size=1e6 8 | max_layers=18 9 | population_size=512 10 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for 11 | 12 | 13 | save_dir=../save_dir/NASWOT_NAS_cifar_params1M 14 | mkdir -p ${save_dir} 15 | 16 | echo "SuperConvK3BNRELU(3,8,1,1)SuperResK3K3(8,16,1,8,1)SuperResK3K3(16,32,2,16,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,64,2,32,1)SuperConvK1BNRELU(64,128,1,1)" \ 17 | > ${save_dir}/init_plainnet.txt 18 | 19 | 20 | python evolution_search.py --gpu 0 \ 21 | --zero_shot_score NASWOT \ 22 | --search_space SearchSpace/search_space_XXBL.py \ 23 | --budget_model_size ${budget_model_size} \ 24 | --max_layers ${max_layers} \ 25 | --batch_size 8 \ 26 | --input_image_size 32 \ 27 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 28 | --num_classes 100 \ 29 | --evolution_max_iter ${evolution_max_iter} \ 30 | --population_size ${population_size} \ 31 | --save_dir ${save_dir} 32 | 33 | 34 | python analyze_model.py \ 35 | --input_image_size 32 \ 36 | --num_classes 100 \ 37 | --arch Masternet.py:MasterNet \ 38 | --plainnet_struct_txt ${save_dir}/best_structure.txt 39 | 40 | python train_image_classification.py --dataset cifar10 --num_classes 10 \ 41 | --dist_mode single --workers_per_gpu 6 \ 42 | --input_image_size 32 --epochs 1440 --warmup 5 \ 43 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 44 | --label_smoothing --random_erase --mixup --auto_augment \ 45 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 46 | --arch Masternet.py:MasterNet \ 47 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 48 | --batch_size_per_gpu 64 \ 49 | --save_dir ${save_dir}/cifar10_1440epochs 50 | 51 | python train_image_classification.py --dataset cifar100 --num_classes 100 \ 52 | --dist_mode single --workers_per_gpu 6 \ 53 | --input_image_size 32 --epochs 1440 --warmup 5 \ 54 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 55 | --label_smoothing --random_erase --mixup --auto_augment \ 56 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 57 | --arch Masternet.py:MasterNet \ 58 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 59 | --batch_size_per_gpu 64 \ 60 | --save_dir ${save_dir}/cifar100_1440epochs 61 | -------------------------------------------------------------------------------- /scripts/Syncflow_NAS_cifar_params1M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | budget_model_size=1e6 8 | max_layers=18 9 | population_size=512 10 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for 11 | 12 | 13 | save_dir=../save_dir/Syncflow_NAS_cifar_params1M 14 | mkdir -p ${save_dir} 15 | 16 | echo "SuperConvK3BNRELU(3,8,1,1)SuperResK3K3(8,16,1,8,1)SuperResK3K3(16,32,2,16,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,64,2,32,1)SuperConvK1BNRELU(64,128,1,1)" \ 17 | > ${save_dir}/init_plainnet.txt 18 | 19 | python evolution_search.py --gpu 0 \ 20 | --zero_shot_score Syncflow \ 21 | --search_space SearchSpace/search_space_XXBL.py \ 22 | --budget_model_size ${budget_model_size} \ 23 | --max_layers ${max_layers} \ 24 | --batch_size 64 \ 25 | --input_image_size 32 \ 26 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 27 | --num_classes 100 \ 28 | --evolution_max_iter ${evolution_max_iter} \ 29 | --population_size ${population_size} \ 30 | --save_dir ${save_dir} 31 | 32 | 33 | python analyze_model.py \ 34 | --input_image_size 32 \ 35 | --num_classes 100 \ 36 | --arch Masternet.py:MasterNet \ 37 | --plainnet_struct_txt ${save_dir}/best_structure.txt 38 | 39 | python train_image_classification.py --dataset cifar10 --num_classes 10 \ 40 | --dist_mode single --workers_per_gpu 6 \ 41 | --input_image_size 32 --epochs 1440 --warmup 5 \ 42 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 43 | --label_smoothing --random_erase --mixup --auto_augment \ 44 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 45 | --arch Masternet.py:MasterNet \ 46 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 47 | --batch_size_per_gpu 64 \ 48 | --save_dir ${save_dir}/cifar10_1440epochs 49 | 50 | python train_image_classification.py --dataset cifar100 --num_classes 100 \ 51 | --dist_mode single --workers_per_gpu 6 \ 52 | --input_image_size 32 --epochs 1440 --warmup 5 \ 53 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 54 | --label_smoothing --random_erase --mixup --auto_augment \ 55 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 56 | --arch Masternet.py:MasterNet \ 57 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 58 | --batch_size_per_gpu 64 \ 59 | --save_dir ${save_dir}/cifar100_1440epochs 60 | -------------------------------------------------------------------------------- /scripts/GradNorm_NAS_cifar_params1M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | set -e 4 | 5 | cd ../ 6 | 7 | budget_model_size=1e6 8 | max_layers=18 9 | population_size=512 10 | evolution_max_iter=480000 # we suggest evolution_max_iter=480000 for 11 | 12 | 13 | save_dir=../save_dir/GradNorm_NAS_cifar_params1M 14 | mkdir -p ${save_dir} 15 | 16 | echo "SuperConvK3BNRELU(3,8,1,1)SuperResK3K3(8,16,1,8,1)SuperResK3K3(16,32,2,16,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,64,2,32,1)SuperConvK1BNRELU(64,128,1,1)" \ 17 | > ${save_dir}/init_plainnet.txt 18 | 19 | python evolution_search.py --gpu 0 \ 20 | --zero_shot_score GradNorm \ 21 | --search_space SearchSpace/search_space_XXBL.py \ 22 | --budget_model_size ${budget_model_size} \ 23 | --max_layers ${max_layers} \ 24 | --batch_size 64 \ 25 | --input_image_size 32 \ 26 | --plainnet_struct_txt ${save_dir}/init_plainnet.txt \ 27 | --num_classes 100 \ 28 | --evolution_max_iter ${evolution_max_iter} \ 29 | --population_size ${population_size} \ 30 | --save_dir ${save_dir} 31 | 32 | 33 | python analyze_model.py \ 34 | --input_image_size 32 \ 35 | --num_classes 100 \ 36 | --arch Masternet.py:MasterNet \ 37 | --plainnet_struct_txt ${save_dir}/best_structure.txt 38 | 39 | 40 | python train_image_classification.py --dataset cifar10 --num_classes 10 \ 41 | --dist_mode single --workers_per_gpu 6 \ 42 | --input_image_size 32 --epochs 1440 --warmup 5 \ 43 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 44 | --label_smoothing --random_erase --mixup --auto_augment \ 45 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 46 | --arch Masternet.py:MasterNet \ 47 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 48 | --batch_size_per_gpu 64 \ 49 | --save_dir ${save_dir}/cifar10_1440epochs 50 | 51 | 52 | python train_image_classification.py --dataset cifar100 --num_classes 100 \ 53 | --dist_mode single --workers_per_gpu 6 \ 54 | --input_image_size 32 --epochs 1440 --warmup 5 \ 55 | --optimizer sgd --bn_momentum 0.01 --wd 5e-4 --nesterov --weight_init custom \ 56 | --label_smoothing --random_erase --mixup --auto_augment \ 57 | --lr_per_256 0.1 --target_lr_per_256 0.0 --lr_mode cosine \ 58 | --arch Masternet.py:MasterNet \ 59 | --plainnet_struct_txt ${save_dir}/best_structure.txt \ 60 | --batch_size_per_gpu 64 \ 61 | --save_dir ${save_dir}/cifar100_1440epochs 62 | -------------------------------------------------------------------------------- /ZeroShotProxy/compute_gradnorm_score.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | 6 | import os, sys, time 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | import torch 9 | from torch import nn 10 | import numpy as np 11 | 12 | def network_weight_gaussian_init(net: nn.Module): 13 | with torch.no_grad(): 14 | for m in net.modules(): 15 | if isinstance(m, nn.Conv2d): 16 | nn.init.normal_(m.weight) 17 | if hasattr(m, 'bias') and m.bias is not None: 18 | nn.init.zeros_(m.bias) 19 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 20 | nn.init.ones_(m.weight) 21 | nn.init.zeros_(m.bias) 22 | elif isinstance(m, nn.Linear): 23 | nn.init.normal_(m.weight) 24 | if hasattr(m, 'bias') and m.bias is not None: 25 | nn.init.zeros_(m.bias) 26 | else: 27 | continue 28 | 29 | return net 30 | 31 | import torch.nn.functional as F 32 | def cross_entropy(logit, target): 33 | # target must be one-hot format!! 34 | prob_logit = F.log_softmax(logit, dim=1) 35 | loss = -(target * prob_logit).sum(dim=1).mean() 36 | return loss 37 | 38 | def compute_nas_score(gpu, model, resolution, batch_size): 39 | 40 | model.train() 41 | model.requires_grad_(True) 42 | 43 | model.zero_grad() 44 | 45 | if gpu is not None: 46 | torch.cuda.set_device(gpu) 47 | model = model.cuda(gpu) 48 | 49 | network_weight_gaussian_init(model) 50 | input = torch.randn(size=[batch_size, 3, resolution, resolution]) 51 | if gpu is not None: 52 | input = input.cuda(gpu) 53 | output = model(input) 54 | # y_true = torch.rand(size=[batch_size, output.shape[1]], device=torch.device('cuda:{}'.format(gpu))) + 1e-10 55 | # y_true = y_true / torch.sum(y_true, dim=1, keepdim=True) 56 | 57 | num_classes = output.shape[1] 58 | y = torch.randint(low=0, high=num_classes, size=[batch_size]) 59 | 60 | one_hot_y = F.one_hot(y, num_classes).float() 61 | if gpu is not None: 62 | one_hot_y = one_hot_y.cuda(gpu) 63 | 64 | loss = cross_entropy(output, one_hot_y) 65 | loss.backward() 66 | norm2_sum = 0 67 | with torch.no_grad(): 68 | for p in model.parameters(): 69 | if hasattr(p, 'grad') and p.grad is not None: 70 | norm2_sum += torch.norm(p.grad) ** 2 71 | 72 | grad_norm = float(torch.sqrt(norm2_sum)) 73 | 74 | return grad_norm 75 | 76 | 77 | -------------------------------------------------------------------------------- /ModelLoader/geffnet/activations/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations (jit) 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | __all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit', 18 | 'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit'] 19 | 20 | 21 | @torch.jit.script 22 | def swish_jit(x, inplace: bool = False): 23 | """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) 24 | and also as Swish (https://arxiv.org/abs/1710.05941). 25 | 26 | TODO Rename to SiLU with addition to PyTorch 27 | """ 28 | return x.mul(x.sigmoid()) 29 | 30 | 31 | @torch.jit.script 32 | def mish_jit(x, _inplace: bool = False): 33 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 34 | """ 35 | return x.mul(F.softplus(x).tanh()) 36 | 37 | 38 | class SwishJit(nn.Module): 39 | def __init__(self, inplace: bool = False): 40 | super(SwishJit, self).__init__() 41 | 42 | def forward(self, x): 43 | return swish_jit(x) 44 | 45 | 46 | class MishJit(nn.Module): 47 | def __init__(self, inplace: bool = False): 48 | super(MishJit, self).__init__() 49 | 50 | def forward(self, x): 51 | return mish_jit(x) 52 | 53 | 54 | @torch.jit.script 55 | def hard_sigmoid_jit(x, inplace: bool = False): 56 | # return F.relu6(x + 3.) / 6. 57 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 58 | 59 | 60 | class HardSigmoidJit(nn.Module): 61 | def __init__(self, inplace: bool = False): 62 | super(HardSigmoidJit, self).__init__() 63 | 64 | def forward(self, x): 65 | return hard_sigmoid_jit(x) 66 | 67 | 68 | @torch.jit.script 69 | def hard_swish_jit(x, inplace: bool = False): 70 | # return x * (F.relu6(x + 3.) / 6) 71 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 72 | 73 | 74 | class HardSwishJit(nn.Module): 75 | def __init__(self, inplace: bool = False): 76 | super(HardSwishJit, self).__init__() 77 | 78 | def forward(self, x): 79 | return hard_swish_jit(x) 80 | -------------------------------------------------------------------------------- /ModelLoader/geffnet/helpers.py: -------------------------------------------------------------------------------- 1 | """ Checkpoint loading / state_dict helpers 2 | Copyright 2020 Ross Wightman 3 | """ 4 | import torch 5 | import os 6 | from collections import OrderedDict 7 | try: 8 | from torch.hub import load_state_dict_from_url 9 | except ImportError: 10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 11 | 12 | 13 | def load_checkpoint(model, checkpoint_path): 14 | if checkpoint_path and os.path.isfile(checkpoint_path): 15 | print("=> Loading checkpoint '{}'".format(checkpoint_path)) 16 | checkpoint = torch.load(checkpoint_path) 17 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 18 | new_state_dict = OrderedDict() 19 | for k, v in checkpoint['state_dict'].items(): 20 | if k.startswith('module'): 21 | name = k[7:] # remove `module.` 22 | else: 23 | name = k 24 | new_state_dict[name] = v 25 | model.load_state_dict(new_state_dict) 26 | else: 27 | model.load_state_dict(checkpoint) 28 | print("=> Loaded checkpoint '{}'".format(checkpoint_path)) 29 | else: 30 | print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) 31 | raise FileNotFoundError() 32 | 33 | 34 | def load_pretrained(model, url, filter_fn=None, strict=True): 35 | if not url: 36 | print("=> Warning: Pretrained model URL is empty, using random initialization.") 37 | return 38 | 39 | state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu') 40 | 41 | input_conv = 'conv_stem' 42 | classifier = 'classifier' 43 | in_chans = getattr(model, input_conv).weight.shape[1] 44 | num_classes = getattr(model, classifier).weight.shape[0] 45 | 46 | input_conv_weight = input_conv + '.weight' 47 | pretrained_in_chans = state_dict[input_conv_weight].shape[1] 48 | if in_chans != pretrained_in_chans: 49 | if in_chans == 1: 50 | print('=> Converting pretrained input conv {} from {} to 1 channel'.format( 51 | input_conv_weight, pretrained_in_chans)) 52 | conv1_weight = state_dict[input_conv_weight] 53 | state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True) 54 | else: 55 | print('=> Discarding pretrained input conv {} since input channel count != {}'.format( 56 | input_conv_weight, pretrained_in_chans)) 57 | del state_dict[input_conv_weight] 58 | strict = False 59 | 60 | classifier_weight = classifier + '.weight' 61 | pretrained_num_classes = state_dict[classifier_weight].shape[0] 62 | if num_classes != pretrained_num_classes: 63 | print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes)) 64 | del state_dict[classifier_weight] 65 | del state_dict[classifier + '.bias'] 66 | strict = False 67 | 68 | if filter_fn is not None: 69 | state_dict = filter_fn(state_dict) 70 | 71 | model.load_state_dict(state_dict, strict=strict) 72 | -------------------------------------------------------------------------------- /hotfix/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | has_transforms = transforms is not None 15 | has_separate_transform = transform is not None or target_transform is not None 16 | if has_transforms and has_separate_transform: 17 | raise ValueError("Only transforms or transform/target_transform can " 18 | "be passed as argument") 19 | 20 | # for backwards-compatibility 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | if has_separate_transform: 25 | transforms = StandardTransform(transform, target_transform) 26 | self.transforms = transforms 27 | 28 | def __getitem__(self, index): 29 | raise NotImplementedError 30 | 31 | def __len__(self): 32 | raise NotImplementedError 33 | 34 | def __repr__(self): 35 | head = "Dataset " + self.__class__.__name__ 36 | body = ["Number of datapoints: {}".format(self.__len__())] 37 | if self.root is not None: 38 | body.append("Root location: {}".format(self.root)) 39 | body += self.extra_repr().splitlines() 40 | if self.transforms is not None: 41 | body += [repr(self.transforms)] 42 | lines = [head] + [" " * self._repr_indent + line for line in body] 43 | return '\n'.join(lines) 44 | 45 | def _format_transform_repr(self, transform, head): 46 | lines = transform.__repr__().splitlines() 47 | return (["{}{}".format(head, lines[0])] + 48 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 49 | 50 | def extra_repr(self): 51 | return "" 52 | 53 | 54 | class StandardTransform(object): 55 | def __init__(self, transform=None, target_transform=None): 56 | self.transform = transform 57 | self.target_transform = target_transform 58 | 59 | def __call__(self, input, target): 60 | if self.transform is not None: 61 | input = self.transform(input) 62 | if self.target_transform is not None: 63 | target = self.target_transform(target) 64 | return input, target 65 | 66 | def _format_transform_repr(self, transform, head): 67 | lines = transform.__repr__().splitlines() 68 | return (["{}{}".format(head, lines[0])] + 69 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 70 | 71 | def __repr__(self): 72 | body = [self.__class__.__name__] 73 | if self.transform is not None: 74 | body += self._format_transform_repr(self.transform, 75 | "Transform: ") 76 | if self.target_transform is not None: 77 | body += self._format_transform_repr(self.target_transform, 78 | "Target transform: ") 79 | 80 | return '\n'.join(body) 81 | -------------------------------------------------------------------------------- /ModelLoader/geffnet/activations/activations.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | from torch.nn import functional as F 10 | 11 | 12 | def swish(x, inplace: bool = False): 13 | """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) 14 | and also as Swish (https://arxiv.org/abs/1710.05941). 15 | 16 | TODO Rename to SiLU with addition to PyTorch 17 | """ 18 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 19 | 20 | 21 | class Swish(nn.Module): 22 | def __init__(self, inplace: bool = False): 23 | super(Swish, self).__init__() 24 | self.inplace = inplace 25 | 26 | def forward(self, x): 27 | return swish(x, self.inplace) 28 | 29 | 30 | def mish(x, inplace: bool = False): 31 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 32 | """ 33 | return x.mul(F.softplus(x).tanh()) 34 | 35 | 36 | class Mish(nn.Module): 37 | def __init__(self, inplace: bool = False): 38 | super(Mish, self).__init__() 39 | self.inplace = inplace 40 | 41 | def forward(self, x): 42 | return mish(x, self.inplace) 43 | 44 | 45 | def sigmoid(x, inplace: bool = False): 46 | return x.sigmoid_() if inplace else x.sigmoid() 47 | 48 | 49 | # PyTorch has this, but not with a consistent inplace argmument interface 50 | class Sigmoid(nn.Module): 51 | def __init__(self, inplace: bool = False): 52 | super(Sigmoid, self).__init__() 53 | self.inplace = inplace 54 | 55 | def forward(self, x): 56 | return x.sigmoid_() if self.inplace else x.sigmoid() 57 | 58 | 59 | def tanh(x, inplace: bool = False): 60 | return x.tanh_() if inplace else x.tanh() 61 | 62 | 63 | # PyTorch has this, but not with a consistent inplace argmument interface 64 | class Tanh(nn.Module): 65 | def __init__(self, inplace: bool = False): 66 | super(Tanh, self).__init__() 67 | self.inplace = inplace 68 | 69 | def forward(self, x): 70 | return x.tanh_() if self.inplace else x.tanh() 71 | 72 | 73 | def hard_swish(x, inplace: bool = False): 74 | inner = F.relu6(x + 3.).div_(6.) 75 | return x.mul_(inner) if inplace else x.mul(inner) 76 | 77 | 78 | class HardSwish(nn.Module): 79 | def __init__(self, inplace: bool = False): 80 | super(HardSwish, self).__init__() 81 | self.inplace = inplace 82 | 83 | def forward(self, x): 84 | return hard_swish(x, self.inplace) 85 | 86 | 87 | def hard_sigmoid(x, inplace: bool = False): 88 | if inplace: 89 | return x.add_(3.).clamp_(0., 6.).div_(6.) 90 | else: 91 | return F.relu6(x + 3.) / 6. 92 | 93 | 94 | class HardSigmoid(nn.Module): 95 | def __init__(self, inplace: bool = False): 96 | super(HardSigmoid, self).__init__() 97 | self.inplace = inplace 98 | 99 | def forward(self, x): 100 | return hard_sigmoid(x, self.inplace) 101 | 102 | 103 | -------------------------------------------------------------------------------- /ModelLoader/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | 4 | The geffnet module is modified from: 5 | https://github.com/rwightman/gen-efficientnet-pytorch 6 | ''' 7 | 8 | 9 | import os,sys 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 12 | 13 | from . import geffnet 14 | from . import myresnet 15 | 16 | import torchvision.models 17 | 18 | torchvision_model_name_list = sorted(name for name in torchvision.models.__dict__ 19 | if name.islower() and not name.startswith("__") 20 | and callable(torchvision.models.__dict__[name])) 21 | 22 | def _get_model_(arch, num_classes, pretrained=False, opt=None, argv=None): 23 | 24 | # load torch vision model 25 | if arch in torchvision_model_name_list: 26 | if pretrained: 27 | print('Using pretrained model: {}'.format(arch)) 28 | model = torchvision.models.__dict__[arch](pretrained=True, num_classes=num_classes) 29 | else: 30 | print('Create torchvision model: {}'.format(arch)) 31 | model = torchvision.models.__dict__[arch](num_classes=num_classes) 32 | 33 | # my implementation of resnet 34 | elif arch == 'myresnet18': 35 | print('Create model: {}'.format(arch)) 36 | model = myresnet.resnet18(pretrained=False, opt=opt, argv=argv) 37 | elif arch == 'myresnet34': 38 | print('Create model: {}'.format(arch)) 39 | model = myresnet.resnet34(pretrained=False, opt=opt, argv=argv) 40 | elif arch == 'myresnet50': 41 | print('Create model: {}'.format(arch)) 42 | model = myresnet.resnet50(pretrained=False, opt=opt, argv=argv) 43 | elif arch == 'myresnet101': 44 | print('Create model: {}'.format(arch)) 45 | model = myresnet.resnet101(pretrained=False, opt=opt, argv=argv) 46 | elif arch == 'myresnet152': 47 | print('Create model: {}'.format(arch)) 48 | model = myresnet.resnet152(pretrained=False, opt=opt, argv=argv) 49 | 50 | # geffnet 51 | elif arch.startswith('geffnet_'): 52 | model_name = arch[len('geffnet_'):] 53 | model = geffnet.create_model(model_name, pretrained=pretrained) 54 | 55 | 56 | # PlainNet 57 | elif arch == 'PlainNet': 58 | import PlainNet 59 | print('Create model: {}'.format(arch)) 60 | model = PlainNet.PlainNet(num_classes=num_classes, opt=opt, argv=argv) 61 | 62 | # Any PlainNet 63 | elif arch.find('.py:MasterNet') >= 0: 64 | module_path = arch.split(':')[0] 65 | assert arch.split(':')[1] == 'MasterNet' 66 | my_working_dir = os.path.dirname(os.path.dirname(__file__)) 67 | module_full_path = os.path.join(my_working_dir, module_path) 68 | 69 | import importlib.util 70 | spec = importlib.util.spec_from_file_location('AnyPlainNet', module_full_path) 71 | AnyPlainNet = importlib.util.module_from_spec(spec) 72 | spec.loader.exec_module(AnyPlainNet) 73 | print('Create model: {}'.format(arch)) 74 | model = AnyPlainNet.MasterNet(num_classes=num_classes, opt=opt, argv=argv) 75 | 76 | else: 77 | raise ValueError('Unknown model arch: ' + arch) 78 | 79 | return model 80 | 81 | def get_model(opt, argv): 82 | return _get_model_(arch=opt.arch, num_classes=opt.num_classes, pretrained=opt.pretrained, opt=opt, argv=argv) -------------------------------------------------------------------------------- /ModelLoader/geffnet/activations/__init__.py: -------------------------------------------------------------------------------- 1 | from geffnet import config 2 | from geffnet.activations.activations_me import * 3 | from geffnet.activations.activations_jit import * 4 | from geffnet.activations.activations import * 5 | 6 | 7 | _ACT_FN_DEFAULT = dict( 8 | swish=swish, 9 | mish=mish, 10 | relu=F.relu, 11 | relu6=F.relu6, 12 | sigmoid=sigmoid, 13 | tanh=tanh, 14 | hard_sigmoid=hard_sigmoid, 15 | hard_swish=hard_swish, 16 | ) 17 | 18 | _ACT_FN_JIT = dict( 19 | swish=swish_jit, 20 | mish=mish_jit, 21 | ) 22 | 23 | _ACT_FN_ME = dict( 24 | swish=swish_me, 25 | mish=mish_me, 26 | hard_swish=hard_swish_me, 27 | hard_sigmoid_jit=hard_sigmoid_me, 28 | ) 29 | 30 | _ACT_LAYER_DEFAULT = dict( 31 | swish=Swish, 32 | mish=Mish, 33 | relu=nn.ReLU, 34 | relu6=nn.ReLU6, 35 | sigmoid=Sigmoid, 36 | tanh=Tanh, 37 | hard_sigmoid=HardSigmoid, 38 | hard_swish=HardSwish, 39 | ) 40 | 41 | _ACT_LAYER_JIT = dict( 42 | swish=SwishJit, 43 | mish=MishJit, 44 | ) 45 | 46 | _ACT_LAYER_ME = dict( 47 | swish=SwishMe, 48 | mish=MishMe, 49 | hard_swish=HardSwishMe, 50 | hard_sigmoid=HardSigmoidMe 51 | ) 52 | 53 | _OVERRIDE_FN = dict() 54 | _OVERRIDE_LAYER = dict() 55 | 56 | 57 | def add_override_act_fn(name, fn): 58 | global _OVERRIDE_FN 59 | _OVERRIDE_FN[name] = fn 60 | 61 | 62 | def update_override_act_fn(overrides): 63 | assert isinstance(overrides, dict) 64 | global _OVERRIDE_FN 65 | _OVERRIDE_FN.update(overrides) 66 | 67 | 68 | def clear_override_act_fn(): 69 | global _OVERRIDE_FN 70 | _OVERRIDE_FN = dict() 71 | 72 | 73 | def add_override_act_layer(name, fn): 74 | _OVERRIDE_LAYER[name] = fn 75 | 76 | 77 | def update_override_act_layer(overrides): 78 | assert isinstance(overrides, dict) 79 | global _OVERRIDE_LAYER 80 | _OVERRIDE_LAYER.update(overrides) 81 | 82 | 83 | def clear_override_act_layer(): 84 | global _OVERRIDE_LAYER 85 | _OVERRIDE_LAYER = dict() 86 | 87 | 88 | def get_act_fn(name='relu'): 89 | """ Activation Function Factory 90 | Fetching activation fns by name with this function allows export or torch script friendly 91 | functions to be returned dynamically based on current config. 92 | """ 93 | if name in _OVERRIDE_FN: 94 | return _OVERRIDE_FN[name] 95 | no_me = config.is_exportable() or config.is_scriptable() or config.is_no_jit() 96 | if not no_me and name in _ACT_FN_ME: 97 | # If not exporting or scripting the model, first look for a memory optimized version 98 | # activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin 99 | return _ACT_FN_ME[name] 100 | no_jit = config.is_exportable() or config.is_no_jit() 101 | # NOTE: export tracing should work with jit scripted components, but I keep running into issues 102 | if no_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting 103 | return _ACT_FN_JIT[name] 104 | return _ACT_FN_DEFAULT[name] 105 | 106 | 107 | def get_act_layer(name='relu'): 108 | """ Activation Layer Factory 109 | Fetching activation layers by name with this function allows export or torch script friendly 110 | functions to be returned dynamically based on current config. 111 | """ 112 | if name in _OVERRIDE_LAYER: 113 | return _OVERRIDE_LAYER[name] 114 | no_me = config.is_exportable() or config.is_scriptable() or config.is_no_jit() 115 | if not no_me and name in _ACT_LAYER_ME: 116 | return _ACT_LAYER_ME[name] 117 | no_jit = config.is_exportable() or config.is_no_jit() 118 | if not no_jit and name in _ACT_LAYER_JIT: # jit scripted models should be okay for export/scripting 119 | return _ACT_LAYER_JIT[name] 120 | return _ACT_LAYER_DEFAULT[name] 121 | 122 | 123 | -------------------------------------------------------------------------------- /ModelLoader/geffnet/config.py: -------------------------------------------------------------------------------- 1 | """ Global layer config state 2 | """ 3 | from typing import Any, Optional 4 | 5 | __all__ = [ 6 | 'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs', 7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' 8 | ] 9 | 10 | # Set to True if prefer to have layers with no jit optimization (includes activations) 11 | _NO_JIT = False 12 | 13 | # Set to True if prefer to have activation layers with no jit optimization 14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying 15 | # the jit flags so far are activations. This will change as more layers are updated and/or added. 16 | _NO_ACTIVATION_JIT = False 17 | 18 | # Set to True if exporting a model with Same padding via ONNX 19 | _EXPORTABLE = False 20 | 21 | # Set to True if wanting to use torch.jit.script on a model 22 | _SCRIPTABLE = False 23 | 24 | 25 | def is_no_jit(): 26 | return _NO_JIT 27 | 28 | 29 | class set_no_jit: 30 | def __init__(self, mode: bool) -> None: 31 | global _NO_JIT 32 | self.prev = _NO_JIT 33 | _NO_JIT = mode 34 | 35 | def __enter__(self) -> None: 36 | pass 37 | 38 | def __exit__(self, *args: Any) -> bool: 39 | global _NO_JIT 40 | _NO_JIT = self.prev 41 | return False 42 | 43 | 44 | def is_exportable(): 45 | return _EXPORTABLE 46 | 47 | 48 | class set_exportable: 49 | def __init__(self, mode: bool) -> None: 50 | global _EXPORTABLE 51 | self.prev = _EXPORTABLE 52 | _EXPORTABLE = mode 53 | 54 | def __enter__(self) -> None: 55 | pass 56 | 57 | def __exit__(self, *args: Any) -> bool: 58 | global _EXPORTABLE 59 | _EXPORTABLE = self.prev 60 | return False 61 | 62 | 63 | def is_scriptable(): 64 | return _SCRIPTABLE 65 | 66 | 67 | class set_scriptable: 68 | def __init__(self, mode: bool) -> None: 69 | global _SCRIPTABLE 70 | self.prev = _SCRIPTABLE 71 | _SCRIPTABLE = mode 72 | 73 | def __enter__(self) -> None: 74 | pass 75 | 76 | def __exit__(self, *args: Any) -> bool: 77 | global _SCRIPTABLE 78 | _SCRIPTABLE = self.prev 79 | return False 80 | 81 | 82 | class set_layer_config: 83 | """ Layer config context manager that allows setting all layer config flags at once. 84 | If a flag arg is None, it will not change the current value. 85 | """ 86 | def __init__( 87 | self, 88 | scriptable: Optional[bool] = None, 89 | exportable: Optional[bool] = None, 90 | no_jit: Optional[bool] = None, 91 | no_activation_jit: Optional[bool] = None): 92 | global _SCRIPTABLE 93 | global _EXPORTABLE 94 | global _NO_JIT 95 | global _NO_ACTIVATION_JIT 96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT 97 | if scriptable is not None: 98 | _SCRIPTABLE = scriptable 99 | if exportable is not None: 100 | _EXPORTABLE = exportable 101 | if no_jit is not None: 102 | _NO_JIT = no_jit 103 | if no_activation_jit is not None: 104 | _NO_ACTIVATION_JIT = no_activation_jit 105 | 106 | def __enter__(self) -> None: 107 | pass 108 | 109 | def __exit__(self, *args: Any) -> bool: 110 | global _SCRIPTABLE 111 | global _EXPORTABLE 112 | global _NO_JIT 113 | global _NO_ACTIVATION_JIT 114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev 115 | return False 116 | 117 | 118 | def layer_config_kwargs(kwargs): 119 | """ Consume config kwargs and return contextmgr obj """ 120 | return set_layer_config( 121 | scriptable=kwargs.pop('scriptable', None), 122 | exportable=kwargs.pop('exportable', None), 123 | no_jit=kwargs.pop('no_jit', None)) 124 | -------------------------------------------------------------------------------- /ZeroShotProxy/compute_zen_score.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | 6 | 7 | import os, sys 8 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | import torch 10 | from torch import nn 11 | import numpy as np 12 | import global_utils, argparse, ModelLoader, time 13 | 14 | def network_weight_gaussian_init(net: nn.Module): 15 | with torch.no_grad(): 16 | for m in net.modules(): 17 | if isinstance(m, nn.Conv2d): 18 | nn.init.normal_(m.weight) 19 | if hasattr(m, 'bias') and m.bias is not None: 20 | nn.init.zeros_(m.bias) 21 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 22 | nn.init.ones_(m.weight) 23 | nn.init.zeros_(m.bias) 24 | elif isinstance(m, nn.Linear): 25 | nn.init.normal_(m.weight) 26 | if hasattr(m, 'bias') and m.bias is not None: 27 | nn.init.zeros_(m.bias) 28 | else: 29 | continue 30 | 31 | return net 32 | 33 | def compute_nas_score(gpu, model, mixup_gamma, resolution, batch_size, repeat, fp16=False): 34 | info = {} 35 | nas_score_list = [] 36 | if gpu is not None: 37 | device = torch.device('cuda:{}'.format(gpu)) 38 | else: 39 | device = torch.device('cpu') 40 | 41 | if fp16: 42 | dtype = torch.half 43 | else: 44 | dtype = torch.float32 45 | 46 | with torch.no_grad(): 47 | for repeat_count in range(repeat): 48 | network_weight_gaussian_init(model) 49 | input = torch.randn(size=[batch_size, 3, resolution, resolution], device=device, dtype=dtype) 50 | input2 = torch.randn(size=[batch_size, 3, resolution, resolution], device=device, dtype=dtype) 51 | mixup_input = input + mixup_gamma * input2 52 | output = model.forward_pre_GAP(input) 53 | mixup_output = model.forward_pre_GAP(mixup_input) 54 | 55 | nas_score = torch.sum(torch.abs(output - mixup_output), dim=[1, 2, 3]) 56 | nas_score = torch.mean(nas_score) 57 | 58 | # compute BN scaling 59 | log_bn_scaling_factor = 0.0 60 | for m in model.modules(): 61 | if isinstance(m, nn.BatchNorm2d): 62 | bn_scaling_factor = torch.sqrt(torch.mean(m.running_var)) 63 | log_bn_scaling_factor += torch.log(bn_scaling_factor) 64 | pass 65 | pass 66 | nas_score = torch.log(nas_score) + log_bn_scaling_factor 67 | nas_score_list.append(float(nas_score)) 68 | 69 | 70 | std_nas_score = np.std(nas_score_list) 71 | avg_precision = 1.96 * std_nas_score / np.sqrt(len(nas_score_list)) 72 | avg_nas_score = np.mean(nas_score_list) 73 | 74 | 75 | info['avg_nas_score'] = float(avg_nas_score) 76 | info['std_nas_score'] = float(std_nas_score) 77 | info['avg_precision'] = float(avg_precision) 78 | return info 79 | 80 | 81 | def parse_cmd_options(argv): 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--batch_size', type=int, default=16, help='number of instances in one mini-batch.') 84 | parser.add_argument('--input_image_size', type=int, default=None, 85 | help='resolution of input image, usually 32 for CIFAR and 224 for ImageNet.') 86 | parser.add_argument('--repeat_times', type=int, default=32) 87 | parser.add_argument('--gpu', type=int, default=None) 88 | parser.add_argument('--mixup_gamma', type=float, default=1e-2) 89 | module_opt, _ = parser.parse_known_args(argv) 90 | return module_opt 91 | 92 | if __name__ == "__main__": 93 | opt = global_utils.parse_cmd_options(sys.argv) 94 | args = parse_cmd_options(sys.argv) 95 | the_model = ModelLoader.get_model(opt, sys.argv) 96 | if args.gpu is not None: 97 | the_model = the_model.cuda(args.gpu) 98 | 99 | 100 | start_timer = time.time() 101 | info = compute_nas_score(gpu=args.gpu, model=the_model, mixup_gamma=args.mixup_gamma, 102 | resolution=args.input_image_size, batch_size=args.batch_size, repeat=args.repeat_times, fp16=False) 103 | time_cost = (time.time() - start_timer) / args.repeat_times 104 | zen_score = info['avg_nas_score'] 105 | print(f'zen-score={zen_score:.4g}, time cost={time_cost:.4g} second(s)') -------------------------------------------------------------------------------- /ZeroShotProxy/compute_NASWOT_score.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | 4 | The implementation of NASWOT score is modified from: 5 | https://github.com/BayesWatch/nas-without-training 6 | ''' 7 | 8 | 9 | 10 | import os, sys, time 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from PlainNet import basic_blocks 16 | import global_utils, argparse, ModelLoader, time 17 | 18 | def network_weight_gaussian_init(net: nn.Module): 19 | with torch.no_grad(): 20 | for m in net.modules(): 21 | if isinstance(m, nn.Conv2d): 22 | nn.init.normal_(m.weight) 23 | if hasattr(m, 'bias') and m.bias is not None: 24 | nn.init.zeros_(m.bias) 25 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 26 | nn.init.ones_(m.weight) 27 | nn.init.zeros_(m.bias) 28 | elif isinstance(m, nn.Linear): 29 | nn.init.normal_(m.weight) 30 | if hasattr(m, 'bias') and m.bias is not None: 31 | nn.init.zeros_(m.bias) 32 | else: 33 | continue 34 | 35 | return net 36 | 37 | def logdet(K): 38 | s, ld = np.linalg.slogdet(K) 39 | return ld 40 | 41 | def get_batch_jacobian(net, x): 42 | net.zero_grad() 43 | x.requires_grad_(True) 44 | y = net(x) 45 | y.backward(torch.ones_like(y)) 46 | jacob = x.grad.detach() 47 | # return jacob, target.detach(), y.detach() 48 | return jacob, y.detach() 49 | 50 | def compute_nas_score(gpu, model, resolution, batch_size): 51 | if gpu is not None: 52 | torch.cuda.set_device(gpu) 53 | model = model.cuda(gpu) 54 | 55 | network_weight_gaussian_init(model) 56 | input = torch.randn(size=[batch_size, 3, resolution, resolution]) 57 | if gpu is not None: 58 | input = input.cuda(gpu) 59 | 60 | model.K = np.zeros((batch_size, batch_size)) 61 | 62 | def counting_forward_hook(module, inp, out): 63 | try: 64 | if not module.visited_backwards: 65 | return 66 | if isinstance(inp, tuple): 67 | inp = inp[0] 68 | inp = inp.view(inp.size(0), -1) 69 | x = (inp > 0).float() 70 | K = x @ x.t() 71 | K2 = (1. - x) @ (1. - x.t()) 72 | model.K = model.K + K.cpu().numpy() + K2.cpu().numpy() 73 | except Exception as err: 74 | print('---- error on model : ') 75 | print(model) 76 | raise err 77 | 78 | 79 | def counting_backward_hook(module, inp, out): 80 | module.visited_backwards = True 81 | 82 | for name, module in model.named_modules(): 83 | # if 'ReLU' in str(type(module)): 84 | if isinstance(module, basic_blocks.RELU): 85 | # hooks[name] = module.register_forward_hook(counting_hook) 86 | module.visited_backwards = True 87 | module.register_forward_hook(counting_forward_hook) 88 | module.register_backward_hook(counting_backward_hook) 89 | 90 | x = input 91 | jacobs, y = get_batch_jacobian(model, x) 92 | 93 | score = logdet(model.K) 94 | 95 | return float(score) 96 | 97 | 98 | 99 | def parse_cmd_options(argv): 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('--batch_size', type=int, default=16, help='number of instances in one mini-batch.') 102 | parser.add_argument('--input_image_size', type=int, default=None, 103 | help='resolution of input image, usually 32 for CIFAR and 224 for ImageNet.') 104 | parser.add_argument('--repeat_times', type=int, default=32) 105 | parser.add_argument('--gpu', type=int, default=None) 106 | module_opt, _ = parser.parse_known_args(argv) 107 | return module_opt 108 | 109 | if __name__ == "__main__": 110 | opt = global_utils.parse_cmd_options(sys.argv) 111 | args = parse_cmd_options(sys.argv) 112 | the_model = ModelLoader.get_model(opt, sys.argv) 113 | if args.gpu is not None: 114 | the_model = the_model.cuda(args.gpu) 115 | 116 | 117 | start_timer = time.time() 118 | 119 | for repeat_count in range(args.repeat_times): 120 | the_score = compute_nas_score(gpu=args.gpu, model=the_model, 121 | resolution=args.input_image_size, batch_size=args.batch_size) 122 | 123 | time_cost = (time.time() - start_timer) / args.repeat_times 124 | 125 | print(f'NASWOT={the_score:.4g}, time cost={time_cost:.4g} second(s)') -------------------------------------------------------------------------------- /ZeroShotProxy/compute_syncflow_score.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | This file is modified from 4 | https://github.com/SamsungLabs/zero-cost-nas 5 | ''' 6 | 7 | 8 | # Copyright 2021 Samsung Electronics Co., Ltd. 9 | # 10 | # Licensed under the Apache License, Version 2.0 (the "License"); 11 | # you may not use this file except in compliance with the License. 12 | # You may obtain a copy of the License at 13 | 14 | # http://www.apache.org/licenses/LICENSE-2.0 15 | 16 | # Unless required by applicable law or agreed to in writing, software 17 | # distributed under the License is distributed on an "AS IS" BASIS, 18 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | # See the License for the specific language governing permissions and 20 | # limitations under the License. 21 | # ============================================================================= 22 | 23 | 24 | import os, sys, time 25 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 26 | import torch 27 | from torch import nn 28 | import numpy as np 29 | 30 | 31 | 32 | import torch 33 | 34 | def network_weight_gaussian_init(net: nn.Module): 35 | with torch.no_grad(): 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | nn.init.normal_(m.weight) 39 | if hasattr(m, 'bias') and m.bias is not None: 40 | nn.init.zeros_(m.bias) 41 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 42 | nn.init.ones_(m.weight) 43 | nn.init.zeros_(m.bias) 44 | elif isinstance(m, nn.Linear): 45 | nn.init.normal_(m.weight) 46 | if hasattr(m, 'bias') and m.bias is not None: 47 | nn.init.zeros_(m.bias) 48 | else: 49 | continue 50 | 51 | return net 52 | 53 | def get_layer_metric_array(net, metric, mode): 54 | metric_array = [] 55 | 56 | for layer in net.modules(): 57 | if mode == 'channel' and hasattr(layer, 'dont_ch_prune'): 58 | continue 59 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 60 | metric_array.append(metric(layer)) 61 | 62 | return metric_array 63 | 64 | 65 | 66 | def compute_synflow_per_weight(net, inputs, mode): 67 | device = inputs.device 68 | 69 | # convert params to their abs. Keep sign for converting it back. 70 | @torch.no_grad() 71 | def linearize(net): 72 | signs = {} 73 | for name, param in net.state_dict().items(): 74 | signs[name] = torch.sign(param) 75 | param.abs_() 76 | return signs 77 | 78 | # convert to orig values 79 | @torch.no_grad() 80 | def nonlinearize(net, signs): 81 | for name, param in net.state_dict().items(): 82 | if 'weight_mask' not in name: 83 | param.mul_(signs[name]) 84 | 85 | # keep signs of all params 86 | signs = linearize(net) 87 | 88 | # Compute gradients with input of 1s 89 | net.zero_grad() 90 | net.double() 91 | input_dim = list(inputs[0, :].shape) 92 | inputs = torch.ones([1] + input_dim).double().to(device) 93 | output = net.forward(inputs) 94 | torch.sum(output).backward() 95 | 96 | # select the gradients that we want to use for search/prune 97 | def synflow(layer): 98 | if layer.weight.grad is not None: 99 | return torch.abs(layer.weight * layer.weight.grad) 100 | else: 101 | return torch.zeros_like(layer.weight) 102 | 103 | grads_abs = get_layer_metric_array(net, synflow, mode) 104 | 105 | # apply signs of all params 106 | nonlinearize(net, signs) 107 | 108 | return grads_abs 109 | 110 | def do_compute_nas_score(gpu, model, resolution, batch_size): 111 | model.train() 112 | model.requires_grad_(True) 113 | 114 | model.zero_grad() 115 | 116 | if gpu is not None: 117 | torch.cuda.set_device(gpu) 118 | model = model.cuda(gpu) 119 | 120 | network_weight_gaussian_init(model) 121 | input = torch.randn(size=[batch_size, 3, resolution, resolution]) 122 | if gpu is not None: 123 | input = input.cuda(gpu) 124 | 125 | grads_abs_list = compute_synflow_per_weight(net=model, inputs=input, mode='') 126 | score = 0 127 | for grad_abs in grads_abs_list: 128 | if len(grad_abs.shape) == 4: 129 | score += float(torch.mean(torch.sum(grad_abs, dim=[1,2,3]))) 130 | elif len(grad_abs.shape) == 2: 131 | score += float(torch.mean(torch.sum(grad_abs, dim=[1]))) 132 | else: 133 | raise RuntimeError('!!!') 134 | 135 | 136 | return -1 * score 137 | 138 | 139 | -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | ''' 6 | Usage: 7 | python val.py --gpu 0 --arch zennet_imagenet1k_latency02ms_res192 8 | ''' 9 | import os, sys, argparse, math, PIL 10 | import torch 11 | from torchvision import transforms, datasets 12 | import ZenNet 13 | 14 | imagenet_data_dir = os.path.expanduser('~/data/imagenet') 15 | 16 | def accuracy(output, target, topk=(1, )): 17 | """Computes the accuracy over the k top predictions for the specified values of k""" 18 | with torch.no_grad(): 19 | maxk = max(topk) 20 | batch_size = target.size(0) 21 | 22 | _, pred = output.topk(maxk, 1, True, True) 23 | pred = pred.t() 24 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 25 | 26 | res = [] 27 | for k in topk: 28 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 29 | res.append(correct_k.mul_(100.0 / batch_size)) 30 | return res 31 | 32 | 33 | if __name__ == '__main__': 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--batch_size', type=int, default=128, 37 | help='batch size for evaluation.') 38 | parser.add_argument('--workers', type=int, default=6, 39 | help='number of workers to load dataset.') 40 | parser.add_argument('--gpu', type=int, default=None, 41 | help='GPU device ID. None for CPU.') 42 | parser.add_argument('--data', type=str, default=imagenet_data_dir, 43 | help='ImageNet data directory.') 44 | parser.add_argument('--arch', type=str, default=None, 45 | help='model to be evaluated.') 46 | parser.add_argument('--fp16', action='store_true') 47 | parser.add_argument('--apex', action='store_true', 48 | help='Use NVIDIA Apex (float16 precision).') 49 | 50 | opt, _ = parser.parse_known_args(sys.argv) 51 | 52 | if opt.apex: 53 | from apex import amp 54 | # else: 55 | # print('Warning!!! The GENets are trained by NVIDIA Apex, ' 56 | # 'it is suggested to turn on --apex in the evaluation. ' 57 | # 'Otherwise the model accuracy might be harmed.') 58 | 59 | input_image_size = ZenNet.zennet_model_zoo[opt.arch]['resolution'] 60 | crop_image_size = ZenNet.zennet_model_zoo[opt.arch]['crop_image_size'] 61 | 62 | print('Evaluate {} at {}x{} resolution.'.format(opt.arch, input_image_size, input_image_size)) 63 | 64 | # load dataset 65 | val_dir = os.path.join(opt.data, 'val') 66 | input_image_crop = 0.875 67 | resize_image_size = int(math.ceil(crop_image_size / input_image_crop)) 68 | transforms_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 69 | transform_list = [transforms.Resize(resize_image_size, interpolation=PIL.Image.BICUBIC), 70 | transforms.CenterCrop(crop_image_size), transforms.ToTensor(), transforms_normalize] 71 | transformer = transforms.Compose(transform_list) 72 | val_dataset = datasets.ImageFolder(val_dir, transformer) 73 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=opt.batch_size, shuffle=False, 74 | num_workers=opt.workers, pin_memory=True, sampler=None) 75 | 76 | # load model 77 | model = ZenNet.get_ZenNet(opt.arch, pretrained=True) 78 | if opt.gpu is not None: 79 | torch.cuda.set_device(opt.gpu) 80 | torch.backends.cudnn.benchmark = True 81 | model = model.cuda(opt.gpu) 82 | print('Using GPU {}.'.format(opt.gpu)) 83 | if opt.apex: 84 | model = amp.initialize(model, opt_level="O1") 85 | elif opt.fp16: 86 | model = model.half() 87 | 88 | model.eval() 89 | acc1_sum = 0 90 | acc5_sum = 0 91 | n = 0 92 | with torch.no_grad(): 93 | for i, (input, target) in enumerate(val_loader): 94 | 95 | if opt.gpu is not None: 96 | input = input.cuda(opt.gpu, non_blocking=True) 97 | target = target.cuda(opt.gpu, non_blocking=True) 98 | if opt.fp16: 99 | input = input.half() 100 | 101 | input = torch.nn.functional.interpolate(input, input_image_size, mode='bilinear') 102 | 103 | output = model(input) 104 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 105 | 106 | acc1_sum += acc1[0] * input.shape[0] 107 | acc5_sum += acc5[0] * input.shape[0] 108 | n += input.shape[0] 109 | 110 | if i % 100 == 0: 111 | print('mini_batch {}, top-1 acc={:4g}%, top-5 acc={:4g}%, number of evaluated images={}'.format(i, acc1[0], acc5[0], n)) 112 | pass 113 | pass 114 | pass 115 | 116 | acc1_avg = acc1_sum / n 117 | acc5_avg = acc5_sum / n 118 | 119 | print('*** arch={}, validation top-1 acc={}%, top-5 acc={}%, number of evaluated images={}'.format(opt.arch, acc1_avg, acc5_avg, n)) -------------------------------------------------------------------------------- /hotfix/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import sys 4 | import random 5 | from PIL import Image 6 | 7 | import numpy as np 8 | import numbers 9 | import types 10 | import collections 11 | import warnings 12 | 13 | def erase(img, i, j, h, w, v, inplace=False): 14 | """ Erase the input Tensor Image with given value. 15 | 16 | Args: 17 | img (Tensor Image): Tensor image of size (C, H, W) to be erased 18 | i (int): i in (i,j) i.e coordinates of the upper left corner. 19 | j (int): j in (i,j) i.e coordinates of the upper left corner. 20 | h (int): Height of the erased region. 21 | w (int): Width of the erased region. 22 | v: Erasing value. 23 | inplace(bool, optional): For in-place operations. By default is set False. 24 | 25 | Returns: 26 | Tensor Image: Erased image. 27 | """ 28 | if not isinstance(img, torch.Tensor): 29 | raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) 30 | 31 | if not inplace: 32 | img = img.clone() 33 | 34 | img[:, i:i + h, j:j + w] = v 35 | return img 36 | 37 | 38 | class RandomErasing(object): 39 | """ Randomly selects a rectangle region in an image and erases its pixels. 40 | 'Random Erasing Data Augmentation' by Zhong et al. 41 | See https://arxiv.org/pdf/1708.04896.pdf 42 | Args: 43 | p: probability that the random erasing operation will be performed. 44 | scale: range of proportion of erased area against input image. 45 | ratio: range of aspect ratio of erased area. 46 | value: erasing value. Default is 0. If a single int, it is used to 47 | erase all pixels. If a tuple of length 3, it is used to erase 48 | R, G, B channels respectively. 49 | If a str of 'random', erasing each pixel with random values. 50 | inplace: boolean to make this transform inplace. Default set to False. 51 | 52 | Returns: 53 | Erased Image. 54 | # Examples: 55 | >>> transform = transforms.Compose([ 56 | >>> transforms.RandomHorizontalFlip(), 57 | >>> transforms.ToTensor(), 58 | >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 59 | >>> transforms.RandomErasing(), 60 | >>> ]) 61 | """ 62 | 63 | def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): 64 | assert isinstance(value, (numbers.Number, str, tuple, list)) 65 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 66 | warnings.warn("range should be of kind (min, max)") 67 | if scale[0] < 0 or scale[1] > 1: 68 | raise ValueError("range of scale should be between 0 and 1") 69 | if p < 0 or p > 1: 70 | raise ValueError("range of random erasing probability should be between 0 and 1") 71 | 72 | self.p = p 73 | self.scale = scale 74 | self.ratio = ratio 75 | self.value = value 76 | self.inplace = inplace 77 | 78 | @staticmethod 79 | def get_params(img, scale, ratio, value=0): 80 | """Get parameters for ``erase`` for a random erasing. 81 | 82 | Args: 83 | img (Tensor): Tensor image of size (C, H, W) to be erased. 84 | scale: range of proportion of erased area against input image. 85 | ratio: range of aspect ratio of erased area. 86 | 87 | Returns: 88 | tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing. 89 | """ 90 | img_c, img_h, img_w = img.shape 91 | area = img_h * img_w 92 | 93 | for attempt in range(10): 94 | erase_area = random.uniform(scale[0], scale[1]) * area 95 | aspect_ratio = random.uniform(ratio[0], ratio[1]) 96 | 97 | h = int(round(math.sqrt(erase_area * aspect_ratio))) 98 | w = int(round(math.sqrt(erase_area / aspect_ratio))) 99 | 100 | if h < img_h and w < img_w: 101 | i = random.randint(0, img_h - h) 102 | j = random.randint(0, img_w - w) 103 | if isinstance(value, numbers.Number): 104 | v = value 105 | elif isinstance(value, torch._six.string_classes): 106 | v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() 107 | elif isinstance(value, (list, tuple)): 108 | v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w) 109 | return i, j, h, w, v 110 | 111 | # Return original image 112 | return 0, 0, img_h, img_w, img 113 | 114 | def __call__(self, img): 115 | """ 116 | Args: 117 | img (Tensor): Tensor image of size (C, H, W) to be erased. 118 | 119 | Returns: 120 | img (Tensor): Erased Tensor image. 121 | """ 122 | if random.uniform(0, 1) < self.p: 123 | x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value) 124 | return erase(img, x, y, h, w, v, self.inplace) 125 | return img -------------------------------------------------------------------------------- /val_cifar.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | ''' 6 | Usage: 7 | python val_cifar.py --dataset cifar10 --gpu 0 --arch zennet_cifar10_model_size05M_res32 8 | ''' 9 | import os, sys, argparse, math, PIL 10 | import torch 11 | from torchvision import transforms, datasets 12 | import ZenNet 13 | 14 | cifar10_data_dir = '~/data/pytorch_cifar10' 15 | cifar100_data_dir = '~/data/pytorch_cifar100' 16 | 17 | def accuracy(output, target, topk=(1, )): 18 | """Computes the accuracy over the k top predictions for the specified values of k""" 19 | with torch.no_grad(): 20 | maxk = max(topk) 21 | batch_size = target.size(0) 22 | 23 | _, pred = output.topk(maxk, 1, True, True) 24 | pred = pred.t() 25 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 26 | 27 | res = [] 28 | for k in topk: 29 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 30 | res.append(correct_k.mul_(100.0 / batch_size)) 31 | return res 32 | 33 | 34 | if __name__ == '__main__': 35 | 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--batch_size', type=int, default=512, 38 | help='batch size for evaluation.') 39 | parser.add_argument('--dataset', type=str, default=None, 40 | help='cifar10 or cifar100') 41 | parser.add_argument('--workers', type=int, default=6, 42 | help='number of workers to load dataset.') 43 | parser.add_argument('--gpu', type=int, default=None, 44 | help='GPU device ID. None for CPU.') 45 | parser.add_argument('--arch', type=str, default=None, 46 | help='model to be evaluated.') 47 | parser.add_argument('--fp16', action='store_true') 48 | parser.add_argument('--apex', action='store_true', 49 | help='Use NVIDIA Apex (float16 precision).') 50 | 51 | opt, _ = parser.parse_known_args(sys.argv) 52 | 53 | if opt.apex: 54 | from apex import amp 55 | else: 56 | print('Warning!!! The GENets are trained by NVIDIA Apex, ' 57 | 'it is suggested to turn on --apex in the evaluation. ' 58 | 'Otherwise the model accuracy might be harmed.') 59 | 60 | input_image_size = ZenNet.zennet_model_zoo[opt.arch]['resolution'] 61 | crop_image_size = ZenNet.zennet_model_zoo[opt.arch]['crop_image_size'] 62 | 63 | print('Evaluate {} at {}x{} resolution.'.format(opt.arch, input_image_size, input_image_size)) 64 | 65 | # load dataset 66 | transforms_normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) 67 | transform_list = [transforms.ToTensor(), transforms_normalize] 68 | transformer = transforms.Compose(transform_list) 69 | 70 | if opt.dataset == 'cifar10': 71 | the_dataset = datasets.CIFAR10(root=cifar10_data_dir, train=False, download=True, transform=transformer) 72 | elif opt.dataset == 'cifar100': 73 | the_dataset = datasets.CIFAR100(root=cifar100_data_dir, train=False, download=True, transform=transformer) 74 | else: 75 | raise ValueError('Unknown dataset_name=' + opt.dataset) 76 | val_loader = torch.utils.data.DataLoader(the_dataset, batch_size=opt.batch_size, shuffle=False, 77 | num_workers=opt.workers, pin_memory=True, sampler=None) 78 | 79 | # load model 80 | model = ZenNet.get_ZenNet(opt.arch, pretrained=True) 81 | if opt.gpu is not None: 82 | torch.cuda.set_device(opt.gpu) 83 | torch.backends.cudnn.benchmark = True 84 | model = model.cuda(opt.gpu) 85 | print('Using GPU {}.'.format(opt.gpu)) 86 | if opt.apex: 87 | model = amp.initialize(model, opt_level="O1") 88 | elif opt.fp16: 89 | model = model.half() 90 | 91 | model.eval() 92 | acc1_sum = 0 93 | acc5_sum = 0 94 | n = 0 95 | with torch.no_grad(): 96 | for i, (input, target) in enumerate(val_loader): 97 | 98 | if opt.gpu is not None: 99 | input = input.cuda(opt.gpu, non_blocking=True) 100 | target = target.cuda(opt.gpu, non_blocking=True) 101 | if opt.fp16: 102 | input = input.half() 103 | 104 | input = torch.nn.functional.interpolate(input, input_image_size, mode='bilinear') 105 | 106 | output = model(input) 107 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 108 | 109 | acc1_sum += acc1[0] * input.shape[0] 110 | acc5_sum += acc5[0] * input.shape[0] 111 | n += input.shape[0] 112 | 113 | if i % 100 == 0: 114 | print('mini_batch {}, top-1 acc={:4g}%, top-5 acc={:4g}%, number of evaluated images={}'.format(i, acc1[0], acc5[0], n)) 115 | pass 116 | pass 117 | pass 118 | 119 | acc1_avg = acc1_sum / n 120 | acc5_avg = acc5_sum / n 121 | 122 | print('*** arch={}, validation top-1 acc={}%, top-5 acc={}%, number of evaluated images={}'.format(opt.arch, acc1_avg, acc5_avg, n)) -------------------------------------------------------------------------------- /SearchSpace/search_space_XXBL.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 4 | 5 | import itertools 6 | 7 | import global_utils 8 | from PlainNet import basic_blocks, super_blocks, SuperResKXKX, SuperResK1KXK1 9 | 10 | seach_space_block_type_list_list = [ 11 | [SuperResK1KXK1.SuperResK1K3K1, SuperResK1KXK1.SuperResK1K5K1, SuperResK1KXK1.SuperResK1K7K1], 12 | [SuperResKXKX.SuperResK3K3, SuperResKXKX.SuperResK5K5, SuperResKXKX.SuperResK7K7], 13 | ] 14 | 15 | __block_type_round_channels_base_dict__ = { 16 | SuperResKXKX.SuperResK3K3: 8, 17 | SuperResKXKX.SuperResK5K5: 8, 18 | SuperResKXKX.SuperResK7K7: 8, 19 | SuperResK1KXK1.SuperResK1K3K1: 8, SuperResK1KXK1.SuperResK1K5K1: 8, SuperResK1KXK1.SuperResK1K7K1: 8, 20 | } 21 | 22 | __block_type_min_channels_base_dict__ = { 23 | SuperResKXKX.SuperResK3K3: 8, 24 | SuperResKXKX.SuperResK5K5: 8, 25 | SuperResKXKX.SuperResK7K7: 8, 26 | SuperResK1KXK1.SuperResK1K3K1: 8, 27 | SuperResK1KXK1.SuperResK1K5K1: 8, 28 | SuperResK1KXK1.SuperResK1K7K1: 8, 29 | } 30 | 31 | 32 | def get_select_student_channels_list(out_channels): 33 | the_list = [out_channels * 2.5, out_channels * 2, out_channels * 1.5, out_channels * 1.25, 34 | out_channels, 35 | out_channels / 1.25, out_channels / 1.5, out_channels / 2, out_channels / 2.5] 36 | the_list = [min(2048, max(8, x)) for x in the_list] 37 | the_list = [global_utils.smart_round(x, base=8) for x in the_list] 38 | the_list = list(set(the_list)) 39 | the_list.sort(reverse=True) 40 | return the_list 41 | 42 | 43 | def get_select_student_sublayers_list(sub_layers): 44 | the_list = [sub_layers, 45 | sub_layers + 1, sub_layers + 2, 46 | sub_layers - 1, sub_layers - 2, ] 47 | the_list = [max(0, round(x)) for x in the_list] 48 | the_list = list(set(the_list)) 49 | the_list.sort(reverse=True) 50 | return the_list 51 | 52 | 53 | def gen_search_space(block_list, block_id): 54 | the_block = block_list[block_id] 55 | student_blocks_list_list = [] 56 | 57 | if isinstance(the_block, super_blocks.SuperConvKXBNRELU): 58 | student_blocks_list = [] 59 | student_out_channels_list = get_select_student_channels_list(the_block.out_channels) 60 | for student_out_channels in student_out_channels_list: 61 | tmp_block_str = type(the_block).__name__ + '({},{},{},1)'.format( 62 | the_block.in_channels, student_out_channels, the_block.stride) 63 | student_blocks_list.append(tmp_block_str) 64 | pass 65 | student_blocks_list = list(set(student_blocks_list)) 66 | assert len(student_blocks_list) >= 1 67 | student_blocks_list_list.append(student_blocks_list) 68 | else: 69 | for student_block_type_list in seach_space_block_type_list_list: 70 | student_blocks_list = [] 71 | student_out_channels_list = get_select_student_channels_list(the_block.out_channels) 72 | student_sublayers_list = get_select_student_sublayers_list(sub_layers=the_block.sub_layers) 73 | student_bottleneck_channels_list = get_select_student_channels_list(the_block.bottleneck_channels) 74 | for student_block_type in student_block_type_list: 75 | for student_out_channels, student_sublayers, student_bottleneck_channels in itertools.product( 76 | student_out_channels_list, student_sublayers_list, student_bottleneck_channels_list): 77 | 78 | # filter smallest possible channel for this block type 79 | min_possible_channels = __block_type_round_channels_base_dict__[student_block_type] 80 | channel_round_base = __block_type_round_channels_base_dict__[student_block_type] 81 | student_out_channels = global_utils.smart_round(student_out_channels, channel_round_base) 82 | student_bottleneck_channels = global_utils.smart_round(student_bottleneck_channels, 83 | channel_round_base) 84 | 85 | if student_out_channels < min_possible_channels or student_bottleneck_channels < min_possible_channels: 86 | continue 87 | if student_sublayers <= 0: # no empty layer 88 | continue 89 | tmp_block_str = student_block_type.__name__ + '({},{},{},{},{})'.format( 90 | the_block.in_channels, student_out_channels, the_block.stride, student_bottleneck_channels, 91 | student_sublayers) 92 | student_blocks_list.append(tmp_block_str) 93 | pass 94 | student_blocks_list = list(set(student_blocks_list)) 95 | assert len(student_blocks_list) >= 1 96 | student_blocks_list_list.append(student_blocks_list) 97 | pass 98 | pass # end for student_block_type_list in seach_space_block_type_list_list: 99 | pass 100 | return student_blocks_list_list 101 | -------------------------------------------------------------------------------- /ModelLoader/geffnet/activations/activations_me.py: -------------------------------------------------------------------------------- 1 | """ Activations (memory-efficient w/ custom autograd) 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | These activations are not compatible with jit scripting or ONNX export of the model, please use either 7 | the JIT or basic versions of the activations. 8 | 9 | Copyright 2020 Ross Wightman 10 | """ 11 | 12 | import torch 13 | from torch import nn as nn 14 | from torch.nn import functional as F 15 | 16 | 17 | __all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe', 18 | 'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe'] 19 | 20 | 21 | @torch.jit.script 22 | def swish_jit_fwd(x): 23 | return x.mul(torch.sigmoid(x)) 24 | 25 | 26 | @torch.jit.script 27 | def swish_jit_bwd(x, grad_output): 28 | x_sigmoid = torch.sigmoid(x) 29 | return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) 30 | 31 | 32 | class SwishJitAutoFn(torch.autograd.Function): 33 | """ torch.jit.script optimised Swish w/ memory-efficient checkpoint 34 | Inspired by conversation btw Jeremy Howard & Adam Pazske 35 | https://twitter.com/jeremyphoward/status/1188251041835315200 36 | 37 | Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) 38 | and also as Swish (https://arxiv.org/abs/1710.05941). 39 | 40 | TODO Rename to SiLU with addition to PyTorch 41 | """ 42 | 43 | @staticmethod 44 | def forward(ctx, x): 45 | ctx.save_for_backward(x) 46 | return swish_jit_fwd(x) 47 | 48 | @staticmethod 49 | def backward(ctx, grad_output): 50 | x = ctx.saved_tensors[0] 51 | return swish_jit_bwd(x, grad_output) 52 | 53 | 54 | def swish_me(x, inplace=False): 55 | return SwishJitAutoFn.apply(x) 56 | 57 | 58 | class SwishMe(nn.Module): 59 | def __init__(self, inplace: bool = False): 60 | super(SwishMe, self).__init__() 61 | 62 | def forward(self, x): 63 | return SwishJitAutoFn.apply(x) 64 | 65 | 66 | @torch.jit.script 67 | def mish_jit_fwd(x): 68 | return x.mul(torch.tanh(F.softplus(x))) 69 | 70 | 71 | @torch.jit.script 72 | def mish_jit_bwd(x, grad_output): 73 | x_sigmoid = torch.sigmoid(x) 74 | x_tanh_sp = F.softplus(x).tanh() 75 | return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) 76 | 77 | 78 | class MishJitAutoFn(torch.autograd.Function): 79 | """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 80 | A memory efficient, jit scripted variant of Mish 81 | """ 82 | @staticmethod 83 | def forward(ctx, x): 84 | ctx.save_for_backward(x) 85 | return mish_jit_fwd(x) 86 | 87 | @staticmethod 88 | def backward(ctx, grad_output): 89 | x = ctx.saved_tensors[0] 90 | return mish_jit_bwd(x, grad_output) 91 | 92 | 93 | def mish_me(x, inplace=False): 94 | return MishJitAutoFn.apply(x) 95 | 96 | 97 | class MishMe(nn.Module): 98 | def __init__(self, inplace: bool = False): 99 | super(MishMe, self).__init__() 100 | 101 | def forward(self, x): 102 | return MishJitAutoFn.apply(x) 103 | 104 | 105 | @torch.jit.script 106 | def hard_sigmoid_jit_fwd(x, inplace: bool = False): 107 | return (x + 3).clamp(min=0, max=6).div(6.) 108 | 109 | 110 | @torch.jit.script 111 | def hard_sigmoid_jit_bwd(x, grad_output): 112 | m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. 113 | return grad_output * m 114 | 115 | 116 | class HardSigmoidJitAutoFn(torch.autograd.Function): 117 | @staticmethod 118 | def forward(ctx, x): 119 | ctx.save_for_backward(x) 120 | return hard_sigmoid_jit_fwd(x) 121 | 122 | @staticmethod 123 | def backward(ctx, grad_output): 124 | x = ctx.saved_tensors[0] 125 | return hard_sigmoid_jit_bwd(x, grad_output) 126 | 127 | 128 | def hard_sigmoid_me(x, inplace: bool = False): 129 | return HardSigmoidJitAutoFn.apply(x) 130 | 131 | 132 | class HardSigmoidMe(nn.Module): 133 | def __init__(self, inplace: bool = False): 134 | super(HardSigmoidMe, self).__init__() 135 | 136 | def forward(self, x): 137 | return HardSigmoidJitAutoFn.apply(x) 138 | 139 | 140 | @torch.jit.script 141 | def hard_swish_jit_fwd(x): 142 | return x * (x + 3).clamp(min=0, max=6).div(6.) 143 | 144 | 145 | @torch.jit.script 146 | def hard_swish_jit_bwd(x, grad_output): 147 | m = torch.ones_like(x) * (x >= 3.) 148 | m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) 149 | return grad_output * m 150 | 151 | 152 | class HardSwishJitAutoFn(torch.autograd.Function): 153 | """A memory efficient, jit-scripted HardSwish activation""" 154 | @staticmethod 155 | def forward(ctx, x): 156 | ctx.save_for_backward(x) 157 | return hard_swish_jit_fwd(x) 158 | 159 | @staticmethod 160 | def backward(ctx, grad_output): 161 | x = ctx.saved_tensors[0] 162 | return hard_swish_jit_bwd(x, grad_output) 163 | 164 | 165 | def hard_swish_me(x, inplace=False): 166 | return HardSwishJitAutoFn.apply(x) 167 | 168 | 169 | class HardSwishMe(nn.Module): 170 | def __init__(self, inplace: bool = False): 171 | super(HardSwishMe, self).__init__() 172 | 173 | def forward(self, x): 174 | return HardSwishJitAutoFn.apply(x) 175 | -------------------------------------------------------------------------------- /benchmark_network_latency.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import os,sys, argparse 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | import ModelLoader, global_utils 8 | import train_image_classification as tic 9 | import torch, time 10 | import numpy as np 11 | 12 | 13 | def __get_latency__(model, batch_size, resolution, channel, gpu, benchmark_repeat_times, fp16): 14 | device = torch.device('cuda:{}'.format(gpu)) 15 | torch.backends.cudnn.benchmark = True 16 | 17 | torch.cuda.set_device(gpu) 18 | model = model.cuda(gpu) 19 | if fp16: 20 | model = model.half() 21 | dtype = torch.float16 22 | else: 23 | dtype = torch.float32 24 | 25 | the_image = torch.randn(batch_size, channel, resolution, resolution, dtype=dtype, 26 | device=device) 27 | model.eval() 28 | warmup_T = 3 29 | with torch.no_grad(): 30 | for i in range(warmup_T): 31 | the_output = model(the_image) 32 | start_timer = time.time() 33 | for repeat_count in range(benchmark_repeat_times): 34 | the_output = model(the_image) 35 | 36 | end_timer = time.time() 37 | the_latency = (end_timer - start_timer) / float(benchmark_repeat_times) / batch_size 38 | return the_latency 39 | 40 | 41 | def get_robust_latency_mean_std(model, batch_size, resolution, channel, gpu, benchmark_repeat_times=30, fp16=False): 42 | robust_repeat_times = 10 43 | latency_list = [] 44 | model = model.cuda(gpu) 45 | for repeat_count in range(robust_repeat_times): 46 | try: 47 | the_latency = __get_latency__(model, batch_size, resolution, channel, gpu, benchmark_repeat_times, fp16) 48 | except Exception as e: 49 | print(e) 50 | the_latency = np.inf 51 | 52 | latency_list.append(the_latency) 53 | 54 | pass # end for 55 | latency_list.sort() 56 | avg_latency = np.mean(latency_list[2:8]) 57 | std_latency = np.std(latency_list[2:8]) 58 | return avg_latency, std_latency 59 | 60 | def main(opt, argv): 61 | global_utils.create_logging() 62 | 63 | batch_size_list = [int(x) for x in opt.batch_size_list.split(',')] 64 | opt.batch_size = 1 65 | opt = tic.config_dist_env_and_opt(opt) 66 | 67 | # create model 68 | model = ModelLoader.get_model(opt, argv) 69 | 70 | print('batch_size, latency_per_image') 71 | 72 | for the_batch_size_per_gpu in batch_size_list: 73 | 74 | the_latency, _ = get_robust_latency_mean_std(model=model, batch_size=the_batch_size_per_gpu, 75 | resolution=opt.input_image_size, channel=3, gpu=opt.gpu, 76 | benchmark_repeat_times=opt.repeat_times, 77 | fp16=opt.fp16) 78 | print('{},{:4g}'.format(the_batch_size_per_gpu, the_latency)) 79 | 80 | if opt.dist_mode == 'auto': 81 | global_utils.release_gpu(opt.gpu) 82 | 83 | 84 | def get_model_latency(model, batch_size, resolution, in_channels, gpu, repeat_times, fp16): 85 | if gpu is not None: 86 | device = torch.device('cuda:{}'.format(gpu)) 87 | else: 88 | device = torch.device('cpu') 89 | 90 | if fp16: 91 | model = model.half() 92 | dtype = torch.float16 93 | else: 94 | dtype = torch.float32 95 | 96 | the_image = torch.randn(batch_size, in_channels, resolution, resolution, dtype=dtype, 97 | device=device) 98 | 99 | model.eval() 100 | warmup_T = 3 101 | with torch.no_grad(): 102 | for i in range(warmup_T): 103 | the_output = model(the_image) 104 | start_timer = time.time() 105 | for repeat_count in range(repeat_times): 106 | the_output = model(the_image) 107 | 108 | end_timer = time.time() 109 | the_latency = (end_timer - start_timer) / float(repeat_times) / batch_size 110 | return the_latency 111 | 112 | 113 | def parse_cmd_options(argv): 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('--batch_size', type=int, default=None, help='number of instances in one mini-batch.') 116 | parser.add_argument('--input_image_size', type=int, default=None, 117 | help='resolution of input image, usually 32 for CIFAR and 224 for ImageNet.') 118 | parser.add_argument('--save_dir', type=str, default=None, 119 | help='output directory') 120 | parser.add_argument('--repeat_times', type=int, default=1) 121 | parser.add_argument('--gpu', type=int, default=None) 122 | parser.add_argument('--fp16', action='store_true') 123 | module_opt, _ = parser.parse_known_args(argv) 124 | return module_opt 125 | 126 | if __name__ == "__main__": 127 | opt = global_utils.parse_cmd_options(sys.argv) 128 | args = parse_cmd_options(sys.argv) 129 | the_model = ModelLoader.get_model(opt, sys.argv) 130 | if args.gpu is not None: 131 | the_model = the_model.cuda(args.gpu) 132 | 133 | 134 | the_latency = get_model_latency(model=the_model, batch_size=args.batch_size, 135 | resolution=args.input_image_size, 136 | in_channels=3, gpu=args.gpu, repeat_times=args.repeat_times, 137 | fp16=args.fp16) 138 | print(f'{the_latency:.4g} second(s) per image, or {1.0/the_latency:.4g} image(s) per second.') 139 | -------------------------------------------------------------------------------- /SearchSpace/search_space_IDW_fixfc.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | 6 | import os, sys 7 | 8 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 9 | 10 | import itertools 11 | 12 | import global_utils 13 | from PlainNet import basic_blocks, super_blocks, SuperResKXKX, SuperResK1KXK1, SuperResIDWEXKX 14 | 15 | seach_space_block_type_list_list = [ 16 | [SuperResIDWEXKX.SuperResIDWE1K3, SuperResIDWEXKX.SuperResIDWE2K3, SuperResIDWEXKX.SuperResIDWE4K3, 17 | SuperResIDWEXKX.SuperResIDWE6K3, 18 | SuperResIDWEXKX.SuperResIDWE1K5, SuperResIDWEXKX.SuperResIDWE2K5, SuperResIDWEXKX.SuperResIDWE4K5, 19 | SuperResIDWEXKX.SuperResIDWE6K5, 20 | SuperResIDWEXKX.SuperResIDWE1K7, SuperResIDWEXKX.SuperResIDWE2K7, SuperResIDWEXKX.SuperResIDWE4K7, 21 | SuperResIDWEXKX.SuperResIDWE6K7], 22 | ] 23 | 24 | __block_type_round_channels_base_dict__ = { 25 | SuperResIDWEXKX.SuperResIDWE1K3: 8, 26 | SuperResIDWEXKX.SuperResIDWE2K3: 8, 27 | SuperResIDWEXKX.SuperResIDWE4K3: 8, 28 | SuperResIDWEXKX.SuperResIDWE6K3: 8, 29 | SuperResIDWEXKX.SuperResIDWE1K5: 8, 30 | SuperResIDWEXKX.SuperResIDWE2K5: 8, 31 | SuperResIDWEXKX.SuperResIDWE4K5: 8, 32 | SuperResIDWEXKX.SuperResIDWE6K5: 8, 33 | SuperResIDWEXKX.SuperResIDWE1K7: 8, 34 | SuperResIDWEXKX.SuperResIDWE2K7: 8, 35 | SuperResIDWEXKX.SuperResIDWE4K7: 8, 36 | SuperResIDWEXKX.SuperResIDWE6K7: 8, 37 | } 38 | 39 | __block_type_min_channels_base_dict__ = { 40 | SuperResIDWEXKX.SuperResIDWE1K3: 8, 41 | SuperResIDWEXKX.SuperResIDWE2K3: 8, 42 | SuperResIDWEXKX.SuperResIDWE4K3: 8, 43 | SuperResIDWEXKX.SuperResIDWE6K3: 8, 44 | SuperResIDWEXKX.SuperResIDWE1K5: 8, 45 | SuperResIDWEXKX.SuperResIDWE2K5: 8, 46 | SuperResIDWEXKX.SuperResIDWE4K5: 8, 47 | SuperResIDWEXKX.SuperResIDWE6K5: 8, 48 | SuperResIDWEXKX.SuperResIDWE1K7: 8, 49 | SuperResIDWEXKX.SuperResIDWE2K7: 8, 50 | SuperResIDWEXKX.SuperResIDWE4K7: 8, 51 | SuperResIDWEXKX.SuperResIDWE6K7: 8, 52 | } 53 | 54 | 55 | def get_select_student_channels_list(out_channels): 56 | the_list = [out_channels * 2.5, out_channels * 2, out_channels * 1.5, out_channels * 1.25, 57 | out_channels, 58 | out_channels / 1.25, out_channels / 1.5, out_channels / 2, out_channels / 2.5] 59 | the_list = [max(8, x) for x in the_list] 60 | the_list = [global_utils.smart_round(x, base=8) for x in the_list] 61 | the_list = list(set(the_list)) 62 | the_list.sort(reverse=True) 63 | return the_list 64 | 65 | 66 | def get_select_student_sublayers_list(sub_layers): 67 | the_list = [sub_layers, 68 | sub_layers + 1, sub_layers + 2, 69 | sub_layers - 1, sub_layers - 2, ] 70 | the_list = [max(0, round(x)) for x in the_list] 71 | the_list = list(set(the_list)) 72 | the_list.sort(reverse=True) 73 | return the_list 74 | 75 | 76 | def gen_search_space(block_list, block_id): 77 | the_block = block_list[block_id] 78 | student_blocks_list_list = [] 79 | 80 | if isinstance(the_block, super_blocks.SuperConvKXBNRELU): 81 | student_blocks_list = [] 82 | 83 | if the_block.kernel_size == 1: # last fc layer, never change fc 84 | student_out_channels_list = [the_block.out_channels] 85 | else: 86 | student_out_channels_list = get_select_student_channels_list(the_block.out_channels) 87 | 88 | for student_out_channels in student_out_channels_list: 89 | tmp_block_str = type(the_block).__name__ + '({},{},{},1)'.format( 90 | the_block.in_channels, student_out_channels, the_block.stride) 91 | student_blocks_list.append(tmp_block_str) 92 | pass 93 | student_blocks_list = list(set(student_blocks_list)) 94 | assert len(student_blocks_list) >= 1 95 | student_blocks_list_list.append(student_blocks_list) 96 | else: 97 | for student_block_type_list in seach_space_block_type_list_list: 98 | student_blocks_list = [] 99 | student_out_channels_list = get_select_student_channels_list(the_block.out_channels) 100 | student_sublayers_list = get_select_student_sublayers_list(sub_layers=the_block.sub_layers) 101 | student_bottleneck_channels_list = get_select_student_channels_list(the_block.bottleneck_channels) 102 | for student_block_type in student_block_type_list: 103 | for student_out_channels, student_sublayers, student_bottleneck_channels in itertools.product( 104 | student_out_channels_list, student_sublayers_list, student_bottleneck_channels_list): 105 | 106 | # filter smallest possible channel for this block type 107 | min_possible_channels = __block_type_round_channels_base_dict__[student_block_type] 108 | channel_round_base = __block_type_round_channels_base_dict__[student_block_type] 109 | student_out_channels = global_utils.smart_round(student_out_channels, channel_round_base) 110 | student_bottleneck_channels = global_utils.smart_round(student_bottleneck_channels, 111 | channel_round_base) 112 | 113 | if student_out_channels < min_possible_channels or student_bottleneck_channels < min_possible_channels: 114 | continue 115 | if student_sublayers <= 0: # no empty layer 116 | continue 117 | tmp_block_str = student_block_type.__name__ + '({},{},{},{},{})'.format( 118 | the_block.in_channels, student_out_channels, the_block.stride, student_bottleneck_channels, 119 | student_sublayers) 120 | student_blocks_list.append(tmp_block_str) 121 | pass 122 | student_blocks_list = list(set(student_blocks_list)) 123 | assert len(student_blocks_list) >= 1 124 | student_blocks_list_list.append(student_blocks_list) 125 | pass 126 | pass # end for student_block_type_list in seach_space_block_type_list_list: 127 | pass 128 | return student_blocks_list_list 129 | -------------------------------------------------------------------------------- /Masternet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import os, sys 6 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 7 | 8 | import numpy as np 9 | import torch, argparse 10 | from torch import nn 11 | import torch.nn.functional as F 12 | import PlainNet 13 | from PlainNet import parse_cmd_options, _create_netblock_list_from_str_, basic_blocks, super_blocks 14 | 15 | 16 | def parse_cmd_options(argv): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--no_BN', action='store_true') 19 | parser.add_argument('--no_reslink', action='store_true') 20 | parser.add_argument('--use_se', action='store_true') 21 | module_opt, _ = parser.parse_known_args(argv) 22 | return module_opt 23 | 24 | 25 | class MasterNet(PlainNet.PlainNet): 26 | def __init__(self, argv=None, opt=None, num_classes=None, plainnet_struct=None, no_create=False, 27 | no_reslink=None, no_BN=None, use_se=None): 28 | 29 | if argv is not None: 30 | module_opt = parse_cmd_options(argv) 31 | else: 32 | module_opt = None 33 | 34 | if no_BN is None: 35 | if module_opt is not None: 36 | no_BN = module_opt.no_BN 37 | else: 38 | no_BN = False 39 | 40 | if no_reslink is None: 41 | if module_opt is not None: 42 | no_reslink = module_opt.no_reslink 43 | else: 44 | no_reslink = False 45 | 46 | if use_se is None: 47 | if module_opt is not None: 48 | use_se = module_opt.use_se 49 | else: 50 | use_se = False 51 | 52 | 53 | super().__init__(argv=argv, opt=opt, num_classes=num_classes, plainnet_struct=plainnet_struct, 54 | no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, use_se=use_se) 55 | self.last_channels = self.block_list[-1].out_channels 56 | self.fc_linear = basic_blocks.Linear(in_channels=self.last_channels, out_channels=self.num_classes, no_create=no_create) 57 | 58 | self.no_create = no_create 59 | self.no_reslink = no_reslink 60 | self.no_BN = no_BN 61 | self.use_se = use_se 62 | 63 | # bn eps 64 | for layer in self.modules(): 65 | if isinstance(layer, nn.BatchNorm2d): 66 | layer.eps = 1e-3 67 | 68 | def extract_stage_features_and_logit(self, x, target_downsample_ratio=None): 69 | stage_features_list = [] 70 | image_size = x.shape[2] 71 | output = x 72 | 73 | for block_id, the_block in enumerate(self.block_list): 74 | output = the_block(output) 75 | dowsample_ratio = round(image_size / output.shape[2]) 76 | if dowsample_ratio == target_downsample_ratio: 77 | stage_features_list.append(output) 78 | target_downsample_ratio *= 2 79 | pass 80 | pass 81 | 82 | output = F.adaptive_avg_pool2d(output, output_size=1) 83 | output = torch.flatten(output, 1) 84 | logit = self.fc_linear(output) 85 | return stage_features_list, logit 86 | 87 | def forward(self, x): 88 | output = x 89 | for block_id, the_block in enumerate(self.block_list): 90 | output = the_block(output) 91 | 92 | output = F.adaptive_avg_pool2d(output, output_size=1) 93 | 94 | output = torch.flatten(output, 1) 95 | output = self.fc_linear(output) 96 | return output 97 | 98 | def forward_pre_GAP(self, x): 99 | output = x 100 | for the_block in self.block_list: 101 | output = the_block(output) 102 | return output 103 | 104 | def get_FLOPs(self, input_resolution): 105 | the_res = input_resolution 106 | the_flops = 0 107 | for the_block in self.block_list: 108 | the_flops += the_block.get_FLOPs(the_res) 109 | the_res = the_block.get_output_resolution(the_res) 110 | 111 | the_flops += self.fc_linear.get_FLOPs(the_res) 112 | 113 | return the_flops 114 | 115 | def get_model_size(self): 116 | the_size = 0 117 | for the_block in self.block_list: 118 | the_size += the_block.get_model_size() 119 | 120 | the_size += self.fc_linear.get_model_size() 121 | 122 | return the_size 123 | 124 | def get_num_layers(self): 125 | num_layers = 0 126 | for block in self.block_list: 127 | assert isinstance(block, super_blocks.PlainNetSuperBlockClass) 128 | num_layers += block.sub_layers 129 | return num_layers 130 | 131 | def replace_block(self, block_id, new_block): 132 | self.block_list[block_id] = new_block 133 | 134 | if block_id < len(self.block_list) - 1: 135 | if self.block_list[block_id + 1].in_channels != new_block.out_channels: 136 | self.block_list[block_id + 1].set_in_channels(new_block.out_channels) 137 | else: 138 | assert block_id == len(self.block_list) - 1 139 | self.last_channels = self.block_list[-1].out_channels 140 | if self.fc_linear.in_channels != self.last_channels: 141 | self.fc_linear.set_in_channels(self.last_channels) 142 | 143 | self.module_list = nn.ModuleList(self.block_list) 144 | 145 | def split(self, split_layer_threshold): 146 | new_str = '' 147 | for block in self.block_list: 148 | new_str += block.split(split_layer_threshold=split_layer_threshold) 149 | return new_str 150 | 151 | def init_parameters(self): 152 | for m in self.modules(): 153 | if isinstance(m, nn.Conv2d): 154 | nn.init.xavier_normal_(m.weight.data, gain=3.26033) 155 | if hasattr(m, 'bias') and m.bias is not None: 156 | nn.init.zeros_(m.bias) 157 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 158 | nn.init.ones_(m.weight) 159 | nn.init.zeros_(m.bias) 160 | elif isinstance(m, nn.Linear): 161 | nn.init.normal_(m.weight, 0, 3.26033 * np.sqrt(2 / (m.weight.shape[0] + m.weight.shape[1]))) 162 | if hasattr(m, 'bias') and m.bias is not None: 163 | nn.init.zeros_(m.bias) 164 | else: 165 | pass 166 | 167 | for superblock in self.block_list: 168 | if not isinstance(superblock, super_blocks.PlainNetSuperBlockClass): 169 | continue 170 | for block in superblock.block_list: 171 | if not (isinstance(block, basic_blocks.ResBlock) or isinstance(block, basic_blocks.ResBlockProj)): 172 | continue 173 | # print('---debug set bn weight zero in resblock {}:{}'.format(superblock, block)) 174 | last_bn_block = None 175 | for inner_resblock in block.block_list: 176 | if isinstance(inner_resblock, basic_blocks.BN): 177 | last_bn_block = inner_resblock 178 | pass 179 | pass # end for 180 | assert last_bn_block is not None 181 | # print('-------- last_bn_block={}'.format(last_bn_block)) 182 | nn.init.zeros_(last_bn_block.netblock.weight) -------------------------------------------------------------------------------- /hotfix/folder.py: -------------------------------------------------------------------------------- 1 | from .vision import VisionDataset 2 | 3 | from PIL import Image 4 | 5 | import os 6 | import os.path 7 | import sys 8 | 9 | 10 | def has_file_allowed_extension(filename, extensions): 11 | """Checks if a file is an allowed extension. 12 | 13 | Args: 14 | filename (string): path to a file 15 | extensions (tuple of strings): extensions to consider (lowercase) 16 | 17 | Returns: 18 | bool: True if the filename ends with one of given extensions 19 | """ 20 | return filename.lower().endswith(extensions) 21 | 22 | 23 | def is_image_file(filename): 24 | """Checks if a file is an allowed image extension. 25 | 26 | Args: 27 | filename (string): path to a file 28 | 29 | Returns: 30 | bool: True if the filename ends with a known image extension 31 | """ 32 | return has_file_allowed_extension(filename, IMG_EXTENSIONS) 33 | 34 | 35 | def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None): 36 | images = [] 37 | dir = os.path.expanduser(dir) 38 | if not ((extensions is None) ^ (is_valid_file is None)): 39 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") 40 | if extensions is not None: 41 | def is_valid_file(x): 42 | return has_file_allowed_extension(x, extensions) 43 | for target in sorted(class_to_idx.keys()): 44 | d = os.path.join(dir, target) 45 | if not os.path.isdir(d): 46 | continue 47 | for root, _, fnames in sorted(os.walk(d)): 48 | for fname in sorted(fnames): 49 | path = os.path.join(root, fname) 50 | if is_valid_file(path): 51 | item = (path, class_to_idx[target]) 52 | images.append(item) 53 | 54 | return images 55 | 56 | 57 | class DatasetFolder(VisionDataset): 58 | """A generic data loader where the samples are arranged in this way: :: 59 | 60 | root/class_x/xxx.ext 61 | root/class_x/xxy.ext 62 | root/class_x/xxz.ext 63 | 64 | root/class_y/123.ext 65 | root/class_y/nsdf3.ext 66 | root/class_y/asd932_.ext 67 | 68 | Args: 69 | root (string): Root directory path. 70 | loader (callable): A function to load a sample given its path. 71 | extensions (tuple[string]): A list of allowed extensions. 72 | both extensions and is_valid_file should not be passed. 73 | transform (callable, optional): A function/transform that takes in 74 | a sample and returns a transformed version. 75 | E.g, ``transforms.RandomCrop`` for images. 76 | target_transform (callable, optional): A function/transform that takes 77 | in the target and transforms it. 78 | is_valid_file (callable, optional): A function that takes path of an Image file 79 | and check if the file is a valid_file (used to check of corrupt files) 80 | both extensions and is_valid_file should not be passed. 81 | 82 | Attributes: 83 | classes (list): List of the class names. 84 | class_to_idx (dict): Dict with items (class_name, class_index). 85 | samples (list): List of (sample path, class_index) tuples 86 | targets (list): The class_index value for each image in the dataset 87 | """ 88 | 89 | def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None): 90 | super(DatasetFolder, self).__init__(root) 91 | self.transform = transform 92 | self.target_transform = target_transform 93 | classes, class_to_idx = self._find_classes(self.root) 94 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) 95 | if len(samples) == 0: 96 | raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n" 97 | "Supported extensions are: " + ",".join(extensions))) 98 | 99 | self.loader = loader 100 | self.extensions = extensions 101 | 102 | self.classes = classes 103 | self.class_to_idx = class_to_idx 104 | self.samples = samples 105 | self.targets = [s[1] for s in samples] 106 | 107 | def _find_classes(self, dir): 108 | """ 109 | Finds the class folders in a dataset. 110 | 111 | Args: 112 | dir (string): Root directory path. 113 | 114 | Returns: 115 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 116 | 117 | Ensures: 118 | No class is a subdirectory of another. 119 | """ 120 | if sys.version_info >= (3, 5): 121 | # Faster and available in Python 3.5 and above 122 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 123 | else: 124 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 125 | classes.sort() 126 | class_to_idx = {classes[i]: i for i in range(len(classes))} 127 | return classes, class_to_idx 128 | 129 | def __getitem__(self, index): 130 | """ 131 | Args: 132 | index (int): Index 133 | 134 | Returns: 135 | tuple: (sample, target) where target is class_index of the target class. 136 | """ 137 | path, target = self.samples[index] 138 | sample = self.loader(path) 139 | if self.transform is not None: 140 | sample = self.transform(sample) 141 | if self.target_transform is not None: 142 | target = self.target_transform(target) 143 | 144 | return sample, target 145 | 146 | def __len__(self): 147 | return len(self.samples) 148 | 149 | 150 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 151 | 152 | 153 | def pil_loader(path): 154 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 155 | with open(path, 'rb') as f: 156 | img = Image.open(f) 157 | return img.convert('RGB') 158 | 159 | 160 | def accimage_loader(path): 161 | import accimage 162 | try: 163 | return accimage.Image(path) 164 | except IOError: 165 | # Potentially a decoding problem, fall back to PIL.Image 166 | return pil_loader(path) 167 | 168 | 169 | def default_loader(path): 170 | from torchvision import get_image_backend 171 | if get_image_backend() == 'accimage': 172 | return accimage_loader(path) 173 | else: 174 | return pil_loader(path) 175 | 176 | 177 | class ImageFolder(DatasetFolder): 178 | """A generic data loader where the images are arranged in this way: :: 179 | 180 | root/dog/xxx.png 181 | root/dog/xxy.png 182 | root/dog/xxz.png 183 | 184 | root/cat/123.png 185 | root/cat/nsdf3.png 186 | root/cat/asd932_.png 187 | 188 | Args: 189 | root (string): Root directory path. 190 | transform (callable, optional): A function/transform that takes in an PIL image 191 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 192 | target_transform (callable, optional): A function/transform that takes in the 193 | target and transforms it. 194 | loader (callable, optional): A function to load an image given its path. 195 | is_valid_file (callable, optional): A function that takes path of an Image file 196 | and check if the file is a valid_file (used to check of corrupt files) 197 | 198 | Attributes: 199 | classes (list): List of the class names. 200 | class_to_idx (dict): Dict with items (class_name, class_index). 201 | imgs (list): List of (image path, class_index) tuples 202 | """ 203 | 204 | def __init__(self, root, transform=None, target_transform=None, 205 | loader=default_loader, is_valid_file=None): 206 | super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, 207 | transform=transform, 208 | target_transform=target_transform, 209 | is_valid_file=is_valid_file) 210 | self.imgs = self.samples 211 | -------------------------------------------------------------------------------- /ZenNet/masternet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | 6 | import os, sys 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | 9 | import numpy as np 10 | import torch, argparse 11 | from torch import nn 12 | import torch.nn.functional as F 13 | import PlainNet 14 | from PlainNet import parse_cmd_options, basic_blocks, super_blocks 15 | 16 | 17 | def parse_cmd_options(argv): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--no_BN', action='store_true') 20 | parser.add_argument('--no_reslink', action='store_true') 21 | parser.add_argument('--use_se', action='store_true') 22 | parser.add_argument('--dropout', type=float, default=None) 23 | module_opt, _ = parser.parse_known_args(argv) 24 | return module_opt 25 | 26 | 27 | class PlainNet(PlainNet.PlainNet): 28 | def __init__(self, argv=None, opt=None, num_classes=None, plainnet_struct=None, no_create=False, 29 | no_reslink=None, no_BN=None, use_se=None, dropout=None, 30 | **kwargs): 31 | 32 | if argv is not None: 33 | module_opt = parse_cmd_options(argv) 34 | else: 35 | module_opt = None 36 | 37 | if no_BN is None: 38 | if module_opt is not None: 39 | no_BN = module_opt.no_BN 40 | else: 41 | no_BN = False 42 | 43 | if no_reslink is None: 44 | if module_opt is not None: 45 | no_reslink = module_opt.no_reslink 46 | else: 47 | no_reslink = False 48 | 49 | if use_se is None: 50 | if module_opt is not None: 51 | use_se = module_opt.use_se 52 | else: 53 | use_se = False 54 | 55 | if dropout is None: 56 | if module_opt is not None: 57 | self.dropout = module_opt.dropout 58 | else: 59 | self.dropout = None 60 | else: 61 | self.dropout = dropout 62 | 63 | if self.dropout is not None: 64 | print('--- using dropout={:4g}'.format(self.dropout)) 65 | 66 | super(PlainNet, self).__init__(argv=argv, opt=opt, num_classes=num_classes, plainnet_struct=plainnet_struct, 67 | no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, use_se=use_se, **kwargs) 68 | self.last_channels = self.block_list[-1].out_channels 69 | self.fc_linear = basic_blocks.Linear(in_channels=self.last_channels, out_channels=self.num_classes, no_create=no_create) 70 | 71 | self.no_create = no_create 72 | self.no_reslink = no_reslink 73 | self.no_BN = no_BN 74 | self.use_se = use_se 75 | 76 | # bn eps 77 | for layer in self.modules(): 78 | if isinstance(layer, nn.BatchNorm2d): 79 | layer.eps = 1e-3 80 | 81 | def extract_stage_features_and_logit(self, x, target_downsample_ratio=None): 82 | stage_features_list = [] 83 | image_size = x.shape[2] 84 | output = x 85 | 86 | for block_id, the_block in enumerate(self.block_list): 87 | output = the_block(output) 88 | if self.dropout is not None: 89 | dropout_p = float(block_id) / len(self.block_list) * self.dropout 90 | output = F.dropout(output, dropout_p, training=self.training, inplace=True) 91 | dowsample_ratio = round(image_size / output.shape[2]) 92 | if dowsample_ratio == target_downsample_ratio: 93 | stage_features_list.append(output) 94 | target_downsample_ratio *= 2 95 | pass 96 | pass 97 | 98 | output = F.adaptive_avg_pool2d(output, output_size=1) 99 | # if self.dropout is not None: 100 | # output = F.dropout(output, self.dropout, training=self.training, inplace=True) 101 | output = torch.flatten(output, 1) 102 | logit = self.fc_linear(output) 103 | return stage_features_list, logit 104 | 105 | def forward(self, x): 106 | output = x 107 | for block_id, the_block in enumerate(self.block_list): 108 | output = the_block(output) 109 | if self.dropout is not None: 110 | dropout_p = float(block_id) / len(self.block_list) * self.dropout 111 | output = F.dropout(output, dropout_p, training=self.training, inplace=True) 112 | 113 | output = F.adaptive_avg_pool2d(output, output_size=1) 114 | if self.dropout is not None: 115 | output = F.dropout(output, self.dropout, training=self.training, inplace=True) 116 | output = torch.flatten(output, 1) 117 | output = self.fc_linear(output) 118 | return output 119 | 120 | def forward_pre_GAP(self, x): 121 | output = x 122 | for the_block in self.block_list: 123 | output = the_block(output) 124 | return output 125 | 126 | def get_FLOPs(self, input_resolution): 127 | the_res = input_resolution 128 | the_flops = 0 129 | for the_block in self.block_list: 130 | the_flops += the_block.get_FLOPs(the_res) 131 | the_res = the_block.get_output_resolution(the_res) 132 | 133 | the_flops += self.fc_linear.get_FLOPs(the_res) 134 | 135 | return the_flops 136 | 137 | def get_model_size(self): 138 | the_size = 0 139 | for the_block in self.block_list: 140 | the_size += the_block.get_model_size() 141 | 142 | the_size += self.fc_linear.get_model_size() 143 | 144 | return the_size 145 | 146 | def get_num_layers(self): 147 | num_layers = 0 148 | for block in self.block_list: 149 | assert isinstance(block, super_blocks.PlainNetSuperBlockClass) 150 | num_layers += block.sub_layers 151 | return num_layers 152 | 153 | def replace_block(self, block_id, new_block): 154 | self.block_list[block_id] = new_block 155 | 156 | if block_id < len(self.block_list) - 1: 157 | if self.block_list[block_id + 1].in_channels != new_block.out_channels: 158 | self.block_list[block_id + 1].set_in_channels(new_block.out_channels) 159 | else: 160 | assert block_id == len(self.block_list) - 1 161 | self.last_channels = self.block_list[-1].out_channels 162 | if self.fc_linear.in_channels != self.last_channels: 163 | self.fc_linear.set_in_channels(self.last_channels) 164 | 165 | self.module_list = nn.ModuleList(self.block_list) 166 | 167 | def split(self, split_layer_threshold): 168 | new_str = '' 169 | for block in self.block_list: 170 | new_str += block.split(split_layer_threshold=split_layer_threshold) 171 | return new_str 172 | 173 | def init_parameters(self): 174 | for m in self.modules(): 175 | if isinstance(m, nn.Conv2d): 176 | nn.init.xavier_normal_(m.weight.data, gain=3.26033) 177 | if hasattr(m, 'bias') and m.bias is not None: 178 | nn.init.zeros_(m.bias) 179 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 180 | nn.init.ones_(m.weight) 181 | nn.init.zeros_(m.bias) 182 | elif isinstance(m, nn.Linear): 183 | nn.init.normal_(m.weight, 0, 3.26033 * np.sqrt(2 / (m.weight.shape[0] + m.weight.shape[1]))) 184 | if hasattr(m, 'bias') and m.bias is not None: 185 | nn.init.zeros_(m.bias) 186 | else: 187 | pass 188 | 189 | for superblock in self.block_list: 190 | if not isinstance(superblock, super_blocks.PlainNetSuperBlockClass): 191 | continue 192 | for block in superblock.block_list: 193 | if not isinstance(block, basic_blocks.ResBlock): 194 | continue 195 | # print('---debug set bn weight zero in resblock {}:{}'.format(superblock, block)) 196 | last_bn_block = None 197 | for inner_resblock in block.block_list: 198 | if isinstance(inner_resblock, basic_blocks.BN): 199 | last_bn_block = inner_resblock 200 | pass 201 | pass # end for 202 | assert last_bn_block is not None 203 | # print('-------- last_bn_block={}'.format(last_bn_block)) 204 | nn.init.zeros_(last_bn_block.netblock.weight) -------------------------------------------------------------------------------- /PlainNet/SuperResKXKX.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | 6 | import os, sys 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | import uuid 9 | 10 | import PlainNet 11 | from PlainNet import _get_right_parentheses_index_ 12 | from PlainNet.super_blocks import PlainNetSuperBlockClass 13 | from torch import nn 14 | import global_utils 15 | 16 | 17 | class SuperResKXKX(PlainNetSuperBlockClass): 18 | def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, kernel_size=None, 19 | no_create=False, no_reslink=False, no_BN=False, use_se=False, **kwargs): 20 | super(SuperResKXKX, self).__init__(**kwargs) 21 | self.in_channels = in_channels 22 | self.out_channels = out_channels 23 | self.stride = stride 24 | self.bottleneck_channels = bottleneck_channels 25 | self.sub_layers = sub_layers 26 | self.kernel_size = kernel_size 27 | self.no_create = no_create 28 | self.no_reslink = no_reslink 29 | self.no_BN = no_BN 30 | self.use_se = use_se 31 | if self.use_se: 32 | print('---debug use_se in ' + str(self)) 33 | 34 | full_str = '' 35 | last_channels = in_channels 36 | current_stride = stride 37 | for i in range(self.sub_layers): 38 | inner_str = '' 39 | 40 | inner_str += 'ConvKX({},{},{},{})'.format(last_channels, self.bottleneck_channels, self.kernel_size, current_stride) 41 | if not self.no_BN: 42 | inner_str += 'BN({})'.format(self.bottleneck_channels) 43 | inner_str += 'RELU({})'.format(self.bottleneck_channels) 44 | if self.use_se: 45 | inner_str += 'SE({})'.format(bottleneck_channels) 46 | 47 | inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.out_channels, self.kernel_size, 1) 48 | if not self.no_BN: 49 | inner_str += 'BN({})'.format(self.out_channels) 50 | 51 | if not self.no_reslink: 52 | if i == 0: 53 | res_str = 'ResBlockProj({})RELU({})'.format(inner_str, out_channels) 54 | else: 55 | res_str = 'ResBlock({})RELU({})'.format(inner_str, out_channels) 56 | else: 57 | res_str = '{}RELU({})'.format(inner_str, out_channels) 58 | 59 | full_str += res_str 60 | 61 | last_channels = out_channels 62 | current_stride = 1 63 | pass 64 | 65 | self.block_list = PlainNet.create_netblock_list_from_str(full_str, no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, **kwargs) 66 | if not no_create: 67 | self.module_list = nn.ModuleList(self.block_list) 68 | else: 69 | self.module_list = None 70 | 71 | def forward_pre_relu(self, x): 72 | output = x 73 | for block in self.block_list[0:-1]: 74 | output = block(output) 75 | return output 76 | 77 | 78 | def __str__(self): 79 | return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, 80 | self.stride, self.bottleneck_channels, self.sub_layers) 81 | 82 | def __repr__(self): 83 | return type(self).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( 84 | self.block_name, self.in_channels, self.out_channels, self.stride, self.bottleneck_channels, self.sub_layers, self.kernel_size 85 | ) 86 | 87 | def encode_structure(self): 88 | return [self.out_channels, self.sub_layers, self.bottleneck_channels] 89 | 90 | def split(self, split_layer_threshold): 91 | if self.sub_layers >= split_layer_threshold: 92 | new_sublayers_1 = split_layer_threshold // 2 93 | new_sublayers_2 = self.sub_layers - new_sublayers_1 94 | new_block_str1 = type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, 95 | self.stride, self.bottleneck_channels, new_sublayers_1) 96 | new_block_str2 = type(self).__name__ + '({},{},{},{},{})'.format(self.out_channels, self.out_channels, 97 | 1, self.bottleneck_channels, 98 | new_sublayers_2) 99 | return new_block_str1 + new_block_str2 100 | else: 101 | return str(self) 102 | 103 | def structure_scale(self, scale=1.0, channel_scale=None, sub_layer_scale=None): 104 | if channel_scale is None: 105 | channel_scale = scale 106 | if sub_layer_scale is None: 107 | sub_layer_scale = scale 108 | 109 | new_out_channels = global_utils.smart_round(self.out_channels * channel_scale) 110 | new_bottleneck_channels = global_utils.smart_round(self.bottleneck_channels * channel_scale) 111 | new_sub_layers = max(1, round(self.sub_layers * sub_layer_scale)) 112 | 113 | return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, new_out_channels, 114 | self.stride, new_bottleneck_channels, new_sub_layers) 115 | 116 | 117 | @classmethod 118 | def create_from_str(cls, s, **kwargs): 119 | assert cls.is_instance_from_str(s) 120 | idx = _get_right_parentheses_index_(s) 121 | assert idx is not None 122 | param_str = s[len(cls.__name__ + '('):idx] 123 | 124 | # find block_name 125 | tmp_idx = param_str.find('|') 126 | if tmp_idx < 0: 127 | tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) 128 | else: 129 | tmp_block_name = param_str[0:tmp_idx] 130 | param_str = param_str[tmp_idx + 1:] 131 | 132 | param_str_split = param_str.split(',') 133 | in_channels = int(param_str_split[0]) 134 | out_channels = int(param_str_split[1]) 135 | stride = int(param_str_split[2]) 136 | bottleneck_channels = int(param_str_split[3]) 137 | sub_layers = int(param_str_split[4]) 138 | return cls(in_channels=in_channels, out_channels=out_channels, stride=stride, 139 | bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, 140 | block_name=tmp_block_name, **kwargs), s[idx + 1:] 141 | 142 | 143 | class SuperResK3K3(SuperResKXKX): 144 | def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): 145 | super(SuperResK3K3, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, 146 | bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, 147 | kernel_size=3, 148 | no_create=no_create, **kwargs) 149 | 150 | class SuperResK5K5(SuperResKXKX): 151 | def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): 152 | super(SuperResK5K5, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, 153 | bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, 154 | kernel_size=5, 155 | no_create=no_create, **kwargs) 156 | 157 | class SuperResK7K7(SuperResKXKX): 158 | def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): 159 | super(SuperResK7K7, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, 160 | bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, 161 | kernel_size=7, 162 | no_create=no_create, **kwargs) 163 | 164 | 165 | 166 | def register_netblocks_dict(netblocks_dict: dict): 167 | this_py_file_netblocks_dict = { 168 | 'SuperResK3K3': SuperResK3K3, 169 | 'SuperResK5K5': SuperResK5K5, 170 | 'SuperResK7K7': SuperResK7K7, 171 | 172 | } 173 | netblocks_dict.update(this_py_file_netblocks_dict) 174 | return netblocks_dict 175 | -------------------------------------------------------------------------------- /PlainNet/SuperResK1KXK1.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | 6 | import os, sys 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | import uuid 9 | 10 | import PlainNet 11 | from PlainNet import _get_right_parentheses_index_ 12 | from PlainNet.super_blocks import PlainNetSuperBlockClass 13 | from torch import nn 14 | import global_utils 15 | 16 | 17 | class SuperResK1KXK1(PlainNetSuperBlockClass): 18 | def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None ,sub_layers=None, kernel_size=None, 19 | no_create=False, no_reslink=False, no_BN=False, use_se=False, **kwargs): 20 | super(SuperResK1KXK1, self).__init__(**kwargs) 21 | self.in_channels = in_channels 22 | self.out_channels = out_channels 23 | self.stride = stride 24 | self.bottleneck_channels = bottleneck_channels 25 | self.sub_layers = sub_layers 26 | self.kernel_size = kernel_size 27 | self.no_create = no_create 28 | self.no_reslink = no_reslink 29 | self.no_BN = no_BN 30 | self.use_se = use_se 31 | if self.use_se: 32 | print('---debug use_se in ' + str(self)) 33 | 34 | full_str = '' 35 | last_channels = in_channels 36 | current_stride = stride 37 | for i in range(self.sub_layers): 38 | inner_str = '' 39 | 40 | # first bl-block with reslink 41 | inner_str += 'ConvKX({},{},{},{})'.format(last_channels, self.bottleneck_channels, 1, 1) 42 | if not self.no_BN: 43 | inner_str += 'BN({})'.format(self.bottleneck_channels) 44 | inner_str += 'RELU({})'.format(self.bottleneck_channels) 45 | 46 | inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.bottleneck_channels, 47 | self.kernel_size, current_stride) 48 | if not self.no_BN: 49 | inner_str += 'BN({})'.format(self.bottleneck_channels) 50 | inner_str += 'RELU({})'.format(self.bottleneck_channels) 51 | if self.use_se: 52 | inner_str += 'SE({})'.format(bottleneck_channels) 53 | 54 | inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.out_channels, 1, 1) 55 | if not self.no_BN: 56 | inner_str += 'BN({})'.format(self.out_channels) 57 | 58 | if not self.no_reslink: 59 | if i == 0: 60 | res_str = 'ResBlockProj({})RELU({})'.format(inner_str, out_channels) 61 | else: 62 | res_str = 'ResBlock({})RELU({})'.format(inner_str, out_channels) 63 | else: 64 | res_str = '{}RELU({})'.format(inner_str, out_channels) 65 | 66 | full_str += res_str 67 | 68 | # second bl-block with reslink 69 | inner_str = '' 70 | inner_str += 'ConvKX({},{},{},{})'.format(self.out_channels, self.bottleneck_channels, 1, 1) 71 | if not self.no_BN: 72 | inner_str += 'BN({})'.format(self.bottleneck_channels) 73 | inner_str += 'RELU({})'.format(self.bottleneck_channels) 74 | 75 | inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.bottleneck_channels, 76 | self.kernel_size, 1) 77 | if not self.no_BN: 78 | inner_str += 'BN({})'.format(self.bottleneck_channels) 79 | inner_str += 'RELU({})'.format(self.bottleneck_channels) 80 | if self.use_se: 81 | inner_str += 'SE({})'.format(bottleneck_channels) 82 | 83 | inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.out_channels, 1, 1) 84 | if not self.no_BN: 85 | inner_str += 'BN({})'.format(self.out_channels) 86 | 87 | if not self.no_reslink: 88 | res_str = 'ResBlock({})RELU({})'.format(inner_str, out_channels) 89 | else: 90 | res_str = '{}RELU({})'.format(inner_str, out_channels) 91 | 92 | full_str += res_str 93 | 94 | last_channels = out_channels 95 | current_stride = 1 96 | pass 97 | 98 | self.block_list = PlainNet.create_netblock_list_from_str(full_str, no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, **kwargs) 99 | if not no_create: 100 | self.module_list = nn.ModuleList(self.block_list) 101 | else: 102 | self.module_list = None 103 | 104 | def __str__(self): 105 | return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, 106 | self.stride, self.bottleneck_channels, self.sub_layers) 107 | 108 | def __repr__(self): 109 | return type(self).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( 110 | self.block_name, self.in_channels, self.out_channels, self.stride, self.bottleneck_channels, self.sub_layers, self.kernel_size 111 | ) 112 | 113 | def encode_structure(self): 114 | return [self.out_channels, self.sub_layers, self.bottleneck_channels] 115 | 116 | def split(self, split_layer_threshold): 117 | if self.sub_layers >= split_layer_threshold: 118 | new_sublayers_1 = split_layer_threshold // 2 119 | new_sublayers_2 = self.sub_layers - new_sublayers_1 120 | new_block_str1 = type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, 121 | self.stride, self.bottleneck_channels, new_sublayers_1) 122 | new_block_str2 = type(self).__name__ + '({},{},{},{},{})'.format(self.out_channels, self.out_channels, 123 | 1, self.bottleneck_channels, 124 | new_sublayers_2) 125 | return new_block_str1 + new_block_str2 126 | else: 127 | return str(self) 128 | 129 | def structure_scale(self, scale=1.0, channel_scale=None, sub_layer_scale=None): 130 | if channel_scale is None: 131 | channel_scale = scale 132 | if sub_layer_scale is None: 133 | sub_layer_scale = scale 134 | 135 | new_out_channels = global_utils.smart_round(self.out_channels * channel_scale) 136 | new_bottleneck_channels = global_utils.smart_round(self.bottleneck_channels * channel_scale) 137 | new_sub_layers = max(1, round(self.sub_layers * sub_layer_scale)) 138 | 139 | return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, new_out_channels, 140 | self.stride, new_bottleneck_channels, new_sub_layers) 141 | 142 | 143 | @classmethod 144 | def create_from_str(cls, s, **kwargs): 145 | assert cls.is_instance_from_str(s) 146 | idx = _get_right_parentheses_index_(s) 147 | assert idx is not None 148 | param_str = s[len(cls.__name__ + '('):idx] 149 | 150 | # find block_name 151 | tmp_idx = param_str.find('|') 152 | if tmp_idx < 0: 153 | tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) 154 | else: 155 | tmp_block_name = param_str[0:tmp_idx] 156 | param_str = param_str[tmp_idx + 1:] 157 | 158 | param_str_split = param_str.split(',') 159 | in_channels = int(param_str_split[0]) 160 | out_channels = int(param_str_split[1]) 161 | stride = int(param_str_split[2]) 162 | bottleneck_channels = int(param_str_split[3]) 163 | sub_layers = int(param_str_split[4]) 164 | return cls(in_channels=in_channels, out_channels=out_channels, stride=stride, 165 | bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, 166 | block_name=tmp_block_name, **kwargs),s[idx + 1:] 167 | 168 | 169 | class SuperResK1K3K1(SuperResK1KXK1): 170 | def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): 171 | super(SuperResK1K3K1, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, 172 | bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, 173 | kernel_size=3, 174 | no_create=no_create, **kwargs) 175 | 176 | class SuperResK1K5K1(SuperResK1KXK1): 177 | def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): 178 | super(SuperResK1K5K1, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, 179 | bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, 180 | kernel_size=5, 181 | no_create=no_create, **kwargs) 182 | 183 | 184 | class SuperResK1K7K1(SuperResK1KXK1): 185 | def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): 186 | super(SuperResK1K7K1, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, 187 | bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, 188 | kernel_size=7, 189 | no_create=no_create, **kwargs) 190 | 191 | 192 | def register_netblocks_dict(netblocks_dict: dict): 193 | this_py_file_netblocks_dict = { 194 | 'SuperResK1K3K1': SuperResK1K3K1, 195 | 'SuperResK1K5K1': SuperResK1K5K1, 196 | 'SuperResK1K7K1': SuperResK1K7K1, 197 | } 198 | netblocks_dict.update(this_py_file_netblocks_dict) 199 | return netblocks_dict -------------------------------------------------------------------------------- /ZenNet/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | 6 | import os, sys 7 | this_script_dir = os.path.dirname(os.path.abspath(__file__)) 8 | 9 | from . import masternet 10 | import global_utils 11 | import torch 12 | import urllib.request 13 | 14 | pretrain_model_pth_dir = os.path.expanduser('~/.cache/pytorch/checkpoints/zennet_pretrained') 15 | 16 | zennet_model_zoo = { 17 | 'zennet_cifar10_model_size05M_res32': { 18 | 'plainnet_str_txt': 'zennet_cifar_model_size05M_res32.txt', 19 | 'pth_path': 'zennet_cifar10_model_size05M_res32/best-params_rank0.pth', 20 | 'num_classes': 10, 21 | 'use_SE': False, 22 | 'resolution': 32, 23 | 'crop_image_size': 32, 24 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/zennet_cifar10_model_size05M_res32/best-params_rank0.pth', 25 | }, 26 | 27 | 'zennet_cifar10_model_size1M_res32': { 28 | 'plainnet_str_txt': 'zennet_cifar_model_size1M_res32.txt', 29 | 'pth_path': 'zennet_cifar10_model_size1M_res32/best-params_rank0.pth', 30 | 'num_classes': 10, 31 | 'use_SE': False, 32 | 'resolution': 32, 33 | 'crop_image_size': 32, 34 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/zennet_cifar10_model_size1M_res32/best-params_rank0.pth', 35 | }, 36 | 37 | 'zennet_cifar10_model_size2M_res32': { 38 | 'plainnet_str_txt': 'zennet_cifar_model_size2M_res32.txt', 39 | 'pth_path': 'zennet_cifar10_model_size2M_res32/best-params_rank0.pth', 40 | 'num_classes': 10, 41 | 'use_SE': False, 42 | 'resolution': 32, 43 | 'crop_image_size': 32, 44 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/zennet_cifar10_model_size2M_res32/best-params_rank0.pth', 45 | }, 46 | 47 | 'zennet_cifar100_model_size05M_res32': { 48 | 'plainnet_str_txt': 'zennet_cifar_model_size05M_res32.txt', 49 | 'pth_path': 'zennet_cifar100_model_size05M_res32/best-params_rank0.pth', 50 | 'num_classes': 100, 51 | 'use_SE': False, 52 | 'resolution': 32, 53 | 'crop_image_size': 32, 54 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/zennet_cifar100_model_size05M_res32/best-params_rank0.pth', 55 | }, 56 | 57 | 'zennet_cifar100_model_size1M_res32': { 58 | 'plainnet_str_txt': 'zennet_cifar_model_size1M_res32.txt', 59 | 'pth_path': 'zennet_cifar100_model_size1M_res32/best-params_rank0.pth', 60 | 'num_classes': 100, 61 | 'use_SE': False, 62 | 'resolution': 32, 63 | 'crop_image_size': 32, 64 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/zennet_cifar100_model_size1M_res32/best-params_rank0.pth', 65 | }, 66 | 67 | 'zennet_cifar100_model_size2M_res32': { 68 | 'plainnet_str_txt': 'zennet_cifar_model_size2M_res32.txt', 69 | 'pth_path': 'zennet_cifar100_model_size2M_res32/best-params_rank0.pth', 70 | 'num_classes': 100, 71 | 'use_SE': False, 72 | 'resolution': 32, 73 | 'crop_image_size': 32, 74 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/zennet_cifar100_model_size2M_res32/best-params_rank0.pth', 75 | }, 76 | 77 | 'zennet_imagenet1k_flops400M_SE_res224': { 78 | 'plainnet_str_txt': 'zennet_imagenet1k_flops400M_res224.txt', 79 | 'pth_path': 'iccv2021_zennet_imagenet1k_flops400M_SE_res224/student_best-params_rank0.pth', 80 | 'num_classes': 1000, 81 | 'use_SE': True, 82 | 'resolution': 224, 83 | 'crop_image_size': 320, 84 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/iccv2021_zennet_imagenet1k_flops400M_SE_res224/student_best-params_rank0.pth', 85 | }, 86 | 87 | 'zennet_imagenet1k_flops600M_SE_res224': { 88 | 'plainnet_str_txt': 'zennet_imagenet1k_flops600M_res224.txt', 89 | 'pth_path': 'iccv2021_zennet_imagenet1k_flops600M_SE_res224/student_best-params_rank0.pth', 90 | 'num_classes': 1000, 91 | 'use_SE': True, 92 | 'resolution': 224, 93 | 'crop_image_size': 320, 94 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/iccv2021_zennet_imagenet1k_flops600M_SE_res224/student_best-params_rank0.pth', 95 | }, 96 | 97 | 'zennet_imagenet1k_flops900M_SE_res224': { 98 | 'plainnet_str_txt': 'zennet_imagenet1k_flops900M_res224.txt', 99 | 'pth_path': 'zennet_imagenet1k_flops900M_SE_res224/student_best-params_rank0.pth', 100 | 'num_classes': 1000, 101 | 'use_SE': True, 102 | 'resolution': 224, 103 | 'crop_image_size': 320, 104 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/zennet_imagenet1k_flops900M_SE_res224/student_best-params_rank0.pth', 105 | }, 106 | 107 | 'zennet_imagenet1k_latency01ms_res224': { 108 | 'plainnet_str_txt': 'zennet_imagenet1k_latency01ms_res224.txt', 109 | 'pth_path': 'iccv2021_zennet_imagenet1k_latency01ms_res224/student_best-params_rank0.pth', 110 | 'num_classes': 1000, 111 | 'use_SE': False, 112 | 'resolution': 224, 113 | 'crop_image_size': 320, 114 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/iccv2021_zennet_imagenet1k_latency01ms_res224/student_best-params_rank0.pth', 115 | }, 116 | 117 | 'zennet_imagenet1k_latency02ms_res224': { 118 | 'plainnet_str_txt': 'zennet_imagenet1k_latency02ms_res224.txt', 119 | 'pth_path': 'iccv2021_zennet_imagenet1k_latency02ms_res224/student_best-params_rank0.pth', 120 | 'num_classes': 1000, 121 | 'use_SE': False, 122 | 'resolution': 224, 123 | 'crop_image_size': 320, 124 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/iccv2021_zennet_imagenet1k_latency02ms_res224/student_best-params_rank0.pth', 125 | }, 126 | 127 | 'zennet_imagenet1k_latency03ms_res224': { 128 | 'plainnet_str_txt': 'zennet_imagenet1k_latency03ms_res224.txt', 129 | 'pth_path': 'iccv2021_zennet_imagenet1k_latency03ms_res224/student_best-params_rank0.pth', 130 | 'num_classes': 1000, 131 | 'use_SE': False, 132 | 'resolution': 224, 133 | 'crop_image_size': 320, 134 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/iccv2021_zennet_imagenet1k_latency03ms_res224/student_best-params_rank0.pth', 135 | }, 136 | 137 | 'zennet_imagenet1k_latency05ms_res224': { 138 | 'plainnet_str_txt': 'zennet_imagenet1k_latency05ms_res224.txt', 139 | 'pth_path': 'iccv2021_zennet_imagenet1k_latency05ms_res224/student_best-params_rank0.pth', 140 | 'num_classes': 1000, 141 | 'use_SE': False, 142 | 'resolution': 224, 143 | 'crop_image_size': 320, 144 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/iccv2021_zennet_imagenet1k_latency05ms_res224/student_best-params_rank0.pth', 145 | }, 146 | 147 | 'zennet_imagenet1k_latency08ms_res224': { 148 | 'plainnet_str_txt': 'zennet_imagenet1k_latency08ms_res224.txt', 149 | 'pth_path': 'iccv2021_zennet_imagenet1k_latency08ms_res224/student_best-params_rank0.pth', 150 | 'num_classes': 1000, 151 | 'use_SE': False, 152 | 'resolution': 224, 153 | 'crop_image_size': 320, 154 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/iccv2021_zennet_imagenet1k_latency08ms_res224/student_best-params_rank0.pth', 155 | }, 156 | 'zennet_imagenet1k_latency12ms_res224': { 157 | 'plainnet_str_txt': 'zennet_imagenet1k_latency12ms_res224.txt', 158 | 'pth_path': 'iccv2021_zennet_imagenet1k_latency12ms_res224/student_best-params_rank0.pth', 159 | 'num_classes': 1000, 160 | 'use_SE': False, 161 | 'resolution': 224, 162 | 'crop_image_size': 380, 163 | 'pretrained_pth_url': 'https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/ZenNet/pretrained_models/iccv2021_zennet_imagenet1k_latency12ms_res224/student_best-params_rank0.pth', 164 | }, 165 | } 166 | 167 | def get_ZenNet(model_name, pretrained=False): 168 | if model_name not in zennet_model_zoo: 169 | print('Error! Cannot find ZenNet model name! Please choose one in the following list:') 170 | 171 | for key in zennet_model_zoo: 172 | print(key) 173 | raise ValueError('ZenNet Model Name not found: ' + model_name) 174 | 175 | model_plainnet_str_txt = os.path.join(this_script_dir, zennet_model_zoo[model_name]['plainnet_str_txt']) 176 | model_pth_path = os.path.join(pretrain_model_pth_dir, zennet_model_zoo[model_name]['pth_path']) 177 | pretrained_pth_url = zennet_model_zoo[model_name]['pretrained_pth_url'] 178 | use_SE = zennet_model_zoo[model_name]['use_SE'] 179 | num_classes = zennet_model_zoo[model_name]['num_classes'] 180 | 181 | with open(model_plainnet_str_txt, 'r') as fid: 182 | model_plainnet_str = fid.readline().strip() 183 | 184 | model = masternet.PlainNet(num_classes=num_classes, plainnet_struct=model_plainnet_str, use_se=use_SE) 185 | 186 | if pretrained: 187 | 188 | if not os.path.isfile(model_pth_path): 189 | print('downloading pretrained parameters from ' + pretrained_pth_url) 190 | global_utils.mkfilepath(model_pth_path) 191 | urllib.request.urlretrieve(url=pretrained_pth_url, filename=model_pth_path) 192 | print('pretrained model parameters cached at ' + model_pth_path) 193 | 194 | print('loading pretrained parameters...') 195 | checkpoint = torch.load(model_pth_path, map_location='cpu') 196 | if 'state_dict' in checkpoint: 197 | state_dict = checkpoint['state_dict'] 198 | else: 199 | state_dict = checkpoint 200 | 201 | model.load_state_dict(state_dict, strict=True) 202 | 203 | return model -------------------------------------------------------------------------------- /PlainNet/super_blocks.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | 6 | import os, sys 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | import uuid 14 | import global_utils 15 | 16 | import PlainNet 17 | from PlainNet import _get_right_parentheses_index_, basic_blocks 18 | 19 | class PlainNetSuperBlockClass(basic_blocks.PlainNetBasicBlockClass): 20 | def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): 21 | super(PlainNetSuperBlockClass, self).__init__() 22 | self.in_channels = in_channels 23 | self.out_channels = out_channels 24 | self.stride = stride 25 | self.sub_layers = sub_layers 26 | self.no_create = no_create 27 | self.block_list = None 28 | self.module_list = None 29 | 30 | def forward(self, x): 31 | output = x 32 | for block in self.block_list: 33 | output = block(output) 34 | return output 35 | 36 | def __str__(self): 37 | return type(self).__name__ + '({},{},{},{})'.format(self.in_channels, self.out_channels, 38 | self.stride, self.sub_layers) 39 | 40 | def __repr__(self): 41 | return type(self).__name__ + '({}|{},{},{},{})'.format(self.block_name, self.in_channels, self.out_channels, 42 | self.stride, self.sub_layers) 43 | 44 | def get_output_resolution(self, input_resolution): 45 | resolution = input_resolution 46 | for block in self.block_list: 47 | resolution = block.get_output_resolution(resolution) 48 | return resolution 49 | 50 | def get_FLOPs(self, input_resolution): 51 | resolution = input_resolution 52 | flops = 0.0 53 | for block in self.block_list: 54 | flops += block.get_FLOPs(resolution) 55 | resolution = block.get_output_resolution(resolution) 56 | return flops 57 | 58 | 59 | def get_model_size(self): 60 | model_size = 0.0 61 | for block in self.block_list: 62 | model_size += block.get_model_size() 63 | return model_size 64 | 65 | def set_in_channels(self, c): 66 | self.in_channels = c 67 | if len(self.block_list) == 0: 68 | self.out_channels = c 69 | return 70 | 71 | self.block_list[0].set_in_channels(c) 72 | last_channels = self.block_list[0].out_channels 73 | if len(self.block_list) >= 2 and \ 74 | (isinstance(self.block_list[0], basic_blocks.ConvKX) or isinstance(self.block_list[0], basic_blocks.ConvDW)) and \ 75 | isinstance(self.block_list[1], basic_blocks.BN): 76 | self.block_list[1].set_in_channels(last_channels) 77 | 78 | def encode_structure(self): 79 | return [self.out_channels, self.sub_layers] 80 | 81 | @classmethod 82 | def create_from_str(cls, s, no_create=False, **kwargs): 83 | assert cls.is_instance_from_str(s) 84 | idx = _get_right_parentheses_index_(s) 85 | assert idx is not None 86 | param_str = s[len(cls.__name__ + '('):idx] 87 | 88 | # find block_name 89 | tmp_idx = param_str.find('|') 90 | if tmp_idx < 0: 91 | tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) 92 | else: 93 | tmp_block_name = param_str[0:tmp_idx] 94 | param_str = param_str[tmp_idx + 1:] 95 | 96 | param_str_split = param_str.split(',') 97 | in_channels = int(param_str_split[0]) 98 | out_channels = int(param_str_split[1]) 99 | stride = int(param_str_split[2]) 100 | sub_layers = int(param_str_split[3]) 101 | return cls(in_channels=in_channels, out_channels=out_channels, stride=stride, 102 | sub_layers=sub_layers, block_name=tmp_block_name, no_create=no_create, 103 | **kwargs),\ 104 | s[idx + 1:] 105 | 106 | 107 | class SuperConvKXBNRELU(PlainNetSuperBlockClass): 108 | def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, kernel_size=None, 109 | no_create=False, no_reslink=False, no_BN=False, **kwargs): 110 | super(SuperConvKXBNRELU, self).__init__(**kwargs) 111 | self.in_channels = in_channels 112 | self.out_channels = out_channels 113 | self.stride = stride 114 | self.sub_layers = sub_layers 115 | self.kernel_size = kernel_size 116 | self.no_create = no_create 117 | self.no_reslink = no_reslink 118 | self.no_BN = no_BN 119 | 120 | # if self.no_reslink: 121 | # print('Warning! {} use no_reslink'.format(str(self))) 122 | # if self.no_BN: 123 | # print('Warning! {} use no_BN'.format(str(self))) 124 | 125 | full_str = '' 126 | last_channels = in_channels 127 | current_stride = stride 128 | for i in range(self.sub_layers): 129 | if not self.no_BN: 130 | inner_str = 'ConvKX({},{},{},{})BN({})RELU({})'.format(last_channels, self.out_channels, 131 | self.kernel_size, 132 | current_stride, 133 | self.out_channels, self.out_channels) 134 | else: 135 | inner_str = 'ConvKX({},{},{},{})RELU({})'.format(last_channels, self.out_channels, 136 | self.kernel_size, 137 | current_stride, 138 | self.out_channels) 139 | full_str += inner_str 140 | 141 | last_channels = out_channels 142 | current_stride = 1 143 | pass 144 | 145 | self.block_list = PlainNet.create_netblock_list_from_str(full_str, no_create=no_create, 146 | no_reslink=no_reslink, no_BN=no_BN) 147 | if not no_create: 148 | self.module_list = nn.ModuleList(self.block_list) 149 | else: 150 | self.module_list = None 151 | 152 | def forward_pre_relu(self, x): 153 | output = x 154 | for block in self.block_list[0:-1]: 155 | output = block(output) 156 | return output 157 | 158 | def __str__(self): 159 | return type(self).__name__ + '({},{},{},{})'.format(self.in_channels, self.out_channels, 160 | self.stride, self.sub_layers) 161 | 162 | def __repr__(self): 163 | return type(self).__name__ + '({}|in={},out={},stride={},sub_layers={},kernel_size={})'.format( 164 | self.block_name, self.in_channels, self.out_channels, self.stride, self.sub_layers, self.kernel_size) 165 | 166 | def split(self, split_layer_threshold): 167 | return str(self) 168 | 169 | def structure_scale(self, scale=1.0, channel_scale=None, sub_layer_scale=None): 170 | if channel_scale is None: 171 | channel_scale = scale 172 | if sub_layer_scale is None: 173 | sub_layer_scale = scale 174 | 175 | new_out_channels = global_utils.smart_round(self.out_channels * channel_scale) 176 | new_sub_layers = max(1, round(self.sub_layers * sub_layer_scale)) 177 | 178 | return type(self).__name__ + '({},{},{},{})'.format(self.in_channels, new_out_channels, 179 | self.stride, new_sub_layers) 180 | 181 | 182 | 183 | class SuperConvK1BNRELU(SuperConvKXBNRELU): 184 | def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): 185 | super(SuperConvK1BNRELU, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, 186 | sub_layers=sub_layers, 187 | kernel_size=1, 188 | no_create=no_create, **kwargs) 189 | 190 | class SuperConvK3BNRELU(SuperConvKXBNRELU): 191 | def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): 192 | super(SuperConvK3BNRELU, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, 193 | sub_layers=sub_layers, 194 | kernel_size=3, 195 | no_create=no_create, **kwargs) 196 | 197 | class SuperConvK5BNRELU(SuperConvKXBNRELU): 198 | def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): 199 | super(SuperConvK5BNRELU, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, 200 | sub_layers=sub_layers, 201 | kernel_size=5, 202 | no_create=no_create, **kwargs) 203 | 204 | 205 | class SuperConvK7BNRELU(SuperConvKXBNRELU): 206 | def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): 207 | super(SuperConvK7BNRELU, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, 208 | sub_layers=sub_layers, 209 | kernel_size=7, 210 | no_create=no_create, **kwargs) 211 | 212 | 213 | def register_netblocks_dict(netblocks_dict: dict): 214 | this_py_file_netblocks_dict = { 215 | 'SuperConvK1BNRELU': SuperConvK1BNRELU, 216 | 'SuperConvK3BNRELU': SuperConvK3BNRELU, 217 | 'SuperConvK5BNRELU': SuperConvK5BNRELU, 218 | 'SuperConvK7BNRELU': SuperConvK7BNRELU, 219 | 220 | } 221 | netblocks_dict.update(this_py_file_netblocks_dict) 222 | return netblocks_dict -------------------------------------------------------------------------------- /PlainNet/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | 6 | import os, sys 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | 9 | import torch, argparse 10 | from torch import nn 11 | 12 | _all_netblocks_dict_ = {} 13 | 14 | def parse_cmd_options(argv, opt=None): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--plainnet_struct', type=str, default=None, help='PlainNet structure string') 17 | parser.add_argument('--plainnet_struct_txt', type=str, default=None, help='PlainNet structure file name') 18 | parser.add_argument('--num_classes', type=int, default=None, help='how to prune') 19 | module_opt, _ = parser.parse_known_args(argv) 20 | 21 | return module_opt 22 | 23 | def _get_right_parentheses_index_(s): 24 | # assert s[0] == '(' 25 | left_paren_count = 0 26 | for index, x in enumerate(s): 27 | 28 | if x == '(': 29 | left_paren_count += 1 30 | elif x == ')': 31 | left_paren_count -= 1 32 | if left_paren_count == 0: 33 | return index 34 | else: 35 | pass 36 | return None 37 | 38 | def pretty_format(plainnet_str, indent=2): 39 | the_formated_str = '' 40 | indent_str = '' 41 | if indent >= 1: 42 | indent_str = ''.join([' '] * indent) 43 | 44 | # print(indent_str, end='') 45 | the_formated_str += indent_str 46 | 47 | s = plainnet_str 48 | while len(s) > 0: 49 | if s[0] == ';': 50 | # print(';\n' + indent_str, end='') 51 | the_formated_str += ';\n' + indent_str 52 | s = s[1:] 53 | 54 | left_par_idx = s.find('(') 55 | assert left_par_idx is not None 56 | right_par_idx = _get_right_parentheses_index_(s) 57 | the_block_class_name = s[0:left_par_idx] 58 | 59 | if the_block_class_name in ['MultiSumBlock', 'MultiCatBlock','MultiGroupBlock']: 60 | # print('\n' + indent_str + the_block_class_name + '(') 61 | sub_str = s[left_par_idx + 1:right_par_idx] 62 | 63 | # find block_name 64 | tmp_idx = sub_str.find('|') 65 | if tmp_idx < 0: 66 | tmp_block_name = 'no_name' 67 | else: 68 | tmp_block_name = sub_str[0:tmp_idx] 69 | sub_str = sub_str[tmp_idx+1:] 70 | 71 | if len(tmp_block_name) > 8: 72 | tmp_block_name = tmp_block_name[0:4] + tmp_block_name[-4:] 73 | 74 | the_formated_str += '\n' + indent_str + the_block_class_name + '({}|\n'.format(tmp_block_name) 75 | 76 | the_formated_str += pretty_format(sub_str, indent + 1) 77 | # print('\n' + indent_str + ')') 78 | # print(indent_str, end='') 79 | the_formated_str += '\n' + indent_str + ')\n' + indent_str 80 | elif the_block_class_name in ['ResBlock']: 81 | # print('\n' + indent_str + the_block_class_name + '(') 82 | in_channels = None 83 | the_stride = None 84 | sub_str = s[left_par_idx + 1:right_par_idx] 85 | # find block_name 86 | tmp_idx = sub_str.find('|') 87 | if tmp_idx < 0: 88 | tmp_block_name = 'no_name' 89 | else: 90 | tmp_block_name = sub_str[0:tmp_idx] 91 | sub_str = sub_str[tmp_idx + 1:] 92 | 93 | first_comma_index = sub_str.find(',') 94 | if first_comma_index < 0 or not sub_str[0:first_comma_index].isdigit(): 95 | in_channels = None 96 | else: 97 | in_channels = int(sub_str[0:first_comma_index]) 98 | sub_str = sub_str[first_comma_index+1:] 99 | second_comma_index = sub_str.find(',') 100 | if second_comma_index < 0 or not sub_str[0:second_comma_index].isdigit(): 101 | the_stride = None 102 | else: 103 | the_stride = int(sub_str[0:second_comma_index]) 104 | sub_str = sub_str[second_comma_index + 1:] 105 | pass 106 | pass 107 | 108 | if len(tmp_block_name) > 8: 109 | tmp_block_name = tmp_block_name[0:4] + tmp_block_name[-4:] 110 | 111 | the_formated_str += '\n' + indent_str + the_block_class_name + '({}|'.format(tmp_block_name) 112 | if in_channels is not None: 113 | the_formated_str += '{},'.format(in_channels) 114 | else: 115 | the_formated_str += ',' 116 | 117 | if the_stride is not None: 118 | the_formated_str += '{},'.format(the_stride) 119 | else: 120 | the_formated_str += ',' 121 | 122 | the_formated_str += '\n' 123 | 124 | the_formated_str += pretty_format(sub_str, indent + 1) 125 | # print('\n' + indent_str + ')') 126 | # print(indent_str, end='') 127 | the_formated_str += '\n' + indent_str + ')\n' + indent_str 128 | else: 129 | # print(s[0:right_par_idx+1], end='') 130 | sub_str = s[left_par_idx + 1:right_par_idx] 131 | # find block_name 132 | tmp_idx = sub_str.find('|') 133 | if tmp_idx < 0: 134 | tmp_block_name = 'no_name' 135 | else: 136 | tmp_block_name = sub_str[0:tmp_idx] 137 | sub_str = sub_str[tmp_idx + 1:] 138 | 139 | if len(tmp_block_name) > 8: 140 | tmp_block_name = tmp_block_name[0:4] + tmp_block_name[-4:] 141 | 142 | the_formated_str += the_block_class_name + '({}|'.format(tmp_block_name) + sub_str + ')' 143 | 144 | s = s[right_par_idx+1:] 145 | pass # end while 146 | 147 | return the_formated_str 148 | 149 | def _create_netblock_list_from_str_(s, no_create=False, **kwargs): 150 | block_list = [] 151 | while len(s) > 0: 152 | is_found_block_class = False 153 | for the_block_class_name in _all_netblocks_dict_.keys(): 154 | tmp_idx = s.find('(') 155 | if tmp_idx > 0 and s[0:tmp_idx] == the_block_class_name: 156 | is_found_block_class = True 157 | the_block_class = _all_netblocks_dict_[the_block_class_name] 158 | the_block, remaining_s = the_block_class.create_from_str(s, no_create=no_create, **kwargs) 159 | if the_block is not None: 160 | block_list.append(the_block) 161 | s = remaining_s 162 | if len(s) > 0 and s[0] == ';': 163 | return block_list, s[1:] 164 | break 165 | pass # end if 166 | pass # end for 167 | assert is_found_block_class 168 | pass # end while 169 | return block_list, '' 170 | 171 | def create_netblock_list_from_str(s, no_create=False, **kwargs): 172 | the_list, remaining_s = _create_netblock_list_from_str_(s, no_create=no_create, **kwargs) 173 | assert len(remaining_s) == 0 174 | return the_list 175 | 176 | def add_SE_block(structure_str: str): 177 | new_str = '' 178 | RELU = 'RELU' 179 | offset = 4 180 | 181 | idx = structure_str.find(RELU) 182 | while idx >= 0: 183 | new_str += structure_str[0: idx] 184 | structure_str = structure_str[idx:] 185 | r_idx = _get_right_parentheses_index_(structure_str[offset:]) + offset 186 | channels = structure_str[offset + 1:r_idx] 187 | new_str += 'RELU({})SE({})'.format(channels, channels) 188 | structure_str = structure_str[r_idx + 1:] 189 | idx = structure_str.find(RELU) 190 | pass 191 | 192 | new_str += structure_str 193 | return new_str 194 | 195 | class PlainNet(nn.Module): 196 | def __init__(self, argv=None, opt=None, num_classes=None, plainnet_struct=None, no_create=False, 197 | **kwargs): 198 | super(PlainNet, self).__init__() 199 | self.argv = argv 200 | self.opt = opt 201 | self.num_classes = num_classes 202 | self.plainnet_struct = plainnet_struct 203 | 204 | self.module_opt = parse_cmd_options(self.argv) 205 | 206 | if self.num_classes is None: 207 | self.num_classes = self.module_opt.num_classes 208 | 209 | if self.plainnet_struct is None and self.module_opt.plainnet_struct is not None: 210 | self.plainnet_struct = self.module_opt.plainnet_struct 211 | 212 | if self.plainnet_struct is None: 213 | # load structure from text file 214 | if hasattr(opt, 'plainnet_struct_txt') and opt.plainnet_struct_txt is not None: 215 | plainnet_struct_txt = opt.plainnet_struct_txt 216 | else: 217 | plainnet_struct_txt = self.module_opt.plainnet_struct_txt 218 | 219 | if plainnet_struct_txt is not None: 220 | with open(plainnet_struct_txt, 'r') as fid: 221 | the_line = fid.readlines()[0].strip() 222 | self.plainnet_struct = the_line 223 | pass 224 | 225 | if self.plainnet_struct is None: 226 | return 227 | 228 | the_s = self.plainnet_struct # type: str 229 | 230 | block_list, remaining_s = _create_netblock_list_from_str_(the_s, no_create=no_create, **kwargs) 231 | assert len(remaining_s) == 0 232 | 233 | self.block_list = block_list 234 | if not no_create: 235 | self.module_list = nn.ModuleList(block_list) # register 236 | 237 | def forward(self, x): 238 | output = x 239 | for the_block in self.block_list: 240 | output = the_block(output) 241 | return output 242 | 243 | def __str__(self): 244 | s = '' 245 | for the_block in self.block_list: 246 | s += str(the_block) 247 | return s 248 | 249 | def __repr__(self): 250 | return str(self) 251 | 252 | def get_FLOPs(self, input_resolution): 253 | the_res = input_resolution 254 | the_flops = 0 255 | for the_block in self.block_list: 256 | the_flops += the_block.get_FLOPs(the_res) 257 | the_res = the_block.get_output_resolution(the_res) 258 | 259 | return the_flops 260 | 261 | def get_model_size(self): 262 | the_size = 0 263 | for the_block in self.block_list: 264 | the_size += the_block.get_model_size() 265 | 266 | return the_size 267 | 268 | def replace_block(self, block_id, new_block): 269 | self.block_list[block_id] = new_block 270 | if block_id < len(self.block_list): 271 | self.block_list[block_id + 1].set_in_channels(new_block.out_channels) 272 | 273 | self.module_list = nn.Module(self.block_list) 274 | 275 | 276 | 277 | from PlainNet import basic_blocks 278 | _all_netblocks_dict_ = basic_blocks.register_netblocks_dict(_all_netblocks_dict_) 279 | 280 | from PlainNet import super_blocks 281 | _all_netblocks_dict_ = super_blocks.register_netblocks_dict(_all_netblocks_dict_) 282 | from PlainNet import SuperResKXKX 283 | _all_netblocks_dict_ = SuperResKXKX.register_netblocks_dict(_all_netblocks_dict_) 284 | 285 | from PlainNet import SuperResK1KXK1 286 | _all_netblocks_dict_ = SuperResK1KXK1.register_netblocks_dict(_all_netblocks_dict_) 287 | 288 | from PlainNet import SuperResIDWEXKX 289 | _all_netblocks_dict_ = SuperResIDWEXKX.register_netblocks_dict(_all_netblocks_dict_) -------------------------------------------------------------------------------- /ZeroShotProxy/compute_te_nas_score.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | 4 | This file is modified from: 5 | https://github.com/VITA-Group/TENAS 6 | ''' 7 | 8 | import os, sys 9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 10 | 11 | import torch 12 | from torch import nn 13 | import global_utils, argparse, ModelLoader, time 14 | 15 | class LinearRegionCount(object): 16 | """Computes and stores the average and current value""" 17 | def __init__(self, n_samples, gpu=None): 18 | self.ActPattern = {} 19 | self.n_LR = -1 20 | self.n_samples = n_samples 21 | self.ptr = 0 22 | self.activations = None 23 | self.gpu = gpu 24 | 25 | 26 | @torch.no_grad() 27 | def update2D(self, activations): 28 | n_batch = activations.size()[0] 29 | n_neuron = activations.size()[1] 30 | self.n_neuron = n_neuron 31 | if self.activations is None: 32 | self.activations = torch.zeros(self.n_samples, n_neuron) 33 | if self.gpu is not None: 34 | self.activations = self.activations.cuda(self.gpu) 35 | self.activations[self.ptr:self.ptr+n_batch] = torch.sign(activations) # after ReLU 36 | self.ptr += n_batch 37 | 38 | @torch.no_grad() 39 | def calc_LR(self): 40 | res = torch.matmul(self.activations.half(), (1-self.activations).T.half()) 41 | res += res.T 42 | res = 1 - torch.sign(res) 43 | res = res.sum(1) 44 | res = 1. / res.float() 45 | self.n_LR = res.sum().item() 46 | del self.activations, res 47 | self.activations = None 48 | if self.gpu is not None: 49 | torch.cuda.empty_cache() 50 | 51 | @torch.no_grad() 52 | def update1D(self, activationList): 53 | code_string = '' 54 | for key, value in activationList.items(): 55 | n_neuron = value.size()[0] 56 | for i in range(n_neuron): 57 | if value[i] > 0: 58 | code_string += '1' 59 | else: 60 | code_string += '0' 61 | if code_string not in self.ActPattern: 62 | self.ActPattern[code_string] = 1 63 | 64 | def getLinearReginCount(self): 65 | if self.n_LR == -1: 66 | self.calc_LR() 67 | return self.n_LR 68 | 69 | class Linear_Region_Collector: 70 | def __init__(self, models=[], input_size=(64, 3, 32, 32), gpu=None, 71 | sample_batch=1, dataset=None, data_path=None, seed=0): 72 | self.models = [] 73 | self.input_size = input_size # BCHW 74 | self.sample_batch = sample_batch 75 | # self.input_numel = reduce(mul, self.input_size, 1) 76 | self.interFeature = [] 77 | self.dataset = dataset 78 | self.data_path = data_path 79 | self.seed = seed 80 | self.gpu = gpu 81 | self.device = torch.device('cuda:{}'.format(self.gpu)) if self.gpu is not None else torch.device('cpu') 82 | # print('Using device:{}'.format(self.device)) 83 | 84 | self.reinit(models, input_size, sample_batch, seed) 85 | 86 | 87 | def reinit(self, models=None, input_size=None, sample_batch=None, seed=None): 88 | if models is not None: 89 | assert isinstance(models, list) 90 | del self.models 91 | self.models = models 92 | for model in self.models: 93 | self.register_hook(model) 94 | self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch, gpu=self.gpu) for _ in range(len(models))] 95 | if input_size is not None or sample_batch is not None: 96 | if input_size is not None: 97 | self.input_size = input_size # BCHW 98 | # self.input_numel = reduce(mul, self.input_size, 1) 99 | if sample_batch is not None: 100 | self.sample_batch = sample_batch 101 | # if self.data_path is not None: 102 | # self.train_data, _, class_num = get_datasets(self.dataset, self.data_path, self.input_size, -1) 103 | # self.train_loader = torch.utils.data.DataLoader(self.train_data, batch_size=self.input_size[0], num_workers=16, pin_memory=True, drop_last=True, shuffle=True) 104 | # self.loader = iter(self.train_loader) 105 | if seed is not None and seed != self.seed: 106 | self.seed = seed 107 | torch.manual_seed(seed) 108 | if self.gpu is not None: 109 | torch.cuda.manual_seed(seed) 110 | del self.interFeature 111 | self.interFeature = [] 112 | if self.gpu is not None: 113 | torch.cuda.empty_cache() 114 | 115 | def clear(self): 116 | self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(self.models))] 117 | del self.interFeature 118 | self.interFeature = [] 119 | if self.gpu is not None: 120 | torch.cuda.empty_cache() 121 | 122 | def register_hook(self, model): 123 | for m in model.modules(): 124 | if isinstance(m, nn.ReLU): 125 | m.register_forward_hook(hook=self.hook_in_forward) 126 | 127 | def hook_in_forward(self, module, input, output): 128 | if isinstance(input, tuple) and len(input[0].size()) == 4: 129 | self.interFeature.append(output.detach()) # for ReLU 130 | 131 | def forward_batch_sample(self): 132 | for _ in range(self.sample_batch): 133 | # try: 134 | # inputs, targets = self.loader.next() 135 | # except Exception: 136 | # del self.loader 137 | # self.loader = iter(self.train_loader) 138 | # inputs, targets = self.loader.next() 139 | inputs = torch.randn(self.input_size, device=self.device) 140 | 141 | for model, LRCount in zip(self.models, self.LRCounts): 142 | self.forward(model, LRCount, inputs) 143 | return [LRCount.getLinearReginCount() for LRCount in self.LRCounts] 144 | 145 | def forward(self, model, LRCount, input_data): 146 | self.interFeature = [] 147 | with torch.no_grad(): 148 | # model.forward(input_data.cuda()) 149 | model.forward(input_data) 150 | if len(self.interFeature) == 0: return 151 | feature_data = torch.cat([f.view(input_data.size(0), -1) for f in self.interFeature], 1) 152 | LRCount.update2D(feature_data) 153 | 154 | 155 | def compute_RN_score(model: nn.Module, batch_size=None, image_size=None, num_batch=None, gpu=None): 156 | # # just debug 157 | # gpu = 0 158 | # import ModelLoader 159 | # model = ModelLoader._get_model_(arch='resnet18', num_classes=1000, pretrained=False, opt=None, argv=None) 160 | # 161 | # if gpu is not None: 162 | # model = model.cuda(gpu) 163 | lrc_model = Linear_Region_Collector(models=[model], input_size=(batch_size, 3, image_size, image_size), 164 | gpu=gpu, sample_batch=num_batch) 165 | num_linear_regions = float(lrc_model.forward_batch_sample()[0]) 166 | del lrc_model 167 | torch.cuda.empty_cache() 168 | return num_linear_regions 169 | 170 | 171 | 172 | import numpy as np 173 | import torch 174 | 175 | 176 | def recal_bn(network, xloader, recalbn, device): 177 | for m in network.modules(): 178 | if isinstance(m, torch.nn.BatchNorm2d): 179 | m.running_mean.data.fill_(0) 180 | m.running_var.data.fill_(0) 181 | m.num_batches_tracked.data.zero_() 182 | m.momentum = None 183 | network.train() 184 | with torch.no_grad(): 185 | for i, (inputs, targets) in enumerate(xloader): 186 | if i >= recalbn: break 187 | inputs = inputs.cuda(device=device, non_blocking=True) 188 | _, _ = network(inputs) 189 | return network 190 | 191 | 192 | def get_ntk_n(networks, recalbn=0, train_mode=False, num_batch=None, 193 | batch_size=None, image_size=None, gpu=None): 194 | if gpu is not None: 195 | device = torch.device('cuda:{}'.format(gpu)) 196 | else: 197 | device = torch.device('cpu') 198 | 199 | # if recalbn > 0: 200 | # network = recal_bn(network, xloader, recalbn, device) 201 | # if network_2 is not None: 202 | # network_2 = recal_bn(network_2, xloader, recalbn, device) 203 | ntks = [] 204 | for network in networks: 205 | if train_mode: 206 | network.train() 207 | else: 208 | network.eval() 209 | ###### 210 | grads = [[] for _ in range(len(networks))] 211 | 212 | # for i, (inputs, targets) in enumerate(xloader): 213 | # if num_batch > 0 and i >= num_batch: break 214 | for i in range(num_batch): 215 | inputs = torch.randn((batch_size, 3, image_size, image_size), device=device) 216 | # inputs = inputs.cuda(device=device, non_blocking=True) 217 | for net_idx, network in enumerate(networks): 218 | network.zero_grad() 219 | if gpu is not None: 220 | inputs_ = inputs.clone().cuda(device=device, non_blocking=True) 221 | else: 222 | inputs_ = inputs.clone() 223 | 224 | logit = network(inputs_) 225 | if isinstance(logit, tuple): 226 | logit = logit[1] # 201 networks: return features and logits 227 | for _idx in range(len(inputs_)): 228 | logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True) 229 | grad = [] 230 | for name, W in network.named_parameters(): 231 | if 'weight' in name and W.grad is not None: 232 | grad.append(W.grad.view(-1).detach()) 233 | grads[net_idx].append(torch.cat(grad, -1)) 234 | network.zero_grad() 235 | if gpu is not None: 236 | torch.cuda.empty_cache() 237 | 238 | ###### 239 | grads = [torch.stack(_grads, 0) for _grads in grads] 240 | ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads] 241 | conds = [] 242 | for ntk in ntks: 243 | eigenvalues, _ = torch.symeig(ntk) # ascending 244 | # conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True, nan=100000.0)) 245 | conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True)) 246 | return conds 247 | 248 | 249 | 250 | def compute_NTK_score(gpu, model, resolution, batch_size): 251 | ntk_score = get_ntk_n([model], recalbn=0, train_mode=True, num_batch=1, 252 | batch_size=batch_size, image_size=resolution, gpu=gpu)[0] 253 | return -1 * ntk_score 254 | 255 | 256 | 257 | def parse_cmd_options(argv): 258 | parser = argparse.ArgumentParser() 259 | parser.add_argument('--batch_size', type=int, default=16, help='number of instances in one mini-batch.') 260 | parser.add_argument('--input_image_size', type=int, default=None, 261 | help='resolution of input image, usually 32 for CIFAR and 224 for ImageNet.') 262 | parser.add_argument('--repeat_times', type=int, default=32) 263 | parser.add_argument('--gpu', type=int, default=None) 264 | module_opt, _ = parser.parse_known_args(argv) 265 | return module_opt 266 | 267 | if __name__ == "__main__": 268 | opt = global_utils.parse_cmd_options(sys.argv) 269 | args = parse_cmd_options(sys.argv) 270 | the_model = ModelLoader.get_model(opt, sys.argv) 271 | if args.gpu is not None: 272 | the_model = the_model.cuda(args.gpu) 273 | 274 | 275 | start_timer = time.time() 276 | 277 | for repeat_count in range(args.repeat_times): 278 | ntk = compute_NTK_score(gpu=args.gpu, model=the_model, 279 | resolution=args.input_image_size, batch_size=args.batch_size) 280 | RN = compute_RN_score(model=the_model, batch_size=args.batch_size, image_size=args.input_image_size, 281 | num_batch=1, gpu=args.gpu) 282 | the_score = RN + ntk 283 | time_cost = (time.time() - start_timer) / args.repeat_times 284 | 285 | print(f'ntk={the_score:.4g}, time cost={time_cost:.4g} second(s)') -------------------------------------------------------------------------------- /DataLoader/autoaugment.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is modified from: 3 | https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py 4 | ''' 5 | 6 | from PIL import Image, ImageEnhance, ImageOps 7 | import numpy as np 8 | import random 9 | 10 | fillcolor=(128, 128, 128) 11 | 12 | class ImageNetPolicy(object): 13 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 14 | 15 | Example: 16 | >>> policy = ImageNetPolicy() 17 | >>> transformed = policy(image) 18 | 19 | Example as a PyTorch Transform: 20 | >>> transform=transforms.Compose([ 21 | >>> transforms.Resize(256), 22 | >>> ImageNetPolicy(), 23 | >>> transforms.ToTensor()]) 24 | """ 25 | def __init__(self, fillcolor=(128, 128, 128)): 26 | self.policies = [ 27 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 28 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 29 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 30 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 31 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 32 | 33 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 34 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 35 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 36 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 37 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 38 | 39 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 40 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 41 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 42 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 43 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 44 | 45 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 46 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 47 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 48 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 49 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 50 | 51 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 52 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 53 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 54 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 55 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 56 | ] 57 | 58 | 59 | def __call__(self, img): 60 | policy_idx = random.randint(0, len(self.policies) - 1) 61 | return self.policies[policy_idx](img) 62 | 63 | def __repr__(self): 64 | return "AutoAugment ImageNet Policy" 65 | 66 | 67 | class CIFAR10Policy(object): 68 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 69 | 70 | Example: 71 | >>> policy = CIFAR10Policy() 72 | >>> transformed = policy(image) 73 | 74 | Example as a PyTorch Transform: 75 | >>> transform=transforms.Compose([ 76 | >>> transforms.Resize(256), 77 | >>> CIFAR10Policy(), 78 | >>> transforms.ToTensor()]) 79 | """ 80 | def __init__(self, fillcolor=(128, 128, 128)): 81 | self.policies = [ 82 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 83 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 84 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 85 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 86 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 87 | 88 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 89 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 90 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 91 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 92 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 93 | 94 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 95 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 96 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 97 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 98 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 99 | 100 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 101 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 102 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 103 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 104 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 105 | 106 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 107 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 108 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 109 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 110 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 111 | ] 112 | 113 | 114 | def __call__(self, img): 115 | policy_idx = random.randint(0, len(self.policies) - 1) 116 | return self.policies[policy_idx](img) 117 | 118 | def __repr__(self): 119 | return "AutoAugment CIFAR10 Policy" 120 | 121 | 122 | class SVHNPolicy(object): 123 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 124 | 125 | Example: 126 | >>> policy = SVHNPolicy() 127 | >>> transformed = policy(image) 128 | 129 | Example as a PyTorch Transform: 130 | >>> transform=transforms.Compose([ 131 | >>> transforms.Resize(256), 132 | >>> SVHNPolicy(), 133 | >>> transforms.ToTensor()]) 134 | """ 135 | def __init__(self, fillcolor=(128, 128, 128)): 136 | self.policies = [ 137 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 138 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 139 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 140 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 141 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 142 | 143 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 144 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 145 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 146 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 147 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 148 | 149 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 150 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 151 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 152 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 153 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 154 | 155 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 156 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 157 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 158 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 159 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 160 | 161 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 162 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 163 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 164 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 165 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 166 | ] 167 | 168 | 169 | def __call__(self, img): 170 | policy_idx = random.randint(0, len(self.policies) - 1) 171 | return self.policies[policy_idx](img) 172 | 173 | def __repr__(self): 174 | return "AutoAugment SVHN Policy" 175 | 176 | 177 | 178 | def _shearX(img, magnitude): 179 | return img.transform( 180 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 181 | Image.BICUBIC, fillcolor=fillcolor) 182 | 183 | def _shearY(img, magnitude): 184 | return img.transform( 185 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 186 | Image.BICUBIC, fillcolor=fillcolor) 187 | 188 | def _translateX(img, magnitude): 189 | return img.transform( 190 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 191 | fillcolor=fillcolor) 192 | 193 | def _translateY(img, magnitude): 194 | return img.transform( 195 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 196 | fillcolor=fillcolor) 197 | 198 | def _rotate(img, magnitude): 199 | return rotate_with_fill(img, magnitude) 200 | 201 | def _color(img, magnitude): 202 | return ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])) 203 | 204 | def _posterize(img, magnitude): 205 | return ImageOps.posterize(img, magnitude) 206 | 207 | def _solarize(img, magnitude): 208 | return ImageOps.solarize(img, magnitude) 209 | 210 | def _contrast(img, magnitude): 211 | return ImageEnhance.Contrast(img).enhance( 212 | 1 + magnitude * random.choice([-1, 1])) 213 | 214 | def _sharpness(img, magnitude): 215 | return ImageEnhance.Sharpness(img).enhance( 216 | 1 + magnitude * random.choice([-1, 1])) 217 | 218 | def _brightness(img, magnitude): 219 | return ImageEnhance.Brightness(img).enhance( 220 | 1 + magnitude * random.choice([-1, 1])) 221 | 222 | def _autocontrast(img, magnitude): 223 | return ImageOps.autocontrast(img) 224 | 225 | def _equalize(img, magnitude): 226 | return ImageOps.equalize(img) 227 | 228 | def _invert(img, magnitude): 229 | return ImageOps.invert(img) 230 | 231 | def rotate_with_fill(img, magnitude): 232 | rot = img.convert("RGBA").rotate(magnitude) 233 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 234 | 235 | class SubPolicy(object): 236 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 237 | ranges = { 238 | "shearX": np.linspace(0, 0.3, 10), 239 | "shearY": np.linspace(0, 0.3, 10), 240 | "translateX": np.linspace(0, 150 / 331, 10), 241 | "translateY": np.linspace(0, 150 / 331, 10), 242 | "rotate": np.linspace(0, 30, 10), 243 | "color": np.linspace(0.0, 0.9, 10), 244 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 245 | "solarize": np.linspace(256, 0, 10), 246 | "contrast": np.linspace(0.0, 0.9, 10), 247 | "sharpness": np.linspace(0.0, 0.9, 10), 248 | "brightness": np.linspace(0.0, 0.9, 10), 249 | "autocontrast": [0] * 10, 250 | "equalize": [0] * 10, 251 | "invert": [0] * 10 252 | } 253 | 254 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 255 | 256 | func = { 257 | "shearX": _shearX, 258 | "shearY": _shearY, 259 | "translateX": _translateX, 260 | "translateY": _translateY, 261 | "rotate": _rotate, 262 | "color": _color, 263 | "posterize": _posterize, 264 | "solarize": _solarize, 265 | "contrast": _contrast, 266 | "sharpness": _sharpness, 267 | "brightness": _brightness, 268 | "autocontrast": _autocontrast, 269 | "equalize": _equalize, 270 | "invert": _invert 271 | } 272 | 273 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 274 | # operation1, ranges[operation1][magnitude_idx1], 275 | # operation2, ranges[operation2][magnitude_idx2]) 276 | self.p1 = p1 277 | self.operation1 = func[operation1] 278 | self.magnitude1 = ranges[operation1][magnitude_idx1] 279 | self.p2 = p2 280 | self.operation2 = func[operation2] 281 | self.magnitude2 = ranges[operation2][magnitude_idx2] 282 | 283 | 284 | def __call__(self, img): 285 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 286 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 287 | return img --------------------------------------------------------------------------------