├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── benchmark.yml │ ├── build.yml │ ├── python-publish.yml │ └── upload.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── _config.yml ├── codecov.yml ├── docs ├── README.md └── Roadmap.md ├── examples ├── ML │ ├── README.md │ ├── ml_example_after.py │ ├── ml_example_before.py │ └── other_ml_example.py ├── README.md ├── __init__.py ├── aliases │ ├── README.md │ └── aliases_example.py ├── config_files │ ├── README.md │ ├── composition.py │ ├── composition_defaults.yaml │ ├── config_a.yaml │ ├── config_b.yaml │ ├── many_configs.py │ ├── many_configs.yaml │ ├── one_config.py │ └── one_config.yaml ├── container_types │ ├── README.md │ └── lists_example.py ├── custom_args │ ├── README.md │ └── custom_args_example.py ├── dataclasses │ ├── README.md │ ├── dataclass_example.py │ └── hyperparameters_example.py ├── demo.py ├── demo_simple.py ├── docstrings │ ├── README.md │ └── docstrings_example.py ├── enums │ ├── README.md │ └── enums_example.py ├── inheritance │ ├── README.md │ ├── inheritance_example.py │ ├── ml_inheritance.py │ └── ml_inheritance_2.py ├── merging │ ├── README.md │ ├── multiple_example.py │ └── multiple_lists_example.py ├── nesting │ ├── README.md │ └── nesting_example.py ├── partials │ ├── README.md │ └── partials_example.py ├── prefixing │ ├── README.md │ └── manual_prefix_example.py ├── serialization │ ├── README.md │ ├── bob.json │ ├── custom_types_example.py │ ├── serialization_example.ipynb │ └── serialization_example.py ├── simple │ ├── _before.py │ ├── basic.py │ ├── choice.py │ ├── flag.py │ ├── help.py │ ├── inheritance.py │ ├── option_strings.py │ ├── reuse.py │ └── to_json.py ├── subgroups │ ├── README.md │ └── subgroups_example.py ├── subparsers │ ├── README.md │ ├── optional_subparsers.py │ └── subparsers_example.py └── ugly │ ├── ugly_example_after.py │ └── ugly_example_before.py ├── pyproject.toml ├── requirements-test.txt ├── simple_parsing ├── __init__.py ├── annotation_utils │ ├── __init__.py │ └── get_field_annotations.py ├── conflicts.py ├── decorators.py ├── docstring.py ├── help_formatter.py ├── helpers │ ├── __init__.py │ ├── custom_actions.py │ ├── fields.py │ ├── flatten.py │ ├── hparams │ │ ├── __init__.py │ │ ├── hparam.py │ │ ├── hyperparameters.py │ │ ├── hyperparameters_test.py │ │ ├── priors.py │ │ ├── priors_test.py │ │ └── utils.py │ ├── nested_partial.py │ ├── partial.py │ ├── serialization │ │ ├── __init__.py │ │ ├── decoding.py │ │ ├── encoding.py │ │ ├── serializable.py │ │ └── yaml_serialization.py │ └── subgroups.py ├── parsing.py ├── py.typed ├── replace.py ├── utils.py └── wrappers │ ├── __init__.py │ ├── dataclass_wrapper.py │ ├── field_metavar.py │ ├── field_parsing.py │ ├── field_wrapper.py │ └── wrapper.py ├── test ├── __init__.py ├── conftest.py ├── foo.py ├── helpers │ ├── __init__.py │ ├── test_encoding.py │ ├── test_encoding │ │ ├── test_encoding_with_dc_types__json_obj0_.json │ │ ├── test_encoding_with_dc_types__json_obj1_.json │ │ ├── test_encoding_with_dc_types__yaml_obj0_.yaml │ │ └── test_encoding_with_dc_types__yaml_obj1_.yaml │ ├── test_enum_serialization.py │ ├── test_from_dict.py │ ├── test_partial.py │ ├── test_partial_postponed.py │ ├── test_save.py │ └── test_serializable.py ├── nesting │ ├── __init__.py │ ├── example_use_cases.py │ ├── test_default_factory_help_strings.py │ ├── test_nesting_auto.py │ ├── test_nesting_defaults.py │ ├── test_nesting_explicit.py │ ├── test_nesting_merge.py │ ├── test_nesting_simple.py │ └── test_weird_use_cases.py ├── postponed_annotations │ ├── __init__.py │ ├── a.py │ ├── b.py │ ├── multi_inherits.py │ ├── overwrite_attribute.py │ ├── overwrite_base.py │ ├── overwrite_subclass.py │ └── test_postponed_annotations.py ├── test_aliases.py ├── test_base.py ├── test_bools.py ├── test_choice.py ├── test_conf_path.py ├── test_conflicts.py ├── test_custom_args.py ├── test_decoding.py ├── test_decorator.py ├── test_default_args.py ├── test_docstrings.py ├── test_examples.py ├── test_fields.py ├── test_forward_ref.py ├── test_future_annotations.py ├── test_generation_mode.py ├── test_huggingface_compat.py ├── test_inheritance.py ├── test_initvar.py ├── test_issue64.py ├── test_issue_107.py ├── test_issue_132.py ├── test_issue_144.py ├── test_issue_46.py ├── test_issue_48.py ├── test_issue_96.py ├── test_lists.py ├── test_literal.py ├── test_multiple.py ├── test_optional.py ├── test_optional_subparsers.py ├── test_optional_union.py ├── test_performance.py ├── test_positional.py ├── test_replace.py ├── test_replace_subgroups.py ├── test_set_defaults.py ├── test_subgroups.py ├── test_subgroups │ ├── test_help[Config---help].md │ ├── test_help[Config---model=model_a --help].md │ ├── test_help[Config---model=model_b --help].md │ ├── test_help[ConfigWithFrozen---conf=even --a 100 --help].md │ ├── test_help[ConfigWithFrozen---conf=even --help].md │ ├── test_help[ConfigWithFrozen---conf=odd --a 123 --help].md │ ├── test_help[ConfigWithFrozen---conf=odd --help].md │ └── test_help[ConfigWithFrozen---help].md ├── test_subparsers.py ├── test_suppress.py ├── test_tuples.py ├── test_union.py ├── test_utils.py ├── testutils.py └── utils │ ├── __init__.py │ ├── test_flattened.py │ ├── test_mutable_field.py │ └── test_yaml.py └── uv.lock /.gitattributes: -------------------------------------------------------------------------------- 1 | simple_parsing/_version.py export-subst 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | ```python 15 | from simple_parsing import ArgumentParser 16 | from dataclasses import dataclass 17 | 18 | @dataclass 19 | class Foo: 20 | bar: int = 123 21 | 22 | if __name__ == "__main__": 23 | parser = ArgumentParser() 24 | parser.add_arguments(Foo, "foo") 25 | args = parser.parse_args() 26 | foo: Foo = args.foo 27 | print(foo) 28 | ``` 29 | 30 | **Expected behavior** 31 | A clear and concise description of what you expected to happen. 32 | 33 | ```console 34 | $ python issue.py 35 | Foo(bar=123) 36 | $ python issue.py --bar 456 37 | Foo(bar=456) 38 | ``` 39 | 40 | **Actual behavior** 41 | A clear and concise description of what is happening. 42 | 43 | ```console 44 | $ python issue.py 45 | Foo(bar=123) 46 | $ python issue.py --bar 456 47 | Foo(bar=456) 48 | ``` 49 | 50 | **Desktop (please complete the following information):** 51 | - Version [e.g. 22] 52 | - Python version: ? 53 | 54 | **Additional context** 55 | Add any other context about the problem here. 56 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/benchmark.yml: -------------------------------------------------------------------------------- 1 | name: Benchmark 2 | 3 | # Do not run this workflow on pull request since this workflow has permission to modify contents. 4 | on: 5 | push: 6 | branches: 7 | - master 8 | workflow_dispatch: {} 9 | 10 | permissions: 11 | # deployments permission to deploy GitHub pages website 12 | contents: write 13 | # contents permission to update benchmark contents in gh-pages branch 14 | deployments: write 15 | 16 | jobs: 17 | benchmark: 18 | name: Run benchmark-action 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Install poetry 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install poetry 26 | - name: Set up Python 3.11 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: 3.11 30 | cache: poetry 31 | 32 | - name: Install dependencies 33 | run: poetry install --all-extras 34 | 35 | - name: Download previous benchmark data 36 | uses: actions/cache@v3 37 | with: 38 | path: ./cache 39 | key: ${{ runner.os }}-benchmark 40 | 41 | - name: Run benchmark 42 | run: | 43 | poetry run pytest --benchmark-only --cov=simple_parsing --cov-report=xml --cov-append --benchmark-json=.benchmark_output.json 44 | 45 | - name: Store benchmark result 46 | uses: benchmark-action/github-action-benchmark@v1 47 | with: 48 | name: Python Benchmark with pytest-benchmark 49 | tool: 'pytest' 50 | # Where the output from the benchmark tool is stored 51 | output-file-path: .benchmark_output.json 52 | # # Where the previous data file is stored 53 | external-data-json-path: ./cache/benchmark-master.json 54 | # Use personal access token instead of GITHUB_TOKEN due to https://github.community/t/github-action-not-triggering-gh-pages-upon-push/16096 55 | github-token: ${{ secrets.GITHUB_TOKEN }} 56 | # NOTE: auto-push must be false when external-data-json-path is set since this action 57 | # reads/writes the given JSON file and never pushes to remote 58 | auto-push: false 59 | # Show alert with commit comment on detecting possible performance regression 60 | alert-threshold: '150%' 61 | comment-on-alert: true 62 | # Enable Job Summary for PRs 63 | summary-always: true 64 | # Workflow will fail when an alert happens 65 | fail-on-alert: true 66 | alert-comment-cc-users: '@lebrice' 67 | 68 | - name: Upload coverage reports to Codecov 69 | uses: codecov/codecov-action@v3 70 | env: 71 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 72 | with: 73 | env_vars: PLATFORM,PYTHON 74 | name: codecov-umbrella 75 | fail_ci_if_error: false 76 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | 8 | jobs: 9 | linting: 10 | name: Run linting/pre-commit checks 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Install uv 16 | uses: astral-sh/setup-uv@v5 17 | with: 18 | version: "0.5.25" 19 | python-version: '3.11' 20 | - run: uvx pre-commit --version 21 | - run: uvx pre-commit run --all-files 22 | 23 | build: 24 | needs: [linting] 25 | runs-on: ubuntu-latest 26 | strategy: 27 | matrix: 28 | python-version: ["3.9", "3.10", "3.11", "3.12"] 29 | 30 | steps: 31 | - uses: actions/checkout@v4 32 | 33 | - name: Install uv 34 | uses: astral-sh/setup-uv@v5 35 | with: 36 | version: "0.5.25" 37 | python-version: ${{ matrix.python-version }} 38 | 39 | 40 | - name: Install the project (no extras) 41 | run: uv sync 42 | 43 | - name: Unit tests with Pytest (no extras) 44 | timeout-minutes: 3 45 | run: | 46 | uv run pytest --benchmark-disable --cov=simple_parsing --cov-report=xml --cov-append 47 | 48 | - name: Install extra dependencies 49 | run: uv sync --all-extras 50 | 51 | - name: Unit tests with Pytest (with extra dependencies) 52 | timeout-minutes: 3 53 | run: | 54 | uv run pytest --benchmark-disable --cov=simple_parsing --cov-report=xml --cov-append 55 | 56 | - name: Upload coverage reports to Codecov 57 | uses: codecov/codecov-action@v4 58 | env: 59 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 60 | with: 61 | env_vars: PLATFORM,PYTHON 62 | name: codecov-umbrella 63 | fail_ci_if_error: false 64 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [published] 9 | workflow_dispatch: {} 10 | 11 | jobs: 12 | publish: 13 | strategy: 14 | matrix: 15 | python-version: [3.9] 16 | os: [ubuntu-latest] 17 | runs-on: ${{ matrix.os }} 18 | steps: 19 | - uses: actions/checkout@v4 20 | - name: Install poetry 21 | run: pipx install poetry 22 | 23 | - name: Install uv 24 | uses: astral-sh/setup-uv@v5 25 | with: 26 | version: "0.5.25" 27 | python-version: ${{ matrix.python-version }} 28 | 29 | - name: Install dependencies 30 | run: | 31 | uv sync 32 | 33 | - name: Publish package 34 | env: 35 | UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }} 36 | run: | 37 | uv build 38 | uv publish 39 | -------------------------------------------------------------------------------- /.github/workflows/upload.yml: -------------------------------------------------------------------------------- 1 | name: Upload the benchmark results 2 | on: 3 | push: 4 | branches: 5 | - master 6 | - temp_master 7 | workflow_dispatch: {} 8 | 9 | jobs: 10 | upload_benchmark_results: 11 | runs-on: ubuntu-latest 12 | name: Test out the action in this repository 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Install poetry 16 | run: | 17 | python -m pip install --upgrade pip 18 | pip install poetry 19 | - name: Set up Python 3.11 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: 3.11 23 | cache: poetry 24 | 25 | - name: Install dependencies 26 | run: poetry install --all-extras 27 | 28 | - name: Run the benchmark 29 | run: poetry run pytest --benchmark-only --benchmark-json=benchmark_results.json 30 | 31 | - name: Upload the file 32 | uses: actions/upload-artifact@v2 33 | with: 34 | name: benchmark_results 35 | path: benchmark_results.json 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | default.json 2 | 3 | # Created by https://www.gitignore.io/api/python 4 | # Edit at https://www.gitignore.io/?templates=python 5 | 6 | ### Python ### 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # pipenv 76 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 77 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 78 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 79 | # install all needed dependencies. 80 | #Pipfile.lock 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | .spyproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | 95 | # Mr Developer 96 | .mr.developer.cfg 97 | .project 98 | .pydevproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .dmypy.json 106 | dmypy.json 107 | 108 | # Pyre type checker 109 | .pyre/ 110 | 111 | # End of https://www.gitignore.io/api/python 112 | .vscode 113 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v5.0.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | require_serial: true 11 | - id: end-of-file-fixer 12 | require_serial: true 13 | # - id: check-docstring-first 14 | - id: check-yaml 15 | require_serial: true 16 | - id: debug-statements 17 | require_serial: true 18 | - id: detect-private-key 19 | require_serial: true 20 | - id: check-executables-have-shebangs 21 | require_serial: true 22 | - id: check-toml 23 | require_serial: true 24 | - id: check-case-conflict 25 | require_serial: true 26 | - id: check-added-large-files 27 | require_serial: true 28 | 29 | - repo: https://github.com/charliermarsh/ruff-pre-commit 30 | # Ruff version. 31 | rev: 'v0.1.14' 32 | hooks: 33 | # Run the linter. 34 | - id: ruff 35 | require_serial: true 36 | args: ["--fix"] 37 | # Run the formatter. 38 | - id: ruff-format 39 | require_serial: true 40 | 41 | # python docstring formatting 42 | - repo: https://github.com/myint/docformatter 43 | rev: eb1df347edd128b30cd3368dddc3aa65edcfac38 # Don't autoupdate until https://github.com/PyCQA/docformatter/issues/293 is fixed 44 | hooks: 45 | - id: docformatter 46 | exclude: ^test/test_docstrings.py 47 | require_serial: true 48 | additional_dependencies: [tomli] 49 | 50 | 51 | # NOTE: Disabling this, since I'm having the glib-c2.29 weird bug. 52 | # # yaml formatting 53 | # - repo: https://github.com/pre-commit/mirrors-prettier 54 | # rev: v2.7.1 55 | # hooks: 56 | # - id: prettier 57 | # types: [yaml] 58 | 59 | # jupyter notebook cell output clearing 60 | - repo: https://github.com/kynan/nbstripout 61 | rev: 0.6.1 62 | hooks: 63 | - id: nbstripout 64 | require_serial: true 65 | 66 | 67 | # md formatting 68 | - repo: https://github.com/executablebooks/mdformat 69 | rev: 0.7.21 70 | hooks: 71 | - id: mdformat 72 | args: ["--number"] 73 | exclude: ^.github/ISSUE_TEMPLATE/.*\.md$ 74 | additional_dependencies: 75 | - mdformat-gfm 76 | - mdformat-tables 77 | - mdformat_frontmatter 78 | # - mdformat-toc 79 | # - mdformat-black 80 | require_serial: true 81 | 82 | # word spelling linter 83 | - repo: https://github.com/codespell-project/codespell 84 | rev: v2.2.2 85 | hooks: 86 | - id: codespell 87 | args: 88 | - --skip=logs/**,data/** 89 | # - --ignore-words-list=abc,def 90 | require_serial: true 91 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Fabrice Normandin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman 2 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: auto # auto compares coverage to the previous base commit 6 | informational: true 7 | patch: 8 | default: 9 | target: 100% 10 | informational: true 11 | 12 | github_checks: 13 | annotations: true 14 | # When modifying this file, please validate using 15 | # curl -X POST --data-binary @codecov.yml https://codecov.io/validate 16 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Simple-Parsing API Documentation 2 | 3 | API documentation is still under construction. 4 | This package is quite simple, with only a single public class of interest: the [`simple_parsing.ArgumentParser` class](https://github.com/lebrice/SimpleParsing/blob/master/simple_parsing/parsing.py#L43). 5 | 6 | For the time being, please take a look at the [Examples section](https://github.com/lebrice/SimpleParsing/tree/master/examples), which provides a fairly decent overview of the current features of the package. 7 | -------------------------------------------------------------------------------- /docs/Roadmap.md: -------------------------------------------------------------------------------- 1 | ## Currently supported features: 2 | 3 | - Parsing of attributes of built-in types: 4 | - `int`, `float`, `str` attributes 5 | - `bool` attributes (using either the `--` or the `-- ` syntax) 6 | - `list` attributes 7 | - `tuple` attributes 8 | - Parsing of multiple instances of a given dataclass, for the above-mentioned attribute types 9 | - Nested parsing of instances (dataclasses within dataclasses) 10 | 11 | ## Possible Future Enhancements: 12 | 13 | - Parsing two different dataclasses which share a base class (this currently would cause a conflict for the base class arguments. 14 | -------------------------------------------------------------------------------- /examples/ML/README.md: -------------------------------------------------------------------------------- 1 | ## Use-Case Example: ML Scripts 2 | 3 | Let's look at a great use-case for `simple-parsing`: ugly ML code: 4 | 5 | ### Before: 6 | 7 | ```python 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | # hyperparameters 13 | parser.add_argument("--learning_rate", type=float, default=0.05) 14 | parser.add_argument("--momentum", type=float, default=0.01) 15 | # (... other hyperparameters here) 16 | 17 | # args for training config 18 | parser.add_argument("--data_dir", type=str, default="/data") 19 | parser.add_argument("--log_dir", type=str, default="/logs") 20 | parser.add_argument("--checkpoint_dir", type=str, default="checkpoints") 21 | 22 | args = parser.parse_args() 23 | 24 | learning_rate = args.learning_rate 25 | momentum = args.momentum 26 | # (...) dereference all the variables here, without any typing 27 | data_dir = args.data_dir 28 | log_dir = args.log_dir 29 | checkpoint_dir = args.checkpoint_dir 30 | 31 | class MyModel(): 32 | def __init__(self, data_dir, log_dir, checkpoint_dir, learning_rate, momentum, *args): 33 | # config: 34 | self.data_dir = data_dir 35 | self.log_dir = log_dir 36 | self.checkpoint_dir = checkpoint_dir 37 | 38 | # hyperparameters: 39 | self.learning_rate = learning_rate 40 | self.momentum = momentum 41 | 42 | m = MyModel(data_dir, log_dir, checkpoint_dir, learning_rate, momentum) 43 | # Ok, what if we wanted to add a new hyperparameter?! 44 | ``` 45 | 46 | ### After: 47 | 48 | ```python 49 | from dataclasses import dataclass 50 | from simple_parsing import ArgumentParser 51 | 52 | # create a parser, as usual 53 | parser = ArgumentParser() 54 | 55 | @dataclass 56 | class MyModelHyperParameters: 57 | """Hyperparameters of MyModel""" 58 | # Learning rate of the Adam optimizer. 59 | learning_rate: float = 0.05 60 | # Momentum of the optimizer. 61 | momentum: float = 0.01 62 | 63 | @dataclass 64 | class TrainingConfig: 65 | """Training configuration settings""" 66 | data_dir: str = "/data" 67 | log_dir: str = "/logs" 68 | checkpoint_dir: str = "checkpoints" 69 | 70 | 71 | # automatically add arguments for all the fields of the classes above: 72 | parser.add_arguments(MyModelHyperParameters, dest="hparams") 73 | parser.add_arguments(TrainingConfig, dest="config") 74 | 75 | args = parser.parse_args() 76 | 77 | # Create an instance of each class and populate its values from the command line arguments: 78 | hyperparameters: MyModelHyperParameters = args.hparams 79 | config: TrainingConfig = args.config 80 | 81 | class MyModel(): 82 | def __init__(self, hyperparameters: MyModelHyperParameters, config: TrainingConfig): 83 | # hyperparameters: 84 | self.hyperparameters = hyperparameters 85 | # config: 86 | self.config = config 87 | 88 | m = MyModel(hyperparameters, config) 89 | 90 | ``` 91 | -------------------------------------------------------------------------------- /examples/ML/ml_example_after.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from simple_parsing import ArgumentParser 4 | 5 | # create a parser, as usual 6 | parser = ArgumentParser() 7 | 8 | 9 | @dataclass 10 | class MyModelHyperParameters: 11 | """Hyperparameters of MyModel.""" 12 | 13 | # Learning rate of the Adam optimizer. 14 | learning_rate: float = 0.05 15 | # Momentum of the optimizer. 16 | momentum: float = 0.01 17 | 18 | 19 | @dataclass 20 | class TrainingConfig: 21 | """Training configuration settings.""" 22 | 23 | data_dir: str = "/data" 24 | log_dir: str = "/logs" 25 | checkpoint_dir: str = "checkpoints" 26 | 27 | 28 | # automatically add arguments for all the fields of the classes above: 29 | parser.add_arguments(MyModelHyperParameters, dest="hparams") 30 | parser.add_arguments(TrainingConfig, dest="config") 31 | 32 | args = parser.parse_args() 33 | 34 | # Create an instance of each class and populate its values from the command line arguments: 35 | hyperparameters: MyModelHyperParameters = args.hparams 36 | config: TrainingConfig = args.config 37 | 38 | 39 | class MyModel: 40 | def __init__(self, hyperparameters: MyModelHyperParameters, config: TrainingConfig): 41 | # hyperparameters: 42 | self.hyperparameters = hyperparameters 43 | # config: 44 | self.config = config 45 | 46 | 47 | m = MyModel(hyperparameters, config) 48 | -------------------------------------------------------------------------------- /examples/ML/ml_example_before.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | parser = ArgumentParser() 4 | 5 | # hyperparameters 6 | parser.add_argument("--learning_rate", type=float, default=0.05) 7 | parser.add_argument("--momentum", type=float, default=0.01) 8 | # (... other hyperparameters here) 9 | 10 | # args for training config 11 | parser.add_argument("--data_dir", type=str, default="/data") 12 | parser.add_argument("--log_dir", type=str, default="/logs") 13 | parser.add_argument("--checkpoint_dir", type=str, default="checkpoints") 14 | 15 | args = parser.parse_args() 16 | 17 | learning_rate = args.learning_rate 18 | momentum = args.momentum 19 | # (...) dereference all the variables here, without any typing 20 | data_dir = args.data_dir 21 | log_dir = args.log_dir 22 | checkpoint_dir = args.checkpoint_dir 23 | 24 | 25 | class MyModel: 26 | def __init__(self, data_dir, log_dir, checkpoint_dir, learning_rate, momentum, *args): 27 | # config: 28 | self.data_dir = data_dir 29 | self.log_dir = log_dir 30 | self.checkpoint_dir = checkpoint_dir 31 | 32 | # hyperparameters: 33 | self.learning_rate = learning_rate 34 | self.momentum = momentum 35 | 36 | 37 | m = MyModel(data_dir, log_dir, checkpoint_dir, learning_rate, momentum) 38 | # Ok, what if we wanted to add a new hyperparameter?! 39 | -------------------------------------------------------------------------------- /examples/ML/other_ml_example.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import simple_parsing 4 | 5 | # create a parser, 6 | parser = simple_parsing.ArgumentParser() 7 | 8 | 9 | @dataclass 10 | class MyModelHyperParameters: 11 | """Hyperparameters of MyModel.""" 12 | 13 | # Batch size (per-GPU) 14 | batch_size: int = 32 15 | # Learning rate of the Adam optimizer. 16 | learning_rate: float = 0.05 17 | # Momentum of the optimizer. 18 | momentum: float = 0.01 19 | 20 | 21 | @dataclass 22 | class TrainingConfig: 23 | """Settings related to Training.""" 24 | 25 | data_dir: str = "data" 26 | log_dir: str = "logs" 27 | checkpoint_dir: str = "checkpoints" 28 | 29 | 30 | @dataclass 31 | class EvalConfig: 32 | """Settings related to evaluation.""" 33 | 34 | eval_dir: str = "eval_data" 35 | 36 | 37 | # automatically add arguments for all the fields of the classes above: 38 | parser.add_arguments(MyModelHyperParameters, "hparams") 39 | parser.add_arguments(TrainingConfig, "train_config") 40 | parser.add_arguments(EvalConfig, "eval_config") 41 | 42 | # NOTE: `ArgumentParser` is just a subclass of `argparse.ArgumentParser`, 43 | # so we could add some other arguments as usual: 44 | # parser.add_argument(...) 45 | # parser.add_argument(...) 46 | # (...) 47 | # parser.add_argument(...) 48 | # parser.add_argument(...) 49 | 50 | args = parser.parse_args() 51 | 52 | # Retrieve the objects from the parsed args! 53 | hparams: MyModelHyperParameters = args.hparams 54 | train_config: TrainingConfig = args.train_config 55 | eval_config: EvalConfig = args.eval_config 56 | 57 | print(hparams, train_config, eval_config, sep="\n") 58 | expected = """ 59 | MyModelHyperParameters(batch_size=32, learning_rate=0.05, momentum=0.01) 60 | TrainingConfig(data_dir='data', log_dir='logs', checkpoint_dir='checkpoints') 61 | EvalConfig(eval_dir='eval_data') 62 | """ 63 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | - [dataclasses intro](dataclasses/README.md): Quick intro to Python's new [dataclasses](https://docs.python.org/3.7/library/dataclasses.html) module. 4 | 5 | - **[Simple example](simple/basic.py)**: Simple use-case example with a before/after comparison. 6 | 7 | - [ML-related Examples](ML/README.md) 8 | 9 | - **NEW**: [Subgroups Example](subgroups/README.md) 10 | 11 | 12 | 13 | - [Serialization to `json`/`yaml`](serialization/README.md) 14 | 15 | - [Attribute Docstrings Example](docstrings/README.md) 16 | 17 | - [Parsing of lists and tuples](container_types/README.md) 18 | 19 | - [**Nesting**!!](nesting/README.md) 20 | 21 | - [Prefixing](prefixing/README.md) 22 | 23 | - [Enums Example](enums/README.md) 24 | 25 | - [Subparsers Example](subparsers/README.md) 26 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lebrice/SimpleParsing/2a69cebb38172b9317c0175aadde6d3b432371ce/examples/__init__.py -------------------------------------------------------------------------------- /examples/aliases/README.md: -------------------------------------------------------------------------------- 1 | # Using Aliases 2 | 3 | ## Notes about `option_strings`: 4 | 5 | Additional names for the same argument can be added via the `alias` argument 6 | of the `field` function (see [the custom_args Example](/examples/custom_args/README.md) for more info). 7 | 8 | The `simple_parsing.ArgumentParser` accepts an argument (currently called `add_option_string_dash_variants`, which defaults to False) which adds additional variants to allow using either dashes or underscores to refer to an argument: 9 | 10 | - Whenever the name of an attribute includes underscores ("\_"), the same 11 | argument can be passed by using dashes ("-") instead. This also includes 12 | aliases. 13 | - If an alias contained leading dashes, either single or double, the 14 | same number of dashes will be used, even in the case where a prefix is 15 | added. 16 | For instance, consider the following example. 17 | Here we have two prefixes: `"train"` and `"valid"`. 18 | The corresponding option_strings for each argument will be 19 | `["--train.debug", "-train.d"]` and `["--valid.debug", "-valid.d"]`, 20 | respectively, as shown here: 21 | 22 | ```python 23 | from dataclasses import dataclass 24 | from simple_parsing import ArgumentParser, field 25 | 26 | @dataclass 27 | class RunSettings: 28 | ''' Parameters for a run. ''' 29 | # whether or not to execute in debug mode. 30 | debug: bool = field(alias=["-d"], default=False) 31 | some_value: int = field(alias=["-v"], default=123) 32 | 33 | parser = ArgumentParser(add_option_string_dash_variants=True) 34 | parser.add_arguments(RunSettings, dest="train") 35 | parser.add_arguments(RunSettings, dest="valid") 36 | parser.print_help() 37 | 38 | # This prints: 39 | ''' 40 | usage: test.py [-h] [--train.debug [bool]] [--train.some_value int] 41 | [--valid.debug [bool]] [--valid.some_value int] 42 | 43 | optional arguments: 44 | -h, --help show this help message and exit 45 | 46 | RunSettings ['train']: 47 | Parameters for a run. 48 | 49 | --train.debug [bool], --train.d [bool] 50 | whether or not to execute in debug mode. (default: 51 | False) 52 | --train.some_value int, --train.v int, ---train.some-value int 53 | 54 | RunSettings ['valid']: 55 | Parameters for a run. 56 | 57 | --valid.debug [bool], --valid.d [bool] 58 | whether or not to execute in debug mode. (default: 59 | False) 60 | --valid.some_value int, --valid.v int, ---valid.some-value int 61 | ''' 62 | ``` 63 | -------------------------------------------------------------------------------- /examples/aliases/aliases_example.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from simple_parsing import ArgumentParser, field 4 | 5 | 6 | @dataclass 7 | class RunSettings: 8 | """Parameters for a run.""" 9 | 10 | # whether or not to execute in debug mode. 11 | debug: bool = field(alias=["-d"], default=False) 12 | # whether or not to add a lot of logging information. 13 | verbose: bool = field(alias=["-v"], action="store_true") 14 | 15 | 16 | parser = ArgumentParser(add_option_string_dash_variants=True) 17 | parser.add_arguments(RunSettings, dest="train") 18 | parser.add_arguments(RunSettings, dest="valid") 19 | args = parser.parse_args() 20 | print(args) 21 | # This prints: 22 | expected = """ 23 | Namespace(train=RunSettings(debug=False, verbose=False), valid=RunSettings(debug=False, verbose=False)) 24 | """ 25 | 26 | parser.print_help() 27 | expected += """\ 28 | usage: aliases_example.py [-h] [-train.d bool] [-train.v] [-valid.d bool] 29 | [-valid.v] 30 | 31 | optional arguments: 32 | -h, --help show this help message and exit 33 | 34 | RunSettings ['train']: 35 | Parameters for a run. 36 | 37 | -train.d bool, --train.debug bool, --train.nod bool, --train.nodebug bool 38 | whether or not to execute in debug mode. (default: 39 | False) 40 | -train.v, --train.verbose 41 | whether or not to add a lot of logging information. 42 | (default: False) 43 | 44 | RunSettings ['valid']: 45 | Parameters for a run. 46 | 47 | -valid.d bool, --valid.debug bool, --valid.nod bool, --valid.nodebug bool 48 | whether or not to execute in debug mode. (default: 49 | False) 50 | -valid.v, --valid.verbose 51 | whether or not to add a lot of logging information. 52 | (default: False) 53 | """ 54 | -------------------------------------------------------------------------------- /examples/config_files/README.md: -------------------------------------------------------------------------------- 1 | # Using config files 2 | 3 | Simple-Parsing can use default values from one or more configuration files. 4 | 5 | The `config_path` argument can be passed to the ArgumentParser constructor. The values read from 6 | that file will overwrite the default values from the dataclass definitions. 7 | 8 | Additionally, when the `add_config_path_arg` argument of the `ArgumentParser` constructor is set, 9 | a `--config_path` argument will be added to the parser. This argument accepts one or more paths to configuration 10 | files, whose contents will be read, and used to update the defaults, in the same manner as with the 11 | `config_path` argument above. 12 | 13 | When using both options (the `config_path` parameter of `ArgumentParser.__init__`, as well as the `--config_path` command-line argument), the defaults are first updated using `ArgumentParser.config_path`, and then 14 | updated with the contents of the `--config_path` file(s). 15 | 16 | In other words, the default values are set like so, in increasing priority: 17 | 18 | 1. normal defaults (e.g. from the dataclass definitions) 19 | 2. updated with the contents of the `config_path` file(s) of `ArgumentParser.__init__` 20 | 3. updated with the contents of the `--config_path` file(s) from the command-line. 21 | 22 | ## [Single Config example](one_config.py) 23 | 24 | When using a single config dataclass, the `simple_parsing.parse` function can then be used to simplify the argument parsing setup a bit. 25 | 26 | ## [Multiple Configs](many_configs.py) 27 | 28 | Config files can also be used when defining multiple config dataclasses with the same parser. 29 | 30 | ## [Composition (WIP)](composition.py) 31 | 32 | (Coming soon): Multiple config files can be composed together à-la Hydra! 33 | -------------------------------------------------------------------------------- /examples/config_files/composition.py: -------------------------------------------------------------------------------- 1 | """Example where we compose different configurations!""" 2 | 3 | import shlex 4 | from dataclasses import dataclass 5 | 6 | import simple_parsing 7 | 8 | 9 | @dataclass 10 | class Foo: 11 | a: str = "default value for `a` (from the dataclass definition of Foo)" 12 | 13 | 14 | @dataclass 15 | class Bar: 16 | b: str = "default value for `b` (from the dataclass definition of Bar)" 17 | 18 | 19 | @dataclass 20 | class Baz: 21 | c: str = "default value for `c` (from the dataclass definition of Baz)" 22 | 23 | 24 | def main(args=None) -> None: 25 | """Example using composition of different configurations.""" 26 | parser = simple_parsing.ArgumentParser( 27 | add_config_path_arg=True, config_path="composition_defaults.yaml" 28 | ) 29 | 30 | parser.add_arguments(Foo, dest="foo") 31 | parser.add_arguments(Bar, dest="bar") 32 | parser.add_arguments(Baz, dest="baz") 33 | 34 | if isinstance(args, str): 35 | args = shlex.split(args) 36 | args = parser.parse_args(args) 37 | 38 | foo: Foo = args.foo 39 | bar: Bar = args.bar 40 | baz: Baz = args.baz 41 | print(f"foo: {foo}") 42 | print(f"bar: {bar}") 43 | print(f"baz: {baz}") 44 | 45 | 46 | main() 47 | expected = """ 48 | foo: Foo(a="default value for `a` from the Parser's `config_path` (composition_defaults.yaml)") 49 | bar: Bar(b="default value for `b` from the Parser's `config_path` (composition_defaults.yaml)") 50 | baz: Baz(c="default value for `c` from the Parser's `config_path` (composition_defaults.yaml)") 51 | """ 52 | 53 | main("--a 'Value passed from the command-line.' --config_path config_b.yaml") 54 | expected += """\ 55 | foo: Foo(a='Value passed from the command-line.') 56 | bar: Bar(b='default value for `b` from the config_b.yaml file') 57 | baz: Baz(c="default value for `c` from the Parser's `config_path` (composition_defaults.yaml)") 58 | """ 59 | 60 | main("--a 'Value passed from the command-line.' --config_path config_a.yaml config_b.yaml") 61 | expected += """\ 62 | foo: Foo(a='Value passed from the command-line.') 63 | bar: Bar(b='default value for `b` from the config_b.yaml file') 64 | baz: Baz(c="default value for `c` from the Parser's `config_path` (composition_defaults.yaml)") 65 | """ 66 | -------------------------------------------------------------------------------- /examples/config_files/composition_defaults.yaml: -------------------------------------------------------------------------------- 1 | foo: 2 | a: "default value for `a` from the Parser's `config_path` (composition_defaults.yaml)" 3 | 4 | bar: 5 | b: "default value for `b` from the Parser's `config_path` (composition_defaults.yaml)" 6 | 7 | baz: 8 | c: "default value for `c` from the Parser's `config_path` (composition_defaults.yaml)" 9 | -------------------------------------------------------------------------------- /examples/config_files/config_a.yaml: -------------------------------------------------------------------------------- 1 | foo: 2 | a: "default value for `a` from the config_a.yaml file" 3 | -------------------------------------------------------------------------------- /examples/config_files/config_b.yaml: -------------------------------------------------------------------------------- 1 | bar: 2 | b: "default value for `b` from the config_b.yaml file" 3 | -------------------------------------------------------------------------------- /examples/config_files/many_configs.py: -------------------------------------------------------------------------------- 1 | import shlex 2 | from dataclasses import dataclass 3 | 4 | import simple_parsing 5 | 6 | 7 | @dataclass 8 | class TrainConfig: 9 | """Training config for Machine Learning.""" 10 | 11 | workers: int = 8 # The number of workers for training 12 | exp_name: str = "default_exp" # The experiment name 13 | 14 | 15 | @dataclass 16 | class EvalConfig: 17 | """Evaluation config.""" 18 | 19 | n_batches: int = 8 # The number of batches for evaluation 20 | checkpoint: str = "best.pth" # The checkpoint to use 21 | 22 | 23 | def main(args=None) -> None: 24 | parser = simple_parsing.ArgumentParser(add_config_path_arg=True) 25 | 26 | parser.add_arguments(TrainConfig, dest="train") 27 | parser.add_arguments(EvalConfig, dest="eval") 28 | 29 | if isinstance(args, str): 30 | args = shlex.split(args) 31 | args = parser.parse_args(args) 32 | 33 | train_config: TrainConfig = args.train 34 | eval_config: EvalConfig = args.eval 35 | print(f"Training {train_config.exp_name} with {train_config.workers} workers...") 36 | print(f"Evaluating '{eval_config.checkpoint}' with {eval_config.n_batches} batches...") 37 | 38 | 39 | main() 40 | expected = """ 41 | Training default_exp with 8 workers... 42 | Evaluating 'best.pth' with 8 batches... 43 | """ 44 | 45 | 46 | main("") 47 | expected += """\ 48 | Training default_exp with 8 workers... 49 | Evaluating 'best.pth' with 8 batches... 50 | """ 51 | 52 | main("--config_path many_configs.yaml --exp_name my_first_exp") 53 | expected += """\ 54 | Training my_first_exp with 42 workers... 55 | Evaluating 'best.pth' with 100 batches... 56 | """ 57 | -------------------------------------------------------------------------------- /examples/config_files/many_configs.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | exp_name: my_yaml_exp 3 | workers: 42 4 | eval: 5 | checkpoint: best.pth 6 | n_batches: 100 7 | -------------------------------------------------------------------------------- /examples/config_files/one_config.py: -------------------------------------------------------------------------------- 1 | """Example adapted from https://github.com/eladrich/pyrallis#my-first-pyrallis-example-""" 2 | 3 | from dataclasses import dataclass 4 | 5 | import simple_parsing 6 | 7 | 8 | @dataclass 9 | class TrainConfig: 10 | """Training configuration.""" 11 | 12 | workers: int = 8 # The number of workers for training 13 | exp_name: str = "default_exp" # The experiment name 14 | 15 | 16 | def main(args=None) -> None: 17 | cfg = simple_parsing.parse( 18 | config_class=TrainConfig, args=args, add_config_path_arg="config-file" 19 | ) 20 | print(f"Training {cfg.exp_name} with {cfg.workers} workers...") 21 | 22 | 23 | main() 24 | expected = """ 25 | Training default_exp with 8 workers... 26 | """ 27 | 28 | main("") 29 | expected += """\ 30 | Training default_exp with 8 workers... 31 | """ 32 | 33 | # NOTE: When running as in the readme: 34 | main("--config-file one_config.yaml --exp_name my_first_exp") 35 | expected += """\ 36 | Training my_first_exp with 42 workers... 37 | """ 38 | -------------------------------------------------------------------------------- /examples/config_files/one_config.yaml: -------------------------------------------------------------------------------- 1 | exp_name: my_yaml_exp 2 | workers: 42 3 | -------------------------------------------------------------------------------- /examples/container_types/README.md: -------------------------------------------------------------------------------- 1 | # Parsing Container-type Arguments (list, tuple) 2 | 3 | In "vanilla" argparse, it is usually difficult to parse lists and tuples. 4 | 5 | `simple-parsing` makes it easier, by leveraging the type-annotations of the builtin `typing` module. Simply mark you attributes using the corresponding annotation, and the item types will be automatically converted for you: 6 | 7 | 8 | 9 | ```python 10 | 11 | ``` 12 | -------------------------------------------------------------------------------- /examples/container_types/lists_example.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from simple_parsing import ArgumentParser 4 | from simple_parsing.helpers import list_field 5 | 6 | 7 | @dataclass 8 | class Example: 9 | some_integers: list[int] = field( 10 | default_factory=list 11 | ) # This is a list of integers (empty by default) 12 | """This list is empty, by default. 13 | 14 | when passed some parameters, they are automatically converted to integers, since we annotated 15 | the attribute with a type (typing.List[]). 16 | """ 17 | 18 | # When using a list attribute, the dataclasses module requires us to use `dataclass.field()`, 19 | # so each instance of this class has a different list, rather than them sharing the same list. 20 | # To simplify this, you can use `MutableField(value)` which is just a shortcut for `field(default_factory=lambda: value)`. 21 | some_floats: list[float] = list_field(3.14, 2.56) 22 | 23 | some_list_of_strings: list[str] = list_field("default_1", "default_2") 24 | """This list has a default value of ["default_1", "default_2"].""" 25 | 26 | 27 | parser = ArgumentParser() 28 | parser.add_arguments(Example, "example") 29 | args = parser.parse_args() 30 | 31 | example: Example = args.example 32 | print(example) 33 | expected = "Example(some_integers=[], some_floats=[3.14, 2.56], some_list_of_strings=['default_1', 'default_2'])" 34 | -------------------------------------------------------------------------------- /examples/custom_args/README.md: -------------------------------------------------------------------------------- 1 | # Custom Argparse Arguments 2 | 3 | The `dataclasses.field()` function is used to customize the declaration of 4 | fields on a dataclass. It accepts, among others, the `default`, 5 | `default_factory`, arguments used to set the default instance values to fields 6 | (please take a look at the [official documentation](https://docs.python.org/3/library/dataclasses.html#dataclasses.field) For more 7 | information). 8 | 9 | `simple-parsing` provides an overloaded version of this function: 10 | `simple_parsing.field()`, which, in addition to all the above-mentioned keyword 11 | arguments of the `field()` method, **can be passed any of the arguments 12 | of the usual `add_argument(*option_strings, **kwargs)` function!**. 13 | 14 | The values passed this way take precedence over those auto-generated by 15 | `simple-parsing`, allowing you to do pretty much anything you want. 16 | 17 | ## Examples 18 | 19 | - ### List of choices 20 | 21 | For example, here is how you would create a list of choices, whereby any 22 | of the passed arguments can only be contained within the choices: 23 | 24 | ```python 25 | 26 | from dataclasses import dataclass 27 | from simple_parsing import ArgumentParser, field 28 | from typing import List 29 | 30 | @dataclass 31 | class Foo: 32 | """ Some class Foo """ 33 | 34 | # A sequence of tasks. 35 | task_sequence: List[str] = field(choices=["train", "test", "ood"]) 36 | 37 | parser = ArgumentParser() 38 | parser.add_arguments(Foo, "foo") 39 | 40 | args = parser.parse_args("--task_sequence train train ood".split()) 41 | foo: Foo = args.foo 42 | print(foo) 43 | assert foo.task_sequence == ["train", "train", "ood"] 44 | 45 | ``` 46 | 47 | - ### Adding additional aliases for an argument 48 | 49 | By passing the 50 | 51 | ```python 52 | @dataclass 53 | class Foo(TestSetup): 54 | """ Some example Foo. """ 55 | # The output directory. (can be passed using any of "-o" or --out or ) 56 | output_dir: str = field( 57 | default="/out", 58 | alias=["-o", "--out"], 59 | choices=["/out", "/bob"] 60 | ) 61 | 62 | foo = Foo.setup("--output_dir /bob") 63 | assert foo.output_dir == "/bob" 64 | 65 | with raises(): 66 | foo = Foo.setup("-o /cat") 67 | assert foo.output_dir == "/cat" 68 | 69 | foo = Foo.setup("--out /bob") 70 | assert foo.output_dir == "/bob" 71 | ``` 72 | 73 | - ### Adding Flags with "store-true" or "store-false" 74 | 75 | Additionally, 76 | -------------------------------------------------------------------------------- /examples/custom_args/custom_args_example.py: -------------------------------------------------------------------------------- 1 | """Example of overwriting auto-generated argparse options with custom ones.""" 2 | 3 | from dataclasses import dataclass 4 | 5 | from simple_parsing import ArgumentParser, field 6 | from simple_parsing.helpers import list_field 7 | 8 | 9 | def parse(cls, args: str = ""): 10 | """Removes some boilerplate code from the examples.""" 11 | parser = ArgumentParser() # Create an argument parser 12 | parser.add_arguments(cls, "example") # add arguments for the dataclass 13 | ns = parser.parse_args(args.split()) # parse the given `args` 14 | return ns.example # return the dataclass instance 15 | 16 | 17 | # Example 1: List of Choices: 18 | 19 | 20 | @dataclass 21 | class Example1: 22 | # A list of animals to take on a walk. (can only be passed 'cat' or 'dog') 23 | pets_to_walk: list[str] = list_field(default=["dog"], choices=["cat", "dog"]) 24 | 25 | 26 | # passing no arguments uses the default values: 27 | assert parse(Example1, "") == Example1(pets_to_walk=["dog"]) 28 | assert parse(Example1, "--pets_to_walk") == Example1(pets_to_walk=[]) 29 | assert parse(Example1, "--pets_to_walk cat") == Example1(pets_to_walk=["cat"]) 30 | assert parse(Example1, "--pets_to_walk dog dog cat") == Example1( 31 | pets_to_walk=["dog", "dog", "cat"] 32 | ) 33 | 34 | 35 | # # Passing a value not in 'choices' produces an error: 36 | # with pytest.raises(SystemExit): 37 | # example = parse(Example1, "--pets_to_walk racoon") 38 | # expected = """ 39 | # usage: custom_args_example.py [-h] [--pets_to_walk [{cat,dog,horse} [{cat,dog} ...]]] 40 | # custom_args_example.py: error: argument --pets_to_walk: invalid choice: 'racoon' (choose from 'cat', 'dog') 41 | # """ 42 | 43 | 44 | # Example 2: Additional Option Strings 45 | 46 | 47 | @dataclass 48 | class Example2: 49 | # (This argument can be passed either as "-i" or "--input_dir") 50 | input_dir: str = field("./in", alias="-i") 51 | # (This argument can be passed either as "-o", "--out", or "--output_dir") 52 | output_dir: str = field("./out", alias=["-o", "--out"]) 53 | 54 | 55 | assert parse(Example2, "-i tmp/data") == Example2(input_dir="tmp/data") 56 | assert parse(Example2, "-o tmp/data") == Example2(output_dir="tmp/data") 57 | assert parse(Example2, "--out louise") == Example2(output_dir="louise") 58 | assert parse(Example2, "--input_dir louise") == Example2(input_dir="louise") 59 | assert parse(Example2, "--output_dir joe/annie") == Example2(output_dir="joe/annie") 60 | assert parse(Example2, "-i input -o output") == Example2(input_dir="input", output_dir="output") 61 | 62 | 63 | # Example 3: Using other actions (store_true, store_false, store_const, etc.) 64 | 65 | 66 | @dataclass 67 | class Example3: 68 | """Examples with other actions.""" 69 | 70 | b: bool = False 71 | debug: bool = field(alias="-d", action="store_true") 72 | verbose: bool = field(alias="-v", action="store_true") 73 | 74 | cache: bool = False 75 | # cache: bool = field(default=True, "--no_cache", "store_false") 76 | # no_cache: bool = field(dest=cache, action="store_false") 77 | 78 | 79 | parser = ArgumentParser() 80 | parser.add_arguments(Example3, "example") 81 | args = parser.parse_args() 82 | example = args.example 83 | print(example) 84 | delattr(args, "example") 85 | assert not vars(args) 86 | -------------------------------------------------------------------------------- /examples/dataclasses/README.md: -------------------------------------------------------------------------------- 1 | # Dataclasses 2 | 3 | These are simple examples showing how to use `@dataclass` to create argument classes. 4 | 5 | First, take a look at the official [dataclasses module documentation](https://docs.python.org/3.7/library/dataclasses.html). 6 | 7 | With `simple-parsing`, groups of attributes are defined directly in code as dataclasses, each holding a set of related parameters. Methods can also be added to these dataclasses, which helps to promote the "Separation of Concerns" principle by keeping all the logic related to argument parsing in the same place as the arguments themselves. 8 | 9 | ## Examples: 10 | 11 | - [dataclass_example.py](dataclass_example.py): a simple toy example showing an example of a dataclass 12 | 13 | - [hyperparameters_example.py](hyperparameters_example.py): Shows an example of an argument dataclass which also defines serialization methods. 14 | 15 | 16 | 17 | NOTE: For attributes of a mutable type (a type besides `int`, `float`, `bool` or `str`, such as `list`, `tuple`, or `object` or any of their subclasses), it is highly recommended (and often required) to use the `field` function of the dataclasses module, and to specify either a default value or a default factory function. 18 | 19 | To simplify this, `simple-parsing` provides `MutableField`, a convenience function, which directly sets the passed argument as the return value of an anonymous factory function. 20 | -------------------------------------------------------------------------------- /examples/dataclasses/dataclass_example.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass, fields 3 | 4 | 5 | @dataclass 6 | class Point: 7 | """Simple class Point.""" 8 | 9 | x: float = 0.0 10 | y: float = 0.0 11 | z: float = 0.0 12 | 13 | def distance(self, other: "Point") -> float: 14 | return math.sqrt( 15 | (self.x - other.x) ** 2 + (self.y - other.y) ** 2 + (self.z - other.z) ** 2 16 | ) 17 | 18 | 19 | p1 = Point(x=1, y=3) 20 | p2 = Point(x=1.0, y=3.0) 21 | 22 | assert p1 == p2 23 | 24 | for field in fields(p1): 25 | print(f"Field {field.name} has type {field.type} and a default value if {field.default}.") 26 | 27 | expected = """ 28 | Field x has type and a default value if 0.0. 29 | Field y has type and a default value if 0.0. 30 | Field z has type and a default value if 0.0. 31 | """ 32 | -------------------------------------------------------------------------------- /examples/dataclasses/hyperparameters_example.py: -------------------------------------------------------------------------------- 1 | """ - Argument dataclasses can also have methods! """ 2 | import json 3 | import os 4 | from dataclasses import asdict, dataclass 5 | 6 | from simple_parsing import ArgumentParser 7 | 8 | parser = ArgumentParser() 9 | 10 | 11 | @dataclass 12 | class HyperParameters: 13 | batch_size: int = 32 14 | optimizer: str = "ADAM" 15 | learning_rate: float = 1e-4 16 | max_epochs: int = 100 17 | l1_reg: float = 1e-5 18 | l2_reg: float = 1e-5 19 | 20 | def save(self, path: str): 21 | with open(path, "w") as f: 22 | config_dict = asdict(self) 23 | json.dump(config_dict, f, indent=1) 24 | 25 | @classmethod 26 | def load(cls, path: str): 27 | with open(path) as f: 28 | config_dict = json.load(f) 29 | return cls(**config_dict) 30 | 31 | 32 | parser.add_arguments(HyperParameters, dest="hparams") 33 | 34 | args = parser.parse_args() 35 | 36 | hparams: HyperParameters = args.hparams 37 | print(hparams) 38 | expected = """ 39 | HyperParameters(batch_size=32, optimizer='ADAM', learning_rate=0.0001, max_epochs=100, l1_reg=1e-05, l2_reg=1e-05) 40 | """ 41 | 42 | # save and load from a json file: 43 | hparams.save("hyperparameters.json") 44 | _hparams = HyperParameters.load("hyperparameters.json") 45 | assert hparams == _hparams 46 | 47 | 48 | os.remove("hyperparameters.json") 49 | -------------------------------------------------------------------------------- /examples/demo.py: -------------------------------------------------------------------------------- 1 | # examples/demo.py 2 | from dataclasses import dataclass 3 | 4 | from simple_parsing import ArgumentParser 5 | 6 | parser = ArgumentParser() 7 | parser.add_argument("--foo", type=int, default=123, help="foo help") 8 | 9 | 10 | @dataclass 11 | class Options: 12 | """Help string for this group of command-line arguments.""" 13 | 14 | log_dir: str # Help string for a required str argument 15 | learning_rate: float = 1e-4 # Help string for a float argument 16 | 17 | 18 | parser.add_arguments(Options, dest="options") 19 | 20 | args = parser.parse_args() 21 | print("foo:", args.foo) 22 | print("options:", args.options) 23 | -------------------------------------------------------------------------------- /examples/demo_simple.py: -------------------------------------------------------------------------------- 1 | # examples/demo_simple.py 2 | from dataclasses import dataclass 3 | 4 | import simple_parsing 5 | 6 | 7 | @dataclass 8 | class Options: 9 | """Help string for this group of command-line arguments.""" 10 | 11 | log_dir: str # Help string for a required str argument 12 | learning_rate: float = 1e-4 # Help string for a float argument 13 | 14 | 15 | options = simple_parsing.parse(Options) 16 | print(options) 17 | -------------------------------------------------------------------------------- /examples/docstrings/README.md: -------------------------------------------------------------------------------- 1 | # Docstrings 2 | 3 | A docstring can either be: 4 | 5 | - A comment on the same line as the attribute definition 6 | - A single or multi-line comment on the line(s) preceding the attribute definition 7 | - A single or multi-line docstring on the line(s) following the attribute 8 | definition, starting with either `"""` or `'''` and ending with the same token. 9 | 10 | When more than one docstring options are present, one of them is chosen to 11 | be used as the '--help' text of the attribute, according to the following ordering: 12 | 13 | 1. docstring below the attribute 14 | 2. comment above the attribute 15 | 3. inline comment 16 | 17 | NOTE: It is recommended to add blank lines between consecutive attribute 18 | assignments when using either the 'comment above' or 'docstring below' 19 | style, just for clarity. This doesn't change anything about the output of 20 | the "--help" command. 21 | 22 | ```python 23 | """ 24 | A simple example to demonstrate the 'attribute docstrings' mechanism of simple-parsing. 25 | 26 | """ 27 | from dataclasses import dataclass, field 28 | from typing import List, Tuple 29 | 30 | from simple_parsing import ArgumentParser 31 | 32 | parser = ArgumentParser() 33 | 34 | @dataclass 35 | class DocStringsExample(): 36 | """NOTE: This block of text is the class docstring, and it will show up under 37 | the name of the class in the --help group for this set of parameters. 38 | """ 39 | 40 | attribute1: float = 1.0 41 | """docstring below, When used, this always shows up in the --help text for this attribute""" 42 | 43 | # Comment above only: this shows up in the help text, since there is no docstring below. 44 | attribute2: float = 1.0 45 | 46 | attribute3: float = 1.0 # inline comment only (this shows up in the help text, since none of the two other options are present.) 47 | 48 | # comment above 42 49 | attribute4: float = 1.0 # inline comment 50 | """docstring below (this appears in --help)""" 51 | 52 | # comment above (this appears in --help) 46 53 | attribute5: float = 1.0 # inline comment 54 | 55 | attribute6: float = 1.0 # inline comment (this appears in --help) 56 | 57 | attribute7: float = 1.0 # inline comment 58 | """docstring below (this appears in --help)""" 59 | 60 | 61 | parser.add_arguments(DocStringsExample, "example") 62 | args = parser.parse_args() 63 | ex = args.example 64 | print(ex) 65 | ``` 66 | -------------------------------------------------------------------------------- /examples/docstrings/docstrings_example.py: -------------------------------------------------------------------------------- 1 | """A simple example to demonstrate the 'attribute docstrings' mechanism of simple-parsing.""" 2 | from dataclasses import dataclass 3 | 4 | from simple_parsing import ArgumentParser 5 | 6 | parser = ArgumentParser() 7 | 8 | 9 | @dataclass 10 | class DocStringsExample: 11 | """NOTE: This block of text is the class docstring, and it will show up under 12 | the name of the class in the --help group for this set of parameters. 13 | """ 14 | 15 | attribute1: float = 1.0 16 | """Docstring below, When used, this always shows up in the --help text for this attribute.""" 17 | 18 | # Comment above only: this shows up in the help text, since there is no docstring below. 19 | attribute2: float = 1.0 20 | 21 | attribute3: float = 1.0 # inline comment only (this shows up in the help text, since none of the two other options are present.) 22 | 23 | # comment above 42 24 | attribute4: float = 1.0 # inline comment 25 | """Docstring below (this appears in --help)""" 26 | 27 | # comment above (this appears in --help) 46 28 | attribute5: float = 1.0 # inline comment 29 | 30 | attribute6: float = 1.0 # inline comment (this appears in --help) 31 | 32 | attribute7: float = 1.0 # inline comment 33 | """Docstring below (this appears in --help)""" 34 | 35 | 36 | parser.add_arguments(DocStringsExample, "example") 37 | args = parser.parse_args() 38 | ex = args.example 39 | print(ex) 40 | expected = """ 41 | DocStringsExample(attribute1=1.0, attribute2=1.0, attribute3=1.0, attribute4=1.0, attribute5=1.0, attribute6=1.0, attribute7=1.0) 42 | """ 43 | -------------------------------------------------------------------------------- /examples/enums/README.md: -------------------------------------------------------------------------------- 1 | # Parsing Enums 2 | 3 | Parsing enums can be done quite simply, like so: 4 | 5 | ```python 6 | import enum 7 | from dataclasses import dataclass, field 8 | 9 | from simple_parsing import ArgumentParser 10 | 11 | parser = ArgumentParser() 12 | 13 | class Color(enum.Enum): 14 | RED = "RED" 15 | ORANGE = "ORANGE" 16 | BLUE = "BLUE" 17 | 18 | class Temperature(enum.Enum): 19 | HOT = 1 20 | WARM = 0 21 | COLD = -1 22 | MONTREAL = -35 23 | 24 | @dataclass 25 | class MyPreferences: 26 | """You can use Enums""" 27 | color: Color = Color.BLUE # my favorite colour 28 | temp: Temperature = Temperature.WARM 29 | 30 | parser.add_arguments(MyPreferences, "my_preferences") 31 | args = parser.parse_args() 32 | prefs: MyPreferences = args.my_preferences 33 | print(prefs) 34 | 35 | ``` 36 | 37 | You parse most datatypes using `simple-parsing`, as the type annotation on an argument is called as a conversion function in case the type of the attribute is not a builtin type or a dataclass. 38 | -------------------------------------------------------------------------------- /examples/enums/enums_example.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from dataclasses import dataclass 3 | 4 | from simple_parsing import ArgumentParser 5 | 6 | parser = ArgumentParser() 7 | 8 | 9 | class Color(enum.Enum): 10 | RED = "RED" 11 | ORANGE = "ORANGE" 12 | BLUE = "BLUE" 13 | 14 | 15 | class Temperature(enum.Enum): 16 | HOT = 1 17 | WARM = 0 18 | COLD = -1 19 | MONTREAL = -35 20 | 21 | 22 | @dataclass 23 | class MyPreferences: 24 | """You can use Enums.""" 25 | 26 | color: Color = Color.BLUE # my favorite colour 27 | temp: Temperature = Temperature.WARM 28 | 29 | 30 | parser.add_arguments(MyPreferences, "my_preferences") 31 | args = parser.parse_args() 32 | prefs: MyPreferences = args.my_preferences 33 | print(prefs) 34 | expected = """ 35 | MyPreferences(color=, temp=) 36 | """ 37 | -------------------------------------------------------------------------------- /examples/inheritance/README.md: -------------------------------------------------------------------------------- 1 | # Inheritance 2 | 3 | Say you are working on a new research project, building on top of some previous work. 4 | 5 | Let's suppose that the previous authors were smart enough to use `simple-parsing` to define their `HyperParameters` as a dataclass, potentially saving you and others a lot of work. All the model hyperparameters can therefore be provided directly as command-line arguments. 6 | 7 | You have a set of new hyperparameters or command-line arguments you want to add to your model. Rather than redefining the same HyperParameters over and over, wouldn't it be nice to be able to just add a few new arguments to an existing arguments dataclass? 8 | 9 | Behold, inheritance: 10 | 11 | ```python 12 | from simple_parsing import ArgumentParser 13 | from simple_parsing.helpers import JsonSerializable 14 | 15 | 16 | from dataclasses import dataclass 17 | from typing import Optional 18 | 19 | @dataclass 20 | class GANHyperParameters(JsonSerializable): 21 | batch_size: int = 32 # batch size 22 | d_steps: int = 1 # number of generator updates 23 | g_steps: int = 1 # number of discriminator updates 24 | learning_rate: float = 1e-4 25 | optimizer: str = "ADAM" 26 | 27 | 28 | @dataclass 29 | class WGANHyperParameters(GANHyperParameters): 30 | lambda_coeff: float = 10 # the lambda penalty coefficient. 31 | 32 | 33 | @dataclass 34 | class WGANGPHyperParameters(WGANHyperParameters): 35 | gp_penalty: float = 1e-6 # Gradient penalty coefficient 36 | 37 | 38 | parser = ArgumentParser() 39 | parser.add_argument( 40 | "--load_path", 41 | type=str, 42 | default=None, 43 | help="If given, the HyperParameters are read from the given file instead of from the command-line." 44 | ) 45 | parser.add_arguments(WGANGPHyperParameters, dest="hparams") 46 | 47 | args = parser.parse_args() 48 | 49 | load_path: str = args.load_path 50 | if load_path is None: 51 | hparams: WGANGPHyperParameters = args.hparams 52 | else: 53 | hparams = WGANGPHyperParameters.load_json(load_path) 54 | print(hparams) 55 | ``` 56 | -------------------------------------------------------------------------------- /examples/inheritance/inheritance_example.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from simple_parsing import ArgumentParser 4 | from simple_parsing.helpers import Serializable 5 | 6 | 7 | @dataclass 8 | class GANHyperParameters(Serializable): 9 | batch_size: int = 32 # batch size 10 | d_steps: int = 1 # number of generator updates 11 | g_steps: int = 1 # number of discriminator updates 12 | learning_rate: float = 1e-4 13 | optimizer: str = "ADAM" 14 | 15 | 16 | @dataclass 17 | class WGANHyperParameters(GANHyperParameters): 18 | lambda_coeff: float = 10 # the lambda penalty coefficient. 19 | 20 | 21 | @dataclass 22 | class WGANGPHyperParameters(WGANHyperParameters): 23 | gp_penalty: float = 1e-6 # Gradient penalty coefficient 24 | 25 | 26 | parser = ArgumentParser() 27 | parser.add_argument( 28 | "--load_path", 29 | type=str, 30 | default=None, 31 | help="If given, the HyperParameters are read from the given file instead of from the command-line.", 32 | ) 33 | parser.add_arguments(WGANGPHyperParameters, dest="hparams") 34 | 35 | args = parser.parse_args() 36 | 37 | load_path: str = args.load_path 38 | if load_path is None: 39 | hparams: WGANGPHyperParameters = args.hparams 40 | else: 41 | hparams = WGANGPHyperParameters.load_json(load_path) 42 | 43 | assert hparams == WGANGPHyperParameters( 44 | batch_size=32, 45 | d_steps=1, 46 | g_steps=1, 47 | learning_rate=0.0001, 48 | optimizer="ADAM", 49 | lambda_coeff=10, 50 | gp_penalty=1e-06, 51 | ) 52 | -------------------------------------------------------------------------------- /examples/inheritance/ml_inheritance.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from simple_parsing import ArgumentParser, choice 4 | from simple_parsing.helpers import Serializable 5 | 6 | # import tensorflow as tf 7 | 8 | 9 | class GAN: 10 | @dataclass 11 | class HyperParameters(Serializable): 12 | """Hyperparameters of the Generator and Discriminator networks.""" 13 | 14 | learning_rate: float = 1e-4 15 | optimizer: str = choice("ADAM", "RMSPROP", "SGD", default="ADAM") 16 | n_disc_iters_per_g_iter: int = ( 17 | 1 # Number of Discriminator iterations per Generator iteration. 18 | ) 19 | 20 | def __init__(self, hparams: HyperParameters): 21 | self.hparams = hparams 22 | 23 | 24 | class WGAN(GAN): 25 | """Wasserstein GAN.""" 26 | 27 | @dataclass 28 | class HyperParameters(GAN.HyperParameters): 29 | e_drift: float = 1e-4 30 | """Coefficient from the progan authors which penalizes critic outputs for having a large 31 | magnitude.""" 32 | 33 | def __init__(self, hparams: HyperParameters): 34 | self.hparams = hparams 35 | 36 | 37 | class WGANGP(WGAN): 38 | """Wasserstein GAN with Gradient Penalty.""" 39 | 40 | @dataclass 41 | class HyperParameters(WGAN.HyperParameters): 42 | e_drift: float = 1e-4 43 | """Coefficient from the progan authors which penalizes critic outputs for having a large 44 | magnitude.""" 45 | gp_coefficient: float = 10.0 46 | """Multiplying coefficient for the gradient penalty term of the loss equation. 47 | 48 | (10.0 is the default value, and was used by the PROGAN authors.) 49 | """ 50 | 51 | def __init__(self, hparams: HyperParameters): 52 | self.hparams: WGANGP.HyperParameters = hparams 53 | print(self.hparams.gp_coefficient) 54 | 55 | 56 | parser = ArgumentParser() 57 | parser.add_arguments(WGANGP.HyperParameters, "hparams") 58 | args = parser.parse_args() 59 | print(args.hparams) 60 | expected = """ 61 | WGANGP.HyperParameters(learning_rate=0.0001, optimizer='ADAM', n_disc_iters_per_g_iter=1, e_drift=0.0001, gp_coefficient=10.0) 62 | """ 63 | -------------------------------------------------------------------------------- /examples/inheritance/ml_inheritance_2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from simple_parsing import ArgumentParser, choice 4 | from simple_parsing.helpers import Serializable, list_field 5 | 6 | # import tensorflow as tf 7 | 8 | 9 | @dataclass 10 | class ConvBlock(Serializable): 11 | """A Block of Conv Layers.""" 12 | 13 | n_layers: int = 4 # number of layers 14 | n_filters: list[int] = list_field(16, 32, 64, 64) # filters per layer 15 | 16 | 17 | @dataclass 18 | class GeneratorHParams(ConvBlock): 19 | """Settings of the Generator model.""" 20 | 21 | optimizer: str = choice("ADAM", "RMSPROP", "SGD", default="ADAM") 22 | 23 | 24 | @dataclass 25 | class DiscriminatorHParams(ConvBlock): 26 | """Settings of the Discriminator model.""" 27 | 28 | optimizer: str = choice("ADAM", "RMSPROP", "SGD", default="ADAM") 29 | 30 | 31 | @dataclass 32 | class GanHParams(Serializable): 33 | """Hyperparameters of the Generator and Discriminator networks.""" 34 | 35 | gen: GeneratorHParams 36 | disc: DiscriminatorHParams 37 | learning_rate: float = 1e-4 38 | n_disc_iters_per_g_iter: int = 1 # Number of Discriminator iterations per Generator iteration. 39 | 40 | 41 | class GAN: 42 | """Generative Adversarial Network Model.""" 43 | 44 | def __init__(self, hparams: GanHParams): 45 | self.hparams = hparams 46 | 47 | 48 | @dataclass 49 | class WGanHParams(GanHParams): 50 | """HParams of the WGAN model.""" 51 | 52 | e_drift: float = 1e-4 53 | """Coefficient from the progan authors which penalizes critic outputs for having a large 54 | magnitude.""" 55 | 56 | 57 | class WGAN(GAN): 58 | """Wasserstein GAN.""" 59 | 60 | def __init__(self, hparams: WGanHParams): 61 | self.hparams = hparams 62 | 63 | 64 | @dataclass 65 | class CriticHParams(DiscriminatorHParams): 66 | """HyperParameters specific to a Critic.""" 67 | 68 | lambda_coefficient: float = 1e-5 69 | 70 | 71 | @dataclass 72 | class WGanGPHParams(WGanHParams): 73 | """Hyperparameters of the WGAN with Gradient Penalty.""" 74 | 75 | e_drift: float = 1e-4 76 | """Coefficient from the progan authors which penalizes critic outputs for having a large 77 | magnitude.""" 78 | gp_coefficient: float = 10.0 79 | """Multiplying coefficient for the gradient penalty term of the loss equation. 80 | 81 | (10.0 is the default value, and was used by the PROGAN authors.) 82 | """ 83 | disc: CriticHParams = field(default_factory=CriticHParams) 84 | # overwrite the usual 'disc' field of the WGanHParams dataclass. 85 | """ Parameters of the Critic. """ 86 | 87 | 88 | class WGANGP(WGAN): 89 | """Wasserstein GAN with Gradient Penalty.""" 90 | 91 | def __init__(self, hparams: WGanGPHParams): 92 | self.hparams = hparams 93 | 94 | 95 | parser = ArgumentParser() 96 | parser.add_arguments(WGanGPHParams, "hparams") 97 | args = parser.parse_args() 98 | 99 | print(args.hparams) 100 | 101 | expected = """ 102 | WGanGPHParams(gen=GeneratorHParams(n_layers=4, n_filters=[16, 32, 64, 64], \ 103 | optimizer='ADAM'), disc=CriticHParams(n_layers=4, n_filters=[16, 32, 64, 64], \ 104 | optimizer='ADAM', lambda_coefficient=1e-05), learning_rate=0.0001, \ 105 | n_disc_iters_per_g_iter=1, e_drift=0.0001, gp_coefficient=10.0) 106 | """ 107 | -------------------------------------------------------------------------------- /examples/merging/README.md: -------------------------------------------------------------------------------- 1 | # Parsing Multiple Dataclasses with Merging 2 | 3 | Here, we demonstrate parsing multiple classes each of which has a list attribute. 4 | There are a few options for doing this. For example, if we want to let each instance have a distinct prefix for its arguments, we could use the ConflictResolution.AUTO option. 5 | 6 | In the following examples, we instead want to create a multiple instances of the argument dataclasses from the command line, but we don't want to have a different prefix for each instance. 7 | 8 | To do this, we pass the `ConflictResolution.ALWAYS_MERGE` option to the argument parser constructor. This creates a single argument for each attribute that will be set as multiple (i.e., if the attribute was of type `str`, the argument becomes a list of `str`, one for each class instance). 9 | 10 | For more info, check out the docstring of the `ConflictResolution` enum. 11 | 12 | ## Examples: 13 | 14 | - [multiple_example.py](multiple_example.py) 15 | - [multiple_lists_example.py](multiple_lists_example.py) 16 | -------------------------------------------------------------------------------- /examples/merging/multiple_example.py: -------------------------------------------------------------------------------- 1 | """Example of how to create multiple instances of a class from the command-line. 2 | 3 | # NOTE: If your dataclass has a list attribute, and you wish to parse multiple instances of that class from the command line, 4 | # simply enclose each list with single or double quotes. 5 | # For this example, something like: 6 | >>> python examples/multiple_instances_example.py --num_instances 2 --foo 1 2 --list_of_ints "3 5 7" "4 6 10" 7 | """ 8 | from dataclasses import dataclass 9 | 10 | from simple_parsing import ArgumentParser, ConflictResolution 11 | 12 | parser = ArgumentParser(conflict_resolution=ConflictResolution.ALWAYS_MERGE) 13 | 14 | 15 | @dataclass 16 | class Config: 17 | """A class which groups related parameters.""" 18 | 19 | run_name: str = "train" # Some parameter for the run name. 20 | some_int: int = 10 # an optional int parameter. 21 | log_dir: str = "logs" # an optional string parameter. 22 | """The logging directory to use. 23 | 24 | (This is an attribute docstring for the log_dir attribute, and shows up when using the "--help" 25 | argument!) 26 | """ 27 | 28 | 29 | parser.add_arguments(Config, "train_config") 30 | parser.add_arguments(Config, "valid_config") 31 | 32 | args = parser.parse_args() 33 | 34 | train_config: Config = args.train_config 35 | valid_config: Config = args.valid_config 36 | 37 | print(train_config, valid_config, sep="\n") 38 | 39 | expected = """ 40 | Config(run_name='train', some_int=10, log_dir='logs') 41 | Config(run_name='train', some_int=10, log_dir='logs') 42 | """ 43 | -------------------------------------------------------------------------------- /examples/merging/multiple_lists_example.py: -------------------------------------------------------------------------------- 1 | """Here, we demonstrate parsing multiple classes each of which has a list attribute. There are a 2 | few options for doing this. For example, if we want to let each instance have a distinct prefix for 3 | its arguments, we could use the ConflictResolution.AUTO option. 4 | 5 | Here, we want to create a few instances of `CNNStack` from the command line, 6 | but don't want to have a different prefix for each instance. 7 | To do this, we pass the `ConflictResolution.ALWAYS_MERGE` option to the argument parser constructor. 8 | This creates a single argument for each attribute, that will be set as multiple 9 | (i.e., if the attribute is a `str`, the argument becomes a list of `str`, one for each class instance). 10 | 11 | For more info, check out the docstring of the `ConflictResolution` enum. 12 | """ 13 | 14 | from dataclasses import dataclass, field 15 | 16 | from simple_parsing import ArgumentParser, ConflictResolution 17 | 18 | 19 | @dataclass 20 | class CNNStack: 21 | name: str = "stack" 22 | num_layers: int = 3 23 | kernel_sizes: tuple[int, int, int] = (7, 5, 5) 24 | num_filters: list[int] = field(default_factory=[32, 64, 64].copy) 25 | 26 | 27 | parser = ArgumentParser(conflict_resolution=ConflictResolution.ALWAYS_MERGE) 28 | 29 | num_stacks = 3 30 | for i in range(num_stacks): 31 | parser.add_arguments(CNNStack, dest=f"stack_{i}", default=CNNStack()) 32 | 33 | args = parser.parse_args() 34 | stack_0 = args.stack_0 35 | stack_1 = args.stack_1 36 | stack_2 = args.stack_2 37 | 38 | # BUG: When the list length and the number of instances to parse is the same, 39 | # AND there is no default value passed to `add_arguments`, it gets parsed as 40 | # multiple lists each with only one element, rather than duplicating the field's 41 | # default value correctly. 42 | 43 | print(stack_0, stack_1, stack_2, sep="\n") 44 | expected = """\ 45 | CNNStack(name='stack', num_layers=3, kernel_sizes=(7, 5, 5), num_filters=[32, 64, 64]) 46 | CNNStack(name='stack', num_layers=3, kernel_sizes=(7, 5, 5), num_filters=[32, 64, 64]) 47 | CNNStack(name='stack', num_layers=3, kernel_sizes=(7, 5, 5), num_filters=[32, 64, 64]) 48 | """ 49 | 50 | # Example of how to pass different lists for each instance: 51 | 52 | args = parser.parse_args("--num_filters [1,2,3] [4,5,6] [7,8,9] ".split()) 53 | stack_0 = args.stack_0 54 | stack_1 = args.stack_1 55 | stack_2 = args.stack_2 56 | 57 | # BUG: TODO: Fix the multiple + list attributes bug. 58 | print(stack_0, stack_1, stack_2, sep="\n") 59 | expected += """\ 60 | CNNStack(name='stack', num_layers=3, kernel_sizes=(7, 5, 5), num_filters=[1, 2, 3]) 61 | CNNStack(name='stack', num_layers=3, kernel_sizes=(7, 5, 5), num_filters=[4, 5, 6]) 62 | CNNStack(name='stack', num_layers=3, kernel_sizes=(7, 5, 5), num_filters=[7, 8, 9]) 63 | """ 64 | -------------------------------------------------------------------------------- /examples/partials/README.md: -------------------------------------------------------------------------------- 1 | # Partials - Configuring arbitrary classes / callables 2 | -------------------------------------------------------------------------------- /examples/partials/partials_example.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | 5 | from simple_parsing import ArgumentParser 6 | from simple_parsing.helpers import subgroups 7 | from simple_parsing.helpers.partial import Partial, config_for 8 | 9 | 10 | # Suppose we want to choose between the Adam and SGD optimizers from PyTorch: 11 | # (NOTE: We don't import pytorch here, so we just create the types to illustrate) 12 | class Optimizer: 13 | def __init__(self, params): 14 | ... 15 | 16 | 17 | class Adam(Optimizer): 18 | def __init__( 19 | self, 20 | params, 21 | lr: float = 3e-4, 22 | beta1: float = 0.9, 23 | beta2: float = 0.999, 24 | eps: float = 1e-08, 25 | ): 26 | self.params = params 27 | self.lr = lr 28 | self.beta1 = beta1 29 | self.beta2 = beta2 30 | self.eps = eps 31 | 32 | 33 | class SGD(Optimizer): 34 | def __init__( 35 | self, 36 | params, 37 | lr: float = 3e-4, 38 | weight_decay: float | None = None, 39 | momentum: float = 0.9, 40 | eps: float = 1e-08, 41 | ): 42 | self.params = params 43 | self.lr = lr 44 | self.weight_decay = weight_decay 45 | self.momentum = momentum 46 | self.eps = eps 47 | 48 | 49 | # Dynamically create a dataclass that will be used for the above type: 50 | # NOTE: We could use Partial[Adam] or Partial[Optimizer], however this would treat `params` as a 51 | # required argument. 52 | # AdamConfig = Partial[Adam] # would treat 'params' as a required argument. 53 | # SGDConfig = Partial[SGD] # same here 54 | AdamConfig: type[Partial[Adam]] = config_for(Adam, ignore_args="params") 55 | SGDConfig: type[Partial[SGD]] = config_for(SGD, ignore_args="params") 56 | 57 | 58 | @dataclass 59 | class Config: 60 | # Which optimizer to use. 61 | optimizer: Partial[Optimizer] = subgroups( 62 | { 63 | "sgd": SGDConfig, 64 | "adam": AdamConfig, 65 | }, 66 | default_factory=AdamConfig, 67 | ) 68 | 69 | 70 | parser = ArgumentParser() 71 | parser.add_arguments(Config, "config") 72 | args = parser.parse_args() 73 | 74 | 75 | config: Config = args.config 76 | print(config) 77 | expected = "Config(optimizer=AdamConfig(lr=0.0003, beta1=0.9, beta2=0.999, eps=1e-08))" 78 | 79 | my_model_parameters = [123] # nn.Sequential(...).parameters() 80 | 81 | optimizer = config.optimizer(params=my_model_parameters) 82 | print(vars(optimizer)) 83 | expected += """ 84 | {'params': [123], 'lr': 0.0003, 'beta1': 0.9, 'beta2': 0.999, 'eps': 1e-08} 85 | """ 86 | -------------------------------------------------------------------------------- /examples/prefixing/README.md: -------------------------------------------------------------------------------- 1 | # Prefixing Mechanics 2 | 3 | Before starting to use multiple dataclasses or nesting them, it is good to first understand the prefixing mechanism used by `simple-parsing`. 4 | 5 | What's important to consider is that in `argparse`, arguments can only be provided as a "flat" list. 6 | 7 | In order to be able to "reuse" arguments and parse multiple instances of the same class from the command-line, we therefore have to choose between these options: 8 | 9 | 1. Give each individual argument a differentiating prefix; (default) 10 | 2. Disallow the reuse of arguments; 11 | 3. Parse a List of values instead of a single value, and later redistribute the value to the instances. 12 | 13 | You can control which of these three behaviours the parser is to use by setting the `conflict_resolution` argument of `simple_parsing.ArgumentParser`'s `__init__` method. 14 | 15 | - For option 1, use either the `ConflictResolution.AUTO` or `ConflictResolution.EXPLICIT` options 16 | - For option 2, use the `ConflictResolution.NONE` option. 17 | - For option 3, use the `ConflictResolution.ALWAYS_MERGE` option. 18 | -------------------------------------------------------------------------------- /examples/prefixing/manual_prefix_example.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from simple_parsing import ArgumentParser 4 | 5 | 6 | @dataclass 7 | class Config: 8 | """Simple example of a class that can be reused.""" 9 | 10 | log_dir: str = "logs" 11 | 12 | 13 | parser = ArgumentParser() 14 | parser.add_arguments(Config, "train_config", prefix="train_") 15 | parser.add_arguments(Config, "valid_config", prefix="valid_") 16 | args = parser.parse_args() 17 | print(vars(args)) 18 | -------------------------------------------------------------------------------- /examples/serialization/README.md: -------------------------------------------------------------------------------- 1 | # Serialization 2 | 3 | The `Serializable` class makes it easy to serialize any dataclass to and from json or yaml. 4 | It is also very easy to add support for serializing/deserializing your own custom types! 5 | 6 | ```python 7 | >>> from simple_parsing.helpers import Serializable 8 | >>> from dataclasses import dataclass 9 | >>> 10 | >>> @dataclass 11 | ... class Person(Serializable): 12 | ... name: str = "Bob" 13 | ... age: int = 20 14 | ... 15 | >>> @dataclass 16 | ... class Student(Person): 17 | ... domain: str = "Computer Science" 18 | ... average_grade: float = 0.80 19 | ... 20 | >>> # Serialization: 21 | ... # We can dump to yaml or json: 22 | ... charlie = Person(name="Charlie") 23 | >>> print(charlie.dumps_yaml()) 24 | age: 20 25 | name: Charlie 26 | 27 | >>> print(charlie.dumps_json()) 28 | {"name": "Charlie", "age": 20} 29 | >>> print(charlie.dumps()) # JSON by default 30 | {"name": "Charlie", "age": 20} 31 | >>> # Deserialization: 32 | ... bob = Student() 33 | >>> print(bob) 34 | Student(name='Bob', age=20, domain='Computer Science', average_grade=0.8) 35 | >>> bob.save("bob.yaml") 36 | >>> # Can load a Student from the base class: this will use the first subclass 37 | ... # that has all the required fields. 38 | ... _bob = Person.load("bob.yaml", drop_extra_fields=False) 39 | >>> assert isinstance(_bob, Student) 40 | >>> assert _bob == bob 41 | ``` 42 | 43 | ## Adding custom types 44 | 45 | Register a new encoding function using `encode`, and a new decoding function using `register_decoding_fn` 46 | 47 | For example: Consider the same example as above, but we add a Tensor attribute from `pytorch`. 48 | 49 | ```python 50 | from dataclasses import dataclass 51 | from typing import List 52 | 53 | import torch 54 | from torch import Tensor 55 | 56 | from simple_parsing.helpers import Serializable 57 | from simple_parsing.helpers.serialization import encode, register_decoding_fn 58 | 59 | expected: str = "" 60 | 61 | @dataclass 62 | class Person(Serializable): 63 | name: str = "Bob" 64 | age: int = 20 65 | t: Tensor = torch.arange(4) 66 | 67 | 68 | @dataclass 69 | class Student(Person): 70 | domain: str = "Computer Science" 71 | average_grade: float = 0.80 72 | 73 | 74 | @encode.register 75 | def encode_tensor(obj: Tensor) -> List: 76 | """ We choose to encode a tensor as a list, for instance """ 77 | return obj.tolist() 78 | 79 | # We will use `torch.as_tensor` as our decoding function 80 | register_decoding_fn(Tensor, torch.as_tensor) 81 | 82 | # Serialization: 83 | # We can dump to yaml or json: 84 | charlie = Person(name="Charlie") 85 | print(charlie.dumps_yaml()) 86 | expected += """\ 87 | age: 20 88 | name: Charlie 89 | t: 90 | - 0 91 | - 1 92 | - 2 93 | - 3 94 | 95 | """ 96 | 97 | 98 | print(charlie.dumps_json()) 99 | expected += """\ 100 | {"name": "Charlie", "age": 20, "t": [0, 1, 2, 3]} 101 | """ 102 | 103 | # Deserialization: 104 | bob = Student() 105 | print(bob) 106 | expected += """\ 107 | Student(name='Bob', age=20, t=tensor([0, 1, 2, 3]), domain='Computer Science', average_grade=0.8) 108 | """ 109 | 110 | # Can load a Student from the base class: this will use the first subclass 111 | # that has all the required fields. 112 | bob.save("bob.yaml") 113 | _bob = Person.load("bob.yaml", drop_extra_fields=False) 114 | assert isinstance(_bob, Student), _bob 115 | # Note: using _bob == bob doesn't work here because of Tensor comparison, 116 | # But this basically shows the same thing as the previous example. 117 | assert str(_bob) == str(bob) 118 | 119 | ``` 120 | -------------------------------------------------------------------------------- /examples/serialization/bob.json: -------------------------------------------------------------------------------- 1 | {"name": "Bob", "age": 20, "domain": "Computer Science", "average_grade": 0.8} 2 | -------------------------------------------------------------------------------- /examples/serialization/custom_types_example.py: -------------------------------------------------------------------------------- 1 | # Cleaning up 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | from simple_parsing.helpers import Serializable 9 | from simple_parsing.helpers.serialization import encode, register_decoding_fn 10 | 11 | expected: str = "" 12 | 13 | 14 | @dataclass 15 | class Person(Serializable): 16 | name: str = "Bob" 17 | age: int = 20 18 | t: Tensor = torch.arange(4) 19 | 20 | 21 | @dataclass 22 | class Student(Person): 23 | domain: str = "Computer Science" 24 | average_grade: float = 0.80 25 | 26 | 27 | @encode.register 28 | def encode_tensor(obj: Tensor) -> list: 29 | """We choose to encode a tensor as a list, for instance.""" 30 | return obj.tolist() 31 | 32 | 33 | # We will use `torch.as_tensor` as our decoding function 34 | register_decoding_fn(Tensor, torch.as_tensor) 35 | 36 | # Serialization: 37 | # We can dump to yaml or json: 38 | charlie = Person(name="Charlie") 39 | print(charlie.dumps_yaml()) 40 | expected += """\ 41 | age: 20 42 | name: Charlie 43 | t: 44 | - 0 45 | - 1 46 | - 2 47 | - 3 48 | 49 | """ 50 | 51 | 52 | print(charlie.dumps_json()) 53 | expected += """\ 54 | {"name": "Charlie", "age": 20, "t": [0, 1, 2, 3]} 55 | """ 56 | 57 | # Deserialization: 58 | bob = Student() 59 | print(bob) 60 | expected += """\ 61 | Student(name='Bob', age=20, t=tensor([0, 1, 2, 3]), domain='Computer Science', average_grade=0.8) 62 | """ 63 | 64 | # Can load a Student from the base class: this will use the first subclass 65 | # that has all the required fields. 66 | bob.save("bob.yaml") 67 | _bob = Person.load("bob.yaml", drop_extra_fields=False) 68 | assert isinstance(_bob, Student), _bob 69 | # Note: using _bob == bob doesn't work here because of Tensor comparison, 70 | # But this basically shows the same thing as the previous example. 71 | assert str(_bob) == str(bob) 72 | 73 | os.remove("bob.yaml") 74 | -------------------------------------------------------------------------------- /examples/serialization/serialization_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from simple_parsing.helpers import Serializable\n", 10 | "from dataclasses import dataclass\n", 11 | "\n", 12 | "@dataclass\n", 13 | "class Person(Serializable):\n", 14 | " name: str = \"Bob\"\n", 15 | " age: int = 20\n", 16 | "\n", 17 | "@dataclass\n", 18 | "class Student(Person):\n", 19 | " domain: str = \"Computer Science\"\n", 20 | " average_grade: float = 0.80\n" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": { 27 | "tags": [] 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "\n", 32 | "# Serialization:\n", 33 | "# We can dump to yaml or json:\n", 34 | "charlie = Person(name=\"Charlie\")\n", 35 | "print(charlie.dumps_yaml())\n", 36 | "print(charlie.dumps_json())\n", 37 | "print(charlie.dumps()) # JSON by default" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": { 44 | "tags": [] 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "# Deserialization:\n", 49 | "bob = Student()\n", 50 | "bob.save(\"bob.yaml\")\n", 51 | "bob.save(\"bob.json\")\n", 52 | "# Can load a Student from the base class: this will use the first subclass\n", 53 | "# that has all the required fields.\n", 54 | "_bob = Person.load(\"bob.yaml\", drop_extra_fields=False)\n", 55 | "assert isinstance(_bob, Student), _bob\n", 56 | "assert _bob == bob" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [] 65 | } 66 | ], 67 | "metadata": { 68 | "kernelspec": { 69 | "display_name": "Python 3", 70 | "name": "python3" 71 | }, 72 | "language_info": { 73 | "codemirror_mode": { 74 | "name": "ipython", 75 | "version": 3 76 | }, 77 | "file_extension": ".py", 78 | "mimetype": "text/x-python", 79 | "name": "python", 80 | "nbconvert_exporter": "python", 81 | "pygments_lexer": "ipython3", 82 | "version": "3.6.10-final" 83 | } 84 | }, 85 | "nbformat": 4, 86 | "nbformat_minor": 2 87 | } 88 | -------------------------------------------------------------------------------- /examples/serialization/serialization_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | from simple_parsing.helpers import Serializable 5 | 6 | 7 | @dataclass 8 | class Person(Serializable): 9 | name: str = "Bob" 10 | age: int = 20 11 | 12 | 13 | @dataclass 14 | class Student(Person): 15 | domain: str = "Computer Science" 16 | average_grade: float = 0.80 17 | 18 | 19 | expected: str = "" 20 | 21 | # Serialization: 22 | # We can dump to yaml or json: 23 | charlie = Person(name="Charlie") 24 | print(charlie.dumps_yaml()) 25 | expected += """\ 26 | age: 20 27 | name: Charlie 28 | 29 | """ 30 | print(charlie.dumps_json()) 31 | expected += """\ 32 | {"name": "Charlie", "age": 20} 33 | """ 34 | print(charlie.dumps()) # JSON by default 35 | expected += """\ 36 | {"name": "Charlie", "age": 20} 37 | """ 38 | # Deserialization: 39 | bob = Student() 40 | print(bob) 41 | expected += """\ 42 | Student(name='Bob', age=20, domain='Computer Science', average_grade=0.8) 43 | """ 44 | 45 | bob.save("bob.yaml") 46 | # Can load a Student from the base class: this will use the first subclass 47 | # that has all the required fields. 48 | _bob = Person.load("bob.yaml", drop_extra_fields=False) 49 | assert isinstance(_bob, Student), _bob 50 | assert _bob == bob 51 | 52 | # Cleaning up 53 | 54 | os.remove("bob.yaml") 55 | -------------------------------------------------------------------------------- /examples/simple/_before.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 2 | 3 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 4 | 5 | group = parser.add_argument_group( 6 | title="Options", description="Set of options for the training of a Model." 7 | ) 8 | group.add_argument("--num_layers", default=4, help="Number of layers to use") 9 | group.add_argument("--num_units", default=64, help="Number of units per layer") 10 | group.add_argument("--learning_rate", default=0.001, help="Learning rate to use") 11 | group.add_argument( 12 | "--optimizer", 13 | default="ADAM", 14 | choices=["ADAM", "SGD", "RMSPROP"], 15 | help="Which optimizer to use", 16 | ) 17 | 18 | args = parser.parse_args() 19 | print(args) 20 | expected = """ 21 | Namespace(learning_rate=0.001, num_layers=4, num_units=64, optimizer='ADAM') 22 | """ 23 | 24 | parser.print_help() 25 | expected += """ 26 | usage: _before.py [-h] [--num_layers NUM_LAYERS] [--num_units NUM_UNITS] 27 | [--learning_rate LEARNING_RATE] 28 | [--optimizer {ADAM,SGD,RMSPROP}] 29 | 30 | optional arguments: 31 | -h, --help show this help message and exit 32 | 33 | Options: 34 | Set of options for the training of a Model. 35 | 36 | --num_layers NUM_LAYERS 37 | Number of layers to use (default: 4) 38 | --num_units NUM_UNITS 39 | Number of units per layer (default: 64) 40 | --learning_rate LEARNING_RATE 41 | Learning rate to use (default: 0.001) 42 | --optimizer {ADAM,SGD,RMSPROP} 43 | Which optimizer to use (default: ADAM) 44 | """ 45 | -------------------------------------------------------------------------------- /examples/simple/basic.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from simple_parsing import ArgumentParser 4 | 5 | 6 | @dataclass 7 | class HParams: 8 | """Set of options for the training of a Model.""" 9 | 10 | num_layers: int = 4 11 | num_units: int = 64 12 | optimizer: str = "ADAM" 13 | learning_rate: float = 0.001 14 | 15 | 16 | parser = ArgumentParser() 17 | parser.add_arguments(HParams, dest="hparams") 18 | 19 | 20 | args = parser.parse_args() 21 | 22 | 23 | print(args.hparams) 24 | expected = """ 25 | HParams(num_layers=4, num_units=64, optimizer='ADAM', learning_rate=0.001) 26 | """ 27 | 28 | 29 | parser.print_help() 30 | expected += """ 31 | usage: basic.py [-h] [--num_layers int] [--num_units int] [--optimizer str] 32 | [--learning_rate float] 33 | 34 | optional arguments: 35 | -h, --help show this help message and exit 36 | 37 | HParams ['hparams']: 38 | Set of options for the training of a Model. 39 | 40 | --num_layers int (default: 4) 41 | --num_units int (default: 64) 42 | --optimizer str (default: ADAM) 43 | --learning_rate float 44 | (default: 0.001) 45 | """ 46 | 47 | 48 | print(parser.equivalent_argparse_code()) 49 | expected += """ 50 | parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 51 | 52 | group = parser.add_argument_group(title="HParams ['hparams']", description="Set of options for the training of a Model.") 53 | group.add_argument(*['--num_layers'], **{'type': int, 'required': False, 'dest': 'hparams.num_layers', 'default': 4, 'help': ' '}) 54 | group.add_argument(*['--num_units'], **{'type': int, 'required': False, 'dest': 'hparams.num_units', 'default': 64, 'help': ' '}) 55 | group.add_argument(*['--optimizer'], **{'type': str, 'required': False, 'dest': 'hparams.optimizer', 'default': 'ADAM', 'help': ' '}) 56 | group.add_argument(*['--learning_rate'], **{'type': float, 'required': False, 'dest': 'hparams.learning_rate', 'default': 0.001, 'help': ' '}) 57 | 58 | args = parser.parse_args() 59 | print(args) 60 | """ 61 | -------------------------------------------------------------------------------- /examples/simple/choice.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from simple_parsing import ArgumentParser, choice 4 | 5 | 6 | @dataclass 7 | class HParams: 8 | """Set of options for the training of a Model.""" 9 | 10 | num_layers: int = 4 11 | num_units: int = 64 12 | optimizer: str = choice("ADAM", "SGD", "RMSPROP", default="ADAM") 13 | learning_rate: float = 0.001 14 | 15 | 16 | parser = ArgumentParser() 17 | parser.add_arguments(HParams, dest="hparams") 18 | args = parser.parse_args() 19 | 20 | print(args.hparams) 21 | expected = """ 22 | HParams(num_layers=4, num_units=64, optimizer='ADAM', learning_rate=0.001) 23 | """ 24 | 25 | parser.print_help() 26 | expected += """ 27 | usage: choice.py [-h] [--num_layers int] [--num_units int] 28 | [--optimizer {ADAM,SGD,RMSPROP}] [--learning_rate float] 29 | 30 | optional arguments: 31 | -h, --help show this help message and exit 32 | 33 | HParams ['hparams']: 34 | Set of options for the training of a Model. 35 | 36 | --num_layers int (default: 4) 37 | --num_units int (default: 64) 38 | --optimizer {ADAM,SGD,RMSPROP} 39 | (default: ADAM) 40 | --learning_rate float 41 | (default: 0.001) 42 | """ 43 | -------------------------------------------------------------------------------- /examples/simple/flag.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from simple_parsing import ArgumentParser 4 | from simple_parsing.helpers import flag 5 | 6 | 7 | def parse(cls, args: str = ""): 8 | """Removes some boilerplate code from the examples.""" 9 | parser = ArgumentParser() # Create an argument parser 10 | parser.add_arguments(cls, dest="hparams") # add arguments for the dataclass 11 | ns = parser.parse_args(args.split()) # parse the given `args` 12 | return ns.hparams 13 | 14 | 15 | @dataclass 16 | class HParams: 17 | """Set of options for the training of a Model.""" 18 | 19 | num_layers: int = 4 20 | num_units: int = 64 21 | optimizer: str = "ADAM" 22 | learning_rate: float = 0.001 23 | train: bool = flag(default=True, negative_prefix="--no-") 24 | 25 | 26 | # Example 1 using default flag, i.e. train set to True 27 | args = parse(HParams) 28 | 29 | print(args) 30 | expected = """ 31 | HParams(num_layers=4, num_units=64, optimizer='ADAM', learning_rate=0.001, train=True) 32 | """ 33 | 34 | # Example 2 using the flags negative prefix 35 | assert parse(HParams, "--no-train") == HParams(train=False) 36 | 37 | 38 | # showing what --help outputs 39 | parser = ArgumentParser() # Create an argument parser 40 | parser.add_arguments(HParams, dest="hparams") # add arguments for the dataclass 41 | parser.print_help() 42 | expected += """ 43 | usage: flag.py [-h] [--num_layers int] [--num_units int] [--optimizer str] 44 | [--learning_rate float] [--train bool] 45 | 46 | optional arguments: 47 | -h, --help show this help message and exit 48 | 49 | HParams ['hparams']: 50 | Set of options for the training of a Model. 51 | 52 | --num_layers int (default: 4) 53 | --num_units int (default: 64) 54 | --optimizer str (default: ADAM) 55 | --learning_rate float 56 | (default: 0.001) 57 | --train bool, --no-train bool 58 | (default: True) 59 | """ 60 | -------------------------------------------------------------------------------- /examples/simple/help.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from simple_parsing import ArgumentParser 4 | 5 | 6 | @dataclass 7 | class HParams: 8 | """Set of options for the training of a ML Model. 9 | 10 | Some more detailed description can be placed here, and will show-up in 11 | the auto-generated "--help" text. 12 | 13 | Some other **cool** uses for this space: 14 | - Provide links to previous works: (easy to click on from the command-line) 15 | - MAML: https://arxiv.org/abs/1703.03400 16 | - google: https://www.google.com 17 | - This can also interact nicely with documentation tools like Sphinx! 18 | For instance, you could add links to other parts of your documentation. 19 | 20 | The default HelpFormatter used by `simple_parsing` will keep the formatting 21 | of this section intact, will add an indicator of the default values, and 22 | will use the name of the attribute's type as the metavar in the help string. 23 | For more info, check out the `SimpleFormatter` class found in 24 | ./simple_parsing/utils.py 25 | """ 26 | 27 | num_layers: int = 4 # Number of layers in the model. 28 | num_units: int = 64 # Number of units (neurons) per layer. 29 | optimizer: str = "ADAM" # Which optimizer to use. 30 | learning_rate: float = 0.001 # Learning_rate used by the optimizer. 31 | 32 | alpha: float = 0.05 # TODO: Tune this. (This doesn't appear in '--help') 33 | """A detailed description of this new 'alpha' parameter, which can potentially span multiple 34 | lines.""" 35 | 36 | 37 | parser = ArgumentParser() 38 | parser.add_arguments(HParams, dest="hparams") 39 | args = parser.parse_args() 40 | 41 | print(args.hparams) 42 | expected = """ 43 | HParams(num_layers=4, num_units=64, optimizer='ADAM', learning_rate=0.001, alpha=0.05) 44 | """ 45 | 46 | parser.print_help() 47 | expected += """ 48 | usage: help.py [-h] [--num_layers int] [--num_units int] [--optimizer str] 49 | [--learning_rate float] [--alpha float] 50 | 51 | optional arguments: 52 | -h, --help show this help message and exit 53 | 54 | HParams ['hparams']: 55 | Set of options for the training of a ML Model. 56 | 57 | Some more detailed description can be placed here, and will show-up in 58 | the auto-generated "--help" text. 59 | 60 | Some other **cool** uses for this space: 61 | - Provide links to previous works: (easy to click on from the command-line) 62 | - MAML: https://arxiv.org/abs/1703.03400 63 | - google: https://www.google.com 64 | - This can also interact nicely with documentation tools like Sphinx! 65 | For instance, you could add links to other parts of your documentation. 66 | 67 | The default HelpFormatter used by `simple_parsing` will keep the formatting 68 | of this section intact, will add an indicator of the default values, and 69 | will use the name of the attribute's type as the metavar in the help string. 70 | For more info, check out the `SimpleFormatter` class found in 71 | ./simple_parsing/utils.py 72 | 73 | 74 | --num_layers int Number of layers in the model. (default: 4) 75 | --num_units int Number of units (neurons) per layer. (default: 64) 76 | --optimizer str Which optimizer to use. (default: ADAM) 77 | --learning_rate float 78 | Learning_rate used by the optimizer. (default: 0.001) 79 | --alpha float A detailed description of this new 'alpha' parameter, 80 | which can potentially span multiple lines. (default: 81 | 0.05) 82 | """ 83 | -------------------------------------------------------------------------------- /examples/simple/inheritance.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from simple_parsing import ArgumentParser 4 | 5 | 6 | @dataclass 7 | class Method: 8 | """Set of options for the training of a Model.""" 9 | 10 | num_layers: int = 4 11 | num_units: int = 64 12 | optimizer: str = "ADAM" 13 | learning_rate: float = 0.001 14 | 15 | 16 | @dataclass 17 | class MAML(Method): 18 | """Overwrites some of the default values and adds new arguments/attributes.""" 19 | 20 | num_layers: int = 6 21 | num_units: int = 128 22 | 23 | # method 24 | name: str = "MAML" 25 | 26 | 27 | parser = ArgumentParser() 28 | parser.add_arguments(MAML, dest="hparams") 29 | args = parser.parse_args() 30 | 31 | 32 | print(args.hparams) 33 | expected = """ 34 | MAML(num_layers=6, num_units=128, optimizer='ADAM', learning_rate=0.001, name='MAML') 35 | """ 36 | 37 | parser.print_help() 38 | expected += """ 39 | usage: inheritance.py [-h] [--num_layers int] [--num_units int] 40 | [--optimizer str] [--learning_rate float] [--name str] 41 | 42 | optional arguments: 43 | -h, --help show this help message and exit 44 | 45 | MAML ['hparams']: 46 | Overwrites some of the default values and adds new arguments/attributes. 47 | 48 | --num_layers int (default: 6) 49 | --num_units int (default: 128) 50 | --optimizer str (default: ADAM) 51 | --learning_rate float 52 | (default: 0.001) 53 | --name str method (default: MAML) 54 | """ 55 | 56 | 57 | print(parser.equivalent_argparse_code()) 58 | expected += """ 59 | parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 60 | 61 | group = parser.add_argument_group(title="MAML ['hparams']", description="Overwrites some of the default values and adds new arguments/attributes.") 62 | group.add_argument(*['--num_layers'], **{'type': int, 'required': False, 'dest': 'hparams.num_layers', 'default': 6, 'help': ' '}) 63 | group.add_argument(*['--num_units'], **{'type': int, 'required': False, 'dest': 'hparams.num_units', 'default': 128, 'help': ' '}) 64 | group.add_argument(*['--optimizer'], **{'type': str, 'required': False, 'dest': 'hparams.optimizer', 'default': 'ADAM', 'help': ' '}) 65 | group.add_argument(*['--learning_rate'], **{'type': float, 'required': False, 'dest': 'hparams.learning_rate', 'default': 0.001, 'help': ' '}) 66 | group.add_argument(*['--name'], **{'type': str, 'required': False, 'dest': 'hparams.name', 'default': 'MAML', 'help': 'method'}) 67 | 68 | args = parser.parse_args() 69 | print(args) 70 | """ 71 | -------------------------------------------------------------------------------- /examples/simple/option_strings.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from simple_parsing import ArgumentParser, field 4 | from simple_parsing.wrappers.field_wrapper import ArgumentGenerationMode 5 | 6 | 7 | @dataclass 8 | class HParams: 9 | """Set of options for the training of a Model.""" 10 | 11 | num_layers: int = field(4, alias="-n") 12 | num_units: int = field(64, alias="-u") 13 | optimizer: str = field("ADAM", alias=["-o", "--opt"]) 14 | learning_rate: float = field(0.001, alias="-lr") 15 | 16 | 17 | parser = ArgumentParser() 18 | parser.add_arguments(HParams, dest="hparams") 19 | args = parser.parse_args() 20 | 21 | print(args.hparams) 22 | expected = """ 23 | HParams(num_layers=4, num_units=64, optimizer='ADAM', learning_rate=0.001) 24 | """ 25 | 26 | parser.print_help() 27 | expected += """ 28 | usage: option_strings.py [-h] [-n int] [-u int] [-o str] [-lr float] 29 | 30 | optional arguments: 31 | -h, --help show this help message and exit 32 | 33 | HParams ['hparams']: 34 | Set of options for the training of a Model. 35 | 36 | -n int, --num_layers int 37 | (default: 4) 38 | -u int, --num_units int 39 | (default: 64) 40 | -o str, --opt str, --optimizer str 41 | (default: ADAM) 42 | -lr float, --learning_rate float 43 | (default: 0.001) 44 | """ 45 | 46 | # Now if we wanted to also be able to set the arguments using their full paths: 47 | parser = ArgumentParser(argument_generation_mode=ArgumentGenerationMode.BOTH) 48 | parser.add_arguments(HParams, dest="hparams") 49 | parser.print_help() 50 | expected += """ 51 | usage: option_strings.py [-h] [-n int] [-u int] [-o str] [-lr float] 52 | 53 | optional arguments: 54 | -h, --help show this help message and exit 55 | 56 | HParams ['hparams']: 57 | Set of options for the training of a Model. 58 | 59 | -n int, --num_layers int, --hparams.num_layers int 60 | (default: 4) 61 | -u int, --num_units int, --hparams.num_units int 62 | (default: 64) 63 | -o str, --opt str, --optimizer str, --hparams.optimizer str 64 | (default: ADAM) 65 | -lr float, --learning_rate float, --hparams.learning_rate float 66 | (default: 0.001) 67 | """ 68 | -------------------------------------------------------------------------------- /examples/simple/reuse.py: -------------------------------------------------------------------------------- 1 | """Modular and reusable! With SimpleParsing, you can easily add similar groups of command-line 2 | arguments by simply reusing the dataclasses you define! There is no longer need for any copy- 3 | pasting of blocks, or adding prefixes everywhere by hand. 4 | 5 | Instead, the ArgumentParser detects when more than one instance of the same `@dataclass` needs to 6 | be parsed, and automatically adds the relevant prefixes to the arguments for you. 7 | """ 8 | 9 | from dataclasses import dataclass 10 | 11 | from simple_parsing import ArgumentParser 12 | 13 | 14 | @dataclass 15 | class HParams: 16 | """Set of options for the training of a Model.""" 17 | 18 | num_layers: int = 4 19 | num_units: int = 64 20 | optimizer: str = "ADAM" 21 | learning_rate: float = 0.001 22 | 23 | 24 | parser = ArgumentParser() 25 | parser.add_arguments(HParams, dest="train") 26 | parser.add_arguments(HParams, dest="valid") 27 | args = parser.parse_args() 28 | 29 | print(args.train) 30 | print(args.valid) 31 | expected = """ 32 | HParams(num_layers=4, num_units=64, optimizer='ADAM', learning_rate=0.001) 33 | HParams(num_layers=4, num_units=64, optimizer='ADAM', learning_rate=0.001) 34 | """ 35 | 36 | parser.print_help() 37 | expected += """ 38 | usage: reuse.py [-h] [--train.num_layers int] [--train.num_units int] 39 | [--train.optimizer str] [--train.learning_rate float] 40 | [--valid.num_layers int] [--valid.num_units int] 41 | [--valid.optimizer str] [--valid.learning_rate float] 42 | 43 | optional arguments: 44 | -h, --help show this help message and exit 45 | 46 | HParams ['train']: 47 | Set of options for the training of a Model. 48 | 49 | --train.num_layers int 50 | (default: 4) 51 | --train.num_units int 52 | (default: 64) 53 | --train.optimizer str 54 | (default: ADAM) 55 | --train.learning_rate float 56 | (default: 0.001) 57 | 58 | HParams ['valid']: 59 | Set of options for the training of a Model. 60 | 61 | --valid.num_layers int 62 | (default: 4) 63 | --valid.num_units int 64 | (default: 64) 65 | --valid.optimizer str 66 | (default: ADAM) 67 | --valid.learning_rate float 68 | (default: 0.001) 69 | """ 70 | -------------------------------------------------------------------------------- /examples/simple/to_json.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import asdict, dataclass 3 | 4 | from simple_parsing import ArgumentParser 5 | from simple_parsing.helpers import Serializable 6 | 7 | 8 | @dataclass 9 | class HParams(Serializable): 10 | """Set of options for the training of a Model.""" 11 | 12 | num_layers: int = 4 13 | num_units: int = 64 14 | optimizer: str = "ADAM" 15 | learning_rate: float = 0.001 16 | 17 | 18 | parser = ArgumentParser() 19 | parser.add_arguments(HParams, dest="hparams") 20 | args = parser.parse_args() 21 | 22 | 23 | hparams: HParams = args.hparams 24 | 25 | 26 | print(asdict(hparams)) 27 | expected = """ 28 | {'num_layers': 4, 'num_units': 64, 'optimizer': 'ADAM', 'learning_rate': 0.001} 29 | """ 30 | 31 | 32 | hparams.save_json("config.json") 33 | hparams_ = HParams.load_json("config.json") 34 | assert hparams == hparams_ 35 | 36 | 37 | os.remove("config.json") 38 | -------------------------------------------------------------------------------- /examples/subgroups/subgroups_example.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | 6 | from simple_parsing import ArgumentParser, subgroups 7 | from simple_parsing.wrappers.field_wrapper import ArgumentGenerationMode, NestedMode 8 | 9 | 10 | @dataclass 11 | class ModelConfig: 12 | ... 13 | 14 | 15 | @dataclass 16 | class DatasetConfig: 17 | ... 18 | 19 | 20 | @dataclass 21 | class ModelAConfig(ModelConfig): 22 | lr: float = 3e-4 23 | optimizer: str = "Adam" 24 | betas: tuple[float, float] = (0.9, 0.999) 25 | 26 | 27 | @dataclass 28 | class ModelBConfig(ModelConfig): 29 | lr: float = 1e-3 30 | optimizer: str = "SGD" 31 | momentum: float = 1.234 32 | 33 | 34 | @dataclass 35 | class Dataset1Config(DatasetConfig): 36 | data_dir: str | Path = "data/foo" 37 | foo: bool = False 38 | 39 | 40 | @dataclass 41 | class Dataset2Config(DatasetConfig): 42 | data_dir: str | Path = "data/bar" 43 | bar: float = 1.2 44 | 45 | 46 | @dataclass 47 | class Config: 48 | # Which model to use 49 | model: ModelConfig = subgroups( 50 | {"model_a": ModelAConfig, "model_b": ModelBConfig}, 51 | default_factory=ModelAConfig, 52 | ) 53 | 54 | # Which dataset to use 55 | dataset: DatasetConfig = subgroups( 56 | {"dataset_1": Dataset1Config, "dataset_2": Dataset2Config}, 57 | default_factory=Dataset2Config, 58 | ) 59 | 60 | 61 | parser = ArgumentParser( 62 | argument_generation_mode=ArgumentGenerationMode.NESTED, nested_mode=NestedMode.WITHOUT_ROOT 63 | ) 64 | parser.add_arguments(Config, dest="config") 65 | args = parser.parse_args() 66 | 67 | config: Config = args.config 68 | 69 | print(config) 70 | expected = """ 71 | Config(model=ModelAConfig(lr=0.0003, optimizer='Adam', betas=(0.9, 0.999)), dataset=Dataset2Config(data_dir='data/bar', bar=1.2)) 72 | """ 73 | 74 | parser.print_help() 75 | expected += """ 76 | usage: subgroups_example.py [-h] [--model {model_a,model_b}] [--dataset {dataset_1,dataset_2}] [--model.lr float] [--model.optimizer str] [--model.betas float float] 77 | [--dataset.data_dir str|Path] [--dataset.bar float] 78 | 79 | options: 80 | -h, --help show this help message and exit 81 | 82 | Config ['config']: 83 | Config(model: 'ModelConfig' = ModelAConfig(lr=0.0003, optimizer='Adam', betas=(0.9, 0.999)), dataset: 'DatasetConfig' = Dataset2Config(data_dir='data/bar', bar=1.2)) 84 | 85 | --model {model_a,model_b} 86 | Which model to use (default: ModelAConfig(lr=0.0003, optimizer='Adam', betas=(0.9, 0.999))) 87 | --dataset {dataset_1,dataset_2} 88 | Which dataset to use (default: Dataset2Config(data_dir='data/bar', bar=1.2)) 89 | 90 | ModelAConfig ['config.model']: 91 | ModelAConfig(lr: 'float' = 0.0003, optimizer: 'str' = 'Adam', betas: 'tuple[float, float]' = (0.9, 0.999)) 92 | 93 | --model.lr float (default: 0.0003) 94 | --model.optimizer str 95 | (default: Adam) 96 | --model.betas float float 97 | (default: (0.9, 0.999)) 98 | 99 | Dataset2Config ['config.dataset']: 100 | Dataset2Config(data_dir: 'str | Path' = 'data/bar', bar: 'float' = 1.2) 101 | 102 | --dataset.data_dir str|Path 103 | (default: data/bar) 104 | --dataset.bar float (default: 1.2) 105 | """ 106 | -------------------------------------------------------------------------------- /examples/subparsers/README.md: -------------------------------------------------------------------------------- 1 | ### [(Examples Home)](../README.md) 2 | 3 | # Creating Commands with Subparsers 4 | 5 | Subparsers are one of the more advanced features of `argparse`. They allow the creation of subcommands, each having their own set of arguments. The `git` command, for instance, takes different arguments than the `pull` subcommand in `git pull`. 6 | 7 | For some more info on subparsers, check out the [argparse documentation](https://docs.python.org/3/library/argparse.html#argparse.ArgumentParser.add_subparsers). 8 | 9 | With `simple-parsing`, subparsers can easily be created by using a `Union` type annotation on a dataclass attribute. By annotating a variable with a Union type, for example `x: Union[T1, T2]`, we simply state that `x` can either be of type `T1` or `T2`. When the arguments to the `Union` type **are all dataclasses**, `simple-parsing` creates subparsers for each dataclass type, using the lowercased class name as the command name by default. 10 | 11 | If you want to extend or change this behaviour (to have "t" and "train" map to the same training subcommand, for example), use the `subparsers` function, passing in a dictionary mapping command names to the appropriate type. 12 | 13 | 14 | 15 | ## Example: 16 | 17 | ```python 18 | from dataclasses import dataclass 19 | from typing import * 20 | from pathlib import Path 21 | from simple_parsing import ArgumentParser, subparsers 22 | 23 | @dataclass 24 | class Train: 25 | """Example of a command to start a Training run.""" 26 | # the training directory 27 | train_dir: Path = Path("~/train") 28 | 29 | def execute(self): 30 | print(f"Training in directory {self.train_dir}") 31 | 32 | 33 | @dataclass 34 | class Test: 35 | """Example of a command to start a Test run.""" 36 | # the testing directory 37 | test_dir: Path = Path("~/train") 38 | 39 | def execute(self): 40 | print(f"Testing in directory {self.test_dir}") 41 | 42 | 43 | @dataclass 44 | class Program: 45 | """Some top-level command""" 46 | command: Union[Train, Test] 47 | verbose: bool = False # log additional messages in the console. 48 | 49 | def execute(self): 50 | print(f"Executing Program (verbose: {self.verbose})") 51 | return self.command.execute() 52 | 53 | 54 | parser = ArgumentParser() 55 | parser.add_arguments(Program, dest="prog") 56 | args = parser.parse_args() 57 | prog: Program = args.prog 58 | 59 | print("prog:", prog) 60 | prog.execute() 61 | ``` 62 | 63 | Here are some usage examples: 64 | 65 | - Executing the training command: 66 | 67 | ```console 68 | $ python examples/subparsers/subparsers_example.py train 69 | prog: Program(command=Train(train_dir=PosixPath('~/train')), verbose=False) 70 | Executing Program (verbose: False) 71 | Training in directory ~/train 72 | ``` 73 | 74 | - Passing a custom training directory: 75 | 76 | ```console 77 | $ python examples/subparsers/subparsers_example.py train --train_dir ~/train 78 | prog: Program(command=Train(train_dir=PosixPath('/home/fabrice/train')), verbose=False) 79 | Executing Program (verbose: False) 80 | Training in directory /home/fabrice/train 81 | ``` 82 | 83 | - Getting help for a subcommand: 84 | 85 | ```console 86 | $ python examples/subparsers/subparsers_example.py train --help 87 | usage: subparsers_example.py train [-h] [--train_dir Path] 88 | 89 | optional arguments: 90 | -h, --help show this help message and exit 91 | 92 | Train ['prog.command']: 93 | Example of a command to start a Training run. 94 | 95 | --train_dir Path the training directory (default: ~/train) 96 | ``` 97 | 98 | - Getting Help for the parent command: 99 | 100 | ```console 101 | $ python examples/subparsers/subparsers_example.py --help 102 | usage: subparsers_example.py [-h] [--verbose [str2bool]] {train,test} ... 103 | 104 | optional arguments: 105 | -h, --help show this help message and exit 106 | 107 | Program ['prog']: 108 | Some top-level command 109 | 110 | --verbose [str2bool] log additional messages in the console. (default: 111 | False) 112 | 113 | command: 114 | {train,test} 115 | ``` 116 | -------------------------------------------------------------------------------- /examples/subparsers/optional_subparsers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Union 3 | 4 | from simple_parsing import ArgumentParser 5 | from simple_parsing.helpers.fields import subparsers 6 | 7 | 8 | @dataclass 9 | class AConfig: 10 | foo: int = 123 11 | 12 | 13 | @dataclass 14 | class BConfig: 15 | bar: float = 4.56 16 | 17 | 18 | @dataclass 19 | class Options: 20 | config: Union[AConfig, BConfig] = subparsers( 21 | {"a": AConfig, "b": BConfig}, default_factory=AConfig 22 | ) 23 | 24 | 25 | def main(): 26 | parser = ArgumentParser() 27 | 28 | parser.add_arguments(Options, dest="options") 29 | 30 | # Equivalent to: 31 | # subparsers = parser.add_subparsers(title="config", required=False) 32 | # parser.set_defaults(config=AConfig()) 33 | # a_parser = subparsers.add_parser("a", help="A help.") 34 | # a_parser.add_arguments(AConfig, dest="config") 35 | # b_parser = subparsers.add_parser("b", help="B help.") 36 | # b_parser.add_arguments(BConfig, dest="config") 37 | 38 | args = parser.parse_args() 39 | 40 | print(args) 41 | options: Options = args.options 42 | print(options) 43 | 44 | 45 | main() 46 | expected = """ 47 | Namespace(options=Options(config=AConfig(foo=123))) 48 | Options(config=AConfig(foo=123)) 49 | """ 50 | -------------------------------------------------------------------------------- /examples/subparsers/subparsers_example.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | from simple_parsing import ArgumentParser 6 | 7 | 8 | @dataclass 9 | class Train: 10 | """Example of a command to start a Training run.""" 11 | 12 | # the training directory 13 | train_dir: Path = Path("~/train") 14 | 15 | def execute(self): 16 | print(f"Training in directory {self.train_dir}") 17 | 18 | 19 | @dataclass 20 | class Test: 21 | """Example of a command to start a Test run.""" 22 | 23 | # the testing directory 24 | test_dir: Path = Path("~/train") 25 | 26 | def execute(self): 27 | print(f"Testing in directory {self.test_dir}") 28 | 29 | 30 | @dataclass 31 | class Program: 32 | """Some top-level command.""" 33 | 34 | command: Union[Train, Test] 35 | verbose: bool = False # log additional messages in the console. 36 | 37 | def execute(self): 38 | print(f"Program (verbose: {self.verbose})") 39 | return self.command.execute() 40 | 41 | 42 | parser = ArgumentParser() 43 | parser.add_arguments(Program, dest="prog") 44 | args = parser.parse_args() 45 | prog: Program = args.prog 46 | 47 | print("prog:", prog) 48 | prog.execute() 49 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | authors = [ 3 | { name = "Fabrice Normandin", email = "fabrice.normandin@gmail.com" }, 4 | ] 5 | name = "simple-parsing" 6 | dynamic = ["version"] 7 | description = "A small utility to simplify and clean up argument parsing scripts." 8 | readme = "README.md" 9 | requires-python = ">=3.9" 10 | dependencies = ["docstring-parser~=0.15", "typing-extensions>=4.5.0"] 11 | license = { file = "LICENSE" } 12 | 13 | [project.optional-dependencies] 14 | yaml = ["pyyaml>=6.0.2"] 15 | toml = ["tomli>=2.2.1", "tomli-w>=1.0.0"] 16 | 17 | [build-system] 18 | requires = ["hatchling", "uv-dynamic-versioning"] 19 | build-backend = "hatchling.build" 20 | 21 | [tool.hatch.version] 22 | source = "uv-dynamic-versioning" 23 | 24 | [tool.ruff] 25 | line-length = 99 26 | 27 | [tool.pytest] 28 | addopts = ["--doctest-modules", "--benchmark-autosave"] 29 | testpaths = ["test", "simple_parsing"] 30 | norecursedirs = ["examples", "docs"] 31 | 32 | [tool.ruff.lint] 33 | select = ["E4", "E7", "E9", "I", "UP"] 34 | 35 | [tool.docformatter] 36 | in-place = true 37 | wrap-summaries = 99 38 | wrap-descriptions = 99 39 | 40 | [tool.codespell] 41 | skip = ["logs/**", "data/**"] 42 | 43 | [dependency-groups] 44 | dev = [ 45 | "matplotlib>=3.9.4", 46 | "numpy>=2.0.2", 47 | "pytest>=8.3.4", 48 | "pytest-benchmark>=5.1.0", 49 | "pytest-cov>=6.0.0", 50 | "pytest-regressions>=2.7.0", 51 | "pytest-xdist>=3.6.1", 52 | ] 53 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | orion 4 | -------------------------------------------------------------------------------- /simple_parsing/__init__.py: -------------------------------------------------------------------------------- 1 | """Simple, Elegant Argument parsing. 2 | 3 | @author: Fabrice Normandin 4 | """ 5 | from . import helpers, utils, wrappers 6 | from .conflicts import ConflictResolution 7 | from .decorators import main 8 | from .help_formatter import SimpleHelpFormatter 9 | from .helpers import ( 10 | Partial, 11 | Serializable, 12 | choice, 13 | config_for, 14 | field, 15 | flag, 16 | list_field, 17 | mutable_field, 18 | subgroups, 19 | subparsers, 20 | ) 21 | from .parsing import ( 22 | ArgumentGenerationMode, 23 | ArgumentParser, 24 | DashVariant, 25 | NestedMode, 26 | ParsingError, 27 | parse, 28 | parse_known_args, 29 | ) 30 | from .replace import replace, replace_subgroups 31 | from .utils import InconsistentArgumentError 32 | 33 | __all__ = [ 34 | "ArgumentGenerationMode", 35 | "ArgumentParser", 36 | "choice", 37 | "config_for", 38 | "ConflictResolution", 39 | "DashVariant", 40 | "field", 41 | "flag", 42 | "helpers", 43 | "InconsistentArgumentError", 44 | "list_field", 45 | "main", 46 | "mutable_field", 47 | "NestedMode", 48 | "parse_known_args", 49 | "parse", 50 | "ParsingError", 51 | "Partial", 52 | "replace", 53 | "replace_subgroups", 54 | "Serializable", 55 | "SimpleHelpFormatter", 56 | "subgroups", 57 | "subparsers", 58 | "utils", 59 | "wrappers", 60 | ] 61 | -------------------------------------------------------------------------------- /simple_parsing/annotation_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lebrice/SimpleParsing/2a69cebb38172b9317c0175aadde6d3b432371ce/simple_parsing/annotation_utils/__init__.py -------------------------------------------------------------------------------- /simple_parsing/help_formatter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from argparse import ONE_OR_MORE, OPTIONAL, PARSER, REMAINDER, ZERO_OR_MORE, Action 3 | from logging import getLogger 4 | from typing import Optional 5 | 6 | from .wrappers.field_metavar import get_metavar 7 | 8 | TEMPORARY_TOKEN = "<__TEMP__>" 9 | logger = getLogger(__name__) 10 | 11 | 12 | class SimpleHelpFormatter( 13 | argparse.ArgumentDefaultsHelpFormatter, 14 | argparse.MetavarTypeHelpFormatter, 15 | argparse.RawDescriptionHelpFormatter, 16 | ): 17 | """Little shorthand for using some useful HelpFormatters from argparse. 18 | 19 | This class inherits from argparse's `ArgumentDefaultHelpFormatter`, 20 | `MetavarTypeHelpFormatter` and `RawDescriptionHelpFormatter` classes. 21 | 22 | This produces the following resulting actions: 23 | - adds a "(default: xyz)" for each argument with a default 24 | - uses the name of the argument type as the metavar. For example, gives 25 | "-n int" instead of "-n N" in the usage and description of the arguments. 26 | - Conserves the formatting of the class and argument docstrings, if given. 27 | """ 28 | 29 | def _format_args(self, action: Action, default_metavar: str): 30 | _get_metavar = self._metavar_formatter(action, default_metavar) 31 | action_type = action.type 32 | 33 | metavar = action.metavar or get_metavar(action_type) 34 | if metavar and not action.choices: 35 | result = metavar 36 | elif action.nargs is None: 37 | result = "%s" % _get_metavar(1) 38 | elif action.nargs == OPTIONAL: 39 | result = "[%s]" % _get_metavar(1) 40 | elif action.nargs == ZERO_OR_MORE: 41 | result = "[%s [%s ...]]" % _get_metavar(2) # noqa: UP031 42 | elif action.nargs == ONE_OR_MORE: 43 | result = "%s [%s ...]" % _get_metavar(2) # noqa: UP031 44 | elif action.nargs == REMAINDER: 45 | result = "..." 46 | elif action.nargs == PARSER: 47 | result = "%s ..." % _get_metavar(1) 48 | else: 49 | formats = ["%s" for _ in range(action.nargs)] 50 | result = " ".join(formats) % _get_metavar(action.nargs) 51 | 52 | # logger.debug( 53 | # f"action type: {action_type}, Result: {result}, nargs: {action.nargs}, default metavar: {default_metavar}" 54 | # ) 55 | return result 56 | 57 | def _get_default_metavar_for_optional(self, action: argparse.Action): 58 | try: 59 | return super()._get_default_metavar_for_optional(action) 60 | except BaseException: 61 | logger.debug(f"Getting metavar for action with dest {action.dest}.") 62 | metavar = self._get_metavar_for_action(action) 63 | logger.debug(f"Result metavar: {metavar}") 64 | return metavar 65 | 66 | def _get_default_metavar_for_positional(self, action: argparse.Action): 67 | try: 68 | return super()._get_default_metavar_for_positional(action) 69 | except BaseException: 70 | logger.debug(f"Getting metavar for action with dest {action.dest}.") 71 | metavar = self._get_metavar_for_action(action) 72 | logger.debug(f"Result metavar: {metavar}") 73 | return metavar 74 | 75 | def _get_metavar_for_action(self, action: argparse.Action) -> str: 76 | return self._get_metavar_for_type(action.type) 77 | 78 | def _get_metavar_for_type(self, t: type) -> str: 79 | return get_metavar(t) or str(t) 80 | 81 | def _get_help_string(self, action: Action) -> Optional[str]: 82 | help = super()._get_help_string(action=action) 83 | if help is not None: 84 | help = help.replace(TEMPORARY_TOKEN, "") 85 | return help 86 | 87 | 88 | Formatter = SimpleHelpFormatter 89 | -------------------------------------------------------------------------------- /simple_parsing/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | """Collection of helper classes and functions to reduce boilerplate code.""" 2 | 3 | from .fields import * 4 | from .flatten import FlattenedAccess 5 | from .hparams import HyperParameters 6 | from .partial import Partial, config_for 7 | from .serialization import FrozenSerializable, Serializable, SimpleJsonEncoder, encode 8 | 9 | try: 10 | from .serialization import YamlSerializable 11 | except ImportError: 12 | pass 13 | 14 | # For backward compatibility purposes 15 | JsonSerializable = Serializable 16 | SimpleEncoder = SimpleJsonEncoder 17 | __all__ = [ 18 | "FlattenedAccess", 19 | "HyperParameters", 20 | "Partial", 21 | "config_for", 22 | "FrozenSerializable", 23 | "Serializable", 24 | "SimpleJsonEncoder", 25 | "encode", 26 | ] 27 | -------------------------------------------------------------------------------- /simple_parsing/helpers/hparams/__init__.py: -------------------------------------------------------------------------------- 1 | from .hparam import categorical, hparam, log_uniform, loguniform, uniform 2 | from .hyperparameters import HP, HyperParameters, Point 3 | from .priors import LogUniformPrior, UniformPrior 4 | 5 | __all__ = [ 6 | "categorical", 7 | "hparam", 8 | "log_uniform", 9 | "loguniform", 10 | "uniform", 11 | "HP", 12 | "HyperParameters", 13 | "Point", 14 | "LogUniformPrior", 15 | "UniformPrior", 16 | ] 17 | -------------------------------------------------------------------------------- /simple_parsing/helpers/hparams/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | def set_seed(seed: int) -> None: 5 | random.seed(seed) 6 | try: 7 | import numpy as np 8 | 9 | np.random.seed(seed) 10 | except ImportError: 11 | pass 12 | 13 | try: 14 | import torch 15 | except ImportError: 16 | pass 17 | else: 18 | try: 19 | torch.manual_seed(seed) 20 | if torch.cuda.is_available(): 21 | torch.cuda.manual_seed_all(seed) 22 | except AttributeError: 23 | pass 24 | -------------------------------------------------------------------------------- /simple_parsing/helpers/nested_partial.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Any, Generic, TypeVar 3 | 4 | _T = TypeVar("_T") 5 | 6 | 7 | class npartial(functools.partial, Generic[_T]): 8 | """Partial that also invokes partials in args and kwargs before feeding them to the function. 9 | 10 | Useful for creating nested partials, e.g.: 11 | 12 | 13 | >>> from dataclasses import dataclass, field 14 | >>> @dataclass 15 | ... class Value: 16 | ... v: int = 0 17 | >>> @dataclass 18 | ... class ValueWrapper: 19 | ... value: Value 20 | ... 21 | >>> from functools import partial 22 | >>> @dataclass 23 | ... class WithRegularPartial: 24 | ... wrapped: ValueWrapper = field( 25 | ... default_factory=partial(ValueWrapper, value=Value(v=123)), 26 | ... ) 27 | 28 | Here's the problem: This here is BAD! They both share the same instance of Value! 29 | 30 | >>> WithRegularPartial().wrapped.value is WithRegularPartial().wrapped.value 31 | True 32 | >>> @dataclass 33 | ... class WithNPartial: 34 | ... wrapped: ValueWrapper = field( 35 | ... default_factory=npartial(ValueWrapper, value=npartial(Value, v=123)), 36 | ... ) 37 | >>> WithNPartial().wrapped.value is WithNPartial().wrapped.value 38 | False 39 | 40 | This is fine now! 41 | """ 42 | 43 | def __call__(self, *args: Any, **keywords: Any) -> _T: 44 | keywords = {**self.keywords, **keywords} 45 | args = self.args + args 46 | args = tuple(arg() if isinstance(arg, npartial) else arg for arg in args) 47 | keywords = {k: v() if isinstance(v, npartial) else v for k, v in keywords.items()} 48 | return self.func(*args, **keywords) 49 | -------------------------------------------------------------------------------- /simple_parsing/helpers/serialization/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoding import * 2 | from .encoding import * 3 | from .serializable import ( 4 | FrozenSerializable, 5 | Serializable, 6 | SerializableMixin, 7 | dump, 8 | dump_json, 9 | dump_yaml, 10 | dumps, 11 | dumps_json, 12 | dumps_yaml, 13 | from_dict, 14 | load, 15 | load_json, 16 | load_yaml, 17 | save, 18 | save_json, 19 | save_yaml, 20 | to_dict, 21 | ) 22 | 23 | try: 24 | from .yaml_serialization import YamlSerializable 25 | except ImportError: 26 | pass 27 | JsonSerializable = Serializable 28 | -------------------------------------------------------------------------------- /simple_parsing/helpers/serialization/yaml_serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from logging import getLogger 4 | from pathlib import Path 5 | from typing import IO 6 | 7 | try: 8 | import yaml 9 | except ImportError: 10 | pass 11 | 12 | from .serializable import D, Serializable 13 | 14 | logger = getLogger(__name__) 15 | 16 | 17 | class YamlSerializable(Serializable): 18 | """Convenience class, just sets different `load_fn` and `dump_fn` defaults for the `dump`, 19 | `dumps`, `load`, `loads` methods of `Serializable`. 20 | 21 | Uses the `yaml.safe_load` and `yaml.dump` for loading and dumping. 22 | 23 | Requires the pyyaml package. 24 | """ 25 | 26 | def dump(self, fp: IO[str], dump_fn=None, **kwargs) -> None: 27 | if dump_fn is None: 28 | dump_fn = yaml.dump 29 | dump_fn(self.to_dict(), fp, **kwargs) 30 | 31 | def dumps(self, dump_fn=None, **kwargs) -> str: 32 | if dump_fn is None: 33 | dump_fn = yaml.dump 34 | return dump_fn(self.to_dict(), **kwargs) 35 | 36 | @classmethod 37 | def load( 38 | cls: type[D], 39 | path: Path | str | IO[str], 40 | drop_extra_fields: bool | None = None, 41 | load_fn=None, 42 | **kwargs, 43 | ) -> D: 44 | if load_fn is None: 45 | load_fn = yaml.safe_load 46 | 47 | return super().load(path, drop_extra_fields=drop_extra_fields, load_fn=load_fn, **kwargs) 48 | 49 | @classmethod 50 | def loads( 51 | cls: type[D], 52 | s: str, 53 | drop_extra_fields: bool | None = None, 54 | load_fn=None, 55 | **kwargs, 56 | ) -> D: 57 | if load_fn is None: 58 | load_fn = yaml.safe_load 59 | return super().loads(s, drop_extra_fields=drop_extra_fields, load_fn=load_fn, **kwargs) 60 | 61 | @classmethod 62 | def _load( 63 | cls: type[D], 64 | fp: IO[str], 65 | drop_extra_fields: bool | None = None, 66 | load_fn=None, 67 | **kwargs, 68 | ) -> D: 69 | if load_fn is None: 70 | load_fn = yaml.safe_load 71 | return super()._load(fp, drop_extra_fields=drop_extra_fields, load_fn=load_fn, **kwargs) 72 | -------------------------------------------------------------------------------- /simple_parsing/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lebrice/SimpleParsing/2a69cebb38172b9317c0175aadde6d3b432371ce/simple_parsing/py.typed -------------------------------------------------------------------------------- /simple_parsing/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataclass_wrapper import DataclassWrapper 2 | from .field_wrapper import DashVariant, FieldWrapper 3 | 4 | __all__ = ["DataclassWrapper", "FieldWrapper", "DashVariant"] 5 | -------------------------------------------------------------------------------- /simple_parsing/wrappers/field_metavar.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from logging import getLogger 3 | from typing import Any, Callable, Optional, TypeVar 4 | 5 | from ..utils import get_type_arguments, is_optional, is_tuple, is_union 6 | 7 | T = TypeVar("T") 8 | 9 | logger = getLogger(__name__) 10 | 11 | _new_metavars: dict[type[T], Optional[str]] = { 12 | # the 'primitive' types don't get a 'new' metavar. 13 | t: t.__name__ 14 | for t in [str, float, int, bytes] 15 | } 16 | 17 | 18 | def log_results(fn: Callable[[type], str]): 19 | @functools.wraps(fn) 20 | def _wrapped(t: type) -> str: 21 | result = fn(t) 22 | # logger.debug(f"Metavar for type {t}: {result}") 23 | return result 24 | 25 | return _wrapped 26 | 27 | 28 | @log_results 29 | def get_metavar(t: type) -> str: 30 | """Gets the metavar to be used for that type in help strings. 31 | 32 | This is crucial when using a `weird` auto-generated parsing functions for 33 | things like Union, Optional, Etc etc. 34 | 35 | type the type arguments that were passed to `get_parsing_fn` that 36 | produced the given parsing_fn. 37 | 38 | returns None if the name shouldn't be changed. 39 | """ 40 | # TODO: Maybe we can create the name for each returned call, a bit like how 41 | # we dynamically create the parsing function itself? 42 | new_name: str = getattr(t, "__name__", None) 43 | 44 | optional = is_optional(t) 45 | 46 | if t in _new_metavars: 47 | return _new_metavars[t] 48 | 49 | elif is_union(t): 50 | args = get_type_arguments(t) 51 | metavars: list[str] = [] 52 | for type_arg in args: 53 | if type_arg is type(None): # noqa: E721 54 | continue 55 | metavars.append(get_metavar(type_arg)) 56 | metavar = "|".join(map(str, metavars)) 57 | if optional: 58 | return f"[{metavar}]" 59 | return metavar 60 | 61 | elif is_tuple(t): 62 | args = get_type_arguments(t) 63 | if not args: 64 | return get_metavar(Any) 65 | logger.debug(f"Tuple args: {args}") 66 | metavars: list[str] = [] 67 | for arg in args: 68 | if arg is Ellipsis: 69 | metavars.append(f"[{metavars[-1]}, ...]") 70 | break 71 | else: 72 | metavars.append(get_metavar(arg)) 73 | return " ".join(metavars) 74 | 75 | return new_name 76 | -------------------------------------------------------------------------------- /simple_parsing/wrappers/wrapper.py: -------------------------------------------------------------------------------- 1 | """Abstract Wrapper base-class for the FieldWrapper and DataclassWrapper.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Optional 5 | 6 | 7 | class Wrapper(ABC): 8 | def __init__(self): 9 | self._dest: Optional[str] = None 10 | 11 | @abstractmethod 12 | def equivalent_argparse_code(self) -> str: 13 | pass 14 | 15 | @property 16 | @abstractmethod 17 | def name(self) -> str: 18 | pass 19 | 20 | @property 21 | @abstractmethod 22 | def parent(self) -> Optional["Wrapper"]: 23 | pass 24 | 25 | @property 26 | def dest(self) -> str: 27 | """Where the attribute will be stored in the Namespace.""" 28 | lineage_names: list[str] = [w.name for w in self.lineage()] 29 | self._dest = ".".join(reversed([self.name] + lineage_names)) 30 | assert self._dest is not None 31 | return self._dest 32 | 33 | def lineage(self) -> list["Wrapper"]: 34 | lineage: list[Wrapper] = [] 35 | parent = self.parent 36 | while parent is not None: 37 | lineage.append(parent) 38 | parent = parent.parent 39 | return lineage 40 | 41 | @property 42 | def nesting_level(self) -> int: 43 | return len(self.lineage()) 44 | level = 0 45 | parent = self.parent 46 | while parent is not None: 47 | parent = parent.parent 48 | level += 1 49 | return level 50 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | pytest.register_assert_rewrite("test.testutils") 4 | 5 | from . import testutils # noqa: E402 6 | 7 | __all__ = ["testutils"] 8 | -------------------------------------------------------------------------------- /test/foo.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lebrice/SimpleParsing/2a69cebb38172b9317c0175aadde6d3b432371ce/test/foo.py -------------------------------------------------------------------------------- /test/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lebrice/SimpleParsing/2a69cebb38172b9317c0175aadde6d3b432371ce/test/helpers/__init__.py -------------------------------------------------------------------------------- /test/helpers/test_encoding.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass, field 4 | from pathlib import Path 5 | 6 | import pytest 7 | from pytest_regressions.file_regression import FileRegressionFixture 8 | 9 | from simple_parsing.helpers.serialization import load, save 10 | 11 | from ..testutils import needs_yaml 12 | 13 | 14 | @dataclass 15 | class A: 16 | a: int = 123 17 | 18 | 19 | @dataclass 20 | class B(A): 21 | b: str = "bob" 22 | 23 | 24 | @dataclass 25 | class Container: 26 | item: A = field(default_factory=A) 27 | 28 | 29 | @dataclass 30 | class BB(B): 31 | """A class that is not shown in the `A | B` annotation above, but that can be set as `item`.""" 32 | 33 | extra_field: int = 123 34 | other_field: int = field(init=False) 35 | 36 | def __post_init__(self): 37 | self.other_field = self.extra_field * 2 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "obj", 42 | [ 43 | Container(item=B(b="hey")), 44 | Container(item=BB(b="hey", extra_field=111)), 45 | ], 46 | ) 47 | @pytest.mark.parametrize("file_type", [".json", pytest.param(".yaml", marks=needs_yaml)]) 48 | def test_encoding_with_dc_types( 49 | obj: Container, file_type: str, tmp_path: Path, file_regression: FileRegressionFixture 50 | ): 51 | file = (tmp_path / "test").with_suffix(file_type) 52 | save(obj, file, save_dc_types=True) 53 | file_regression.check(file.read_text(), extension=file.suffix) 54 | 55 | assert load(Container, file) == obj 56 | 57 | 58 | @pytest.fixture(autouse=True) 59 | def reset_encoding_fns(): 60 | from simple_parsing.helpers.serialization.decoding import _decoding_fns 61 | 62 | copy = _decoding_fns.copy() 63 | # info = get_decoding_fn.cache_info() 64 | 65 | yield 66 | 67 | _decoding_fns.clear() 68 | _decoding_fns.update(copy) 69 | 70 | 71 | @pytest.mark.parametrize("file_type", [".json", pytest.param(".yaml", marks=needs_yaml)]) 72 | def test_encoding_inner_dc_types_raises_warning_and_doest_work(tmp_path: Path, file_type: str): 73 | file = (tmp_path / "test").with_suffix(file_type) 74 | 75 | @dataclass(eq=True) 76 | class BBInner(B): 77 | something: float = 3.21 78 | 79 | obj = Container(item=BBInner(something=123.456)) 80 | with pytest.warns( 81 | RuntimeWarning, 82 | match="BBInner'> is defined in a function scope, which might cause issues", 83 | ): 84 | save(obj, file, save_dc_types=True) 85 | 86 | # NOTE: This would work if `A` were made a subclass of `Serializable`, because we currently 87 | # don't pass the value of the `drop_extra_fields` flag to the decoding function for each field. 88 | # We only use it when deserializing the top-level dataclass. 89 | 90 | # Here we actually expect this to work (since BBInner should be found via 91 | # `B.__subclasses__()`). 92 | from simple_parsing.helpers.serialization.decoding import _decoding_fns 93 | 94 | print(_decoding_fns.keys()) 95 | loaded_obj = load(Container, file, drop_extra_fields=False) 96 | # BUG: There is something a bit weird going on with this comparison: The two objects aren't 97 | # considered equal, but they seem identical 🤔 98 | assert str(loaded_obj) == str(obj) # This comparison works! 99 | 100 | # NOTE: There appears to be some kind of caching mechanism. Running this test a few times in 101 | # succession fails the first time, and passes the remaining times. Seems like waiting 30 102 | # seconds or so invalidates some sort of caching mechanism, and makes the test fail again. 103 | 104 | # assert loaded_obj == obj # BUG? This comparison fails, because: 105 | # assert type(loaded_obj.item) == type(obj.item) # These two types are *sometimes* different?! 106 | -------------------------------------------------------------------------------- /test/helpers/test_encoding/test_encoding_with_dc_types__json_obj0_.json: -------------------------------------------------------------------------------- 1 | {"_type_": "test.helpers.test_encoding.Container", "item": {"_type_": "test.helpers.test_encoding.B", "a": 123, "b": "hey"}} 2 | -------------------------------------------------------------------------------- /test/helpers/test_encoding/test_encoding_with_dc_types__json_obj1_.json: -------------------------------------------------------------------------------- 1 | {"_type_": "test.helpers.test_encoding.Container", "item": {"_type_": "test.helpers.test_encoding.BB", "a": 123, "b": "hey", "extra_field": 111, "other_field": 222}} 2 | -------------------------------------------------------------------------------- /test/helpers/test_encoding/test_encoding_with_dc_types__yaml_obj0_.yaml: -------------------------------------------------------------------------------- 1 | _type_: test.helpers.test_encoding.Container 2 | item: 3 | _type_: test.helpers.test_encoding.B 4 | a: 123 5 | b: hey 6 | -------------------------------------------------------------------------------- /test/helpers/test_encoding/test_encoding_with_dc_types__yaml_obj1_.yaml: -------------------------------------------------------------------------------- 1 | _type_: test.helpers.test_encoding.Container 2 | item: 3 | _type_: test.helpers.test_encoding.BB 4 | a: 123 5 | b: hey 6 | extra_field: 111 7 | other_field: 222 8 | -------------------------------------------------------------------------------- /test/helpers/test_from_dict.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | from dataclasses import dataclass, field, replace 5 | 6 | import pytest 7 | 8 | from simple_parsing.helpers.serialization import from_dict, to_dict 9 | from simple_parsing.utils import Dataclass 10 | 11 | 12 | def test_replace_and_from_dict_already_call_post_init(): 13 | n_post_init_calls = 0 14 | 15 | @dataclass 16 | class Bob: 17 | a: int = 123 18 | 19 | def __post_init__(self): 20 | nonlocal n_post_init_calls 21 | n_post_init_calls += 1 22 | 23 | assert n_post_init_calls == 0 24 | bob = Bob() 25 | assert n_post_init_calls == 1 26 | _ = replace(bob, a=456) 27 | assert n_post_init_calls == 2 28 | 29 | _ = from_dict(Bob, {"a": 456}) 30 | assert n_post_init_calls == 3 31 | 32 | 33 | @dataclass 34 | class InnerConfig: 35 | arg1: int = 1 36 | arg2: str = "foo" 37 | arg1_post_init: str = field(init=False) 38 | 39 | def __post_init__(self): 40 | self.arg1_post_init = str(self.arg1) 41 | 42 | 43 | @dataclass 44 | class OuterConfig1: 45 | out_arg: int = 0 46 | inner: InnerConfig = field(default_factory=InnerConfig) 47 | 48 | 49 | @dataclass 50 | class OuterConfig2: 51 | out_arg: int = 0 52 | inner: InnerConfig = field(default_factory=functools.partial(InnerConfig, arg2="bar")) 53 | 54 | 55 | @dataclass 56 | class Level1: 57 | arg: int = 1 58 | 59 | 60 | @dataclass 61 | class Level2: 62 | arg: int = 1 63 | prev: Level1 = field(default_factory=Level1) 64 | 65 | 66 | @dataclass 67 | class Level3: 68 | arg: int = 1 69 | prev: Level2 = field(default_factory=Level2) 70 | 71 | 72 | @pytest.mark.parametrize( 73 | ("config"), 74 | [ 75 | OuterConfig1(), 76 | OuterConfig2(), 77 | Level1(arg=2), 78 | Level2(arg=2, prev=Level1(arg=3)), 79 | Level2(), 80 | Level3(), 81 | ], 82 | ) 83 | def test_issue_210_nested_dataclasses_serialization(config: Dataclass): 84 | _from_dict = functools.partial(from_dict, type(config)) 85 | assert _from_dict(to_dict(config)) == config 86 | assert _from_dict(to_dict(config), drop_extra_fields=True) == config 87 | # More 'intense' comparisons, to make sure that the serialization is reversible: 88 | assert to_dict(_from_dict(to_dict(config))) == to_dict(config) 89 | assert _from_dict(to_dict(_from_dict(to_dict(config)))) == _from_dict(to_dict(config)) 90 | -------------------------------------------------------------------------------- /test/helpers/test_partial.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from collections.abc import Hashable 3 | from dataclasses import dataclass, fields, is_dataclass 4 | 5 | import simple_parsing as sp 6 | from simple_parsing import ArgumentParser 7 | from simple_parsing.helpers.partial import Partial 8 | 9 | from ..testutils import TestSetup 10 | 11 | 12 | @dataclass 13 | class Foo: 14 | a: int = 1 15 | b: int = 2 16 | 17 | 18 | def some_function(v1: int = 123, v2: int = 456): 19 | """Gives back the mean of two numbers.""" 20 | return (v1 + v2) / 2 21 | 22 | 23 | def test_partial_class_attribute(): 24 | @dataclass 25 | class Bob(TestSetup): 26 | foo_factory: Partial[Foo] 27 | 28 | parser = ArgumentParser() 29 | parser.add_arguments(Bob, dest="bob") 30 | args = parser.parse_args("--a 456".split()) 31 | bob = args.bob 32 | foo_factory: Partial[Foo] = bob.foo_factory 33 | assert foo_factory.a == 456 34 | assert foo_factory.b == 2 35 | assert str(foo_factory) == "FooConfig(a=456, b=2)" 36 | 37 | foo = foo_factory() 38 | assert foo == Foo(a=456, b=2) 39 | assert is_dataclass(foo_factory) 40 | assert isinstance(foo_factory, functools.partial) 41 | 42 | 43 | def test_partial_function_attribute(): 44 | @dataclass 45 | class Bob(TestSetup): 46 | some_fn: Partial[some_function] # type: ignore 47 | 48 | bob = Bob.setup("--v2 781") 49 | assert str(bob.some_fn) == "some_function_config(v1=123, v2=781)" 50 | assert bob.some_fn() == some_function(v2=781) 51 | assert bob.some_fn(v1=3, v2=7) == some_function(3, 7) 52 | 53 | 54 | def test_dynamic_classes_are_cached(): 55 | assert Partial[Foo] is Partial[Foo] 56 | 57 | 58 | def test_pickling(): 59 | # TODO: Test that we can pickle / unpickle these dynamic classes objects. 60 | 61 | import pickle 62 | 63 | dynamic_class = Partial[some_function] 64 | 65 | serialized = pickle.dumps(dynamic_class) 66 | 67 | deserialized = pickle.loads(serialized) 68 | assert deserialized is dynamic_class 69 | 70 | 71 | def some_function_with_required_arg(required_arg, v1: int = 123, v2: int = 456): 72 | """Gives back the mean of two numbers.""" 73 | return required_arg, (v1 + v2) / 2 74 | 75 | 76 | @dataclass 77 | class FooWithRequiredArg(TestSetup): 78 | some_fn: Partial[some_function_with_required_arg] 79 | 80 | 81 | def test_partial_for_fn_with_required_args(): 82 | bob = FooWithRequiredArg.setup("--v1 1 --v2 2") 83 | assert is_dataclass(bob.some_fn) 84 | assert isinstance(bob.some_fn, functools.partial) 85 | 86 | assert "required_arg" not in [f.name for f in fields(bob.some_fn)] 87 | assert bob.some_fn(123) == (123, 1.5) 88 | 89 | 90 | def test_getattr(): 91 | bob = FooWithRequiredArg.setup("--v1 1 --v2 2") 92 | some_fn_partial = bob.some_fn 93 | assert some_fn_partial.v1 == 1 94 | assert some_fn_partial.v2 == 2 95 | 96 | 97 | def test_works_with_frozen_instances_as_default(): 98 | @dataclass 99 | class A: 100 | x: int 101 | y: bool = True 102 | 103 | AConfig = sp.config_for(A, ignore_args="x", frozen=True) 104 | 105 | a1_config = AConfig(y=False) 106 | a2_config = AConfig(y=True) 107 | 108 | assert isinstance(a1_config, functools.partial) 109 | assert isinstance(a1_config, Hashable) 110 | 111 | @dataclass(frozen=True) 112 | class ParentConfig: 113 | a: Partial[A] = sp.subgroups( 114 | { 115 | "a1": a1_config, 116 | "a2": a2_config, 117 | }, 118 | default=a2_config, 119 | ) 120 | 121 | b = sp.parse(ParentConfig, args="--a a2") 122 | assert b.a(x=1) == A(x=1, y=a2_config.y) 123 | -------------------------------------------------------------------------------- /test/helpers/test_partial_postponed.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pickle 4 | from dataclasses import dataclass, is_dataclass 5 | from test.testutils import TestSetup 6 | 7 | from simple_parsing import ArgumentParser 8 | from simple_parsing.helpers.partial import Partial 9 | 10 | 11 | @dataclass 12 | class Foo: 13 | a: int = 1 14 | b: int = 2 15 | 16 | 17 | def some_function(v1: int = 123, v2: int = 456): 18 | """Gives back the mean of two numbers.""" 19 | return (v1 + v2) / 2 20 | 21 | 22 | def test_partial_class_attribute(): 23 | @dataclass 24 | class Bob(TestSetup): 25 | foo_factory: Partial[Foo] 26 | 27 | parser = ArgumentParser() 28 | parser.add_arguments(Bob, dest="bob") 29 | args = parser.parse_args("--a 456".split()) 30 | bob = args.bob 31 | foo_factory: Partial[Foo] = bob.foo_factory 32 | assert is_dataclass(foo_factory) 33 | assert foo_factory.a == 456 34 | assert foo_factory.b == 2 35 | assert str(foo_factory) == "FooConfig(a=456, b=2)" 36 | 37 | 38 | def test_partial_function_attribute(): 39 | @dataclass 40 | class Bob(TestSetup): 41 | some_fn: Partial[some_function] 42 | 43 | bob = Bob.setup("--v2 781") 44 | assert str(bob.some_fn) == "some_function_config(v1=123, v2=781)" 45 | assert bob.some_fn() == some_function(v2=781) 46 | assert bob.some_fn(v1=3, v2=7) == some_function(3, 7) 47 | 48 | 49 | def test_dynamic_classes_are_cached(): 50 | assert Partial[Foo] is Partial[Foo] 51 | 52 | 53 | # bob = Bob(foo_factory=Foo, some_fn=some_function) 54 | 55 | 56 | def test_pickling(): 57 | # TODO: Test that we can pickle / unpickle these dynamic classes objects. 58 | dynamic_class = Partial[some_function] 59 | 60 | serialized = pickle.dumps(dynamic_class) 61 | 62 | deserialized = pickle.loads(serialized) 63 | assert deserialized is dynamic_class 64 | -------------------------------------------------------------------------------- /test/helpers/test_save.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | from ..nesting.example_use_cases import HyperParameters 6 | from ..testutils import needs_toml, needs_yaml 7 | 8 | 9 | @needs_yaml 10 | def test_save_yaml(tmpdir: Path): 11 | hparams = HyperParameters.setup("") 12 | tmp_path = Path(tmpdir / "temp.yml") 13 | hparams.save_yaml(tmp_path) 14 | 15 | _hparams = HyperParameters.load_yaml(tmp_path) 16 | assert hparams == _hparams 17 | 18 | 19 | def test_save_json(tmpdir: Path): 20 | hparams = HyperParameters.setup("") 21 | tmp_path = Path(tmpdir / "temp.json") 22 | hparams.save_json(tmp_path) 23 | _hparams = HyperParameters.load_json(tmp_path) 24 | assert hparams == _hparams 25 | 26 | 27 | @needs_yaml 28 | def test_save_yml(tmpdir: Path): 29 | hparams = HyperParameters.setup("") 30 | tmp_path = Path(tmpdir / "temp.yml") 31 | hparams.save(tmp_path) 32 | 33 | _hparams = HyperParameters.load(tmp_path) 34 | assert hparams == _hparams 35 | 36 | 37 | def test_save_pickle(tmpdir: Path): 38 | hparams = HyperParameters.setup("") 39 | tmp_path = Path(tmpdir / "temp.pkl") 40 | hparams.save(tmp_path) 41 | 42 | _hparams = HyperParameters.load(tmp_path) 43 | assert hparams == _hparams 44 | 45 | 46 | def test_save_numpy(tmpdir: Path): 47 | hparams = HyperParameters.setup("") 48 | tmp_path = Path(tmpdir / "temp.npy") 49 | hparams.save(tmp_path) 50 | 51 | _hparams = HyperParameters.load(tmp_path) 52 | assert hparams == _hparams 53 | 54 | 55 | try: 56 | import torch 57 | except ImportError: 58 | torch = None 59 | 60 | 61 | @pytest.mark.skipif(torch is None, reason="PyTorch is not installed") 62 | def test_save_torch(tmpdir: Path): 63 | hparams = HyperParameters.setup("") 64 | tmp_path = Path(tmpdir / "temp.pth") 65 | hparams.save(tmp_path) 66 | 67 | _hparams = HyperParameters.load(tmp_path) 68 | assert hparams == _hparams 69 | 70 | 71 | @needs_toml 72 | def test_save_toml(tmpdir: Path): 73 | hparams = HyperParameters.setup("") 74 | tmp_path = Path(tmpdir / "temp.toml") 75 | hparams.save(tmp_path) 76 | 77 | _hparams = HyperParameters.load(tmp_path) 78 | assert hparams == _hparams 79 | -------------------------------------------------------------------------------- /test/nesting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lebrice/SimpleParsing/2a69cebb38172b9317c0175aadde6d3b432371ce/test/nesting/__init__.py -------------------------------------------------------------------------------- /test/nesting/test_default_factory_help_strings.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | from dataclasses import dataclass, field 5 | from functools import partial 6 | from test.testutils import TestSetup 7 | from typing import Callable 8 | 9 | import pytest 10 | 11 | 12 | @dataclass 13 | class Value: 14 | """A simple dataclass with a single int field.""" 15 | 16 | v: int = 123 17 | 18 | 19 | @dataclass 20 | class Wrapped: 21 | """A dataclass with a single field, which is a Value.""" 22 | 23 | value: Value = field(default_factory=partial(Value, v=456)) 24 | 25 | 26 | @pytest.mark.parametrize( 27 | "default_factory, expected_default_in_help_str", 28 | [ 29 | (partial(Wrapped, value=Value(v=789)), 789), 30 | (Wrapped, 456), 31 | # Should fetch the default from the field type. 32 | (dataclasses.MISSING, 456), 33 | # NOTE: We actually invoke the default factory lambda here, which isn't ideal. 34 | (lambda: Wrapped(value=Value(789)), 789), 35 | ], 36 | ) 37 | def test_defaults_from_field_default_factory_show_in_help( 38 | default_factory: Callable[[], Wrapped] | dataclasses._MISSING_TYPE, 39 | expected_default_in_help_str: int, 40 | ): 41 | """When using a functools.partial as the default factory for a field, we want to be able to 42 | show the right default values in the help string: those from the factory, not those from the 43 | dataclass field. 44 | 45 | This isn't *that* big a deal, but it would be nice. 46 | """ 47 | 48 | @dataclass 49 | class Config(TestSetup): 50 | """A dataclass with a single field, which is a Wrapped object.""" 51 | 52 | wrapped: Wrapped = field(default_factory=default_factory) # type: ignore 53 | 54 | help_text = Config.get_help_text() 55 | 56 | assert f"--v int (default: {expected_default_in_help_str})" in help_text 57 | -------------------------------------------------------------------------------- /test/nesting/test_nesting_auto.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from simple_parsing import ConflictResolution 4 | 5 | from .example_use_cases import HyperParameters 6 | 7 | 8 | def test_real_use_case(silent): 9 | hparams = HyperParameters.setup( 10 | "--age_group.num_layers 5 " "--age_group.num_units 65 ", 11 | conflict_resolution_mode=ConflictResolution.AUTO, 12 | ) 13 | assert isinstance(hparams, HyperParameters) 14 | # print(hparams.get_help_text()) 15 | assert hparams.gender.num_layers == 1 16 | assert hparams.gender.num_units == 32 17 | assert hparams.age_group.num_layers == 5 18 | assert hparams.age_group.num_units == 65 19 | assert hparams.age_group.use_likes is True 20 | -------------------------------------------------------------------------------- /test/nesting/test_nesting_defaults.py: -------------------------------------------------------------------------------- 1 | import shlex 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | from simple_parsing import ArgumentParser 7 | from simple_parsing.helpers import field 8 | from simple_parsing.helpers.serialization.serializable import Serializable 9 | 10 | from ..testutils import needs_yaml 11 | 12 | 13 | @dataclass 14 | class AdvTraining(Serializable): 15 | epsilon: float 16 | iters: int 17 | 18 | 19 | @dataclass 20 | class DPTraining(Serializable): 21 | epsilon: float 22 | delta: float 23 | 24 | 25 | @dataclass 26 | class DatasetConfig(Serializable): 27 | name: str 28 | prop: str 29 | split: str = field(choices=["victim", "adv"]) 30 | value: float 31 | drop_senstive_cols: Optional[bool] = False 32 | scale: Optional[float] = 1.0 33 | 34 | 35 | @dataclass 36 | class TrainConfig(Serializable): 37 | data_config: DatasetConfig 38 | epochs: int 39 | ... 40 | dp_config: Optional[DPTraining] = None 41 | adv_config: Optional[AdvTraining] = None 42 | ... 43 | cpu: bool = False 44 | 45 | 46 | @needs_yaml 47 | def test_comment_pull115(tmp_path): 48 | config_in_file = TrainConfig( 49 | data_config=DatasetConfig(name="bob", split="victim", prop="123", value=1.23), 50 | epochs=1, 51 | ) 52 | config_in_file.save(tmp_path / "config.yaml") 53 | 54 | parser = ArgumentParser(add_help=False) 55 | parser.add_argument("--config_file", help="Specify config file", type=Path) 56 | args, remaining_argv = parser.parse_known_args( 57 | shlex.split( 58 | f"--config_file {tmp_path / 'config.yaml'} --cpu --epochs 2 --name bob " 59 | f"--prop 123 --value 1.23 --split victim " 60 | f"--drop_senstive_cols True --scale 1.0 --epochs 2 " 61 | f"--dp_config.epsilon 0.1 --delta 0.2 --adv_config.epsilon 0.1 " 62 | f"--iters 10" 63 | ) 64 | ) 65 | 66 | # Attempt to extract as much information from config file as you can 67 | config = TrainConfig.load(args.config_file, drop_extra_fields=False) 68 | # Also give user the option to provide config values over CLI 69 | parser = ArgumentParser(parents=[parser]) 70 | parser.add_arguments(TrainConfig, dest="train_config", default=config) 71 | args = parser.parse_args(remaining_argv) 72 | 73 | config_in_args = TrainConfig( 74 | data_config=DatasetConfig( 75 | name="bob", 76 | split="victim", 77 | prop="123", 78 | value=1.23, 79 | drop_senstive_cols=True, 80 | scale=1.0, 81 | ), 82 | epochs=2, 83 | cpu=True, 84 | adv_config=AdvTraining(epsilon=0.1, iters=10), 85 | dp_config=DPTraining(epsilon=0.1, delta=0.2), 86 | ) 87 | 88 | expected_dict = config_in_file.to_dict() 89 | expected_dict.update(config_in_args.to_dict()) 90 | 91 | assert args.train_config.to_dict() == expected_dict 92 | assert args.train_config == TrainConfig.from_dict(expected_dict) 93 | -------------------------------------------------------------------------------- /test/nesting/test_nesting_explicit.py: -------------------------------------------------------------------------------- 1 | from simple_parsing import ConflictResolution 2 | 3 | from .example_use_cases import HyperParameters, TaskHyperParameters 4 | 5 | 6 | def test_real_use_case(silent): 7 | hparams = HyperParameters.setup( 8 | "--hyper_parameters.age_group.num_layers 5", 9 | conflict_resolution_mode=ConflictResolution.EXPLICIT, 10 | ) 11 | assert isinstance(hparams, HyperParameters) 12 | # print(hparams.get_help_text()) 13 | assert hparams.age_group.num_layers == 5 14 | assert hparams.gender.num_layers == 1 15 | assert hparams.gender.num_units == 32 16 | assert isinstance(hparams.age_group, TaskHyperParameters) 17 | assert hparams.age_group.use_likes is True 18 | -------------------------------------------------------------------------------- /test/nesting/test_nesting_merge.py: -------------------------------------------------------------------------------- 1 | from simple_parsing import ArgumentParser, ConflictResolution 2 | 3 | from .example_use_cases import HyperParameters 4 | 5 | 6 | def test_parser_preprocessing_steps(): 7 | parser = ArgumentParser(conflict_resolution=ConflictResolution.ALWAYS_MERGE) 8 | parser.add_arguments(HyperParameters, "hparams") 9 | 10 | wrappers = parser._wrappers 11 | # Fix the potential conflicts between dataclass fields with the same names. 12 | merged_wrappers = parser._conflict_resolver.resolve_and_flatten(wrappers) 13 | from simple_parsing.parsing import _unflatten_wrappers 14 | 15 | assert merged_wrappers[1].parent is merged_wrappers[0] 16 | assert merged_wrappers[1] in merged_wrappers[0]._children 17 | 18 | assert _unflatten_wrappers(merged_wrappers) == [merged_wrappers[0]] 19 | 20 | wrappers = merged_wrappers 21 | 22 | assert len(wrappers) == 2 23 | hparams_dc_wrapper = wrappers[0] 24 | assert hparams_dc_wrapper.destinations == ["hparams"] 25 | merged_dcs_wrapper = wrappers[1] 26 | assert merged_dcs_wrapper.destinations == [ 27 | "hparams.gender", 28 | "hparams.age_group", 29 | "hparams.personality", 30 | ] 31 | num_layers_field_wrapper = next(f for f in merged_dcs_wrapper.fields if f.name == "num_layers") 32 | assert num_layers_field_wrapper.destinations == [ 33 | "hparams.gender.num_layers", 34 | "hparams.age_group.num_layers", 35 | "hparams.personality.num_layers", 36 | ] 37 | assert num_layers_field_wrapper.option_strings == ["--num_layers"] 38 | assert num_layers_field_wrapper.arg_options == { 39 | "default": [1, 2, 1], 40 | "type": int, 41 | "required": False, 42 | "help": "number of dense layers", 43 | "nargs": "*", 44 | # NOTE: This `dest` is where all the merged values are stored. 45 | "dest": "hparams.gender.num_layers", 46 | } 47 | 48 | parser._wrappers = wrappers 49 | parser._preprocessing_done = True 50 | # Create one argument group per dataclass 51 | for wrapped_dataclass in wrappers: 52 | print( 53 | f"Parser {id(parser)} is Adding arguments for dataclass: {wrapped_dataclass.dataclass} " 54 | f"at destinations {wrapped_dataclass.destinations}" 55 | ) 56 | wrapped_dataclass.add_arguments(parser=parser) 57 | assert "--num_layers" in parser._option_string_actions 58 | 59 | 60 | def test_hparam_use_case(silent): 61 | hparams = HyperParameters.setup( 62 | "--num_layers 5 6 7", conflict_resolution_mode=ConflictResolution.ALWAYS_MERGE 63 | ) 64 | assert isinstance(hparams, HyperParameters) 65 | # print(hparams.get_help_text()) 66 | assert hparams.gender.num_layers == 5 67 | assert hparams.age_group.num_layers == 6 68 | assert hparams.personality.num_layers == 7 69 | 70 | assert hparams.gender.num_units == 32 71 | assert hparams.age_group.num_units == 64 72 | assert hparams.personality.num_units == 8 73 | 74 | assert hparams.gender.use_likes is True 75 | assert hparams.age_group.use_likes is True 76 | assert hparams.personality.use_likes is False 77 | -------------------------------------------------------------------------------- /test/postponed_annotations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lebrice/SimpleParsing/2a69cebb38172b9317c0175aadde6d3b432371ce/test/postponed_annotations/__init__.py -------------------------------------------------------------------------------- /test/postponed_annotations/a.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | 6 | 7 | @dataclass 8 | class A: 9 | p: Path | None 10 | -------------------------------------------------------------------------------- /test/postponed_annotations/b.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from ..test_utils import TestSetup 4 | from .a import A 5 | 6 | 7 | @dataclass 8 | class B(A, TestSetup): 9 | v: int 10 | -------------------------------------------------------------------------------- /test/postponed_annotations/multi_inherits.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | 5 | from ..test_utils import TestSetup 6 | from .b import B 7 | 8 | 9 | @dataclass 10 | class P1(TestSetup): 11 | a1: int = 1 12 | 13 | 14 | @dataclass 15 | class P2(P1): 16 | a2: int = 2 17 | 18 | 19 | @dataclass 20 | class P3(P2): 21 | a3: int = 3 22 | 23 | 24 | @dataclass 25 | class P4(P3): 26 | a4: int = 4 27 | 28 | 29 | @dataclass 30 | class C(B): 31 | m: str 32 | -------------------------------------------------------------------------------- /test/postponed_annotations/overwrite_attribute.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | 5 | from ..test_utils import TestSetup 6 | from .overwrite_base import Base, ParamCls 7 | 8 | 9 | @dataclass 10 | class ParamClsSubclass(ParamCls): 11 | v: bool 12 | 13 | 14 | @dataclass 15 | class Subclass(Base, TestSetup): 16 | attribute: ParamClsSubclass 17 | -------------------------------------------------------------------------------- /test/postponed_annotations/overwrite_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | 5 | 6 | @dataclass 7 | class ParamCls: 8 | ... 9 | 10 | 11 | @dataclass 12 | class Base: 13 | attribute: ParamCls 14 | -------------------------------------------------------------------------------- /test/postponed_annotations/overwrite_subclass.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | 5 | from ..test_utils import TestSetup 6 | from .overwrite_base import Base 7 | 8 | 9 | @dataclass 10 | class ParamCls: 11 | something_else: bool = True 12 | 13 | 14 | @dataclass 15 | class Subclass(Base, TestSetup): 16 | other_attribute: ParamCls 17 | -------------------------------------------------------------------------------- /test/postponed_annotations/test_postponed_annotations.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | 6 | import pytest 7 | 8 | from simple_parsing import Serializable 9 | 10 | from ..test_utils import TestSetup 11 | from .b import B 12 | from .multi_inherits import P4, C 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "argv, v, p", 17 | [ 18 | ("--v 1", 1, None), 19 | ("--v 2 --p test/", 2, Path("test/")), 20 | ("--v 3 --p test/test1", 3, Path("test/test1")), 21 | pytest.param( 22 | "", None, None, marks=pytest.mark.xfail(reason="no default value in the dataclass") 23 | ), 24 | ], 25 | ) 26 | def test_postponed_annotations_with_baseclass(argv: str, v: int | None, p: Path | None): 27 | assert B.setup(argv) == B(v=v, p=p) 28 | 29 | 30 | @dataclass 31 | class MyArguments(Serializable, TestSetup): 32 | arg1: str = "this_argument" 33 | arg2: str | Path = Path("test_dir") 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "argv, arg1, arg2", 38 | [ 39 | ("", "this_argument", "test_dir"), 40 | ("--arg1 test1", "test1", Path("test_dir")), 41 | ("--arg2 test_path_dir", "this_argument", Path("test_path_dir")), 42 | ], 43 | ) 44 | def test_postponed_annotations_with_Serializable_base(argv: str, arg1: str, arg2: str | Path): 45 | actual_args = MyArguments.setup(argv) 46 | target_args = MyArguments(arg1=arg1, arg2=arg2) 47 | assert actual_args.arg1 == target_args.arg1 48 | assert str(actual_args.arg2) == str(target_args.arg2) 49 | 50 | 51 | def test_postponed_annotations_with_multi_depth_inherits_1(): 52 | assert P4.setup("--a1 4 --a2 3 --a3 2 --a4 1") == P4(4, 3, 2, 1) 53 | 54 | 55 | def test_postponed_annotations_with_multi_depth_inherits_2(): 56 | assert C.setup("--p test/test1 --v 1 --m string") == C(Path("test/test1"), 1, "string") 57 | 58 | 59 | def test_overwrite_base(): 60 | """Test that postponed annotations don't break types with the same name in multiple files.""" 61 | import test.postponed_annotations.overwrite_base as overwrite_base 62 | import test.postponed_annotations.overwrite_subclass as overwrite_subclass 63 | 64 | assert overwrite_subclass.Subclass.setup( 65 | "--something_else False" 66 | ) == overwrite_subclass.Subclass( 67 | attribute=overwrite_base.ParamCls(), 68 | other_attribute=overwrite_subclass.ParamCls(False), 69 | ) 70 | 71 | 72 | def test_overwrite_field(): 73 | """Test that postponed annotations don't break attribute overwriting in multiple files.""" 74 | import test.postponed_annotations.overwrite_attribute as overwrite_attribute 75 | import test.postponed_annotations.overwrite_base as overwrite_base 76 | 77 | instance = overwrite_attribute.Subclass.setup("--v True") 78 | assert ( 79 | type(instance.attribute) != overwrite_base.ParamCls 80 | ), "attribute type from Base class correctly ignored" 81 | assert instance == overwrite_attribute.Subclass( 82 | attribute=overwrite_attribute.ParamClsSubclass(True) 83 | ), "parsed attribute value is correct" 84 | -------------------------------------------------------------------------------- /test/test_aliases.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from test.testutils import TestSetup 3 | 4 | from simple_parsing import field 5 | 6 | 7 | def test_aliases_with_given_dashes(): 8 | @dataclass 9 | class Foo(TestSetup): 10 | output_dir: str = field(default="/out", alias=["-o", "--out"]) 11 | 12 | foo = Foo.setup("--output_dir /bob") 13 | assert foo.output_dir == "/bob" 14 | 15 | foo = Foo.setup("-o /cat") 16 | assert foo.output_dir == "/cat" 17 | 18 | foo = Foo.setup("--out /john") 19 | assert foo.output_dir == "/john" 20 | 21 | 22 | def test_aliases_without_dashes(): 23 | @dataclass 24 | class Foo(TestSetup): 25 | output_dir: str = field(default="/out", alias=["o", "out"]) 26 | 27 | foo = Foo.setup("--output_dir /bob") 28 | assert foo.output_dir == "/bob" 29 | 30 | foo = Foo.setup("-o /cat") 31 | assert foo.output_dir == "/cat" 32 | 33 | foo = Foo.setup("--out /john") 34 | assert foo.output_dir == "/john" 35 | -------------------------------------------------------------------------------- /test/test_conf_path.py: -------------------------------------------------------------------------------- 1 | """Tests for config-path option.""" 2 | 3 | import json 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | 7 | import pytest 8 | 9 | from simple_parsing.parsing import ArgumentParser, parse 10 | 11 | 12 | @dataclass 13 | class BarConf: 14 | foo: str 15 | 16 | 17 | @pytest.mark.parametrize( 18 | "conf_arg_name", ["config-file", "config_file", "foo.bar.baz?", "bob bob bob"] 19 | ) 20 | def test_config_path_arg(tmp_path: Path, conf_arg_name: str): 21 | """Test config_path with valid strings.""" 22 | # Create config file 23 | conf_path = tmp_path / "foo.yml" 24 | with conf_path.open("w") as f: 25 | json.dump({"foo": "bee"}, f) 26 | 27 | # with pytest.raises(ValueError): 28 | parser = ArgumentParser(BarConf, add_config_path_arg=conf_arg_name) 29 | args = parser.parse_args([f"--{conf_arg_name}", str(conf_path)]) 30 | print(args) 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "conf_arg_name", 35 | [ 36 | "-------", 37 | ], 38 | ) 39 | def test_pass_invalid_value_to_add_config_path_arg(tmp_path: Path, conf_arg_name: str): 40 | """Test config_path with invalid strings.""" 41 | # Create config file 42 | conf_path = tmp_path / "foo.yml" 43 | with conf_path.open("w") as f: 44 | json.dump({"foo": "bee"}, f) 45 | 46 | parser = ArgumentParser(BarConf, add_config_path_arg=conf_arg_name) 47 | with pytest.raises(ValueError): 48 | parser.parse_args([f"--{conf_arg_name}", str(conf_path)]) 49 | 50 | 51 | def test_config_path_same_as_dst_error(): 52 | """Raise an error if add_config_path_arg and dest are the equal.""" 53 | with pytest.raises(ValueError, match="`add_config_path_arg` cannot be the same as `dest`."): 54 | parse(BarConf, dest="boo", add_config_path_arg="boo") 55 | -------------------------------------------------------------------------------- /test/test_conflicts.py: -------------------------------------------------------------------------------- 1 | """Tests for weird conflicts.""" 2 | import argparse 3 | import functools 4 | from dataclasses import dataclass, field 5 | 6 | from simple_parsing import ArgumentParser 7 | 8 | from .testutils import TestSetup, raises 9 | 10 | 11 | def test_arg_and_dataclass_with_same_name(silent): 12 | @dataclass 13 | class SomeClass: 14 | a: int = 1 # some docstring for attribute 'a' 15 | 16 | parser = ArgumentParser() 17 | parser.add_argument("--a", default=123) 18 | with raises(argparse.ArgumentError): 19 | parser.add_arguments(SomeClass, dest="some_class") 20 | parser.parse_args("") 21 | 22 | 23 | def test_arg_and_dataclass_with_same_name_after_prefixing(silent): 24 | @dataclass 25 | class SomeClass: 26 | a: int = 1 # some docstring for attribute 'a' 27 | 28 | @dataclass 29 | class Parent: 30 | pre: SomeClass = field(default_factory=lambda: SomeClass) 31 | bla: SomeClass = field(default_factory=lambda: SomeClass) 32 | 33 | parser = ArgumentParser() 34 | parser.add_argument("--pre.a", default=123, type=int) 35 | with raises(argparse.ArgumentError): 36 | parser.add_arguments(Parent, dest="some_class") 37 | parser.parse_args("--pre.a 123 --pre.a 456".split()) 38 | 39 | 40 | def test_weird_hierarchy(): 41 | @dataclass 42 | class Base: 43 | v: float = 0.0 44 | 45 | @dataclass 46 | class A(Base): 47 | pass 48 | 49 | @dataclass 50 | class B(Base): 51 | pass 52 | 53 | @dataclass 54 | class C(Base): 55 | pass 56 | 57 | @dataclass 58 | class Options: 59 | a: A = field(default_factory=functools.partial(A, 0.1)) 60 | b: B = field(default_factory=functools.partial(B, 0.2)) 61 | 62 | @dataclass 63 | class Settings(TestSetup): 64 | opt: Options = field(default_factory=Options) 65 | c: Base = field(default_factory=functools.partial(C, 0.3)) 66 | 67 | opt = Settings.setup("") 68 | print(opt) 69 | 70 | 71 | def test_parent_child_conflict(): 72 | @dataclass 73 | class HParams: 74 | batch_size: int = 32 75 | 76 | @dataclass 77 | class Parent2(TestSetup): 78 | batch_size: int = 48 79 | child: HParams = field(default_factory=HParams) 80 | 81 | p: Parent2 = Parent2.setup() 82 | assert p.child.batch_size == 32 83 | assert p.batch_size == 48 84 | -------------------------------------------------------------------------------- /test/test_default_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataclasses import dataclass 3 | 4 | from simple_parsing import ArgumentParser 5 | 6 | from .testutils import raises_missing_required_arg, shlex 7 | 8 | 9 | def test_no_default_argument(simple_attribute, silent): 10 | some_type, passed_value, expected_value = simple_attribute 11 | 12 | @dataclass 13 | class SomeClass: 14 | a: some_type 15 | 16 | parser = ArgumentParser() 17 | parser.add_arguments(SomeClass, dest="some_class") 18 | 19 | args = parser.parse_args(shlex.split(f"--a {passed_value}")) 20 | assert args == argparse.Namespace(some_class=SomeClass(a=expected_value)) 21 | 22 | with raises_missing_required_arg(): 23 | parser.parse_args("") 24 | 25 | 26 | def test_default_dataclass_argument(simple_attribute, silent): 27 | some_type, passed_value, expected_value = simple_attribute 28 | 29 | @dataclass 30 | class SomeClass: 31 | a: some_type 32 | 33 | parser = ArgumentParser() 34 | parser.add_arguments(SomeClass, dest="some_class", default=SomeClass(a=expected_value)) 35 | 36 | args = parser.parse_args("") 37 | assert args == argparse.Namespace(some_class=SomeClass(a=expected_value)) 38 | 39 | 40 | def test_default_dict_argument(simple_attribute, silent): 41 | some_type, passed_value, expected_value = simple_attribute 42 | 43 | @dataclass 44 | class SomeClass: 45 | a: some_type 46 | 47 | parser = ArgumentParser() 48 | parser.add_arguments(SomeClass, dest="some_class", default={"a": expected_value}) 49 | 50 | args = parser.parse_args("") 51 | assert args == argparse.Namespace(some_class=SomeClass(a=expected_value)) 52 | 53 | 54 | def test_default_dict_argument_override_cmdline(simple_attribute, silent): 55 | some_type, passed_value, expected_value = simple_attribute 56 | 57 | @dataclass 58 | class SomeClass: 59 | a: some_type 60 | 61 | parser = ArgumentParser() 62 | parser.add_arguments(SomeClass, dest="some_class", default={"a": 0}) 63 | 64 | args = parser.parse_args(shlex.split(f"--a {passed_value}")) 65 | assert args == argparse.Namespace(some_class=SomeClass(a=expected_value)) 66 | 67 | 68 | def test_partial_default_dict_argument(simple_attribute, silent): 69 | some_type, passed_value, expected_value = simple_attribute 70 | 71 | @dataclass 72 | class SomeClass: 73 | a: some_type 74 | b: int 75 | 76 | parser = ArgumentParser() 77 | parser.add_arguments(SomeClass, dest="some_class", default={"a": expected_value}) 78 | 79 | args = parser.parse_args(shlex.split("--b 0")) 80 | assert args == argparse.Namespace(some_class=SomeClass(a=expected_value, b=0)) 81 | with raises_missing_required_arg(): 82 | parser.parse_args(shlex.split(f"--a {passed_value}")) 83 | with raises_missing_required_arg(): 84 | parser.parse_args("") 85 | -------------------------------------------------------------------------------- /test/test_fields.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from typing import Optional 4 | 5 | import pytest 6 | 7 | from simple_parsing import ArgumentParser, ConflictResolution, field 8 | from simple_parsing.utils import str2bool 9 | from simple_parsing.wrappers.field_parsing import parse_enum 10 | 11 | 12 | def test_cmd_false_doesnt_create_conflicts(): 13 | @dataclass 14 | class A: 15 | batch_size: int = field(default=10, cmd=False) 16 | 17 | @dataclass 18 | class B: 19 | batch_size: int = 20 20 | 21 | # @dataclass 22 | # class Foo(TestSetup): 23 | # a: A = mutable_field(A) 24 | # b: B = mutable_field(B) 25 | 26 | parser = ArgumentParser(conflict_resolution=ConflictResolution.NONE) 27 | parser.add_arguments(A, "a") 28 | parser.add_arguments(B, "b") 29 | args = parser.parse_args("--batch_size 32".split()) 30 | a: A = args.a 31 | b: B = args.b 32 | assert a == A() 33 | assert b == B(batch_size=32) 34 | 35 | 36 | class Color(Enum): 37 | blue: str = "BLUE" 38 | red: str = "RED" 39 | green: str = "GREEN" 40 | orange: str = "ORANGE" 41 | 42 | 43 | @pytest.mark.xfail( 44 | reason="Removed this function. TODO: see https://github.com/lebrice/SimpleParsing/issues/150." 45 | ) 46 | @pytest.mark.parametrize( 47 | "annotation, expected_options", 48 | [ 49 | (tuple[int, int], dict(nargs=2, type=int)), 50 | (tuple[Color, Color], dict(nargs=2, type=parse_enum(Color))), 51 | ( 52 | Optional[tuple[Color, Color]], 53 | dict(nargs=2, type=parse_enum(Color), required=False), 54 | ), 55 | (list[str], dict(nargs="*", type=str)), 56 | (Optional[list[str]], dict(nargs="*", type=str, required=False)), 57 | (Optional[str], dict(nargs="?", type=str, required=False)), 58 | (Optional[bool], dict(nargs="?", type=str2bool, required=False)), 59 | # (Optional[Tuple[Color, str]], dict(nargs=2, type=get_parsing_fn(Tuple[Color, str]), required=False)), 60 | ], 61 | ) 62 | def test_generated_options_from_annotation(annotation: type, expected_options: dict): 63 | raise NotImplementedError( 64 | """ 65 | TODO: Would be a good idea to refactor the FieldWrapper class a bit. The args_dict (a dict 66 | of all the argparse arguments for a given field, that get passed to parser.add_arguments 67 | in the FieldWrapper) is currently created using a mix of three things (with increasing 68 | priority): 69 | - The type annotation 70 | - The dataclass context (e.g. when adding an Optional[Dataclass] field on another 71 | dataclass, or when using the `default` or `prefix` arguments to `parser.add_arguments`. 72 | - The manual overrides (arguments of `parser.add_argument` passed to the `field` function) 73 | 74 | These three are currently a bit mixed together in the `FieldWrapper` class. It would be 75 | preferable to design a way for them to be cleanly separated. 76 | """ 77 | ) 78 | # from simple_parsing.wrappers.field_wrapper import get_argparse_options_for_annotation 79 | 80 | # actual_options = get_argparse_options_for_annotation(annotation) 81 | # for option, expected_value in expected_options.items(): 82 | # assert actual_options[option] == expected_value 83 | -------------------------------------------------------------------------------- /test/test_forward_ref.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from simple_parsing import field 4 | 5 | from .testutils import TestSetup 6 | 7 | 8 | @dataclass 9 | class Foo(TestSetup): 10 | a: int = 123 11 | 12 | b: str = "fooobar" 13 | c: tuple[int, float] = (123, 4.56) 14 | 15 | d: list[bool] = field(default_factory=list) 16 | 17 | 18 | @dataclass 19 | class Bar(TestSetup): 20 | barry: Foo = field(default_factory=Foo) 21 | joe: "Foo" = field(default_factory=lambda: Foo(b="rrrrr")) 22 | z: "float" = 123.456 23 | some_list: "list[float]" = field(default_factory=[1.0, 2.0].copy) 24 | 25 | 26 | def test_forward_ref(): 27 | foo = Foo.setup() 28 | assert foo == Foo() 29 | 30 | foo = Foo.setup("--a 2 --b heyo --c 1 7.89") 31 | assert foo == Foo(a=2, b="heyo", c=(1, 7.89)) 32 | 33 | 34 | def test_forward_ref_nested(): 35 | bar = Bar.setup() 36 | assert bar == Bar() 37 | assert bar.barry == Foo() 38 | bar = Bar.setup("--barry.a 2 --barry.b heyo --barry.c 1 7.89") 39 | assert bar.barry == Foo(a=2, b="heyo", c=(1, 7.89)) 40 | assert isinstance(bar.joe, Foo) 41 | -------------------------------------------------------------------------------- /test/test_generation_mode.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import pytest 4 | 5 | from simple_parsing.wrappers.field_wrapper import ArgumentGenerationMode, NestedMode 6 | 7 | from .testutils import TestSetup 8 | 9 | 10 | @dataclass 11 | class ModelOptions: 12 | path: str 13 | device: str 14 | 15 | 16 | @dataclass 17 | class ServerOptions(TestSetup): 18 | host: str 19 | port: int 20 | model: ModelOptions 21 | 22 | 23 | expected = ServerOptions(host="myserver", port=80, model=ModelOptions(path="a_path", device="cpu")) 24 | 25 | 26 | def test_flat(): 27 | options = ServerOptions.setup( 28 | "--host myserver " "--port 80 " "--path a_path " "--device cpu", 29 | ) 30 | assert options == expected 31 | 32 | with pytest.raises(SystemExit): 33 | ServerOptions.setup( 34 | "--opts.host myserver " 35 | "--opts.port 80 " 36 | "--opts.model.path a_path " 37 | "--opts.model.device cpu", 38 | dest="opts", 39 | ) 40 | 41 | 42 | @pytest.mark.parametrize("without_root", [True, False]) 43 | def test_both(without_root): 44 | options = ServerOptions.setup( 45 | "--host myserver " "--port 80 " "--path a_path " "--device cpu", 46 | dest="opts", 47 | argument_generation_mode=ArgumentGenerationMode.BOTH, 48 | ) 49 | assert options == expected 50 | 51 | args = ( 52 | "--opts.host myserver " 53 | "--opts.port 80 " 54 | "--opts.model.path a_path " 55 | "--opts.model.device cpu" 56 | ) 57 | if without_root: 58 | args = args.replace("opts.", "") 59 | options = ServerOptions.setup( 60 | args, 61 | dest="opts", 62 | argument_generation_mode=ArgumentGenerationMode.BOTH, 63 | nested_mode=NestedMode.WITHOUT_ROOT if without_root else NestedMode.DEFAULT, 64 | ) 65 | assert options == expected 66 | 67 | 68 | @pytest.mark.parametrize("without_root", [True, False]) 69 | def test_nested(without_root): 70 | with pytest.raises(SystemExit): 71 | options = ServerOptions.setup( 72 | "--host myserver " "--port 80 " "--path a_path " "--device cpu", 73 | dest="opts", 74 | argument_generation_mode=ArgumentGenerationMode.NESTED, 75 | ) 76 | 77 | args = ( 78 | "--opts.host myserver " 79 | "--opts.port 80 " 80 | "--opts.model.path a_path " 81 | "--opts.model.device cpu" 82 | ) 83 | if without_root: 84 | args = args.replace("opts.", "") 85 | options = ServerOptions.setup( 86 | args, 87 | dest="opts", 88 | argument_generation_mode=ArgumentGenerationMode.NESTED, 89 | nested_mode=NestedMode.WITHOUT_ROOT if without_root else NestedMode.DEFAULT, 90 | ) 91 | assert options == expected 92 | -------------------------------------------------------------------------------- /test/test_inheritance.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from simple_parsing.helpers import Serializable, choice, list_field 4 | 5 | from .testutils import ConflictResolution, TestSetup, xfail 6 | 7 | 8 | @dataclass 9 | class Base(TestSetup): 10 | """Some extension of base-class `Base`""" 11 | 12 | a: int = 1 13 | 14 | 15 | @dataclass 16 | class ExtendedB(Base, TestSetup): 17 | b: int = 2 18 | 19 | 20 | @dataclass 21 | class ExtendedC(Base, TestSetup): 22 | c: int = 3 23 | 24 | 25 | @dataclass 26 | class Inheritance(TestSetup): 27 | ext_b: ExtendedB = field(default_factory=ExtendedB) 28 | ext_c: ExtendedC = field(default_factory=ExtendedC) 29 | 30 | 31 | def test_simple_subclassing_no_args(): 32 | extended = ExtendedB.setup() 33 | assert extended.a == 1 34 | assert extended.b == 2 35 | 36 | 37 | def test_simple_subclassing_with_args(): 38 | extended = ExtendedB.setup("--a 123 --b 56") 39 | assert extended.a == 123 40 | assert extended.b == 56 41 | 42 | 43 | # @xfail(reason="TODO: make sure this is how people would want to use this feature.") 44 | def test_subclasses_with_same_base_class_no_args(): 45 | ext = Inheritance.setup() 46 | 47 | assert ext.ext_b.a == 1 48 | assert ext.ext_b.b == 2 49 | 50 | assert ext.ext_c.a == 1 51 | assert ext.ext_c.c == 3 52 | 53 | 54 | def test_subclasses_with_same_base_class_with_args(): 55 | ext = Inheritance.setup( 56 | "--ext_b.a 10 --b 20 --ext_c.a 30 --c 40", 57 | conflict_resolution_mode=ConflictResolution.AUTO, 58 | ) 59 | assert ext.ext_b.a == 10 60 | assert ext.ext_b.b == 20 61 | 62 | assert ext.ext_c.a == 30 63 | assert ext.ext_c.c == 40 64 | 65 | 66 | @xfail( 67 | reason=( 68 | "Merging is not working yet with triangle inheritance, since we wouldn't " 69 | "know how to assign which value to which attribute.." 70 | ) 71 | ) 72 | def test_subclasses_with_same_base_class_with_args_merge(): 73 | ext = Inheritance.setup( 74 | "--a 10 30 --b 20 --c 40", 75 | conflict_resolution_mode=ConflictResolution.ALWAYS_MERGE, 76 | ) 77 | 78 | assert ext.ext_b.a == 10 79 | assert ext.ext_b.b == 20 80 | 81 | assert ext.ext_c.a == 30 82 | assert ext.ext_c.c == 40 83 | 84 | 85 | def test_weird_structure(): 86 | """Both is-a, and has-a at the same time, a very weird inheritance structure.""" 87 | 88 | @dataclass 89 | class ConvBlock(Serializable): 90 | """A Block of Conv Layers.""" 91 | 92 | n_layers: int = 4 # number of layers 93 | n_filters: list[int] = list_field(16, 32, 64, 64) # filters per layer 94 | 95 | @dataclass 96 | class GeneratorHParams(ConvBlock): 97 | """Settings of the Generator model.""" 98 | 99 | conv: ConvBlock = field(default_factory=ConvBlock) 100 | optimizer: str = choice("ADAM", "RMSPROP", "SGD", default="ADAM") 101 | 102 | @dataclass 103 | class DiscriminatorHParams(ConvBlock): 104 | """Settings of the Discriminator model.""" 105 | 106 | conv: ConvBlock = field(default_factory=ConvBlock) 107 | optimizer: str = choice("ADAM", "RMSPROP", "SGD", default="ADAM") 108 | 109 | @dataclass 110 | class SomeWeirdClass(TestSetup): 111 | gen: GeneratorHParams 112 | disc: DiscriminatorHParams 113 | 114 | s = SomeWeirdClass.setup() 115 | assert s.gen.conv.n_layers == 4 116 | assert s.gen.n_layers == 4 117 | assert s.disc.conv.n_layers == 4 118 | assert s.disc.n_layers == 4 119 | -------------------------------------------------------------------------------- /test/test_initvar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from dataclasses import InitVar, dataclass 3 | from typing import Any 4 | 5 | import pytest 6 | from typing_extensions import Literal 7 | 8 | from .testutils import TestSetup 9 | 10 | 11 | @pytest.mark.skipif( 12 | sys.version_info[:2] < (3, 8), 13 | reason="Before 3.8 `InitVar[tp] is InitVar` so it's impossible to retrieve field type", 14 | ) 15 | @pytest.mark.parametrize( 16 | "tp, passed_value, expected", 17 | [ 18 | (int, "1", 1), 19 | (float, "1.4", 1.4), 20 | (tuple[int, float], "2 -1.2", (2, -1.2)), 21 | (list[str], "12 abc", ["12", "abc"]), 22 | (Literal[1, 2, 3, "4"], "1", 1), 23 | (Literal[1, 2, 3, "4"], "4", "4"), 24 | ], 25 | ) 26 | def test_initvar(tp: type[Any], passed_value: str, expected: Any) -> None: 27 | @dataclass 28 | class Foo(TestSetup): 29 | init_var: InitVar[tp] 30 | 31 | def __post_init__(self, init_var: tp) -> None: 32 | assert init_var == expected 33 | 34 | Foo.setup(f"--init_var {passed_value}") 35 | -------------------------------------------------------------------------------- /test/test_issue64.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from dataclasses import dataclass 3 | from io import StringIO 4 | from test.testutils import assert_help_output_equals 5 | 6 | import pytest 7 | 8 | from simple_parsing import ArgumentParser 9 | 10 | 11 | @dataclass 12 | class Options: 13 | """These are the options.""" 14 | 15 | foo: str = "aaa" # Description 16 | bar: str = "bbb" 17 | 18 | 19 | @pytest.mark.xfail(reason="Issue64 is solved below.") 20 | def test_reproduce_issue64(): 21 | parser = ArgumentParser("issue64") 22 | 23 | parser.add_arguments(Options, dest="options") 24 | 25 | # args = parser.parse_args(["--help"]) 26 | 27 | s = StringIO() 28 | parser.print_help(file=s) 29 | s.seek(0) 30 | 31 | assert s.read() == textwrap.dedent( 32 | """\ 33 | usage: issue64 [-h] [--foo str] [--bar str] 34 | 35 | optional arguments: 36 | -h, --help show this help message and exit 37 | 38 | Options ['options']: 39 | These are the options 40 | 41 | --foo str Description (default: aaa) 42 | --bar str 43 | """ 44 | ) 45 | 46 | 47 | def test_vanilla_argparse_issue64(): 48 | """This test shows that the ArgumentDefaultsHelpFormatter of argparse doesn't add the 49 | "(default: xyz)" if the 'help' argument isn't already passed! 50 | 51 | This begs the question: Should simple-parsing add a 'help' argument always, so that the 52 | formatter can then add the default string after? 53 | """ 54 | import argparse 55 | 56 | parser = ArgumentParser("issue64", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 57 | 58 | group = parser.add_argument_group("Options ['options']", description="These are the options") 59 | group.add_argument("--foo", type=str, metavar="str", default="aaa", help="Description") 60 | group.add_argument("--bar", type=str, metavar="str", default="bbb") 61 | 62 | from io import StringIO 63 | 64 | s = StringIO() 65 | parser.print_help(file=s) 66 | s.seek(0) 67 | output = str(s.read()) 68 | assert_help_output_equals( 69 | actual=output, 70 | expected=textwrap.dedent( 71 | """\ 72 | usage: issue64 [-h] [--foo str] [--bar str] 73 | 74 | optional arguments: 75 | -h, --help show this help message and exit 76 | 77 | Options ['options']: 78 | These are the options 79 | 80 | --foo str Description (default: aaa) 81 | --bar str 82 | """ 83 | ), 84 | ) 85 | 86 | 87 | def test_solved_issue64(): 88 | """Test that shows that Issue 64 is solved now, by adding a single space as the 'help' 89 | argument, the help formatter can then add the "(default: bbb)" after the argument.""" 90 | parser = ArgumentParser("issue64") 91 | parser.add_arguments(Options, dest="options") 92 | 93 | s = StringIO() 94 | parser.print_help(file=s) 95 | s.seek(0) 96 | output = str(s.read()) 97 | assert_help_output_equals( 98 | actual=output, 99 | expected=textwrap.dedent( 100 | """\ 101 | usage: issue64 [-h] [--foo str] [--bar str] 102 | 103 | optional arguments: 104 | -h, --help show this help message and exit 105 | 106 | Options ['options']: 107 | These are the options 108 | 109 | --foo str Description (default: aaa) 110 | --bar str (default: bbb) 111 | """ 112 | ), 113 | ) 114 | -------------------------------------------------------------------------------- /test/test_issue_107.py: -------------------------------------------------------------------------------- 1 | """Test for https://github.com/lebrice/SimpleParsing/issues/107.""" 2 | from dataclasses import dataclass 3 | from typing import Any 4 | 5 | import pytest 6 | 7 | from simple_parsing.helpers.serialization.serializable import Serializable 8 | 9 | 10 | @dataclass 11 | class Foo(Serializable): 12 | a: bool = False 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "passed, expected", 17 | [ 18 | ("True", True), 19 | ("False", False), 20 | (True, True), 21 | (False, False), 22 | ("true", True), 23 | ("false", False), 24 | ("1", True), 25 | ("0", False), 26 | (1, True), 27 | (0, False), 28 | ], 29 | ) 30 | def test_parsing_of_bool_works_as_expected(passed: Any, expected: bool): 31 | assert Foo.from_dict({"a": passed}) == Foo(a=expected) 32 | -------------------------------------------------------------------------------- /test/test_issue_132.py: -------------------------------------------------------------------------------- 1 | """Test for https://github.com/lebrice/SimpleParsing/issues/132.""" 2 | from dataclasses import dataclass 3 | 4 | from simple_parsing import field 5 | 6 | from .conftest import SimpleAttributeTuple 7 | from .testutils import TestSetup 8 | 9 | 10 | def test_field_with_custom_required_arg_is_optional( 11 | simple_attribute: SimpleAttributeTuple, 12 | ): 13 | """Test that the `field` function can be used as a work-around for issue 132. 14 | 15 | When passing `required` (or any other of the usual arguments of `parser.add_argument`) to the 16 | `field` function, they get saved into the `FieldWrapper`, and used to populate the args_dict 17 | that gets passed to `parser.add_arguments(*field_wrapper.option_strings, **args_dict)`. 18 | 19 | Therefore, using `required=False` as a custom argument is a work-around, while we fix this 20 | issue. 21 | """ 22 | some_type, passed_value, expected_value = simple_attribute 23 | 24 | @dataclass 25 | class Foo(TestSetup): 26 | a: some_type = field(default=None, required=False) # type: ignore 27 | 28 | assert Foo.setup() == Foo(a=None) 29 | assert Foo.setup(f"--a {passed_value}") == Foo(a=expected_value) 30 | 31 | 32 | def test_field_with_none_default_is_optional(simple_attribute: SimpleAttributeTuple): 33 | """Test that when the default value is None, the argument is treated as optional.""" 34 | some_type, passed_value, expected_value = simple_attribute 35 | 36 | @dataclass 37 | class Foo(TestSetup): 38 | a: some_type = None # type: ignore 39 | 40 | assert Foo.setup() == Foo(a=None) 41 | assert Foo.setup(f"--a {passed_value}") == Foo(a=expected_value) 42 | 43 | 44 | def test_dataclass_field_with_none_default_is_optional( 45 | simple_attribute: SimpleAttributeTuple, 46 | ): 47 | """Test that when the default value is None, the argument is treated as optional.""" 48 | some_type, passed_value, expected_value = simple_attribute 49 | 50 | @dataclass 51 | class Foo(TestSetup): 52 | a: some_type # type: ignore 53 | 54 | @dataclass 55 | class Bar(TestSetup): 56 | foo: Foo = None # type: ignore 57 | 58 | assert Bar.setup() == Bar(foo=None) # type: ignore 59 | assert Bar.setup(f"--a {passed_value}") == Bar(foo=Foo(a=expected_value)) 60 | -------------------------------------------------------------------------------- /test/test_issue_144.py: -------------------------------------------------------------------------------- 1 | """Tests for issue 144: https://github.com/lebrice/SimpleParsing/issues/144.""" 2 | from __future__ import annotations 3 | 4 | from dataclasses import dataclass, field 5 | from typing import Union 6 | 7 | import pytest 8 | 9 | from simple_parsing.helpers.serialization.serializable import Serializable 10 | 11 | 12 | class TestOptional: 13 | @dataclass 14 | class Foo(Serializable): 15 | foo: int | None = 123 16 | 17 | @pytest.mark.parametrize("d", [{"foo": None}, {"foo": 1}]) 18 | def test_round_trip(self, d: dict): 19 | # NOTE: this double round-trip makes the comparison agnostic to any conversion that may 20 | # happen between the raw dict values and the arguments of the dataclasses. 21 | assert self.Foo.from_dict(self.Foo.from_dict(d).to_dict()) == self.Foo.from_dict(d) 22 | 23 | 24 | class TestUnion: 25 | @dataclass 26 | class Foo(Serializable): 27 | foo: Union[int, dict[int, bool]] = 123 # noqa: UP007 28 | 29 | @pytest.mark.parametrize("d", [{"foo": None}, {"foo": {1: "False"}}]) 30 | def test_round_trip(self, d: dict): 31 | # NOTE: this double round-trip makes the comparison agnostic to any conversion that may 32 | # happen between the raw dict values and the arguments of the dataclasses. 33 | assert self.Foo.from_dict(self.Foo.from_dict(d).to_dict()) == self.Foo.from_dict(d) 34 | 35 | 36 | class TestList: 37 | @dataclass 38 | class Foo(Serializable): 39 | foo: list[int] = field(default_factory=list) 40 | 41 | @pytest.mark.parametrize("d", [{"foo": []}, {"foo": [123, 456]}]) 42 | def test_round_trip(self, d: dict): 43 | # NOTE: this double round-trip makes the comparison agnostic to any conversion that may 44 | # happen between the raw dict values and the arguments of the dataclasses. 45 | assert self.Foo.from_dict(self.Foo.from_dict(d).to_dict()) == self.Foo.from_dict(d) 46 | 47 | 48 | class TestTuple: 49 | @dataclass 50 | class Foo(Serializable): 51 | foo: tuple[int, float, bool] 52 | 53 | @pytest.mark.parametrize("d", [{"foo": (1, 1.2, False)}, {"foo": ("1", "1.2", "True")}]) 54 | def test_round_trip(self, d: dict): 55 | # NOTE: this double round-trip makes the comparison agnostic to any conversion that may 56 | # happen between the raw dict values and the arguments of the dataclasses. 57 | assert self.Foo.from_dict(self.Foo.from_dict(d).to_dict()) == self.Foo.from_dict(d) 58 | 59 | 60 | class TestDict: 61 | @dataclass 62 | class Foo(Serializable): 63 | foo: dict[int, float] = field(default_factory=dict) 64 | 65 | @pytest.mark.parametrize("d", [{"foo": {}}, {"foo": {"123": "4.56"}}]) 66 | def test_round_trip(self, d: dict): 67 | # NOTE: this double round-trip makes the comparison agnostic to any conversion that may 68 | # happen between the raw dict values and the arguments of the dataclasses. 69 | assert self.Foo.from_dict(self.Foo.from_dict(d).to_dict()) == self.Foo.from_dict(d) 70 | -------------------------------------------------------------------------------- /test/test_issue_48.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from dataclasses import dataclass 3 | from io import StringIO 4 | from test.testutils import assert_help_output_equals 5 | 6 | from simple_parsing import ArgumentParser, field 7 | 8 | 9 | @dataclass 10 | class InputArgs: 11 | # Start date from which to collect data about base users. Input in iso format (YYYY-MM-DD). 12 | # The date is included in the data 13 | start_date: str = field(alias="s", metadata={"a": "b"}) 14 | 15 | # End date for collecting base users. Input in iso format (YYYY-MM-DD). The date is included in the data. 16 | # Should not be before `start_date` 17 | end_date: str = field(alias="e") 18 | 19 | 20 | def test_issue_48(): 21 | parser = ArgumentParser("Prepare input data for training") 22 | parser.add_arguments(InputArgs, dest="args") 23 | s = StringIO() 24 | parser.print_help(file=s) 25 | s.seek(0) 26 | output = str(s.read()) 27 | assert_help_output_equals( 28 | actual=output, 29 | expected=textwrap.dedent( 30 | """\ 31 | usage: Prepare input data for training [-h] -s str -e str 32 | 33 | optional arguments: 34 | -h, --help show this help message and exit 35 | 36 | InputArgs ['args']: 37 | InputArgs(start_date:str, end_date:str) 38 | 39 | -s str, --start_date str 40 | Start date from which to collect data about base 41 | users. Input in iso format (YYYY-MM-DD). The date is 42 | included in the data (default: None) 43 | -e str, --end_date str 44 | End date for collecting base users. Input in iso 45 | format (YYYY-MM-DD). The date is included in the data. 46 | Should not be before `start_date` (default: None) 47 | """ 48 | ), 49 | ) 50 | 51 | # args = parser.parse_args() 52 | -------------------------------------------------------------------------------- /test/test_issue_96.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from simple_parsing import ArgumentParser 4 | 5 | from .testutils import TestSetup, exits_and_writes_to_stderr 6 | 7 | 8 | def test_repro_issue_96(): 9 | @dataclass 10 | class Options(TestSetup): 11 | list_items: list[str] # SOMETHING 12 | 13 | parser = ArgumentParser(add_option_string_dash_variants=True) 14 | parser.add_arguments(Options, dest="options") 15 | 16 | with exits_and_writes_to_stderr(match="the following arguments are required: --list_items"): 17 | assert Options.setup("") 18 | 19 | assert Options.setup("--list_items foo") == Options(list_items=["foo"]) 20 | assert Options.setup("--list_items") == Options(list_items=[]) 21 | -------------------------------------------------------------------------------- /test/test_optional_union.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Optional, Union 4 | 5 | from .testutils import TestSetup 6 | 7 | 8 | def test_optional_union(): 9 | @dataclass 10 | class Config(TestSetup): 11 | path: Optional[Union[Path, str]] = None 12 | 13 | config = Config.setup("") 14 | assert config.path is None 15 | 16 | config = Config.setup("--path") 17 | assert config.path is None 18 | 19 | config = Config.setup("--path bob") 20 | assert config.path == Path("bob") 21 | -------------------------------------------------------------------------------- /test/test_performance.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import importlib 3 | import sys 4 | from pathlib import Path 5 | from typing import Callable, TypeVar 6 | 7 | import pytest 8 | from pytest_benchmark.fixture import BenchmarkFixture 9 | 10 | from .testutils import needs_yaml 11 | 12 | C = TypeVar("C", bound=Callable) 13 | 14 | 15 | def import_sp(): 16 | assert "simple_parsing" not in sys.modules 17 | __import__("simple_parsing") 18 | 19 | 20 | def unimport_sp(): 21 | if "simple_parsing" in sys.modules: 22 | import simple_parsing # noqa 23 | 24 | del simple_parsing 25 | importlib.invalidate_caches() 26 | sys.modules.pop("simple_parsing") 27 | assert "simple_parsing" not in sys.modules 28 | 29 | 30 | def clear_lru_caches(): 31 | from simple_parsing.docstring import dp_parse, inspect_getdoc, inspect_getsource 32 | 33 | dp_parse.cache_clear() 34 | inspect_getdoc.cache_clear() 35 | inspect_getsource.cache_clear() 36 | 37 | 38 | def call_before(before: Callable[[], None], fn: C) -> C: 39 | @functools.wraps(fn) 40 | def wrapped(*args, **kwargs): 41 | before() 42 | return fn(*args, **kwargs) 43 | 44 | return wrapped # type: ignore 45 | 46 | 47 | @pytest.mark.benchmark( 48 | group="import", 49 | ) 50 | def test_import_performance(benchmark: BenchmarkFixture): 51 | # NOTE: Issue is that the `conftest.py` actually already imports simple-parsing! 52 | benchmark(call_before(unimport_sp, import_sp)) 53 | 54 | 55 | @pytest.mark.benchmark( 56 | group="parse", 57 | ) 58 | def test_parse_performance(benchmark: BenchmarkFixture): 59 | from test.nesting.example_use_cases import HyperParameters 60 | 61 | import simple_parsing as sp 62 | 63 | benchmark( 64 | call_before(clear_lru_caches, sp.parse), 65 | HyperParameters, 66 | args="--age_group.num_layers 5 --age_group.num_units 65 ", 67 | ) 68 | 69 | 70 | @pytest.mark.benchmark( 71 | group="serialization", 72 | ) 73 | @pytest.mark.parametrize("filetype", [pytest.param(".yaml", marks=needs_yaml), ".json", ".pkl"]) 74 | def test_serialization_performance(benchmark: BenchmarkFixture, tmp_path: Path, filetype: str): 75 | from test.test_huggingface_compat import TrainingArguments 76 | 77 | from simple_parsing.helpers.serialization import load, save 78 | 79 | args = TrainingArguments() 80 | path = (tmp_path / "bob").with_suffix(filetype) 81 | 82 | def save_and_load(): 83 | clear_lru_caches() 84 | # NOTE: can't just use unlink(missing_ok=True) since python3.7 doesn't have it. 85 | if path.exists(): 86 | path.unlink() 87 | save(args, path) 88 | assert load(TrainingArguments, path) == args 89 | 90 | benchmark(save_and_load) 91 | -------------------------------------------------------------------------------- /test/test_positional.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from simple_parsing import field 4 | 5 | from .testutils import TestSetup 6 | 7 | 8 | def test_single_posarg(): 9 | @dataclass 10 | class Foo(TestSetup): 11 | output_dir: str = field(positional=True) 12 | extra_flag: bool = False 13 | 14 | foo = Foo.setup("/bob --extra_flag") 15 | assert foo.output_dir == "/bob" 16 | assert foo.extra_flag 17 | 18 | 19 | def test_repeated_posarg(): 20 | @dataclass 21 | class Foo(TestSetup): 22 | output_dir: list[str] = field(positional=True) 23 | extra_flag: bool = False 24 | 25 | # Here we see why 'invoke' wrote their own parser. Doesn't seem obvious how to explain to argparse that 26 | # --extra_flag /cherry (which is a little ambiguous) or --extra_flag True /cherry is something we'd like to allow. 27 | foo = Foo.setup("/alice /bob /cherry --extra_flag") 28 | assert foo.output_dir == ["/alice", "/bob", "/cherry"] 29 | assert foo.extra_flag 30 | -------------------------------------------------------------------------------- /test/test_replace_subgroups.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass, field 4 | 5 | from simple_parsing import replace_subgroups, subgroups 6 | 7 | 8 | @dataclass 9 | class A: 10 | a: float = 0.0 11 | 12 | 13 | @dataclass 14 | class B: 15 | b: str = "bar" 16 | 17 | 18 | @dataclass 19 | class AorB: 20 | a_or_b: A | B = subgroups({"a": A, "b": B}, default_factory=A) 21 | 22 | 23 | @dataclass(frozen=True) 24 | class FrozenConfig: 25 | a: int = 1 26 | b: str = "bob" 27 | 28 | 29 | odd = FrozenConfig(a=1, b="odd") 30 | even = FrozenConfig(a=2, b="even") 31 | 32 | 33 | @dataclass 34 | class Config: 35 | subgroup: A | B = subgroups({"a": A, "b": B}, default_factory=A) 36 | frozen_subgroup: FrozenConfig = subgroups({"odd": odd, "even": even}, default=odd) 37 | optional: A | None = None 38 | implicit_optional: A = None 39 | union: A | B = field(default_factory=A) 40 | nested_subgroup: AorB = field(default_factory=AorB) 41 | 42 | 43 | def test_replace_subgroups(): 44 | c = Config() 45 | assert replace_subgroups(c, {"subgroup": "b"}) == Config(subgroup=B()) 46 | assert replace_subgroups(c, {"frozen_subgroup": "odd"}) == Config(frozen_subgroup=odd) 47 | assert replace_subgroups(c, {"optional": A}) == Config(optional=A()) 48 | assert replace_subgroups(c, {"implicit_optional": A}) == Config(implicit_optional=A()) 49 | assert replace_subgroups(c, {"union": B}) == Config(union=B()) 50 | assert replace_subgroups(c, {"nested_subgroup.a_or_b": "b"}) == Config( 51 | nested_subgroup=AorB(a_or_b=B()) 52 | ) 53 | assert replace_subgroups(c, {"nested_subgroup": {"a_or_b": "b"}}) == Config( 54 | nested_subgroup=AorB(a_or_b=B()) 55 | ) 56 | -------------------------------------------------------------------------------- /test/test_subgroups/test_help[Config---help].md: -------------------------------------------------------------------------------- 1 | # Regression file for test_subgroups.py::test_help 2 | 3 | Given Source code: 4 | 5 | ```python 6 | @dataclass 7 | class Config(TestSetup): 8 | # Which model to use 9 | model: ModelConfig = subgroups( 10 | {"model_a": ModelAConfig, "model_b": ModelBConfig}, 11 | default_factory=ModelAConfig, 12 | ) 13 | 14 | ``` 15 | 16 | and command: '--help' 17 | 18 | We expect to get: 19 | 20 | ```console 21 | usage: pytest [-h] [--model {model_a,model_b}] [--lr float] [--optimizer str] 22 | [--betas float float] 23 | 24 | options: 25 | -h, --help show this help message and exit 26 | 27 | Config ['config']: 28 | Config(model: 'ModelConfig' = ) 29 | 30 | --model {model_a,model_b} 31 | Which model to use (default: model_a) 32 | 33 | ModelAConfig ['config.model']: 34 | ModelAConfig(lr: 'float' = 0.0003, optimizer: 'str' = 'Adam', betas: 'tuple[float, float]' = (0.9, 0.999)) 35 | 36 | --lr float (default: 0.0003) 37 | --optimizer str (default: Adam) 38 | --betas float float (default: (0.9, 0.999)) 39 | 40 | ``` 41 | -------------------------------------------------------------------------------- /test/test_subgroups/test_help[Config---model=model_a --help].md: -------------------------------------------------------------------------------- 1 | # Regression file for test_subgroups.py::test_help 2 | 3 | Given Source code: 4 | 5 | ```python 6 | @dataclass 7 | class Config(TestSetup): 8 | # Which model to use 9 | model: ModelConfig = subgroups( 10 | {"model_a": ModelAConfig, "model_b": ModelBConfig}, 11 | default_factory=ModelAConfig, 12 | ) 13 | 14 | ``` 15 | 16 | and command: '--model=model_a --help' 17 | 18 | We expect to get: 19 | 20 | ```console 21 | usage: pytest [-h] [--model {model_a,model_b}] [--lr float] [--optimizer str] 22 | [--betas float float] 23 | 24 | options: 25 | -h, --help show this help message and exit 26 | 27 | Config ['config']: 28 | Config(model: 'ModelConfig' = ) 29 | 30 | --model {model_a,model_b} 31 | Which model to use (default: model_a) 32 | 33 | ModelAConfig ['config.model']: 34 | ModelAConfig(lr: 'float' = 0.0003, optimizer: 'str' = 'Adam', betas: 'tuple[float, float]' = (0.9, 0.999)) 35 | 36 | --lr float (default: 0.0003) 37 | --optimizer str (default: Adam) 38 | --betas float float (default: (0.9, 0.999)) 39 | 40 | ``` 41 | -------------------------------------------------------------------------------- /test/test_subgroups/test_help[Config---model=model_b --help].md: -------------------------------------------------------------------------------- 1 | # Regression file for test_subgroups.py::test_help 2 | 3 | Given Source code: 4 | 5 | ```python 6 | @dataclass 7 | class Config(TestSetup): 8 | # Which model to use 9 | model: ModelConfig = subgroups( 10 | {"model_a": ModelAConfig, "model_b": ModelBConfig}, 11 | default_factory=ModelAConfig, 12 | ) 13 | 14 | ``` 15 | 16 | and command: '--model=model_b --help' 17 | 18 | We expect to get: 19 | 20 | ```console 21 | usage: pytest [-h] [--model {model_a,model_b}] [--lr float] [--optimizer str] 22 | [--momentum float] 23 | 24 | options: 25 | -h, --help show this help message and exit 26 | 27 | Config ['config']: 28 | Config(model: 'ModelConfig' = ) 29 | 30 | --model {model_a,model_b} 31 | Which model to use (default: model_a) 32 | 33 | ModelBConfig ['config.model']: 34 | ModelBConfig(lr: 'float' = 0.001, optimizer: 'str' = 'SGD', momentum: 'float' = 1.234) 35 | 36 | --lr float (default: 0.001) 37 | --optimizer str (default: SGD) 38 | --momentum float (default: 1.234) 39 | 40 | ``` 41 | -------------------------------------------------------------------------------- /test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md: -------------------------------------------------------------------------------- 1 | # Regression file for test_subgroups.py::test_help 2 | 3 | Given Source code: 4 | 5 | ```python 6 | @dataclasses.dataclass 7 | class ConfigWithFrozen(TestSetup): 8 | conf: FrozenConfig = subgroups({"odd": odd, "even": even}, default=odd) 9 | 10 | ``` 11 | 12 | and command: '--conf=even --a 100 --help' 13 | 14 | We expect to get: 15 | 16 | ```console 17 | usage: pytest [-h] [--conf {odd,even}] [-a int] [-b str] 18 | 19 | options: 20 | -h, --help show this help message and exit 21 | 22 | ConfigWithFrozen ['config_with_frozen']: 23 | ConfigWithFrozen(conf: 'FrozenConfig' = 'odd') 24 | 25 | --conf {odd,even} (default: odd) 26 | 27 | FrozenConfig ['config_with_frozen.conf']: 28 | FrozenConfig(a: 'int' = 1, b: 'str' = 'bob') 29 | 30 | -a int, --a int (default: 2) 31 | -b str, --b str (default: even) 32 | 33 | ``` 34 | -------------------------------------------------------------------------------- /test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md: -------------------------------------------------------------------------------- 1 | # Regression file for test_subgroups.py::test_help 2 | 3 | Given Source code: 4 | 5 | ```python 6 | @dataclasses.dataclass 7 | class ConfigWithFrozen(TestSetup): 8 | conf: FrozenConfig = subgroups({"odd": odd, "even": even}, default=odd) 9 | 10 | ``` 11 | 12 | and command: '--conf=even --help' 13 | 14 | We expect to get: 15 | 16 | ```console 17 | usage: pytest [-h] [--conf {odd,even}] [-a int] [-b str] 18 | 19 | options: 20 | -h, --help show this help message and exit 21 | 22 | ConfigWithFrozen ['config_with_frozen']: 23 | ConfigWithFrozen(conf: 'FrozenConfig' = 'odd') 24 | 25 | --conf {odd,even} (default: odd) 26 | 27 | FrozenConfig ['config_with_frozen.conf']: 28 | FrozenConfig(a: 'int' = 1, b: 'str' = 'bob') 29 | 30 | -a int, --a int (default: 2) 31 | -b str, --b str (default: even) 32 | 33 | ``` 34 | -------------------------------------------------------------------------------- /test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md: -------------------------------------------------------------------------------- 1 | # Regression file for test_subgroups.py::test_help 2 | 3 | Given Source code: 4 | 5 | ```python 6 | @dataclasses.dataclass 7 | class ConfigWithFrozen(TestSetup): 8 | conf: FrozenConfig = subgroups({"odd": odd, "even": even}, default=odd) 9 | 10 | ``` 11 | 12 | and command: '--conf=odd --a 123 --help' 13 | 14 | We expect to get: 15 | 16 | ```console 17 | usage: pytest [-h] [--conf {odd,even}] [-a int] [-b str] 18 | 19 | options: 20 | -h, --help show this help message and exit 21 | 22 | ConfigWithFrozen ['config_with_frozen']: 23 | ConfigWithFrozen(conf: 'FrozenConfig' = 'odd') 24 | 25 | --conf {odd,even} (default: odd) 26 | 27 | FrozenConfig ['config_with_frozen.conf']: 28 | FrozenConfig(a: 'int' = 1, b: 'str' = 'bob') 29 | 30 | -a int, --a int (default: 1) 31 | -b str, --b str (default: odd) 32 | 33 | ``` 34 | -------------------------------------------------------------------------------- /test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md: -------------------------------------------------------------------------------- 1 | # Regression file for test_subgroups.py::test_help 2 | 3 | Given Source code: 4 | 5 | ```python 6 | @dataclasses.dataclass 7 | class ConfigWithFrozen(TestSetup): 8 | conf: FrozenConfig = subgroups({"odd": odd, "even": even}, default=odd) 9 | 10 | ``` 11 | 12 | and command: '--conf=odd --help' 13 | 14 | We expect to get: 15 | 16 | ```console 17 | usage: pytest [-h] [--conf {odd,even}] [-a int] [-b str] 18 | 19 | options: 20 | -h, --help show this help message and exit 21 | 22 | ConfigWithFrozen ['config_with_frozen']: 23 | ConfigWithFrozen(conf: 'FrozenConfig' = 'odd') 24 | 25 | --conf {odd,even} (default: odd) 26 | 27 | FrozenConfig ['config_with_frozen.conf']: 28 | FrozenConfig(a: 'int' = 1, b: 'str' = 'bob') 29 | 30 | -a int, --a int (default: 1) 31 | -b str, --b str (default: odd) 32 | 33 | ``` 34 | -------------------------------------------------------------------------------- /test/test_subgroups/test_help[ConfigWithFrozen---help].md: -------------------------------------------------------------------------------- 1 | # Regression file for test_subgroups.py::test_help 2 | 3 | Given Source code: 4 | 5 | ```python 6 | @dataclasses.dataclass 7 | class ConfigWithFrozen(TestSetup): 8 | conf: FrozenConfig = subgroups({"odd": odd, "even": even}, default=odd) 9 | 10 | ``` 11 | 12 | and command: '--help' 13 | 14 | We expect to get: 15 | 16 | ```console 17 | usage: pytest [-h] [--conf {odd,even}] [-a int] [-b str] 18 | 19 | options: 20 | -h, --help show this help message and exit 21 | 22 | ConfigWithFrozen ['config_with_frozen']: 23 | ConfigWithFrozen(conf: 'FrozenConfig' = 'odd') 24 | 25 | --conf {odd,even} (default: odd) 26 | 27 | FrozenConfig ['config_with_frozen.conf']: 28 | FrozenConfig(a: 'int' = 1, b: 'str' = 'bob') 29 | 30 | -a int, --a int (default: 1) 31 | -b str, --b str (default: odd) 32 | 33 | ``` 34 | -------------------------------------------------------------------------------- /test/test_union.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Union 3 | 4 | from .testutils import TestSetup, exits_and_writes_to_stderr 5 | 6 | 7 | def test_union_type(): 8 | @dataclass 9 | class Foo(TestSetup): 10 | x: Union[int, float, str] = 0 11 | 12 | foo = Foo.setup("--x 1.2") 13 | assert foo.x == 1.2 14 | 15 | foo = Foo.setup("--x bob") 16 | assert foo.x == "bob" 17 | 18 | foo = Foo.setup("--x 2") 19 | assert foo.x == 2 and isinstance(foo.x, int) 20 | 21 | 22 | def test_union_type_raises_error(): 23 | @dataclass 24 | class Foo2(TestSetup): 25 | x: Union[int, float] = 0 26 | 27 | foo = Foo2.setup("--x 1.2") 28 | assert foo.x == 1.2 29 | 30 | with exits_and_writes_to_stderr(match="invalid int|float value: 'bob'"): 31 | foo = Foo2.setup("--x bob") 32 | 33 | foo = Foo2.setup("--x 2") 34 | assert foo.x == 2 and isinstance(foo.x, int) 35 | -------------------------------------------------------------------------------- /test/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lebrice/SimpleParsing/2a69cebb38172b9317c0175aadde6d3b432371ce/test/utils/__init__.py -------------------------------------------------------------------------------- /test/utils/test_mutable_field.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from dataclasses import dataclass 3 | from typing import Generic 4 | 5 | import pytest 6 | from typing_extensions import NamedTuple # For Generic NamedTuples 7 | 8 | from simple_parsing import mutable_field 9 | 10 | from ..conftest import T, default_values_for_type, simple_arguments 11 | from ..testutils import TestSetup 12 | 13 | 14 | @dataclass 15 | class A: 16 | a: str = "bob" 17 | 18 | 19 | @dataclass 20 | class B: 21 | # # shared_list: List = [] # not allowed. 22 | # different_list: List = field(default_factory=list) 23 | if sys.version_info < (3, 11): 24 | shared: A = A() 25 | different: A = mutable_field(A, a="123") 26 | 27 | 28 | def test_mutable_field_sharing(): 29 | b1 = B() 30 | b2 = B() 31 | if sys.version_info < (3, 11): 32 | assert b1.shared is b2.shared 33 | assert b1.different is not b2.different 34 | 35 | 36 | class SimpleAttributeWithTwoDefaults(NamedTuple, Generic[T]): 37 | field_type: type[T] 38 | passed_cmdline_value: str 39 | expected_value: T 40 | default_value: T 41 | other_default_value: T 42 | 43 | 44 | @pytest.fixture( 45 | params=[ 46 | SimpleAttributeWithTwoDefaults( 47 | some_type, 48 | passed_value, 49 | expected_value, 50 | default_value, 51 | other_default_value=default_values_for_type[some_type][ 52 | (i + 1) % len(default_values_for_type[some_type]) 53 | ], 54 | ) 55 | for some_type, passed_value, expected_value in simple_arguments 56 | for i, default_value in enumerate(default_values_for_type[some_type]) 57 | ] 58 | ) 59 | def simple_attribute_with_two_defaults(request: pytest.FixtureRequest): 60 | return request.param 61 | 62 | 63 | def test_uses_default_from_field_kwargs( 64 | simple_attribute_with_two_defaults: SimpleAttributeWithTwoDefaults, 65 | ): 66 | ( 67 | field_type, 68 | passed_cmdline_value, 69 | expected_value, 70 | default_value, 71 | other_default_value, 72 | ) = simple_attribute_with_two_defaults 73 | 74 | @dataclass 75 | class Inner: 76 | a: field_type = other_default_value # type: ignore 77 | 78 | @dataclass 79 | class B(TestSetup): 80 | inner: Inner = mutable_field(Inner, a=default_value) 81 | 82 | assert Inner() == Inner(a=other_default_value) 83 | # Constructing the field works just like a regular dataclass field with a default factory: 84 | assert B() == B(inner=Inner(a=default_value)) 85 | # No arguments passed: Should do the same thing: 86 | assert B.setup("") == B(inner=Inner(a=default_value)) 87 | 88 | # Now, passing a value should 89 | assert B.setup(f"--a={passed_cmdline_value}") == B(inner=Inner(a=expected_value)) 90 | -------------------------------------------------------------------------------- /test/utils/test_yaml.py: -------------------------------------------------------------------------------- 1 | """Tests for serialization to/from yaml files.""" 2 | import textwrap 3 | from dataclasses import dataclass 4 | 5 | import pytest 6 | 7 | from simple_parsing import list_field 8 | 9 | yaml = pytest.importorskip("yaml") 10 | 11 | from simple_parsing.helpers.serialization.yaml_serialization import YamlSerializable # noqa: E402 12 | 13 | 14 | @dataclass 15 | class Point(YamlSerializable): 16 | x: int = 0 17 | y: int = 0 18 | 19 | 20 | @dataclass 21 | class Config(YamlSerializable): 22 | name: str = "train" 23 | bob: int = 123 24 | some_float: float = 1.23 25 | 26 | points: list[Point] = list_field() 27 | 28 | 29 | def test_dumps(): 30 | p1 = Point(x=1, y=6) 31 | p2 = Point(x=3, y=1) 32 | config = Config(name="heyo", points=[p1, p2]) 33 | assert config.dumps() == textwrap.dedent( 34 | """\ 35 | bob: 123 36 | name: heyo 37 | points: 38 | - x: 1 39 | y: 6 40 | - x: 3 41 | y: 1 42 | some_float: 1.23 43 | """ 44 | ) 45 | 46 | 47 | def test_dumps_loads(): 48 | p1 = Point(x=1, y=6) 49 | p2 = Point(x=3, y=1) 50 | config = Config(name="heyo", points=[p1, p2]) 51 | assert Config.loads(config.dumps()) == config 52 | 53 | assert config == Config.loads( 54 | textwrap.dedent( 55 | """\ 56 | bob: 123 57 | name: heyo 58 | points: 59 | - x: 1 60 | y: 6 61 | - x: 3 62 | y: 1 63 | some_float: 1.23 64 | """ 65 | ) 66 | ) 67 | 68 | 69 | # def test_save_yml(HyperParameters, tmpdir: Path): 70 | # hparams = HyperParameters.setup("") 71 | # tmp_path = Path(tmpdir / "temp.pth") 72 | # hparams.save(tmp_path) 73 | 74 | # _hparams = HyperParameters.load(tmp_path) 75 | # assert hparams == _hparams 76 | --------------------------------------------------------------------------------