├── examples ├── .gitignore ├── 01-nn-training │ ├── .gitignore │ ├── model.py │ ├── dataset.py │ └── train.py ├── 03-hpcli │ ├── .gitignore │ ├── src.py │ └── Makefile ├── 00-basic │ ├── config_modified.yaml │ ├── lib.py │ └── main.py └── 02-brief │ └── main.py ├── hpargparse ├── pkginfo.py ├── __init__.py ├── config.py └── hputils.py ├── tests ├── test_files │ ├── basic │ │ ├── config.yaml │ │ ├── str.yaml │ │ ├── config.failure.yaml │ │ ├── str.py │ │ └── lib.py │ └── dict_and_list │ │ ├── config.yaml │ │ └── lib_dict_and_list.py └── test_hputils.py ├── docs ├── requirements.txt ├── gendoc.sh ├── index.rst ├── Makefile ├── make.bat └── conf.py ├── requirements.txt ├── requirements.dev.txt ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── .gitignore ├── .readthedocs.yml ├── .circleci └── config.yml ├── Makefile ├── bin └── hpcli ├── LICENSE ├── .gitlab-ci.yml ├── .git-commit-template.txt ├── setup.py └── README.md /examples/.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | -------------------------------------------------------------------------------- /examples/01-nn-training/.gitignore: -------------------------------------------------------------------------------- 1 | model.pt 2 | -------------------------------------------------------------------------------- /examples/03-hpcli/.gitignore: -------------------------------------------------------------------------------- 1 | config.yaml 2 | -------------------------------------------------------------------------------- /hpargparse/pkginfo.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.14.0" 2 | -------------------------------------------------------------------------------- /tests/test_files/basic/config.yaml: -------------------------------------------------------------------------------- 1 | a: 2 2 | b: 3 3 | -------------------------------------------------------------------------------- /tests/test_files/basic/str.yaml: -------------------------------------------------------------------------------- 1 | str_from: 'yaml' 2 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-rtd-theme 3 | m2r 4 | -------------------------------------------------------------------------------- /examples/00-basic/config_modified.yaml: -------------------------------------------------------------------------------- 1 | a: 123 2 | b: 456 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | rich 2 | dill 3 | hpman>=0.0.6 4 | PyYAML 5 | -------------------------------------------------------------------------------- /tests/test_files/basic/config.failure.yaml: -------------------------------------------------------------------------------- 1 | a: 2 | key: 1 3 | b: 3 4 | -------------------------------------------------------------------------------- /hpargparse/__init__.py: -------------------------------------------------------------------------------- 1 | from .hputils import bind 2 | from .pkginfo import * 3 | -------------------------------------------------------------------------------- /tests/test_files/dict_and_list/config.yaml: -------------------------------------------------------------------------------- 1 | a: 2 | key: 3 3 | b: [2, 3, 4] 4 | c: 45 5 | -------------------------------------------------------------------------------- /tests/test_files/basic/str.py: -------------------------------------------------------------------------------- 1 | from hpman.m import _ 2 | 3 | 4 | _("str_from", "source_code") 5 | -------------------------------------------------------------------------------- /examples/03-hpcli/src.py: -------------------------------------------------------------------------------- 1 | from hpman.m import _ 2 | 3 | _("num_channels", 128) 4 | _("num_layers", 50) 5 | -------------------------------------------------------------------------------- /requirements.dev.txt: -------------------------------------------------------------------------------- 1 | black 2 | coverage 3 | pytest 4 | pytest-cov 5 | pre-commit 6 | sphinx 7 | sphinx-rtd-theme 8 | m2r 9 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/python/black 3 | rev: 22.10.0 4 | hooks: 5 | - id: black 6 | -------------------------------------------------------------------------------- /hpargparse/config.py: -------------------------------------------------------------------------------- 1 | HP_SERIAL_FORMAT_DEFAULT = "auto" 2 | HP_SERIAL_FORMAT_CHOICES = ["auto", "yaml", "pickle"] 3 | 4 | HP_ACTION_PREFIX_DEFAULT = "hp" 5 | -------------------------------------------------------------------------------- /tests/test_files/dict_and_list/lib_dict_and_list.py: -------------------------------------------------------------------------------- 1 | from hpman.m import _ 2 | 3 | a = _("a", {"key": 1}) 4 | b = _("b", [1, 2, 3]) 5 | c = _("c", 23) 6 | -------------------------------------------------------------------------------- /tests/test_files/basic/lib.py: -------------------------------------------------------------------------------- 1 | from hpman.m import _ 2 | 3 | 4 | def add(): 5 | return _("a", 1) + _("b", 2) 6 | 7 | 8 | def mult(): 9 | return _("a") * _("b") 10 | -------------------------------------------------------------------------------- /docs/gendoc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # use local version instead of installed version 3 | export PYTHONPATH="$(pwd)/..:$PYTHONPATH" 4 | 5 | set -x 6 | sphinx-apidoc -f -o api ../hpargparse 7 | make SPHINXOPTS="-n" html 8 | 9 | -------------------------------------------------------------------------------- /examples/03-hpcli/Makefile: -------------------------------------------------------------------------------- 1 | all: run 2 | 3 | run: 4 | hpcli src.py 5 | hpcli src.py --num-layers 101 6 | hpcli src.py --num-layers 101 --hp-save config.yaml 7 | hpcli src.py --num-layers 101 --hp-save config.yaml --hp-list detail 8 | hpcli src.py -h 9 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## v0.12.0 - 2020-09-26 4 | ### Fixed 5 | - Strings of class StringAsDefault has been set to hpman, which will in turn being dump to yaml as `hpargparse.hputils.StringAsDefault` object. This has been fixed. 6 | 7 | ## v0.11.0 - 2020-09-22 8 | ### Added 9 | - Support subparser 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | *.pyc 3 | __pycache__/ 4 | 5 | # vim swap file 6 | *.swp 7 | 8 | # coverage report 9 | htmlcov/ 10 | test-results/ 11 | .coverage 12 | coverage.xml 13 | 14 | # Distribution / packaging 15 | build/ 16 | dist/ 17 | eggs/ 18 | .eggs/ 19 | *.egg-info/ 20 | 21 | # sphinx 22 | /docs/api 23 | /docs/_build 24 | -------------------------------------------------------------------------------- /examples/00-basic/lib.py: -------------------------------------------------------------------------------- 1 | from hpman.m import _ 2 | 3 | # for more usecases, please refer to hpman's document 4 | 5 | 6 | def add(): 7 | # define a hyperparameter on-the-fly with defaults 8 | return _("a", 1) + _("b", 2) 9 | 10 | 11 | def mult(): 12 | # reuse a pre-defined hyperparameters 13 | return _("a") * _("b") 14 | -------------------------------------------------------------------------------- /examples/02-brief/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from hpman.m import _ 4 | import hpargparse 5 | 6 | import argparse 7 | 8 | 9 | def func(): 10 | weight_decay = _("weight_decay", 1e-5) 11 | print("weight decay is {}".format(weight_decay)) 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | _.parse_file(__file__) 17 | hpargparse.bind(parser, _) 18 | parser.parse_args() 19 | 20 | func() 21 | 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. hpman documentation master file, created by 2 | sphinx-quickstart on Sat Jul 27 14:03:30 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to hpman's documentation! 7 | ================================= 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | 14 | .. mdinclude:: ../README.md 15 | 16 | 17 | Indices and tables 18 | ================== 19 | 20 | * :ref:`genindex` 21 | * :ref:`modindex` 22 | * :ref:`search` 23 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/conf.py 11 | 12 | # Build documentation with MkDocs 13 | #mkdocs: 14 | # configuration: mkdocs.yml 15 | 16 | # Optionally build your docs in additional formats such as PDF and ePub 17 | formats: all 18 | 19 | # Optionally set the version of Python and requirements required to build your docs 20 | python: 21 | # version: 3.7 22 | install: 23 | - requirements: docs/requirements.txt 24 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | orbs: 3 | python: circleci/python@0.2.1 4 | jobs: 5 | build: 6 | docker: 7 | - image: circleci/python:3.6.7-jessie 8 | steps: 9 | - checkout 10 | - python/load-cache 11 | - run: pip install --user -r requirements.txt -r requirements.dev.txt 12 | - python/save-cache 13 | - run: make test 14 | - store_test_results: 15 | path: test-results 16 | - store_artifacts: 17 | path: test-results 18 | - run: bash <(curl -s https://codecov.io/bash) 19 | workflows: 20 | version: 2 21 | 22 | main: 23 | jobs: 24 | - build 25 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | 3 | test: 4 | mkdir -p test-results 5 | python3 -m pytest \ 6 | --cov=hpargparse \ 7 | --no-cov-on-fail \ 8 | --cov-report=html:test-results/htmlcov \ 9 | --cov-report term \ 10 | --doctest-modules \ 11 | --junitxml=test-results/junit.xml \ 12 | hpargparse tests 13 | python3 -m coverage xml -o test-results/coverage.xml 14 | 15 | style-check: 16 | black --diff --check . 17 | 18 | serve-coverage-report: 19 | cd test-results/htmlcov && python3 -m http.server 20 | 21 | wheel: 22 | python3 setup.py sdist bdist_wheel 23 | 24 | doc: 25 | cd docs && ./gendoc.sh 26 | 27 | install: 28 | # install prerequisites 29 | # TODO: 30 | # 1. install requirments 31 | # 2. install pre-commit hook 32 | 33 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /bin/hpcli: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | from hpman import HyperParameterManager 5 | import hpargparse 6 | import os 7 | import sys 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser(add_help=False) 12 | parser.add_argument(dest="files_and_directories", nargs="+") 13 | parser.add_argument( 14 | "--placeholder", default="_", help="placeholder of hpman used in given files" 15 | ) 16 | 17 | args, remain_args = parser.parse_known_args() 18 | 19 | parser = argparse.ArgumentParser() 20 | hp_mgr = HyperParameterManager(args.placeholder) 21 | hp_mgr.parse_file(args.files_and_directories) 22 | hpargparse.bind(parser, hp_mgr) 23 | 24 | # switch --hp-list on by default 25 | for action in parser._actions: 26 | if action.dest == "hp_list": 27 | action.default = "yaml" 28 | 29 | args = parser.parse_args(remain_args) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /examples/00-basic/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import hpargparse 4 | from hpman.m import _ 5 | import os 6 | 7 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser() 12 | 13 | # ... do whatever you want 14 | parser.add_argument(dest="predefined_arg") 15 | 16 | # analyze everything in this directory 17 | _.parse_file(BASE_DIR) # <-- IMPORTANT 18 | 19 | # bind will monkey_patch parser.parse_args to do its job 20 | hpargparse.bind(parser, _) # <-- IMPORTANT 21 | 22 | # parse args and set the values 23 | args = parser.parse_args() 24 | 25 | # ... do whatever you want next 26 | import lib 27 | 28 | print("a = {}".format(_.get_value("a"))) 29 | print("b = {}".format(_.get_value("b"))) 30 | print("lib.add() = {}".format(lib.add())) 31 | print("lib.mult() = {}".format(lib.mult())) 32 | 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Megvii Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | style: 2 | script: 3 | - python3 -V 4 | - pip3 install -r requirements.dev.txt --user 5 | - make style-check 6 | 7 | unittest: 8 | script: 9 | - pip3 install -r requirements.dev.txt -r requirements.txt --user 10 | - make test 11 | artifacts: 12 | paths: 13 | - htmlcov 14 | 15 | doc: 16 | script: 17 | - pip3 install -r requirements.dev.txt --user 18 | - make doc 19 | artifacts: 20 | paths: 21 | - docs/_build/html 22 | only: 23 | - master 24 | 25 | pages: 26 | stage: deploy 27 | dependencies: 28 | - unittest 29 | - doc 30 | script: 31 | - mv docs/_build/html public 32 | - mv htmlcov public/cov 33 | artifacts: 34 | paths: 35 | - public 36 | only: 37 | - master 38 | 39 | release: 40 | stage: deploy 41 | dependencies: 42 | - unittest 43 | - doc 44 | script: 45 | - make wheel 46 | - devpi use "${DEVPI_URL}" 47 | - devpi login "${DEVPI_LOGIN}" --password="${DEVPI_PASSWORD}" && devpi upload dist/* 48 | only: 49 | - tags 50 | -------------------------------------------------------------------------------- /.git-commit-template.txt: -------------------------------------------------------------------------------- 1 | # : subject 2 | # |<---- Using a Maximum Of 50 Characters ---->| 3 | # subject describes "If applied, this commit will..." 4 | 5 | # Explain why this change is being made and how it addresses the issue 6 | # |<---- Try To Limit Each Line to a Maximum Of 72 Characters ---->| 7 | 8 | # Any relevant tickets, articles or other resources? 9 | # example: fix issue #23 10 | 11 | # --- COMMIT END --- 12 | # Type can be 13 | # feat (new feature) 14 | # fix (bug fix) 15 | # refactor (refactoring production code) 16 | # style (formatting, missing semi colons, etc; no code change) 17 | # docs (changes to documentation) 18 | # test (adding or refactoring tests; no production code change) 19 | # chore (updating grunt tasks etc; no production code change) 20 | # -------------------- 21 | # Remember to 22 | # Capitalize the subject line 23 | # Use the imperative mood in the subject line 24 | # Do not end the subject line with a period 25 | # Separate subject from body with a blank line 26 | # Use the body to explain what and why vs. how 27 | # Can use multiple lines with "-" for bullet points in body 28 | # -------------------- 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from pathlib import Path 3 | 4 | BASE_DIR = Path(__file__).parent 5 | 6 | with open("README.md", "r") as fh: 7 | long_description = fh.read() 8 | 9 | with open("requirements.txt") as f: 10 | requirements = [line.strip() for line in f] 11 | 12 | pkginfo = {} 13 | exec((BASE_DIR / "hpargparse" / "pkginfo.py").read_text(), None, pkginfo) 14 | 15 | setuptools.setup( 16 | name="hpargparse", 17 | version=pkginfo["__version__"], 18 | author="Xinyu Zhou", 19 | author_email="zxy@megvii.com", 20 | description="An argparse extension for hpman", 21 | long_description=long_description, 22 | long_description_content_type="text/markdown", 23 | url="https://github.com/megvii-research/hpargparse", 24 | packages=setuptools.find_packages(), 25 | install_requires=requirements, 26 | scripts=["bin/hpcli"], 27 | classifiers=[ 28 | "Programming Language :: Python :: 3", 29 | "License :: OSI Approved :: MIT License", 30 | "Operating System :: OS Independent", 31 | "Intended Audience :: Developers", 32 | "Intended Audience :: Science/Research", 33 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 34 | "Topic :: Software Development :: Libraries :: Python Modules", 35 | ], 36 | ) 37 | -------------------------------------------------------------------------------- /examples/01-nn-training/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from hpman.m import _ 5 | 6 | 7 | class EnsureFloat(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, x): 12 | return x.float() 13 | 14 | 15 | class GlobalAveragePooling(nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def forward(self, x): 20 | assert len(x.shape) == 4, x.shape 21 | h, w = x.shape[2], x.shape[3] 22 | v = F.avg_pool2d(x, (h, w)) 23 | return v.reshape(v.shape[:2]) 24 | 25 | 26 | def ConvBNReLU(in_channels, out_channels, *args, **kwargs): 27 | return nn.Sequential( 28 | nn.Conv2d(in_channels, out_channels, *args, **kwargs), 29 | nn.BatchNorm2d(out_channels), 30 | nn.ReLU(inplace=True), 31 | ) 32 | 33 | 34 | def get_model(): 35 | base_channel = _("base_channel", 32) # <-- hyperparameter 36 | in_channels = 1 # _('input_channels', 1) 37 | 38 | return nn.Sequential( 39 | EnsureFloat(), 40 | ConvBNReLU(in_channels, base_channel, 3, stride=2, padding=1), 41 | ConvBNReLU(base_channel, base_channel * 2, 3, stride=2, padding=1), 42 | ConvBNReLU(base_channel * 2, base_channel * 4, 3, stride=2, padding=1), 43 | GlobalAveragePooling(), 44 | nn.Linear(base_channel * 4, 10), 45 | ) 46 | 47 | 48 | def compute_loss(y_pred, y_gt): 49 | return F.cross_entropy(y_pred, y_gt) 50 | 51 | 52 | def compute_metrics(y_pred, y_gt): 53 | return { 54 | "misclassify": (y_pred.argmax(dim=1) != y_gt) 55 | .detach() 56 | .float() 57 | .mean() 58 | .cpu() 59 | .numpy() 60 | } 61 | -------------------------------------------------------------------------------- /examples/01-nn-training/dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import MNIST 2 | import torch 3 | import functools 4 | import numpy as np 5 | 6 | 7 | @functools.lru_cache(maxsize=2) 8 | def get_mnist_dataset(train): 9 | return MNIST("../data", train=train, download=True) 10 | 11 | 12 | def get_data_and_labels(dataset_type): 13 | """ 14 | :param dataset_type: train or test 15 | :return: a dict of 16 | { 17 | 'data': X, 18 | 'labels': Y, 19 | } 20 | """ 21 | mnist = get_mnist_dataset(dataset_type == "train") # False for test dataet 22 | 23 | X = mnist.data 24 | Y = mnist.targets 25 | 26 | X = X.reshape(-1, 1, 28, 28) 27 | return {"data": X, "labels": Y} 28 | 29 | 30 | def iter_dataset_batch(rng, dct, batch_size, loop=False, cuda=False): 31 | """ 32 | :param rng: random number generator, a np.random.RandomState object 33 | :param dct: data dict. values must be torch tensors 34 | :param batch_size: batch size 35 | :param loop: whether to loop forever 36 | :param cuda: whether to use cuda tensor 37 | """ 38 | 39 | # sanity check 40 | assert len(dct) > 0, len(dct) 41 | 42 | n = None 43 | for k, v in dct.items(): 44 | assert isinstance(v, torch.Tensor), type(v) 45 | if n is None: 46 | n = v.shape[0] 47 | else: 48 | assert v.shape[0] == n, (n, v.shape) 49 | 50 | idx = np.arange(n) 51 | 52 | def run(): 53 | rng.shuffle(idx) 54 | 55 | for i in range(0, n, batch_size): 56 | j = min(n, i + batch_size) 57 | mb = {k: v[idx[i:j]] for k, v in dct.items()} 58 | if cuda: 59 | mb = {k: v.cuda() for k, v in mb.items()} 60 | yield mb 61 | 62 | if loop: 63 | while True: 64 | yield from run() 65 | else: 66 | yield from run() 67 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # http://www.sphinx-doc.org/en/master/config 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "hpargparse" 21 | copyright = "2019, EMTF" 22 | author = "EMTF" 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = "0.0.1" 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | master_doc = "index" 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = [ 35 | "sphinx.ext.autodoc", 36 | "sphinx.ext.todo", 37 | "sphinx.ext.coverage", 38 | "sphinx.ext.mathjax", 39 | "sphinx.ext.viewcode", 40 | "m2r", 41 | ] 42 | 43 | # Add any paths that contain templates here, relative to this directory. 44 | templates_path = ["_templates"] 45 | 46 | # List of patterns, relative to source directory, that match files and 47 | # directories to ignore when looking for source files. 48 | # This pattern also affects html_static_path and html_extra_path. 49 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 50 | 51 | pygments_style = "sphinx" 52 | todo_include_todos = True 53 | autoclass_content = "both" 54 | 55 | # -- Options for HTML output ------------------------------------------------- 56 | 57 | # The theme to use for HTML and HTML Help pages. See the documentation for 58 | # a list of builtin themes. 59 | # 60 | html_theme = "sphinx_rtd_theme" 61 | 62 | # Add any paths that contain custom static files (such as style sheets) here, 63 | # relative to this directory. They are copied after the builtin static files, 64 | # so a file named "default.css" will overwrite the builtin "default.css". 65 | html_static_path = ["_static"] 66 | -------------------------------------------------------------------------------- /examples/01-nn-training/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from tqdm import tqdm 3 | import functools 4 | import numpy as np 5 | import argparse 6 | import torch 7 | import yaml 8 | import os 9 | from torch import optim 10 | 11 | from hpman.m import _ 12 | import hpargparse 13 | 14 | 15 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser() 20 | _.parse_file(BASE_DIR) 21 | hpargparse.bind(parser, _) 22 | parser.parse_args() # we need not to use args 23 | 24 | # print all hyperparameters 25 | print("-" * 10 + " Hyperparameters " + "-" * 10) 26 | print(yaml.dump(_.get_values())) 27 | 28 | optimizer_cls = { 29 | "adam": optim.Adam, 30 | "sgd": functools.partial(optim.SGD, momentum=0.9), 31 | }[ 32 | _("optimizer", "adam") # <-- hyperparameter 33 | ] 34 | 35 | import model 36 | 37 | net = model.get_model() 38 | if torch.cuda.is_available(): 39 | net.cuda() 40 | 41 | optimizer = optimizer_cls( 42 | net.parameters(), 43 | lr=_("learning_rate", 1e-3), # <-- hyperparameter 44 | weight_decay=_("weight_decay", 1e-5), # <-- hyperparameter 45 | ) 46 | 47 | import dataset 48 | 49 | train_ds = dataset.get_data_and_labels("train") 50 | test_ds = dataset.get_data_and_labels("test") 51 | if torch.cuda.is_available(): 52 | # since mnist is a small dataset, we store the test dataset all in the 53 | # gpu memory 54 | test_ds = {k: v.cuda() for k, v in test_ds.items()} 55 | 56 | rng = np.random.RandomState(_("seed", 42)) # <-- hyperparameter 57 | 58 | for epoch in range(_("num_epochs", 30)): # <-- hyperparameter 59 | net.train() 60 | tq = tqdm( 61 | enumerate( 62 | dataset.iter_dataset_batch( 63 | rng, 64 | train_ds, 65 | _("batch_size", 256), # <-- hyperparameter 66 | cuda=torch.cuda.is_available(), 67 | ) 68 | ) 69 | ) 70 | for step, minibatch in tq: 71 | optimizer.zero_grad() 72 | 73 | Y_pred = net(minibatch["data"]) 74 | loss = model.compute_loss(Y_pred, minibatch["labels"]) 75 | 76 | loss.backward() 77 | optimizer.step() 78 | 79 | metrics = model.compute_metrics(Y_pred, minibatch["labels"]) 80 | metrics["loss"] = loss.detach().cpu().numpy() 81 | tq.desc = "e:{} s:{} {}".format( 82 | epoch, 83 | step, 84 | " ".join(["{}:{}".format(k, v) for k, v in sorted(metrics.items())]), 85 | ) 86 | 87 | net.eval() 88 | 89 | # since mnist is a small dataset, we predict all values at once. 90 | Y_pred = net(test_ds["data"]) 91 | metrics = model.compute_metrics(Y_pred, test_ds["labels"]) 92 | print( 93 | "eval: {}".format( 94 | " ".join(["{}:{}".format(k, v) for k, v in sorted(metrics.items())]) 95 | ) 96 | ) 97 | 98 | # Save the model. We intentionally not saving the model here for 99 | # tidiness of the example 100 | # torch.save(net, "model.pt") 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /tests/test_hputils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import hpargparse 3 | import argparse 4 | import pickle 5 | import hpman 6 | from pathlib import Path 7 | import contextlib 8 | 9 | import os 10 | import shutil 11 | import tempfile 12 | 13 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 14 | test_file_dir = Path(BASE_DIR) / "test_files" 15 | 16 | 17 | @contextlib.contextmanager 18 | def auto_cleanup_temp_dir(): 19 | """ 20 | :return: a `class`:`pathlib.Path` object 21 | """ 22 | try: 23 | tmpdir = Path(tempfile.mkdtemp(prefix="hpargparse-test")) 24 | yield tmpdir 25 | finally: 26 | shutil.rmtree(str(tmpdir)) 27 | 28 | 29 | class TestAll(unittest.TestCase): 30 | def _make(self, fpath): 31 | fpath = str(fpath) 32 | hp_mgr = hpman.HyperParameterManager("_") 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument(dest="predefined_arg") 35 | hp_mgr.parse_file(fpath) 36 | hpargparse.bind(parser, hp_mgr) 37 | return parser, hp_mgr 38 | 39 | def _make_basic(self): 40 | return self._make(test_file_dir / "basic" / "lib.py") 41 | 42 | def _make_str(self): 43 | return self._make(test_file_dir / "basic" / "str.py") 44 | 45 | def _make_dict_and_list(self): 46 | return self._make(test_file_dir / "dict_and_list" / "lib_dict_and_list.py") 47 | 48 | def test_basic(self): 49 | parser, hp_mgr = self._make_basic() 50 | parser.parse_args(["an_arg_value"]) 51 | 52 | def test_set_value(self): 53 | parser, hp_mgr = self._make_basic() 54 | args = parser.parse_args(["an_arg_value", "--a", "3"]) 55 | self.assertEqual(args.a, 3) 56 | self.assertEqual(hp_mgr.get_value("a"), 3) 57 | 58 | def test_set_value_failure(self): 59 | parser, hp_mgr = self._make_basic() 60 | self.assertRaises(SystemExit, parser.parse_args, ["an_arg_value", "--a", "abc"]) 61 | 62 | def test_hp_list(self): 63 | parser, hp_mgr = self._make_basic() 64 | self.assertRaises(SystemExit, parser.parse_args, ["an_arg_value", "--hp-list"]) 65 | self.assertRaises( 66 | SystemExit, parser.parse_args, ["an_arg_value", "--hp-list", "detail"] 67 | ) 68 | self.assertRaises( 69 | SystemExit, parser.parse_args, ["an_arg_value", "--hp-list", "json"] 70 | ) 71 | 72 | def test_hp_detail(self): 73 | parser, hp_mgr = self._make_basic() 74 | self.assertRaises( 75 | SystemExit, parser.parse_args, ["an_arg_value", "--hp-detail"] 76 | ) 77 | 78 | def test_hp_save(self): 79 | for name in ["config.yaml", "config.yml", "config.pkl", "config.pickle"]: 80 | with auto_cleanup_temp_dir() as d: 81 | parser, hp_mgr = self._make_basic() 82 | path = str(d / name) 83 | parser.parse_args(["an_arg_value", "--hp-save", path]) 84 | 85 | def test_hp_save_failure(self): 86 | with auto_cleanup_temp_dir() as d: 87 | parser, hp_mgr = self._make_basic() 88 | path = str(d / "a.bcd") 89 | self.assertRaisesRegex( 90 | ValueError, 91 | "Unsupported file extension: .bcd", 92 | parser.parse_args, 93 | ["an_arg_value", "--hp-save", path], 94 | ) 95 | 96 | def test_hp_load(self): 97 | # test load yaml 98 | parser, hp_mgr = self._make_basic() 99 | parser.parse_args( 100 | ["an_arg_value", "--hp-load", str(test_file_dir / "basic" / "config.yaml")] 101 | ) 102 | 103 | # test load pickle 104 | with auto_cleanup_temp_dir() as d: 105 | parser, hp_mgr = self._make_basic() 106 | path = str(d / "a.pkl") 107 | with open(path, "wb") as f: 108 | pickle.dump({"a": 10, "b": 20}, f) 109 | parser.parse_args(["an_arg_value", "--hp-load", str(path)]) 110 | 111 | def test_hp_load_failure(self): 112 | parser, hp_mgr = self._make_basic() 113 | self.assertRaisesRegex( 114 | TypeError, 115 | """\\('Error parsing hyperparameter `a`', "int\\(\\) argument must be a string, a bytes-like object or a real number, not 'dict'"\\)""", 116 | parser.parse_args, 117 | [ 118 | "an_arg_value", 119 | "--hp-load", 120 | str(test_file_dir / "basic" / "config.failure.yaml"), 121 | ], 122 | ) 123 | 124 | def test_hp_load_str(self): 125 | parser, hp_mgr = self._make_str() 126 | self.assertEqual(hp_mgr.get_value("str_from"), "source_code") 127 | parser.parse_args( 128 | ["an_arg_value", "--hp-load", str(test_file_dir / "basic" / "str.yaml")] 129 | ) 130 | self.assertEqual(hp_mgr.get_value("str_from"), "yaml") 131 | 132 | def test_hp_load_str_with_cmd_set(self): 133 | parser, hp_mgr = self._make_str() 134 | parser.parse_args( 135 | [ 136 | "an_arg_value", 137 | "--hp-load", 138 | str(test_file_dir / "basic" / "str.yaml"), 139 | "--str-from", 140 | "cmd", 141 | ] 142 | ) 143 | self.assertEqual(hp_mgr.get_value("str_from"), "cmd") 144 | 145 | def test_set_dict_and_list(self): 146 | parser, hp_mgr = self._make_dict_and_list() 147 | args = parser.parse_args(["an_arg_value", "--a", '{"key": 3}']) 148 | self.assertDictEqual(args.a, {"key": 3}) 149 | self.assertDictEqual(hp_mgr.get_value("a"), {"key": 3}) 150 | 151 | def test_set_dict_and_list_failure(self): 152 | parser, hp_mgr = self._make_dict_and_list() 153 | self.assertRaises(SystemExit, parser.parse_args, ["an_arg_value", "--a", "1"]) 154 | 155 | def test_hp_load_dict_and_list(self): 156 | parser, hp_mgr = self._make_dict_and_list() 157 | parser.parse_args( 158 | [ 159 | "an_arg_value", 160 | "--hp-load", 161 | str(test_file_dir / "dict_and_list" / "config.yaml"), 162 | ] 163 | ) 164 | self.assertDictEqual(hp_mgr.get_value("a"), {"key": 3}) 165 | self.assertListEqual(hp_mgr.get_value("b"), [2, 3, 4]) 166 | 167 | def test_hp_load_dict_and_list_with_cmd_set(self): 168 | parser, hp_mgr = self._make_dict_and_list() 169 | parser.parse_args( 170 | [ 171 | "an_arg_value", 172 | "--hp-load", 173 | str(test_file_dir / "dict_and_list" / "config.yaml"), 174 | "--a", 175 | '{"a": 1}', 176 | "--b", 177 | "[1, 2, 3]", 178 | "--c", 179 | "42", 180 | ] 181 | ) 182 | self.assertDictEqual(hp_mgr.get_value("a"), {"a": 1}) 183 | self.assertListEqual(hp_mgr.get_value("b"), [1, 2, 3]) 184 | self.assertEqual(hp_mgr.get_value("c"), 42) 185 | 186 | def _make_pair(self): 187 | hp_mgr = hpman.HyperParameterManager("_") 188 | parser = argparse.ArgumentParser( 189 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 190 | ) 191 | hp_mgr.parse_source('_("a", 1)') 192 | return hp_mgr, parser 193 | 194 | def test_bind_actions(self): 195 | def do_test(assert_type, regex, **bind_kwargs): 196 | hp_mgr, parser = self._make_pair() 197 | hpargparse.bind(parser, hp_mgr, **bind_kwargs) 198 | 199 | assertion = {True: self.assertRegex, False: self.assertNotRegex}[ 200 | assert_type 201 | ] 202 | assertion(parser.format_help(), regex) 203 | 204 | all_keys = ["save", "load", "list", "detail", "serial-format", "exit"] 205 | 206 | def test_exist(keywords, **bind_kwargs): 207 | 208 | not_exists = all_keys.copy() 209 | for k in keywords: 210 | assert k in all_keys, (k, all_keys) 211 | not_exists.remove(k) 212 | 213 | exist_regex = "|".join("-{}".format(k) for k in keywords) 214 | not_exist_regex = "|".join("-{}".format(k) for k in not_exists) 215 | 216 | assert (set(keywords) | set(not_exists)) == set(all_keys) 217 | 218 | if exist_regex: 219 | do_test(True, exist_regex, **bind_kwargs) 220 | if not_exist_regex: 221 | do_test(False, not_exist_regex, **bind_kwargs) 222 | 223 | test_exist(all_keys) 224 | test_exist([], inject_actions=False) 225 | test_exist(["save", "serial-format", "exit"], inject_actions=["save"]) 226 | test_exist(["load", "serial-format", "exit"], inject_actions=["load"]) 227 | test_exist( 228 | ["save", "load", "serial-format", "exit"], inject_actions=["save", "load"] 229 | ) 230 | test_exist(["list", "exit"], inject_actions=["list"]) 231 | test_exist( 232 | ["save", "serial-format", "list", "exit"], inject_actions=["save", "list"] 233 | ) 234 | test_exist( 235 | ["load", "serial-format", "list", "exit"], inject_actions=["load", "list"] 236 | ) 237 | test_exist( 238 | ["save", "load", "serial-format", "list", "exit"], 239 | inject_actions=["save", "load", "list"], 240 | ) 241 | 242 | def test_bind_action_prefix(self): 243 | hp_mgr, parser = self._make_pair() 244 | hpargparse.bind(parser, hp_mgr, action_prefix="hahaha") 245 | self.assertRegex(parser.format_help(), r"--hahaha-") 246 | 247 | def test_bind_serial_format(self): 248 | hp_mgr, parser = self._make_pair() 249 | hpargparse.bind(parser, hp_mgr, serial_format="pickle") 250 | 251 | with auto_cleanup_temp_dir() as d: 252 | path = d / "a.pkl" 253 | parser.parse_args(["--hp-save", str(path)]) 254 | 255 | with open(str(path), "rb") as f: 256 | x = pickle.load(f) 257 | 258 | self.assertEqual(x["a"], 1) 259 | 260 | def test_show_default_value_in_help_message(self): 261 | hp_mgr, parser = self._make_pair() 262 | hp_mgr.parse_source("_('b', True)") 263 | hp_mgr.parse_source("_('c', 'deadbeef')") 264 | hpargparse.bind(parser, hp_mgr) 265 | 266 | h = parser.format_help() 267 | 268 | self.assertRegex(h, "A int hyper-parameter named `a`\\. \\(default: 1\\)") 269 | self.assertRegex(h, "--b {True,False}") 270 | self.assertRegex(h, "A bool hyper-parameter named `b`\\. \\(default: True\\)") 271 | self.assertRegex( 272 | h, "A str hyper-parameter named `c`\\. \\(default: deadbeef\\)" 273 | ) 274 | 275 | def test_show_use_defined_arguments(self): 276 | hp_mgr, parser = self._make_pair() 277 | # help message 278 | hp_mgr.parse_source('_("b", True, help="a help message")') 279 | # choices 280 | hp_mgr.parse_source('_("c", "deadbeef", choices=["deadbeef", "beefdead"])') 281 | # required 282 | hp_mgr.parse_source('_("d", "somevalue", required=True)') 283 | 284 | hpargparse.bind(parser, hp_mgr) 285 | 286 | h = parser.format_help() 287 | 288 | self.assertRegex(h, "a help message") 289 | self.assertRegex(h, "--c {deadbeef,beefdead}") 290 | self.assertNotRegex(h, "\\[--d D\\]") 291 | 292 | def test_subparser(self): 293 | hp_mgr = hpman.HyperParameterManager("_") 294 | hp_mgr.parse_source('_("a", 1)') 295 | 296 | parser = argparse.ArgumentParser( 297 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 298 | ) 299 | subparsers = parser.add_subparsers() 300 | p = subparsers.add_parser("sub") 301 | hpargparse.bind(p, hp_mgr) 302 | parser.parse_args(["sub", "--a", str(2)]) 303 | 304 | self.assertEqual(2, hp_mgr.get_values()["a"]) 305 | 306 | def test_string_as_default_should_not_be_saved(self): 307 | hp_mgr = hpman.HyperParameterManager("_") 308 | hp_mgr.parse_source('_("a", "hello")') 309 | 310 | parser = argparse.ArgumentParser( 311 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 312 | ) 313 | hpargparse.bind(parser, hp_mgr) 314 | args = parser.parse_args(["--a", "world"]) 315 | self.assertNotIsInstance( 316 | hp_mgr.get_value("a"), hpargparse.hputils.StringAsDefault 317 | ) 318 | 319 | def test_bool_argument(self): 320 | hp_mgr, parser = self._make_pair() 321 | hp_mgr.parse_source("_('b', True)") 322 | hpargparse.bind(parser, hp_mgr) 323 | 324 | args = parser.parse_args(["--b", "False"]) 325 | self.assertEqual(args.b, False) 326 | 327 | # test the other way around 328 | hp_mgr, parser = self._make_pair() 329 | hp_mgr.parse_source("_('b', False)") 330 | hpargparse.bind(parser, hp_mgr) 331 | 332 | args = parser.parse_args(["--b", "True"]) 333 | self.assertEqual(args.b, True) 334 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hpargparse 2 | [![CircleCI](https://img.shields.io/circleci/build/github/megvii-research/hpargparse/master)](https://app.circleci.com/pipelines/github/megvii-research/hpargparse?branch=master) 3 | [![codecov](https://img.shields.io/codecov/c/github/megvii-research/hpargparse)](https://codecov.io/gh/megvii-research/hpargparse) 4 | 5 | An [argparse](https://docs.python.org/3/library/argparse.html) extension for [hpman](https://github.com/megvii-research/hpman) 6 | 7 | # Installation 8 | ```bash 9 | python3 -m pip install hpargparse 10 | ``` 11 | 12 | # Brief Introduction 13 | 14 | The following example lies in [examples/02-brief](./examples/02-brief). 15 | 16 | `main.py` 17 | ```python 18 | #!/usr/bin/env python3 19 | 20 | from hpman.m import _ 21 | import hpargparse 22 | 23 | import argparse 24 | 25 | 26 | def func(): 27 | weight_decay = _("weight_decay", 1e-5) 28 | print("weight decay is {}".format(weight_decay)) 29 | 30 | 31 | def main(): 32 | parser = argparse.ArgumentParser() 33 | _.parse_file(__file__) 34 | hpargparse.bind(parser, _) 35 | parser.parse_args() 36 | 37 | func() 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | ``` 43 | 44 | results in: 45 | ```bash 46 | $ ./main.py 47 | weight decay is 1e-05 48 | $ ./main.py --weight-decay 1e-4 49 | weight decay is 0.0001 50 | $ ./main.py --weight-decay 1e-4 --hp-list 51 | weight_decay: 0.0001 52 | $ ./main.py --weight-decay 1e-4 --hp-list detail 53 | All hyperparameters: 54 | ['weight_decay'] 55 | Details 56 | ╔══════════════╦═══════╦════════╦═════════════════════════════════════════════════════════════════╗ 57 | ║ name ║ type ║ value ║ details ║ 58 | ╠══════════════╬═══════╬════════╬═════════════════════════════════════════════════════════════════╣ 59 | ║ weight_decay ║ float ║ 0.0001 ║ occurrence[0]: ║ 60 | ║ ║ ║ ║ ./main.py:10 ║ 61 | ║ ║ ║ ║ 5: ║ 62 | ║ ║ ║ ║ 6: import argparse ║ 63 | ║ ║ ║ ║ 7: ║ 64 | ║ ║ ║ ║ 8: ║ 65 | ║ ║ ║ ║ 9: def func(): ║ 66 | ║ ║ ║ ║ ==> 10: weight_decay = _("weight_decay", 1e-5) ║ 67 | ║ ║ ║ ║ 11: print("weight decay is {}".format(weight_decay)) ║ 68 | ║ ║ ║ ║ 12: ║ 69 | ║ ║ ║ ║ 13: ║ 70 | ║ ║ ║ ║ 14: def main(): ║ 71 | ║ ║ ║ ║ 15: parser = argparse.ArgumentParser() ║ 72 | ╚══════════════╩═══════╩════════╩═════════════════════════════════════════════════════════════════╝ 73 | $ ./main.py -h 74 | usage: main.py [-h] [--weight-decay WEIGHT_DECAY] [--hp-save HP_SAVE] 75 | [--hp-load HP_LOAD] [--hp-list [{detail,yaml,json}]] 76 | [--hp-detail] [--hp-serial-format {auto,yaml,pickle}] 77 | [--hp-exit] 78 | 79 | optional arguments: 80 | -h, --help show this help message and exit 81 | --weight-decay WEIGHT_DECAY 82 | (default: 1e-05) 83 | --hp-save HP_SAVE Save hyperparameters to a file. The hyperparameters 84 | are saved after processing of all other options 85 | (default: None) 86 | --hp-load HP_LOAD Load hyperparameters from a file. The hyperparameters 87 | are loaded before any other options are processed 88 | (default: None) 89 | --hp-list [{detail,yaml,json}] 90 | List all available hyperparameters. If `--hp-list 91 | detail` is specified, a verbose table will be print 92 | (default: None) 93 | --hp-detail Shorthand for --hp-list detail (default: False) 94 | --hp-serial-format {auto,yaml,pickle} 95 | Format of the saved config file. Defaults to auto. It 96 | can be set to override auto file type deduction. 97 | (default: auto) 98 | --hp-exit process all hpargparse actions and quit (default: 99 | False) 100 | ``` 101 | 102 | # hpcli: The Command Line Tool 103 | Besides using `hpargparse.bind` in you code, we also come with a command line 104 | tool `hpcli` to provide similar functions to any existing file using hpman. 105 | 106 | `src.py` 107 | ```python 108 | from hpman.m import _ 109 | 110 | _('num_channels', 128) 111 | _('num_layers', 50) 112 | ``` 113 | 114 | In shell: 115 | ```bash 116 | $ hpcli src.py 117 | num_channels: 128 118 | num_layers: 50 119 | $ hpcli src.py --num-layers 101 120 | num_channels: 128 121 | num_layers: 101 122 | $ hpcli src.py --num-layers 101 --hp-save config.yaml 123 | num_channels: 128 124 | num_layers: 101 125 | $ hpcli src.py --num-layers 101 --hp-save config.yaml --hp-list detail 126 | All hyperparameters: 127 | ['num_channels', 'num_layers'] 128 | Details 129 | ╔══════════════╦══════╦═══════╦═══════════════════════════════════════╗ 130 | ║ name ║ type ║ value ║ details ║ 131 | ╠══════════════╬══════╬═══════╬═══════════════════════════════════════╣ 132 | ║ num_channels ║ int ║ 128 ║ occurrence[0]: ║ 133 | ║ ║ ║ ║ src.py:3 ║ 134 | ║ ║ ║ ║ 1: from hpman.m import _ ║ 135 | ║ ║ ║ ║ 2: ║ 136 | ║ ║ ║ ║ ==> 3: _("num_channels", 128) ║ 137 | ║ ║ ║ ║ 4: _("num_layers", 50) ║ 138 | ║ ║ ║ ║ 5: ║ 139 | ╠══════════════╬══════╬═══════╬═══════════════════════════════════════╣ 140 | ║ num_layers ║ int ║ 101 ║ occurrence[0]: ║ 141 | ║ ║ ║ ║ src.py:4 ║ 142 | ║ ║ ║ ║ 1: from hpman.m import _ ║ 143 | ║ ║ ║ ║ 2: ║ 144 | ║ ║ ║ ║ 3: _("num_channels", 128) ║ 145 | ║ ║ ║ ║ ==> 4: _("num_layers", 50) ║ 146 | ║ ║ ║ ║ 5: ║ 147 | ╚══════════════╩══════╩═══════╩═══════════════════════════════════════╝ 148 | $ hpcli src.py -h 149 | usage: hpcli [-h] [--num-channels NUM_CHANNELS] [--num-layers NUM_LAYERS] 150 | [--hp-save HP_SAVE] [--hp-load HP_LOAD] 151 | [--hp-list [{detail,yaml,json}]] [--hp-detail] 152 | [--hp-serial-format {auto,yaml,pickle}] [--hp-exit] 153 | 154 | optional arguments: 155 | -h, --help show this help message and exit 156 | --num-channels NUM_CHANNELS 157 | (default: 128) 158 | --num-layers NUM_LAYERS 159 | (default: 50) 160 | --hp-save HP_SAVE Save hyperparameters to a file. The hyperparameters 161 | are saved after processing of all other options 162 | (default: None) 163 | --hp-load HP_LOAD Load hyperparameters from a file. The hyperparameters 164 | are loaded before any other options are processed 165 | (default: None) 166 | --hp-list [{detail,yaml,json}] 167 | List all available hyperparameters. If `--hp-list 168 | detail` is specified, a verbose table will be print 169 | (default: yaml) 170 | --hp-detail Shorthand for --hp-list detail (default: False) 171 | --hp-serial-format {auto,yaml,pickle} 172 | Format of the saved config file. Defaults to auto. It 173 | can be set to override auto file type deduction. 174 | (default: auto) 175 | --hp-exit process all hpargparse actions and quit (default: 176 | False) 177 | ``` 178 | 179 | This could be a handy tool to inspect the hyperparameters in your code. 180 | 181 | # Example: Deep Learning Experiment 182 | This example lies in [examples/01-nn-training](./examples/01-nn-training). 183 | 184 | It is a fully functional example of training a LeNet on MNIST using 185 | `hpargparse` and `hpman` collaboratively to manage hyperparameters. 186 | 187 | We **highly suggest** you playing around this example. 188 | 189 | 190 | # Example: Basics Walkthrough 191 | Now we break down the functions one-by-one. 192 | 193 | The following example lies in [examples/00-basic](./examples/00-basic). 194 | 195 | `lib.py`: 196 | ```python 197 | from hpman.m import _ 198 | 199 | 200 | def add(): 201 | return _("a", 1) + _("b", 2) 202 | 203 | 204 | def mult(): 205 | return _("a") * _("b") 206 | ``` 207 | 208 | `main.py` 209 | ```python 210 | #!/usr/bin/env python3 211 | import argparse 212 | import hpargparse 213 | from hpman.m import _ 214 | import os 215 | 216 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 217 | 218 | 219 | def main(): 220 | parser = argparse.ArgumentParser() 221 | 222 | # ... do whatever you want 223 | parser.add_argument(dest="predefined_arg") 224 | 225 | # analyze everything in this directory 226 | _.parse_file(BASE_DIR) # <-- IMPORTANT 227 | 228 | # bind will monkey_patch parser.parse_args to do its job 229 | hpargparse.bind(parser, _) # <-- IMPORTANT 230 | 231 | # parse args and set the values 232 | args = parser.parse_args() 233 | 234 | # ... do whatever you want next 235 | import lib 236 | 237 | print("a = {}".format(_.get_value("a"))) 238 | print("b = {}".format(_.get_value("b"))) 239 | print("lib.add() = {}".format(lib.add())) 240 | print("lib.mult() = {}".format(lib.mult())) 241 | 242 | 243 | if __name__ == "__main__": 244 | main() 245 | ``` 246 | 247 | ## Help 248 | ```bash 249 | $ ./main.py -h 250 | usage: main.py [-h] [--a A] [--b B] [--hp-save HP_SAVE] [--hp-load HP_LOAD] 251 | [--hp-list [{detail,yaml,json}]] [--hp-detail] 252 | [--hp-serial-format {auto,yaml,pickle}] [--hp-exit] 253 | predefined_arg 254 | 255 | positional arguments: 256 | predefined_arg 257 | 258 | optional arguments: 259 | -h, --help show this help message and exit 260 | --a A (default: 1) 261 | --b B (default: 2) 262 | --hp-save HP_SAVE Save hyperparameters to a file. The hyperparameters 263 | are saved after processing of all other options 264 | (default: None) 265 | --hp-load HP_LOAD Load hyperparameters from a file. The hyperparameters 266 | are loaded before any other options are processed 267 | (default: None) 268 | --hp-list [{detail,yaml,json}] 269 | List all available hyperparameters. If `--hp-list 270 | detail` is specified, a verbose table will be print 271 | (default: None) 272 | --hp-detail Shorthand for --hp-list detail (default: False) 273 | --hp-serial-format {auto,yaml,pickle} 274 | Format of the saved config file. Defaults to auto. It 275 | can be set to override auto file type deduction. 276 | (default: auto) 277 | --hp-exit process all hpargparse actions and quit (default: 278 | False) 279 | ``` 280 | 281 | 282 | ## Set Hyperparameters from Command Line Arguments 283 | ```bash 284 | $ ./main.py some_thing --a 3 --b 5 285 | a = 3 286 | b = 5 287 | lib.add() = 8 288 | lib.mult() = 15 289 | ``` 290 | 291 | 292 | ## List All Hyperparameters 293 | ```bash 294 | $ ./main.py some_arg --hp-list 295 | a: 1 296 | b: 2 297 | ``` 298 | ... and details: 299 | ```bash 300 | $ ./main.py some_arg --hp-list detail 301 | All hyperparameters: 302 | ['a', 'b'] 303 | Details 304 | ╔══════╦══════╦═══════╦═══════════════════════════════════════════════════════════════════════════╗ 305 | ║ name ║ type ║ value ║ details ║ 306 | ╠══════╬══════╬═══════╬═══════════════════════════════════════════════════════════════════════════╣ 307 | ║ a ║ int ║ 1 ║ occurrence[0]: ║ 308 | ║ ║ ║ ║ /data/project/hpargparse/examples/00-basic/lib.py:8 ║ 309 | ║ ║ ║ ║ 3: # for more usecases, please refer to hpman's document ║ 310 | ║ ║ ║ ║ 4: ║ 311 | ║ ║ ║ ║ 5: ║ 312 | ║ ║ ║ ║ 6: def add(): ║ 313 | ║ ║ ║ ║ 7: # define a hyperparameter on-the-fly with defaults ║ 314 | ║ ║ ║ ║ ==> 8: return _("a", 1) + _("b", 2) ║ 315 | ║ ║ ║ ║ 9: ║ 316 | ║ ║ ║ ║ 10: ║ 317 | ║ ║ ║ ║ 11: def mult(): ║ 318 | ║ ║ ║ ║ 12: # reuse a pre-defined hyperparameters ║ 319 | ║ ║ ║ ║ 13: return _("a") * _("b") ║ 320 | ║ ║ ║ ║ occurrence[1]: ║ 321 | ║ ║ ║ ║ /data/project/hpargparse/examples/00-basic/lib.py:13 ║ 322 | ║ ║ ║ ║ 8: return _("a", 1) + _("b", 2) ║ 323 | ║ ║ ║ ║ 9: ║ 324 | ║ ║ ║ ║ 10: ║ 325 | ║ ║ ║ ║ 11: def mult(): ║ 326 | ║ ║ ║ ║ 12: # reuse a pre-defined hyperparameters ║ 327 | ║ ║ ║ ║ ==> 13: return _("a") * _("b") ║ 328 | ║ ║ ║ ║ 14: ║ 329 | ╠══════╬══════╬═══════╬═══════════════════════════════════════════════════════════════════════════╣ 330 | ║ b ║ int ║ 2 ║ occurrence[0]: ║ 331 | ║ ║ ║ ║ /data/project/hpargparse/examples/00-basic/lib.py:8 ║ 332 | ║ ║ ║ ║ 3: # for more usecases, please refer to hpman's document ║ 333 | ║ ║ ║ ║ 4: ║ 334 | ║ ║ ║ ║ 5: ║ 335 | ║ ║ ║ ║ 6: def add(): ║ 336 | ║ ║ ║ ║ 7: # define a hyperparameter on-the-fly with defaults ║ 337 | ║ ║ ║ ║ ==> 8: return _("a", 1) + _("b", 2) ║ 338 | ║ ║ ║ ║ 9: ║ 339 | ║ ║ ║ ║ 10: ║ 340 | ║ ║ ║ ║ 11: def mult(): ║ 341 | ║ ║ ║ ║ 12: # reuse a pre-defined hyperparameters ║ 342 | ║ ║ ║ ║ 13: return _("a") * _("b") ║ 343 | ║ ║ ║ ║ occurrence[1]: ║ 344 | ║ ║ ║ ║ /data/project/hpargparse/examples/00-basic/lib.py:13 ║ 345 | ║ ║ ║ ║ 8: return _("a", 1) + _("b", 2) ║ 346 | ║ ║ ║ ║ 9: ║ 347 | ║ ║ ║ ║ 10: ║ 348 | ║ ║ ║ ║ 11: def mult(): ║ 349 | ║ ║ ║ ║ 12: # reuse a pre-defined hyperparameters ║ 350 | ║ ║ ║ ║ ==> 13: return _("a") * _("b") ║ 351 | ║ ║ ║ ║ 14: ║ 352 | ╚══════╩══════╩═══════╩═══════════════════════════════════════════════════════════════════════════╝ 353 | ``` 354 | 355 | ## Save/Load from/to YAML file 356 | ```bash 357 | # save to yaml file 358 | $ ./main.py some_arg --hp-save /tmp/config.yaml --hp-exit 359 | 360 | $ cat /tmp/config.yaml 361 | a: 1 362 | b: 2 363 | 364 | # load from yaml file 365 | $ cat config_modified.yaml 366 | a: 123 367 | b: 456 368 | 369 | $ ./main.py some_arg --hp-load config_modified.yaml --hp-list 370 | a: 123 371 | b: 456 372 | ``` 373 | 374 | # Development 375 | 1. Install requirements: 376 | ```bash 377 | pip install -r requirements.dev.txt -r requirements.txt 378 | ``` 379 | 380 | 2. Activate git commit template 381 | ```bash 382 | git config commit.template .git-commit-template.txt 383 | ``` 384 | 385 | 3. Install pre-commit hook 386 | ```bash 387 | pre-commit install 388 | ``` 389 | 390 | -------------------------------------------------------------------------------- /hpargparse/hputils.py: -------------------------------------------------------------------------------- 1 | # logistics 2 | import subprocess 3 | import sys 4 | import ast 5 | 6 | import argparse 7 | import functools 8 | import dill 9 | import yaml 10 | import json 11 | import collections 12 | import os 13 | from copy import deepcopy 14 | 15 | from types import MethodType 16 | 17 | import hpman 18 | from hpman import ( 19 | HyperParameterManager, 20 | HyperParameterOccurrence, 21 | SourceHelper, 22 | EmptyValue, 23 | ) 24 | 25 | from . import config 26 | 27 | from typing import Union, List 28 | 29 | 30 | from rich.console import Console 31 | from rich.table import Column, Table 32 | from rich.syntax import Syntax 33 | from rich.style import Style 34 | from rich import box 35 | 36 | 37 | class StringAsDefault(str): 38 | """If a string is used as parser.add_argument(default=string), 39 | this string should be mark as StringAsDefault 40 | """ 41 | 42 | pass 43 | 44 | 45 | def list_of_dict2tab(list_of_dict, headers): 46 | """Convert "a list of dict" to "a list of list" that suitable 47 | for table processing libraries (such as tabulate) 48 | 49 | :params list_of_dict: input data 50 | :params headers: a list of str, header items with order 51 | 52 | :return: list of list of objects 53 | """ 54 | rows = [[dct[h] for h in headers] for dct in list_of_dict] 55 | return rows 56 | 57 | 58 | def make_detail_str(details): 59 | """ 60 | :param details: List of details. A detail is either a string 61 | or a list of strings, each one comprises a line 62 | 63 | :return: a string, the formatted detail 64 | """ 65 | strs = [] 66 | for d in details: 67 | if isinstance(d["detail"], str): 68 | ds = [d["detail"]] 69 | else: 70 | assert isinstance(d["detail"], (list, tuple)), d["detail"] 71 | ds = d["detail"] 72 | 73 | s = "\n".join(["{}:".format(d["name"])] + [" " + line for line in ds]) 74 | strs.append(s) 75 | strs.append("\n") 76 | 77 | return "".join(strs) 78 | 79 | 80 | def make_value_illu(v): 81 | """Mute non-literal-evaluable values 82 | 83 | :return: None if v is :class:`hpman.NotLiteralEvaluable`, 84 | otherwise the original input. 85 | """ 86 | if isinstance(v, hpman.NotLiteralEvaluable): 87 | return None 88 | return v 89 | 90 | 91 | def hp_list(mgr): 92 | """Print hyperparameter settings to stdout""" 93 | 94 | syntax = Syntax( 95 | "All hyperparameters:\n" + " {}".format(sorted(mgr.get_values().keys())), 96 | "python", 97 | theme="monokai", 98 | ) 99 | console = Console() 100 | console.print(syntax) 101 | 102 | # construct a table 103 | console = Console() 104 | table = Table( 105 | title="Details", 106 | title_style=Style(color="bright_cyan", bgcolor="grey15", bold=True), 107 | style=Style(bgcolor="grey15"), 108 | header_style="bold magenta", 109 | box=box.DOUBLE, 110 | border_style="bright_cyan", 111 | show_lines=True, 112 | ) 113 | table.add_column("name", style="green_yellow") 114 | table.add_column("type", style="light_steel_blue") 115 | table.add_column("value", style="light_cyan1") 116 | table.add_column("details") 117 | 118 | """ 119 | for k, d in sorted(mgr.db.group_by("name").items()): 120 | details = [] 121 | for i, oc in enumerate( 122 | d.select(L.exist_attr("filename")).sorted(L.order_by("filename")) 123 | ): 124 | # make context detail 125 | details.append( 126 | { 127 | "name": "occurrence[{}]".format(i), 128 | "detail": SourceHelper.format_given_filepath_and_lineno( 129 | oc.filename, oc.lineno 130 | ), 131 | } 132 | ) 133 | 134 | # combine details 135 | detail_str = make_detail_str(details) 136 | oc = d.sorted(L.value_priority)[0] 137 | detail_syntax = Syntax(detail_str, "python", theme="monokai") 138 | table.add_row( 139 | k, 140 | str(type(oc.value).__name__), 141 | str(make_value_illu(oc.value)), 142 | detail_syntax, 143 | ) 144 | """ 145 | 146 | for node in sorted(mgr.tree.flatten(), key=lambda x: x.name): 147 | details = [] 148 | for i, oc in enumerate(sorted(node.db, key=lambda x: x.filename)): 149 | # make context detail 150 | details.append( 151 | { 152 | "name": "occurrence[{}]".format(i), 153 | "detail": SourceHelper.format_given_filepath_and_lineno( 154 | oc.filename, oc.lineno 155 | ), 156 | } 157 | ) 158 | 159 | # combine details 160 | detail_str = make_detail_str(details) 161 | detail_syntax = Syntax(detail_str, "python", theme="monokai") 162 | table.add_row( 163 | node.name, 164 | str(type(node.value).__name__), 165 | str(make_value_illu(node.value)), 166 | detail_syntax, 167 | ) 168 | 169 | console.print(table) 170 | 171 | 172 | def parse_action_list(inject_actions: Union[bool, List[str]]) -> List[str]: 173 | """Parse inputs to inject actions. 174 | 175 | :param inject_actions: see :func:`.bind` for detail 176 | :return: a list of action names 177 | """ 178 | if isinstance(inject_actions, bool): 179 | inject_actions = {True: ["save", "load", "list", "detail"], False: []}[ 180 | inject_actions 181 | ] 182 | return inject_actions 183 | 184 | 185 | def _get_argument_type_by_value(value): 186 | typ = type(value) 187 | if isinstance(value, (list, dict)): 188 | 189 | def type_func(s): 190 | if isinstance(s, typ): 191 | eval_val = s 192 | else: 193 | assert isinstance(s, str), type(s) 194 | eval_val = ast.literal_eval(s) 195 | 196 | if not isinstance(eval_val, typ): 197 | raise TypeError("value `{}` is not of type {}".format(eval_val, typ)) 198 | return eval_val 199 | 200 | return type_func 201 | return typ 202 | 203 | 204 | def str2bool(v): 205 | """Parsing a string into a bool type. 206 | 207 | :param v: A string that needs to be parsed. 208 | 209 | :return: True or False 210 | """ 211 | if v.lower() in ["yes", "true", "t", "y", "1"]: 212 | return True 213 | elif v.lower() in ["no", "false", "f", "n", "0"]: 214 | return False 215 | else: 216 | raise argparse.ArgumentTypeError("Unsupported value encountered.") 217 | 218 | 219 | def inject_args( 220 | parser: argparse.ArgumentParser, 221 | hp_mgr: hpman.HyperParameterManager, 222 | *, 223 | inject_actions: List[str], 224 | action_prefix: str, 225 | serial_format: str, 226 | show_defaults: bool, 227 | ) -> argparse.ArgumentParser: 228 | """Inject hpman parsed hyperparameter settings into argparse arguments. 229 | Only a limited set of format are supported. See code for details. 230 | 231 | :param parser: Use given parser object of `class`:`argparse.ArgumentParser`. 232 | :param hp_mgr: A `class`:`hpman.HyperParameterManager` object. 233 | 234 | :param inject_actions: A list of actions names to inject 235 | :param action_prefix: Prefix for hpargparse related options 236 | :param serial_format: One of 'yaml' and 'pickle' 237 | :param show_defaults: Show default values 238 | 239 | :return: The injected parser. 240 | """ 241 | 242 | if show_defaults: 243 | parser.formatter_class = argparse.ArgumentDefaultsHelpFormatter 244 | 245 | # Default value will be shown when using argparse.ArgumentDefaultsHelpFormatter 246 | # only if a help message is present. This is the behavior of argparse. 247 | 248 | value_names_been_set = set() 249 | 250 | def _make_value_names_been_set_injection(name, func): 251 | @functools.wraps(func) 252 | def wrapper(string): 253 | # when isinstance(default, string), 254 | # the `parser.parse_args()` will run type(default) automaticly. 255 | # value_names_been_set should ignore these names. 256 | if not isinstance(string, StringAsDefault): 257 | value_names_been_set.add(name) 258 | return func(string) 259 | 260 | return wrapper 261 | 262 | def get_node_attr(node, attr_name): 263 | for oc in node.db: 264 | if hasattr(oc, attr_name): 265 | return getattr(oc, attr_name) 266 | 267 | if attr_name in oc.hints: 268 | return oc.hints[attr_name] 269 | 270 | return None 271 | 272 | # add options for collected hyper-parameters 273 | for node in hp_mgr.get_nodes(): 274 | k = get_node_attr(node, "name") 275 | v = node.get().value 276 | 277 | # this is just a simple hack 278 | option_name = "--{}".format(k.replace("_", "-")) 279 | 280 | value_type = _get_argument_type_by_value(v) 281 | type_str = value_type.__name__ 282 | help = ( 283 | get_node_attr(node, "help") or f"A {type_str} hyper-parameter named `{k}`." 284 | ) 285 | choices = get_node_attr(node, "choices") 286 | required = get_node_attr(node, "required") 287 | other_kwargs = { 288 | "choices": choices, 289 | "required": required, 290 | "help": help, 291 | } 292 | 293 | if value_type == bool: 294 | # argparse does not directly support bool types. 295 | other_kwargs.update(choices=[True, False]) 296 | parser.add_argument( 297 | option_name, 298 | type=_make_value_names_been_set_injection(k, str2bool), 299 | default=v, 300 | **other_kwargs, 301 | ) 302 | else: 303 | if isinstance(v, str): 304 | # if isinstance(v, str), mark as StringAsDefault 305 | v = StringAsDefault(v) 306 | 307 | parser.add_argument( 308 | option_name, 309 | type=_make_value_names_been_set_injection( 310 | k, _get_argument_type_by_value(v) 311 | ), 312 | default=v, 313 | **other_kwargs, 314 | ) 315 | 316 | make_option = lambda name: "--{}-{}".format(action_prefix, name) 317 | 318 | for action in inject_actions: 319 | if action == "list": 320 | parser.add_argument( 321 | make_option("list"), 322 | action="store", 323 | default=None, 324 | const="yaml", 325 | nargs="?", 326 | choices=["detail", "yaml", "json"], 327 | help=( 328 | "List all available hyperparameters. If `{} detail` is" 329 | " specified, a verbose table will be print" 330 | ).format(make_option("list")), 331 | ) 332 | elif action == "detail": 333 | parser.add_argument( 334 | make_option("detail"), 335 | action="store_true", 336 | help="Shorthand for --hp-list detail", 337 | ) 338 | elif action == "save": 339 | parser.add_argument( 340 | make_option("save"), 341 | help=( 342 | "Save hyperparameters to a file. The hyperparameters" 343 | " are saved after processing of all other options" 344 | ), 345 | ) 346 | 347 | elif action == "load": 348 | parser.add_argument( 349 | make_option("load"), 350 | help=( 351 | "Load hyperparameters from a file. The hyperparameters" 352 | " are loaded before any other options are processed" 353 | ), 354 | ) 355 | 356 | if "load" in inject_actions or "save" in inject_actions: 357 | parser.add_argument( 358 | make_option("serial-format"), 359 | default=serial_format, 360 | choices=config.HP_SERIAL_FORMAT_CHOICES, 361 | help=( 362 | "Format of the saved config file. Defaults to {}." 363 | " It can be set to override auto file type deduction." 364 | ).format(serial_format), 365 | ) 366 | 367 | if inject_actions: 368 | parser.add_argument( 369 | make_option("exit"), 370 | action="store_true", 371 | help="process all hpargparse actions and quit", 372 | ) 373 | 374 | def __hpargparse_value_names_been_set(self): 375 | return value_names_been_set 376 | 377 | parser.__hpargparse_value_names_been_set = MethodType( 378 | __hpargparse_value_names_been_set, parser 379 | ) 380 | 381 | return parser 382 | 383 | 384 | def _infer_file_format(path): 385 | name, ext = os.path.splitext(path) 386 | supported_exts = { 387 | ".yaml": "yaml", 388 | ".yml": "yaml", 389 | ".pickle": "pickle", 390 | ".pkl": "pickle", 391 | } 392 | 393 | if ext in supported_exts: 394 | return supported_exts[ext] 395 | raise ValueError( 396 | "Unsupported file extension: {} of path {}".format(ext, path), 397 | "Supported file extensions: {}".format( 398 | ", ".join("`{}`".format(i) for i in sorted(supported_exts)) 399 | ), 400 | ) 401 | 402 | 403 | def hp_save(path: str, hp_mgr: hpman.HyperParameterManager, serial_format: str): 404 | """Save(serialize) hyperparamters. 405 | 406 | :param path: Where to save 407 | :param hp_mgr: The HyperParameterManager to be saved. 408 | :param serial_format: The saving format. 409 | 410 | :see: :func:`.bind` for more detail. 411 | """ 412 | values = hp_mgr.get_values() 413 | 414 | if serial_format == "auto": 415 | serial_format = _infer_file_format(path) 416 | 417 | if serial_format == "yaml": 418 | with open(path, "w") as f: 419 | yaml.dump(values, f) 420 | else: 421 | assert serial_format == "pickle", serial_format 422 | with open(path, "wb") as f: 423 | dill.dump(values, f) 424 | 425 | 426 | def hp_load(path, hp_mgr, serial_format): 427 | """Load(deserialize) hyperparamters. 428 | 429 | :param path: Where to load 430 | :param hp_mgr: The HyperParameterManager to be set. 431 | :param serial_format: The saving format. 432 | 433 | :see: :func:`.bind` for more detail. 434 | """ 435 | if serial_format == "auto": 436 | serial_format = _infer_file_format(path) 437 | 438 | if serial_format == "yaml": 439 | with open(path, "r") as f: 440 | values = yaml.safe_load(f) 441 | else: 442 | assert serial_format == "pickle", serial_format 443 | with open(path, "rb") as f: 444 | values = dill.load(f) 445 | 446 | old_values = hp_mgr.get_values() 447 | new_values = {} 448 | for k, v in values.items(): 449 | if k in old_values: 450 | old_v = old_values[k] 451 | try: 452 | new_values[k] = _get_argument_type_by_value(old_v)(v) 453 | except TypeError as e: 454 | e.args = ("Error parsing hyperparameter `{}`".format(k),) + e.args 455 | raise 456 | 457 | hp_mgr.set_values(new_values) 458 | 459 | 460 | def bind( 461 | parser: argparse.ArgumentParser, 462 | hp_mgr: hpman.HyperParameterManager, 463 | *, 464 | inject_actions: Union[bool, List[str]] = True, 465 | action_prefix: str = config.HP_ACTION_PREFIX_DEFAULT, 466 | serial_format: str = config.HP_SERIAL_FORMAT_DEFAULT, 467 | show_defaults: bool = True, 468 | ): 469 | """Bridging the gap between argparse and hpman. This is 470 | the most important method. Once bounded, hpargparse 471 | will do the rest for you. 472 | 473 | :param parser: A `class`:`argparse.ArgumentParser` object 474 | :param hp_mgr: The hyperparameter manager from `hpman`. It is 475 | usually an 'underscore' variable obtained by `from hpman.m import _` 476 | :param inject_actions: A list of actions names to inject, or True, to 477 | inject all available actions. Available actions are 'save', 'load', 478 | 'detail' and 'list' 479 | :param action_prefix: Prefix for options of hpargparse injected additional 480 | actions. e.g., the default action_prefix is 'hp'. Therefore, the 481 | command line options added by :func:`.bind` will be '--hp-save', 482 | '--hp-load', '--hp-list', etc. 483 | :param serial_format: One of 'auto', 'yaml' and 'pickle'. Defaults to 484 | 'auto'. In most cases you need not to alter this argument as long as 485 | you give the right file extension when using save and load action. To 486 | be specific, '.yaml' and '.yml' would be deemed as yaml format, and 487 | '.pickle' and '.pkl' would be seen as pickle format. 488 | :param show_defaults: Show the default value in help messages. 489 | 490 | :note: pickle is done by `dill` to support pickling of more types. 491 | """ 492 | 493 | # make action list to be injected 494 | inject_actions = parse_action_list(inject_actions) 495 | 496 | args_set_getter = inject_args( 497 | parser, 498 | hp_mgr, 499 | inject_actions=inject_actions, 500 | action_prefix=action_prefix, 501 | serial_format=serial_format, 502 | show_defaults=show_defaults, 503 | ) 504 | 505 | # hook parser.parse_known_args 506 | parser._original_parse_known_args = parser.parse_known_args 507 | 508 | def new_parse_known_args(self, *args, **kwargs): 509 | args, extras = self._original_parse_known_args(*args, **kwargs) 510 | 511 | get_action_value = lambda name: getattr( 512 | args, "{}_{}".format(action_prefix, name) 513 | ) 514 | 515 | # load saved hyperparameter instance 516 | load_value = get_action_value("load") 517 | if "load" in inject_actions and load_value is not None: 518 | hp_load(load_value, hp_mgr, serial_format) 519 | 520 | # set hyperparameters set from command lines 521 | old_values = hp_mgr.get_values() 522 | for k in self.__hpargparse_value_names_been_set(): 523 | v = old_values[k] 524 | assert hasattr(args, k) 525 | t = getattr(args, k) 526 | if isinstance(t, StringAsDefault): 527 | t = str(t) 528 | hp_mgr.set_value(k, t) 529 | 530 | save_value = get_action_value("save") 531 | if "save" in inject_actions and save_value is not None: 532 | hp_save(save_value, hp_mgr, serial_format) 533 | 534 | # `--hp-detail`` need to preceed `--hp-list`` because `--hp-list detail` 535 | # will be set by default. 536 | if "detail" in inject_actions and get_action_value("detail"): 537 | hp_list(hp_mgr) 538 | sys.exit(0) 539 | 540 | hp_list_value = get_action_value("list") 541 | if "list" in inject_actions and hp_list_value is not None: 542 | if hp_list_value == "yaml": 543 | syntax = Syntax( 544 | yaml.dump(hp_mgr.get_values()).replace("\n\n", "\n"), 545 | "yaml", 546 | theme="monokai", 547 | ) 548 | console = Console() 549 | console.print(syntax) 550 | elif hp_list_value == "json": 551 | syntax = Syntax( 552 | json.dumps(hp_mgr.get_values()), "json", theme="monokai" 553 | ) 554 | console = Console() 555 | console.print(syntax) 556 | else: 557 | assert hp_list_value == "detail", hp_list_value 558 | hp_list(hp_mgr) 559 | 560 | sys.exit(0) 561 | 562 | if inject_actions and get_action_value("exit"): 563 | sys.exit(0) 564 | 565 | return args, extras 566 | 567 | parser.parse_known_args = MethodType(new_parse_known_args, parser) 568 | --------------------------------------------------------------------------------