├── .gitignore
├── .readthedocs.yml
├── LICENSE
├── README.md
├── configs
├── search
│ ├── AttentiveNAS
│ │ ├── eval.yaml
│ │ └── train.yaml
│ ├── BigNAS
│ │ ├── eval.yaml
│ │ ├── search.yaml
│ │ └── train.yaml
│ ├── DrNAS
│ │ ├── drnas_darts_cifar10.yaml
│ │ ├── drnas_darts_imagenet.yaml
│ │ ├── drnas_nasbench201_cifar10.yaml
│ │ ├── drnas_nasbench201_cifar100.yaml
│ │ └── drnas_nasbench201_cifar10_progressive.yaml
│ ├── OFA
│ │ └── mbv3
│ │ │ ├── depth_1.yaml
│ │ │ ├── depth_2.yaml
│ │ │ ├── expand_1.yaml
│ │ │ ├── expand_2.yaml
│ │ │ ├── kernel_1.yaml
│ │ │ └── normal_1.yaml
│ ├── RMINAS
│ │ ├── rminas_darts_cifar10.yaml
│ │ ├── rminas_darts_cifar100.yaml
│ │ ├── rminas_nasbench201_cifar10.yaml
│ │ ├── rminas_nasbench201_cifar100.yaml
│ │ ├── rminas_nasbench201_imagenet16.yaml
│ │ ├── rminas_nasbenchmacro_cifar10.yaml
│ │ └── rminas_proxyless_imagenet.yaml
│ ├── SNG
│ │ ├── ASNG_darts_cifar10.yaml
│ │ ├── DDPNAS_darts_cifar10.yaml
│ │ ├── DDPNAS_nasbench201_cifar10.yaml
│ │ ├── DynamicASNG_darts_cifar10.yaml
│ │ ├── DynamicSNG_darts_cifar10.yaml
│ │ ├── MIGO_darts_cifar10.yaml
│ │ └── SNG_darts_cifar10.yaml
│ ├── darts_darts_cifar10.yaml
│ ├── darts_nasbench201_cifar10.yaml
│ ├── dropnas_darts_cifar10.yaml
│ ├── gdas_nasbench201_cifar10.yaml
│ ├── nasbenchmacro_nasbenchmacro_cifar10.yaml
│ ├── pcdarts_pcdarts_cifar10.yaml
│ ├── pdarts_pdarts_cifar10.yaml
│ ├── snas_nasbench201_cifar10.yaml
│ ├── spos_nasbench201_cifar10.yaml
│ └── spos_spos_cifar10.yaml
└── train
│ ├── darts_darts_cifar10.yaml
│ └── spos_spos_cifar10.yaml
├── docs
├── conf.py
├── data_preparation.md
├── get_started.md
├── index.rst
├── notes.md
└── requirements.txt
├── examples
├── search
│ ├── DrNAS
│ │ ├── drnas_darts_cifar10.sh
│ │ ├── drnas_darts_imagenet.sh
│ │ ├── drnas_nasbench201_cifar10.sh
│ │ ├── drnas_nasbench201_cifar100.sh
│ │ └── drnas_nasbench201_cifar10_progressive.sh
│ ├── OFA
│ │ └── train_supernet.sh
│ ├── RMINAS
│ │ ├── rminas_darts_cifar10.sh
│ │ ├── rminas_nasbench201_cifar10.sh
│ │ ├── rminas_nasbench201_cifar100.sh
│ │ └── rminas_nasbench201_imagenet16.sh
│ ├── darts_darts_cifar10.sh
│ ├── darts_nasbench201_cifar10.sh
│ ├── dropnas_darts_cifar10.sh
│ ├── gdas_nasbench201_cifar10.sh
│ ├── ofa_mbv3_imagenet.sh
│ ├── pcdarts_pcdarts_cifar10.sh
│ ├── pdarts_pdarts_cifar10.sh
│ └── snas_nasbench201_cifar10.sh
└── train
│ └── darts_darts_cifar10.sh
├── requirements.txt
├── scripts
├── search
│ ├── AttentiveNAS
│ │ └── train_supernet.py
│ ├── BigNAS
│ │ ├── eval.py
│ │ ├── search.py
│ │ └── train_supernet.py
│ ├── DARTS.py
│ ├── DrNAS.py
│ ├── DropNAS.py
│ ├── NBMacro.py
│ ├── OFA
│ │ ├── eval_supernet.py
│ │ ├── search.py
│ │ ├── train_supernet.py
│ │ └── train_teacher_model.py
│ ├── PCDARTS.py
│ ├── PDARTS.py
│ ├── RMINAS.py
│ ├── SNG
│ │ ├── func_optimize.py
│ │ ├── nb1shot1.py
│ │ └── search.py
│ └── SPOS.py
└── train
│ └── DARTS.py
├── tests
├── MultiSizeRandomCrop.py
├── adjust_lr_per_iter.py
├── configs
│ ├── datasets.yaml
│ └── imagenet.yaml
├── dataloader.py
├── imagenet.py
├── nb101space.py
├── nb1shot1space.py
├── nb201eval.py
├── nb301eval.py
├── ofa_matrices_test.py
└── test_REA.py
└── xnas
├── algorithms
├── AttentiveNAS
│ └── sampler.py
├── DARTS.py
├── DrNAS.py
├── DropNAS.py
├── PCDARTS.py
├── PDARTS.py
├── RMINAS
│ ├── README.md
│ ├── download_weight.sh
│ ├── sampler
│ │ ├── RF_sampling.py
│ │ ├── available_archs.txt
│ │ └── sampling.py
│ ├── teacher_model
│ │ ├── fbresnet_imagenet
│ │ │ └── fbresnet.py
│ │ ├── resnet101_cifar100
│ │ │ └── resnet.py
│ │ └── resnet20_cifar10
│ │ │ └── resnet.py
│ └── utils
│ │ ├── RMI_torch.py
│ │ ├── get_accuracy.ipynb
│ │ └── random_data.py
├── SNG
│ ├── ASNG.py
│ ├── DDPNAS.py
│ ├── GridSearch.py
│ ├── MDENAS.py
│ ├── MIGO.py
│ ├── RAND.py
│ ├── SNG.py
│ └── categorical.py
└── SPOS.py
├── core
├── builder.py
├── config.py
└── utils.py
├── datasets
├── __init__.py
├── auto_augment_tf.py
├── imagenet.py
├── imagenet16.py
├── loader.py
├── transforms.py
├── transforms_imagenet.py
└── utils
│ ├── msrc_loader.py
│ └── msrc_worker.py
├── evaluations
├── NASBench201.py
├── NASBench301.py
├── NASBenchMacro
│ ├── evaluate.py
│ └── nas-bench-macro_cifar10.json
└── NASBenchmacro
│ ├── evaluate.py
│ └── nas-bench-macro_cifar10.json
├── logger
├── checkpoint.py
├── logging.py
├── meter.py
└── timer.py
├── runner
├── __init__.py
├── criterion.py
├── optimizer.py
├── scheduler.py
├── trainer.py
└── trainer_spos.py
└── spaces
├── AttentiveNAS
└── cnn.py
├── BigNAS
├── cnn.py
├── dynamic_layers.py
├── dynamic_ops.py
├── ops.py
└── utils.py
├── DARTS
├── __init__.py
├── cnn.py
├── genos.py
├── ops.py
└── utils.py
├── DrNAS
├── darts_cnn.py
├── nb201_cnn.py
└── utils.py
├── DropNAS
└── cnn.py
├── NASBench1Shot1
├── cnn.py
└── ops.py
├── NASBench201
├── cnn.py
├── genos.py
├── ops.py
└── utils.py
├── NASBenchMacro
└── cnn.py
├── OFA
├── MobileNetV3
│ ├── cnn.py
│ └── ofa_cnn.py
├── ProxylessNet
│ ├── cnn.py
│ └── ofa_cnn.py
├── ResNets
│ ├── cnn.py
│ └── ofa_cnn.py
├── dynamic_ops.py
├── ops.py
└── utils.py
├── PCDARTS
└── cnn.py
├── PDARTS
└── cnn.py
├── ProxylessNAS
└── cnn.py
├── SPOS
└── cnn.py
└── __init__.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | data/
3 | experiment/
4 | exp/
5 | __pycache__/
6 | tmp/
7 | .idea/
8 | .vscode/
9 | *.py[cod]
10 | *.pdf
11 | *$py.class
12 | *.DS_Store
13 | *hyper.md
14 | *debug.py
15 | benchmark/
16 | test.py
17 | test.ipynb
18 | nohup*.out
19 |
20 | # model weights
21 | *.pth
22 | *.th
23 |
24 | # C extensions
25 | *.so
26 |
27 | # Distribution / packaging
28 | .Python
29 | build/
30 | develop-eggs/
31 | dist/
32 | downloads/
33 | eggs/
34 | .eggs/
35 | # lib/
36 | lib64/
37 | parts/
38 | sdist/
39 | var/
40 | wheels/
41 | *.egg-info/
42 | .installed.cfg
43 | *.egg
44 | MANIFEST
45 |
46 | # PyInstaller
47 | # Usually these files are written by a python script from a template
48 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
49 | *.manifest
50 | *.spec
51 |
52 | # Installer logs
53 | pip-log.txt
54 | pip-delete-this-directory.txt
55 |
56 | # Unit test / coverage reports
57 | htmlcov/
58 | .tox/
59 | .coverage
60 | .coverage.*
61 | .cache
62 | nosetests.xml
63 | coverage.xml
64 | *.cover
65 | .hypothesis/
66 | .pytest_cache/
67 |
68 | # Translations
69 | *.mo
70 | *.pot
71 |
72 | # Django stuff:
73 | *.log
74 | local_settings.py
75 | db.sqlite3
76 |
77 | # Flask stuff:
78 | instance/
79 | .webassets-cache
80 |
81 | # Scrapy stuff:
82 | .scrapy
83 |
84 | # Sphinx documentation
85 | docs/_build/
86 | docs/build/
87 |
88 | # PyBuilder
89 | target/
90 |
91 | # Jupyter Notebook
92 | .ipynb_checkpoints
93 |
94 | # pyenv
95 | .python-version
96 |
97 | # celery beat schedule file
98 | celerybeat-schedule
99 |
100 | # SageMath parsed files
101 | *.sage.py
102 |
103 | # Environments
104 | .env
105 | .venv
106 | env/
107 | venv/
108 | ENV/
109 | env.bak/
110 | venv.bak/
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 |
125 | #exp_out
126 | *.pkl
127 | *.png
128 |
129 | 301model/
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | sphinx:
4 | configuration: docs/conf.py
5 |
6 | python:
7 | version: 3.7
8 | install:
9 | - requirements: docs/requirements.txt
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 SherwoodZheng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/configs/search/AttentiveNAS/eval.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 4
2 | RNG_SEED: 2
3 | SPACE:
4 | NAME: 'attentivenas'
5 | LOADER:
6 | DATASET: 'imagenet'
7 | NUM_CLASSES: 1000
8 | BATCH_SIZE: 256
9 | NUM_WORKERS: 4
10 | USE_VAL: True
11 | TRANSFORM: "auto_augment_tf"
12 | SEARCH:
13 | IM_SIZE: 224
14 | ATTENTIVENAS:
15 | BN_MOMENTUM: 0.1
16 | BN_EPS: 1.e-5
17 | POST_BN_CALIBRATION_BATCH_NUM: 64
18 | ACTIVE_SUBNET: # chosen from following settings
19 | # attentive_nas_a0
20 | RESOLUTION: 192
21 | WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792]
22 | KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3]
23 | EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
24 | DEPTH: [1, 3, 3, 3, 3, 3, 1]
25 |
26 | # # attentive_nas_a1
27 | # RESOLUTION: 224
28 | # WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1984]
29 | # KERNEL_SIZE: [3, 3, 3, 5, 3, 5, 3]
30 | # EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
31 | # DEPTH: [1, 3, 3, 3, 3, 3, 1]
32 |
33 | # # attentive_nas_a2
34 | # RESOLUTION: 224
35 | # WIDTH: [16, 16, 24, 32, 64, 112, 200, 224, 1984]
36 | # KERNEL_SIZE: [3, 3, 3, 3, 3, 5, 3]
37 | # EXPAND_RATIO: [1, 4, 5, 4, 4, 6, 6]
38 | # DEPTH: [1, 3, 3, 3, 3, 4, 1]
39 |
40 | # # attentive_nas_a3
41 | # RESOLUTION: 224
42 | # WIDTH: [16, 16, 24, 32, 64, 112, 208, 224, 1984]
43 | # KERNEL_SIZE: [3, 3, 3, 5, 3, 3, 3]
44 | # EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
45 | # DEPTH: [2, 3, 3, 4, 3, 5, 1]
46 |
47 | # # attentive_nas_a4
48 | # RESOLUTION: 256
49 | # WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1984]
50 | # KERNEL_SIZE: [3, 3, 3, 5, 3, 5, 3]
51 | # EXPAND_RATIO: [1, 4, 4, 5, 4, 6, 6]
52 | # DEPTH: [1, 3, 3, 4, 3, 5, 1]
53 |
54 | # # attentive_nas_a5
55 | # RESOLUTION: 256
56 | # WIDTH: [16, 16, 24, 32, 72, 112, 192, 216, 1792]
57 | # KERNEL_SIZE: [3, 3, 3, 5, 3, 3, 3]
58 | # EXPAND_RATIO: [1, 4, 5, 4, 4, 6, 6]
59 | # DEPTH: [1, 3, 3, 3, 4, 6, 1]
60 |
61 | # # attentive_nas_a6
62 | # RESOLUTION: 288
63 | # WIDTH: [16, 16, 24, 32, 64, 112, 216, 224, 1984]
64 | # KERNEL_SIZE: [3, 3, 3, 3, 3, 5, 3]
65 | # EXPAND_RATIO: [1, 4, 6, 5, 4, 6, 6]
66 | # DEPTH: [1, 3, 3, 4, 4, 6, 1]
67 | SUPERNET_CFG:
68 | use_v3_head: True
69 | resolutions: [192, 224, 256, 288]
70 | first_conv:
71 | c: [16, 24]
72 | act_func: 'swish'
73 | s: 2
74 | mb1:
75 | c: [16, 24]
76 | d: [1, 2]
77 | k: [3, 5]
78 | t: [1]
79 | s: 1
80 | act_func: 'swish'
81 | se: False
82 | mb2:
83 | c: [24, 32]
84 | d: [3, 4, 5]
85 | k: [3, 5]
86 | t: [4, 5, 6]
87 | s: 2
88 | act_func: 'swish'
89 | se: False
90 | mb3:
91 | c: [32, 40]
92 | d: [3, 4, 5, 6]
93 | k: [3, 5]
94 | t: [4, 5, 6]
95 | s: 2
96 | act_func: 'swish'
97 | se: True
98 | mb4:
99 | c: [64, 72]
100 | d: [3, 4, 5, 6]
101 | k: [3, 5]
102 | t: [4, 5, 6]
103 | s: 2
104 | act_func: 'swish'
105 | se: False
106 | mb5:
107 | c: [112, 120, 128]
108 | d: [3, 4, 5, 6, 7, 8]
109 | k: [3, 5]
110 | t: [4, 5, 6]
111 | s: 1
112 | act_func: 'swish'
113 | se: True
114 | mb6:
115 | c: [192, 200, 208, 216]
116 | d: [3, 4, 5, 6, 7, 8]
117 | k: [3, 5]
118 | t: [6]
119 | s: 2
120 | act_func: 'swish'
121 | se: True
122 | mb7:
123 | c: [216, 224]
124 | d: [1, 2]
125 | k: [3, 5]
126 | t: [6]
127 | s: 1
128 | act_func: 'swish'
129 | se: True
130 | last_conv:
131 | c: [1792, 1984]
132 | act_func: 'swish'
133 |
--------------------------------------------------------------------------------
/configs/search/AttentiveNAS/train.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 4
2 | RNG_SEED: 0
3 | SPACE:
4 | NAME: 'attentivenas'
5 | LOADER:
6 | DATASET: 'imagenet'
7 | NUM_CLASSES: 1000
8 | BATCH_SIZE: 64 # 32*8 in total
9 | NUM_WORKERS: 4
10 | USE_VAL: True
11 | TRANSFORM: "auto_augment_tf"
12 | OPTIM:
13 | GRAD_CLIP: 1.
14 | WARMUP_EPOCH: 5
15 | MAX_EPOCH: 360
16 | WEIGHT_DECAY: 1.e-5
17 | BASE_LR: 0.2
18 | NESTEROV: True
19 | SEARCH:
20 | LOSS_FUN: "cross_entropy_smooth"
21 | LABEL_SMOOTH: 0.1
22 | TRAIN:
23 | DROP_PATH_PROB: 0.2
24 | ATTENTIVENAS:
25 | SANDWICH_NUM: 4 # max + 2*middle + min
26 | DROP_CONNECT: 0.2
27 | BN_MOMENTUM: 0.
28 | BN_EPS: 1.e-5
29 | POST_BN_CALIBRATION_BATCH_NUM: 64
30 | SAMPLER:
31 | METHOD: 'bestup'
32 | MAP_PATH: 'xnas/algorithms/AttentiveNAS/flops_archs_off_table.map'
33 | DISCRETIZE_STEP: 25
34 | NUM_TRIALS: 3
35 | SUPERNET_CFG:
36 | use_v3_head: True
37 | resolutions: [192, 224, 256, 288]
38 | first_conv:
39 | c: [16, 24]
40 | act_func: 'swish'
41 | s: 2
42 | mb1:
43 | c: [16, 24]
44 | d: [1, 2]
45 | k: [3, 5]
46 | t: [1]
47 | s: 1
48 | act_func: 'swish'
49 | se: False
50 | mb2:
51 | c: [24, 32]
52 | d: [3, 4, 5]
53 | k: [3, 5]
54 | t: [4, 5, 6]
55 | s: 2
56 | act_func: 'swish'
57 | se: False
58 | mb3:
59 | c: [32, 40]
60 | d: [3, 4, 5, 6]
61 | k: [3, 5]
62 | t: [4, 5, 6]
63 | s: 2
64 | act_func: 'swish'
65 | se: True
66 | mb4:
67 | c: [64, 72]
68 | d: [3, 4, 5, 6]
69 | k: [3, 5]
70 | t: [4, 5, 6]
71 | s: 2
72 | act_func: 'swish'
73 | se: False
74 | mb5:
75 | c: [112, 120, 128]
76 | d: [3, 4, 5, 6, 7, 8]
77 | k: [3, 5]
78 | t: [4, 5, 6]
79 | s: 1
80 | act_func: 'swish'
81 | se: True
82 | mb6:
83 | c: [192, 200, 208, 216]
84 | d: [3, 4, 5, 6, 7, 8]
85 | k: [3, 5]
86 | t: [6]
87 | s: 2
88 | act_func: 'swish'
89 | se: True
90 | mb7:
91 | c: [216, 224]
92 | d: [1, 2]
93 | k: [3, 5]
94 | t: [6]
95 | s: 1
96 | act_func: 'swish'
97 | se: True
98 | last_conv:
99 | c: [1792, 1984]
100 | act_func: 'swish'
--------------------------------------------------------------------------------
/configs/search/BigNAS/eval.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 4
2 | RNG_SEED: 2
3 | SPACE:
4 | NAME: 'infer_bignas'
5 | LOADER:
6 | DATASET: 'imagenet'
7 | NUM_CLASSES: 1000
8 | BATCH_SIZE: 128
9 | NUM_WORKERS: 4
10 | USE_VAL: True
11 | TRANSFORM: "auto_augment_tf"
12 | SEARCH:
13 | IM_SIZE: 224
14 | BIGNAS:
15 | BN_MOMENTUM: 0.1
16 | BN_EPS: 1.e-5
17 | POST_BN_CALIBRATION_BATCH_NUM: 64
18 | ACTIVE_SUBNET: # subnet for evaluation
19 | RESOLUTION: 192
20 | WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792]
21 | KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3]
22 | EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
23 | DEPTH: [1, 3, 3, 3, 3, 3, 1]
24 | SUPERNET_CFG:
25 | use_v3_head: True
26 | resolutions: [192, 224, 256, 288]
27 | first_conv:
28 | c: [16, 24]
29 | act_func: 'swish'
30 | s: 2
31 | mb1:
32 | c: [16, 24]
33 | d: [1, 2]
34 | k: [3, 5]
35 | t: [1]
36 | s: 1
37 | act_func: 'swish'
38 | se: False
39 | mb2:
40 | c: [24, 32]
41 | d: [3, 4, 5]
42 | k: [3, 5]
43 | t: [4, 5, 6]
44 | s: 2
45 | act_func: 'swish'
46 | se: False
47 | mb3:
48 | c: [32, 40]
49 | d: [3, 4, 5, 6]
50 | k: [3, 5]
51 | t: [4, 5, 6]
52 | s: 2
53 | act_func: 'swish'
54 | se: True
55 | mb4:
56 | c: [64, 72]
57 | d: [3, 4, 5, 6]
58 | k: [3, 5]
59 | t: [4, 5, 6]
60 | s: 2
61 | act_func: 'swish'
62 | se: False
63 | mb5:
64 | c: [112, 120, 128]
65 | d: [3, 4, 5, 6, 7, 8]
66 | k: [3, 5]
67 | t: [4, 5, 6]
68 | s: 1
69 | act_func: 'swish'
70 | se: True
71 | mb6:
72 | c: [192, 200, 208, 216]
73 | d: [3, 4, 5, 6, 7, 8]
74 | k: [3, 5]
75 | t: [6]
76 | s: 2
77 | act_func: 'swish'
78 | se: True
79 | mb7:
80 | c: [216, 224]
81 | d: [1, 2]
82 | k: [3, 5]
83 | t: [6]
84 | s: 1
85 | act_func: 'swish'
86 | se: True
87 | last_conv:
88 | c: [1792, 1984]
89 | act_func: 'swish'
90 |
--------------------------------------------------------------------------------
/configs/search/BigNAS/search.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 1
2 | RNG_SEED: 2
3 | SPACE:
4 | NAME: 'bignas'
5 | LOADER:
6 | DATASET: 'imagenet'
7 | NUM_CLASSES: 1000
8 | BATCH_SIZE: 128
9 | NUM_WORKERS: 8
10 | USE_VAL: True
11 | TRANSFORM: "auto_augment_tf"
12 | SEARCH:
13 | IM_SIZE: 224
14 | WEIGHTS: "exp/search/test/checkpoints/best_model_epoch_0009.pyth"
15 | BIGNAS:
16 | CONSTRAINT_FLOPS: 6.e+8 # 600M
17 | NUM_MUTATE: 200
18 | BN_MOMENTUM: 0.1
19 | BN_EPS: 1.e-5
20 | POST_BN_CALIBRATION_BATCH_NUM: 64
21 | # ACTIVE_SUBNET: # subnet for evaluation
22 | # RESOLUTION: 192
23 | # WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792]
24 | # KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3]
25 | # EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
26 | # DEPTH: [1, 3, 3, 3, 3, 3, 1]
27 | SEARCH_CFG_SETS:
28 | resolutions: [224, 256]
29 | first_conv:
30 | c: [16]
31 | mb1:
32 | c: [16]
33 | d: [2]
34 | k: [3]
35 | t: [1]
36 | mb2:
37 | c: [24]
38 | d: [3]
39 | k: [3]
40 | t: [5]
41 | mb3:
42 | c: [32]
43 | d: [4]
44 | k: [3]
45 | t: [5]
46 | mb4:
47 | c: [64]
48 | d: [5]
49 | k: [3]
50 | t: [5]
51 | mb5:
52 | c: [120]
53 | d: [6]
54 | k: [3]
55 | t: [5]
56 | mb6:
57 | c: [192]
58 | d: [6]
59 | k: [3, 5]
60 | t: [6]
61 | mb7:
62 | c: [216]
63 | d: [2]
64 | k: [3]
65 | t: [6]
66 | last_conv:
67 | c: [1792]
68 | SUPERNET_CFG:
69 | use_v3_head: True
70 | resolutions: [192, 224, 256, 288]
71 | first_conv:
72 | c: [16, 24]
73 | act_func: 'swish'
74 | s: 2
75 | mb1:
76 | c: [16, 24]
77 | d: [1, 2]
78 | k: [3, 5]
79 | t: [1]
80 | s: 1
81 | act_func: 'swish'
82 | se: False
83 | mb2:
84 | c: [24, 32]
85 | d: [3, 4, 5]
86 | k: [3, 5]
87 | t: [4, 5, 6]
88 | s: 2
89 | act_func: 'swish'
90 | se: False
91 | mb3:
92 | c: [32, 40]
93 | d: [3, 4, 5, 6]
94 | k: [3, 5]
95 | t: [4, 5, 6]
96 | s: 2
97 | act_func: 'swish'
98 | se: True
99 | mb4:
100 | c: [64, 72]
101 | d: [3, 4, 5, 6]
102 | k: [3, 5]
103 | t: [4, 5, 6]
104 | s: 2
105 | act_func: 'swish'
106 | se: False
107 | mb5:
108 | c: [112, 120, 128]
109 | d: [3, 4, 5, 6, 7, 8]
110 | k: [3, 5]
111 | t: [4, 5, 6]
112 | s: 1
113 | act_func: 'swish'
114 | se: True
115 | mb6:
116 | c: [192, 200, 208, 216]
117 | d: [3, 4, 5, 6, 7, 8]
118 | k: [3, 5]
119 | t: [6]
120 | s: 2
121 | act_func: 'swish'
122 | se: True
123 | mb7:
124 | c: [216, 224]
125 | d: [1, 2]
126 | k: [3, 5]
127 | t: [6]
128 | s: 1
129 | act_func: 'swish'
130 | se: True
131 | last_conv:
132 | c: [1792, 1984]
133 | act_func: 'swish'
134 |
--------------------------------------------------------------------------------
/configs/search/BigNAS/train.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 4
2 | RNG_SEED: 0
3 | SPACE:
4 | NAME: 'bignas'
5 | LOADER:
6 | DATASET: 'imagenet'
7 | NUM_CLASSES: 1000
8 | BATCH_SIZE: 48
9 | NUM_WORKERS: 4
10 | USE_VAL: True
11 | TRANSFORM: "auto_augment_tf"
12 | OPTIM:
13 | GRAD_CLIP: 1.
14 | WARMUP_EPOCH: 5
15 | MAX_EPOCH: 360
16 | WEIGHT_DECAY: 1.e-5
17 | BASE_LR: 0.15
18 | NESTEROV: True
19 | SEARCH:
20 | LOSS_FUN: "cross_entropy_smooth"
21 | LABEL_SMOOTH: 0.1
22 | TRAIN:
23 | DROP_PATH_PROB: 0.2
24 | BIGNAS:
25 | SANDWICH_NUM: 4 # max + 2*middle + min
26 | DROP_CONNECT: 0.2
27 | BN_MOMENTUM: 0.
28 | BN_EPS: 1.e-5
29 | POST_BN_CALIBRATION_BATCH_NUM: 64
30 | SUPERNET_CFG:
31 | use_v3_head: True
32 | resolutions: [192, 224, 256, 288]
33 | first_conv:
34 | c: [16, 24]
35 | act_func: 'swish'
36 | s: 2
37 | mb1:
38 | c: [16, 24]
39 | d: [1, 2]
40 | k: [3, 5]
41 | t: [1]
42 | s: 1
43 | act_func: 'swish'
44 | se: False
45 | mb2:
46 | c: [24, 32]
47 | d: [3, 4, 5]
48 | k: [3, 5]
49 | t: [4, 5, 6]
50 | s: 2
51 | act_func: 'swish'
52 | se: False
53 | mb3:
54 | c: [32, 40]
55 | d: [3, 4, 5, 6]
56 | k: [3, 5]
57 | t: [4, 5, 6]
58 | s: 2
59 | act_func: 'swish'
60 | se: True
61 | mb4:
62 | c: [64, 72]
63 | d: [3, 4, 5, 6]
64 | k: [3, 5]
65 | t: [4, 5, 6]
66 | s: 2
67 | act_func: 'swish'
68 | se: False
69 | mb5:
70 | c: [112, 120, 128]
71 | d: [3, 4, 5, 6, 7, 8]
72 | k: [3, 5]
73 | t: [4, 5, 6]
74 | s: 1
75 | act_func: 'swish'
76 | se: True
77 | mb6:
78 | c: [192, 200, 208, 216]
79 | d: [3, 4, 5, 6, 7, 8]
80 | k: [3, 5]
81 | t: [6]
82 | s: 2
83 | act_func: 'swish'
84 | se: True
85 | mb7:
86 | c: [216, 224]
87 | d: [1, 2]
88 | k: [3, 5]
89 | t: [6]
90 | s: 1
91 | act_func: 'swish'
92 | se: True
93 | last_conv:
94 | c: [1792, 1984]
95 | act_func: 'swish'
--------------------------------------------------------------------------------
/configs/search/DrNAS/drnas_darts_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'drnas_darts'
3 | CHANNELS: 36
4 | LAYERS: 20
5 | NODES: 4
6 | LOADER:
7 | DATASET: 'cifar10'
8 | SPLIT: [0.5, 0.5]
9 | BATCH_SIZE: 64
10 | NUM_CLASSES: 10
11 | SEARCH:
12 | IM_SIZE: 32
13 | OPTIM:
14 | MIN_LR: 0.0
15 | WEIGHT_DECAY: 3.e-4
16 | LR_POLICY: 'cos'
17 | DARTS:
18 | UNROLLED: True
19 | ALPHA_LR: 6.e-4
20 | ALPHA_WEIGHT_DECAY: 1.e-3
21 | DRNAS:
22 | K: 6
23 | REG_TYPE: "l2"
24 | REG_SCALE: 1.e-3
25 | METHOD: 'dirichlet'
26 | TAU: [1, 10]
27 | OUT_DIR: 'exp/drnas'
28 |
--------------------------------------------------------------------------------
/configs/search/DrNAS/drnas_darts_imagenet.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'drnas_darts'
3 | CHANNELS: 48
4 | LAYERS: 14
5 | NODES: 4
6 | LOADER:
7 | DATASET: 'imagenet'
8 | SPLIT: [0.5, 0.5]
9 | BATCH_SIZE: 512
10 | NUM_WORKERS: 16
11 | NUM_CLASSES: 1000
12 | SEARCH:
13 | IM_SIZE: 32
14 | OPTIM:
15 | BASE_LR: 0.5
16 | MIN_LR: 0.0
17 | WEIGHT_DECAY: 3.e-4
18 | LR_POLICY: 'cos'
19 | WARMUP_EPOCH: 5
20 | DARTS:
21 | UNROLLED: False
22 | ALPHA_LR: 6.e-3
23 | ALPHA_WEIGHT_DECAY: 1.e-3
24 | DRNAS:
25 | K: 6
26 | REG_TYPE: "l2"
27 | REG_SCALE: 1.e-3
28 | METHOD: 'dirichlet'
29 | TAU: [1, 10]
30 | OUT_DIR: 'exp/drnas'
--------------------------------------------------------------------------------
/configs/search/DrNAS/drnas_nasbench201_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'drnas_nb201'
3 | CHANNELS: 16
4 | LAYERS: 5
5 | NODES: 4
6 | LOADER:
7 | DATASET: 'cifar10'
8 | SPLIT: [0.5, 0.5]
9 | BATCH_SIZE: 64
10 | NUM_CLASSES: 10
11 | SEARCH:
12 | IM_SIZE: 32
13 | OPTIM:
14 | BASE_LR: 0.025
15 | MIN_LR: 0.001
16 | WEIGHT_DECAY: 3.e-4
17 | LR_POLICY: 'cos'
18 | MAX_EPOCH: 100
19 | DARTS:
20 | UNROLLED: True
21 | ALPHA_LR: 3.e-4
22 | ALPHA_WEIGHT_DECAY: 1.e-3
23 | DRNAS:
24 | K: 1
25 | REG_TYPE: "l2"
26 | REG_SCALE: 1.e-3
27 | METHOD: 'dirichlet'
28 | TAU: [1, 10]
29 | PROGRESSIVE: False
30 | OUT_DIR: 'exp/drnas'
--------------------------------------------------------------------------------
/configs/search/DrNAS/drnas_nasbench201_cifar100.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'drnas_nb201'
3 | CHANNELS: 16
4 | LAYERS: 5
5 | NODES: 4
6 | LOADER:
7 | DATASET: 'cifar100'
8 | SPLIT: [0.5, 0.5]
9 | BATCH_SIZE: 64
10 | NUM_CLASSES: 100
11 | SEARCH:
12 | IM_SIZE: 32
13 | OPTIM:
14 | BASE_LR: 0.025
15 | MIN_LR: 0.001
16 | WEIGHT_DECAY: 3.e-4
17 | LR_POLICY: 'cos'
18 | MAX_EPOCH: 100
19 | DARTS:
20 | UNROLLED: True
21 | ALPHA_LR: 3.e-4
22 | ALPHA_WEIGHT_DECAY: 1.e-3
23 | DRNAS:
24 | K: 1
25 | REG_TYPE: "l2"
26 | REG_SCALE: 1.e-3
27 | METHOD: 'dirichlet'
28 | TAU: [1, 10]
29 | PROGRESSIVE: False
30 | OUT_DIR: 'exp/drnas'
--------------------------------------------------------------------------------
/configs/search/DrNAS/drnas_nasbench201_cifar10_progressive.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'drnas_nb201'
3 | CHANNELS: 16
4 | LAYERS: 5
5 | NODES: 4
6 | LOADER:
7 | DATASET: 'cifar10'
8 | SPLIT: [0.5, 0.5]
9 | BATCH_SIZE: 64
10 | NUM_CLASSES: 10
11 | SEARCH:
12 | IM_SIZE: 32
13 | OPTIM:
14 | BASE_LR: 0.025
15 | MIN_LR: 0.001
16 | WEIGHT_DECAY: 3.e-4
17 | LR_POLICY: 'cos'
18 | MAX_EPOCH: 100
19 | DARTS:
20 | UNROLLED: True
21 | ALPHA_LR: 3.e-4
22 | ALPHA_WEIGHT_DECAY: 1.e-3
23 | DRNAS:
24 | K: 4
25 | REG_TYPE: "l2"
26 | REG_SCALE: 1.e-3
27 | METHOD: 'dirichlet'
28 | TAU: [1, 10]
29 | PROGRESSIVE: True
30 | OUT_DIR: 'exp/drnas'
--------------------------------------------------------------------------------
/configs/search/OFA/mbv3/depth_1.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 4
2 | SPACE:
3 | NAME: 'ofa_mbv3'
4 | LOADER:
5 | DATASET: 'imagenet'
6 | NUM_CLASSES: 1000
7 | BATCH_SIZE: 128
8 | NUM_WORKERS: 4
9 | USE_VAL: True
10 | SEARCH:
11 | MULTI_SIZES: [128,160,192,224]
12 | LOSS_FUN: 'cross_entropy_smooth'
13 | LABEL_SMOOTH: 0.1
14 | AUTO_RESUME: True
15 | OFA:
16 | TASK: 'depth'
17 | PHASE: 1
18 | WIDTH_MULTI_LIST: [1.0]
19 | KS_LIST: [3,5,7]
20 | EXPAND_LIST: [6]
21 | DEPTH_LIST: [3,4]
22 | CHANNEL_DIVISIBLE: 8
23 | SUBNET_BATCH_SIZE: 2
24 | KD_RATIO: 0.
25 | # KD_PATH: "exp/OFA/teacher_model.pyth"
26 | OPTIM:
27 | MAX_EPOCH: 25
28 | BASE_LR: 5.e-3
29 | MIN_LR: 0.0
30 | WARMUP_EPOCH: 0
31 | WARMUP_FACTOR: 1.
32 | LR_POLICY: 'cos'
33 | MOMENTUM: 0.9
34 | WEIGHT_DECAY: 3.e-5
35 | NESTEROV: True
36 | TEST:
37 | BATCH_SIZE: 256
38 | IM_SIZE: 224
--------------------------------------------------------------------------------
/configs/search/OFA/mbv3/depth_2.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 4
2 | SPACE:
3 | NAME: 'ofa_mbv3'
4 | LOADER:
5 | DATASET: 'imagenet'
6 | NUM_CLASSES: 1000
7 | BATCH_SIZE: 128
8 | NUM_WORKERS: 4
9 | USE_VAL: True
10 | SEARCH:
11 | MULTI_SIZES: [128,160,192,224]
12 | LOSS_FUN: 'cross_entropy_smooth'
13 | LABEL_SMOOTH: 0.1
14 | AUTO_RESUME: True
15 | OFA:
16 | TASK: 'depth'
17 | PHASE: 2
18 | WIDTH_MULTI_LIST: [1.0]
19 | KS_LIST: [3,5,7]
20 | EXPAND_LIST: [6]
21 | DEPTH_LIST: [2,3,4]
22 | CHANNEL_DIVISIBLE: 8
23 | SUBNET_BATCH_SIZE: 2
24 | KD_RATIO: 0.
25 | # KD_PATH: "exp/OFA/teacher_model.pyth"
26 | OPTIM:
27 | MAX_EPOCH: 120
28 | BASE_LR: 1.5e-2
29 | MIN_LR: 1.e-3
30 | WARMUP_EPOCH: 5
31 | WARMUP_FACTOR: 1.
32 | LR_POLICY: 'cos'
33 | MOMENTUM: 0.9
34 | WEIGHT_DECAY: 3.e-5
35 | NESTEROV: True
36 | TEST:
37 | BATCH_SIZE: 256
38 | IM_SIZE: 224
39 |
40 |
--------------------------------------------------------------------------------
/configs/search/OFA/mbv3/expand_1.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 4
2 | SPACE:
3 | NAME: 'ofa_mbv3'
4 | LOADER:
5 | DATASET: 'imagenet'
6 | NUM_CLASSES: 1000
7 | BATCH_SIZE: 128
8 | NUM_WORKERS: 4
9 | USE_VAL: True
10 | SEARCH:
11 | MULTI_SIZES: [128,160,192,224]
12 | LOSS_FUN: 'cross_entropy_smooth'
13 | LABEL_SMOOTH: 0.1
14 | AUTO_RESUME: True
15 | OFA:
16 | TASK: 'expand'
17 | PHASE: 1
18 | WIDTH_MULTI_LIST: [1.0]
19 | KS_LIST: [3,5,7]
20 | EXPAND_LIST: [4,6]
21 | DEPTH_LIST: [2,3,4]
22 | CHANNEL_DIVISIBLE: 8
23 | SUBNET_BATCH_SIZE: 4
24 | KD_RATIO: 0.
25 | # KD_PATH: "exp/OFA/teacher_model.pyth"
26 | OPTIM:
27 | MAX_EPOCH: 25
28 | BASE_LR: 5.e-3
29 | MIN_LR: 0.0
30 | WARMUP_EPOCH: 0
31 | WARMUP_FACTOR: 1.
32 | LR_POLICY: 'cos'
33 | MOMENTUM: 0.9
34 | WEIGHT_DECAY: 3.e-5
35 | NESTEROV: True
36 | TEST:
37 | BATCH_SIZE: 256
38 | IM_SIZE: 224
--------------------------------------------------------------------------------
/configs/search/OFA/mbv3/expand_2.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 4
2 | SPACE:
3 | NAME: 'ofa_mbv3'
4 | LOADER:
5 | DATASET: 'imagenet'
6 | NUM_CLASSES: 1000
7 | BATCH_SIZE: 128
8 | NUM_WORKERS: 4
9 | USE_VAL: True
10 | SEARCH:
11 | MULTI_SIZES: [128,160,192,224]
12 | LOSS_FUN: 'cross_entropy_smooth'
13 | LABEL_SMOOTH: 0.1
14 | AUTO_RESUME: True
15 | OFA:
16 | TASK: 'expand'
17 | PHASE: 2
18 | WIDTH_MULTI_LIST: [1.0]
19 | KS_LIST: [3,5,7]
20 | EXPAND_LIST: [3,4,6]
21 | DEPTH_LIST: [2,3,4]
22 | CHANNEL_DIVISIBLE: 8
23 | SUBNET_BATCH_SIZE: 4
24 | KD_RATIO: 0.
25 | # KD_PATH: "exp/OFA/teacher_model.pyth"
26 | OPTIM:
27 | MAX_EPOCH: 120
28 | BASE_LR: 1.5e-2
29 | MIN_LR: 1.e-3
30 | WARMUP_EPOCH: 5
31 | WARMUP_FACTOR: 1.
32 | LR_POLICY: 'cos'
33 | MOMENTUM: 0.9
34 | WEIGHT_DECAY: 3.e-5
35 | NESTEROV: True
36 | TEST:
37 | BATCH_SIZE: 256
38 | IM_SIZE: 224
39 |
40 |
--------------------------------------------------------------------------------
/configs/search/OFA/mbv3/kernel_1.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 4
2 | SPACE:
3 | NAME: 'ofa_mbv3'
4 | LOADER:
5 | DATASET: 'imagenet'
6 | NUM_CLASSES: 1000
7 | BATCH_SIZE: 128
8 | NUM_WORKERS: 4
9 | USE_VAL: True
10 | SEARCH:
11 | MULTI_SIZES: [128,160,192,224]
12 | LOSS_FUN: 'cross_entropy_smooth'
13 | LABEL_SMOOTH: 0.1
14 | AUTO_RESUME: True
15 | OFA:
16 | TASK: 'kernel'
17 | PHASE: 1
18 | WIDTH_MULTI_LIST: [1.0]
19 | KS_LIST: [3,5,7]
20 | EXPAND_LIST: [6]
21 | DEPTH_LIST: [4]
22 | CHANNEL_DIVISIBLE: 8
23 | SUBNET_BATCH_SIZE: 1
24 | KD_RATIO: 0.
25 | # KD_PATH: "exp/OFA/teacher_model.pyth"
26 | OPTIM:
27 | MAX_EPOCH: 120
28 | BASE_LR: 6.0e-2
29 | MIN_LR: 1.e-3
30 | WARMUP_EPOCH: 5
31 | WARMUP_FACTOR: 1.
32 | LR_POLICY: 'cos'
33 | MOMENTUM: 0.9
34 | WEIGHT_DECAY: 3.e-5
35 | NESTEROV: True
36 | TEST:
37 | BATCH_SIZE: 256
38 | IM_SIZE: 224
--------------------------------------------------------------------------------
/configs/search/OFA/mbv3/normal_1.yaml:
--------------------------------------------------------------------------------
1 | # --------------
2 | # refer from:
3 | # https://github.com/skhu101/GM-NAS/blob/main/once-for-all-GM/train_ofa_net.py
4 | # --------------
5 |
6 | NUM_GPUS: 4
7 | SPACE:
8 | NAME: 'ofa_mbv3'
9 | LOADER:
10 | DATASET: 'imagenet'
11 | NUM_CLASSES: 1000
12 | BATCH_SIZE: 128
13 | NUM_WORKERS: 4
14 | USE_VAL: True
15 | SEARCH:
16 | MULTI_SIZES: [128,160,192,224]
17 | LOSS_FUN: 'cross_entropy_smooth'
18 | LABEL_SMOOTH: 0.1
19 | WEIGHTS: ''
20 | AUTO_RESUME: True
21 | OFA:
22 | TASK: 'normal'
23 | PHASE: 1
24 | WIDTH_MULTI_LIST: [1.0]
25 | KS_LIST: [7]
26 | EXPAND_LIST: [6]
27 | DEPTH_LIST: [4]
28 | CHANNEL_DIVISIBLE: 8
29 | SUBNET_BATCH_SIZE: 1
30 | KD_RATIO: 0.
31 | # KD_PATH: "exp/OFA/teacher_model.pyth"
32 | OPTIM:
33 | MAX_EPOCH: 180
34 | BASE_LR: 0.325
35 | MIN_LR: 1.e-3
36 | WARMUP_EPOCH: 5
37 | WARMUP_FACTOR: 1.
38 | LR_POLICY: 'cos'
39 | MOMENTUM: 0.9
40 | WEIGHT_DECAY: 3.e-5
41 | NESTEROV: True
42 | TEST:
43 | BATCH_SIZE: 256
44 | IM_SIZE: 224
45 |
46 |
--------------------------------------------------------------------------------
/configs/search/RMINAS/rminas_darts_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'infer_darts'
3 | LOADER:
4 | DATASET: 'cifar10'
5 | NUM_CLASSES: 10
6 | BATCH_SIZE: 128
7 | OPTIM:
8 | BASE_LR: 0.025
9 | MOMENTUM: 0.9
10 | WEIGHT_DECAY: 0.0003
11 | MAX_EPOCH: 250
12 | TRAIN:
13 | CHANNELS: 16
14 | LAYERS: 8
15 | RMINAS:
16 | LOSS_BETA: 0.8
17 | RF_WARMUP: 100
18 | RF_THRESRATE: 0.05
19 | RF_SUCC: 100
20 | OUT_DIR: 'exp/rminas'
--------------------------------------------------------------------------------
/configs/search/RMINAS/rminas_darts_cifar100.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'infer_darts'
3 | LOADER:
4 | DATASET: 'cifar100'
5 | NUM_CLASSES: 100
6 | BATCH_SIZE: 128
7 | OPTIM:
8 | BASE_LR: 0.025
9 | MOMENTUM: 0.9
10 | WEIGHT_DECAY: 0.0003
11 | MAX_EPOCH: 250
12 | TRAIN:
13 | CHANNELS: 16
14 | LAYERS: 8
15 | RMINAS:
16 | LOSS_BETA: 0.8
17 | RF_WARMUP: 100
18 | RF_THRESRATE: 0.05
19 | RF_SUCC: 100
20 | OUT_DIR: 'exp/rminas'
--------------------------------------------------------------------------------
/configs/search/RMINAS/rminas_nasbench201_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'infer_nb201'
3 | LOADER:
4 | DATASET: 'cifar10'
5 | NUM_CLASSES: 10
6 | BATCH_SIZE: 32
7 | OPTIM:
8 | BASE_LR: 0.1
9 | MOMENTUM: 0.9
10 | WEIGHT_DECAY: 0.0005
11 | MAX_EPOCH: 150
12 | TRAIN:
13 | CHANNELS: 16
14 | LAYERS: 8
15 | RMINAS:
16 | LOSS_BETA: 0.8
17 | RF_WARMUP: 100
18 | RF_THRESRATE: 0.05
19 | RF_SUCC: 100
20 | OUT_DIR: 'exp/rminas'
--------------------------------------------------------------------------------
/configs/search/RMINAS/rminas_nasbench201_cifar100.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'infer_nb201'
3 | LOADER:
4 | DATASET: 'cifar100'
5 | NUM_CLASSES: 100
6 | BATCH_SIZE: 32
7 | OPTIM:
8 | BASE_LR: 0.1
9 | MOMENTUM: 0.9
10 | WEIGHT_DECAY: 0.0005
11 | MAX_EPOCH: 150
12 | TRAIN:
13 | CHANNELS: 16
14 | LAYERS: 8
15 | RMINAS:
16 | LOSS_BETA: 0.8
17 | RF_WARMUP: 100
18 | RF_THRESRATE: 0.05
19 | RF_SUCC: 100
20 | OUT_DIR: 'exp/rminas'
--------------------------------------------------------------------------------
/configs/search/RMINAS/rminas_nasbench201_imagenet16.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'infer_nb201'
3 | LOADER:
4 | DATASET: 'imagenet16'
5 | NUM_CLASSES: 120
6 | BATCH_SIZE: 32
7 | OPTIM:
8 | BASE_LR: 0.1
9 | MOMENTUM: 0.9
10 | WEIGHT_DECAY: 0.0005
11 | MAX_EPOCH: 150
12 | TRAIN:
13 | CHANNELS: 16
14 | LAYERS: 8
15 | RMINAS:
16 | LOSS_BETA: 0.8
17 | RF_WARMUP: 100
18 | RF_THRESRATE: 0.05
19 | RF_SUCC: 100
20 | OUT_DIR: 'exp/rminas'
--------------------------------------------------------------------------------
/configs/search/RMINAS/rminas_nasbenchmacro_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'nasbenchmacro'
3 | LOADER:
4 | DATASET: 'cifar10'
5 | NUM_CLASSES: 10
6 | BATCH_SIZE: 64
7 | OPTIM:
8 | # BASE_LR: 0.1
9 | # MOMENTUM: 0.9
10 | # WEIGHT_DECAY: 0.0005
11 | MAX_EPOCH: 30
12 | # TRAIN:
13 | # CHANNELS: 16
14 | # LAYERS: 8
15 | RMINAS:
16 | LOSS_BETA: 0.8
17 | RF_WARMUP: 100
18 | RF_THRESRATE: 0.05
19 | RF_SUCC: 100
20 | SEARCH:
21 | IM_SIZE: 32
22 | OUT_DIR: 'exp/rminas'
--------------------------------------------------------------------------------
/configs/search/RMINAS/rminas_proxyless_imagenet.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'proxyless'
3 | LOADER:
4 | DATASET: 'imagenet'
5 | NUM_CLASSES: 10
6 | NUM_WORKERS: 0
7 | BATCH_SIZE: 128
8 | OPTIM:
9 | BASE_LR: 0.025
10 | MOMENTUM: 0.9
11 | WEIGHT_DECAY: 0.0003
12 | MAX_EPOCH: 500
13 | TRAIN:
14 | CHANNELS: 16
15 | LAYERS: 8
16 | RMINAS:
17 | LOSS_BETA: 0.8
18 | RF_WARMUP: 100
19 | RF_THRESRATE: 0.05
20 | RF_SUCC: 100
21 | OUT_DIR: 'exp/rminas'
--------------------------------------------------------------------------------
/configs/search/SNG/ASNG_darts_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'darts'
3 | LOADER:
4 | DATASET: 'cifar10'
5 | SPLIT: [0.8, 0.2]
6 | BATCH_SIZE: 256
7 | NUM_CLASSES: 10
8 | SEARCH:
9 | IM_SIZE: 32
10 | OPTIM:
11 | STEPS: [48, 96]
12 | SNG:
13 | NAME: 'ASNG'
14 | THETA_LR: 0.1
15 | PRUNING: True
16 | PRUNING_STEP: 3
17 | PROB_SAMPLING: False
18 | UTILITY: 'log'
19 | UTILITY_FACTOR: 0.4
20 | LAMBDA: -1
21 | MOMENTUM: True
22 | GAMMA: 0.9
23 | SAMPLING_PER_EDGE: 1
24 | RANDOM_SAMPLE: True
25 | WARMUP_RANDOM_SAMPLE: True
26 | BIGMODEL_SAMPLE_PROB: 0.5
27 | BIGMODEL_NON_PARA: 2
28 | EDGE_SAMPLING: False
29 | EDGE_SAMPLING_EPOCH: -1
30 | OUT_DIR: 'exp/ASNG'
--------------------------------------------------------------------------------
/configs/search/SNG/DDPNAS_darts_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'darts'
3 | LOADER:
4 | DATASET: 'cifar10'
5 | SPLIT: [0.8, 0.2]
6 | BATCH_SIZE: 256
7 | NUM_CLASSES: 10
8 | SEARCH:
9 | IM_SIZE: 32
10 | OPTIM:
11 | STEPS: [48, 96]
12 | BASE_LR: 0.025
13 | LR_POLICY: 'cos'
14 | SNG:
15 | NAME: 'DDPNAS'
16 | THETA_LR: 0.01
17 | PRUNING: True
18 | PRUNING_STEP: 3
19 | PROB_SAMPLING: False
20 | UTILITY: 'log'
21 | UTILITY_FACTOR: 0.4
22 | LAMBDA: -1
23 | MOMENTUM: True
24 | GAMMA: 0.9
25 | SAMPLING_PER_EDGE: 1
26 | RANDOM_SAMPLE: True
27 | WARMUP_RANDOM_SAMPLE: True
28 | BIGMODEL_SAMPLE_PROB: 0.5
29 | BIGMODEL_NON_PARA: 2
30 | EDGE_SAMPLING: False
31 | EDGE_SAMPLING_EPOCH: -1
32 | OUT_DIR: 'exp/DDPNAS'
33 |
--------------------------------------------------------------------------------
/configs/search/SNG/DDPNAS_nasbench201_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'nasbench201'
3 | CHANNELS: 16
4 | LAYERS: 5
5 | NODES: 4
6 | LOADER:
7 | DATASET: 'cifar10'
8 | SPLIT: [0.8, 0.2]
9 | BATCH_SIZE: 256
10 | NUM_CLASSES: 10
11 | SEARCH:
12 | IM_SIZE: 32
13 | EVALUATION: 'nasbench201'
14 | OPTIM:
15 | STEPS: [48, 96]
16 | BASE_LR: 0.025
17 | LR_POLICY: 'cos'
18 | SNG:
19 | NAME: 'DDPNAS'
20 | THETA_LR: 0.01
21 | PRUNING: True
22 | PRUNING_STEP: 3
23 | PROB_SAMPLING: False
24 | UTILITY: 'log'
25 | UTILITY_FACTOR: 0.4
26 | LAMBDA: -1
27 | MOMENTUM: True
28 | GAMMA: 0.9
29 | SAMPLING_PER_EDGE: 1
30 | RANDOM_SAMPLE: True
31 | WARMUP_RANDOM_SAMPLE: True
32 | BIGMODEL_SAMPLE_PROB: 0.
33 | BIGMODEL_NON_PARA: 2
34 | EDGE_SAMPLING: False
35 | EDGE_SAMPLING_EPOCH: -1
36 | OUT_DIR: 'exp/DDPNAS'
37 |
--------------------------------------------------------------------------------
/configs/search/SNG/DynamicASNG_darts_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'darts'
3 | LOADER:
4 | DATASET: 'cifar10'
5 | SPLIT: [0.8, 0.2]
6 | BATCH_SIZE: 256
7 | NUM_CLASSES: 10
8 | SEARCH:
9 | IM_SIZE: 32
10 | OPTIM:
11 | STEPS: [48, 96]
12 | SNG:
13 | NAME: 'dynamic_ASNG'
14 | THETA_LR: 0.1
15 | PRUNING: True
16 | PRUNING_STEP: 3
17 | PROB_SAMPLING: False
18 | UTILITY: 'log'
19 | UTILITY_FACTOR: 0.4
20 | LAMBDA: -1
21 | MOMENTUM: True
22 | GAMMA: 0.9
23 | SAMPLING_PER_EDGE: 1
24 | RANDOM_SAMPLE: True
25 | WARMUP_RANDOM_SAMPLE: True
26 | BIGMODEL_SAMPLE_PROB: 0.5
27 | BIGMODEL_NON_PARA: 2
28 | EDGE_SAMPLING: False
29 | EDGE_SAMPLING_EPOCH: -1
30 | OUT_DIR: 'exp/Dynamic_ASNG'
--------------------------------------------------------------------------------
/configs/search/SNG/DynamicSNG_darts_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'darts'
3 | LOADER:
4 | DATASET: 'cifar10'
5 | SPLIT: [0.8, 0.2]
6 | BATCH_SIZE: 256
7 | NUM_CLASSES: 10
8 | SEARCH:
9 | IM_SIZE: 32
10 | OPTIM:
11 | STEPS: [48, 96]
12 | SNG:
13 | NAME: 'dynamic_SNG'
14 | THETA_LR: 0.1
15 | PRUNING: True
16 | PRUNING_STEP: 3
17 | PROB_SAMPLING: False
18 | UTILITY: 'log'
19 | UTILITY_FACTOR: 0.4
20 | LAMBDA: -1
21 | MOMENTUM: True
22 | GAMMA: 0.9
23 | SAMPLING_PER_EDGE: 1
24 | RANDOM_SAMPLE: True
25 | WARMUP_RANDOM_SAMPLE: True
26 | BIGMODEL_SAMPLE_PROB: 0.5
27 | BIGMODEL_NON_PARA: 2
28 | EDGE_SAMPLING: False
29 | EDGE_SAMPLING_EPOCH: -1
30 | OUT_DIR: 'exp/Dynamic_SNG'
--------------------------------------------------------------------------------
/configs/search/SNG/MIGO_darts_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'darts'
3 | LOADER:
4 | DATASET: 'cifar10'
5 | SPLIT: [0.5, 0.5]
6 | BATCH_SIZE: 256
7 | NUM_CLASSES: 10
8 | SEARCH:
9 | IM_SIZE: 32
10 | OPTIM:
11 | BASE_LR: 0.1
12 | MOMENTUM: 0.9
13 | STEPS: [60, 120]
14 | LR_POLICY: 'step'
15 | WARMUP_EPOCH: 0
16 | MAX_EPOCH: 200
17 | FINAL_EPOCH: 50
18 | SNG:
19 | NAME: 'MIGO'
20 | THETA_LR: 0.1
21 | PRUNING: False
22 | PRUNING_STEP: 2
23 | PROB_SAMPLING: True
24 | UTILITY: 'log'
25 | UTILITY_FACTOR: 0.4
26 | LAMBDA: -1
27 | MOMENTUM: True
28 | GAMMA: 0.5
29 | SAMPLING_PER_EDGE: 1
30 | RANDOM_SAMPLE: True
31 | WARMUP_RANDOM_SAMPLE: True
32 | BIGMODEL_SAMPLE_PROB: 0.5
33 | BIGMODEL_NON_PARA: 2
34 | EDGE_SAMPLING: True
35 | EDGE_SAMPLING_EPOCH: -1
36 | OUT_DIR: 'exp/MIGO'
--------------------------------------------------------------------------------
/configs/search/SNG/SNG_darts_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'darts'
3 | LOADER:
4 | DATASET: 'cifar10'
5 | SPLIT: [0.8, 0.2]
6 | BATCH_SIZE: 256
7 | NUM_CLASSES: 10
8 | SEARCH:
9 | IM_SIZE: 32
10 | OPTIM:
11 | STEPS: [48, 96]
12 | SNG:
13 | NAME: 'SNG'
14 | THETA_LR: 0.1
15 | PRUNING: True
16 | PRUNING_STEP: 3
17 | PROB_SAMPLING: False
18 | UTILITY: 'log'
19 | UTILITY_FACTOR: 0.4
20 | LAMBDA: -1
21 | MOMENTUM: True
22 | GAMMA: 0.9
23 | SAMPLING_PER_EDGE: 1
24 | RANDOM_SAMPLE: True
25 | WARMUP_RANDOM_SAMPLE: True
26 | BIGMODEL_SAMPLE_PROB: 0.5
27 | BIGMODEL_NON_PARA: 2
28 | EDGE_SAMPLING: False
29 | EDGE_SAMPLING_EPOCH: -1
30 | OUT_DIR: 'exp/SNG'
--------------------------------------------------------------------------------
/configs/search/darts_darts_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'darts'
3 | CHANNELS: 16
4 | LAYERS: 8
5 | NODES: 4
6 | LOADER:
7 | DATASET: 'cifar10'
8 | SPLIT: [0.5, 0.5]
9 | BATCH_SIZE: 64
10 | NUM_CLASSES: 10
11 | SEARCH:
12 | IM_SIZE: 32
13 | OPTIM:
14 | MAX_EPOCH: 50
15 | LR_POLICY: 'cos'
16 | DARTS:
17 | UNROLLED: True
18 | ALPHA_LR: 3.e-4
19 | ALPHA_WEIGHT_DECAY: 1.e-3
20 | OUT_DIR: 'exp/darts'
--------------------------------------------------------------------------------
/configs/search/darts_nasbench201_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'nasbench201'
3 | CHANNELS: 16
4 | LAYERS: 5
5 | NODES: 4
6 | LOADER:
7 | DATASET: 'cifar10'
8 | SPLIT: [0.5, 0.5]
9 | BATCH_SIZE: 32
10 | NUM_CLASSES: 10
11 | SEARCH:
12 | IM_SIZE: 32
13 | OPTIM:
14 | MAX_EPOCH: 50
15 | LR_POLICY: 'cos'
16 | DARTS:
17 | UNROLLED: True
18 | ALPHA_LR: 3.e-4
19 | ALPHA_WEIGHT_DECAY: 1.e-3
20 | OUT_DIR: 'exp/darts'
--------------------------------------------------------------------------------
/configs/search/dropnas_darts_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'dropnas'
3 | CHANNELS: 16
4 | LAYERS: 8
5 | NODES: 4
6 | LOADER:
7 | DATASET: 'cifar10'
8 | NUM_CLASSES: 10
9 | SPLIT: [0.5, 0.5]
10 | BATCH_SIZE: 64
11 | SEARCH:
12 | IM_SIZE: 32
13 | OPTIM:
14 | BASE_LR: 0.0375
15 | MIN_LR: 0.0015
16 | MOMENTUM: 0.9
17 | WEIGHT_DECAY: 3.e-4
18 | MAX_EPOCH: 50
19 | LR_POLICY: 'cos'
20 | WARMUP_EPOCH: 0
21 | DARTS:
22 | ALPHA_WEIGHT_DECAY: 0
23 | ALPHA_LR: 3.e-4
24 | DROPNAS:
25 | DROP_RATE: 3.e-5
26 | OUT_DIR: 'exp/dropnas'
--------------------------------------------------------------------------------
/configs/search/gdas_nasbench201_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'gdas_nb201'
3 | CHANNELS: 16
4 | LAYERS: 5
5 | NODES: 4
6 | LOADER:
7 | DATASET: 'cifar10'
8 | SPLIT: [0.5, 0.5]
9 | BATCH_SIZE: 64
10 | NUM_CLASSES: 10
11 | SEARCH:
12 | IM_SIZE: 32
13 | OPTIM:
14 | BASE_LR: 0.025
15 | MIN_LR: 0.001
16 | WEIGHT_DECAY: 3.e-4
17 | LR_POLICY: 'cos'
18 | MAX_EPOCH: 100
19 | DARTS:
20 | UNROLLED: True
21 | ALPHA_LR: 3.e-4
22 | ALPHA_WEIGHT_DECAY: 1.e-3
23 | DRNAS:
24 | K: 1
25 | REG_TYPE: "l2"
26 | REG_SCALE: 1.e-3
27 | METHOD: 'gdas'
28 | TAU: [1, 10]
29 | PROGRESSIVE: False
30 | OUT_DIR: 'exp/gdas'
--------------------------------------------------------------------------------
/configs/search/nasbenchmacro_nasbenchmacro_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'nasbenchmacro'
3 | LOADER:
4 | DATASET: 'cifar10'
5 | NUM_CLASSES: 10
6 | SPLIT: [0.5, 0.5]
7 | BATCH_SIZE: 64
8 | SEARCH:
9 | IM_SIZE: 32
10 | OUT_DIR: 'exp/nbm'
--------------------------------------------------------------------------------
/configs/search/pcdarts_pcdarts_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'pcdarts'
3 | CHANNELS: 16
4 | LAYERS: 8
5 | NODES: 4
6 | LOADER:
7 | DATASET: 'cifar10'
8 | SPLIT: [0.5, 0.5]
9 | BATCH_SIZE: 256
10 | NUM_CLASSES: 10
11 | SEARCH:
12 | IM_SIZE: 32
13 | OPTIM:
14 | MAX_EPOCH: 50
15 | LR_POLICY: 'cos'
16 | WEIGHT_DECAY: 3e-4
17 | MIN_LR: 0.0
18 | DARTS:
19 | UNROLLED: True
20 | ALPHA_LR: 6.e-4
21 | ALPHA_WEIGHT_DECAY: 1.e-3
22 | OUT_DIR: 'exp/pcdarts'
--------------------------------------------------------------------------------
/configs/search/pdarts_pdarts_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'pdarts'
3 | CHANNELS: 16
4 | LAYERS: 5
5 | NODES: 4
6 | BASIC_OP: [
7 | 'none',
8 | 'max_pool_3x3',
9 | 'avg_pool_3x3',
10 | 'skip_connect',
11 | 'sep_conv_3x3',
12 | 'sep_conv_5x5',
13 | 'dil_conv_3x3',
14 | 'dil_conv_5x5'
15 | ]
16 | LOADER:
17 | DATASET: 'cifar10'
18 | SPLIT: [0.5, 0.5]
19 | BATCH_SIZE: 64
20 | NUM_CLASSES: 10
21 | SEARCH:
22 | IM_SIZE: 32
23 | OPTIM:
24 | MAX_EPOCH: 25
25 | MIN_LR: 0.0
26 | BASE_LR: 0.025
27 | WEIGHT_DECAY: 3.e-4
28 | DARTS:
29 | UNROLLED: False
30 | ALPHA_LR: 6.e-4
31 | ALPHA_WEIGHT_DECAY: 1.e-3
32 | PDARTS:
33 | ADD_LAYERS: [0, 6, 12]
34 | ADD_WIDTH: [0, 0, 0]
35 | DROPOUT_RATE: [0.1, 0.4, 0.7]
36 | NUM_TO_KEEP: [5, 3, 1]
37 | EPS_NO_ARCHS: [10, 10, 10]
38 | SCALE_FACTOR: 0.2
39 | OUT_DIR: 'exp/pdarts'
--------------------------------------------------------------------------------
/configs/search/snas_nasbench201_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'drnas_nb201'
3 | CHANNELS: 16
4 | LAYERS: 5
5 | NODES: 4
6 | LOADER:
7 | DATASET: 'cifar10'
8 | SPLIT: [0.5, 0.5]
9 | BATCH_SIZE: 64
10 | NUM_CLASSES: 10
11 | SEARCH:
12 | IM_SIZE: 32
13 | OPTIM:
14 | BASE_LR: 0.025
15 | MIN_LR: 0.001
16 | WEIGHT_DECAY: 3.e-4
17 | LR_POLICY: 'cos'
18 | MAX_EPOCH: 100
19 | DARTS:
20 | UNROLLED: True
21 | ALPHA_LR: 3.e-4
22 | ALPHA_WEIGHT_DECAY: 1.e-3
23 | DRNAS:
24 | K: 1
25 | REG_TYPE: "l2"
26 | REG_SCALE: 1.e-3
27 | METHOD: 'snas'
28 | TAU: [1, 10]
29 | PROGRESSIVE: False
30 | OUT_DIR: 'exp/snas'
--------------------------------------------------------------------------------
/configs/search/spos_nasbench201_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'spos_nb201'
3 | CHANNELS: 16
4 | LAYERS: 5
5 | NODES: 4
6 | LOADER:
7 | DATASET: 'cifar10'
8 | NUM_CLASSES: 10
9 | # SPLIT: [0.5, 0.5]
10 | BATCH_SIZE: 128
11 | SEARCH:
12 | IM_SIZE: 32
13 | AUTO_RESUME: False
14 | SPOS:
15 | RESIZE: True
16 | LAYERS: 90 # 6 * 15
17 | NUM_CHOICE: 5
18 | OPTIM:
19 | BASE_LR: 0.025
20 | MIN_LR: 0.001
21 | WEIGHT_DECAY: 1.e-4
22 | LR_POLICY: 'cos'
23 | MAX_EPOCH: 50
24 | OUT_DIR: 'exp/spos_201'
--------------------------------------------------------------------------------
/configs/search/spos_spos_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'spos'
3 | LOADER:
4 | DATASET: 'cifar10'
5 | NUM_CLASSES: 10
6 | # SPLIT: [0.5, 0.5]
7 | BATCH_SIZE: 96
8 | SEARCH:
9 | IM_SIZE: 32
10 | # LABEL_SMOOTH: 0.1
11 | SPOS:
12 | RESIZE: True
13 | LAYERS: 20
14 | NUM_CHOICE: 4
15 | OPTIM:
16 | MAX_EPOCH: 600
17 | BASE_LR: 0.025
18 | WEIGHT_DECAY: 3.e-4
19 | OUT_DIR: 'exp/spos'
--------------------------------------------------------------------------------
/configs/train/darts_darts_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'infer_darts'
3 | LOADER:
4 | DATASET: 'cifar10'
5 | BATCH_SIZE: 96
6 | NUM_CLASSES: 10
7 | TEST:
8 | BATCH_SIZE: 96 # stay same.
9 | OPTIM:
10 | BASE_LR: 0.025
11 | WEIGHT_DECAY: 3.e-4
12 | MAX_EPOCH: 600
13 | TRAIN:
14 | CHANNELS: 36
15 | LAYERS: 20
16 | AUX_WEIGHT: 0.4
17 | DROP_PATH_PROB: 0.3
18 | GENOTYPE: "Genotype(normal=[('sep_conv_5x5', 1), ('sep_conv_3x3', 0),('skip_connect', 0), ('sep_conv_5x5', 1),('sep_conv_5x5', 3), ('sep_conv_3x3', 1),('dil_conv_5x5', 3), ('max_pool_3x3', 4)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 0), ('sep_conv_5x5', 1),('skip_connect', 0), ('skip_connect', 1),('sep_conv_3x3', 3), ('skip_connect', 2),('dil_conv_3x3', 3), ('sep_conv_5x5', 0)], reduce_concat=range(2, 6))"
--------------------------------------------------------------------------------
/configs/train/spos_spos_cifar10.yaml:
--------------------------------------------------------------------------------
1 | SPACE:
2 | NAME: 'infer_spos'
3 | LOADER:
4 | DATASET: 'cifar10'
5 | NUM_CLASSES: 10
6 | SPLIT: [0.5, 0.5]
7 | BATCH_SIZE: 64
8 | SEARCH:
9 | IM_SIZE: 32
10 | SPOS:
11 | RESIZE: True
12 | LAYERS: 20
13 | NUM_CHOICE: 4
14 | CHOICE: [1, 0, 3, 1, 3, 0, 3, 0, 0, 3, 3, 0, 1, 0, 1, 2, 2, 1, 1, 3]
15 | OUT_DIR: 'username/project/XNAS/experiment/spos'
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | # import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 |
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = 'XNAS'
21 | copyright = '2022, PCL_AutoML'
22 | author = 'PCL_AutoML'
23 |
24 | # The full version, including alpha/beta/rc tags
25 | release = 'v0.0.1'
26 |
27 |
28 | # -- General configuration ---------------------------------------------------
29 |
30 | # Add any Sphinx extension module names here, as strings. They can be
31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
32 | # ones.
33 | extensions = [
34 | 'recommonmark',
35 | 'sphinx_markdown_tables'
36 | ]
37 |
38 | # Add any paths that contain templates here, relative to this directory.
39 | templates_path = ['_templates']
40 |
41 | # The language for content autogenerated by Sphinx. Refer to documentation
42 | # for a list of supported languages.
43 | #
44 | # This is also used if you do content translation via gettext catalogs.
45 | # Usually you set "language" from the command line for these cases.
46 | language = 'en'
47 | # language = 'zh_CN'
48 |
49 | # List of patterns, relative to source directory, that match files and
50 | # directories to ignore when looking for source files.
51 | # This pattern also affects html_static_path and html_extra_path.
52 | exclude_patterns = []
53 |
54 |
55 | # -- Options for HTML output -------------------------------------------------
56 |
57 | # The theme to use for HTML and HTML Help pages. See the documentation for
58 | # a list of builtin themes.
59 | #
60 |
61 | # html_theme = 'alabaster'
62 | # html_theme = 'sphinx_rtd_theme'
63 | html_theme = 'furo'
64 |
65 |
66 |
67 | # Add any paths that contain custom static files (such as style sheets) here,
68 | # relative to this directory. They are copied after the builtin static files,
69 | # so a file named "default.css" will overwrite the builtin "default.css".
70 | html_static_path = ['_static']
--------------------------------------------------------------------------------
/docs/data_preparation.md:
--------------------------------------------------------------------------------
1 | # Common Settings
2 |
3 | It is highly recommended to save or link datasets to the `$XNAS/data` folder, thus no additional configuration is required.
4 |
5 | However, manually setting the path for datasets is also available by modifying the `cfg.LOADER.DATAPATH` attribute in the configuration file `$XNAS/xnas/core/config.py`.
6 |
7 | Additionally, files required by benchmarks are also in the `$XNAS/data` folder. You can also modify related attributes under `cfg.BENCHMARK` in the configuration file, to match your actual file locations.
8 |
9 |
10 | # Dataset Preparation
11 |
12 | The dataloaders of XNAS will read the dataset files from `$XNAS/data/$DATASET_NAME` by default, and we use lowercase filenames and remove the hyphens. For example, files for CIFAR-10 should be placed (or auto downloaded) under `$XNAS/data/cifar/` directory.
13 |
14 | XNAS currently supports the following datasets.
15 |
16 | - CIFAR-10
17 | - CIFAR-100
18 | - ImageNet
19 | - ImageNet16 (Downsampled)
20 | - SVHN
21 | - MNIST
22 | - FashionMNIST
23 |
24 | # Benchmark Preparation
25 |
26 | Some search spaces or algorithms supported by XNAS require specific APIs provided by NAS benchmarks. Installation and properly setting are required to run these code.
27 |
28 | Benchmarks supported by XNAS and their linkes are following.
29 | - nasbench101: [GitHub](https://github.com/google-research/nasbench)
30 | - nasbench1shot1: [GitHub](https://github.com/automl/nasbench-1shot1)
31 | - nasbench201: [GitHub](https://github.com/D-X-Y/NAS-Bench-201)
32 | - nasbench301: [GitHub](https://github.com/automl/nasbench301)
33 |
34 |
--------------------------------------------------------------------------------
/docs/get_started.md:
--------------------------------------------------------------------------------
1 | # Prerequisites
2 |
3 | XNAS does not provide installation via `pip` currently. To run XNAS, `python>=3.7` and `pytorch==1.9` are required. Other versions of `PyTorch` may also work well, but there are potential API differences that can cause warnings to be generated.
4 |
5 | We have listed other requirements in `requirements.txt` file.
6 |
7 | # Installation
8 |
9 | 1. Clone this repo.
10 | 2. (Optional) Create a virtualenv for this library.
11 | ```sh
12 | virtualenv venv
13 | ```
14 | 3. Install dependencies.
15 | ```sh
16 | pip install -r requirements.txt
17 | ```
18 | 4. Set the `$PYTHONPATH` environment variable.
19 | ```sh
20 | export PYTHONPATH=$PYTHONPATH:/Path/to/XNAS
21 | ```
22 | 5. Set the visible GPU device for XNAS. Currently XNAS supports single GPU only, but we will provide support for multi-GPUs soon.
23 | ```sh
24 | export CUDA_VISIBLE_DEVICES=0
25 | ```
26 |
27 | Notably, environment variables are **valid only for the current terminal**. For ease of use, we recommend adding commands within your environment profile (like `~/.bashrc` for `bash`) to automatically configure environment variables after login:
28 |
29 | ```sh
30 | echo "export PYTHONPATH=$PYTHONPATH:/Path/to/XNAS" >> ~/.bashrc
31 | ```
32 |
33 | some search spaces or algorithms supported by XNAS require specific APIs provided by NAS benchmarks. Installation and properly setting are required to run these code.
34 |
35 | Benchmarks supported by XNAS and their linkes are following.
36 | - nasbench101: [GitHub](https://github.com/google-research/nasbench)
37 | - nasbench1shot1: [GitHub](https://github.com/automl/nasbench-1shot1)
38 | - nasbench201: [GitHub](https://github.com/D-X-Y/NAS-Bench-201)
39 | - nasbench301: [GitHub](https://github.com/automl/nasbench301)
40 |
41 | For detailed instructions to install these benchmarks, please refer to [**Data Preparation**](./data_preparation.md).
42 |
43 |
44 | # Usage
45 |
46 | Before running code in XNAS, please make sure you have followed instructions in [**Data Preparation**](./data_preparation.md) in our docs to complete preparing the necessary data.
47 |
48 | The main program entries for the search and training process are in the `$XNAS/scripts` folder. To modify and add NAS code, please place files in this folder.
49 |
50 | ## Configuration Files
51 |
52 | XNAS uses the `.yaml` file format to organize the configuration files. All configuration files are placed under `$XNAS/configs` directory. To ensure the uniformity and clarity of files, we strongly recommend using the following naming convention:
53 |
54 | ```sh
55 | Algorithm_Space_Dataset[_Evaluation][_ModifiedParams].yaml
56 | ```
57 |
58 | For example, using `DARTS` algorithm, searching on `NASBench201` space and `CIFAR-10` dataset, evaluated by `NASBench301` while modifying `MAX_EPOCH` parameter to `75`, then the file should be named as this:
59 |
60 | ```sh
61 | darts_nasbench201_cifar10_nasbench301_maxepoch75.yaml
62 | ```
63 |
64 | ## Running Examples
65 |
66 | XNAS reads configuration files from the command line. A simple running example is following:
67 |
68 | ```sh
69 | python scripts/search/DARTS.py --cfg configs/search/darts_darts_cifar10.yaml
70 | ```
71 |
72 | The configuration file can be overridden by adding or modifying additional parameters on the command line. For example, run with the modified output directory:
73 |
74 | ```sh
75 | python scripts/search/DARTS.py --cfg configs/search/darts_darts_cifar10.yaml OUT_DIR exp/another_folder
76 | ```
77 |
78 | Using `.sh` files to save commands is very efficient for when you need to run and modify parameters repeatedly. We provide shell scripts under `$XNAS/examples` folder, together with other potential test code added in the future. It can be simply run with the following command:
79 |
80 | ```sh
81 | ./examples/darts_darts_cifar10.sh
82 | ```
83 |
84 | A common mistake is forgetting to add run permissions to these files:
85 |
86 | ```sh
87 | chmod +x examples/*/*.sh examples/*/*/*.sh
88 | ```
89 |
90 | The script files follow the same naming convention as the configuration file above, and set the output directory to the same folder. You can achieve a continuous search/training process by adding multiple lines of commands to the script file at once.
91 |
92 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. XNAS documentation master file, created by
2 | sphinx-quickstart on Sat May 21 14:09:17 2022.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 |
7 | Welcome to XNAS' documentation!
8 | ====================================
9 |
10 |
11 | .. toctree::
12 | :maxdepth: 2
13 | :caption: Get Started
14 |
15 | get_started.md
16 |
17 | .. toctree::
18 | :maxdepth: 2
19 | :caption: Data Preparation
20 |
21 | data_preparation.md
22 |
23 |
24 | .. toctree::
25 | :maxdepth: 2
26 | :caption: Notes
27 |
28 | notes.md
29 |
30 |
31 | Indices and tables
32 | ==================
33 |
34 | * :ref:`genindex`
35 | * :ref:`search`
--------------------------------------------------------------------------------
/docs/notes.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | We welcome contributions to the library along with any potential issues or suggestions.
4 |
5 | ## Pull Requests
6 |
7 | We actively welcome your pull requests.
8 |
9 | 1. Fork the repo and create your branch from `dev` branch.
10 | 2. If you've added code that should be tested, pleash add code in `tests` folder.
11 | 3. If you've changed APIs, please update the description in relevant functions.
12 | 4. Create a pull request.
13 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx
2 | furo
3 | recommonmark
4 | sphinx_markdown_tables
5 |
--------------------------------------------------------------------------------
/examples/search/DrNAS/drnas_darts_cifar10.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/DrNAS.py --cfg configs/search/DrNAS/drnas_darts_cifar10.yaml OUT_DIR exp/search/drnas_darts_cifar10
2 |
--------------------------------------------------------------------------------
/examples/search/DrNAS/drnas_darts_imagenet.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/DrNAS.py --cfg configs/search/DrNAS/drnas_darts_imagenet.yaml OUT_DIR exp/search/drnas_darts_imagenet
2 |
--------------------------------------------------------------------------------
/examples/search/DrNAS/drnas_nasbench201_cifar10.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/DrNAS.py --cfg configs/search/DrNAS/drnas_nasbench201_cifar10.yaml OUT_DIR exp/search/drnas_nasbench201_cifar10
2 |
--------------------------------------------------------------------------------
/examples/search/DrNAS/drnas_nasbench201_cifar100.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/DrNAS.py --cfg configs/search/DrNAS/drnas_nasbench201_cifar100.yaml OUT_DIR exp/search/drnas_nasbench201_cifar100
2 |
--------------------------------------------------------------------------------
/examples/search/DrNAS/drnas_nasbench201_cifar10_progressive.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/DrNAS.py --cfg configs/search/DrNAS/drnas_nasbench201_cifar10_progressive.yaml OUT_DIR exp/search/drnas_nasbench201_cifar10_progressive
2 |
--------------------------------------------------------------------------------
/examples/search/OFA/train_supernet.sh:
--------------------------------------------------------------------------------
1 | OUT_NAME="OFA_trial_25"
2 | TASKS="normal_1 kernel_1 depth_1 depth_2 expand_1 expand_2"
3 |
4 | for loop in $TASKS
5 | do
6 | # echo `torchrun --nproc_per_node 2 scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3/$loop.yaml OUT_DIR exp/search/$OUT_NAME/$loop OPTIM.MAX_EPOCH 2 OPTIM.WARMUP_EPOCH 2 LOADER.BATCH_SIZE 128`
7 | echo "using gpus: $@"
8 | echo `python scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3/$loop.yaml OUT_DIR exp/search/$OUT_NAME/$loop NUM_GPUS $@`
9 | echo `sleep 5s`
10 | done
11 |
12 | # # full supernet
13 | # python scripts/search/OFA/train_supernet.py --cfg configs/OFA/mbv3/OFA_normal_phase1.yaml OUT_DIR exp/$OUT_NAME/normal_1
14 | # # elastic kernel size
15 | # python scripts/search/OFA/train_supernet.py --cfg configs/OFA/mbv3/OFA_kernel_phase1.yaml OUT_DIR exp/$OUT_NAME/kernel_1
16 | # # elastic depth
17 | # python scripts/search/OFA/train_supernet.py --cfg configs/OFA/mbv3/OFA_depth_phase1.yaml OUT_DIR exp/$OUT_NAME/depth_1
18 | # python scripts/search/OFA/train_supernet.py --cfg configs/OFA/mbv3/OFA_depth_phase2.yaml OUT_DIR exp/$OUT_NAME/depth_2
19 | # # elastic width
20 | # python scripts/search/OFA/train_supernet.py --cfg configs/OFA/mbv3/OFA_expand_phase1.yaml OUT_DIR exp/$OUT_NAME/expand_1
21 | # python scripts/search/OFA/train_supernet.py --cfg configs/OFA/mbv3/OFA_expand_phase1.yaml OUT_DIR exp/$OUT_NAME/expand_1
22 |
--------------------------------------------------------------------------------
/examples/search/RMINAS/rminas_darts_cifar10.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/RMINAS.py --cfg configs/search/RMINAS/rminas_darts_cifar10.yaml OUT_DIR exp/search/rminas_darts_cifar10
--------------------------------------------------------------------------------
/examples/search/RMINAS/rminas_nasbench201_cifar10.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/RMINAS.py --cfg configs/search/RMINAS/rminas_nasbench201_cifar10.yaml OUT_DIR exp/search/rminas_nasbench201_cifar10
--------------------------------------------------------------------------------
/examples/search/RMINAS/rminas_nasbench201_cifar100.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/RMINAS.py --cfg configs/search/RMINAS/rminas_nasbench201_cifar100.yaml OUT_DIR exp/search/rminas_nasbench201_cifar100
--------------------------------------------------------------------------------
/examples/search/RMINAS/rminas_nasbench201_imagenet16.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/RMINAS.py --cfg configs/search/RMINAS/rminas_nasbench201_imagenet16.yaml OUT_DIR exp/search/rminas_nasbench201_imagenet16
--------------------------------------------------------------------------------
/examples/search/darts_darts_cifar10.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/DARTS.py --cfg configs/search/darts_darts_cifar10.yaml OUT_DIR exp/search/darts_darts_cifar10
2 |
--------------------------------------------------------------------------------
/examples/search/darts_nasbench201_cifar10.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/DARTS.py --cfg configs/search/darts_nasbench201_cifar10.yaml OUT_DIR exp/search/darts_nasbench201_cifar10
2 |
--------------------------------------------------------------------------------
/examples/search/dropnas_darts_cifar10.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/DropNAS.py --cfg configs/search/dropnas_darts_cifar10.yaml OUT_DIR exp/search/dropnas_darts_cifar10
2 |
--------------------------------------------------------------------------------
/examples/search/gdas_nasbench201_cifar10.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/DrNAS.py --cfg configs/search/gdas_nasbench201_cifar10.yaml OUT_DIR exp/search/gdas_nasbench201_cifar10
2 |
--------------------------------------------------------------------------------
/examples/search/ofa_mbv3_imagenet.sh:
--------------------------------------------------------------------------------
1 | # 1. 最大网络训练
2 | python scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3_cifar10/OFA_normal_phase1.yaml
3 | # 2. elastic kernel size
4 | python scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3_cifar10/OFA_kernel_phase1.yaml
5 | # 3. elastic depth
6 | python scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3_cifar10/OFA_depth_phase1.yaml
7 | python scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3_cifar10/OFA_depth_phase2.yaml
8 | # 4. elastic width
9 | python scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3_cifar10/OFA_expand_phase1.yaml
10 | python scripts/search/OFA/train_supernet.py --cfg configs/search/OFA/mbv3_cifar10/OFA_expand_phase2.yaml
11 |
--------------------------------------------------------------------------------
/examples/search/pcdarts_pcdarts_cifar10.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/PCDARTS.py --cfg configs/search/pcdarts_pcdarts_cifar10.yaml OUT_DIR exp/search/pcdarts_pcdarts_cifar10
2 |
--------------------------------------------------------------------------------
/examples/search/pdarts_pdarts_cifar10.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/PDARTS.py --cfg configs/search/pdarts_pdarts_cifar10.yaml OUT_DIR exp/search/pdarts_pdarts_cifar10
2 |
--------------------------------------------------------------------------------
/examples/search/snas_nasbench201_cifar10.sh:
--------------------------------------------------------------------------------
1 | python scripts/search/DrNAS.py --cfg configs/search/snas_nasbench201_cifar10.yaml OUT_DIR exp/search/snas_nasbench201_cifar10
2 |
--------------------------------------------------------------------------------
/examples/train/darts_darts_cifar10.sh:
--------------------------------------------------------------------------------
1 | python scripts/train/DARTS.py --cfg configs/train/darts_darts_cifar10.yaml OUT_DIR exp/train/darts_darts_cifar10
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | configspace
2 | numpy
3 | python >= 3.7
4 | pytorch == 1.9
5 | pillow
6 | simplejson
7 | scikit-learn
8 | tensorboard
9 | yacs
--------------------------------------------------------------------------------
/scripts/search/BigNAS/eval.py:
--------------------------------------------------------------------------------
1 | """BigNAS subnet searching: Coarse-to-fine Architecture Selection"""
2 |
3 | import numpy as np
4 | from itertools import product
5 |
6 | import torch
7 |
8 | import xnas.core.config as config
9 | import xnas.logger.meter as meter
10 | import xnas.logger.logging as logging
11 | from xnas.core.builder import *
12 | from xnas.core.config import cfg
13 | from xnas.datasets.loader import get_normal_dataloader
14 | from xnas.logger.meter import TestMeter
15 |
16 |
17 | # Load config and check
18 | config.load_configs()
19 | logger = logging.get_logger(__name__)
20 |
21 |
22 | def main():
23 | setup_env()
24 | net = space_builder().cuda()
25 |
26 | [train_loader, valid_loader] = get_normal_dataloader()
27 |
28 | test_meter = TestMeter(len(valid_loader))
29 | # Validate
30 | top1_err, top5_err = validate(net, train_loader, valid_loader, test_meter)
31 | flops = net.compute_active_subnet_flops()
32 |
33 | logger.info("flops:{} top1_err:{} top5_err:{}".format(
34 | flops, top1_err, top5_err
35 | ))
36 |
37 |
38 | @torch.no_grad()
39 | def validate(subnet, train_loader, valid_loader, test_meter):
40 | # BN calibration
41 | subnet.eval()
42 | logger.info("Calibrating BN running statistics.")
43 | subnet.reset_running_stats_for_calibration()
44 | for cur_iter, (inputs, _) in enumerate(train_loader):
45 | if cur_iter >= cfg.BIGNAS.POST_BN_CALIBRATION_BATCH_NUM:
46 | break
47 | inputs = inputs.cuda()
48 | subnet(inputs) # forward only
49 |
50 | top1_err, top5_err = test_epoch(subnet, valid_loader, test_meter)
51 | return top1_err, top5_err
52 |
53 |
54 | def test_epoch(subnet, test_loader, test_meter):
55 | subnet.eval()
56 | test_meter.reset(True)
57 | test_meter.iter_tic()
58 | for cur_iter, (inputs, labels) in enumerate(test_loader):
59 | inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
60 | preds = subnet(inputs)
61 | top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
62 | top1_err, top5_err = top1_err.item(), top5_err.item()
63 |
64 | test_meter.iter_toc()
65 | test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
66 | test_meter.log_iter_stats(0, cur_iter)
67 | test_meter.iter_tic()
68 | top1_err = test_meter.mb_top1_err.get_win_avg()
69 | top5_err = test_meter.mb_top5_err.get_win_avg()
70 | # self.writer.add_scalar('val/top1_error', test_meter.mb_top1_err.get_win_avg(), cur_epoch)
71 | # self.writer.add_scalar('val/top5_error', test_meter.mb_top5_err.get_win_avg(), cur_epoch)
72 | # Log epoch stats
73 | test_meter.log_epoch_stats(0)
74 | # test_meter.reset()
75 | return top1_err, top5_err
76 |
77 |
78 | if __name__ == "__main__":
79 | main()
80 |
--------------------------------------------------------------------------------
/scripts/search/DARTS.py:
--------------------------------------------------------------------------------
1 | """DARTS searching"""
2 |
3 | import xnas.core.config as config
4 | import xnas.logger.logging as logging
5 | from xnas.core.config import cfg
6 | from xnas.core.builder import *
7 |
8 | # DARTS
9 | from xnas.algorithms.DARTS import *
10 | from xnas.runner.trainer import DartsTrainer
11 | from xnas.runner.optimizer import darts_alpha_optimizer
12 |
13 |
14 | # Load config and check
15 | config.load_configs()
16 | logger = logging.get_logger(__name__)
17 |
18 | def main():
19 | device = setup_env()
20 | search_space = space_builder()
21 | criterion = criterion_builder().to(device)
22 | evaluator = evaluator_builder()
23 |
24 | [train_loader, valid_loader] = construct_loader()
25 |
26 | # init models
27 | darts_controller = DartsCNNController(search_space, criterion).to(device)
28 | architect = Architect(darts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY)
29 |
30 | # init optimizers
31 | w_optim = optimizer_builder("SGD", darts_controller.weights())
32 | a_optim = darts_alpha_optimizer("Adam", darts_controller.alphas())
33 | lr_scheduler = lr_scheduler_builder(w_optim)
34 |
35 | # init recorders
36 | darts_trainer = DartsTrainer(
37 | darts_controller=darts_controller,
38 | architect=architect,
39 | criterion=criterion,
40 | lr_scheduler=lr_scheduler,
41 | w_optim=w_optim,
42 | a_optim=a_optim,
43 | train_loader=train_loader,
44 | valid_loader=valid_loader,
45 | )
46 |
47 | # load checkpoint or initial weights
48 | start_epoch = darts_trainer.darts_loading() if cfg.SEARCH.AUTO_RESUME else 0
49 |
50 | # start training
51 | darts_trainer.start()
52 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
53 | # train epoch
54 | darts_trainer.train_epoch(cur_epoch)
55 | # test epoch
56 | if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH:
57 | darts_trainer.test_epoch(cur_epoch)
58 | # recording genotype and alpha to logger
59 | logger.info("=== Optimal genotype at epoch: {} ===".format(cur_epoch))
60 | logger.info(darts_trainer.model.genotype())
61 | logger.info("=== alphas at epoch: {} ===".format(cur_epoch))
62 | darts_trainer.model.print_alphas(logger)
63 | # evaluate model
64 | if evaluator:
65 | evaluator(darts_trainer.model.genotype())
66 | darts_trainer.finish()
67 |
68 |
69 | if __name__ == "__main__":
70 | main()
71 |
--------------------------------------------------------------------------------
/scripts/search/DropNAS.py:
--------------------------------------------------------------------------------
1 | """DropNAS searching"""
2 |
3 | from torch import device
4 | import torch.nn as nn
5 |
6 | import xnas.core.config as config
7 | import xnas.logger.logging as logging
8 | import xnas.logger.meter as meter
9 | from xnas.core.config import cfg
10 | from xnas.core.builder import *
11 |
12 | # DropNAS
13 | from xnas.algorithms.DropNAS import *
14 | from xnas.runner.trainer import DartsTrainer
15 | from xnas.runner.optimizer import darts_alpha_optimizer
16 |
17 |
18 | # Load config and check
19 | config.load_configs()
20 | logger = logging.get_logger(__name__)
21 |
22 | def main():
23 | device = setup_env()
24 | search_space = space_builder()
25 | criterion = criterion_builder().to(device)
26 | evaluator = evaluator_builder()
27 |
28 | [train_loader, valid_loader] = construct_loader()
29 |
30 | # init models
31 | darts_controller = DropNAS_CNNController(search_space, criterion).to(device)
32 |
33 | # init optimizers
34 | w_optim = optimizer_builder("SGD", darts_controller.weights())
35 | a_optim = darts_alpha_optimizer("Adam", darts_controller.alphas())
36 | lr_scheduler = lr_scheduler_builder(w_optim)
37 |
38 | # init recorders
39 | dropnas_trainer = DropNAS_Trainer(
40 | darts_controller=darts_controller,
41 | architect=None,
42 | criterion=criterion,
43 | lr_scheduler=lr_scheduler,
44 | w_optim=w_optim,
45 | a_optim=a_optim,
46 | train_loader=train_loader,
47 | valid_loader=valid_loader,
48 | )
49 |
50 | # load checkpoint or initial weights
51 | start_epoch = dropnas_trainer.darts_loading() if cfg.SEARCH.AUTO_RESUME else 0
52 |
53 | # start training
54 | dropnas_trainer.start()
55 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
56 | # train epoch
57 | drop_rate = 0. if cur_epoch < cfg.OPTIM.WARMUP_EPOCH else cfg.DROPNAS.DROP_RATE
58 | logger.info("Current drop rate: {:.6f}".format(drop_rate))
59 | dropnas_trainer.train_epoch(cur_epoch, drop_rate)
60 | # test epoch
61 | if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH:
62 | # NOTE: the source code of DropNAS does not use test codes.
63 | # recording genotype and alpha to logger
64 | logger.info("=== Optimal genotype at epoch: {} ===".format(cur_epoch))
65 | logger.info(dropnas_trainer.model.genotype())
66 | logger.info("=== alphas at epoch: {} ===".format(cur_epoch))
67 | dropnas_trainer.model.print_alphas(logger)
68 | if evaluator:
69 | evaluator(dropnas_trainer.model.genotype())
70 | dropnas_trainer.finish()
71 |
72 |
73 | class DropNAS_Trainer(DartsTrainer):
74 | """Trainer for DropNAS.
75 | Rewrite the train_epoch with DropNAS's double-losses policy.
76 | """
77 | def __init__(self, darts_controller, architect, criterion, w_optim, a_optim, lr_scheduler, train_loader, valid_loader):
78 | super().__init__(darts_controller, architect, criterion, w_optim, a_optim, lr_scheduler, train_loader, valid_loader)
79 |
80 | def train_epoch(self, cur_epoch, drop_rate):
81 | self.model.train()
82 | lr = self.lr_scheduler.get_last_lr()[0]
83 | cur_step = cur_epoch * len(self.train_loader)
84 | self.writer.add_scalar('train/lr', lr, cur_step)
85 | self.train_meter.iter_tic()
86 | for cur_iter, (trn_X, trn_y) in enumerate(self.train_loader):
87 | trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device, non_blocking=True)
88 |
89 | # forward pass loss
90 | self.a_optimizer.zero_grad()
91 | self.optimizer.zero_grad()
92 | preds = self.model(trn_X, drop_rate=drop_rate)
93 | loss1 = self.criterion(preds, trn_y)
94 | loss1.backward()
95 | nn.utils.clip_grad_norm_(self.model.weights(), cfg.OPTIM.GRAD_CLIP)
96 | self.optimizer.step()
97 | if cur_epoch >= cfg.OPTIM.WARMUP_EPOCH:
98 | self.a_optimizer.step()
99 |
100 | # weight decay loss
101 | self.a_optimizer.zero_grad()
102 | self.optimizer.zero_grad()
103 | loss2 = self.model.weight_decay_loss(cfg.OPTIM.WEIGHT_DECAY) \
104 | + self.model.alpha_decay_loss(cfg.DARTS.ALPHA_WEIGHT_DECAY)
105 | loss2.backward()
106 | nn.utils.clip_grad_norm_(self.model.weights(), cfg.OPTIM.GRAD_CLIP)
107 | self.optimizer.step()
108 | self.a_optimizer.step()
109 |
110 | self.model.adjust_alphas()
111 | loss = loss1 + loss2
112 |
113 | # Compute the errors
114 | top1_err, top5_err = meter.topk_errors(preds, trn_y, [1, 5])
115 | loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item()
116 | self.train_meter.iter_toc()
117 | # Update and log stats
118 | self.train_meter.update_stats(top1_err, top5_err, loss, lr, trn_X.size(0))
119 | self.train_meter.log_iter_stats(cur_epoch, cur_iter)
120 | self.train_meter.iter_tic()
121 | self.writer.add_scalar('train/loss', loss, cur_step)
122 | self.writer.add_scalar('train/top1_error', top1_err, cur_step)
123 | self.writer.add_scalar('train/top5_error', top5_err, cur_step)
124 | cur_step += 1
125 | # Log epoch stats
126 | self.train_meter.log_epoch_stats(cur_epoch)
127 | self.train_meter.reset()
128 | # saving model
129 | if (cur_epoch + 1) % cfg.SAVE_PERIOD == 0:
130 | self.saving(cur_epoch)
131 |
132 |
133 | if __name__ == "__main__":
134 | main()
135 |
--------------------------------------------------------------------------------
/scripts/search/NBMacro.py:
--------------------------------------------------------------------------------
1 | """
2 | NAS-Bench-Macro: only (8 layers * 3 choices) + CIFAR10
3 | """
4 |
5 | import xnas.core.config as config
6 | import xnas.logger.logging as logging
7 | from xnas.core.config import cfg
8 | from xnas.core.builder import *
9 | from xnas.runner.trainer import OneShotTrainer
10 | from xnas.algorithms.SPOS import RAND, REA
11 |
12 | # Load config and check
13 | config.load_configs()
14 | logger = logging.get_logger(__name__)
15 |
16 |
17 | def main():
18 | setup_env()
19 | criterion = criterion_builder().cuda()
20 | [train_loader, valid_loader] = construct_loader()
21 | model = space_builder().cuda()
22 | optimizer = optimizer_builder("SGD", model.parameters())
23 | lr_scheduler = lr_scheduler_builder(optimizer)
24 |
25 | # init sampler
26 | train_sampler = RAND(3, 8)
27 | evaluate_sampler = REA(3, 8)
28 |
29 | # init recorders
30 | nbm_trainer = OneShotTrainer(
31 | supernet=model,
32 | criterion=criterion,
33 | optimizer=optimizer,
34 | lr_scheduler=lr_scheduler,
35 | train_loader=train_loader,
36 | test_loader=valid_loader,
37 | sample_type='iter'
38 | )
39 | nbm_trainer.register_iter_sample(train_sampler)
40 |
41 | # load checkpoint or initial weights
42 | start_epoch = nbm_trainer.loading() if cfg.SEARCH.AUTO_RESUME else 0
43 |
44 | # start training
45 | nbm_trainer.start()
46 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
47 | # train epoch
48 | top1_err = nbm_trainer.train_epoch(cur_epoch)
49 | # test epoch
50 | if (cur_epoch + 1) % cfg.EVAL_PERIOD == 0 or (cur_epoch + 1) == cfg.OPTIM.MAX_EPOCH:
51 | top1_err = nbm_trainer.test_epoch(cur_epoch)
52 | nbm_trainer.finish()
53 |
54 | # # sample best architecture from supernet
55 | # for cycle in range(200): # NOTE: this should be a hyperparameter
56 | # sample = evaluate_sampler.suggest()
57 | # top1_err = nbm_trainer.evaluate_epoch(sample)
58 | # evaluate_sampler.record(sample, top1_err)
59 | # best_arch, best_top1err = evaluate_sampler.final_best()
60 | # logger.info("Best arch: {} \nTop1 error: {}".format(best_arch, best_top1err))
61 |
62 | # from xnas.evaluations.NASBenchMacro.evaluate import evaluate
63 | # # for example : arch = '00000000'
64 | # arch = ''
65 | # Nbm_Eva(arch)
66 |
67 |
68 | if __name__ == '__main__':
69 | main()
70 |
--------------------------------------------------------------------------------
/scripts/search/OFA/search.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/scripts/search/OFA/search.py
--------------------------------------------------------------------------------
/scripts/search/OFA/train_teacher_model.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/scripts/search/OFA/train_teacher_model.py
--------------------------------------------------------------------------------
/scripts/search/PCDARTS.py:
--------------------------------------------------------------------------------
1 | """PCDARTS searching"""
2 |
3 | import xnas.core.config as config
4 | import xnas.logger.logging as logging
5 | from xnas.core.config import cfg
6 | from xnas.core.builder import *
7 |
8 | # DARTS
9 | from xnas.algorithms.PCDARTS import *
10 | from xnas.runner.trainer import DartsTrainer
11 | from xnas.runner.optimizer import darts_alpha_optimizer
12 |
13 |
14 | # Load config and check
15 | config.load_configs()
16 | logger = logging.get_logger(__name__)
17 |
18 | def main():
19 | device = setup_env()
20 | search_space = space_builder()
21 | criterion = criterion_builder().to(device)
22 | evaluator = evaluator_builder()
23 |
24 | [train_loader, valid_loader] = construct_loader()
25 |
26 | # init models
27 | pcdarts_controller = PCDartsCNNController(search_space, criterion).to(device)
28 | architect = Architect(pcdarts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY)
29 |
30 | # init optimizers
31 | w_optim = optimizer_builder("SGD", pcdarts_controller.weights())
32 | a_optim = darts_alpha_optimizer("Adam", pcdarts_controller.alphas())
33 | lr_scheduler = lr_scheduler_builder(w_optim)
34 |
35 | # init recorders
36 | pcdarts_trainer = DartsTrainer(
37 | darts_controller=pcdarts_controller,
38 | architect=architect,
39 | criterion=criterion,
40 | lr_scheduler=lr_scheduler,
41 | w_optim=w_optim,
42 | a_optim=a_optim,
43 | train_loader=train_loader,
44 | valid_loader=valid_loader,
45 | )
46 |
47 | # Load checkpoint or initial weights
48 | start_epoch = pcdarts_trainer.darts_loading() if cfg.SEARCH.AUTO_RESUME else 0
49 |
50 | # start training
51 | pcdarts_trainer.start()
52 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
53 | # train epoch
54 | pcdarts_trainer.train_epoch(cur_epoch, cur_epoch>=15)
55 | # test epoch
56 | if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH:
57 | pcdarts_trainer.test_epoch(cur_epoch)
58 | # recording genotype and alpha to logger
59 | logger.info("=== Optimal genotype at epoch: {} ===".format(cur_epoch))
60 | logger.info(pcdarts_trainer.model.genotype())
61 | logger.info("=== alphas at epoch: {} ===".format(cur_epoch))
62 | pcdarts_trainer.model.print_alphas(logger)
63 | # evaluate model
64 | if evaluator:
65 | evaluator(pcdarts_trainer.model.genotype())
66 | pcdarts_trainer.finish()
67 |
68 |
69 | if __name__ == "__main__":
70 | main()
71 |
--------------------------------------------------------------------------------
/scripts/search/PDARTS.py:
--------------------------------------------------------------------------------
1 | """PDARTS searching"""
2 |
3 | from cmath import phase
4 | import xnas.core.config as config
5 | import xnas.logger.logging as logging
6 | from xnas.core.config import cfg
7 | from xnas.core.builder import *
8 |
9 | # PDARTS
10 | from xnas.algorithms.PDARTS import *
11 | from xnas.runner.trainer import DartsTrainer
12 | from xnas.runner.optimizer import darts_alpha_optimizer
13 |
14 |
15 | # Load config and check
16 | config.load_configs()
17 | logger = logging.get_logger(__name__)
18 |
19 | def main():
20 | device = setup_env()
21 | criterion = criterion_builder().to(device)
22 | evaluator = evaluator_builder()
23 |
24 | [train_loader, valid_loader] = construct_loader()
25 |
26 | num_edges = (cfg.SPACE.NODES+3)*cfg.SPACE.NODES//2
27 | basic_op = []
28 | for i in range(num_edges * 2):
29 | basic_op.append(cfg.SPACE.BASIC_OP)
30 |
31 | for sp in range(len(cfg.PDARTS.NUM_TO_KEEP)):
32 | search_space = space_builder(
33 | add_layers=cfg.PDARTS.ADD_LAYERS[sp],
34 | add_width=cfg.PDARTS.ADD_WIDTH[sp],
35 | dropout_rate=float(cfg.PDARTS.DROPOUT_RATE[sp]),
36 | basic_op=basic_op,
37 | )
38 |
39 | # init models
40 | pdarts_controller = PDartsCNNController(search_space, criterion).to(device)
41 | architect = Architect(pdarts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY)
42 |
43 | # init optimizers
44 | w_optim = optimizer_builder("SGD", pdarts_controller.subnet_weights())
45 | a_optim = darts_alpha_optimizer("Adam", pdarts_controller.alphas())
46 | lr_scheduler = lr_scheduler_builder(w_optim)
47 |
48 | # init recorders
49 | pdarts_trainer = DartsTrainer(
50 | darts_controller=pdarts_controller,
51 | architect=architect,
52 | criterion=criterion,
53 | lr_scheduler=lr_scheduler,
54 | w_optim=w_optim,
55 | a_optim=a_optim,
56 | train_loader=train_loader,
57 | valid_loader=valid_loader,
58 | )
59 |
60 | # Load checkpoint or initial weights
61 | start_epoch = pdarts_trainer.darts_loading() if cfg.SEARCH.AUTO_RESUME else 0
62 |
63 | # start training
64 | pdarts_trainer.start()
65 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
66 | # train epoch
67 | if cur_epoch < cfg.PDARTS.EPS_NO_ARCHS[sp]:
68 | pdarts_trainer.model.update_p(
69 | float(cfg.PDARTS.DROPOUT_RATE[sp]) *
70 | (cfg.OPTIM.MAX_EPOCH - cur_epoch - 1) /
71 | cfg.OPTIM.MAX_EPOCH
72 | )
73 | pdarts_trainer.train_epoch(cur_epoch, alpha_step=False)
74 | else:
75 | pdarts_trainer.model.update_p(
76 | float(cfg.PDARTS.DROPOUT_RATE[sp]) *
77 | np.exp(-(cur_epoch - cfg.PDARTS.EPS_NO_ARCHS[sp]) * cfg.PDARTS.SCALE_FACTOR)
78 | )
79 | pdarts_trainer.train_epoch(cur_epoch, alpha_step=True)
80 |
81 | # test epoch
82 | if (cur_epoch+1) >= cfg.OPTIM.MAX_EPOCH - 5:
83 | pdarts_trainer.test_epoch(cur_epoch)
84 | # recording genotype and alpha to logger
85 | logger.info("=== Optimal genotype at epoch: {} ===".format(cur_epoch))
86 | logger.info(pdarts_trainer.model.genotype())
87 | logger.info("=== alphas at epoch: {} ===".format(cur_epoch))
88 | pdarts_trainer.model.print_alphas(logger)
89 | # evaluate model
90 | if evaluator:
91 | evaluator(pdarts_trainer.model.genotype())
92 | pdarts_trainer.finish()
93 | logger.info("Top-{} primitive: {}".format(
94 | cfg.PDARTS.NUM_TO_KEEP[sp],
95 | pdarts_trainer.model.get_topk_op(cfg.PDARTS.NUM_TO_KEEP[sp]))
96 | )
97 | if sp == len(cfg.PDARTS.NUM_TO_KEEP) - 1:
98 | phase_ending(pdarts_trainer)
99 | else:
100 | basic_op = pdarts_trainer.model.get_topk_op(cfg.PDARTS.NUM_TO_KEEP[sp])
101 | phase_ending(pdarts_trainer)
102 |
103 |
104 | def phase_ending(pdarts_trainer, final=False):
105 | phase = "Final" if final else "Stage"
106 | logger.info("=== {} optimal genotype ===".format(phase))
107 | logger.info(pdarts_trainer.model.genotype(final=True))
108 | logger.info("=== {} alphas ===".format(phase))
109 | pdarts_trainer.model.print_alphas(logger)
110 | # restrict skip connect
111 | logger.info('Restricting skip-connect')
112 | for sks in range(0, 9):
113 | max_sk = 8-sks
114 | num_sk = pdarts_trainer.model.get_skip_number()
115 | if not num_sk > max_sk:
116 | continue
117 | while num_sk > max_sk:
118 | pdarts_trainer.model.delete_skip()
119 | num_sk = pdarts_trainer.model.get_skip_number()
120 |
121 | logger.info('Number of skip-connect: %d', max_sk)
122 | logger.info(pdarts_trainer.model.genotype(final=True))
123 |
124 |
125 | if __name__ == "__main__":
126 | main()
127 |
--------------------------------------------------------------------------------
/scripts/search/SPOS.py:
--------------------------------------------------------------------------------
1 | """Single Path One-Shot"""
2 |
3 | import xnas.core.config as config
4 | import xnas.logger.logging as logging
5 | from xnas.core.config import cfg
6 | from xnas.core.builder import *
7 |
8 | # SPOS
9 | from xnas.algorithms.SPOS import RAND, REA
10 | from xnas.runner.trainer import OneShotTrainer
11 |
12 |
13 | # Load config and check
14 | config.load_configs()
15 | logger = logging.get_logger(__name__)
16 |
17 | def main():
18 | device = setup_env()
19 | criterion = criterion_builder().to(device)
20 | [train_loader, valid_loader] = construct_loader()
21 | model = space_builder().cuda() #to(device)
22 | optimizer = optimizer_builder("SGD", model.parameters())
23 | lr_scheduler = lr_scheduler_builder(optimizer)
24 |
25 | # init sampler
26 | train_sampler = RAND(cfg.SPOS.NUM_CHOICE, cfg.SPOS.LAYERS)
27 | evaluate_sampler = REA(cfg.SPOS.NUM_CHOICE, cfg.SPOS.LAYERS)
28 |
29 | # init recorders
30 | spos_trainer = OneShotTrainer(
31 | supernet=model,
32 | criterion=criterion,
33 | optimizer=optimizer,
34 | lr_scheduler=lr_scheduler,
35 | train_loader=train_loader,
36 | test_loader=valid_loader,
37 | sample_type='iter'
38 | )
39 | spos_trainer.register_iter_sample(train_sampler)
40 |
41 | # load checkpoint or initial weights
42 | start_epoch = spos_trainer.loading() if cfg.SEARCH.AUTO_RESUME else 0
43 |
44 | # start training
45 | spos_trainer.start()
46 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
47 | # train epoch
48 | top1_err = spos_trainer.train_epoch(cur_epoch)
49 | # test epoch
50 | if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH:
51 | top1_err = spos_trainer.test_epoch(cur_epoch)
52 | spos_trainer.finish()
53 |
54 | # sample best architecture from supernet
55 | for cycle in range(200): # NOTE: this should be a hyperparameter
56 | sample = evaluate_sampler.suggest()
57 | top1_err = spos_trainer.evaluate_epoch(sample)
58 | evaluate_sampler.record(sample, top1_err)
59 | best_arch, best_top1err = evaluate_sampler.final_best()
60 | logger.info("Best arch: {} \nTop1 error: {}".format(best_arch, best_top1err))
61 |
62 | if __name__ == '__main__':
63 | main()
64 |
--------------------------------------------------------------------------------
/scripts/train/DARTS.py:
--------------------------------------------------------------------------------
1 | """DARTS retraining"""
2 | import torch
3 | import torch.nn as nn
4 |
5 | import xnas.core.config as config
6 | import xnas.logger.logging as logging
7 | import xnas.logger.meter as meter
8 | from xnas.core.config import cfg
9 | from xnas.core.builder import *
10 | from xnas.datasets.loader import get_normal_dataloader
11 |
12 | from xnas.runner.trainer import Trainer
13 |
14 | # Load config and check
15 | config.load_configs()
16 | logger = logging.get_logger(__name__)
17 |
18 | def main():
19 | device = setup_env()
20 |
21 | model = space_builder().to(device)
22 | criterion = criterion_builder().to(device)
23 | # evaluator = evaluator_builder()
24 |
25 | [train_loader, valid_loader] = get_normal_dataloader()
26 | optimizer = optimizer_builder("SGD", model.parameters())
27 | lr_scheduler = lr_scheduler_builder(optimizer)
28 |
29 | darts_retrainer = Darts_Retrainer(
30 | model, criterion, optimizer, lr_scheduler, train_loader, valid_loader
31 | )
32 |
33 | start_epoch = darts_retrainer.loading() if cfg.SEARCH.AUTO_RESUME else 0
34 |
35 | darts_retrainer.start()
36 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
37 | darts_retrainer.model.drop_path_prob = cfg.TRAIN.DROP_PATH_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH
38 | darts_retrainer.train_epoch(cur_epoch)
39 | if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH:
40 | darts_retrainer.test_epoch(cur_epoch)
41 | darts_retrainer.finish()
42 |
43 |
44 | # overwrite training & validating with auxiliary
45 | class Darts_Retrainer(Trainer):
46 | def __init__(self, model, criterion, optimizer, lr_scheduler, train_loader, test_loader):
47 | super().__init__(model, criterion, optimizer, lr_scheduler, train_loader, test_loader)
48 |
49 | def train_epoch(self, cur_epoch):
50 | self.model.train()
51 | lr = self.lr_scheduler.get_last_lr()[0]
52 | cur_step = cur_epoch * len(self.train_loader)
53 | self.writer.add_scalar('train/lr', lr, cur_step)
54 | self.train_meter.iter_tic()
55 | for cur_iter, (inputs, labels) in enumerate(self.train_loader):
56 | inputs, labels = inputs.to(self.device), labels.to(self.device, non_blocking=True)
57 | preds, preds_aux = self.model(inputs)
58 | loss = self.criterion(preds, labels)
59 | self.optimizer.zero_grad()
60 | if cfg.TRAIN.AUX_WEIGHT > 0.:
61 | loss += cfg.TRAIN.AUX_WEIGHT * self.criterion(preds_aux, labels)
62 | loss.backward()
63 | nn.utils.clip_grad_norm_(self.model.weights(), cfg.OPTIM.GRAD_CLIP)
64 | self.optimizer.step()
65 |
66 | # Compute the errors
67 | top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
68 | loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item()
69 | self.train_meter.iter_toc()
70 | # Update and log stats
71 | self.train_meter.update_stats(top1_err, top5_err, loss, lr, inputs.size(0) * cfg.NUM_GPUS)
72 | self.train_meter.log_iter_stats(cur_epoch, cur_iter)
73 | self.train_meter.iter_tic()
74 | self.writer.add_scalar('train/loss', loss, cur_step)
75 | self.writer.add_scalar('train/top1_error', top1_err, cur_step)
76 | self.writer.add_scalar('train/top5_error', top5_err, cur_step)
77 | cur_step += 1
78 | # Log epoch stats
79 | self.train_meter.log_epoch_stats(cur_epoch)
80 | self.train_meter.reset()
81 | self.lr_scheduler.step()
82 | # Saving checkpoint
83 | if (cur_epoch + 1) % cfg.SAVE_PERIOD == 0:
84 | self.saving(cur_epoch)
85 |
86 | @torch.no_grad()
87 | def test_epoch(self, cur_epoch):
88 | self.model.eval()
89 | self.test_meter.iter_tic()
90 | for cur_iter, (inputs, labels) in enumerate(self.test_loader):
91 | inputs, labels = inputs.to(self.device), labels.to(self.device, non_blocking=True)
92 | preds, _ = self.model(inputs)
93 | top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
94 | top1_err, top5_err = top1_err.item(), top5_err.item()
95 |
96 | self.test_meter.iter_toc()
97 | self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
98 | self.test_meter.log_iter_stats(cur_epoch, cur_iter)
99 | self.test_meter.iter_tic()
100 | top1_err = self.test_meter.mb_top1_err.get_win_avg()
101 | self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
102 | self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
103 | # Log epoch stats
104 | self.test_meter.log_epoch_stats(cur_epoch)
105 | self.test_meter.reset()
106 | # Saving best model
107 | if self.best_err > top1_err:
108 | self.best_err = top1_err
109 | self.saving(cur_epoch, best=True)
110 |
111 | if __name__ == '__main__':
112 | main()
113 |
--------------------------------------------------------------------------------
/tests/MultiSizeRandomCrop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import torchvision.datasets as dset
4 | import torchvision.transforms as transforms
5 | from xnas.datasets.transforms import MultiSizeRandomCrop
6 |
7 | msrc = MultiSizeRandomCrop([4,224])
8 | # print(MultiSizeRandomCrop.CANDIDATE_SIZES)
9 |
10 | for i in range(5):
11 | MultiSizeRandomCrop.sample_image_size()
12 | print(MultiSizeRandomCrop.ACTIVE_SIZE)
13 |
14 | def my_collate(batch):
15 | msrc.sample_image_size()
16 | xs = torch.stack([i[0] for i in batch])
17 | ys = torch.Tensor([i[1] for i in batch])
18 | return [xs, ys]
19 |
20 | T = transforms.Compose([
21 | msrc,
22 | transforms.ToTensor(),
23 | ])
24 |
25 | _data = dset.CIFAR10(
26 | root='./data/cifar10',
27 | train=True,
28 | transform=T,
29 | )
30 |
31 | loader = data.DataLoader(
32 | dataset=_data,
33 | batch_size=128,
34 | collate_fn=my_collate,
35 | )
36 |
37 | for i, (trn_X, trn_y) in enumerate(loader):
38 | print(trn_X.shape, trn_y.shape)
39 | if i==5:
40 | break
41 |
--------------------------------------------------------------------------------
/tests/adjust_lr_per_iter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class model1(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 | self.conv = nn.Conv2d(3,3,3)
9 |
10 | def forward(self, x):
11 | x = self.conv(x)
12 | return x
13 |
14 | net = model1()
15 |
16 | opr = torch.optim.SGD(net.parameters(), 0.1)
17 | lrs = torch.optim.lr_scheduler.CosineAnnealingLR(opr, 10)
18 |
19 | from xnas.runner.scheduler import adjust_learning_rate_per_batch
20 |
21 | for cur_epoch in range(10):
22 | print("epoch:{} lr:{}".format(cur_epoch, lrs.get_last_lr()[0]))
23 | opr.zero_grad()
24 | opr.step()
25 | lrs.step()
26 |
27 | print("*"*20)
28 | del opr, lrs
29 |
30 | opr = torch.optim.SGD(net.parameters(), 0.1)
31 | lrs = torch.optim.lr_scheduler.CosineAnnealingLR(opr, 10)
32 |
33 | for cur_epoch in range(10):
34 | opr.zero_grad()
35 | for cur_iter in range(5):
36 | new_lr = adjust_learning_rate_per_batch(
37 | init_lr=0.1,
38 | n_epochs=10,
39 | n_warmup_epochs=3,
40 | epoch=cur_epoch,
41 | n_iter=5,
42 | iter=cur_iter,
43 | warmup=(cur_epoch < 3),
44 | warmup_lr=0.1, # use base_lr as warmup_lr
45 | )
46 | for param_group in opr.param_groups:
47 | param_group["lr"] = new_lr
48 | if cur_iter == 0:
49 | print("epoch:{} iter:{} lr:{}".format(cur_epoch, cur_iter, new_lr))
50 | opr.step()
51 |
--------------------------------------------------------------------------------
/tests/configs/datasets.yaml:
--------------------------------------------------------------------------------
1 | LOADER:
2 | DATASET: 'cifar10'
3 | SPLIT: [0.5, 0.5]
4 | BATCH_SIZE: 64
5 | NUM_CLASSES: 10
6 | SEARCH:
7 | IM_SIZE: 32
8 |
--------------------------------------------------------------------------------
/tests/configs/imagenet.yaml:
--------------------------------------------------------------------------------
1 | LOADER:
2 | DATASET: 'imagenet'
3 | BATCH_SIZE: 64
4 | NUM_CLASSES: 1000
5 | USE_VAL: True
6 | NUM_WORKERS: 4
7 | NUM_GPUS: 2
8 | SEARCH:
9 | MULTI_SIZES: [160, 192, 224]
10 |
--------------------------------------------------------------------------------
/tests/dataloader.py:
--------------------------------------------------------------------------------
1 | import xnas.core.config as config
2 | from xnas.datasets.loader import construct_loader
3 | from xnas.core.config import cfg
4 |
5 | config.load_configs()
6 |
7 | # cifar10
8 | [train_loader, valid_loader] = construct_loader()
9 |
10 | # cifar100
11 | cfg.LOADER.DATASET = 'cifar100'
12 | cfg.LOADER.NUM_CLASSES = 100
13 | [train_loader, valid_loader] = construct_loader()
14 |
15 | # imagenet16
16 | cfg.LOADER.DATASET = 'imagenet16'
17 | cfg.LOADER.NUM_CLASSES = 120
18 | [train_loader, valid_loader] = construct_loader()
19 |
20 |
--------------------------------------------------------------------------------
/tests/imagenet.py:
--------------------------------------------------------------------------------
1 | import xnas.core.config as config
2 | from xnas.datasets.loader import construct_loader
3 | from xnas.core.config import cfg
4 |
5 | config.load_configs()
6 |
7 | [train_loader, valid_loader] = construct_loader()
8 |
9 | for i, (trn_X, trn_y) in enumerate(train_loader):
10 | print(trn_X.shape, trn_y.shape)
11 | if i==9:
12 | break
13 |
14 | # cfg.SEARCH.MULTI_SIZES = []
15 |
16 | print("===")
17 | for i, (trn_X, trn_y) in enumerate(valid_loader):
18 | print(trn_X.shape, trn_y.shape)
19 | if i==9:
20 | break
21 |
--------------------------------------------------------------------------------
/tests/nb101space.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/tests/nb101space.py
--------------------------------------------------------------------------------
/tests/nb1shot1space.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/tests/nb1shot1space.py
--------------------------------------------------------------------------------
/tests/nb201eval.py:
--------------------------------------------------------------------------------
1 | from xnas.evaluations.NASBench201 import evaluate, index_to_genotype, distill
2 |
3 | arch = index_to_genotype(2333)
4 | result = evaluate(arch)
5 | (
6 | cifar10_train,
7 | cifar10_test,
8 | cifar100_train,
9 | cifar100_valid,
10 | cifar100_test,
11 | imagenet16_train,
12 | imagenet16_valid,
13 | imagenet16_test,
14 | ) = distill(result)
15 |
16 | print("cifar10 train %f test %f", cifar10_train, cifar10_test)
17 | print("cifar100 train %f valid %f test %f", cifar100_train, cifar100_valid, cifar100_test)
18 | print("imagenet16 train %f valid %f test %f", imagenet16_train, imagenet16_valid, imagenet16_test)
19 |
20 |
21 | result = evaluate(arch, epoch=200)
22 | (
23 | cifar10_train,
24 | cifar10_test,
25 | cifar100_train,
26 | cifar100_valid,
27 | cifar100_test,
28 | imagenet16_train,
29 | imagenet16_valid,
30 | imagenet16_test,
31 | ) = distill(result)
32 |
33 | print("cifar10 train %f test %f", cifar10_train, cifar10_test)
34 | print("cifar100 train %f valid %f test %f", cifar100_train, cifar100_valid, cifar100_test)
35 | print("imagenet16 train %f valid %f test %f", imagenet16_train, imagenet16_valid, imagenet16_test)
36 |
--------------------------------------------------------------------------------
/tests/nb301eval.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/tests/nb301eval.py
--------------------------------------------------------------------------------
/tests/ofa_matrices_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | def test_local():
5 | root = '/home/xfey/XNAS/exp/search/OFA_trial_25/kernel_1/checkpoints/'
6 | filename_prefix = 'model_epoch_'
7 | filename_postfix = '.pyth'
8 |
9 | for i in range(101, 110):
10 | ckpt = torch.load(root+filename_prefix+'0'+str(i)+filename_postfix)
11 | for k,v in ckpt['model_state'].items():
12 | # print(k)
13 | # if k.endswith('5to3_matrix'):
14 | if k == 'blocks.6.conv.depth_conv.conv.5to3_matrix':
15 | print(k, v[0:3, 0:3])
16 | break
17 | for k,v in ckpt['model_state'].items():
18 | # print(k)
19 | # if k.endswith('7to5_matrix'):
20 | if k == 'blocks.6.conv.depth_conv.conv.7to5_matrix':
21 | print(k, v[0:3, 0:3])
22 | break
23 |
24 | def test_original():
25 | root = "/home/xfey/XNAS/tests/weights/"
26 | ckpt = torch.load(root+"ofa_D4_E6_K357")
27 | for k,v in ckpt['state_dict'].items():
28 | # print(k)
29 | if k.endswith('5to3_matrix') and k.startswith('blocks.6'):
30 | print(k, v[0:3, 0:3])
31 | break
32 | for k,v in ckpt['state_dict'].items():
33 | # print(k)
34 | if k.endswith('7to5_matrix') and k.startswith('blocks.6'):
35 | print(k, v[0:3, 0:3])
36 | break
37 |
38 | if __name__ == '__main__':
39 | # test_local()
40 | test_original()
41 |
--------------------------------------------------------------------------------
/tests/test_REA.py:
--------------------------------------------------------------------------------
1 | from xnas.algorithms.SPOS import REA
2 |
3 | rea = REA(num_choice=4, child_len=10, population_size=5)
4 |
5 | for i in range(30):
6 | child = rea.suggest()
7 | print("suggest: {}".format(child))
8 | rea.record(child, value=sum(child))
9 | print("===")
10 | for p in rea.population:
11 | print(p)
12 | input()
13 | print(rea.final_best())
--------------------------------------------------------------------------------
/xnas/algorithms/AttentiveNAS/sampler.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | def count_helper(v, flops, m):
4 | if flops not in m:
5 | m[flops] = {}
6 | if v not in m[flops]:
7 | m[flops][v] = 0
8 | m[flops][v] += 1
9 |
10 |
11 | def round_flops(flops, step):
12 | return int(round(flops / step) * step)
13 |
14 |
15 | def convert_count_to_prob(m):
16 | if isinstance(m[list(m.keys())[0]], dict):
17 | for k in m:
18 | convert_count_to_prob(m[k])
19 | else:
20 | t = sum(m.values())
21 | for k in m:
22 | m[k] = 1.0 * m[k] / t
23 |
24 |
25 | def sample_helper(flops, m):
26 | keys = list(m[flops].keys())
27 | probs = list(m[flops].values())
28 | return random.choices(keys, weights=probs)[0]
29 |
30 |
31 | def build_trasition_prob_matrix(file_handler, step):
32 | # initlizie
33 | prob_map = {}
34 | prob_map['discretize_step'] = step
35 | for k in ['flops', 'resolution', 'width', 'depth', 'kernel_size', 'expand_ratio']:
36 | prob_map[k] = {}
37 |
38 | cc = 0
39 | for line in file_handler:
40 | vals = eval(line.strip())
41 |
42 | # discretize
43 | flops = round_flops(vals['flops'], step)
44 | prob_map['flops'][flops] = prob_map['flops'].get(flops, 0) + 1
45 |
46 | # resolution
47 | r = vals['resolution']
48 | count_helper(r, flops, prob_map['resolution'])
49 |
50 | for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
51 | for idx, v in enumerate(vals[k]):
52 | if idx not in prob_map[k]:
53 | prob_map[k][idx] = {}
54 | count_helper(v, flops, prob_map[k][idx])
55 |
56 | cc += 1
57 |
58 | # convert count to probability
59 | for k in ['flops', 'resolution', 'width', 'depth', 'kernel_size', 'expand_ratio']:
60 | convert_count_to_prob(prob_map[k])
61 | prob_map['n_observations'] = cc
62 | return prob_map
63 |
64 |
65 |
66 | class ArchSampler():
67 | def __init__(self, arch_to_flops_map_file_path, discretize_step, model, acc_predictor=None):
68 | super(ArchSampler, self).__init__()
69 |
70 | with open(arch_to_flops_map_file_path, 'r') as fp:
71 | self.prob_map = build_trasition_prob_matrix(fp, discretize_step)
72 |
73 | self.discretize_step = discretize_step
74 | self.model = model
75 |
76 | self.acc_predictor = acc_predictor
77 |
78 | self.min_flops = min(list(self.prob_map['flops'].keys()))
79 | self.max_flops = max(list(self.prob_map['flops'].keys()))
80 |
81 | self.curr_sample_pool = None #TODO; architecture samples could be generated in an asynchronous way
82 |
83 |
84 | def sample_one_target_flops(self, flops_uniform=False):
85 | f_vals = list(self.prob_map['flops'].keys())
86 | f_probs = list(self.prob_map['flops'].values())
87 |
88 | if flops_uniform:
89 | return random.choice(f_vals)
90 | else:
91 | return random.choices(f_vals, weights=f_probs)[0]
92 |
93 |
94 | def sample_archs_according_to_flops(self, target_flops, n_samples=1, max_trials=100, return_flops=True, return_trials=False):
95 | archs = []
96 | #for _ in range(n_samples):
97 | while len(archs) < n_samples:
98 | for _trial in range(max_trials+1):
99 | arch = {}
100 | arch['resolution'] = sample_helper(target_flops, self.prob_map['resolution'])
101 | for k in ['width', 'kernel_size', 'depth', 'expand_ratio']:
102 | arch[k] = []
103 | for idx in sorted(list(self.prob_map[k].keys())):
104 | arch[k].append(sample_helper(target_flops, self.prob_map[k][idx]))
105 | if self.model:
106 | self.model.set_active_subnet(**arch)
107 | flops = self.model.compute_active_subnet_flops()
108 | if return_flops:
109 | arch['flops'] = flops
110 | if round_flops(flops, self.discretize_step) == target_flops:
111 | break
112 | else:
113 | raise NotImplementedError
114 | #accepte the sample anyway
115 | archs.append(arch)
116 | return archs
117 |
118 |
--------------------------------------------------------------------------------
/xnas/algorithms/DARTS.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | '''
8 | Darts: highly copyed from https://github.com/khanrc/pt.darts
9 | '''
10 |
11 |
12 | class DartsCNNController(nn.Module):
13 | """SearchCNN Controller"""
14 |
15 | def __init__(self, net, criterion, device_ids=None):
16 | super().__init__()
17 | if device_ids is None:
18 | device_ids = list(range(torch.cuda.device_count()))
19 | self.net = net
20 | self.device_ids = device_ids
21 | self.n_ops = self.net.num_ops
22 | self.alpha = nn.Parameter(
23 | 1e-3*torch.randn(self.net.all_edges, self.n_ops))
24 | self.criterion = criterion
25 |
26 | # Setup alphas list
27 | self._alphas = []
28 | for n, p in self.named_parameters():
29 | if 'alpha' in n:
30 | self._alphas.append((n, p))
31 |
32 | def forward(self, x):
33 | weights_ = F.softmax(self.alpha, dim=-1)
34 | if len(self.device_ids) == 1:
35 | return self.net(x, weights_)
36 | else:
37 | raise NotImplementedError
38 |
39 | def genotype(self):
40 | """return genotype of DARTS CNN"""
41 | return self.net.genotype(self.alpha.cpu().detach().numpy())
42 |
43 | def weights(self):
44 | return self.net.parameters()
45 |
46 | def named_weights(self):
47 | return self.net.named_parameters()
48 |
49 | def alphas(self):
50 | for n, p in self._alphas:
51 | yield p
52 |
53 | def named_alphas(self):
54 | for n, p in self._alphas:
55 | yield n, p
56 |
57 | def print_alphas(self, logger):
58 | logger.info("####### ALPHA #######")
59 | for alpha in self.alpha:
60 | logger.info(F.softmax(alpha, dim=-1).cpu().detach().numpy())
61 | logger.info("#####################")
62 |
63 | def loss(self, X, y):
64 | logits = self.forward(X)
65 | return self.criterion(logits, y)
66 |
67 |
68 | class Architect():
69 | """ Compute gradients of alphas """
70 |
71 | def __init__(self, net, w_momentum, w_weight_decay):
72 | """
73 | Args:
74 | net
75 | w_momentum: weights momentum
76 | """
77 | self.net = net
78 | self.v_net = copy.deepcopy(net)
79 | self.w_momentum = w_momentum
80 | self.w_weight_decay = w_weight_decay
81 |
82 | def virtual_step(self, trn_X, trn_y, xi, w_optim):
83 | """
84 | Compute unrolled weight w' (virtual step)
85 | Step process:
86 | 1) forward
87 | 2) calc loss
88 | 3) compute gradient (by backprop)
89 | 4) update gradient
90 | Args:
91 | xi: learning rate for virtual gradient step (same as weights lr)
92 | w_optim: weights optimizer
93 | """
94 | # forward & calc loss
95 | loss = self.net.loss(trn_X, trn_y) # L_trn(w)
96 |
97 | # compute gradient
98 | gradients = torch.autograd.grad(loss, self.net.weights())
99 |
100 | # virtual step (update gradient)
101 | # operations below do not need gradient tracking
102 | with torch.no_grad():
103 | # dict key is not the value, but the pointer,
104 | # So original network weight have to be iterated also.
105 | for w, vw, g in zip(self.net.weights(), self.v_net.weights(), gradients):
106 | m = w_optim.state[w].get(
107 | 'momentum_buffer', 0.) * self.w_momentum
108 | vw.copy_(w - xi * (m + g + self.w_weight_decay*w))
109 |
110 | # synchronize alphas
111 | for a, va in zip(self.net.alphas(), self.v_net.alphas()):
112 | va.copy_(a)
113 |
114 | def unrolled_backward(self, trn_X, trn_y, val_X, val_y, xi, w_optim, unrolled=True):
115 | """ Compute unrolled loss and backward its gradients
116 | Args:
117 | xi: learning rate for virtual gradient step (same as net lr)
118 | w_optim: weights optimizer - for virtual step
119 | """
120 | # do virtual step (calc w`)
121 | if unrolled:
122 | self.virtual_step(trn_X, trn_y, xi, w_optim)
123 |
124 | # calc unrolled loss
125 | loss = self.v_net.loss(val_X, val_y) # L_val(w`)
126 |
127 | # compute gradient
128 | v_alphas = tuple(self.v_net.alphas())
129 | v_weights = tuple(self.v_net.weights())
130 | v_grads = torch.autograd.grad(loss, v_alphas + v_weights)
131 | dalpha = v_grads[:len(v_alphas)]
132 | dw = v_grads[len(v_alphas):]
133 |
134 | hessian = self.compute_hessian(dw, trn_X, trn_y)
135 |
136 | # update final gradient = dalpha - xi*hessian
137 | with torch.no_grad():
138 | for alpha, da, h in zip(self.net.alphas(), dalpha, hessian):
139 | alpha.grad = da - xi * h
140 |
141 | else:
142 | loss = self.net.loss(val_X, val_y) # L_trn(w)
143 | loss.loss.backward()
144 |
145 |
146 | def compute_hessian(self, dw, trn_X, trn_y):
147 | """
148 | dw = dw` { L_val(w`, alpha) }
149 | w+ = w + eps * dw
150 | w- = w - eps * dw
151 | hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
152 | eps = 0.01 / ||dw||
153 | """
154 |
155 | norm = torch.cat([w.view(-1) for w in dw]).norm()
156 | eps = 0.01 / norm
157 |
158 | # w+ = w + eps*dw`
159 | with torch.no_grad():
160 | for p, d in zip(self.net.weights(), dw):
161 | p += eps * d
162 | loss = self.net.loss(trn_X, trn_y)
163 | dalpha_pos = torch.autograd.grad(
164 | loss, self.net.alphas()) # dalpha { L_trn(w+) }
165 |
166 | # w- = w - eps*dw`
167 | with torch.no_grad():
168 | for p, d in zip(self.net.weights(), dw):
169 | p -= 2. * eps * d
170 | loss = self.net.loss(trn_X, trn_y)
171 | dalpha_neg = torch.autograd.grad(
172 | loss, self.net.alphas()) # dalpha { L_trn(w-) }
173 |
174 | # recover w
175 | with torch.no_grad():
176 | for p, d in zip(self.net.weights(), dw):
177 | p += eps * d
178 |
179 | hessian = [(p-n) / 2.*eps for p, n in zip(dalpha_pos, dalpha_neg)]
180 | return hessian
181 |
--------------------------------------------------------------------------------
/xnas/algorithms/DrNAS.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.autograd import Variable
4 |
5 |
6 | def _concat(xs):
7 | return torch.cat([x.view(-1) for x in xs])
8 |
9 |
10 | class Architect(object):
11 | def __init__(self, net, cfg):
12 | self.network_momentum = cfg.OPTIM.MOMENTUM
13 | self.network_weight_decay = cfg.OPTIM.WEIGHT_DECAY
14 | self.net = net
15 | if cfg.DRNAS.REG_TYPE == "l2":
16 | weight_decay = cfg.DRNAS.REG_SCALE
17 | elif cfg.DRNAS.REG_TYPE == "kl":
18 | weight_decay = 0
19 | self.optimizer = torch.optim.Adam(
20 | self.net.arch_parameters(),
21 | lr=cfg.DARTS.ALPHA_LR,
22 | betas=(0.5, 0.999),
23 | weight_decay=weight_decay,
24 | )
25 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
26 |
27 | def _compute_unrolled_model(self, input, target, eta, network_optimizer):
28 | loss = self.net._loss(input, target)
29 | theta = _concat(self.net.parameters()).data
30 | try:
31 | moment = _concat(
32 | network_optimizer.state[v]["momentum_buffer"]
33 | for v in self.net.parameters()
34 | ).mul_(self.network_momentum)
35 | except:
36 | moment = torch.zeros_like(theta)
37 | dtheta = (
38 | _concat(torch.autograd.grad(loss, self.net.parameters())).data
39 | + self.network_weight_decay * theta
40 | )
41 | unrolled_model = self._construct_model_from_theta(
42 | theta.sub(eta, moment + dtheta)
43 | ).to(self.device)
44 | return unrolled_model
45 |
46 | def unrolled_backward(
47 | self,
48 | input_train,
49 | target_train,
50 | input_valid,
51 | target_valid,
52 | eta,
53 | network_optimizer,
54 | unrolled,
55 | ):
56 | self.optimizer.zero_grad()
57 | if unrolled:
58 | self._backward_step_unrolled(
59 | input_train,
60 | target_train,
61 | input_valid,
62 | target_valid,
63 | eta,
64 | network_optimizer,
65 | )
66 | else:
67 | self._backward_step(input_valid, target_valid)
68 | self.optimizer.step()
69 |
70 | # def pruning(self, masks):
71 | # for i, p in enumerate(self.optimizer.param_groups[0]['params']):
72 | # if masks[i] is None:
73 | # continue
74 | # state = self.optimizer.state[p]
75 | # mask = masks[i]
76 | # state['exp_avg'][~mask] = 0.0
77 | # state['exp_avg_sq'][~mask] = 0.0
78 |
79 | def _backward_step(self, input_valid, target_valid):
80 | loss = self.net._loss(input_valid, target_valid)
81 | loss.backward()
82 |
83 | def _backward_step_unrolled(
84 | self,
85 | input_train,
86 | target_train,
87 | input_valid,
88 | target_valid,
89 | eta,
90 | network_optimizer,
91 | ):
92 | unrolled_model = self._compute_unrolled_model(
93 | input_train, target_train, eta, network_optimizer
94 | )
95 | unrolled_loss = unrolled_model._loss(input_valid, target_valid)
96 |
97 | unrolled_loss.backward()
98 | dalpha = [v.grad for v in unrolled_model.arch_parameters()]
99 | vector = [v.grad.data for v in unrolled_model.parameters()]
100 | implicit_grads = self._hessian_vector_product(vector, input_train, target_train)
101 |
102 | for g, ig in zip(dalpha, implicit_grads):
103 | g.data.sub_(eta, ig.data)
104 |
105 | for v, g in zip(self.net.arch_parameters(), dalpha):
106 | if v.grad is None:
107 | v.grad = Variable(g.data)
108 | else:
109 | v.grad.data.copy_(g.data)
110 |
111 | def _construct_model_from_theta(self, theta):
112 | model_new = self.net.new()
113 | model_dict = self.net.state_dict()
114 |
115 | params, offset = {}, 0
116 | for k, v in self.net.named_parameters():
117 | v_length = np.prod(v.size())
118 | params[k] = theta[offset : offset + v_length].view(v.size())
119 | offset += v_length
120 |
121 | assert offset == len(theta)
122 | model_dict.update(params)
123 | model_new.load_state_dict(model_dict)
124 | return model_new
125 |
126 | def _hessian_vector_product(self, vector, input, target, r=1e-2):
127 | R = r / _concat(vector).norm()
128 | for p, v in zip(self.net.parameters(), vector):
129 | p.data.add_(R, v)
130 | loss = self.net._loss(input, target)
131 | grads_p = torch.autograd.grad(loss, self.net.arch_parameters())
132 |
133 | for p, v in zip(self.net.parameters(), vector):
134 | p.data.sub_(2 * R, v)
135 | loss = self.net._loss(input, target)
136 | grads_n = torch.autograd.grad(loss, self.net.arch_parameters())
137 |
138 | for p, v in zip(self.net.parameters(), vector):
139 | p.data.add_(R, v)
140 |
141 | return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)]
142 |
--------------------------------------------------------------------------------
/xnas/algorithms/RMINAS/README.md:
--------------------------------------------------------------------------------
1 | ## Introduction
2 |
3 | Code for paper: **Neural Architecture Search with Representation Mutual Information**
4 |
5 | RMI-NAS is an efficient architecture search method based on Representation Mutual Information (RMI) theory. It aims at improving the speed of performance evaluation by ranking architectures with RMI, which is an accurate and effective indicator to facilitate NAS. RMI-NAS uses only one batch of data to complete training and generalizes well to different search spaces. For more details, please refer to our paper.
6 |
7 | ## Results
8 |
9 | ### Results on NAS-Bench-201
10 |
11 | | Method | Search Cost
(seconds) | CIFAR-10
Test Acc.(%) | CIFAR-100
Test Acc.(%) | ImageNet16-120
Test Acc.(%) |
12 | | ----------- | -------------------------- | --------------------------- | ---------------------------- | --------------------------------- |
13 | | RL | 27870.7 | 93.85±0.37 | 71.71±1.09 | 45.24±1.18 |
14 | | DARTS-V2 | 35781.8 | 54.30±0.00 | 15.61±0.00 | 16.32±0.00 |
15 | | GDAS | 31609.8 | 93.61±0.09 | 70.70±0.30 | 41.71±0.98 |
16 | | FairNAS | 9845.0 | 93.23±0.18 | 71.00±1.46 | 42.19±0.31 |
17 | | **RMI-NAS** | **1258.2** | **94.28±0.10** | **73.36±0.19** | **46.34±0.00** |
18 |
19 | Our method shows significant efficiency and accuracy improvements.
20 |
21 | ### Results on DARTS
22 |
23 | | Method | Search Cost
(seconds) | CIFAR-10
Test Acc.(%)
(paper) | CIFAR-10
Test Acc.(%)
(retrain) |
24 | | ----------- | -------------------------- | ---------------------------------------- | ------------------------------------------ |
25 | | AmoebaNet-B | 3150 | 2.55±0.05 | - |
26 | | NASNet-A | 1800 | 2.65 | - |
27 | | DARTS (1st) | 0.4 | 3.00±0.14 | 2.75 |
28 | | DARTS (2nd) | 1 | 2.76±0.09 | 2.60 |
29 | | SNAS | 1.5 | 2.85±0.02 | 2.68 |
30 | | PC-DARTS | 1 | 2.57±0.07 | 2.71±0.11 |
31 | | FairDARTS-D | 0.4 | 2.54±0.05 | 2.71 |
32 | | **RMI-NAS** | **0.08** | - | 2.64±0.04 |
33 |
34 | Comparisons with other methods in DARTS. We also report retrained results under exactly the same settings to ensure a fair comparison. Our method delivers a comparable accuracy but substantial improvements on time comsumption.
35 |
36 |
37 |
38 | ## Usage
39 |
40 | #### Install RMI-NAS
41 |
42 | Our code contains functions from XNAS repository, which is required to be installed.
43 |
44 | ```bash
45 | # install XNAS
46 | git clone https://github.com/MAC-AutoML/XNAS.git
47 | export PYTHONPATH=$PYTHONPATH:/PATH/to/XNAS
48 |
49 | # prepare environment for RMI-NAS (conda)
50 | conda env create --file environment.yaml
51 |
52 | # download weight files for teacher models
53 | chmod +x xnas/algorithms/RMINAS/download_weight.sh
54 | bash xnas/algorithms/RMINAS/download_weight.sh
55 | ```
56 |
57 | File [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) is required for a previous version of `NAS-Bench-201` we are using. It should be downloaded and put into the `utils` directory.
58 |
59 | #### Search
60 |
61 | ```bash
62 | # NAS-Bench-201 + CIFAR-10
63 | python search/RMINAS/RMINAS_nb201.py --cfg configs/search/RMINAS/nb201_cifar10.yaml
64 |
65 | # DARTS + CIFAR-100 + specific exp path
66 | python search/RMINAS/RMINAS_darts.py --cfg configs/search/RMINAS/darts_cifar100.yaml OUT_DIR experiments/
67 | ```
68 |
69 |
70 | ## Related work
71 |
72 | [NAS-Bench-201](https://github.com/D-X-Y/NAS-Bench-201)
73 |
74 | [XNAS](https://github.com/MAC-AutoML/XNAS)
--------------------------------------------------------------------------------
/xnas/algorithms/RMINAS/download_weight.sh:
--------------------------------------------------------------------------------
1 | echo `cd xnas/algorithms/RMINAS/teacher_model/resnet20_cifar10 && wget http://cdn.thrase.cn/rmi/resnet20.th`
2 | echo `cd xnas/algorithms/RMINAS/teacher_model/nb201model_imagenet16120 && wget http://cdn.thrase.cn/rmi/009930-FULL.pth`
3 | echo `cd xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet && wget http://cdn.thrase.cn/rmi/fbresnet152.pth`
4 | echo `cd xnas/algorithms/RMINAS/teacher_model/resnet101_cifar100 && wget http://cdn.thrase.cn/rmi/resnet101.pth`
5 | echo "Finish downloading weight files."
6 |
--------------------------------------------------------------------------------
/xnas/algorithms/RMINAS/sampler/sampling.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | # true_list = []
5 | # with open('xnas/search_algorithm/RMINAS/sampler/available_archs.txt', 'r') as f:
6 | # true_list = eval(f.readline())
7 |
8 | # def random_sampling(times):
9 | # sample_list = []
10 | # if times > sum(true_list):
11 | # print('can only sample {} times.'.format(sum(true_list)))
12 | # times = sum(true_list)
13 | # for _ in range(times):
14 | # i = random.randint(0, 15624)
15 | # while (not true_list[i]) or (i in sample_list):
16 | # i = random.randint(0, 15624)
17 | # sample_list.append(i)
18 | # return sample_list
19 |
20 | def darts_sug2alpha(suggest_sample):
21 | b = np.c_[suggest_sample, np.zeros(14)]
22 | return torch.from_numpy(np.r_[b,b])
23 |
24 | def nb201genostr2array(geno_str):
25 | # |none~0|+|nor_conv_1x1~0|none~1|+|avg_pool_3x3~0|skip_connect~1|nor_conv_3x3~2|
26 | OPS = ["none", "skip_connect", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3"]
27 | _tmp = geno_str.split('|')
28 | _tmp2 = []
29 | for i in range(len(_tmp)):
30 | if i in [1,3,4,6,7,8]:
31 | _tmp2.append(_tmp[i][:-2])
32 | _tmp_np = np.array([0]*6)
33 | for i in range(6):
34 | _tmp_np[i] = OPS.index(_tmp2[i])
35 | _tmp_oh = np.zeros((_tmp_np.size, 5))
36 | _tmp_oh[np.arange(_tmp_np.size),_tmp_np] = 1
37 | return _tmp_oh
38 |
39 | # def array2genostr(arr):
40 | # OPS = ["none", "skip_connect", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3"]
41 | # """[[1. 0. 0. 0. 0.]
42 | # [0. 0. 1. 0. 0.]
43 | # [1. 0. 0. 0. 0.]
44 | # [0. 0. 0. 0. 1.]
45 | # [0. 1. 0. 0. 0.]
46 | # [0. 0. 0. 1. 0.]]"""
47 | # idx = [list(i).index(1.) for i in arr]
48 | # op = [OPS[x] for x in idx]
49 | # mixed = '|' + op[0] + '~0|+|' + op[1] + '~0|' + op[2] + '~1|+|' + op[3] + '~0|' + op[4] + '~1|' + op[5] + '~2|'
50 | # return mixed
51 |
52 | # def base_transform(n, x):
53 | # a=[0,1,2,3,4,5,6,7,8,9,'A','b','C','D','E','F']
54 | # b=[]
55 | # while True:
56 | # s=n//x
57 | # y=n%x
58 | # b=b+[y]
59 | # if s==0:
60 | # break
61 | # n=s
62 | # b.reverse()
63 | # zero_arr = [0]*(6-len(b))
64 | # return zero_arr+b
65 |
66 | # def array_morearch(arr, distance):
67 | # """[[1. 0. 0. 0. 0.]
68 | # [0. 0. 1. 0. 0.]
69 | # [1. 0. 0. 0. 0.]
70 | # [0. 0. 0. 0. 1.]
71 | # [0. 1. 0. 0. 0.]
72 | # [0. 0. 0. 1. 0.]]"""
73 | # am = list(arr.argmax(axis=1)) # [0,2,0,4,1,3]
74 | # morearch = []
75 | # if distance == 1:
76 | # for i in range(len(am)):
77 | # for j in range(5):
78 | # if am[i]!=j:
79 | # _tmp = am[:]
80 | # _tmp[i] = j
81 | # _tmp_np = np.array(_tmp)
82 | # _tmp_oh = np.zeros((_tmp_np.size, 5))
83 | # _tmp_oh[np.arange(_tmp_np.size),_tmp_np] = 1
84 | # morearch.append(_tmp_oh)
85 | # else:
86 | # for i in range(15625):
87 | # arr = base_transform(i, 5)
88 | # if distance == 6-sum([arr[i]==am[i] for i in range(6)]):
89 | # _tmp_np = np.array(arr)
90 | # _tmp_oh = np.zeros((_tmp_np.size, 5))
91 | # _tmp_oh[np.arange(_tmp_np.size),_tmp_np] = 1
92 | # morearch.append(_tmp_oh)
93 | # # morearch.append(arr)
94 | # return morearch
95 |
96 |
97 |
--------------------------------------------------------------------------------
/xnas/algorithms/RMINAS/teacher_model/resnet101_cifar100/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference:
3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
5 | """
6 |
7 | import torch.nn as nn
8 |
9 | class BasicBlock(nn.Module):
10 | expansion = 1
11 |
12 | def __init__(self, in_channels, out_channels, stride=1):
13 | super().__init__()
14 |
15 | self.residual_function = nn.Sequential(
16 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
17 | nn.BatchNorm2d(out_channels),
18 | nn.ReLU(inplace=True),
19 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
20 | nn.BatchNorm2d(out_channels * BasicBlock.expansion)
21 | )
22 |
23 | self.shortcut = nn.Sequential()
24 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
25 | self.shortcut = nn.Sequential(
26 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
27 | nn.BatchNorm2d(out_channels * BasicBlock.expansion)
28 | )
29 |
30 | def forward(self, x):
31 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
32 |
33 | class BottleNeck(nn.Module):
34 | expansion = 4
35 | def __init__(self, in_channels, out_channels, stride=1):
36 | super().__init__()
37 | self.residual_function = nn.Sequential(
38 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
39 | nn.BatchNorm2d(out_channels),
40 | nn.ReLU(inplace=True),
41 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
42 | nn.BatchNorm2d(out_channels),
43 | nn.ReLU(inplace=True),
44 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
45 | nn.BatchNorm2d(out_channels * BottleNeck.expansion),
46 | )
47 |
48 | self.shortcut = nn.Sequential()
49 |
50 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
51 | self.shortcut = nn.Sequential(
52 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
53 | nn.BatchNorm2d(out_channels * BottleNeck.expansion)
54 | )
55 |
56 | def forward(self, x):
57 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
58 |
59 | class ResNet(nn.Module):
60 |
61 | def __init__(self, block, num_block, num_classes=100):
62 | super().__init__()
63 |
64 | self.in_channels = 64
65 |
66 | self.conv1 = nn.Sequential(
67 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
68 | nn.BatchNorm2d(64),
69 | nn.ReLU(inplace=True))
70 | #we use a different inputsize than the original paper
71 | #so conv2_x's stride is 1
72 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
73 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
74 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
75 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
76 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
77 | self.fc = nn.Linear(512 * block.expansion, num_classes)
78 |
79 | def _make_layer(self, block, out_channels, num_blocks, stride):
80 | strides = [stride] + [1] * (num_blocks - 1)
81 | layers = []
82 | for stride in strides:
83 | layers.append(block(self.in_channels, out_channels, stride))
84 | self.in_channels = out_channels * block.expansion
85 |
86 | return nn.Sequential(*layers)
87 |
88 | def forward(self, x):
89 | output = self.conv1(x)
90 | output = self.conv2_x(output)
91 | output = self.conv3_x(output)
92 | output = self.conv4_x(output)
93 | output = self.conv5_x(output)
94 | output = self.avg_pool(output)
95 | output = output.view(output.size(0), -1)
96 | output = self.fc(output)
97 |
98 | return output
99 |
100 | def feature_extractor(self, x):
101 | features = []
102 | output = self.conv1(x)
103 | output = self.conv2_x(output)
104 | features.append(output)
105 | output = self.conv3_x(output)
106 | output = self.conv4_x(output)
107 | features.append(output)
108 | output = self.conv5_x(output)
109 | features.append(output)
110 | # output = self.avg_pool(output)
111 | # output = output.view(output.size(0), -1)
112 | # output = self.fc(output)
113 |
114 | return features
115 |
116 | def resnet18():
117 | return ResNet(BasicBlock, [2, 2, 2, 2])
118 |
119 | def resnet34():
120 | return ResNet(BasicBlock, [3, 4, 6, 3])
121 |
122 | def resnet50():
123 | return ResNet(BottleNeck, [3, 4, 6, 3])
124 |
125 | def resnet101():
126 | return ResNet(BottleNeck, [3, 4, 23, 3])
127 |
128 | def resnet152():
129 | return ResNet(BottleNeck, [3, 8, 36, 3])
130 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/xnas/algorithms/RMINAS/teacher_model/resnet20_cifar10/resnet.py:
--------------------------------------------------------------------------------
1 | '''
2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1].
3 |
4 | The implementation and structure of this file is hugely influenced by [2]
5 | which is implemented for ImageNet and doesn't have option A for identity.
6 | Moreover, most of the implementations on the web is copy-paste from
7 | torchvision's resnet and has wrong number of params.
8 |
9 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
10 | number of layers and parameters:
11 |
12 | name | layers | params
13 | ResNet20 | 20 | 0.27M
14 | ResNet32 | 32 | 0.46M
15 | ResNet44 | 44 | 0.66M
16 | ResNet56 | 56 | 0.85M
17 | ResNet110 | 110 | 1.7M
18 | ResNet1202| 1202 | 19.4m
19 |
20 | which this implementation indeed has.
21 |
22 | Reference:
23 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
24 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
25 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
26 |
27 | If you use this implementation in you work, please don't forget to mention the
28 | author, Yerlan Idelbayev.
29 | '''
30 |
31 | import torch.nn as nn
32 | import torch.nn.functional as F
33 | import torch.nn.init as init
34 |
35 |
36 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
37 |
38 | def _weights_init(m):
39 | classname = m.__class__.__name__
40 | #print(classname)
41 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
42 | init.kaiming_normal_(m.weight)
43 |
44 | class LambdaLayer(nn.Module):
45 | def __init__(self, lambd):
46 | super(LambdaLayer, self).__init__()
47 | self.lambd = lambd
48 |
49 | def forward(self, x):
50 | return self.lambd(x)
51 |
52 |
53 | class BasicBlock(nn.Module):
54 | expansion = 1
55 |
56 | def __init__(self, in_planes, planes, stride=1, option='A'):
57 | super(BasicBlock, self).__init__()
58 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
59 | self.bn1 = nn.BatchNorm2d(planes)
60 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
61 | self.bn2 = nn.BatchNorm2d(planes)
62 |
63 | self.shortcut = nn.Sequential()
64 | if stride != 1 or in_planes != planes:
65 | if option == 'A':
66 | """
67 | For CIFAR10 ResNet paper uses option A.
68 | """
69 | self.shortcut = LambdaLayer(lambda x:
70 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
71 | elif option == 'B':
72 | self.shortcut = nn.Sequential(
73 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
74 | nn.BatchNorm2d(self.expansion * planes)
75 | )
76 |
77 | def forward(self, x):
78 | out = F.relu(self.bn1(self.conv1(x)))
79 | out = self.bn2(self.conv2(out))
80 | out += self.shortcut(x)
81 | out = F.relu(out)
82 | return out
83 |
84 |
85 | class ResNet(nn.Module):
86 | def __init__(self, block, num_blocks, num_classes=10):
87 | super(ResNet, self).__init__()
88 | self.in_planes = 16
89 |
90 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
91 | self.bn1 = nn.BatchNorm2d(16)
92 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
93 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
94 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
95 | self.linear = nn.Linear(64, num_classes)
96 |
97 | self.apply(_weights_init)
98 |
99 | def _make_layer(self, block, planes, num_blocks, stride):
100 | strides = [stride] + [1]*(num_blocks-1)
101 | layers = []
102 | for stride in strides:
103 | layers.append(block(self.in_planes, planes, stride))
104 | self.in_planes = planes * block.expansion
105 |
106 | return nn.Sequential(*layers)
107 |
108 | def forward(self, x):
109 | out = F.relu(self.bn1(self.conv1(x)))
110 | out = self.layer1(out)
111 | out = self.layer2(out)
112 | out = self.layer3(out)
113 | out = F.avg_pool2d(out, out.size()[3])
114 | out = out.view(out.size(0), -1)
115 | out = self.linear(out)
116 | return out
117 |
118 | def feature_extractor(self, x):
119 | features = []
120 | out = F.relu(self.bn1(self.conv1(x)))
121 | out = self.layer1(out)
122 | features.append(out)
123 | out = self.layer2(out)
124 | features.append(out)
125 | out = self.layer3(out)
126 | features.append(out)
127 | # out = F.avg_pool2d(out, out.size()[3])
128 | # out = out.view(out.size(0), -1)
129 | # out = self.linear(out)
130 | return features
131 |
132 |
133 | def resnet20():
134 | return ResNet(BasicBlock, [3, 3, 3])
135 |
136 |
137 | def resnet32():
138 | return ResNet(BasicBlock, [5, 5, 5])
139 |
140 |
141 | def resnet44():
142 | return ResNet(BasicBlock, [7, 7, 7])
143 |
144 |
145 | def resnet56():
146 | return ResNet(BasicBlock, [9, 9, 9])
147 |
148 |
149 | def resnet110():
150 | return ResNet(BasicBlock, [18, 18, 18])
151 |
152 |
153 | def resnet1202():
154 | return ResNet(BasicBlock, [200, 200, 200])
155 |
156 |
157 | def test(net):
158 | import numpy as np
159 | total_params = 0
160 |
161 | for x in filter(lambda p: p.requires_grad, net.parameters()):
162 | total_params += np.prod(x.data.numpy().shape)
163 | print("Total number of params", total_params)
164 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))
165 |
166 |
167 | if __name__ == "__main__":
168 | for net_name in __all__:
169 | if net_name.startswith('resnet'):
170 | print(net_name)
171 | test(globals()[net_name]())
172 | print()
--------------------------------------------------------------------------------
/xnas/algorithms/RMINAS/utils/random_data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | from xnas.datasets.loader import get_normal_dataloader
5 | from xnas.datasets.imagenet import ImageFolder
6 |
7 |
8 | def get_random_data(batchsize, name):
9 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
10 | if name == 'imagenet':
11 | train_loader, _ = ImageFolder(
12 | datapath="./data/imagenet/ILSVRC2012_img_train/",
13 | batch_size=batchsize*16,
14 | split=[0.5, 0.5],
15 | ).generate_data_loader()
16 | else:
17 | train_loader, _ = get_normal_dataloader(name, batchsize*16)
18 |
19 | random_idxs = np.random.randint(0, len(train_loader.dataset), size=train_loader.batch_size)
20 | (more_data_X, more_data_y) = zip(*[train_loader.dataset[idx] for idx in random_idxs])
21 | more_data_X = torch.stack(more_data_X, dim=0).to(device)
22 | more_data_y = torch.Tensor(more_data_y).long().to(device)
23 | return more_data_X, more_data_y
24 |
--------------------------------------------------------------------------------
/xnas/algorithms/SNG/DDPNAS.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from copy import deepcopy
4 |
5 | from xnas.algorithms.SNG.categorical import Categorical
6 | from xnas.core.utils import index_to_one_hot, softmax
7 |
8 |
9 | class CategoricalDDPNAS:
10 | def __init__(self, category, steps, theta_lr, gamma):
11 | self.p_model = Categorical(categories=category)
12 | # how many steps to pruning the distribution
13 | self.steps = steps
14 | self.current_step = 1
15 | self.ignore_index = []
16 | self.sample_index = []
17 | self.pruned_index = []
18 | self.val_performance = []
19 | self.sample = []
20 | self.init_record()
21 | self.score_decay = 0.5
22 | self.learning_rate = 0.2
23 | self.training_finish = False
24 | self.training_epoch = self.get_training_epoch()
25 | self.non_param_index = [0, 1, 2, 7]
26 | self.param_index = [3, 4, 5, 6]
27 | self.non_param_index_num = len(self.non_param_index)
28 | self.pruned_index_num = len(self.param_index)
29 | self.non_param_index_count = [0] * self.p_model.d
30 | self.param_index_count = [0] * self.p_model.d
31 | self.gamma = gamma
32 | self.theta_lr = theta_lr
33 | self.velocity = np.zeros(self.p_model.theta.shape)
34 |
35 | def init_record(self):
36 | for i in range(self.p_model.d):
37 | self.ignore_index.append([])
38 | self.sample_index.append(list(range(self.p_model.Cmax)))
39 | self.pruned_index.append([])
40 | self.val_performance.append(np.zeros([self.p_model.d, self.p_model.Cmax]))
41 |
42 | def get_training_epoch(self):
43 | return self.steps * sum(list(range(self.p_model.Cmax)))
44 |
45 | def sampling(self):
46 | # return self.sampling_index()
47 | self.sample = self.sampling_index()
48 | return index_to_one_hot(self.sample, self.p_model.Cmax)
49 |
50 | def sampling_index(self):
51 | sample = []
52 | for i in range(self.p_model.d):
53 | sample.append(random.choice(self.sample_index[i]))
54 | if len(self.sample_index[i]) > 0:
55 | self.sample_index[i].remove(sample[i])
56 | return np.array(sample)
57 |
58 | def sample_with_constrains(self):
59 | # pass
60 | raise NotImplementedError
61 |
62 | def record_information(self, sample, performance):
63 | # self.sample = sample
64 | for i in range(self.p_model.d):
65 | self.val_performance[-1][i, self.sample[i]] = performance
66 |
67 | def update_sample_index(self):
68 | for i in range(self.p_model.d):
69 | self.sample_index[i] = list(set(range(self.p_model.Cmax)) - set(self.pruned_index[i]))
70 |
71 | def update(self):
72 | # when a search epoch for operations is over and not to the total search epoch
73 | if len(self.sample_index[0]) == 0:
74 | if self.current_step < self.steps:
75 | self.current_step += 1
76 | # append new val performance
77 | self.val_performance.append(np.zeros([self.p_model.d, self.p_model.Cmax]))
78 | # when the total search is down
79 | else:
80 | self.current_step += 1
81 | expectation = np.zeros([self.p_model.d, self.p_model.Cmax])
82 | for i in range(self.steps):
83 | # multi 100 to ignore 0
84 | expectation += softmax(self.val_performance[i] * 100, axis=1)
85 | expectation = expectation / float(self.steps)
86 | # self.p_model.theta = expectation + self.score_decay * self.p_model.theta
87 | self.velocity = self.gamma * self.velocity + (1 - self.gamma) * expectation
88 | # NOTE: THETA_LR not applied.
89 | self.p_model.theta += self.velocity
90 | # self.p_model.theta = self.p_model.theta + self.theta_lr * expectation
91 | # prune the index
92 | pruned_weight = deepcopy(self.p_model.theta)
93 | for index in range(self.p_model.d):
94 | if not len(self.pruned_index[index]) == 0:
95 | pruned_weight[index, self.pruned_index[index]] = np.nan
96 | pruned_index = np.nanargmin(pruned_weight[index, :])
97 | # if self.non_param_index_count[index] == 3 and pruned_index in self.non_param_index:
98 | # pruned_weight[index, pruned_index] = np.nan
99 | # pruned_index = np.nanargmin(pruned_weight[index, :])
100 | # if self.param_index_count[index] == 3 and pruned_index in self.param_index:
101 | # pruned_weight[index, pruned_index] = np.nan
102 | # pruned_index = np.nanargmin(pruned_weight[index, :])
103 | # if pruned_index in self.param_index:
104 | # self.param_index_count[index] += 1
105 | # if pruned_index in self.non_param_index:
106 | # self.non_param_index_count[index] += 1
107 | self.pruned_index[index].append(pruned_index)
108 | # self.p_model.theta[index, pruned_index] = 0
109 | self.p_model.theta /= np.sum(self.p_model.theta, axis=1)[:, np.newaxis]
110 | # if self.param_index_count[0] == 3 and self.non_param_index_count[0] == 3:
111 | # self.training_finish = True
112 | self.current_step = 1
113 | # init val_performance
114 | self.val_performance = []
115 | self.val_performance.append(np.zeros([self.p_model.d, self.p_model.Cmax]))
116 | self.update_sample_index()
117 |
--------------------------------------------------------------------------------
/xnas/algorithms/SNG/GridSearch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from xnas.algorithms.SNG.categorical import Categorical
4 |
5 |
6 | class GridSearch:
7 | def __init__(self, categories, fresh_size=4, init_theta=None, max_mize=True):
8 |
9 | self.p_model = Categorical(categories)
10 | self.p_model.C = np.array(self.p_model.C)
11 | self.valid_d = len(self.p_model.C[self.p_model.C > 1])
12 |
13 | # Refresh theta
14 | for k in range(self.p_model.d):
15 | self.p_model.theta[k, 0] = 1
16 | self.p_model.theta[k, 1:self.p_model.C[k]] = 0
17 |
18 | if init_theta is not None:
19 | self.p_model.theta = init_theta
20 |
21 | self.fresh_size = fresh_size
22 | self.sample = []
23 | self.objective = []
24 | self.maxmize = -1 if max_mize else 1
25 | self.obj_optim = float('inf')
26 | self.training_finish = False
27 |
28 | # Record point to move
29 | self.sample_point = self.p_model.theta
30 | self.point = [self.p_model.d-1, 0]
31 |
32 | def sampling(self):
33 | return self.sample_point
34 |
35 | def record_information(self, sample, objective):
36 | self.sample.append(sample)
37 | self.objective.append(objective * self.maxmize)
38 |
39 | def update(self):
40 | """
41 | Update sampling by grid search
42 | e.g.
43 | categories = [3, 2, 4]
44 | sample point as = [1, 1, 1]
45 | [0, 0, 0]
46 | [0, , 0]
47 | [ , , 0]
48 | point as now searching = [2, 0]
49 | """
50 | if len(self.sample) == self.fresh_size:
51 | # update sample
52 | if self.point[1] == self.p_model.C[self.point[0]] - 1:
53 | for i in range(self.point[0] + 1, self.p_model.d):
54 | self.sample_point[i] = np.zeros(self.p_model.Cmax)
55 | self.sample_point[i][0] = 1
56 | for j in range(self.point[0], -1, -1):
57 | k = np.argmax(self.sample_point[j])
58 | if k < self.p_model.C[j] - 1:
59 | self.sample_point[j][k] = 0
60 | self.sample_point[j][k+1] = 1
61 | break
62 | else:
63 | self.sample_point[j][k] = 0
64 | self.sample_point[j][0] = 1
65 | if j == 0:
66 | self.training_finish = True
67 | break
68 | self.point = [self.p_model.d-1, 0]
69 | else:
70 | self.sample_point[self.point[0], self.point[1]] = 0
71 | self.sample_point[self.point[0], self.point[1]+1] = 1
72 | self.point[1] += 1
73 |
74 | # update optim and theta
75 | if min(self.objective) < self.obj_optim:
76 | self.obj_optim = min(self.objective)
77 | seq = np.argmin(np.array(self.objective))
78 | self.p_model.theta = self.sample[seq]
79 |
80 | # update record
81 | self.sample = []
82 | self.objective = []
--------------------------------------------------------------------------------
/xnas/algorithms/SNG/MDENAS.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from xnas.algorithms.SNG.categorical import Categorical
4 |
5 |
6 | class CategoricalMDENAS:
7 | def __init__(self, category, learning_rate):
8 | self.p_model = Categorical(categories=category)
9 | self.leaning_rate = learning_rate
10 | self.information_recoder = {'epoch': np.zeros((self.p_model.d, self.p_model.Cmax)),
11 | 'performance': np.zeros((self.p_model.d, self.p_model.Cmax))}
12 |
13 | def sampling(self):
14 | return self.p_model.sampling()
15 |
16 | def sampling_index(self):
17 | return self.p_model.sampling_index()
18 |
19 | def record_information(self, sample, performance):
20 | for i in range(len(sample)):
21 | self.information_recoder['epoch'][i, sample[i]] += 1
22 | self.information_recoder['performance'][i, sample[i]] = performance
23 |
24 | def update(self):
25 | # update the probability
26 | for edges_index in range(self.p_model.d):
27 | for i in range(self.p_model.Cmax):
28 | for j in range(i+1, self.p_model.Cmax):
29 | if (self.information_recoder['epoch'][edges_index, i] >= self.information_recoder['epoch'][edges_index, j])\
30 | and (self.information_recoder['performance'][edges_index, i] < self.information_recoder['performance'][edges_index, j]):
31 | if self.p_model.theta[edges_index, i] > self.leaning_rate:
32 | self.p_model.theta[edges_index, i] -= self.leaning_rate
33 | self.p_model.theta[edges_index, j] += self.leaning_rate
34 | else:
35 | self.p_model.theta[edges_index, j] += self.p_model.theta[edges_index, i]
36 | self.p_model.theta[edges_index, i] = 0
37 |
38 | if (self.information_recoder['epoch'][edges_index, i] <= self.information_recoder['epoch'][edges_index, j]) \
39 | and (self.information_recoder['performance'][edges_index, i] > self.information_recoder['performance'][edges_index, j]):
40 | if self.p_model.theta[edges_index, j] > self.leaning_rate:
41 | self.p_model.theta[edges_index, j] -= self.leaning_rate
42 | self.p_model.theta[edges_index, i] += self.leaning_rate
43 | else:
44 | self.p_model.theta[edges_index, i] += self.p_model.theta[edges_index, j]
45 | self.p_model.theta[edges_index, j] = 0
46 |
--------------------------------------------------------------------------------
/xnas/algorithms/SNG/RAND.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from xnas.algorithms.SNG.categorical import Categorical
4 | from xnas.core.utils import one_hot_to_index
5 |
6 |
7 | class RandomSample:
8 | """
9 | Random Sample for Categorical Distribution
10 | """
11 |
12 | def __init__(self, categories, delta_init=1., opt_type="best", init_theta=None, max_mize=True):
13 | # Categorical distribution
14 | self.p_model = Categorical(categories)
15 | # valid dimension size
16 | self.p_model.C = np.array(self.p_model.C)
17 |
18 | if init_theta is not None:
19 | self.p_model.theta = init_theta
20 |
21 | self.sample_list = []
22 | self.obj_list = []
23 | self.max_mize = -1 if max_mize else 1
24 |
25 | self.select = opt_type
26 | self.best_object = 1e10 * self.max_mize
27 |
28 | def record_information(self, sample, objective):
29 | self.sample_list.append(sample)
30 | self.obj_list.append(objective*self.max_mize)
31 |
32 | def sampling(self):
33 | """
34 | Draw a sample from the categorical distribution (one-hot)
35 | Sample one archi at once
36 | """
37 | c = np.zeros(self.p_model.theta.shape, dtype=np.bool)
38 | for i, upper in enumerate(self.p_model.C):
39 | j = np.random.randint(upper)
40 | c[i, j] = True
41 | return c
42 |
43 | def sampling_index(self):
44 | return one_hot_to_index(np.array(self.sampling()))
45 |
46 | def mle(self):
47 | """
48 | Get most likely categorical variables (one-hot)
49 | """
50 | m = self.p_model.theta.argmax(axis=1)
51 | x = np.zeros((self.p_model.d, self.p_model.Cmax))
52 | for i, c in enumerate(m):
53 | x[i, c] = 1
54 | return x
55 |
56 | def update(self):
57 | objective = np.array(self.obj_list[-1])
58 | sample = np.array(self.sample_list[-1])
59 | if (self.select == 'best'):
60 | if (objective > self.best_object):
61 | self.best_object = objective
62 | # refresh theta to best one
63 | self.p_model.theta = np.array(sample)
64 | else:
65 | raise NotImplementedError
66 | self.sample_list = []
67 | self.obj_list = []
68 |
--------------------------------------------------------------------------------
/xnas/algorithms/SNG/categorical.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import logging
3 |
4 |
5 | class Categorical(object):
6 | """
7 | Categorical distribution for categorical variables parametrized by :math:`\\{ \\theta \\}_{i=1}^{(d \\times K)}`.
8 |
9 | :param categories: the numbers of categories
10 | :type categories: array_like, shape(d), dtype=int
11 | """
12 |
13 | def __init__(self, categories):
14 | self.d = len(categories)
15 | self.C = categories
16 | self.Cmax = np.max(categories)
17 | self.theta = np.zeros((self.d, self.Cmax))
18 | # initialize theta by 1/C for each dimensions
19 | for i in range(self.d):
20 | self.theta[i, :self.C[i]] = 1./self.C[i]
21 | # pad zeros to unused elements
22 | for i in range(self.d):
23 | self.theta[i, self.C[i]:] = 0.
24 | # number of valid parameters
25 | # self.valid_param_num = int(np.sum(self.C - 1))
26 | # valid dimension size
27 | # self.valid_d = len(self.C[self.C > 1])
28 |
29 | def sampling_lam(self, lam):
30 | """
31 | Draw :math:`\\lambda` samples from the categorical distribution.
32 | :param int lam: sample size :math:`\\lambda`
33 | :return: sampled variables from the categorical distribution (one-hot representation)
34 | :rtype: array_like, shape=(lam, d, Cmax), dtype=bool
35 | """
36 | rand = np.random.rand(lam, self.d, 1) # range of random number is [0, 1)
37 | cum_theta = self.theta.cumsum(axis=1) # (d, Cmax)
38 | X = (cum_theta - self.theta <= rand) & (rand < cum_theta)
39 | return X
40 |
41 | def sampling(self):
42 | """
43 | Draw a sample from the categorical distribution.
44 | :return: sampled variables from the categorical distribution (one-hot representation)
45 | :rtype: array_like, shape=(d, Cmax), dtype=bool
46 | """
47 | rand = np.random.rand(self.d, 1) # range of random number is [0, 1)
48 | cum_theta = self.theta.cumsum(axis=1) # (d, Cmax)
49 |
50 | # x[i, j] becomes 1 iff cum_theta[i, j] - theta[i, j] <= rand[i] < cum_theta[i, j]
51 | x = (cum_theta - self.theta <= rand) & (rand < cum_theta)
52 | return x
53 |
54 | def sampling_index(self):
55 | """
56 |
57 | :return: a numpy index list
58 | """
59 | index_list = []
60 | for prob in self.theta:
61 | index_list.append(np.random.choice(a=list(range(prob.shape[0])), p=prob))
62 | return np.array(index_list)
63 |
64 | def mle(self):
65 | """
66 | Return the most likely categories.
67 | :return: categorical variables (one-hot representation)
68 | :rtype: array_like, shape=(d, Cmax), dtype=bool
69 | """
70 | m = self.theta.argmax(axis=1)
71 | x = np.zeros((self.d, self.Cmax))
72 | for i, c in enumerate(m):
73 | x[i, c] = 1
74 | return x
75 |
76 | def loglikelihood(self, X):
77 | """
78 | Calculate log likelihood.
79 |
80 | :param X: samples (one-hot representation)
81 | :type X: array_like, shape=(lam, d, maxK), dtype=bool
82 | :return: log likelihoods
83 | :rtype: array_like, shape=(lam), dtype=float
84 | """
85 | return (X * np.log(self.theta)).sum(axis=2).sum(axis=1)
86 |
87 | def log_header(self):
88 | header_list = []
89 | for i in range(self.d):
90 | header_list += ['theta%d_%d' % (i, j) for j in range(self.C[i])]
91 | return header_list
92 |
93 | def log(self):
94 | theta_list = []
95 | for i in range(self.d):
96 | theta_list += ['%f' % self.theta[i, j] for j in range(self.C[i])]
97 | return theta_list
98 |
99 | def load_theta_from_log(self, theta):
100 | self.theta = np.zeros((self.d, self.Cmax))
101 | k = 0
102 | for i in range(self.d):
103 | for j in range(self.C[i]):
104 | self.theta[i, j] = theta[k]
105 | k += 1
106 |
107 | def print_theta(self, logger):
108 | # remove formats
109 | org_formatters = []
110 | for handler in logger.handlers:
111 | org_formatters.append(handler.formatter)
112 | handler.setFormatter(logging.Formatter("%(message)s"))
113 |
114 | logger.info("####### theta #######")
115 | logger.info("# Theta value")
116 | for alpha in self.theta:
117 | logger.info(alpha)
118 | logger.info("#####################")
119 |
120 | # restore formats
121 | for handler, formatter in zip(logger.handlers, org_formatters):
122 | handler.setFormatter(formatter)
123 |
--------------------------------------------------------------------------------
/xnas/algorithms/SPOS.py:
--------------------------------------------------------------------------------
1 | """Samplers for Single Path One-Shot Search Space."""
2 |
3 | import numpy as np
4 | from copy import deepcopy
5 | from collections import deque
6 |
7 |
8 | class RAND():
9 | """Random choice"""
10 | def __init__(self, num_choice, layers):
11 | self.num_choice = num_choice
12 | self.child_len = layers
13 | self.history = []
14 |
15 | def record(self, child, value):
16 | self.history.append({"child":child, "value":value})
17 |
18 | def suggest(self):
19 | return list(np.random.randint(self.num_choice, size=self.child_len))
20 |
21 | def final_best(self):
22 | best_child = min(self.history, key=lambda i:i["value"])
23 | return best_child['child'], best_child['value']
24 |
25 |
26 | class REA():
27 | """Regularized Evolution Algorithm"""
28 | def __init__(self, num_choice, layers, population_size=20, better=min):
29 | self.num_choice = num_choice
30 | self.population_size = population_size
31 | self.child_len = layers
32 | self.better = better
33 | self.population = deque()
34 | self.history = []
35 | # init population
36 | self.init_pop = np.random.randint(
37 | self.num_choice, size=(self.population_size, self.child_len)
38 | )
39 |
40 | def _get_mutated_parent(self):
41 | parent = self.better(self.population, key=lambda i:i["value"]) # default: min(error)
42 | return self._mutate(parent['child'])
43 |
44 | def _mutate(self, parent):
45 | parent = deepcopy(parent)
46 | idx = np.random.randint(0, len(parent))
47 | prev_value, new_value = parent[idx], parent[idx]
48 | while new_value == prev_value:
49 | new_value = np.random.randint(self.num_choice)
50 | parent[idx] = new_value
51 | return parent
52 |
53 | def record(self, child, value):
54 | self.history.append({"child":child, "value":value})
55 | self.population.append({"child":child, "value":value})
56 | if len(self.population) > self.population_size:
57 | self.population.popleft()
58 |
59 | def suggest(self):
60 | if len(self.history) < self.population_size:
61 | return list(self.init_pop[len(self.history)])
62 | else:
63 | return self._get_mutated_parent()
64 |
65 | def final_best(self):
66 | best_child = self.better(self.history, key=lambda i:i["value"])
67 | return best_child['child'], best_child['value']
68 |
--------------------------------------------------------------------------------
/xnas/core/config.py:
--------------------------------------------------------------------------------
1 | """Configuration file (powered by YACS)."""
2 |
3 | import os
4 | import sys
5 | import argparse
6 | from yacs.config import CfgNode
7 |
8 |
9 | # Global config object
10 | _C = CfgNode(new_allowed=True)
11 | cfg = _C
12 |
13 |
14 | # -------------------------------------------------------- #
15 | # Data Loader options
16 | # -------------------------------------------------------- #
17 | _C.LOADER = CfgNode(new_allowed=True)
18 |
19 | _C.LOADER.DATASET = "cifar10"
20 |
21 | # stay empty to use "./data/$dataset" as default
22 | _C.LOADER.DATAPATH = ""
23 |
24 | _C.LOADER.SPLIT = [0.8, 0.2]
25 |
26 | # whether using val dataset (imagenet only)
27 | _C.LOADER.USE_VAL = False
28 |
29 | _C.LOADER.NUM_CLASSES = 10
30 |
31 | _C.LOADER.NUM_WORKERS = 8
32 |
33 | _C.LOADER.PIN_MEMORY = True
34 |
35 | # batch size of training and validation
36 | # type: int or list(different during validation)
37 | # _C.LOADER.BATCH_SIZE = [256, 128]
38 | _C.LOADER.BATCH_SIZE = 256
39 |
40 | # augment type using by ImageNet only
41 | # chosen from ['default', 'auto_augment_tf']
42 | _C.LOADER.TRANSFORM = "default"
43 |
44 |
45 | # ------------------------------------------------------------------------------------ #
46 | # Search Space options
47 | # ------------------------------------------------------------------------------------ #
48 | _C.SPACE = CfgNode(new_allowed=True)
49 |
50 | _C.SPACE.NAME = 'darts'
51 |
52 | # first layer's channels, not channels for input image.
53 | _C.SPACE.CHANNELS = 16
54 |
55 | _C.SPACE.LAYERS = 8
56 |
57 | _C.SPACE.NODES = 4
58 |
59 | _C.SPACE.BASIC_OP = []
60 |
61 |
62 |
63 | # ------------------------------------------------------------------------------------ #
64 | # Optimizer options in network
65 | # ------------------------------------------------------------------------------------ #
66 | _C.OPTIM = CfgNode(new_allowed=True)
67 |
68 | # Base learning rate, init_lr = OPTIM.BASE_LR * NUM_GPUS
69 | _C.OPTIM.BASE_LR = 0.1
70 |
71 | _C.OPTIM.MIN_LR = 1.e-3
72 |
73 |
74 | # Learning rate policy select from {'cos', 'exp', 'steps'}
75 | _C.OPTIM.LR_POLICY = "cos"
76 | # Steps for 'steps' policy (in epochs)
77 | _C.OPTIM.STEPS = [30, 60, 90]
78 | # Learning rate multiplier for 'steps' policy
79 | _C.OPTIM.LR_MULT = 0.1
80 |
81 |
82 | # Momentum
83 | _C.OPTIM.MOMENTUM = 0.9
84 | # Momentum dampening
85 | _C.OPTIM.DAMPENING = 0.0
86 | # Nesterov momentum
87 | _C.OPTIM.NESTEROV = False
88 |
89 |
90 | _C.OPTIM.WEIGHT_DECAY = 5e-4
91 |
92 | _C.OPTIM.GRAD_CLIP = 5.0
93 |
94 | _C.OPTIM.MAX_EPOCH = 200
95 | # Warm up epochs
96 | _C.OPTIM.WARMUP_EPOCH = 0
97 | # Start the warm up from init_lr * OPTIM.WARMUP_FACTOR
98 | _C.OPTIM.WARMUP_FACTOR = 0.1
99 | # Ending epochs
100 | _C.OPTIM.FINAL_EPOCH = 0
101 |
102 |
103 |
104 | # -------------------------------------------------------- #
105 | # Searching options
106 | # -------------------------------------------------------- #
107 | _C.SEARCH = CfgNode(new_allowed=True)
108 |
109 | _C.SEARCH.IM_SIZE = 32
110 |
111 | # Multi-sized Crop
112 | # NOTE: IM_SIZE in ImageNet will be covered if this one is setted.
113 | _C.SEARCH.MULTI_SIZES = []
114 |
115 | # channels of input images, 3 for rgb
116 | _C.SEARCH.INPUT_CHANNELS = 3
117 |
118 | _C.SEARCH.LOSS_FUN = 'cross_entropy'
119 | # label smoothing for cross entropy loss
120 | _C.SEARCH.LABEL_SMOOTH = 0.
121 |
122 | # resume and path of checkpoints
123 | _C.SEARCH.AUTO_RESUME = True
124 |
125 | _C.SEARCH.WEIGHTS = ""
126 |
127 | _C.SEARCH.EVALUATION = ""
128 |
129 |
130 | # ------------------------------------------------------------------------------------ #
131 | # Options for model training
132 | # ------------------------------------------------------------------------------------ #
133 | _C.TRAIN = CfgNode(new_allowed=True)
134 |
135 | _C.TRAIN.IM_SIZE = 32
136 |
137 | # channels of input images, 3 for rgb
138 | _C.TRAIN.INPUT_CHANNELS = 3
139 |
140 | _C.TRAIN.DROP_PATH_PROB = 0.2
141 |
142 | _C.TRAIN.LAYERS = 20
143 |
144 | _C.TRAIN.CHANNELS = 36
145 |
146 | _C.TRAIN.GENOTYPE = ""
147 |
148 |
149 | # -------------------------------------------------------- #
150 | # Model testing options
151 | # -------------------------------------------------------- #
152 | _C.TEST = CfgNode(new_allowed=True)
153 |
154 | _C.TEST.IM_SIZE = 224
155 |
156 | # using specific batchsize for testing
157 | # using search.batch_size if this value keeps -1
158 | _C.TEST.BATCH_SIZE = -1
159 |
160 |
161 |
162 | # -------------------------------------------------------- #
163 | # Benchmarks options
164 | # -------------------------------------------------------- #
165 | _C.BENCHMARK = CfgNode(new_allowed=True)
166 |
167 | # Path to NAS-Bench-201 weights file
168 | _C.BENCHMARK.NB201PATH = "./data/NAS-Bench-201-v1_1-096897.pth"
169 |
170 | # path to NAS-Bench-301 folder
171 | _C.BENCHMARK.NB301PATH = "./data/nb301models/"
172 |
173 |
174 | # -------------------------------------------------------- #
175 | # Misc options
176 | # -------------------------------------------------------- #
177 |
178 | _C.CUDNN_BENCH = True
179 |
180 | _C.LOG_PERIOD = 10
181 |
182 | _C.EVAL_PERIOD = 1
183 |
184 | _C.SAVE_PERIOD = 1
185 |
186 | _C.NUM_GPUS = 1
187 |
188 | _C.OUT_DIR = "exp/"
189 |
190 | _C.DETERMINSTIC = True
191 |
192 | _C.RNG_SEED = 1
193 |
194 |
195 |
196 | # -------------------------------------------------------- #
197 |
198 | def dump_cfgfile(cfg_dest="config.yaml"):
199 | """Dumps the config to the output directory."""
200 | cfg_file = os.path.join(_C.OUT_DIR, cfg_dest)
201 | with open(cfg_file, "w") as f:
202 | _C.dump(stream=f)
203 |
204 |
205 | def load_cfgfile(out_dir, cfg_dest="config.yaml"):
206 | """Loads config from specified output directory."""
207 | cfg_file = os.path.join(out_dir, cfg_dest)
208 | _C.merge_from_file(cfg_file)
209 |
210 |
211 | def load_configs():
212 | """Load config from command line arguments and set any specified options.
213 | How to use: python xx.py --cfg path_to_your_config.cfg test1 0 test2 True
214 | opts will return a list with ['test1', '0', 'test2', 'True'], yacs will compile to corresponding values
215 | """
216 | parser = argparse.ArgumentParser(description="Config file options.")
217 | parser.add_argument("--cfg", required=True, type=str)
218 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER)
219 | if len(sys.argv) == 1:
220 | parser.print_help()
221 | sys.exit(1)
222 | args = parser.parse_args()
223 | _C.merge_from_file(args.cfg)
224 | _C.merge_from_list(args.opts)
225 |
--------------------------------------------------------------------------------
/xnas/core/utils.py:
--------------------------------------------------------------------------------
1 | import decimal
2 | import random
3 | import string
4 | import time
5 | import numpy as np
6 |
7 |
8 | def random_time_string(stringLength=8):
9 | letters = string.ascii_lowercase
10 | return str(time.time()).join(random.choice(letters) for i in range(stringLength))
11 |
12 |
13 | def one_hot_to_index(one_hot_matrix):
14 | return np.array([np.where(r == 1)[0][0] for r in one_hot_matrix])
15 |
16 |
17 | def index_to_one_hot(index_vector, C):
18 | return np.eye(C)[index_vector.reshape(-1)]
19 |
20 |
21 | def float_to_decimal(data, prec=4):
22 | """Convert floats to decimals which allows for fixed width json."""
23 | if isinstance(data, dict):
24 | return {k: float_to_decimal(v, prec) for k, v in data.items()}
25 | if isinstance(data, float):
26 | return decimal.Decimal(("{:." + str(prec) + "f}").format(data))
27 | else:
28 | return data
29 |
30 |
31 | def softmax(x, axis=None):
32 | x = x - x.max(axis=axis, keepdims=True)
33 | y = np.exp(x)
34 | return y / y.sum(axis=axis, keepdims=True)
35 |
--------------------------------------------------------------------------------
/xnas/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/xnas/datasets/__init__.py
--------------------------------------------------------------------------------
/xnas/datasets/imagenet16.py:
--------------------------------------------------------------------------------
1 | import os, sys, hashlib
2 | import numpy as np
3 | from PIL import Image
4 | import torch.utils.data as data
5 | import pickle
6 |
7 |
8 | # def ImageNet16Loader():
9 | # mean = [x / 255 for x in [122.68, 116.66, 104.01]]
10 | # std = [x / 255 for x in [63.22, 61.26, 65.09]]
11 | # lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
12 | # train_transform = transforms.Compose(lists)
13 | # train_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120)
14 | # assert len(train_data) == 151700
15 |
16 | def calculate_md5(fpath, chunk_size=1024 * 1024):
17 | md5 = hashlib.md5()
18 | with open(fpath, "rb") as f:
19 | for chunk in iter(lambda: f.read(chunk_size), b""):
20 | md5.update(chunk)
21 | return md5.hexdigest()
22 |
23 |
24 | def check_md5(fpath, md5, **kwargs):
25 | return md5 == calculate_md5(fpath, **kwargs)
26 |
27 |
28 | def check_integrity(fpath, md5=None):
29 | if not os.path.isfile(fpath):
30 | return False
31 | if md5 is None:
32 | return True
33 | else:
34 | return check_md5(fpath, md5)
35 |
36 |
37 | class ImageNet16(data.Dataset):
38 | # http://image-net.org/download-images
39 | # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
40 | # https://arxiv.org/pdf/1707.08819.pdf
41 |
42 | train_list = [
43 | ["train_data_batch_1", "27846dcaa50de8e21a7d1a35f30f0e91"],
44 | ["train_data_batch_2", "c7254a054e0e795c69120a5727050e3f"],
45 | ["train_data_batch_3", "4333d3df2e5ffb114b05d2ffc19b1e87"],
46 | ["train_data_batch_4", "1620cdf193304f4a92677b695d70d10f"],
47 | ["train_data_batch_5", "348b3c2fdbb3940c4e9e834affd3b18d"],
48 | ["train_data_batch_6", "6e765307c242a1b3d7d5ef9139b48945"],
49 | ["train_data_batch_7", "564926d8cbf8fc4818ba23d2faac7564"],
50 | ["train_data_batch_8", "f4755871f718ccb653440b9dd0ebac66"],
51 | ["train_data_batch_9", "bb6dd660c38c58552125b1a92f86b5d4"],
52 | ["train_data_batch_10", "8f03f34ac4b42271a294f91bf480f29b"],
53 | ]
54 | valid_list = [
55 | ["val_data", "3410e3017fdaefba8d5073aaa65e4bd6"],
56 | ]
57 |
58 | def __init__(self, root, train, transform, use_num_of_class_only=None):
59 | self.root = root
60 | self.transform = transform
61 | self.train = train # training set or valid set
62 | if not self._check_integrity():
63 | raise RuntimeError("Dataset not found or corrupted.")
64 |
65 | if self.train:
66 | downloaded_list = self.train_list
67 | else:
68 | downloaded_list = self.valid_list
69 | self.data = []
70 | self.targets = []
71 |
72 | # now load the picked numpy arrays
73 | for i, (file_name, checksum) in enumerate(downloaded_list):
74 | file_path = os.path.join(self.root, file_name)
75 | # print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
76 | with open(file_path, "rb") as f:
77 | if sys.version_info[0] == 2:
78 | entry = pickle.load(f)
79 | else:
80 | entry = pickle.load(f, encoding="latin1")
81 | self.data.append(entry["data"])
82 | self.targets.extend(entry["labels"])
83 | self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
84 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
85 | if use_num_of_class_only is not None:
86 | assert (
87 | isinstance(use_num_of_class_only, int)
88 | and use_num_of_class_only > 0
89 | and use_num_of_class_only < 1000
90 | ), "invalid use_num_of_class_only : {:}".format(use_num_of_class_only)
91 | new_data, new_targets = [], []
92 | for I, L in zip(self.data, self.targets):
93 | if 1 <= L <= use_num_of_class_only:
94 | new_data.append(I)
95 | new_targets.append(L)
96 | self.data = new_data
97 | self.targets = new_targets
98 | # self.mean.append(entry['mean'])
99 | # self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16)
100 | # self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1)
101 | # print ('Mean : {:}'.format(self.mean))
102 | # temp = self.data - np.reshape(self.mean, (1, 1, 1, 3))
103 | # std_data = np.std(temp, axis=0)
104 | # std_data = np.mean(np.mean(std_data, axis=0), axis=0)
105 | # print ('Std : {:}'.format(std_data))
106 |
107 | def __getitem__(self, index):
108 | img, target = self.data[index], self.targets[index] - 1
109 | img = Image.fromarray(img)
110 | if self.transform is not None:
111 | img = self.transform(img)
112 | return img, target
113 |
114 | def __len__(self):
115 | return len(self.data)
116 |
117 | def _check_integrity(self):
118 | root = self.root
119 | for fentry in self.train_list + self.valid_list:
120 | filename, md5 = fentry[0], fentry[1]
121 | fpath = os.path.join(root, filename)
122 | if not check_integrity(fpath, md5):
123 | return False
124 | return True
125 |
--------------------------------------------------------------------------------
/xnas/datasets/transforms_imagenet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from PIL import Image
4 | import torchvision.transforms as transforms
5 |
6 | from xnas.core.config import cfg
7 | from xnas.datasets.auto_augment_tf import auto_augment_policy, AutoAugment
8 |
9 |
10 | IMAGENET_RGB_MEAN = [0.485, 0.456, 0.406]
11 | IMAGENET_RGB_STD = [0.229, 0.224, 0.225]
12 |
13 |
14 | def get_data_transform(augment, **kwargs):
15 | if len(cfg.SEARCH.MULTI_SIZES)==0:
16 | # using single image_size for training
17 | train_crop_size = cfg.SEARCH.IM_SIZE
18 | else:
19 | # using MultiSize_RandomCrop
20 | train_crop_size = cfg.SEARCH.MULTI_SIZES
21 | min_train_scale = 0.08
22 | test_scale = math.ceil(cfg.TEST.IM_SIZE / 0.875) # 224 / 0.875 = 256
23 | test_crop_size = cfg.TEST.IM_SIZE # do not crop and using 224 by default.
24 |
25 | interpolation = transforms.InterpolationMode.BICUBIC
26 | if 'interpolation' in kwargs.keys() and kwargs['interpolation'] == 'bilinear':
27 | interpolation = transforms.InterpolationMode.BILINEAR
28 |
29 | da_args = {
30 | 'train_crop_size': train_crop_size,
31 | 'train_min_scale': min_train_scale,
32 | 'test_scale': test_scale,
33 | 'test_crop_size': test_crop_size,
34 | 'interpolation': interpolation,
35 | }
36 |
37 | if augment == 'default':
38 | return build_default_transform(**da_args)
39 | elif augment == 'auto_augment_tf':
40 | policy = 'v0' if 'policy' not in kwargs.keys() else kwargs['policy']
41 | return build_imagenet_auto_augment_tf_transform(policy=policy, **da_args)
42 | else:
43 | raise ValueError(augment)
44 |
45 |
46 | def get_normalize():
47 | return transforms.Normalize(
48 | mean=torch.Tensor(IMAGENET_RGB_MEAN),
49 | std=torch.Tensor(IMAGENET_RGB_STD),
50 | )
51 |
52 |
53 | def get_randomResizedCrop(train_crop_size=224, train_min_scale=0.08, interpolation=transforms.InterpolationMode.BICUBIC):
54 | if isinstance(train_crop_size, int):
55 | return transforms.RandomResizedCrop(train_crop_size, scale=(train_min_scale, 1.0), interpolation=interpolation)
56 | elif isinstance(train_crop_size, list):
57 | from xnas.datasets.transforms import MultiSizeRandomCrop
58 | msrc = MultiSizeRandomCrop(train_crop_size)
59 | return msrc
60 | else:
61 | raise TypeError(train_crop_size)
62 |
63 |
64 | def build_default_transform(
65 | train_crop_size=224, train_min_scale=0.08, test_scale=256, test_crop_size=224, interpolation=transforms.InterpolationMode.BICUBIC
66 | ):
67 | normalize = get_normalize()
68 | train_crop_transform = get_randomResizedCrop(
69 | train_crop_size, train_min_scale, interpolation
70 | )
71 | train_transform = transforms.Compose(
72 | [
73 | # transforms.RandomResizedCrop(train_crop_size, interpolation=interpolation),
74 | train_crop_transform,
75 | transforms.RandomHorizontalFlip(),
76 | transforms.ToTensor(),
77 | normalize,
78 | ]
79 | )
80 | test_transform = transforms.Compose(
81 | [
82 | transforms.Resize(test_scale, interpolation=interpolation),
83 | transforms.CenterCrop(test_crop_size),
84 | transforms.ToTensor(),
85 | normalize,
86 | ]
87 | )
88 | return train_transform, test_transform
89 |
90 |
91 | def build_imagenet_auto_augment_tf_transform(
92 | policy='v0', train_crop_size=224, train_min_scale=0.08, test_scale=256, test_crop_size=224, interpolation=transforms.InterpolationMode.BICUBIC
93 | ):
94 |
95 | normalize = get_normalize()
96 | img_size = train_crop_size
97 | aa_params = {
98 | "translate_const": int(img_size * 0.45),
99 | "img_mean": tuple(round(x) for x in IMAGENET_RGB_MEAN),
100 | }
101 |
102 | aa_policy = AutoAugment(auto_augment_policy(policy, aa_params))
103 | train_crop_transform = get_randomResizedCrop(
104 | train_crop_size, train_min_scale, interpolation
105 | )
106 | train_transform = transforms.Compose(
107 | [
108 | # transforms.RandomResizedCrop(train_crop_size, interpolation=interpolation),
109 | train_crop_transform,
110 | transforms.RandomHorizontalFlip(),
111 | aa_policy,
112 | transforms.ToTensor(),
113 | normalize,
114 | ]
115 | )
116 | test_transform = transforms.Compose(
117 | [
118 | transforms.Resize(test_scale, interpolation=interpolation),
119 | transforms.CenterCrop(test_crop_size),
120 | transforms.ToTensor(),
121 | normalize,
122 | ]
123 | )
124 | return train_transform, test_transform
125 |
--------------------------------------------------------------------------------
/xnas/evaluations/NASBench201.py:
--------------------------------------------------------------------------------
1 | """Evaluate model by NAS-Bench-201"""
2 |
3 | from xnas.core.config import cfg
4 | import xnas.logger.logging as logging
5 |
6 | logger = logging.get_logger(__name__)
7 |
8 |
9 | try:
10 | from nas_201_api import NASBench201API as API
11 | api = API(cfg.BENCHMARK.NB201PATH)
12 | except ImportError:
13 | print('Could not import NASBench201.')
14 | exit(1)
15 |
16 |
17 | def index_to_genotype(index):
18 | return api.arch(index)
19 |
20 | def evaluate(genotype, epoch=12, **kwargs):
21 | """Require info from NAS-Bench-201 API.
22 |
23 | Implemented following the source code of DrNAS.
24 | """
25 | result = api.query_by_arch(genotype, str(epoch))
26 | (
27 | cifar10_train,
28 | cifar10_test,
29 | cifar100_train,
30 | cifar100_valid,
31 | cifar100_test,
32 | imagenet16_train,
33 | imagenet16_valid,
34 | imagenet16_test,
35 | ) = distill(result)
36 |
37 | logger.info("Evaluate with NAS-Bench-201 (Bench epoch:{})".format(epoch))
38 | logger.info("cifar10 train %f test %f", cifar10_train, cifar10_test)
39 | logger.info("cifar100 train %f valid %f test %f", cifar100_train, cifar100_valid, cifar100_test)
40 | logger.info("imagenet16 train %f valid %f test %f", imagenet16_train, imagenet16_valid, imagenet16_test)
41 |
42 | if "writer" in kwargs.keys():
43 | writer = kwargs["writer"]
44 | cur_epoch = kwargs["cur_epoch"]
45 | writer.add_scalars("nasbench201/cifar10", {"train": cifar10_train, "test": cifar10_test}, cur_epoch)
46 | writer.add_scalars("nasbench201/cifar100", {"train": cifar100_train, "valid": cifar100_valid, "test": cifar100_test}, cur_epoch)
47 | writer.add_scalars("nasbench201/imagenet16", {"train": imagenet16_train, "valid": imagenet16_valid, "test": imagenet16_test}, cur_epoch)
48 | return result
49 |
50 | # distill 201api's results
51 | def distill(result):
52 | result = result.split("\n")
53 | cifar10 = result[5].replace(" ", "").split(":")
54 | cifar100 = result[7].replace(" ", "").split(":")
55 | imagenet16 = result[9].replace(" ", "").split(":")
56 |
57 | cifar10_train = float(cifar10[1].strip(",test")[-7:-2].strip("="))
58 | cifar10_test = float(cifar10[2][-7:-2].strip("="))
59 | cifar100_train = float(cifar100[1].strip(",valid")[-7:-2].strip("="))
60 | cifar100_valid = float(cifar100[2].strip(",test")[-7:-2].strip("="))
61 | cifar100_test = float(cifar100[3][-7:-2].strip("="))
62 | imagenet16_train = float(imagenet16[1].strip(",valid")[-7:-2].strip("="))
63 | imagenet16_valid = float(imagenet16[2].strip(",test")[-7:-2].strip("="))
64 | imagenet16_test = float(imagenet16[3][-7:-2].strip("="))
65 |
66 | return (
67 | cifar10_train,
68 | cifar10_test,
69 | cifar100_train,
70 | cifar100_valid,
71 | cifar100_test,
72 | imagenet16_train,
73 | imagenet16_valid,
74 | imagenet16_test,
75 | )
--------------------------------------------------------------------------------
/xnas/evaluations/NASBench301.py:
--------------------------------------------------------------------------------
1 | """Evaluate model by NAS-Bench-301"""
2 |
3 | import os
4 | from collections import namedtuple
5 | from xnas.core.config import cfg
6 | import xnas.logger.logging as logging
7 |
8 | try:
9 | import nasbench301 as nb
10 | except ImportError:
11 | print('Could not import NASBench301.')
12 | exit(1)
13 |
14 |
15 | __all__ = ['evaluate']
16 |
17 |
18 | logger = logging.get_logger(__name__)
19 |
20 | download_dir = cfg.BENCHMARK.NB301PATH
21 | version = '0.9'
22 |
23 |
24 | def init_model(version=0.9, download_dir=download_dir):
25 |
26 | # Note: Uses 0.9 as the default models, switch to 1.0 to use 1.0 models
27 | models_0_9_dir = os.path.join(download_dir, 'nb_models_0.9')
28 | model_paths_0_9 = {
29 | model_name: os.path.join(models_0_9_dir, '{}_v0.9'.format(model_name))
30 | for model_name in ['xgb', 'gnn_gin', 'lgb_runtime']
31 | }
32 | models_1_0_dir = os.path.join(download_dir, 'nb_models_1.0')
33 | model_paths_1_0 = {
34 | model_name: os.path.join(models_1_0_dir, '{}_v1.0'.format(model_name))
35 | for model_name in ['xgb', 'gnn_gin', 'lgb_runtime']
36 | }
37 | model_paths = model_paths_0_9 if version == '0.9' else model_paths_1_0
38 |
39 | # If the models are not available at the paths, automatically download
40 | # the models
41 | # Note: If you would like to provide your own model locations, comment this out
42 | if not all(os.path.exists(model) for model in model_paths.values()):
43 | nb.download_models(version=version, delete_zip=True,
44 | download_dir=download_dir)
45 |
46 | # Load the performance surrogate model
47 | # NOTE: Loading the ensemble will set the seed to the same as used during training (logged in the model_configs.json)
48 | # NOTE: Defaults to using the default model download path
49 | ensemble_dir_performance = model_paths['xgb']
50 | performance_model = nb.load_ensemble(ensemble_dir_performance)
51 |
52 | # Load the runtime surrogate model
53 | # NOTE: Defaults to using the default model download path
54 | ensemble_dir_runtime = model_paths['lgb_runtime']
55 | runtime_model = nb.load_ensemble(ensemble_dir_runtime)
56 |
57 | return performance_model, runtime_model
58 |
59 | def evaluate(genotype):
60 | """
61 | Evaluate with nasbench301, space=DARTS/nasbench301
62 |
63 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
64 | genotype = Genotype(
65 | normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)],
66 | normal_concat=[2, 3, 4, 5],
67 | reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)],
68 | reduce_concat=[2, 3, 4, 5]
69 | )
70 | """
71 |
72 | assert cfg.SPACE.NAME == "darts", "NAS-Bench-301 supports DARTS space only."
73 |
74 | performance_model, runtime_model = init_model(version, download_dir)
75 |
76 | # reformat the output of DartsCNN.genotype()
77 | genotype = reformat_DARTS(genotype)
78 |
79 | prediction_genotype = performance_model.predict(
80 | config=genotype, representation="genotype", with_noise=True)
81 | runtime_genotype = runtime_model.predict(
82 | config=genotype, representation="genotype")
83 |
84 | logger.info("Genotype architecture performance: %f, runtime %f" %
85 | (prediction_genotype, runtime_genotype))
86 |
87 | """
88 | Codes below are used when sampling from a ConfigSpace.
89 | Rewrite it when being used.
90 | """
91 | # configspace_path = os.path.join(download_dir, 'configspace.json')
92 | # with open(configspace_path, "r") as f:
93 | # json_string = f.read()
94 | # configspace = cs_json.read(json_string)
95 | # configspace_config = configspace.sample_configuration()
96 | # prediction_configspace = performance_model.predict(config=configspace_config, representation="configspace", with_noise=True)
97 | # runtime_configspace = runtime_model.predict(config=configspace_config, representation="configspace")
98 | # print("Configspace architecture performance: %f, runtime %f" %(prediction_configspace, runtime_configspace))
99 |
100 | def reformat_DARTS(genotype):
101 | """
102 | format genotype for DARTS-like
103 | from:
104 | Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_5x5', 0)], [('sep_conv_3x3', 2), ('max_pool_3x3', 1)], [('sep_conv_3x3', 3), ('dil_conv_3x3', 2)], [('dil_conv_5x5', 4), ('dil_conv_5x5', 3)]], normal_concat=range(2, 6), reduce=[[('max_pool_3x3', 0), ('sep_conv_5x5', 1)], [('max_pool_3x3', 0), ('dil_conv_5x5', 2)], [('max_pool_3x3', 0), ('sep_conv_5x5', 1)], [('dil_conv_5x5', 4), ('max_pool_3x3', 0)]], reduce_concat=range(2, 6))
105 | to:
106 | Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], reduce_concat=[2, 3, 4, 5])
107 | """
108 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
109 | if len(genotype.normal) == 1:
110 | return genotype
111 |
112 | _normal = []
113 | _reduce = []
114 | for i in genotype.normal:
115 | for j in i:
116 | _normal.append(j)
117 | for i in genotype.reduce:
118 | for j in i:
119 | _reduce.append(j)
120 | _normal_concat = [i for i in genotype.normal_concat]
121 | _reduce_concat = [i for i in genotype.reduce_concat]
122 | r_genotype = Genotype(
123 | normal=_normal,
124 | normal_concat=_normal_concat,
125 | reduce=_reduce,
126 | reduce_concat=_reduce_concat
127 | )
128 | return r_genotype
129 |
--------------------------------------------------------------------------------
/xnas/evaluations/NASBenchMacro/evaluate.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | with open('xnas/evaluations/NASBenchmacro/nas-bench-macro_cifar10.json', 'r+', encoding='utf-8') as data_file:
4 | data = json.load(data_file)
5 |
6 |
7 | def evaluate(arch):
8 | # print(data[arch])
9 | print("{} accuracy for three test on CIFAR10: {}".format(arch, data[arch]['test_acc']))
10 | print("mean accuracy : {}, std : {}, params : {}, flops : {}".format(data[arch]['mean_acc'], data[arch]['std'],
11 | data[arch]['params'], data[arch]['flops']))
12 |
--------------------------------------------------------------------------------
/xnas/evaluations/NASBenchmacro/evaluate.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | with open('xnas/evaluations/NASBenchmacro/nas-bench-macro_cifar10.json', 'r+', encoding='utf-8') as data_file:
4 | data = json.load(data_file)
5 |
6 |
7 | def evaluate(arch):
8 | # print(data[arch])
9 | print("{} accuracy for three test on CIFAR10: {}".format(arch, data[arch]['test_acc']))
10 | print("mean accuracy : {}, std : {}, params : {}, flops : {}".format(data[arch]['mean_acc'], data[arch]['std'],
11 | data[arch]['params'], data[arch]['flops']))
12 |
--------------------------------------------------------------------------------
/xnas/logger/checkpoint.py:
--------------------------------------------------------------------------------
1 | """Functions that handle saving and loading of checkpoints"""
2 |
3 | import os
4 | import torch
5 | from xnas.core.config import cfg
6 |
7 | # Checkpoints directory name
8 | _DIR_NAME = "checkpoints"
9 | # Common prefix for checkpoint file names
10 | _NAME_PREFIX = "model_epoch_"
11 |
12 |
13 | def get_checkpoint_dir(out_dir=None):
14 | """Retrieves the location for storing checkpoints."""
15 | if out_dir is None:
16 | return os.path.join(cfg.OUT_DIR, _DIR_NAME)
17 | else:
18 | return os.path.join(out_dir, _DIR_NAME)
19 |
20 |
21 | def get_checkpoint_name(epoch, checkpoint_dir=None, best=False):
22 | """Retrieves the path to a checkpoint file."""
23 | name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
24 | name = "best_" + name if best else name
25 | if checkpoint_dir is None:
26 | return os.path.join(get_checkpoint_dir(), name)
27 | else:
28 | return os.path.join(checkpoint_dir, name)
29 |
30 |
31 | def get_last_checkpoint(checkpoint_dir=None, best=False):
32 | """Retrieves the most recent checkpoint (highest epoch number)."""
33 | if checkpoint_dir is None:
34 | checkpoint_dir = get_checkpoint_dir()
35 | # Checkpoint file names are in lexicographic order
36 | filename = "best_" + _NAME_PREFIX if best else _NAME_PREFIX
37 | checkpoints = [f for f in os.listdir(checkpoint_dir) if filename in f]
38 | last_checkpoint_name = sorted(checkpoints)[-1]
39 | return os.path.join(checkpoint_dir, last_checkpoint_name)
40 |
41 |
42 | def has_checkpoint(checkpoint_dir=None):
43 | """Determines if there are checkpoints available."""
44 | if checkpoint_dir is None:
45 | checkpoint_dir = get_checkpoint_dir()
46 | if not os.path.exists(checkpoint_dir):
47 | return False
48 | return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
49 |
50 |
51 | def save_checkpoint(model, epoch, checkpoint_dir=None, best=False, **kwargs):
52 | """Saves a checkpoint."""
53 | if checkpoint_dir is None:
54 | checkpoint_dir = get_checkpoint_dir()
55 | os.makedirs(checkpoint_dir, exist_ok=True)
56 |
57 | ms = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
58 | sd = ms.state_dict()
59 | checkpoint = {
60 | "epoch": epoch,
61 | "model_state": sd,
62 | }
63 | for k,v in kwargs.items():
64 | vsd = v.state_dict()
65 | checkpoint[k] = vsd
66 | # Write the checkpoint
67 | checkpoint_file = get_checkpoint_name(epoch + 1, checkpoint_dir, best=best)
68 | torch.save(checkpoint, checkpoint_file)
69 | return checkpoint_file
70 |
71 |
72 | def load_checkpoint(checkpoint_file, model):
73 | """Loads the checkpoint from the given file."""
74 | err_str = "Checkpoint '{}' not found"
75 | assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
76 | # Load the checkpoint on CPU to avoid GPU mem spike
77 | checkpoint = torch.load(checkpoint_file, map_location="cpu")
78 | # Account for the DDP wrapper in the multi-gpu setting
79 | ms = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
80 | ms.load_state_dict(checkpoint["model_state"])
81 | # Load the optimizer state (commonly not done when fine-tuning)
82 | others = {}
83 | for k,v in checkpoint.items():
84 | if k not in ["epoch", "model_state"]:
85 | others[k] = v
86 | return checkpoint["epoch"], others
87 |
--------------------------------------------------------------------------------
/xnas/logger/logging.py:
--------------------------------------------------------------------------------
1 | """Logging."""
2 |
3 | import os
4 | import logging
5 | import simplejson
6 | from xnas.core.config import cfg
7 | from xnas.core.utils import float_to_decimal
8 |
9 |
10 | # Show filename and line number in logs
11 | _FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"
12 |
13 | # Log file name
14 | _LOG_FILE = "stdout.log"
15 |
16 | # Data output with dump_log_data(data, data_type) will be tagged w/ this
17 | _TAG = "json_stats: "
18 |
19 | # Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type
20 | _TYPE = "_type"
21 |
22 |
23 | def setup_logging():
24 | """Sets up the logging."""
25 | # Clear the root logger to prevent any existing logging config
26 | # (e.g. set by another module) from messing with our setup
27 | logging.root.handlers = []
28 | # Construct logging configuration
29 | logging_config = {"level": logging.INFO, "format": _FORMAT}
30 | logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
31 | # Configure logging
32 | logging.basicConfig(**logging_config)
33 |
34 |
35 | def get_logger(name):
36 | """Retrieves the logger."""
37 | return logging.getLogger(name)
38 |
39 |
40 | def dump_log_data(data, data_type, prec=4):
41 | """Covert data (a dictionary) into tagged json string for logging."""
42 | data[_TYPE] = data_type
43 | data = float_to_decimal(data, prec)
44 | data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True)
45 | return "{:s}{:s}".format(_TAG, data_json)
46 |
--------------------------------------------------------------------------------
/xnas/logger/timer.py:
--------------------------------------------------------------------------------
1 | """Timer."""
2 |
3 | import time
4 |
5 |
6 | class Timer(object):
7 | """A simple timer (adapted from Detectron)."""
8 |
9 | def __init__(self):
10 | self.total_time = None
11 | self.calls = None
12 | self.start_time = None
13 | self.diff = None
14 | self.average_time = None
15 | self.reset()
16 |
17 | def tic(self):
18 | # using time.time as time.clock does not nomalize for multithreading
19 | self.start_time = time.time()
20 |
21 | def toc(self):
22 | self.diff = time.time() - self.start_time
23 | self.total_time += self.diff
24 | self.calls += 1
25 | self.average_time = self.total_time / self.calls
26 |
27 | def reset(self):
28 | self.total_time = 0.0
29 | self.calls = 0
30 | self.start_time = 0.0
31 | self.diff = 0.0
32 | self.average_time = 0.0
33 |
--------------------------------------------------------------------------------
/xnas/runner/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/xnas/runner/__init__.py
--------------------------------------------------------------------------------
/xnas/runner/criterion.py:
--------------------------------------------------------------------------------
1 | """Loss functions."""
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from xnas.core.config import cfg
7 |
8 |
9 | __all__ = ['criterion_builder']
10 |
11 |
12 | def _label_smooth(target, n_classes: int, label_smoothing):
13 | # convert to one-hot
14 | batch_size = target.size(0)
15 | target = torch.unsqueeze(target, 1)
16 | soft_target = torch.zeros((batch_size, n_classes), device=target.device)
17 | soft_target.scatter_(1, target, 1)
18 | # label smoothing
19 | soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
20 | return soft_target
21 |
22 |
23 | def CrossEntropyLoss_soft_target(pred, soft_target):
24 | """CELoss with soft target, mainly used during KD"""
25 | logsoftmax = nn.LogSoftmax(dim=1)
26 | return torch.mean(torch.sum(-soft_target * logsoftmax(pred), dim=1))
27 |
28 |
29 | def CrossEntropyLoss_label_smoothed(pred, target, label_smoothing=0.):
30 | label_smoothing = cfg.SEARCH.LABEL_SMOOTH if label_smoothing == 0. else label_smoothing
31 | soft_target = _label_smooth(target, pred.size(1), label_smoothing)
32 | return CrossEntropyLoss_soft_target(pred, soft_target)
33 |
34 |
35 | class KLLossSoft(torch.nn.modules.loss._Loss):
36 | """ inplace distillation for image classification
37 | output: output logits of the student network
38 | target: output logits of the teacher network
39 | T: temperature
40 | KL(p||q) = Ep \log p - \Ep log q
41 | """
42 | def forward(self, output, soft_logits, target=None, temperature=1., alpha=0.9):
43 | output, soft_logits = output / temperature, soft_logits / temperature
44 | soft_target_prob = F.softmax(soft_logits, dim=1)
45 | output_log_prob = F.log_softmax(output, dim=1)
46 | kd_loss = -torch.sum(soft_target_prob * output_log_prob, dim=1)
47 | if target is not None:
48 | n_class = output.size(1)
49 | target = torch.zeros_like(output).scatter(1, target.view(-1, 1), 1)
50 | target = target.unsqueeze(1)
51 | output_log_prob = output_log_prob.unsqueeze(2)
52 | ce_loss = -torch.bmm(target, output_log_prob).squeeze()
53 | loss = alpha * temperature * temperature * kd_loss + (1.0 - alpha) * ce_loss
54 | else:
55 | loss = kd_loss
56 |
57 | if self.reduction == 'mean':
58 | return loss.mean()
59 | elif self.reduction == 'sum':
60 | return loss.sum()
61 | return loss
62 |
63 |
64 | class MultiHeadCrossEntropyLoss(nn.Module):
65 | def forward(self, preds, targets):
66 | assert preds.dim() == 3, preds
67 | assert targets.dim() == 2, targets
68 |
69 | assert preds.size(1) == targets.size(1), (preds, targets)
70 | num_heads = targets.size(1)
71 |
72 | loss = 0
73 | for k in range(num_heads):
74 | loss += F.cross_entropy(preds[:, k, :], targets[:, k]) / num_heads
75 | return loss
76 |
77 |
78 | # ----------
79 |
80 | SUPPORTED_CRITERIONS = {
81 | "cross_entropy": torch.nn.CrossEntropyLoss(),
82 | "cross_entropy_soft": CrossEntropyLoss_soft_target,
83 | "cross_entropy_smooth": CrossEntropyLoss_label_smoothed,
84 | "cross_entropy_multihead": MultiHeadCrossEntropyLoss(),
85 | "kl_soft": KLLossSoft(),
86 | }
87 |
88 |
89 | def criterion_builder(criterion=None):
90 | err_str = "Loss function type '{}' not supported"
91 | loss_fun = cfg.SEARCH.LOSS_FUN if criterion is None else criterion
92 | assert loss_fun in SUPPORTED_CRITERIONS.keys(), err_str.format(loss_fun)
93 | return SUPPORTED_CRITERIONS[loss_fun]
94 |
--------------------------------------------------------------------------------
/xnas/runner/optimizer.py:
--------------------------------------------------------------------------------
1 | """Optimizers."""
2 |
3 | import torch
4 | from xnas.core.config import cfg
5 |
6 |
7 | __all__ = [
8 | 'optimizer_builder',
9 | 'darts_alpha_optimizer',
10 | ]
11 |
12 |
13 | SUPPORTED_OPTIMIZERS = {
14 | "SGD",
15 | "Adam",
16 | }
17 |
18 |
19 | def optimizer_builder(name, param):
20 | """optimizer builder
21 |
22 | Args:
23 | name (str): name of optimizer
24 | param (dict): parameters to optimize
25 |
26 | Returns:
27 | optimizer: optimizer
28 | """
29 | assert name in SUPPORTED_OPTIMIZERS, "optimizer not supported."
30 | if name == "SGD":
31 | return torch.optim.SGD(
32 | param,
33 | cfg.OPTIM.BASE_LR,
34 | cfg.OPTIM.MOMENTUM,
35 | cfg.OPTIM.DAMPENING, # 0.0 following default
36 | cfg.OPTIM.WEIGHT_DECAY,
37 | cfg.OPTIM.NESTEROV, # False following default
38 | )
39 | elif name == "Adam":
40 | return torch.optim.Adam(
41 | param,
42 | cfg.OPTIM.BASE_LR,
43 | betas=(0.5, 0.999),
44 | weight_decay=cfg.OPTIM.WEIGHT_DECAY,
45 | )
46 |
47 |
48 | def darts_alpha_optimizer(name, param):
49 | """alpha optimizer for DARTS-like methods.
50 | Make sure cfg.DARTS has been initialized.
51 |
52 | Args:
53 | name (str): name of optimizer
54 | param (dict): parameters to optimize
55 |
56 | Returns:
57 | optimizer: optimizer
58 | """
59 | assert name in SUPPORTED_OPTIMIZERS, "optimizer not supported."
60 | if name == "Adam":
61 | return torch.optim.Adam(
62 | param,
63 | cfg.DARTS.ALPHA_LR,
64 | betas=(0.5, 0.999),
65 | weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY,
66 | )
67 |
--------------------------------------------------------------------------------
/xnas/runner/scheduler.py:
--------------------------------------------------------------------------------
1 | """Learning rate schedulers."""
2 |
3 | import math
4 | import torch
5 | from xnas.core.config import cfg
6 |
7 | from torch.optim.lr_scheduler import _LRScheduler
8 |
9 |
10 | __all__ = ['lr_scheduler_builder', 'adjust_learning_rate_per_batch']
11 |
12 |
13 | def lr_scheduler_builder(optimizer, last_epoch=-1, **kwargs):
14 | """Learning rate scheduler, now support warmup_epoch."""
15 | actual_scheduler = None
16 | if cfg.OPTIM.LR_POLICY == "cos":
17 | actual_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
18 | optimizer,
19 | T_max=kwargs['T_max'] if 'T_max' in kwargs.keys() else cfg.OPTIM.MAX_EPOCH,
20 | eta_min=cfg.OPTIM.MIN_LR,
21 | last_epoch=last_epoch)
22 | elif cfg.OPTIM.LR_POLICY == "step":
23 | actual_scheduler = torch.optim.lr_scheduler.MultiStepLR(
24 | optimizer,
25 | cfg.OPTIM.STEPS,
26 | gamma=cfg.OPTIM.LR_MULT,
27 | last_epoch=last_epoch)
28 | else:
29 | raise NotImplementedError
30 |
31 | if cfg.OPTIM.WARMUP_EPOCH > 0:
32 | return GradualWarmupScheduler(
33 | optimizer,
34 | actual_scheduler,
35 | cfg.OPTIM.WARMUP_EPOCH,
36 | cfg.OPTIM.WARMUP_FACTOR,
37 | last_epoch)
38 | else:
39 | return actual_scheduler
40 |
41 |
42 | class GradualWarmupScheduler(_LRScheduler):
43 | """
44 | Implementation Reference:
45 | https://github.com/ildoonet/pytorch-gradual-warmup-lr
46 | Args:
47 | optimizer (Optimizer): Wrapped optimizer.
48 | factor: start the warm up from init_lr * factor
49 | actual_scheduler: after warmup_epochs, use this scheduler
50 | warmup_epochs: init_lr is reached at warmup_epochs, linearly
51 | """
52 |
53 | def __init__(self,
54 | optimizer: torch.optim.Optimizer,
55 | actual_scheduler: _LRScheduler,
56 | warmup_epochs: int,
57 | factor: float,
58 | last_epoch=-1):
59 | super().__init__(optimizer, last_epoch)
60 |
61 | self.actual_scheduler = actual_scheduler
62 | self.warmup_epochs = warmup_epochs
63 | self.factor = factor
64 | self.last_epoch = last_epoch
65 |
66 |
67 | def get_lr(self):
68 | if self.last_epoch > self.warmup_epochs:
69 | return self.actual_scheduler.get_lr()
70 | else:
71 | return [base_lr * ((1. - self.factor) * self.last_epoch / self.warmup_epochs + self.factor) for base_lr in self.base_lrs]
72 |
73 |
74 | def step(self, epoch=None):
75 | if self.last_epoch > self.warmup_epochs:
76 | if epoch is None:
77 | self.actual_scheduler.step(None)
78 | else:
79 | self.actual_scheduler.step(epoch - self.warmup_epochs)
80 | self._last_lr = self.actual_scheduler.get_last_lr()
81 | else:
82 | return super(GradualWarmupScheduler, self).step(epoch)
83 |
84 |
85 | def _calc_learning_rate(
86 | init_lr, n_epochs, epoch, n_iter=None, iter=0,
87 | ):
88 | if cfg.OPTIM.LR_POLICY == "cos":
89 | t_total = n_epochs * n_iter
90 | t_cur = epoch * n_iter + iter
91 | lr = 0.5 * init_lr * (1 + math.cos(math.pi * t_cur / t_total))
92 | else:
93 | raise ValueError("do not support: {}".format(cfg.OPTIM.LR_POLICY))
94 | return lr
95 |
96 |
97 | def _warmup_adjust_learning_rate(
98 | init_lr, n_epochs, epoch, n_iter, iter=0, warmup_lr=0
99 | ):
100 | """adjust lr during warming-up. Changes linearly from `warmup_lr` to `init_lr`."""
101 | T_cur = epoch * n_iter + iter + 1
102 | t_total = n_epochs * n_iter
103 | new_lr = T_cur / t_total * (init_lr - warmup_lr) + warmup_lr
104 | return new_lr
105 |
106 |
107 | def adjust_learning_rate_per_batch(epoch, n_iter=None, iter=0, warmup=False):
108 | """adjust learning of a given optimizer and return the new learning rate"""
109 |
110 | init_lr = cfg.OPTIM.BASE_LR * cfg.NUM_GPUS
111 | n_epochs = cfg.OPTIM.MAX_EPOCH
112 | n_warmup_epochs = cfg.OPTIM.WARMUP_EPOCH
113 | warmup_lr = init_lr * cfg.OPTIM.WARMUP_FACTOR
114 |
115 | if warmup:
116 | new_lr = _warmup_adjust_learning_rate(
117 | init_lr, n_warmup_epochs, epoch, n_iter, iter, warmup_lr
118 | )
119 | else:
120 | new_lr = _calc_learning_rate(
121 | init_lr, n_epochs, epoch, n_iter, iter
122 | )
123 | return new_lr
124 |
--------------------------------------------------------------------------------
/xnas/spaces/BigNAS/ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from collections import OrderedDict
5 | from xnas.spaces.OFA.ops import SEModule, build_activation
6 | from xnas.spaces.OFA.utils import (
7 | get_same_padding,
8 | )
9 |
10 | class MBConvLayer(nn.Module):
11 | def __init__(
12 | self,
13 | in_channels,
14 | out_channels,
15 | kernel_size=3,
16 | stride=1,
17 | expand_ratio=6,
18 | mid_channels=None,
19 | act_func="relu6",
20 | use_se=False,
21 | channels_per_group=1,
22 | ):
23 | super(MBConvLayer, self).__init__()
24 |
25 | self.in_channels = in_channels
26 | self.out_channels = out_channels
27 |
28 | self.kernel_size = kernel_size
29 | self.stride = stride
30 | self.expand_ratio = expand_ratio
31 | self.mid_channels = mid_channels
32 | self.act_func = act_func
33 | self.use_se = use_se
34 | self.channels_per_group = channels_per_group
35 |
36 | if self.mid_channels is None:
37 | feature_dim = round(self.in_channels * self.expand_ratio)
38 | else:
39 | feature_dim = self.mid_channels
40 |
41 | if self.expand_ratio == 1:
42 | self.inverted_bottleneck = None
43 | else:
44 | self.inverted_bottleneck = nn.Sequential(OrderedDict([
45 | ("conv", nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
46 | ("bn", nn.BatchNorm2d(feature_dim)),
47 | ("act", build_activation(self.act_func, inplace=True)),
48 | ]))
49 |
50 | assert feature_dim % self.channels_per_group == 0
51 | active_groups = feature_dim // self.channels_per_group
52 | pad = get_same_padding(self.kernel_size)
53 |
54 | # assert feature_dim % self.groups == 0
55 | # active_groups = feature_dim // self.groups
56 | depth_conv_modules = [
57 | (
58 | "conv",
59 | nn.Conv2d(
60 | feature_dim,
61 | feature_dim,
62 | kernel_size,
63 | stride,
64 | pad,
65 | groups=active_groups,
66 | bias=False,
67 | ),
68 | ),
69 | ("bn", nn.BatchNorm2d(feature_dim)),
70 | ("act", build_activation(self.act_func, inplace=True)),
71 | ]
72 | if self.use_se:
73 | depth_conv_modules.append(("se", SEModule(feature_dim)))
74 | self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules))
75 |
76 | self.point_linear = nn.Sequential(
77 | OrderedDict(
78 | [
79 | ("conv", nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
80 | ("bn", nn.BatchNorm2d(out_channels)),
81 | ]
82 | )
83 | )
84 |
85 | def forward(self, x):
86 | if self.inverted_bottleneck:
87 | x = self.inverted_bottleneck(x)
88 | x = self.depth_conv(x)
89 | x = self.point_linear(x)
90 | return x
91 |
92 | @property
93 | def module_str(self):
94 | if self.mid_channels is None:
95 | expand_ratio = self.expand_ratio
96 | else:
97 | expand_ratio = self.mid_channels // self.in_channels
98 | layer_str = "%dx%d_MBConv%d_%s" % (
99 | self.kernel_size,
100 | self.kernel_size,
101 | expand_ratio,
102 | self.act_func.upper(),
103 | )
104 | if self.use_se:
105 | layer_str = "SE_" + layer_str
106 | layer_str += "_O%d" % self.out_channels
107 | if self.channels_per_group is not None:
108 | layer_str += "_G%d" % self.channels_per_group
109 | if isinstance(self.point_linear.bn, nn.GroupNorm):
110 | layer_str += "_GN%d" % self.point_linear.bn.num_groups
111 | elif isinstance(self.point_linear.bn, nn.BatchNorm2d):
112 | layer_str += "_BN"
113 |
114 | return layer_str
115 |
116 | @property
117 | def config(self):
118 | return {
119 | "name": MBConvLayer.__name__,
120 | "in_channels": self.in_channels,
121 | "out_channels": self.out_channels,
122 | "kernel_size": self.kernel_size,
123 | "stride": self.stride,
124 | "expand_ratio": self.expand_ratio,
125 | "mid_channels": self.mid_channels,
126 | "act_func": self.act_func,
127 | "use_se": self.use_se,
128 | "channeles_per_group": self.channels_per_group,
129 | }
130 |
131 | @staticmethod
132 | def build_from_config(config):
133 | return MBConvLayer(**config)
134 |
--------------------------------------------------------------------------------
/xnas/spaces/BigNAS/utils.py:
--------------------------------------------------------------------------------
1 | # Implementation adapted from attentiveNAS - https://github.com/facebookresearch/AttentiveNAS
2 |
3 | import torch
4 | import torch.nn as nn
5 | import copy
6 | import math
7 |
8 | multiply_adds = 1
9 |
10 |
11 | def count_convNd(m, _, y):
12 | cin = m.in_channels
13 |
14 | kernel_ops = m.weight.size()[2] * m.weight.size()[3]
15 | ops_per_element = kernel_ops
16 | output_elements = y.nelement()
17 |
18 | # cout x oW x oH
19 | total_ops = cin * output_elements * ops_per_element // m.groups
20 | m.total_ops = torch.Tensor([int(total_ops)])
21 |
22 |
23 | def count_linear(m, _, __):
24 | total_ops = m.in_features * m.out_features
25 |
26 | m.total_ops = torch.Tensor([int(total_ops)])
27 |
28 |
29 | register_hooks = {
30 | nn.Conv1d: count_convNd,
31 | nn.Conv2d: count_convNd,
32 | nn.Conv3d: count_convNd,
33 | ######################################
34 | nn.Linear: count_linear,
35 | ######################################
36 | nn.Dropout: None,
37 | nn.Dropout2d: None,
38 | nn.Dropout3d: None,
39 | nn.BatchNorm2d: None,
40 | }
41 |
42 |
43 | def profile(model, input_size=(1, 3, 224, 224), custom_ops=None):
44 | handler_collection = []
45 | custom_ops = {} if custom_ops is None else custom_ops
46 |
47 | def add_hooks(m_):
48 | if len(list(m_.children())) > 0:
49 | return
50 |
51 | m_.register_buffer('total_ops', torch.zeros(1))
52 | m_.register_buffer('total_params', torch.zeros(1))
53 |
54 | for p in m_.parameters():
55 | m_.total_params += torch.Tensor([p.numel()])
56 |
57 | m_type = type(m_)
58 | fn = None
59 |
60 | if m_type in custom_ops:
61 | fn = custom_ops[m_type]
62 | elif m_type in register_hooks:
63 | fn = register_hooks[m_type]
64 | else:
65 | # print("Not implemented for ", m_)
66 | pass
67 |
68 | if fn is not None:
69 | # print("Register FLOP counter for module %s" % str(m_))
70 | _handler = m_.register_forward_hook(fn)
71 | handler_collection.append(_handler)
72 |
73 | original_device = model.parameters().__next__().device
74 | training = model.training
75 |
76 | model.eval()
77 | model.apply(add_hooks)
78 |
79 | x = torch.zeros(input_size).to(original_device)
80 | with torch.no_grad():
81 | model(x)
82 |
83 | total_ops = 0
84 | total_params = 0
85 | for m in model.modules():
86 | if len(list(m.children())) > 0: # skip for non-leaf module
87 | continue
88 | total_ops += m.total_ops
89 | total_params += m.total_params
90 |
91 | total_ops = total_ops.item()
92 | total_params = total_params.item()
93 |
94 | model.train(training)
95 | model.to(original_device)
96 |
97 | for handler in handler_collection:
98 | handler.remove()
99 |
100 | return total_ops, total_params
101 |
102 |
103 | def count_net_flops_and_params(net, data_shape=(1, 3, 224, 224)):
104 | if isinstance(net, nn.DataParallel):
105 | net = net.module
106 |
107 | net = copy.deepcopy(net)
108 | flop, nparams = profile(net, data_shape)
109 | return flop /1e6, nparams /1e6
110 |
111 |
112 | def init_model(self, model_init="he_fout"):
113 | """ Conv2d, BatchNorm2d, BatchNorm1d, Linear, """
114 | for m in self.modules():
115 | if isinstance(m, nn.Conv2d):
116 | if model_init == 'he_fout':
117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
118 | m.weight.data.normal_(0, math.sqrt(2. / n))
119 | elif model_init == 'he_fin':
120 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
121 | m.weight.data.normal_(0, math.sqrt(2. / n))
122 | else:
123 | raise NotImplementedError
124 | if m.bias is not None:
125 | m.bias.data.zero_()
126 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
127 | if m.affine:
128 | m.weight.data.fill_(1)
129 | m.bias.data.zero_()
130 | elif isinstance(m, nn.Linear):
131 | stdv = 1. / math.sqrt(m.weight.size(1))
132 | m.weight.data.uniform_(-stdv, stdv)
133 | if m.bias is not None:
134 | m.bias.data.zero_()
--------------------------------------------------------------------------------
/xnas/spaces/DARTS/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/xnas/spaces/DARTS/__init__.py
--------------------------------------------------------------------------------
/xnas/spaces/DARTS/utils.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from .ops import *
3 |
4 |
5 | basic_op_list = ['max_pool_3x3', 'avg_pool_3x3', 'skip_connect', 'sep_conv_3x3', 'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5', 'none']
6 |
7 |
8 | def geno_from_alpha(theta):
9 | Genotype = namedtuple(
10 | 'Genotype', 'normal normal_concat reduce reduce_concat')
11 | theta_norm = darts_weight_unpack(
12 | theta[0:14], 4)
13 | theta_reduce = darts_weight_unpack(
14 | theta[14:], 4)
15 | gene_normal = parse_from_numpy(
16 | theta_norm, k=2, basic_op_list=basic_op_list)
17 | gene_reduce = parse_from_numpy(
18 | theta_reduce, k=2, basic_op_list=basic_op_list)
19 | concat = range(2, 6) # concat all intermediate nodes
20 | return Genotype(normal=gene_normal, normal_concat=concat,
21 | reduce=gene_reduce, reduce_concat=concat)
22 |
23 | def reformat_DARTS(genotype):
24 | """
25 | format genotype for DARTS-like
26 | from:
27 | Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_5x5', 0)], [('sep_conv_3x3', 2), ('max_pool_3x3', 1)], [('sep_conv_3x3', 3), ('dil_conv_3x3', 2)], [('dil_conv_5x5', 4), ('dil_conv_5x5', 3)]], normal_concat=range(2, 6), reduce=[[('max_pool_3x3', 0), ('sep_conv_5x5', 1)], [('max_pool_3x3', 0), ('dil_conv_5x5', 2)], [('max_pool_3x3', 0), ('sep_conv_5x5', 1)], [('dil_conv_5x5', 4), ('max_pool_3x3', 0)]], reduce_concat=range(2, 6))
28 | to:
29 | Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], reduce_concat=[2, 3, 4, 5])
30 | """
31 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
32 | _normal = []
33 | _reduce = []
34 | for i in genotype.normal:
35 | for j in i:
36 | _normal.append(j)
37 | for i in genotype.reduce:
38 | for j in i:
39 | _reduce.append(j)
40 | _normal_concat = [i for i in genotype.normal_concat]
41 | _reduce_concat = [i for i in genotype.reduce_concat]
42 | r_genotype = Genotype(
43 | normal=_normal,
44 | normal_concat=_normal_concat,
45 | reduce=_reduce,
46 | reduce_concat=_reduce_concat
47 | )
48 | return r_genotype
49 |
--------------------------------------------------------------------------------
/xnas/spaces/DrNAS/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def optimizer_transfer(optimizer_old, optimizer_new):
6 | for i, p in enumerate(optimizer_new.param_groups[0]['params']):
7 | if not hasattr(p, 'raw_id'):
8 | optimizer_new.state[p] = optimizer_old.state[p]
9 | continue
10 | state_old = optimizer_old.state_dict()['state'][p.raw_id]
11 | state_new = optimizer_new.state[p]
12 |
13 | state_new['momentum_buffer'] = state_old['momentum_buffer']
14 | if p.t == 'bn':
15 | # BN layer
16 | state_new['momentum_buffer'] = torch.cat(
17 | [state_new['momentum_buffer'], state_new['momentum_buffer'][p.out_index].clone()], dim=0)
18 | # clean to enable multiple call
19 | del p.t, p.raw_id, p.out_index
20 |
21 | elif p.t == 'conv':
22 | # conv layer
23 | if hasattr(p, 'in_index'):
24 | state_new['momentum_buffer'] = torch.cat(
25 | [state_new['momentum_buffer'], state_new['momentum_buffer'][:, p.in_index, :, :].clone()], dim=1)
26 | if hasattr(p, 'out_index'):
27 | state_new['momentum_buffer'] = torch.cat(
28 | [state_new['momentum_buffer'], state_new['momentum_buffer'][p.out_index, :, :, :].clone()], dim=0)
29 | # clean to enable multiple call
30 | del p.t, p.raw_id
31 | if hasattr(p, 'in_index'):
32 | del p.in_index
33 | if hasattr(p, 'out_index'):
34 | del p.out_index
35 | print('%d momemtum buffers loaded' % (i+1))
36 | return optimizer_new
37 |
38 | def scheduler_transfer(scheduler_old, scheduler_new):
39 | scheduler_new.load_state_dict(scheduler_old.state_dict())
40 | print('scheduler loaded')
41 | return scheduler_new
42 |
43 |
44 | def process_step_vector(x, method, mask, tau=None):
45 | if method == "softmax":
46 | output = F.softmax(x, dim=-1)
47 | elif method == "dirichlet":
48 | output = torch.distributions.dirichlet.Dirichlet(F.elu(x) + 1).rsample()
49 | elif method == "gumbel":
50 | output = F.gumbel_softmax(x, tau=tau, hard=False, dim=-1)
51 |
52 | if mask is None:
53 | return output
54 | else:
55 | output_pruned = torch.zeros_like(output)
56 | output_pruned[mask] = output[mask]
57 | output_pruned /= output_pruned.sum()
58 | assert (output_pruned[~mask] == 0.0).all()
59 | return output_pruned
60 |
61 |
62 | def process_step_matrix(x, method, mask, tau=None):
63 | weights = []
64 | if mask is None:
65 | for line in x:
66 | weights.append(process_step_vector(line, method, None, tau))
67 | else:
68 | for i, line in enumerate(x):
69 | weights.append(process_step_vector(line, method, mask[i], tau))
70 | return torch.stack(weights)
71 |
72 |
73 | def prune(x, num_keep, mask, reset=False):
74 | if not mask is None:
75 | x.data[~mask] -= 1000000
76 | src, index = x.topk(k=num_keep, dim=-1)
77 | if not reset:
78 | x.data.copy_(torch.zeros_like(x).scatter(dim=1, index=index, src=src))
79 | else:
80 | x.data.copy_(
81 | torch.zeros_like(x).scatter(
82 | dim=1, index=index, src=1e-3 * torch.randn_like(src)
83 | )
84 | )
85 | mask = torch.zeros_like(x, dtype=torch.bool).scatter(
86 | dim=1, index=index, src=torch.ones_like(src, dtype=torch.bool)
87 | )
88 | return mask
89 |
--------------------------------------------------------------------------------
/xnas/spaces/DropNAS/cnn.py:
--------------------------------------------------------------------------------
1 | from xnas.spaces.DARTS.ops import *
2 | import xnas.spaces.DARTS.genos as gt
3 |
4 |
5 | class Drop_MixedOp(nn.Module):
6 | """ Mixed operation """
7 |
8 | def __init__(self, C_in, C_out, stride):
9 | super().__init__()
10 | self._ops = nn.ModuleList()
11 | for primitive in gt.PRIMITIVES:
12 | op = OPS[primitive](C_in, C_out, stride, affine=False)
13 | self._ops.append(op)
14 |
15 | def forward(self, x, weights, masks):
16 | """
17 | Args:
18 | x: input
19 | weights: weight for each operation
20 | masks: list of boolean
21 | """
22 | return sum(w * op(x) for w, op, mask in zip(weights, self._ops, masks) if mask)
23 | # return sum(w * op(x) for w, op in zip(weights, self._ops))
24 |
25 |
26 | class DropNASCell(nn.Module):
27 | """ Cell for search
28 | Each edge is mixed and continuous relaxed.
29 | """
30 |
31 | def __init__(self, n_nodes, C_pp, C_p, C, reduction_p, reduction):
32 | """
33 | Args:
34 | n_nodes: # of intermediate n_nodes
35 | C_pp: C_out[k-2]
36 | C_p : C_out[k-1]
37 | C : C_in[k] (current)
38 | reduction_p: flag for whether the previous cell is reduction cell or not
39 | reduction: flag for whether the current cell is reduction cell or not
40 | """
41 | super().__init__()
42 | self.reduction = reduction
43 | self.n_nodes = n_nodes
44 |
45 | # If previous cell is reduction cell, current input size does not match with
46 | # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
47 | if reduction_p:
48 | self.preproc0 = FactorizedReduce(C_pp, C, affine=False)
49 | else:
50 | self.preproc0 = ReluConvBn(C_pp, C, 1, 1, 0, affine=False)
51 | self.preproc1 = ReluConvBn(C_p, C, 1, 1, 0, affine=False)
52 |
53 | # generate dag
54 | self.dag = nn.ModuleList()
55 | for i in range(self.n_nodes):
56 | self.dag.append(nn.ModuleList())
57 | for j in range(2 + i): # include 2 input nodes
58 | # reduction should be used only for input node
59 | stride = 2 if reduction and j < 2 else 1
60 | op = Drop_MixedOp(C, C, stride)
61 | self.dag[i].append(op)
62 |
63 | def forward(self, s0, s1, w_dag, masks):
64 | s0 = self.preproc0(s0)
65 | s1 = self.preproc1(s1)
66 |
67 | states = [s0, s1]
68 | for edges, w_list, m_list in zip(self.dag, w_dag, masks):
69 | s_cur = sum(edges[i](s, w, m) for i, (s, w, m) in enumerate(zip(states, w_list, m_list)))
70 | states.append(s_cur)
71 |
72 | s_out = torch.cat(states[2:], dim=1)
73 | return s_out
74 |
75 |
76 | class DropNASCNN(nn.Module):
77 | """ Search CNN model """
78 |
79 | def __init__(self, C_in, C, n_classes, n_layers, n_nodes=4, stem_multiplier=3):
80 | """
81 | Args:
82 | C_in: # of input channels
83 | C: # of starting model channels
84 | n_classes: # of classes
85 | n_layers: # of layers
86 | n_nodes: # of intermediate nodes in Cell
87 | """
88 | super().__init__()
89 | self.C_in = C_in
90 | self.C = C
91 | self.n_classes = n_classes
92 | self.n_layers = n_layers
93 |
94 | C_cur = stem_multiplier * C
95 | self.stem = nn.Sequential(
96 | nn.Conv2d(C_in, C_cur, 3, 1, 1, bias=False),
97 | nn.BatchNorm2d(C_cur)
98 | )
99 |
100 | # for the first cell, stem is used for both s0 and s1
101 | # [!] C_pp and C_p is output channel size, but C_cur is input channel size.
102 | C_pp, C_p, C_cur = C_cur, C_cur, C
103 |
104 | self.cells = nn.ModuleList()
105 | reduction_p = False
106 | for i in range(n_layers):
107 | # Reduce featuremap size and double channels in 1/3 and 2/3 layer.
108 | if i in [n_layers // 3, 2 * n_layers // 3]:
109 | C_cur *= 2
110 | reduction = True
111 | else:
112 | reduction = False
113 |
114 | cell = DropNASCell(
115 | n_nodes=n_nodes,
116 | C_pp=C_pp,
117 | C_p=C_p,
118 | C=C_cur,
119 | reduction_p=reduction_p,
120 | reduction=reduction
121 | )
122 |
123 | reduction_p = reduction
124 | self.cells.append(cell)
125 | C_cur_out = C_cur * n_nodes
126 | C_pp, C_p = C_p, C_cur_out
127 |
128 | self.gap = nn.AdaptiveAvgPool2d(1)
129 | self.linear = nn.Linear(C_p, n_classes)
130 |
131 | def forward(self, x, weights_normal, weights_reduce, masks_normal, masks_reduce):
132 | """
133 | Args:
134 | weights_xxx: probability contribution of each operation
135 | masks_xxx: decide whether to drop an operation
136 | """
137 | s0 = s1 = self.stem(x)
138 |
139 | for i, cell in enumerate(self.cells):
140 | weights = weights_reduce if cell.reduction else weights_normal
141 | masks = masks_reduce if cell.reduction else masks_normal
142 | s0, s1 = s1, cell(s0, s1, weights, masks)
143 |
144 | out = self.gap(s1)
145 | out = out.view(out.size(0), -1) # flatten
146 | logits = self.linear(out)
147 | return logits
148 |
149 |
150 | # build API
151 | def _DropNASCNN():
152 | from xnas.core.config import cfg
153 | return DropNASCNN(
154 | C_in=cfg.SEARCH.INPUT_CHANNELS,
155 | C=cfg.SPACE.CHANNELS,
156 | n_classes=cfg.LOADER.NUM_CLASSES,
157 | n_layers=cfg.SPACE.LAYERS,
158 | n_nodes=cfg.SPACE.NODES,
159 | )
160 |
--------------------------------------------------------------------------------
/xnas/spaces/NASBench1Shot1/cnn.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/xnas/spaces/NASBench1Shot1/cnn.py
--------------------------------------------------------------------------------
/xnas/spaces/NASBench1Shot1/ops.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/xnas/spaces/NASBench1Shot1/ops.py
--------------------------------------------------------------------------------
/xnas/spaces/NASBench201/utils.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | from .genos import Structure as CellStructure
4 | from .cnn import TinyNetwork
5 |
6 |
7 | def dict2config(xdict, logger):
8 | assert isinstance(xdict, dict), "invalid type : {:}".format(type(xdict))
9 | Arguments = namedtuple("Configure", " ".join(xdict.keys()))
10 | content = Arguments(**xdict)
11 | if hasattr(logger, "log"):
12 | logger.log("{:}".format(content))
13 | return content
14 |
15 | def get_cell_based_tiny_net(config):
16 | if hasattr(config, "genotype"):
17 | genotype = config.genotype
18 | elif hasattr(config, "arch_str"):
19 | genotype = CellStructure.str2structure(config.arch_str)
20 | else:
21 | raise ValueError(
22 | "Can not find genotype from this config : {:}".format(config)
23 | )
24 | return TinyNetwork(config.C, config.N, genotype, config.num_classes)
25 |
--------------------------------------------------------------------------------
/xnas/spaces/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAC-AutoML/XNAS/ea8f8ab31f67155482f5b9a9ad2a0b54c45f45d1/xnas/spaces/__init__.py
--------------------------------------------------------------------------------