├── .github └── workflows │ ├── environment.yml │ └── python-package-conda.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── MANIFEST.in ├── README.md ├── README_cn.md ├── conda_env_gpu.yaml ├── distribute.sh ├── docs ├── release.md └── usage.md ├── examples ├── DARTS_CIFAR10.ipynb ├── RandomNAS_MNIST.ipynb └── ResNet_CIFAR10.ipynb ├── hyperbox ├── __init__.py ├── callbacks │ ├── __init__.py │ └── wandb_callbacks.py ├── configs │ ├── __init__.py │ ├── callbacks │ │ ├── default.yaml │ │ ├── none.yaml │ │ └── wandb.yaml │ ├── config.yaml │ ├── datamodule │ │ ├── cifar100_datamodule.yaml │ │ ├── cifar10_datamodule.yaml │ │ ├── fakedata_datamodule.yaml │ │ ├── imagenet_dali_datamodule.yaml │ │ ├── imagenet_datamodule.yaml │ │ ├── medmnist_datamodule.yaml │ │ ├── mnist_datamodule.yaml │ │ └── transforms │ │ │ └── cifar.yaml │ ├── engine │ │ └── none.yaml │ ├── experiment │ │ ├── example_bnnas.yaml │ │ ├── example_classify.yaml │ │ ├── example_darts_nas.yaml │ │ ├── example_full.yaml │ │ ├── example_nasbench.yaml │ │ ├── example_ofa_nas.yaml │ │ ├── example_random_nas.yaml │ │ ├── example_repnas.yaml │ │ └── example_simple.yaml │ ├── hparams_search │ │ └── mnist_optuna.yaml │ ├── hydra │ │ └── default.yaml │ ├── lite.yaml │ ├── logger │ │ ├── comet.yaml │ │ ├── csv.yaml │ │ ├── many_loggers.yaml │ │ ├── mlflow.yaml │ │ ├── neptune.yaml │ │ ├── tensorboard.yaml │ │ └── wandb.yaml │ ├── model │ │ ├── classify_model.yaml │ │ ├── darts_model.yaml │ │ ├── loss_cfg │ │ │ ├── cross_entropy.yaml │ │ │ └── cross_entropy_labelsmooth.yaml │ │ ├── metric_cfg │ │ │ └── accuracy.yaml │ │ ├── mnist_model.yaml │ │ ├── mutator_cfg │ │ │ ├── darts_multiple_mutator.yaml │ │ │ ├── darts_mutator.yaml │ │ │ ├── ea_mutator.yaml │ │ │ ├── enas_mutator.yaml │ │ │ ├── evolution_mutator.yaml │ │ │ ├── fairdarts_mutator.yaml │ │ │ ├── fairnas_mutator.yaml │ │ │ ├── fewshot_mutator.yaml │ │ │ ├── onehot_mutator.yaml │ │ │ ├── random_multiple_mutator.yaml │ │ │ ├── random_mutator.yaml │ │ │ └── repnas_mutator.yaml │ │ ├── nasbench_model.yaml │ │ ├── network_cfg │ │ │ ├── bn_nas.yaml │ │ │ ├── darts_network.yaml │ │ │ ├── enas_macro.yaml │ │ │ ├── enas_micro.yaml │ │ │ ├── finegrained_resnet.yaml │ │ │ ├── mobilenet2d_nas.yaml │ │ │ ├── mobilenet3d_nas.yaml │ │ │ ├── nasbench201.yaml │ │ │ ├── nasbench_mbnet.yaml │ │ │ ├── ofa_mbv3.yaml │ │ │ ├── proxylessnas.yaml │ │ │ ├── repnas_network.yaml │ │ │ └── torch_network.yaml │ │ ├── ofa_model.yaml │ │ ├── optimizer_cfg │ │ │ ├── adadelta.yaml │ │ │ ├── adam.yaml │ │ │ ├── adamw.yaml │ │ │ ├── asam.yaml │ │ │ ├── lamb.yaml │ │ │ ├── rmsprop.yaml │ │ │ ├── sam.yaml │ │ │ └── sgd.yaml │ │ ├── random_model.yaml │ │ ├── repnas_model.yaml │ │ ├── resnet18.yaml │ │ └── scheduler_cfg │ │ │ ├── CosineAnnealingLR.yaml │ │ │ ├── ExponentialLR.yaml │ │ │ ├── MultiStepLR.yaml │ │ │ ├── ReducedLRonPlateau.yaml │ │ │ └── warmup_scheduler.yaml │ ├── paths │ │ └── default.yaml │ └── trainer │ │ ├── ddp.yaml │ │ └── default.yaml ├── datamodules │ ├── __init__.py │ ├── cifar_datamodule.py │ ├── datasets │ │ └── __init__.py │ ├── distributed_sampler_wrapper.py │ ├── fakedata_datamodule.py │ ├── imagenet_dali_datamodule.py │ ├── imagenet_datamodule.py │ ├── medmnist_datamodule.py │ ├── mnist_datamodule.py │ └── transforms │ │ ├── __init__.py │ │ ├── albumentation_transforms.py │ │ ├── autoaugment.py │ │ ├── base_transforms.py │ │ ├── cutout.py │ │ └── torch_transforms.py ├── engine │ ├── __init__.py │ └── base_engine.py ├── lites │ └── base_lite.py ├── losses │ ├── __init__.py │ ├── ce_labelsmooth_loss.py │ ├── focal_loss.py │ └── kd_loss.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── classify_model.py │ ├── darts_model.py │ ├── mnist_model.py │ ├── nasbench_model.py │ ├── ofa_model.py │ └── random_model.py ├── mutables │ ├── __init__.py │ ├── layers │ │ ├── __init__.py │ │ └── layers2d.py │ ├── masker.py │ ├── ops │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── batchnorm.py │ │ ├── conv.py │ │ ├── embedding.py │ │ ├── groupnorm.py │ │ ├── layernorm.py │ │ ├── linear.py │ │ ├── multihead_attention.py │ │ └── utils.py │ └── spaces.py ├── mutator │ ├── __init__.py │ ├── base_mutator.py │ ├── darts_multiple_mutator.py │ ├── darts_mutator.py │ ├── default_mutator.py │ ├── enas_mutator.py │ ├── evolution_mutator.py │ ├── fairnas_mutator.py │ ├── fewshot_mutator.py │ ├── fixed_mutator.py │ ├── onehot_mutator.py │ ├── proxyless_mutator.py │ ├── random_multiple_mutator.py │ ├── random_mutator.py │ ├── repnas_mutator.py │ ├── sequential_mutator.py │ └── utils.py ├── networks │ ├── __init__.py │ ├── base_nas_network.py │ ├── bnnas │ │ ├── __init__.py │ │ ├── bn_blocks.py │ │ ├── bn_net.py │ │ └── ea_search.py │ ├── darts │ │ ├── __init__.py │ │ ├── darts_mask.json │ │ ├── darts_network.py │ │ └── darts_ops.py │ ├── enas │ │ ├── __init__.py │ │ ├── enas_network.py │ │ └── enas_ops.py │ ├── gpt │ │ ├── __init__.py │ │ └── gpt2.py │ ├── mobilenet │ │ ├── __init__.py │ │ ├── mobile3d_net.py │ │ ├── mobile3d_ops.py │ │ ├── mobile_net.py │ │ ├── mobile_ops.py │ │ └── mobile_utils.py │ ├── nasbench101 │ │ ├── __init__.py │ │ ├── base_ops.py │ │ ├── db_gen │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ ├── db_gen.py │ │ │ ├── graph_util.py │ │ │ ├── model.py │ │ │ └── query.py │ │ ├── graph_util.py │ │ ├── model_spec.py │ │ ├── nasbench101.py │ │ └── readme.md │ ├── nasbench201 │ │ ├── __init__.py │ │ ├── db_gen │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ ├── db_gen.py │ │ │ ├── model.py │ │ │ └── query.py │ │ ├── gen_nasbench201_data.sh │ │ └── nasbench201.py │ ├── nasbench301 │ │ ├── __init__.py │ │ ├── install │ │ │ ├── readme.md │ │ │ ├── requirements_cu102.txt │ │ │ └── requirements_cu111.txt │ │ ├── nasbench301_network.py │ │ └── utils.py │ ├── nasbench_mbnet │ │ ├── __init__.py │ │ ├── nasbench_mbnet_cifar10.json │ │ └── network.py │ ├── nasbenchasr │ │ ├── __init__.py │ │ ├── dataset.py │ │ ├── download_nasbenchasr.sh │ │ ├── graph_utils.py │ │ ├── model.py │ │ ├── ops.py │ │ ├── readme.md │ │ ├── search_space.py │ │ └── utils.py │ ├── network_ema.py │ ├── ofa │ │ ├── __init__.py │ │ ├── ofa_mbv3.py │ │ └── ofa_mbv3_searchspace.json │ ├── proxylessnas │ │ ├── __init__.py │ │ ├── network.py │ │ ├── ops.py │ │ └── putils.py │ ├── pytorch_modules.py │ ├── repnas │ │ ├── __init__.py │ │ ├── rep_ops.py │ │ ├── repnas_spos.py │ │ └── utils.py │ ├── resnet │ │ ├── __init__.py │ │ └── resnet.py │ ├── spos │ │ ├── __init__.py │ │ ├── shuffle_blocks.py │ │ └── spos_net.py │ ├── utils.py │ └── vit │ │ ├── __init__.py │ │ └── vit.py ├── optimizers │ ├── __init__.py │ ├── lamb.py │ └── sam.py ├── run.py ├── schedulers │ ├── __init__.py │ └── warmup_scheduler.py ├── train.py └── utils │ ├── __init__.py │ ├── average_meter.py │ ├── calc_model_size.py │ ├── logger.py │ ├── metrics.py │ ├── utils.py │ ├── visualize_darts_cell.py │ └── visualize_mbconv_net.py ├── requirements.txt ├── run.sh ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── datamodules └── test_transforms.py ├── helpers ├── __init__.py ├── module_available.py ├── run_command.py └── runif.py ├── models └── test_random_nas_model.py ├── mutables ├── test_all_mutables.py ├── test_duplicate_mutables.py ├── test_op_bn.py ├── test_op_conv.py └── test_op_linear.py ├── networks ├── test_build_subnet.py ├── test_darts_net.py ├── test_enas_net.py ├── test_mobile_net.py └── test_resnet.py ├── smoke ├── __init__.py ├── test_commands.py ├── test_mixed_precision.py ├── test_sweeps.py └── test_wandb.py ├── unit ├── __init__.py └── test_sth.py └── utils ├── test_calc_model_size.py ├── test_hparams.py └── test_logger.py /.github/workflows/environment.yml: -------------------------------------------------------------------------------- 1 | name: hyperbox 2 | 3 | channels: 4 | - pytorch 5 | - conda-forge 6 | 7 | dependencies: 8 | - python=3.10 9 | - pip 10 | - cudatoolkit 11 | - pytorch=1.12.1 12 | - torchvision=0.13.1 13 | - pip: 14 | - -r requirements.txt 15 | -------------------------------------------------------------------------------- /.github/workflows/python-package-conda.yml: -------------------------------------------------------------------------------- 1 | name: Python Package using Conda 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build-linux: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | max-parallel: 5 10 | 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Python 3.10 14 | uses: actions/setup-python@v3 15 | with: 16 | python-version: '3.10' 17 | - name: Add conda to system path 18 | run: | 19 | # $CONDA is an environment variable pointing to the root of the miniconda directory 20 | echo $CONDA/bin >> $GITHUB_PATH 21 | - name: Install dependencies 22 | run: | 23 | conda env update --file /home/runner/work/hyperbox/.github/workflows/environment.yml --name hyperbox 24 | - name: Lint with flake8 25 | run: | 26 | conda install flake8 27 | # stop the build if there are Python syntax errors or undefined names 28 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 29 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 30 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 31 | - name: Test with pytest 32 | run: | 33 | conda install pytest 34 | pytest 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | ### Macbook 140 | .DS_Store 141 | 142 | # JetBrains 143 | .idea/ 144 | 145 | # Lightning-Hydra-Template 146 | data/ 147 | logs/ 148 | wandb/ 149 | .env 150 | .autoenv 151 | 152 | outputs 153 | logs 154 | hyperbox_app/* 155 | !hyperbox_app/empty_app_template -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.8 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v3.4.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-yaml 12 | - id: check-added-large-files 13 | - id: debug-statements 14 | - id: detect-private-key 15 | 16 | # python code formatting 17 | - repo: https://github.com/psf/black 18 | rev: 20.8b1 19 | hooks: 20 | - id: black 21 | args: [--line-length, "99"] 22 | 23 | # python import sorting 24 | - repo: https://github.com/PyCQA/isort 25 | rev: 5.8.0 26 | hooks: 27 | - id: isort 28 | 29 | # yaml formatting 30 | - repo: https://github.com/pre-commit/mirrors-prettier 31 | rev: v2.3.0 32 | hooks: 33 | - id: prettier 34 | types: [yaml] 35 | 36 | # python code analysis 37 | - repo: https://github.com/PyCQA/flake8 38 | rev: 3.9.2 39 | hooks: 40 | - id: flake8 41 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Build: docker build -t project_name . 2 | # Run: docker run --gpus all -it --rm project_name 3 | 4 | # Build from official Nvidia PyTorch image 5 | # GPU-ready with Apex for mixed-precision support 6 | # https://ngc.nvidia.com/catalog/containers/nvidia:pytorch 7 | # https://docs.nvidia.com/deeplearning/frameworks/support-matrix/ 8 | FROM nvcr.io/nvidia/pytorch:22.07-py3 9 | 10 | 11 | # Copy all files 12 | ADD . /workspace/hyperbox 13 | WORKDIR /workspace/hyperbox 14 | 15 | 16 | # Create hyperbox 17 | RUN conda env create -f conda_env_gpu.yaml -n hyperbox 18 | RUN conda init bash 19 | 20 | 21 | # Set hyperbox to default virtual environment 22 | RUN echo "source activate hyperbox" >> ~/.bashrc 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 KINoAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include hyperbox/configs *.yaml -------------------------------------------------------------------------------- /conda_env_gpu.yaml: -------------------------------------------------------------------------------- 1 | name: hyperbox 2 | 3 | channels: 4 | - pytorch 5 | - conda-forge 6 | 7 | dependencies: 8 | - python=3.8 9 | - pip 10 | - cudatoolkit 11 | - pytorch=1.8.1 12 | - torchvision=0.9.1 13 | - pip: 14 | - -r requirements.txt 15 | -------------------------------------------------------------------------------- /distribute.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -rf build dist hyperbox.egg-info 4 | python setup.py sdist bdist_wheel 5 | echo Building... 6 | echo Uploading... 7 | if (($1==1));then 8 | python -m twine upload dist/* 9 | echo Uploading to PyPI... 10 | elif (($1==2));then 11 | python -m twine upload --repository-url https://test.pypi.org/legacy/ dist/* 12 | echo Uploading to TestPyPI... 13 | else 14 | echo "Wrong command, only support 1 or 2" 15 | fi 16 | echo Done. -------------------------------------------------------------------------------- /hyperbox/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/__init__.py -------------------------------------------------------------------------------- /hyperbox/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/callbacks/__init__.py -------------------------------------------------------------------------------- /hyperbox/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/configs/__init__.py -------------------------------------------------------------------------------- /hyperbox/configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: "val/acc" # name of the logged metric which determines when model is improving 4 | save_top_k: 1 # save k best models (determined by above metric) 5 | save_last: True # additionaly always save model from last epoch 6 | mode: "max" # can be "max" or "min" 7 | verbose: False 8 | dirpath: "checkpoints/" 9 | filename: "{epoch:02d}_{val/acc:.4f}" 10 | 11 | early_stopping: 12 | _target_: pytorch_lightning.callbacks.EarlyStopping 13 | monitor: "val/acc" # name of the logged metric which determines when model is improving 14 | patience: 1000 # how many epochs of not improving until training stops 15 | mode: "max" # can be "max" or "min" 16 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 17 | check_on_train_epoch_end: False 18 | strict: False -------------------------------------------------------------------------------- /hyperbox/configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /hyperbox/configs/callbacks/wandb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | watch_model: 5 | _target_: hyperbox.callbacks.wandb_callbacks.WatchModel 6 | log: "all" 7 | log_freq: 100 8 | 9 | upload_code_as_artifact: 10 | _target_: hyperbox.callbacks.wandb_callbacks.UploadCodeAsArtifact 11 | code_dir: ${work_dir}/src 12 | 13 | upload_ckpts_as_artifact: 14 | _target_: hyperbox.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact 15 | ckpt_dir: "checkpoints/" 16 | upload_best_only: True 17 | 18 | log_f1_precision_recall_heatmap: 19 | _target_: hyperbox.callbacks.wandb_callbacks.LogF1PrecRecHeatmap 20 | 21 | log_confusion_matrix: 22 | _target_: hyperbox.callbacks.wandb_callbacks.LogConfusionMatrix 23 | 24 | log_image_predictions: 25 | _target_: hyperbox.callbacks.wandb_callbacks.LogImagePredictions 26 | num_samples: 8 27 | -------------------------------------------------------------------------------- /hyperbox/configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - trainer: default.yaml 6 | - model: mnist_model.yaml 7 | - datamodule: mnist_datamodule.yaml 8 | - callbacks: default.yaml # set this to null if you don't want to use callbacks 9 | - logger: wandb # set logger here or use command line (e.g. `python run.py logger=wandb`) 10 | - engine: none.yaml 11 | - hparams_search: null 12 | 13 | - paths: default.yaml 14 | - hydra: default.yaml 15 | 16 | - experiment: null 17 | 18 | task_name: "hyperbox_project" 19 | 20 | # use `python run.py debug=true` for easy debugging! 21 | # this will run 1 train, val and test loop with only 1 batch 22 | # equivalent to running `python run.py trainer.fast_dev_run=true` 23 | # (this is placed here just for easier access from command line) 24 | debug: False 25 | 26 | # pretty print config at the start of the run using Rich library 27 | print_config: True 28 | 29 | # disable python warnings if they annoy you 30 | ignore_warnings: True 31 | 32 | # check performance on test set, using the best model achieved during training 33 | # lightning chooses best model based on metric specified in checkpoint callback 34 | test_after_training: True 35 | 36 | # checkpoint for testing. If specified, only test will be performed 37 | only_test: False 38 | 39 | # only load pretrained weights for network (exclude optimizer and scheduler) 40 | # e.g., 'logs/run/exp_name/checkpoint/epoch=66/acc=66.66.ckpt' 41 | pretrained_weight: null 42 | ipdb_debug: False -------------------------------------------------------------------------------- /hyperbox/configs/datamodule/cifar100_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.datamodules.cifar_datamodule.CIFAR100DataModule 2 | 3 | defaults: 4 | - transforms: cifar 5 | 6 | data_dir: ~/datasets/cifar100 7 | val_split: 0.1 8 | num_workers: 4 9 | normalize: True 10 | batch_size: 64 11 | seed: 666 12 | shuffle: True 13 | pin_memory: False 14 | drop_last: False 15 | is_customized: False 16 | -------------------------------------------------------------------------------- /hyperbox/configs/datamodule/cifar10_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.datamodules.cifar_datamodule.CIFAR10DataModule 2 | 3 | defaults: 4 | - transforms: cifar 5 | 6 | data_dir: ~/datasets/cifar10 7 | val_split: 0.5 8 | num_workers: 4 9 | normalize: True 10 | batch_size: 64 11 | seed: 666 12 | shuffle: True 13 | pin_memory: False 14 | drop_last: False 15 | is_customized: False 16 | -------------------------------------------------------------------------------- /hyperbox/configs/datamodule/fakedata_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.datamodules.fakedata_datamodule.FakeDataModule 2 | 3 | train_size: 4000 4 | test_size: 500 5 | image_size: [3, 64, 64] 6 | num_classes: 10 7 | data_dir: /path/to/data # data_dir is specified in config.yaml 8 | batch_size: 8 9 | train_val_test_split: [0.6, 0.2, 0.2] 10 | num_workers: 0 11 | pin_memory: False 12 | -------------------------------------------------------------------------------- /hyperbox/configs/datamodule/imagenet_dali_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.datamodules.imagenet_datamodule.ImagenetDALIDataModule 2 | dataset_path: /datasets/imagenet100 3 | batch_size: 4 4 | num_workers: 8 5 | dali_cpu: False 6 | val_size: 256 7 | crop_size: 224 -------------------------------------------------------------------------------- /hyperbox/configs/datamodule/imagenet_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.datamodules.imagenet_datamodule.ImagenetDataModule 2 | data_dir: ~/datasets/imagenet2012 3 | classes: 1000 4 | autoaugment: True 5 | image_size: 224 6 | batch_size: 64 7 | shuffle: True 8 | pin_memory: True 9 | drop_last: False 10 | num_workers: 8 -------------------------------------------------------------------------------- /hyperbox/configs/datamodule/medmnist_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.datamodules.medmnist_datamodule.MedMNISTDataModule 2 | data_dir: '~/datasets/medmnist/' 3 | data_flag: 'synapsemnist3d' 4 | batch_size: 128 5 | num_workers: 8 -------------------------------------------------------------------------------- /hyperbox/configs/datamodule/mnist_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.datamodules.mnist_datamodule.MNISTDataModule 2 | 3 | data_dir: /path/to/data # data_dir is specified in config.yaml 4 | batch_size: 32 5 | num_workers: 0 6 | pin_memory: False 7 | -------------------------------------------------------------------------------- /hyperbox/configs/datamodule/transforms/cifar.yaml: -------------------------------------------------------------------------------- 1 | input_size: [32, 32] 2 | random_crop: 3 | enable: 1 4 | padding: 4 5 | size: 32 6 | random_horizontal_flip: 7 | enable: 1 8 | p: 0.5 9 | cutout: 10 | enable: 1 11 | n_holes: 1 12 | length: 16 13 | normalize: 14 | enable: 1 15 | mean: [0.4914, 0.4822, 0.4465] 16 | std: [0.2023, 0.1994, 0.2010] -------------------------------------------------------------------------------- /hyperbox/configs/engine/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/configs/engine/none.yaml -------------------------------------------------------------------------------- /hyperbox/configs/experiment/example_bnnas.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: random_model.yaml 9 | - override /model/network_cfg: bn_nas.yaml 10 | - override /datamodule: cifar10_datamodule.yaml 11 | - override /callbacks: default.yaml 12 | - override /logger: wandb.yaml 13 | - override /model/scheduler_cfg: CosineAnnealingLR.yaml 14 | 15 | # all parameters below will be merged with parameters from default configurations set above 16 | # this allows you to overwrite only specified parameters 17 | 18 | seed: 12345 19 | 20 | trainer: 21 | max_epochs: 200 22 | strategy: horovod 23 | gpus: 1 24 | debug: False 25 | # logger.wandb.name: test 26 | # logger.wandb.offline: True 27 | # hydra: 28 | # job: 29 | # name: test -------------------------------------------------------------------------------- /hyperbox/configs/experiment/example_classify.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: classify_model.yaml 9 | - override /datamodule: cifar10_datamodule.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: wandb.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | seed: 12345 17 | 18 | trainer: 19 | min_epochs: 1 20 | max_epochs: 600 21 | 22 | logger: 23 | wandb: 24 | offline: True 25 | project: Classification -------------------------------------------------------------------------------- /hyperbox/configs/experiment/example_darts_nas.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: darts_model.yaml 9 | - override /datamodule: fakedata_datamodule.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: wandb.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | seed: 12345 17 | 18 | trainer: 19 | min_epochs: 1 20 | max_epochs: 200 21 | 22 | datamodule: 23 | is_customized: True 24 | 25 | # model: 26 | # mutator_cfg: 27 | # _target_: hyperbox.mutator.OnehotMutator 28 | 29 | logger: 30 | wandb: 31 | offline: True -------------------------------------------------------------------------------- /hyperbox/configs/experiment/example_full.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_full.yaml 5 | 6 | defaults: 7 | - override /trainer: null # override trainer to null so it's not loaded from main config defaults... 8 | - override /model: null 9 | - override /datamodule: null 10 | - override /callbacks: null 11 | - override /logger: null 12 | 13 | # we override default configurations with nulls to prevent them from loading at all 14 | # instead we define all modules and their paths directly in this config, 15 | # so everything is stored in one place 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | _target_: pytorch_lightning.Trainer 21 | min_epochs: 1 22 | max_epochs: 10 23 | 24 | model: 25 | _target_: hyperbox.models.mnist_model.MNISTLitModel 26 | lr: 0.001 27 | weight_decay: 0.00005 28 | architecture: SimpleDenseNet 29 | input_size: 784 30 | lin1_size: 256 31 | lin2_size: 256 32 | lin3_size: 128 33 | output_size: 10 34 | 35 | datamodule: 36 | _target_: hyperbox.datamodules.mnist_datamodule.MNISTDataModule 37 | data_dir: /path/to/data 38 | batch_size: 64 39 | train_val_test_split: [55_000, 5_000, 10_000] 40 | num_workers: 0 41 | pin_memory: False 42 | 43 | callbacks: 44 | model_checkpoint: 45 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 46 | monitor: "val/acc" 47 | save_top_k: 2 48 | save_last: True 49 | mode: "max" 50 | dirpath: "checkpoints/" 51 | filename: "sample-mnist-{epoch:02d}" 52 | early_stopping: 53 | _target_: pytorch_lightning.callbacks.EarlyStopping 54 | monitor: "val/acc" 55 | patience: 10 56 | mode: "max" 57 | 58 | logger: 59 | wandb: 60 | tags: ["best_model", "uwu"] 61 | notes: "Description of this model." 62 | neptune: 63 | tags: ["best_model"] 64 | csv_logger: 65 | save_dir: "." 66 | -------------------------------------------------------------------------------- /hyperbox/configs/experiment/example_nasbench.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: darts_model.yaml 9 | - override /datamodule: cifar10_datamodule.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: wandb.yaml 12 | - override /model/scheduler_cfg: CosineAnnealingLR.yaml 13 | 14 | 15 | # all parameters below will be merged with parameters from default configurations set above 16 | # this allows you to overwrite only specified parameters 17 | 18 | seed: 12345 19 | 20 | trainer: 21 | min_epochs: 1 22 | max_epochs: 200 23 | gpus: 1 24 | 25 | datamodule: 26 | is_customized: True 27 | # is_customized: False 28 | 29 | model: 30 | mutator_cfg: 31 | _target_: hyperbox.mutator.DartsMultipleMutator 32 | # _target_: hyperbox.mutator.RandomMultipleMutator 33 | optimizer_cfg: 34 | lr: 0.001 35 | 36 | logger: 37 | wandb: 38 | offline: True -------------------------------------------------------------------------------- /hyperbox/configs/experiment/example_ofa_nas.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: ofa_model.yaml 9 | - override /datamodule: cifar10_datamodule.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: wandb.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | seed: 12345 17 | 18 | trainer: 19 | min_epochs: 1 20 | max_epochs: 100 21 | logger: 22 | wandb: 23 | project: "OFA" 24 | offline: False 25 | callbacks: 26 | model_checkpoint: 27 | monitor: "val/acc" # name of the logged metric which determines when model is improving 28 | save_top_k: 2 # save k best models (determined by above metric) 29 | -------------------------------------------------------------------------------- /hyperbox/configs/experiment/example_random_nas.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: random_model.yaml 9 | - override /datamodule: fakedata_datamodule.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: wandb.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | seed: 12345 17 | 18 | -------------------------------------------------------------------------------- /hyperbox/configs/experiment/example_repnas.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | # fintune repname_model.yaml -> classify_model 7 | 8 | defaults: 9 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/' 10 | - override /model: repnas_model.yaml 11 | - override /datamodule: cifar10_datamodule.yaml 12 | - override /callbacks: default.yaml 13 | - override /logger: wandb.yaml 14 | - override /model/scheduler_cfg: CosineAnnealingLR.yaml 15 | 16 | # all parameters below will be merged with parameters from default configurations set above 17 | # this allows you to overwrite only specified parameters 18 | 19 | seed: 12345 20 | 21 | trainer: 22 | min_epochs: 1 23 | max_epochs: 200 24 | gpus: 1 25 | 26 | datamodule: 27 | is_customized: True # for darts mutator 28 | # is_customized: False # for other mutators 29 | 30 | 31 | model: 32 | mutator_cfg: 33 | _target_: hyperbox.mutator.DartsMultipleMutator 34 | # _target_: hyperbox.mutator.RandomMultipleMutator 35 | optimizer_cfg: 36 | lr: 0.001 37 | 38 | logger: 39 | wandb: 40 | offline: True 41 | -------------------------------------------------------------------------------- /hyperbox/configs/experiment/example_simple.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: minimal.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: mnist_model.yaml 9 | - override /datamodule: fakedata_datamodule.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: null 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | seed: 12345 17 | 18 | trainer: 19 | min_epochs: 1 20 | max_epochs: 1 21 | gradient_clip_val: 0.5 22 | 23 | model: 24 | lin1_size: 128 25 | lin2_size: 256 26 | lin3_size: 64 27 | lr: 0.002 28 | 29 | datamodule: 30 | batch_size: 64 31 | image_size: [1,28,28] 32 | -------------------------------------------------------------------------------- /hyperbox/configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple 5 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple hydra.sweeper.n_trials=30 6 | # python run.py -m hparams_search=mnist_optuna experiment=example_simple logger=wandb 7 | 8 | defaults: 9 | - override /hydra/sweeper: optuna 10 | 11 | # choose metric which will be optimized by Optuna 12 | optimized_metric: "val/acc" 13 | 14 | hydra: 15 | # here we define Optuna hyperparameter search 16 | # it optimizes for value returned from function with @hydra.main decorator 17 | # learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper 18 | sweeper: 19 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 20 | storage: null 21 | study_name: null 22 | n_jobs: 1 23 | 24 | # 'minimize' or 'maximize' the objective 25 | direction: maximize 26 | 27 | # number of experiments that will be executed 28 | n_trials: 20 29 | 30 | # choose Optuna hyperparameter sampler 31 | # learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html 32 | sampler: 33 | _target_: optuna.samplers.TPESampler 34 | seed: 12345 35 | consider_prior: true 36 | prior_weight: 1.0 37 | consider_magic_clip: true 38 | consider_endpoints: false 39 | n_startup_trials: 10 40 | n_ei_candidates: 24 41 | multivariate: false 42 | warn_independent_sampling: true 43 | 44 | # define range of hyperparameters 45 | search_space: 46 | datamodule.batch_size: 47 | type: categorical 48 | choices: [32, 64, 128] 49 | model.lr: 50 | type: float 51 | low: 0.0001 52 | high: 0.2 53 | model.lin1_size: 54 | type: categorical 55 | choices: [32, 64, 128, 256, 512] 56 | model.lin2_size: 57 | type: categorical 58 | choices: [32, 64, 128, 256, 512] 59 | model.lin3_size: 60 | type: categorical 61 | choices: [32, 64, 128, 256, 512] 62 | -------------------------------------------------------------------------------- /hyperbox/configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # output paths for hydra logs 2 | 3 | defaults: 4 | - override hydra_logging: colorlog 5 | - override job_logging: colorlog 6 | 7 | run: 8 | dir: ${paths.log_dir}/runs/${task_name}/${now:%Y-%m-%d_%H-%M-%S} 9 | sweep: 10 | dir: ${paths.log_dir}/multiruns/${task_name}/${now:%Y-%m-%d_%H-%M-%S} 11 | subdir: ${hydra.job.num} 12 | 13 | # you can set here environment variables that are universal for all users 14 | # for system specific variables (like data paths) it's better to use .env file! 15 | job: 16 | env_set: 17 | EXAMPLE_VAR: "example_value" 18 | name: "exp" 19 | chdir: True # the output/working directory will be changed to {hydra.job.name}, you can see below URL for more details 20 | # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory/#disable-changing-current-working-dir-to-jobs-output-dir 21 | -------------------------------------------------------------------------------- /hyperbox/configs/lite.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - model: darts_model.yaml 5 | - datamodule: cifar10_datamodule.yaml 6 | - logger: wandb.yaml 7 | - hydra: default.yaml 8 | 9 | lite: 10 | accelerator: "auto" # gpu, cpu, ipu, tpu, auto 11 | strategy: null # "dp", "ddp", "ddp_spawn", "tpu_spawn", "deepspeed", "ddp_sharded", or "ddp_sharded_spawn" 12 | devices: null 13 | num_nodes: 1 14 | precision: 32 15 | plugins: null 16 | gpus: null 17 | tpu_cores: null 18 | 19 | ipdb_debug: False 20 | logger: 21 | wandb: 22 | name: 'lite' 23 | hydra: 24 | job: 25 | name: 'lite' 26 | -------------------------------------------------------------------------------- /hyperbox/configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: pytorch_lightning.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is laoded from environment variable 6 | project_name: "template-tests" 7 | experiment_name: null 8 | -------------------------------------------------------------------------------- /hyperbox/configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "." 6 | name: "csv/" 7 | version: null 8 | prefix: "" 9 | -------------------------------------------------------------------------------- /hyperbox/configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - aim.yaml 5 | # - comet.yaml 6 | # - csv.yaml 7 | # - mlflow.yaml 8 | # - neptune.yaml 9 | - tensorboard.yaml 10 | - wandb.yaml 11 | -------------------------------------------------------------------------------- /hyperbox/configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger 5 | experiment_name: default 6 | tracking_uri: null 7 | tags: null 8 | save_dir: ./mlruns 9 | prefix: "" 10 | artifact_location: null 11 | -------------------------------------------------------------------------------- /hyperbox/configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is laoded from environment variable 6 | project_name: your_name/template-tests 7 | close_after_fit: True 8 | offline_mode: False 9 | experiment_name: null 10 | experiment_id: null 11 | prefix: "" 12 | -------------------------------------------------------------------------------- /hyperbox/configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "tensorboard/" 6 | name: "default" 7 | version: null 8 | log_graph: False 9 | default_hp_metric: True 10 | prefix: "" 11 | -------------------------------------------------------------------------------- /hyperbox/configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | project: "template-tests" 6 | name: null 7 | save_dir: "." 8 | offline: True # set True to store all logs only locally 9 | id: null # pass correct id to resume experiment! 10 | # entity: "" # set to name of your wandb team or just remove it 11 | log_model: False 12 | prefix: "" 13 | job_type: "train" 14 | group: "" 15 | tags: [] 16 | -------------------------------------------------------------------------------- /hyperbox/configs/model/classify_model.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.models.classify_model.ClassifyModel 2 | 3 | defaults: 4 | - metric_cfg: accuracy 5 | - loss_cfg: cross_entropy 6 | - optimizer_cfg: sgd 7 | - network_cfg: darts_network 8 | - scheduler_cfg: null 9 | 10 | optimizer_cfg: 11 | lr: 0.025 12 | weight_decay: 0.0003 13 | -------------------------------------------------------------------------------- /hyperbox/configs/model/darts_model.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.models.darts_model.DARTSModel 2 | 3 | defaults: 4 | - network_cfg: finegrained_resnet 5 | - mutator_cfg: darts_mutator 6 | - optimizer_cfg: adam 7 | - metric_cfg: accuracy 8 | - loss_cfg: cross_entropy 9 | - scheduler_cfg: null -------------------------------------------------------------------------------- /hyperbox/configs/model/loss_cfg/cross_entropy.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.CrossEntropyLoss -------------------------------------------------------------------------------- /hyperbox/configs/model/loss_cfg/cross_entropy_labelsmooth.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.losses.ce_labelsmooth_loss.CrossEntropyLabelSmooth 2 | label_smoothing: 0.1 -------------------------------------------------------------------------------- /hyperbox/configs/model/metric_cfg/accuracy.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchmetrics.classification.accuracy.Accuracy 2 | # _target_: hyperbox.utils.metrics.Accuracy -------------------------------------------------------------------------------- /hyperbox/configs/model/mnist_model.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.models.mnist_model.MNISTLitModel 2 | 3 | input_size: 784 4 | lin1_size: 256 5 | lin2_size: 256 6 | lin3_size: 256 7 | output_size: 10 8 | lr: 0.001 9 | weight_decay: 0.0005 10 | -------------------------------------------------------------------------------- /hyperbox/configs/model/mutator_cfg/darts_multiple_mutator.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.mutator.DartsMultipleMutator -------------------------------------------------------------------------------- /hyperbox/configs/model/mutator_cfg/darts_mutator.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.mutator.DartsMutator -------------------------------------------------------------------------------- /hyperbox/configs/model/mutator_cfg/ea_mutator.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.mutator.EAMutator 2 | init_population_mode: 'warmup' 3 | warmup_epochs: 5 4 | num_population: 10 5 | prob_crossover: 0.3 6 | prob_mutation: 0.2 7 | object_keys: [1] # 0: flops & 1:size 8 | target_keys: [0] # 0: meters & 1:speed 9 | algorithm: 'cars' # 进化算法 10 | offspring_ratio: 1 11 | -------------------------------------------------------------------------------- /hyperbox/configs/model/mutator_cfg/enas_mutator.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.mutator.EnasMutator 2 | lstm_size: 61 3 | lstm_num_layers: 1 4 | tanh_constant: 1.5 5 | cell_exit_extra_step: False 6 | skip_target: 0.4 7 | branch_bias: 0.25 8 | arch_loss_weight: 0.02 # 0.002:small 0.02:medium 0.2:big 9 | reward_weight: 50 10 | temperature: 2 11 | entropy_reduction: 'sum' # 'sum', 'mean', 12 | -------------------------------------------------------------------------------- /hyperbox/configs/model/mutator_cfg/evolution_mutator.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.mutator.evolution_mutator.EvolutionMutator 2 | warmup_epochs: 0 3 | evolution_epochs: 100 4 | population_num: 50 5 | selection_alg: 'best' 6 | selection_num: 0.2 7 | crossover_num: 0.5 8 | crossover_prob: 0.1 9 | mutation_num: 0.5 10 | mutation_prob: 0.1 11 | flops_limit: 5000 # MFLOPs 12 | size_limit: 800 # MB 13 | log_dir: 'evolution_logs' 14 | topk: 10 15 | resume_from_checkpoint: null 16 | to_save_checkpoint: True 17 | to_plot_pareto: True 18 | figname: 'evolution_pareto.pdf' -------------------------------------------------------------------------------- /hyperbox/configs/model/mutator_cfg/fairdarts_mutator.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.mutator.fairdarts_mutator.FairDartsMutator -------------------------------------------------------------------------------- /hyperbox/configs/model/mutator_cfg/fairnas_mutator.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.mutator.fairnas_mutator.FairNASMutator -------------------------------------------------------------------------------- /hyperbox/configs/model/mutator_cfg/fewshot_mutator.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.mutator.fewshot_mutator.FewshotMutator 2 | training_epochs: 100 3 | split_mutable_indices: 2 4 | save_dir: null 5 | to_save_sub_supernets: True -------------------------------------------------------------------------------- /hyperbox/configs/model/mutator_cfg/onehot_mutator.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.mutator.OnehotMutator -------------------------------------------------------------------------------- /hyperbox/configs/model/mutator_cfg/random_multiple_mutator.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.mutator.RandomMultipleMutator -------------------------------------------------------------------------------- /hyperbox/configs/model/mutator_cfg/random_mutator.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.mutator.random_mutator.RandomMutator -------------------------------------------------------------------------------- /hyperbox/configs/model/mutator_cfg/repnas_mutator.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.mutator.RepnasMutator -------------------------------------------------------------------------------- /hyperbox/configs/model/nasbench_model.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.models.nasbench_model.NASBench201Model 2 | defaults: 3 | - network_cfg: nasbench201 4 | - mutator_cfg: random_mutator 5 | - optimizer_cfg: adam 6 | - metric_cfg: accuracy 7 | - loss_cfg: cross_entropy 8 | - scheduler_cfg: null 9 | -------------------------------------------------------------------------------- /hyperbox/configs/model/network_cfg/bn_nas.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.networks.bnnas.bn_net.BNNet 2 | first_stride: 1 3 | first_channels: 24 4 | width_mult: 1 5 | channels_list: [32,40,80,96,192,320] 6 | num_blocks: [2,2,4,4,4,1] 7 | strides_list: [2,2,2,1,2,1] 8 | num_classes: 10 9 | search_depth: False 10 | is_only_train_bn: False 11 | mask: null -------------------------------------------------------------------------------- /hyperbox/configs/model/network_cfg/darts_network.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.networks.darts.darts_network.DartsNetwork 2 | in_channels: 3 3 | channels: 16 # 36 for imagenet 4 | n_layers: 8 # 20 for imagenet 5 | stem_multiplier: 3 6 | n_nodes: 4 7 | n_classes: 10 8 | # mask: /path/to/hyperbox/hyperbox/networks/darts/darts_mask.json 9 | -------------------------------------------------------------------------------- /hyperbox/configs/model/network_cfg/enas_macro.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.networks.enas.ENASMacroGeneralModel 2 | num_layers: 12 3 | out_filters: 24 4 | in_channels: 3 5 | num_classes: 10 6 | dropout_rate: 0.5 7 | -------------------------------------------------------------------------------- /hyperbox/configs/model/network_cfg/enas_micro.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.networks.enas.ENASMicroNetwork 2 | num_layers: 2 3 | num_nodes: 5 4 | out_channels: 24 5 | in_channels: 3 6 | dropout_rate: 0.5 7 | -------------------------------------------------------------------------------- /hyperbox/configs/model/network_cfg/finegrained_resnet.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.networks.resnet.resnet20 -------------------------------------------------------------------------------- /hyperbox/configs/model/network_cfg/mobilenet2d_nas.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.networks.mobilenet.MobileNet 2 | width_stages: [24,40,80,96,192,320] 3 | n_cell_stages: [4,4,4,4,4,1] 4 | stride_stages: [2,1,1,1,2,1] 5 | width_mult: 1 6 | classes: 1000 7 | dropout_rate: 0 8 | bn_param: [0.1, 1e-3] 9 | -------------------------------------------------------------------------------- /hyperbox/configs/model/network_cfg/mobilenet3d_nas.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.networks.mobilenet.Mobile3DNet 2 | width_stages: [24,40,80,96,192,320] 3 | n_cell_stages: [4,4,4,4,4,1] 4 | stride_stages: [2,2,2,1,2,1] 5 | width_mult: 1 6 | classes: 1000 7 | dropout_rate: 0 8 | bn_param: [0.1, 1e-3] 9 | -------------------------------------------------------------------------------- /hyperbox/configs/model/network_cfg/nasbench201.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.networks.nasbench201.nasbench201.NASBench201Network 2 | stem_out_channels: 16 3 | num_modules_per_stack: 5 4 | -------------------------------------------------------------------------------- /hyperbox/configs/model/network_cfg/nasbench_mbnet.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.networks.nasbench_mbnet.network.NASBenchMBNet 2 | arch_list: null 3 | num_classes: 10 4 | stages: [2,3,3] 5 | init_channels: 32 6 | mask: null -------------------------------------------------------------------------------- /hyperbox/configs/model/network_cfg/ofa_mbv3.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.networks.ofa.OFAMobileNetV3 2 | kernel_size_list: [3,5,7] 3 | expand_ratio_list: [3,4,6] 4 | depth_list: [2,3,4] 5 | base_stage_width: [16,16,24,40,80,112,160,960,1280] 6 | stride_stages: [1,2,2,2,1,2] 7 | act_stages: ['relu','relu','relu','h_swish','h_swish','h_swish'] 8 | se_stages: [False,False,True,False,True,True] 9 | width_mult: 1.0 10 | num_classes: 1000 -------------------------------------------------------------------------------- /hyperbox/configs/model/network_cfg/proxylessnas.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox_app.distributed.networks.proxylessnas.network.ProxylessNAS 2 | width_stages: [24,40,80,96,192,320] 3 | n_cell_stages: [4,4,4,4,4,1] 4 | stride_stages: [2,2,2,1,2,1] 5 | width_mult: 1 6 | num_classes: 1000 7 | dropout_rate: 0 8 | bn_param: [0.1,1e-3] 9 | mask: null -------------------------------------------------------------------------------- /hyperbox/configs/model/network_cfg/repnas_network.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.networks.repnas.RepNAS 2 | # in_channels: 3 3 | # channels: 16 4 | # n_layers: 8 5 | # stem_multiplier: 3 6 | # n_nodes: 4 7 | # n_classes: 10 8 | # mask: /home/comp/18481086/code/hyperbox/hyperbox/networks/darts/darts_mask.json 9 | -------------------------------------------------------------------------------- /hyperbox/configs/model/network_cfg/torch_network.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.models.mobilenet_v3_small 2 | num_classes: 10 -------------------------------------------------------------------------------- /hyperbox/configs/model/ofa_model.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.models.ofa_model.OFAModel 2 | 3 | defaults: 4 | - network_cfg: finegrained_resnet 5 | - mutator_cfg: random_mutator 6 | - optimizer_cfg: adam 7 | - metric_cfg: accuracy 8 | - loss_cfg: cross_entropy 9 | - scheduler_cfg: null -------------------------------------------------------------------------------- /hyperbox/configs/model/optimizer_cfg/adadelta.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adadelta 2 | lr: 0.01 3 | alpha: 0.99 4 | eps: 1e-08 5 | weight_decay: 0.005 6 | momentum: 0.9 -------------------------------------------------------------------------------- /hyperbox/configs/model/optimizer_cfg/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | lr: 0.001 3 | weight_decay: 0.0005 4 | -------------------------------------------------------------------------------- /hyperbox/configs/model/optimizer_cfg/adamw.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.AdamW 2 | lr: 0.001 3 | weight_decay: 0.001 4 | -------------------------------------------------------------------------------- /hyperbox/configs/model/optimizer_cfg/asam.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.optimizers.ASAM 2 | rho: 0.5 3 | eta: 0.01 -------------------------------------------------------------------------------- /hyperbox/configs/model/optimizer_cfg/lamb.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.optimizers.lamb.Lamb 2 | lr: 1e-3 3 | betas: [0.9,0.999] 4 | eps: 1e-6 5 | weight_decay: 5e-4 6 | adam: False -------------------------------------------------------------------------------- /hyperbox/configs/model/optimizer_cfg/rmsprop.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.RMSprop 2 | lr: 0.01 3 | alpha: 0.99 4 | eps: 1e-08 5 | weight_decay: 0.005 6 | momentum: 0.9 7 | centered: False -------------------------------------------------------------------------------- /hyperbox/configs/model/optimizer_cfg/sam.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.optimizers.SAM 2 | rho: 0.5 3 | eta: 0.01 -------------------------------------------------------------------------------- /hyperbox/configs/model/optimizer_cfg/sgd.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.SGD 2 | lr: 0.01 3 | weight_decay: 0.0005 4 | momentum: 0.9 5 | -------------------------------------------------------------------------------- /hyperbox/configs/model/random_model.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.models.random_model.RandomModel 2 | 3 | defaults: 4 | - network_cfg: darts_network 5 | - mutator_cfg: random_mutator 6 | - optimizer_cfg: adam 7 | - metric_cfg: accuracy 8 | - loss_cfg: cross_entropy 9 | - scheduler_cfg: null 10 | -------------------------------------------------------------------------------- /hyperbox/configs/model/repnas_model.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.models.darts_model.DARTSModel 2 | # _target_: hyperbox.models.classify_model.ClassifyModel 3 | # _target_: hyperbox.models.random_model.RandomModel 4 | 5 | defaults: 6 | - network_cfg: repnas_network.yaml 7 | - mutator_cfg: repnas_mutator.yaml 8 | - optimizer_cfg: sgd 9 | - metric_cfg: accuracy 10 | - scheduler_cfg: warmup_scheduler 11 | - loss_cfg: cross_entropy 12 | 13 | # network_cfg: 14 | # mask: '/home/pdluser/mask_json/mask_epoch_60.json' 15 | 16 | -------------------------------------------------------------------------------- /hyperbox/configs/model/resnet18.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.models.classify_model.ClassifyModel 2 | 3 | defaults: 4 | - metric_cfg: accuracy 5 | - loss_cfg: cross_entropy 6 | - optimizer_cfg: sgd 7 | - scheduler_cfg: null 8 | 9 | network_cfg: 10 | _target_: torchvision.models.resnet18 11 | optimizer_cfg: 12 | lr: 0.001 13 | weight_decay: 0.0005 -------------------------------------------------------------------------------- /hyperbox/configs/model/scheduler_cfg/CosineAnnealingLR.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.CosineAnnealingLR 2 | T_max: 180 3 | eta_min: 1e-4 4 | last_epoch: -1 -------------------------------------------------------------------------------- /hyperbox/configs/model/scheduler_cfg/ExponentialLR.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.ExponentialLR 2 | gamma: 0.99 3 | last_epoch: -1 -------------------------------------------------------------------------------- /hyperbox/configs/model/scheduler_cfg/MultiStepLR.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.MultiStepLR 2 | milestones: [80, 150, 220, 300] 3 | gamma: 0.8 4 | last_epoch: -1 -------------------------------------------------------------------------------- /hyperbox/configs/model/scheduler_cfg/ReducedLRonPlateau.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 2 | mode: 'max' # acc, 'min' for loss 3 | factor: 0.5 4 | patience: 5 5 | threshold: 0.0001 6 | threshold_mode: 'rel' # or 'abs' 7 | cooldown: 0 8 | min_lr: 1e-6 9 | verbose: False -------------------------------------------------------------------------------- /hyperbox/configs/model/scheduler_cfg/warmup_scheduler.yaml: -------------------------------------------------------------------------------- 1 | _target_: hyperbox.schedulers.warmup_scheduler.GradualWarmupScheduler 2 | multiplier: 1 3 | warmup_epoch: 10 4 | after_scheduler: 5 | _target_: torch.optim.lr_scheduler.CosineAnnealingLR 6 | T_max: 180 7 | eta_min: 1e-4 8 | # after_scheduler: null 9 | # after_scheduler: 10 | # _target_: torch.optim.lr_scheduler.StepLR 11 | # step_size: 80 12 | # gamma: 0.1 -------------------------------------------------------------------------------- /hyperbox/configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py` 4 | root_dir: . 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} -------------------------------------------------------------------------------- /hyperbox/configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # use "ddp_spawn" instead of "ddp", 5 | # it's slower but normal "ddp" currently doesn't work ideally with hydra 6 | # https://github.com/facebookresearch/hydra/issues/2070 7 | # https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn 8 | strategy: ddp_spawn 9 | 10 | accelerator: gpu 11 | devices: 4 12 | num_nodes: 1 13 | sync_batchnorm: True -------------------------------------------------------------------------------- /hyperbox/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | min_epochs: 1 # prevents early stopping 4 | max_epochs: 10 5 | 6 | accelerator: gpu 7 | devices: 1 8 | 9 | # mixed precision for extra speed-up 10 | # precision: 16 11 | 12 | # set True to to ensure deterministic results 13 | # makes training slower but gives more reproducibility than just setting seeds 14 | deterministic: False -------------------------------------------------------------------------------- /hyperbox/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from hyperbox.utils.utils import _module_available 2 | from .cifar_datamodule import * 3 | from .fakedata_datamodule import * 4 | from .mnist_datamodule import * 5 | 6 | if _module_available("medmnist"): 7 | from .medmnist_datamodule import * 8 | if _module_available("nvidia.dali"): 9 | from .imagenet_datamodule import * 10 | -------------------------------------------------------------------------------- /hyperbox/datamodules/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/datamodules/datasets/__init__.py -------------------------------------------------------------------------------- /hyperbox/datamodules/distributed_sampler_wrapper.py: -------------------------------------------------------------------------------- 1 | from operator import itemgetter 2 | from typing import Iterator, List, Optional, Union 3 | 4 | from torch.utils.data import Dataset, DistributedSampler, Sampler 5 | 6 | __all__ = [ 7 | 'DatasetFromSampler', 8 | 'DistributedSamplerWrapper' 9 | ] 10 | 11 | class DatasetFromSampler(Dataset): 12 | """Dataset to create indexes from `Sampler`. 13 | Args: 14 | sampler: PyTorch sampler 15 | """ 16 | 17 | def __init__(self, sampler: Sampler): 18 | """Initialisation for DatasetFromSampler.""" 19 | self.sampler = sampler 20 | self.sampler_list = None 21 | 22 | def __getitem__(self, index: int): 23 | """Gets element of the dataset. 24 | Args: 25 | index: index of the element in the dataset 26 | Returns: 27 | Single element by index 28 | """ 29 | if self.sampler_list is None: 30 | self.sampler_list = list(self.sampler) 31 | return self.sampler_list[index] 32 | 33 | def __len__(self) -> int: 34 | """ 35 | Returns: 36 | int: length of the dataset 37 | """ 38 | return len(self.sampler) 39 | 40 | 41 | class DistributedSamplerWrapper(DistributedSampler): 42 | """ 43 | Wrapper over `Sampler` for distributed training. 44 | Allows you to use any sampler in distributed mode. 45 | It is especially useful in conjunction with 46 | `torch.nn.parallel.DistributedDataParallel`. In such case, each 47 | process can pass a DistributedSamplerWrapper instance as a DataLoader 48 | sampler, and load a subset of subsampled data of the original dataset 49 | that is exclusive to it. 50 | .. note:: 51 | Sampler is assumed to be of constant size. 52 | """ 53 | 54 | def __init__( 55 | self, 56 | sampler, 57 | num_replicas: Optional[int] = None, 58 | rank: Optional[int] = None, 59 | shuffle: bool = True, 60 | ): 61 | """ 62 | Args: 63 | sampler: Sampler used for subsampling 64 | num_replicas (int, optional): Number of processes participating in 65 | distributed training 66 | rank (int, optional): Rank of the current process 67 | within ``num_replicas`` 68 | shuffle (bool, optional): If true (default), 69 | sampler will shuffle the indices 70 | """ 71 | super(DistributedSamplerWrapper, self).__init__( 72 | DatasetFromSampler(sampler), 73 | num_replicas=num_replicas, 74 | rank=rank, 75 | shuffle=shuffle, 76 | ) 77 | self.sampler = sampler 78 | 79 | def __iter__(self): 80 | """@TODO: Docs. Contribution is welcome.""" 81 | self.dataset = DatasetFromSampler(self.sampler) 82 | indexes_of_indexes = super().__iter__() 83 | subsampler_indexes = self.dataset 84 | return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) 85 | -------------------------------------------------------------------------------- /hyperbox/datamodules/mnist_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split 5 | from torchvision.datasets import MNIST 6 | from torchvision.transforms import transforms 7 | 8 | __all__ = ['MNISTDataModule'] 9 | 10 | 11 | class MNISTDataModule(LightningDataModule): 12 | """ 13 | Example of LightningDataModule for MNIST dataset. 14 | A DataModule implements 5 key methods: 15 | - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode) 16 | - setup (things to do on every accelerator in distributed mode) 17 | - train_dataloader (the training dataloader) 18 | - val_dataloader (the validation dataloader(s)) 19 | - test_dataloader (the test dataloader(s)) 20 | This allows you to share a full dataset without explaining how to download, 21 | split, transform and process the data 22 | Read the docs: 23 | https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html 24 | """ 25 | 26 | def __init__( 27 | self, 28 | data_dir: str = "data/", 29 | train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000), 30 | batch_size: int = 64, 31 | num_workers: int = 0, 32 | pin_memory: bool = False, 33 | **kwargs, 34 | ): 35 | super().__init__() 36 | 37 | self.data_dir = data_dir 38 | self.train_val_test_split = train_val_test_split 39 | self.batch_size = batch_size 40 | self.num_workers = num_workers 41 | self.pin_memory = pin_memory 42 | 43 | self.transforms = transforms.Compose( 44 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 45 | ) 46 | 47 | # self.dims is returned when you call datamodule.size() 48 | self.dims = (1, 28, 28) 49 | 50 | self.data_train: Optional[Dataset] = None 51 | self.data_val: Optional[Dataset] = None 52 | self.data_test: Optional[Dataset] = None 53 | 54 | @property 55 | def num_classes(self) -> int: 56 | return 10 57 | 58 | def prepare_data(self): 59 | """Download data if needed. This method is called only from a single GPU. 60 | Do not use it to assign state (self.x = y).""" 61 | MNIST(self.data_dir, train=True, download=True) 62 | MNIST(self.data_dir, train=False, download=True) 63 | 64 | def setup(self, stage: Optional[str] = None): 65 | """Load data. Set variables: self.data_train, self.data_val, self.data_test.""" 66 | trainset = MNIST(self.data_dir, train=True, transform=self.transforms) 67 | testset = MNIST(self.data_dir, train=False, transform=self.transforms) 68 | dataset = ConcatDataset(datasets=[trainset, testset]) 69 | self.data_train, self.data_val, self.data_test = random_split( 70 | dataset, self.train_val_test_split 71 | ) 72 | 73 | def train_dataloader(self): 74 | return DataLoader( 75 | dataset=self.data_train, 76 | batch_size=self.batch_size, 77 | num_workers=self.num_workers, 78 | pin_memory=self.pin_memory, 79 | shuffle=True, 80 | ) 81 | 82 | def val_dataloader(self): 83 | return DataLoader( 84 | dataset=self.data_val, 85 | batch_size=self.batch_size, 86 | num_workers=self.num_workers, 87 | pin_memory=self.pin_memory, 88 | shuffle=False, 89 | ) 90 | 91 | def test_dataloader(self): 92 | return DataLoader( 93 | dataset=self.data_test, 94 | batch_size=self.batch_size, 95 | num_workers=self.num_workers, 96 | pin_memory=self.pin_memory, 97 | shuffle=False, 98 | ) 99 | -------------------------------------------------------------------------------- /hyperbox/datamodules/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .base_transforms import BaseTransforms 3 | from .torch_transforms import TorchTransforms 4 | from .cutout import Cutout 5 | 6 | import importlib 7 | alb = importlib.util.find_spec('albumentations') 8 | if alb is not None: 9 | from .albumentation_transforms import AlbumentationsTransforms 10 | 11 | def get_transforms(name, kwargs: dict): 12 | if name == 'torch': 13 | from .torch_transforms import TorchTransforms 14 | T = TorchTransforms(**kwargs) 15 | elif name == 'albumentation': 16 | from .albumentation_transforms import AlbumentationsTransforms 17 | T = AlbumentationsTransforms(**kwargs) 18 | return T 19 | -------------------------------------------------------------------------------- /hyperbox/datamodules/transforms/base_transforms.py: -------------------------------------------------------------------------------- 1 | 2 | from omegaconf import DictConfig 3 | 4 | from hyperbox.utils.utils import hparams_wrapper 5 | from hyperbox.utils.logger import get_logger 6 | _logger = get_logger(__name__) 7 | 8 | 9 | @hparams_wrapper 10 | class BaseTransforms(object): 11 | def __init__(self, *args, **kwargs): 12 | for key, value in self.hparams.items(): 13 | if isinstance(value, dict): 14 | value = DictConfig(value) 15 | setattr(self, key, value) 16 | 17 | def get_transform(self, is_train:bool = True): 18 | if not is_train: 19 | _logger.info('Generating validation transform ...') 20 | transform = self.valid_transform 21 | _logger.info(f'Valid transform={transform}') 22 | else: 23 | _logger.info('Generating training transform ...') 24 | transform = self.train_transform 25 | _logger.info(f'Train transform={transform}') 26 | return transform 27 | 28 | @property 29 | def transform_valid(self): 30 | raise NotImplementedError 31 | 32 | @property 33 | def transform_train(self): 34 | raise NotImplementedError -------------------------------------------------------------------------------- /hyperbox/datamodules/transforms/cutout.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class Cutout(object): 7 | """Randomly mask out one or more patches from an image. 8 | https://github.com/uoguelph-mlrg/Cutout/blob/287f934ea5fa00d4345c2cccecf3552e2b1c33e3/train.py#L45 9 | Args: 10 | n_holes (int): Number of patches to cut out of each image. 11 | length (int): The length (in pixels) of each square patch. 12 | """ 13 | def __init__(self, n_holes=1, length=16): 14 | self.n_holes = n_holes 15 | self.length = length 16 | 17 | def __call__(self, img): 18 | """ 19 | Args: 20 | img (Tensor): Tensor image of size (C, H, W). 21 | Returns: 22 | Tensor: Image with n_holes of dimension length x length cut out of it. 23 | """ 24 | h, w = img.shape[1:3] 25 | 26 | mask = np.ones((h, w), np.float32) 27 | 28 | for n in range(self.n_holes): 29 | y = np.random.randint(h) 30 | x = np.random.randint(w) 31 | 32 | y1 = np.clip(y - self.length // 2, 0, h) 33 | y2 = np.clip(y + self.length // 2, 0, h) 34 | x1 = np.clip(x - self.length // 2, 0, w) 35 | x2 = np.clip(x + self.length // 2, 0, w) 36 | 37 | mask[y1: y2, x1: x2] = 0. 38 | 39 | mask = torch.from_numpy(mask) 40 | mask = mask.expand_as(img) 41 | img = img * mask 42 | 43 | return img -------------------------------------------------------------------------------- /hyperbox/datamodules/transforms/torch_transforms.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Union, Optional 3 | 4 | from omegaconf import DictConfig 5 | from torchvision import transforms 6 | 7 | from .cutout import Cutout 8 | from .base_transforms import BaseTransforms 9 | 10 | 11 | class TorchTransforms(BaseTransforms): 12 | def __init__( 13 | self, 14 | input_size: Union[list] = [32], 15 | random_resized_crop: Union[dict, DictConfig] = {'enable': 0, 'padding': 0}, 16 | resize: Union[dict, DictConfig] = {'enable': 0}, 17 | random_crop: Union[dict, DictConfig] = {'enable': 0}, 18 | center_crop: Union[dict, DictConfig] = {'enable': 0}, 19 | color_jitter: Union[dict, DictConfig] = {'enable': 0}, 20 | random_horizontal_flip: Union[dict, DictConfig] = {'enable': 0, 'p': 0.5}, 21 | random_vertical_flip: Union[dict, DictConfig] = {'enable': 0, 'p': 0.5}, 22 | random_rotation: Union[dict, DictConfig] = {'enable': 0, 'degrees': 20}, 23 | cutout: Union[dict, DictConfig] = {'enable': 0, 'n_holes': 8, 'length': 4}, 24 | to_tensor: Union[dict, DictConfig] = {'enable': 1}, 25 | normalize: Union[dict, DictConfig] = { 26 | 'enable': 1, 'mean': [0.4914, 0.4822, 0.4465], 'std': [0.2023, 0.1994, 0.2010]}, 27 | ): 28 | super(TorchTransforms, self).__init__() 29 | self.min_edge_size = min(input_size) 30 | self._transform_train = self.parse_transforms() 31 | 32 | self._transform_valid = [ 33 | transforms.Resize(self.input_size) 34 | ] 35 | if self.to_tensor.enable: 36 | self._transform_valid.append(transforms.ToTensor()) 37 | if self.normalize.enable: 38 | mean = self.normalize.mean 39 | std = self.normalize.std 40 | self.normalize = transforms.Normalize(mean, std) 41 | self._transform_valid.append(self.normalize) 42 | self._transform_valid = transforms.Compose(self._transform_valid) 43 | 44 | def parse_transforms(self): 45 | img_size = self.input_size 46 | transform_list = [] 47 | 48 | # resize and crop opertaion 49 | if self.random_resized_crop.enable: 50 | transform_list.append(transforms.RandomResizedCrop(img_size)) 51 | elif self.resize.enable: 52 | transform_list.append(transforms.Resize(img_size)) 53 | if self.random_crop.enable: 54 | padding = self.random_crop.padding 55 | size = getattr(self.random_crop, 'size', self.min_edge_size) 56 | transform_list.append(transforms.RandomCrop(size, padding=padding)) 57 | elif self.center_crop.enable: 58 | transform_list.append(transforms.CenterCrop(self.min_edge_size)) 59 | 60 | # ColorJitter 61 | if self.color_jitter.enable: 62 | params = {key: self.color_jitter[key] for key in self.color_jitter 63 | if key != 'enable'} 64 | transform_list.append(transforms.ColorJitter(**params)) 65 | 66 | # horizontal flip 67 | if self.random_horizontal_flip.enable: 68 | p = self.random_horizontal_flip.p 69 | transform_list.append(transforms.RandomHorizontalFlip(p)) 70 | 71 | # vertical flip 72 | if self.random_vertical_flip.enable: 73 | p = self.random_vertical_flip.p 74 | transform_list.append(transforms.RandomVerticalFlip(p)) 75 | 76 | # rotation 77 | if self.random_rotation.enable: 78 | degrees = self.random_rotation.degrees 79 | transform_list.append(transforms.RandomRotation(degrees)) 80 | 81 | # cutout 82 | if self.to_tensor.enable: 83 | transform_list.append(transforms.ToTensor()) 84 | if self.cutout.enable: 85 | n_holes = self.cutout.n_holes 86 | length = self.cutout.length 87 | transform_list.append(Cutout(n_holes, length)) 88 | if self.normalize.enable: 89 | mean = self.normalize.mean 90 | std = self.normalize.std 91 | transform_list.append(transforms.Normalize(mean, std)) 92 | transform_list = transforms.Compose(transform_list) 93 | assert len(transform_list.transforms) > 0, "The length of transform list much be larger than 0." 94 | return transform_list 95 | 96 | @property 97 | def transform_train(self): 98 | raise self._transform_train 99 | 100 | @property 101 | def transform_valid(self): 102 | return self._transform_train 103 | -------------------------------------------------------------------------------- /hyperbox/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/engine/__init__.py -------------------------------------------------------------------------------- /hyperbox/engine/base_engine.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class BaseEngine: 4 | """ 5 | Base class for all engines. 6 | """ 7 | 8 | def __init__( 9 | self, 10 | trainer, 11 | model, 12 | datamodule, 13 | cfg 14 | ): 15 | self.trainer = trainer 16 | self.model = model 17 | self.datamodule = datamodule 18 | self.cfg = cfg 19 | 20 | def run(self): 21 | raise NotImplementedError 22 | -------------------------------------------------------------------------------- /hyperbox/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/losses/__init__.py -------------------------------------------------------------------------------- /hyperbox/losses/ce_labelsmooth_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | __all__ = [ 5 | 'CrossEntropyLabelSmooth', 6 | ] 7 | 8 | 9 | class CrossEntropyLabelSmooth(torch.nn.Module): 10 | def __init__(self, label_smoothing, weight=None): 11 | super(CrossEntropyLabelSmooth, self).__init__() 12 | self.label_smoothing = label_smoothing 13 | if weight is not None: 14 | self.weight = torch.tensor(weight) 15 | else: 16 | self.weight = None 17 | 18 | def forward(self, pred, target): 19 | if self.weight is not None: 20 | self.weight = self.weight.to(pred.device) 21 | logsoftmax = torch.nn.LogSoftmax(dim=1) 22 | n_classes = pred.size(1) 23 | # convert to one-hot 24 | target = torch.unsqueeze(target, 1) 25 | soft_target = torch.zeros_like(pred) 26 | soft_target.scatter_(1, target, 1) 27 | # label smoothing 28 | soft_target = soft_target * (1 - self.label_smoothing) + self.label_smoothing / n_classes 29 | if self.weight is not None: 30 | soft_target *= self.weight 31 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) 32 | -------------------------------------------------------------------------------- /hyperbox/losses/focal_loss.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | __all__ = [ 7 | 'FocalLoss', 8 | '_FocalLoss' 9 | ] 10 | 11 | 12 | class FocalLoss(torch.nn.Module): 13 | def __init__(self, alpha, gamma, num_classes, size_average=True): 14 | """focal_loss function: -α(1-yi)**γ *ce_loss(xi,yi) 15 | Args: 16 | alpha: class weight (default 0.25). 17 | When α is a 'list', it indicates the class-wise weights; 18 | When α is a constant, 19 | if in detection task, it indicates that the class-wise weights are[α, 1-α, 1-α, ...], 20 | the first class indicates the background 21 | if in classification task, it indicates that the class-wise weights are the same 22 | gamma: γ (default 2), focusing paramter smoothly adjusts the rate at which easy examples are down-weighted. 23 | num_classes: the number of classes 24 | size_average: (default 'mean'/'sum') specify the way to compute the loss value 25 | """ 26 | 27 | super(_FocalLoss,self).__init__() 28 | self.size_average = size_average 29 | if isinstance(alpha,list): 30 | assert len(alpha)==num_classes 31 | alpha /= np.sum(alpha) # setting the value in range of [0, 1] 32 | # print("Focal loss alpha = {}, assign specific weights for each class".format(alpha)) 33 | self.alpha = torch.Tensor(alpha) 34 | else: 35 | assert alpha<=1 36 | self.alpha = torch.zeros(num_classes)+0.00001 37 | 38 | # classification task 39 | # print("Focal loss alpha={}, the weight for each class is the same".format(alpha)) 40 | self.alpha += alpha 41 | 42 | # detection task # 如果α为一个常数,则降低第一类的影响,在目标检测中背景为第一类 43 | # print(" --- Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用 --- ".format(alpha)) 44 | # self.alpha[0] += alpha 45 | # self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes] 46 | self.gamma = gamma 47 | 48 | def forward(self, predictions, labels): 49 | """ 50 | focal_loss损失计算 51 | Args: 52 | preds: 预测类别. size:[B,N,C] or [B,C] 分别对应与检测与分类任务, B 批次, N检测框数, C类别数 53 | labels: 实际类别. size:[B,N] or [B] 54 | return: 55 | loss 56 | """ 57 | assert predictions.dim()==2 and labels.dim()==1 58 | preds = predictions.view(-1,predictions.size(-1)) # num*classes 59 | alpha = self.alpha.to(labels.device) 60 | 61 | # 这里并没有直接使用log_softmax, 因为后面会用到softmax的结果(当然你也可以使用log_softmax,然后进行exp操作) 62 | preds_softmax = F.softmax(preds, dim=1) 63 | preds_logsoft = torch.log(preds_softmax) 64 | 65 | # implement nll_loss ( crossempty = log_softmax + nll ) 66 | preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) # num*1 67 | preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1)) # num*1 68 | alpha = alpha.gather(0,labels.view(-1)) # num 69 | 70 | # calc loss 71 | # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ 72 | # shape: num*1 73 | loss = -torch.mul( torch.pow((1-preds_softmax), self.gamma), preds_logsoft ) 74 | # α * (1-pt)**γ * ce_loss 75 | # shape: 76 | loss = torch.mul(alpha, loss.t()) 77 | del preds 78 | del alpha 79 | if self.size_average: 80 | loss = loss.mean() 81 | else: 82 | loss = loss.sum() 83 | return loss -------------------------------------------------------------------------------- /hyperbox/losses/kd_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def KDLoss(outputs, teacher_outputs, alpha=0.4, temperature=2): 5 | """ 6 | Compute the knowledge-distillation (KD) loss given outputs, labels. 7 | "Hyperparameters": temperature and alpha 8 | NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher 9 | and student expects the input tensor to be log probabilities! See Issue #2 10 | """ 11 | T = temperature 12 | student = F.log_softmax(outputs/T, dim=-1) 13 | teacher = F.softmax(teacher_outputs/T, dim=-1) 14 | KD_loss = F.kl_div(student, teacher)* (alpha * T * T) 15 | 16 | return KD_loss -------------------------------------------------------------------------------- /hyperbox/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/models/__init__.py -------------------------------------------------------------------------------- /hyperbox/mutables/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | 'spaces': 4 | ['Mutable', 'OperationSpace', 'InputSpace', 'MutableScope', 'ValueSpace'] 5 | || 6 | || 7 | \/ 8 | 'ops' 9 | ['Conv1d', 'Conv2d', 'Conv3d', 'Linear', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d] 10 | || 11 | || 12 | \/ 13 | 'layers' 14 | ['MBLayer', ...] 15 | ''' -------------------------------------------------------------------------------- /hyperbox/mutables/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers2d import * -------------------------------------------------------------------------------- /hyperbox/mutables/masker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = [ 6 | '__MASKERS__', 7 | 'BaseMasker', 8 | 'L1Masker' 9 | ] 10 | 11 | 12 | class BaseMasker: 13 | def __init__(self,): 14 | pass 15 | 16 | def get_channel_sortedIdx(self, module): 17 | raise NotImplementedError 18 | 19 | def __call__(self, module, is_in_feat=True): 20 | return self.get_channel_sortedIdx(module, is_in_feat) 21 | 22 | 23 | class L1Masker(BaseMasker): 24 | def get_channel_sortedIdx(self, module, is_in_feat=True): 25 | num_dim = len(module.weight.shape) 26 | dim = list(range(num_dim)) 27 | if num_dim > 1: # conv or linear 28 | if is_in_feat: 29 | del dim[1] # [0, 2, 3] for conv2d, [0, 2, 3, 4] for conv3, [0] for linear 30 | else: 31 | del dim[0] # [1, 2, 3] 32 | importance = torch.sum(torch.abs(module.weight.data), dim=dim) 33 | else: # BN 34 | importance = torch.abs(module.weight.data) 35 | sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True) 36 | return sorted_idx 37 | 38 | 39 | __MASKERS__ = { 40 | 'l1': L1Masker, 41 | } 42 | -------------------------------------------------------------------------------- /hyperbox/mutables/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_module import * 2 | from .conv import * 3 | from .linear import * 4 | from .batchnorm import * 5 | from .multihead_attention import * 6 | from .embedding import * 7 | from .layernorm import * 8 | from .groupnorm import * 9 | -------------------------------------------------------------------------------- /hyperbox/mutables/ops/base_module.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from hyperbox.mutables.spaces import ValueSpace 7 | from hyperbox.utils.utils import hparams_wrapper 8 | 9 | __all__ = [ 10 | 'FinegrainedModule' 11 | ] 12 | 13 | 14 | @hparams_wrapper 15 | class FinegrainedModule(nn.Module): 16 | def __init__(self, *args, **kwargs): 17 | super(FinegrainedModule, self).__init__() 18 | # The decorator @hparams_wrapper can automatically save all input arguments to 19 | # ``hparams`` attribute 20 | self.value_spaces = self.getValueSpaces(self.hparams) 21 | 22 | def getValueSpaces(self, kwargs): 23 | value_spaces = nn.ModuleDict() 24 | for key, value in kwargs.items(): 25 | if key in ['weight', 'bias']: 26 | if hasattr(self, key): delattr(self, key) 27 | key = '_' + key 28 | if isinstance(value, ValueSpace): 29 | value_spaces[key] = value 30 | if value.index is not None: 31 | _v = value.candidates_original[value.index] 32 | elif len(value.mask) != 0: 33 | if isinstance(value.mask, torch.Tensor): 34 | index = value.mask.clone().detach().cpu().numpy().argmax() 35 | else: 36 | index = np.array(value.mask).argmax() 37 | _v = value.candidates_original[index] 38 | else: 39 | _v = value.max_value 40 | setattr(self, key, _v) 41 | else: 42 | setattr(self, key, value) 43 | return value_spaces 44 | 45 | def __deepcopy__(self, memo): 46 | try: 47 | new_instance = self.__class__(**self.hparams) 48 | device = next(self.parameters()).device 49 | new_instance.load_state_dict(self.state_dict()) 50 | return new_instance.to(device) 51 | except Exception as e: 52 | print(str(e)) 53 | -------------------------------------------------------------------------------- /hyperbox/mutables/ops/embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | from hyperbox.mutables.ops.base_module import FinegrainedModule 9 | from hyperbox.mutables.ops.utils import is_searchable 10 | from hyperbox.mutables.spaces import ValueSpace 11 | 12 | 13 | class Embedding(nn.Embedding, FinegrainedModule): 14 | def __init__( 15 | self, 16 | num_embeddings: int, 17 | embedding_dim: Union[int, ValueSpace], 18 | device=None, 19 | dtype=None, 20 | *args, **kwargs, 21 | ) -> None: 22 | if isinstance(embedding_dim, ValueSpace): 23 | _embedding_dim = embedding_dim.max_value 24 | else: 25 | _embedding_dim = embedding_dim 26 | super(Embedding, self).__init__( 27 | num_embeddings, _embedding_dim, *args, **kwargs) 28 | self.is_search = self.isSearchEmbedding() 29 | 30 | def forward(self, input: Tensor) -> Tensor: 31 | if not self.is_search: 32 | x = F.embedding(input, self.weight, self.padding_idx, self.max_norm, 33 | self.norm_type, self.scale_grad_by_freq, self.sparse) 34 | else: 35 | embedding_dim = self.value_spaces["embedding_dim"].value 36 | weight = self.weight[:, :embedding_dim] 37 | x = F.embedding(input, weight, self.padding_idx, self.max_norm, 38 | self.norm_type, self.scale_grad_by_freq, self.sparse) 39 | return x 40 | 41 | def isSearchEmbedding(self): 42 | '''search flag 43 | search 44 | search_embedding_dim 45 | ''' 46 | self.search_embedding_dim = False 47 | if all([not vs.is_search for vs in self.value_spaces.values()]): 48 | return False 49 | if is_searchable(getattr(self.value_spaces, 'embedding_dim', None)): 50 | self.search_embedding_dim = True 51 | return True 52 | 53 | @property 54 | def params(self): 55 | weight = self.weight 56 | if self.search_embedding_dim: 57 | weight = weight[:, :self.value_spaces['embedding_dim'].value] 58 | size = weight.numel() 59 | return size 60 | 61 | 62 | if __name__ == "__main__": 63 | from hyperbox.mutator import RandomMutator 64 | for i in range(10): 65 | embed_dim = ValueSpace([10, 20]) 66 | embedding = Embedding( 67 | 10, embed_dim, max_norm=1.0, norm_type=2.0, 68 | scale_grad_by_freq=True, sparse=True) 69 | rm = RandomMutator(embedding) 70 | print(embedding) 71 | for j in range(4): 72 | rm.reset() 73 | x = torch.randint(0, 10, (2, 30)) 74 | y = embedding(x) 75 | print(y.shape) 76 | -------------------------------------------------------------------------------- /hyperbox/mutables/ops/groupnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from hyperbox.mutables.ops.base_module import FinegrainedModule 6 | from hyperbox.mutables.ops.utils import is_searchable 7 | from hyperbox.mutables.spaces import ValueSpace 8 | 9 | 10 | class GroupNorm(nn.GroupNorm, FinegrainedModule): 11 | 12 | def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): 13 | _num_groups = num_groups.max_value if isinstance(num_groups, ValueSpace) else num_groups 14 | _num_channels = num_channels.max_value if isinstance(num_channels, ValueSpace) else num_channels 15 | super(GroupNorm, self).__init__(_num_groups, _num_channels, eps, affine) 16 | self.is_search = self.isSearchGroupNorm() 17 | 18 | def isSearchGroupNorm(self): 19 | if all([not vs.is_search for vs in self.value_spaces.values()]): 20 | return False 21 | 22 | if is_searchable(getattr(self.value_spaces, 'num_groups', None)): 23 | self.search_num_groups = True 24 | else: 25 | self.search_num_groups = False 26 | if is_searchable(getattr(self.value_spaces, 'num_channels', None)): 27 | self.search_num_channels = True 28 | else: 29 | self.search_num_channels = False 30 | return True 31 | 32 | def forward(self, input): 33 | if not self.is_search: 34 | x = F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) 35 | else: 36 | n_groups = self.value_spaces['num_groups'].value if self.search_num_groups else self.num_groups 37 | n_channels = self.value_spaces['num_channels'].value if self.search_num_channels else self.num_channels 38 | return F.group_norm(input, n_groups, self.weight[:n_channels], self.bias[:n_channels], self.eps) 39 | 40 | @property 41 | def params(self): 42 | weight = self.weight 43 | bias = self.bias 44 | if self.search_num_channels: 45 | weight = weight[:self.value_spaces['num_channels'].value] 46 | bias = bias[:self.value_spaces['num_channels'].value] if bias is not None else None 47 | parameters = [weight, bias] 48 | params = sum([p.numel() for p in parameters if p is not None]) 49 | return params 50 | 51 | 52 | if __name__ == '__main__': 53 | from hyperbox.mutator import RandomMutator 54 | from hyperbox.mutables.ops import Conv2d 55 | input = torch.randn(20, 6, 10, 10) 56 | 57 | #### no search 58 | ##### Separate 6 channels into 3 groups 59 | print("no search") 60 | m = nn.GroupNorm(3, 6) 61 | output = m(input) 62 | print(output.shape) 63 | ##### Separate 6 channels into 6 groups (equivalent with InstanceNorm) 64 | m = nn.GroupNorm(6, 6) 65 | output = m(input) 66 | print(output.shape) 67 | ##### Put all 6 channels into a single group (equivalent with LayerNorm) 68 | m = nn.GroupNorm(1, 6) 69 | output = m(input) 70 | print(output.shape) 71 | 72 | 73 | #### search only n_groups 74 | print("search only n_groups") 75 | n_groups = ValueSpace([2, 3, 6, 1], key='num_groups') 76 | m = GroupNorm(n_groups, 6) 77 | rm = RandomMutator(m) 78 | for i in range(5): 79 | rm.reset() 80 | output = m(input) 81 | print(output.shape) 82 | 83 | # search only n_channels 84 | print("search only n_channels") 85 | n_channels = ValueSpace([12, 18, 24], key='num_channels') 86 | m = nn.Sequential( 87 | Conv2d(6, n_channels, 3, 1, 1), 88 | GroupNorm(3, n_channels) 89 | ) 90 | rm = RandomMutator(m) 91 | for i in range(10): 92 | rm.reset() 93 | output = m(input) 94 | print(output.shape) 95 | 96 | # search n_groups and n_channels 97 | print("search n_groups and n_channels") 98 | n_groups = ValueSpace([2, 3, 6, 1], key='num_groups2') 99 | n_channels = ValueSpace([12, 18, 24], key='num_channels2') 100 | m = nn.Sequential( 101 | Conv2d(6, n_channels, 3, 1, 1), 102 | GroupNorm(n_groups, n_channels) 103 | ) 104 | rm = RandomMutator(m) 105 | for i in range(10): 106 | rm.reset() 107 | output = m(input) 108 | print(output.shape) -------------------------------------------------------------------------------- /hyperbox/mutables/ops/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..spaces import Mutable 7 | 8 | def is_searchable( 9 | obj: Optional[Union[None, Mutable]] 10 | ): 11 | '''Check whether the Space obj is searchable''' 12 | if (obj is None) or (not obj.is_search): 13 | return False 14 | return True 15 | 16 | def sub_filter_start_end(kernel_size, sub_kernel_size): 17 | if isinstance(kernel_size, (list, tuple)): 18 | kernel_size = kernel_size[0] 19 | center = kernel_size // 2 20 | dev = sub_kernel_size // 2 21 | start, end = center - dev, center + dev + 1 22 | assert end - start == sub_kernel_size 23 | return start, end 24 | -------------------------------------------------------------------------------- /hyperbox/mutator/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_mutator import BaseMutator 2 | from .default_mutator import Mutator 3 | from .darts_mutator import DartsMutator 4 | from .evolution_mutator import EvolutionMutator 5 | from .enas_mutator import EnasMutator 6 | from .onehot_mutator import OnehotMutator 7 | from .random_mutator import RandomMutator 8 | from .sequential_mutator import SequentialMutator 9 | from .proxyless_mutator import ProxylessMutator 10 | from .fixed_mutator import apply_fixed_architecture, FixedArchitecture 11 | from .random_multiple_mutator import RandomMultipleMutator 12 | from .darts_multiple_mutator import DartsMultipleMutator 13 | from .repnas_mutator import RepnasMutator 14 | -------------------------------------------------------------------------------- /hyperbox/mutator/fixed_mutator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import json 5 | 6 | import torch 7 | 8 | from hyperbox.mutables.spaces import MutableScope 9 | 10 | from hyperbox.mutator.default_mutator import Mutator 11 | 12 | 13 | class FixedArchitecture(Mutator): 14 | 15 | def __init__(self, model, fixed_arc, strict=True): 16 | """ 17 | Initialize a fixed architecture mutator. 18 | 19 | Parameters 20 | ---------- 21 | model : nn.Module 22 | A mutable network. 23 | fixed_arc : str or dict 24 | Path to the architecture checkpoint (a string), or preloaded architecture object (a dict). 25 | strict : bool 26 | Force everything that appears in `fixed_arc` to be used at least once. 27 | """ 28 | super().__init__(model) 29 | self._fixed_arc = fixed_arc 30 | 31 | mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)]) 32 | fixed_arc_keys = set(self._fixed_arc.keys()) 33 | if fixed_arc_keys - mutable_keys: 34 | raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) 35 | if mutable_keys - fixed_arc_keys: 36 | raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) 37 | 38 | def sample_search(self): 39 | result = self._fixed_arc # for OperationSpace and InputSpace 40 | for mutable in self.mutables: 41 | mutable.mask = result[mutable.key] # for ValueSpace 42 | return result 43 | 44 | def sample_final(self): 45 | return self.sample_search() 46 | 47 | 48 | def _encode_tensor(data): 49 | if isinstance(data, list): 50 | if all(map(lambda o: isinstance(o, bool), data)): 51 | return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable 52 | else: 53 | return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable 54 | if isinstance(data, dict): 55 | return {k: _encode_tensor(v) for k, v in data.items()} 56 | return data 57 | 58 | 59 | def apply_fixed_architecture(model, fixed_arc_path): 60 | """ 61 | Load architecture from `fixed_arc_path` and apply to model. 62 | 63 | Parameters 64 | ---------- 65 | model : torch.nn.Module 66 | Model with mutables. 67 | fixed_arc_path : str 68 | Path to the JSON that stores the architecture. 69 | 70 | Returns 71 | ------- 72 | FixedArchitecture 73 | """ 74 | 75 | if isinstance(fixed_arc_path, str): 76 | with open(fixed_arc_path, "r") as f: 77 | fixed_arc = json.load(f) 78 | fixed_arc = _encode_tensor(fixed_arc) 79 | architecture = FixedArchitecture(model, fixed_arc) 80 | architecture.reset() 81 | return architecture 82 | -------------------------------------------------------------------------------- /hyperbox/mutator/onehot_mutator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from hyperbox.mutables.spaces import InputSpace, OperationSpace, ValueSpace 10 | 11 | from hyperbox.mutator.default_mutator import Mutator 12 | from .darts_mutator import DartsMutator 13 | 14 | __all__ = [ 15 | 'OnehotMutator', 16 | ] 17 | 18 | 19 | class OnehotMutator(DartsMutator): 20 | def __init__(self, model, *args, **kwargs): 21 | super().__init__(model, *args, **kwargs) 22 | 23 | def sample_search(self): 24 | result = dict() 25 | for mutable in self.mutables: 26 | if isinstance(mutable, OperationSpace): 27 | result[mutable.key] = F.gumbel_softmax(self.choices[mutable.key], hard=True, dim=-1) 28 | mutable.mask = torch.zeros_like(result[mutable.key]) 29 | mutable.mask[result[mutable.key].cpu().detach().numpy().argmax()] = 1 30 | elif isinstance(mutable, ValueSpace): 31 | result[mutable.key] = F.gumbel_softmax(self.choices[mutable.key], hard=True, dim=-1) 32 | mutable.mask.data = F.gumbel_softmax(self.choices[mutable.key], hard=True, dim=-1).data 33 | elif isinstance(mutable, InputSpace): 34 | result[mutable.key] = F.gumbel_softmax(self.choices[mutable.key], hard=True, dim=-1) 35 | mutable.mask = torch.zeros_like(result[mutable.key]) 36 | mutable.mask[result[mutable.key].cpu().detach().numpy().argmax()] = 1 37 | return result 38 | 39 | def sample_final(self): 40 | return super().sample_final() 41 | -------------------------------------------------------------------------------- /hyperbox/mutator/random_multiple_mutator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from hyperbox.mutables.spaces import InputSpace, OperationSpace, ValueSpace 6 | from hyperbox.mutator.default_mutator import Mutator 7 | from hyperbox.mutator.random_mutator import RandomMutator 8 | 9 | 10 | __all__ = [ 11 | 'RandomMultipleMutator', 12 | ] 13 | 14 | 15 | class RandomMultipleMutator(RandomMutator): 16 | def __init__(self, model, single_path_prob=0, *args, **kwargs): 17 | super(RandomMultipleMutator, self).__init__(model) 18 | self.single_path_prob = single_path_prob 19 | 20 | def sample_search(self): 21 | if np.random.rand() < self.single_path_prob: 22 | return super(RandomMultipleMutator, self).sample_search() 23 | result = dict() 24 | for mutable in self.mutables: 25 | if isinstance(mutable, OperationSpace): 26 | crt_mask = torch.randint(high=2, size=(mutable.length,)).view(-1).bool() 27 | if crt_mask.sum() == 0: 28 | crt_mask[-1] = 1 29 | result[mutable.key] = crt_mask.view(-1).bool() 30 | mutable.mask = result[mutable.key].detach() 31 | elif isinstance(mutable, InputSpace): 32 | if mutable.n_chosen is None: 33 | result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool() 34 | else: 35 | perm = torch.randperm(mutable.n_candidates) 36 | mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)] 37 | result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable 38 | mutable.mask = result[mutable.key].detach() 39 | elif isinstance(mutable, ValueSpace): 40 | gen_index = torch.randint(high=mutable.length, size=(1, )) 41 | result[mutable.key] = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool() 42 | mutable.mask = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool() 43 | return result 44 | 45 | def sample_final(self): 46 | return self.sample_search() 47 | -------------------------------------------------------------------------------- /hyperbox/mutator/random_mutator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from hyperbox.mutables.spaces import InputSpace, OperationSpace, ValueSpace 6 | 7 | from hyperbox.mutator.default_mutator import Mutator 8 | 9 | __all__ = [ 10 | 'RandomMutator', 11 | ] 12 | 13 | 14 | class RandomMutator(Mutator): 15 | def __init__(self, model, *args, **kwargs): 16 | super().__init__(model) 17 | 18 | def sample_search(self): 19 | result = dict() 20 | for mutable in self.mutables: 21 | if isinstance(mutable, OperationSpace): 22 | gen_index = torch.randint(high=mutable.length, size=(1, )) 23 | result[mutable.key] = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool() 24 | mutable.mask = result[mutable.key].detach() 25 | elif isinstance(mutable, InputSpace): 26 | if mutable.n_chosen is None: 27 | result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool() 28 | else: 29 | perm = torch.randperm(mutable.n_candidates) 30 | mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)] 31 | result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable 32 | mutable.mask = result[mutable.key].detach() 33 | elif isinstance(mutable, ValueSpace): 34 | gen_index = torch.randint(high=mutable.length, size=(1, )) 35 | result[mutable.key] = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool() 36 | mutable.mask = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool() 37 | return result 38 | 39 | def sample_final(self): 40 | return self.sample_search() 41 | -------------------------------------------------------------------------------- /hyperbox/mutator/sequential_mutator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from hyperbox.mutables.spaces import InputSpace, OperationSpace, ValueSpace 7 | from hyperbox.mutator.default_mutator import Mutator 8 | 9 | 10 | __all__ = [ 11 | 'SequentialMutator', 12 | ] 13 | 14 | class SequentialMutator(Mutator): 15 | def __init__(self, model, start_idx: int): 16 | super().__init__(model) 17 | with open('./mutator/Track1_final_archs.json', 'r') as f: 18 | self.masks = json.load(f) 19 | self.crt_index = start_idx 20 | self.max_num = len(self.masks) 21 | assert self.crt_index > 0, 'Index should start at 1' 22 | 23 | def sample_search(self): 24 | result = dict() 25 | for mutable in self.mutables: 26 | if isinstance(mutable, OperationSpace): 27 | gen_index = torch.randint(high=mutable.length, size=(1, )) 28 | result[mutable.key] = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool() 29 | mutable.mask = torch.zeros_like(result[mutable.key]) 30 | mutable.mask[result[mutable.key].detach().numpy().argmax()] = 1 31 | elif isinstance(mutable, InputSpace): 32 | if mutable.n_chosen is None: 33 | result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool() 34 | else: 35 | perm = torch.randperm(mutable.n_candidates) 36 | mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)] 37 | result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable 38 | mutable.mask = torch.zeros_like(result[mutable.key]) 39 | mutable.mask[result[mutable.key].detach().numpy().argmax()] = 1 40 | elif isinstance(mutable, ValueSpace): 41 | index_choice = int(mutable.key.split('ValueSpace')[-1]) - 1 42 | value = self.masks[f'arch{self.crt_index}']['arch'].split('-')[index_choice] 43 | gen_index = np.argwhere(np.array(mutable.candidates)==int(value))[0][0] 44 | gen_index = torch.tensor(gen_index) 45 | result[mutable.key] = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool() 46 | mutable.mask = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool() 47 | if self.crt_index >= self.max_num: 48 | self.crt_index = 1 49 | else: 50 | self.crt_index += 1 51 | return result 52 | 53 | def sample_final(self): 54 | return self.sample_search() 55 | -------------------------------------------------------------------------------- /hyperbox/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/networks/__init__.py -------------------------------------------------------------------------------- /hyperbox/networks/bnnas/__init__.py: -------------------------------------------------------------------------------- 1 | from .bn_net import * 2 | from .bn_blocks import * -------------------------------------------------------------------------------- /hyperbox/networks/bnnas/bn_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = [ 5 | 'blocks_dict', 6 | 'InvertedResidual', 7 | 'conv_1x1_bn', 8 | 'conv_bn' 9 | ] 10 | 11 | 12 | blocks_dict = { 13 | 'k3r3':lambda inp, oup, stride : InvertedResidual(inp, oup, 3, 1, stride, 3), 14 | 'k3r6':lambda inp, oup, stride : InvertedResidual(inp, oup, 3, 1, stride, 6), 15 | 'k5r3':lambda inp, oup, stride : InvertedResidual(inp, oup, 5, 2, stride, 3), 16 | 'k5r6':lambda inp, oup, stride : InvertedResidual(inp, oup, 5, 2, stride, 6), 17 | 'k7r3':lambda inp, oup, stride : InvertedResidual(inp, oup, 7, 3, stride, 3), 18 | 'k7r6':lambda inp, oup, stride : InvertedResidual(inp, oup, 7, 3, stride, 6), 19 | } 20 | 21 | 22 | class InvertedResidual(nn.Module): 23 | def __init__(self, inp, oup, ksize, padding, stride, expand_ratio): 24 | super(InvertedResidual, self).__init__() 25 | self.stride = stride 26 | self.use_res_connect = self.stride == 1 and inp == oup 27 | self.expand_ratio = expand_ratio 28 | 29 | if expand_ratio == 1: 30 | self.conv = nn.Sequential( 31 | # dw 32 | nn.Conv2d(inp, inp, ksize, stride, padding, groups=inp, bias=False), 33 | nn.BatchNorm2d(inp), 34 | nn.ReLU6(inplace=True), 35 | # pw-linear 36 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 37 | nn.BatchNorm2d(oup), 38 | ) 39 | else: 40 | self.conv = nn.Sequential( 41 | # pw 42 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 43 | nn.BatchNorm2d(inp * expand_ratio), 44 | nn.ReLU6(inplace=True), 45 | # dw 46 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, ksize, stride, padding, groups=inp * expand_ratio, bias=False), 47 | nn.BatchNorm2d(inp * expand_ratio), 48 | nn.ReLU6(inplace=True), 49 | # pw-linear 50 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 51 | nn.BatchNorm2d(oup), 52 | ) 53 | 54 | def forward(self, x): 55 | if self.use_res_connect: 56 | return x + self.conv(x) 57 | else: 58 | return self.conv(x) 59 | 60 | 61 | def conv_bn(inp, oup, stride): 62 | return nn.Sequential( 63 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 64 | nn.BatchNorm2d(oup), 65 | nn.ReLU6(inplace=True) 66 | ) 67 | 68 | def conv_1x1_bn(inp, oup): 69 | return nn.Sequential( 70 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 71 | nn.BatchNorm2d(oup), 72 | nn.ReLU6(inplace=True) 73 | ) 74 | 75 | -------------------------------------------------------------------------------- /hyperbox/networks/bnnas/ea_search.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | from hyperbox.utils.utils import TorchTensorEncoder, save_arch_to_json 8 | from hyperbox.networks.bnnas import BNNet 9 | from hyperbox.mutator.ea_mutator import EAMutator 10 | from hyperbox.mutator.utils import NonDominatedSorting 11 | 12 | 13 | if __name__ == '__main__': 14 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 15 | net = BNNet(num_classes=10, search_depth=False) 16 | # ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_bn_depth_adam0.001_sync_hete/2021-10-02_23-05-51/checkpoints/epoch=390_val/acc=27.8100.ckpt' 17 | # ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_all_depth_adam0.001_sync_hete/2021-10-02_23-05-43/checkpoints/epoch=339_val/acc=43.1300.ckpt' 18 | # ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_bn_adam0.001_sync_hete/2021-10-06_06-29-41/checkpoints/epoch=392_val/acc=28.8200.ckpt' 19 | # ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_all_adam0.001_sync_hete/2021-10-06_06-31-00/checkpoints/epoch=302_val/acc=44.0300.ckpt' 20 | # ckpt = '/datasets/xinhe/xinhe/hyperbox/logs/runs/bnnas_c10_all_bn/2021-10-22_04-03-29/checkpoints/epoch=14_val/acc=25.1800.ckpt' 21 | # ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_all20_bn20/2021-10-27_05-33-45/checkpoints/epoch=39_val/acc=39.3800.ckpt' 22 | # ckpt = torch.load(ckpt, map_location='cpu') 23 | # weights = {} 24 | # for key in ckpt['state_dict']: 25 | # weights[key.replace('network.', '')] = ckpt['state_dict'][key] 26 | 27 | # net.load_state_dict(weights) 28 | net = net.to(device) 29 | 30 | # method 1 31 | mode = 'all20_bn20' 32 | search_algorithm = 'cars' 33 | ea = EAMutator(net, num_population=50, algorithm=search_algorithm) 34 | # ea.load_ckpt('epoch2.pth') 35 | eval_func = lambda arch, network: net.bn_metrics().item() 36 | # eval_func = lambda arch, network: np.random.rand() 37 | # ea.search(20, eval_func, verbose=True, filling_history=True) 38 | ea.search(20, eval_func, eval_kwargs={}, verbose=True, filling_history=True) 39 | size = np.array([pool['size'] for pool in ea.history.values()]) 40 | metric = np.array([pool['metric'] for pool in ea.history.values()]) 41 | indices = np.argsort(size) 42 | size, metric = size[indices], metric[indices] 43 | epoch = ea.crt_epoch 44 | pareto_lists = NonDominatedSorting(np.vstack( (size.reshape(-1), 1/metric.reshape(-1)) )) 45 | pareto_indices = pareto_lists[0] # e.g., [75, 87, 113, 201, 205] 46 | ea.plot_pareto_fronts( 47 | size, metric, pareto_indices, 'model size (MB)', 'BN-based metric', 48 | figname=f'{mode}_pareto_searchepoch{epoch}_{search_algorithm}.pdf' 49 | ) 50 | 51 | path = ckpt.split('checkpoints')[0] 52 | path = os.path.join(path, 'mask_json') 53 | if not os.path.exists(path): 54 | os.makedirs(path) 55 | for key, arch_info in ea.pareto_fronts.items(): 56 | arch = arch_info['arch'] 57 | arch_path = os.path.join(path, f'arch_{key}.json') 58 | save_arch_to_json(arch, arch_path) 59 | # method 2 60 | # ea = EAMutator(net, num_population=50, algorithm='top') 61 | 62 | # ea.start_evolve = True 63 | # ea.init_population(ea.init_population_mode) 64 | # for i in range(30): 65 | # print(f"\n\nsearch epoch {i}") 66 | # if i>0: 67 | # ea.evolve() 68 | # for j, arch in enumerate(ea.population.values()): 69 | # # ea.reset() 70 | # ea.reset_cache_mask(arch['arch']) 71 | # arch['metric'] = net.bn_metrics().item() 72 | # metrics = [pool['metric'] for pool in ea.population.values()] 73 | # metrics.sort() 74 | # print('pop',metrics) 75 | # metrics = [pool['metric'] for pool in ea.pareto_fronts.values()] 76 | # metrics.sort() 77 | # print('pare',metrics) 78 | # visited = [] 79 | # for arch in ea.pareto_fronts.values(): 80 | # if arch['arch_code'] not in visited: visited.append(arch['arch_code']) 81 | # print(len(visited)) -------------------------------------------------------------------------------- /hyperbox/networks/darts/__init__.py: -------------------------------------------------------------------------------- 1 | from .darts_network import * 2 | -------------------------------------------------------------------------------- /hyperbox/networks/enas/__init__.py: -------------------------------------------------------------------------------- 1 | from .enas_network import * 2 | -------------------------------------------------------------------------------- /hyperbox/networks/gpt/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpt2 import * 2 | -------------------------------------------------------------------------------- /hyperbox/networks/mobilenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .mobile_net import MobileNet 2 | from .mobile3d_net import Mobile3DNet -------------------------------------------------------------------------------- /hyperbox/networks/mobilenet/mobile_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def get_parameters(model, keys=None, mode='include'): 5 | if keys is None: 6 | for name, param in model.named_parameters(): 7 | yield param 8 | elif mode == 'include': 9 | for name, param in model.named_parameters(): 10 | flag = False 11 | for key in keys: 12 | if key in name: 13 | flag = True 14 | break 15 | if flag: 16 | yield param 17 | elif mode == 'exclude': 18 | for name, param in model.named_parameters(): 19 | flag = True 20 | for key in keys: 21 | if key in name: 22 | flag = False 23 | break 24 | if flag: 25 | yield param 26 | else: 27 | raise ValueError('do not support: %s' % mode) 28 | 29 | 30 | def get_same_padding(kernel_size): 31 | if isinstance(kernel_size, tuple): 32 | assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size 33 | p1 = get_same_padding(kernel_size[0]) 34 | p2 = get_same_padding(kernel_size[1]) 35 | return p1, p2 36 | assert isinstance(kernel_size, int), 'kernel size should be either `int` or `tuple`' 37 | assert kernel_size % 2 > 0, 'kernel size should be odd number' 38 | return kernel_size // 2 39 | 40 | def build_activation(act_func, inplace=True): 41 | if act_func == 'relu': 42 | return nn.ReLU(inplace=inplace) 43 | elif act_func == 'relu6': 44 | return nn.ReLU6(inplace=inplace) 45 | elif act_func == 'tanh': 46 | return nn.Tanh() 47 | elif act_func == 'sigmoid': 48 | return nn.Sigmoid() 49 | elif act_func is None: 50 | return None 51 | else: 52 | raise ValueError('do not support: %s' % act_func) 53 | 54 | 55 | def make_divisible(v, divisor, min_val=None): 56 | """ 57 | This function is taken from the original tf repo. 58 | It ensures that all layers have a channel number that is divisible by 8 59 | It can be seen here: 60 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 61 | """ 62 | if min_val is None: 63 | min_val = divisor 64 | new_v = max(min_val, int(v + divisor / 2) // divisor * divisor) 65 | # Make sure that round down does not go down by more than 10%. 66 | if new_v < 0.9 * v: 67 | new_v += divisor 68 | return new_v 69 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench101/__init__.py: -------------------------------------------------------------------------------- 1 | from .nasbench101 import NASBench101Network 2 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench101/base_ops.py: -------------------------------------------------------------------------------- 1 | """Base operations used by the modules in this search space.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class ConvBnRelu(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0): 13 | super(ConvBnRelu, self).__init__() 14 | 15 | self.conv_bn_relu = nn.Sequential( 16 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), 17 | nn.BatchNorm2d(out_channels), 18 | nn.ReLU(inplace=True) 19 | ) 20 | 21 | def forward(self, x): 22 | return self.conv_bn_relu(x) 23 | 24 | class Conv3x3BnRelu(nn.Module): 25 | """3x3 convolution with batch norm and ReLU activation.""" 26 | def __init__(self, in_channels, out_channels): 27 | super(Conv3x3BnRelu, self).__init__() 28 | 29 | self.conv3x3 = ConvBnRelu(in_channels, out_channels, 3, 1, 1) 30 | 31 | def forward(self, x): 32 | x = self.conv3x3(x) 33 | return x 34 | 35 | class Conv1x1BnRelu(nn.Module): 36 | """1x1 convolution with batch norm and ReLU activation.""" 37 | def __init__(self, in_channels, out_channels): 38 | super(Conv1x1BnRelu, self).__init__() 39 | 40 | self.conv1x1 = ConvBnRelu(in_channels, out_channels, 1, 1, 0) 41 | 42 | def forward(self, x): 43 | x = self.conv1x1(x) 44 | return x 45 | 46 | class MaxPool3x3(nn.Module): 47 | """3x3 max pool with no subsampling.""" 48 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): 49 | super(MaxPool3x3, self).__init__() 50 | 51 | self.maxpool = nn.MaxPool2d(kernel_size, stride, padding) 52 | 53 | def forward(self, x): 54 | x = self.maxpool(x) 55 | return x 56 | 57 | # Commas should not be used in op names 58 | # OP_MAP = { 59 | # 'conv3x3-bn-relu': Conv3x3BnRelu, 60 | # 'conv1x1-bn-relu': Conv1x1BnRelu, 61 | # 'maxpool3x3': MaxPool3x3 62 | # } 63 | 64 | OP_MAP = { 65 | 'conv3x3-bn-relu': lambda cin,cout: Conv3x3BnRelu(cin,cout), 66 | 'conv1x1-bn-relu': lambda cin,cout: Conv1x1BnRelu(cin,cout), 67 | 'maxpool3x3': lambda cin,cout: MaxPool3x3(cin,cout) 68 | } -------------------------------------------------------------------------------- /hyperbox/networks/nasbench101/db_gen/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import INPUT, OUTPUT, CONV3X3_BN_RELU, CONV1X1_BN_RELU, MAXPOOL3X3 2 | from .model import Nb101TrialStats, Nb101IntermediateStats, Nb101TrialConfig 3 | from .query import query_nb101_trial_stats 4 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench101/db_gen/constants.py: -------------------------------------------------------------------------------- 1 | INPUT = 'input' 2 | OUTPUT = 'output' 3 | CONV3X3_BN_RELU = 'conv3x3-bn-relu' 4 | CONV1X1_BN_RELU = 'conv1x1-bn-relu' 5 | MAXPOOL3X3 = 'maxpool3x3' 6 | 7 | 8 | LABEL2ID = { 9 | INPUT: -1, 10 | OUTPUT: -2, 11 | CONV3X3_BN_RELU: 0, 12 | CONV1X1_BN_RELU: 1, 13 | MAXPOOL3X3: 2 14 | } 15 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench101/db_gen/db_gen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from tqdm import tqdm 4 | from nasbench import api # pylint: disable=import-error 5 | 6 | from hyperbox.networks.nasbench101.db_gen.model import db, Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats 7 | from hyperbox.networks.nasbench101.db_gen.graph_util import nasbench_format_to_architecture_repr, hash_module 8 | 9 | 10 | def main(args): 11 | nasbench = api.NASBench(args.inputFile) 12 | with db: 13 | db.create_tables([Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats]) 14 | for hashval in tqdm(nasbench.hash_iterator(), desc='Dumping data into database'): 15 | metadata, metrics = nasbench.get_metrics_from_hash(hashval) 16 | num_vertices, architecture = nasbench_format_to_architecture_repr( 17 | metadata['module_adjacency'], metadata['module_operations']) 18 | assert hashval == hash_module(architecture, num_vertices) 19 | for epochs in [4, 12, 36, 108]: 20 | trial_config = Nb101TrialConfig.create( 21 | arch=architecture, 22 | num_vertices=num_vertices, 23 | hash=hashval, 24 | num_epochs=epochs 25 | ) 26 | 27 | for seed in range(3): 28 | cur = metrics[epochs][seed] 29 | trial = Nb101TrialStats.create( 30 | config=trial_config, 31 | train_acc=cur['final_train_accuracy'] * 100, 32 | valid_acc=cur['final_validation_accuracy'] * 100, 33 | test_acc=cur['final_test_accuracy'] * 100, 34 | parameters=metadata['trainable_parameters'] / 1e6, 35 | training_time=cur['final_training_time'] * 60 36 | ) 37 | for t in ['halfway', 'final']: 38 | Nb101IntermediateStats.create( 39 | trial=trial, 40 | current_epoch=epochs // 2 if t == 'halfway' else epochs, 41 | training_time=cur[t + '_training_time'], 42 | train_acc=cur[t + '_train_accuracy'] * 100, 43 | valid_acc=cur[t + '_validation_accuracy'] * 100, 44 | test_acc=cur[t + '_test_accuracy'] * 100 45 | ) 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--inputFile', 51 | help='Path to the file to be converted, e.g., nasbench_full.tfrecord') 52 | args = parser.parse_args() 53 | main(args) 54 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench101/db_gen/graph_util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | 3 | import numpy as np 4 | 5 | from hyperbox.networks.nasbench101.db_gen.constants import INPUT, LABEL2ID, OUTPUT 6 | 7 | 8 | def _labeling_from_architecture(architecture, vertices): 9 | return [INPUT] + [architecture['op{}'.format(i)] for i in range(1, vertices - 1)] + [OUTPUT] 10 | 11 | 12 | def _adjancency_matrix_from_architecture(architecture, vertices): 13 | matrix = np.zeros((vertices, vertices), dtype=np.bool) 14 | for i in range(1, vertices): 15 | for k in architecture['input{}'.format(i)]: 16 | matrix[k, i] = 1 17 | return matrix 18 | 19 | 20 | def nasbench_format_to_architecture_repr(adjacency_matrix, labeling): 21 | """ 22 | Computes a graph-invariance MD5 hash of the matrix and label pair. 23 | Imported from NAS-Bench-101 repo. 24 | 25 | Parameters 26 | ---------- 27 | adjacency_matrix : np.ndarray 28 | A 2D array of shape NxN, where N is the number of vertices. 29 | ``matrix[u][v]`` is 1 if there is a direct edge from `u` to `v`, 30 | otherwise it will be 0. 31 | labeling : list of str 32 | A list of str that starts with input and ends with output. The intermediate 33 | nodes are chosen from candidate operators. 34 | 35 | Returns 36 | ------- 37 | tuple and int and dict 38 | Converted number of vertices and architecture. 39 | """ 40 | num_vertices = adjacency_matrix.shape[0] 41 | assert len(labeling) == num_vertices 42 | architecture = {} 43 | for i in range(1, num_vertices - 1): 44 | architecture['op{}'.format(i)] = labeling[i] 45 | assert labeling[i] not in [INPUT, OUTPUT] 46 | for i in range(1, num_vertices): 47 | architecture['input{}'.format(i)] = [k for k in range(i) if adjacency_matrix[k, i]] 48 | return num_vertices, architecture 49 | 50 | 51 | def infer_num_vertices(architecture): 52 | """ 53 | Infer number of vertices from an architecture dict. 54 | 55 | Parameters 56 | ---------- 57 | architecture : dict 58 | Architecture in NNI format. 59 | 60 | Returns 61 | ------- 62 | int 63 | Number of vertices. 64 | """ 65 | op_keys = set([k for k in architecture.keys() if k.startswith('op')]) 66 | intermediate_vertices = len(op_keys) 67 | assert op_keys == {'op{}'.format(i) for i in range(1, intermediate_vertices + 1)} 68 | return intermediate_vertices + 2 69 | 70 | 71 | def hash_module(architecture, vertices): 72 | """ 73 | Computes a graph-invariance MD5 hash of the matrix and label pair. 74 | This snippet is modified from code in NAS-Bench-101 repo. 75 | 76 | Parameters 77 | ---------- 78 | matrix : np.ndarray 79 | Square upper-triangular adjacency matrix. 80 | labeling : list of int 81 | Labels of length equal to both dimensions of matrix. 82 | 83 | Returns 84 | ------- 85 | str 86 | MD5 hash of the matrix and labeling. 87 | """ 88 | labeling = _labeling_from_architecture(architecture, vertices) 89 | labeling = [LABEL2ID[t] for t in labeling] 90 | matrix = _adjancency_matrix_from_architecture(architecture, vertices) 91 | in_edges = np.sum(matrix, axis=0).tolist() 92 | out_edges = np.sum(matrix, axis=1).tolist() 93 | 94 | assert len(in_edges) == len(out_edges) == len(labeling) 95 | hashes = list(zip(out_edges, in_edges, labeling)) 96 | hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes] 97 | # Computing this up to the diameter is probably sufficient but since the 98 | # operation is fast, it is okay to repeat more times. 99 | for _ in range(vertices): 100 | new_hashes = [] 101 | for v in range(vertices): 102 | in_neighbors = [hashes[w] for w in range(vertices) if matrix[w, v]] 103 | out_neighbors = [hashes[w] for w in range(vertices) if matrix[v, w]] 104 | new_hashes.append(hashlib.md5( 105 | (''.join(sorted(in_neighbors)) + '|' + 106 | ''.join(sorted(out_neighbors)) + '|' + 107 | hashes[v]).encode('utf-8')).hexdigest()) 108 | hashes = new_hashes 109 | fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest() 110 | 111 | return fingerprint 112 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench101/db_gen/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import functools 3 | import json 4 | from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model 5 | from playhouse.sqlite_ext import JSONField, SqliteExtDatabase 6 | 7 | 8 | 9 | json_dumps = functools.partial(json.dumps, sort_keys=True) 10 | DATABASE_DIR = os.environ.get("NASBENCHMARK_DIR", os.path.expanduser("~/.hyperbox/nasbenchmark")) 11 | db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench101.db'), autoconnect=True) 12 | 13 | 14 | class Nb101TrialConfig(Model): 15 | """ 16 | Trial config for NAS-Bench-101. 17 | 18 | Attributes 19 | ---------- 20 | arch : dict 21 | A dict with keys ``op1``, ``op2``, ... and ``input1``, ``input2``, ... Vertices are 22 | enumerate from 0. Since node 0 is input node, it is skipped in this dict. Each ``op`` 23 | is one of :const:`nni.nas.benchmark.nasbench101.CONV3X3_BN_RELU`, 24 | :const:`nni.nas.benchmark.nasbench101.CONV1X1_BN_RELU`, and :const:`nni.nas.benchmark.nasbench101.MAXPOOL3X3`. 25 | Each ``input`` is a list of previous nodes. For example ``input5`` can be ``[0, 1, 3]``. 26 | num_vertices : int 27 | Number of vertices (nodes) in one cell. Should be less than or equal to 7 in default setup. 28 | hash : str 29 | Graph-invariant MD5 string for this architecture. 30 | num_epochs : int 31 | Number of epochs planned for this trial. Should be one of 4, 12, 36, 108 in default setup. 32 | """ 33 | 34 | arch = JSONField(index=True) 35 | num_vertices = IntegerField(index=True) 36 | hash = CharField(max_length=64, index=True) 37 | num_epochs = IntegerField(index=True) 38 | 39 | class Meta: 40 | database = db 41 | 42 | 43 | class Nb101TrialStats(Model): 44 | """ 45 | Computation statistics for NAS-Bench-101. Each corresponds to one trial. 46 | Each config has multiple trials with different random seeds, but unfortunately seed for each trial is unavailable. 47 | NAS-Bench-101 trains and evaluates on CIFAR-10 by default. The original training set is divided into 48 | 40k training images and 10k validation images, and the original validation set is used for test only. 49 | 50 | Attributes 51 | ---------- 52 | config : Nb101TrialConfig 53 | Setup for this trial data. 54 | train_acc : float 55 | Final accuracy on training data, ranging from 0 to 100. 56 | valid_acc : float 57 | Final accuracy on validation data, ranging from 0 to 100. 58 | test_acc : float 59 | Final accuracy on test data, ranging from 0 to 100. 60 | parameters : float 61 | Number of trainable parameters in million. 62 | training_time : float 63 | Duration of training in seconds. 64 | """ 65 | config = ForeignKeyField(Nb101TrialConfig, backref='trial_stats', index=True) 66 | train_acc = FloatField() 67 | valid_acc = FloatField() 68 | test_acc = FloatField() 69 | parameters = FloatField() 70 | training_time = FloatField() 71 | 72 | class Meta: 73 | database = db 74 | 75 | 76 | class Nb101IntermediateStats(Model): 77 | """ 78 | Intermediate statistics for NAS-Bench-101. 79 | 80 | Attributes 81 | ---------- 82 | trial : Nb101TrialStats 83 | The exact trial where the intermediate result is produced. 84 | current_epoch : int 85 | Elapsed epochs when evaluation is done. 86 | train_acc : float 87 | Intermediate accuracy on training data, ranging from 0 to 100. 88 | valid_acc : float 89 | Intermediate accuracy on validation data, ranging from 0 to 100. 90 | test_acc : float 91 | Intermediate accuracy on test data, ranging from 0 to 100. 92 | training_time : float 93 | Time elapsed in seconds. 94 | """ 95 | 96 | trial = ForeignKeyField(Nb101TrialStats, backref='intermediates', index=True) 97 | current_epoch = IntegerField(index=True) 98 | train_acc = FloatField() 99 | valid_acc = FloatField() 100 | test_acc = FloatField() 101 | training_time = FloatField() 102 | 103 | class Meta: 104 | database = db 105 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench101/db_gen/query.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from peewee import fn 4 | from playhouse.shortcuts import model_to_dict 5 | from hyperbox.networks.nasbench101.db_gen.model import Nb101TrialStats, Nb101TrialConfig 6 | from hyperbox.networks.nasbench101.db_gen.graph_util import hash_module, infer_num_vertices 7 | 8 | 9 | def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None): 10 | """ 11 | Query trial stats of NAS-Bench-101 given conditions. 12 | 13 | Parameters 14 | ---------- 15 | arch : dict or None 16 | If a dict, it is in the format that is described in 17 | :class:`nni.nas.benchmark.nasbench101.Nb101TrialConfig`. Only trial stats 18 | matched will be returned. If none, architecture will be a wildcard. 19 | num_epochs : int or None 20 | If int, matching results will be returned. Otherwise a wildcard. 21 | isomorphism : boolean 22 | Whether to match essentially-same architecture, i.e., architecture with the 23 | same graph-invariant hash value. 24 | reduction : str or None 25 | If 'none' or None, all trial stats will be returned directly. 26 | If 'mean', fields in trial stats will be averaged given the same trial config. 27 | 28 | Returns 29 | ------- 30 | generator of dict 31 | A generator of :class:`nni.nas.benchmark.nasbench101.Nb101TrialStats` objects, 32 | where each of them has been converted into a dict. 33 | """ 34 | fields = [] 35 | if reduction == 'none': 36 | reduction = None 37 | if reduction == 'mean': 38 | for field_name in Nb101TrialStats._meta.sorted_field_names: 39 | if field_name not in ['id', 'config']: 40 | fields.append(fn.AVG(getattr(Nb101TrialStats, field_name)).alias(field_name)) 41 | elif reduction is None: 42 | fields.append(Nb101TrialStats) 43 | else: 44 | raise ValueError('Unsupported reduction: \'%s\'' % reduction) 45 | query = Nb101TrialStats.select(*fields, Nb101TrialConfig).join(Nb101TrialConfig) 46 | conditions = [] 47 | if arch is not None: 48 | if isomorphism: 49 | num_vertices = infer_num_vertices(arch) 50 | conditions.append(Nb101TrialConfig.hash == hash_module(arch, num_vertices)) 51 | else: 52 | conditions.append(Nb101TrialConfig.arch == arch) 53 | if num_epochs is not None: 54 | conditions.append(Nb101TrialConfig.num_epochs == num_epochs) 55 | if conditions: 56 | query = query.where(functools.reduce(lambda a, b: a & b, conditions)) 57 | if reduction is not None: 58 | query = query.group_by(Nb101TrialStats.config) 59 | for k in query: 60 | yield model_to_dict(k) 61 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench101/readme.md: -------------------------------------------------------------------------------- 1 | > NASBench101 does not support weight-sharing training 2 | 3 | 1. install `nasbench-101` 4 | 5 | - for tensorflow 1.x 6 | ```bash 7 | pip install git+https://github.com/google-research/nasbench.git 8 | ``` 9 | 10 | - for tensorflow 2.x 11 | ```bash 12 | pip install git+https://github.com/gabikadlecova/nasbench.git 13 | pip install protobuf==3.20.0 14 | ``` 15 | 16 | 2. downlowd tfrecord 17 | 18 | ``` 19 | wget https://storage.googleapis.com/nasbench/nasbench_full.tfrecord 20 | wget https://storage.googleapis.com/nasbench/nasbench_only_108.tfrecord 21 | ``` 22 | 23 | - nasbench_full.tfrecord includes the results on epoch of 4, 12, 36, 108 24 | - nasbench_only_108.tfrecord only includes the result on epoch of 108 25 | 26 | 3. convert tfrecord to database 27 | 28 | ``` 29 | python db_gen.py --inputFile /path/to/nasbench_full.tfrecord 30 | ``` 31 | 32 | # Thanks 33 | - https://github.com/romulus0914/NASBench-PyTorch/blob/master/main.py 34 | - https://nni.readthedocs.io/en/v1.7/NAS/BenchmarksExample.html 35 | - https://github.com/microsoft/nni/blob/v1.7/src/sdk/pynni/nni/nas/benchmarks/nasbench101/__init__.py -------------------------------------------------------------------------------- /hyperbox/networks/nasbench201/__init__.py: -------------------------------------------------------------------------------- 1 | from .nasbench201 import NASBench201Network 2 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench201/db_gen/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import NONE, SKIP_CONNECT, CONV_1X1, CONV_3X3, AVG_POOL_3X3, PRIMITIVES 2 | from .model import Nb201TrialStats, Nb201IntermediateStats, Nb201TrialConfig 3 | from .query import query_nb201_trial_stats 4 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench201/db_gen/constants.py: -------------------------------------------------------------------------------- 1 | NONE = "none" 2 | SKIP_CONNECT = "skip_connect" 3 | CONV_1X1 = "conv_1x1" 4 | CONV_3X3 = "conv_3x3" 5 | AVG_POOL_3X3 = "avg_pool_3x3" 6 | 7 | PRIMITIVES = [NONE, AVG_POOL_3X3, CONV_3X3, CONV_1X1, SKIP_CONNECT] 8 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench201/db_gen/query.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from peewee import fn 4 | from playhouse.shortcuts import model_to_dict 5 | from hyperbox.networks.nasbench201.db_gen.model import Nb201TrialStats, Nb201TrialConfig 6 | 7 | 8 | def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_intermediates=False): 9 | """ 10 | Query trial stats of NAS-Bench-201 given conditions. 11 | Parameters 12 | ---------- 13 | arch : dict or None 14 | If a dict, it is in the format that is described in 15 | :class:`nni.nas.benchmark.nasbench201.Nb201TrialConfig`. Only trial stats 16 | matched will be returned. If none, all architectures in the database will be matched. 17 | num_epochs : int or None 18 | If int, matching results will be returned. Otherwise a wildcard. 19 | dataset : str or None 20 | If specified, can be one of the dataset available in :class:`nni.nas.benchmark.nasbench201.Nb201TrialConfig`. 21 | Otherwise a wildcard. 22 | reduction : str or None 23 | If 'none' or None, all trial stats will be returned directly. 24 | If 'mean', fields in trial stats will be averaged given the same trial config. 25 | include_intermediates : boolean 26 | If true, intermediate results will be returned. 27 | Returns 28 | ------- 29 | generator of dict 30 | A generator of :class:`nni.nas.benchmark.nasbench201.Nb201TrialStats` objects, 31 | where each of them has been converted into a dict. 32 | """ 33 | fields = [] 34 | if reduction == 'none': 35 | reduction = None 36 | if reduction == 'mean': 37 | for field_name in Nb201TrialStats._meta.sorted_field_names: 38 | if field_name not in ['id', 'config', 'seed']: 39 | fields.append(fn.AVG(getattr(Nb201TrialStats, field_name)).alias(field_name)) 40 | elif reduction is None: 41 | fields.append(Nb201TrialStats) 42 | else: 43 | raise ValueError('Unsupported reduction: \'%s\'' % reduction) 44 | query = Nb201TrialStats.select(*fields, Nb201TrialConfig).join(Nb201TrialConfig) 45 | conditions = [] 46 | if arch is not None: 47 | conditions.append(Nb201TrialConfig.arch == arch) 48 | if num_epochs is not None: 49 | conditions.append(Nb201TrialConfig.num_epochs == num_epochs) 50 | if dataset is not None: 51 | conditions.append(Nb201TrialConfig.dataset == dataset) 52 | if conditions: 53 | query = query.where(functools.reduce(lambda a, b: a & b, conditions)) 54 | if reduction is not None: 55 | query = query.group_by(Nb201TrialStats.config) 56 | for trial in query: 57 | if include_intermediates: 58 | data = model_to_dict(trial) 59 | # exclude 'trial' from intermediates as it is already available in data 60 | data['intermediates'] = [ 61 | {k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates 62 | ] 63 | yield data 64 | else: 65 | yield model_to_dict(trial) 66 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench201/gen_nasbench201_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | pip install peewee 5 | pip install gdown 6 | 7 | if [ -z "${NASBENCHMARK_DIR}" ]; then 8 | NASBENCHMARK_DIR=~/.hyperbox/nasbench201 9 | fi 10 | 11 | echo "Downloading NAS-Bench-201..." 12 | if [ -f "nb201.pth" ]; then 13 | echo "nb201.pth found. Skip download." 14 | else 15 | gdown https://drive.google.com/uc\?id\=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_ -O nb201.pth 16 | fi 17 | 18 | echo "Generating database..." 19 | rm -f ${NASBENCHMARK_DIR}/nasbench201.db ${NASBENCHMARK_DIR}/nasbench201.db-journal 20 | mkdir -p ${NASBENCHMARK_DIR} 21 | python hyperbox/networks/nasbench201/db_gen/db_gen.py nb201.pth 22 | # python hyperbox.networks.nasbench201.db_gen.db_gen nb201.pth 23 | # rm -f nb201.pth -------------------------------------------------------------------------------- /hyperbox/networks/nasbench301/__init__.py: -------------------------------------------------------------------------------- 1 | from .nasbench301 import NASBench301Network 2 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench301/install/readme.md: -------------------------------------------------------------------------------- 1 | ## NASBench301 Installation Guidence 2 | 3 | ## Environments 4 | 5 | ubuntu16.04 6 | cuda11.1 7 | torch=1.8+cu111 8 | python=3.7.0 9 | 10 | 如果是cuda 10.2,执行以下命令: 11 | 12 | ``` 13 | cat requirements_cuda102.txt | xargs -n 1 -L 1 pip install 14 | ``` 15 | 16 | 如果是cuda11.1,执行以下命令: 17 | ``` 18 | cat requirements_cuda111.txt | xargs -n 1 -L 1 pip install 19 | ``` 20 | 21 | ## Install NB301 22 | 23 | ``` 24 | $ git clone https://github.com/automl/nasbench301 25 | $ cd nasbench301 26 | $ pip install . 27 | ``` 28 | 29 | ## Download weights 30 | 31 | v0.9: [nasbench301_models_v0.9_zip](https://figshare.com/articles/software/nasbench301_models_v0_9_zip/12962432) 32 | 33 | v1.0: [nasbench301_models_v1_0_zip](https://figshare.com/articles/software/nasbench301_models_v1_0_zip/13061510) 34 | 35 | 运行官方demo: 36 | 37 | ``` 38 | cd nasbench301/nasbench301 39 | unzip nasbench301_models_v1.0.zip 40 | mv nb_models nb_models_1.0 41 | python example.py 42 | ``` 43 | 44 | 运行结果: 45 | 46 | ``` 47 | ==> Loading performance surrogate model... 48 | /home/pdluser/project/nasbench301/nasbench301/nb_models_1.0/xgb_v1.0 49 | ==> Loading runtime surrogate model... 50 | ==> Creating test configs... 51 | ==> Predict runtime and performance... 52 | Genotype architecture performance: 94.167947, runtime 4781.153014 53 | Configspace architecture performance: 91.834275, runtime 4285.378243 54 | ``` 55 | 56 | ## Using NB301 in Hyperbox 57 | 58 | replace `default_path` to your download path `/path/to/your/downloads/nb_models_1.0` -------------------------------------------------------------------------------- /hyperbox/networks/nasbench301/install/requirements_cu102.txt: -------------------------------------------------------------------------------- 1 | # for cuda10.2 2 | xgboost 3 | lightgbm 4 | pathvalidate 5 | ConfigSpace 6 | git+https://github.com/automl/nasbench301 7 | torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.1.html 8 | torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.1.html 9 | torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.1.html 10 | torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.1.html 11 | torch-geometric -f https://pytorch-geometric.com/whl/torch-1.8.1.html 12 | 13 | # for cuda10.2 14 | git+https://github.com/automl/Auto-PyTorch.git@nb301 15 | autograd>=1.3 16 | click 17 | Cython 18 | ConfigSpace==0.4.12 19 | fasteners 20 | ipython 21 | imblearn 22 | lightgbm>=2.3.1 23 | matplotlib 24 | netifaces 25 | numpy 26 | pandas 27 | pathvalidate 28 | Pillow>=7.1.2 29 | psutil 30 | pynisher 31 | Pyro4 32 | scikit-image 33 | scipy 34 | scikit-learn==0.23.0 35 | seaborn 36 | setuptools 37 | serpent 38 | statsmodels 39 | tensorboard==1.14.0 40 | tensorflow-estimator 41 | tensorflow-gpu 42 | tensorboard_logger 43 | torch>=1.5.0 44 | torchvision>=0.6.0 45 | timm 46 | tqdm 47 | xgboost 48 | torch-scatter==2.0.4+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html 49 | torch-sparse==0.6.3+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html 50 | torch-cluster==1.5.5+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html 51 | torch-spline-conv==1.2.0+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html 52 | torch-geometric 53 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench301/utils.py: -------------------------------------------------------------------------------- 1 | # author: pprp 2 | # date: 2021-10-31 3 | # function: default modelv1.0 4 | # install: 5 | ''' 6 | git clone https://github.com/automl/nasbench301.git 7 | cd nasbench301 8 | cat requirements.txt | xargs -n 1 -L 1 pip install 9 | pip install . 10 | ''' 11 | 12 | import os 13 | from collections import namedtuple 14 | 15 | import nasbench301 as nb 16 | from ConfigSpace.read_and_write import json as cs_json 17 | from nasbench301.surrogate_models import ensemble 18 | 19 | default_version = "1.0" 20 | 21 | default_path = "~/.hyperbox/nb_models" 22 | 23 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 24 | 25 | 26 | def load_model(): 27 | model_paths = { 28 | model_name: os.path.join(default_path, '{}_v1.0'.format(model_name)) 29 | for model_name in ['xgb', 'gnn_gin', 'lgb_runtime'] 30 | } 31 | print("==> Loading performance surrogate model...") 32 | ensemble_dir_performance = model_paths['xgb'] 33 | performance_model = nb.load_ensemble(ensemble_dir_performance) 34 | 35 | print("==> Loading runtime surrogate model...") 36 | ensemble_dir_runtime = model_paths['lgb_runtime'] 37 | runtime_model = nb.load_ensemble(ensemble_dir_runtime) 38 | return performance_model, runtime_model 39 | 40 | 41 | def generate_results(genotype_config): 42 | # pmodel: prediction model 43 | # rmodel: runtime model 44 | pmodel, rmodel = load_model() 45 | print("==> Creating test configs...") 46 | genotype_config = Genotype( 47 | normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), 48 | ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], 49 | normal_concat=[2, 3, 4, 5], 50 | reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), 51 | ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], 52 | reduce_concat=[2, 3, 4, 5] 53 | ) if genotype_config is None else genotype_config 54 | 55 | prediction_genotype = pmodel.predict( 56 | config=genotype_config, representation="genotype", with_noise=True) 57 | runtime_genotype = rmodel.predict( 58 | config=genotype_config, representation="genotype") 59 | 60 | return prediction_genotype, runtime_genotype 61 | 62 | 63 | def sample_config(): 64 | configspace_path = './hyperbox/networks/nasbench301/configspace.json' 65 | with open(configspace_path, "r") as f: 66 | json_string = f.read() 67 | configspace = cs_json.read(json_string) 68 | configspace_config = configspace.sample_configuration() 69 | return configspace_config 70 | 71 | 72 | def test1(genotype_config=None): 73 | p, r = generate_results(genotype_config) 74 | print("Genotype architecture performance: %f, runtime %f" % (p, r)) 75 | 76 | 77 | def test2(): 78 | # pmodel: prediction model 79 | # rmodel: runtime model 80 | pmodel, rmodel = load_model() 81 | configspace_config = sample_config() 82 | prediction_configspace = pmodel.predict( 83 | config=configspace_config, representation="configspace", with_noise=True) 84 | runtime_configspace = rmodel.predict( 85 | config=configspace_config, representation="configspace") 86 | print("Configspace architecture performance: %f, runtime %f" % 87 | (prediction_configspace, runtime_configspace)) 88 | 89 | 90 | if __name__ == "__main__": 91 | test1() 92 | test2() 93 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbench_mbnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/networks/nasbench_mbnet/__init__.py -------------------------------------------------------------------------------- /hyperbox/networks/nasbenchasr/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import NASBenchASR 2 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbenchasr/download_nasbenchasr.sh: -------------------------------------------------------------------------------- 1 | if [ -z "${NASBENCHMARK_DIR}" ]; then 2 | NASBENCHMARK_DIR=~/.hyperbox/nasbenchasr/ 3 | fi 4 | 5 | echo "Downloading NAS-Bench-ASR..." 6 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-bench-gtx-1080ti-fp32.pickle 7 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-bench-jetson-nano-fp32.pickle 8 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e10-1234.pickle 9 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e40-1234.pickle 10 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e40-1235.pickle 11 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e40-1236.pickle 12 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e5-1234.pickle 13 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-info.pickle 14 | 15 | 16 | mkdir -p ${NASBENCHMARK_DIR} 17 | mv nb-asr*.pickle ${NASBENCHMARK_DIR} -------------------------------------------------------------------------------- /hyperbox/networks/nasbenchasr/ops.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class PadConvRelu(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, dilation, strides, groups=1, dropout_rate=0, context=4, name='PadConvRelu'): 9 | super().__init__() 10 | self.name = name 11 | 12 | if int(context / strides) >= (kernel_size*dilation-strides): 13 | rpad = kernel_size*dilation-strides 14 | lpad = 0 15 | else: 16 | rpad = int(context / strides) 17 | lpad = int((kernel_size - 1)*dilation - rpad) 18 | 19 | self.pad = nn.ZeroPad2d((lpad, rpad, 0, 0)) 20 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=strides, dilation=dilation, groups=groups) 21 | self.relu = nn.ReLU(inplace=False) 22 | self.dropout = nn.Dropout(p=dropout_rate) 23 | 24 | def forward(self, x): 25 | x = self.pad(x) 26 | x = self.conv(x) 27 | x = self.relu(x) 28 | x = torch.clamp_max_(x, 20) 29 | x = self.dropout(x) 30 | return x 31 | 32 | 33 | class Linear(nn.Module): 34 | def __init__(self, in_features, out_features, dropout_rate=0, name='Linear'): 35 | super().__init__() 36 | self.name = name 37 | 38 | self.linear = nn.Linear(in_features, out_features) 39 | self.relu = nn.ReLU(inplace=False) 40 | self.dropout = nn.Dropout(p=dropout_rate) 41 | 42 | def forward(self, x): 43 | shape = x.shape 44 | x = x.permute(0,2,1) 45 | x = self.linear(x) 46 | x = self.relu(x) 47 | x = torch.clamp_max_(x, 20) 48 | x = self.dropout(x) 49 | x = x.permute(0,2,1) 50 | return x 51 | 52 | 53 | class Identity(nn.Module): 54 | def __init__(self, name='Identity'): 55 | super().__init__() 56 | self.name = name 57 | 58 | def forward(self, x): 59 | return x 60 | 61 | 62 | class Zero(nn.Module): 63 | def __init__(self, name='Zero'): 64 | super(Zero, self).__init__() 65 | self.name = name 66 | 67 | def forward(self, x): 68 | return torch.zeros_like(x) 69 | 70 | 71 | _ops = { 72 | 'linear': Linear, 73 | 'conv5': functools.partial(PadConvRelu, kernel_size=5, dilation=1, strides=1, groups=100, name='conv5'), 74 | 'conv5d2': functools.partial(PadConvRelu, kernel_size=5, dilation=2, strides=1, groups=100, name='conv52d'), 75 | 'conv7': functools.partial(PadConvRelu, kernel_size=7, dilation=1, strides=1, groups=100, name='conv7'), 76 | 'conv7d2': functools.partial(PadConvRelu, kernel_size=7, dilation=2, strides=1, groups=100, name='conv52d'), 77 | 'zero': lambda *args, **kwargs: Zero(name='zero') 78 | } 79 | 80 | _branch_ops = { 81 | 0: Zero, # branch not present 82 | 1: Identity # branch present 83 | } 84 | 85 | -------------------------------------------------------------------------------- /hyperbox/networks/nasbenchasr/readme.md: -------------------------------------------------------------------------------- 1 | # NAS-Bench-ASR 2 | 3 | ## 1. Download datasets 4 | 5 | ```bash 6 | if [ -z "${NASBENCHMARK_DIR}" ]; then 7 | NASBENCHMARK_DIR=~/.hyperbox/nasbenchasr/ 8 | fi 9 | 10 | echo "Downloading NAS-Bench-ASR..." 11 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-bench-gtx-1080ti-fp32.pickle 12 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-bench-jetson-nano-fp32.pickle 13 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e10-1234.pickle 14 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e40-1234.pickle 15 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e40-1235.pickle 16 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-e5-1234.pickle 17 | wget https://github.com/SamsungLabs/nb-asr/releases/download/v1.1.0/nb-asr-info.pickle 18 | 19 | mkdir -p ${NASBENCHMARK_DIR} 20 | mv nb-asr*.pickle ${NASBENCHMARK_DIR} 21 | ``` 22 | 23 | ## 2. Usage 24 | 25 | ```python 26 | from hyperbox.mutator import RandomMutator 27 | model = NASBenchASR() 28 | print(sum([p.numel() for p in model.parameters()])) 29 | rm = RandomMutator(model) 30 | rm.reset() 31 | 32 | B, F, T = 2, 80, 30 33 | for T in [16]: 34 | x = torch.rand(B, F, T) 35 | rm.reset() 36 | y = model(x) 37 | print(y.shape) 38 | # print(rm._cache, len(rm._cache)) 39 | 40 | 41 | list_desc = [ 42 | ['linear', 1], 43 | ['conv5', 1, 0], 44 | ['conv7d2', 1, 0, 1], 45 | ] 46 | mask = NASBenchASR.list_desc_to_dict_mask(list_desc) 47 | # print(mask) 48 | model2 = NASBenchASR(mask=mask) 49 | print(sum([p.numel() for p in model2.parameters()])) 50 | print(model2.arch_size((B, F, T))) 51 | y = model(x) 52 | print(NASBenchASR.dict_mask_to_list_desc(mask)) 53 | 54 | print(model2.query_full_info()) 55 | print(model2.query_flops()) 56 | print(model2.query_latency()) 57 | print(model2.query_params()) 58 | print(model2.query_test_acc()) 59 | print(model2.query_val_acc()) 60 | ``` 61 | 62 | 输出结果(带有随机性): 63 | ```bash 64 | 84867649 65 | torch.Size([2, 4, 49]) 66 | 39733249 67 | (181.789408, 151.5703125) 68 | [['linear', 1], ['conv5', 1, 0], ['conv7d2', 1, 0, 1]] 69 | {'val_per': [0.9687853, 0.87188685, 0.7872086, 0.6666002, 0.5817228, 0.50474864, 0.4540081, 0.4089128, 0.3726506, 0.35265988, 0.33154014, 0.30796307, 0.29800093, 0.28405392, 0.27661553, 0.270439, 0.26074252, 0.2551637, 0.25124526, 0.2481238, 0.24918643, 0.24540082, 0.24041975, 0.2369662, 0.2374311, 0.2337119, 0.2337119, 0.23391114, 0.2327821, 0.22939497, 0.22893007, 0.22879724, 0.22753537, 0.22992627, 0.22985986, 0.22839876, 0.22461313, 0.22381617, 0.22275354, 0.22328486], 'test_per': 0.24767844378948212, 'arch_vec': [(0, 1), (1, 1, 0), (4, 1, 0, 1)], 'model_hash': 'adb47992d93622245376905cc956a149', 'seed': 1236, 'jetson-nano-fp32': {'latency': 0.578345775604248}, 'gtx-1080ti-fp32': {'latency': 0.04792499542236328}} 70 | 3845877266 71 | {'jetson-nano-fp32': {'latency': 0.578345775604248}, 'gtx-1080ti-fp32': {'latency': 0.04792499542236328}} 72 | 43100448 73 | 0.24906444549560547 74 | 0.22275354 75 | ``` 76 | 77 | # Acknowledge 78 | The code is based on https://github.com/SamsungLabs/nb-asr -------------------------------------------------------------------------------- /hyperbox/networks/nasbenchasr/search_space.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from hyperbox.networks.nasbenchasr.utils import recursive_iter, flatten, copy_structure 4 | 5 | 6 | all_ops = ['linear', 'conv5', 'conv5d2', 'conv7', 'conv7d2', 'zero'] 7 | ops_no_zero = all_ops[:-1] 8 | default_nodes = 3 9 | 10 | 11 | def get_search_space(ops=None, nodes=None): 12 | ''' Return boundaries of the search space for the given list 13 | of available operations and number of nodes. 14 | ''' 15 | ops = ops if ops is not None else all_ops 16 | nodes = nodes if nodes is not None else default_nodes 17 | search_space = [[len(ops)] + [2]*(idx+1) for idx in range(nodes)] 18 | return search_space 19 | 20 | 21 | def get_model_hash(arch_vec, ops=None, minimize=True): 22 | ''' Get hash of the architecture specified by arch_vec. 23 | Architecture hash can be used to determine if two 24 | configurations from the search space are in fact the 25 | same (graph isomorphism). 26 | ''' 27 | from .graph_utils import get_model_graph, graph_hash 28 | g, _ = get_model_graph(arch_vec, ops=ops, minimize=minimize) 29 | return graph_hash(g) 30 | 31 | 32 | def get_all_architectures(ops=None, nodes=None): 33 | ''' Yields all architecture configurations in the search space 34 | ''' 35 | search_space = get_search_space(ops, nodes) 36 | flat = flatten(search_space) 37 | cfg = [0 for _ in range(len(flat))] 38 | end = False 39 | while not end: 40 | yield copy_structure(cfg, search_space) 41 | for dim in range(len(flat)): 42 | cfg[dim] += 1 43 | if cfg[dim] != flat[dim]: 44 | break 45 | cfg[dim] = 0 46 | if dim+1 >= len(flat): 47 | end = True 48 | 49 | 50 | def get_random_architectures(num, ops=None, nodes=None, seed=None): 51 | ''' Get random architecture configurations from the search space 52 | ''' 53 | ops = ops if ops is not None else all_ops 54 | nodes = nodes if nodes is not None else default_nodes 55 | if seed is not None: 56 | random.seed(seed) 57 | search_space = [[len(ops)] + [2]*(idx+1) for idx in range(nodes)] 58 | flat = flatten(search_space) 59 | models = [] 60 | while len(models) < num: 61 | m = [random.randrange(opts) for opts in flat] 62 | m = copy_structure(m, search_space) 63 | models.append(m) 64 | return models 65 | 66 | 67 | def get_archs_with_zero(): 68 | models_with_zero = {} 69 | for m in get_all_architectures(all_ops, default_nodes): 70 | if 5 in flatten(m): 71 | h = get_model_hash(m) 72 | models_with_zero[h] = m 73 | new_model_archs = [models_with_zero[k] for k in sorted(models_with_zero.keys())] 74 | return new_model_archs 75 | 76 | 77 | def arch_vec_to_names(arch_vec, ops=None): 78 | ''' Translates identifiers of operations in ``arch_vec`` to their names. 79 | ``ops`` can be provided externally to avoid relying on the current definition 80 | of available ops. Otherwise canonical ``all_ops`` will be used. 81 | ''' 82 | 83 | if ops is None: 84 | ops = all_ops 85 | 86 | # current approach is to have an arch vector contain sub-vectors for node in a cell, 87 | # each subvector has a form of: 88 | # [op_idx, branch_op_idx...] 89 | # where op_idx points to an operation from ``all_ops`` and ``branch_op_idx`` is 90 | # either 0 (no skip connection) or 1 (identity skip connection) 91 | # since skip connects are already quite self-explanatory we leave them as they are 92 | # and only change numbers of the main operations to their respective names 93 | return [[all_ops[op_idx]] + branches for op_idx, *branches in arch_vec] 94 | -------------------------------------------------------------------------------- /hyperbox/networks/network_ema.py: -------------------------------------------------------------------------------- 1 | """ Exponential Moving Average (EMA) of model updates 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | 4 | source code: https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/model_ema.py 5 | """ 6 | from collections import OrderedDict 7 | from copy import deepcopy 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from hyperbox.networks.base_nas_network import BaseNASNetwork 13 | 14 | 15 | class ModelEma(nn.Module): 16 | """ Model Exponential Moving Average V2 17 | Keep a moving average of everything in the model state_dict (parameters and buffers). 18 | V2 of this module is simpler, it does not match params/buffers based on name but simply 19 | iterates in order. It works with torchscript (JIT of full model). 20 | This is intended to allow functionality like 21 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 22 | A smoothed version of the weights is necessary for some training schemes to perform well. 23 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use 24 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA 25 | smoothing of weights to match results. Pay attention to the decay constant you are using 26 | relative to your update count per epoch. 27 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but 28 | disable validation of the EMA weights. Validation will have to be done manually in a separate 29 | process, or after the training stops converging. 30 | This class is sensitive where it is initialized in the sequence of model init, 31 | GPU assignment and distributed training wrappers. 32 | """ 33 | def __init__(self, model, decay=0.7, final_decay=0.999, device=None): 34 | super(ModelEma, self).__init__() 35 | # make a copy of the model for accumulating moving average of weights 36 | if isinstance(model, BaseNASNetwork): 37 | self.module = model.copy() 38 | else: 39 | self.module = deepcopy(model) 40 | self.module.eval() 41 | self.decay = decay 42 | self.init_decay = decay 43 | self.final_decay = final_decay 44 | self.device = device # perform ema on different device from model if set 45 | if self.device is not None: 46 | self.module.to(device=device) 47 | 48 | def _update(self, model, update_fn): 49 | with torch.no_grad(): 50 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): 51 | if self.device is not None: 52 | model_v = model_v.to(device=self.device) 53 | ema_v.copy_(update_fn(ema_v, model_v)) 54 | 55 | def update(self, model): 56 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) 57 | 58 | def set(self, model): 59 | self._update(model, update_fn=lambda e, m: m) 60 | 61 | def update_decay(self, epoch, all_epoch=100): 62 | self.decay = self.init_decay + (epoch/all_epoch)*(self.final_decay-self.init_decay) 63 | 64 | def forward(self, *args, **kwargs): 65 | return self.module(*args, **kwargs) 66 | 67 | 68 | if __name__ == '__main__': 69 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 70 | from hyperbox.networks.ofa import OFAMobileNetV3 71 | from hyperbox.mutator import RandomMutator 72 | 73 | supernet = OFAMobileNetV3(num_classes=10, width_mult=1.0).to(device).eval() 74 | mutator = RandomMutator(supernet) 75 | # ema = ModelEma(supernet, 0.9) 76 | ema = ModelEma(supernet, 0.9).eval() 77 | print('init\nsupernet', supernet.classifier.weight.data[:2,:8]) 78 | print('ema', ema.module.classifier.weight.data[:2,:8]) 79 | for i in range(10): 80 | mutator.reset() 81 | supernet.init_weights() 82 | ema.update(supernet) 83 | x = torch.rand(2,3,64,64).to(device) 84 | y1 = supernet(x) 85 | subnet = ema.module.build_subnet(mutator._cache).cuda().eval() 86 | y2 = subnet(x) 87 | print('supernet', supernet.classifier.weight.data[:2,:8]) 88 | print('ema', ema.module.classifier.weight.data[:2,:8]) 89 | -------------------------------------------------------------------------------- /hyperbox/networks/ofa/__init__.py: -------------------------------------------------------------------------------- 1 | from .ofa_mbv3 import OFAMobileNetV3 2 | -------------------------------------------------------------------------------- /hyperbox/networks/proxylessnas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/networks/proxylessnas/__init__.py -------------------------------------------------------------------------------- /hyperbox/networks/proxylessnas/putils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def get_parameters(model, keys=None, mode='include'): 4 | if keys is None: 5 | for name, param in model.named_parameters(): 6 | yield param 7 | elif mode == 'include': 8 | for name, param in model.named_parameters(): 9 | flag = False 10 | for key in keys: 11 | if key in name: 12 | flag = True 13 | break 14 | if flag: 15 | yield param 16 | elif mode == 'exclude': 17 | for name, param in model.named_parameters(): 18 | flag = True 19 | for key in keys: 20 | if key in name: 21 | flag = False 22 | break 23 | if flag: 24 | yield param 25 | else: 26 | raise ValueError('do not support: %s' % mode) 27 | 28 | 29 | def get_same_padding(kernel_size): 30 | if isinstance(kernel_size, tuple): 31 | assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size 32 | p1 = get_same_padding(kernel_size[0]) 33 | p2 = get_same_padding(kernel_size[1]) 34 | return p1, p2 35 | assert isinstance(kernel_size, int), 'kernel size should be either `int` or `tuple`' 36 | assert kernel_size % 2 > 0, 'kernel size should be odd number' 37 | return kernel_size // 2 38 | 39 | def build_activation(act_func, inplace=True): 40 | if act_func == 'relu': 41 | return nn.ReLU(inplace=inplace) 42 | elif act_func == 'relu6': 43 | return nn.ReLU6(inplace=inplace) 44 | elif act_func == 'tanh': 45 | return nn.Tanh() 46 | elif act_func == 'sigmoid': 47 | return nn.Sigmoid() 48 | elif act_func is None: 49 | return None 50 | else: 51 | raise ValueError('do not support: %s' % act_func) 52 | 53 | 54 | def make_divisible(v, divisor, min_val=None): 55 | """ 56 | This function is taken from the original tf repo. 57 | It ensures that all layers have a channel number that is divisible by 8 58 | It can be seen here: 59 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 60 | """ 61 | if min_val is None: 62 | min_val = divisor 63 | new_v = max(min_val, int(v + divisor / 2) // divisor * divisor) 64 | # Make sure that round down does not go down by more than 10%. 65 | if new_v < 0.9 * v: 66 | new_v += divisor 67 | return new_v 68 | -------------------------------------------------------------------------------- /hyperbox/networks/repnas/__init__.py: -------------------------------------------------------------------------------- 1 | from .repnas_spos import RepNAS -------------------------------------------------------------------------------- /hyperbox/networks/repnas/utils.py: -------------------------------------------------------------------------------- 1 | from hyperbox.mutables.spaces import OperationSpace 2 | from numpy.lib.arraysetops import isin 3 | import torch 4 | import torch.nn as nn 5 | # from .rep_ops import * 6 | from hyperbox.networks.repnas.rep_ops import * 7 | 8 | 9 | def fuse(candidates, weights, kernel_size=3): 10 | k_list = [] 11 | b_list = [] 12 | 13 | for i in range(len(candidates)): 14 | op = candidates[i] 15 | weight = weights[i].float() 16 | if op.__class__.__name__ == "DBB1x1kxk": 17 | if hasattr(op.dbb_1x1_kxk, 'idconv1'): 18 | k1 = op.dbb_1x1_kxk.idconv1.get_actual_kernel() 19 | else: 20 | k1 = op.dbb_1x1_kxk.conv1.weight 21 | 22 | k1, b1 = transI_fusebn(k1, op.dbb_1x1_kxk.bn1) 23 | k2, b2 = transI_fusebn(op.dbb_1x1_kxk.conv2.weight, op.dbb_1x1_kxk.bn2) 24 | 25 | k, b = transIII_1x1_kxk(k1, b1, k2, b2, groups=op.groups) 26 | elif op.__class__.__name__ == "DBB1x1": 27 | k, b = transI_fusebn(op.conv.weight, op.bn) 28 | k = transVI_multiscale(k, kernel_size) 29 | elif op.__class__.__name__ == "DBBORIGIN": 30 | k, b = transI_fusebn(op.conv.weight, op.bn) 31 | elif op.__class__.__name__ == "DBBAVG": 32 | ka = transV_avg(op.out_channels, op.kernel_size, op.groups) 33 | k2, b2 = transI_fusebn(ka.to(op.dbb_avg.avgbn.weight.device), op.dbb_avg.avgbn) 34 | 35 | if hasattr(op.dbb_avg, 'conv'): 36 | k1, b1 = transI_fusebn(op.dbb_avg.conv.weight, op.dbb_avg.bn) 37 | k, b = transIII_1x1_kxk(k1, b1, k2, b2, groups=op.groups) 38 | else: 39 | k, b = k2, b2 40 | else: 41 | raise "TypeError: Not In DBBAVG DBB1x1kxk DBB1x1 DBBORIGIN." 42 | k_list.append(k.detach() * weight) 43 | b_list.append(b.detach() * weight) 44 | 45 | return transII_addbranch(k_list, b_list) 46 | 47 | 48 | def replace(net): 49 | for name, module in net.named_modules(): 50 | if isinstance(module, OperationSpace): 51 | candidates = [] 52 | weights = [] 53 | for idx, weight in enumerate(module.mask): 54 | if weight: 55 | candidates.append(module.candidates_original[idx]) 56 | weights.append(weight) 57 | ks = max([c_.kernel_size for c_ in candidates]) 58 | k, b = fuse(candidates, weights, ks) 59 | first = module.candidates_original[0] 60 | inc = first.in_channels 61 | ouc = first.out_channels 62 | s = first.stride 63 | p = ks//2 64 | g = first.groups 65 | reparam = nn.Conv2d(in_channels=inc, out_channels=ouc, kernel_size=ks, 66 | stride=s, padding=p, dilation=1, groups=g) 67 | reparam.weight.data = k 68 | reparam.bias.data = b 69 | 70 | module.candidates_original = [reparam] 71 | module.candidates = torch.nn.ModuleList([reparam]) 72 | module.mask = torch.tensor([True]) 73 | -------------------------------------------------------------------------------- /hyperbox/networks/resnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * -------------------------------------------------------------------------------- /hyperbox/networks/spos/__init__.py: -------------------------------------------------------------------------------- 1 | from .shuffle_blocks import * 2 | from .spos_net import * 3 | -------------------------------------------------------------------------------- /hyperbox/networks/vit/__init__.py: -------------------------------------------------------------------------------- 1 | from .vit import * -------------------------------------------------------------------------------- /hyperbox/optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/optimizers/__init__.py -------------------------------------------------------------------------------- /hyperbox/optimizers/sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | 4 | __all__ = ["ASAM", "SAM"] 5 | 6 | class ASAM: 7 | def __init__(self, optimizer, model, rho=0.5, eta=0.01): 8 | self.optimizer = optimizer 9 | self.model = model 10 | self.rho = rho 11 | self.eta = eta 12 | self.state = defaultdict(dict) 13 | 14 | @torch.no_grad() 15 | def ascent_step(self): 16 | wgrads = [] 17 | for n, p in self.model.named_parameters(): 18 | if p.grad is None: 19 | continue 20 | t_w = self.state[p].get("eps") 21 | if t_w is None: 22 | t_w = torch.clone(p).detach() 23 | self.state[p]["eps"] = t_w 24 | if 'weight' in n: 25 | t_w[...] = p[...] 26 | t_w.abs_().add_(self.eta) 27 | p.grad.mul_(t_w) 28 | wgrads.append(torch.norm(p.grad, p=2)) 29 | wgrad_norm = torch.norm(torch.stack(wgrads), p=2) + 1.e-16 30 | for n, p in self.model.named_parameters(): 31 | if p.grad is None: 32 | continue 33 | t_w = self.state[p].get("eps") 34 | if 'weight' in n: 35 | p.grad.mul_(t_w) 36 | eps = t_w 37 | eps[...] = p.grad[...] 38 | eps.mul_(self.rho / wgrad_norm) 39 | p.add_(eps) 40 | self.optimizer.zero_grad() 41 | 42 | @torch.no_grad() 43 | def descent_step(self): 44 | for n, p in self.model.named_parameters(): 45 | if p.grad is None: 46 | continue 47 | p.sub_(self.state[p]["eps"]) 48 | self.optimizer.step() 49 | self.optimizer.zero_grad() 50 | 51 | 52 | class SAM(ASAM): 53 | @torch.no_grad() 54 | def ascent_step(self): 55 | grads = [] 56 | for n, p in self.model.named_parameters(): 57 | if p.grad is None: 58 | continue 59 | grads.append(torch.norm(p.grad, p=2)) 60 | grad_norm = torch.norm(torch.stack(grads), p=2) + 1.e-16 61 | for n, p in self.model.named_parameters(): 62 | if p.grad is None: 63 | continue 64 | eps = self.state[p].get("eps") 65 | if eps is None: 66 | eps = torch.clone(p).detach() 67 | self.state[p]["eps"] = eps 68 | eps[...] = p.grad[...] 69 | eps.mul_(self.rho / grad_norm) 70 | p.add_(eps) 71 | self.optimizer.zero_grad() -------------------------------------------------------------------------------- /hyperbox/run.py: -------------------------------------------------------------------------------- 1 | # https://github.com/ashleve/lightning-hydra-template/blob/main/src/train.py 2 | import os 3 | import sys 4 | 5 | import hydra 6 | from omegaconf import DictConfig 7 | import pyrootutils 8 | 9 | # project root setup 10 | # searches for root indicators in parent dirs, like ".git", "pyproject.toml", etc. 11 | # sets PROJECT_ROOT environment variable (used in `configs/paths/default.yaml`) 12 | # loads environment variables from ".env" if exists 13 | # adds root dir to the PYTHONPATH (so this file can be run from any place) 14 | # https://github.com/ashleve/pyrootutils 15 | root = pyrootutils.setup_root(__file__, dotenv=True, pythonpath=True) # config.paths.root_dir 16 | 17 | 18 | @hydra.main(version_base="1.2", config_path="configs/", config_name="config.yaml") 19 | def main(config: DictConfig): 20 | 21 | # Imports should be nested inside @hydra.main to optimize tab completion 22 | # Read more here: https://github.com/facebookresearch/hydra/issues/934 23 | from hyperbox.train import train 24 | from hyperbox.utils import utils 25 | 26 | # A couple of optional utilities: 27 | # - disabling python warnings 28 | # - easier access to debug mode 29 | # - forcing debug friendly configuration 30 | # - forcing multi-gpu friendly configuration 31 | # You can safely get rid of this line if you don't want those 32 | utils.extras(config) 33 | 34 | # Pretty print config using Rich library 35 | if config.get("print_config"): 36 | utils.print_config(config, resolve=True) 37 | 38 | # Train model 39 | if config.ipdb_debug: 40 | from ipdb import set_trace 41 | set_trace() 42 | return train(config) 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /hyperbox/schedulers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/schedulers/__init__.py -------------------------------------------------------------------------------- /hyperbox/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/hyperbox/utils/__init__.py -------------------------------------------------------------------------------- /hyperbox/utils/average_meter.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | __all__ = [ 4 | 'AverageMeterGroup', 5 | 'AverageMeter' 6 | ] 7 | 8 | class AverageMeterGroup: 9 | """ 10 | Average meter group for multiple average meters. 11 | """ 12 | 13 | def __init__(self, verbose_type='avg'): 14 | self.meters = OrderedDict() 15 | self.verbose_type = verbose_type 16 | 17 | def update(self, data): 18 | """ 19 | Update the meter group with a dict of metrics. 20 | Non-exist average meters will be automatically created. 21 | """ 22 | for k, v in data.items(): 23 | if k not in self.meters: 24 | self.meters[k] = AverageMeter(k, ":.4f", self.verbose_type) 25 | self.meters[k].update(v) 26 | 27 | # def __getattr__(self, item): 28 | # return self.meters[item] 29 | 30 | def __getitem__(self, item): 31 | return self.meters[item] 32 | 33 | def __str__(self): 34 | return " ".join(f"{v}" for v in self.meters.values()) 35 | 36 | def summary(self): 37 | """ 38 | Return a summary string of group data. 39 | """ 40 | return " ".join(v.summary() for v in self.meters.values()) 41 | 42 | 43 | class AverageMeter: 44 | """ 45 | Computes and stores the average and current value. 46 | Parameters 47 | ---------- 48 | name : str 49 | Name to display. 50 | fmt : str 51 | Format string to print the values. 52 | verbose_type : str 53 | 'all': value(avg) 54 | 'avg': avg 55 | """ 56 | 57 | def __init__(self, name, fmt=':f', verbose_type='avg'): 58 | self.name = name 59 | self.fmt = fmt 60 | if verbose_type not in ['all', 'avg']: 61 | print('Not supported verbose type, using default verbose, "avg"') 62 | verbose_type = 'avg' 63 | self.verbose_type = verbose_type 64 | self.reset() 65 | 66 | def reset(self): 67 | """ 68 | Reset the meter. 69 | """ 70 | self.val = 0 71 | self.avg = 0 72 | self.sum = 0 73 | self.count = 0 74 | 75 | def update(self, val, n=1): 76 | """ 77 | Update with value and weight. 78 | Parameters 79 | ---------- 80 | val : float or int 81 | The new value to be accounted in. 82 | n : int 83 | The weight of the new value. 84 | """ 85 | self.val = val 86 | self.sum += val * n 87 | self.count += n 88 | self.avg = self.sum / self.count 89 | 90 | def __str__(self): 91 | if self.verbose_type=='all': 92 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 93 | elif self.verbose_type=='avg': 94 | fmtstr = '{name} {avg' + self.fmt + '}' 95 | else: 96 | fmtstr = '{name} {avg' + self.fmt + '}' 97 | return fmtstr.format(**self.__dict__) 98 | 99 | def summary(self): 100 | fmtstr = '{name}: {avg' + self.fmt + '}' 101 | return fmtstr.format(**self.__dict__) -------------------------------------------------------------------------------- /hyperbox/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | from loguru import logger 5 | 6 | 7 | def custom_format(record): 8 | path = os.path.abspath(record["file"].path) 9 | record["extra"]["abspath"] = path 10 | fmt = "[{time:YYYY-MM-DD HH:mm:ss}] [{level}] [{extra[abspath]}:{line} ({name})] {message}\n{exception}" 11 | return fmt 12 | 13 | 14 | LOGGER_DICT = {} 15 | 16 | def get_logger(name=None, level=logging.INFO, is_rank_zero=True, log2file=False): 17 | if is_rank_zero or name is None: 18 | name = 'exp' 19 | if name in LOGGER_DICT: 20 | return LOGGER_DICT[name] 21 | fmt = custom_format 22 | logger.remove() 23 | kwargs = {'sink': sys.stderr, 'format': fmt, 'level': level, 'colorize': True, 'backtrace': True} 24 | handlers = [kwargs] 25 | logger.opt(exception=True) 26 | if log2file: 27 | snd_kwargs = {k: v for k, v in kwargs.items() if k != 'sink'} 28 | snd_kwargs['sink'] = os.path.join(os.getcwd(), name + '.log') 29 | handlers.append(snd_kwargs) 30 | config = {"handlers": handlers} 31 | logger.configure(**config) 32 | LOGGER_DICT[name] = logger 33 | return logger 34 | 35 | 36 | if __name__ == '__main__': 37 | log1 = get_logger(None, level=logging.DEBUG, log2file=False) 38 | log1.debug('test') # showing 39 | log2 = get_logger('test2', level=logging.INFO, is_rank_zero=False, log2file=False) 40 | log2.debug('test2') # not showing 41 | log1.debug('test') # not showing 42 | print(log1 is log2) # True -------------------------------------------------------------------------------- /hyperbox/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import scipy.stats as stats 3 | 4 | class Accuracy: 5 | def __call__(self, output, target, topk=(1,)): 6 | return accuracy(output, target, topk) 7 | 8 | 9 | def accuracy(output, target, topk=(1,)): 10 | """ Computes the precision@k for the specified values of k """ 11 | maxk = max(topk) 12 | batch_size = target.size(0) 13 | 14 | _, pred = output.topk(maxk, 1, True, True) 15 | pred = pred.t() 16 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 17 | 18 | res = [] 19 | for k in topk: 20 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 21 | res.append(torch.tensor(correct_k.mul_(100.0 / batch_size))) 22 | if len(res) == 1: 23 | return res[-1] 24 | return res 25 | 26 | class Pearson: 27 | def __call__(self, output, target): 28 | return pearson(output, target) 29 | 30 | def pearson(vector1, vector2): 31 | n = len(vector1) 32 | #simple sums 33 | sum1 = sum(float(vector1[i]) for i in range(n)) 34 | sum2 = sum(float(vector2[i]) for i in range(n)) 35 | #sum up the squares 36 | sum1_pow = sum([pow(v, 2.0) for v in vector1]) 37 | sum2_pow = sum([pow(v, 2.0) for v in vector2]) 38 | #sum up the products 39 | p_sum = sum([vector1[i]*vector2[i] for i in range(n)]) 40 | #分子num,分母den 41 | num = p_sum - (sum1*sum2/n) 42 | den = math.sqrt((sum1_pow-pow(sum1, 2)/n)*(sum2_pow-pow(sum2, 2)/n)) 43 | if den == 0: 44 | return 0.0 45 | return num/den 46 | 47 | class KendallTau: 48 | def __call__(self, output, target): 49 | stats.kendalltau(output, target) -------------------------------------------------------------------------------- /hyperbox/utils/visualize_darts_cell.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import sys 3 | import os 4 | import json 5 | import numpy as np 6 | from graphviz import Digraph 7 | from argparse import ArgumentParser 8 | 9 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 10 | 11 | PRIMITIVES = [ 12 | 'max_pool_3x3', 13 | 'avg_pool_3x3', 14 | 'skip_connect', 15 | 'sep_conv_3x3', 16 | 'sep_conv_5x5', 17 | 'dil_conv_3x3', 18 | 'dil_conv_5x5' 19 | ] 20 | 21 | def convert_genotypes(arch_json_file): 22 | with open(arch_json_file, 'r') as f: 23 | arch = json.load(f) 24 | normal = [] 25 | reduce = [] 26 | ops = [] 27 | indices = [] 28 | for key in arch: 29 | if 'switch' not in key: 30 | if sum(arch[key]) == 0: 31 | op = 'None' 32 | else: 33 | op = np.array(PRIMITIVES)[arch[key]][0] 34 | ops.append(op) 35 | else: 36 | _ops = [] 37 | for i in range(len(arch[key])): 38 | if arch[key][i]: 39 | indices.append(i) 40 | _ops.append(ops[i]) 41 | for op, idx in zip(_ops, indices): 42 | if 'norm' in key: 43 | normal.append((op, idx)) 44 | else: 45 | reduce.append((op, idx)) 46 | ops = [] 47 | indices = [] 48 | geno = Genotype( 49 | normal=normal, normal_concat=list(range(2, 2+len(normal)//2)), 50 | reduce=reduce, reduce_concat=list(range(2, 2+len(reduce)//2)), 51 | ) 52 | print(geno) 53 | return geno 54 | 55 | def plot(genotype, filename): 56 | g = Digraph( 57 | format='pdf', 58 | edge_attr=dict(fontsize='20', fontname="times"), 59 | node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"), 60 | engine='dot') 61 | g.body.extend(['rankdir=LR']) 62 | 63 | g.node("c_{k-2}", fillcolor='darkseagreen2') 64 | g.node("c_{k-1}", fillcolor='darkseagreen2') 65 | assert len(genotype) % 2 == 0 66 | steps = len(genotype) // 2 67 | 68 | for i in range(steps): 69 | g.node(str(i), fillcolor='lightblue') 70 | 71 | for i in range(steps): 72 | for k in [2*i, 2*i + 1]: 73 | op, j = genotype[k] 74 | if j == 0: 75 | u = "c_{k-2}" 76 | elif j == 1: 77 | u = "c_{k-1}" 78 | else: 79 | u = str(j-2) 80 | v = str(i) 81 | if op != 'None': 82 | g.edge(u, v, label=op, fillcolor="gray") 83 | 84 | g.node("c_{k}", fillcolor='palegoldenrod') 85 | for i in range(steps): 86 | g.edge(str(i), "c_{k}", fillcolor="gray") 87 | 88 | g.render(filename, view=True) 89 | 90 | 91 | if __name__ == '__main__': 92 | parser = ArgumentParser('Visualize_DARTS') 93 | parser.add_argument('--file', type=str, help="the path of mask (json) file") 94 | args = parser.parse_args() 95 | 96 | arch_file = args.file 97 | assert os.path.exists(arch_file), f"{arch_file} not found." 98 | genotype = convert_genotypes(arch_file) 99 | plot(genotype.normal, "normal") 100 | plot(genotype.reduce, "reduction") 101 | -------------------------------------------------------------------------------- /hyperbox/utils/visualize_mbconv_net.py: -------------------------------------------------------------------------------- 1 | import graphviz 2 | 3 | 4 | class ColorSet: 5 | hex_colors1 = [ 6 | '#2878b5', 7 | '#9ac9db', 8 | '#f8ac8c', 9 | '#c82423', 10 | '#ff8884', 11 | ] 12 | hex_colors2 = [ 13 | '#BEB8DC', 14 | '#E7DAD2', 15 | '#8ECFC9', 16 | '#FFBE7A', 17 | '#FA7F6F', 18 | '#82B0D2', 19 | ] 20 | 21 | def __init__(self, color_set='hex_colors1'): 22 | if color_set == 'hex_colors1': 23 | self._colors = self.hex_colors1 24 | elif color_set == 'hex_colors2': 25 | self._colors = self.hex_colors2 26 | else: 27 | raise ValueError('color_set must be one of hex_colors1 or hex_colors2') 28 | 29 | @property 30 | def colors(self): 31 | return self._colors 32 | 33 | @colors.setter 34 | def colors(self, value): 35 | self._colors = value 36 | 37 | 38 | def draw_arch(arch, index2op, filename='arch', color_set='hex_colors1'): 39 | colorset = ColorSet(color_set) 40 | index2color = { 41 | i: colorset.colors[i] for i in range(len(set(arch))) 42 | } 43 | 44 | dot = graphviz.Graph(comment='The Round Table') 45 | # dot.graph_attr['rankdir'] = 'LR' 46 | dot.graph_attr['rotate'] = '90' 47 | prev_index = None 48 | for idx, op_idx in enumerate(arch): 49 | op_name = index2op[op_idx] 50 | op_color = index2color[op_idx] 51 | op_index = f"{idx}" 52 | dot.node(op_index, label=op_name, style='filled', fillcolor=op_color, shape='box', fontsize='10') 53 | if idx == 0: 54 | prev_index = op_index 55 | else: 56 | dot.edge(op_index, prev_index) 57 | prev_index = op_index 58 | dot.render(directory='visual_output', view=False, filename=filename) 59 | 60 | if __name__ == '__main__': 61 | arch = [ 62 | 0,1,2,1,2,3,1,1,2,2 63 | ] 64 | index2op = { 65 | 0: 'MBConv3', 66 | 1: 'MBConv5', 67 | 2: 'MBConv7', 68 | 3: 'Identity' 69 | } 70 | draw_arch(arch, index2op) 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch>=1.8.1 3 | torchvision 4 | pytorch-lightning==1.8.6 5 | torchmetrics>=0.3.2 6 | lightning-bolts==0.3.3 7 | 8 | # --------- hydra --------- # 9 | hydra-core==1.2.0 10 | hydra-colorlog==1.1.0.dev1 11 | # hydra-optuna-sweeper==1.1.0.dev2 # for optuna-based HPO 12 | # hydra-ax-sweeper==1.1.0 13 | # hydra-ray-launcher==0.1.2 14 | # hydra-submitit-launcher==1.1.0 15 | 16 | # --------- loggers --------- # 17 | wandb>=0.12.21 18 | # neptune-client 19 | # mlflow 20 | # comet-ml 21 | # torch_tb_profiler 22 | 23 | # --------- linters --------- # 24 | pre-commit 25 | black 26 | isort 27 | flake8 28 | 29 | # --------- cores --------- # 30 | loguru>=0.6.0 31 | einops 32 | colorlog 33 | scikit-learn 34 | #numpy>=1.22.0 35 | 36 | # --------- visualization --------- # 37 | graphviz>=0.20.1 38 | 39 | # --------- for building nasbench-201 --------- # 40 | peewee 41 | 42 | # --------- others --------- # 43 | ipdb 44 | pyrootutils 45 | rich 46 | pytest 47 | sh 48 | seaborn 49 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | num=$1 4 | remark=$2 5 | others=$3 6 | 7 | BS=256 8 | exp=example_darts_nas 9 | 10 | 11 | if [ $num -le 1 ] 12 | then 13 | { 14 | echo "python run.py experiment=$exp.yaml logger.wandb.name=g${num}_bs${BS}_${remark} trainer.gpus=${num} datamodule.batch_size=${BS} ${others}" 15 | python -m ipdb run.py \ 16 | experiment=$exp.yaml \ 17 | logger.wandb.name=g${num}_bs${BS}_${remark} \ 18 | trainer.gpus=${num} \ 19 | datamodule.batch_size=${BS} \ 20 | logger.wandb.offline=True \ 21 | $others 22 | } 23 | else 24 | { 25 | echo "python run.py experiment=$exp.yaml logger.wandb.name=ddp_g${num}_bs${BS}_${remark} trainer.gpus=${num} trainer.accelerator=ddp datamodule.batch_size=${BS} ${others}" 26 | python run.py \ 27 | experiment=$exp.yaml \ 28 | logger.wandb.name=ddp_g${num}_bs${BS}_${remark} \ 29 | trainer.gpus=${num} \ 30 | trainer.accelerator=ddp \ 31 | datamodule.batch_size=${BS} \ 32 | logger.wandb.offline=True \ 33 | $others 34 | } 35 | # mpirun -np ${num} python run.py experiment=$exp.yaml logger.wandb.name=ddp_g${num} trainer.gpus=1 trainer.accelerator=horovod 36 | fi -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | project_name = hyperbox 3 | author = marsggbo 4 | contact = marsggbo@foxmail.com 5 | license_file = LICENSE 6 | description_file = README.md 7 | project_template = https://github.com/marsggbo/hyperbox 8 | 9 | 10 | [isort] 11 | line_length = 99 12 | profile = black 13 | filter_files = True 14 | 15 | 16 | [flake8] 17 | max_line_length = 99 18 | show_source = True 19 | format = pylint 20 | ignore = 21 | F401 # Module imported but unused 22 | W504 # Line break occurred after a binary operator 23 | F841 # Local variable name is assigned to but never used 24 | exclude = 25 | .git 26 | __pycache__ 27 | data/* 28 | tests/* 29 | notebooks/* 30 | logs/* 31 | 32 | 33 | [tool:pytest] 34 | python_files = tests/* 35 | log_cli = True 36 | markers = 37 | slow 38 | addopts = 39 | --durations=0 40 | --strict-markers 41 | --doctest-modules 42 | filterwarnings = 43 | ignore::DeprecationWarning 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | required_modules = [] 4 | with open('requirements.txt') as f: 5 | required = f.read().splitlines() 6 | for module in required: 7 | if not module.startswith('#'): 8 | required_modules.append(module) 9 | 10 | setup( 11 | name="hyperbox", # you should change "src" to your project name 12 | version="1.4.4", 13 | description="Hyperbox: An easy-to-use NAS framework.", 14 | author="marsggbo", 15 | url="https://github.com/marsggbo/hyperbox", 16 | # replace with your own github project link 17 | install_requires=required_modules, 18 | packages=find_packages(), 19 | include_package_data=True, 20 | ) 21 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/tests/__init__.py -------------------------------------------------------------------------------- /tests/datamodules/test_transforms.py: -------------------------------------------------------------------------------- 1 | from hyperbox.datamodules.transforms import get_transforms, TorchTransforms 2 | 3 | if __name__ == '__main__': 4 | kwargs = { 5 | 'input_size': [32, 32], 6 | 'random_crop': {'enable': 1, 'padding': 4, 'size': 28}, 7 | 'random_horizontal_flip': {'enable': 1, 'p': 0.8} 8 | } 9 | T = TorchTransforms(**kwargs) 10 | pass -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/tests/helpers/__init__.py -------------------------------------------------------------------------------- /tests/helpers/module_available.py: -------------------------------------------------------------------------------- 1 | import platform 2 | from importlib.util import find_spec 3 | 4 | """ 5 | Adapted from: 6 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/imports.py 7 | """ 8 | 9 | 10 | def _module_available(module_path: str) -> bool: 11 | """Check if a path is available in your environment. 12 | 13 | >>> _module_available('os') 14 | True 15 | >>> _module_available('bla.bla') 16 | False 17 | 18 | """ 19 | try: 20 | return find_spec(module_path) is not None 21 | except AttributeError: 22 | # Python 3.6 23 | return False 24 | except ModuleNotFoundError: 25 | # Python 3.7+ 26 | return False 27 | 28 | 29 | _IS_WINDOWS = platform.system() == "Windows" 30 | _APEX_AVAILABLE = _module_available("apex.amp") 31 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available("deepspeed") 32 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") 33 | _RPC_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.rpc") 34 | -------------------------------------------------------------------------------- /tests/helpers/run_command.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | import sh 5 | 6 | 7 | def run_command(command: List[str]): 8 | """Default method for executing shell commands with pytest.""" 9 | msg = None 10 | try: 11 | sh.python(command) 12 | except sh.ErrorReturnCode as e: 13 | msg = e.stderr.decode() 14 | if msg: 15 | pytest.fail(msg=msg) 16 | -------------------------------------------------------------------------------- /tests/helpers/runif.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Optional 3 | 4 | import pytest 5 | import torch 6 | from packaging.version import Version 7 | from pkg_resources import get_distribution 8 | 9 | """ 10 | Adapted from: 11 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py 12 | """ 13 | 14 | from tests.helpers.module_available import ( 15 | _APEX_AVAILABLE, 16 | _DEEPSPEED_AVAILABLE, 17 | _FAIRSCALE_AVAILABLE, 18 | _IS_WINDOWS, 19 | _RPC_AVAILABLE, 20 | ) 21 | 22 | 23 | class RunIf: 24 | """ 25 | RunIf wrapper for conditional skipping of tests. 26 | Fully compatible with `@pytest.mark`. 27 | 28 | Example: 29 | 30 | @RunIf(min_torch="1.8") 31 | @pytest.mark.parametrize("arg1", [1.0, 2.0]) 32 | def test_wrapper(arg1): 33 | assert arg1 > 0 34 | 35 | """ 36 | 37 | def __new__( 38 | self, 39 | min_gpus: int = 0, 40 | min_torch: Optional[str] = None, 41 | max_torch: Optional[str] = None, 42 | min_python: Optional[str] = None, 43 | amp_apex: bool = False, 44 | skip_windows: bool = False, 45 | rpc: bool = False, 46 | fairscale: bool = False, 47 | deepspeed: bool = False, 48 | **kwargs, 49 | ): 50 | """ 51 | Args: 52 | min_gpus: min number of gpus required to run test 53 | min_torch: minimum pytorch version to run test 54 | max_torch: maximum pytorch version to run test 55 | min_python: minimum python version required to run test 56 | amp_apex: NVIDIA Apex is installed 57 | skip_windows: skip test for Windows platform 58 | rpc: requires Remote Procedure Call (RPC) 59 | fairscale: if `fairscale` module is required to run the test 60 | deepspeed: if `deepspeed` module is required to run the test 61 | kwargs: native pytest.mark.skipif keyword arguments 62 | """ 63 | conditions = [] 64 | reasons = [] 65 | 66 | if min_gpus: 67 | conditions.append(torch.cuda.device_count() < min_gpus) 68 | reasons.append(f"GPUs>={min_gpus}") 69 | 70 | if min_torch: 71 | torch_version = get_distribution("torch").version 72 | conditions.append(Version(torch_version) < Version(min_torch)) 73 | reasons.append(f"torch>={min_torch}") 74 | 75 | if max_torch: 76 | torch_version = get_distribution("torch").version 77 | conditions.append(Version(torch_version) >= Version(max_torch)) 78 | reasons.append(f"torch<{max_torch}") 79 | 80 | if min_python: 81 | py_version = ( 82 | f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" 83 | ) 84 | conditions.append(Version(py_version) < Version(min_python)) 85 | reasons.append(f"python>={min_python}") 86 | 87 | if amp_apex: 88 | conditions.append(not _APEX_AVAILABLE) 89 | reasons.append("NVIDIA Apex") 90 | 91 | if skip_windows: 92 | conditions.append(_IS_WINDOWS) 93 | reasons.append("does not run on Windows") 94 | 95 | if rpc: 96 | conditions.append(not _RPC_AVAILABLE) 97 | reasons.append("RPC") 98 | 99 | if fairscale: 100 | conditions.append(not _FAIRSCALE_AVAILABLE) 101 | reasons.append("Fairscale") 102 | 103 | if deepspeed: 104 | conditions.append(not _DEEPSPEED_AVAILABLE) 105 | reasons.append("Deepspeed") 106 | 107 | reasons = [rs for cond, rs in zip(conditions, reasons) if cond] 108 | return pytest.mark.skipif( 109 | condition=any(conditions), 110 | reason=f"Requires: [{' + '.join(reasons)}]", 111 | **kwargs, 112 | ) 113 | -------------------------------------------------------------------------------- /tests/models/test_random_nas_model.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import OmegaConf 3 | 4 | if __name__ == '__main__': 5 | cfg = OmegaConf.load('../../configs/model/random_model.yaml') 6 | nas_net = hydra.utils.instantiate(cfg, _recursive_=False) 7 | -------------------------------------------------------------------------------- /tests/mutables/test_all_mutables.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from hyperbox.mutator import RandomMutator 5 | from hyperbox.mutables.spaces import InputSpace, OperationSpace, ValueSpace 6 | from hyperbox.mutables.ops import Conv2d 7 | 8 | 9 | class Net(nn.Module): 10 | def __init__(self,): 11 | super().__init__() 12 | ops = [ 13 | nn.Conv2d(3,4,kernel_size=3,stride=1,padding=1), 14 | nn.Conv2d(3,4,kernel_size=5,stride=1,padding=2), 15 | nn.Conv2d(3,4,kernel_size=7,stride=1,padding=3), 16 | ] 17 | self.candidate_op1 = OperationSpace(ops, key='candidate1') 18 | self.candidate_op2 = OperationSpace(ops, key='candidate2') 19 | 20 | v1 = ValueSpace([4,8,16]) 21 | v2 = ValueSpace([2]) 22 | v3 = ValueSpace([3,5,7]) 23 | self.fop1 = Conv2d(4, v1, kernel_size=v3,stride=v2,padding=1,auto_padding=True) 24 | self.fop2 = Conv2d(v1, 8, kernel_size=v3,stride=v2,padding=1,auto_padding=True) 25 | 26 | self.input_op = InputSpace(n_candidates=2, n_chosen=1, key='input1') 27 | 28 | def forward(self, x): 29 | out1 = self.candidate_op1(x) 30 | out2 = self.candidate_op2(x) 31 | 32 | out = self.input_op([out1, out2]) 33 | print(out.shape) 34 | out = self.fop1(out) 35 | out = self.fop2(out) 36 | return out 37 | 38 | if __name__ == '__main__': 39 | net = Net() 40 | random = RandomMutator(net) 41 | random.reset() 42 | x = torch.rand(2,3,64,64) 43 | y = net(x) 44 | print(y.shape) 45 | 46 | 47 | # test OperationSpace 48 | x = torch.rand(2,10) 49 | ops = [ 50 | nn.Linear(10,10,bias=False), 51 | nn.Linear(10,100,bias=False), 52 | nn.Linear(10,100,bias=False), 53 | nn.Identity() 54 | ] 55 | mixop = OperationSpace( 56 | ops, 57 | mask=[0,1,0,0] 58 | ) 59 | y = mixop(x) 60 | print(y.shape) 61 | 62 | mixop = OperationSpace( 63 | ops, return_mask=True 64 | ) 65 | m = RandomMutator(mixop) 66 | m.reset() 67 | y, mask = mixop(x) 68 | print(y.shape, mask) 69 | 70 | # test InputSpace 71 | input1 = torch.rand(1,3) 72 | input2 = torch.rand(1,2) 73 | input3 = torch.rand(2,1) 74 | inputs = [input1, input2, input3] 75 | ic1 = InputSpace(n_candidates=3, n_chosen=1, return_mask=True) 76 | m = RandomMutator(ic1) 77 | m.reset() 78 | out, mask = ic1(inputs) 79 | print(out.shape, mask) 80 | 81 | inputs = {'key1':input1, 'key2':input2, 'key3':input3} 82 | ic2 = InputSpace(choose_from=['key1', 'key2', 'key3'], n_chosen=1, return_mask=True) 83 | m = RandomMutator(ic2) 84 | m.reset() 85 | out, mask = ic2(inputs) 86 | print(out.shape, mask) 87 | 88 | vc1 = ValueSpace([8,16,24], index=1) 89 | vc2 = ValueSpace([1,2,3,4,5,6,7,8,9]) 90 | vc = nn.ModuleList([vc1,vc2]) 91 | m = RandomMutator(vc) 92 | m.reset() 93 | print(vc1, vc2) 94 | -------------------------------------------------------------------------------- /tests/mutables/test_duplicate_mutables.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from hyperbox.mutables.spaces import OperationSpace 5 | from hyperbox.networks.base_nas_network import BaseNASNetwork 6 | from hyperbox.mutator import RandomMutator 7 | 8 | 9 | class ToyDuplicateNet(BaseNASNetwork): 10 | def __init__(self): 11 | super().__init__() 12 | self.op1 = OperationSpace(candidates=[nn.ReLU(), nn.Sigmoid()], key='op') 13 | self.op2 = OperationSpace(candidates=[nn.ReLU(), nn.Sigmoid()], key='op') 14 | 15 | def forward(self, x): 16 | out1 = self.op1(x) 17 | out2 = self.op2(x) 18 | out = out1 - out2 19 | return out 20 | 21 | @property 22 | def arch(self): 23 | return f"op1:{self.op1.mask}-op2:{self.op2.mask}" 24 | 25 | 26 | if __name__ == '__main__': 27 | model = ToyDuplicateNet() 28 | mutator = RandomMutator(model) 29 | x = torch.rand(10) 30 | for i in range(3): 31 | mutator.reset() 32 | print(model.arch) 33 | y = model(x) 34 | print(y) 35 | print('='*20) -------------------------------------------------------------------------------- /tests/mutables/test_op_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from hyperbox.mutables import ops 3 | from hyperbox.mutables.spaces import ValueSpace 4 | from hyperbox.mutator import RandomMutator 5 | 6 | if __name__ == '__main__': 7 | x = torch.rand(4,3) 8 | print('Normal case') 9 | linear = ops.Linear(3,10) 10 | bn = ops.BatchNorm1d(10) 11 | y = linear(x) 12 | print(y.shape) 13 | y = bn(y) 14 | print(y.shape) 15 | 16 | print('Search case') 17 | vs1 = ValueSpace([1,2,3]) 18 | linear = ops.Linear(3, vs1, bias=False) 19 | bn = ops.BatchNorm1d(vs1) 20 | m = RandomMutator(linear) 21 | m.reset() 22 | print(m._cache) 23 | print(linear.weight.shape) 24 | y = linear(x) 25 | print(y.shape) 26 | y = bn(y) 27 | print(y.shape) 28 | -------------------------------------------------------------------------------- /tests/mutables/test_op_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from hyperbox.mutables import ops 3 | from hyperbox.mutables.spaces import ValueSpace 4 | from hyperbox.mutator import RandomMutator 5 | 6 | from hyperbox.utils.calc_model_size import flops_size_counter 7 | 8 | def test_groups(cin, cout, group_list): 9 | print(f"test_groups: cin={cin}, cout={cout}, group_list={group_list}\n") 10 | groups = ValueSpace(candidates=group_list) 11 | x = torch.rand(2, cin, 64, 64) 12 | conv = ops.Conv2d(cin, cout, 3, 1, 1, groups=groups) 13 | m = RandomMutator(conv) 14 | for i in range(10): 15 | m.reset() 16 | print(f'\n*******step{i}********\n', conv) 17 | y = conv(x) 18 | 19 | def test_cin_groups(cin_list, cout, group_list): 20 | print(f"test_cin_groups: cin_list={cin_list}, cout={cout}, group_list={group_list}\n") 21 | cin = ValueSpace(candidates=cin_list) 22 | groups = ValueSpace(candidates=group_list) 23 | x = torch.rand(2, 3, 64, 64) 24 | conv = torch.nn.Sequential( 25 | ops.Conv2d(3, cin, 3, 1, 1), 26 | ops.Conv2d(cin, cout, 3, 1, 1, groups=cin) 27 | ) 28 | m = RandomMutator(conv) 29 | for i in range(10): 30 | m.reset() 31 | print(f'\n*******step{i}********\n', conv) 32 | y = conv(x) 33 | 34 | def test_cout_groups(cin, cout_list, group_list): 35 | print(f"test_cout_groups: cin={cin}, cout_list={cout_list}, group_list={group_list}\n") 36 | cout = ValueSpace(candidates=cout_list) 37 | groups = ValueSpace(candidates=group_list) 38 | x = torch.rand(2, cin, 64, 64) 39 | conv = ops.Conv2d(cin, cout, 3, 1, 1, groups=groups) 40 | m = RandomMutator(conv) 41 | for i in range(10): 42 | m.reset() 43 | print(f'\n*******step{i}********\n', conv) 44 | y = conv(x) 45 | 46 | def test_cin_cout_groups(cin_list, cout_list, group_list): 47 | print(f"test_cin_cout_groups: cin_list={cin_list}, cout_list={cout_list}, group_list={group_list}\n") 48 | cin = ValueSpace(candidates=cin_list) 49 | cout = ValueSpace(candidates=cout_list) 50 | groups = ValueSpace(candidates=group_list) 51 | x = torch.rand(2, 3, 64, 64) 52 | conv = torch.nn.Sequential( 53 | ops.Conv2d(3, cin, 3, 1, 1), 54 | ops.Conv2d(cin, cout, 3, 1, 1, groups=cin) 55 | ) 56 | m = RandomMutator(conv) 57 | for i in range(10): 58 | m.reset() 59 | print(f'\n*******step{i}********\n', conv) 60 | y = conv(x) 61 | 62 | def test_conv(): 63 | print('testing conv flops and sizes') 64 | x = torch.rand(1,3,64,64) 65 | vs1 = ValueSpace([10,2]) 66 | vs2 = ValueSpace([3,5,7]) 67 | conv = ops.Conv2d(3, vs1, vs2,bias=False) 68 | bn = ops.BatchNorm2d(vs1) 69 | op = torch.nn.Sequential( 70 | # torch.nn.Conv2d(3,3,3,1), 71 | conv, bn 72 | ) 73 | m = RandomMutator(op) 74 | m.reset() 75 | # print(op) 76 | # print(m._cache) 77 | # print(conv.weight.shape) 78 | # print(conv(x).shape) 79 | r = flops_size_counter(op, (1,3,8,8), False, False) 80 | # print(r) 81 | op = torch.nn.Sequential( 82 | ops.Conv2d(3,8,3,1), 83 | ops.BatchNorm2d(8) 84 | ) 85 | r = flops_size_counter(op, (1,3,8,8), False, False) 86 | # print(conv(x).shape) 87 | 88 | if __name__ == '__main__': 89 | test_conv() 90 | test_groups(32, 64, group_list=[1, 2, 4, 8, 16, 32]) 91 | test_cin_groups([32, 64], 128, group_list=[32, 64]) 92 | test_cout_groups(8, [16, 32], group_list=[1, 2, 4, 8]) 93 | test_cin_cout_groups([8, 16], [32, 64], group_list=[1, 8, 16]) 94 | -------------------------------------------------------------------------------- /tests/mutables/test_op_linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from hyperbox.mutables import ops 3 | from hyperbox.mutables.spaces import ValueSpace 4 | from hyperbox.mutator import RandomMutator 5 | 6 | from hyperbox.utils.calc_model_size import flops_size_counter 7 | 8 | if __name__ == '__main__': 9 | x = torch.rand(1,3) 10 | vs1 = ValueSpace([1,2,3]) 11 | linear = ops.Linear(3, vs1, bias=False) 12 | m = RandomMutator(linear) 13 | m.reset() 14 | print(m._cache) 15 | print(linear.weight.shape) 16 | print(linear(x).shape) 17 | r = flops_size_counter(linear, (2,3), False, True) 18 | 19 | linear = ops.Linear(3,10) 20 | print(linear(x).shape) 21 | r = flops_size_counter(linear, (2,3), False, True) -------------------------------------------------------------------------------- /tests/networks/test_darts_net.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from hyperbox.networks.darts import DartsNetwork, DartsCell 5 | from hyperbox.mutator.random_mutator import RandomMutator 6 | 7 | 8 | if __name__ == '__main__': 9 | net = DartsNetwork(3,16,10,4,mask='./hyperbox/networks/darts/darts_mask.json') 10 | pass 11 | 12 | net = DartsNetwork( 13 | in_channels=3, 14 | channels=16, 15 | n_classes=10, 16 | n_layers=3, 17 | factory_func=DartsCell, 18 | ).cuda() 19 | print(net.arch) 20 | m = RandomMutator(net) 21 | m.reset() 22 | print(net.arch) 23 | m.reset() 24 | print(net.arch) 25 | m.reset() 26 | print(net.arch) 27 | x = torch.rand(2,3,64,64).cuda() 28 | output = net(x) 29 | print(output.shape) 30 | pass -------------------------------------------------------------------------------- /tests/networks/test_enas_net.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from hyperbox.networks.enas import ENASMicroNetwork 5 | from hyperbox.mutator.random_mutator import RandomMutator 6 | 7 | if __name__ == '__main__': 8 | net = ENASMicroNetwork( 9 | num_layers=2, 10 | num_nodes=3, 11 | ).cuda() 12 | m = RandomMutator(net) 13 | m.reset() 14 | x = torch.rand(2,3,64,64).cuda() 15 | output = net(x) 16 | print(output.shape) 17 | pass -------------------------------------------------------------------------------- /tests/networks/test_mobile_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from hyperbox.networks.mobilenet.mobile_net import MobileNet 5 | from hyperbox.networks.mobilenet.mobile3d_net import Mobile3DNet 6 | 7 | if __name__ == '__main__': 8 | from hyperbox.mutator.random_mutator import RandomMutator 9 | net = MobileNet() 10 | m = RandomMutator(net) 11 | m.reset() 12 | x = torch.rand(2,3,64,64) 13 | output = net(x) 14 | print(output.shape) 15 | 16 | 17 | net = Mobile3DNet() 18 | m = RandomMutator(net) 19 | m.reset() 20 | x = torch.rand(2,3,64,64,64) 21 | output = net(x) 22 | print(output.shape) -------------------------------------------------------------------------------- /tests/networks/test_resnet.py: -------------------------------------------------------------------------------- 1 | 2 | from hyperbox.networks.resnet import resnet20 3 | 4 | if __name__ == '__main__': 5 | size = lambda net: sum([p.numel() for name, p in net.named_parameters() if 'value' not in name]) 6 | net = resnet20() 7 | params_num = size(net) 8 | mb = params_num * 4 / 1024**2 9 | print(f"#pamrams of {net.__class__.__name__}={size(net)} {mb}(MB)") 10 | -------------------------------------------------------------------------------- /tests/smoke/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/tests/smoke/__init__.py -------------------------------------------------------------------------------- /tests/smoke/test_commands.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_command import run_command 4 | from tests.helpers.runif import RunIf 5 | 6 | 7 | def test_fast_dev_run(): 8 | """Run 1 train, val, test batch.""" 9 | command = ["run.py", "trainer=default", "trainer.fast_dev_run=true"] 10 | run_command(command) 11 | 12 | 13 | def test_default_cpu(): 14 | """Test default configuration on CPU.""" 15 | command = ["run.py", "trainer.max_epochs=1", "trainer.gpus=0"] 16 | run_command(command) 17 | 18 | 19 | @RunIf(min_gpus=1) 20 | def test_default_gpu(): 21 | """Test default configuration on GPU.""" 22 | command = [ 23 | "run.py", 24 | "trainer.max_epochs=1", 25 | "trainer.gpus=1", 26 | "datamodule.pin_memory=True", 27 | ] 28 | run_command(command) 29 | 30 | 31 | @pytest.mark.slow 32 | def test_experiments(): 33 | """Train 1 epoch with all experiment configs.""" 34 | command = ["run.py", "-m", "experiment=glob(*)", "trainer.max_epochs=1"] 35 | run_command(command) 36 | 37 | 38 | def test_limit_batches(): 39 | """Train 1 epoch on 25% of data.""" 40 | command = [ 41 | "run.py", 42 | "trainer=default", 43 | "trainer.max_epochs=1", 44 | "trainer.limit_train_batches=0.25", 45 | "trainer.limit_val_batches=0.25", 46 | "trainer.limit_test_batches=0.25", 47 | ] 48 | run_command(command) 49 | 50 | 51 | def test_gradient_accumulation(): 52 | """Train 1 epoch with gradient accumulation.""" 53 | command = [ 54 | "run.py", 55 | "trainer=default", 56 | "trainer.max_epochs=1", 57 | "trainer.accumulate_grad_batches=10", 58 | ] 59 | run_command(command) 60 | 61 | 62 | def test_double_validation_loop(): 63 | """Train 1 epoch with validation loop twice per epoch.""" 64 | command = [ 65 | "run.py", 66 | "trainer=default", 67 | "trainer.max_epochs=1", 68 | "trainer.val_check_interval=0.5", 69 | ] 70 | run_command(command) 71 | 72 | 73 | def test_csv_logger(): 74 | """Train 5 epochs with 5 batches with CSVLogger.""" 75 | command = [ 76 | "run.py", 77 | "trainer=default", 78 | "trainer.max_epochs=5", 79 | "trainer.limit_train_batches=5", 80 | "logger=csv", 81 | ] 82 | run_command(command) 83 | 84 | 85 | def test_tensorboard_logger(): 86 | """Train 5 epochs with 5 batches with TensorboardLogger.""" 87 | command = [ 88 | "run.py", 89 | "trainer=default", 90 | "trainer.max_epochs=5", 91 | "trainer.limit_train_batches=5", 92 | "logger=tensorboard", 93 | ] 94 | run_command(command) 95 | 96 | 97 | def test_overfit_batches(): 98 | """Overfit to 10 batches over 10 epochs.""" 99 | command = [ 100 | "run.py", 101 | "trainer=default", 102 | "trainer.min_epochs=10", 103 | "trainer.max_epochs=10", 104 | "trainer.overfit_batches=10", 105 | ] 106 | run_command(command) 107 | -------------------------------------------------------------------------------- /tests/smoke/test_mixed_precision.py: -------------------------------------------------------------------------------- 1 | from tests.helpers.run_command import run_command 2 | from tests.helpers.runif import RunIf 3 | 4 | 5 | @RunIf(amp_apex=True) 6 | def test_apex_01(): 7 | """Test mixed-precision level 01.""" 8 | command = [ 9 | "run.py", 10 | "trainer=default", 11 | "trainer.max_epochs=1", 12 | "trainer.gpus=1", 13 | "trainer.amp_backend=apex", 14 | "trainer.amp_level=O1", 15 | "trainer.precision=16", 16 | ] 17 | run_command(command) 18 | 19 | 20 | @RunIf(amp_apex=True) 21 | def test_apex_02(): 22 | """Test mixed-precision level 02.""" 23 | command = [ 24 | "run.py", 25 | "trainer=default", 26 | "trainer.max_epochs=1", 27 | "trainer.gpus=1", 28 | "trainer.amp_backend=apex", 29 | "trainer.amp_level=O2", 30 | "trainer.precision=16", 31 | ] 32 | run_command(command) 33 | 34 | 35 | @RunIf(amp_apex=True) 36 | def test_apex_03(): 37 | """Test mixed-precision level 03.""" 38 | command = [ 39 | "run.py", 40 | "trainer=default", 41 | "trainer.max_epochs=1", 42 | "trainer.gpus=1", 43 | "trainer.amp_backend=apex", 44 | "trainer.amp_level=O3", 45 | "trainer.precision=16", 46 | ] 47 | run_command(command) 48 | -------------------------------------------------------------------------------- /tests/smoke/test_sweeps.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_command import run_command 4 | 5 | """ 6 | Use the following command to skip slow tests: 7 | pytest -k "not slow" 8 | """ 9 | 10 | 11 | @pytest.mark.slow 12 | def test_default_sweep(): 13 | """Test default Hydra sweeper.""" 14 | command = [ 15 | "run.py", 16 | "-m", 17 | "datamodule.batch_size=64,128", 18 | "model.lr=0.01,0.02", 19 | "trainer=default", 20 | "trainer.fast_dev_run=true", 21 | ] 22 | run_command(command) 23 | 24 | 25 | @pytest.mark.slow 26 | def test_optuna_sweep(): 27 | """Test Optuna sweeper.""" 28 | command = [ 29 | "run.py", 30 | "-m", 31 | "hparams_search=mnist_optuna", 32 | "trainer=default", 33 | "trainer.fast_dev_run=true", 34 | ] 35 | run_command(command) 36 | 37 | 38 | @pytest.mark.skip(reason="TODO: Add Ax sweep config.") 39 | @pytest.mark.slow 40 | def test_ax_sweep(): 41 | """Test Ax sweeper.""" 42 | command = ["run.py", "-m", "hparams_search=mnist_ax", "trainer.fast_dev_run=true"] 43 | run_command(command) 44 | -------------------------------------------------------------------------------- /tests/smoke/test_wandb.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_command import run_command 4 | 5 | """ 6 | Use the following command to skip slow tests: 7 | pytest -k "not slow" 8 | """ 9 | 10 | 11 | # @pytest.mark.slow 12 | # def test_wandb_optuna_sweep(): 13 | # """Test wandb logging with Optuna sweep.""" 14 | # command = [ 15 | # "run.py", 16 | # "-m", 17 | # "hparams_search=mnist_optuna", 18 | # "trainer=default", 19 | # "trainer.max_epochs=10", 20 | # "trainer.limit_train_batches=20", 21 | # "logger=wandb", 22 | # "logger.wandb.project=template-tests", 23 | # "logger.wandb.group=Optuna_SimpleDenseNet_MNIST", 24 | # "hydra.sweeper.n_trials=5", 25 | # ] 26 | # run_command(command) 27 | 28 | 29 | # @pytest.mark.slow 30 | # def test_wandb_callbacks(): 31 | # """Test wandb callbacks.""" 32 | # command = [ 33 | # "run.py", 34 | # "trainer=default", 35 | # "trainer.max_epochs=3", 36 | # "logger=wandb", 37 | # "logger.wandb.project=template-tests", 38 | # "callbacks=wandb", 39 | # ] 40 | # run_command(command) 41 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsggbo/hyperbox/cf163b77c3b53b0df6e430d4a3cb6f60e36ae366/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/test_sth.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.runif import RunIf 4 | 5 | 6 | def test_something1(): 7 | """Some test description.""" 8 | assert True is True 9 | 10 | 11 | def test_something2(): 12 | """Some test description.""" 13 | assert 1 + 1 == 2 14 | 15 | 16 | @pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0]) 17 | def test_something3(arg1: float): 18 | """Some test description.""" 19 | assert arg1 > 0 20 | 21 | 22 | # use RunIf to skip execution of some tests, e.g. when not on windows or when no gpus are available 23 | @RunIf(skip_windows=True, min_gpus=1) 24 | def test_something4(): 25 | """Some test description.""" 26 | assert True is True 27 | -------------------------------------------------------------------------------- /tests/utils/test_calc_model_size.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from hyperbox.mutator import RandomMutator 5 | from hyperbox.mutables import ops, spaces 6 | from hyperbox.networks.mobilenet.mobile_net import MobileNet 7 | from hyperbox.utils.calc_model_size import flops_size_counter 8 | 9 | 10 | class HybridNet(torch.nn.Module): 11 | def __init__(self, mask=None): 12 | super().__init__() 13 | self.op_space1 = spaces.OperationSpace( 14 | candidates=[ 15 | torch.nn.Conv2d(3,20,3,padding=1), 16 | torch.nn.Conv2d(3,20,5,padding=2) 17 | ], 18 | mask=mask 19 | ) 20 | self.op_space2 = spaces.OperationSpace( 21 | candidates=[ 22 | torch.nn.Conv2d(3,20,7,padding=3), 23 | torch.nn.Conv2d(3,20,3,padding=1) 24 | ], 25 | mask=mask 26 | ) 27 | self.input_space = spaces.InputSpace(n_candidates=2, n_chosen=1, mask=mask) 28 | 29 | vs_channel = spaces.ValueSpace([24,16,18], mask=mask) 30 | vs_kernel = spaces.ValueSpace([3,5,7], mask=mask) 31 | vs_stride = spaces.ValueSpace([1,2], mask=mask) 32 | self.finegrained_conv = ops.Conv2d(20, vs_channel, vs_kernel, stride=vs_stride, auto_padding=True) 33 | self.finegrained_bn = ops.BatchNorm2d(vs_channel) 34 | self.finegrained_linear = ops.Linear(vs_channel, 1) 35 | 36 | def forward(self, x): 37 | bs = x.size(0) 38 | out1 = self.op_space1(x) 39 | out2 = self.op_space2(x) 40 | out = self.input_space([out1, out2]) 41 | out = self.finegrained_conv(out) 42 | out = self.finegrained_bn(out) 43 | out = torch.nn.AdaptiveAvgPool2d(1)(out) 44 | out = out.view(bs, -1) 45 | out = self.finegrained_linear(out) 46 | return out 47 | 48 | 49 | if __name__ == '__main__': 50 | 51 | net = MobileNet() 52 | mutator = RandomMutator(net) 53 | mutator.reset() 54 | 55 | x = torch.rand(1,3,128,128) 56 | result = flops_size_counter(net, (x,)) 57 | flops, params = [result[k] for k in result] 58 | print(f"{flops} MFLOPS, {params:.4f} MB") 59 | 60 | x = torch.rand(1,3,64,64) 61 | result = flops_size_counter(net, (x,)) 62 | flops, params = [result[k] for k in result] 63 | print(f"{flops} MFLOPS, {params:.4f} MB") 64 | 65 | mask = { 66 | "OperationSpace1": [True, False], 67 | "OperationSpace2": [False, True], 68 | "InputSpace3": [True, False], 69 | "ValueSpace6": [True, False], 70 | "ValueSpace4": [False, True, False], 71 | "ValueSpace5": [False, False, True] 72 | } 73 | net = HybridNet(mask) 74 | x = torch.rand(1,3,64,64) 75 | result = flops_size_counter(net, (x,)) 76 | flops, params = [result[k] for k in result] 77 | print(f"{flops} MFLOPS, {params:.4f} MB") 78 | 79 | net = HybridNet() 80 | mutator = RandomMutator(net) 81 | mutator.reset() 82 | x = torch.rand(1,3,32,32) 83 | result = flops_size_counter(net, (x,)) 84 | flops, params = [result[k] for k in result] 85 | print(f"{flops} MFLOPS, {params:.4f} MB") 86 | -------------------------------------------------------------------------------- /tests/utils/test_hparams.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from hyperbox.utils.utils import hparams_wrapper 4 | 5 | @hparams_wrapper 6 | class MyClass1(torch.nn.Module): 7 | def __init__(self, arg1): 8 | super(MyClass1, self).__init__() 9 | self.arg1 = arg1 10 | self.linear = torch.nn.Linear(3, arg1) 11 | 12 | class MyClass2(MyClass1): 13 | def __init__(self, arg1, arg2, arg3=False): 14 | super(MyClass2, self).__init__(arg1) 15 | self.arg1 = arg1 16 | self.arg2 = arg2 17 | self.arg3 = arg3 18 | self.conv = torch.nn.Conv2d(3, arg1, arg2, bias=arg3) 19 | 20 | def __setattr__(self, name, value): 21 | if name != 'key': 22 | super(MyClass2, self).__setattr__(name, value) 23 | 24 | 25 | if __name__ == '__main__': 26 | print(MyClass1.__mro__) # (, , ) 27 | print(MyClass2.__mro__) # (, , , ) 28 | # Test 1 29 | obj1 = MyClass1(1) 30 | print(obj1.hparams) # Output: {'arg1': 1} 31 | print(obj1) # Output: 32 | # MyClass1( 33 | # (linear): Linear(in_features=3, out_features=1, bias=True) 34 | # ) 35 | 36 | # Test 2 37 | obj2 = MyClass2(3,4) 38 | print(obj2.hparams) # Output: {'arg3': False, 'arg1': 3, 'arg2': 4} 39 | print(obj2) # Output: 40 | # MyClass2( 41 | # (conv): Conv2d(3, 3, kernel_size=(4, 4), stride=(1, 1), bias=False) 42 | # ) 43 | 44 | # Test 3 45 | from hyperbox.mutables.spaces import OperationSpace 46 | obj3 = OperationSpace([torch.nn.Linear(2,3),torch.nn.Linear(3,2)]) 47 | print(obj3.hparams) # Output: {'mask': None, 'index': None, 'reduction': 'sum', 'return_mask': False, 'key': None, 'candidates': [Linear(in_features=2, out_features=3, bias=True), Linear(in_features=3, out_features=2, bias=True)]} 48 | print(obj3) # Output: 49 | # OperationSpace( 50 | # (candidates): ModuleList( 51 | # (0): Linear(in_features=2, out_features=3, bias=True) 52 | # (1): Linear(in_features=3, out_features=2, bias=True) 53 | # ) 54 | # ) 55 | -------------------------------------------------------------------------------- /tests/utils/test_logger.py: -------------------------------------------------------------------------------- 1 | from hyperbox.utils.logger import get_logger 2 | 3 | if __name__ == '__main__': 4 | logger = get_logger(__name__) 5 | logger.info('This is info') 6 | logger.debug('This is debug') 7 | logger.warning('This is warning') 8 | logger.error('This is error') --------------------------------------------------------------------------------