├── .gitignore ├── .readthedocs.yml ├── LICENSE ├── README.md ├── configs ├── search │ ├── AttentiveNAS │ │ ├── eval.yaml │ │ └── train.yaml │ ├── BigNAS │ │ ├── eval.yaml │ │ ├── search.yaml │ │ └── train.yaml │ ├── DrNAS │ │ ├── drnas_darts_cifar10.yaml │ │ ├── drnas_darts_imagenet.yaml │ │ ├── drnas_nasbench201_cifar10.yaml │ │ ├── drnas_nasbench201_cifar100.yaml │ │ └── drnas_nasbench201_cifar10_progressive.yaml │ ├── OFA │ │ └── mbv3 │ │ │ ├── depth_1.yaml │ │ │ ├── depth_2.yaml │ │ │ ├── expand_1.yaml │ │ │ ├── expand_2.yaml │ │ │ ├── kernel_1.yaml │ │ │ └── normal_1.yaml │ ├── RMINAS │ │ ├── rminas_darts_cifar10.yaml │ │ ├── rminas_darts_cifar100.yaml │ │ ├── rminas_nasbench201_cifar10.yaml │ │ ├── rminas_nasbench201_cifar100.yaml │ │ ├── rminas_nasbench201_imagenet16.yaml │ │ ├── rminas_nasbenchmacro_cifar10.yaml │ │ └── rminas_proxyless_imagenet.yaml │ ├── SNG │ │ ├── ASNG_darts_cifar10.yaml │ │ ├── DDPNAS_darts_cifar10.yaml │ │ ├── DDPNAS_nasbench201_cifar10.yaml │ │ ├── DynamicASNG_darts_cifar10.yaml │ │ ├── DynamicSNG_darts_cifar10.yaml │ │ ├── MIGO_darts_cifar10.yaml │ │ └── SNG_darts_cifar10.yaml │ ├── darts_darts_cifar10.yaml │ ├── darts_nasbench201_cifar10.yaml │ ├── dropnas_darts_cifar10.yaml │ ├── gdas_nasbench201_cifar10.yaml │ ├── nasbenchmacro_nasbenchmacro_cifar10.yaml │ ├── pcdarts_pcdarts_cifar10.yaml │ ├── pdarts_pdarts_cifar10.yaml │ ├── snas_nasbench201_cifar10.yaml │ ├── spos_nasbench201_cifar10.yaml │ └── spos_spos_cifar10.yaml └── train │ ├── darts_darts_cifar10.yaml │ └── spos_spos_cifar10.yaml ├── docs ├── conf.py ├── data_preparation.md ├── get_started.md ├── index.rst ├── notes.md └── requirements.txt ├── examples ├── search │ ├── DrNAS │ │ ├── drnas_darts_cifar10.sh │ │ ├── drnas_darts_imagenet.sh │ │ ├── drnas_nasbench201_cifar10.sh │ │ ├── drnas_nasbench201_cifar100.sh │ │ └── drnas_nasbench201_cifar10_progressive.sh │ ├── OFA │ │ └── train_supernet.sh │ ├── RMINAS │ │ ├── rminas_darts_cifar10.sh │ │ ├── rminas_nasbench201_cifar10.sh │ │ ├── rminas_nasbench201_cifar100.sh │ │ └── rminas_nasbench201_imagenet16.sh │ ├── darts_darts_cifar10.sh │ ├── darts_nasbench201_cifar10.sh │ ├── dropnas_darts_cifar10.sh │ ├── gdas_nasbench201_cifar10.sh │ ├── ofa_mbv3_imagenet.sh │ ├── pcdarts_pcdarts_cifar10.sh │ ├── pdarts_pdarts_cifar10.sh │ └── snas_nasbench201_cifar10.sh └── train │ └── darts_darts_cifar10.sh ├── requirements.txt ├── scripts ├── search │ ├── AttentiveNAS │ │ └── train_supernet.py │ ├── BigNAS │ │ ├── eval.py │ │ ├── search.py │ │ └── train_supernet.py │ ├── DARTS.py │ ├── DrNAS.py │ ├── DropNAS.py │ ├── NBMacro.py │ ├── OFA │ │ ├── eval_supernet.py │ │ ├── search.py │ │ ├── train_supernet.py │ │ └── train_teacher_model.py │ ├── PCDARTS.py │ ├── PDARTS.py │ ├── RMINAS.py │ ├── SNG │ │ ├── func_optimize.py │ │ ├── nb1shot1.py │ │ └── search.py │ └── SPOS.py └── train │ └── DARTS.py ├── tests ├── MultiSizeRandomCrop.py ├── adjust_lr_per_iter.py ├── configs │ ├── datasets.yaml │ └── imagenet.yaml ├── dataloader.py ├── imagenet.py ├── nb101space.py ├── nb1shot1space.py ├── nb201eval.py ├── nb301eval.py ├── ofa_matrices_test.py └── test_REA.py └── xnas ├── algorithms ├── AttentiveNAS │ └── sampler.py ├── DARTS.py ├── DrNAS.py ├── DropNAS.py ├── PCDARTS.py ├── PDARTS.py ├── RMINAS │ ├── README.md │ ├── download_weight.sh │ ├── sampler │ │ ├── RF_sampling.py │ │ ├── available_archs.txt │ │ └── sampling.py │ ├── teacher_model │ │ ├── fbresnet_imagenet │ │ │ └── fbresnet.py │ │ ├── resnet101_cifar100 │ │ │ └── resnet.py │ │ └── resnet20_cifar10 │ │ │ └── resnet.py │ └── utils │ │ ├── RMI_torch.py │ │ ├── get_accuracy.ipynb │ │ └── random_data.py ├── SNG │ ├── ASNG.py │ ├── DDPNAS.py │ ├── GridSearch.py │ ├── MDENAS.py │ ├── MIGO.py │ ├── RAND.py │ ├── SNG.py │ └── categorical.py └── SPOS.py ├── core ├── builder.py ├── config.py └── utils.py ├── datasets ├── __init__.py ├── auto_augment_tf.py ├── imagenet.py ├── imagenet16.py ├── loader.py ├── transforms.py ├── transforms_imagenet.py └── utils │ ├── msrc_loader.py │ └── msrc_worker.py ├── evaluations ├── NASBench201.py ├── NASBench301.py ├── NASBenchMacro │ ├── evaluate.py │ └── nas-bench-macro_cifar10.json └── NASBenchmacro │ ├── evaluate.py │ └── nas-bench-macro_cifar10.json ├── logger ├── checkpoint.py ├── logging.py ├── meter.py └── timer.py ├── runner ├── __init__.py ├── criterion.py ├── optimizer.py ├── scheduler.py ├── trainer.py └── trainer_spos.py └── spaces ├── AttentiveNAS └── cnn.py ├── BigNAS ├── cnn.py ├── dynamic_layers.py ├── dynamic_ops.py ├── ops.py └── utils.py ├── DARTS ├── __init__.py ├── cnn.py ├── genos.py ├── ops.py └── utils.py ├── DrNAS ├── darts_cnn.py ├── nb201_cnn.py └── utils.py ├── DropNAS └── cnn.py ├── NASBench1Shot1 ├── cnn.py └── ops.py ├── NASBench201 ├── cnn.py ├── genos.py ├── ops.py └── utils.py ├── NASBenchMacro └── cnn.py ├── OFA ├── MobileNetV3 │ ├── cnn.py │ └── ofa_cnn.py ├── ProxylessNet │ ├── cnn.py │ └── ofa_cnn.py ├── ResNets │ ├── cnn.py │ └── ofa_cnn.py ├── dynamic_ops.py ├── ops.py └── utils.py ├── PCDARTS └── cnn.py ├── PDARTS └── cnn.py ├── ProxylessNAS └── cnn.py ├── SPOS └── cnn.py └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | data/ 3 | experiment/ 4 | exp/ 5 | __pycache__/ 6 | tmp/ 7 | .idea/ 8 | .vscode/ 9 | *.py[cod] 10 | *.pdf 11 | *$py.class 12 | *.DS_Store 13 | *hyper.md 14 | *debug.py 15 | benchmark/ 16 | test.py 17 | test.ipynb 18 | nohup*.out 19 | 20 | # model weights 21 | *.pth 22 | *.th 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | # lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | docs/build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | 125 | #exp_out 126 | *.pkl 127 | *.png 128 | 129 | 301model/ -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/conf.py 5 | 6 | python: 7 | version: 3.7 8 | install: 9 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 SherwoodZheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/search/AttentiveNAS/eval.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 4 2 | RNG_SEED: 2 3 | SPACE: 4 | NAME: 'attentivenas' 5 | LOADER: 6 | DATASET: 'imagenet' 7 | NUM_CLASSES: 1000 8 | BATCH_SIZE: 256 9 | NUM_WORKERS: 4 10 | USE_VAL: True 11 | TRANSFORM: "auto_augment_tf" 12 | SEARCH: 13 | IM_SIZE: 224 14 | ATTENTIVENAS: 15 | BN_MOMENTUM: 0.1 16 | BN_EPS: 1.e-5 17 | POST_BN_CALIBRATION_BATCH_NUM: 64 18 | ACTIVE_SUBNET: # chosen from following settings 19 | # attentive_nas_a0 20 | RESOLUTION: 192 21 | WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792] 22 | KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3] 23 | EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6] 24 | DEPTH: [1, 3, 3, 3, 3, 3, 1] 25 | 26 | # # attentive_nas_a1 27 | # RESOLUTION: 224 28 | # WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1984] 29 | # KERNEL_SIZE: [3, 3, 3, 5, 3, 5, 3] 30 | # EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6] 31 | # DEPTH: [1, 3, 3, 3, 3, 3, 1] 32 | 33 | # # attentive_nas_a2 34 | # RESOLUTION: 224 35 | # WIDTH: [16, 16, 24, 32, 64, 112, 200, 224, 1984] 36 | # KERNEL_SIZE: [3, 3, 3, 3, 3, 5, 3] 37 | # EXPAND_RATIO: [1, 4, 5, 4, 4, 6, 6] 38 | # DEPTH: [1, 3, 3, 3, 3, 4, 1] 39 | 40 | # # attentive_nas_a3 41 | # RESOLUTION: 224 42 | # WIDTH: [16, 16, 24, 32, 64, 112, 208, 224, 1984] 43 | # KERNEL_SIZE: [3, 3, 3, 5, 3, 3, 3] 44 | # EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6] 45 | # DEPTH: [2, 3, 3, 4, 3, 5, 1] 46 | 47 | # # attentive_nas_a4 48 | # RESOLUTION: 256 49 | # WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1984] 50 | # KERNEL_SIZE: [3, 3, 3, 5, 3, 5, 3] 51 | # EXPAND_RATIO: [1, 4, 4, 5, 4, 6, 6] 52 | # DEPTH: [1, 3, 3, 4, 3, 5, 1] 53 | 54 | # # attentive_nas_a5 55 | # RESOLUTION: 256 56 | # WIDTH: [16, 16, 24, 32, 72, 112, 192, 216, 1792] 57 | # KERNEL_SIZE: [3, 3, 3, 5, 3, 3, 3] 58 | # EXPAND_RATIO: [1, 4, 5, 4, 4, 6, 6] 59 | # DEPTH: [1, 3, 3, 3, 4, 6, 1] 60 | 61 | # # attentive_nas_a6 62 | # RESOLUTION: 288 63 | # WIDTH: [16, 16, 24, 32, 64, 112, 216, 224, 1984] 64 | # KERNEL_SIZE: [3, 3, 3, 3, 3, 5, 3] 65 | # EXPAND_RATIO: [1, 4, 6, 5, 4, 6, 6] 66 | # DEPTH: [1, 3, 3, 4, 4, 6, 1] 67 | SUPERNET_CFG: 68 | use_v3_head: True 69 | resolutions: [192, 224, 256, 288] 70 | first_conv: 71 | c: [16, 24] 72 | act_func: 'swish' 73 | s: 2 74 | mb1: 75 | c: [16, 24] 76 | d: [1, 2] 77 | k: [3, 5] 78 | t: [1] 79 | s: 1 80 | act_func: 'swish' 81 | se: False 82 | mb2: 83 | c: [24, 32] 84 | d: [3, 4, 5] 85 | k: [3, 5] 86 | t: [4, 5, 6] 87 | s: 2 88 | act_func: 'swish' 89 | se: False 90 | mb3: 91 | c: [32, 40] 92 | d: [3, 4, 5, 6] 93 | k: [3, 5] 94 | t: [4, 5, 6] 95 | s: 2 96 | act_func: 'swish' 97 | se: True 98 | mb4: 99 | c: [64, 72] 100 | d: [3, 4, 5, 6] 101 | k: [3, 5] 102 | t: [4, 5, 6] 103 | s: 2 104 | act_func: 'swish' 105 | se: False 106 | mb5: 107 | c: [112, 120, 128] 108 | d: [3, 4, 5, 6, 7, 8] 109 | k: [3, 5] 110 | t: [4, 5, 6] 111 | s: 1 112 | act_func: 'swish' 113 | se: True 114 | mb6: 115 | c: [192, 200, 208, 216] 116 | d: [3, 4, 5, 6, 7, 8] 117 | k: [3, 5] 118 | t: [6] 119 | s: 2 120 | act_func: 'swish' 121 | se: True 122 | mb7: 123 | c: [216, 224] 124 | d: [1, 2] 125 | k: [3, 5] 126 | t: [6] 127 | s: 1 128 | act_func: 'swish' 129 | se: True 130 | last_conv: 131 | c: [1792, 1984] 132 | act_func: 'swish' 133 | -------------------------------------------------------------------------------- /configs/search/AttentiveNAS/train.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 4 2 | RNG_SEED: 0 3 | SPACE: 4 | NAME: 'attentivenas' 5 | LOADER: 6 | DATASET: 'imagenet' 7 | NUM_CLASSES: 1000 8 | BATCH_SIZE: 64 # 32*8 in total 9 | NUM_WORKERS: 4 10 | USE_VAL: True 11 | TRANSFORM: "auto_augment_tf" 12 | OPTIM: 13 | GRAD_CLIP: 1. 14 | WARMUP_EPOCH: 5 15 | MAX_EPOCH: 360 16 | WEIGHT_DECAY: 1.e-5 17 | BASE_LR: 0.2 18 | NESTEROV: True 19 | SEARCH: 20 | LOSS_FUN: "cross_entropy_smooth" 21 | LABEL_SMOOTH: 0.1 22 | TRAIN: 23 | DROP_PATH_PROB: 0.2 24 | ATTENTIVENAS: 25 | SANDWICH_NUM: 4 # max + 2*middle + min 26 | DROP_CONNECT: 0.2 27 | BN_MOMENTUM: 0. 28 | BN_EPS: 1.e-5 29 | POST_BN_CALIBRATION_BATCH_NUM: 64 30 | SAMPLER: 31 | METHOD: 'bestup' 32 | MAP_PATH: 'xnas/algorithms/AttentiveNAS/flops_archs_off_table.map' 33 | DISCRETIZE_STEP: 25 34 | NUM_TRIALS: 3 35 | SUPERNET_CFG: 36 | use_v3_head: True 37 | resolutions: [192, 224, 256, 288] 38 | first_conv: 39 | c: [16, 24] 40 | act_func: 'swish' 41 | s: 2 42 | mb1: 43 | c: [16, 24] 44 | d: [1, 2] 45 | k: [3, 5] 46 | t: [1] 47 | s: 1 48 | act_func: 'swish' 49 | se: False 50 | mb2: 51 | c: [24, 32] 52 | d: [3, 4, 5] 53 | k: [3, 5] 54 | t: [4, 5, 6] 55 | s: 2 56 | act_func: 'swish' 57 | se: False 58 | mb3: 59 | c: [32, 40] 60 | d: [3, 4, 5, 6] 61 | k: [3, 5] 62 | t: [4, 5, 6] 63 | s: 2 64 | act_func: 'swish' 65 | se: True 66 | mb4: 67 | c: [64, 72] 68 | d: [3, 4, 5, 6] 69 | k: [3, 5] 70 | t: [4, 5, 6] 71 | s: 2 72 | act_func: 'swish' 73 | se: False 74 | mb5: 75 | c: [112, 120, 128] 76 | d: [3, 4, 5, 6, 7, 8] 77 | k: [3, 5] 78 | t: [4, 5, 6] 79 | s: 1 80 | act_func: 'swish' 81 | se: True 82 | mb6: 83 | c: [192, 200, 208, 216] 84 | d: [3, 4, 5, 6, 7, 8] 85 | k: [3, 5] 86 | t: [6] 87 | s: 2 88 | act_func: 'swish' 89 | se: True 90 | mb7: 91 | c: [216, 224] 92 | d: [1, 2] 93 | k: [3, 5] 94 | t: [6] 95 | s: 1 96 | act_func: 'swish' 97 | se: True 98 | last_conv: 99 | c: [1792, 1984] 100 | act_func: 'swish' -------------------------------------------------------------------------------- /configs/search/BigNAS/eval.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 4 2 | RNG_SEED: 2 3 | SPACE: 4 | NAME: 'infer_bignas' 5 | LOADER: 6 | DATASET: 'imagenet' 7 | NUM_CLASSES: 1000 8 | BATCH_SIZE: 128 9 | NUM_WORKERS: 4 10 | USE_VAL: True 11 | TRANSFORM: "auto_augment_tf" 12 | SEARCH: 13 | IM_SIZE: 224 14 | BIGNAS: 15 | BN_MOMENTUM: 0.1 16 | BN_EPS: 1.e-5 17 | POST_BN_CALIBRATION_BATCH_NUM: 64 18 | ACTIVE_SUBNET: # subnet for evaluation 19 | RESOLUTION: 192 20 | WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792] 21 | KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3] 22 | EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6] 23 | DEPTH: [1, 3, 3, 3, 3, 3, 1] 24 | SUPERNET_CFG: 25 | use_v3_head: True 26 | resolutions: [192, 224, 256, 288] 27 | first_conv: 28 | c: [16, 24] 29 | act_func: 'swish' 30 | s: 2 31 | mb1: 32 | c: [16, 24] 33 | d: [1, 2] 34 | k: [3, 5] 35 | t: [1] 36 | s: 1 37 | act_func: 'swish' 38 | se: False 39 | mb2: 40 | c: [24, 32] 41 | d: [3, 4, 5] 42 | k: [3, 5] 43 | t: [4, 5, 6] 44 | s: 2 45 | act_func: 'swish' 46 | se: False 47 | mb3: 48 | c: [32, 40] 49 | d: [3, 4, 5, 6] 50 | k: [3, 5] 51 | t: [4, 5, 6] 52 | s: 2 53 | act_func: 'swish' 54 | se: True 55 | mb4: 56 | c: [64, 72] 57 | d: [3, 4, 5, 6] 58 | k: [3, 5] 59 | t: [4, 5, 6] 60 | s: 2 61 | act_func: 'swish' 62 | se: False 63 | mb5: 64 | c: [112, 120, 128] 65 | d: [3, 4, 5, 6, 7, 8] 66 | k: [3, 5] 67 | t: [4, 5, 6] 68 | s: 1 69 | act_func: 'swish' 70 | se: True 71 | mb6: 72 | c: [192, 200, 208, 216] 73 | d: [3, 4, 5, 6, 7, 8] 74 | k: [3, 5] 75 | t: [6] 76 | s: 2 77 | act_func: 'swish' 78 | se: True 79 | mb7: 80 | c: [216, 224] 81 | d: [1, 2] 82 | k: [3, 5] 83 | t: [6] 84 | s: 1 85 | act_func: 'swish' 86 | se: True 87 | last_conv: 88 | c: [1792, 1984] 89 | act_func: 'swish' 90 | -------------------------------------------------------------------------------- /configs/search/BigNAS/search.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | RNG_SEED: 2 3 | SPACE: 4 | NAME: 'bignas' 5 | LOADER: 6 | DATASET: 'imagenet' 7 | NUM_CLASSES: 1000 8 | BATCH_SIZE: 128 9 | NUM_WORKERS: 8 10 | USE_VAL: True 11 | TRANSFORM: "auto_augment_tf" 12 | SEARCH: 13 | IM_SIZE: 224 14 | WEIGHTS: "exp/search/test/checkpoints/best_model_epoch_0009.pyth" 15 | BIGNAS: 16 | CONSTRAINT_FLOPS: 6.e+8 # 600M 17 | NUM_MUTATE: 200 18 | BN_MOMENTUM: 0.1 19 | BN_EPS: 1.e-5 20 | POST_BN_CALIBRATION_BATCH_NUM: 64 21 | # ACTIVE_SUBNET: # subnet for evaluation 22 | # RESOLUTION: 192 23 | # WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792] 24 | # KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3] 25 | # EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6] 26 | # DEPTH: [1, 3, 3, 3, 3, 3, 1] 27 | SEARCH_CFG_SETS: 28 | resolutions: [224, 256] 29 | first_conv: 30 | c: [16] 31 | mb1: 32 | c: [16] 33 | d: [2] 34 | k: [3] 35 | t: [1] 36 | mb2: 37 | c: [24] 38 | d: [3] 39 | k: [3] 40 | t: [5] 41 | mb3: 42 | c: [32] 43 | d: [4] 44 | k: [3] 45 | t: [5] 46 | mb4: 47 | c: [64] 48 | d: [5] 49 | k: [3] 50 | t: [5] 51 | mb5: 52 | c: [120] 53 | d: [6] 54 | k: [3] 55 | t: [5] 56 | mb6: 57 | c: [192] 58 | d: [6] 59 | k: [3, 5] 60 | t: [6] 61 | mb7: 62 | c: [216] 63 | d: [2] 64 | k: [3] 65 | t: [6] 66 | last_conv: 67 | c: [1792] 68 | SUPERNET_CFG: 69 | use_v3_head: True 70 | resolutions: [192, 224, 256, 288] 71 | first_conv: 72 | c: [16, 24] 73 | act_func: 'swish' 74 | s: 2 75 | mb1: 76 | c: [16, 24] 77 | d: [1, 2] 78 | k: [3, 5] 79 | t: [1] 80 | s: 1 81 | act_func: 'swish' 82 | se: False 83 | mb2: 84 | c: [24, 32] 85 | d: [3, 4, 5] 86 | k: [3, 5] 87 | t: [4, 5, 6] 88 | s: 2 89 | act_func: 'swish' 90 | se: False 91 | mb3: 92 | c: [32, 40] 93 | d: [3, 4, 5, 6] 94 | k: [3, 5] 95 | t: [4, 5, 6] 96 | s: 2 97 | act_func: 'swish' 98 | se: True 99 | mb4: 100 | c: [64, 72] 101 | d: [3, 4, 5, 6] 102 | k: [3, 5] 103 | t: [4, 5, 6] 104 | s: 2 105 | act_func: 'swish' 106 | se: False 107 | mb5: 108 | c: [112, 120, 128] 109 | d: [3, 4, 5, 6, 7, 8] 110 | k: [3, 5] 111 | t: [4, 5, 6] 112 | s: 1 113 | act_func: 'swish' 114 | se: True 115 | mb6: 116 | c: [192, 200, 208, 216] 117 | d: [3, 4, 5, 6, 7, 8] 118 | k: [3, 5] 119 | t: [6] 120 | s: 2 121 | act_func: 'swish' 122 | se: True 123 | mb7: 124 | c: [216, 224] 125 | d: [1, 2] 126 | k: [3, 5] 127 | t: [6] 128 | s: 1 129 | act_func: 'swish' 130 | se: True 131 | last_conv: 132 | c: [1792, 1984] 133 | act_func: 'swish' 134 | -------------------------------------------------------------------------------- /configs/search/BigNAS/train.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 4 2 | RNG_SEED: 0 3 | SPACE: 4 | NAME: 'bignas' 5 | LOADER: 6 | DATASET: 'imagenet' 7 | NUM_CLASSES: 1000 8 | BATCH_SIZE: 48 9 | NUM_WORKERS: 4 10 | USE_VAL: True 11 | TRANSFORM: "auto_augment_tf" 12 | OPTIM: 13 | GRAD_CLIP: 1. 14 | WARMUP_EPOCH: 5 15 | MAX_EPOCH: 360 16 | WEIGHT_DECAY: 1.e-5 17 | BASE_LR: 0.15 18 | NESTEROV: True 19 | SEARCH: 20 | LOSS_FUN: "cross_entropy_smooth" 21 | LABEL_SMOOTH: 0.1 22 | TRAIN: 23 | DROP_PATH_PROB: 0.2 24 | BIGNAS: 25 | SANDWICH_NUM: 4 # max + 2*middle + min 26 | DROP_CONNECT: 0.2 27 | BN_MOMENTUM: 0. 28 | BN_EPS: 1.e-5 29 | POST_BN_CALIBRATION_BATCH_NUM: 64 30 | SUPERNET_CFG: 31 | use_v3_head: True 32 | resolutions: [192, 224, 256, 288] 33 | first_conv: 34 | c: [16, 24] 35 | act_func: 'swish' 36 | s: 2 37 | mb1: 38 | c: [16, 24] 39 | d: [1, 2] 40 | k: [3, 5] 41 | t: [1] 42 | s: 1 43 | act_func: 'swish' 44 | se: False 45 | mb2: 46 | c: [24, 32] 47 | d: [3, 4, 5] 48 | k: [3, 5] 49 | t: [4, 5, 6] 50 | s: 2 51 | act_func: 'swish' 52 | se: False 53 | mb3: 54 | c: [32, 40] 55 | d: [3, 4, 5, 6] 56 | k: [3, 5] 57 | t: [4, 5, 6] 58 | s: 2 59 | act_func: 'swish' 60 | se: True 61 | mb4: 62 | c: [64, 72] 63 | d: [3, 4, 5, 6] 64 | k: [3, 5] 65 | t: [4, 5, 6] 66 | s: 2 67 | act_func: 'swish' 68 | se: False 69 | mb5: 70 | c: [112, 120, 128] 71 | d: [3, 4, 5, 6, 7, 8] 72 | k: [3, 5] 73 | t: [4, 5, 6] 74 | s: 1 75 | act_func: 'swish' 76 | se: True 77 | mb6: 78 | c: [192, 200, 208, 216] 79 | d: [3, 4, 5, 6, 7, 8] 80 | k: [3, 5] 81 | t: [6] 82 | s: 2 83 | act_func: 'swish' 84 | se: True 85 | mb7: 86 | c: [216, 224] 87 | d: [1, 2] 88 | k: [3, 5] 89 | t: [6] 90 | s: 1 91 | act_func: 'swish' 92 | se: True 93 | last_conv: 94 | c: [1792, 1984] 95 | act_func: 'swish' -------------------------------------------------------------------------------- /configs/search/DrNAS/drnas_darts_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'drnas_darts' 3 | CHANNELS: 36 4 | LAYERS: 20 5 | NODES: 4 6 | LOADER: 7 | DATASET: 'cifar10' 8 | SPLIT: [0.5, 0.5] 9 | BATCH_SIZE: 64 10 | NUM_CLASSES: 10 11 | SEARCH: 12 | IM_SIZE: 32 13 | OPTIM: 14 | MIN_LR: 0.0 15 | WEIGHT_DECAY: 3.e-4 16 | LR_POLICY: 'cos' 17 | DARTS: 18 | UNROLLED: True 19 | ALPHA_LR: 6.e-4 20 | ALPHA_WEIGHT_DECAY: 1.e-3 21 | DRNAS: 22 | K: 6 23 | REG_TYPE: "l2" 24 | REG_SCALE: 1.e-3 25 | METHOD: 'dirichlet' 26 | TAU: [1, 10] 27 | OUT_DIR: 'exp/drnas' 28 | -------------------------------------------------------------------------------- /configs/search/DrNAS/drnas_darts_imagenet.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'drnas_darts' 3 | CHANNELS: 48 4 | LAYERS: 14 5 | NODES: 4 6 | LOADER: 7 | DATASET: 'imagenet' 8 | SPLIT: [0.5, 0.5] 9 | BATCH_SIZE: 512 10 | NUM_WORKERS: 16 11 | NUM_CLASSES: 1000 12 | SEARCH: 13 | IM_SIZE: 32 14 | OPTIM: 15 | BASE_LR: 0.5 16 | MIN_LR: 0.0 17 | WEIGHT_DECAY: 3.e-4 18 | LR_POLICY: 'cos' 19 | WARMUP_EPOCH: 5 20 | DARTS: 21 | UNROLLED: False 22 | ALPHA_LR: 6.e-3 23 | ALPHA_WEIGHT_DECAY: 1.e-3 24 | DRNAS: 25 | K: 6 26 | REG_TYPE: "l2" 27 | REG_SCALE: 1.e-3 28 | METHOD: 'dirichlet' 29 | TAU: [1, 10] 30 | OUT_DIR: 'exp/drnas' -------------------------------------------------------------------------------- /configs/search/DrNAS/drnas_nasbench201_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'drnas_nb201' 3 | CHANNELS: 16 4 | LAYERS: 5 5 | NODES: 4 6 | LOADER: 7 | DATASET: 'cifar10' 8 | SPLIT: [0.5, 0.5] 9 | BATCH_SIZE: 64 10 | NUM_CLASSES: 10 11 | SEARCH: 12 | IM_SIZE: 32 13 | OPTIM: 14 | BASE_LR: 0.025 15 | MIN_LR: 0.001 16 | WEIGHT_DECAY: 3.e-4 17 | LR_POLICY: 'cos' 18 | MAX_EPOCH: 100 19 | DARTS: 20 | UNROLLED: True 21 | ALPHA_LR: 3.e-4 22 | ALPHA_WEIGHT_DECAY: 1.e-3 23 | DRNAS: 24 | K: 1 25 | REG_TYPE: "l2" 26 | REG_SCALE: 1.e-3 27 | METHOD: 'dirichlet' 28 | TAU: [1, 10] 29 | PROGRESSIVE: False 30 | OUT_DIR: 'exp/drnas' -------------------------------------------------------------------------------- /configs/search/DrNAS/drnas_nasbench201_cifar100.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'drnas_nb201' 3 | CHANNELS: 16 4 | LAYERS: 5 5 | NODES: 4 6 | LOADER: 7 | DATASET: 'cifar100' 8 | SPLIT: [0.5, 0.5] 9 | BATCH_SIZE: 64 10 | NUM_CLASSES: 100 11 | SEARCH: 12 | IM_SIZE: 32 13 | OPTIM: 14 | BASE_LR: 0.025 15 | MIN_LR: 0.001 16 | WEIGHT_DECAY: 3.e-4 17 | LR_POLICY: 'cos' 18 | MAX_EPOCH: 100 19 | DARTS: 20 | UNROLLED: True 21 | ALPHA_LR: 3.e-4 22 | ALPHA_WEIGHT_DECAY: 1.e-3 23 | DRNAS: 24 | K: 1 25 | REG_TYPE: "l2" 26 | REG_SCALE: 1.e-3 27 | METHOD: 'dirichlet' 28 | TAU: [1, 10] 29 | PROGRESSIVE: False 30 | OUT_DIR: 'exp/drnas' -------------------------------------------------------------------------------- /configs/search/DrNAS/drnas_nasbench201_cifar10_progressive.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'drnas_nb201' 3 | CHANNELS: 16 4 | LAYERS: 5 5 | NODES: 4 6 | LOADER: 7 | DATASET: 'cifar10' 8 | SPLIT: [0.5, 0.5] 9 | BATCH_SIZE: 64 10 | NUM_CLASSES: 10 11 | SEARCH: 12 | IM_SIZE: 32 13 | OPTIM: 14 | BASE_LR: 0.025 15 | MIN_LR: 0.001 16 | WEIGHT_DECAY: 3.e-4 17 | LR_POLICY: 'cos' 18 | MAX_EPOCH: 100 19 | DARTS: 20 | UNROLLED: True 21 | ALPHA_LR: 3.e-4 22 | ALPHA_WEIGHT_DECAY: 1.e-3 23 | DRNAS: 24 | K: 4 25 | REG_TYPE: "l2" 26 | REG_SCALE: 1.e-3 27 | METHOD: 'dirichlet' 28 | TAU: [1, 10] 29 | PROGRESSIVE: True 30 | OUT_DIR: 'exp/drnas' -------------------------------------------------------------------------------- /configs/search/OFA/mbv3/depth_1.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 4 2 | SPACE: 3 | NAME: 'ofa_mbv3' 4 | LOADER: 5 | DATASET: 'imagenet' 6 | NUM_CLASSES: 1000 7 | BATCH_SIZE: 128 8 | NUM_WORKERS: 4 9 | USE_VAL: True 10 | SEARCH: 11 | MULTI_SIZES: [128,160,192,224] 12 | LOSS_FUN: 'cross_entropy_smooth' 13 | LABEL_SMOOTH: 0.1 14 | AUTO_RESUME: True 15 | OFA: 16 | TASK: 'depth' 17 | PHASE: 1 18 | WIDTH_MULTI_LIST: [1.0] 19 | KS_LIST: [3,5,7] 20 | EXPAND_LIST: [6] 21 | DEPTH_LIST: [3,4] 22 | CHANNEL_DIVISIBLE: 8 23 | SUBNET_BATCH_SIZE: 2 24 | KD_RATIO: 0. 25 | # KD_PATH: "exp/OFA/teacher_model.pyth" 26 | OPTIM: 27 | MAX_EPOCH: 25 28 | BASE_LR: 5.e-3 29 | MIN_LR: 0.0 30 | WARMUP_EPOCH: 0 31 | WARMUP_FACTOR: 1. 32 | LR_POLICY: 'cos' 33 | MOMENTUM: 0.9 34 | WEIGHT_DECAY: 3.e-5 35 | NESTEROV: True 36 | TEST: 37 | BATCH_SIZE: 256 38 | IM_SIZE: 224 -------------------------------------------------------------------------------- /configs/search/OFA/mbv3/depth_2.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 4 2 | SPACE: 3 | NAME: 'ofa_mbv3' 4 | LOADER: 5 | DATASET: 'imagenet' 6 | NUM_CLASSES: 1000 7 | BATCH_SIZE: 128 8 | NUM_WORKERS: 4 9 | USE_VAL: True 10 | SEARCH: 11 | MULTI_SIZES: [128,160,192,224] 12 | LOSS_FUN: 'cross_entropy_smooth' 13 | LABEL_SMOOTH: 0.1 14 | AUTO_RESUME: True 15 | OFA: 16 | TASK: 'depth' 17 | PHASE: 2 18 | WIDTH_MULTI_LIST: [1.0] 19 | KS_LIST: [3,5,7] 20 | EXPAND_LIST: [6] 21 | DEPTH_LIST: [2,3,4] 22 | CHANNEL_DIVISIBLE: 8 23 | SUBNET_BATCH_SIZE: 2 24 | KD_RATIO: 0. 25 | # KD_PATH: "exp/OFA/teacher_model.pyth" 26 | OPTIM: 27 | MAX_EPOCH: 120 28 | BASE_LR: 1.5e-2 29 | MIN_LR: 1.e-3 30 | WARMUP_EPOCH: 5 31 | WARMUP_FACTOR: 1. 32 | LR_POLICY: 'cos' 33 | MOMENTUM: 0.9 34 | WEIGHT_DECAY: 3.e-5 35 | NESTEROV: True 36 | TEST: 37 | BATCH_SIZE: 256 38 | IM_SIZE: 224 39 | 40 | -------------------------------------------------------------------------------- /configs/search/OFA/mbv3/expand_1.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 4 2 | SPACE: 3 | NAME: 'ofa_mbv3' 4 | LOADER: 5 | DATASET: 'imagenet' 6 | NUM_CLASSES: 1000 7 | BATCH_SIZE: 128 8 | NUM_WORKERS: 4 9 | USE_VAL: True 10 | SEARCH: 11 | MULTI_SIZES: [128,160,192,224] 12 | LOSS_FUN: 'cross_entropy_smooth' 13 | LABEL_SMOOTH: 0.1 14 | AUTO_RESUME: True 15 | OFA: 16 | TASK: 'expand' 17 | PHASE: 1 18 | WIDTH_MULTI_LIST: [1.0] 19 | KS_LIST: [3,5,7] 20 | EXPAND_LIST: [4,6] 21 | DEPTH_LIST: [2,3,4] 22 | CHANNEL_DIVISIBLE: 8 23 | SUBNET_BATCH_SIZE: 4 24 | KD_RATIO: 0. 25 | # KD_PATH: "exp/OFA/teacher_model.pyth" 26 | OPTIM: 27 | MAX_EPOCH: 25 28 | BASE_LR: 5.e-3 29 | MIN_LR: 0.0 30 | WARMUP_EPOCH: 0 31 | WARMUP_FACTOR: 1. 32 | LR_POLICY: 'cos' 33 | MOMENTUM: 0.9 34 | WEIGHT_DECAY: 3.e-5 35 | NESTEROV: True 36 | TEST: 37 | BATCH_SIZE: 256 38 | IM_SIZE: 224 -------------------------------------------------------------------------------- /configs/search/OFA/mbv3/expand_2.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 4 2 | SPACE: 3 | NAME: 'ofa_mbv3' 4 | LOADER: 5 | DATASET: 'imagenet' 6 | NUM_CLASSES: 1000 7 | BATCH_SIZE: 128 8 | NUM_WORKERS: 4 9 | USE_VAL: True 10 | SEARCH: 11 | MULTI_SIZES: [128,160,192,224] 12 | LOSS_FUN: 'cross_entropy_smooth' 13 | LABEL_SMOOTH: 0.1 14 | AUTO_RESUME: True 15 | OFA: 16 | TASK: 'expand' 17 | PHASE: 2 18 | WIDTH_MULTI_LIST: [1.0] 19 | KS_LIST: [3,5,7] 20 | EXPAND_LIST: [3,4,6] 21 | DEPTH_LIST: [2,3,4] 22 | CHANNEL_DIVISIBLE: 8 23 | SUBNET_BATCH_SIZE: 4 24 | KD_RATIO: 0. 25 | # KD_PATH: "exp/OFA/teacher_model.pyth" 26 | OPTIM: 27 | MAX_EPOCH: 120 28 | BASE_LR: 1.5e-2 29 | MIN_LR: 1.e-3 30 | WARMUP_EPOCH: 5 31 | WARMUP_FACTOR: 1. 32 | LR_POLICY: 'cos' 33 | MOMENTUM: 0.9 34 | WEIGHT_DECAY: 3.e-5 35 | NESTEROV: True 36 | TEST: 37 | BATCH_SIZE: 256 38 | IM_SIZE: 224 39 | 40 | -------------------------------------------------------------------------------- /configs/search/OFA/mbv3/kernel_1.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 4 2 | SPACE: 3 | NAME: 'ofa_mbv3' 4 | LOADER: 5 | DATASET: 'imagenet' 6 | NUM_CLASSES: 1000 7 | BATCH_SIZE: 128 8 | NUM_WORKERS: 4 9 | USE_VAL: True 10 | SEARCH: 11 | MULTI_SIZES: [128,160,192,224] 12 | LOSS_FUN: 'cross_entropy_smooth' 13 | LABEL_SMOOTH: 0.1 14 | AUTO_RESUME: True 15 | OFA: 16 | TASK: 'kernel' 17 | PHASE: 1 18 | WIDTH_MULTI_LIST: [1.0] 19 | KS_LIST: [3,5,7] 20 | EXPAND_LIST: [6] 21 | DEPTH_LIST: [4] 22 | CHANNEL_DIVISIBLE: 8 23 | SUBNET_BATCH_SIZE: 1 24 | KD_RATIO: 0. 25 | # KD_PATH: "exp/OFA/teacher_model.pyth" 26 | OPTIM: 27 | MAX_EPOCH: 120 28 | BASE_LR: 6.0e-2 29 | MIN_LR: 1.e-3 30 | WARMUP_EPOCH: 5 31 | WARMUP_FACTOR: 1. 32 | LR_POLICY: 'cos' 33 | MOMENTUM: 0.9 34 | WEIGHT_DECAY: 3.e-5 35 | NESTEROV: True 36 | TEST: 37 | BATCH_SIZE: 256 38 | IM_SIZE: 224 -------------------------------------------------------------------------------- /configs/search/OFA/mbv3/normal_1.yaml: -------------------------------------------------------------------------------- 1 | # -------------- 2 | # refer from: 3 | # https://github.com/skhu101/GM-NAS/blob/main/once-for-all-GM/train_ofa_net.py 4 | # -------------- 5 | 6 | NUM_GPUS: 4 7 | SPACE: 8 | NAME: 'ofa_mbv3' 9 | LOADER: 10 | DATASET: 'imagenet' 11 | NUM_CLASSES: 1000 12 | BATCH_SIZE: 128 13 | NUM_WORKERS: 4 14 | USE_VAL: True 15 | SEARCH: 16 | MULTI_SIZES: [128,160,192,224] 17 | LOSS_FUN: 'cross_entropy_smooth' 18 | LABEL_SMOOTH: 0.1 19 | WEIGHTS: '' 20 | AUTO_RESUME: True 21 | OFA: 22 | TASK: 'normal' 23 | PHASE: 1 24 | WIDTH_MULTI_LIST: [1.0] 25 | KS_LIST: [7] 26 | EXPAND_LIST: [6] 27 | DEPTH_LIST: [4] 28 | CHANNEL_DIVISIBLE: 8 29 | SUBNET_BATCH_SIZE: 1 30 | KD_RATIO: 0. 31 | # KD_PATH: "exp/OFA/teacher_model.pyth" 32 | OPTIM: 33 | MAX_EPOCH: 180 34 | BASE_LR: 0.325 35 | MIN_LR: 1.e-3 36 | WARMUP_EPOCH: 5 37 | WARMUP_FACTOR: 1. 38 | LR_POLICY: 'cos' 39 | MOMENTUM: 0.9 40 | WEIGHT_DECAY: 3.e-5 41 | NESTEROV: True 42 | TEST: 43 | BATCH_SIZE: 256 44 | IM_SIZE: 224 45 | 46 | -------------------------------------------------------------------------------- /configs/search/RMINAS/rminas_darts_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'infer_darts' 3 | LOADER: 4 | DATASET: 'cifar10' 5 | NUM_CLASSES: 10 6 | BATCH_SIZE: 128 7 | OPTIM: 8 | BASE_LR: 0.025 9 | MOMENTUM: 0.9 10 | WEIGHT_DECAY: 0.0003 11 | MAX_EPOCH: 250 12 | TRAIN: 13 | CHANNELS: 16 14 | LAYERS: 8 15 | RMINAS: 16 | LOSS_BETA: 0.8 17 | RF_WARMUP: 100 18 | RF_THRESRATE: 0.05 19 | RF_SUCC: 100 20 | OUT_DIR: 'exp/rminas' -------------------------------------------------------------------------------- /configs/search/RMINAS/rminas_darts_cifar100.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'infer_darts' 3 | LOADER: 4 | DATASET: 'cifar100' 5 | NUM_CLASSES: 100 6 | BATCH_SIZE: 128 7 | OPTIM: 8 | BASE_LR: 0.025 9 | MOMENTUM: 0.9 10 | WEIGHT_DECAY: 0.0003 11 | MAX_EPOCH: 250 12 | TRAIN: 13 | CHANNELS: 16 14 | LAYERS: 8 15 | RMINAS: 16 | LOSS_BETA: 0.8 17 | RF_WARMUP: 100 18 | RF_THRESRATE: 0.05 19 | RF_SUCC: 100 20 | OUT_DIR: 'exp/rminas' -------------------------------------------------------------------------------- /configs/search/RMINAS/rminas_nasbench201_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'infer_nb201' 3 | LOADER: 4 | DATASET: 'cifar10' 5 | NUM_CLASSES: 10 6 | BATCH_SIZE: 32 7 | OPTIM: 8 | BASE_LR: 0.1 9 | MOMENTUM: 0.9 10 | WEIGHT_DECAY: 0.0005 11 | MAX_EPOCH: 150 12 | TRAIN: 13 | CHANNELS: 16 14 | LAYERS: 8 15 | RMINAS: 16 | LOSS_BETA: 0.8 17 | RF_WARMUP: 100 18 | RF_THRESRATE: 0.05 19 | RF_SUCC: 100 20 | OUT_DIR: 'exp/rminas' -------------------------------------------------------------------------------- /configs/search/RMINAS/rminas_nasbench201_cifar100.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'infer_nb201' 3 | LOADER: 4 | DATASET: 'cifar100' 5 | NUM_CLASSES: 100 6 | BATCH_SIZE: 32 7 | OPTIM: 8 | BASE_LR: 0.1 9 | MOMENTUM: 0.9 10 | WEIGHT_DECAY: 0.0005 11 | MAX_EPOCH: 150 12 | TRAIN: 13 | CHANNELS: 16 14 | LAYERS: 8 15 | RMINAS: 16 | LOSS_BETA: 0.8 17 | RF_WARMUP: 100 18 | RF_THRESRATE: 0.05 19 | RF_SUCC: 100 20 | OUT_DIR: 'exp/rminas' -------------------------------------------------------------------------------- /configs/search/RMINAS/rminas_nasbench201_imagenet16.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'infer_nb201' 3 | LOADER: 4 | DATASET: 'imagenet16' 5 | NUM_CLASSES: 120 6 | BATCH_SIZE: 32 7 | OPTIM: 8 | BASE_LR: 0.1 9 | MOMENTUM: 0.9 10 | WEIGHT_DECAY: 0.0005 11 | MAX_EPOCH: 150 12 | TRAIN: 13 | CHANNELS: 16 14 | LAYERS: 8 15 | RMINAS: 16 | LOSS_BETA: 0.8 17 | RF_WARMUP: 100 18 | RF_THRESRATE: 0.05 19 | RF_SUCC: 100 20 | OUT_DIR: 'exp/rminas' -------------------------------------------------------------------------------- /configs/search/RMINAS/rminas_nasbenchmacro_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'nasbenchmacro' 3 | LOADER: 4 | DATASET: 'cifar10' 5 | NUM_CLASSES: 10 6 | BATCH_SIZE: 64 7 | OPTIM: 8 | # BASE_LR: 0.1 9 | # MOMENTUM: 0.9 10 | # WEIGHT_DECAY: 0.0005 11 | MAX_EPOCH: 30 12 | # TRAIN: 13 | # CHANNELS: 16 14 | # LAYERS: 8 15 | RMINAS: 16 | LOSS_BETA: 0.8 17 | RF_WARMUP: 100 18 | RF_THRESRATE: 0.05 19 | RF_SUCC: 100 20 | SEARCH: 21 | IM_SIZE: 32 22 | OUT_DIR: 'exp/rminas' -------------------------------------------------------------------------------- /configs/search/RMINAS/rminas_proxyless_imagenet.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'proxyless' 3 | LOADER: 4 | DATASET: 'imagenet' 5 | NUM_CLASSES: 10 6 | NUM_WORKERS: 0 7 | BATCH_SIZE: 128 8 | OPTIM: 9 | BASE_LR: 0.025 10 | MOMENTUM: 0.9 11 | WEIGHT_DECAY: 0.0003 12 | MAX_EPOCH: 500 13 | TRAIN: 14 | CHANNELS: 16 15 | LAYERS: 8 16 | RMINAS: 17 | LOSS_BETA: 0.8 18 | RF_WARMUP: 100 19 | RF_THRESRATE: 0.05 20 | RF_SUCC: 100 21 | OUT_DIR: 'exp/rminas' -------------------------------------------------------------------------------- /configs/search/SNG/ASNG_darts_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'darts' 3 | LOADER: 4 | DATASET: 'cifar10' 5 | SPLIT: [0.8, 0.2] 6 | BATCH_SIZE: 256 7 | NUM_CLASSES: 10 8 | SEARCH: 9 | IM_SIZE: 32 10 | OPTIM: 11 | STEPS: [48, 96] 12 | SNG: 13 | NAME: 'ASNG' 14 | THETA_LR: 0.1 15 | PRUNING: True 16 | PRUNING_STEP: 3 17 | PROB_SAMPLING: False 18 | UTILITY: 'log' 19 | UTILITY_FACTOR: 0.4 20 | LAMBDA: -1 21 | MOMENTUM: True 22 | GAMMA: 0.9 23 | SAMPLING_PER_EDGE: 1 24 | RANDOM_SAMPLE: True 25 | WARMUP_RANDOM_SAMPLE: True 26 | BIGMODEL_SAMPLE_PROB: 0.5 27 | BIGMODEL_NON_PARA: 2 28 | EDGE_SAMPLING: False 29 | EDGE_SAMPLING_EPOCH: -1 30 | OUT_DIR: 'exp/ASNG' -------------------------------------------------------------------------------- /configs/search/SNG/DDPNAS_darts_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'darts' 3 | LOADER: 4 | DATASET: 'cifar10' 5 | SPLIT: [0.8, 0.2] 6 | BATCH_SIZE: 256 7 | NUM_CLASSES: 10 8 | SEARCH: 9 | IM_SIZE: 32 10 | OPTIM: 11 | STEPS: [48, 96] 12 | BASE_LR: 0.025 13 | LR_POLICY: 'cos' 14 | SNG: 15 | NAME: 'DDPNAS' 16 | THETA_LR: 0.01 17 | PRUNING: True 18 | PRUNING_STEP: 3 19 | PROB_SAMPLING: False 20 | UTILITY: 'log' 21 | UTILITY_FACTOR: 0.4 22 | LAMBDA: -1 23 | MOMENTUM: True 24 | GAMMA: 0.9 25 | SAMPLING_PER_EDGE: 1 26 | RANDOM_SAMPLE: True 27 | WARMUP_RANDOM_SAMPLE: True 28 | BIGMODEL_SAMPLE_PROB: 0.5 29 | BIGMODEL_NON_PARA: 2 30 | EDGE_SAMPLING: False 31 | EDGE_SAMPLING_EPOCH: -1 32 | OUT_DIR: 'exp/DDPNAS' 33 | -------------------------------------------------------------------------------- /configs/search/SNG/DDPNAS_nasbench201_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'nasbench201' 3 | CHANNELS: 16 4 | LAYERS: 5 5 | NODES: 4 6 | LOADER: 7 | DATASET: 'cifar10' 8 | SPLIT: [0.8, 0.2] 9 | BATCH_SIZE: 256 10 | NUM_CLASSES: 10 11 | SEARCH: 12 | IM_SIZE: 32 13 | EVALUATION: 'nasbench201' 14 | OPTIM: 15 | STEPS: [48, 96] 16 | BASE_LR: 0.025 17 | LR_POLICY: 'cos' 18 | SNG: 19 | NAME: 'DDPNAS' 20 | THETA_LR: 0.01 21 | PRUNING: True 22 | PRUNING_STEP: 3 23 | PROB_SAMPLING: False 24 | UTILITY: 'log' 25 | UTILITY_FACTOR: 0.4 26 | LAMBDA: -1 27 | MOMENTUM: True 28 | GAMMA: 0.9 29 | SAMPLING_PER_EDGE: 1 30 | RANDOM_SAMPLE: True 31 | WARMUP_RANDOM_SAMPLE: True 32 | BIGMODEL_SAMPLE_PROB: 0. 33 | BIGMODEL_NON_PARA: 2 34 | EDGE_SAMPLING: False 35 | EDGE_SAMPLING_EPOCH: -1 36 | OUT_DIR: 'exp/DDPNAS' 37 | -------------------------------------------------------------------------------- /configs/search/SNG/DynamicASNG_darts_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'darts' 3 | LOADER: 4 | DATASET: 'cifar10' 5 | SPLIT: [0.8, 0.2] 6 | BATCH_SIZE: 256 7 | NUM_CLASSES: 10 8 | SEARCH: 9 | IM_SIZE: 32 10 | OPTIM: 11 | STEPS: [48, 96] 12 | SNG: 13 | NAME: 'dynamic_ASNG' 14 | THETA_LR: 0.1 15 | PRUNING: True 16 | PRUNING_STEP: 3 17 | PROB_SAMPLING: False 18 | UTILITY: 'log' 19 | UTILITY_FACTOR: 0.4 20 | LAMBDA: -1 21 | MOMENTUM: True 22 | GAMMA: 0.9 23 | SAMPLING_PER_EDGE: 1 24 | RANDOM_SAMPLE: True 25 | WARMUP_RANDOM_SAMPLE: True 26 | BIGMODEL_SAMPLE_PROB: 0.5 27 | BIGMODEL_NON_PARA: 2 28 | EDGE_SAMPLING: False 29 | EDGE_SAMPLING_EPOCH: -1 30 | OUT_DIR: 'exp/Dynamic_ASNG' -------------------------------------------------------------------------------- /configs/search/SNG/DynamicSNG_darts_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'darts' 3 | LOADER: 4 | DATASET: 'cifar10' 5 | SPLIT: [0.8, 0.2] 6 | BATCH_SIZE: 256 7 | NUM_CLASSES: 10 8 | SEARCH: 9 | IM_SIZE: 32 10 | OPTIM: 11 | STEPS: [48, 96] 12 | SNG: 13 | NAME: 'dynamic_SNG' 14 | THETA_LR: 0.1 15 | PRUNING: True 16 | PRUNING_STEP: 3 17 | PROB_SAMPLING: False 18 | UTILITY: 'log' 19 | UTILITY_FACTOR: 0.4 20 | LAMBDA: -1 21 | MOMENTUM: True 22 | GAMMA: 0.9 23 | SAMPLING_PER_EDGE: 1 24 | RANDOM_SAMPLE: True 25 | WARMUP_RANDOM_SAMPLE: True 26 | BIGMODEL_SAMPLE_PROB: 0.5 27 | BIGMODEL_NON_PARA: 2 28 | EDGE_SAMPLING: False 29 | EDGE_SAMPLING_EPOCH: -1 30 | OUT_DIR: 'exp/Dynamic_SNG' -------------------------------------------------------------------------------- /configs/search/SNG/MIGO_darts_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'darts' 3 | LOADER: 4 | DATASET: 'cifar10' 5 | SPLIT: [0.5, 0.5] 6 | BATCH_SIZE: 256 7 | NUM_CLASSES: 10 8 | SEARCH: 9 | IM_SIZE: 32 10 | OPTIM: 11 | BASE_LR: 0.1 12 | MOMENTUM: 0.9 13 | STEPS: [60, 120] 14 | LR_POLICY: 'step' 15 | WARMUP_EPOCH: 0 16 | MAX_EPOCH: 200 17 | FINAL_EPOCH: 50 18 | SNG: 19 | NAME: 'MIGO' 20 | THETA_LR: 0.1 21 | PRUNING: False 22 | PRUNING_STEP: 2 23 | PROB_SAMPLING: True 24 | UTILITY: 'log' 25 | UTILITY_FACTOR: 0.4 26 | LAMBDA: -1 27 | MOMENTUM: True 28 | GAMMA: 0.5 29 | SAMPLING_PER_EDGE: 1 30 | RANDOM_SAMPLE: True 31 | WARMUP_RANDOM_SAMPLE: True 32 | BIGMODEL_SAMPLE_PROB: 0.5 33 | BIGMODEL_NON_PARA: 2 34 | EDGE_SAMPLING: True 35 | EDGE_SAMPLING_EPOCH: -1 36 | OUT_DIR: 'exp/MIGO' -------------------------------------------------------------------------------- /configs/search/SNG/SNG_darts_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'darts' 3 | LOADER: 4 | DATASET: 'cifar10' 5 | SPLIT: [0.8, 0.2] 6 | BATCH_SIZE: 256 7 | NUM_CLASSES: 10 8 | SEARCH: 9 | IM_SIZE: 32 10 | OPTIM: 11 | STEPS: [48, 96] 12 | SNG: 13 | NAME: 'SNG' 14 | THETA_LR: 0.1 15 | PRUNING: True 16 | PRUNING_STEP: 3 17 | PROB_SAMPLING: False 18 | UTILITY: 'log' 19 | UTILITY_FACTOR: 0.4 20 | LAMBDA: -1 21 | MOMENTUM: True 22 | GAMMA: 0.9 23 | SAMPLING_PER_EDGE: 1 24 | RANDOM_SAMPLE: True 25 | WARMUP_RANDOM_SAMPLE: True 26 | BIGMODEL_SAMPLE_PROB: 0.5 27 | BIGMODEL_NON_PARA: 2 28 | EDGE_SAMPLING: False 29 | EDGE_SAMPLING_EPOCH: -1 30 | OUT_DIR: 'exp/SNG' -------------------------------------------------------------------------------- /configs/search/darts_darts_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'darts' 3 | CHANNELS: 16 4 | LAYERS: 8 5 | NODES: 4 6 | LOADER: 7 | DATASET: 'cifar10' 8 | SPLIT: [0.5, 0.5] 9 | BATCH_SIZE: 64 10 | NUM_CLASSES: 10 11 | SEARCH: 12 | IM_SIZE: 32 13 | OPTIM: 14 | MAX_EPOCH: 50 15 | LR_POLICY: 'cos' 16 | DARTS: 17 | UNROLLED: True 18 | ALPHA_LR: 3.e-4 19 | ALPHA_WEIGHT_DECAY: 1.e-3 20 | OUT_DIR: 'exp/darts' -------------------------------------------------------------------------------- /configs/search/darts_nasbench201_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'nasbench201' 3 | CHANNELS: 16 4 | LAYERS: 5 5 | NODES: 4 6 | LOADER: 7 | DATASET: 'cifar10' 8 | SPLIT: [0.5, 0.5] 9 | BATCH_SIZE: 32 10 | NUM_CLASSES: 10 11 | SEARCH: 12 | IM_SIZE: 32 13 | OPTIM: 14 | MAX_EPOCH: 50 15 | LR_POLICY: 'cos' 16 | DARTS: 17 | UNROLLED: True 18 | ALPHA_LR: 3.e-4 19 | ALPHA_WEIGHT_DECAY: 1.e-3 20 | OUT_DIR: 'exp/darts' -------------------------------------------------------------------------------- /configs/search/dropnas_darts_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'dropnas' 3 | CHANNELS: 16 4 | LAYERS: 8 5 | NODES: 4 6 | LOADER: 7 | DATASET: 'cifar10' 8 | NUM_CLASSES: 10 9 | SPLIT: [0.5, 0.5] 10 | BATCH_SIZE: 64 11 | SEARCH: 12 | IM_SIZE: 32 13 | OPTIM: 14 | BASE_LR: 0.0375 15 | MIN_LR: 0.0015 16 | MOMENTUM: 0.9 17 | WEIGHT_DECAY: 3.e-4 18 | MAX_EPOCH: 50 19 | LR_POLICY: 'cos' 20 | WARMUP_EPOCH: 0 21 | DARTS: 22 | ALPHA_WEIGHT_DECAY: 0 23 | ALPHA_LR: 3.e-4 24 | DROPNAS: 25 | DROP_RATE: 3.e-5 26 | OUT_DIR: 'exp/dropnas' -------------------------------------------------------------------------------- /configs/search/gdas_nasbench201_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'gdas_nb201' 3 | CHANNELS: 16 4 | LAYERS: 5 5 | NODES: 4 6 | LOADER: 7 | DATASET: 'cifar10' 8 | SPLIT: [0.5, 0.5] 9 | BATCH_SIZE: 64 10 | NUM_CLASSES: 10 11 | SEARCH: 12 | IM_SIZE: 32 13 | OPTIM: 14 | BASE_LR: 0.025 15 | MIN_LR: 0.001 16 | WEIGHT_DECAY: 3.e-4 17 | LR_POLICY: 'cos' 18 | MAX_EPOCH: 100 19 | DARTS: 20 | UNROLLED: True 21 | ALPHA_LR: 3.e-4 22 | ALPHA_WEIGHT_DECAY: 1.e-3 23 | DRNAS: 24 | K: 1 25 | REG_TYPE: "l2" 26 | REG_SCALE: 1.e-3 27 | METHOD: 'gdas' 28 | TAU: [1, 10] 29 | PROGRESSIVE: False 30 | OUT_DIR: 'exp/gdas' -------------------------------------------------------------------------------- /configs/search/nasbenchmacro_nasbenchmacro_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'nasbenchmacro' 3 | LOADER: 4 | DATASET: 'cifar10' 5 | NUM_CLASSES: 10 6 | SPLIT: [0.5, 0.5] 7 | BATCH_SIZE: 64 8 | SEARCH: 9 | IM_SIZE: 32 10 | OUT_DIR: 'exp/nbm' -------------------------------------------------------------------------------- /configs/search/pcdarts_pcdarts_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'pcdarts' 3 | CHANNELS: 16 4 | LAYERS: 8 5 | NODES: 4 6 | LOADER: 7 | DATASET: 'cifar10' 8 | SPLIT: [0.5, 0.5] 9 | BATCH_SIZE: 256 10 | NUM_CLASSES: 10 11 | SEARCH: 12 | IM_SIZE: 32 13 | OPTIM: 14 | MAX_EPOCH: 50 15 | LR_POLICY: 'cos' 16 | WEIGHT_DECAY: 3e-4 17 | MIN_LR: 0.0 18 | DARTS: 19 | UNROLLED: True 20 | ALPHA_LR: 6.e-4 21 | ALPHA_WEIGHT_DECAY: 1.e-3 22 | OUT_DIR: 'exp/pcdarts' -------------------------------------------------------------------------------- /configs/search/pdarts_pdarts_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'pdarts' 3 | CHANNELS: 16 4 | LAYERS: 5 5 | NODES: 4 6 | BASIC_OP: [ 7 | 'none', 8 | 'max_pool_3x3', 9 | 'avg_pool_3x3', 10 | 'skip_connect', 11 | 'sep_conv_3x3', 12 | 'sep_conv_5x5', 13 | 'dil_conv_3x3', 14 | 'dil_conv_5x5' 15 | ] 16 | LOADER: 17 | DATASET: 'cifar10' 18 | SPLIT: [0.5, 0.5] 19 | BATCH_SIZE: 64 20 | NUM_CLASSES: 10 21 | SEARCH: 22 | IM_SIZE: 32 23 | OPTIM: 24 | MAX_EPOCH: 25 25 | MIN_LR: 0.0 26 | BASE_LR: 0.025 27 | WEIGHT_DECAY: 3.e-4 28 | DARTS: 29 | UNROLLED: False 30 | ALPHA_LR: 6.e-4 31 | ALPHA_WEIGHT_DECAY: 1.e-3 32 | PDARTS: 33 | ADD_LAYERS: [0, 6, 12] 34 | ADD_WIDTH: [0, 0, 0] 35 | DROPOUT_RATE: [0.1, 0.4, 0.7] 36 | NUM_TO_KEEP: [5, 3, 1] 37 | EPS_NO_ARCHS: [10, 10, 10] 38 | SCALE_FACTOR: 0.2 39 | OUT_DIR: 'exp/pdarts' -------------------------------------------------------------------------------- /configs/search/snas_nasbench201_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'drnas_nb201' 3 | CHANNELS: 16 4 | LAYERS: 5 5 | NODES: 4 6 | LOADER: 7 | DATASET: 'cifar10' 8 | SPLIT: [0.5, 0.5] 9 | BATCH_SIZE: 64 10 | NUM_CLASSES: 10 11 | SEARCH: 12 | IM_SIZE: 32 13 | OPTIM: 14 | BASE_LR: 0.025 15 | MIN_LR: 0.001 16 | WEIGHT_DECAY: 3.e-4 17 | LR_POLICY: 'cos' 18 | MAX_EPOCH: 100 19 | DARTS: 20 | UNROLLED: True 21 | ALPHA_LR: 3.e-4 22 | ALPHA_WEIGHT_DECAY: 1.e-3 23 | DRNAS: 24 | K: 1 25 | REG_TYPE: "l2" 26 | REG_SCALE: 1.e-3 27 | METHOD: 'snas' 28 | TAU: [1, 10] 29 | PROGRESSIVE: False 30 | OUT_DIR: 'exp/snas' -------------------------------------------------------------------------------- /configs/search/spos_nasbench201_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'spos_nb201' 3 | CHANNELS: 16 4 | LAYERS: 5 5 | NODES: 4 6 | LOADER: 7 | DATASET: 'cifar10' 8 | NUM_CLASSES: 10 9 | # SPLIT: [0.5, 0.5] 10 | BATCH_SIZE: 128 11 | SEARCH: 12 | IM_SIZE: 32 13 | AUTO_RESUME: False 14 | SPOS: 15 | RESIZE: True 16 | LAYERS: 90 # 6 * 15 17 | NUM_CHOICE: 5 18 | OPTIM: 19 | BASE_LR: 0.025 20 | MIN_LR: 0.001 21 | WEIGHT_DECAY: 1.e-4 22 | LR_POLICY: 'cos' 23 | MAX_EPOCH: 50 24 | OUT_DIR: 'exp/spos_201' -------------------------------------------------------------------------------- /configs/search/spos_spos_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'spos' 3 | LOADER: 4 | DATASET: 'cifar10' 5 | NUM_CLASSES: 10 6 | # SPLIT: [0.5, 0.5] 7 | BATCH_SIZE: 96 8 | SEARCH: 9 | IM_SIZE: 32 10 | # LABEL_SMOOTH: 0.1 11 | SPOS: 12 | RESIZE: True 13 | LAYERS: 20 14 | NUM_CHOICE: 4 15 | OPTIM: 16 | MAX_EPOCH: 600 17 | BASE_LR: 0.025 18 | WEIGHT_DECAY: 3.e-4 19 | OUT_DIR: 'exp/spos' -------------------------------------------------------------------------------- /configs/train/darts_darts_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'infer_darts' 3 | LOADER: 4 | DATASET: 'cifar10' 5 | BATCH_SIZE: 96 6 | NUM_CLASSES: 10 7 | TEST: 8 | BATCH_SIZE: 96 # stay same. 9 | OPTIM: 10 | BASE_LR: 0.025 11 | WEIGHT_DECAY: 3.e-4 12 | MAX_EPOCH: 600 13 | TRAIN: 14 | CHANNELS: 36 15 | LAYERS: 20 16 | AUX_WEIGHT: 0.4 17 | DROP_PATH_PROB: 0.3 18 | GENOTYPE: "Genotype(normal=[('sep_conv_5x5', 1), ('sep_conv_3x3', 0),('skip_connect', 0), ('sep_conv_5x5', 1),('sep_conv_5x5', 3), ('sep_conv_3x3', 1),('dil_conv_5x5', 3), ('max_pool_3x3', 4)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 0), ('sep_conv_5x5', 1),('skip_connect', 0), ('skip_connect', 1),('sep_conv_3x3', 3), ('skip_connect', 2),('dil_conv_3x3', 3), ('sep_conv_5x5', 0)], reduce_concat=range(2, 6))" -------------------------------------------------------------------------------- /configs/train/spos_spos_cifar10.yaml: -------------------------------------------------------------------------------- 1 | SPACE: 2 | NAME: 'infer_spos' 3 | LOADER: 4 | DATASET: 'cifar10' 5 | NUM_CLASSES: 10 6 | SPLIT: [0.5, 0.5] 7 | BATCH_SIZE: 64 8 | SEARCH: 9 | IM_SIZE: 32 10 | SPOS: 11 | RESIZE: True 12 | LAYERS: 20 13 | NUM_CHOICE: 4 14 | CHOICE: [1, 0, 3, 1, 3, 0, 3, 0, 0, 3, 3, 0, 1, 0, 1, 2, 2, 1, 1, 3] 15 | OUT_DIR: 'username/project/XNAS/experiment/spos' -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'XNAS' 21 | copyright = '2022, PCL_AutoML' 22 | author = 'PCL_AutoML' 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = 'v0.0.1' 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | 'recommonmark', 35 | 'sphinx_markdown_tables' 36 | ] 37 | 38 | # Add any paths that contain templates here, relative to this directory. 39 | templates_path = ['_templates'] 40 | 41 | # The language for content autogenerated by Sphinx. Refer to documentation 42 | # for a list of supported languages. 43 | # 44 | # This is also used if you do content translation via gettext catalogs. 45 | # Usually you set "language" from the command line for these cases. 46 | language = 'en' 47 | # language = 'zh_CN' 48 | 49 | # List of patterns, relative to source directory, that match files and 50 | # directories to ignore when looking for source files. 51 | # This pattern also affects html_static_path and html_extra_path. 52 | exclude_patterns = [] 53 | 54 | 55 | # -- Options for HTML output ------------------------------------------------- 56 | 57 | # The theme to use for HTML and HTML Help pages. See the documentation for 58 | # a list of builtin themes. 59 | # 60 | 61 | # html_theme = 'alabaster' 62 | # html_theme = 'sphinx_rtd_theme' 63 | html_theme = 'furo' 64 | 65 | 66 | 67 | # Add any paths that contain custom static files (such as style sheets) here, 68 | # relative to this directory. They are copied after the builtin static files, 69 | # so a file named "default.css" will overwrite the builtin "default.css". 70 | html_static_path = ['_static'] -------------------------------------------------------------------------------- /docs/data_preparation.md: -------------------------------------------------------------------------------- 1 | # Common Settings 2 | 3 | It is highly recommended to save or link datasets to the `$XNAS/data` folder, thus no additional configuration is required. 4 | 5 | However, manually setting the path for datasets is also available by modifying the `cfg.LOADER.DATAPATH` attribute in the configuration file `$XNAS/xnas/core/config.py`. 6 | 7 | Additionally, files required by benchmarks are also in the `$XNAS/data` folder. You can also modify related attributes under `cfg.BENCHMARK` in the configuration file, to match your actual file locations. 8 | 9 | 10 | # Dataset Preparation 11 | 12 | The dataloaders of XNAS will read the dataset files from `$XNAS/data/$DATASET_NAME` by default, and we use lowercase filenames and remove the hyphens. For example, files for CIFAR-10 should be placed (or auto downloaded) under `$XNAS/data/cifar/` directory. 13 | 14 | XNAS currently supports the following datasets. 15 | 16 | - CIFAR-10 17 | - CIFAR-100 18 | - ImageNet 19 | - ImageNet16 (Downsampled) 20 | - SVHN 21 | - MNIST 22 | - FashionMNIST 23 | 24 | # Benchmark Preparation 25 | 26 | Some search spaces or algorithms supported by XNAS require specific APIs provided by NAS benchmarks. Installation and properly setting are required to run these code. 27 | 28 | Benchmarks supported by XNAS and their linkes are following. 29 | - nasbench101: [GitHub](https://github.com/google-research/nasbench) 30 | - nasbench1shot1: [GitHub](https://github.com/automl/nasbench-1shot1) 31 | - nasbench201: [GitHub](https://github.com/D-X-Y/NAS-Bench-201) 32 | - nasbench301: [GitHub](https://github.com/automl/nasbench301) 33 | 34 | -------------------------------------------------------------------------------- /docs/get_started.md: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | 3 | XNAS does not provide installation via `pip` currently. To run XNAS, `python>=3.7` and `pytorch==1.9` are required. Other versions of `PyTorch` may also work well, but there are potential API differences that can cause warnings to be generated. 4 | 5 | We have listed other requirements in `requirements.txt` file. 6 | 7 | # Installation 8 | 9 | 1. Clone this repo. 10 | 2. (Optional) Create a virtualenv for this library. 11 | ```sh 12 | virtualenv venv 13 | ``` 14 | 3. Install dependencies. 15 | ```sh 16 | pip install -r requirements.txt 17 | ``` 18 | 4. Set the `$PYTHONPATH` environment variable. 19 | ```sh 20 | export PYTHONPATH=$PYTHONPATH:/Path/to/XNAS 21 | ``` 22 | 5. Set the visible GPU device for XNAS. Currently XNAS supports single GPU only, but we will provide support for multi-GPUs soon. 23 | ```sh 24 | export CUDA_VISIBLE_DEVICES=0 25 | ``` 26 | 27 | Notably, environment variables are **valid only for the current terminal**. For ease of use, we recommend adding commands within your environment profile (like `~/.bashrc` for `bash`) to automatically configure environment variables after login: 28 | 29 | ```sh 30 | echo "export PYTHONPATH=$PYTHONPATH:/Path/to/XNAS" >> ~/.bashrc 31 | ``` 32 | 33 | some search spaces or algorithms supported by XNAS require specific APIs provided by NAS benchmarks. Installation and properly setting are required to run these code. 34 | 35 | Benchmarks supported by XNAS and their linkes are following. 36 | - nasbench101: [GitHub](https://github.com/google-research/nasbench) 37 | - nasbench1shot1: [GitHub](https://github.com/automl/nasbench-1shot1) 38 | - nasbench201: [GitHub](https://github.com/D-X-Y/NAS-Bench-201) 39 | - nasbench301: [GitHub](https://github.com/automl/nasbench301) 40 | 41 | For detailed instructions to install these benchmarks, please refer to [**Data Preparation**](./data_preparation.md). 42 | 43 | 44 | # Usage 45 | 46 | Before running code in XNAS, please make sure you have followed instructions in [**Data Preparation**](./data_preparation.md) in our docs to complete preparing the necessary data. 47 | 48 | The main program entries for the search and training process are in the `$XNAS/scripts` folder. To modify and add NAS code, please place files in this folder. 49 | 50 | ## Configuration Files 51 | 52 | XNAS uses the `.yaml` file format to organize the configuration files. All configuration files are placed under `$XNAS/configs` directory. To ensure the uniformity and clarity of files, we strongly recommend using the following naming convention: 53 | 54 | ```sh 55 | Algorithm_Space_Dataset[_Evaluation][_ModifiedParams].yaml 56 | ``` 57 | 58 | For example, using `DARTS` algorithm, searching on `NASBench201` space and `CIFAR-10` dataset, evaluated by `NASBench301` while modifying `MAX_EPOCH` parameter to `75`, then the file should be named as this: 59 | 60 | ```sh 61 | darts_nasbench201_cifar10_nasbench301_maxepoch75.yaml 62 | ``` 63 | 64 | ## Running Examples 65 | 66 | XNAS reads configuration files from the command line. A simple running example is following: 67 | 68 | ```sh 69 | python scripts/search/DARTS.py --cfg configs/search/darts_darts_cifar10.yaml 70 | ``` 71 | 72 | The configuration file can be overridden by adding or modifying additional parameters on the command line. For example, run with the modified output directory: 73 | 74 | ```sh 75 | python scripts/search/DARTS.py --cfg configs/search/darts_darts_cifar10.yaml OUT_DIR exp/another_folder 76 | ``` 77 | 78 | Using `.sh` files to save commands is very efficient for when you need to run and modify parameters repeatedly. We provide shell scripts under `$XNAS/examples` folder, together with other potential test code added in the future. It can be simply run with the following command: 79 | 80 | ```sh 81 | ./examples/darts_darts_cifar10.sh 82 | ``` 83 | 84 | A common mistake is forgetting to add run permissions to these files: 85 | 86 | ```sh 87 | chmod +x examples/*/*.sh examples/*/*/*.sh 88 | ``` 89 | 90 | The script files follow the same naming convention as the configuration file above, and set the output directory to the same folder. You can achieve a continuous search/training process by adding multiple lines of commands to the script file at once. 91 | 92 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. XNAS documentation master file, created by 2 | sphinx-quickstart on Sat May 21 14:09:17 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | 7 | Welcome to XNAS' documentation! 8 | ==================================== 9 | 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: Get Started 14 | 15 | get_started.md 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | :caption: Data Preparation 20 | 21 | data_preparation.md 22 | 23 | 24 | .. toctree:: 25 | :maxdepth: 2 26 | :caption: Notes 27 | 28 | notes.md 29 | 30 | 31 | Indices and tables 32 | ================== 33 | 34 | * :ref:`genindex` 35 | * :ref:`search` -------------------------------------------------------------------------------- /docs/notes.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | We welcome contributions to the library along with any potential issues or suggestions. 4 | 5 | ## Pull Requests 6 | 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `dev` branch. 10 | 2. If you've added code that should be tested, pleash add code in `tests` folder. 11 | 3. If you've changed APIs, please update the description in relevant functions. 12 | 4. Create a pull request. 13 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | furo 3 | recommonmark 4 | sphinx_markdown_tables 5 | -------------------------------------------------------------------------------- /examples/search/DrNAS/drnas_darts_cifar10.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/DrNAS.py --cfg configs/search/DrNAS/drnas_darts_cifar10.yaml OUT_DIR exp/search/drnas_darts_cifar10 2 | -------------------------------------------------------------------------------- /examples/search/DrNAS/drnas_darts_imagenet.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/DrNAS.py --cfg configs/search/DrNAS/drnas_darts_imagenet.yaml OUT_DIR exp/search/drnas_darts_imagenet 2 | -------------------------------------------------------------------------------- /examples/search/DrNAS/drnas_nasbench201_cifar10.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/DrNAS.py --cfg configs/search/DrNAS/drnas_nasbench201_cifar10.yaml OUT_DIR exp/search/drnas_nasbench201_cifar10 2 | -------------------------------------------------------------------------------- /examples/search/DrNAS/drnas_nasbench201_cifar100.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/DrNAS.py --cfg configs/search/DrNAS/drnas_nasbench201_cifar100.yaml OUT_DIR exp/search/drnas_nasbench201_cifar100 2 | -------------------------------------------------------------------------------- /examples/search/DrNAS/drnas_nasbench201_cifar10_progressive.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/DrNAS.py --cfg configs/search/DrNAS/drnas_nasbench201_cifar10_progressive.yaml OUT_DIR exp/search/drnas_nasbench201_cifar10_progressive 2 | -------------------------------------------------------------------------------- /examples/search/OFA/train_supernet.sh: -------------------------------------------------------------------------------- 1 | OUT_NAME="OFA_trial_25" 2 | TASKS="normal_1 kernel_1 depth_1 depth_2 expand_1 expand_2" 3 | 4 | for loop in $TASKS 5 | do 6 | # echo `torchrun --nproc_per_node 2 scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3/$loop.yaml OUT_DIR exp/search/$OUT_NAME/$loop OPTIM.MAX_EPOCH 2 OPTIM.WARMUP_EPOCH 2 LOADER.BATCH_SIZE 128` 7 | echo "using gpus: $@" 8 | echo `python scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3/$loop.yaml OUT_DIR exp/search/$OUT_NAME/$loop NUM_GPUS $@` 9 | echo `sleep 5s` 10 | done 11 | 12 | # # full supernet 13 | # python scripts/search/OFA/train_supernet.py --cfg configs/OFA/mbv3/OFA_normal_phase1.yaml OUT_DIR exp/$OUT_NAME/normal_1 14 | # # elastic kernel size 15 | # python scripts/search/OFA/train_supernet.py --cfg configs/OFA/mbv3/OFA_kernel_phase1.yaml OUT_DIR exp/$OUT_NAME/kernel_1 16 | # # elastic depth 17 | # python scripts/search/OFA/train_supernet.py --cfg configs/OFA/mbv3/OFA_depth_phase1.yaml OUT_DIR exp/$OUT_NAME/depth_1 18 | # python scripts/search/OFA/train_supernet.py --cfg configs/OFA/mbv3/OFA_depth_phase2.yaml OUT_DIR exp/$OUT_NAME/depth_2 19 | # # elastic width 20 | # python scripts/search/OFA/train_supernet.py --cfg configs/OFA/mbv3/OFA_expand_phase1.yaml OUT_DIR exp/$OUT_NAME/expand_1 21 | # python scripts/search/OFA/train_supernet.py --cfg configs/OFA/mbv3/OFA_expand_phase1.yaml OUT_DIR exp/$OUT_NAME/expand_1 22 | -------------------------------------------------------------------------------- /examples/search/RMINAS/rminas_darts_cifar10.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/RMINAS.py --cfg configs/search/RMINAS/rminas_darts_cifar10.yaml OUT_DIR exp/search/rminas_darts_cifar10 -------------------------------------------------------------------------------- /examples/search/RMINAS/rminas_nasbench201_cifar10.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/RMINAS.py --cfg configs/search/RMINAS/rminas_nasbench201_cifar10.yaml OUT_DIR exp/search/rminas_nasbench201_cifar10 -------------------------------------------------------------------------------- /examples/search/RMINAS/rminas_nasbench201_cifar100.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/RMINAS.py --cfg configs/search/RMINAS/rminas_nasbench201_cifar100.yaml OUT_DIR exp/search/rminas_nasbench201_cifar100 -------------------------------------------------------------------------------- /examples/search/RMINAS/rminas_nasbench201_imagenet16.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/RMINAS.py --cfg configs/search/RMINAS/rminas_nasbench201_imagenet16.yaml OUT_DIR exp/search/rminas_nasbench201_imagenet16 -------------------------------------------------------------------------------- /examples/search/darts_darts_cifar10.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/DARTS.py --cfg configs/search/darts_darts_cifar10.yaml OUT_DIR exp/search/darts_darts_cifar10 2 | -------------------------------------------------------------------------------- /examples/search/darts_nasbench201_cifar10.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/DARTS.py --cfg configs/search/darts_nasbench201_cifar10.yaml OUT_DIR exp/search/darts_nasbench201_cifar10 2 | -------------------------------------------------------------------------------- /examples/search/dropnas_darts_cifar10.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/DropNAS.py --cfg configs/search/dropnas_darts_cifar10.yaml OUT_DIR exp/search/dropnas_darts_cifar10 2 | -------------------------------------------------------------------------------- /examples/search/gdas_nasbench201_cifar10.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/DrNAS.py --cfg configs/search/gdas_nasbench201_cifar10.yaml OUT_DIR exp/search/gdas_nasbench201_cifar10 2 | -------------------------------------------------------------------------------- /examples/search/ofa_mbv3_imagenet.sh: -------------------------------------------------------------------------------- 1 | # 1. 最大网络训练 2 | python scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3_cifar10/OFA_normal_phase1.yaml 3 | # 2. elastic kernel size 4 | python scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3_cifar10/OFA_kernel_phase1.yaml 5 | # 3. elastic depth 6 | python scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3_cifar10/OFA_depth_phase1.yaml 7 | python scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3_cifar10/OFA_depth_phase2.yaml 8 | # 4. elastic width 9 | python scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3_cifar10/OFA_expand_phase1.yaml 10 | python scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3_cifar10/OFA_expand_phase2.yaml 11 | -------------------------------------------------------------------------------- /examples/search/pcdarts_pcdarts_cifar10.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/PCDARTS.py --cfg configs/search/pcdarts_pcdarts_cifar10.yaml OUT_DIR exp/search/pcdarts_pcdarts_cifar10 2 | -------------------------------------------------------------------------------- /examples/search/pdarts_pdarts_cifar10.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/PDARTS.py --cfg configs/search/pdarts_pdarts_cifar10.yaml OUT_DIR exp/search/pdarts_pdarts_cifar10 2 | -------------------------------------------------------------------------------- /examples/search/snas_nasbench201_cifar10.sh: -------------------------------------------------------------------------------- 1 | python scripts/search/DrNAS.py --cfg configs/search/snas_nasbench201_cifar10.yaml OUT_DIR exp/search/snas_nasbench201_cifar10 2 | -------------------------------------------------------------------------------- /examples/train/darts_darts_cifar10.sh: -------------------------------------------------------------------------------- 1 | python scripts/train/DARTS.py --cfg configs/train/darts_darts_cifar10.yaml OUT_DIR exp/train/darts_darts_cifar10 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | configspace 2 | numpy 3 | python >= 3.7 4 | pytorch == 1.9 5 | pillow 6 | simplejson 7 | scikit-learn 8 | tensorboard 9 | yacs -------------------------------------------------------------------------------- /scripts/search/BigNAS/eval.py: -------------------------------------------------------------------------------- 1 | """BigNAS subnet searching: Coarse-to-fine Architecture Selection""" 2 | 3 | import numpy as np 4 | from itertools import product 5 | 6 | import torch 7 | 8 | import xnas.core.config as config 9 | import xnas.logger.meter as meter 10 | import xnas.logger.logging as logging 11 | from xnas.core.builder import * 12 | from xnas.core.config import cfg 13 | from xnas.datasets.loader import get_normal_dataloader 14 | from xnas.logger.meter import TestMeter 15 | 16 | 17 | # Load config and check 18 | config.load_configs() 19 | logger = logging.get_logger(__name__) 20 | 21 | 22 | def main(): 23 | setup_env() 24 | net = space_builder().cuda() 25 | 26 | [train_loader, valid_loader] = get_normal_dataloader() 27 | 28 | test_meter = TestMeter(len(valid_loader)) 29 | # Validate 30 | top1_err, top5_err = validate(net, train_loader, valid_loader, test_meter) 31 | flops = net.compute_active_subnet_flops() 32 | 33 | logger.info("flops:{} top1_err:{} top5_err:{}".format( 34 | flops, top1_err, top5_err 35 | )) 36 | 37 | 38 | @torch.no_grad() 39 | def validate(subnet, train_loader, valid_loader, test_meter): 40 | # BN calibration 41 | subnet.eval() 42 | logger.info("Calibrating BN running statistics.") 43 | subnet.reset_running_stats_for_calibration() 44 | for cur_iter, (inputs, _) in enumerate(train_loader): 45 | if cur_iter >= cfg.BIGNAS.POST_BN_CALIBRATION_BATCH_NUM: 46 | break 47 | inputs = inputs.cuda() 48 | subnet(inputs) # forward only 49 | 50 | top1_err, top5_err = test_epoch(subnet, valid_loader, test_meter) 51 | return top1_err, top5_err 52 | 53 | 54 | def test_epoch(subnet, test_loader, test_meter): 55 | subnet.eval() 56 | test_meter.reset(True) 57 | test_meter.iter_tic() 58 | for cur_iter, (inputs, labels) in enumerate(test_loader): 59 | inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) 60 | preds = subnet(inputs) 61 | top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) 62 | top1_err, top5_err = top1_err.item(), top5_err.item() 63 | 64 | test_meter.iter_toc() 65 | test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) 66 | test_meter.log_iter_stats(0, cur_iter) 67 | test_meter.iter_tic() 68 | top1_err = test_meter.mb_top1_err.get_win_avg() 69 | top5_err = test_meter.mb_top5_err.get_win_avg() 70 | # self.writer.add_scalar('val/top1_error', test_meter.mb_top1_err.get_win_avg(), cur_epoch) 71 | # self.writer.add_scalar('val/top5_error', test_meter.mb_top5_err.get_win_avg(), cur_epoch) 72 | # Log epoch stats 73 | test_meter.log_epoch_stats(0) 74 | # test_meter.reset() 75 | return top1_err, top5_err 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /scripts/search/DARTS.py: -------------------------------------------------------------------------------- 1 | """DARTS searching""" 2 | 3 | import xnas.core.config as config 4 | import xnas.logger.logging as logging 5 | from xnas.core.config import cfg 6 | from xnas.core.builder import * 7 | 8 | # DARTS 9 | from xnas.algorithms.DARTS import * 10 | from xnas.runner.trainer import DartsTrainer 11 | from xnas.runner.optimizer import darts_alpha_optimizer 12 | 13 | 14 | # Load config and check 15 | config.load_configs() 16 | logger = logging.get_logger(__name__) 17 | 18 | def main(): 19 | device = setup_env() 20 | search_space = space_builder() 21 | criterion = criterion_builder().to(device) 22 | evaluator = evaluator_builder() 23 | 24 | [train_loader, valid_loader] = construct_loader() 25 | 26 | # init models 27 | darts_controller = DartsCNNController(search_space, criterion).to(device) 28 | architect = Architect(darts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY) 29 | 30 | # init optimizers 31 | w_optim = optimizer_builder("SGD", darts_controller.weights()) 32 | a_optim = darts_alpha_optimizer("Adam", darts_controller.alphas()) 33 | lr_scheduler = lr_scheduler_builder(w_optim) 34 | 35 | # init recorders 36 | darts_trainer = DartsTrainer( 37 | darts_controller=darts_controller, 38 | architect=architect, 39 | criterion=criterion, 40 | lr_scheduler=lr_scheduler, 41 | w_optim=w_optim, 42 | a_optim=a_optim, 43 | train_loader=train_loader, 44 | valid_loader=valid_loader, 45 | ) 46 | 47 | # load checkpoint or initial weights 48 | start_epoch = darts_trainer.darts_loading() if cfg.SEARCH.AUTO_RESUME else 0 49 | 50 | # start training 51 | darts_trainer.start() 52 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): 53 | # train epoch 54 | darts_trainer.train_epoch(cur_epoch) 55 | # test epoch 56 | if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH: 57 | darts_trainer.test_epoch(cur_epoch) 58 | # recording genotype and alpha to logger 59 | logger.info("=== Optimal genotype at epoch: {} ===".format(cur_epoch)) 60 | logger.info(darts_trainer.model.genotype()) 61 | logger.info("=== alphas at epoch: {} ===".format(cur_epoch)) 62 | darts_trainer.model.print_alphas(logger) 63 | # evaluate model 64 | if evaluator: 65 | evaluator(darts_trainer.model.genotype()) 66 | darts_trainer.finish() 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /scripts/search/DropNAS.py: -------------------------------------------------------------------------------- 1 | """DropNAS searching""" 2 | 3 | from torch import device 4 | import torch.nn as nn 5 | 6 | import xnas.core.config as config 7 | import xnas.logger.logging as logging 8 | import xnas.logger.meter as meter 9 | from xnas.core.config import cfg 10 | from xnas.core.builder import * 11 | 12 | # DropNAS 13 | from xnas.algorithms.DropNAS import * 14 | from xnas.runner.trainer import DartsTrainer 15 | from xnas.runner.optimizer import darts_alpha_optimizer 16 | 17 | 18 | # Load config and check 19 | config.load_configs() 20 | logger = logging.get_logger(__name__) 21 | 22 | def main(): 23 | device = setup_env() 24 | search_space = space_builder() 25 | criterion = criterion_builder().to(device) 26 | evaluator = evaluator_builder() 27 | 28 | [train_loader, valid_loader] = construct_loader() 29 | 30 | # init models 31 | darts_controller = DropNAS_CNNController(search_space, criterion).to(device) 32 | 33 | # init optimizers 34 | w_optim = optimizer_builder("SGD", darts_controller.weights()) 35 | a_optim = darts_alpha_optimizer("Adam", darts_controller.alphas()) 36 | lr_scheduler = lr_scheduler_builder(w_optim) 37 | 38 | # init recorders 39 | dropnas_trainer = DropNAS_Trainer( 40 | darts_controller=darts_controller, 41 | architect=None, 42 | criterion=criterion, 43 | lr_scheduler=lr_scheduler, 44 | w_optim=w_optim, 45 | a_optim=a_optim, 46 | train_loader=train_loader, 47 | valid_loader=valid_loader, 48 | ) 49 | 50 | # load checkpoint or initial weights 51 | start_epoch = dropnas_trainer.darts_loading() if cfg.SEARCH.AUTO_RESUME else 0 52 | 53 | # start training 54 | dropnas_trainer.start() 55 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): 56 | # train epoch 57 | drop_rate = 0. if cur_epoch < cfg.OPTIM.WARMUP_EPOCH else cfg.DROPNAS.DROP_RATE 58 | logger.info("Current drop rate: {:.6f}".format(drop_rate)) 59 | dropnas_trainer.train_epoch(cur_epoch, drop_rate) 60 | # test epoch 61 | if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH: 62 | # NOTE: the source code of DropNAS does not use test codes. 63 | # recording genotype and alpha to logger 64 | logger.info("=== Optimal genotype at epoch: {} ===".format(cur_epoch)) 65 | logger.info(dropnas_trainer.model.genotype()) 66 | logger.info("=== alphas at epoch: {} ===".format(cur_epoch)) 67 | dropnas_trainer.model.print_alphas(logger) 68 | if evaluator: 69 | evaluator(dropnas_trainer.model.genotype()) 70 | dropnas_trainer.finish() 71 | 72 | 73 | class DropNAS_Trainer(DartsTrainer): 74 | """Trainer for DropNAS. 75 | Rewrite the train_epoch with DropNAS's double-losses policy. 76 | """ 77 | def __init__(self, darts_controller, architect, criterion, w_optim, a_optim, lr_scheduler, train_loader, valid_loader): 78 | super().__init__(darts_controller, architect, criterion, w_optim, a_optim, lr_scheduler, train_loader, valid_loader) 79 | 80 | def train_epoch(self, cur_epoch, drop_rate): 81 | self.model.train() 82 | lr = self.lr_scheduler.get_last_lr()[0] 83 | cur_step = cur_epoch * len(self.train_loader) 84 | self.writer.add_scalar('train/lr', lr, cur_step) 85 | self.train_meter.iter_tic() 86 | for cur_iter, (trn_X, trn_y) in enumerate(self.train_loader): 87 | trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device, non_blocking=True) 88 | 89 | # forward pass loss 90 | self.a_optimizer.zero_grad() 91 | self.optimizer.zero_grad() 92 | preds = self.model(trn_X, drop_rate=drop_rate) 93 | loss1 = self.criterion(preds, trn_y) 94 | loss1.backward() 95 | nn.utils.clip_grad_norm_(self.model.weights(), cfg.OPTIM.GRAD_CLIP) 96 | self.optimizer.step() 97 | if cur_epoch >= cfg.OPTIM.WARMUP_EPOCH: 98 | self.a_optimizer.step() 99 | 100 | # weight decay loss 101 | self.a_optimizer.zero_grad() 102 | self.optimizer.zero_grad() 103 | loss2 = self.model.weight_decay_loss(cfg.OPTIM.WEIGHT_DECAY) \ 104 | + self.model.alpha_decay_loss(cfg.DARTS.ALPHA_WEIGHT_DECAY) 105 | loss2.backward() 106 | nn.utils.clip_grad_norm_(self.model.weights(), cfg.OPTIM.GRAD_CLIP) 107 | self.optimizer.step() 108 | self.a_optimizer.step() 109 | 110 | self.model.adjust_alphas() 111 | loss = loss1 + loss2 112 | 113 | # Compute the errors 114 | top1_err, top5_err = meter.topk_errors(preds, trn_y, [1, 5]) 115 | loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item() 116 | self.train_meter.iter_toc() 117 | # Update and log stats 118 | self.train_meter.update_stats(top1_err, top5_err, loss, lr, trn_X.size(0)) 119 | self.train_meter.log_iter_stats(cur_epoch, cur_iter) 120 | self.train_meter.iter_tic() 121 | self.writer.add_scalar('train/loss', loss, cur_step) 122 | self.writer.add_scalar('train/top1_error', top1_err, cur_step) 123 | self.writer.add_scalar('train/top5_error', top5_err, cur_step) 124 | cur_step += 1 125 | # Log epoch stats 126 | self.train_meter.log_epoch_stats(cur_epoch) 127 | self.train_meter.reset() 128 | # saving model 129 | if (cur_epoch + 1) % cfg.SAVE_PERIOD == 0: 130 | self.saving(cur_epoch) 131 | 132 | 133 | if __name__ == "__main__": 134 | main() 135 | -------------------------------------------------------------------------------- /scripts/search/NBMacro.py: -------------------------------------------------------------------------------- 1 | """ 2 | NAS-Bench-Macro: only (8 layers * 3 choices) + CIFAR10 3 | """ 4 | 5 | import xnas.core.config as config 6 | import xnas.logger.logging as logging 7 | from xnas.core.config import cfg 8 | from xnas.core.builder import * 9 | from xnas.runner.trainer import OneShotTrainer 10 | from xnas.algorithms.SPOS import RAND, REA 11 | 12 | # Load config and check 13 | config.load_configs() 14 | logger = logging.get_logger(__name__) 15 | 16 | 17 | def main(): 18 | setup_env() 19 | criterion = criterion_builder().cuda() 20 | [train_loader, valid_loader] = construct_loader() 21 | model = space_builder().cuda() 22 | optimizer = optimizer_builder("SGD", model.parameters()) 23 | lr_scheduler = lr_scheduler_builder(optimizer) 24 | 25 | # init sampler 26 | train_sampler = RAND(3, 8) 27 | evaluate_sampler = REA(3, 8) 28 | 29 | # init recorders 30 | nbm_trainer = OneShotTrainer( 31 | supernet=model, 32 | criterion=criterion, 33 | optimizer=optimizer, 34 | lr_scheduler=lr_scheduler, 35 | train_loader=train_loader, 36 | test_loader=valid_loader, 37 | sample_type='iter' 38 | ) 39 | nbm_trainer.register_iter_sample(train_sampler) 40 | 41 | # load checkpoint or initial weights 42 | start_epoch = nbm_trainer.loading() if cfg.SEARCH.AUTO_RESUME else 0 43 | 44 | # start training 45 | nbm_trainer.start() 46 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): 47 | # train epoch 48 | top1_err = nbm_trainer.train_epoch(cur_epoch) 49 | # test epoch 50 | if (cur_epoch + 1) % cfg.EVAL_PERIOD == 0 or (cur_epoch + 1) == cfg.OPTIM.MAX_EPOCH: 51 | top1_err = nbm_trainer.test_epoch(cur_epoch) 52 | nbm_trainer.finish() 53 | 54 | # # sample best architecture from supernet 55 | # for cycle in range(200): # NOTE: this should be a hyperparameter 56 | # sample = evaluate_sampler.suggest() 57 | # top1_err = nbm_trainer.evaluate_epoch(sample) 58 | # evaluate_sampler.record(sample, top1_err) 59 | # best_arch, best_top1err = evaluate_sampler.final_best() 60 | # logger.info("Best arch: {} \nTop1 error: {}".format(best_arch, best_top1err)) 61 | 62 | # from xnas.evaluations.NASBenchMacro.evaluate import evaluate 63 | # # for example : arch = '00000000' 64 | # arch = '' 65 | # Nbm_Eva(arch) 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /scripts/search/OFA/search.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/scripts/search/OFA/search.py -------------------------------------------------------------------------------- /scripts/search/OFA/train_teacher_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/scripts/search/OFA/train_teacher_model.py -------------------------------------------------------------------------------- /scripts/search/PCDARTS.py: -------------------------------------------------------------------------------- 1 | """PCDARTS searching""" 2 | 3 | import xnas.core.config as config 4 | import xnas.logger.logging as logging 5 | from xnas.core.config import cfg 6 | from xnas.core.builder import * 7 | 8 | # DARTS 9 | from xnas.algorithms.PCDARTS import * 10 | from xnas.runner.trainer import DartsTrainer 11 | from xnas.runner.optimizer import darts_alpha_optimizer 12 | 13 | 14 | # Load config and check 15 | config.load_configs() 16 | logger = logging.get_logger(__name__) 17 | 18 | def main(): 19 | device = setup_env() 20 | search_space = space_builder() 21 | criterion = criterion_builder().to(device) 22 | evaluator = evaluator_builder() 23 | 24 | [train_loader, valid_loader] = construct_loader() 25 | 26 | # init models 27 | pcdarts_controller = PCDartsCNNController(search_space, criterion).to(device) 28 | architect = Architect(pcdarts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY) 29 | 30 | # init optimizers 31 | w_optim = optimizer_builder("SGD", pcdarts_controller.weights()) 32 | a_optim = darts_alpha_optimizer("Adam", pcdarts_controller.alphas()) 33 | lr_scheduler = lr_scheduler_builder(w_optim) 34 | 35 | # init recorders 36 | pcdarts_trainer = DartsTrainer( 37 | darts_controller=pcdarts_controller, 38 | architect=architect, 39 | criterion=criterion, 40 | lr_scheduler=lr_scheduler, 41 | w_optim=w_optim, 42 | a_optim=a_optim, 43 | train_loader=train_loader, 44 | valid_loader=valid_loader, 45 | ) 46 | 47 | # Load checkpoint or initial weights 48 | start_epoch = pcdarts_trainer.darts_loading() if cfg.SEARCH.AUTO_RESUME else 0 49 | 50 | # start training 51 | pcdarts_trainer.start() 52 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): 53 | # train epoch 54 | pcdarts_trainer.train_epoch(cur_epoch, cur_epoch>=15) 55 | # test epoch 56 | if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH: 57 | pcdarts_trainer.test_epoch(cur_epoch) 58 | # recording genotype and alpha to logger 59 | logger.info("=== Optimal genotype at epoch: {} ===".format(cur_epoch)) 60 | logger.info(pcdarts_trainer.model.genotype()) 61 | logger.info("=== alphas at epoch: {} ===".format(cur_epoch)) 62 | pcdarts_trainer.model.print_alphas(logger) 63 | # evaluate model 64 | if evaluator: 65 | evaluator(pcdarts_trainer.model.genotype()) 66 | pcdarts_trainer.finish() 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /scripts/search/PDARTS.py: -------------------------------------------------------------------------------- 1 | """PDARTS searching""" 2 | 3 | from cmath import phase 4 | import xnas.core.config as config 5 | import xnas.logger.logging as logging 6 | from xnas.core.config import cfg 7 | from xnas.core.builder import * 8 | 9 | # PDARTS 10 | from xnas.algorithms.PDARTS import * 11 | from xnas.runner.trainer import DartsTrainer 12 | from xnas.runner.optimizer import darts_alpha_optimizer 13 | 14 | 15 | # Load config and check 16 | config.load_configs() 17 | logger = logging.get_logger(__name__) 18 | 19 | def main(): 20 | device = setup_env() 21 | criterion = criterion_builder().to(device) 22 | evaluator = evaluator_builder() 23 | 24 | [train_loader, valid_loader] = construct_loader() 25 | 26 | num_edges = (cfg.SPACE.NODES+3)*cfg.SPACE.NODES//2 27 | basic_op = [] 28 | for i in range(num_edges * 2): 29 | basic_op.append(cfg.SPACE.BASIC_OP) 30 | 31 | for sp in range(len(cfg.PDARTS.NUM_TO_KEEP)): 32 | search_space = space_builder( 33 | add_layers=cfg.PDARTS.ADD_LAYERS[sp], 34 | add_width=cfg.PDARTS.ADD_WIDTH[sp], 35 | dropout_rate=float(cfg.PDARTS.DROPOUT_RATE[sp]), 36 | basic_op=basic_op, 37 | ) 38 | 39 | # init models 40 | pdarts_controller = PDartsCNNController(search_space, criterion).to(device) 41 | architect = Architect(pdarts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY) 42 | 43 | # init optimizers 44 | w_optim = optimizer_builder("SGD", pdarts_controller.subnet_weights()) 45 | a_optim = darts_alpha_optimizer("Adam", pdarts_controller.alphas()) 46 | lr_scheduler = lr_scheduler_builder(w_optim) 47 | 48 | # init recorders 49 | pdarts_trainer = DartsTrainer( 50 | darts_controller=pdarts_controller, 51 | architect=architect, 52 | criterion=criterion, 53 | lr_scheduler=lr_scheduler, 54 | w_optim=w_optim, 55 | a_optim=a_optim, 56 | train_loader=train_loader, 57 | valid_loader=valid_loader, 58 | ) 59 | 60 | # Load checkpoint or initial weights 61 | start_epoch = pdarts_trainer.darts_loading() if cfg.SEARCH.AUTO_RESUME else 0 62 | 63 | # start training 64 | pdarts_trainer.start() 65 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): 66 | # train epoch 67 | if cur_epoch < cfg.PDARTS.EPS_NO_ARCHS[sp]: 68 | pdarts_trainer.model.update_p( 69 | float(cfg.PDARTS.DROPOUT_RATE[sp]) * 70 | (cfg.OPTIM.MAX_EPOCH - cur_epoch - 1) / 71 | cfg.OPTIM.MAX_EPOCH 72 | ) 73 | pdarts_trainer.train_epoch(cur_epoch, alpha_step=False) 74 | else: 75 | pdarts_trainer.model.update_p( 76 | float(cfg.PDARTS.DROPOUT_RATE[sp]) * 77 | np.exp(-(cur_epoch - cfg.PDARTS.EPS_NO_ARCHS[sp]) * cfg.PDARTS.SCALE_FACTOR) 78 | ) 79 | pdarts_trainer.train_epoch(cur_epoch, alpha_step=True) 80 | 81 | # test epoch 82 | if (cur_epoch+1) >= cfg.OPTIM.MAX_EPOCH - 5: 83 | pdarts_trainer.test_epoch(cur_epoch) 84 | # recording genotype and alpha to logger 85 | logger.info("=== Optimal genotype at epoch: {} ===".format(cur_epoch)) 86 | logger.info(pdarts_trainer.model.genotype()) 87 | logger.info("=== alphas at epoch: {} ===".format(cur_epoch)) 88 | pdarts_trainer.model.print_alphas(logger) 89 | # evaluate model 90 | if evaluator: 91 | evaluator(pdarts_trainer.model.genotype()) 92 | pdarts_trainer.finish() 93 | logger.info("Top-{} primitive: {}".format( 94 | cfg.PDARTS.NUM_TO_KEEP[sp], 95 | pdarts_trainer.model.get_topk_op(cfg.PDARTS.NUM_TO_KEEP[sp])) 96 | ) 97 | if sp == len(cfg.PDARTS.NUM_TO_KEEP) - 1: 98 | phase_ending(pdarts_trainer) 99 | else: 100 | basic_op = pdarts_trainer.model.get_topk_op(cfg.PDARTS.NUM_TO_KEEP[sp]) 101 | phase_ending(pdarts_trainer) 102 | 103 | 104 | def phase_ending(pdarts_trainer, final=False): 105 | phase = "Final" if final else "Stage" 106 | logger.info("=== {} optimal genotype ===".format(phase)) 107 | logger.info(pdarts_trainer.model.genotype(final=True)) 108 | logger.info("=== {} alphas ===".format(phase)) 109 | pdarts_trainer.model.print_alphas(logger) 110 | # restrict skip connect 111 | logger.info('Restricting skip-connect') 112 | for sks in range(0, 9): 113 | max_sk = 8-sks 114 | num_sk = pdarts_trainer.model.get_skip_number() 115 | if not num_sk > max_sk: 116 | continue 117 | while num_sk > max_sk: 118 | pdarts_trainer.model.delete_skip() 119 | num_sk = pdarts_trainer.model.get_skip_number() 120 | 121 | logger.info('Number of skip-connect: %d', max_sk) 122 | logger.info(pdarts_trainer.model.genotype(final=True)) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /scripts/search/SPOS.py: -------------------------------------------------------------------------------- 1 | """Single Path One-Shot""" 2 | 3 | import xnas.core.config as config 4 | import xnas.logger.logging as logging 5 | from xnas.core.config import cfg 6 | from xnas.core.builder import * 7 | 8 | # SPOS 9 | from xnas.algorithms.SPOS import RAND, REA 10 | from xnas.runner.trainer import OneShotTrainer 11 | 12 | 13 | # Load config and check 14 | config.load_configs() 15 | logger = logging.get_logger(__name__) 16 | 17 | def main(): 18 | device = setup_env() 19 | criterion = criterion_builder().to(device) 20 | [train_loader, valid_loader] = construct_loader() 21 | model = space_builder().cuda() #to(device) 22 | optimizer = optimizer_builder("SGD", model.parameters()) 23 | lr_scheduler = lr_scheduler_builder(optimizer) 24 | 25 | # init sampler 26 | train_sampler = RAND(cfg.SPOS.NUM_CHOICE, cfg.SPOS.LAYERS) 27 | evaluate_sampler = REA(cfg.SPOS.NUM_CHOICE, cfg.SPOS.LAYERS) 28 | 29 | # init recorders 30 | spos_trainer = OneShotTrainer( 31 | supernet=model, 32 | criterion=criterion, 33 | optimizer=optimizer, 34 | lr_scheduler=lr_scheduler, 35 | train_loader=train_loader, 36 | test_loader=valid_loader, 37 | sample_type='iter' 38 | ) 39 | spos_trainer.register_iter_sample(train_sampler) 40 | 41 | # load checkpoint or initial weights 42 | start_epoch = spos_trainer.loading() if cfg.SEARCH.AUTO_RESUME else 0 43 | 44 | # start training 45 | spos_trainer.start() 46 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): 47 | # train epoch 48 | top1_err = spos_trainer.train_epoch(cur_epoch) 49 | # test epoch 50 | if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH: 51 | top1_err = spos_trainer.test_epoch(cur_epoch) 52 | spos_trainer.finish() 53 | 54 | # sample best architecture from supernet 55 | for cycle in range(200): # NOTE: this should be a hyperparameter 56 | sample = evaluate_sampler.suggest() 57 | top1_err = spos_trainer.evaluate_epoch(sample) 58 | evaluate_sampler.record(sample, top1_err) 59 | best_arch, best_top1err = evaluate_sampler.final_best() 60 | logger.info("Best arch: {} \nTop1 error: {}".format(best_arch, best_top1err)) 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /scripts/train/DARTS.py: -------------------------------------------------------------------------------- 1 | """DARTS retraining""" 2 | import torch 3 | import torch.nn as nn 4 | 5 | import xnas.core.config as config 6 | import xnas.logger.logging as logging 7 | import xnas.logger.meter as meter 8 | from xnas.core.config import cfg 9 | from xnas.core.builder import * 10 | from xnas.datasets.loader import get_normal_dataloader 11 | 12 | from xnas.runner.trainer import Trainer 13 | 14 | # Load config and check 15 | config.load_configs() 16 | logger = logging.get_logger(__name__) 17 | 18 | def main(): 19 | device = setup_env() 20 | 21 | model = space_builder().to(device) 22 | criterion = criterion_builder().to(device) 23 | # evaluator = evaluator_builder() 24 | 25 | [train_loader, valid_loader] = get_normal_dataloader() 26 | optimizer = optimizer_builder("SGD", model.parameters()) 27 | lr_scheduler = lr_scheduler_builder(optimizer) 28 | 29 | darts_retrainer = Darts_Retrainer( 30 | model, criterion, optimizer, lr_scheduler, train_loader, valid_loader 31 | ) 32 | 33 | start_epoch = darts_retrainer.loading() if cfg.SEARCH.AUTO_RESUME else 0 34 | 35 | darts_retrainer.start() 36 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): 37 | darts_retrainer.model.drop_path_prob = cfg.TRAIN.DROP_PATH_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH 38 | darts_retrainer.train_epoch(cur_epoch) 39 | if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH: 40 | darts_retrainer.test_epoch(cur_epoch) 41 | darts_retrainer.finish() 42 | 43 | 44 | # overwrite training & validating with auxiliary 45 | class Darts_Retrainer(Trainer): 46 | def __init__(self, model, criterion, optimizer, lr_scheduler, train_loader, test_loader): 47 | super().__init__(model, criterion, optimizer, lr_scheduler, train_loader, test_loader) 48 | 49 | def train_epoch(self, cur_epoch): 50 | self.model.train() 51 | lr = self.lr_scheduler.get_last_lr()[0] 52 | cur_step = cur_epoch * len(self.train_loader) 53 | self.writer.add_scalar('train/lr', lr, cur_step) 54 | self.train_meter.iter_tic() 55 | for cur_iter, (inputs, labels) in enumerate(self.train_loader): 56 | inputs, labels = inputs.to(self.device), labels.to(self.device, non_blocking=True) 57 | preds, preds_aux = self.model(inputs) 58 | loss = self.criterion(preds, labels) 59 | self.optimizer.zero_grad() 60 | if cfg.TRAIN.AUX_WEIGHT > 0.: 61 | loss += cfg.TRAIN.AUX_WEIGHT * self.criterion(preds_aux, labels) 62 | loss.backward() 63 | nn.utils.clip_grad_norm_(self.model.weights(), cfg.OPTIM.GRAD_CLIP) 64 | self.optimizer.step() 65 | 66 | # Compute the errors 67 | top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) 68 | loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item() 69 | self.train_meter.iter_toc() 70 | # Update and log stats 71 | self.train_meter.update_stats(top1_err, top5_err, loss, lr, inputs.size(0) * cfg.NUM_GPUS) 72 | self.train_meter.log_iter_stats(cur_epoch, cur_iter) 73 | self.train_meter.iter_tic() 74 | self.writer.add_scalar('train/loss', loss, cur_step) 75 | self.writer.add_scalar('train/top1_error', top1_err, cur_step) 76 | self.writer.add_scalar('train/top5_error', top5_err, cur_step) 77 | cur_step += 1 78 | # Log epoch stats 79 | self.train_meter.log_epoch_stats(cur_epoch) 80 | self.train_meter.reset() 81 | self.lr_scheduler.step() 82 | # Saving checkpoint 83 | if (cur_epoch + 1) % cfg.SAVE_PERIOD == 0: 84 | self.saving(cur_epoch) 85 | 86 | @torch.no_grad() 87 | def test_epoch(self, cur_epoch): 88 | self.model.eval() 89 | self.test_meter.iter_tic() 90 | for cur_iter, (inputs, labels) in enumerate(self.test_loader): 91 | inputs, labels = inputs.to(self.device), labels.to(self.device, non_blocking=True) 92 | preds, _ = self.model(inputs) 93 | top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) 94 | top1_err, top5_err = top1_err.item(), top5_err.item() 95 | 96 | self.test_meter.iter_toc() 97 | self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) 98 | self.test_meter.log_iter_stats(cur_epoch, cur_iter) 99 | self.test_meter.iter_tic() 100 | top1_err = self.test_meter.mb_top1_err.get_win_avg() 101 | self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch) 102 | self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch) 103 | # Log epoch stats 104 | self.test_meter.log_epoch_stats(cur_epoch) 105 | self.test_meter.reset() 106 | # Saving best model 107 | if self.best_err > top1_err: 108 | self.best_err = top1_err 109 | self.saving(cur_epoch, best=True) 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /tests/MultiSizeRandomCrop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import torchvision.datasets as dset 4 | import torchvision.transforms as transforms 5 | from xnas.datasets.transforms import MultiSizeRandomCrop 6 | 7 | msrc = MultiSizeRandomCrop([4,224]) 8 | # print(MultiSizeRandomCrop.CANDIDATE_SIZES) 9 | 10 | for i in range(5): 11 | MultiSizeRandomCrop.sample_image_size() 12 | print(MultiSizeRandomCrop.ACTIVE_SIZE) 13 | 14 | def my_collate(batch): 15 | msrc.sample_image_size() 16 | xs = torch.stack([i[0] for i in batch]) 17 | ys = torch.Tensor([i[1] for i in batch]) 18 | return [xs, ys] 19 | 20 | T = transforms.Compose([ 21 | msrc, 22 | transforms.ToTensor(), 23 | ]) 24 | 25 | _data = dset.CIFAR10( 26 | root='./data/cifar10', 27 | train=True, 28 | transform=T, 29 | ) 30 | 31 | loader = data.DataLoader( 32 | dataset=_data, 33 | batch_size=128, 34 | collate_fn=my_collate, 35 | ) 36 | 37 | for i, (trn_X, trn_y) in enumerate(loader): 38 | print(trn_X.shape, trn_y.shape) 39 | if i==5: 40 | break 41 | -------------------------------------------------------------------------------- /tests/adjust_lr_per_iter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class model1(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.conv = nn.Conv2d(3,3,3) 9 | 10 | def forward(self, x): 11 | x = self.conv(x) 12 | return x 13 | 14 | net = model1() 15 | 16 | opr = torch.optim.SGD(net.parameters(), 0.1) 17 | lrs = torch.optim.lr_scheduler.CosineAnnealingLR(opr, 10) 18 | 19 | from xnas.runner.scheduler import adjust_learning_rate_per_batch 20 | 21 | for cur_epoch in range(10): 22 | print("epoch:{} lr:{}".format(cur_epoch, lrs.get_last_lr()[0])) 23 | opr.zero_grad() 24 | opr.step() 25 | lrs.step() 26 | 27 | print("*"*20) 28 | del opr, lrs 29 | 30 | opr = torch.optim.SGD(net.parameters(), 0.1) 31 | lrs = torch.optim.lr_scheduler.CosineAnnealingLR(opr, 10) 32 | 33 | for cur_epoch in range(10): 34 | opr.zero_grad() 35 | for cur_iter in range(5): 36 | new_lr = adjust_learning_rate_per_batch( 37 | init_lr=0.1, 38 | n_epochs=10, 39 | n_warmup_epochs=3, 40 | epoch=cur_epoch, 41 | n_iter=5, 42 | iter=cur_iter, 43 | warmup=(cur_epoch < 3), 44 | warmup_lr=0.1, # use base_lr as warmup_lr 45 | ) 46 | for param_group in opr.param_groups: 47 | param_group["lr"] = new_lr 48 | if cur_iter == 0: 49 | print("epoch:{} iter:{} lr:{}".format(cur_epoch, cur_iter, new_lr)) 50 | opr.step() 51 | -------------------------------------------------------------------------------- /tests/configs/datasets.yaml: -------------------------------------------------------------------------------- 1 | LOADER: 2 | DATASET: 'cifar10' 3 | SPLIT: [0.5, 0.5] 4 | BATCH_SIZE: 64 5 | NUM_CLASSES: 10 6 | SEARCH: 7 | IM_SIZE: 32 8 | -------------------------------------------------------------------------------- /tests/configs/imagenet.yaml: -------------------------------------------------------------------------------- 1 | LOADER: 2 | DATASET: 'imagenet' 3 | BATCH_SIZE: 64 4 | NUM_CLASSES: 1000 5 | USE_VAL: True 6 | NUM_WORKERS: 4 7 | NUM_GPUS: 2 8 | SEARCH: 9 | MULTI_SIZES: [160, 192, 224] 10 | -------------------------------------------------------------------------------- /tests/dataloader.py: -------------------------------------------------------------------------------- 1 | import xnas.core.config as config 2 | from xnas.datasets.loader import construct_loader 3 | from xnas.core.config import cfg 4 | 5 | config.load_configs() 6 | 7 | # cifar10 8 | [train_loader, valid_loader] = construct_loader() 9 | 10 | # cifar100 11 | cfg.LOADER.DATASET = 'cifar100' 12 | cfg.LOADER.NUM_CLASSES = 100 13 | [train_loader, valid_loader] = construct_loader() 14 | 15 | # imagenet16 16 | cfg.LOADER.DATASET = 'imagenet16' 17 | cfg.LOADER.NUM_CLASSES = 120 18 | [train_loader, valid_loader] = construct_loader() 19 | 20 | -------------------------------------------------------------------------------- /tests/imagenet.py: -------------------------------------------------------------------------------- 1 | import xnas.core.config as config 2 | from xnas.datasets.loader import construct_loader 3 | from xnas.core.config import cfg 4 | 5 | config.load_configs() 6 | 7 | [train_loader, valid_loader] = construct_loader() 8 | 9 | for i, (trn_X, trn_y) in enumerate(train_loader): 10 | print(trn_X.shape, trn_y.shape) 11 | if i==9: 12 | break 13 | 14 | # cfg.SEARCH.MULTI_SIZES = [] 15 | 16 | print("===") 17 | for i, (trn_X, trn_y) in enumerate(valid_loader): 18 | print(trn_X.shape, trn_y.shape) 19 | if i==9: 20 | break 21 | -------------------------------------------------------------------------------- /tests/nb101space.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/tests/nb101space.py -------------------------------------------------------------------------------- /tests/nb1shot1space.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/tests/nb1shot1space.py -------------------------------------------------------------------------------- /tests/nb201eval.py: -------------------------------------------------------------------------------- 1 | from xnas.evaluations.NASBench201 import evaluate, index_to_genotype, distill 2 | 3 | arch = index_to_genotype(2333) 4 | result = evaluate(arch) 5 | ( 6 | cifar10_train, 7 | cifar10_test, 8 | cifar100_train, 9 | cifar100_valid, 10 | cifar100_test, 11 | imagenet16_train, 12 | imagenet16_valid, 13 | imagenet16_test, 14 | ) = distill(result) 15 | 16 | print("cifar10 train %f test %f", cifar10_train, cifar10_test) 17 | print("cifar100 train %f valid %f test %f", cifar100_train, cifar100_valid, cifar100_test) 18 | print("imagenet16 train %f valid %f test %f", imagenet16_train, imagenet16_valid, imagenet16_test) 19 | 20 | 21 | result = evaluate(arch, epoch=200) 22 | ( 23 | cifar10_train, 24 | cifar10_test, 25 | cifar100_train, 26 | cifar100_valid, 27 | cifar100_test, 28 | imagenet16_train, 29 | imagenet16_valid, 30 | imagenet16_test, 31 | ) = distill(result) 32 | 33 | print("cifar10 train %f test %f", cifar10_train, cifar10_test) 34 | print("cifar100 train %f valid %f test %f", cifar100_train, cifar100_valid, cifar100_test) 35 | print("imagenet16 train %f valid %f test %f", imagenet16_train, imagenet16_valid, imagenet16_test) 36 | -------------------------------------------------------------------------------- /tests/nb301eval.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/tests/nb301eval.py -------------------------------------------------------------------------------- /tests/ofa_matrices_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | def test_local(): 5 | root = '/home/xfey/XNAS/exp/search/OFA_trial_25/kernel_1/checkpoints/' 6 | filename_prefix = 'model_epoch_' 7 | filename_postfix = '.pyth' 8 | 9 | for i in range(101, 110): 10 | ckpt = torch.load(root+filename_prefix+'0'+str(i)+filename_postfix) 11 | for k,v in ckpt['model_state'].items(): 12 | # print(k) 13 | # if k.endswith('5to3_matrix'): 14 | if k == 'blocks.6.conv.depth_conv.conv.5to3_matrix': 15 | print(k, v[0:3, 0:3]) 16 | break 17 | for k,v in ckpt['model_state'].items(): 18 | # print(k) 19 | # if k.endswith('7to5_matrix'): 20 | if k == 'blocks.6.conv.depth_conv.conv.7to5_matrix': 21 | print(k, v[0:3, 0:3]) 22 | break 23 | 24 | def test_original(): 25 | root = "/home/xfey/XNAS/tests/weights/" 26 | ckpt = torch.load(root+"ofa_D4_E6_K357") 27 | for k,v in ckpt['state_dict'].items(): 28 | # print(k) 29 | if k.endswith('5to3_matrix') and k.startswith('blocks.6'): 30 | print(k, v[0:3, 0:3]) 31 | break 32 | for k,v in ckpt['state_dict'].items(): 33 | # print(k) 34 | if k.endswith('7to5_matrix') and k.startswith('blocks.6'): 35 | print(k, v[0:3, 0:3]) 36 | break 37 | 38 | if __name__ == '__main__': 39 | # test_local() 40 | test_original() 41 | -------------------------------------------------------------------------------- /tests/test_REA.py: -------------------------------------------------------------------------------- 1 | from xnas.algorithms.SPOS import REA 2 | 3 | rea = REA(num_choice=4, child_len=10, population_size=5) 4 | 5 | for i in range(30): 6 | child = rea.suggest() 7 | print("suggest: {}".format(child)) 8 | rea.record(child, value=sum(child)) 9 | print("===") 10 | for p in rea.population: 11 | print(p) 12 | input() 13 | print(rea.final_best()) -------------------------------------------------------------------------------- /xnas/algorithms/AttentiveNAS/sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | def count_helper(v, flops, m): 4 | if flops not in m: 5 | m[flops] = {} 6 | if v not in m[flops]: 7 | m[flops][v] = 0 8 | m[flops][v] += 1 9 | 10 | 11 | def round_flops(flops, step): 12 | return int(round(flops / step) * step) 13 | 14 | 15 | def convert_count_to_prob(m): 16 | if isinstance(m[list(m.keys())[0]], dict): 17 | for k in m: 18 | convert_count_to_prob(m[k]) 19 | else: 20 | t = sum(m.values()) 21 | for k in m: 22 | m[k] = 1.0 * m[k] / t 23 | 24 | 25 | def sample_helper(flops, m): 26 | keys = list(m[flops].keys()) 27 | probs = list(m[flops].values()) 28 | return random.choices(keys, weights=probs)[0] 29 | 30 | 31 | def build_trasition_prob_matrix(file_handler, step): 32 | # initlizie 33 | prob_map = {} 34 | prob_map['discretize_step'] = step 35 | for k in ['flops', 'resolution', 'width', 'depth', 'kernel_size', 'expand_ratio']: 36 | prob_map[k] = {} 37 | 38 | cc = 0 39 | for line in file_handler: 40 | vals = eval(line.strip()) 41 | 42 | # discretize 43 | flops = round_flops(vals['flops'], step) 44 | prob_map['flops'][flops] = prob_map['flops'].get(flops, 0) + 1 45 | 46 | # resolution 47 | r = vals['resolution'] 48 | count_helper(r, flops, prob_map['resolution']) 49 | 50 | for k in ['width', 'depth', 'kernel_size', 'expand_ratio']: 51 | for idx, v in enumerate(vals[k]): 52 | if idx not in prob_map[k]: 53 | prob_map[k][idx] = {} 54 | count_helper(v, flops, prob_map[k][idx]) 55 | 56 | cc += 1 57 | 58 | # convert count to probability 59 | for k in ['flops', 'resolution', 'width', 'depth', 'kernel_size', 'expand_ratio']: 60 | convert_count_to_prob(prob_map[k]) 61 | prob_map['n_observations'] = cc 62 | return prob_map 63 | 64 | 65 | 66 | class ArchSampler(): 67 | def __init__(self, arch_to_flops_map_file_path, discretize_step, model, acc_predictor=None): 68 | super(ArchSampler, self).__init__() 69 | 70 | with open(arch_to_flops_map_file_path, 'r') as fp: 71 | self.prob_map = build_trasition_prob_matrix(fp, discretize_step) 72 | 73 | self.discretize_step = discretize_step 74 | self.model = model 75 | 76 | self.acc_predictor = acc_predictor 77 | 78 | self.min_flops = min(list(self.prob_map['flops'].keys())) 79 | self.max_flops = max(list(self.prob_map['flops'].keys())) 80 | 81 | self.curr_sample_pool = None #TODO; architecture samples could be generated in an asynchronous way 82 | 83 | 84 | def sample_one_target_flops(self, flops_uniform=False): 85 | f_vals = list(self.prob_map['flops'].keys()) 86 | f_probs = list(self.prob_map['flops'].values()) 87 | 88 | if flops_uniform: 89 | return random.choice(f_vals) 90 | else: 91 | return random.choices(f_vals, weights=f_probs)[0] 92 | 93 | 94 | def sample_archs_according_to_flops(self, target_flops, n_samples=1, max_trials=100, return_flops=True, return_trials=False): 95 | archs = [] 96 | #for _ in range(n_samples): 97 | while len(archs) < n_samples: 98 | for _trial in range(max_trials+1): 99 | arch = {} 100 | arch['resolution'] = sample_helper(target_flops, self.prob_map['resolution']) 101 | for k in ['width', 'kernel_size', 'depth', 'expand_ratio']: 102 | arch[k] = [] 103 | for idx in sorted(list(self.prob_map[k].keys())): 104 | arch[k].append(sample_helper(target_flops, self.prob_map[k][idx])) 105 | if self.model: 106 | self.model.set_active_subnet(**arch) 107 | flops = self.model.compute_active_subnet_flops() 108 | if return_flops: 109 | arch['flops'] = flops 110 | if round_flops(flops, self.discretize_step) == target_flops: 111 | break 112 | else: 113 | raise NotImplementedError 114 | #accepte the sample anyway 115 | archs.append(arch) 116 | return archs 117 | 118 | -------------------------------------------------------------------------------- /xnas/algorithms/DARTS.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | ''' 8 | Darts: highly copyed from https://github.com/khanrc/pt.darts 9 | ''' 10 | 11 | 12 | class DartsCNNController(nn.Module): 13 | """SearchCNN Controller""" 14 | 15 | def __init__(self, net, criterion, device_ids=None): 16 | super().__init__() 17 | if device_ids is None: 18 | device_ids = list(range(torch.cuda.device_count())) 19 | self.net = net 20 | self.device_ids = device_ids 21 | self.n_ops = self.net.num_ops 22 | self.alpha = nn.Parameter( 23 | 1e-3*torch.randn(self.net.all_edges, self.n_ops)) 24 | self.criterion = criterion 25 | 26 | # Setup alphas list 27 | self._alphas = [] 28 | for n, p in self.named_parameters(): 29 | if 'alpha' in n: 30 | self._alphas.append((n, p)) 31 | 32 | def forward(self, x): 33 | weights_ = F.softmax(self.alpha, dim=-1) 34 | if len(self.device_ids) == 1: 35 | return self.net(x, weights_) 36 | else: 37 | raise NotImplementedError 38 | 39 | def genotype(self): 40 | """return genotype of DARTS CNN""" 41 | return self.net.genotype(self.alpha.cpu().detach().numpy()) 42 | 43 | def weights(self): 44 | return self.net.parameters() 45 | 46 | def named_weights(self): 47 | return self.net.named_parameters() 48 | 49 | def alphas(self): 50 | for n, p in self._alphas: 51 | yield p 52 | 53 | def named_alphas(self): 54 | for n, p in self._alphas: 55 | yield n, p 56 | 57 | def print_alphas(self, logger): 58 | logger.info("####### ALPHA #######") 59 | for alpha in self.alpha: 60 | logger.info(F.softmax(alpha, dim=-1).cpu().detach().numpy()) 61 | logger.info("#####################") 62 | 63 | def loss(self, X, y): 64 | logits = self.forward(X) 65 | return self.criterion(logits, y) 66 | 67 | 68 | class Architect(): 69 | """ Compute gradients of alphas """ 70 | 71 | def __init__(self, net, w_momentum, w_weight_decay): 72 | """ 73 | Args: 74 | net 75 | w_momentum: weights momentum 76 | """ 77 | self.net = net 78 | self.v_net = copy.deepcopy(net) 79 | self.w_momentum = w_momentum 80 | self.w_weight_decay = w_weight_decay 81 | 82 | def virtual_step(self, trn_X, trn_y, xi, w_optim): 83 | """ 84 | Compute unrolled weight w' (virtual step) 85 | Step process: 86 | 1) forward 87 | 2) calc loss 88 | 3) compute gradient (by backprop) 89 | 4) update gradient 90 | Args: 91 | xi: learning rate for virtual gradient step (same as weights lr) 92 | w_optim: weights optimizer 93 | """ 94 | # forward & calc loss 95 | loss = self.net.loss(trn_X, trn_y) # L_trn(w) 96 | 97 | # compute gradient 98 | gradients = torch.autograd.grad(loss, self.net.weights()) 99 | 100 | # virtual step (update gradient) 101 | # operations below do not need gradient tracking 102 | with torch.no_grad(): 103 | # dict key is not the value, but the pointer, 104 | # So original network weight have to be iterated also. 105 | for w, vw, g in zip(self.net.weights(), self.v_net.weights(), gradients): 106 | m = w_optim.state[w].get( 107 | 'momentum_buffer', 0.) * self.w_momentum 108 | vw.copy_(w - xi * (m + g + self.w_weight_decay*w)) 109 | 110 | # synchronize alphas 111 | for a, va in zip(self.net.alphas(), self.v_net.alphas()): 112 | va.copy_(a) 113 | 114 | def unrolled_backward(self, trn_X, trn_y, val_X, val_y, xi, w_optim, unrolled=True): 115 | """ Compute unrolled loss and backward its gradients 116 | Args: 117 | xi: learning rate for virtual gradient step (same as net lr) 118 | w_optim: weights optimizer - for virtual step 119 | """ 120 | # do virtual step (calc w`) 121 | if unrolled: 122 | self.virtual_step(trn_X, trn_y, xi, w_optim) 123 | 124 | # calc unrolled loss 125 | loss = self.v_net.loss(val_X, val_y) # L_val(w`) 126 | 127 | # compute gradient 128 | v_alphas = tuple(self.v_net.alphas()) 129 | v_weights = tuple(self.v_net.weights()) 130 | v_grads = torch.autograd.grad(loss, v_alphas + v_weights) 131 | dalpha = v_grads[:len(v_alphas)] 132 | dw = v_grads[len(v_alphas):] 133 | 134 | hessian = self.compute_hessian(dw, trn_X, trn_y) 135 | 136 | # update final gradient = dalpha - xi*hessian 137 | with torch.no_grad(): 138 | for alpha, da, h in zip(self.net.alphas(), dalpha, hessian): 139 | alpha.grad = da - xi * h 140 | 141 | else: 142 | loss = self.net.loss(val_X, val_y) # L_trn(w) 143 | loss.loss.backward() 144 | 145 | 146 | def compute_hessian(self, dw, trn_X, trn_y): 147 | """ 148 | dw = dw` { L_val(w`, alpha) } 149 | w+ = w + eps * dw 150 | w- = w - eps * dw 151 | hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps) 152 | eps = 0.01 / ||dw|| 153 | """ 154 | 155 | norm = torch.cat([w.view(-1) for w in dw]).norm() 156 | eps = 0.01 / norm 157 | 158 | # w+ = w + eps*dw` 159 | with torch.no_grad(): 160 | for p, d in zip(self.net.weights(), dw): 161 | p += eps * d 162 | loss = self.net.loss(trn_X, trn_y) 163 | dalpha_pos = torch.autograd.grad( 164 | loss, self.net.alphas()) # dalpha { L_trn(w+) } 165 | 166 | # w- = w - eps*dw` 167 | with torch.no_grad(): 168 | for p, d in zip(self.net.weights(), dw): 169 | p -= 2. * eps * d 170 | loss = self.net.loss(trn_X, trn_y) 171 | dalpha_neg = torch.autograd.grad( 172 | loss, self.net.alphas()) # dalpha { L_trn(w-) } 173 | 174 | # recover w 175 | with torch.no_grad(): 176 | for p, d in zip(self.net.weights(), dw): 177 | p += eps * d 178 | 179 | hessian = [(p-n) / 2.*eps for p, n in zip(dalpha_pos, dalpha_neg)] 180 | return hessian 181 | -------------------------------------------------------------------------------- /xnas/algorithms/DrNAS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | 5 | 6 | def _concat(xs): 7 | return torch.cat([x.view(-1) for x in xs]) 8 | 9 | 10 | class Architect(object): 11 | def __init__(self, net, cfg): 12 | self.network_momentum = cfg.OPTIM.MOMENTUM 13 | self.network_weight_decay = cfg.OPTIM.WEIGHT_DECAY 14 | self.net = net 15 | if cfg.DRNAS.REG_TYPE == "l2": 16 | weight_decay = cfg.DRNAS.REG_SCALE 17 | elif cfg.DRNAS.REG_TYPE == "kl": 18 | weight_decay = 0 19 | self.optimizer = torch.optim.Adam( 20 | self.net.arch_parameters(), 21 | lr=cfg.DARTS.ALPHA_LR, 22 | betas=(0.5, 0.999), 23 | weight_decay=weight_decay, 24 | ) 25 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 26 | 27 | def _compute_unrolled_model(self, input, target, eta, network_optimizer): 28 | loss = self.net._loss(input, target) 29 | theta = _concat(self.net.parameters()).data 30 | try: 31 | moment = _concat( 32 | network_optimizer.state[v]["momentum_buffer"] 33 | for v in self.net.parameters() 34 | ).mul_(self.network_momentum) 35 | except: 36 | moment = torch.zeros_like(theta) 37 | dtheta = ( 38 | _concat(torch.autograd.grad(loss, self.net.parameters())).data 39 | + self.network_weight_decay * theta 40 | ) 41 | unrolled_model = self._construct_model_from_theta( 42 | theta.sub(eta, moment + dtheta) 43 | ).to(self.device) 44 | return unrolled_model 45 | 46 | def unrolled_backward( 47 | self, 48 | input_train, 49 | target_train, 50 | input_valid, 51 | target_valid, 52 | eta, 53 | network_optimizer, 54 | unrolled, 55 | ): 56 | self.optimizer.zero_grad() 57 | if unrolled: 58 | self._backward_step_unrolled( 59 | input_train, 60 | target_train, 61 | input_valid, 62 | target_valid, 63 | eta, 64 | network_optimizer, 65 | ) 66 | else: 67 | self._backward_step(input_valid, target_valid) 68 | self.optimizer.step() 69 | 70 | # def pruning(self, masks): 71 | # for i, p in enumerate(self.optimizer.param_groups[0]['params']): 72 | # if masks[i] is None: 73 | # continue 74 | # state = self.optimizer.state[p] 75 | # mask = masks[i] 76 | # state['exp_avg'][~mask] = 0.0 77 | # state['exp_avg_sq'][~mask] = 0.0 78 | 79 | def _backward_step(self, input_valid, target_valid): 80 | loss = self.net._loss(input_valid, target_valid) 81 | loss.backward() 82 | 83 | def _backward_step_unrolled( 84 | self, 85 | input_train, 86 | target_train, 87 | input_valid, 88 | target_valid, 89 | eta, 90 | network_optimizer, 91 | ): 92 | unrolled_model = self._compute_unrolled_model( 93 | input_train, target_train, eta, network_optimizer 94 | ) 95 | unrolled_loss = unrolled_model._loss(input_valid, target_valid) 96 | 97 | unrolled_loss.backward() 98 | dalpha = [v.grad for v in unrolled_model.arch_parameters()] 99 | vector = [v.grad.data for v in unrolled_model.parameters()] 100 | implicit_grads = self._hessian_vector_product(vector, input_train, target_train) 101 | 102 | for g, ig in zip(dalpha, implicit_grads): 103 | g.data.sub_(eta, ig.data) 104 | 105 | for v, g in zip(self.net.arch_parameters(), dalpha): 106 | if v.grad is None: 107 | v.grad = Variable(g.data) 108 | else: 109 | v.grad.data.copy_(g.data) 110 | 111 | def _construct_model_from_theta(self, theta): 112 | model_new = self.net.new() 113 | model_dict = self.net.state_dict() 114 | 115 | params, offset = {}, 0 116 | for k, v in self.net.named_parameters(): 117 | v_length = np.prod(v.size()) 118 | params[k] = theta[offset : offset + v_length].view(v.size()) 119 | offset += v_length 120 | 121 | assert offset == len(theta) 122 | model_dict.update(params) 123 | model_new.load_state_dict(model_dict) 124 | return model_new 125 | 126 | def _hessian_vector_product(self, vector, input, target, r=1e-2): 127 | R = r / _concat(vector).norm() 128 | for p, v in zip(self.net.parameters(), vector): 129 | p.data.add_(R, v) 130 | loss = self.net._loss(input, target) 131 | grads_p = torch.autograd.grad(loss, self.net.arch_parameters()) 132 | 133 | for p, v in zip(self.net.parameters(), vector): 134 | p.data.sub_(2 * R, v) 135 | loss = self.net._loss(input, target) 136 | grads_n = torch.autograd.grad(loss, self.net.arch_parameters()) 137 | 138 | for p, v in zip(self.net.parameters(), vector): 139 | p.data.add_(R, v) 140 | 141 | return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)] 142 | -------------------------------------------------------------------------------- /xnas/algorithms/RMINAS/README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | Code for paper: **Neural Architecture Search with Representation Mutual Information** 4 | 5 | RMI-NAS is an efficient architecture search method based on Representation Mutual Information (RMI) theory. It aims at improving the speed of performance evaluation by ranking architectures with RMI, which is an accurate and effective indicator to facilitate NAS. RMI-NAS uses only one batch of data to complete training and generalizes well to different search spaces. For more details, please refer to our paper. 6 | 7 | ## Results 8 | 9 | ### Results on NAS-Bench-201 10 | 11 | | Method | Search Cost
(seconds) | CIFAR-10
Test Acc.(%) | CIFAR-100
Test Acc.(%) | ImageNet16-120
Test Acc.(%) | 12 | | ----------- | -------------------------- | --------------------------- | ---------------------------- | --------------------------------- | 13 | | RL | 27870.7 | 93.85±0.37 | 71.71±1.09 | 45.24±1.18 | 14 | | DARTS-V2 | 35781.8 | 54.30±0.00 | 15.61±0.00 | 16.32±0.00 | 15 | | GDAS | 31609.8 | 93.61±0.09 | 70.70±0.30 | 41.71±0.98 | 16 | | FairNAS | 9845.0 | 93.23±0.18 | 71.00±1.46 | 42.19±0.31 | 17 | | **RMI-NAS** | **1258.2** | **94.28±0.10** | **73.36±0.19** | **46.34±0.00** | 18 | 19 | Our method shows significant efficiency and accuracy improvements. 20 | 21 | ### Results on DARTS 22 | 23 | | Method | Search Cost
(seconds) | CIFAR-10
Test Acc.(%)
(paper) | CIFAR-10
Test Acc.(%)
(retrain) | 24 | | ----------- | -------------------------- | ---------------------------------------- | ------------------------------------------ | 25 | | AmoebaNet-B | 3150 | 2.55±0.05 | - | 26 | | NASNet-A | 1800 | 2.65 | - | 27 | | DARTS (1st) | 0.4 | 3.00±0.14 | 2.75 | 28 | | DARTS (2nd) | 1 | 2.76±0.09 | 2.60 | 29 | | SNAS | 1.5 | 2.85±0.02 | 2.68 | 30 | | PC-DARTS | 1 | 2.57±0.07 | 2.71±0.11 | 31 | | FairDARTS-D | 0.4 | 2.54±0.05 | 2.71 | 32 | | **RMI-NAS** | **0.08** | - | 2.64±0.04 | 33 | 34 | Comparisons with other methods in DARTS. We also report retrained results under exactly the same settings to ensure a fair comparison. Our method delivers a comparable accuracy but substantial improvements on time comsumption. 35 | 36 | 37 | 38 | ## Usage 39 | 40 | #### Install RMI-NAS 41 | 42 | Our code contains functions from XNAS repository, which is required to be installed. 43 | 44 | ```bash 45 | # install XNAS 46 | git clone https://github.com/MAC-AutoML/XNAS.git 47 | export PYTHONPATH=$PYTHONPATH:/PATH/to/XNAS 48 | 49 | # prepare environment for RMI-NAS (conda) 50 | conda env create --file environment.yaml 51 | 52 | # download weight files for teacher models 53 | chmod +x xnas/algorithms/RMINAS/download_weight.sh 54 | bash xnas/algorithms/RMINAS/download_weight.sh 55 | ``` 56 | 57 | File [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) is required for a previous version of `NAS-Bench-201` we are using. It should be downloaded and put into the `utils` directory. 58 | 59 | #### Search 60 | 61 | ```bash 62 | # NAS-Bench-201 + CIFAR-10 63 | python search/RMINAS/RMINAS_nb201.py --cfg configs/search/RMINAS/nb201_cifar10.yaml 64 | 65 | # DARTS + CIFAR-100 + specific exp path 66 | python search/RMINAS/RMINAS_darts.py --cfg configs/search/RMINAS/darts_cifar100.yaml OUT_DIR experiments/ 67 | ``` 68 | 69 | 70 | ## Related work 71 | 72 | [NAS-Bench-201](https://github.com/D-X-Y/NAS-Bench-201) 73 | 74 | [XNAS](https://github.com/MAC-AutoML/XNAS) -------------------------------------------------------------------------------- /xnas/algorithms/RMINAS/download_weight.sh: -------------------------------------------------------------------------------- 1 | echo `cd xnas/algorithms/RMINAS/teacher_model/resnet20_cifar10 && wget http://cdn.thrase.cn/rmi/resnet20.th` 2 | echo `cd xnas/algorithms/RMINAS/teacher_model/nb201model_imagenet16120 && wget http://cdn.thrase.cn/rmi/009930-FULL.pth` 3 | echo `cd xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet && wget http://cdn.thrase.cn/rmi/fbresnet152.pth` 4 | echo `cd xnas/algorithms/RMINAS/teacher_model/resnet101_cifar100 && wget http://cdn.thrase.cn/rmi/resnet101.pth` 5 | echo "Finish downloading weight files." 6 | -------------------------------------------------------------------------------- /xnas/algorithms/RMINAS/sampler/sampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | # true_list = [] 5 | # with open('xnas/search_algorithm/RMINAS/sampler/available_archs.txt', 'r') as f: 6 | # true_list = eval(f.readline()) 7 | 8 | # def random_sampling(times): 9 | # sample_list = [] 10 | # if times > sum(true_list): 11 | # print('can only sample {} times.'.format(sum(true_list))) 12 | # times = sum(true_list) 13 | # for _ in range(times): 14 | # i = random.randint(0, 15624) 15 | # while (not true_list[i]) or (i in sample_list): 16 | # i = random.randint(0, 15624) 17 | # sample_list.append(i) 18 | # return sample_list 19 | 20 | def darts_sug2alpha(suggest_sample): 21 | b = np.c_[suggest_sample, np.zeros(14)] 22 | return torch.from_numpy(np.r_[b,b]) 23 | 24 | def nb201genostr2array(geno_str): 25 | # |none~0|+|nor_conv_1x1~0|none~1|+|avg_pool_3x3~0|skip_connect~1|nor_conv_3x3~2| 26 | OPS = ["none", "skip_connect", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3"] 27 | _tmp = geno_str.split('|') 28 | _tmp2 = [] 29 | for i in range(len(_tmp)): 30 | if i in [1,3,4,6,7,8]: 31 | _tmp2.append(_tmp[i][:-2]) 32 | _tmp_np = np.array([0]*6) 33 | for i in range(6): 34 | _tmp_np[i] = OPS.index(_tmp2[i]) 35 | _tmp_oh = np.zeros((_tmp_np.size, 5)) 36 | _tmp_oh[np.arange(_tmp_np.size),_tmp_np] = 1 37 | return _tmp_oh 38 | 39 | # def array2genostr(arr): 40 | # OPS = ["none", "skip_connect", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3"] 41 | # """[[1. 0. 0. 0. 0.] 42 | # [0. 0. 1. 0. 0.] 43 | # [1. 0. 0. 0. 0.] 44 | # [0. 0. 0. 0. 1.] 45 | # [0. 1. 0. 0. 0.] 46 | # [0. 0. 0. 1. 0.]]""" 47 | # idx = [list(i).index(1.) for i in arr] 48 | # op = [OPS[x] for x in idx] 49 | # mixed = '|' + op[0] + '~0|+|' + op[1] + '~0|' + op[2] + '~1|+|' + op[3] + '~0|' + op[4] + '~1|' + op[5] + '~2|' 50 | # return mixed 51 | 52 | # def base_transform(n, x): 53 | # a=[0,1,2,3,4,5,6,7,8,9,'A','b','C','D','E','F'] 54 | # b=[] 55 | # while True: 56 | # s=n//x 57 | # y=n%x 58 | # b=b+[y] 59 | # if s==0: 60 | # break 61 | # n=s 62 | # b.reverse() 63 | # zero_arr = [0]*(6-len(b)) 64 | # return zero_arr+b 65 | 66 | # def array_morearch(arr, distance): 67 | # """[[1. 0. 0. 0. 0.] 68 | # [0. 0. 1. 0. 0.] 69 | # [1. 0. 0. 0. 0.] 70 | # [0. 0. 0. 0. 1.] 71 | # [0. 1. 0. 0. 0.] 72 | # [0. 0. 0. 1. 0.]]""" 73 | # am = list(arr.argmax(axis=1)) # [0,2,0,4,1,3] 74 | # morearch = [] 75 | # if distance == 1: 76 | # for i in range(len(am)): 77 | # for j in range(5): 78 | # if am[i]!=j: 79 | # _tmp = am[:] 80 | # _tmp[i] = j 81 | # _tmp_np = np.array(_tmp) 82 | # _tmp_oh = np.zeros((_tmp_np.size, 5)) 83 | # _tmp_oh[np.arange(_tmp_np.size),_tmp_np] = 1 84 | # morearch.append(_tmp_oh) 85 | # else: 86 | # for i in range(15625): 87 | # arr = base_transform(i, 5) 88 | # if distance == 6-sum([arr[i]==am[i] for i in range(6)]): 89 | # _tmp_np = np.array(arr) 90 | # _tmp_oh = np.zeros((_tmp_np.size, 5)) 91 | # _tmp_oh[np.arange(_tmp_np.size),_tmp_np] = 1 92 | # morearch.append(_tmp_oh) 93 | # # morearch.append(arr) 94 | # return morearch 95 | 96 | 97 | -------------------------------------------------------------------------------- /xnas/algorithms/RMINAS/teacher_model/resnet101_cifar100/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | """ 6 | 7 | import torch.nn as nn 8 | 9 | class BasicBlock(nn.Module): 10 | expansion = 1 11 | 12 | def __init__(self, in_channels, out_channels, stride=1): 13 | super().__init__() 14 | 15 | self.residual_function = nn.Sequential( 16 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 17 | nn.BatchNorm2d(out_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 20 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 21 | ) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 27 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 28 | ) 29 | 30 | def forward(self, x): 31 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 32 | 33 | class BottleNeck(nn.Module): 34 | expansion = 4 35 | def __init__(self, in_channels, out_channels, stride=1): 36 | super().__init__() 37 | self.residual_function = nn.Sequential( 38 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 39 | nn.BatchNorm2d(out_channels), 40 | nn.ReLU(inplace=True), 41 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 42 | nn.BatchNorm2d(out_channels), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), 45 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 46 | ) 47 | 48 | self.shortcut = nn.Sequential() 49 | 50 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 53 | nn.BatchNorm2d(out_channels * BottleNeck.expansion) 54 | ) 55 | 56 | def forward(self, x): 57 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 58 | 59 | class ResNet(nn.Module): 60 | 61 | def __init__(self, block, num_block, num_classes=100): 62 | super().__init__() 63 | 64 | self.in_channels = 64 65 | 66 | self.conv1 = nn.Sequential( 67 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 68 | nn.BatchNorm2d(64), 69 | nn.ReLU(inplace=True)) 70 | #we use a different inputsize than the original paper 71 | #so conv2_x's stride is 1 72 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 73 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 74 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 75 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 76 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 77 | self.fc = nn.Linear(512 * block.expansion, num_classes) 78 | 79 | def _make_layer(self, block, out_channels, num_blocks, stride): 80 | strides = [stride] + [1] * (num_blocks - 1) 81 | layers = [] 82 | for stride in strides: 83 | layers.append(block(self.in_channels, out_channels, stride)) 84 | self.in_channels = out_channels * block.expansion 85 | 86 | return nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | output = self.conv1(x) 90 | output = self.conv2_x(output) 91 | output = self.conv3_x(output) 92 | output = self.conv4_x(output) 93 | output = self.conv5_x(output) 94 | output = self.avg_pool(output) 95 | output = output.view(output.size(0), -1) 96 | output = self.fc(output) 97 | 98 | return output 99 | 100 | def feature_extractor(self, x): 101 | features = [] 102 | output = self.conv1(x) 103 | output = self.conv2_x(output) 104 | features.append(output) 105 | output = self.conv3_x(output) 106 | output = self.conv4_x(output) 107 | features.append(output) 108 | output = self.conv5_x(output) 109 | features.append(output) 110 | # output = self.avg_pool(output) 111 | # output = output.view(output.size(0), -1) 112 | # output = self.fc(output) 113 | 114 | return features 115 | 116 | def resnet18(): 117 | return ResNet(BasicBlock, [2, 2, 2, 2]) 118 | 119 | def resnet34(): 120 | return ResNet(BasicBlock, [3, 4, 6, 3]) 121 | 122 | def resnet50(): 123 | return ResNet(BottleNeck, [3, 4, 6, 3]) 124 | 125 | def resnet101(): 126 | return ResNet(BottleNeck, [3, 4, 23, 3]) 127 | 128 | def resnet152(): 129 | return ResNet(BottleNeck, [3, 8, 36, 3]) 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /xnas/algorithms/RMINAS/teacher_model/resnet20_cifar10/resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 3 | 4 | The implementation and structure of this file is hugely influenced by [2] 5 | which is implemented for ImageNet and doesn't have option A for identity. 6 | Moreover, most of the implementations on the web is copy-paste from 7 | torchvision's resnet and has wrong number of params. 8 | 9 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 10 | number of layers and parameters: 11 | 12 | name | layers | params 13 | ResNet20 | 20 | 0.27M 14 | ResNet32 | 32 | 0.46M 15 | ResNet44 | 44 | 0.66M 16 | ResNet56 | 56 | 0.85M 17 | ResNet110 | 110 | 1.7M 18 | ResNet1202| 1202 | 19.4m 19 | 20 | which this implementation indeed has. 21 | 22 | Reference: 23 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 25 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 26 | 27 | If you use this implementation in you work, please don't forget to mention the 28 | author, Yerlan Idelbayev. 29 | ''' 30 | 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | import torch.nn.init as init 34 | 35 | 36 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 37 | 38 | def _weights_init(m): 39 | classname = m.__class__.__name__ 40 | #print(classname) 41 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 42 | init.kaiming_normal_(m.weight) 43 | 44 | class LambdaLayer(nn.Module): 45 | def __init__(self, lambd): 46 | super(LambdaLayer, self).__init__() 47 | self.lambd = lambd 48 | 49 | def forward(self, x): 50 | return self.lambd(x) 51 | 52 | 53 | class BasicBlock(nn.Module): 54 | expansion = 1 55 | 56 | def __init__(self, in_planes, planes, stride=1, option='A'): 57 | super(BasicBlock, self).__init__() 58 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(planes) 60 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 61 | self.bn2 = nn.BatchNorm2d(planes) 62 | 63 | self.shortcut = nn.Sequential() 64 | if stride != 1 or in_planes != planes: 65 | if option == 'A': 66 | """ 67 | For CIFAR10 ResNet paper uses option A. 68 | """ 69 | self.shortcut = LambdaLayer(lambda x: 70 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 71 | elif option == 'B': 72 | self.shortcut = nn.Sequential( 73 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 74 | nn.BatchNorm2d(self.expansion * planes) 75 | ) 76 | 77 | def forward(self, x): 78 | out = F.relu(self.bn1(self.conv1(x))) 79 | out = self.bn2(self.conv2(out)) 80 | out += self.shortcut(x) 81 | out = F.relu(out) 82 | return out 83 | 84 | 85 | class ResNet(nn.Module): 86 | def __init__(self, block, num_blocks, num_classes=10): 87 | super(ResNet, self).__init__() 88 | self.in_planes = 16 89 | 90 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 91 | self.bn1 = nn.BatchNorm2d(16) 92 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 93 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 94 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 95 | self.linear = nn.Linear(64, num_classes) 96 | 97 | self.apply(_weights_init) 98 | 99 | def _make_layer(self, block, planes, num_blocks, stride): 100 | strides = [stride] + [1]*(num_blocks-1) 101 | layers = [] 102 | for stride in strides: 103 | layers.append(block(self.in_planes, planes, stride)) 104 | self.in_planes = planes * block.expansion 105 | 106 | return nn.Sequential(*layers) 107 | 108 | def forward(self, x): 109 | out = F.relu(self.bn1(self.conv1(x))) 110 | out = self.layer1(out) 111 | out = self.layer2(out) 112 | out = self.layer3(out) 113 | out = F.avg_pool2d(out, out.size()[3]) 114 | out = out.view(out.size(0), -1) 115 | out = self.linear(out) 116 | return out 117 | 118 | def feature_extractor(self, x): 119 | features = [] 120 | out = F.relu(self.bn1(self.conv1(x))) 121 | out = self.layer1(out) 122 | features.append(out) 123 | out = self.layer2(out) 124 | features.append(out) 125 | out = self.layer3(out) 126 | features.append(out) 127 | # out = F.avg_pool2d(out, out.size()[3]) 128 | # out = out.view(out.size(0), -1) 129 | # out = self.linear(out) 130 | return features 131 | 132 | 133 | def resnet20(): 134 | return ResNet(BasicBlock, [3, 3, 3]) 135 | 136 | 137 | def resnet32(): 138 | return ResNet(BasicBlock, [5, 5, 5]) 139 | 140 | 141 | def resnet44(): 142 | return ResNet(BasicBlock, [7, 7, 7]) 143 | 144 | 145 | def resnet56(): 146 | return ResNet(BasicBlock, [9, 9, 9]) 147 | 148 | 149 | def resnet110(): 150 | return ResNet(BasicBlock, [18, 18, 18]) 151 | 152 | 153 | def resnet1202(): 154 | return ResNet(BasicBlock, [200, 200, 200]) 155 | 156 | 157 | def test(net): 158 | import numpy as np 159 | total_params = 0 160 | 161 | for x in filter(lambda p: p.requires_grad, net.parameters()): 162 | total_params += np.prod(x.data.numpy().shape) 163 | print("Total number of params", total_params) 164 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 165 | 166 | 167 | if __name__ == "__main__": 168 | for net_name in __all__: 169 | if net_name.startswith('resnet'): 170 | print(net_name) 171 | test(globals()[net_name]()) 172 | print() -------------------------------------------------------------------------------- /xnas/algorithms/RMINAS/utils/random_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | from xnas.datasets.loader import get_normal_dataloader 5 | from xnas.datasets.imagenet import ImageFolder 6 | 7 | 8 | def get_random_data(batchsize, name): 9 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 10 | if name == 'imagenet': 11 | train_loader, _ = ImageFolder( 12 | datapath="./data/imagenet/ILSVRC2012_img_train/", 13 | batch_size=batchsize*16, 14 | split=[0.5, 0.5], 15 | ).generate_data_loader() 16 | else: 17 | train_loader, _ = get_normal_dataloader(name, batchsize*16) 18 | 19 | random_idxs = np.random.randint(0, len(train_loader.dataset), size=train_loader.batch_size) 20 | (more_data_X, more_data_y) = zip(*[train_loader.dataset[idx] for idx in random_idxs]) 21 | more_data_X = torch.stack(more_data_X, dim=0).to(device) 22 | more_data_y = torch.Tensor(more_data_y).long().to(device) 23 | return more_data_X, more_data_y 24 | -------------------------------------------------------------------------------- /xnas/algorithms/SNG/DDPNAS.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from copy import deepcopy 4 | 5 | from xnas.algorithms.SNG.categorical import Categorical 6 | from xnas.core.utils import index_to_one_hot, softmax 7 | 8 | 9 | class CategoricalDDPNAS: 10 | def __init__(self, category, steps, theta_lr, gamma): 11 | self.p_model = Categorical(categories=category) 12 | # how many steps to pruning the distribution 13 | self.steps = steps 14 | self.current_step = 1 15 | self.ignore_index = [] 16 | self.sample_index = [] 17 | self.pruned_index = [] 18 | self.val_performance = [] 19 | self.sample = [] 20 | self.init_record() 21 | self.score_decay = 0.5 22 | self.learning_rate = 0.2 23 | self.training_finish = False 24 | self.training_epoch = self.get_training_epoch() 25 | self.non_param_index = [0, 1, 2, 7] 26 | self.param_index = [3, 4, 5, 6] 27 | self.non_param_index_num = len(self.non_param_index) 28 | self.pruned_index_num = len(self.param_index) 29 | self.non_param_index_count = [0] * self.p_model.d 30 | self.param_index_count = [0] * self.p_model.d 31 | self.gamma = gamma 32 | self.theta_lr = theta_lr 33 | self.velocity = np.zeros(self.p_model.theta.shape) 34 | 35 | def init_record(self): 36 | for i in range(self.p_model.d): 37 | self.ignore_index.append([]) 38 | self.sample_index.append(list(range(self.p_model.Cmax))) 39 | self.pruned_index.append([]) 40 | self.val_performance.append(np.zeros([self.p_model.d, self.p_model.Cmax])) 41 | 42 | def get_training_epoch(self): 43 | return self.steps * sum(list(range(self.p_model.Cmax))) 44 | 45 | def sampling(self): 46 | # return self.sampling_index() 47 | self.sample = self.sampling_index() 48 | return index_to_one_hot(self.sample, self.p_model.Cmax) 49 | 50 | def sampling_index(self): 51 | sample = [] 52 | for i in range(self.p_model.d): 53 | sample.append(random.choice(self.sample_index[i])) 54 | if len(self.sample_index[i]) > 0: 55 | self.sample_index[i].remove(sample[i]) 56 | return np.array(sample) 57 | 58 | def sample_with_constrains(self): 59 | # pass 60 | raise NotImplementedError 61 | 62 | def record_information(self, sample, performance): 63 | # self.sample = sample 64 | for i in range(self.p_model.d): 65 | self.val_performance[-1][i, self.sample[i]] = performance 66 | 67 | def update_sample_index(self): 68 | for i in range(self.p_model.d): 69 | self.sample_index[i] = list(set(range(self.p_model.Cmax)) - set(self.pruned_index[i])) 70 | 71 | def update(self): 72 | # when a search epoch for operations is over and not to the total search epoch 73 | if len(self.sample_index[0]) == 0: 74 | if self.current_step < self.steps: 75 | self.current_step += 1 76 | # append new val performance 77 | self.val_performance.append(np.zeros([self.p_model.d, self.p_model.Cmax])) 78 | # when the total search is down 79 | else: 80 | self.current_step += 1 81 | expectation = np.zeros([self.p_model.d, self.p_model.Cmax]) 82 | for i in range(self.steps): 83 | # multi 100 to ignore 0 84 | expectation += softmax(self.val_performance[i] * 100, axis=1) 85 | expectation = expectation / float(self.steps) 86 | # self.p_model.theta = expectation + self.score_decay * self.p_model.theta 87 | self.velocity = self.gamma * self.velocity + (1 - self.gamma) * expectation 88 | # NOTE: THETA_LR not applied. 89 | self.p_model.theta += self.velocity 90 | # self.p_model.theta = self.p_model.theta + self.theta_lr * expectation 91 | # prune the index 92 | pruned_weight = deepcopy(self.p_model.theta) 93 | for index in range(self.p_model.d): 94 | if not len(self.pruned_index[index]) == 0: 95 | pruned_weight[index, self.pruned_index[index]] = np.nan 96 | pruned_index = np.nanargmin(pruned_weight[index, :]) 97 | # if self.non_param_index_count[index] == 3 and pruned_index in self.non_param_index: 98 | # pruned_weight[index, pruned_index] = np.nan 99 | # pruned_index = np.nanargmin(pruned_weight[index, :]) 100 | # if self.param_index_count[index] == 3 and pruned_index in self.param_index: 101 | # pruned_weight[index, pruned_index] = np.nan 102 | # pruned_index = np.nanargmin(pruned_weight[index, :]) 103 | # if pruned_index in self.param_index: 104 | # self.param_index_count[index] += 1 105 | # if pruned_index in self.non_param_index: 106 | # self.non_param_index_count[index] += 1 107 | self.pruned_index[index].append(pruned_index) 108 | # self.p_model.theta[index, pruned_index] = 0 109 | self.p_model.theta /= np.sum(self.p_model.theta, axis=1)[:, np.newaxis] 110 | # if self.param_index_count[0] == 3 and self.non_param_index_count[0] == 3: 111 | # self.training_finish = True 112 | self.current_step = 1 113 | # init val_performance 114 | self.val_performance = [] 115 | self.val_performance.append(np.zeros([self.p_model.d, self.p_model.Cmax])) 116 | self.update_sample_index() 117 | -------------------------------------------------------------------------------- /xnas/algorithms/SNG/GridSearch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from xnas.algorithms.SNG.categorical import Categorical 4 | 5 | 6 | class GridSearch: 7 | def __init__(self, categories, fresh_size=4, init_theta=None, max_mize=True): 8 | 9 | self.p_model = Categorical(categories) 10 | self.p_model.C = np.array(self.p_model.C) 11 | self.valid_d = len(self.p_model.C[self.p_model.C > 1]) 12 | 13 | # Refresh theta 14 | for k in range(self.p_model.d): 15 | self.p_model.theta[k, 0] = 1 16 | self.p_model.theta[k, 1:self.p_model.C[k]] = 0 17 | 18 | if init_theta is not None: 19 | self.p_model.theta = init_theta 20 | 21 | self.fresh_size = fresh_size 22 | self.sample = [] 23 | self.objective = [] 24 | self.maxmize = -1 if max_mize else 1 25 | self.obj_optim = float('inf') 26 | self.training_finish = False 27 | 28 | # Record point to move 29 | self.sample_point = self.p_model.theta 30 | self.point = [self.p_model.d-1, 0] 31 | 32 | def sampling(self): 33 | return self.sample_point 34 | 35 | def record_information(self, sample, objective): 36 | self.sample.append(sample) 37 | self.objective.append(objective * self.maxmize) 38 | 39 | def update(self): 40 | """ 41 | Update sampling by grid search 42 | e.g. 43 | categories = [3, 2, 4] 44 | sample point as = [1, 1, 1] 45 | [0, 0, 0] 46 | [0, , 0] 47 | [ , , 0] 48 | point as now searching = [2, 0] 49 | """ 50 | if len(self.sample) == self.fresh_size: 51 | # update sample 52 | if self.point[1] == self.p_model.C[self.point[0]] - 1: 53 | for i in range(self.point[0] + 1, self.p_model.d): 54 | self.sample_point[i] = np.zeros(self.p_model.Cmax) 55 | self.sample_point[i][0] = 1 56 | for j in range(self.point[0], -1, -1): 57 | k = np.argmax(self.sample_point[j]) 58 | if k < self.p_model.C[j] - 1: 59 | self.sample_point[j][k] = 0 60 | self.sample_point[j][k+1] = 1 61 | break 62 | else: 63 | self.sample_point[j][k] = 0 64 | self.sample_point[j][0] = 1 65 | if j == 0: 66 | self.training_finish = True 67 | break 68 | self.point = [self.p_model.d-1, 0] 69 | else: 70 | self.sample_point[self.point[0], self.point[1]] = 0 71 | self.sample_point[self.point[0], self.point[1]+1] = 1 72 | self.point[1] += 1 73 | 74 | # update optim and theta 75 | if min(self.objective) < self.obj_optim: 76 | self.obj_optim = min(self.objective) 77 | seq = np.argmin(np.array(self.objective)) 78 | self.p_model.theta = self.sample[seq] 79 | 80 | # update record 81 | self.sample = [] 82 | self.objective = [] -------------------------------------------------------------------------------- /xnas/algorithms/SNG/MDENAS.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from xnas.algorithms.SNG.categorical import Categorical 4 | 5 | 6 | class CategoricalMDENAS: 7 | def __init__(self, category, learning_rate): 8 | self.p_model = Categorical(categories=category) 9 | self.leaning_rate = learning_rate 10 | self.information_recoder = {'epoch': np.zeros((self.p_model.d, self.p_model.Cmax)), 11 | 'performance': np.zeros((self.p_model.d, self.p_model.Cmax))} 12 | 13 | def sampling(self): 14 | return self.p_model.sampling() 15 | 16 | def sampling_index(self): 17 | return self.p_model.sampling_index() 18 | 19 | def record_information(self, sample, performance): 20 | for i in range(len(sample)): 21 | self.information_recoder['epoch'][i, sample[i]] += 1 22 | self.information_recoder['performance'][i, sample[i]] = performance 23 | 24 | def update(self): 25 | # update the probability 26 | for edges_index in range(self.p_model.d): 27 | for i in range(self.p_model.Cmax): 28 | for j in range(i+1, self.p_model.Cmax): 29 | if (self.information_recoder['epoch'][edges_index, i] >= self.information_recoder['epoch'][edges_index, j])\ 30 | and (self.information_recoder['performance'][edges_index, i] < self.information_recoder['performance'][edges_index, j]): 31 | if self.p_model.theta[edges_index, i] > self.leaning_rate: 32 | self.p_model.theta[edges_index, i] -= self.leaning_rate 33 | self.p_model.theta[edges_index, j] += self.leaning_rate 34 | else: 35 | self.p_model.theta[edges_index, j] += self.p_model.theta[edges_index, i] 36 | self.p_model.theta[edges_index, i] = 0 37 | 38 | if (self.information_recoder['epoch'][edges_index, i] <= self.information_recoder['epoch'][edges_index, j]) \ 39 | and (self.information_recoder['performance'][edges_index, i] > self.information_recoder['performance'][edges_index, j]): 40 | if self.p_model.theta[edges_index, j] > self.leaning_rate: 41 | self.p_model.theta[edges_index, j] -= self.leaning_rate 42 | self.p_model.theta[edges_index, i] += self.leaning_rate 43 | else: 44 | self.p_model.theta[edges_index, i] += self.p_model.theta[edges_index, j] 45 | self.p_model.theta[edges_index, j] = 0 46 | -------------------------------------------------------------------------------- /xnas/algorithms/SNG/RAND.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from xnas.algorithms.SNG.categorical import Categorical 4 | from xnas.core.utils import one_hot_to_index 5 | 6 | 7 | class RandomSample: 8 | """ 9 | Random Sample for Categorical Distribution 10 | """ 11 | 12 | def __init__(self, categories, delta_init=1., opt_type="best", init_theta=None, max_mize=True): 13 | # Categorical distribution 14 | self.p_model = Categorical(categories) 15 | # valid dimension size 16 | self.p_model.C = np.array(self.p_model.C) 17 | 18 | if init_theta is not None: 19 | self.p_model.theta = init_theta 20 | 21 | self.sample_list = [] 22 | self.obj_list = [] 23 | self.max_mize = -1 if max_mize else 1 24 | 25 | self.select = opt_type 26 | self.best_object = 1e10 * self.max_mize 27 | 28 | def record_information(self, sample, objective): 29 | self.sample_list.append(sample) 30 | self.obj_list.append(objective*self.max_mize) 31 | 32 | def sampling(self): 33 | """ 34 | Draw a sample from the categorical distribution (one-hot) 35 | Sample one archi at once 36 | """ 37 | c = np.zeros(self.p_model.theta.shape, dtype=np.bool) 38 | for i, upper in enumerate(self.p_model.C): 39 | j = np.random.randint(upper) 40 | c[i, j] = True 41 | return c 42 | 43 | def sampling_index(self): 44 | return one_hot_to_index(np.array(self.sampling())) 45 | 46 | def mle(self): 47 | """ 48 | Get most likely categorical variables (one-hot) 49 | """ 50 | m = self.p_model.theta.argmax(axis=1) 51 | x = np.zeros((self.p_model.d, self.p_model.Cmax)) 52 | for i, c in enumerate(m): 53 | x[i, c] = 1 54 | return x 55 | 56 | def update(self): 57 | objective = np.array(self.obj_list[-1]) 58 | sample = np.array(self.sample_list[-1]) 59 | if (self.select == 'best'): 60 | if (objective > self.best_object): 61 | self.best_object = objective 62 | # refresh theta to best one 63 | self.p_model.theta = np.array(sample) 64 | else: 65 | raise NotImplementedError 66 | self.sample_list = [] 67 | self.obj_list = [] 68 | -------------------------------------------------------------------------------- /xnas/algorithms/SNG/categorical.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | 4 | 5 | class Categorical(object): 6 | """ 7 | Categorical distribution for categorical variables parametrized by :math:`\\{ \\theta \\}_{i=1}^{(d \\times K)}`. 8 | 9 | :param categories: the numbers of categories 10 | :type categories: array_like, shape(d), dtype=int 11 | """ 12 | 13 | def __init__(self, categories): 14 | self.d = len(categories) 15 | self.C = categories 16 | self.Cmax = np.max(categories) 17 | self.theta = np.zeros((self.d, self.Cmax)) 18 | # initialize theta by 1/C for each dimensions 19 | for i in range(self.d): 20 | self.theta[i, :self.C[i]] = 1./self.C[i] 21 | # pad zeros to unused elements 22 | for i in range(self.d): 23 | self.theta[i, self.C[i]:] = 0. 24 | # number of valid parameters 25 | # self.valid_param_num = int(np.sum(self.C - 1)) 26 | # valid dimension size 27 | # self.valid_d = len(self.C[self.C > 1]) 28 | 29 | def sampling_lam(self, lam): 30 | """ 31 | Draw :math:`\\lambda` samples from the categorical distribution. 32 | :param int lam: sample size :math:`\\lambda` 33 | :return: sampled variables from the categorical distribution (one-hot representation) 34 | :rtype: array_like, shape=(lam, d, Cmax), dtype=bool 35 | """ 36 | rand = np.random.rand(lam, self.d, 1) # range of random number is [0, 1) 37 | cum_theta = self.theta.cumsum(axis=1) # (d, Cmax) 38 | X = (cum_theta - self.theta <= rand) & (rand < cum_theta) 39 | return X 40 | 41 | def sampling(self): 42 | """ 43 | Draw a sample from the categorical distribution. 44 | :return: sampled variables from the categorical distribution (one-hot representation) 45 | :rtype: array_like, shape=(d, Cmax), dtype=bool 46 | """ 47 | rand = np.random.rand(self.d, 1) # range of random number is [0, 1) 48 | cum_theta = self.theta.cumsum(axis=1) # (d, Cmax) 49 | 50 | # x[i, j] becomes 1 iff cum_theta[i, j] - theta[i, j] <= rand[i] < cum_theta[i, j] 51 | x = (cum_theta - self.theta <= rand) & (rand < cum_theta) 52 | return x 53 | 54 | def sampling_index(self): 55 | """ 56 | 57 | :return: a numpy index list 58 | """ 59 | index_list = [] 60 | for prob in self.theta: 61 | index_list.append(np.random.choice(a=list(range(prob.shape[0])), p=prob)) 62 | return np.array(index_list) 63 | 64 | def mle(self): 65 | """ 66 | Return the most likely categories. 67 | :return: categorical variables (one-hot representation) 68 | :rtype: array_like, shape=(d, Cmax), dtype=bool 69 | """ 70 | m = self.theta.argmax(axis=1) 71 | x = np.zeros((self.d, self.Cmax)) 72 | for i, c in enumerate(m): 73 | x[i, c] = 1 74 | return x 75 | 76 | def loglikelihood(self, X): 77 | """ 78 | Calculate log likelihood. 79 | 80 | :param X: samples (one-hot representation) 81 | :type X: array_like, shape=(lam, d, maxK), dtype=bool 82 | :return: log likelihoods 83 | :rtype: array_like, shape=(lam), dtype=float 84 | """ 85 | return (X * np.log(self.theta)).sum(axis=2).sum(axis=1) 86 | 87 | def log_header(self): 88 | header_list = [] 89 | for i in range(self.d): 90 | header_list += ['theta%d_%d' % (i, j) for j in range(self.C[i])] 91 | return header_list 92 | 93 | def log(self): 94 | theta_list = [] 95 | for i in range(self.d): 96 | theta_list += ['%f' % self.theta[i, j] for j in range(self.C[i])] 97 | return theta_list 98 | 99 | def load_theta_from_log(self, theta): 100 | self.theta = np.zeros((self.d, self.Cmax)) 101 | k = 0 102 | for i in range(self.d): 103 | for j in range(self.C[i]): 104 | self.theta[i, j] = theta[k] 105 | k += 1 106 | 107 | def print_theta(self, logger): 108 | # remove formats 109 | org_formatters = [] 110 | for handler in logger.handlers: 111 | org_formatters.append(handler.formatter) 112 | handler.setFormatter(logging.Formatter("%(message)s")) 113 | 114 | logger.info("####### theta #######") 115 | logger.info("# Theta value") 116 | for alpha in self.theta: 117 | logger.info(alpha) 118 | logger.info("#####################") 119 | 120 | # restore formats 121 | for handler, formatter in zip(logger.handlers, org_formatters): 122 | handler.setFormatter(formatter) 123 | -------------------------------------------------------------------------------- /xnas/algorithms/SPOS.py: -------------------------------------------------------------------------------- 1 | """Samplers for Single Path One-Shot Search Space.""" 2 | 3 | import numpy as np 4 | from copy import deepcopy 5 | from collections import deque 6 | 7 | 8 | class RAND(): 9 | """Random choice""" 10 | def __init__(self, num_choice, layers): 11 | self.num_choice = num_choice 12 | self.child_len = layers 13 | self.history = [] 14 | 15 | def record(self, child, value): 16 | self.history.append({"child":child, "value":value}) 17 | 18 | def suggest(self): 19 | return list(np.random.randint(self.num_choice, size=self.child_len)) 20 | 21 | def final_best(self): 22 | best_child = min(self.history, key=lambda i:i["value"]) 23 | return best_child['child'], best_child['value'] 24 | 25 | 26 | class REA(): 27 | """Regularized Evolution Algorithm""" 28 | def __init__(self, num_choice, layers, population_size=20, better=min): 29 | self.num_choice = num_choice 30 | self.population_size = population_size 31 | self.child_len = layers 32 | self.better = better 33 | self.population = deque() 34 | self.history = [] 35 | # init population 36 | self.init_pop = np.random.randint( 37 | self.num_choice, size=(self.population_size, self.child_len) 38 | ) 39 | 40 | def _get_mutated_parent(self): 41 | parent = self.better(self.population, key=lambda i:i["value"]) # default: min(error) 42 | return self._mutate(parent['child']) 43 | 44 | def _mutate(self, parent): 45 | parent = deepcopy(parent) 46 | idx = np.random.randint(0, len(parent)) 47 | prev_value, new_value = parent[idx], parent[idx] 48 | while new_value == prev_value: 49 | new_value = np.random.randint(self.num_choice) 50 | parent[idx] = new_value 51 | return parent 52 | 53 | def record(self, child, value): 54 | self.history.append({"child":child, "value":value}) 55 | self.population.append({"child":child, "value":value}) 56 | if len(self.population) > self.population_size: 57 | self.population.popleft() 58 | 59 | def suggest(self): 60 | if len(self.history) < self.population_size: 61 | return list(self.init_pop[len(self.history)]) 62 | else: 63 | return self._get_mutated_parent() 64 | 65 | def final_best(self): 66 | best_child = self.better(self.history, key=lambda i:i["value"]) 67 | return best_child['child'], best_child['value'] 68 | -------------------------------------------------------------------------------- /xnas/core/config.py: -------------------------------------------------------------------------------- 1 | """Configuration file (powered by YACS).""" 2 | 3 | import os 4 | import sys 5 | import argparse 6 | from yacs.config import CfgNode 7 | 8 | 9 | # Global config object 10 | _C = CfgNode(new_allowed=True) 11 | cfg = _C 12 | 13 | 14 | # -------------------------------------------------------- # 15 | # Data Loader options 16 | # -------------------------------------------------------- # 17 | _C.LOADER = CfgNode(new_allowed=True) 18 | 19 | _C.LOADER.DATASET = "cifar10" 20 | 21 | # stay empty to use "./data/$dataset" as default 22 | _C.LOADER.DATAPATH = "" 23 | 24 | _C.LOADER.SPLIT = [0.8, 0.2] 25 | 26 | # whether using val dataset (imagenet only) 27 | _C.LOADER.USE_VAL = False 28 | 29 | _C.LOADER.NUM_CLASSES = 10 30 | 31 | _C.LOADER.NUM_WORKERS = 8 32 | 33 | _C.LOADER.PIN_MEMORY = True 34 | 35 | # batch size of training and validation 36 | # type: int or list(different during validation) 37 | # _C.LOADER.BATCH_SIZE = [256, 128] 38 | _C.LOADER.BATCH_SIZE = 256 39 | 40 | # augment type using by ImageNet only 41 | # chosen from ['default', 'auto_augment_tf'] 42 | _C.LOADER.TRANSFORM = "default" 43 | 44 | 45 | # ------------------------------------------------------------------------------------ # 46 | # Search Space options 47 | # ------------------------------------------------------------------------------------ # 48 | _C.SPACE = CfgNode(new_allowed=True) 49 | 50 | _C.SPACE.NAME = 'darts' 51 | 52 | # first layer's channels, not channels for input image. 53 | _C.SPACE.CHANNELS = 16 54 | 55 | _C.SPACE.LAYERS = 8 56 | 57 | _C.SPACE.NODES = 4 58 | 59 | _C.SPACE.BASIC_OP = [] 60 | 61 | 62 | 63 | # ------------------------------------------------------------------------------------ # 64 | # Optimizer options in network 65 | # ------------------------------------------------------------------------------------ # 66 | _C.OPTIM = CfgNode(new_allowed=True) 67 | 68 | # Base learning rate, init_lr = OPTIM.BASE_LR * NUM_GPUS 69 | _C.OPTIM.BASE_LR = 0.1 70 | 71 | _C.OPTIM.MIN_LR = 1.e-3 72 | 73 | 74 | # Learning rate policy select from {'cos', 'exp', 'steps'} 75 | _C.OPTIM.LR_POLICY = "cos" 76 | # Steps for 'steps' policy (in epochs) 77 | _C.OPTIM.STEPS = [30, 60, 90] 78 | # Learning rate multiplier for 'steps' policy 79 | _C.OPTIM.LR_MULT = 0.1 80 | 81 | 82 | # Momentum 83 | _C.OPTIM.MOMENTUM = 0.9 84 | # Momentum dampening 85 | _C.OPTIM.DAMPENING = 0.0 86 | # Nesterov momentum 87 | _C.OPTIM.NESTEROV = False 88 | 89 | 90 | _C.OPTIM.WEIGHT_DECAY = 5e-4 91 | 92 | _C.OPTIM.GRAD_CLIP = 5.0 93 | 94 | _C.OPTIM.MAX_EPOCH = 200 95 | # Warm up epochs 96 | _C.OPTIM.WARMUP_EPOCH = 0 97 | # Start the warm up from init_lr * OPTIM.WARMUP_FACTOR 98 | _C.OPTIM.WARMUP_FACTOR = 0.1 99 | # Ending epochs 100 | _C.OPTIM.FINAL_EPOCH = 0 101 | 102 | 103 | 104 | # -------------------------------------------------------- # 105 | # Searching options 106 | # -------------------------------------------------------- # 107 | _C.SEARCH = CfgNode(new_allowed=True) 108 | 109 | _C.SEARCH.IM_SIZE = 32 110 | 111 | # Multi-sized Crop 112 | # NOTE: IM_SIZE in ImageNet will be covered if this one is setted. 113 | _C.SEARCH.MULTI_SIZES = [] 114 | 115 | # channels of input images, 3 for rgb 116 | _C.SEARCH.INPUT_CHANNELS = 3 117 | 118 | _C.SEARCH.LOSS_FUN = 'cross_entropy' 119 | # label smoothing for cross entropy loss 120 | _C.SEARCH.LABEL_SMOOTH = 0. 121 | 122 | # resume and path of checkpoints 123 | _C.SEARCH.AUTO_RESUME = True 124 | 125 | _C.SEARCH.WEIGHTS = "" 126 | 127 | _C.SEARCH.EVALUATION = "" 128 | 129 | 130 | # ------------------------------------------------------------------------------------ # 131 | # Options for model training 132 | # ------------------------------------------------------------------------------------ # 133 | _C.TRAIN = CfgNode(new_allowed=True) 134 | 135 | _C.TRAIN.IM_SIZE = 32 136 | 137 | # channels of input images, 3 for rgb 138 | _C.TRAIN.INPUT_CHANNELS = 3 139 | 140 | _C.TRAIN.DROP_PATH_PROB = 0.2 141 | 142 | _C.TRAIN.LAYERS = 20 143 | 144 | _C.TRAIN.CHANNELS = 36 145 | 146 | _C.TRAIN.GENOTYPE = "" 147 | 148 | 149 | # -------------------------------------------------------- # 150 | # Model testing options 151 | # -------------------------------------------------------- # 152 | _C.TEST = CfgNode(new_allowed=True) 153 | 154 | _C.TEST.IM_SIZE = 224 155 | 156 | # using specific batchsize for testing 157 | # using search.batch_size if this value keeps -1 158 | _C.TEST.BATCH_SIZE = -1 159 | 160 | 161 | 162 | # -------------------------------------------------------- # 163 | # Benchmarks options 164 | # -------------------------------------------------------- # 165 | _C.BENCHMARK = CfgNode(new_allowed=True) 166 | 167 | # Path to NAS-Bench-201 weights file 168 | _C.BENCHMARK.NB201PATH = "./data/NAS-Bench-201-v1_1-096897.pth" 169 | 170 | # path to NAS-Bench-301 folder 171 | _C.BENCHMARK.NB301PATH = "./data/nb301models/" 172 | 173 | 174 | # -------------------------------------------------------- # 175 | # Misc options 176 | # -------------------------------------------------------- # 177 | 178 | _C.CUDNN_BENCH = True 179 | 180 | _C.LOG_PERIOD = 10 181 | 182 | _C.EVAL_PERIOD = 1 183 | 184 | _C.SAVE_PERIOD = 1 185 | 186 | _C.NUM_GPUS = 1 187 | 188 | _C.OUT_DIR = "exp/" 189 | 190 | _C.DETERMINSTIC = True 191 | 192 | _C.RNG_SEED = 1 193 | 194 | 195 | 196 | # -------------------------------------------------------- # 197 | 198 | def dump_cfgfile(cfg_dest="config.yaml"): 199 | """Dumps the config to the output directory.""" 200 | cfg_file = os.path.join(_C.OUT_DIR, cfg_dest) 201 | with open(cfg_file, "w") as f: 202 | _C.dump(stream=f) 203 | 204 | 205 | def load_cfgfile(out_dir, cfg_dest="config.yaml"): 206 | """Loads config from specified output directory.""" 207 | cfg_file = os.path.join(out_dir, cfg_dest) 208 | _C.merge_from_file(cfg_file) 209 | 210 | 211 | def load_configs(): 212 | """Load config from command line arguments and set any specified options. 213 | How to use: python xx.py --cfg path_to_your_config.cfg test1 0 test2 True 214 | opts will return a list with ['test1', '0', 'test2', 'True'], yacs will compile to corresponding values 215 | """ 216 | parser = argparse.ArgumentParser(description="Config file options.") 217 | parser.add_argument("--cfg", required=True, type=str) 218 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER) 219 | if len(sys.argv) == 1: 220 | parser.print_help() 221 | sys.exit(1) 222 | args = parser.parse_args() 223 | _C.merge_from_file(args.cfg) 224 | _C.merge_from_list(args.opts) 225 | -------------------------------------------------------------------------------- /xnas/core/utils.py: -------------------------------------------------------------------------------- 1 | import decimal 2 | import random 3 | import string 4 | import time 5 | import numpy as np 6 | 7 | 8 | def random_time_string(stringLength=8): 9 | letters = string.ascii_lowercase 10 | return str(time.time()).join(random.choice(letters) for i in range(stringLength)) 11 | 12 | 13 | def one_hot_to_index(one_hot_matrix): 14 | return np.array([np.where(r == 1)[0][0] for r in one_hot_matrix]) 15 | 16 | 17 | def index_to_one_hot(index_vector, C): 18 | return np.eye(C)[index_vector.reshape(-1)] 19 | 20 | 21 | def float_to_decimal(data, prec=4): 22 | """Convert floats to decimals which allows for fixed width json.""" 23 | if isinstance(data, dict): 24 | return {k: float_to_decimal(v, prec) for k, v in data.items()} 25 | if isinstance(data, float): 26 | return decimal.Decimal(("{:." + str(prec) + "f}").format(data)) 27 | else: 28 | return data 29 | 30 | 31 | def softmax(x, axis=None): 32 | x = x - x.max(axis=axis, keepdims=True) 33 | y = np.exp(x) 34 | return y / y.sum(axis=axis, keepdims=True) 35 | -------------------------------------------------------------------------------- /xnas/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/xnas/datasets/__init__.py -------------------------------------------------------------------------------- /xnas/datasets/imagenet16.py: -------------------------------------------------------------------------------- 1 | import os, sys, hashlib 2 | import numpy as np 3 | from PIL import Image 4 | import torch.utils.data as data 5 | import pickle 6 | 7 | 8 | # def ImageNet16Loader(): 9 | # mean = [x / 255 for x in [122.68, 116.66, 104.01]] 10 | # std = [x / 255 for x in [63.22, 61.26, 65.09]] 11 | # lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)] 12 | # train_transform = transforms.Compose(lists) 13 | # train_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120) 14 | # assert len(train_data) == 151700 15 | 16 | def calculate_md5(fpath, chunk_size=1024 * 1024): 17 | md5 = hashlib.md5() 18 | with open(fpath, "rb") as f: 19 | for chunk in iter(lambda: f.read(chunk_size), b""): 20 | md5.update(chunk) 21 | return md5.hexdigest() 22 | 23 | 24 | def check_md5(fpath, md5, **kwargs): 25 | return md5 == calculate_md5(fpath, **kwargs) 26 | 27 | 28 | def check_integrity(fpath, md5=None): 29 | if not os.path.isfile(fpath): 30 | return False 31 | if md5 is None: 32 | return True 33 | else: 34 | return check_md5(fpath, md5) 35 | 36 | 37 | class ImageNet16(data.Dataset): 38 | # http://image-net.org/download-images 39 | # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets 40 | # https://arxiv.org/pdf/1707.08819.pdf 41 | 42 | train_list = [ 43 | ["train_data_batch_1", "27846dcaa50de8e21a7d1a35f30f0e91"], 44 | ["train_data_batch_2", "c7254a054e0e795c69120a5727050e3f"], 45 | ["train_data_batch_3", "4333d3df2e5ffb114b05d2ffc19b1e87"], 46 | ["train_data_batch_4", "1620cdf193304f4a92677b695d70d10f"], 47 | ["train_data_batch_5", "348b3c2fdbb3940c4e9e834affd3b18d"], 48 | ["train_data_batch_6", "6e765307c242a1b3d7d5ef9139b48945"], 49 | ["train_data_batch_7", "564926d8cbf8fc4818ba23d2faac7564"], 50 | ["train_data_batch_8", "f4755871f718ccb653440b9dd0ebac66"], 51 | ["train_data_batch_9", "bb6dd660c38c58552125b1a92f86b5d4"], 52 | ["train_data_batch_10", "8f03f34ac4b42271a294f91bf480f29b"], 53 | ] 54 | valid_list = [ 55 | ["val_data", "3410e3017fdaefba8d5073aaa65e4bd6"], 56 | ] 57 | 58 | def __init__(self, root, train, transform, use_num_of_class_only=None): 59 | self.root = root 60 | self.transform = transform 61 | self.train = train # training set or valid set 62 | if not self._check_integrity(): 63 | raise RuntimeError("Dataset not found or corrupted.") 64 | 65 | if self.train: 66 | downloaded_list = self.train_list 67 | else: 68 | downloaded_list = self.valid_list 69 | self.data = [] 70 | self.targets = [] 71 | 72 | # now load the picked numpy arrays 73 | for i, (file_name, checksum) in enumerate(downloaded_list): 74 | file_path = os.path.join(self.root, file_name) 75 | # print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) 76 | with open(file_path, "rb") as f: 77 | if sys.version_info[0] == 2: 78 | entry = pickle.load(f) 79 | else: 80 | entry = pickle.load(f, encoding="latin1") 81 | self.data.append(entry["data"]) 82 | self.targets.extend(entry["labels"]) 83 | self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) 84 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 85 | if use_num_of_class_only is not None: 86 | assert ( 87 | isinstance(use_num_of_class_only, int) 88 | and use_num_of_class_only > 0 89 | and use_num_of_class_only < 1000 90 | ), "invalid use_num_of_class_only : {:}".format(use_num_of_class_only) 91 | new_data, new_targets = [], [] 92 | for I, L in zip(self.data, self.targets): 93 | if 1 <= L <= use_num_of_class_only: 94 | new_data.append(I) 95 | new_targets.append(L) 96 | self.data = new_data 97 | self.targets = new_targets 98 | # self.mean.append(entry['mean']) 99 | # self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16) 100 | # self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1) 101 | # print ('Mean : {:}'.format(self.mean)) 102 | # temp = self.data - np.reshape(self.mean, (1, 1, 1, 3)) 103 | # std_data = np.std(temp, axis=0) 104 | # std_data = np.mean(np.mean(std_data, axis=0), axis=0) 105 | # print ('Std : {:}'.format(std_data)) 106 | 107 | def __getitem__(self, index): 108 | img, target = self.data[index], self.targets[index] - 1 109 | img = Image.fromarray(img) 110 | if self.transform is not None: 111 | img = self.transform(img) 112 | return img, target 113 | 114 | def __len__(self): 115 | return len(self.data) 116 | 117 | def _check_integrity(self): 118 | root = self.root 119 | for fentry in self.train_list + self.valid_list: 120 | filename, md5 = fentry[0], fentry[1] 121 | fpath = os.path.join(root, filename) 122 | if not check_integrity(fpath, md5): 123 | return False 124 | return True 125 | -------------------------------------------------------------------------------- /xnas/datasets/transforms_imagenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from PIL import Image 4 | import torchvision.transforms as transforms 5 | 6 | from xnas.core.config import cfg 7 | from xnas.datasets.auto_augment_tf import auto_augment_policy, AutoAugment 8 | 9 | 10 | IMAGENET_RGB_MEAN = [0.485, 0.456, 0.406] 11 | IMAGENET_RGB_STD = [0.229, 0.224, 0.225] 12 | 13 | 14 | def get_data_transform(augment, **kwargs): 15 | if len(cfg.SEARCH.MULTI_SIZES)==0: 16 | # using single image_size for training 17 | train_crop_size = cfg.SEARCH.IM_SIZE 18 | else: 19 | # using MultiSize_RandomCrop 20 | train_crop_size = cfg.SEARCH.MULTI_SIZES 21 | min_train_scale = 0.08 22 | test_scale = math.ceil(cfg.TEST.IM_SIZE / 0.875) # 224 / 0.875 = 256 23 | test_crop_size = cfg.TEST.IM_SIZE # do not crop and using 224 by default. 24 | 25 | interpolation = transforms.InterpolationMode.BICUBIC 26 | if 'interpolation' in kwargs.keys() and kwargs['interpolation'] == 'bilinear': 27 | interpolation = transforms.InterpolationMode.BILINEAR 28 | 29 | da_args = { 30 | 'train_crop_size': train_crop_size, 31 | 'train_min_scale': min_train_scale, 32 | 'test_scale': test_scale, 33 | 'test_crop_size': test_crop_size, 34 | 'interpolation': interpolation, 35 | } 36 | 37 | if augment == 'default': 38 | return build_default_transform(**da_args) 39 | elif augment == 'auto_augment_tf': 40 | policy = 'v0' if 'policy' not in kwargs.keys() else kwargs['policy'] 41 | return build_imagenet_auto_augment_tf_transform(policy=policy, **da_args) 42 | else: 43 | raise ValueError(augment) 44 | 45 | 46 | def get_normalize(): 47 | return transforms.Normalize( 48 | mean=torch.Tensor(IMAGENET_RGB_MEAN), 49 | std=torch.Tensor(IMAGENET_RGB_STD), 50 | ) 51 | 52 | 53 | def get_randomResizedCrop(train_crop_size=224, train_min_scale=0.08, interpolation=transforms.InterpolationMode.BICUBIC): 54 | if isinstance(train_crop_size, int): 55 | return transforms.RandomResizedCrop(train_crop_size, scale=(train_min_scale, 1.0), interpolation=interpolation) 56 | elif isinstance(train_crop_size, list): 57 | from xnas.datasets.transforms import MultiSizeRandomCrop 58 | msrc = MultiSizeRandomCrop(train_crop_size) 59 | return msrc 60 | else: 61 | raise TypeError(train_crop_size) 62 | 63 | 64 | def build_default_transform( 65 | train_crop_size=224, train_min_scale=0.08, test_scale=256, test_crop_size=224, interpolation=transforms.InterpolationMode.BICUBIC 66 | ): 67 | normalize = get_normalize() 68 | train_crop_transform = get_randomResizedCrop( 69 | train_crop_size, train_min_scale, interpolation 70 | ) 71 | train_transform = transforms.Compose( 72 | [ 73 | # transforms.RandomResizedCrop(train_crop_size, interpolation=interpolation), 74 | train_crop_transform, 75 | transforms.RandomHorizontalFlip(), 76 | transforms.ToTensor(), 77 | normalize, 78 | ] 79 | ) 80 | test_transform = transforms.Compose( 81 | [ 82 | transforms.Resize(test_scale, interpolation=interpolation), 83 | transforms.CenterCrop(test_crop_size), 84 | transforms.ToTensor(), 85 | normalize, 86 | ] 87 | ) 88 | return train_transform, test_transform 89 | 90 | 91 | def build_imagenet_auto_augment_tf_transform( 92 | policy='v0', train_crop_size=224, train_min_scale=0.08, test_scale=256, test_crop_size=224, interpolation=transforms.InterpolationMode.BICUBIC 93 | ): 94 | 95 | normalize = get_normalize() 96 | img_size = train_crop_size 97 | aa_params = { 98 | "translate_const": int(img_size * 0.45), 99 | "img_mean": tuple(round(x) for x in IMAGENET_RGB_MEAN), 100 | } 101 | 102 | aa_policy = AutoAugment(auto_augment_policy(policy, aa_params)) 103 | train_crop_transform = get_randomResizedCrop( 104 | train_crop_size, train_min_scale, interpolation 105 | ) 106 | train_transform = transforms.Compose( 107 | [ 108 | # transforms.RandomResizedCrop(train_crop_size, interpolation=interpolation), 109 | train_crop_transform, 110 | transforms.RandomHorizontalFlip(), 111 | aa_policy, 112 | transforms.ToTensor(), 113 | normalize, 114 | ] 115 | ) 116 | test_transform = transforms.Compose( 117 | [ 118 | transforms.Resize(test_scale, interpolation=interpolation), 119 | transforms.CenterCrop(test_crop_size), 120 | transforms.ToTensor(), 121 | normalize, 122 | ] 123 | ) 124 | return train_transform, test_transform 125 | -------------------------------------------------------------------------------- /xnas/evaluations/NASBench201.py: -------------------------------------------------------------------------------- 1 | """Evaluate model by NAS-Bench-201""" 2 | 3 | from xnas.core.config import cfg 4 | import xnas.logger.logging as logging 5 | 6 | logger = logging.get_logger(__name__) 7 | 8 | 9 | try: 10 | from nas_201_api import NASBench201API as API 11 | api = API(cfg.BENCHMARK.NB201PATH) 12 | except ImportError: 13 | print('Could not import NASBench201.') 14 | exit(1) 15 | 16 | 17 | def index_to_genotype(index): 18 | return api.arch(index) 19 | 20 | def evaluate(genotype, epoch=12, **kwargs): 21 | """Require info from NAS-Bench-201 API. 22 | 23 | Implemented following the source code of DrNAS. 24 | """ 25 | result = api.query_by_arch(genotype, str(epoch)) 26 | ( 27 | cifar10_train, 28 | cifar10_test, 29 | cifar100_train, 30 | cifar100_valid, 31 | cifar100_test, 32 | imagenet16_train, 33 | imagenet16_valid, 34 | imagenet16_test, 35 | ) = distill(result) 36 | 37 | logger.info("Evaluate with NAS-Bench-201 (Bench epoch:{})".format(epoch)) 38 | logger.info("cifar10 train %f test %f", cifar10_train, cifar10_test) 39 | logger.info("cifar100 train %f valid %f test %f", cifar100_train, cifar100_valid, cifar100_test) 40 | logger.info("imagenet16 train %f valid %f test %f", imagenet16_train, imagenet16_valid, imagenet16_test) 41 | 42 | if "writer" in kwargs.keys(): 43 | writer = kwargs["writer"] 44 | cur_epoch = kwargs["cur_epoch"] 45 | writer.add_scalars("nasbench201/cifar10", {"train": cifar10_train, "test": cifar10_test}, cur_epoch) 46 | writer.add_scalars("nasbench201/cifar100", {"train": cifar100_train, "valid": cifar100_valid, "test": cifar100_test}, cur_epoch) 47 | writer.add_scalars("nasbench201/imagenet16", {"train": imagenet16_train, "valid": imagenet16_valid, "test": imagenet16_test}, cur_epoch) 48 | return result 49 | 50 | # distill 201api's results 51 | def distill(result): 52 | result = result.split("\n") 53 | cifar10 = result[5].replace(" ", "").split(":") 54 | cifar100 = result[7].replace(" ", "").split(":") 55 | imagenet16 = result[9].replace(" ", "").split(":") 56 | 57 | cifar10_train = float(cifar10[1].strip(",test")[-7:-2].strip("=")) 58 | cifar10_test = float(cifar10[2][-7:-2].strip("=")) 59 | cifar100_train = float(cifar100[1].strip(",valid")[-7:-2].strip("=")) 60 | cifar100_valid = float(cifar100[2].strip(",test")[-7:-2].strip("=")) 61 | cifar100_test = float(cifar100[3][-7:-2].strip("=")) 62 | imagenet16_train = float(imagenet16[1].strip(",valid")[-7:-2].strip("=")) 63 | imagenet16_valid = float(imagenet16[2].strip(",test")[-7:-2].strip("=")) 64 | imagenet16_test = float(imagenet16[3][-7:-2].strip("=")) 65 | 66 | return ( 67 | cifar10_train, 68 | cifar10_test, 69 | cifar100_train, 70 | cifar100_valid, 71 | cifar100_test, 72 | imagenet16_train, 73 | imagenet16_valid, 74 | imagenet16_test, 75 | ) -------------------------------------------------------------------------------- /xnas/evaluations/NASBench301.py: -------------------------------------------------------------------------------- 1 | """Evaluate model by NAS-Bench-301""" 2 | 3 | import os 4 | from collections import namedtuple 5 | from xnas.core.config import cfg 6 | import xnas.logger.logging as logging 7 | 8 | try: 9 | import nasbench301 as nb 10 | except ImportError: 11 | print('Could not import NASBench301.') 12 | exit(1) 13 | 14 | 15 | __all__ = ['evaluate'] 16 | 17 | 18 | logger = logging.get_logger(__name__) 19 | 20 | download_dir = cfg.BENCHMARK.NB301PATH 21 | version = '0.9' 22 | 23 | 24 | def init_model(version=0.9, download_dir=download_dir): 25 | 26 | # Note: Uses 0.9 as the default models, switch to 1.0 to use 1.0 models 27 | models_0_9_dir = os.path.join(download_dir, 'nb_models_0.9') 28 | model_paths_0_9 = { 29 | model_name: os.path.join(models_0_9_dir, '{}_v0.9'.format(model_name)) 30 | for model_name in ['xgb', 'gnn_gin', 'lgb_runtime'] 31 | } 32 | models_1_0_dir = os.path.join(download_dir, 'nb_models_1.0') 33 | model_paths_1_0 = { 34 | model_name: os.path.join(models_1_0_dir, '{}_v1.0'.format(model_name)) 35 | for model_name in ['xgb', 'gnn_gin', 'lgb_runtime'] 36 | } 37 | model_paths = model_paths_0_9 if version == '0.9' else model_paths_1_0 38 | 39 | # If the models are not available at the paths, automatically download 40 | # the models 41 | # Note: If you would like to provide your own model locations, comment this out 42 | if not all(os.path.exists(model) for model in model_paths.values()): 43 | nb.download_models(version=version, delete_zip=True, 44 | download_dir=download_dir) 45 | 46 | # Load the performance surrogate model 47 | # NOTE: Loading the ensemble will set the seed to the same as used during training (logged in the model_configs.json) 48 | # NOTE: Defaults to using the default model download path 49 | ensemble_dir_performance = model_paths['xgb'] 50 | performance_model = nb.load_ensemble(ensemble_dir_performance) 51 | 52 | # Load the runtime surrogate model 53 | # NOTE: Defaults to using the default model download path 54 | ensemble_dir_runtime = model_paths['lgb_runtime'] 55 | runtime_model = nb.load_ensemble(ensemble_dir_runtime) 56 | 57 | return performance_model, runtime_model 58 | 59 | def evaluate(genotype): 60 | """ 61 | Evaluate with nasbench301, space=DARTS/nasbench301 62 | 63 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 64 | genotype = Genotype( 65 | normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], 66 | normal_concat=[2, 3, 4, 5], 67 | reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], 68 | reduce_concat=[2, 3, 4, 5] 69 | ) 70 | """ 71 | 72 | assert cfg.SPACE.NAME == "darts", "NAS-Bench-301 supports DARTS space only." 73 | 74 | performance_model, runtime_model = init_model(version, download_dir) 75 | 76 | # reformat the output of DartsCNN.genotype() 77 | genotype = reformat_DARTS(genotype) 78 | 79 | prediction_genotype = performance_model.predict( 80 | config=genotype, representation="genotype", with_noise=True) 81 | runtime_genotype = runtime_model.predict( 82 | config=genotype, representation="genotype") 83 | 84 | logger.info("Genotype architecture performance: %f, runtime %f" % 85 | (prediction_genotype, runtime_genotype)) 86 | 87 | """ 88 | Codes below are used when sampling from a ConfigSpace. 89 | Rewrite it when being used. 90 | """ 91 | # configspace_path = os.path.join(download_dir, 'configspace.json') 92 | # with open(configspace_path, "r") as f: 93 | # json_string = f.read() 94 | # configspace = cs_json.read(json_string) 95 | # configspace_config = configspace.sample_configuration() 96 | # prediction_configspace = performance_model.predict(config=configspace_config, representation="configspace", with_noise=True) 97 | # runtime_configspace = runtime_model.predict(config=configspace_config, representation="configspace") 98 | # print("Configspace architecture performance: %f, runtime %f" %(prediction_configspace, runtime_configspace)) 99 | 100 | def reformat_DARTS(genotype): 101 | """ 102 | format genotype for DARTS-like 103 | from: 104 | Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_5x5', 0)], [('sep_conv_3x3', 2), ('max_pool_3x3', 1)], [('sep_conv_3x3', 3), ('dil_conv_3x3', 2)], [('dil_conv_5x5', 4), ('dil_conv_5x5', 3)]], normal_concat=range(2, 6), reduce=[[('max_pool_3x3', 0), ('sep_conv_5x5', 1)], [('max_pool_3x3', 0), ('dil_conv_5x5', 2)], [('max_pool_3x3', 0), ('sep_conv_5x5', 1)], [('dil_conv_5x5', 4), ('max_pool_3x3', 0)]], reduce_concat=range(2, 6)) 105 | to: 106 | Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], reduce_concat=[2, 3, 4, 5]) 107 | """ 108 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 109 | if len(genotype.normal) == 1: 110 | return genotype 111 | 112 | _normal = [] 113 | _reduce = [] 114 | for i in genotype.normal: 115 | for j in i: 116 | _normal.append(j) 117 | for i in genotype.reduce: 118 | for j in i: 119 | _reduce.append(j) 120 | _normal_concat = [i for i in genotype.normal_concat] 121 | _reduce_concat = [i for i in genotype.reduce_concat] 122 | r_genotype = Genotype( 123 | normal=_normal, 124 | normal_concat=_normal_concat, 125 | reduce=_reduce, 126 | reduce_concat=_reduce_concat 127 | ) 128 | return r_genotype 129 | -------------------------------------------------------------------------------- /xnas/evaluations/NASBenchMacro/evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | with open('xnas/evaluations/NASBenchmacro/nas-bench-macro_cifar10.json', 'r+', encoding='utf-8') as data_file: 4 | data = json.load(data_file) 5 | 6 | 7 | def evaluate(arch): 8 | # print(data[arch]) 9 | print("{} accuracy for three test on CIFAR10: {}".format(arch, data[arch]['test_acc'])) 10 | print("mean accuracy : {}, std : {}, params : {}, flops : {}".format(data[arch]['mean_acc'], data[arch]['std'], 11 | data[arch]['params'], data[arch]['flops'])) 12 | -------------------------------------------------------------------------------- /xnas/evaluations/NASBenchmacro/evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | with open('xnas/evaluations/NASBenchmacro/nas-bench-macro_cifar10.json', 'r+', encoding='utf-8') as data_file: 4 | data = json.load(data_file) 5 | 6 | 7 | def evaluate(arch): 8 | # print(data[arch]) 9 | print("{} accuracy for three test on CIFAR10: {}".format(arch, data[arch]['test_acc'])) 10 | print("mean accuracy : {}, std : {}, params : {}, flops : {}".format(data[arch]['mean_acc'], data[arch]['std'], 11 | data[arch]['params'], data[arch]['flops'])) 12 | -------------------------------------------------------------------------------- /xnas/logger/checkpoint.py: -------------------------------------------------------------------------------- 1 | """Functions that handle saving and loading of checkpoints""" 2 | 3 | import os 4 | import torch 5 | from xnas.core.config import cfg 6 | 7 | # Checkpoints directory name 8 | _DIR_NAME = "checkpoints" 9 | # Common prefix for checkpoint file names 10 | _NAME_PREFIX = "model_epoch_" 11 | 12 | 13 | def get_checkpoint_dir(out_dir=None): 14 | """Retrieves the location for storing checkpoints.""" 15 | if out_dir is None: 16 | return os.path.join(cfg.OUT_DIR, _DIR_NAME) 17 | else: 18 | return os.path.join(out_dir, _DIR_NAME) 19 | 20 | 21 | def get_checkpoint_name(epoch, checkpoint_dir=None, best=False): 22 | """Retrieves the path to a checkpoint file.""" 23 | name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch) 24 | name = "best_" + name if best else name 25 | if checkpoint_dir is None: 26 | return os.path.join(get_checkpoint_dir(), name) 27 | else: 28 | return os.path.join(checkpoint_dir, name) 29 | 30 | 31 | def get_last_checkpoint(checkpoint_dir=None, best=False): 32 | """Retrieves the most recent checkpoint (highest epoch number).""" 33 | if checkpoint_dir is None: 34 | checkpoint_dir = get_checkpoint_dir() 35 | # Checkpoint file names are in lexicographic order 36 | filename = "best_" + _NAME_PREFIX if best else _NAME_PREFIX 37 | checkpoints = [f for f in os.listdir(checkpoint_dir) if filename in f] 38 | last_checkpoint_name = sorted(checkpoints)[-1] 39 | return os.path.join(checkpoint_dir, last_checkpoint_name) 40 | 41 | 42 | def has_checkpoint(checkpoint_dir=None): 43 | """Determines if there are checkpoints available.""" 44 | if checkpoint_dir is None: 45 | checkpoint_dir = get_checkpoint_dir() 46 | if not os.path.exists(checkpoint_dir): 47 | return False 48 | return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir)) 49 | 50 | 51 | def save_checkpoint(model, epoch, checkpoint_dir=None, best=False, **kwargs): 52 | """Saves a checkpoint.""" 53 | if checkpoint_dir is None: 54 | checkpoint_dir = get_checkpoint_dir() 55 | os.makedirs(checkpoint_dir, exist_ok=True) 56 | 57 | ms = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model 58 | sd = ms.state_dict() 59 | checkpoint = { 60 | "epoch": epoch, 61 | "model_state": sd, 62 | } 63 | for k,v in kwargs.items(): 64 | vsd = v.state_dict() 65 | checkpoint[k] = vsd 66 | # Write the checkpoint 67 | checkpoint_file = get_checkpoint_name(epoch + 1, checkpoint_dir, best=best) 68 | torch.save(checkpoint, checkpoint_file) 69 | return checkpoint_file 70 | 71 | 72 | def load_checkpoint(checkpoint_file, model): 73 | """Loads the checkpoint from the given file.""" 74 | err_str = "Checkpoint '{}' not found" 75 | assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file) 76 | # Load the checkpoint on CPU to avoid GPU mem spike 77 | checkpoint = torch.load(checkpoint_file, map_location="cpu") 78 | # Account for the DDP wrapper in the multi-gpu setting 79 | ms = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model 80 | ms.load_state_dict(checkpoint["model_state"]) 81 | # Load the optimizer state (commonly not done when fine-tuning) 82 | others = {} 83 | for k,v in checkpoint.items(): 84 | if k not in ["epoch", "model_state"]: 85 | others[k] = v 86 | return checkpoint["epoch"], others 87 | -------------------------------------------------------------------------------- /xnas/logger/logging.py: -------------------------------------------------------------------------------- 1 | """Logging.""" 2 | 3 | import os 4 | import logging 5 | import simplejson 6 | from xnas.core.config import cfg 7 | from xnas.core.utils import float_to_decimal 8 | 9 | 10 | # Show filename and line number in logs 11 | _FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s" 12 | 13 | # Log file name 14 | _LOG_FILE = "stdout.log" 15 | 16 | # Data output with dump_log_data(data, data_type) will be tagged w/ this 17 | _TAG = "json_stats: " 18 | 19 | # Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type 20 | _TYPE = "_type" 21 | 22 | 23 | def setup_logging(): 24 | """Sets up the logging.""" 25 | # Clear the root logger to prevent any existing logging config 26 | # (e.g. set by another module) from messing with our setup 27 | logging.root.handlers = [] 28 | # Construct logging configuration 29 | logging_config = {"level": logging.INFO, "format": _FORMAT} 30 | logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE) 31 | # Configure logging 32 | logging.basicConfig(**logging_config) 33 | 34 | 35 | def get_logger(name): 36 | """Retrieves the logger.""" 37 | return logging.getLogger(name) 38 | 39 | 40 | def dump_log_data(data, data_type, prec=4): 41 | """Covert data (a dictionary) into tagged json string for logging.""" 42 | data[_TYPE] = data_type 43 | data = float_to_decimal(data, prec) 44 | data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True) 45 | return "{:s}{:s}".format(_TAG, data_json) 46 | -------------------------------------------------------------------------------- /xnas/logger/timer.py: -------------------------------------------------------------------------------- 1 | """Timer.""" 2 | 3 | import time 4 | 5 | 6 | class Timer(object): 7 | """A simple timer (adapted from Detectron).""" 8 | 9 | def __init__(self): 10 | self.total_time = None 11 | self.calls = None 12 | self.start_time = None 13 | self.diff = None 14 | self.average_time = None 15 | self.reset() 16 | 17 | def tic(self): 18 | # using time.time as time.clock does not nomalize for multithreading 19 | self.start_time = time.time() 20 | 21 | def toc(self): 22 | self.diff = time.time() - self.start_time 23 | self.total_time += self.diff 24 | self.calls += 1 25 | self.average_time = self.total_time / self.calls 26 | 27 | def reset(self): 28 | self.total_time = 0.0 29 | self.calls = 0 30 | self.start_time = 0.0 31 | self.diff = 0.0 32 | self.average_time = 0.0 33 | -------------------------------------------------------------------------------- /xnas/runner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/xnas/runner/__init__.py -------------------------------------------------------------------------------- /xnas/runner/criterion.py: -------------------------------------------------------------------------------- 1 | """Loss functions.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from xnas.core.config import cfg 7 | 8 | 9 | __all__ = ['criterion_builder'] 10 | 11 | 12 | def _label_smooth(target, n_classes: int, label_smoothing): 13 | # convert to one-hot 14 | batch_size = target.size(0) 15 | target = torch.unsqueeze(target, 1) 16 | soft_target = torch.zeros((batch_size, n_classes), device=target.device) 17 | soft_target.scatter_(1, target, 1) 18 | # label smoothing 19 | soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes 20 | return soft_target 21 | 22 | 23 | def CrossEntropyLoss_soft_target(pred, soft_target): 24 | """CELoss with soft target, mainly used during KD""" 25 | logsoftmax = nn.LogSoftmax(dim=1) 26 | return torch.mean(torch.sum(-soft_target * logsoftmax(pred), dim=1)) 27 | 28 | 29 | def CrossEntropyLoss_label_smoothed(pred, target, label_smoothing=0.): 30 | label_smoothing = cfg.SEARCH.LABEL_SMOOTH if label_smoothing == 0. else label_smoothing 31 | soft_target = _label_smooth(target, pred.size(1), label_smoothing) 32 | return CrossEntropyLoss_soft_target(pred, soft_target) 33 | 34 | 35 | class KLLossSoft(torch.nn.modules.loss._Loss): 36 | """ inplace distillation for image classification 37 | output: output logits of the student network 38 | target: output logits of the teacher network 39 | T: temperature 40 | KL(p||q) = Ep \log p - \Ep log q 41 | """ 42 | def forward(self, output, soft_logits, target=None, temperature=1., alpha=0.9): 43 | output, soft_logits = output / temperature, soft_logits / temperature 44 | soft_target_prob = F.softmax(soft_logits, dim=1) 45 | output_log_prob = F.log_softmax(output, dim=1) 46 | kd_loss = -torch.sum(soft_target_prob * output_log_prob, dim=1) 47 | if target is not None: 48 | n_class = output.size(1) 49 | target = torch.zeros_like(output).scatter(1, target.view(-1, 1), 1) 50 | target = target.unsqueeze(1) 51 | output_log_prob = output_log_prob.unsqueeze(2) 52 | ce_loss = -torch.bmm(target, output_log_prob).squeeze() 53 | loss = alpha * temperature * temperature * kd_loss + (1.0 - alpha) * ce_loss 54 | else: 55 | loss = kd_loss 56 | 57 | if self.reduction == 'mean': 58 | return loss.mean() 59 | elif self.reduction == 'sum': 60 | return loss.sum() 61 | return loss 62 | 63 | 64 | class MultiHeadCrossEntropyLoss(nn.Module): 65 | def forward(self, preds, targets): 66 | assert preds.dim() == 3, preds 67 | assert targets.dim() == 2, targets 68 | 69 | assert preds.size(1) == targets.size(1), (preds, targets) 70 | num_heads = targets.size(1) 71 | 72 | loss = 0 73 | for k in range(num_heads): 74 | loss += F.cross_entropy(preds[:, k, :], targets[:, k]) / num_heads 75 | return loss 76 | 77 | 78 | # ---------- 79 | 80 | SUPPORTED_CRITERIONS = { 81 | "cross_entropy": torch.nn.CrossEntropyLoss(), 82 | "cross_entropy_soft": CrossEntropyLoss_soft_target, 83 | "cross_entropy_smooth": CrossEntropyLoss_label_smoothed, 84 | "cross_entropy_multihead": MultiHeadCrossEntropyLoss(), 85 | "kl_soft": KLLossSoft(), 86 | } 87 | 88 | 89 | def criterion_builder(criterion=None): 90 | err_str = "Loss function type '{}' not supported" 91 | loss_fun = cfg.SEARCH.LOSS_FUN if criterion is None else criterion 92 | assert loss_fun in SUPPORTED_CRITERIONS.keys(), err_str.format(loss_fun) 93 | return SUPPORTED_CRITERIONS[loss_fun] 94 | -------------------------------------------------------------------------------- /xnas/runner/optimizer.py: -------------------------------------------------------------------------------- 1 | """Optimizers.""" 2 | 3 | import torch 4 | from xnas.core.config import cfg 5 | 6 | 7 | __all__ = [ 8 | 'optimizer_builder', 9 | 'darts_alpha_optimizer', 10 | ] 11 | 12 | 13 | SUPPORTED_OPTIMIZERS = { 14 | "SGD", 15 | "Adam", 16 | } 17 | 18 | 19 | def optimizer_builder(name, param): 20 | """optimizer builder 21 | 22 | Args: 23 | name (str): name of optimizer 24 | param (dict): parameters to optimize 25 | 26 | Returns: 27 | optimizer: optimizer 28 | """ 29 | assert name in SUPPORTED_OPTIMIZERS, "optimizer not supported." 30 | if name == "SGD": 31 | return torch.optim.SGD( 32 | param, 33 | cfg.OPTIM.BASE_LR, 34 | cfg.OPTIM.MOMENTUM, 35 | cfg.OPTIM.DAMPENING, # 0.0 following default 36 | cfg.OPTIM.WEIGHT_DECAY, 37 | cfg.OPTIM.NESTEROV, # False following default 38 | ) 39 | elif name == "Adam": 40 | return torch.optim.Adam( 41 | param, 42 | cfg.OPTIM.BASE_LR, 43 | betas=(0.5, 0.999), 44 | weight_decay=cfg.OPTIM.WEIGHT_DECAY, 45 | ) 46 | 47 | 48 | def darts_alpha_optimizer(name, param): 49 | """alpha optimizer for DARTS-like methods. 50 | Make sure cfg.DARTS has been initialized. 51 | 52 | Args: 53 | name (str): name of optimizer 54 | param (dict): parameters to optimize 55 | 56 | Returns: 57 | optimizer: optimizer 58 | """ 59 | assert name in SUPPORTED_OPTIMIZERS, "optimizer not supported." 60 | if name == "Adam": 61 | return torch.optim.Adam( 62 | param, 63 | cfg.DARTS.ALPHA_LR, 64 | betas=(0.5, 0.999), 65 | weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY, 66 | ) 67 | -------------------------------------------------------------------------------- /xnas/runner/scheduler.py: -------------------------------------------------------------------------------- 1 | """Learning rate schedulers.""" 2 | 3 | import math 4 | import torch 5 | from xnas.core.config import cfg 6 | 7 | from torch.optim.lr_scheduler import _LRScheduler 8 | 9 | 10 | __all__ = ['lr_scheduler_builder', 'adjust_learning_rate_per_batch'] 11 | 12 | 13 | def lr_scheduler_builder(optimizer, last_epoch=-1, **kwargs): 14 | """Learning rate scheduler, now support warmup_epoch.""" 15 | actual_scheduler = None 16 | if cfg.OPTIM.LR_POLICY == "cos": 17 | actual_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 18 | optimizer, 19 | T_max=kwargs['T_max'] if 'T_max' in kwargs.keys() else cfg.OPTIM.MAX_EPOCH, 20 | eta_min=cfg.OPTIM.MIN_LR, 21 | last_epoch=last_epoch) 22 | elif cfg.OPTIM.LR_POLICY == "step": 23 | actual_scheduler = torch.optim.lr_scheduler.MultiStepLR( 24 | optimizer, 25 | cfg.OPTIM.STEPS, 26 | gamma=cfg.OPTIM.LR_MULT, 27 | last_epoch=last_epoch) 28 | else: 29 | raise NotImplementedError 30 | 31 | if cfg.OPTIM.WARMUP_EPOCH > 0: 32 | return GradualWarmupScheduler( 33 | optimizer, 34 | actual_scheduler, 35 | cfg.OPTIM.WARMUP_EPOCH, 36 | cfg.OPTIM.WARMUP_FACTOR, 37 | last_epoch) 38 | else: 39 | return actual_scheduler 40 | 41 | 42 | class GradualWarmupScheduler(_LRScheduler): 43 | """ 44 | Implementation Reference: 45 | https://github.com/ildoonet/pytorch-gradual-warmup-lr 46 | Args: 47 | optimizer (Optimizer): Wrapped optimizer. 48 | factor: start the warm up from init_lr * factor 49 | actual_scheduler: after warmup_epochs, use this scheduler 50 | warmup_epochs: init_lr is reached at warmup_epochs, linearly 51 | """ 52 | 53 | def __init__(self, 54 | optimizer: torch.optim.Optimizer, 55 | actual_scheduler: _LRScheduler, 56 | warmup_epochs: int, 57 | factor: float, 58 | last_epoch=-1): 59 | super().__init__(optimizer, last_epoch) 60 | 61 | self.actual_scheduler = actual_scheduler 62 | self.warmup_epochs = warmup_epochs 63 | self.factor = factor 64 | self.last_epoch = last_epoch 65 | 66 | 67 | def get_lr(self): 68 | if self.last_epoch > self.warmup_epochs: 69 | return self.actual_scheduler.get_lr() 70 | else: 71 | return [base_lr * ((1. - self.factor) * self.last_epoch / self.warmup_epochs + self.factor) for base_lr in self.base_lrs] 72 | 73 | 74 | def step(self, epoch=None): 75 | if self.last_epoch > self.warmup_epochs: 76 | if epoch is None: 77 | self.actual_scheduler.step(None) 78 | else: 79 | self.actual_scheduler.step(epoch - self.warmup_epochs) 80 | self._last_lr = self.actual_scheduler.get_last_lr() 81 | else: 82 | return super(GradualWarmupScheduler, self).step(epoch) 83 | 84 | 85 | def _calc_learning_rate( 86 | init_lr, n_epochs, epoch, n_iter=None, iter=0, 87 | ): 88 | if cfg.OPTIM.LR_POLICY == "cos": 89 | t_total = n_epochs * n_iter 90 | t_cur = epoch * n_iter + iter 91 | lr = 0.5 * init_lr * (1 + math.cos(math.pi * t_cur / t_total)) 92 | else: 93 | raise ValueError("do not support: {}".format(cfg.OPTIM.LR_POLICY)) 94 | return lr 95 | 96 | 97 | def _warmup_adjust_learning_rate( 98 | init_lr, n_epochs, epoch, n_iter, iter=0, warmup_lr=0 99 | ): 100 | """adjust lr during warming-up. Changes linearly from `warmup_lr` to `init_lr`.""" 101 | T_cur = epoch * n_iter + iter + 1 102 | t_total = n_epochs * n_iter 103 | new_lr = T_cur / t_total * (init_lr - warmup_lr) + warmup_lr 104 | return new_lr 105 | 106 | 107 | def adjust_learning_rate_per_batch(epoch, n_iter=None, iter=0, warmup=False): 108 | """adjust learning of a given optimizer and return the new learning rate""" 109 | 110 | init_lr = cfg.OPTIM.BASE_LR * cfg.NUM_GPUS 111 | n_epochs = cfg.OPTIM.MAX_EPOCH 112 | n_warmup_epochs = cfg.OPTIM.WARMUP_EPOCH 113 | warmup_lr = init_lr * cfg.OPTIM.WARMUP_FACTOR 114 | 115 | if warmup: 116 | new_lr = _warmup_adjust_learning_rate( 117 | init_lr, n_warmup_epochs, epoch, n_iter, iter, warmup_lr 118 | ) 119 | else: 120 | new_lr = _calc_learning_rate( 121 | init_lr, n_epochs, epoch, n_iter, iter 122 | ) 123 | return new_lr 124 | -------------------------------------------------------------------------------- /xnas/spaces/BigNAS/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from collections import OrderedDict 5 | from xnas.spaces.OFA.ops import SEModule, build_activation 6 | from xnas.spaces.OFA.utils import ( 7 | get_same_padding, 8 | ) 9 | 10 | class MBConvLayer(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size=3, 16 | stride=1, 17 | expand_ratio=6, 18 | mid_channels=None, 19 | act_func="relu6", 20 | use_se=False, 21 | channels_per_group=1, 22 | ): 23 | super(MBConvLayer, self).__init__() 24 | 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | 28 | self.kernel_size = kernel_size 29 | self.stride = stride 30 | self.expand_ratio = expand_ratio 31 | self.mid_channels = mid_channels 32 | self.act_func = act_func 33 | self.use_se = use_se 34 | self.channels_per_group = channels_per_group 35 | 36 | if self.mid_channels is None: 37 | feature_dim = round(self.in_channels * self.expand_ratio) 38 | else: 39 | feature_dim = self.mid_channels 40 | 41 | if self.expand_ratio == 1: 42 | self.inverted_bottleneck = None 43 | else: 44 | self.inverted_bottleneck = nn.Sequential(OrderedDict([ 45 | ("conv", nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)), 46 | ("bn", nn.BatchNorm2d(feature_dim)), 47 | ("act", build_activation(self.act_func, inplace=True)), 48 | ])) 49 | 50 | assert feature_dim % self.channels_per_group == 0 51 | active_groups = feature_dim // self.channels_per_group 52 | pad = get_same_padding(self.kernel_size) 53 | 54 | # assert feature_dim % self.groups == 0 55 | # active_groups = feature_dim // self.groups 56 | depth_conv_modules = [ 57 | ( 58 | "conv", 59 | nn.Conv2d( 60 | feature_dim, 61 | feature_dim, 62 | kernel_size, 63 | stride, 64 | pad, 65 | groups=active_groups, 66 | bias=False, 67 | ), 68 | ), 69 | ("bn", nn.BatchNorm2d(feature_dim)), 70 | ("act", build_activation(self.act_func, inplace=True)), 71 | ] 72 | if self.use_se: 73 | depth_conv_modules.append(("se", SEModule(feature_dim))) 74 | self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules)) 75 | 76 | self.point_linear = nn.Sequential( 77 | OrderedDict( 78 | [ 79 | ("conv", nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)), 80 | ("bn", nn.BatchNorm2d(out_channels)), 81 | ] 82 | ) 83 | ) 84 | 85 | def forward(self, x): 86 | if self.inverted_bottleneck: 87 | x = self.inverted_bottleneck(x) 88 | x = self.depth_conv(x) 89 | x = self.point_linear(x) 90 | return x 91 | 92 | @property 93 | def module_str(self): 94 | if self.mid_channels is None: 95 | expand_ratio = self.expand_ratio 96 | else: 97 | expand_ratio = self.mid_channels // self.in_channels 98 | layer_str = "%dx%d_MBConv%d_%s" % ( 99 | self.kernel_size, 100 | self.kernel_size, 101 | expand_ratio, 102 | self.act_func.upper(), 103 | ) 104 | if self.use_se: 105 | layer_str = "SE_" + layer_str 106 | layer_str += "_O%d" % self.out_channels 107 | if self.channels_per_group is not None: 108 | layer_str += "_G%d" % self.channels_per_group 109 | if isinstance(self.point_linear.bn, nn.GroupNorm): 110 | layer_str += "_GN%d" % self.point_linear.bn.num_groups 111 | elif isinstance(self.point_linear.bn, nn.BatchNorm2d): 112 | layer_str += "_BN" 113 | 114 | return layer_str 115 | 116 | @property 117 | def config(self): 118 | return { 119 | "name": MBConvLayer.__name__, 120 | "in_channels": self.in_channels, 121 | "out_channels": self.out_channels, 122 | "kernel_size": self.kernel_size, 123 | "stride": self.stride, 124 | "expand_ratio": self.expand_ratio, 125 | "mid_channels": self.mid_channels, 126 | "act_func": self.act_func, 127 | "use_se": self.use_se, 128 | "channeles_per_group": self.channels_per_group, 129 | } 130 | 131 | @staticmethod 132 | def build_from_config(config): 133 | return MBConvLayer(**config) 134 | -------------------------------------------------------------------------------- /xnas/spaces/BigNAS/utils.py: -------------------------------------------------------------------------------- 1 | # Implementation adapted from attentiveNAS - https://github.com/facebookresearch/AttentiveNAS 2 | 3 | import torch 4 | import torch.nn as nn 5 | import copy 6 | import math 7 | 8 | multiply_adds = 1 9 | 10 | 11 | def count_convNd(m, _, y): 12 | cin = m.in_channels 13 | 14 | kernel_ops = m.weight.size()[2] * m.weight.size()[3] 15 | ops_per_element = kernel_ops 16 | output_elements = y.nelement() 17 | 18 | # cout x oW x oH 19 | total_ops = cin * output_elements * ops_per_element // m.groups 20 | m.total_ops = torch.Tensor([int(total_ops)]) 21 | 22 | 23 | def count_linear(m, _, __): 24 | total_ops = m.in_features * m.out_features 25 | 26 | m.total_ops = torch.Tensor([int(total_ops)]) 27 | 28 | 29 | register_hooks = { 30 | nn.Conv1d: count_convNd, 31 | nn.Conv2d: count_convNd, 32 | nn.Conv3d: count_convNd, 33 | ###################################### 34 | nn.Linear: count_linear, 35 | ###################################### 36 | nn.Dropout: None, 37 | nn.Dropout2d: None, 38 | nn.Dropout3d: None, 39 | nn.BatchNorm2d: None, 40 | } 41 | 42 | 43 | def profile(model, input_size=(1, 3, 224, 224), custom_ops=None): 44 | handler_collection = [] 45 | custom_ops = {} if custom_ops is None else custom_ops 46 | 47 | def add_hooks(m_): 48 | if len(list(m_.children())) > 0: 49 | return 50 | 51 | m_.register_buffer('total_ops', torch.zeros(1)) 52 | m_.register_buffer('total_params', torch.zeros(1)) 53 | 54 | for p in m_.parameters(): 55 | m_.total_params += torch.Tensor([p.numel()]) 56 | 57 | m_type = type(m_) 58 | fn = None 59 | 60 | if m_type in custom_ops: 61 | fn = custom_ops[m_type] 62 | elif m_type in register_hooks: 63 | fn = register_hooks[m_type] 64 | else: 65 | # print("Not implemented for ", m_) 66 | pass 67 | 68 | if fn is not None: 69 | # print("Register FLOP counter for module %s" % str(m_)) 70 | _handler = m_.register_forward_hook(fn) 71 | handler_collection.append(_handler) 72 | 73 | original_device = model.parameters().__next__().device 74 | training = model.training 75 | 76 | model.eval() 77 | model.apply(add_hooks) 78 | 79 | x = torch.zeros(input_size).to(original_device) 80 | with torch.no_grad(): 81 | model(x) 82 | 83 | total_ops = 0 84 | total_params = 0 85 | for m in model.modules(): 86 | if len(list(m.children())) > 0: # skip for non-leaf module 87 | continue 88 | total_ops += m.total_ops 89 | total_params += m.total_params 90 | 91 | total_ops = total_ops.item() 92 | total_params = total_params.item() 93 | 94 | model.train(training) 95 | model.to(original_device) 96 | 97 | for handler in handler_collection: 98 | handler.remove() 99 | 100 | return total_ops, total_params 101 | 102 | 103 | def count_net_flops_and_params(net, data_shape=(1, 3, 224, 224)): 104 | if isinstance(net, nn.DataParallel): 105 | net = net.module 106 | 107 | net = copy.deepcopy(net) 108 | flop, nparams = profile(net, data_shape) 109 | return flop /1e6, nparams /1e6 110 | 111 | 112 | def init_model(self, model_init="he_fout"): 113 | """ Conv2d, BatchNorm2d, BatchNorm1d, Linear, """ 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | if model_init == 'he_fout': 117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 118 | m.weight.data.normal_(0, math.sqrt(2. / n)) 119 | elif model_init == 'he_fin': 120 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 121 | m.weight.data.normal_(0, math.sqrt(2. / n)) 122 | else: 123 | raise NotImplementedError 124 | if m.bias is not None: 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 127 | if m.affine: 128 | m.weight.data.fill_(1) 129 | m.bias.data.zero_() 130 | elif isinstance(m, nn.Linear): 131 | stdv = 1. / math.sqrt(m.weight.size(1)) 132 | m.weight.data.uniform_(-stdv, stdv) 133 | if m.bias is not None: 134 | m.bias.data.zero_() -------------------------------------------------------------------------------- /xnas/spaces/DARTS/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/xnas/spaces/DARTS/__init__.py -------------------------------------------------------------------------------- /xnas/spaces/DARTS/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from .ops import * 3 | 4 | 5 | basic_op_list = ['max_pool_3x3', 'avg_pool_3x3', 'skip_connect', 'sep_conv_3x3', 'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5', 'none'] 6 | 7 | 8 | def geno_from_alpha(theta): 9 | Genotype = namedtuple( 10 | 'Genotype', 'normal normal_concat reduce reduce_concat') 11 | theta_norm = darts_weight_unpack( 12 | theta[0:14], 4) 13 | theta_reduce = darts_weight_unpack( 14 | theta[14:], 4) 15 | gene_normal = parse_from_numpy( 16 | theta_norm, k=2, basic_op_list=basic_op_list) 17 | gene_reduce = parse_from_numpy( 18 | theta_reduce, k=2, basic_op_list=basic_op_list) 19 | concat = range(2, 6) # concat all intermediate nodes 20 | return Genotype(normal=gene_normal, normal_concat=concat, 21 | reduce=gene_reduce, reduce_concat=concat) 22 | 23 | def reformat_DARTS(genotype): 24 | """ 25 | format genotype for DARTS-like 26 | from: 27 | Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_5x5', 0)], [('sep_conv_3x3', 2), ('max_pool_3x3', 1)], [('sep_conv_3x3', 3), ('dil_conv_3x3', 2)], [('dil_conv_5x5', 4), ('dil_conv_5x5', 3)]], normal_concat=range(2, 6), reduce=[[('max_pool_3x3', 0), ('sep_conv_5x5', 1)], [('max_pool_3x3', 0), ('dil_conv_5x5', 2)], [('max_pool_3x3', 0), ('sep_conv_5x5', 1)], [('dil_conv_5x5', 4), ('max_pool_3x3', 0)]], reduce_concat=range(2, 6)) 28 | to: 29 | Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], reduce_concat=[2, 3, 4, 5]) 30 | """ 31 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 32 | _normal = [] 33 | _reduce = [] 34 | for i in genotype.normal: 35 | for j in i: 36 | _normal.append(j) 37 | for i in genotype.reduce: 38 | for j in i: 39 | _reduce.append(j) 40 | _normal_concat = [i for i in genotype.normal_concat] 41 | _reduce_concat = [i for i in genotype.reduce_concat] 42 | r_genotype = Genotype( 43 | normal=_normal, 44 | normal_concat=_normal_concat, 45 | reduce=_reduce, 46 | reduce_concat=_reduce_concat 47 | ) 48 | return r_genotype 49 | -------------------------------------------------------------------------------- /xnas/spaces/DrNAS/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def optimizer_transfer(optimizer_old, optimizer_new): 6 | for i, p in enumerate(optimizer_new.param_groups[0]['params']): 7 | if not hasattr(p, 'raw_id'): 8 | optimizer_new.state[p] = optimizer_old.state[p] 9 | continue 10 | state_old = optimizer_old.state_dict()['state'][p.raw_id] 11 | state_new = optimizer_new.state[p] 12 | 13 | state_new['momentum_buffer'] = state_old['momentum_buffer'] 14 | if p.t == 'bn': 15 | # BN layer 16 | state_new['momentum_buffer'] = torch.cat( 17 | [state_new['momentum_buffer'], state_new['momentum_buffer'][p.out_index].clone()], dim=0) 18 | # clean to enable multiple call 19 | del p.t, p.raw_id, p.out_index 20 | 21 | elif p.t == 'conv': 22 | # conv layer 23 | if hasattr(p, 'in_index'): 24 | state_new['momentum_buffer'] = torch.cat( 25 | [state_new['momentum_buffer'], state_new['momentum_buffer'][:, p.in_index, :, :].clone()], dim=1) 26 | if hasattr(p, 'out_index'): 27 | state_new['momentum_buffer'] = torch.cat( 28 | [state_new['momentum_buffer'], state_new['momentum_buffer'][p.out_index, :, :, :].clone()], dim=0) 29 | # clean to enable multiple call 30 | del p.t, p.raw_id 31 | if hasattr(p, 'in_index'): 32 | del p.in_index 33 | if hasattr(p, 'out_index'): 34 | del p.out_index 35 | print('%d momemtum buffers loaded' % (i+1)) 36 | return optimizer_new 37 | 38 | def scheduler_transfer(scheduler_old, scheduler_new): 39 | scheduler_new.load_state_dict(scheduler_old.state_dict()) 40 | print('scheduler loaded') 41 | return scheduler_new 42 | 43 | 44 | def process_step_vector(x, method, mask, tau=None): 45 | if method == "softmax": 46 | output = F.softmax(x, dim=-1) 47 | elif method == "dirichlet": 48 | output = torch.distributions.dirichlet.Dirichlet(F.elu(x) + 1).rsample() 49 | elif method == "gumbel": 50 | output = F.gumbel_softmax(x, tau=tau, hard=False, dim=-1) 51 | 52 | if mask is None: 53 | return output 54 | else: 55 | output_pruned = torch.zeros_like(output) 56 | output_pruned[mask] = output[mask] 57 | output_pruned /= output_pruned.sum() 58 | assert (output_pruned[~mask] == 0.0).all() 59 | return output_pruned 60 | 61 | 62 | def process_step_matrix(x, method, mask, tau=None): 63 | weights = [] 64 | if mask is None: 65 | for line in x: 66 | weights.append(process_step_vector(line, method, None, tau)) 67 | else: 68 | for i, line in enumerate(x): 69 | weights.append(process_step_vector(line, method, mask[i], tau)) 70 | return torch.stack(weights) 71 | 72 | 73 | def prune(x, num_keep, mask, reset=False): 74 | if not mask is None: 75 | x.data[~mask] -= 1000000 76 | src, index = x.topk(k=num_keep, dim=-1) 77 | if not reset: 78 | x.data.copy_(torch.zeros_like(x).scatter(dim=1, index=index, src=src)) 79 | else: 80 | x.data.copy_( 81 | torch.zeros_like(x).scatter( 82 | dim=1, index=index, src=1e-3 * torch.randn_like(src) 83 | ) 84 | ) 85 | mask = torch.zeros_like(x, dtype=torch.bool).scatter( 86 | dim=1, index=index, src=torch.ones_like(src, dtype=torch.bool) 87 | ) 88 | return mask 89 | -------------------------------------------------------------------------------- /xnas/spaces/DropNAS/cnn.py: -------------------------------------------------------------------------------- 1 | from xnas.spaces.DARTS.ops import * 2 | import xnas.spaces.DARTS.genos as gt 3 | 4 | 5 | class Drop_MixedOp(nn.Module): 6 | """ Mixed operation """ 7 | 8 | def __init__(self, C_in, C_out, stride): 9 | super().__init__() 10 | self._ops = nn.ModuleList() 11 | for primitive in gt.PRIMITIVES: 12 | op = OPS[primitive](C_in, C_out, stride, affine=False) 13 | self._ops.append(op) 14 | 15 | def forward(self, x, weights, masks): 16 | """ 17 | Args: 18 | x: input 19 | weights: weight for each operation 20 | masks: list of boolean 21 | """ 22 | return sum(w * op(x) for w, op, mask in zip(weights, self._ops, masks) if mask) 23 | # return sum(w * op(x) for w, op in zip(weights, self._ops)) 24 | 25 | 26 | class DropNASCell(nn.Module): 27 | """ Cell for search 28 | Each edge is mixed and continuous relaxed. 29 | """ 30 | 31 | def __init__(self, n_nodes, C_pp, C_p, C, reduction_p, reduction): 32 | """ 33 | Args: 34 | n_nodes: # of intermediate n_nodes 35 | C_pp: C_out[k-2] 36 | C_p : C_out[k-1] 37 | C : C_in[k] (current) 38 | reduction_p: flag for whether the previous cell is reduction cell or not 39 | reduction: flag for whether the current cell is reduction cell or not 40 | """ 41 | super().__init__() 42 | self.reduction = reduction 43 | self.n_nodes = n_nodes 44 | 45 | # If previous cell is reduction cell, current input size does not match with 46 | # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. 47 | if reduction_p: 48 | self.preproc0 = FactorizedReduce(C_pp, C, affine=False) 49 | else: 50 | self.preproc0 = ReluConvBn(C_pp, C, 1, 1, 0, affine=False) 51 | self.preproc1 = ReluConvBn(C_p, C, 1, 1, 0, affine=False) 52 | 53 | # generate dag 54 | self.dag = nn.ModuleList() 55 | for i in range(self.n_nodes): 56 | self.dag.append(nn.ModuleList()) 57 | for j in range(2 + i): # include 2 input nodes 58 | # reduction should be used only for input node 59 | stride = 2 if reduction and j < 2 else 1 60 | op = Drop_MixedOp(C, C, stride) 61 | self.dag[i].append(op) 62 | 63 | def forward(self, s0, s1, w_dag, masks): 64 | s0 = self.preproc0(s0) 65 | s1 = self.preproc1(s1) 66 | 67 | states = [s0, s1] 68 | for edges, w_list, m_list in zip(self.dag, w_dag, masks): 69 | s_cur = sum(edges[i](s, w, m) for i, (s, w, m) in enumerate(zip(states, w_list, m_list))) 70 | states.append(s_cur) 71 | 72 | s_out = torch.cat(states[2:], dim=1) 73 | return s_out 74 | 75 | 76 | class DropNASCNN(nn.Module): 77 | """ Search CNN model """ 78 | 79 | def __init__(self, C_in, C, n_classes, n_layers, n_nodes=4, stem_multiplier=3): 80 | """ 81 | Args: 82 | C_in: # of input channels 83 | C: # of starting model channels 84 | n_classes: # of classes 85 | n_layers: # of layers 86 | n_nodes: # of intermediate nodes in Cell 87 | """ 88 | super().__init__() 89 | self.C_in = C_in 90 | self.C = C 91 | self.n_classes = n_classes 92 | self.n_layers = n_layers 93 | 94 | C_cur = stem_multiplier * C 95 | self.stem = nn.Sequential( 96 | nn.Conv2d(C_in, C_cur, 3, 1, 1, bias=False), 97 | nn.BatchNorm2d(C_cur) 98 | ) 99 | 100 | # for the first cell, stem is used for both s0 and s1 101 | # [!] C_pp and C_p is output channel size, but C_cur is input channel size. 102 | C_pp, C_p, C_cur = C_cur, C_cur, C 103 | 104 | self.cells = nn.ModuleList() 105 | reduction_p = False 106 | for i in range(n_layers): 107 | # Reduce featuremap size and double channels in 1/3 and 2/3 layer. 108 | if i in [n_layers // 3, 2 * n_layers // 3]: 109 | C_cur *= 2 110 | reduction = True 111 | else: 112 | reduction = False 113 | 114 | cell = DropNASCell( 115 | n_nodes=n_nodes, 116 | C_pp=C_pp, 117 | C_p=C_p, 118 | C=C_cur, 119 | reduction_p=reduction_p, 120 | reduction=reduction 121 | ) 122 | 123 | reduction_p = reduction 124 | self.cells.append(cell) 125 | C_cur_out = C_cur * n_nodes 126 | C_pp, C_p = C_p, C_cur_out 127 | 128 | self.gap = nn.AdaptiveAvgPool2d(1) 129 | self.linear = nn.Linear(C_p, n_classes) 130 | 131 | def forward(self, x, weights_normal, weights_reduce, masks_normal, masks_reduce): 132 | """ 133 | Args: 134 | weights_xxx: probability contribution of each operation 135 | masks_xxx: decide whether to drop an operation 136 | """ 137 | s0 = s1 = self.stem(x) 138 | 139 | for i, cell in enumerate(self.cells): 140 | weights = weights_reduce if cell.reduction else weights_normal 141 | masks = masks_reduce if cell.reduction else masks_normal 142 | s0, s1 = s1, cell(s0, s1, weights, masks) 143 | 144 | out = self.gap(s1) 145 | out = out.view(out.size(0), -1) # flatten 146 | logits = self.linear(out) 147 | return logits 148 | 149 | 150 | # build API 151 | def _DropNASCNN(): 152 | from xnas.core.config import cfg 153 | return DropNASCNN( 154 | C_in=cfg.SEARCH.INPUT_CHANNELS, 155 | C=cfg.SPACE.CHANNELS, 156 | n_classes=cfg.LOADER.NUM_CLASSES, 157 | n_layers=cfg.SPACE.LAYERS, 158 | n_nodes=cfg.SPACE.NODES, 159 | ) 160 | -------------------------------------------------------------------------------- /xnas/spaces/NASBench1Shot1/cnn.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/xnas/spaces/NASBench1Shot1/cnn.py -------------------------------------------------------------------------------- /xnas/spaces/NASBench1Shot1/ops.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/xnas/spaces/NASBench1Shot1/ops.py -------------------------------------------------------------------------------- /xnas/spaces/NASBench201/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | from .genos import Structure as CellStructure 4 | from .cnn import TinyNetwork 5 | 6 | 7 | def dict2config(xdict, logger): 8 | assert isinstance(xdict, dict), "invalid type : {:}".format(type(xdict)) 9 | Arguments = namedtuple("Configure", " ".join(xdict.keys())) 10 | content = Arguments(**xdict) 11 | if hasattr(logger, "log"): 12 | logger.log("{:}".format(content)) 13 | return content 14 | 15 | def get_cell_based_tiny_net(config): 16 | if hasattr(config, "genotype"): 17 | genotype = config.genotype 18 | elif hasattr(config, "arch_str"): 19 | genotype = CellStructure.str2structure(config.arch_str) 20 | else: 21 | raise ValueError( 22 | "Can not find genotype from this config : {:}".format(config) 23 | ) 24 | return TinyNetwork(config.C, config.N, genotype, config.num_classes) 25 | -------------------------------------------------------------------------------- /xnas/spaces/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/xnas/spaces/__init__.py --------------------------------------------------------------------------------