├── .github
└── workflows
│ ├── environment.yml
│ └── python-package-conda.yml
├── .gitignore
├── .pre-commit-config.yaml
├── Dockerfile
├── LICENSE
├── MANIFEST.in
├── README.md
├── README_cn.md
├── conda_env_gpu.yaml
├── distribute.sh
├── docs
├── release.md
└── usage.md
├── examples
├── DARTS_CIFAR10.ipynb
├── RandomNAS_MNIST.ipynb
└── ResNet_CIFAR10.ipynb
├── hyperbox
├── __init__.py
├── callbacks
│ ├── __init__.py
│ └── wandb_callbacks.py
├── configs
│ ├── __init__.py
│ ├── callbacks
│ │ ├── default.yaml
│ │ ├── none.yaml
│ │ └── wandb.yaml
│ ├── config.yaml
│ ├── datamodule
│ │ ├── cifar100_datamodule.yaml
│ │ ├── cifar10_datamodule.yaml
│ │ ├── fakedata_datamodule.yaml
│ │ ├── imagenet_dali_datamodule.yaml
│ │ ├── imagenet_datamodule.yaml
│ │ ├── medmnist_datamodule.yaml
│ │ ├── mnist_datamodule.yaml
│ │ └── transforms
│ │ │ └── cifar.yaml
│ ├── engine
│ │ └── none.yaml
│ ├── experiment
│ │ ├── example_bnnas.yaml
│ │ ├── example_classify.yaml
│ │ ├── example_darts_nas.yaml
│ │ ├── example_full.yaml
│ │ ├── example_nasbench.yaml
│ │ ├── example_ofa_nas.yaml
│ │ ├── example_random_nas.yaml
│ │ ├── example_repnas.yaml
│ │ └── example_simple.yaml
│ ├── hparams_search
│ │ └── mnist_optuna.yaml
│ ├── hydra
│ │ └── default.yaml
│ ├── lite.yaml
│ ├── logger
│ │ ├── comet.yaml
│ │ ├── csv.yaml
│ │ ├── many_loggers.yaml
│ │ ├── mlflow.yaml
│ │ ├── neptune.yaml
│ │ ├── tensorboard.yaml
│ │ └── wandb.yaml
│ ├── model
│ │ ├── classify_model.yaml
│ │ ├── darts_model.yaml
│ │ ├── loss_cfg
│ │ │ ├── cross_entropy.yaml
│ │ │ └── cross_entropy_labelsmooth.yaml
│ │ ├── metric_cfg
│ │ │ └── accuracy.yaml
│ │ ├── mnist_model.yaml
│ │ ├── mutator_cfg
│ │ │ ├── darts_multiple_mutator.yaml
│ │ │ ├── darts_mutator.yaml
│ │ │ ├── ea_mutator.yaml
│ │ │ ├── enas_mutator.yaml
│ │ │ ├── evolution_mutator.yaml
│ │ │ ├── fairdarts_mutator.yaml
│ │ │ ├── fairnas_mutator.yaml
│ │ │ ├── fewshot_mutator.yaml
│ │ │ ├── onehot_mutator.yaml
│ │ │ ├── random_multiple_mutator.yaml
│ │ │ ├── random_mutator.yaml
│ │ │ └── repnas_mutator.yaml
│ │ ├── nasbench_model.yaml
│ │ ├── network_cfg
│ │ │ ├── bn_nas.yaml
│ │ │ ├── darts_network.yaml
│ │ │ ├── enas_macro.yaml
│ │ │ ├── enas_micro.yaml
│ │ │ ├── finegrained_resnet.yaml
│ │ │ ├── mobilenet2d_nas.yaml
│ │ │ ├── mobilenet3d_nas.yaml
│ │ │ ├── nasbench201.yaml
│ │ │ ├── nasbench_mbnet.yaml
│ │ │ ├── ofa_mbv3.yaml
│ │ │ ├── proxylessnas.yaml
│ │ │ ├── repnas_network.yaml
│ │ │ └── torch_network.yaml
│ │ ├── ofa_model.yaml
│ │ ├── optimizer_cfg
│ │ │ ├── adadelta.yaml
│ │ │ ├── adam.yaml
│ │ │ ├── adamw.yaml
│ │ │ ├── asam.yaml
│ │ │ ├── lamb.yaml
│ │ │ ├── rmsprop.yaml
│ │ │ ├── sam.yaml
│ │ │ └── sgd.yaml
│ │ ├── random_model.yaml
│ │ ├── repnas_model.yaml
│ │ ├── resnet18.yaml
│ │ └── scheduler_cfg
│ │ │ ├── CosineAnnealingLR.yaml
│ │ │ ├── ExponentialLR.yaml
│ │ │ ├── MultiStepLR.yaml
│ │ │ ├── ReducedLRonPlateau.yaml
│ │ │ └── warmup_scheduler.yaml
│ ├── paths
│ │ └── default.yaml
│ └── trainer
│ │ ├── ddp.yaml
│ │ └── default.yaml
├── datamodules
│ ├── __init__.py
│ ├── cifar_datamodule.py
│ ├── datasets
│ │ └── __init__.py
│ ├── distributed_sampler_wrapper.py
│ ├── fakedata_datamodule.py
│ ├── imagenet_dali_datamodule.py
│ ├── imagenet_datamodule.py
│ ├── medmnist_datamodule.py
│ ├── mnist_datamodule.py
│ └── transforms
│ │ ├── __init__.py
│ │ ├── albumentation_transforms.py
│ │ ├── autoaugment.py
│ │ ├── base_transforms.py
│ │ ├── cutout.py
│ │ └── torch_transforms.py
├── engine
│ ├── __init__.py
│ └── base_engine.py
├── lites
│ └── base_lite.py
├── losses
│ ├── __init__.py
│ ├── ce_labelsmooth_loss.py
│ ├── focal_loss.py
│ └── kd_loss.py
├── models
│ ├── __init__.py
│ ├── base_model.py
│ ├── classify_model.py
│ ├── darts_model.py
│ ├── mnist_model.py
│ ├── nasbench_model.py
│ ├── ofa_model.py
│ └── random_model.py
├── mutables
│ ├── __init__.py
│ ├── layers
│ │ ├── __init__.py
│ │ └── layers2d.py
│ ├── masker.py
│ ├── ops
│ │ ├── __init__.py
│ │ ├── base_module.py
│ │ ├── batchnorm.py
│ │ ├── conv.py
│ │ ├── embedding.py
│ │ ├── groupnorm.py
│ │ ├── layernorm.py
│ │ ├── linear.py
│ │ ├── multihead_attention.py
│ │ └── utils.py
│ └── spaces.py
├── mutator
│ ├── __init__.py
│ ├── base_mutator.py
│ ├── darts_multiple_mutator.py
│ ├── darts_mutator.py
│ ├── default_mutator.py
│ ├── enas_mutator.py
│ ├── evolution_mutator.py
│ ├── fairnas_mutator.py
│ ├── fewshot_mutator.py
│ ├── fixed_mutator.py
│ ├── onehot_mutator.py
│ ├── proxyless_mutator.py
│ ├── random_multiple_mutator.py
│ ├── random_mutator.py
│ ├── repnas_mutator.py
│ ├── sequential_mutator.py
│ └── utils.py
├── networks
│ ├── __init__.py
│ ├── base_nas_network.py
│ ├── bnnas
│ │ ├── __init__.py
│ │ ├── bn_blocks.py
│ │ ├── bn_net.py
│ │ └── ea_search.py
│ ├── darts
│ │ ├── __init__.py
│ │ ├── darts_mask.json
│ │ ├── darts_network.py
│ │ └── darts_ops.py
│ ├── enas
│ │ ├── __init__.py
│ │ ├── enas_network.py
│ │ └── enas_ops.py
│ ├── gpt
│ │ ├── __init__.py
│ │ └── gpt2.py
│ ├── mobilenet
│ │ ├── __init__.py
│ │ ├── mobile3d_net.py
│ │ ├── mobile3d_ops.py
│ │ ├── mobile_net.py
│ │ ├── mobile_ops.py
│ │ └── mobile_utils.py
│ ├── nasbench101
│ │ ├── __init__.py
│ │ ├── base_ops.py
│ │ ├── db_gen
│ │ │ ├── __init__.py
│ │ │ ├── constants.py
│ │ │ ├── db_gen.py
│ │ │ ├── graph_util.py
│ │ │ ├── model.py
│ │ │ └── query.py
│ │ ├── graph_util.py
│ │ ├── model_spec.py
│ │ ├── nasbench101.py
│ │ └── readme.md
│ ├── nasbench201
│ │ ├── __init__.py
│ │ ├── db_gen
│ │ │ ├── __init__.py
│ │ │ ├── constants.py
│ │ │ ├── db_gen.py
│ │ │ ├── model.py
│ │ │ └── query.py
│ │ ├── gen_nasbench201_data.sh
│ │ └── nasbench201.py
│ ├── nasbench301
│ │ ├── __init__.py
│ │ ├── install
│ │ │ ├── readme.md
│ │ │ ├── requirements_cu102.txt
│ │ │ └── requirements_cu111.txt
│ │ ├── nasbench301_network.py
│ │ └── utils.py
│ ├── nasbench_mbnet
│ │ ├── __init__.py
│ │ ├── nasbench_mbnet_cifar10.json
│ │ └── network.py
│ ├── nasbenchasr
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ ├── download_nasbenchasr.sh
│ │ ├── graph_utils.py
│ │ ├── model.py
│ │ ├── ops.py
│ │ ├── readme.md
│ │ ├── search_space.py
│ │ └── utils.py
│ ├── network_ema.py
│ ├── ofa
│ │ ├── __init__.py
│ │ ├── ofa_mbv3.py
│ │ └── ofa_mbv3_searchspace.json
│ ├── proxylessnas
│ │ ├── __init__.py
│ │ ├── network.py
│ │ ├── ops.py
│ │ └── putils.py
│ ├── pytorch_modules.py
│ ├── repnas
│ │ ├── __init__.py
│ │ ├── rep_ops.py
│ │ ├── repnas_spos.py
│ │ └── utils.py
│ ├── resnet
│ │ ├── __init__.py
│ │ └── resnet.py
│ ├── spos
│ │ ├── __init__.py
│ │ ├── shuffle_blocks.py
│ │ └── spos_net.py
│ ├── utils.py
│ └── vit
│ │ ├── __init__.py
│ │ └── vit.py
├── optimizers
│ ├── __init__.py
│ ├── lamb.py
│ └── sam.py
├── run.py
├── schedulers
│ ├── __init__.py
│ └── warmup_scheduler.py
├── train.py
└── utils
│ ├── __init__.py
│ ├── average_meter.py
│ ├── calc_model_size.py
│ ├── logger.py
│ ├── metrics.py
│ ├── utils.py
│ ├── visualize_darts_cell.py
│ └── visualize_mbconv_net.py
├── requirements.txt
├── run.sh
├── setup.cfg
├── setup.py
└── tests
├── __init__.py
├── datamodules
└── test_transforms.py
├── helpers
├── __init__.py
├── module_available.py
├── run_command.py
└── runif.py
├── models
└── test_random_nas_model.py
├── mutables
├── test_all_mutables.py
├── test_duplicate_mutables.py
├── test_op_bn.py
├── test_op_conv.py
└── test_op_linear.py
├── networks
├── test_build_subnet.py
├── test_darts_net.py
├── test_enas_net.py
├── test_mobile_net.py
└── test_resnet.py
├── smoke
├── __init__.py
├── test_commands.py
├── test_mixed_precision.py
├── test_sweeps.py
└── test_wandb.py
├── unit
├── __init__.py
└── test_sth.py
└── utils
├── test_calc_model_size.py
├── test_hparams.py
└── test_logger.py
/.github/workflows/environment.yml:
--------------------------------------------------------------------------------
1 | name: hyperbox
2 |
3 | channels:
4 | - pytorch
5 | - conda-forge
6 |
7 | dependencies:
8 | - python=3.10
9 | - pip
10 | - cudatoolkit
11 | - pytorch=1.12.1
12 | - torchvision=0.13.1
13 | - pip:
14 | - -r requirements.txt
15 |
--------------------------------------------------------------------------------
/.github/workflows/python-package-conda.yml:
--------------------------------------------------------------------------------
1 | name: Python Package using Conda
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build-linux:
7 | runs-on: ubuntu-latest
8 | strategy:
9 | max-parallel: 5
10 |
11 | steps:
12 | - uses: actions/checkout@v3
13 | - name: Set up Python 3.10
14 | uses: actions/setup-python@v3
15 | with:
16 | python-version: '3.10'
17 | - name: Add conda to system path
18 | run: |
19 | # $CONDA is an environment variable pointing to the root of the miniconda directory
20 | echo $CONDA/bin >> $GITHUB_PATH
21 | - name: Install dependencies
22 | run: |
23 | conda env update --file /home/runner/work/hyperbox/.github/workflows/environment.yml --name hyperbox
24 | - name: Lint with flake8
25 | run: |
26 | conda install flake8
27 | # stop the build if there are Python syntax errors or undefined names
28 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
29 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
30 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
31 | - name: Test with pytest
32 | run: |
33 | conda install pytest
34 | pytest
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .venv
106 | env/
107 | venv/
108 | ENV/
109 | env.bak/
110 | venv.bak/
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 | .dmypy.json
125 | dmypy.json
126 |
127 | # Pyre type checker
128 | .pyre/
129 |
130 | ### VisualStudioCode
131 | .vscode/*
132 | !.vscode/settings.json
133 | !.vscode/tasks.json
134 | !.vscode/launch.json
135 | !.vscode/extensions.json
136 | *.code-workspace
137 | **/.vscode
138 |
139 | ### Macbook
140 | .DS_Store
141 |
142 | # JetBrains
143 | .idea/
144 |
145 | # Lightning-Hydra-Template
146 | data/
147 | logs/
148 | wandb/
149 | .env
150 | .autoenv
151 |
152 | outputs
153 | logs
154 | hyperbox_app/*
155 | !hyperbox_app/empty_app_template
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3.8
3 |
4 | repos:
5 | - repo: https://github.com/pre-commit/pre-commit-hooks
6 | rev: v3.4.0
7 | hooks:
8 | # list of supported hooks: https://pre-commit.com/hooks.html
9 | - id: trailing-whitespace
10 | - id: end-of-file-fixer
11 | - id: check-yaml
12 | - id: check-added-large-files
13 | - id: debug-statements
14 | - id: detect-private-key
15 |
16 | # python code formatting
17 | - repo: https://github.com/psf/black
18 | rev: 20.8b1
19 | hooks:
20 | - id: black
21 | args: [--line-length, "99"]
22 |
23 | # python import sorting
24 | - repo: https://github.com/PyCQA/isort
25 | rev: 5.8.0
26 | hooks:
27 | - id: isort
28 |
29 | # yaml formatting
30 | - repo: https://github.com/pre-commit/mirrors-prettier
31 | rev: v2.3.0
32 | hooks:
33 | - id: prettier
34 | types: [yaml]
35 |
36 | # python code analysis
37 | - repo: https://github.com/PyCQA/flake8
38 | rev: 3.9.2
39 | hooks:
40 | - id: flake8
41 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Build: docker build -t project_name .
2 | # Run: docker run --gpus all -it --rm project_name
3 |
4 | # Build from official Nvidia PyTorch image
5 | # GPU-ready with Apex for mixed-precision support
6 | # https://ngc.nvidia.com/catalog/containers/nvidia:pytorch
7 | # https://docs.nvidia.com/deeplearning/frameworks/support-matrix/
8 | FROM nvcr.io/nvidia/pytorch:22.07-py3
9 |
10 |
11 | # Copy all files
12 | ADD . /workspace/hyperbox
13 | WORKDIR /workspace/hyperbox
14 |
15 |
16 | # Create hyperbox
17 | RUN conda env create -f conda_env_gpu.yaml -n hyperbox
18 | RUN conda init bash
19 |
20 |
21 | # Set hyperbox to default virtual environment
22 | RUN echo "source activate hyperbox" >> ~/.bashrc
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 KINoAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include hyperbox/configs *.yaml
--------------------------------------------------------------------------------
/conda_env_gpu.yaml:
--------------------------------------------------------------------------------
1 | name: hyperbox
2 |
3 | channels:
4 | - pytorch
5 | - conda-forge
6 |
7 | dependencies:
8 | - python=3.8
9 | - pip
10 | - cudatoolkit
11 | - pytorch=1.8.1
12 | - torchvision=0.9.1
13 | - pip:
14 | - -r requirements.txt
15 |
--------------------------------------------------------------------------------
/distribute.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | rm -rf build dist hyperbox.egg-info
4 | python setup.py sdist bdist_wheel
5 | echo Building...
6 | echo Uploading...
7 | if (($1==1));then
8 | python -m twine upload dist/*
9 | echo Uploading to PyPI...
10 | elif (($1==2));then
11 | python -m twine upload --repository-url https://test.pypi.org/legacy/ dist/*
12 | echo Uploading to TestPyPI...
13 | else
14 | echo "Wrong command, only support 1 or 2"
15 | fi
16 | echo Done.
--------------------------------------------------------------------------------
/hyperbox/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/__init__.py
--------------------------------------------------------------------------------
/hyperbox/callbacks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/callbacks/__init__.py
--------------------------------------------------------------------------------
/hyperbox/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/configs/__init__.py
--------------------------------------------------------------------------------
/hyperbox/configs/callbacks/default.yaml:
--------------------------------------------------------------------------------
1 | model_checkpoint:
2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
3 | monitor: "val/acc" # name of the logged metric which determines when model is improving
4 | save_top_k: 1 # save k best models (determined by above metric)
5 | save_last: True # additionaly always save model from last epoch
6 | mode: "max" # can be "max" or "min"
7 | verbose: False
8 | dirpath: "checkpoints/"
9 | filename: "{epoch:02d}_{val/acc:.4f}"
10 |
11 | early_stopping:
12 | _target_: pytorch_lightning.callbacks.EarlyStopping
13 | monitor: "val/acc" # name of the logged metric which determines when model is improving
14 | patience: 1000 # how many epochs of not improving until training stops
15 | mode: "max" # can be "max" or "min"
16 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
17 | check_on_train_epoch_end: False
18 | strict: False
--------------------------------------------------------------------------------
/hyperbox/configs/callbacks/none.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/configs/callbacks/none.yaml
--------------------------------------------------------------------------------
/hyperbox/configs/callbacks/wandb.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | watch_model:
5 | _target_: hyperbox.callbacks.wandb_callbacks.WatchModel
6 | log: "all"
7 | log_freq: 100
8 |
9 | upload_code_as_artifact:
10 | _target_: hyperbox.callbacks.wandb_callbacks.UploadCodeAsArtifact
11 | code_dir: ${work_dir}/src
12 |
13 | upload_ckpts_as_artifact:
14 | _target_: hyperbox.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
15 | ckpt_dir: "checkpoints/"
16 | upload_best_only: True
17 |
18 | log_f1_precision_recall_heatmap:
19 | _target_: hyperbox.callbacks.wandb_callbacks.LogF1PrecRecHeatmap
20 |
21 | log_confusion_matrix:
22 | _target_: hyperbox.callbacks.wandb_callbacks.LogConfusionMatrix
23 |
24 | log_image_predictions:
25 | _target_: hyperbox.callbacks.wandb_callbacks.LogImagePredictions
26 | num_samples: 8
27 |
--------------------------------------------------------------------------------
/hyperbox/configs/config.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # specify here default training configuration
4 | defaults:
5 | - trainer: default.yaml
6 | - model: mnist_model.yaml
7 | - datamodule: mnist_datamodule.yaml
8 | - callbacks: default.yaml # set this to null if you don't want to use callbacks
9 | - logger: wandb # set logger here or use command line (e.g. `python run.py logger=wandb`)
10 | - engine: none.yaml
11 | - hparams_search: null
12 |
13 | - paths: default.yaml
14 | - hydra: default.yaml
15 |
16 | - experiment: null
17 |
18 | task_name: "hyperbox_project"
19 |
20 | # use `python run.py debug=true` for easy debugging!
21 | # this will run 1 train, val and test loop with only 1 batch
22 | # equivalent to running `python run.py trainer.fast_dev_run=true`
23 | # (this is placed here just for easier access from command line)
24 | debug: False
25 |
26 | # pretty print config at the start of the run using Rich library
27 | print_config: True
28 |
29 | # disable python warnings if they annoy you
30 | ignore_warnings: True
31 |
32 | # check performance on test set, using the best model achieved during training
33 | # lightning chooses best model based on metric specified in checkpoint callback
34 | test_after_training: True
35 |
36 | # checkpoint for testing. If specified, only test will be performed
37 | only_test: False
38 |
39 | # only load pretrained weights for network (exclude optimizer and scheduler)
40 | # e.g., 'logs/run/exp_name/checkpoint/epoch=66/acc=66.66.ckpt'
41 | pretrained_weight: null
42 | ipdb_debug: False
--------------------------------------------------------------------------------
/hyperbox/configs/datamodule/cifar100_datamodule.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.datamodules.cifar_datamodule.CIFAR100DataModule
2 |
3 | defaults:
4 | - transforms: cifar
5 |
6 | data_dir: ~/datasets/cifar100
7 | val_split: 0.1
8 | num_workers: 4
9 | normalize: True
10 | batch_size: 64
11 | seed: 666
12 | shuffle: True
13 | pin_memory: False
14 | drop_last: False
15 | is_customized: False
16 |
--------------------------------------------------------------------------------
/hyperbox/configs/datamodule/cifar10_datamodule.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.datamodules.cifar_datamodule.CIFAR10DataModule
2 |
3 | defaults:
4 | - transforms: cifar
5 |
6 | data_dir: ~/datasets/cifar10
7 | val_split: 0.5
8 | num_workers: 4
9 | normalize: True
10 | batch_size: 64
11 | seed: 666
12 | shuffle: True
13 | pin_memory: False
14 | drop_last: False
15 | is_customized: False
16 |
--------------------------------------------------------------------------------
/hyperbox/configs/datamodule/fakedata_datamodule.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.datamodules.fakedata_datamodule.FakeDataModule
2 |
3 | train_size: 4000
4 | test_size: 500
5 | image_size: [3, 64, 64]
6 | num_classes: 10
7 | data_dir: /path/to/data # data_dir is specified in config.yaml
8 | batch_size: 8
9 | train_val_test_split: [0.6, 0.2, 0.2]
10 | num_workers: 0
11 | pin_memory: False
12 |
--------------------------------------------------------------------------------
/hyperbox/configs/datamodule/imagenet_dali_datamodule.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.datamodules.imagenet_datamodule.ImagenetDALIDataModule
2 | dataset_path: /datasets/imagenet100
3 | batch_size: 4
4 | num_workers: 8
5 | dali_cpu: False
6 | val_size: 256
7 | crop_size: 224
--------------------------------------------------------------------------------
/hyperbox/configs/datamodule/imagenet_datamodule.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.datamodules.imagenet_datamodule.ImagenetDataModule
2 | data_dir: ~/datasets/imagenet2012
3 | classes: 1000
4 | autoaugment: True
5 | image_size: 224
6 | batch_size: 64
7 | shuffle: True
8 | pin_memory: True
9 | drop_last: False
10 | num_workers: 8
--------------------------------------------------------------------------------
/hyperbox/configs/datamodule/medmnist_datamodule.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.datamodules.medmnist_datamodule.MedMNISTDataModule
2 | data_dir: '~/datasets/medmnist/'
3 | data_flag: 'synapsemnist3d'
4 | batch_size: 128
5 | num_workers: 8
--------------------------------------------------------------------------------
/hyperbox/configs/datamodule/mnist_datamodule.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.datamodules.mnist_datamodule.MNISTDataModule
2 |
3 | data_dir: /path/to/data # data_dir is specified in config.yaml
4 | batch_size: 32
5 | num_workers: 0
6 | pin_memory: False
7 |
--------------------------------------------------------------------------------
/hyperbox/configs/datamodule/transforms/cifar.yaml:
--------------------------------------------------------------------------------
1 | input_size: [32, 32]
2 | random_crop:
3 | enable: 1
4 | padding: 4
5 | size: 32
6 | random_horizontal_flip:
7 | enable: 1
8 | p: 0.5
9 | cutout:
10 | enable: 1
11 | n_holes: 1
12 | length: 16
13 | normalize:
14 | enable: 1
15 | mean: [0.4914, 0.4822, 0.4465]
16 | std: [0.2023, 0.1994, 0.2010]
--------------------------------------------------------------------------------
/hyperbox/configs/engine/none.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/configs/engine/none.yaml
--------------------------------------------------------------------------------
/hyperbox/configs/experiment/example_bnnas.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python run.py experiment=example_simple.yaml
5 |
6 | defaults:
7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/'
8 | - override /model: random_model.yaml
9 | - override /model/network_cfg: bn_nas.yaml
10 | - override /datamodule: cifar10_datamodule.yaml
11 | - override /callbacks: default.yaml
12 | - override /logger: wandb.yaml
13 | - override /model/scheduler_cfg: CosineAnnealingLR.yaml
14 |
15 | # all parameters below will be merged with parameters from default configurations set above
16 | # this allows you to overwrite only specified parameters
17 |
18 | seed: 12345
19 |
20 | trainer:
21 | max_epochs: 200
22 | strategy: horovod
23 | gpus: 1
24 | debug: False
25 | # logger.wandb.name: test
26 | # logger.wandb.offline: True
27 | # hydra:
28 | # job:
29 | # name: test
--------------------------------------------------------------------------------
/hyperbox/configs/experiment/example_classify.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python run.py experiment=example_simple.yaml
5 |
6 | defaults:
7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/'
8 | - override /model: classify_model.yaml
9 | - override /datamodule: cifar10_datamodule.yaml
10 | - override /callbacks: default.yaml
11 | - override /logger: wandb.yaml
12 |
13 | # all parameters below will be merged with parameters from default configurations set above
14 | # this allows you to overwrite only specified parameters
15 |
16 | seed: 12345
17 |
18 | trainer:
19 | min_epochs: 1
20 | max_epochs: 600
21 |
22 | logger:
23 | wandb:
24 | offline: True
25 | project: Classification
--------------------------------------------------------------------------------
/hyperbox/configs/experiment/example_darts_nas.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python run.py experiment=example_simple.yaml
5 |
6 | defaults:
7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/'
8 | - override /model: darts_model.yaml
9 | - override /datamodule: fakedata_datamodule.yaml
10 | - override /callbacks: default.yaml
11 | - override /logger: wandb.yaml
12 |
13 | # all parameters below will be merged with parameters from default configurations set above
14 | # this allows you to overwrite only specified parameters
15 |
16 | seed: 12345
17 |
18 | trainer:
19 | min_epochs: 1
20 | max_epochs: 200
21 |
22 | datamodule:
23 | is_customized: True
24 |
25 | # model:
26 | # mutator_cfg:
27 | # _target_: hyperbox.mutator.OnehotMutator
28 |
29 | logger:
30 | wandb:
31 | offline: True
--------------------------------------------------------------------------------
/hyperbox/configs/experiment/example_full.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python run.py experiment=example_full.yaml
5 |
6 | defaults:
7 | - override /trainer: null # override trainer to null so it's not loaded from main config defaults...
8 | - override /model: null
9 | - override /datamodule: null
10 | - override /callbacks: null
11 | - override /logger: null
12 |
13 | # we override default configurations with nulls to prevent them from loading at all
14 | # instead we define all modules and their paths directly in this config,
15 | # so everything is stored in one place
16 |
17 | seed: 12345
18 |
19 | trainer:
20 | _target_: pytorch_lightning.Trainer
21 | min_epochs: 1
22 | max_epochs: 10
23 |
24 | model:
25 | _target_: hyperbox.models.mnist_model.MNISTLitModel
26 | lr: 0.001
27 | weight_decay: 0.00005
28 | architecture: SimpleDenseNet
29 | input_size: 784
30 | lin1_size: 256
31 | lin2_size: 256
32 | lin3_size: 128
33 | output_size: 10
34 |
35 | datamodule:
36 | _target_: hyperbox.datamodules.mnist_datamodule.MNISTDataModule
37 | data_dir: /path/to/data
38 | batch_size: 64
39 | train_val_test_split: [55_000, 5_000, 10_000]
40 | num_workers: 0
41 | pin_memory: False
42 |
43 | callbacks:
44 | model_checkpoint:
45 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
46 | monitor: "val/acc"
47 | save_top_k: 2
48 | save_last: True
49 | mode: "max"
50 | dirpath: "checkpoints/"
51 | filename: "sample-mnist-{epoch:02d}"
52 | early_stopping:
53 | _target_: pytorch_lightning.callbacks.EarlyStopping
54 | monitor: "val/acc"
55 | patience: 10
56 | mode: "max"
57 |
58 | logger:
59 | wandb:
60 | tags: ["best_model", "uwu"]
61 | notes: "Description of this model."
62 | neptune:
63 | tags: ["best_model"]
64 | csv_logger:
65 | save_dir: "."
66 |
--------------------------------------------------------------------------------
/hyperbox/configs/experiment/example_nasbench.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python run.py experiment=example_simple.yaml
5 |
6 | defaults:
7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/'
8 | - override /model: darts_model.yaml
9 | - override /datamodule: cifar10_datamodule.yaml
10 | - override /callbacks: default.yaml
11 | - override /logger: wandb.yaml
12 | - override /model/scheduler_cfg: CosineAnnealingLR.yaml
13 |
14 |
15 | # all parameters below will be merged with parameters from default configurations set above
16 | # this allows you to overwrite only specified parameters
17 |
18 | seed: 12345
19 |
20 | trainer:
21 | min_epochs: 1
22 | max_epochs: 200
23 | gpus: 1
24 |
25 | datamodule:
26 | is_customized: True
27 | # is_customized: False
28 |
29 | model:
30 | mutator_cfg:
31 | _target_: hyperbox.mutator.DartsMultipleMutator
32 | # _target_: hyperbox.mutator.RandomMultipleMutator
33 | optimizer_cfg:
34 | lr: 0.001
35 |
36 | logger:
37 | wandb:
38 | offline: True
--------------------------------------------------------------------------------
/hyperbox/configs/experiment/example_ofa_nas.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python run.py experiment=example_simple.yaml
5 |
6 | defaults:
7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/'
8 | - override /model: ofa_model.yaml
9 | - override /datamodule: cifar10_datamodule.yaml
10 | - override /callbacks: default.yaml
11 | - override /logger: wandb.yaml
12 |
13 | # all parameters below will be merged with parameters from default configurations set above
14 | # this allows you to overwrite only specified parameters
15 |
16 | seed: 12345
17 |
18 | trainer:
19 | min_epochs: 1
20 | max_epochs: 100
21 | logger:
22 | wandb:
23 | project: "OFA"
24 | offline: False
25 | callbacks:
26 | model_checkpoint:
27 | monitor: "val/acc" # name of the logged metric which determines when model is improving
28 | save_top_k: 2 # save k best models (determined by above metric)
29 |
--------------------------------------------------------------------------------
/hyperbox/configs/experiment/example_random_nas.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python run.py experiment=example_simple.yaml
5 |
6 | defaults:
7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/'
8 | - override /model: random_model.yaml
9 | - override /datamodule: fakedata_datamodule.yaml
10 | - override /callbacks: default.yaml
11 | - override /logger: wandb.yaml
12 |
13 | # all parameters below will be merged with parameters from default configurations set above
14 | # this allows you to overwrite only specified parameters
15 |
16 | seed: 12345
17 |
18 |
--------------------------------------------------------------------------------
/hyperbox/configs/experiment/example_repnas.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python run.py experiment=example_simple.yaml
5 |
6 | # fintune repname_model.yaml -> classify_model
7 |
8 | defaults:
9 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/'
10 | - override /model: repnas_model.yaml
11 | - override /datamodule: cifar10_datamodule.yaml
12 | - override /callbacks: default.yaml
13 | - override /logger: wandb.yaml
14 | - override /model/scheduler_cfg: CosineAnnealingLR.yaml
15 |
16 | # all parameters below will be merged with parameters from default configurations set above
17 | # this allows you to overwrite only specified parameters
18 |
19 | seed: 12345
20 |
21 | trainer:
22 | min_epochs: 1
23 | max_epochs: 200
24 | gpus: 1
25 |
26 | datamodule:
27 | is_customized: True # for darts mutator
28 | # is_customized: False # for other mutators
29 |
30 |
31 | model:
32 | mutator_cfg:
33 | _target_: hyperbox.mutator.DartsMultipleMutator
34 | # _target_: hyperbox.mutator.RandomMultipleMutator
35 | optimizer_cfg:
36 | lr: 0.001
37 |
38 | logger:
39 | wandb:
40 | offline: True
41 |
--------------------------------------------------------------------------------
/hyperbox/configs/experiment/example_simple.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python run.py experiment=example_simple.yaml
5 |
6 | defaults:
7 | - override /trainer: minimal.yaml # choose trainer from 'configs/trainer/'
8 | - override /model: mnist_model.yaml
9 | - override /datamodule: fakedata_datamodule.yaml
10 | - override /callbacks: default.yaml
11 | - override /logger: null
12 |
13 | # all parameters below will be merged with parameters from default configurations set above
14 | # this allows you to overwrite only specified parameters
15 |
16 | seed: 12345
17 |
18 | trainer:
19 | min_epochs: 1
20 | max_epochs: 1
21 | gradient_clip_val: 0.5
22 |
23 | model:
24 | lin1_size: 128
25 | lin2_size: 256
26 | lin3_size: 64
27 | lr: 0.002
28 |
29 | datamodule:
30 | batch_size: 64
31 | image_size: [1,28,28]
32 |
--------------------------------------------------------------------------------
/hyperbox/configs/hparams_search/mnist_optuna.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # example hyperparameter optimization of some experiment with Optuna:
4 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple
5 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple hydra.sweeper.n_trials=30
6 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple logger=wandb
7 |
8 | defaults:
9 | - override /hydra/sweeper: optuna
10 |
11 | # choose metric which will be optimized by Optuna
12 | optimized_metric: "val/acc"
13 |
14 | hydra:
15 | # here we define Optuna hyperparameter search
16 | # it optimizes for value returned from function with @hydra.main decorator
17 | # learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper
18 | sweeper:
19 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
20 | storage: null
21 | study_name: null
22 | n_jobs: 1
23 |
24 | # 'minimize' or 'maximize' the objective
25 | direction: maximize
26 |
27 | # number of experiments that will be executed
28 | n_trials: 20
29 |
30 | # choose Optuna hyperparameter sampler
31 | # learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html
32 | sampler:
33 | _target_: optuna.samplers.TPESampler
34 | seed: 12345
35 | consider_prior: true
36 | prior_weight: 1.0
37 | consider_magic_clip: true
38 | consider_endpoints: false
39 | n_startup_trials: 10
40 | n_ei_candidates: 24
41 | multivariate: false
42 | warn_independent_sampling: true
43 |
44 | # define range of hyperparameters
45 | search_space:
46 | datamodule.batch_size:
47 | type: categorical
48 | choices: [32, 64, 128]
49 | model.lr:
50 | type: float
51 | low: 0.0001
52 | high: 0.2
53 | model.lin1_size:
54 | type: categorical
55 | choices: [32, 64, 128, 256, 512]
56 | model.lin2_size:
57 | type: categorical
58 | choices: [32, 64, 128, 256, 512]
59 | model.lin3_size:
60 | type: categorical
61 | choices: [32, 64, 128, 256, 512]
62 |
--------------------------------------------------------------------------------
/hyperbox/configs/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | # output paths for hydra logs
2 |
3 | defaults:
4 | - override hydra_logging: colorlog
5 | - override job_logging: colorlog
6 |
7 | run:
8 | dir: ${paths.log_dir}/runs/${task_name}/${now:%Y-%m-%d_%H-%M-%S}
9 | sweep:
10 | dir: ${paths.log_dir}/multiruns/${task_name}/${now:%Y-%m-%d_%H-%M-%S}
11 | subdir: ${hydra.job.num}
12 |
13 | # you can set here environment variables that are universal for all users
14 | # for system specific variables (like data paths) it's better to use .env file!
15 | job:
16 | env_set:
17 | EXAMPLE_VAR: "example_value"
18 | name: "exp"
19 | chdir: True # the output/working directory will be changed to {hydra.job.name}, you can see below URL for more details
20 | # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory/#disable-changing-current-working-dir-to-jobs-output-dir
21 |
--------------------------------------------------------------------------------
/hyperbox/configs/lite.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - model: darts_model.yaml
5 | - datamodule: cifar10_datamodule.yaml
6 | - logger: wandb.yaml
7 | - hydra: default.yaml
8 |
9 | lite:
10 | accelerator: "auto" # gpu, cpu, ipu, tpu, auto
11 | strategy: null # "dp", "ddp", "ddp_spawn", "tpu_spawn", "deepspeed", "ddp_sharded", or "ddp_sharded_spawn"
12 | devices: null
13 | num_nodes: 1
14 | precision: 32
15 | plugins: null
16 | gpus: null
17 | tpu_cores: null
18 |
19 | ipdb_debug: False
20 | logger:
21 | wandb:
22 | name: 'lite'
23 | hydra:
24 | job:
25 | name: 'lite'
26 |
--------------------------------------------------------------------------------
/hyperbox/configs/logger/comet.yaml:
--------------------------------------------------------------------------------
1 | # https://www.comet.ml
2 |
3 | comet:
4 | _target_: pytorch_lightning.loggers.comet.CometLogger
5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is laoded from environment variable
6 | project_name: "template-tests"
7 | experiment_name: null
8 |
--------------------------------------------------------------------------------
/hyperbox/configs/logger/csv.yaml:
--------------------------------------------------------------------------------
1 | # csv logger built in lightning
2 |
3 | csv:
4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger
5 | save_dir: "."
6 | name: "csv/"
7 | version: null
8 | prefix: ""
9 |
--------------------------------------------------------------------------------
/hyperbox/configs/logger/many_loggers.yaml:
--------------------------------------------------------------------------------
1 | # train with many loggers at once
2 |
3 | defaults:
4 | # - aim.yaml
5 | # - comet.yaml
6 | # - csv.yaml
7 | # - mlflow.yaml
8 | # - neptune.yaml
9 | - tensorboard.yaml
10 | - wandb.yaml
11 |
--------------------------------------------------------------------------------
/hyperbox/configs/logger/mlflow.yaml:
--------------------------------------------------------------------------------
1 | # https://mlflow.org
2 |
3 | mlflow:
4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger
5 | experiment_name: default
6 | tracking_uri: null
7 | tags: null
8 | save_dir: ./mlruns
9 | prefix: ""
10 | artifact_location: null
11 |
--------------------------------------------------------------------------------
/hyperbox/configs/logger/neptune.yaml:
--------------------------------------------------------------------------------
1 | # https://neptune.ai
2 |
3 | neptune:
4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger
5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is laoded from environment variable
6 | project_name: your_name/template-tests
7 | close_after_fit: True
8 | offline_mode: False
9 | experiment_name: null
10 | experiment_id: null
11 | prefix: ""
12 |
--------------------------------------------------------------------------------
/hyperbox/configs/logger/tensorboard.yaml:
--------------------------------------------------------------------------------
1 | # https://www.tensorflow.org/tensorboard/
2 |
3 | tensorboard:
4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
5 | save_dir: "tensorboard/"
6 | name: "default"
7 | version: null
8 | log_graph: False
9 | default_hp_metric: True
10 | prefix: ""
11 |
--------------------------------------------------------------------------------
/hyperbox/configs/logger/wandb.yaml:
--------------------------------------------------------------------------------
1 | # https://wandb.ai
2 |
3 | wandb:
4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger
5 | project: "template-tests"
6 | name: null
7 | save_dir: "."
8 | offline: True # set True to store all logs only locally
9 | id: null # pass correct id to resume experiment!
10 | # entity: "" # set to name of your wandb team or just remove it
11 | log_model: False
12 | prefix: ""
13 | job_type: "train"
14 | group: ""
15 | tags: []
16 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/classify_model.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.models.classify_model.ClassifyModel
2 |
3 | defaults:
4 | - metric_cfg: accuracy
5 | - loss_cfg: cross_entropy
6 | - optimizer_cfg: sgd
7 | - network_cfg: darts_network
8 | - scheduler_cfg: null
9 |
10 | optimizer_cfg:
11 | lr: 0.025
12 | weight_decay: 0.0003
13 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/darts_model.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.models.darts_model.DARTSModel
2 |
3 | defaults:
4 | - network_cfg: finegrained_resnet
5 | - mutator_cfg: darts_mutator
6 | - optimizer_cfg: adam
7 | - metric_cfg: accuracy
8 | - loss_cfg: cross_entropy
9 | - scheduler_cfg: null
--------------------------------------------------------------------------------
/hyperbox/configs/model/loss_cfg/cross_entropy.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.nn.CrossEntropyLoss
--------------------------------------------------------------------------------
/hyperbox/configs/model/loss_cfg/cross_entropy_labelsmooth.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.losses.ce_labelsmooth_loss.CrossEntropyLabelSmooth
2 | label_smoothing: 0.1
--------------------------------------------------------------------------------
/hyperbox/configs/model/metric_cfg/accuracy.yaml:
--------------------------------------------------------------------------------
1 | _target_: torchmetrics.classification.accuracy.Accuracy
2 | # _target_: hyperbox.utils.metrics.Accuracy
--------------------------------------------------------------------------------
/hyperbox/configs/model/mnist_model.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.models.mnist_model.MNISTLitModel
2 |
3 | input_size: 784
4 | lin1_size: 256
5 | lin2_size: 256
6 | lin3_size: 256
7 | output_size: 10
8 | lr: 0.001
9 | weight_decay: 0.0005
10 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/mutator_cfg/darts_multiple_mutator.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.mutator.DartsMultipleMutator
--------------------------------------------------------------------------------
/hyperbox/configs/model/mutator_cfg/darts_mutator.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.mutator.DartsMutator
--------------------------------------------------------------------------------
/hyperbox/configs/model/mutator_cfg/ea_mutator.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.mutator.EAMutator
2 | init_population_mode: 'warmup'
3 | warmup_epochs: 5
4 | num_population: 10
5 | prob_crossover: 0.3
6 | prob_mutation: 0.2
7 | object_keys: [1] # 0: flops & 1:size
8 | target_keys: [0] # 0: meters & 1:speed
9 | algorithm: 'cars' # 进化算法
10 | offspring_ratio: 1
11 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/mutator_cfg/enas_mutator.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.mutator.EnasMutator
2 | lstm_size: 61
3 | lstm_num_layers: 1
4 | tanh_constant: 1.5
5 | cell_exit_extra_step: False
6 | skip_target: 0.4
7 | branch_bias: 0.25
8 | arch_loss_weight: 0.02 # 0.002:small 0.02:medium 0.2:big
9 | reward_weight: 50
10 | temperature: 2
11 | entropy_reduction: 'sum' # 'sum', 'mean',
12 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/mutator_cfg/evolution_mutator.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.mutator.evolution_mutator.EvolutionMutator
2 | warmup_epochs: 0
3 | evolution_epochs: 100
4 | population_num: 50
5 | selection_alg: 'best'
6 | selection_num: 0.2
7 | crossover_num: 0.5
8 | crossover_prob: 0.1
9 | mutation_num: 0.5
10 | mutation_prob: 0.1
11 | flops_limit: 5000 # MFLOPs
12 | size_limit: 800 # MB
13 | log_dir: 'evolution_logs'
14 | topk: 10
15 | resume_from_checkpoint: null
16 | to_save_checkpoint: True
17 | to_plot_pareto: True
18 | figname: 'evolution_pareto.pdf'
--------------------------------------------------------------------------------
/hyperbox/configs/model/mutator_cfg/fairdarts_mutator.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.mutator.fairdarts_mutator.FairDartsMutator
--------------------------------------------------------------------------------
/hyperbox/configs/model/mutator_cfg/fairnas_mutator.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.mutator.fairnas_mutator.FairNASMutator
--------------------------------------------------------------------------------
/hyperbox/configs/model/mutator_cfg/fewshot_mutator.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.mutator.fewshot_mutator.FewshotMutator
2 | training_epochs: 100
3 | split_mutable_indices: 2
4 | save_dir: null
5 | to_save_sub_supernets: True
--------------------------------------------------------------------------------
/hyperbox/configs/model/mutator_cfg/onehot_mutator.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.mutator.OnehotMutator
--------------------------------------------------------------------------------
/hyperbox/configs/model/mutator_cfg/random_multiple_mutator.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.mutator.RandomMultipleMutator
--------------------------------------------------------------------------------
/hyperbox/configs/model/mutator_cfg/random_mutator.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.mutator.random_mutator.RandomMutator
--------------------------------------------------------------------------------
/hyperbox/configs/model/mutator_cfg/repnas_mutator.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.mutator.RepnasMutator
--------------------------------------------------------------------------------
/hyperbox/configs/model/nasbench_model.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.models.nasbench_model.NASBench201Model
2 | defaults:
3 | - network_cfg: nasbench201
4 | - mutator_cfg: random_mutator
5 | - optimizer_cfg: adam
6 | - metric_cfg: accuracy
7 | - loss_cfg: cross_entropy
8 | - scheduler_cfg: null
9 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/network_cfg/bn_nas.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.networks.bnnas.bn_net.BNNet
2 | first_stride: 1
3 | first_channels: 24
4 | width_mult: 1
5 | channels_list: [32,40,80,96,192,320]
6 | num_blocks: [2,2,4,4,4,1]
7 | strides_list: [2,2,2,1,2,1]
8 | num_classes: 10
9 | search_depth: False
10 | is_only_train_bn: False
11 | mask: null
--------------------------------------------------------------------------------
/hyperbox/configs/model/network_cfg/darts_network.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.networks.darts.darts_network.DartsNetwork
2 | in_channels: 3
3 | channels: 16 # 36 for imagenet
4 | n_layers: 8 # 20 for imagenet
5 | stem_multiplier: 3
6 | n_nodes: 4
7 | n_classes: 10
8 | # mask: /path/to/hyperbox/hyperbox/networks/darts/darts_mask.json
9 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/network_cfg/enas_macro.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.networks.enas.ENASMacroGeneralModel
2 | num_layers: 12
3 | out_filters: 24
4 | in_channels: 3
5 | num_classes: 10
6 | dropout_rate: 0.5
7 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/network_cfg/enas_micro.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.networks.enas.ENASMicroNetwork
2 | num_layers: 2
3 | num_nodes: 5
4 | out_channels: 24
5 | in_channels: 3
6 | dropout_rate: 0.5
7 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/network_cfg/finegrained_resnet.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.networks.resnet.resnet20
--------------------------------------------------------------------------------
/hyperbox/configs/model/network_cfg/mobilenet2d_nas.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.networks.mobilenet.MobileNet
2 | width_stages: [24,40,80,96,192,320]
3 | n_cell_stages: [4,4,4,4,4,1]
4 | stride_stages: [2,1,1,1,2,1]
5 | width_mult: 1
6 | classes: 1000
7 | dropout_rate: 0
8 | bn_param: [0.1, 1e-3]
9 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/network_cfg/mobilenet3d_nas.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.networks.mobilenet.Mobile3DNet
2 | width_stages: [24,40,80,96,192,320]
3 | n_cell_stages: [4,4,4,4,4,1]
4 | stride_stages: [2,2,2,1,2,1]
5 | width_mult: 1
6 | classes: 1000
7 | dropout_rate: 0
8 | bn_param: [0.1, 1e-3]
9 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/network_cfg/nasbench201.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.networks.nasbench201.nasbench201.NASBench201Network
2 | stem_out_channels: 16
3 | num_modules_per_stack: 5
4 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/network_cfg/nasbench_mbnet.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.networks.nasbench_mbnet.network.NASBenchMBNet
2 | arch_list: null
3 | num_classes: 10
4 | stages: [2,3,3]
5 | init_channels: 32
6 | mask: null
--------------------------------------------------------------------------------
/hyperbox/configs/model/network_cfg/ofa_mbv3.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.networks.ofa.OFAMobileNetV3
2 | kernel_size_list: [3,5,7]
3 | expand_ratio_list: [3,4,6]
4 | depth_list: [2,3,4]
5 | base_stage_width: [16,16,24,40,80,112,160,960,1280]
6 | stride_stages: [1,2,2,2,1,2]
7 | act_stages: ['relu','relu','relu','h_swish','h_swish','h_swish']
8 | se_stages: [False,False,True,False,True,True]
9 | width_mult: 1.0
10 | num_classes: 1000
--------------------------------------------------------------------------------
/hyperbox/configs/model/network_cfg/proxylessnas.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox_app.distributed.networks.proxylessnas.network.ProxylessNAS
2 | width_stages: [24,40,80,96,192,320]
3 | n_cell_stages: [4,4,4,4,4,1]
4 | stride_stages: [2,2,2,1,2,1]
5 | width_mult: 1
6 | num_classes: 1000
7 | dropout_rate: 0
8 | bn_param: [0.1,1e-3]
9 | mask: null
--------------------------------------------------------------------------------
/hyperbox/configs/model/network_cfg/repnas_network.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.networks.repnas.RepNAS
2 | # in_channels: 3
3 | # channels: 16
4 | # n_layers: 8
5 | # stem_multiplier: 3
6 | # n_nodes: 4
7 | # n_classes: 10
8 | # mask: /home/comp/18481086/code/hyperbox/hyperbox/networks/darts/darts_mask.json
9 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/network_cfg/torch_network.yaml:
--------------------------------------------------------------------------------
1 | _target_: torchvision.models.mobilenet_v3_small
2 | num_classes: 10
--------------------------------------------------------------------------------
/hyperbox/configs/model/ofa_model.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.models.ofa_model.OFAModel
2 |
3 | defaults:
4 | - network_cfg: finegrained_resnet
5 | - mutator_cfg: random_mutator
6 | - optimizer_cfg: adam
7 | - metric_cfg: accuracy
8 | - loss_cfg: cross_entropy
9 | - scheduler_cfg: null
--------------------------------------------------------------------------------
/hyperbox/configs/model/optimizer_cfg/adadelta.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.Adadelta
2 | lr: 0.01
3 | alpha: 0.99
4 | eps: 1e-08
5 | weight_decay: 0.005
6 | momentum: 0.9
--------------------------------------------------------------------------------
/hyperbox/configs/model/optimizer_cfg/adam.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.Adam
2 | lr: 0.001
3 | weight_decay: 0.0005
4 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/optimizer_cfg/adamw.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.AdamW
2 | lr: 0.001
3 | weight_decay: 0.001
4 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/optimizer_cfg/asam.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.optimizers.ASAM
2 | rho: 0.5
3 | eta: 0.01
--------------------------------------------------------------------------------
/hyperbox/configs/model/optimizer_cfg/lamb.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.optimizers.lamb.Lamb
2 | lr: 1e-3
3 | betas: [0.9,0.999]
4 | eps: 1e-6
5 | weight_decay: 5e-4
6 | adam: False
--------------------------------------------------------------------------------
/hyperbox/configs/model/optimizer_cfg/rmsprop.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.RMSprop
2 | lr: 0.01
3 | alpha: 0.99
4 | eps: 1e-08
5 | weight_decay: 0.005
6 | momentum: 0.9
7 | centered: False
--------------------------------------------------------------------------------
/hyperbox/configs/model/optimizer_cfg/sam.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.optimizers.SAM
2 | rho: 0.5
3 | eta: 0.01
--------------------------------------------------------------------------------
/hyperbox/configs/model/optimizer_cfg/sgd.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.SGD
2 | lr: 0.01
3 | weight_decay: 0.0005
4 | momentum: 0.9
5 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/random_model.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.models.random_model.RandomModel
2 |
3 | defaults:
4 | - network_cfg: darts_network
5 | - mutator_cfg: random_mutator
6 | - optimizer_cfg: adam
7 | - metric_cfg: accuracy
8 | - loss_cfg: cross_entropy
9 | - scheduler_cfg: null
10 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/repnas_model.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.models.darts_model.DARTSModel
2 | # _target_: hyperbox.models.classify_model.ClassifyModel
3 | # _target_: hyperbox.models.random_model.RandomModel
4 |
5 | defaults:
6 | - network_cfg: repnas_network.yaml
7 | - mutator_cfg: repnas_mutator.yaml
8 | - optimizer_cfg: sgd
9 | - metric_cfg: accuracy
10 | - scheduler_cfg: warmup_scheduler
11 | - loss_cfg: cross_entropy
12 |
13 | # network_cfg:
14 | # mask: '/home/pdluser/mask_json/mask_epoch_60.json'
15 |
16 |
--------------------------------------------------------------------------------
/hyperbox/configs/model/resnet18.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.models.classify_model.ClassifyModel
2 |
3 | defaults:
4 | - metric_cfg: accuracy
5 | - loss_cfg: cross_entropy
6 | - optimizer_cfg: sgd
7 | - scheduler_cfg: null
8 |
9 | network_cfg:
10 | _target_: torchvision.models.resnet18
11 | optimizer_cfg:
12 | lr: 0.001
13 | weight_decay: 0.0005
--------------------------------------------------------------------------------
/hyperbox/configs/model/scheduler_cfg/CosineAnnealingLR.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.lr_scheduler.CosineAnnealingLR
2 | T_max: 180
3 | eta_min: 1e-4
4 | last_epoch: -1
--------------------------------------------------------------------------------
/hyperbox/configs/model/scheduler_cfg/ExponentialLR.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.lr_scheduler.ExponentialLR
2 | gamma: 0.99
3 | last_epoch: -1
--------------------------------------------------------------------------------
/hyperbox/configs/model/scheduler_cfg/MultiStepLR.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.lr_scheduler.MultiStepLR
2 | milestones: [80, 150, 220, 300]
3 | gamma: 0.8
4 | last_epoch: -1
--------------------------------------------------------------------------------
/hyperbox/configs/model/scheduler_cfg/ReducedLRonPlateau.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
2 | mode: 'max' # acc, 'min' for loss
3 | factor: 0.5
4 | patience: 5
5 | threshold: 0.0001
6 | threshold_mode: 'rel' # or 'abs'
7 | cooldown: 0
8 | min_lr: 1e-6
9 | verbose: False
--------------------------------------------------------------------------------
/hyperbox/configs/model/scheduler_cfg/warmup_scheduler.yaml:
--------------------------------------------------------------------------------
1 | _target_: hyperbox.schedulers.warmup_scheduler.GradualWarmupScheduler
2 | multiplier: 1
3 | warmup_epoch: 10
4 | after_scheduler:
5 | _target_: torch.optim.lr_scheduler.CosineAnnealingLR
6 | T_max: 180
7 | eta_min: 1e-4
8 | # after_scheduler: null
9 | # after_scheduler:
10 | # _target_: torch.optim.lr_scheduler.StepLR
11 | # step_size: 80
12 | # gamma: 0.1
--------------------------------------------------------------------------------
/hyperbox/configs/paths/default.yaml:
--------------------------------------------------------------------------------
1 | # path to root directory
2 | # this requires PROJECT_ROOT environment variable to exist
3 | # PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py`
4 | root_dir: .
5 |
6 | # path to data directory
7 | data_dir: ${paths.root_dir}/data/
8 |
9 | # path to logging directory
10 | log_dir: ${paths.root_dir}/logs/
11 |
12 | # path to output directory, created dynamically by hydra
13 | # path generation pattern is specified in `configs/hydra/default.yaml`
14 | # use it to store all files generated during the run, like ckpts and metrics
15 | output_dir: ${hydra:runtime.output_dir}
16 |
17 | # path to working directory
18 | work_dir: ${hydra:runtime.cwd}
--------------------------------------------------------------------------------
/hyperbox/configs/trainer/ddp.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | # use "ddp_spawn" instead of "ddp",
5 | # it's slower but normal "ddp" currently doesn't work ideally with hydra
6 | # https://github.com/facebookresearch/hydra/issues/2070
7 | # https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn
8 | strategy: ddp_spawn
9 |
10 | accelerator: gpu
11 | devices: 4
12 | num_nodes: 1
13 | sync_batchnorm: True
--------------------------------------------------------------------------------
/hyperbox/configs/trainer/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.Trainer
2 |
3 | min_epochs: 1 # prevents early stopping
4 | max_epochs: 10
5 |
6 | accelerator: gpu
7 | devices: 1
8 |
9 | # mixed precision for extra speed-up
10 | # precision: 16
11 |
12 | # set True to to ensure deterministic results
13 | # makes training slower but gives more reproducibility than just setting seeds
14 | deterministic: False
--------------------------------------------------------------------------------
/hyperbox/datamodules/__init__.py:
--------------------------------------------------------------------------------
1 | from hyperbox.utils.utils import _module_available
2 | from .cifar_datamodule import *
3 | from .fakedata_datamodule import *
4 | from .mnist_datamodule import *
5 |
6 | if _module_available("medmnist"):
7 | from .medmnist_datamodule import *
8 | if _module_available("nvidia.dali"):
9 | from .imagenet_datamodule import *
10 |
--------------------------------------------------------------------------------
/hyperbox/datamodules/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/datamodules/datasets/__init__.py
--------------------------------------------------------------------------------
/hyperbox/datamodules/distributed_sampler_wrapper.py:
--------------------------------------------------------------------------------
1 | from operator import itemgetter
2 | from typing import Iterator, List, Optional, Union
3 |
4 | from torch.utils.data import Dataset, DistributedSampler, Sampler
5 |
6 | __all__ = [
7 | 'DatasetFromSampler',
8 | 'DistributedSamplerWrapper'
9 | ]
10 |
11 | class DatasetFromSampler(Dataset):
12 | """Dataset to create indexes from `Sampler`.
13 | Args:
14 | sampler: PyTorch sampler
15 | """
16 |
17 | def __init__(self, sampler: Sampler):
18 | """Initialisation for DatasetFromSampler."""
19 | self.sampler = sampler
20 | self.sampler_list = None
21 |
22 | def __getitem__(self, index: int):
23 | """Gets element of the dataset.
24 | Args:
25 | index: index of the element in the dataset
26 | Returns:
27 | Single element by index
28 | """
29 | if self.sampler_list is None:
30 | self.sampler_list = list(self.sampler)
31 | return self.sampler_list[index]
32 |
33 | def __len__(self) -> int:
34 | """
35 | Returns:
36 | int: length of the dataset
37 | """
38 | return len(self.sampler)
39 |
40 |
41 | class DistributedSamplerWrapper(DistributedSampler):
42 | """
43 | Wrapper over `Sampler` for distributed training.
44 | Allows you to use any sampler in distributed mode.
45 | It is especially useful in conjunction with
46 | `torch.nn.parallel.DistributedDataParallel`. In such case, each
47 | process can pass a DistributedSamplerWrapper instance as a DataLoader
48 | sampler, and load a subset of subsampled data of the original dataset
49 | that is exclusive to it.
50 | .. note::
51 | Sampler is assumed to be of constant size.
52 | """
53 |
54 | def __init__(
55 | self,
56 | sampler,
57 | num_replicas: Optional[int] = None,
58 | rank: Optional[int] = None,
59 | shuffle: bool = True,
60 | ):
61 | """
62 | Args:
63 | sampler: Sampler used for subsampling
64 | num_replicas (int, optional): Number of processes participating in
65 | distributed training
66 | rank (int, optional): Rank of the current process
67 | within ``num_replicas``
68 | shuffle (bool, optional): If true (default),
69 | sampler will shuffle the indices
70 | """
71 | super(DistributedSamplerWrapper, self).__init__(
72 | DatasetFromSampler(sampler),
73 | num_replicas=num_replicas,
74 | rank=rank,
75 | shuffle=shuffle,
76 | )
77 | self.sampler = sampler
78 |
79 | def __iter__(self):
80 | """@TODO: Docs. Contribution is welcome."""
81 | self.dataset = DatasetFromSampler(self.sampler)
82 | indexes_of_indexes = super().__iter__()
83 | subsampler_indexes = self.dataset
84 | return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
85 |
--------------------------------------------------------------------------------
/hyperbox/datamodules/mnist_datamodule.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | from pytorch_lightning import LightningDataModule
4 | from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
5 | from torchvision.datasets import MNIST
6 | from torchvision.transforms import transforms
7 |
8 | __all__ = ['MNISTDataModule']
9 |
10 |
11 | class MNISTDataModule(LightningDataModule):
12 | """
13 | Example of LightningDataModule for MNIST dataset.
14 | A DataModule implements 5 key methods:
15 | - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
16 | - setup (things to do on every accelerator in distributed mode)
17 | - train_dataloader (the training dataloader)
18 | - val_dataloader (the validation dataloader(s))
19 | - test_dataloader (the test dataloader(s))
20 | This allows you to share a full dataset without explaining how to download,
21 | split, transform and process the data
22 | Read the docs:
23 | https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
24 | """
25 |
26 | def __init__(
27 | self,
28 | data_dir: str = "data/",
29 | train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000),
30 | batch_size: int = 64,
31 | num_workers: int = 0,
32 | pin_memory: bool = False,
33 | **kwargs,
34 | ):
35 | super().__init__()
36 |
37 | self.data_dir = data_dir
38 | self.train_val_test_split = train_val_test_split
39 | self.batch_size = batch_size
40 | self.num_workers = num_workers
41 | self.pin_memory = pin_memory
42 |
43 | self.transforms = transforms.Compose(
44 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
45 | )
46 |
47 | # self.dims is returned when you call datamodule.size()
48 | self.dims = (1, 28, 28)
49 |
50 | self.data_train: Optional[Dataset] = None
51 | self.data_val: Optional[Dataset] = None
52 | self.data_test: Optional[Dataset] = None
53 |
54 | @property
55 | def num_classes(self) -> int:
56 | return 10
57 |
58 | def prepare_data(self):
59 | """Download data if needed. This method is called only from a single GPU.
60 | Do not use it to assign state (self.x = y)."""
61 | MNIST(self.data_dir, train=True, download=True)
62 | MNIST(self.data_dir, train=False, download=True)
63 |
64 | def setup(self, stage: Optional[str] = None):
65 | """Load data. Set variables: self.data_train, self.data_val, self.data_test."""
66 | trainset = MNIST(self.data_dir, train=True, transform=self.transforms)
67 | testset = MNIST(self.data_dir, train=False, transform=self.transforms)
68 | dataset = ConcatDataset(datasets=[trainset, testset])
69 | self.data_train, self.data_val, self.data_test = random_split(
70 | dataset, self.train_val_test_split
71 | )
72 |
73 | def train_dataloader(self):
74 | return DataLoader(
75 | dataset=self.data_train,
76 | batch_size=self.batch_size,
77 | num_workers=self.num_workers,
78 | pin_memory=self.pin_memory,
79 | shuffle=True,
80 | )
81 |
82 | def val_dataloader(self):
83 | return DataLoader(
84 | dataset=self.data_val,
85 | batch_size=self.batch_size,
86 | num_workers=self.num_workers,
87 | pin_memory=self.pin_memory,
88 | shuffle=False,
89 | )
90 |
91 | def test_dataloader(self):
92 | return DataLoader(
93 | dataset=self.data_test,
94 | batch_size=self.batch_size,
95 | num_workers=self.num_workers,
96 | pin_memory=self.pin_memory,
97 | shuffle=False,
98 | )
99 |
--------------------------------------------------------------------------------
/hyperbox/datamodules/transforms/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .base_transforms import BaseTransforms
3 | from .torch_transforms import TorchTransforms
4 | from .cutout import Cutout
5 |
6 | import importlib
7 | alb = importlib.util.find_spec('albumentations')
8 | if alb is not None:
9 | from .albumentation_transforms import AlbumentationsTransforms
10 |
11 | def get_transforms(name, kwargs: dict):
12 | if name == 'torch':
13 | from .torch_transforms import TorchTransforms
14 | T = TorchTransforms(**kwargs)
15 | elif name == 'albumentation':
16 | from .albumentation_transforms import AlbumentationsTransforms
17 | T = AlbumentationsTransforms(**kwargs)
18 | return T
19 |
--------------------------------------------------------------------------------
/hyperbox/datamodules/transforms/base_transforms.py:
--------------------------------------------------------------------------------
1 |
2 | from omegaconf import DictConfig
3 |
4 | from hyperbox.utils.utils import hparams_wrapper
5 | from hyperbox.utils.logger import get_logger
6 | _logger = get_logger(__name__)
7 |
8 |
9 | @hparams_wrapper
10 | class BaseTransforms(object):
11 | def __init__(self, *args, **kwargs):
12 | for key, value in self.hparams.items():
13 | if isinstance(value, dict):
14 | value = DictConfig(value)
15 | setattr(self, key, value)
16 |
17 | def get_transform(self, is_train:bool = True):
18 | if not is_train:
19 | _logger.info('Generating validation transform ...')
20 | transform = self.valid_transform
21 | _logger.info(f'Valid transform={transform}')
22 | else:
23 | _logger.info('Generating training transform ...')
24 | transform = self.train_transform
25 | _logger.info(f'Train transform={transform}')
26 | return transform
27 |
28 | @property
29 | def transform_valid(self):
30 | raise NotImplementedError
31 |
32 | @property
33 | def transform_train(self):
34 | raise NotImplementedError
--------------------------------------------------------------------------------
/hyperbox/datamodules/transforms/cutout.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 |
5 |
6 | class Cutout(object):
7 | """Randomly mask out one or more patches from an image.
8 | https://github.com/uoguelph-mlrg/Cutout/blob/287f934ea5fa00d4345c2cccecf3552e2b1c33e3/train.py#L45
9 | Args:
10 | n_holes (int): Number of patches to cut out of each image.
11 | length (int): The length (in pixels) of each square patch.
12 | """
13 | def __init__(self, n_holes=1, length=16):
14 | self.n_holes = n_holes
15 | self.length = length
16 |
17 | def __call__(self, img):
18 | """
19 | Args:
20 | img (Tensor): Tensor image of size (C, H, W).
21 | Returns:
22 | Tensor: Image with n_holes of dimension length x length cut out of it.
23 | """
24 | h, w = img.shape[1:3]
25 |
26 | mask = np.ones((h, w), np.float32)
27 |
28 | for n in range(self.n_holes):
29 | y = np.random.randint(h)
30 | x = np.random.randint(w)
31 |
32 | y1 = np.clip(y - self.length // 2, 0, h)
33 | y2 = np.clip(y + self.length // 2, 0, h)
34 | x1 = np.clip(x - self.length // 2, 0, w)
35 | x2 = np.clip(x + self.length // 2, 0, w)
36 |
37 | mask[y1: y2, x1: x2] = 0.
38 |
39 | mask = torch.from_numpy(mask)
40 | mask = mask.expand_as(img)
41 | img = img * mask
42 |
43 | return img
--------------------------------------------------------------------------------
/hyperbox/datamodules/transforms/torch_transforms.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Union, Optional
3 |
4 | from omegaconf import DictConfig
5 | from torchvision import transforms
6 |
7 | from .cutout import Cutout
8 | from .base_transforms import BaseTransforms
9 |
10 |
11 | class TorchTransforms(BaseTransforms):
12 | def __init__(
13 | self,
14 | input_size: Union[list] = [32],
15 | random_resized_crop: Union[dict, DictConfig] = {'enable': 0, 'padding': 0},
16 | resize: Union[dict, DictConfig] = {'enable': 0},
17 | random_crop: Union[dict, DictConfig] = {'enable': 0},
18 | center_crop: Union[dict, DictConfig] = {'enable': 0},
19 | color_jitter: Union[dict, DictConfig] = {'enable': 0},
20 | random_horizontal_flip: Union[dict, DictConfig] = {'enable': 0, 'p': 0.5},
21 | random_vertical_flip: Union[dict, DictConfig] = {'enable': 0, 'p': 0.5},
22 | random_rotation: Union[dict, DictConfig] = {'enable': 0, 'degrees': 20},
23 | cutout: Union[dict, DictConfig] = {'enable': 0, 'n_holes': 8, 'length': 4},
24 | to_tensor: Union[dict, DictConfig] = {'enable': 1},
25 | normalize: Union[dict, DictConfig] = {
26 | 'enable': 1, 'mean': [0.4914, 0.4822, 0.4465], 'std': [0.2023, 0.1994, 0.2010]},
27 | ):
28 | super(TorchTransforms, self).__init__()
29 | self.min_edge_size = min(input_size)
30 | self._transform_train = self.parse_transforms()
31 |
32 | self._transform_valid = [
33 | transforms.Resize(self.input_size)
34 | ]
35 | if self.to_tensor.enable:
36 | self._transform_valid.append(transforms.ToTensor())
37 | if self.normalize.enable:
38 | mean = self.normalize.mean
39 | std = self.normalize.std
40 | self.normalize = transforms.Normalize(mean, std)
41 | self._transform_valid.append(self.normalize)
42 | self._transform_valid = transforms.Compose(self._transform_valid)
43 |
44 | def parse_transforms(self):
45 | img_size = self.input_size
46 | transform_list = []
47 |
48 | # resize and crop opertaion
49 | if self.random_resized_crop.enable:
50 | transform_list.append(transforms.RandomResizedCrop(img_size))
51 | elif self.resize.enable:
52 | transform_list.append(transforms.Resize(img_size))
53 | if self.random_crop.enable:
54 | padding = self.random_crop.padding
55 | size = getattr(self.random_crop, 'size', self.min_edge_size)
56 | transform_list.append(transforms.RandomCrop(size, padding=padding))
57 | elif self.center_crop.enable:
58 | transform_list.append(transforms.CenterCrop(self.min_edge_size))
59 |
60 | # ColorJitter
61 | if self.color_jitter.enable:
62 | params = {key: self.color_jitter[key] for key in self.color_jitter
63 | if key != 'enable'}
64 | transform_list.append(transforms.ColorJitter(**params))
65 |
66 | # horizontal flip
67 | if self.random_horizontal_flip.enable:
68 | p = self.random_horizontal_flip.p
69 | transform_list.append(transforms.RandomHorizontalFlip(p))
70 |
71 | # vertical flip
72 | if self.random_vertical_flip.enable:
73 | p = self.random_vertical_flip.p
74 | transform_list.append(transforms.RandomVerticalFlip(p))
75 |
76 | # rotation
77 | if self.random_rotation.enable:
78 | degrees = self.random_rotation.degrees
79 | transform_list.append(transforms.RandomRotation(degrees))
80 |
81 | # cutout
82 | if self.to_tensor.enable:
83 | transform_list.append(transforms.ToTensor())
84 | if self.cutout.enable:
85 | n_holes = self.cutout.n_holes
86 | length = self.cutout.length
87 | transform_list.append(Cutout(n_holes, length))
88 | if self.normalize.enable:
89 | mean = self.normalize.mean
90 | std = self.normalize.std
91 | transform_list.append(transforms.Normalize(mean, std))
92 | transform_list = transforms.Compose(transform_list)
93 | assert len(transform_list.transforms) > 0, "The length of transform list much be larger than 0."
94 | return transform_list
95 |
96 | @property
97 | def transform_train(self):
98 | raise self._transform_train
99 |
100 | @property
101 | def transform_valid(self):
102 | return self._transform_train
103 |
--------------------------------------------------------------------------------
/hyperbox/engine/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/engine/__init__.py
--------------------------------------------------------------------------------
/hyperbox/engine/base_engine.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class BaseEngine:
4 | """
5 | Base class for all engines.
6 | """
7 |
8 | def __init__(
9 | self,
10 | trainer,
11 | model,
12 | datamodule,
13 | cfg
14 | ):
15 | self.trainer = trainer
16 | self.model = model
17 | self.datamodule = datamodule
18 | self.cfg = cfg
19 |
20 | def run(self):
21 | raise NotImplementedError
22 |
--------------------------------------------------------------------------------
/hyperbox/losses/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/losses/__init__.py
--------------------------------------------------------------------------------
/hyperbox/losses/ce_labelsmooth_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | __all__ = [
5 | 'CrossEntropyLabelSmooth',
6 | ]
7 |
8 |
9 | class CrossEntropyLabelSmooth(torch.nn.Module):
10 | def __init__(self, label_smoothing, weight=None):
11 | super(CrossEntropyLabelSmooth, self).__init__()
12 | self.label_smoothing = label_smoothing
13 | if weight is not None:
14 | self.weight = torch.tensor(weight)
15 | else:
16 | self.weight = None
17 |
18 | def forward(self, pred, target):
19 | if self.weight is not None:
20 | self.weight = self.weight.to(pred.device)
21 | logsoftmax = torch.nn.LogSoftmax(dim=1)
22 | n_classes = pred.size(1)
23 | # convert to one-hot
24 | target = torch.unsqueeze(target, 1)
25 | soft_target = torch.zeros_like(pred)
26 | soft_target.scatter_(1, target, 1)
27 | # label smoothing
28 | soft_target = soft_target * (1 - self.label_smoothing) + self.label_smoothing / n_classes
29 | if self.weight is not None:
30 | soft_target *= self.weight
31 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
32 |
--------------------------------------------------------------------------------
/hyperbox/losses/focal_loss.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | __all__ = [
7 | 'FocalLoss',
8 | '_FocalLoss'
9 | ]
10 |
11 |
12 | class FocalLoss(torch.nn.Module):
13 | def __init__(self, alpha, gamma, num_classes, size_average=True):
14 | """focal_loss function: -α(1-yi)**γ *ce_loss(xi,yi)
15 | Args:
16 | alpha: class weight (default 0.25).
17 | When α is a 'list', it indicates the class-wise weights;
18 | When α is a constant,
19 | if in detection task, it indicates that the class-wise weights are[α, 1-α, 1-α, ...],
20 | the first class indicates the background
21 | if in classification task, it indicates that the class-wise weights are the same
22 | gamma: γ (default 2), focusing paramter smoothly adjusts the rate at which easy examples are down-weighted.
23 | num_classes: the number of classes
24 | size_average: (default 'mean'/'sum') specify the way to compute the loss value
25 | """
26 |
27 | super(_FocalLoss,self).__init__()
28 | self.size_average = size_average
29 | if isinstance(alpha,list):
30 | assert len(alpha)==num_classes
31 | alpha /= np.sum(alpha) # setting the value in range of [0, 1]
32 | # print("Focal loss alpha = {}, assign specific weights for each class".format(alpha))
33 | self.alpha = torch.Tensor(alpha)
34 | else:
35 | assert alpha<=1
36 | self.alpha = torch.zeros(num_classes)+0.00001
37 |
38 | # classification task
39 | # print("Focal loss alpha={}, the weight for each class is the same".format(alpha))
40 | self.alpha += alpha
41 |
42 | # detection task # 如果α为一个常数,则降低第一类的影响,在目标检测中背景为第一类
43 | # print(" --- Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用 --- ".format(alpha))
44 | # self.alpha[0] += alpha
45 | # self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]
46 | self.gamma = gamma
47 |
48 | def forward(self, predictions, labels):
49 | """
50 | focal_loss损失计算
51 | Args:
52 | preds: 预测类别. size:[B,N,C] or [B,C] 分别对应与检测与分类任务, B 批次, N检测框数, C类别数
53 | labels: 实际类别. size:[B,N] or [B]
54 | return:
55 | loss
56 | """
57 | assert predictions.dim()==2 and labels.dim()==1
58 | preds = predictions.view(-1,predictions.size(-1)) # num*classes
59 | alpha = self.alpha.to(labels.device)
60 |
61 | # 这里并没有直接使用log_softmax, 因为后面会用到softmax的结果(当然你也可以使用log_softmax,然后进行exp操作)
62 | preds_softmax = F.softmax(preds, dim=1)
63 | preds_logsoft = torch.log(preds_softmax)
64 |
65 | # implement nll_loss ( crossempty = log_softmax + nll )
66 | preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) # num*1
67 | preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1)) # num*1
68 | alpha = alpha.gather(0,labels.view(-1)) # num
69 |
70 | # calc loss
71 | # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ
72 | # shape: num*1
73 | loss = -torch.mul( torch.pow((1-preds_softmax), self.gamma), preds_logsoft )
74 | # α * (1-pt)**γ * ce_loss
75 | # shape:
76 | loss = torch.mul(alpha, loss.t())
77 | del preds
78 | del alpha
79 | if self.size_average:
80 | loss = loss.mean()
81 | else:
82 | loss = loss.sum()
83 | return loss
--------------------------------------------------------------------------------
/hyperbox/losses/kd_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 |
3 |
4 | def KDLoss(outputs, teacher_outputs, alpha=0.4, temperature=2):
5 | """
6 | Compute the knowledge-distillation (KD) loss given outputs, labels.
7 | "Hyperparameters": temperature and alpha
8 | NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
9 | and student expects the input tensor to be log probabilities! See Issue #2
10 | """
11 | T = temperature
12 | student = F.log_softmax(outputs/T, dim=-1)
13 | teacher = F.softmax(teacher_outputs/T, dim=-1)
14 | KD_loss = F.kl_div(student, teacher)* (alpha * T * T)
15 |
16 | return KD_loss
--------------------------------------------------------------------------------
/hyperbox/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/models/__init__.py
--------------------------------------------------------------------------------
/hyperbox/mutables/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | 'spaces':
4 | ['Mutable', 'OperationSpace', 'InputSpace', 'MutableScope', 'ValueSpace']
5 | ||
6 | ||
7 | \/
8 | 'ops'
9 | ['Conv1d', 'Conv2d', 'Conv3d', 'Linear', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d]
10 | ||
11 | ||
12 | \/
13 | 'layers'
14 | ['MBLayer', ...]
15 | '''
--------------------------------------------------------------------------------
/hyperbox/mutables/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .layers2d import *
--------------------------------------------------------------------------------
/hyperbox/mutables/masker.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | __all__ = [
6 | '__MASKERS__',
7 | 'BaseMasker',
8 | 'L1Masker'
9 | ]
10 |
11 |
12 | class BaseMasker:
13 | def __init__(self,):
14 | pass
15 |
16 | def get_channel_sortedIdx(self, module):
17 | raise NotImplementedError
18 |
19 | def __call__(self, module, is_in_feat=True):
20 | return self.get_channel_sortedIdx(module, is_in_feat)
21 |
22 |
23 | class L1Masker(BaseMasker):
24 | def get_channel_sortedIdx(self, module, is_in_feat=True):
25 | num_dim = len(module.weight.shape)
26 | dim = list(range(num_dim))
27 | if num_dim > 1: # conv or linear
28 | if is_in_feat:
29 | del dim[1] # [0, 2, 3] for conv2d, [0, 2, 3, 4] for conv3, [0] for linear
30 | else:
31 | del dim[0] # [1, 2, 3]
32 | importance = torch.sum(torch.abs(module.weight.data), dim=dim)
33 | else: # BN
34 | importance = torch.abs(module.weight.data)
35 | sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
36 | return sorted_idx
37 |
38 |
39 | __MASKERS__ = {
40 | 'l1': L1Masker,
41 | }
42 |
--------------------------------------------------------------------------------
/hyperbox/mutables/ops/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_module import *
2 | from .conv import *
3 | from .linear import *
4 | from .batchnorm import *
5 | from .multihead_attention import *
6 | from .embedding import *
7 | from .layernorm import *
8 | from .groupnorm import *
9 |
--------------------------------------------------------------------------------
/hyperbox/mutables/ops/base_module.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 | from hyperbox.mutables.spaces import ValueSpace
7 | from hyperbox.utils.utils import hparams_wrapper
8 |
9 | __all__ = [
10 | 'FinegrainedModule'
11 | ]
12 |
13 |
14 | @hparams_wrapper
15 | class FinegrainedModule(nn.Module):
16 | def __init__(self, *args, **kwargs):
17 | super(FinegrainedModule, self).__init__()
18 | # The decorator @hparams_wrapper can automatically save all input arguments to
19 | # ``hparams`` attribute
20 | self.value_spaces = self.getValueSpaces(self.hparams)
21 |
22 | def getValueSpaces(self, kwargs):
23 | value_spaces = nn.ModuleDict()
24 | for key, value in kwargs.items():
25 | if key in ['weight', 'bias']:
26 | if hasattr(self, key): delattr(self, key)
27 | key = '_' + key
28 | if isinstance(value, ValueSpace):
29 | value_spaces[key] = value
30 | if value.index is not None:
31 | _v = value.candidates_original[value.index]
32 | elif len(value.mask) != 0:
33 | if isinstance(value.mask, torch.Tensor):
34 | index = value.mask.clone().detach().cpu().numpy().argmax()
35 | else:
36 | index = np.array(value.mask).argmax()
37 | _v = value.candidates_original[index]
38 | else:
39 | _v = value.max_value
40 | setattr(self, key, _v)
41 | else:
42 | setattr(self, key, value)
43 | return value_spaces
44 |
45 | def __deepcopy__(self, memo):
46 | try:
47 | new_instance = self.__class__(**self.hparams)
48 | device = next(self.parameters()).device
49 | new_instance.load_state_dict(self.state_dict())
50 | return new_instance.to(device)
51 | except Exception as e:
52 | print(str(e))
53 |
--------------------------------------------------------------------------------
/hyperbox/mutables/ops/embedding.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 |
8 | from hyperbox.mutables.ops.base_module import FinegrainedModule
9 | from hyperbox.mutables.ops.utils import is_searchable
10 | from hyperbox.mutables.spaces import ValueSpace
11 |
12 |
13 | class Embedding(nn.Embedding, FinegrainedModule):
14 | def __init__(
15 | self,
16 | num_embeddings: int,
17 | embedding_dim: Union[int, ValueSpace],
18 | device=None,
19 | dtype=None,
20 | *args, **kwargs,
21 | ) -> None:
22 | if isinstance(embedding_dim, ValueSpace):
23 | _embedding_dim = embedding_dim.max_value
24 | else:
25 | _embedding_dim = embedding_dim
26 | super(Embedding, self).__init__(
27 | num_embeddings, _embedding_dim, *args, **kwargs)
28 | self.is_search = self.isSearchEmbedding()
29 |
30 | def forward(self, input: Tensor) -> Tensor:
31 | if not self.is_search:
32 | x = F.embedding(input, self.weight, self.padding_idx, self.max_norm,
33 | self.norm_type, self.scale_grad_by_freq, self.sparse)
34 | else:
35 | embedding_dim = self.value_spaces["embedding_dim"].value
36 | weight = self.weight[:, :embedding_dim]
37 | x = F.embedding(input, weight, self.padding_idx, self.max_norm,
38 | self.norm_type, self.scale_grad_by_freq, self.sparse)
39 | return x
40 |
41 | def isSearchEmbedding(self):
42 | '''search flag
43 | search
44 | search_embedding_dim
45 | '''
46 | self.search_embedding_dim = False
47 | if all([not vs.is_search for vs in self.value_spaces.values()]):
48 | return False
49 | if is_searchable(getattr(self.value_spaces, 'embedding_dim', None)):
50 | self.search_embedding_dim = True
51 | return True
52 |
53 | @property
54 | def params(self):
55 | weight = self.weight
56 | if self.search_embedding_dim:
57 | weight = weight[:, :self.value_spaces['embedding_dim'].value]
58 | size = weight.numel()
59 | return size
60 |
61 |
62 | if __name__ == "__main__":
63 | from hyperbox.mutator import RandomMutator
64 | for i in range(10):
65 | embed_dim = ValueSpace([10, 20])
66 | embedding = Embedding(
67 | 10, embed_dim, max_norm=1.0, norm_type=2.0,
68 | scale_grad_by_freq=True, sparse=True)
69 | rm = RandomMutator(embedding)
70 | print(embedding)
71 | for j in range(4):
72 | rm.reset()
73 | x = torch.randint(0, 10, (2, 30))
74 | y = embedding(x)
75 | print(y.shape)
76 |
--------------------------------------------------------------------------------
/hyperbox/mutables/ops/groupnorm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from hyperbox.mutables.ops.base_module import FinegrainedModule
6 | from hyperbox.mutables.ops.utils import is_searchable
7 | from hyperbox.mutables.spaces import ValueSpace
8 |
9 |
10 | class GroupNorm(nn.GroupNorm, FinegrainedModule):
11 |
12 | def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
13 | _num_groups = num_groups.max_value if isinstance(num_groups, ValueSpace) else num_groups
14 | _num_channels = num_channels.max_value if isinstance(num_channels, ValueSpace) else num_channels
15 | super(GroupNorm, self).__init__(_num_groups, _num_channels, eps, affine)
16 | self.is_search = self.isSearchGroupNorm()
17 |
18 | def isSearchGroupNorm(self):
19 | if all([not vs.is_search for vs in self.value_spaces.values()]):
20 | return False
21 |
22 | if is_searchable(getattr(self.value_spaces, 'num_groups', None)):
23 | self.search_num_groups = True
24 | else:
25 | self.search_num_groups = False
26 | if is_searchable(getattr(self.value_spaces, 'num_channels', None)):
27 | self.search_num_channels = True
28 | else:
29 | self.search_num_channels = False
30 | return True
31 |
32 | def forward(self, input):
33 | if not self.is_search:
34 | x = F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
35 | else:
36 | n_groups = self.value_spaces['num_groups'].value if self.search_num_groups else self.num_groups
37 | n_channels = self.value_spaces['num_channels'].value if self.search_num_channels else self.num_channels
38 | return F.group_norm(input, n_groups, self.weight[:n_channels], self.bias[:n_channels], self.eps)
39 |
40 | @property
41 | def params(self):
42 | weight = self.weight
43 | bias = self.bias
44 | if self.search_num_channels:
45 | weight = weight[:self.value_spaces['num_channels'].value]
46 | bias = bias[:self.value_spaces['num_channels'].value] if bias is not None else None
47 | parameters = [weight, bias]
48 | params = sum([p.numel() for p in parameters if p is not None])
49 | return params
50 |
51 |
52 | if __name__ == '__main__':
53 | from hyperbox.mutator import RandomMutator
54 | from hyperbox.mutables.ops import Conv2d
55 | input = torch.randn(20, 6, 10, 10)
56 |
57 | #### no search
58 | ##### Separate 6 channels into 3 groups
59 | print("no search")
60 | m = nn.GroupNorm(3, 6)
61 | output = m(input)
62 | print(output.shape)
63 | ##### Separate 6 channels into 6 groups (equivalent with InstanceNorm)
64 | m = nn.GroupNorm(6, 6)
65 | output = m(input)
66 | print(output.shape)
67 | ##### Put all 6 channels into a single group (equivalent with LayerNorm)
68 | m = nn.GroupNorm(1, 6)
69 | output = m(input)
70 | print(output.shape)
71 |
72 |
73 | #### search only n_groups
74 | print("search only n_groups")
75 | n_groups = ValueSpace([2, 3, 6, 1], key='num_groups')
76 | m = GroupNorm(n_groups, 6)
77 | rm = RandomMutator(m)
78 | for i in range(5):
79 | rm.reset()
80 | output = m(input)
81 | print(output.shape)
82 |
83 | # search only n_channels
84 | print("search only n_channels")
85 | n_channels = ValueSpace([12, 18, 24], key='num_channels')
86 | m = nn.Sequential(
87 | Conv2d(6, n_channels, 3, 1, 1),
88 | GroupNorm(3, n_channels)
89 | )
90 | rm = RandomMutator(m)
91 | for i in range(10):
92 | rm.reset()
93 | output = m(input)
94 | print(output.shape)
95 |
96 | # search n_groups and n_channels
97 | print("search n_groups and n_channels")
98 | n_groups = ValueSpace([2, 3, 6, 1], key='num_groups2')
99 | n_channels = ValueSpace([12, 18, 24], key='num_channels2')
100 | m = nn.Sequential(
101 | Conv2d(6, n_channels, 3, 1, 1),
102 | GroupNorm(n_groups, n_channels)
103 | )
104 | rm = RandomMutator(m)
105 | for i in range(10):
106 | rm.reset()
107 | output = m(input)
108 | print(output.shape)
--------------------------------------------------------------------------------
/hyperbox/mutables/ops/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from ..spaces import Mutable
7 |
8 | def is_searchable(
9 | obj: Optional[Union[None, Mutable]]
10 | ):
11 | '''Check whether the Space obj is searchable'''
12 | if (obj is None) or (not obj.is_search):
13 | return False
14 | return True
15 |
16 | def sub_filter_start_end(kernel_size, sub_kernel_size):
17 | if isinstance(kernel_size, (list, tuple)):
18 | kernel_size = kernel_size[0]
19 | center = kernel_size // 2
20 | dev = sub_kernel_size // 2
21 | start, end = center - dev, center + dev + 1
22 | assert end - start == sub_kernel_size
23 | return start, end
24 |
--------------------------------------------------------------------------------
/hyperbox/mutator/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_mutator import BaseMutator
2 | from .default_mutator import Mutator
3 | from .darts_mutator import DartsMutator
4 | from .evolution_mutator import EvolutionMutator
5 | from .enas_mutator import EnasMutator
6 | from .onehot_mutator import OnehotMutator
7 | from .random_mutator import RandomMutator
8 | from .sequential_mutator import SequentialMutator
9 | from .proxyless_mutator import ProxylessMutator
10 | from .fixed_mutator import apply_fixed_architecture, FixedArchitecture
11 | from .random_multiple_mutator import RandomMultipleMutator
12 | from .darts_multiple_mutator import DartsMultipleMutator
13 | from .repnas_mutator import RepnasMutator
14 |
--------------------------------------------------------------------------------
/hyperbox/mutator/fixed_mutator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import json
5 |
6 | import torch
7 |
8 | from hyperbox.mutables.spaces import MutableScope
9 |
10 | from hyperbox.mutator.default_mutator import Mutator
11 |
12 |
13 | class FixedArchitecture(Mutator):
14 |
15 | def __init__(self, model, fixed_arc, strict=True):
16 | """
17 | Initialize a fixed architecture mutator.
18 |
19 | Parameters
20 | ----------
21 | model : nn.Module
22 | A mutable network.
23 | fixed_arc : str or dict
24 | Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
25 | strict : bool
26 | Force everything that appears in `fixed_arc` to be used at least once.
27 | """
28 | super().__init__(model)
29 | self._fixed_arc = fixed_arc
30 |
31 | mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)])
32 | fixed_arc_keys = set(self._fixed_arc.keys())
33 | if fixed_arc_keys - mutable_keys:
34 | raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys))
35 | if mutable_keys - fixed_arc_keys:
36 | raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys))
37 |
38 | def sample_search(self):
39 | result = self._fixed_arc # for OperationSpace and InputSpace
40 | for mutable in self.mutables:
41 | mutable.mask = result[mutable.key] # for ValueSpace
42 | return result
43 |
44 | def sample_final(self):
45 | return self.sample_search()
46 |
47 |
48 | def _encode_tensor(data):
49 | if isinstance(data, list):
50 | if all(map(lambda o: isinstance(o, bool), data)):
51 | return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable
52 | else:
53 | return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable
54 | if isinstance(data, dict):
55 | return {k: _encode_tensor(v) for k, v in data.items()}
56 | return data
57 |
58 |
59 | def apply_fixed_architecture(model, fixed_arc_path):
60 | """
61 | Load architecture from `fixed_arc_path` and apply to model.
62 |
63 | Parameters
64 | ----------
65 | model : torch.nn.Module
66 | Model with mutables.
67 | fixed_arc_path : str
68 | Path to the JSON that stores the architecture.
69 |
70 | Returns
71 | -------
72 | FixedArchitecture
73 | """
74 |
75 | if isinstance(fixed_arc_path, str):
76 | with open(fixed_arc_path, "r") as f:
77 | fixed_arc = json.load(f)
78 | fixed_arc = _encode_tensor(fixed_arc)
79 | architecture = FixedArchitecture(model, fixed_arc)
80 | architecture.reset()
81 | return architecture
82 |
--------------------------------------------------------------------------------
/hyperbox/mutator/onehot_mutator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from hyperbox.mutables.spaces import InputSpace, OperationSpace, ValueSpace
10 |
11 | from hyperbox.mutator.default_mutator import Mutator
12 | from .darts_mutator import DartsMutator
13 |
14 | __all__ = [
15 | 'OnehotMutator',
16 | ]
17 |
18 |
19 | class OnehotMutator(DartsMutator):
20 | def __init__(self, model, *args, **kwargs):
21 | super().__init__(model, *args, **kwargs)
22 |
23 | def sample_search(self):
24 | result = dict()
25 | for mutable in self.mutables:
26 | if isinstance(mutable, OperationSpace):
27 | result[mutable.key] = F.gumbel_softmax(self.choices[mutable.key], hard=True, dim=-1)
28 | mutable.mask = torch.zeros_like(result[mutable.key])
29 | mutable.mask[result[mutable.key].cpu().detach().numpy().argmax()] = 1
30 | elif isinstance(mutable, ValueSpace):
31 | result[mutable.key] = F.gumbel_softmax(self.choices[mutable.key], hard=True, dim=-1)
32 | mutable.mask.data = F.gumbel_softmax(self.choices[mutable.key], hard=True, dim=-1).data
33 | elif isinstance(mutable, InputSpace):
34 | result[mutable.key] = F.gumbel_softmax(self.choices[mutable.key], hard=True, dim=-1)
35 | mutable.mask = torch.zeros_like(result[mutable.key])
36 | mutable.mask[result[mutable.key].cpu().detach().numpy().argmax()] = 1
37 | return result
38 |
39 | def sample_final(self):
40 | return super().sample_final()
41 |
--------------------------------------------------------------------------------
/hyperbox/mutator/random_multiple_mutator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from hyperbox.mutables.spaces import InputSpace, OperationSpace, ValueSpace
6 | from hyperbox.mutator.default_mutator import Mutator
7 | from hyperbox.mutator.random_mutator import RandomMutator
8 |
9 |
10 | __all__ = [
11 | 'RandomMultipleMutator',
12 | ]
13 |
14 |
15 | class RandomMultipleMutator(RandomMutator):
16 | def __init__(self, model, single_path_prob=0, *args, **kwargs):
17 | super(RandomMultipleMutator, self).__init__(model)
18 | self.single_path_prob = single_path_prob
19 |
20 | def sample_search(self):
21 | if np.random.rand() < self.single_path_prob:
22 | return super(RandomMultipleMutator, self).sample_search()
23 | result = dict()
24 | for mutable in self.mutables:
25 | if isinstance(mutable, OperationSpace):
26 | crt_mask = torch.randint(high=2, size=(mutable.length,)).view(-1).bool()
27 | if crt_mask.sum() == 0:
28 | crt_mask[-1] = 1
29 | result[mutable.key] = crt_mask.view(-1).bool()
30 | mutable.mask = result[mutable.key].detach()
31 | elif isinstance(mutable, InputSpace):
32 | if mutable.n_chosen is None:
33 | result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool()
34 | else:
35 | perm = torch.randperm(mutable.n_candidates)
36 | mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)]
37 | result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable
38 | mutable.mask = result[mutable.key].detach()
39 | elif isinstance(mutable, ValueSpace):
40 | gen_index = torch.randint(high=mutable.length, size=(1, ))
41 | result[mutable.key] = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool()
42 | mutable.mask = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool()
43 | return result
44 |
45 | def sample_final(self):
46 | return self.sample_search()
47 |
--------------------------------------------------------------------------------
/hyperbox/mutator/random_mutator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from hyperbox.mutables.spaces import InputSpace, OperationSpace, ValueSpace
6 |
7 | from hyperbox.mutator.default_mutator import Mutator
8 |
9 | __all__ = [
10 | 'RandomMutator',
11 | ]
12 |
13 |
14 | class RandomMutator(Mutator):
15 | def __init__(self, model, *args, **kwargs):
16 | super().__init__(model)
17 |
18 | def sample_search(self):
19 | result = dict()
20 | for mutable in self.mutables:
21 | if isinstance(mutable, OperationSpace):
22 | gen_index = torch.randint(high=mutable.length, size=(1, ))
23 | result[mutable.key] = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool()
24 | mutable.mask = result[mutable.key].detach()
25 | elif isinstance(mutable, InputSpace):
26 | if mutable.n_chosen is None:
27 | result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool()
28 | else:
29 | perm = torch.randperm(mutable.n_candidates)
30 | mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)]
31 | result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable
32 | mutable.mask = result[mutable.key].detach()
33 | elif isinstance(mutable, ValueSpace):
34 | gen_index = torch.randint(high=mutable.length, size=(1, ))
35 | result[mutable.key] = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool()
36 | mutable.mask = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool()
37 | return result
38 |
39 | def sample_final(self):
40 | return self.sample_search()
41 |
--------------------------------------------------------------------------------
/hyperbox/mutator/sequential_mutator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from hyperbox.mutables.spaces import InputSpace, OperationSpace, ValueSpace
7 | from hyperbox.mutator.default_mutator import Mutator
8 |
9 |
10 | __all__ = [
11 | 'SequentialMutator',
12 | ]
13 |
14 | class SequentialMutator(Mutator):
15 | def __init__(self, model, start_idx: int):
16 | super().__init__(model)
17 | with open('./mutator/Track1_final_archs.json', 'r') as f:
18 | self.masks = json.load(f)
19 | self.crt_index = start_idx
20 | self.max_num = len(self.masks)
21 | assert self.crt_index > 0, 'Index should start at 1'
22 |
23 | def sample_search(self):
24 | result = dict()
25 | for mutable in self.mutables:
26 | if isinstance(mutable, OperationSpace):
27 | gen_index = torch.randint(high=mutable.length, size=(1, ))
28 | result[mutable.key] = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool()
29 | mutable.mask = torch.zeros_like(result[mutable.key])
30 | mutable.mask[result[mutable.key].detach().numpy().argmax()] = 1
31 | elif isinstance(mutable, InputSpace):
32 | if mutable.n_chosen is None:
33 | result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool()
34 | else:
35 | perm = torch.randperm(mutable.n_candidates)
36 | mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)]
37 | result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable
38 | mutable.mask = torch.zeros_like(result[mutable.key])
39 | mutable.mask[result[mutable.key].detach().numpy().argmax()] = 1
40 | elif isinstance(mutable, ValueSpace):
41 | index_choice = int(mutable.key.split('ValueSpace')[-1]) - 1
42 | value = self.masks[f'arch{self.crt_index}']['arch'].split('-')[index_choice]
43 | gen_index = np.argwhere(np.array(mutable.candidates)==int(value))[0][0]
44 | gen_index = torch.tensor(gen_index)
45 | result[mutable.key] = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool()
46 | mutable.mask = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool()
47 | if self.crt_index >= self.max_num:
48 | self.crt_index = 1
49 | else:
50 | self.crt_index += 1
51 | return result
52 |
53 | def sample_final(self):
54 | return self.sample_search()
55 |
--------------------------------------------------------------------------------
/hyperbox/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/networks/__init__.py
--------------------------------------------------------------------------------
/hyperbox/networks/bnnas/__init__.py:
--------------------------------------------------------------------------------
1 | from .bn_net import *
2 | from .bn_blocks import *
--------------------------------------------------------------------------------
/hyperbox/networks/bnnas/bn_blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | __all__ = [
5 | 'blocks_dict',
6 | 'InvertedResidual',
7 | 'conv_1x1_bn',
8 | 'conv_bn'
9 | ]
10 |
11 |
12 | blocks_dict = {
13 | 'k3r3':lambda inp, oup, stride : InvertedResidual(inp, oup, 3, 1, stride, 3),
14 | 'k3r6':lambda inp, oup, stride : InvertedResidual(inp, oup, 3, 1, stride, 6),
15 | 'k5r3':lambda inp, oup, stride : InvertedResidual(inp, oup, 5, 2, stride, 3),
16 | 'k5r6':lambda inp, oup, stride : InvertedResidual(inp, oup, 5, 2, stride, 6),
17 | 'k7r3':lambda inp, oup, stride : InvertedResidual(inp, oup, 7, 3, stride, 3),
18 | 'k7r6':lambda inp, oup, stride : InvertedResidual(inp, oup, 7, 3, stride, 6),
19 | }
20 |
21 |
22 | class InvertedResidual(nn.Module):
23 | def __init__(self, inp, oup, ksize, padding, stride, expand_ratio):
24 | super(InvertedResidual, self).__init__()
25 | self.stride = stride
26 | self.use_res_connect = self.stride == 1 and inp == oup
27 | self.expand_ratio = expand_ratio
28 |
29 | if expand_ratio == 1:
30 | self.conv = nn.Sequential(
31 | # dw
32 | nn.Conv2d(inp, inp, ksize, stride, padding, groups=inp, bias=False),
33 | nn.BatchNorm2d(inp),
34 | nn.ReLU6(inplace=True),
35 | # pw-linear
36 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
37 | nn.BatchNorm2d(oup),
38 | )
39 | else:
40 | self.conv = nn.Sequential(
41 | # pw
42 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False),
43 | nn.BatchNorm2d(inp * expand_ratio),
44 | nn.ReLU6(inplace=True),
45 | # dw
46 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, ksize, stride, padding, groups=inp * expand_ratio, bias=False),
47 | nn.BatchNorm2d(inp * expand_ratio),
48 | nn.ReLU6(inplace=True),
49 | # pw-linear
50 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False),
51 | nn.BatchNorm2d(oup),
52 | )
53 |
54 | def forward(self, x):
55 | if self.use_res_connect:
56 | return x + self.conv(x)
57 | else:
58 | return self.conv(x)
59 |
60 |
61 | def conv_bn(inp, oup, stride):
62 | return nn.Sequential(
63 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
64 | nn.BatchNorm2d(oup),
65 | nn.ReLU6(inplace=True)
66 | )
67 |
68 | def conv_1x1_bn(inp, oup):
69 | return nn.Sequential(
70 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
71 | nn.BatchNorm2d(oup),
72 | nn.ReLU6(inplace=True)
73 | )
74 |
75 |
--------------------------------------------------------------------------------
/hyperbox/networks/bnnas/ea_search.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 |
7 | from hyperbox.utils.utils import TorchTensorEncoder, save_arch_to_json
8 | from hyperbox.networks.bnnas import BNNet
9 | from hyperbox.mutator.ea_mutator import EAMutator
10 | from hyperbox.mutator.utils import NonDominatedSorting
11 |
12 |
13 | if __name__ == '__main__':
14 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
15 | net = BNNet(num_classes=10, search_depth=False)
16 | # ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_bn_depth_adam0.001_sync_hete/2021-10-02_23-05-51/checkpoints/epoch=390_val/acc=27.8100.ckpt'
17 | # ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_all_depth_adam0.001_sync_hete/2021-10-02_23-05-43/checkpoints/epoch=339_val/acc=43.1300.ckpt'
18 | # ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_bn_adam0.001_sync_hete/2021-10-06_06-29-41/checkpoints/epoch=392_val/acc=28.8200.ckpt'
19 | # ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_all_adam0.001_sync_hete/2021-10-06_06-31-00/checkpoints/epoch=302_val/acc=44.0300.ckpt'
20 | # ckpt = '/datasets/xinhe/xinhe/hyperbox/logs/runs/bnnas_c10_all_bn/2021-10-22_04-03-29/checkpoints/epoch=14_val/acc=25.1800.ckpt'
21 | # ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_all20_bn20/2021-10-27_05-33-45/checkpoints/epoch=39_val/acc=39.3800.ckpt'
22 | # ckpt = torch.load(ckpt, map_location='cpu')
23 | # weights = {}
24 | # for key in ckpt['state_dict']:
25 | # weights[key.replace('network.', '')] = ckpt['state_dict'][key]
26 |
27 | # net.load_state_dict(weights)
28 | net = net.to(device)
29 |
30 | # method 1
31 | mode = 'all20_bn20'
32 | search_algorithm = 'cars'
33 | ea = EAMutator(net, num_population=50, algorithm=search_algorithm)
34 | # ea.load_ckpt('epoch2.pth')
35 | eval_func = lambda arch, network: net.bn_metrics().item()
36 | # eval_func = lambda arch, network: np.random.rand()
37 | # ea.search(20, eval_func, verbose=True, filling_history=True)
38 | ea.search(20, eval_func, eval_kwargs={}, verbose=True, filling_history=True)
39 | size = np.array([pool['size'] for pool in ea.history.values()])
40 | metric = np.array([pool['metric'] for pool in ea.history.values()])
41 | indices = np.argsort(size)
42 | size, metric = size[indices], metric[indices]
43 | epoch = ea.crt_epoch
44 | pareto_lists = NonDominatedSorting(np.vstack( (size.reshape(-1), 1/metric.reshape(-1)) ))
45 | pareto_indices = pareto_lists[0] # e.g., [75, 87, 113, 201, 205]
46 | ea.plot_pareto_fronts(
47 | size, metric, pareto_indices, 'model size (MB)', 'BN-based metric',
48 | figname=f'{mode}_pareto_searchepoch{epoch}_{search_algorithm}.pdf'
49 | )
50 |
51 | path = ckpt.split('checkpoints')[0]
52 | path = os.path.join(path, 'mask_json')
53 | if not os.path.exists(path):
54 | os.makedirs(path)
55 | for key, arch_info in ea.pareto_fronts.items():
56 | arch = arch_info['arch']
57 | arch_path = os.path.join(path, f'arch_{key}.json')
58 | save_arch_to_json(arch, arch_path)
59 | # method 2
60 | # ea = EAMutator(net, num_population=50, algorithm='top')
61 |
62 | # ea.start_evolve = True
63 | # ea.init_population(ea.init_population_mode)
64 | # for i in range(30):
65 | # print(f"\n\nsearch epoch {i}")
66 | # if i>0:
67 | # ea.evolve()
68 | # for j, arch in enumerate(ea.population.values()):
69 | # # ea.reset()
70 | # ea.reset_cache_mask(arch['arch'])
71 | # arch['metric'] = net.bn_metrics().item()
72 | # metrics = [pool['metric'] for pool in ea.population.values()]
73 | # metrics.sort()
74 | # print('pop',metrics)
75 | # metrics = [pool['metric'] for pool in ea.pareto_fronts.values()]
76 | # metrics.sort()
77 | # print('pare',metrics)
78 | # visited = []
79 | # for arch in ea.pareto_fronts.values():
80 | # if arch['arch_code'] not in visited: visited.append(arch['arch_code'])
81 | # print(len(visited))
--------------------------------------------------------------------------------
/hyperbox/networks/darts/__init__.py:
--------------------------------------------------------------------------------
1 | from .darts_network import *
2 |
--------------------------------------------------------------------------------
/hyperbox/networks/enas/__init__.py:
--------------------------------------------------------------------------------
1 | from .enas_network import *
2 |
--------------------------------------------------------------------------------
/hyperbox/networks/gpt/__init__.py:
--------------------------------------------------------------------------------
1 | from .gpt2 import *
2 |
--------------------------------------------------------------------------------
/hyperbox/networks/mobilenet/__init__.py:
--------------------------------------------------------------------------------
1 | from .mobile_net import MobileNet
2 | from .mobile3d_net import Mobile3DNet
--------------------------------------------------------------------------------
/hyperbox/networks/mobilenet/mobile_utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def get_parameters(model, keys=None, mode='include'):
5 | if keys is None:
6 | for name, param in model.named_parameters():
7 | yield param
8 | elif mode == 'include':
9 | for name, param in model.named_parameters():
10 | flag = False
11 | for key in keys:
12 | if key in name:
13 | flag = True
14 | break
15 | if flag:
16 | yield param
17 | elif mode == 'exclude':
18 | for name, param in model.named_parameters():
19 | flag = True
20 | for key in keys:
21 | if key in name:
22 | flag = False
23 | break
24 | if flag:
25 | yield param
26 | else:
27 | raise ValueError('do not support: %s' % mode)
28 |
29 |
30 | def get_same_padding(kernel_size):
31 | if isinstance(kernel_size, tuple):
32 | assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size
33 | p1 = get_same_padding(kernel_size[0])
34 | p2 = get_same_padding(kernel_size[1])
35 | return p1, p2
36 | assert isinstance(kernel_size, int), 'kernel size should be either `int` or `tuple`'
37 | assert kernel_size % 2 > 0, 'kernel size should be odd number'
38 | return kernel_size // 2
39 |
40 | def build_activation(act_func, inplace=True):
41 | if act_func == 'relu':
42 | return nn.ReLU(inplace=inplace)
43 | elif act_func == 'relu6':
44 | return nn.ReLU6(inplace=inplace)
45 | elif act_func == 'tanh':
46 | return nn.Tanh()
47 | elif act_func == 'sigmoid':
48 | return nn.Sigmoid()
49 | elif act_func is None:
50 | return None
51 | else:
52 | raise ValueError('do not support: %s' % act_func)
53 |
54 |
55 | def make_divisible(v, divisor, min_val=None):
56 | """
57 | This function is taken from the original tf repo.
58 | It ensures that all layers have a channel number that is divisible by 8
59 | It can be seen here:
60 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
61 | """
62 | if min_val is None:
63 | min_val = divisor
64 | new_v = max(min_val, int(v + divisor / 2) // divisor * divisor)
65 | # Make sure that round down does not go down by more than 10%.
66 | if new_v < 0.9 * v:
67 | new_v += divisor
68 | return new_v
69 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench101/__init__.py:
--------------------------------------------------------------------------------
1 | from .nasbench101 import NASBench101Network
2 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench101/base_ops.py:
--------------------------------------------------------------------------------
1 | """Base operations used by the modules in this search space."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | class ConvBnRelu(nn.Module):
12 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
13 | super(ConvBnRelu, self).__init__()
14 |
15 | self.conv_bn_relu = nn.Sequential(
16 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
17 | nn.BatchNorm2d(out_channels),
18 | nn.ReLU(inplace=True)
19 | )
20 |
21 | def forward(self, x):
22 | return self.conv_bn_relu(x)
23 |
24 | class Conv3x3BnRelu(nn.Module):
25 | """3x3 convolution with batch norm and ReLU activation."""
26 | def __init__(self, in_channels, out_channels):
27 | super(Conv3x3BnRelu, self).__init__()
28 |
29 | self.conv3x3 = ConvBnRelu(in_channels, out_channels, 3, 1, 1)
30 |
31 | def forward(self, x):
32 | x = self.conv3x3(x)
33 | return x
34 |
35 | class Conv1x1BnRelu(nn.Module):
36 | """1x1 convolution with batch norm and ReLU activation."""
37 | def __init__(self, in_channels, out_channels):
38 | super(Conv1x1BnRelu, self).__init__()
39 |
40 | self.conv1x1 = ConvBnRelu(in_channels, out_channels, 1, 1, 0)
41 |
42 | def forward(self, x):
43 | x = self.conv1x1(x)
44 | return x
45 |
46 | class MaxPool3x3(nn.Module):
47 | """3x3 max pool with no subsampling."""
48 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
49 | super(MaxPool3x3, self).__init__()
50 |
51 | self.maxpool = nn.MaxPool2d(kernel_size, stride, padding)
52 |
53 | def forward(self, x):
54 | x = self.maxpool(x)
55 | return x
56 |
57 | # Commas should not be used in op names
58 | # OP_MAP = {
59 | # 'conv3x3-bn-relu': Conv3x3BnRelu,
60 | # 'conv1x1-bn-relu': Conv1x1BnRelu,
61 | # 'maxpool3x3': MaxPool3x3
62 | # }
63 |
64 | OP_MAP = {
65 | 'conv3x3-bn-relu': lambda cin,cout: Conv3x3BnRelu(cin,cout),
66 | 'conv1x1-bn-relu': lambda cin,cout: Conv1x1BnRelu(cin,cout),
67 | 'maxpool3x3': lambda cin,cout: MaxPool3x3(cin,cout)
68 | }
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench101/db_gen/__init__.py:
--------------------------------------------------------------------------------
1 | from .constants import INPUT, OUTPUT, CONV3X3_BN_RELU, CONV1X1_BN_RELU, MAXPOOL3X3
2 | from .model import Nb101TrialStats, Nb101IntermediateStats, Nb101TrialConfig
3 | from .query import query_nb101_trial_stats
4 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench101/db_gen/constants.py:
--------------------------------------------------------------------------------
1 | INPUT = 'input'
2 | OUTPUT = 'output'
3 | CONV3X3_BN_RELU = 'conv3x3-bn-relu'
4 | CONV1X1_BN_RELU = 'conv1x1-bn-relu'
5 | MAXPOOL3X3 = 'maxpool3x3'
6 |
7 |
8 | LABEL2ID = {
9 | INPUT: -1,
10 | OUTPUT: -2,
11 | CONV3X3_BN_RELU: 0,
12 | CONV1X1_BN_RELU: 1,
13 | MAXPOOL3X3: 2
14 | }
15 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench101/db_gen/db_gen.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from tqdm import tqdm
4 | from nasbench import api # pylint: disable=import-error
5 |
6 | from hyperbox.networks.nasbench101.db_gen.model import db, Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats
7 | from hyperbox.networks.nasbench101.db_gen.graph_util import nasbench_format_to_architecture_repr, hash_module
8 |
9 |
10 | def main(args):
11 | nasbench = api.NASBench(args.inputFile)
12 | with db:
13 | db.create_tables([Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats])
14 | for hashval in tqdm(nasbench.hash_iterator(), desc='Dumping data into database'):
15 | metadata, metrics = nasbench.get_metrics_from_hash(hashval)
16 | num_vertices, architecture = nasbench_format_to_architecture_repr(
17 | metadata['module_adjacency'], metadata['module_operations'])
18 | assert hashval == hash_module(architecture, num_vertices)
19 | for epochs in [4, 12, 36, 108]:
20 | trial_config = Nb101TrialConfig.create(
21 | arch=architecture,
22 | num_vertices=num_vertices,
23 | hash=hashval,
24 | num_epochs=epochs
25 | )
26 |
27 | for seed in range(3):
28 | cur = metrics[epochs][seed]
29 | trial = Nb101TrialStats.create(
30 | config=trial_config,
31 | train_acc=cur['final_train_accuracy'] * 100,
32 | valid_acc=cur['final_validation_accuracy'] * 100,
33 | test_acc=cur['final_test_accuracy'] * 100,
34 | parameters=metadata['trainable_parameters'] / 1e6,
35 | training_time=cur['final_training_time'] * 60
36 | )
37 | for t in ['halfway', 'final']:
38 | Nb101IntermediateStats.create(
39 | trial=trial,
40 | current_epoch=epochs // 2 if t == 'halfway' else epochs,
41 | training_time=cur[t + '_training_time'],
42 | train_acc=cur[t + '_train_accuracy'] * 100,
43 | valid_acc=cur[t + '_validation_accuracy'] * 100,
44 | test_acc=cur[t + '_test_accuracy'] * 100
45 | )
46 |
47 |
48 | if __name__ == '__main__':
49 | parser = argparse.ArgumentParser()
50 | parser.add_argument('--inputFile',
51 | help='Path to the file to be converted, e.g., nasbench_full.tfrecord')
52 | args = parser.parse_args()
53 | main(args)
54 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench101/db_gen/graph_util.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 |
3 | import numpy as np
4 |
5 | from hyperbox.networks.nasbench101.db_gen.constants import INPUT, LABEL2ID, OUTPUT
6 |
7 |
8 | def _labeling_from_architecture(architecture, vertices):
9 | return [INPUT] + [architecture['op{}'.format(i)] for i in range(1, vertices - 1)] + [OUTPUT]
10 |
11 |
12 | def _adjancency_matrix_from_architecture(architecture, vertices):
13 | matrix = np.zeros((vertices, vertices), dtype=np.bool)
14 | for i in range(1, vertices):
15 | for k in architecture['input{}'.format(i)]:
16 | matrix[k, i] = 1
17 | return matrix
18 |
19 |
20 | def nasbench_format_to_architecture_repr(adjacency_matrix, labeling):
21 | """
22 | Computes a graph-invariance MD5 hash of the matrix and label pair.
23 | Imported from NAS-Bench-101 repo.
24 |
25 | Parameters
26 | ----------
27 | adjacency_matrix : np.ndarray
28 | A 2D array of shape NxN, where N is the number of vertices.
29 | ``matrix[u][v]`` is 1 if there is a direct edge from `u` to `v`,
30 | otherwise it will be 0.
31 | labeling : list of str
32 | A list of str that starts with input and ends with output. The intermediate
33 | nodes are chosen from candidate operators.
34 |
35 | Returns
36 | -------
37 | tuple and int and dict
38 | Converted number of vertices and architecture.
39 | """
40 | num_vertices = adjacency_matrix.shape[0]
41 | assert len(labeling) == num_vertices
42 | architecture = {}
43 | for i in range(1, num_vertices - 1):
44 | architecture['op{}'.format(i)] = labeling[i]
45 | assert labeling[i] not in [INPUT, OUTPUT]
46 | for i in range(1, num_vertices):
47 | architecture['input{}'.format(i)] = [k for k in range(i) if adjacency_matrix[k, i]]
48 | return num_vertices, architecture
49 |
50 |
51 | def infer_num_vertices(architecture):
52 | """
53 | Infer number of vertices from an architecture dict.
54 |
55 | Parameters
56 | ----------
57 | architecture : dict
58 | Architecture in NNI format.
59 |
60 | Returns
61 | -------
62 | int
63 | Number of vertices.
64 | """
65 | op_keys = set([k for k in architecture.keys() if k.startswith('op')])
66 | intermediate_vertices = len(op_keys)
67 | assert op_keys == {'op{}'.format(i) for i in range(1, intermediate_vertices + 1)}
68 | return intermediate_vertices + 2
69 |
70 |
71 | def hash_module(architecture, vertices):
72 | """
73 | Computes a graph-invariance MD5 hash of the matrix and label pair.
74 | This snippet is modified from code in NAS-Bench-101 repo.
75 |
76 | Parameters
77 | ----------
78 | matrix : np.ndarray
79 | Square upper-triangular adjacency matrix.
80 | labeling : list of int
81 | Labels of length equal to both dimensions of matrix.
82 |
83 | Returns
84 | -------
85 | str
86 | MD5 hash of the matrix and labeling.
87 | """
88 | labeling = _labeling_from_architecture(architecture, vertices)
89 | labeling = [LABEL2ID[t] for t in labeling]
90 | matrix = _adjancency_matrix_from_architecture(architecture, vertices)
91 | in_edges = np.sum(matrix, axis=0).tolist()
92 | out_edges = np.sum(matrix, axis=1).tolist()
93 |
94 | assert len(in_edges) == len(out_edges) == len(labeling)
95 | hashes = list(zip(out_edges, in_edges, labeling))
96 | hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes]
97 | # Computing this up to the diameter is probably sufficient but since the
98 | # operation is fast, it is okay to repeat more times.
99 | for _ in range(vertices):
100 | new_hashes = []
101 | for v in range(vertices):
102 | in_neighbors = [hashes[w] for w in range(vertices) if matrix[w, v]]
103 | out_neighbors = [hashes[w] for w in range(vertices) if matrix[v, w]]
104 | new_hashes.append(hashlib.md5(
105 | (''.join(sorted(in_neighbors)) + '|' +
106 | ''.join(sorted(out_neighbors)) + '|' +
107 | hashes[v]).encode('utf-8')).hexdigest())
108 | hashes = new_hashes
109 | fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest()
110 |
111 | return fingerprint
112 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench101/db_gen/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import functools
3 | import json
4 | from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
5 | from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
6 |
7 |
8 |
9 | json_dumps = functools.partial(json.dumps, sort_keys=True)
10 | DATABASE_DIR = os.environ.get("NASBENCHMARK_DIR", os.path.expanduser("~/.hyperbox/nasbenchmark"))
11 | db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench101.db'), autoconnect=True)
12 |
13 |
14 | class Nb101TrialConfig(Model):
15 | """
16 | Trial config for NAS-Bench-101.
17 |
18 | Attributes
19 | ----------
20 | arch : dict
21 | A dict with keys ``op1``, ``op2``, ... and ``input1``, ``input2``, ... Vertices are
22 | enumerate from 0. Since node 0 is input node, it is skipped in this dict. Each ``op``
23 | is one of :const:`nni.nas.benchmark.nasbench101.CONV3X3_BN_RELU`,
24 | :const:`nni.nas.benchmark.nasbench101.CONV1X1_BN_RELU`, and :const:`nni.nas.benchmark.nasbench101.MAXPOOL3X3`.
25 | Each ``input`` is a list of previous nodes. For example ``input5`` can be ``[0, 1, 3]``.
26 | num_vertices : int
27 | Number of vertices (nodes) in one cell. Should be less than or equal to 7 in default setup.
28 | hash : str
29 | Graph-invariant MD5 string for this architecture.
30 | num_epochs : int
31 | Number of epochs planned for this trial. Should be one of 4, 12, 36, 108 in default setup.
32 | """
33 |
34 | arch = JSONField(index=True)
35 | num_vertices = IntegerField(index=True)
36 | hash = CharField(max_length=64, index=True)
37 | num_epochs = IntegerField(index=True)
38 |
39 | class Meta:
40 | database = db
41 |
42 |
43 | class Nb101TrialStats(Model):
44 | """
45 | Computation statistics for NAS-Bench-101. Each corresponds to one trial.
46 | Each config has multiple trials with different random seeds, but unfortunately seed for each trial is unavailable.
47 | NAS-Bench-101 trains and evaluates on CIFAR-10 by default. The original training set is divided into
48 | 40k training images and 10k validation images, and the original validation set is used for test only.
49 |
50 | Attributes
51 | ----------
52 | config : Nb101TrialConfig
53 | Setup for this trial data.
54 | train_acc : float
55 | Final accuracy on training data, ranging from 0 to 100.
56 | valid_acc : float
57 | Final accuracy on validation data, ranging from 0 to 100.
58 | test_acc : float
59 | Final accuracy on test data, ranging from 0 to 100.
60 | parameters : float
61 | Number of trainable parameters in million.
62 | training_time : float
63 | Duration of training in seconds.
64 | """
65 | config = ForeignKeyField(Nb101TrialConfig, backref='trial_stats', index=True)
66 | train_acc = FloatField()
67 | valid_acc = FloatField()
68 | test_acc = FloatField()
69 | parameters = FloatField()
70 | training_time = FloatField()
71 |
72 | class Meta:
73 | database = db
74 |
75 |
76 | class Nb101IntermediateStats(Model):
77 | """
78 | Intermediate statistics for NAS-Bench-101.
79 |
80 | Attributes
81 | ----------
82 | trial : Nb101TrialStats
83 | The exact trial where the intermediate result is produced.
84 | current_epoch : int
85 | Elapsed epochs when evaluation is done.
86 | train_acc : float
87 | Intermediate accuracy on training data, ranging from 0 to 100.
88 | valid_acc : float
89 | Intermediate accuracy on validation data, ranging from 0 to 100.
90 | test_acc : float
91 | Intermediate accuracy on test data, ranging from 0 to 100.
92 | training_time : float
93 | Time elapsed in seconds.
94 | """
95 |
96 | trial = ForeignKeyField(Nb101TrialStats, backref='intermediates', index=True)
97 | current_epoch = IntegerField(index=True)
98 | train_acc = FloatField()
99 | valid_acc = FloatField()
100 | test_acc = FloatField()
101 | training_time = FloatField()
102 |
103 | class Meta:
104 | database = db
105 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench101/db_gen/query.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | from peewee import fn
4 | from playhouse.shortcuts import model_to_dict
5 | from hyperbox.networks.nasbench101.db_gen.model import Nb101TrialStats, Nb101TrialConfig
6 | from hyperbox.networks.nasbench101.db_gen.graph_util import hash_module, infer_num_vertices
7 |
8 |
9 | def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None):
10 | """
11 | Query trial stats of NAS-Bench-101 given conditions.
12 |
13 | Parameters
14 | ----------
15 | arch : dict or None
16 | If a dict, it is in the format that is described in
17 | :class:`nni.nas.benchmark.nasbench101.Nb101TrialConfig`. Only trial stats
18 | matched will be returned. If none, architecture will be a wildcard.
19 | num_epochs : int or None
20 | If int, matching results will be returned. Otherwise a wildcard.
21 | isomorphism : boolean
22 | Whether to match essentially-same architecture, i.e., architecture with the
23 | same graph-invariant hash value.
24 | reduction : str or None
25 | If 'none' or None, all trial stats will be returned directly.
26 | If 'mean', fields in trial stats will be averaged given the same trial config.
27 |
28 | Returns
29 | -------
30 | generator of dict
31 | A generator of :class:`nni.nas.benchmark.nasbench101.Nb101TrialStats` objects,
32 | where each of them has been converted into a dict.
33 | """
34 | fields = []
35 | if reduction == 'none':
36 | reduction = None
37 | if reduction == 'mean':
38 | for field_name in Nb101TrialStats._meta.sorted_field_names:
39 | if field_name not in ['id', 'config']:
40 | fields.append(fn.AVG(getattr(Nb101TrialStats, field_name)).alias(field_name))
41 | elif reduction is None:
42 | fields.append(Nb101TrialStats)
43 | else:
44 | raise ValueError('Unsupported reduction: \'%s\'' % reduction)
45 | query = Nb101TrialStats.select(*fields, Nb101TrialConfig).join(Nb101TrialConfig)
46 | conditions = []
47 | if arch is not None:
48 | if isomorphism:
49 | num_vertices = infer_num_vertices(arch)
50 | conditions.append(Nb101TrialConfig.hash == hash_module(arch, num_vertices))
51 | else:
52 | conditions.append(Nb101TrialConfig.arch == arch)
53 | if num_epochs is not None:
54 | conditions.append(Nb101TrialConfig.num_epochs == num_epochs)
55 | if conditions:
56 | query = query.where(functools.reduce(lambda a, b: a & b, conditions))
57 | if reduction is not None:
58 | query = query.group_by(Nb101TrialStats.config)
59 | for k in query:
60 | yield model_to_dict(k)
61 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench101/readme.md:
--------------------------------------------------------------------------------
1 | > NASBench101 does not support weight-sharing training
2 |
3 | 1. install `nasbench-101`
4 |
5 | - for tensorflow 1.x
6 | ```bash
7 | pip install git+https://github.com/google-research/nasbench.git
8 | ```
9 |
10 | - for tensorflow 2.x
11 | ```bash
12 | pip install git+https://github.com/gabikadlecova/nasbench.git
13 | pip install protobuf==3.20.0
14 | ```
15 |
16 | 2. downlowd tfrecord
17 |
18 | ```
19 | wget https://storage.googleapis.com/nasbench/nasbench_full.tfrecord
20 | wget https://storage.googleapis.com/nasbench/nasbench_only_108.tfrecord
21 | ```
22 |
23 | - nasbench_full.tfrecord includes the results on epoch of 4, 12, 36, 108
24 | - nasbench_only_108.tfrecord only includes the result on epoch of 108
25 |
26 | 3. convert tfrecord to database
27 |
28 | ```
29 | python db_gen.py --inputFile /path/to/nasbench_full.tfrecord
30 | ```
31 |
32 | # Thanks
33 | - https://github.com/romulus0914/NASBench-PyTorch/blob/master/main.py
34 | - https://nni.readthedocs.io/en/v1.7/NAS/BenchmarksExample.html
35 | - https://github.com/microsoft/nni/blob/v1.7/src/sdk/pynni/nni/nas/benchmarks/nasbench101/__init__.py
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench201/__init__.py:
--------------------------------------------------------------------------------
1 | from .nasbench201 import NASBench201Network
2 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench201/db_gen/__init__.py:
--------------------------------------------------------------------------------
1 | from .constants import NONE, SKIP_CONNECT, CONV_1X1, CONV_3X3, AVG_POOL_3X3, PRIMITIVES
2 | from .model import Nb201TrialStats, Nb201IntermediateStats, Nb201TrialConfig
3 | from .query import query_nb201_trial_stats
4 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench201/db_gen/constants.py:
--------------------------------------------------------------------------------
1 | NONE = "none"
2 | SKIP_CONNECT = "skip_connect"
3 | CONV_1X1 = "conv_1x1"
4 | CONV_3X3 = "conv_3x3"
5 | AVG_POOL_3X3 = "avg_pool_3x3"
6 |
7 | PRIMITIVES = [NONE, AVG_POOL_3X3, CONV_3X3, CONV_1X1, SKIP_CONNECT]
8 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench201/db_gen/query.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | from peewee import fn
4 | from playhouse.shortcuts import model_to_dict
5 | from hyperbox.networks.nasbench201.db_gen.model import Nb201TrialStats, Nb201TrialConfig
6 |
7 |
8 | def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_intermediates=False):
9 | """
10 | Query trial stats of NAS-Bench-201 given conditions.
11 | Parameters
12 | ----------
13 | arch : dict or None
14 | If a dict, it is in the format that is described in
15 | :class:`nni.nas.benchmark.nasbench201.Nb201TrialConfig`. Only trial stats
16 | matched will be returned. If none, all architectures in the database will be matched.
17 | num_epochs : int or None
18 | If int, matching results will be returned. Otherwise a wildcard.
19 | dataset : str or None
20 | If specified, can be one of the dataset available in :class:`nni.nas.benchmark.nasbench201.Nb201TrialConfig`.
21 | Otherwise a wildcard.
22 | reduction : str or None
23 | If 'none' or None, all trial stats will be returned directly.
24 | If 'mean', fields in trial stats will be averaged given the same trial config.
25 | include_intermediates : boolean
26 | If true, intermediate results will be returned.
27 | Returns
28 | -------
29 | generator of dict
30 | A generator of :class:`nni.nas.benchmark.nasbench201.Nb201TrialStats` objects,
31 | where each of them has been converted into a dict.
32 | """
33 | fields = []
34 | if reduction == 'none':
35 | reduction = None
36 | if reduction == 'mean':
37 | for field_name in Nb201TrialStats._meta.sorted_field_names:
38 | if field_name not in ['id', 'config', 'seed']:
39 | fields.append(fn.AVG(getattr(Nb201TrialStats, field_name)).alias(field_name))
40 | elif reduction is None:
41 | fields.append(Nb201TrialStats)
42 | else:
43 | raise ValueError('Unsupported reduction: \'%s\'' % reduction)
44 | query = Nb201TrialStats.select(*fields, Nb201TrialConfig).join(Nb201TrialConfig)
45 | conditions = []
46 | if arch is not None:
47 | conditions.append(Nb201TrialConfig.arch == arch)
48 | if num_epochs is not None:
49 | conditions.append(Nb201TrialConfig.num_epochs == num_epochs)
50 | if dataset is not None:
51 | conditions.append(Nb201TrialConfig.dataset == dataset)
52 | if conditions:
53 | query = query.where(functools.reduce(lambda a, b: a & b, conditions))
54 | if reduction is not None:
55 | query = query.group_by(Nb201TrialStats.config)
56 | for trial in query:
57 | if include_intermediates:
58 | data = model_to_dict(trial)
59 | # exclude 'trial' from intermediates as it is already available in data
60 | data['intermediates'] = [
61 | {k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
62 | ]
63 | yield data
64 | else:
65 | yield model_to_dict(trial)
66 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench201/gen_nasbench201_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | pip install peewee
5 | pip install gdown
6 |
7 | if [ -z "${NASBENCHMARK_DIR}" ]; then
8 | NASBENCHMARK_DIR=~/.hyperbox/nasbench201
9 | fi
10 |
11 | echo "Downloading NAS-Bench-201..."
12 | if [ -f "nb201.pth" ]; then
13 | echo "nb201.pth found. Skip download."
14 | else
15 | gdown https://drive.google.com/uc\?id\=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_ -O nb201.pth
16 | fi
17 |
18 | echo "Generating database..."
19 | rm -f ${NASBENCHMARK_DIR}/nasbench201.db ${NASBENCHMARK_DIR}/nasbench201.db-journal
20 | mkdir -p ${NASBENCHMARK_DIR}
21 | python hyperbox/networks/nasbench201/db_gen/db_gen.py nb201.pth
22 | # python hyperbox.networks.nasbench201.db_gen.db_gen nb201.pth
23 | # rm -f nb201.pth
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench301/__init__.py:
--------------------------------------------------------------------------------
1 | from .nasbench301 import NASBench301Network
2 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench301/install/readme.md:
--------------------------------------------------------------------------------
1 | ## NASBench301 Installation Guidence
2 |
3 | ## Environments
4 |
5 | ubuntu16.04
6 | cuda11.1
7 | torch=1.8+cu111
8 | python=3.7.0
9 |
10 | 如果是cuda 10.2,执行以下命令:
11 |
12 | ```
13 | cat requirements_cuda102.txt | xargs -n 1 -L 1 pip install
14 | ```
15 |
16 | 如果是cuda11.1,执行以下命令:
17 | ```
18 | cat requirements_cuda111.txt | xargs -n 1 -L 1 pip install
19 | ```
20 |
21 | ## Install NB301
22 |
23 | ```
24 | $ git clone https://github.com/automl/nasbench301
25 | $ cd nasbench301
26 | $ pip install .
27 | ```
28 |
29 | ## Download weights
30 |
31 | v0.9: [nasbench301_models_v0.9_zip](https://figshare.com/articles/software/nasbench301_models_v0_9_zip/12962432)
32 |
33 | v1.0: [nasbench301_models_v1_0_zip](https://figshare.com/articles/software/nasbench301_models_v1_0_zip/13061510)
34 |
35 | 运行官方demo:
36 |
37 | ```
38 | cd nasbench301/nasbench301
39 | unzip nasbench301_models_v1.0.zip
40 | mv nb_models nb_models_1.0
41 | python example.py
42 | ```
43 |
44 | 运行结果:
45 |
46 | ```
47 | ==> Loading performance surrogate model...
48 | /home/pdluser/project/nasbench301/nasbench301/nb_models_1.0/xgb_v1.0
49 | ==> Loading runtime surrogate model...
50 | ==> Creating test configs...
51 | ==> Predict runtime and performance...
52 | Genotype architecture performance: 94.167947, runtime 4781.153014
53 | Configspace architecture performance: 91.834275, runtime 4285.378243
54 | ```
55 |
56 | ## Using NB301 in Hyperbox
57 |
58 | replace `default_path` to your download path `/path/to/your/downloads/nb_models_1.0`
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench301/install/requirements_cu102.txt:
--------------------------------------------------------------------------------
1 | # for cuda10.2
2 | xgboost
3 | lightgbm
4 | pathvalidate
5 | ConfigSpace
6 | git+https://github.com/automl/nasbench301
7 | torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.1.html
8 | torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.1.html
9 | torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.1.html
10 | torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.1.html
11 | torch-geometric -f https://pytorch-geometric.com/whl/torch-1.8.1.html
12 |
13 | # for cuda10.2
14 | git+https://github.com/automl/Auto-PyTorch.git@nb301
15 | autograd>=1.3
16 | click
17 | Cython
18 | ConfigSpace==0.4.12
19 | fasteners
20 | ipython
21 | imblearn
22 | lightgbm>=2.3.1
23 | matplotlib
24 | netifaces
25 | numpy
26 | pandas
27 | pathvalidate
28 | Pillow>=7.1.2
29 | psutil
30 | pynisher
31 | Pyro4
32 | scikit-image
33 | scipy
34 | scikit-learn==0.23.0
35 | seaborn
36 | setuptools
37 | serpent
38 | statsmodels
39 | tensorboard==1.14.0
40 | tensorflow-estimator
41 | tensorflow-gpu
42 | tensorboard_logger
43 | torch>=1.5.0
44 | torchvision>=0.6.0
45 | timm
46 | tqdm
47 | xgboost
48 | torch-scatter==2.0.4+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
49 | torch-sparse==0.6.3+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
50 | torch-cluster==1.5.5+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
51 | torch-spline-conv==1.2.0+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
52 | torch-geometric
53 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench301/utils.py:
--------------------------------------------------------------------------------
1 | # author: pprp
2 | # date: 2021-10-31
3 | # function: default modelv1.0
4 | # install:
5 | '''
6 | git clone https://github.com/automl/nasbench301.git
7 | cd nasbench301
8 | cat requirements.txt | xargs -n 1 -L 1 pip install
9 | pip install .
10 | '''
11 |
12 | import os
13 | from collections import namedtuple
14 |
15 | import nasbench301 as nb
16 | from ConfigSpace.read_and_write import json as cs_json
17 | from nasbench301.surrogate_models import ensemble
18 |
19 | default_version = "1.0"
20 |
21 | default_path = "~/.hyperbox/nb_models"
22 |
23 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
24 |
25 |
26 | def load_model():
27 | model_paths = {
28 | model_name: os.path.join(default_path, '{}_v1.0'.format(model_name))
29 | for model_name in ['xgb', 'gnn_gin', 'lgb_runtime']
30 | }
31 | print("==> Loading performance surrogate model...")
32 | ensemble_dir_performance = model_paths['xgb']
33 | performance_model = nb.load_ensemble(ensemble_dir_performance)
34 |
35 | print("==> Loading runtime surrogate model...")
36 | ensemble_dir_runtime = model_paths['lgb_runtime']
37 | runtime_model = nb.load_ensemble(ensemble_dir_runtime)
38 | return performance_model, runtime_model
39 |
40 |
41 | def generate_results(genotype_config):
42 | # pmodel: prediction model
43 | # rmodel: runtime model
44 | pmodel, rmodel = load_model()
45 | print("==> Creating test configs...")
46 | genotype_config = Genotype(
47 | normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1),
48 | ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)],
49 | normal_concat=[2, 3, 4, 5],
50 | reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1),
51 | ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)],
52 | reduce_concat=[2, 3, 4, 5]
53 | ) if genotype_config is None else genotype_config
54 |
55 | prediction_genotype = pmodel.predict(
56 | config=genotype_config, representation="genotype", with_noise=True)
57 | runtime_genotype = rmodel.predict(
58 | config=genotype_config, representation="genotype")
59 |
60 | return prediction_genotype, runtime_genotype
61 |
62 |
63 | def sample_config():
64 | configspace_path = './hyperbox/networks/nasbench301/configspace.json'
65 | with open(configspace_path, "r") as f:
66 | json_string = f.read()
67 | configspace = cs_json.read(json_string)
68 | configspace_config = configspace.sample_configuration()
69 | return configspace_config
70 |
71 |
72 | def test1(genotype_config=None):
73 | p, r = generate_results(genotype_config)
74 | print("Genotype architecture performance: %f, runtime %f" % (p, r))
75 |
76 |
77 | def test2():
78 | # pmodel: prediction model
79 | # rmodel: runtime model
80 | pmodel, rmodel = load_model()
81 | configspace_config = sample_config()
82 | prediction_configspace = pmodel.predict(
83 | config=configspace_config, representation="configspace", with_noise=True)
84 | runtime_configspace = rmodel.predict(
85 | config=configspace_config, representation="configspace")
86 | print("Configspace architecture performance: %f, runtime %f" %
87 | (prediction_configspace, runtime_configspace))
88 |
89 |
90 | if __name__ == "__main__":
91 | test1()
92 | test2()
93 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbench_mbnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/networks/nasbench_mbnet/__init__.py
--------------------------------------------------------------------------------
/hyperbox/networks/nasbenchasr/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import NASBenchASR
2 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbenchasr/download_nasbenchasr.sh:
--------------------------------------------------------------------------------
1 | if [ -z "${NASBENCHMARK_DIR}" ]; then
2 | NASBENCHMARK_DIR=~/.hyperbox/nasbenchasr/
3 | fi
4 |
5 | echo "Downloading NAS-Bench-ASR..."
6 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-bench-gtx-1080ti-fp32.pickle
7 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-bench-jetson-nano-fp32.pickle
8 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e10-1234.pickle
9 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e40-1234.pickle
10 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e40-1235.pickle
11 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e40-1236.pickle
12 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e5-1234.pickle
13 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-info.pickle
14 |
15 |
16 | mkdir -p ${NASBENCHMARK_DIR}
17 | mv nb-asr*.pickle ${NASBENCHMARK_DIR}
--------------------------------------------------------------------------------
/hyperbox/networks/nasbenchasr/ops.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class PadConvRelu(nn.Module):
8 | def __init__(self, in_channels, out_channels, kernel_size, dilation, strides, groups=1, dropout_rate=0, context=4, name='PadConvRelu'):
9 | super().__init__()
10 | self.name = name
11 |
12 | if int(context / strides) >= (kernel_size*dilation-strides):
13 | rpad = kernel_size*dilation-strides
14 | lpad = 0
15 | else:
16 | rpad = int(context / strides)
17 | lpad = int((kernel_size - 1)*dilation - rpad)
18 |
19 | self.pad = nn.ZeroPad2d((lpad, rpad, 0, 0))
20 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=strides, dilation=dilation, groups=groups)
21 | self.relu = nn.ReLU(inplace=False)
22 | self.dropout = nn.Dropout(p=dropout_rate)
23 |
24 | def forward(self, x):
25 | x = self.pad(x)
26 | x = self.conv(x)
27 | x = self.relu(x)
28 | x = torch.clamp_max_(x, 20)
29 | x = self.dropout(x)
30 | return x
31 |
32 |
33 | class Linear(nn.Module):
34 | def __init__(self, in_features, out_features, dropout_rate=0, name='Linear'):
35 | super().__init__()
36 | self.name = name
37 |
38 | self.linear = nn.Linear(in_features, out_features)
39 | self.relu = nn.ReLU(inplace=False)
40 | self.dropout = nn.Dropout(p=dropout_rate)
41 |
42 | def forward(self, x):
43 | shape = x.shape
44 | x = x.permute(0,2,1)
45 | x = self.linear(x)
46 | x = self.relu(x)
47 | x = torch.clamp_max_(x, 20)
48 | x = self.dropout(x)
49 | x = x.permute(0,2,1)
50 | return x
51 |
52 |
53 | class Identity(nn.Module):
54 | def __init__(self, name='Identity'):
55 | super().__init__()
56 | self.name = name
57 |
58 | def forward(self, x):
59 | return x
60 |
61 |
62 | class Zero(nn.Module):
63 | def __init__(self, name='Zero'):
64 | super(Zero, self).__init__()
65 | self.name = name
66 |
67 | def forward(self, x):
68 | return torch.zeros_like(x)
69 |
70 |
71 | _ops = {
72 | 'linear': Linear,
73 | 'conv5': functools.partial(PadConvRelu, kernel_size=5, dilation=1, strides=1, groups=100, name='conv5'),
74 | 'conv5d2': functools.partial(PadConvRelu, kernel_size=5, dilation=2, strides=1, groups=100, name='conv52d'),
75 | 'conv7': functools.partial(PadConvRelu, kernel_size=7, dilation=1, strides=1, groups=100, name='conv7'),
76 | 'conv7d2': functools.partial(PadConvRelu, kernel_size=7, dilation=2, strides=1, groups=100, name='conv52d'),
77 | 'zero': lambda *args, **kwargs: Zero(name='zero')
78 | }
79 |
80 | _branch_ops = {
81 | 0: Zero, # branch not present
82 | 1: Identity # branch present
83 | }
84 |
85 |
--------------------------------------------------------------------------------
/hyperbox/networks/nasbenchasr/readme.md:
--------------------------------------------------------------------------------
1 | # NAS-Bench-ASR
2 |
3 | ## 1. Download datasets
4 |
5 | ```bash
6 | if [ -z "${NASBENCHMARK_DIR}" ]; then
7 | NASBENCHMARK_DIR=~/.hyperbox/nasbenchasr/
8 | fi
9 |
10 | echo "Downloading NAS-Bench-ASR..."
11 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-bench-gtx-1080ti-fp32.pickle
12 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-bench-jetson-nano-fp32.pickle
13 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e10-1234.pickle
14 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e40-1234.pickle
15 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e40-1235.pickle
16 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e5-1234.pickle
17 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-info.pickle
18 |
19 | mkdir -p ${NASBENCHMARK_DIR}
20 | mv nb-asr*.pickle ${NASBENCHMARK_DIR}
21 | ```
22 |
23 | ## 2. Usage
24 |
25 | ```python
26 | from hyperbox.mutator import RandomMutator
27 | model = NASBenchASR()
28 | print(sum([p.numel() for p in model.parameters()]))
29 | rm = RandomMutator(model)
30 | rm.reset()
31 |
32 | B, F, T = 2, 80, 30
33 | for T in [16]:
34 | x = torch.rand(B, F, T)
35 | rm.reset()
36 | y = model(x)
37 | print(y.shape)
38 | # print(rm._cache, len(rm._cache))
39 |
40 |
41 | list_desc = [
42 | ['linear', 1],
43 | ['conv5', 1, 0],
44 | ['conv7d2', 1, 0, 1],
45 | ]
46 | mask = NASBenchASR.list_desc_to_dict_mask(list_desc)
47 | # print(mask)
48 | model2 = NASBenchASR(mask=mask)
49 | print(sum([p.numel() for p in model2.parameters()]))
50 | print(model2.arch_size((B, F, T)))
51 | y = model(x)
52 | print(NASBenchASR.dict_mask_to_list_desc(mask))
53 |
54 | print(model2.query_full_info())
55 | print(model2.query_flops())
56 | print(model2.query_latency())
57 | print(model2.query_params())
58 | print(model2.query_test_acc())
59 | print(model2.query_val_acc())
60 | ```
61 |
62 | 输出结果(带有随机性):
63 | ```bash
64 | 84867649
65 | torch.Size([2, 4, 49])
66 | 39733249
67 | (181.789408, 151.5703125)
68 | [['linear', 1], ['conv5', 1, 0], ['conv7d2', 1, 0, 1]]
69 | {'val_per': [0.9687853, 0.87188685, 0.7872086, 0.6666002, 0.5817228, 0.50474864, 0.4540081, 0.4089128, 0.3726506, 0.35265988, 0.33154014, 0.30796307, 0.29800093, 0.28405392, 0.27661553, 0.270439, 0.26074252, 0.2551637, 0.25124526, 0.2481238, 0.24918643, 0.24540082, 0.24041975, 0.2369662, 0.2374311, 0.2337119, 0.2337119, 0.23391114, 0.2327821, 0.22939497, 0.22893007, 0.22879724, 0.22753537, 0.22992627, 0.22985986, 0.22839876, 0.22461313, 0.22381617, 0.22275354, 0.22328486], 'test_per': 0.24767844378948212, 'arch_vec': [(0, 1), (1, 1, 0), (4, 1, 0, 1)], 'model_hash': 'adb47992d93622245376905cc956a149', 'seed': 1236, 'jetson-nano-fp32': {'latency': 0.578345775604248}, 'gtx-1080ti-fp32': {'latency': 0.04792499542236328}}
70 | 3845877266
71 | {'jetson-nano-fp32': {'latency': 0.578345775604248}, 'gtx-1080ti-fp32': {'latency': 0.04792499542236328}}
72 | 43100448
73 | 0.24906444549560547
74 | 0.22275354
75 | ```
76 |
77 | # Acknowledge
78 | The code is based on https://github.com/SamsungLabs/nb-asr
--------------------------------------------------------------------------------
/hyperbox/networks/nasbenchasr/search_space.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from hyperbox.networks.nasbenchasr.utils import recursive_iter, flatten, copy_structure
4 |
5 |
6 | all_ops = ['linear', 'conv5', 'conv5d2', 'conv7', 'conv7d2', 'zero']
7 | ops_no_zero = all_ops[:-1]
8 | default_nodes = 3
9 |
10 |
11 | def get_search_space(ops=None, nodes=None):
12 | ''' Return boundaries of the search space for the given list
13 | of available operations and number of nodes.
14 | '''
15 | ops = ops if ops is not None else all_ops
16 | nodes = nodes if nodes is not None else default_nodes
17 | search_space = [[len(ops)] + [2]*(idx+1) for idx in range(nodes)]
18 | return search_space
19 |
20 |
21 | def get_model_hash(arch_vec, ops=None, minimize=True):
22 | ''' Get hash of the architecture specified by arch_vec.
23 | Architecture hash can be used to determine if two
24 | configurations from the search space are in fact the
25 | same (graph isomorphism).
26 | '''
27 | from .graph_utils import get_model_graph, graph_hash
28 | g, _ = get_model_graph(arch_vec, ops=ops, minimize=minimize)
29 | return graph_hash(g)
30 |
31 |
32 | def get_all_architectures(ops=None, nodes=None):
33 | ''' Yields all architecture configurations in the search space
34 | '''
35 | search_space = get_search_space(ops, nodes)
36 | flat = flatten(search_space)
37 | cfg = [0 for _ in range(len(flat))]
38 | end = False
39 | while not end:
40 | yield copy_structure(cfg, search_space)
41 | for dim in range(len(flat)):
42 | cfg[dim] += 1
43 | if cfg[dim] != flat[dim]:
44 | break
45 | cfg[dim] = 0
46 | if dim+1 >= len(flat):
47 | end = True
48 |
49 |
50 | def get_random_architectures(num, ops=None, nodes=None, seed=None):
51 | ''' Get random architecture configurations from the search space
52 | '''
53 | ops = ops if ops is not None else all_ops
54 | nodes = nodes if nodes is not None else default_nodes
55 | if seed is not None:
56 | random.seed(seed)
57 | search_space = [[len(ops)] + [2]*(idx+1) for idx in range(nodes)]
58 | flat = flatten(search_space)
59 | models = []
60 | while len(models) < num:
61 | m = [random.randrange(opts) for opts in flat]
62 | m = copy_structure(m, search_space)
63 | models.append(m)
64 | return models
65 |
66 |
67 | def get_archs_with_zero():
68 | models_with_zero = {}
69 | for m in get_all_architectures(all_ops, default_nodes):
70 | if 5 in flatten(m):
71 | h = get_model_hash(m)
72 | models_with_zero[h] = m
73 | new_model_archs = [models_with_zero[k] for k in sorted(models_with_zero.keys())]
74 | return new_model_archs
75 |
76 |
77 | def arch_vec_to_names(arch_vec, ops=None):
78 | ''' Translates identifiers of operations in ``arch_vec`` to their names.
79 | ``ops`` can be provided externally to avoid relying on the current definition
80 | of available ops. Otherwise canonical ``all_ops`` will be used.
81 | '''
82 |
83 | if ops is None:
84 | ops = all_ops
85 |
86 | # current approach is to have an arch vector contain sub-vectors for node in a cell,
87 | # each subvector has a form of:
88 | # [op_idx, branch_op_idx...]
89 | # where op_idx points to an operation from ``all_ops`` and ``branch_op_idx`` is
90 | # either 0 (no skip connection) or 1 (identity skip connection)
91 | # since skip connects are already quite self-explanatory we leave them as they are
92 | # and only change numbers of the main operations to their respective names
93 | return [[all_ops[op_idx]] + branches for op_idx, *branches in arch_vec]
94 |
--------------------------------------------------------------------------------
/hyperbox/networks/network_ema.py:
--------------------------------------------------------------------------------
1 | """ Exponential Moving Average (EMA) of model updates
2 | Hacked together by / Copyright 2020 Ross Wightman
3 |
4 | source code: https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/model_ema.py
5 | """
6 | from collections import OrderedDict
7 | from copy import deepcopy
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 | from hyperbox.networks.base_nas_network import BaseNASNetwork
13 |
14 |
15 | class ModelEma(nn.Module):
16 | """ Model Exponential Moving Average V2
17 | Keep a moving average of everything in the model state_dict (parameters and buffers).
18 | V2 of this module is simpler, it does not match params/buffers based on name but simply
19 | iterates in order. It works with torchscript (JIT of full model).
20 | This is intended to allow functionality like
21 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
22 | A smoothed version of the weights is necessary for some training schemes to perform well.
23 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
24 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
25 | smoothing of weights to match results. Pay attention to the decay constant you are using
26 | relative to your update count per epoch.
27 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
28 | disable validation of the EMA weights. Validation will have to be done manually in a separate
29 | process, or after the training stops converging.
30 | This class is sensitive where it is initialized in the sequence of model init,
31 | GPU assignment and distributed training wrappers.
32 | """
33 | def __init__(self, model, decay=0.7, final_decay=0.999, device=None):
34 | super(ModelEma, self).__init__()
35 | # make a copy of the model for accumulating moving average of weights
36 | if isinstance(model, BaseNASNetwork):
37 | self.module = model.copy()
38 | else:
39 | self.module = deepcopy(model)
40 | self.module.eval()
41 | self.decay = decay
42 | self.init_decay = decay
43 | self.final_decay = final_decay
44 | self.device = device # perform ema on different device from model if set
45 | if self.device is not None:
46 | self.module.to(device=device)
47 |
48 | def _update(self, model, update_fn):
49 | with torch.no_grad():
50 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
51 | if self.device is not None:
52 | model_v = model_v.to(device=self.device)
53 | ema_v.copy_(update_fn(ema_v, model_v))
54 |
55 | def update(self, model):
56 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
57 |
58 | def set(self, model):
59 | self._update(model, update_fn=lambda e, m: m)
60 |
61 | def update_decay(self, epoch, all_epoch=100):
62 | self.decay = self.init_decay + (epoch/all_epoch)*(self.final_decay-self.init_decay)
63 |
64 | def forward(self, *args, **kwargs):
65 | return self.module(*args, **kwargs)
66 |
67 |
68 | if __name__ == '__main__':
69 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
70 | from hyperbox.networks.ofa import OFAMobileNetV3
71 | from hyperbox.mutator import RandomMutator
72 |
73 | supernet = OFAMobileNetV3(num_classes=10, width_mult=1.0).to(device).eval()
74 | mutator = RandomMutator(supernet)
75 | # ema = ModelEma(supernet, 0.9)
76 | ema = ModelEma(supernet, 0.9).eval()
77 | print('init\nsupernet', supernet.classifier.weight.data[:2,:8])
78 | print('ema', ema.module.classifier.weight.data[:2,:8])
79 | for i in range(10):
80 | mutator.reset()
81 | supernet.init_weights()
82 | ema.update(supernet)
83 | x = torch.rand(2,3,64,64).to(device)
84 | y1 = supernet(x)
85 | subnet = ema.module.build_subnet(mutator._cache).cuda().eval()
86 | y2 = subnet(x)
87 | print('supernet', supernet.classifier.weight.data[:2,:8])
88 | print('ema', ema.module.classifier.weight.data[:2,:8])
89 |
--------------------------------------------------------------------------------
/hyperbox/networks/ofa/__init__.py:
--------------------------------------------------------------------------------
1 | from .ofa_mbv3 import OFAMobileNetV3
2 |
--------------------------------------------------------------------------------
/hyperbox/networks/proxylessnas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/networks/proxylessnas/__init__.py
--------------------------------------------------------------------------------
/hyperbox/networks/proxylessnas/putils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | def get_parameters(model, keys=None, mode='include'):
4 | if keys is None:
5 | for name, param in model.named_parameters():
6 | yield param
7 | elif mode == 'include':
8 | for name, param in model.named_parameters():
9 | flag = False
10 | for key in keys:
11 | if key in name:
12 | flag = True
13 | break
14 | if flag:
15 | yield param
16 | elif mode == 'exclude':
17 | for name, param in model.named_parameters():
18 | flag = True
19 | for key in keys:
20 | if key in name:
21 | flag = False
22 | break
23 | if flag:
24 | yield param
25 | else:
26 | raise ValueError('do not support: %s' % mode)
27 |
28 |
29 | def get_same_padding(kernel_size):
30 | if isinstance(kernel_size, tuple):
31 | assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size
32 | p1 = get_same_padding(kernel_size[0])
33 | p2 = get_same_padding(kernel_size[1])
34 | return p1, p2
35 | assert isinstance(kernel_size, int), 'kernel size should be either `int` or `tuple`'
36 | assert kernel_size % 2 > 0, 'kernel size should be odd number'
37 | return kernel_size // 2
38 |
39 | def build_activation(act_func, inplace=True):
40 | if act_func == 'relu':
41 | return nn.ReLU(inplace=inplace)
42 | elif act_func == 'relu6':
43 | return nn.ReLU6(inplace=inplace)
44 | elif act_func == 'tanh':
45 | return nn.Tanh()
46 | elif act_func == 'sigmoid':
47 | return nn.Sigmoid()
48 | elif act_func is None:
49 | return None
50 | else:
51 | raise ValueError('do not support: %s' % act_func)
52 |
53 |
54 | def make_divisible(v, divisor, min_val=None):
55 | """
56 | This function is taken from the original tf repo.
57 | It ensures that all layers have a channel number that is divisible by 8
58 | It can be seen here:
59 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
60 | """
61 | if min_val is None:
62 | min_val = divisor
63 | new_v = max(min_val, int(v + divisor / 2) // divisor * divisor)
64 | # Make sure that round down does not go down by more than 10%.
65 | if new_v < 0.9 * v:
66 | new_v += divisor
67 | return new_v
68 |
--------------------------------------------------------------------------------
/hyperbox/networks/repnas/__init__.py:
--------------------------------------------------------------------------------
1 | from .repnas_spos import RepNAS
--------------------------------------------------------------------------------
/hyperbox/networks/repnas/utils.py:
--------------------------------------------------------------------------------
1 | from hyperbox.mutables.spaces import OperationSpace
2 | from numpy.lib.arraysetops import isin
3 | import torch
4 | import torch.nn as nn
5 | # from .rep_ops import *
6 | from hyperbox.networks.repnas.rep_ops import *
7 |
8 |
9 | def fuse(candidates, weights, kernel_size=3):
10 | k_list = []
11 | b_list = []
12 |
13 | for i in range(len(candidates)):
14 | op = candidates[i]
15 | weight = weights[i].float()
16 | if op.__class__.__name__ == "DBB1x1kxk":
17 | if hasattr(op.dbb_1x1_kxk, 'idconv1'):
18 | k1 = op.dbb_1x1_kxk.idconv1.get_actual_kernel()
19 | else:
20 | k1 = op.dbb_1x1_kxk.conv1.weight
21 |
22 | k1, b1 = transI_fusebn(k1, op.dbb_1x1_kxk.bn1)
23 | k2, b2 = transI_fusebn(op.dbb_1x1_kxk.conv2.weight, op.dbb_1x1_kxk.bn2)
24 |
25 | k, b = transIII_1x1_kxk(k1, b1, k2, b2, groups=op.groups)
26 | elif op.__class__.__name__ == "DBB1x1":
27 | k, b = transI_fusebn(op.conv.weight, op.bn)
28 | k = transVI_multiscale(k, kernel_size)
29 | elif op.__class__.__name__ == "DBBORIGIN":
30 | k, b = transI_fusebn(op.conv.weight, op.bn)
31 | elif op.__class__.__name__ == "DBBAVG":
32 | ka = transV_avg(op.out_channels, op.kernel_size, op.groups)
33 | k2, b2 = transI_fusebn(ka.to(op.dbb_avg.avgbn.weight.device), op.dbb_avg.avgbn)
34 |
35 | if hasattr(op.dbb_avg, 'conv'):
36 | k1, b1 = transI_fusebn(op.dbb_avg.conv.weight, op.dbb_avg.bn)
37 | k, b = transIII_1x1_kxk(k1, b1, k2, b2, groups=op.groups)
38 | else:
39 | k, b = k2, b2
40 | else:
41 | raise "TypeError: Not In DBBAVG DBB1x1kxk DBB1x1 DBBORIGIN."
42 | k_list.append(k.detach() * weight)
43 | b_list.append(b.detach() * weight)
44 |
45 | return transII_addbranch(k_list, b_list)
46 |
47 |
48 | def replace(net):
49 | for name, module in net.named_modules():
50 | if isinstance(module, OperationSpace):
51 | candidates = []
52 | weights = []
53 | for idx, weight in enumerate(module.mask):
54 | if weight:
55 | candidates.append(module.candidates_original[idx])
56 | weights.append(weight)
57 | ks = max([c_.kernel_size for c_ in candidates])
58 | k, b = fuse(candidates, weights, ks)
59 | first = module.candidates_original[0]
60 | inc = first.in_channels
61 | ouc = first.out_channels
62 | s = first.stride
63 | p = ks//2
64 | g = first.groups
65 | reparam = nn.Conv2d(in_channels=inc, out_channels=ouc, kernel_size=ks,
66 | stride=s, padding=p, dilation=1, groups=g)
67 | reparam.weight.data = k
68 | reparam.bias.data = b
69 |
70 | module.candidates_original = [reparam]
71 | module.candidates = torch.nn.ModuleList([reparam])
72 | module.mask = torch.tensor([True])
73 |
--------------------------------------------------------------------------------
/hyperbox/networks/resnet/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
--------------------------------------------------------------------------------
/hyperbox/networks/spos/__init__.py:
--------------------------------------------------------------------------------
1 | from .shuffle_blocks import *
2 | from .spos_net import *
3 |
--------------------------------------------------------------------------------
/hyperbox/networks/vit/__init__.py:
--------------------------------------------------------------------------------
1 | from .vit import *
--------------------------------------------------------------------------------
/hyperbox/optimizers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/optimizers/__init__.py
--------------------------------------------------------------------------------
/hyperbox/optimizers/sam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import defaultdict
3 |
4 | __all__ = ["ASAM", "SAM"]
5 |
6 | class ASAM:
7 | def __init__(self, optimizer, model, rho=0.5, eta=0.01):
8 | self.optimizer = optimizer
9 | self.model = model
10 | self.rho = rho
11 | self.eta = eta
12 | self.state = defaultdict(dict)
13 |
14 | @torch.no_grad()
15 | def ascent_step(self):
16 | wgrads = []
17 | for n, p in self.model.named_parameters():
18 | if p.grad is None:
19 | continue
20 | t_w = self.state[p].get("eps")
21 | if t_w is None:
22 | t_w = torch.clone(p).detach()
23 | self.state[p]["eps"] = t_w
24 | if 'weight' in n:
25 | t_w[...] = p[...]
26 | t_w.abs_().add_(self.eta)
27 | p.grad.mul_(t_w)
28 | wgrads.append(torch.norm(p.grad, p=2))
29 | wgrad_norm = torch.norm(torch.stack(wgrads), p=2) + 1.e-16
30 | for n, p in self.model.named_parameters():
31 | if p.grad is None:
32 | continue
33 | t_w = self.state[p].get("eps")
34 | if 'weight' in n:
35 | p.grad.mul_(t_w)
36 | eps = t_w
37 | eps[...] = p.grad[...]
38 | eps.mul_(self.rho / wgrad_norm)
39 | p.add_(eps)
40 | self.optimizer.zero_grad()
41 |
42 | @torch.no_grad()
43 | def descent_step(self):
44 | for n, p in self.model.named_parameters():
45 | if p.grad is None:
46 | continue
47 | p.sub_(self.state[p]["eps"])
48 | self.optimizer.step()
49 | self.optimizer.zero_grad()
50 |
51 |
52 | class SAM(ASAM):
53 | @torch.no_grad()
54 | def ascent_step(self):
55 | grads = []
56 | for n, p in self.model.named_parameters():
57 | if p.grad is None:
58 | continue
59 | grads.append(torch.norm(p.grad, p=2))
60 | grad_norm = torch.norm(torch.stack(grads), p=2) + 1.e-16
61 | for n, p in self.model.named_parameters():
62 | if p.grad is None:
63 | continue
64 | eps = self.state[p].get("eps")
65 | if eps is None:
66 | eps = torch.clone(p).detach()
67 | self.state[p]["eps"] = eps
68 | eps[...] = p.grad[...]
69 | eps.mul_(self.rho / grad_norm)
70 | p.add_(eps)
71 | self.optimizer.zero_grad()
--------------------------------------------------------------------------------
/hyperbox/run.py:
--------------------------------------------------------------------------------
1 | # https://github.com/ashleve/lightning-hydra-template/blob/main/src/train.py
2 | import os
3 | import sys
4 |
5 | import hydra
6 | from omegaconf import DictConfig
7 | import pyrootutils
8 |
9 | # project root setup
10 | # searches for root indicators in parent dirs, like ".git", "pyproject.toml", etc.
11 | # sets PROJECT_ROOT environment variable (used in `configs/paths/default.yaml`)
12 | # loads environment variables from ".env" if exists
13 | # adds root dir to the PYTHONPATH (so this file can be run from any place)
14 | # https://github.com/ashleve/pyrootutils
15 | root = pyrootutils.setup_root(__file__, dotenv=True, pythonpath=True) # config.paths.root_dir
16 |
17 |
18 | @hydra.main(version_base="1.2", config_path="configs/", config_name="config.yaml")
19 | def main(config: DictConfig):
20 |
21 | # Imports should be nested inside @hydra.main to optimize tab completion
22 | # Read more here: https://github.com/facebookresearch/hydra/issues/934
23 | from hyperbox.train import train
24 | from hyperbox.utils import utils
25 |
26 | # A couple of optional utilities:
27 | # - disabling python warnings
28 | # - easier access to debug mode
29 | # - forcing debug friendly configuration
30 | # - forcing multi-gpu friendly configuration
31 | # You can safely get rid of this line if you don't want those
32 | utils.extras(config)
33 |
34 | # Pretty print config using Rich library
35 | if config.get("print_config"):
36 | utils.print_config(config, resolve=True)
37 |
38 | # Train model
39 | if config.ipdb_debug:
40 | from ipdb import set_trace
41 | set_trace()
42 | return train(config)
43 |
44 |
45 | if __name__ == "__main__":
46 | main()
47 |
--------------------------------------------------------------------------------
/hyperbox/schedulers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/schedulers/__init__.py
--------------------------------------------------------------------------------
/hyperbox/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/utils/__init__.py
--------------------------------------------------------------------------------
/hyperbox/utils/average_meter.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | __all__ = [
4 | 'AverageMeterGroup',
5 | 'AverageMeter'
6 | ]
7 |
8 | class AverageMeterGroup:
9 | """
10 | Average meter group for multiple average meters.
11 | """
12 |
13 | def __init__(self, verbose_type='avg'):
14 | self.meters = OrderedDict()
15 | self.verbose_type = verbose_type
16 |
17 | def update(self, data):
18 | """
19 | Update the meter group with a dict of metrics.
20 | Non-exist average meters will be automatically created.
21 | """
22 | for k, v in data.items():
23 | if k not in self.meters:
24 | self.meters[k] = AverageMeter(k, ":.4f", self.verbose_type)
25 | self.meters[k].update(v)
26 |
27 | # def __getattr__(self, item):
28 | # return self.meters[item]
29 |
30 | def __getitem__(self, item):
31 | return self.meters[item]
32 |
33 | def __str__(self):
34 | return " ".join(f"{v}" for v in self.meters.values())
35 |
36 | def summary(self):
37 | """
38 | Return a summary string of group data.
39 | """
40 | return " ".join(v.summary() for v in self.meters.values())
41 |
42 |
43 | class AverageMeter:
44 | """
45 | Computes and stores the average and current value.
46 | Parameters
47 | ----------
48 | name : str
49 | Name to display.
50 | fmt : str
51 | Format string to print the values.
52 | verbose_type : str
53 | 'all': value(avg)
54 | 'avg': avg
55 | """
56 |
57 | def __init__(self, name, fmt=':f', verbose_type='avg'):
58 | self.name = name
59 | self.fmt = fmt
60 | if verbose_type not in ['all', 'avg']:
61 | print('Not supported verbose type, using default verbose, "avg"')
62 | verbose_type = 'avg'
63 | self.verbose_type = verbose_type
64 | self.reset()
65 |
66 | def reset(self):
67 | """
68 | Reset the meter.
69 | """
70 | self.val = 0
71 | self.avg = 0
72 | self.sum = 0
73 | self.count = 0
74 |
75 | def update(self, val, n=1):
76 | """
77 | Update with value and weight.
78 | Parameters
79 | ----------
80 | val : float or int
81 | The new value to be accounted in.
82 | n : int
83 | The weight of the new value.
84 | """
85 | self.val = val
86 | self.sum += val * n
87 | self.count += n
88 | self.avg = self.sum / self.count
89 |
90 | def __str__(self):
91 | if self.verbose_type=='all':
92 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
93 | elif self.verbose_type=='avg':
94 | fmtstr = '{name} {avg' + self.fmt + '}'
95 | else:
96 | fmtstr = '{name} {avg' + self.fmt + '}'
97 | return fmtstr.format(**self.__dict__)
98 |
99 | def summary(self):
100 | fmtstr = '{name}: {avg' + self.fmt + '}'
101 | return fmtstr.format(**self.__dict__)
--------------------------------------------------------------------------------
/hyperbox/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import logging
4 | from loguru import logger
5 |
6 |
7 | def custom_format(record):
8 | path = os.path.abspath(record["file"].path)
9 | record["extra"]["abspath"] = path
10 | fmt = "[{time:YYYY-MM-DD HH:mm:ss}] [{level}] [{extra[abspath]}:{line} ({name})] {message}\n{exception}"
11 | return fmt
12 |
13 |
14 | LOGGER_DICT = {}
15 |
16 | def get_logger(name=None, level=logging.INFO, is_rank_zero=True, log2file=False):
17 | if is_rank_zero or name is None:
18 | name = 'exp'
19 | if name in LOGGER_DICT:
20 | return LOGGER_DICT[name]
21 | fmt = custom_format
22 | logger.remove()
23 | kwargs = {'sink': sys.stderr, 'format': fmt, 'level': level, 'colorize': True, 'backtrace': True}
24 | handlers = [kwargs]
25 | logger.opt(exception=True)
26 | if log2file:
27 | snd_kwargs = {k: v for k, v in kwargs.items() if k != 'sink'}
28 | snd_kwargs['sink'] = os.path.join(os.getcwd(), name + '.log')
29 | handlers.append(snd_kwargs)
30 | config = {"handlers": handlers}
31 | logger.configure(**config)
32 | LOGGER_DICT[name] = logger
33 | return logger
34 |
35 |
36 | if __name__ == '__main__':
37 | log1 = get_logger(None, level=logging.DEBUG, log2file=False)
38 | log1.debug('test') # showing
39 | log2 = get_logger('test2', level=logging.INFO, is_rank_zero=False, log2file=False)
40 | log2.debug('test2') # not showing
41 | log1.debug('test') # not showing
42 | print(log1 is log2) # True
--------------------------------------------------------------------------------
/hyperbox/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import scipy.stats as stats
3 |
4 | class Accuracy:
5 | def __call__(self, output, target, topk=(1,)):
6 | return accuracy(output, target, topk)
7 |
8 |
9 | def accuracy(output, target, topk=(1,)):
10 | """ Computes the precision@k for the specified values of k """
11 | maxk = max(topk)
12 | batch_size = target.size(0)
13 |
14 | _, pred = output.topk(maxk, 1, True, True)
15 | pred = pred.t()
16 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
17 |
18 | res = []
19 | for k in topk:
20 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
21 | res.append(torch.tensor(correct_k.mul_(100.0 / batch_size)))
22 | if len(res) == 1:
23 | return res[-1]
24 | return res
25 |
26 | class Pearson:
27 | def __call__(self, output, target):
28 | return pearson(output, target)
29 |
30 | def pearson(vector1, vector2):
31 | n = len(vector1)
32 | #simple sums
33 | sum1 = sum(float(vector1[i]) for i in range(n))
34 | sum2 = sum(float(vector2[i]) for i in range(n))
35 | #sum up the squares
36 | sum1_pow = sum([pow(v, 2.0) for v in vector1])
37 | sum2_pow = sum([pow(v, 2.0) for v in vector2])
38 | #sum up the products
39 | p_sum = sum([vector1[i]*vector2[i] for i in range(n)])
40 | #分子num,分母den
41 | num = p_sum - (sum1*sum2/n)
42 | den = math.sqrt((sum1_pow-pow(sum1, 2)/n)*(sum2_pow-pow(sum2, 2)/n))
43 | if den == 0:
44 | return 0.0
45 | return num/den
46 |
47 | class KendallTau:
48 | def __call__(self, output, target):
49 | stats.kendalltau(output, target)
--------------------------------------------------------------------------------
/hyperbox/utils/visualize_darts_cell.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import sys
3 | import os
4 | import json
5 | import numpy as np
6 | from graphviz import Digraph
7 | from argparse import ArgumentParser
8 |
9 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
10 |
11 | PRIMITIVES = [
12 | 'max_pool_3x3',
13 | 'avg_pool_3x3',
14 | 'skip_connect',
15 | 'sep_conv_3x3',
16 | 'sep_conv_5x5',
17 | 'dil_conv_3x3',
18 | 'dil_conv_5x5'
19 | ]
20 |
21 | def convert_genotypes(arch_json_file):
22 | with open(arch_json_file, 'r') as f:
23 | arch = json.load(f)
24 | normal = []
25 | reduce = []
26 | ops = []
27 | indices = []
28 | for key in arch:
29 | if 'switch' not in key:
30 | if sum(arch[key]) == 0:
31 | op = 'None'
32 | else:
33 | op = np.array(PRIMITIVES)[arch[key]][0]
34 | ops.append(op)
35 | else:
36 | _ops = []
37 | for i in range(len(arch[key])):
38 | if arch[key][i]:
39 | indices.append(i)
40 | _ops.append(ops[i])
41 | for op, idx in zip(_ops, indices):
42 | if 'norm' in key:
43 | normal.append((op, idx))
44 | else:
45 | reduce.append((op, idx))
46 | ops = []
47 | indices = []
48 | geno = Genotype(
49 | normal=normal, normal_concat=list(range(2, 2+len(normal)//2)),
50 | reduce=reduce, reduce_concat=list(range(2, 2+len(reduce)//2)),
51 | )
52 | print(geno)
53 | return geno
54 |
55 | def plot(genotype, filename):
56 | g = Digraph(
57 | format='pdf',
58 | edge_attr=dict(fontsize='20', fontname="times"),
59 | node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
60 | engine='dot')
61 | g.body.extend(['rankdir=LR'])
62 |
63 | g.node("c_{k-2}", fillcolor='darkseagreen2')
64 | g.node("c_{k-1}", fillcolor='darkseagreen2')
65 | assert len(genotype) % 2 == 0
66 | steps = len(genotype) // 2
67 |
68 | for i in range(steps):
69 | g.node(str(i), fillcolor='lightblue')
70 |
71 | for i in range(steps):
72 | for k in [2*i, 2*i + 1]:
73 | op, j = genotype[k]
74 | if j == 0:
75 | u = "c_{k-2}"
76 | elif j == 1:
77 | u = "c_{k-1}"
78 | else:
79 | u = str(j-2)
80 | v = str(i)
81 | if op != 'None':
82 | g.edge(u, v, label=op, fillcolor="gray")
83 |
84 | g.node("c_{k}", fillcolor='palegoldenrod')
85 | for i in range(steps):
86 | g.edge(str(i), "c_{k}", fillcolor="gray")
87 |
88 | g.render(filename, view=True)
89 |
90 |
91 | if __name__ == '__main__':
92 | parser = ArgumentParser('Visualize_DARTS')
93 | parser.add_argument('--file', type=str, help="the path of mask (json) file")
94 | args = parser.parse_args()
95 |
96 | arch_file = args.file
97 | assert os.path.exists(arch_file), f"{arch_file} not found."
98 | genotype = convert_genotypes(arch_file)
99 | plot(genotype.normal, "normal")
100 | plot(genotype.reduce, "reduction")
101 |
--------------------------------------------------------------------------------
/hyperbox/utils/visualize_mbconv_net.py:
--------------------------------------------------------------------------------
1 | import graphviz
2 |
3 |
4 | class ColorSet:
5 | hex_colors1 = [
6 | '#2878b5',
7 | '#9ac9db',
8 | '#f8ac8c',
9 | '#c82423',
10 | '#ff8884',
11 | ]
12 | hex_colors2 = [
13 | '#BEB8DC',
14 | '#E7DAD2',
15 | '#8ECFC9',
16 | '#FFBE7A',
17 | '#FA7F6F',
18 | '#82B0D2',
19 | ]
20 |
21 | def __init__(self, color_set='hex_colors1'):
22 | if color_set == 'hex_colors1':
23 | self._colors = self.hex_colors1
24 | elif color_set == 'hex_colors2':
25 | self._colors = self.hex_colors2
26 | else:
27 | raise ValueError('color_set must be one of hex_colors1 or hex_colors2')
28 |
29 | @property
30 | def colors(self):
31 | return self._colors
32 |
33 | @colors.setter
34 | def colors(self, value):
35 | self._colors = value
36 |
37 |
38 | def draw_arch(arch, index2op, filename='arch', color_set='hex_colors1'):
39 | colorset = ColorSet(color_set)
40 | index2color = {
41 | i: colorset.colors[i] for i in range(len(set(arch)))
42 | }
43 |
44 | dot = graphviz.Graph(comment='The Round Table')
45 | # dot.graph_attr['rankdir'] = 'LR'
46 | dot.graph_attr['rotate'] = '90'
47 | prev_index = None
48 | for idx, op_idx in enumerate(arch):
49 | op_name = index2op[op_idx]
50 | op_color = index2color[op_idx]
51 | op_index = f"{idx}"
52 | dot.node(op_index, label=op_name, style='filled', fillcolor=op_color, shape='box', fontsize='10')
53 | if idx == 0:
54 | prev_index = op_index
55 | else:
56 | dot.edge(op_index, prev_index)
57 | prev_index = op_index
58 | dot.render(directory='visual_output', view=False, filename=filename)
59 |
60 | if __name__ == '__main__':
61 | arch = [
62 | 0,1,2,1,2,3,1,1,2,2
63 | ]
64 | index2op = {
65 | 0: 'MBConv3',
66 | 1: 'MBConv5',
67 | 2: 'MBConv7',
68 | 3: 'Identity'
69 | }
70 | draw_arch(arch, index2op)
71 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # --------- pytorch --------- #
2 | torch>=1.8.1
3 | torchvision
4 | pytorch-lightning==1.8.6
5 | torchmetrics>=0.3.2
6 | lightning-bolts==0.3.3
7 |
8 | # --------- hydra --------- #
9 | hydra-core==1.2.0
10 | hydra-colorlog==1.1.0.dev1
11 | # hydra-optuna-sweeper==1.1.0.dev2 # for optuna-based HPO
12 | # hydra-ax-sweeper==1.1.0
13 | # hydra-ray-launcher==0.1.2
14 | # hydra-submitit-launcher==1.1.0
15 |
16 | # --------- loggers --------- #
17 | wandb>=0.12.21
18 | # neptune-client
19 | # mlflow
20 | # comet-ml
21 | # torch_tb_profiler
22 |
23 | # --------- linters --------- #
24 | pre-commit
25 | black
26 | isort
27 | flake8
28 |
29 | # --------- cores --------- #
30 | loguru>=0.6.0
31 | einops
32 | colorlog
33 | scikit-learn
34 | #numpy>=1.22.0
35 |
36 | # --------- visualization --------- #
37 | graphviz>=0.20.1
38 |
39 | # --------- for building nasbench-201 --------- #
40 | peewee
41 |
42 | # --------- others --------- #
43 | ipdb
44 | pyrootutils
45 | rich
46 | pytest
47 | sh
48 | seaborn
49 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | num=$1
4 | remark=$2
5 | others=$3
6 |
7 | BS=256
8 | exp=example_darts_nas
9 |
10 |
11 | if [ $num -le 1 ]
12 | then
13 | {
14 | echo "python run.py experiment=$exp.yaml logger.wandb.name=g${num}_bs${BS}_${remark} trainer.gpus=${num} datamodule.batch_size=${BS} ${others}"
15 | python -m ipdb run.py \
16 | experiment=$exp.yaml \
17 | logger.wandb.name=g${num}_bs${BS}_${remark} \
18 | trainer.gpus=${num} \
19 | datamodule.batch_size=${BS} \
20 | logger.wandb.offline=True \
21 | $others
22 | }
23 | else
24 | {
25 | echo "python run.py experiment=$exp.yaml logger.wandb.name=ddp_g${num}_bs${BS}_${remark} trainer.gpus=${num} trainer.accelerator=ddp datamodule.batch_size=${BS} ${others}"
26 | python run.py \
27 | experiment=$exp.yaml \
28 | logger.wandb.name=ddp_g${num}_bs${BS}_${remark} \
29 | trainer.gpus=${num} \
30 | trainer.accelerator=ddp \
31 | datamodule.batch_size=${BS} \
32 | logger.wandb.offline=True \
33 | $others
34 | }
35 | # mpirun -np ${num} python run.py experiment=$exp.yaml logger.wandb.name=ddp_g${num} trainer.gpus=1 trainer.accelerator=horovod
36 | fi
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | project_name = hyperbox
3 | author = marsggbo
4 | contact = marsggbo@foxmail.com
5 | license_file = LICENSE
6 | description_file = README.md
7 | project_template = https://github.com/marsggbo/hyperbox
8 |
9 |
10 | [isort]
11 | line_length = 99
12 | profile = black
13 | filter_files = True
14 |
15 |
16 | [flake8]
17 | max_line_length = 99
18 | show_source = True
19 | format = pylint
20 | ignore =
21 | F401 # Module imported but unused
22 | W504 # Line break occurred after a binary operator
23 | F841 # Local variable name is assigned to but never used
24 | exclude =
25 | .git
26 | __pycache__
27 | data/*
28 | tests/*
29 | notebooks/*
30 | logs/*
31 |
32 |
33 | [tool:pytest]
34 | python_files = tests/*
35 | log_cli = True
36 | markers =
37 | slow
38 | addopts =
39 | --durations=0
40 | --strict-markers
41 | --doctest-modules
42 | filterwarnings =
43 | ignore::DeprecationWarning
44 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | required_modules = []
4 | with open('requirements.txt') as f:
5 | required = f.read().splitlines()
6 | for module in required:
7 | if not module.startswith('#'):
8 | required_modules.append(module)
9 |
10 | setup(
11 | name="hyperbox", # you should change "src" to your project name
12 | version="1.4.4",
13 | description="Hyperbox: An easy-to-use NAS framework.",
14 | author="marsggbo",
15 | url="https://github.com/marsggbo/hyperbox",
16 | # replace with your own github project link
17 | install_requires=required_modules,
18 | packages=find_packages(),
19 | include_package_data=True,
20 | )
21 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/tests/__init__.py
--------------------------------------------------------------------------------
/tests/datamodules/test_transforms.py:
--------------------------------------------------------------------------------
1 | from hyperbox.datamodules.transforms import get_transforms, TorchTransforms
2 |
3 | if __name__ == '__main__':
4 | kwargs = {
5 | 'input_size': [32, 32],
6 | 'random_crop': {'enable': 1, 'padding': 4, 'size': 28},
7 | 'random_horizontal_flip': {'enable': 1, 'p': 0.8}
8 | }
9 | T = TorchTransforms(**kwargs)
10 | pass
--------------------------------------------------------------------------------
/tests/helpers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/tests/helpers/__init__.py
--------------------------------------------------------------------------------
/tests/helpers/module_available.py:
--------------------------------------------------------------------------------
1 | import platform
2 | from importlib.util import find_spec
3 |
4 | """
5 | Adapted from:
6 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/imports.py
7 | """
8 |
9 |
10 | def _module_available(module_path: str) -> bool:
11 | """Check if a path is available in your environment.
12 |
13 | >>> _module_available('os')
14 | True
15 | >>> _module_available('bla.bla')
16 | False
17 |
18 | """
19 | try:
20 | return find_spec(module_path) is not None
21 | except AttributeError:
22 | # Python 3.6
23 | return False
24 | except ModuleNotFoundError:
25 | # Python 3.7+
26 | return False
27 |
28 |
29 | _IS_WINDOWS = platform.system() == "Windows"
30 | _APEX_AVAILABLE = _module_available("apex.amp")
31 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available("deepspeed")
32 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn")
33 | _RPC_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.rpc")
34 |
--------------------------------------------------------------------------------
/tests/helpers/run_command.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import pytest
4 | import sh
5 |
6 |
7 | def run_command(command: List[str]):
8 | """Default method for executing shell commands with pytest."""
9 | msg = None
10 | try:
11 | sh.python(command)
12 | except sh.ErrorReturnCode as e:
13 | msg = e.stderr.decode()
14 | if msg:
15 | pytest.fail(msg=msg)
16 |
--------------------------------------------------------------------------------
/tests/helpers/runif.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Optional
3 |
4 | import pytest
5 | import torch
6 | from packaging.version import Version
7 | from pkg_resources import get_distribution
8 |
9 | """
10 | Adapted from:
11 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py
12 | """
13 |
14 | from tests.helpers.module_available import (
15 | _APEX_AVAILABLE,
16 | _DEEPSPEED_AVAILABLE,
17 | _FAIRSCALE_AVAILABLE,
18 | _IS_WINDOWS,
19 | _RPC_AVAILABLE,
20 | )
21 |
22 |
23 | class RunIf:
24 | """
25 | RunIf wrapper for conditional skipping of tests.
26 | Fully compatible with `@pytest.mark`.
27 |
28 | Example:
29 |
30 | @RunIf(min_torch="1.8")
31 | @pytest.mark.parametrize("arg1", [1.0, 2.0])
32 | def test_wrapper(arg1):
33 | assert arg1 > 0
34 |
35 | """
36 |
37 | def __new__(
38 | self,
39 | min_gpus: int = 0,
40 | min_torch: Optional[str] = None,
41 | max_torch: Optional[str] = None,
42 | min_python: Optional[str] = None,
43 | amp_apex: bool = False,
44 | skip_windows: bool = False,
45 | rpc: bool = False,
46 | fairscale: bool = False,
47 | deepspeed: bool = False,
48 | **kwargs,
49 | ):
50 | """
51 | Args:
52 | min_gpus: min number of gpus required to run test
53 | min_torch: minimum pytorch version to run test
54 | max_torch: maximum pytorch version to run test
55 | min_python: minimum python version required to run test
56 | amp_apex: NVIDIA Apex is installed
57 | skip_windows: skip test for Windows platform
58 | rpc: requires Remote Procedure Call (RPC)
59 | fairscale: if `fairscale` module is required to run the test
60 | deepspeed: if `deepspeed` module is required to run the test
61 | kwargs: native pytest.mark.skipif keyword arguments
62 | """
63 | conditions = []
64 | reasons = []
65 |
66 | if min_gpus:
67 | conditions.append(torch.cuda.device_count() < min_gpus)
68 | reasons.append(f"GPUs>={min_gpus}")
69 |
70 | if min_torch:
71 | torch_version = get_distribution("torch").version
72 | conditions.append(Version(torch_version) < Version(min_torch))
73 | reasons.append(f"torch>={min_torch}")
74 |
75 | if max_torch:
76 | torch_version = get_distribution("torch").version
77 | conditions.append(Version(torch_version) >= Version(max_torch))
78 | reasons.append(f"torch<{max_torch}")
79 |
80 | if min_python:
81 | py_version = (
82 | f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
83 | )
84 | conditions.append(Version(py_version) < Version(min_python))
85 | reasons.append(f"python>={min_python}")
86 |
87 | if amp_apex:
88 | conditions.append(not _APEX_AVAILABLE)
89 | reasons.append("NVIDIA Apex")
90 |
91 | if skip_windows:
92 | conditions.append(_IS_WINDOWS)
93 | reasons.append("does not run on Windows")
94 |
95 | if rpc:
96 | conditions.append(not _RPC_AVAILABLE)
97 | reasons.append("RPC")
98 |
99 | if fairscale:
100 | conditions.append(not _FAIRSCALE_AVAILABLE)
101 | reasons.append("Fairscale")
102 |
103 | if deepspeed:
104 | conditions.append(not _DEEPSPEED_AVAILABLE)
105 | reasons.append("Deepspeed")
106 |
107 | reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
108 | return pytest.mark.skipif(
109 | condition=any(conditions),
110 | reason=f"Requires: [{' + '.join(reasons)}]",
111 | **kwargs,
112 | )
113 |
--------------------------------------------------------------------------------
/tests/models/test_random_nas_model.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | from omegaconf import OmegaConf
3 |
4 | if __name__ == '__main__':
5 | cfg = OmegaConf.load('../../configs/model/random_model.yaml')
6 | nas_net = hydra.utils.instantiate(cfg, _recursive_=False)
7 |
--------------------------------------------------------------------------------
/tests/mutables/test_all_mutables.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from hyperbox.mutator import RandomMutator
5 | from hyperbox.mutables.spaces import InputSpace, OperationSpace, ValueSpace
6 | from hyperbox.mutables.ops import Conv2d
7 |
8 |
9 | class Net(nn.Module):
10 | def __init__(self,):
11 | super().__init__()
12 | ops = [
13 | nn.Conv2d(3,4,kernel_size=3,stride=1,padding=1),
14 | nn.Conv2d(3,4,kernel_size=5,stride=1,padding=2),
15 | nn.Conv2d(3,4,kernel_size=7,stride=1,padding=3),
16 | ]
17 | self.candidate_op1 = OperationSpace(ops, key='candidate1')
18 | self.candidate_op2 = OperationSpace(ops, key='candidate2')
19 |
20 | v1 = ValueSpace([4,8,16])
21 | v2 = ValueSpace([2])
22 | v3 = ValueSpace([3,5,7])
23 | self.fop1 = Conv2d(4, v1, kernel_size=v3,stride=v2,padding=1,auto_padding=True)
24 | self.fop2 = Conv2d(v1, 8, kernel_size=v3,stride=v2,padding=1,auto_padding=True)
25 |
26 | self.input_op = InputSpace(n_candidates=2, n_chosen=1, key='input1')
27 |
28 | def forward(self, x):
29 | out1 = self.candidate_op1(x)
30 | out2 = self.candidate_op2(x)
31 |
32 | out = self.input_op([out1, out2])
33 | print(out.shape)
34 | out = self.fop1(out)
35 | out = self.fop2(out)
36 | return out
37 |
38 | if __name__ == '__main__':
39 | net = Net()
40 | random = RandomMutator(net)
41 | random.reset()
42 | x = torch.rand(2,3,64,64)
43 | y = net(x)
44 | print(y.shape)
45 |
46 |
47 | # test OperationSpace
48 | x = torch.rand(2,10)
49 | ops = [
50 | nn.Linear(10,10,bias=False),
51 | nn.Linear(10,100,bias=False),
52 | nn.Linear(10,100,bias=False),
53 | nn.Identity()
54 | ]
55 | mixop = OperationSpace(
56 | ops,
57 | mask=[0,1,0,0]
58 | )
59 | y = mixop(x)
60 | print(y.shape)
61 |
62 | mixop = OperationSpace(
63 | ops, return_mask=True
64 | )
65 | m = RandomMutator(mixop)
66 | m.reset()
67 | y, mask = mixop(x)
68 | print(y.shape, mask)
69 |
70 | # test InputSpace
71 | input1 = torch.rand(1,3)
72 | input2 = torch.rand(1,2)
73 | input3 = torch.rand(2,1)
74 | inputs = [input1, input2, input3]
75 | ic1 = InputSpace(n_candidates=3, n_chosen=1, return_mask=True)
76 | m = RandomMutator(ic1)
77 | m.reset()
78 | out, mask = ic1(inputs)
79 | print(out.shape, mask)
80 |
81 | inputs = {'key1':input1, 'key2':input2, 'key3':input3}
82 | ic2 = InputSpace(choose_from=['key1', 'key2', 'key3'], n_chosen=1, return_mask=True)
83 | m = RandomMutator(ic2)
84 | m.reset()
85 | out, mask = ic2(inputs)
86 | print(out.shape, mask)
87 |
88 | vc1 = ValueSpace([8,16,24], index=1)
89 | vc2 = ValueSpace([1,2,3,4,5,6,7,8,9])
90 | vc = nn.ModuleList([vc1,vc2])
91 | m = RandomMutator(vc)
92 | m.reset()
93 | print(vc1, vc2)
94 |
--------------------------------------------------------------------------------
/tests/mutables/test_duplicate_mutables.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from hyperbox.mutables.spaces import OperationSpace
5 | from hyperbox.networks.base_nas_network import BaseNASNetwork
6 | from hyperbox.mutator import RandomMutator
7 |
8 |
9 | class ToyDuplicateNet(BaseNASNetwork):
10 | def __init__(self):
11 | super().__init__()
12 | self.op1 = OperationSpace(candidates=[nn.ReLU(), nn.Sigmoid()], key='op')
13 | self.op2 = OperationSpace(candidates=[nn.ReLU(), nn.Sigmoid()], key='op')
14 |
15 | def forward(self, x):
16 | out1 = self.op1(x)
17 | out2 = self.op2(x)
18 | out = out1 - out2
19 | return out
20 |
21 | @property
22 | def arch(self):
23 | return f"op1:{self.op1.mask}-op2:{self.op2.mask}"
24 |
25 |
26 | if __name__ == '__main__':
27 | model = ToyDuplicateNet()
28 | mutator = RandomMutator(model)
29 | x = torch.rand(10)
30 | for i in range(3):
31 | mutator.reset()
32 | print(model.arch)
33 | y = model(x)
34 | print(y)
35 | print('='*20)
--------------------------------------------------------------------------------
/tests/mutables/test_op_bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from hyperbox.mutables import ops
3 | from hyperbox.mutables.spaces import ValueSpace
4 | from hyperbox.mutator import RandomMutator
5 |
6 | if __name__ == '__main__':
7 | x = torch.rand(4,3)
8 | print('Normal case')
9 | linear = ops.Linear(3,10)
10 | bn = ops.BatchNorm1d(10)
11 | y = linear(x)
12 | print(y.shape)
13 | y = bn(y)
14 | print(y.shape)
15 |
16 | print('Search case')
17 | vs1 = ValueSpace([1,2,3])
18 | linear = ops.Linear(3, vs1, bias=False)
19 | bn = ops.BatchNorm1d(vs1)
20 | m = RandomMutator(linear)
21 | m.reset()
22 | print(m._cache)
23 | print(linear.weight.shape)
24 | y = linear(x)
25 | print(y.shape)
26 | y = bn(y)
27 | print(y.shape)
28 |
--------------------------------------------------------------------------------
/tests/mutables/test_op_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from hyperbox.mutables import ops
3 | from hyperbox.mutables.spaces import ValueSpace
4 | from hyperbox.mutator import RandomMutator
5 |
6 | from hyperbox.utils.calc_model_size import flops_size_counter
7 |
8 | def test_groups(cin, cout, group_list):
9 | print(f"test_groups: cin={cin}, cout={cout}, group_list={group_list}\n")
10 | groups = ValueSpace(candidates=group_list)
11 | x = torch.rand(2, cin, 64, 64)
12 | conv = ops.Conv2d(cin, cout, 3, 1, 1, groups=groups)
13 | m = RandomMutator(conv)
14 | for i in range(10):
15 | m.reset()
16 | print(f'\n*******step{i}********\n', conv)
17 | y = conv(x)
18 |
19 | def test_cin_groups(cin_list, cout, group_list):
20 | print(f"test_cin_groups: cin_list={cin_list}, cout={cout}, group_list={group_list}\n")
21 | cin = ValueSpace(candidates=cin_list)
22 | groups = ValueSpace(candidates=group_list)
23 | x = torch.rand(2, 3, 64, 64)
24 | conv = torch.nn.Sequential(
25 | ops.Conv2d(3, cin, 3, 1, 1),
26 | ops.Conv2d(cin, cout, 3, 1, 1, groups=cin)
27 | )
28 | m = RandomMutator(conv)
29 | for i in range(10):
30 | m.reset()
31 | print(f'\n*******step{i}********\n', conv)
32 | y = conv(x)
33 |
34 | def test_cout_groups(cin, cout_list, group_list):
35 | print(f"test_cout_groups: cin={cin}, cout_list={cout_list}, group_list={group_list}\n")
36 | cout = ValueSpace(candidates=cout_list)
37 | groups = ValueSpace(candidates=group_list)
38 | x = torch.rand(2, cin, 64, 64)
39 | conv = ops.Conv2d(cin, cout, 3, 1, 1, groups=groups)
40 | m = RandomMutator(conv)
41 | for i in range(10):
42 | m.reset()
43 | print(f'\n*******step{i}********\n', conv)
44 | y = conv(x)
45 |
46 | def test_cin_cout_groups(cin_list, cout_list, group_list):
47 | print(f"test_cin_cout_groups: cin_list={cin_list}, cout_list={cout_list}, group_list={group_list}\n")
48 | cin = ValueSpace(candidates=cin_list)
49 | cout = ValueSpace(candidates=cout_list)
50 | groups = ValueSpace(candidates=group_list)
51 | x = torch.rand(2, 3, 64, 64)
52 | conv = torch.nn.Sequential(
53 | ops.Conv2d(3, cin, 3, 1, 1),
54 | ops.Conv2d(cin, cout, 3, 1, 1, groups=cin)
55 | )
56 | m = RandomMutator(conv)
57 | for i in range(10):
58 | m.reset()
59 | print(f'\n*******step{i}********\n', conv)
60 | y = conv(x)
61 |
62 | def test_conv():
63 | print('testing conv flops and sizes')
64 | x = torch.rand(1,3,64,64)
65 | vs1 = ValueSpace([10,2])
66 | vs2 = ValueSpace([3,5,7])
67 | conv = ops.Conv2d(3, vs1, vs2,bias=False)
68 | bn = ops.BatchNorm2d(vs1)
69 | op = torch.nn.Sequential(
70 | # torch.nn.Conv2d(3,3,3,1),
71 | conv, bn
72 | )
73 | m = RandomMutator(op)
74 | m.reset()
75 | # print(op)
76 | # print(m._cache)
77 | # print(conv.weight.shape)
78 | # print(conv(x).shape)
79 | r = flops_size_counter(op, (1,3,8,8), False, False)
80 | # print(r)
81 | op = torch.nn.Sequential(
82 | ops.Conv2d(3,8,3,1),
83 | ops.BatchNorm2d(8)
84 | )
85 | r = flops_size_counter(op, (1,3,8,8), False, False)
86 | # print(conv(x).shape)
87 |
88 | if __name__ == '__main__':
89 | test_conv()
90 | test_groups(32, 64, group_list=[1, 2, 4, 8, 16, 32])
91 | test_cin_groups([32, 64], 128, group_list=[32, 64])
92 | test_cout_groups(8, [16, 32], group_list=[1, 2, 4, 8])
93 | test_cin_cout_groups([8, 16], [32, 64], group_list=[1, 8, 16])
94 |
--------------------------------------------------------------------------------
/tests/mutables/test_op_linear.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from hyperbox.mutables import ops
3 | from hyperbox.mutables.spaces import ValueSpace
4 | from hyperbox.mutator import RandomMutator
5 |
6 | from hyperbox.utils.calc_model_size import flops_size_counter
7 |
8 | if __name__ == '__main__':
9 | x = torch.rand(1,3)
10 | vs1 = ValueSpace([1,2,3])
11 | linear = ops.Linear(3, vs1, bias=False)
12 | m = RandomMutator(linear)
13 | m.reset()
14 | print(m._cache)
15 | print(linear.weight.shape)
16 | print(linear(x).shape)
17 | r = flops_size_counter(linear, (2,3), False, True)
18 |
19 | linear = ops.Linear(3,10)
20 | print(linear(x).shape)
21 | r = flops_size_counter(linear, (2,3), False, True)
--------------------------------------------------------------------------------
/tests/networks/test_darts_net.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | from hyperbox.networks.darts import DartsNetwork, DartsCell
5 | from hyperbox.mutator.random_mutator import RandomMutator
6 |
7 |
8 | if __name__ == '__main__':
9 | net = DartsNetwork(3,16,10,4,mask='./hyperbox/networks/darts/darts_mask.json')
10 | pass
11 |
12 | net = DartsNetwork(
13 | in_channels=3,
14 | channels=16,
15 | n_classes=10,
16 | n_layers=3,
17 | factory_func=DartsCell,
18 | ).cuda()
19 | print(net.arch)
20 | m = RandomMutator(net)
21 | m.reset()
22 | print(net.arch)
23 | m.reset()
24 | print(net.arch)
25 | m.reset()
26 | print(net.arch)
27 | x = torch.rand(2,3,64,64).cuda()
28 | output = net(x)
29 | print(output.shape)
30 | pass
--------------------------------------------------------------------------------
/tests/networks/test_enas_net.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | from hyperbox.networks.enas import ENASMicroNetwork
5 | from hyperbox.mutator.random_mutator import RandomMutator
6 |
7 | if __name__ == '__main__':
8 | net = ENASMicroNetwork(
9 | num_layers=2,
10 | num_nodes=3,
11 | ).cuda()
12 | m = RandomMutator(net)
13 | m.reset()
14 | x = torch.rand(2,3,64,64).cuda()
15 | output = net(x)
16 | print(output.shape)
17 | pass
--------------------------------------------------------------------------------
/tests/networks/test_mobile_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from hyperbox.networks.mobilenet.mobile_net import MobileNet
5 | from hyperbox.networks.mobilenet.mobile3d_net import Mobile3DNet
6 |
7 | if __name__ == '__main__':
8 | from hyperbox.mutator.random_mutator import RandomMutator
9 | net = MobileNet()
10 | m = RandomMutator(net)
11 | m.reset()
12 | x = torch.rand(2,3,64,64)
13 | output = net(x)
14 | print(output.shape)
15 |
16 |
17 | net = Mobile3DNet()
18 | m = RandomMutator(net)
19 | m.reset()
20 | x = torch.rand(2,3,64,64,64)
21 | output = net(x)
22 | print(output.shape)
--------------------------------------------------------------------------------
/tests/networks/test_resnet.py:
--------------------------------------------------------------------------------
1 |
2 | from hyperbox.networks.resnet import resnet20
3 |
4 | if __name__ == '__main__':
5 | size = lambda net: sum([p.numel() for name, p in net.named_parameters() if 'value' not in name])
6 | net = resnet20()
7 | params_num = size(net)
8 | mb = params_num * 4 / 1024**2
9 | print(f"#pamrams of {net.__class__.__name__}={size(net)} {mb}(MB)")
10 |
--------------------------------------------------------------------------------
/tests/smoke/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/tests/smoke/__init__.py
--------------------------------------------------------------------------------
/tests/smoke/test_commands.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from tests.helpers.run_command import run_command
4 | from tests.helpers.runif import RunIf
5 |
6 |
7 | def test_fast_dev_run():
8 | """Run 1 train, val, test batch."""
9 | command = ["run.py", "trainer=default", "trainer.fast_dev_run=true"]
10 | run_command(command)
11 |
12 |
13 | def test_default_cpu():
14 | """Test default configuration on CPU."""
15 | command = ["run.py", "trainer.max_epochs=1", "trainer.gpus=0"]
16 | run_command(command)
17 |
18 |
19 | @RunIf(min_gpus=1)
20 | def test_default_gpu():
21 | """Test default configuration on GPU."""
22 | command = [
23 | "run.py",
24 | "trainer.max_epochs=1",
25 | "trainer.gpus=1",
26 | "datamodule.pin_memory=True",
27 | ]
28 | run_command(command)
29 |
30 |
31 | @pytest.mark.slow
32 | def test_experiments():
33 | """Train 1 epoch with all experiment configs."""
34 | command = ["run.py", "-m", "experiment=glob(*)", "trainer.max_epochs=1"]
35 | run_command(command)
36 |
37 |
38 | def test_limit_batches():
39 | """Train 1 epoch on 25% of data."""
40 | command = [
41 | "run.py",
42 | "trainer=default",
43 | "trainer.max_epochs=1",
44 | "trainer.limit_train_batches=0.25",
45 | "trainer.limit_val_batches=0.25",
46 | "trainer.limit_test_batches=0.25",
47 | ]
48 | run_command(command)
49 |
50 |
51 | def test_gradient_accumulation():
52 | """Train 1 epoch with gradient accumulation."""
53 | command = [
54 | "run.py",
55 | "trainer=default",
56 | "trainer.max_epochs=1",
57 | "trainer.accumulate_grad_batches=10",
58 | ]
59 | run_command(command)
60 |
61 |
62 | def test_double_validation_loop():
63 | """Train 1 epoch with validation loop twice per epoch."""
64 | command = [
65 | "run.py",
66 | "trainer=default",
67 | "trainer.max_epochs=1",
68 | "trainer.val_check_interval=0.5",
69 | ]
70 | run_command(command)
71 |
72 |
73 | def test_csv_logger():
74 | """Train 5 epochs with 5 batches with CSVLogger."""
75 | command = [
76 | "run.py",
77 | "trainer=default",
78 | "trainer.max_epochs=5",
79 | "trainer.limit_train_batches=5",
80 | "logger=csv",
81 | ]
82 | run_command(command)
83 |
84 |
85 | def test_tensorboard_logger():
86 | """Train 5 epochs with 5 batches with TensorboardLogger."""
87 | command = [
88 | "run.py",
89 | "trainer=default",
90 | "trainer.max_epochs=5",
91 | "trainer.limit_train_batches=5",
92 | "logger=tensorboard",
93 | ]
94 | run_command(command)
95 |
96 |
97 | def test_overfit_batches():
98 | """Overfit to 10 batches over 10 epochs."""
99 | command = [
100 | "run.py",
101 | "trainer=default",
102 | "trainer.min_epochs=10",
103 | "trainer.max_epochs=10",
104 | "trainer.overfit_batches=10",
105 | ]
106 | run_command(command)
107 |
--------------------------------------------------------------------------------
/tests/smoke/test_mixed_precision.py:
--------------------------------------------------------------------------------
1 | from tests.helpers.run_command import run_command
2 | from tests.helpers.runif import RunIf
3 |
4 |
5 | @RunIf(amp_apex=True)
6 | def test_apex_01():
7 | """Test mixed-precision level 01."""
8 | command = [
9 | "run.py",
10 | "trainer=default",
11 | "trainer.max_epochs=1",
12 | "trainer.gpus=1",
13 | "trainer.amp_backend=apex",
14 | "trainer.amp_level=O1",
15 | "trainer.precision=16",
16 | ]
17 | run_command(command)
18 |
19 |
20 | @RunIf(amp_apex=True)
21 | def test_apex_02():
22 | """Test mixed-precision level 02."""
23 | command = [
24 | "run.py",
25 | "trainer=default",
26 | "trainer.max_epochs=1",
27 | "trainer.gpus=1",
28 | "trainer.amp_backend=apex",
29 | "trainer.amp_level=O2",
30 | "trainer.precision=16",
31 | ]
32 | run_command(command)
33 |
34 |
35 | @RunIf(amp_apex=True)
36 | def test_apex_03():
37 | """Test mixed-precision level 03."""
38 | command = [
39 | "run.py",
40 | "trainer=default",
41 | "trainer.max_epochs=1",
42 | "trainer.gpus=1",
43 | "trainer.amp_backend=apex",
44 | "trainer.amp_level=O3",
45 | "trainer.precision=16",
46 | ]
47 | run_command(command)
48 |
--------------------------------------------------------------------------------
/tests/smoke/test_sweeps.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from tests.helpers.run_command import run_command
4 |
5 | """
6 | Use the following command to skip slow tests:
7 | pytest -k "not slow"
8 | """
9 |
10 |
11 | @pytest.mark.slow
12 | def test_default_sweep():
13 | """Test default Hydra sweeper."""
14 | command = [
15 | "run.py",
16 | "-m",
17 | "datamodule.batch_size=64,128",
18 | "model.lr=0.01,0.02",
19 | "trainer=default",
20 | "trainer.fast_dev_run=true",
21 | ]
22 | run_command(command)
23 |
24 |
25 | @pytest.mark.slow
26 | def test_optuna_sweep():
27 | """Test Optuna sweeper."""
28 | command = [
29 | "run.py",
30 | "-m",
31 | "hparams_search=mnist_optuna",
32 | "trainer=default",
33 | "trainer.fast_dev_run=true",
34 | ]
35 | run_command(command)
36 |
37 |
38 | @pytest.mark.skip(reason="TODO: Add Ax sweep config.")
39 | @pytest.mark.slow
40 | def test_ax_sweep():
41 | """Test Ax sweeper."""
42 | command = ["run.py", "-m", "hparams_search=mnist_ax", "trainer.fast_dev_run=true"]
43 | run_command(command)
44 |
--------------------------------------------------------------------------------
/tests/smoke/test_wandb.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from tests.helpers.run_command import run_command
4 |
5 | """
6 | Use the following command to skip slow tests:
7 | pytest -k "not slow"
8 | """
9 |
10 |
11 | # @pytest.mark.slow
12 | # def test_wandb_optuna_sweep():
13 | # """Test wandb logging with Optuna sweep."""
14 | # command = [
15 | # "run.py",
16 | # "-m",
17 | # "hparams_search=mnist_optuna",
18 | # "trainer=default",
19 | # "trainer.max_epochs=10",
20 | # "trainer.limit_train_batches=20",
21 | # "logger=wandb",
22 | # "logger.wandb.project=template-tests",
23 | # "logger.wandb.group=Optuna_SimpleDenseNet_MNIST",
24 | # "hydra.sweeper.n_trials=5",
25 | # ]
26 | # run_command(command)
27 |
28 |
29 | # @pytest.mark.slow
30 | # def test_wandb_callbacks():
31 | # """Test wandb callbacks."""
32 | # command = [
33 | # "run.py",
34 | # "trainer=default",
35 | # "trainer.max_epochs=3",
36 | # "logger=wandb",
37 | # "logger.wandb.project=template-tests",
38 | # "callbacks=wandb",
39 | # ]
40 | # run_command(command)
41 |
--------------------------------------------------------------------------------
/tests/unit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/tests/unit/__init__.py
--------------------------------------------------------------------------------
/tests/unit/test_sth.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from tests.helpers.runif import RunIf
4 |
5 |
6 | def test_something1():
7 | """Some test description."""
8 | assert True is True
9 |
10 |
11 | def test_something2():
12 | """Some test description."""
13 | assert 1 + 1 == 2
14 |
15 |
16 | @pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0])
17 | def test_something3(arg1: float):
18 | """Some test description."""
19 | assert arg1 > 0
20 |
21 |
22 | # use RunIf to skip execution of some tests, e.g. when not on windows or when no gpus are available
23 | @RunIf(skip_windows=True, min_gpus=1)
24 | def test_something4():
25 | """Some test description."""
26 | assert True is True
27 |
--------------------------------------------------------------------------------
/tests/utils/test_calc_model_size.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | from hyperbox.mutator import RandomMutator
5 | from hyperbox.mutables import ops, spaces
6 | from hyperbox.networks.mobilenet.mobile_net import MobileNet
7 | from hyperbox.utils.calc_model_size import flops_size_counter
8 |
9 |
10 | class HybridNet(torch.nn.Module):
11 | def __init__(self, mask=None):
12 | super().__init__()
13 | self.op_space1 = spaces.OperationSpace(
14 | candidates=[
15 | torch.nn.Conv2d(3,20,3,padding=1),
16 | torch.nn.Conv2d(3,20,5,padding=2)
17 | ],
18 | mask=mask
19 | )
20 | self.op_space2 = spaces.OperationSpace(
21 | candidates=[
22 | torch.nn.Conv2d(3,20,7,padding=3),
23 | torch.nn.Conv2d(3,20,3,padding=1)
24 | ],
25 | mask=mask
26 | )
27 | self.input_space = spaces.InputSpace(n_candidates=2, n_chosen=1, mask=mask)
28 |
29 | vs_channel = spaces.ValueSpace([24,16,18], mask=mask)
30 | vs_kernel = spaces.ValueSpace([3,5,7], mask=mask)
31 | vs_stride = spaces.ValueSpace([1,2], mask=mask)
32 | self.finegrained_conv = ops.Conv2d(20, vs_channel, vs_kernel, stride=vs_stride, auto_padding=True)
33 | self.finegrained_bn = ops.BatchNorm2d(vs_channel)
34 | self.finegrained_linear = ops.Linear(vs_channel, 1)
35 |
36 | def forward(self, x):
37 | bs = x.size(0)
38 | out1 = self.op_space1(x)
39 | out2 = self.op_space2(x)
40 | out = self.input_space([out1, out2])
41 | out = self.finegrained_conv(out)
42 | out = self.finegrained_bn(out)
43 | out = torch.nn.AdaptiveAvgPool2d(1)(out)
44 | out = out.view(bs, -1)
45 | out = self.finegrained_linear(out)
46 | return out
47 |
48 |
49 | if __name__ == '__main__':
50 |
51 | net = MobileNet()
52 | mutator = RandomMutator(net)
53 | mutator.reset()
54 |
55 | x = torch.rand(1,3,128,128)
56 | result = flops_size_counter(net, (x,))
57 | flops, params = [result[k] for k in result]
58 | print(f"{flops} MFLOPS, {params:.4f} MB")
59 |
60 | x = torch.rand(1,3,64,64)
61 | result = flops_size_counter(net, (x,))
62 | flops, params = [result[k] for k in result]
63 | print(f"{flops} MFLOPS, {params:.4f} MB")
64 |
65 | mask = {
66 | "OperationSpace1": [True, False],
67 | "OperationSpace2": [False, True],
68 | "InputSpace3": [True, False],
69 | "ValueSpace6": [True, False],
70 | "ValueSpace4": [False, True, False],
71 | "ValueSpace5": [False, False, True]
72 | }
73 | net = HybridNet(mask)
74 | x = torch.rand(1,3,64,64)
75 | result = flops_size_counter(net, (x,))
76 | flops, params = [result[k] for k in result]
77 | print(f"{flops} MFLOPS, {params:.4f} MB")
78 |
79 | net = HybridNet()
80 | mutator = RandomMutator(net)
81 | mutator.reset()
82 | x = torch.rand(1,3,32,32)
83 | result = flops_size_counter(net, (x,))
84 | flops, params = [result[k] for k in result]
85 | print(f"{flops} MFLOPS, {params:.4f} MB")
86 |
--------------------------------------------------------------------------------
/tests/utils/test_hparams.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from hyperbox.utils.utils import hparams_wrapper
4 |
5 | @hparams_wrapper
6 | class MyClass1(torch.nn.Module):
7 | def __init__(self, arg1):
8 | super(MyClass1, self).__init__()
9 | self.arg1 = arg1
10 | self.linear = torch.nn.Linear(3, arg1)
11 |
12 | class MyClass2(MyClass1):
13 | def __init__(self, arg1, arg2, arg3=False):
14 | super(MyClass2, self).__init__(arg1)
15 | self.arg1 = arg1
16 | self.arg2 = arg2
17 | self.arg3 = arg3
18 | self.conv = torch.nn.Conv2d(3, arg1, arg2, bias=arg3)
19 |
20 | def __setattr__(self, name, value):
21 | if name != 'key':
22 | super(MyClass2, self).__setattr__(name, value)
23 |
24 |
25 | if __name__ == '__main__':
26 | print(MyClass1.__mro__) # (, , )
27 | print(MyClass2.__mro__) # (, , , )
28 | # Test 1
29 | obj1 = MyClass1(1)
30 | print(obj1.hparams) # Output: {'arg1': 1}
31 | print(obj1) # Output:
32 | # MyClass1(
33 | # (linear): Linear(in_features=3, out_features=1, bias=True)
34 | # )
35 |
36 | # Test 2
37 | obj2 = MyClass2(3,4)
38 | print(obj2.hparams) # Output: {'arg3': False, 'arg1': 3, 'arg2': 4}
39 | print(obj2) # Output:
40 | # MyClass2(
41 | # (conv): Conv2d(3, 3, kernel_size=(4, 4), stride=(1, 1), bias=False)
42 | # )
43 |
44 | # Test 3
45 | from hyperbox.mutables.spaces import OperationSpace
46 | obj3 = OperationSpace([torch.nn.Linear(2,3),torch.nn.Linear(3,2)])
47 | print(obj3.hparams) # Output: {'mask': None, 'index': None, 'reduction': 'sum', 'return_mask': False, 'key': None, 'candidates': [Linear(in_features=2, out_features=3, bias=True), Linear(in_features=3, out_features=2, bias=True)]}
48 | print(obj3) # Output:
49 | # OperationSpace(
50 | # (candidates): ModuleList(
51 | # (0): Linear(in_features=2, out_features=3, bias=True)
52 | # (1): Linear(in_features=3, out_features=2, bias=True)
53 | # )
54 | # )
55 |
--------------------------------------------------------------------------------
/tests/utils/test_logger.py:
--------------------------------------------------------------------------------
1 | from hyperbox.utils.logger import get_logger
2 |
3 | if __name__ == '__main__':
4 | logger = get_logger(__name__)
5 | logger.info('This is info')
6 | logger.debug('This is debug')
7 | logger.warning('This is warning')
8 | logger.error('This is error')
--------------------------------------------------------------------------------