├── .github ├── PULL_REQUEST_TEMPLATE.md └── workflows │ └── python-check-publish.yml ├── .gitignore ├── .readthedocs.yml ├── CONTRIBUTING.md ├── LICENSE.md ├── NOTICE ├── README.md ├── docs ├── CONTRIBUTING.md ├── LICENSE.md ├── README.md ├── advanced.md ├── api.rst ├── conf.py ├── index.rst ├── requirements.txt └── smart-arg-demo.gif ├── publish.sh ├── setup.cfg ├── setup.py ├── smart-arg-demo.gif ├── smart_arg.py └── test ├── smart_arg_demo.py └── test_smart_arg.py /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. 4 | 5 | Fixes # (issue) 6 | 7 | ## Type of change 8 | 9 | Please delete options that are not relevant. 10 | 11 | - [ ] Bug fix (non-breaking change which fixes an issue) 12 | - [ ] New feature (non-breaking change which adds functionality) 13 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 14 | 15 | ## List all changes 16 | Please list all changes in the commit. 17 | * change1 18 | * change2 19 | 20 | # Testing 21 | Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration 22 | 23 | 24 | **Test Configuration**: 25 | * Firmware version: 26 | * Hardware: 27 | * Toolchain: 28 | * SDK: 29 | 30 | # Checklist 31 | 32 | - [ ] My code follows the style guidelines of this project 33 | - [ ] I have performed a self-review of my own code 34 | - [ ] I have commented my code, particularly in hard-to-understand areas 35 | - [ ] I have made corresponding changes to the documentation 36 | - [ ] My changes generate no new warnings 37 | - [ ] I have added tests that prove my fix is effective or that my feature works 38 | - [ ] New and existing unit tests pass locally with my changes 39 | - [ ] Any dependent changes have been merged and published in downstream modules 40 | -------------------------------------------------------------------------------- /.github/workflows/python-check-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python check and publish 5 | 6 | on: 7 | pull_request: 8 | push: 9 | branches: 10 | - master 11 | 12 | jobs: 13 | check: 14 | strategy: 15 | matrix: 16 | python-version: [3.6, 3.7, 3.8, 3.9] 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip setuptools 28 | python setup.py develop 29 | - name: flake8 30 | run: | 31 | pip install flake8 32 | # stop the build if there are Python syntax errors or undefined names 33 | flake8 smart_arg.py 34 | - name: mypy 35 | run: | 36 | pip install types-pkg_resources 37 | [ ${{ matrix.python-version }} == 3.6 ] && pip install dataclasses types-dataclasses 38 | pip install mypy 39 | mypy 40 | - name: Test with pytest 41 | run: | 42 | pip install pytest 43 | pytest 44 | 45 | publish: 46 | runs-on: ubuntu-latest 47 | needs: check 48 | if: > 49 | github.ref == 'refs/heads/master' && github.event_name == 'push' && github.repository_owner == 'linkedin' 50 | && !contains(github.event.head_commit.message, 'NO_PUBLISH') 51 | 52 | steps: 53 | - uses: actions/checkout@v2 54 | with: 55 | fetch-depth: '0' # To fetch tags too 56 | - name: Set up Python 57 | uses: actions/setup-python@v2 58 | with: 59 | python-version: '3.7' 60 | - name: Install dependencies 61 | run: | 62 | python -m pip install --upgrade pip setuptools twine 63 | python setup.py develop 64 | - name: Build, publish and tag 65 | env: 66 | TWINE_USERNAME: __token__ 67 | TWINE_PASSWORD: ${{ secrets.pypi_token }} 68 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 69 | run: | 70 | bash publish.sh $(python -c "import smart_arg; print(smart_arg._base_version)") 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.egg 3 | *.egg-info/ 4 | *.iml 5 | *.ipr 6 | *.iws 7 | *.pyc 8 | *.pyo 9 | *.sublime-* 10 | .*.swo 11 | .*.swp 12 | .cache/ 13 | .coverage 14 | .direnv/ 15 | .env 16 | .envrc 17 | .gradle/ 18 | .idea/ 19 | .tox* 20 | .venv* 21 | .vscode/ 22 | /*/*pinned.txt 23 | /*/.mypy_cache/ 24 | /*/MANIFEST 25 | /*/activate 26 | /*/build/ 27 | /*/config 28 | /*/coverage.xml 29 | /*/dist/ 30 | /*/htmlcov/ 31 | /*/product-spec.json 32 | /build/ 33 | /config/external/ 34 | /dist/ 35 | /ligradle/ 36 | TEST-*.xml 37 | __pycache__/ 38 | 39 | 40 | 41 | # Added by mp-maker 42 | build 43 | .shelf 44 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | python: 4 | version: 3.7 5 | install: 6 | - requirements: docs/requirements.txt 7 | - method: pip 8 | path: . -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # CONTRIBUTION 2 | 3 | ## Contribution Agreement 4 | 5 | As a contributor, you represent that the code you submit is your original work or 6 | that of your employer (in which case you represent you have the right to bind your 7 | employer). By submitting code, you (and, if applicable, your employer) are 8 | licensing the submitted code to LinkedIn and the open source community subject 9 | to the BSD 2-Clause license. 10 | 11 | ## Responsible Disclosure of Security Vulnerabilities 12 | 13 | **Do not file an issue on Github for security issues.** Please review 14 | the [guidelines for disclosure][disclosure_guidelines]. Reports should 15 | be encrypted using PGP ([public key][pubkey]) and sent to 16 | [security@linkedin.com][disclosure_email] preferably with the title 17 | "Vulnerability in Github LinkedIn/smart-arg - <short summary>". 18 | 19 | ## Setup for Development 20 | 21 | ```shell-session 22 | # # Uncomment the next a few lines to set up a virtual environment and install the packages as needed 23 | # python3 -m venv .venv 24 | # . .venv/bin/activate 25 | # pip install -U setuptools pytest flake8 mypy 26 | 27 | python3 setup.py develop 28 | 29 | # Happily make changes 30 | # ... 31 | 32 | mypy 33 | flake8 smart_arg.py 34 | pytest 35 | ``` 36 | 37 | 38 | ## Tips for Getting Your Pull Request Accepted 39 | 40 | 1. Make sure all new features are tested and the tests pass. 41 | 2. Bug fixes must include a test case demonstrating the error that it fixes. 42 | 3. Open an issue first and seek advice for your change before submitting 43 | a pull request. Large features which have never been discussed are 44 | unlikely to be accepted. **You have been warned.** 45 | 46 | [disclosure_guidelines]: https://www.linkedin.com/help/linkedin/answer/62924 47 | [pubkey]: https://www.linkedin.com/help/linkedin/answer/79676 48 | [disclosure_email]: mailto:security@linkedin.com?subject=Vulnerability%20in%20Github%20LinkedIn/smart-arg%20-%20%3Csummary%3E 49 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | BSD 2-CLAUSE LICENSE 2 | ==== 3 | BSD 2-CLAUSE LICENSE 4 | Copyright 2020 LinkedIn Corporation 5 | All Rights Reserved. 6 | 7 | Redistribution and use in source and binary forms, with or 8 | without modification, are permitted provided that the following 9 | conditions are met: 10 | 11 | 1. Redistributions of source code must retain the above copyright 12 | notice, this list of conditions and the following disclaimer. 13 | 2. Redistributions in binary form must reproduce the above 14 | copyright notice, this list of conditions and the following 15 | disclaimer in the documentation and/or other materials provided 16 | with the distribution. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2020 LinkedIn Corporation 2 | All Rights Reserved. 3 | 4 | Licensed under the BSD 2-Clause License (the "License"). See License in the project root for license information. 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Smart Argument Suite (`smart-arg`) 2 | 3 | [![GitHub tag](https://img.shields.io/github/tag/linkedin/smart-arg.svg)](https://GitHub.com/linkedin/smart-arg/tags/) 4 | [![PyPI version](https://img.shields.io/pypi/v/smart-arg.svg)](https://pypi.python.org/pypi/smart-arg/) 5 | 6 | Smart Argument Suite (`smart-arg`) is a slim and handy Python library that helps one work safely and conveniently 7 | with the arguments that are represented by an immutable argument container class' fields 8 | ([`NamedTuple`](https://docs.python.org/3.7/library/typing.html?highlight=namedtuple#typing.NamedTuple) or 9 | [`dataclass`](https://docs.python.org/3.7/library/dataclasses.html#dataclasses.dataclass) out-of-box), 10 | and passed through command-line interfaces. 11 | 12 | `smart-arg` promotes arguments type-safety, enables IDEs' code autocompletion and type hints 13 | functionalities, and helps one produce correct code. 14 | 15 | ![](smart-arg-demo.gif) 16 | 17 | ## Quick start 18 | 19 | The [`smart-arg`](https://pypi.org/project/smart-arg/) package is available through `pip`. 20 | ```shell 21 | pip3 install smart-arg 22 | ``` 23 | 24 | Users can bring or define, if not already, their argument container class -- a `NamedTuple` or `dataclass`, 25 | and then annotate it with `smart-arg` decorator `@arg_suite` in their Python scripts. 26 | 27 | Now an argument container class instance, e.g. `my_arg` of `MyArg` class, once created, is ready to be serialized by the `smart-arg` API -- 28 | `my_arg.__to_argv__()` to a sequence of strings, passed through the command-line interface 29 | and then deserialized back to an instance again by `my_arg = MyArg.__from_argv__(sys.argv[1:])`. 30 | 31 | ```python 32 | import sys 33 | from typing import NamedTuple, List, Tuple, Dict, Optional 34 | from smart_arg import arg_suite 35 | 36 | 37 | # Define the argument container class 38 | @arg_suite 39 | class MyArg(NamedTuple): 40 | """ 41 | MyArg is smart! (docstring goes to description) 42 | """ 43 | nn: List[int] # Comments go to argparse help 44 | a_tuple: Tuple[str, int] # a random tuple argument 45 | encoder: str # Text encoder type 46 | h_param: Dict[str, int] # Hyperparameters 47 | batch_size: Optional[int] = None 48 | adp: bool = True # bool is a bit tricky 49 | embedding_dim: int = 100 # Size of embedding vector 50 | lr: float = 1e-3 # Learning rate 51 | 52 | 53 | def cli_interfaced_job_scheduler(): 54 | """ 55 | This is to be called by the job scheduler to set up the job launching command, 56 | i.e., producer side of the Python job arguments 57 | """ 58 | # Create the argument container instance 59 | my_arg = MyArg(nn=[3], a_tuple=("str", 1), encoder='lstm', h_param={}, adp=False) # The patched argument container class requires keyword arguments to instantiate the class 60 | 61 | # Serialize the argument to command-line representation 62 | argv = my_arg.__to_argv__() 63 | cli = 'my_job.py ' + ' '.join(argv) 64 | # Schedule the job with command line `cli` 65 | print(f"Executing job:\n{cli}") 66 | # Executing job: 67 | # my_job.py --nn 3 --a_tuple str 1 --encoder lstm --h_param --batch_size None --adp False --embedding_dim 100 --lr 0.001 68 | 69 | 70 | def my_job(my_arg: MyArg): 71 | """ 72 | This is the actual job defined by the input argument my_arg, 73 | i.e., consumer side of the Python job arguments 74 | """ 75 | print(my_arg) 76 | # MyArg(nn=[3], a_tuple=('str', 1), encoder='lstm', h_param={}, batch_size=None, adp=False, embedding_dim=100, lr=0.001) 77 | 78 | # `my_arg` can be used in later script with a typed manner, which help of IDEs (type hints and auto completion) 79 | # ... 80 | print(f"My network has {len(my_arg.nn)} layers with sizes of {my_arg.nn}.") 81 | # My network has 1 layers with sizes of [3]. 82 | 83 | 84 | # my_job.py 85 | if __name__ == '__main__': 86 | # Deserialize the command-line representation of the argument back to a container instance 87 | arg_deserialized: MyArg = MyArg.__from_argv__(sys.argv[1:]) # Equivalent to `MyArg(None)`, one positional arg required to indicate the arg is a command-line representation. 88 | my_job(arg_deserialized) 89 | ``` 90 | 91 | ```shell-session 92 | > python my_job.py -h 93 | usage: my_job.py [-h] --nn [int [int ...]] --a_tuple str int --encoder str 94 | --h_param [str:int [str:int ...]] [--batch_size int] 95 | [--adp {True,False}] [--embedding_dim int] [--lr float] 96 | 97 | MyArg is smart! (docstring goes to description) 98 | 99 | optional arguments: 100 | -h, --help show this help message and exit 101 | --nn [int [int ...]] (List[int], required) Comments go to argparse help 102 | --a_tuple str int (Tuple[str, int], required) a random tuple argument 103 | --encoder str (str, required) Text encoder type 104 | --h_param [str:int [str:int ...]] 105 | (Dict[str, int], required) Hyperparameters 106 | --batch_size int (Optional[int], default: None) 107 | --adp {True,False} (bool, default: True) bool is a bit tricky 108 | --embedding_dim int (int, default: 100) Size of embedding vector 109 | --lr float (float, default: 0.001) Learning rate 110 | 111 | ``` 112 | ## Promoted practices 113 | * Focus on defining the arguments diligently, and let the `smart-arg` 114 | (backed by [argparse.ArgumentParser](https://docs.python.org/3/library/argparse.html#argumentparser-objects)) 115 | work its magic around command-line interface. 116 | * Always work directly with argument container class instances when possible, even if you only need to generate the command-line representation. 117 | * Stick to the default behavior and the basic features, think twice before using any of the [advanced features](https://smart-arg.readthedocs.io/en/latest/advanced.html#advanced-usages). 118 | 119 | 120 | ## More detail 121 | For more features and implementation detail, please refer to the [documentation](https://smart-arg.readthedocs.io/). 122 | 123 | ## Contributing 124 | 125 | Please read [CONTRIBUTING.md](CONTRIBUTING.md) for details on our code of conduct, and the process for submitting pull requests to us. 126 | 127 | ## License 128 | 129 | This project is licensed under the BSD 2-CLAUSE LICENSE - see the [LICENSE.md](LICENSE.md) file for details 130 | -------------------------------------------------------------------------------- /docs/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ../CONTRIBUTING.md -------------------------------------------------------------------------------- /docs/LICENSE.md: -------------------------------------------------------------------------------- 1 | ../LICENSE.md -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | ../README.md -------------------------------------------------------------------------------- /docs/advanced.md: -------------------------------------------------------------------------------- 1 | ## More Features 2 | ### Post processing & validation 3 | 4 | A user can define a method in the argument container class to do post processing after the argument is created bt 5 | before returned to the caller. 6 | For example, when a field's default value depends on some other input fields, one could use a default 7 | placeholder `LateInit`, and a `__post_init__` function to define the actual value. 8 | 9 | The `__post_init__` function can also be used to validate the argument after instantiation: 10 | 11 | ```python 12 | from typing import NamedTuple 13 | 14 | from smart_arg import arg_suite, LateInit 15 | 16 | 17 | @arg_suite 18 | class MyArg(NamedTuple): 19 | network: str 20 | _network = {'choices': ['cnn', 'mlp']} # Use enum instead if you can 21 | num_layers: int = LateInit 22 | 23 | def __post_init__(self) -> 'MyArg': 24 | assert self.network in self._network['choices'], f'Invalid network {self.network}' 25 | if self.num_layers is LateInit: 26 | if self.network == 'cnn': 27 | num_layers = 3 28 | elif self.network == 'mlp': 29 | num_layers = 5 30 | else: 31 | raise RuntimeError("Should not be reachable!") 32 | self.num_layers=num_layers 33 | else: 34 | assert self.num_layers >= 0, f"number_layers: {self.num_layers} can not be negative" 35 | 36 | ``` 37 | Notes: 38 | * If any fields are assigned a default placeholder `LateInit`, a `__post_init__` is expected 39 | to be defined, replace any `LateInit` with actual values, or it will fail the internal validation and trigger a `SmartArgError`. 40 | * Field mutations are only allowed in the `__post_init__` as part of the class instance construction. No mutations are 41 | allowed after construction, including manually calling `__post_init__`. 42 | * `__post_init__` of a `NamedTuple` works similar to that of a `dataclass`. 43 | 44 | ## Supported Argument Container Classes 45 | ### [`NameTuple`](https://docs.python.org/3.7/library/typing.html?highlight=namedtuple#typing.NamedTuple) 46 | * Strong immutability 47 | * Part of the Python distribution 48 | * No inheritance 49 | 50 | ### [`dataclass`](https://docs.python.org/3.7/library/dataclasses.html) 51 | * Weak immutability with native `@dataclass(frozen=True)` or `smart-arg` patched (when not `frozen`) 52 | * `pip install dataclasses` is needed for Python 3.6 53 | * Inheritance support 54 | * Native `__post_init__` support 55 | ## Advanced Usages 56 | 57 | By default, `smart-arg` supports the following types as fields of an argument container class: 58 | * primitives: `int`, `float`, `bool`, `str`, `enum.Enum` 59 | * `Tuple`: elements of the tuple are expected to be primitives 60 | * `Sequence`/`Set`: `Sequence[int]`, `Sequence[float]`, `Sequence[bool]`, `Sequence[str]`, `Sequence[enum.Enum]`, `Set[int]`, `Set[float]`, `Set[bool]`, `Set[str]`, `Set[enum.Enum]` 61 | * `Dict`: `Dict[int, int]`, `Dict[int, float]`, `Dict[int, bool]`, `Dict[int, str]`, `Dict[float, int]`, `Dict[float, float]`, 62 | `Dict[float, bool]`, `Dict[float, str]`, `Dict[bool, int]`, `Dict[bool, float]`, `Dict[bool, bool]`, `Dict[bool, str]`, 63 | `Dict[str, int]`, `Dict[str, float]`, `Dict[str, bool]`, `Dict[str, str]`, `Dict[enum.Enum, int/float/bool/str]`, `Dict[int/float/bool/str, enum.Enum]` 64 | * `Optional[AnyOtherSupportedType]`: Beware that any optional field is required to **default to `None`**. 65 | 66 | ### override argument Ser/De 67 | A user can change the parsing behavior of certain field of an argument container class. 68 | One can only do this when the field's type is already supported by `smart-arg`. 69 | 70 | 71 | This is done by defining a private companion field starts with "``__``" (double underscores) to overwrite the keyed arguments 72 | to [ArgumentParser.add_argument](https://docs.python.org/3/library/argparse.html#the-add-argument-method) with a dictionary. 73 | The key '_serialization' defines an [`iterator/generator`](https://wiki.python.org/moin/Generators) and all other keys go to `argparse` 74 | for deserialization/parsing. 75 | 76 | ALERT: this can lead to **inconsistent behaviors** when one also generates the command-line 77 | representation of an argument container class instance, since it can only modify the deserialization 78 | behavior from the command-line representation. 79 | ```python 80 | from typing import NamedTuple, Sequence 81 | 82 | from smart_arg import arg_suite 83 | 84 | 85 | @arg_suite 86 | class MyTup(NamedTuple): 87 | a_list: Sequence[int] 88 | __a_list = {'choices': [200, 300], 'nargs': '+'} 89 | ``` 90 | 91 | ### override or extend the support of primitive and other types 92 | A user can use this provided functionality to change serialization and deserialization behavior for supported types and add support for additional types. 93 | * User can overwrite the primitive types handling by defining additional `PrimitiveHandlerAddon`. The basic primitive handler 94 | is defined in source code `PrimitiveHandlerAddon`. A user can pass in the customized handlers to the decorator. 95 | * Same to type handler by providing additional `TypeHandler` and pass in the decorator argument. `TypeHandler` is to deal with complex types 96 | other than primitive ones such as `Sequence`, `Set`, `Dict`, `Tuple`, etc. 97 | 98 | ```python 99 | from math import sqrt 100 | from typing import NamedTuple, Any, Type 101 | 102 | from smart_arg import PrimitiveHandlerAddon, TypeHandler, custom_arg_suite 103 | 104 | 105 | # overwrite int primitive type handling by squaring it 106 | class IntHandlerAddon(PrimitiveHandlerAddon): 107 | @staticmethod 108 | def build_type(arg_type) -> Any: 109 | return lambda s: int(s) ** 2 110 | 111 | @staticmethod 112 | def build_str(arg) -> str: 113 | return str(int(sqrt(arg))) 114 | 115 | @staticmethod 116 | def handles(t: Type) -> bool: 117 | return t == int 118 | 119 | 120 | class IntTypeHandler(TypeHandler): 121 | def _build_other(self, kwargs, arg_type) -> None: 122 | kwargs.type = self.primitive_addons[0].build_type(arg_type) 123 | 124 | def handles(self, t: Type) -> bool: 125 | return t == int 126 | 127 | 128 | my_suite = custom_arg_suite(primitive_handler_addons=[IntHandlerAddon], type_handlers=[IntTypeHandler]) 129 | 130 | 131 | @my_suite 132 | class MyTuple(NamedTuple): 133 | a_int: int 134 | 135 | ``` -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ------------- 3 | .. automodule:: smart_arg 4 | :members: -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import datetime 3 | import smart_arg 4 | 5 | # Add any Sphinx extension module names here, as strings. They can be extensions 6 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 7 | extensions = [ 8 | 'sphinx.ext.autodoc', 9 | 'sphinx_autodoc_typehints', 10 | 'sphinx.ext.intersphinx', 11 | 'recommonmark', 12 | ] 13 | 14 | version = release = smart_arg.__version__ 15 | # Add any paths that contain templates here, relative to this directory. 16 | templates_path = ['_templates'] 17 | 18 | # The suffix of source filenames. 19 | source_suffix = '.rst' 20 | 21 | # The master toctree document. 22 | master_doc = 'index' 23 | 24 | project = u'Smart Argument Suite' 25 | 26 | # General information about the project. 27 | copyright = f'2020-{datetime.datetime.today().year}, LinkedIn' 28 | 29 | # List of patterns, relative to source directory, that match files and 30 | # directories to ignore when looking for source files. 31 | exclude_patterns = ['_build'] 32 | 33 | # The name of the Pygments (syntax highlighting) style to use. 34 | pygments_style = 'sphinx' 35 | 36 | # -- Options for HTML output --------------------------------------------------- 37 | 38 | html_theme = 'default' 39 | # html_static_path = ['_static'] 40 | 41 | # Example configuration for intersphinx: refer to the Python standard library. 42 | intersphinx_mapping = { 43 | 'python': ('http://docs.python.org/', None), 44 | } 45 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to ``smart-arg``'s documentation! 2 | ============================================ 3 | Contents: 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | :glob: 8 | 9 | README 10 | advanced 11 | api 12 | CONTRIBUTING 13 | LICENSE 14 | 15 | 16 | Indices and tables 17 | ================== 18 | 19 | * :ref:`genindex` 20 | * :ref:`modindex` 21 | * :ref:`search` 22 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx_autodoc_typehints -------------------------------------------------------------------------------- /docs/smart-arg-demo.gif: -------------------------------------------------------------------------------- 1 | ../smart-arg-demo.gif -------------------------------------------------------------------------------- /publish.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | function git_tag() { 4 | # get repo name from git 5 | repo=$(git config --get remote.origin.url | sed 's|^.*[/:]\(.*/.*\)\(\.git\)\?$|\1|') 6 | commit=$(git rev-parse HEAD) 7 | new_tag=$1 8 | git tag "$new_tag" 9 | 10 | echo >&2 POST the new tag ref "$new_tag" to "$repo" via Github API 11 | HTTP_STATUS=$( 12 | curl -w "%{http_code}" -o >(cat >&2) -si -X POST "https://api.github.com/repos/$repo/git/refs" \ 13 | -H "Authorization: token $GITHUB_TOKEN" \ 14 | -d @- </dev/null); then 32 | echo >&2 last_matched_tag is "$last_matched_tag" for "$star_version" 33 | rest=${last_matched_tag#v$star_version} 34 | increment_version=$((${rest%%.*} + 1)) 35 | [ "$(git rev-parse HEAD)" == "$(git rev-parse "$last_matched_tag")" ] && echo HEAD already at "$last_matched_tag" >&2 && return 126 36 | fi 37 | resolved_version=$star_version_prefix$((increment_version)) 38 | echo "$star_version is resolved to version '$resolved_version'." >&2 39 | echo "$resolved_version" 40 | else 41 | echo "Using the input star version '$star_version' as a fixed version." >&2 42 | echo "$star_version" 43 | fi 44 | else 45 | echo -e "Unsupported star version: '$star_version'.\nOnly supports star minor or patch version." >&2 46 | return 127 47 | fi 48 | } 49 | 50 | # Set star_version to the first argument or $STAR_VERSION or $($STAR_VERSION_CMD) in this order 51 | star_version=${1-${STAR_VERSION-$($STAR_VERSION_CMD)}} 52 | new_version=$(resolve_version "$star_version") 53 | git_tag v"$new_version" 54 | 55 | rm -rf dist 56 | python setup.py sdist 57 | twine upload dist/* --verbose 58 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E731, W503 3 | max-line-length = 160 4 | 5 | [tool:pytest] 6 | testpaths = test 7 | 8 | [coverage:report] 9 | fail_under = 90 10 | show_missing = true 11 | 12 | [coverage:run] 13 | branch = true 14 | 15 | [mypy] 16 | files = smart_arg.py 17 | ignore_missing_imports = true -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Argument class <=> Human friendly cli""" 2 | 3 | from setuptools import setup 4 | 5 | doc = 'https://smart-arg.readthedocs.io' 6 | with open('README.md', encoding='utf-8') as f: 7 | readme = f.read() 8 | 9 | setup( 10 | name='smart-arg', 11 | use_scm_version={ 12 | 'root': '.', 13 | 'relative_to': __file__ 14 | }, 15 | setup_requires=['setuptools_scm'], 16 | description=__doc__, 17 | long_description=readme, 18 | long_description_content_type='text/markdown', 19 | license='BSD-2-CLAUSE', 20 | python_requires='>=3.6', 21 | url=doc, 22 | download_url='https://pypi.python.org/pypi/smart-arg', 23 | project_urls={ 24 | 'Documentation': doc, 25 | 'Source': 'https://github.com/linkedin/smart-arg.git', 26 | 'Tracker': 'https://github.com/linkedin/smart-arg/issues', 27 | }, 28 | py_modules=['smart_arg'], 29 | install_requires=[], 30 | tests_require=['pytest', 'mypy'], 31 | classifiers=[ 32 | 'Programming Language :: Python', 33 | 'Programming Language :: Python :: 3', 34 | 'Operating System :: OS Independent', 35 | 'License :: OSI Approved', 36 | 'Typing :: Typed' 37 | ], 38 | keywords=[ 39 | 'typing', 40 | 'argument parser', 41 | 'reverse argument parser', 42 | 'human friendly', 43 | 'configuration (de)serialization', 44 | 'python', 45 | 'cli' 46 | ] 47 | ) 48 | -------------------------------------------------------------------------------- /smart-arg-demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/smart-arg/cc762a7f8b0bb2abf98d380492ebc7e40fae927a/smart-arg-demo.gif -------------------------------------------------------------------------------- /smart_arg.py: -------------------------------------------------------------------------------- 1 | """Smart Argument Suite 2 | 3 | This module is an argument serialization and deserialization library that: 4 | 5 | - handles two-way conversions: a typed, preferably immutable argument container object 6 | (a `NamedTuple` or `dataclass` instance) <==> a command-line argv 7 | - enables IDE type hints and code auto-completion by using `NamedTuple` or `dataclass` 8 | - promotes type-safety of cli-ready arguments 9 | 10 | 11 | The following is a simple usage example:: 12 | 13 | # Define the argument container class: 14 | @arg_suite 15 | class MyArg(NamedTuple): 16 | '''MyArg is smart and the docstring goes to description''' 17 | 18 | nn: List[int] # Comments go to ArgumentParser argument help 19 | 20 | a_tuple: Tuple[str, int] # Arguments without defaults are treated as "required = True" 21 | h_param: Dict[str, int] # Also supports List, Set 22 | 23 | ###### arguments without defaults go first as required by NamedTuple ###### 24 | 25 | l_r: float = 0.01 # Arguments with defaults are treated as "required = False" 26 | n: Optional[int] = None # Optional can only default to `None` 27 | 28 | 29 | # Create the corresponding argument container class instance from the command line argument list: 30 | # by using 31 | arg: ArgClass = ArgClass.__from_argv__(sys.argv[1:]) # with factory method, need the manual type hint `:ArgClass` to help IDE 32 | # or the monkey-patched constructor of the annotated argument container class 33 | arg = ArgClass(sys.argv[1:]) # with monkey-patched constructor with one positional argument, no manual hint needed 34 | 35 | # Create a NamedTuple argument instance and generate its command line counterpart: 36 | # the monkey-patched constructor only take named arguments for directly creating the NamedTuple 37 | arg = ArgClass(nn=[200, 300], a_tuple=('t', 0), h_param={'batch_size': 1000}) 38 | # generate the corresponding command line argument list 39 | arg.__to_argv__() 40 | 41 | 42 | The module contains the following public classes/functions: 43 | 44 | - `arg_suite` -- The main entry point for Smart Argument Suite. As the 45 | example above shows, this decorator will attach an `ArgSuite` instance to 46 | the argument container class `NamedTuple` or `dataclass` and patch it. 47 | 48 | - `PrimitiveHandlerAddon` -- The base class to deal some basic operation on primitive types, 49 | and users can implement their owns to change the behavior. 50 | 51 | - `TypeHandler` -- The base class to handle types, users can extend to expand the 52 | support or change existing behaviors. 53 | 54 | All other classes and methods in this module are considered implementation details.""" 55 | 56 | import logging 57 | import os 58 | import sys 59 | from argparse import Action, ArgumentParser 60 | from collections import abc 61 | from enum import EnumMeta, Enum 62 | from itertools import chain 63 | from types import SimpleNamespace 64 | from typing import Any, Callable, Dict, Generic, Iterable, List, NamedTuple, Optional, Sequence, Tuple, Type, TypeVar, Union 65 | from warnings import warn 66 | 67 | import pkg_resources 68 | 69 | _base_version = '1.1.*' # star version should be auto-resolved to concrete number when run setup.py 70 | __version__ = pkg_resources.get_distribution(__name__).version 71 | __all__ = ( 72 | 'arg_suite', 73 | 'custom_arg_suite', 74 | 'frozenlist', 75 | 'LateInit', 76 | 'SmartArgError', 77 | 'TypeHandler', 78 | 'PrimitiveHandlerAddon', 79 | ) 80 | 81 | ArgType = TypeVar('ArgType', bound=NamedTuple) # NamedTuple is not a real class bound, but this makes mypy happier 82 | NoneType = None.__class__ 83 | FieldMeta = NamedTuple('FieldMeta', (('comment', str), ('default', Any), ('type', Type), ('optional', bool))) 84 | KwargsType = SimpleNamespace 85 | 86 | logger = logging.getLogger(__name__) 87 | SMART_ARG_LOG_LEVEL = 'SMART_ARG_LOG_LEVEL' 88 | if SMART_ARG_LOG_LEVEL in os.environ: 89 | logger.addHandler(logging.StreamHandler()) 90 | log_level = os.environ[SMART_ARG_LOG_LEVEL].upper() 91 | logger.setLevel(log_level) 92 | logger.info(f"Detected environment var 'SMART_ARG_LOG_LEVEL', set log level to '{log_level}' and log to stream.") 93 | 94 | if sys.version_info >= (3, 8): 95 | from typing import get_origin, get_args 96 | elif sys.version_info >= (3, 7): 97 | # Python == 3.7.x. Defining the back-ported get_origin, get_args 98 | # 3.7 `List.__origin__ == list` 99 | get_origin, get_args = lambda tp: getattr(tp, '__origin__', None), lambda tp: getattr(tp, '__args__', []) 100 | elif sys.version_info >= (3, 6): 101 | # Python == 3.6.x. Defining the back-ported get_origin, get_args 102 | # 3.6 `List.__origin__ == List`, `Optional` does not have `__dict__` 103 | get_origin, get_args = lambda tp: getattr(tp, '__extra__', ()) or getattr(tp, '__origin__', None), lambda tp: getattr(tp, '__args__', ()) or [] 104 | else: 105 | try: 106 | warn(f"Unsupported and untested python version {sys.version_info} < 3.6. " 107 | f"The package may or may not work properly. " 108 | f"Try 'from typing_inspect import get_origin, get_args' now.") 109 | from typing_inspect import get_args, get_origin 110 | except ImportError: 111 | warn(f"`from typing_inspect import get_origin, get_args` failed for " 112 | f"unsupported python version {sys.version_info} < 3.6. It might work with 'pip install typing_inspect'.") 113 | raise 114 | 115 | frozenlist = tuple # tuple to simulate an immutable list. [https://www.python.org/dev/peps/pep-0603/#rationale] 116 | _black_hole = lambda *_, **__: None 117 | _mro_all = lambda arg_class, get_dict: {k: v for b in (*arg_class.__mro__[-1:0:-1], arg_class) if b.__module__ != 'builtins' for k, v in get_dict(b).items()} 118 | # note: Fields without type annotation won't be regarded as a property/entry _fields. 119 | _annotations = lambda arg_class: _mro_all(arg_class, lambda b: getattr(b, '__annotations__', {})) 120 | try: 121 | from dataclasses import MISSING, asdict, is_dataclass 122 | except ImportError: 123 | logger.warning("Importing dataclasses failed. You might need to 'pip install dataclasses' on python 3.6.") 124 | class MISSING: pass # type: ignore # noqa: E701 125 | is_dataclass = _black_hole # type: ignore # Always return None 126 | 127 | 128 | class LateInit: 129 | """special singleton/class to mark late initialized fields""" 130 | 131 | 132 | class SmartArgError(Exception): # TODO Extend to better represent different types of errors. 133 | """Base exception for smart-arg.""" 134 | 135 | 136 | def _raise_if(message: str, condition: Any = True): 137 | if condition: 138 | raise SmartArgError(message) 139 | 140 | 141 | def _first_handles(with_handles, arg_type, default: Optional[bool] = None): 142 | handle = next(filter(lambda h: h.handles(arg_type), with_handles), default) 143 | return handle if handle is not None else _raise_if(f"{arg_type!r} is not supported.") 144 | 145 | 146 | class PrimitiveHandlerAddon: 147 | """Primitive handler addon defines some basic operations on primitive types. Only `staticmethod` is expected. 148 | Users can extend/modify the primitive handling by inheriting this class.""" 149 | @staticmethod 150 | def build_type(arg_type: Type) -> Union[Type, Callable[[str], Any]]: 151 | return (lambda s: True if s == 'True' else False if s == 'False' else _raise_if(f"Invalid bool: {s!r}")) if arg_type is bool else \ 152 | (lambda s: getattr(arg_type, s, None) or _raise_if(f"Invalid enum {s!r} for {arg_type!r}")) if type(arg_type) is EnumMeta else \ 153 | arg_type 154 | 155 | @staticmethod 156 | def build_str(arg: Any) -> str: 157 | """Define to serialize the `arg` to a string. 158 | 159 | :param arg: The argument 160 | :type arg: Any type that is supported by this class 161 | :return: The string serialization of `arg`""" 162 | return str(arg.name) if isinstance(arg, Enum) else str(arg) 163 | 164 | @staticmethod 165 | def build_metavar(arg_type: Type) -> str: 166 | """Define the hint string in argument help message for `arg_type`.""" 167 | return '{True, False}' if arg_type is bool else \ 168 | f"{{{', '.join(map(str, arg_type._member_names_))}}}" if type(arg_type) is EnumMeta else \ 169 | arg_type.__name__ 170 | 171 | @staticmethod 172 | def build_choices(arg_type) -> Optional[Iterable[Any]]: 173 | """Enumerate `arg_type` if possible, or return `None`.""" 174 | return (True, False) if arg_type is bool else \ 175 | arg_type if type(arg_type) is EnumMeta else \ 176 | None 177 | 178 | @staticmethod 179 | def handles(t: Type) -> bool: 180 | return t in (str, int, float, bool) or type(t) is EnumMeta 181 | 182 | 183 | class TypeHandler: 184 | """Base type handler. A subclass typically implements `gen_cli_arg` for serialization and `_build_other` for deserialization.""" 185 | def __init__(self, primitive_addons: Sequence[Type[PrimitiveHandlerAddon]]): 186 | self.primitive_addons = primitive_addons 187 | 188 | def _build_common(self, kwargs: KwargsType, field_meta: FieldMeta, parent_required: bool) -> None: 189 | """Build `help`, `default` and `required` for keyword arguments for ArgumentParser.add_argument 190 | 191 | :param parent_required: is parent a required field or not 192 | :param kwargs: the keyword argument KwargsType object 193 | :param field_meta: the meta information extracted from the NamedTuple class""" 194 | # Build help message 195 | arg_type = field_meta.type 196 | help_builder = ['(', 'Optional[' if field_meta.optional else '', self._type_to_str(arg_type), ']' if field_meta.optional else ''] 197 | # Get default if specified and set required if no default 198 | kwargs.required = parent_required and field_meta.default is MISSING 199 | message_required = 'required' if kwargs.required else \ 200 | 'to be post-processed' if field_meta.default is LateInit else \ 201 | 'required if its container is being specified or marked' if field_meta.default is MISSING else \ 202 | f'default: {field_meta.default}' # informational only. The default is set when creating the argument container instance. 203 | help_builder.extend(('; ', message_required, ') ', field_meta.comment)) # Add from source code comment 204 | kwargs.help = ''.join(help_builder) 205 | 206 | def _build_other(self, kwargs: KwargsType, arg_type: Type) -> None: 207 | """Build `nargs`, `type` and `metavar` for the keyword argument KwargsType object 208 | 209 | :param kwargs: the keyword argument KwargsType object 210 | :param arg_type: the type of the argument extracted from NamedTuple (primitive types)""" 211 | 212 | def gen_kwargs(self, field_meta: FieldMeta, parent_required: bool) -> KwargsType: 213 | """Build keyword argument object KwargsType 214 | 215 | :param parent_required: if the parent container is required 216 | :param field_meta: argument metadata information 217 | :return: keyword argument object""" 218 | kwargs = KwargsType() 219 | self._build_common(kwargs, field_meta, parent_required) 220 | self._build_other(kwargs, field_meta.type) 221 | return kwargs 222 | 223 | def gen_cli_arg(self, arg: Any) -> Iterable[str]: 224 | """Generate command line for argument. This defines the serialization process. 225 | 226 | :param arg: value of the argument 227 | :return: iterable command line str""" 228 | args = (arg,) if isinstance(arg, str) or not isinstance(arg, Iterable) else arg 229 | yield from map(lambda a: _first_handles(self.primitive_addons, type(a)).build_str(a), args) 230 | 231 | def _type_to_str(self, t: Union[type, Type]) -> str: 232 | """Convert type to string for ArgumentParser help message 233 | 234 | :param t: type of the argument, i.e. float, Dict[str, int], Set[int], List[str] etc. 235 | :return: string representation of the argument type""" 236 | return f'{getattr(t, "_name", "") or t.__name__}[{", ".join(map(lambda a: a.__name__, get_args(t)))}]' 237 | 238 | def handles(self, t: Type) -> bool: 239 | raise NotImplementedError 240 | 241 | 242 | class PrimitiveHandler(TypeHandler): 243 | def handles(self, t: Type) -> bool: 244 | return _first_handles(self.primitive_addons, t, False) 245 | 246 | def _build_other(self, kwargs: KwargsType, arg_type: Type) -> None: 247 | addon = _first_handles(self.primitive_addons, arg_type) 248 | kwargs.type = addon.build_type(arg_type) 249 | kwargs.metavar = addon.build_metavar(arg_type) 250 | kwargs.choices = addon.build_choices(arg_type) 251 | 252 | def _type_to_str(self, t: Union[type, Type]) -> str: 253 | return t.__name__ 254 | 255 | 256 | class TupleHandler(TypeHandler): 257 | class __BuildType: 258 | def __init__(self, types, p_addons): 259 | self.__name__ = f'Tuple[{", ".join(t.__name__ for t in types)}]' 260 | self.counter, self.types, self.p_addons = 0, types, p_addons 261 | 262 | def __call__(self, s): 263 | if self.counter == len(self.types): 264 | self.counter = 0 265 | t = self.types[self.counter] 266 | self.counter += 1 267 | return _first_handles(self.p_addons, t).build_type(t)(s) 268 | 269 | def _build_other(self, kwargs: KwargsType, arg_type: Type) -> None: 270 | # get the tuple element types 271 | types = get_args(arg_type) 272 | kwargs.nargs = len(types) 273 | kwargs.metavar = tuple(map(lambda t: _first_handles(self.primitive_addons, t).build_metavar(t), types)) 274 | kwargs.type = TupleHandler.__BuildType(types, self.primitive_addons) 275 | 276 | def handles(self, t: Type) -> bool: 277 | return get_origin(t) is tuple and get_args(t) # type: ignore 278 | 279 | 280 | class CollectionHandler(TypeHandler): 281 | def _build_other(self, kwargs: KwargsType, arg_type: Type) -> None: 282 | kwargs.nargs = '*' 283 | unboxed_type = get_args(arg_type)[0] 284 | addon = _first_handles(self.primitive_addons, unboxed_type) 285 | kwargs.metavar = addon.build_metavar(unboxed_type) 286 | kwargs.type = addon.build_type(unboxed_type) 287 | 288 | def handles(self, t: Type) -> bool: 289 | args = get_args(t) 290 | return len(args) == 1 and get_origin(t) in (list, set, frozenset, abc.Sequence) and _first_handles(self.primitive_addons, args[0], False) 291 | 292 | 293 | class DictHandler(TypeHandler): 294 | def _build_other(self, kwargs, arg_type) -> None: 295 | addon_method = lambda t, method: getattr(_first_handles(self.primitive_addons, t), method)(t) # Find the addon for a type and a method 296 | kv_apply = lambda method, arg_types=get_args(arg_type): (addon_method(arg_types[0], method), addon_method(arg_types[1], method)) # Apply on k/v pair 297 | k_type, v_type = kv_apply('build_type') 298 | kwargs.nargs = '*' 299 | 300 | def dict_type(s: str): 301 | k, v = s.split(":") 302 | return k_type(k), v_type(v) 303 | kwargs.type = dict_type 304 | k, v = kv_apply('build_metavar') 305 | kwargs.metavar = f'{k}:{v}' 306 | 307 | def gen_cli_arg(self, arg): 308 | arg_to_str = lambda arg_v: _first_handles(self.primitive_addons, type(arg_v)).build_str(arg_v) 309 | yield from map(lambda kv: f'{arg_to_str(kv[0])}:{arg_to_str(kv[1])}', arg.items()) 310 | 311 | def handles(self, t: Type) -> bool: 312 | args, addons = get_args(t), self.primitive_addons 313 | return len(args) == 2 and get_origin(t) is dict and _first_handles(addons, args[0], False) and _first_handles(addons, args[1], False) 314 | 315 | 316 | class _namedtuple: # TODO expand lambdas to static methods or use a better holder representation 317 | """A NamedTuple proxy, helper function holder for NamedTuple support""" 318 | @staticmethod 319 | def new_instance(arg_class, kwargs): 320 | """:return A new instance of `arg_class`: call original __new__ -> __post_init__ -> post_validation""" 321 | new_instance = arg_class.__original_new__(arg_class, **kwargs) 322 | post_init = getattr(arg_class, '__post_init__', None) 323 | if post_init: 324 | fake_namedtuple = SimpleNamespace(**new_instance._asdict()) 325 | post_init(fake_namedtuple) # make the faked NamedTuple mutable in post_init only while initialization 326 | new_instance = arg_class.__original_new__(arg_class, **vars(fake_namedtuple)) 327 | arg_class.__arg_suite__.post_validation(new_instance) 328 | return new_instance 329 | 330 | @staticmethod 331 | def proxy(t: Type): 332 | """:return This proxy class if `t` is a `NamedTuple` or `None`""" 333 | b, f, f_t = getattr(t, '__bases__', []), getattr(t, '_fields', []), getattr(t, '__annotations__', {}) 334 | return _namedtuple if (len(b) == 1 and b[0] is tuple and isinstance(f, tuple) and isinstance(f_t, dict) 335 | and all(map(lambda n: type(n) is str, chain(f, f_t.keys())))) else None 336 | asdict = lambda args: args._asdict() 337 | field_default = lambda arg_class, raw_arg_name: arg_class._field_defaults.get(raw_arg_name, MISSING) 338 | patch = _black_hole # No need to patch 339 | 340 | 341 | class _dataclasses: 342 | """dataclass proxy""" 343 | @staticmethod 344 | def patch(cls): 345 | """Patch the argument dataclass so that `post_validation` is called, and it's immutable if not `frozen` after initialization""" 346 | def raise_if_frozen(self, fun, name, *args, **kwargs): 347 | _raise_if(f"cannot assign to/delete field {name!r}", getattr(self, '__frozen__', False)) 348 | getattr(object, fun)(self, name, *args, **kwargs) 349 | 350 | def init(self, *args, **kwargs): 351 | if args and hasattr(self, '__frozen__'): 352 | logger.debug(f"Assuming {self} is from the patched __new__ with __from_argv__, already initialized, skipping init.") 353 | return 354 | self.__original_init__(*args, **kwargs) 355 | object.__setattr__(self, '__frozen__', True) 356 | self.__class__.__arg_suite__.post_validation(self) 357 | cls.__init__, cls.__original_init__ = init, cls.__init__ 358 | cls.__setattr__ = lambda self, name, value: raise_if_frozen(self, '__setattr__', name, value) 359 | cls.__delattr__ = lambda self, name, : raise_if_frozen(self, '__delattr__', name) 360 | 361 | @staticmethod 362 | def field_default(arg_class, raw_arg_name): 363 | f = arg_class.__dataclass_fields__[raw_arg_name] 364 | return f.default_factory() if f.default_factory is not MISSING else f.default # Assuming the default_factory has no side effects. 365 | proxy = lambda t: _dataclasses if is_dataclass(t) else None 366 | asdict = lambda args: asdict(args) 367 | new_instance = lambda arg_class, _: arg_class.__original_new__(arg_class) 368 | 369 | 370 | _type_proxies = [_namedtuple, _dataclasses] # Supported container types. Users can extend it if they know what they are doing. Not officially supported yet. 371 | 372 | 373 | def _get_type_proxy(arg_class): 374 | return next(filter(lambda p: p.proxy(arg_class), _type_proxies), False) 375 | 376 | 377 | def _unwrap_optional(arg_type): 378 | type_origin, type_args, optional = get_origin(arg_type), get_args(arg_type), False 379 | if type_origin is Union and len(type_args) == 2 and type_args[1] is NoneType: # `Optional` support 380 | arg_type, optional = type_args[0], True # Unwrap `Optional` and validate 381 | return arg_type, optional 382 | 383 | 384 | class ArgSuite(Generic[ArgType]): 385 | """Generates the corresponding `ArgumentParser` and handles the two-way conversions.""" 386 | @staticmethod 387 | def new_arg(arg_class, *args, **kwargs): 388 | """Monkey-Patched argument container class __new__. 389 | If any positional arguments exist, it would assume that the user is trying to parse a sequence of strings. 390 | It would also assume there is only one positional argument, and raise an `SmartArgError` otherwise. 391 | If no positional arguments exist, it would call the argument container class instance creator with all keyword arguments. 392 | 393 | :param arg_class: Decorated class 394 | :param args: Optional positional argument, to be parsed to the arg_class type. 395 | `args[0]`: an optional marker to mark the sub-sequence of `argv` to be parsed by the parser. ``None`` will 396 | be interpreted as ``sys.argv[1:]`` 397 | `args[1]`: default to `None`, indicating using the default separator for the argument container class 398 | 399 | :type `(Optional[Sequence[str]], Optional[str])` 400 | :param kwargs: Optional keyword arguments, to be passed to the argument container class specific instance creator.""" 401 | logger.info(f"Patched __new__ for {arg_class} is called with {args} and {kwargs}.") 402 | if args: 403 | warn(f"Calling the patched constructor of {arg_class} with argv is deprecated, please use {arg_class}.__from_argv__ instead.") 404 | _raise_if(f"Calling '{arg_class}(positional {args}, keyword {kwargs})' is not allowed:\n" 405 | f"Only accept positional arguments to parse to the '{arg_class}'\nkeyword arguments can only be used to create an instance directly.", 406 | kwargs or len(args) > 2 or len(args) == 2 and args[1].__class__ not in (NoneType, str) 407 | or not (args[0] is None or isinstance(args[0], Sequence) and all(map(lambda a: a.__class__ is str, args[0])))) 408 | 409 | return arg_class.__from_argv__(args[0]) 410 | else: 411 | return _get_type_proxy(arg_class).new_instance(arg_class, kwargs) 412 | 413 | def __init__(self, type_handlers: Sequence[TypeHandler], arg_class): 414 | type_proxy = _get_type_proxy(arg_class) 415 | _raise_if(f"Unsupported argument container class {arg_class}.", not type_proxy) 416 | self.handlers = type_handlers 417 | self.handler_actions: Dict[str, Tuple[Union[TypeHandler, Type], Action]] = {} 418 | type_proxy.patch(arg_class) 419 | # A big assumption here is that the argument container classes never override __new__ 420 | if not hasattr(arg_class, '__original_new__'): 421 | arg_class.__original_new__ = arg_class.__new__ 422 | arg_class.__new__ = ArgSuite.new_arg 423 | arg_class.__to_argv__ = lambda arg_self, separator='': self.to_argv(arg_self, separator) # arg_class instance level method 424 | arg_class.__from_argv__ = self.parse_to_arg 425 | arg_class.__arg_suite__ = self 426 | self._arg_class = arg_class 427 | self._parser = ArgumentParser(description=self._arg_class.__doc__, argument_default=MISSING, # type: ignore 428 | fromfile_prefix_chars='@', allow_abbrev=False) 429 | self._parser.convert_arg_line_to_args = lambda arg_line: arg_line.split() # type: ignore 430 | self._gen_arguments_from_class(self._arg_class, '', True, [], type_proxy) 431 | 432 | @staticmethod 433 | def _validate_fields(arg_class: Type) -> None: 434 | """Validate fields in `arg_class`. 435 | 436 | :raise: SmartArgError if the decorated argument container class has non-typed field with defaults and such field 437 | does not startswith "_" to overwrite the existing argument field property.""" 438 | arg_fields = _annotations(arg_class).keys() 439 | invalid_fields = tuple(filter(lambda s: s.endswith('_'), arg_fields)) 440 | _raise_if(f"'{arg_class}': found invalid (ending with '_') fields : {invalid_fields}.", invalid_fields) 441 | private_prefix = f'_{arg_class.__name__}__' 442 | l_prefix = len(private_prefix) 443 | # skip callable methods and typed fields 444 | for f in filter(lambda f: not callable(getattr(arg_class, f)) and f not in arg_fields, vars(arg_class).keys()): 445 | is_private = f.startswith(private_prefix) 446 | _raise_if(f"'{arg_class}': there is no field '{f[l_prefix:]}' for '{f}' to override.", is_private and f[l_prefix:] not in arg_fields) 447 | _raise_if(f"'{arg_class}': found invalid (untyped) field '{f}'.", not (is_private or f.startswith('_'))) 448 | 449 | def _gen_arguments_from_class(self, arg_class, prefix: str, parent_required, arg_classes: List, type_proxy) -> None: 450 | """Add argument to the self._parser for each field in the self._arg_class 451 | :raise: SmartArgError if a corresponding handler for the argument type is not found.""" 452 | _raise_if(f"Recursively nested argument container class '{arg_class}' is not supported.", arg_class in arg_classes) 453 | suite = getattr(arg_class, '__arg_suite__', None) 454 | _raise_if(f"Nested argument container class '{arg_class}' with '__post_init__' expected to be decorated.", 455 | not (suite and suite._arg_class is arg_class) and hasattr(arg_class, '__post_init__')) 456 | self._validate_fields(arg_class) 457 | comments = _mro_all(arg_class, self.get_comments_from_source) 458 | for raw_arg_name, arg_type in _annotations(arg_class).items(): 459 | arg_name = f'{prefix}{raw_arg_name}' 460 | try: 461 | default = type_proxy.field_default(arg_class, raw_arg_name) 462 | arg_type, optional = _unwrap_optional(arg_type) 463 | _raise_if(f"Optional field: {arg_name!r}={default!r} must default to `None` or `LateInit`", optional and default not in (None, LateInit)) 464 | sub_type_proxy = _get_type_proxy(arg_type) 465 | if sub_type_proxy: 466 | required = parent_required and default is MISSING 467 | arg_classes.append(arg_class) 468 | self._gen_arguments_from_class(arg_type, f'{arg_name}.', required, arg_classes, sub_type_proxy) 469 | arg_classes.pop() 470 | kwargs = KwargsType(nargs='?', 471 | required=required, 472 | metavar='Does NOT accept any value', 473 | # capture the value of arg_name in the lambda default argument since arg_name is not immutable 474 | type=lambda _, arg=arg_name: _raise_if(f"Nested argument container marker {arg!r} does not accept any value."), 475 | help=f"nested argument container marker{'' if default in (LateInit, MISSING) else f'; default: {default}'}") 476 | else: 477 | handler = _first_handles(self.handlers, arg_type) 478 | user_override: dict = getattr(arg_class, f'_{arg_class.__name__}__{raw_arg_name}', {}) 479 | comment = user_override.pop('_comment_to_help', comments.get(raw_arg_name, '')) 480 | field_meta = FieldMeta(comment=comment, default=default, type=arg_type, optional=optional) 481 | kwargs = handler.gen_kwargs(field_meta, parent_required) 482 | if user_override.get('choices', None): 483 | logger.info(f"Instead of defining `choices`, please consider using Enum for {arg_name}") 484 | kwargs.__dict__.update(user_override) # apply user override to the keyword argument object 485 | kwargs.__dict__.pop('_serialization', None) 486 | logger.debug(f"Adding kwargs {kwargs} for --{arg_name}") 487 | self.handler_actions[arg_name] = (sub_type_proxy or handler), self._parser.add_argument(f'--{arg_name}', **vars(kwargs)) 488 | except BaseException as b_e: 489 | logger.critical(f"Failed creating argument parser for {arg_name!r}:{arg_type!r} with exception {b_e}.") 490 | raise 491 | 492 | @staticmethod 493 | def strip_argv(separator: str, argv: Optional[Sequence[str]]) -> Sequence[str]: 494 | """Strip any elements outside `{separator}+` and `{separator}-` of `argv`. 495 | :param separator: A string marker prefix to mark the boundaries of the belonging arguments in argv 496 | :param argv: Input argument list, treated as `sys.argv[1:]` if `None` 497 | :return: Stripped `argv`""" 498 | if argv is None: 499 | argv = sys.argv[1:] 500 | l_s, r_s = separator + '+', separator + '-' 501 | lc, rc = argv.count(l_s), argv.count(r_s) 502 | if lc == rc: 503 | if lc == 0: 504 | return argv 505 | elif lc == 1: 506 | b, e = argv.index(l_s), argv.index(r_s) 507 | if e > b: 508 | return argv[b + 1: e] 509 | raise SmartArgError(f"Expecting up to 1 pair of separator markers'{l_s}' and '{r_s}' in {argv}") 510 | 511 | def parse_to_arg(self, argv: Optional[Sequence[str]] = None, separator: Optional[str] = '', *, error_on_unknown: bool = True) -> ArgType: 512 | """Parse the command line to decorated ArgType 513 | 514 | :param separator: Optional marker to mark the sub-sequence of `argv` ['{separator}+' to '{separator}-'] to parse 515 | :param error_on_unknown: When `True`, raise if there is any unknown argument in the marked sub-sequence. 516 | :param argv: the command line list 517 | :return: parsed decorated object from command line""" 518 | def to_arg(arg_class: Type[ArgType], prefix) -> ArgType: 519 | nest_arg = {} 520 | for raw_arg_name, arg_type in _annotations(arg_class).items(): 521 | arg_name = f'{prefix}{raw_arg_name}' 522 | handler_or_proxy, _ = self.handler_actions[arg_name] 523 | if isinstance(handler_or_proxy, TypeHandler): 524 | value = arg_dict.get(arg_name, MISSING) 525 | type_origin = get_origin(arg_type) 526 | type_args = get_args(arg_type) 527 | type_to_new = get_origin(type_args[0]) if type_origin is Union and len(type_args) == 2 and type_args[1] is NoneType else type_origin 528 | # argparse reading variable length arguments are all lists, need to apply the origin type for the conversion to correct type. 529 | value = type_to_new(value) if value is not None and isinstance(value, List) else value # type: ignore 530 | else: # deal with nested container, consider defined only if there's subfield 531 | is_nested_items_defined = any(filter(lambda name: name.startswith(arg_name), arg_dict.keys())) # type: ignore 532 | value = to_arg(_unwrap_optional(arg_type)[0], f'{arg_name}.') if is_nested_items_defined else MISSING 533 | if value is not MISSING: 534 | nest_arg[raw_arg_name] = value 535 | return arg_class(**nest_arg) 536 | argv = self.strip_argv(separator or self._arg_class.__name__, argv) 537 | ns, unknown = self._parser.parse_known_args(argv) 538 | error_on_unknown and unknown and self._parser.error(f"unrecognized arguments: {' '.join(unknown)}") 539 | logger.info(f"{argv} is parsed to {ns}") 540 | arg_dict: Dict[str, Any] = {name: value for name, value in vars(ns).items() if value is not MISSING} # Filter out defaults added by the parser. 541 | return to_arg(self._arg_class, '') 542 | 543 | def post_validation(self, arg: ArgType, prefix: str = '') -> None: 544 | """This is called after __post_init__ to validate the fields.""" 545 | arg_class = arg.__class__ 546 | for name, t in _annotations(arg_class).items(): 547 | attr = getattr(arg, name) 548 | _raise_if(f"Field '{name}' is still not initialized after post processing for {arg_class}", attr is LateInit) 549 | attr_class = attr.__class__ 550 | type_origin, type_args = get_origin(t), get_args(t) 551 | if _get_type_proxy(t) or type_origin is Union and len(type_args) == 2 and type_args[1] is NoneType and _get_type_proxy(type_args[0]): 552 | _raise_if(f"Field {name}' value of {attr}:{attr_class} is not of the expected type '{t}' for {arg_class}", attr_class not in (t, *type_args)) 553 | self.post_validation(attr, f'{prefix}{name}.') 554 | else: 555 | arg_type = get_origin(t) or t 556 | if arg_type is get_origin(Optional[Any]): 557 | arg_type = get_args(t) 558 | if arg_type is list and isinstance(attr, tuple) or arg_type is set and isinstance(attr, frozenset): # allow for immutable replacements 559 | continue 560 | try: 561 | conforming = isinstance(attr, arg_type) # e.g. list and List 562 | except TypeError: # best effort to check the instance type 563 | logger.warning(f"Unable to check if {attr!r} is of type {t} for field {name!r} of argument container class {arg_class}") 564 | conforming = True 565 | _raise_if(f"Field {name!r} has value of {attr!r} of type {attr_class} which is not of the expected type '{t}' for {arg_class}", not conforming) 566 | 567 | def _gen_cmd_argv(self, args: ArgType, prefix) -> Iterable[str]: 568 | arg_class = args.__class__ 569 | proxy = _get_type_proxy(arg_class) 570 | for name, arg in proxy.asdict(args).items(): 571 | default = proxy.field_default(arg_class, name) 572 | if arg != default: 573 | handler_or_proxy, action = self.handler_actions[f'{prefix}{name}'] 574 | serialization_user_override = getattr(arg_class, f'_{arg_class.__name__}__{name}', {}).get('_serialization', None) 575 | yield action.option_strings[0] 576 | yield from serialization_user_override(arg) if serialization_user_override else \ 577 | handler_or_proxy.gen_cli_arg(arg) if isinstance(handler_or_proxy, TypeHandler) else \ 578 | self._gen_cmd_argv(arg, f'{prefix}{name}.') 579 | 580 | @staticmethod 581 | def get_comments_from_source(arg_cls: Type[ArgType]) -> Dict[str, str]: 582 | """Get in-line comments for the input class fields. Only single line of trailing comment is supported. 583 | 584 | :param arg_cls: the input class 585 | :return: a dictionary with key of class field name and value of in-line comment""" 586 | comments = {} 587 | indent = 0 588 | field = None 589 | import inspect, tokenize # noqa: E401 590 | try: 591 | for token in tokenize.generate_tokens(iter(inspect.getsourcelines(arg_cls)[0]).__next__): 592 | if token.type == tokenize.NEWLINE: 593 | field = None 594 | elif token.type == tokenize.INDENT: 595 | indent += 1 596 | elif token.type == tokenize.DEDENT: 597 | indent -= 1 598 | elif token.type == tokenize.NAME and indent == 1 and not field: 599 | field = token 600 | elif token.type == tokenize.COMMENT and field: 601 | # TODO nicer way to deal with with long comments or support multiple lines 602 | comments[field.string] = (token.string + ' ')[1:token.string.lower().find('# noqa:')] # TODO consider move processing out 603 | except Exception as e: 604 | logger.error(f'Failed to parse comments from source of class {arg_cls}, continue without them.', exc_info=e) 605 | return comments 606 | 607 | def to_argv(self, arg: ArgType, separator: Optional[str] = '') -> Sequence[str]: 608 | """Generate the command line arguments 609 | 610 | :param separator: separator marker, empty str to disable, None to default to class name 611 | :param arg: the annotated argument container class object 612 | :return: command line sequence""" 613 | argv = self._gen_cmd_argv(arg, '') 614 | if separator is None: 615 | separator = self._arg_class.__name__ 616 | return (separator + '+', *argv, separator + '-') if separator else tuple(argv) 617 | 618 | 619 | class ArgSuiteDecorator: 620 | """Generate a decorator to easily convert back and forth from command-line to `NamedTuple` or `dataclass`. 621 | 622 | The decorator monkey patches the constructor, so that the IDE would infer the type of 623 | the deserialized `arg` for code auto-completion and type check:: 624 | 625 | arg: ArgClass = ArgClass.__from_argv__(my_argv) # Factory method, need the manual type hint `:ArgClass` 626 | arg = ArgClass(my_argv) # with monkey-patched constructor, no manual hint needed 627 | 628 | For the ArgClass fields without types but with default values: only private fields starts with "__" to overwrite the existed 629 | argument parameters are allowed; others will throw SmartArgError. 630 | 631 | For the handlers/addons, first one in the sequence that claims to handle a type takes the precedence. 632 | 633 | Usage:: 634 | 635 | @arg_suite # `arg_suite` is a shorthand for `custom_arg_suite()` 636 | class MyArg(NamedTuple): 637 | field_one: str 638 | _field_one = {"choices": ["one", "two"]} # advanced usage: overwrite the `field_one` parameter 639 | field_two: List[int] 640 | ... 641 | 642 | :param primitive_handler_addons: the primitive types handling in addition to the provided primitive type basic operations 643 | :param type_handlers: the types handling in addition to the provided types handling. 644 | :return: the argument container class decorator""" 645 | def __init__(self, type_handlers: Sequence[Type[TypeHandler]] = (), primitive_handler_addons: Sequence[Type[PrimitiveHandlerAddon]] = ()) -> None: 646 | addons = (*primitive_handler_addons, PrimitiveHandlerAddon) 647 | self.handlers = tuple(map(lambda handler: handler(addons), chain(type_handlers, (PrimitiveHandler, CollectionHandler, DictHandler, TupleHandler)))) 648 | 649 | def __call__(self, cls): 650 | ArgSuite(self.handlers, cls) 651 | return cls 652 | 653 | 654 | custom_arg_suite = ArgSuiteDecorator # Snake case decorator alias for the class name in camel case 655 | arg_suite = custom_arg_suite() # Default argument container class decorator to expose smart arg functionalities. 656 | -------------------------------------------------------------------------------- /test/smart_arg_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from enum import Enum, auto 3 | from typing import List, NamedTuple, Tuple, Optional, Dict, Sequence 4 | 5 | from smart_arg import arg_suite, LateInit, frozenlist 6 | 7 | 8 | class Encoder(Enum): 9 | FASTTEXT = auto() 10 | WORD2VEC = auto() 11 | 12 | 13 | class NestedArg(NamedTuple): 14 | class NNested(NamedTuple): 15 | n_name: str 16 | n_nested: NNested = NNested(n_name='nested name') 17 | 18 | 19 | @arg_suite 20 | class MyModelConfig(NamedTuple): 21 | """MyModelConfig docstring goes to description""" 22 | adp: bool 23 | nn: List[int] # Comments go to argparse help 24 | a_tuple: Tuple[str, int] 25 | h_param: Dict[str, int] = {} # Hyperparameters 26 | encoder: Encoder = Encoder.FASTTEXT # Word encoder type 27 | nested: Optional[NestedArg] = None # nested args 28 | n: Optional[int] = LateInit # An argument that can be auto-set in post processing 29 | immutable_list: Sequence[int] = frozenlist((1, 2)) # A frozenlist is an alias to tuple 30 | embedding_dim: int = 100 # Size of embedding vector 31 | lr: float = 1e-3 # Learning rate 32 | 33 | def __post_init__(self): 34 | assert self.nn, "Expect nn to be non-empty." # validation 35 | if self.n is LateInit: # post processing 36 | self.n = self.nn[0] 37 | 38 | 39 | if __name__ == '__main__': 40 | my_config = MyModelConfig(nn=[200, 300], a_tuple=("s", 5), adp=True, h_param={'n': 0, 'y': 1}, nested=NestedArg()) 41 | 42 | my_config_argv = my_config.__to_argv__() 43 | print(f"Serialized reference to the expected argument container object:\n{my_config_argv!r}") 44 | 45 | print(f"Deserializing from the command line string list: {sys.argv[1:]!r}.") 46 | deserialized = MyModelConfig.__from_argv__() 47 | 48 | re_deserialized = MyModelConfig.__from_argv__(my_config_argv) 49 | if deserialized == my_config == re_deserialized: 50 | print(f"All matched, argument container object: '{deserialized!r}'") 51 | else: 52 | err_code = 168 53 | print(f"Error({err_code}):\n" 54 | f"Expected argument container object:\n {my_config!r}\n" 55 | f"Re-deserialized:\n {re_deserialized!r}\n" 56 | f"Deserialized:\n {deserialized!r}") 57 | exit(err_code) 58 | -------------------------------------------------------------------------------- /test/test_smart_arg.py: -------------------------------------------------------------------------------- 1 | """Tests for smart-arg-suite.""" 2 | import os 3 | import subprocess 4 | import sys 5 | from contextlib import redirect_stderr 6 | from dataclasses import dataclass, replace, field 7 | from enum import Enum 8 | from math import sqrt 9 | from types import SimpleNamespace 10 | from typing import List, NamedTuple, Tuple, Optional, Dict, Set, Type, Any, Sequence, FrozenSet 11 | 12 | import pytest 13 | 14 | from smart_arg import arg_suite, custom_arg_suite, LateInit, SmartArgError, TypeHandler, _first_handles, PrimitiveHandlerAddon, frozenlist 15 | 16 | 17 | @arg_suite 18 | @dataclass(frozen=True) 19 | class MyTupBasic: 20 | """ 21 | MyTup docstring goes to description 22 | """ 23 | def _serialization(a_str): 24 | if a_str not in MyTupBasic.__a_str['choices']: 25 | raise ValueError 26 | yield a_str 27 | 28 | a_int: int # a is int 29 | a_float: float # a is float 30 | a_bool: bool 31 | a_str: str # comment will NOT go to a_str's help 32 | _comment_in_help = """ 33 | multiline override 34 | trailing comment in help""" 35 | __a_str = {'choices': ['hello', 'bonjour', 'hola'], '_serialization': _serialization, '_comment_to_help': _comment_in_help} # will overwrite the a_str argument choices, serialization and comment 36 | b_list_int: List[int] 37 | b_set_str: Set[str] 38 | d_tuple_3: Tuple[int, float, bool] 39 | d_tuple_2: Tuple[str, int] 40 | e_dict_str_int: Dict[str, int] 41 | e_dict_int_bool: Dict[int, bool] 42 | c_optional_float: Optional[float] = None 43 | immutable_list: List[int] = frozenlist((1, 2)) 44 | immutable_seq: Sequence[int] = frozenlist((1, 2)) 45 | immutable_set: Set[int] = frozenset((1, 2)) 46 | immutable_f_set: FrozenSet[int] = frozenset((1, 2)) 47 | with_default_factory: List[float] = field(default_factory=list) 48 | with_default: int = field(default=10) 49 | 50 | 51 | my_tup_basic = MyTupBasic( 52 | a_int=32, 53 | a_float=0.3, 54 | a_bool=True, 55 | a_str='hello', 56 | b_list_int=[1, 2, 3], 57 | b_set_str={'set1', 'set2'}, 58 | c_optional_float=None, 59 | d_tuple_2=('tuple', 12), 60 | d_tuple_3=(10, 0.5, False), 61 | e_dict_str_int={'size': 32, 'area': 90}, 62 | e_dict_int_bool={10: True, 20: False, 30: True}) 63 | 64 | 65 | def test_basic_parse_to_arg(): 66 | with pytest.raises(TypeError, match="missing 10"): 67 | MyTupBasic() 68 | arg_cmd = 'MyTupBasic+ --a_int 32 --a_float 0.3 --a_bool True --a_str hello --b_list_int 1 2 3 --b_set_str set1 set2 ' + \ 69 | '--d_tuple_3 10 0.5 False --d_tuple_2 tuple 12 --e_dict_str_int size:32 area:90 --e_dict_int_bool 10:True 20:False 30:True MyTupBasic-' 70 | parsed_arg_from_factory: MyTupBasic = MyTupBasic.__from_argv__(arg_cmd.split()) 71 | assert my_tup_basic == parsed_arg_from_factory 72 | serialized_cmd_line = my_tup_basic.__to_argv__(separator=None) 73 | assert set(serialized_cmd_line) == set(arg_cmd.split()) 74 | my_parser = MyTupBasic.__arg_suite__._parser 75 | assert my_parser._option_string_actions['--c_optional_float'].help == '(Optional[float]; default: None) ' 76 | assert my_parser._option_string_actions['--a_int'].help == '(int; required) a is int' 77 | assert my_parser._option_string_actions['--a_float'].help == '(float; required) a is float' 78 | assert my_parser._option_string_actions['--a_str'].choices == ['hello', 'bonjour', 'hola'] 79 | assert my_parser._option_string_actions['--a_str'].help == f"(str; required) {MyTupBasic._comment_in_help}" 80 | 81 | parsed_arg = MyTupBasic(arg_cmd.split()) # Patched constructor 82 | assert parsed_arg == my_tup_basic 83 | 84 | pytest.raises(ValueError, replace(parsed_arg, a_str='').__to_argv__, []) # serialization is overridden 85 | 86 | 87 | muted = redirect_stderr(SimpleNamespace(write=lambda *_: None)) 88 | 89 | 90 | def test_parse_error(): 91 | with muted: 92 | pytest.raises(SystemExit, MyTupBasic.__from_argv__, []) 93 | 94 | 95 | def test_optional(): 96 | @arg_suite 97 | class MyTup(NamedTuple): 98 | ints: Optional[List[int]] = None 99 | 100 | with muted: 101 | pytest.raises(SystemExit, MyTup.__from_argv__, ['--ints', 'None']) 102 | assert MyTup.__arg_suite__._parser._option_string_actions['--ints'].help == '(Optional[List[int]]; default: None) ' 103 | assert MyTup.__from_argv__([]).ints is None 104 | assert MyTup.__from_argv__(['--ints', '1', '2']).ints == [1, 2] 105 | assert MyTup.__from_argv__(['--ints']).ints == [] 106 | 107 | class InvalidOptional(NamedTuple): 108 | no: Optional[int] 109 | 110 | pytest.raises(SmartArgError, arg_suite, InvalidOptional) 111 | 112 | 113 | def test_post_process(): 114 | @arg_suite 115 | class MyTup(NamedTuple): 116 | a_int: Optional[int] = LateInit # if a_int is not in the argument, post_process will initialize it 117 | 118 | pytest.raises(SmartArgError, MyTup) 119 | pytest.raises(SmartArgError, MyTup, a_int='not a int') 120 | pytest.raises(SmartArgError, MyTup.__from_argv__, []) 121 | 122 | @arg_suite 123 | class MyTup(NamedTuple): 124 | a_int: Optional[int] = LateInit # if a_int is not in the argument, post_process will initialize it 125 | 126 | def __post_init__(self): 127 | self.a_int = 10 if self.a_int is LateInit else self.a_int 128 | 129 | assert MyTup.__from_argv__([]).a_int == 10 130 | assert MyTup().a_int == 10 131 | assert MyTup(a_int=0).a_int == 0 132 | 133 | 134 | def test_validate(): 135 | @arg_suite 136 | class MyTup(NamedTuple): 137 | a_int: int = LateInit # if a_int is not in the argument, __post_init__ needs to initialize it 138 | 139 | validated = False 140 | 141 | def validate(s): 142 | nonlocal validated 143 | validated = True 144 | raise AttributeError() 145 | 146 | MyTup.__post_init__ = validate 147 | pytest.raises(AttributeError, MyTup.__from_argv__, ['--a_int', '1']) 148 | assert validated, "`validate` might not be executed." 149 | 150 | @arg_suite 151 | class MyTuple(NamedTuple): 152 | abc: str 153 | 154 | def __post_init__(self): 155 | if self.abc != 'abc': 156 | raise AttributeError() 157 | return self 158 | 159 | @staticmethod 160 | def format(): 161 | return "hello" 162 | 163 | @classmethod 164 | def format2(cls): 165 | return "hello2" 166 | 167 | format3 = lambda: "hello3" 168 | 169 | # no exception for callable methods 170 | tup = MyTuple.__from_argv__(['--abc', 'abc']) 171 | assert tup.abc == 'abc' 172 | assert tup.format() == 'hello' 173 | assert MyTuple.format2() == 'hello2' 174 | assert MyTuple.format3() == 'hello3' 175 | 176 | 177 | def test_validate_fields(): 178 | class DanglingParamOverride(NamedTuple): 179 | __a_str = "abc" 180 | 181 | class MyNonType(NamedTuple): 182 | a_str = "abc" 183 | 184 | class MyNonTypeTuple(NamedTuple): 185 | a_tuple: Tuple 186 | 187 | class TrailingUnderscore(NamedTuple): 188 | a_: int 189 | 190 | pytest.raises(SmartArgError, arg_suite, DanglingParamOverride) 191 | pytest.raises(SmartArgError, arg_suite, MyNonType) 192 | pytest.raises(SmartArgError, arg_suite, MyNonTypeTuple) 193 | pytest.raises(SmartArgError, arg_suite, TrailingUnderscore) 194 | 195 | 196 | def test_primitive_addon(): 197 | class IntHandlerAddon(PrimitiveHandlerAddon): 198 | @staticmethod 199 | def build_type(arg_type) -> Any: 200 | return lambda s: int(s) ** 2 201 | 202 | @staticmethod 203 | def build_str(arg) -> str: 204 | return str(int(sqrt(arg))) 205 | 206 | @staticmethod 207 | def handles(t: Type) -> bool: 208 | return t == int 209 | 210 | class IntTypeHandler(TypeHandler): 211 | def _build_common(self, kwargs, field_meta, parent_required) -> None: 212 | super()._build_common(kwargs, field_meta, parent_required) 213 | kwargs.help = '(int, squared)' 214 | 215 | def _build_other(self, kwargs, arg_type) -> None: 216 | kwargs.type = _first_handles(self.primitive_addons, arg_type).build_type(arg_type) 217 | 218 | def handles(self, t: Type) -> bool: 219 | return t == int 220 | 221 | @custom_arg_suite(primitive_handler_addons=[IntHandlerAddon], type_handlers=[IntTypeHandler]) 222 | class MyTuple(NamedTuple): 223 | a_int: int 224 | 225 | argv = ('--a_int', '3') 226 | tup = MyTuple.__from_argv__(argv) 227 | assert tup.a_int == 9 228 | assert tup.__to_argv__() == argv 229 | my_parser = MyTuple.__arg_suite__._parser 230 | assert my_parser._option_string_actions['--a_int'].help == '(int, squared)' 231 | 232 | 233 | def test_unsupported_types(): 234 | class MyTuple(NamedTuple): 235 | a: List[List[int]] 236 | 237 | pytest.raises(SmartArgError, arg_suite, MyTuple) 238 | 239 | 240 | def test_nested(): 241 | @arg_suite 242 | class MyTup(NamedTuple): 243 | def __post_init__(self): 244 | self.b_int = 10 if self.b_int is LateInit else self.b_int 245 | 246 | b_int: int = LateInit # if a_int is not in the argument, __post_init__ will initialize it 247 | 248 | @arg_suite 249 | class Nested(NamedTuple): 250 | a_int: int 251 | nested: Optional[MyTupBasic] = None # Optional nested 252 | another_int: int = 0 253 | another_nested: MyTup = MyTup() 254 | 255 | nested = Nested(a_int=0) 256 | assert nested.another_nested.b_int == 10 257 | argv = nested.__to_argv__() 258 | assert Nested(argv) == nested 259 | assert Nested(argv[0:2]) == nested 260 | pytest.raises(SmartArgError, Nested, ['--a_int', '0', '--nested', 'Not Allowed']) 261 | pytest.raises(SmartArgError, Nested, a_int=0, nested='Not nested MyTupBasic') 262 | 263 | class NotDecoratedWithPost(NamedTuple): 264 | def __post_init__(self): pass 265 | 266 | class Nested(NamedTuple): 267 | nested: NotDecoratedWithPost 268 | 269 | # Nested class is not allowed to have __post_init__ if not decorated 270 | pytest.raises(SmartArgError, arg_suite, Nested) 271 | del NotDecoratedWithPost.__post_init__ 272 | arg_suite(Nested) # should not raise 273 | 274 | 275 | def test_cli_execution(): 276 | cmd_line = f'{sys.executable if sys.executable else "python"} {os.path.join(os.path.dirname(__file__), "smart_arg_demo.py")}' 277 | args = ' --nn 200 300 --a_tuple s 5 --encoder FASTTEXT --h_param y:1 n:0 --nested --embedding_dim 100 --lr 0.001 --adp True' 278 | cmd_line += args 279 | kwargs = {'stdout': subprocess.PIPE, 'stderr': subprocess.PIPE, 'shell': True} 280 | completed_process = subprocess.run(cmd_line, **kwargs) 281 | assert completed_process.stderr.decode('utf-8') == '' 282 | assert completed_process.returncode == 0 283 | 284 | completed_process = subprocess.run(f'{cmd_line} --adp False', **kwargs) 285 | assert completed_process.returncode == 168, "deserialization succeeded, but not as expected" 286 | 287 | completed_process = subprocess.run(f'{cmd_line} --nested "OH NO!"', **kwargs) 288 | assert completed_process.returncode == 1, "nested field should not have any value defined directly" 289 | 290 | separator = 'MyModelConfig' 291 | completed_process = subprocess.run(f'{cmd_line} {separator}+ {args} {separator}- --nested "OH NO!"', **kwargs) 292 | assert completed_process.stderr.decode('utf-8') == '' 293 | assert completed_process.returncode == 0, "Argument outside of separators should be ignored." 294 | 295 | 296 | def test_dataclass(): 297 | @dataclass 298 | class GdmixParams: 299 | ACTIONS = ("action_inference", "action_train") 300 | action: str = ACTIONS[1] # Train or inference. 301 | __action = {"choices": ACTIONS} 302 | STAGES = ("fixed_effect", "random_effect") 303 | stage: str = STAGES[0] # Fixed or random effect. 304 | __stage = {"choices": STAGES} 305 | MODEL_TYPES = ("logistic_regression", "detext") 306 | model_type: str = MODEL_TYPES[0] # The model type to train, e.g, logistic regression, detext, etc. 307 | __model_type = {"choices": MODEL_TYPES} 308 | 309 | # Input / output files or directories 310 | training_output_dir: Optional[str] = None # Training output directory. 311 | validation_output_dir: Optional[str] = None # Validation output directory. 312 | 313 | # Driver arguments for random effect training 314 | partition_list_file: Optional[str] = None # File containing a list of all the partition ids, for random effect only 315 | 316 | def __post_init__(self): 317 | assert self.action in self.ACTIONS, "Action must be either train or inference" 318 | assert self.stage in self.STAGES, "Stage must be either fixed_effect or random_effect" 319 | assert self.model_type in self.MODEL_TYPES, "Model type must be either logistic_regression or detext" 320 | 321 | @dataclass 322 | class SchemaParams: 323 | # Schema names 324 | sample_id: str # Sample id column name in the input file. 325 | sample_weight: Optional[str] = None # Sample weight column name in the input file. 326 | label: Optional[str] = None # Label column name in the train/validation file. 327 | prediction_score: Optional[str] = None # Prediction score column name in the generated result file. 328 | prediction_score_per_coordinate: str = "predictionScorePerCoordinate" # ColumnName of the prediction score without the offset. 329 | 330 | @arg_suite 331 | @dataclass 332 | class Params(GdmixParams, SchemaParams): 333 | """GDMix Driver""" 334 | 335 | def __post_init__(self): 336 | super().__post_init__() 337 | assert (self.action == self.ACTIONS[1] and self.label) or (self.action == self.ACTIONS[0] and self.prediction_score) 338 | self.prediction_score = self.prediction_score 339 | 340 | argv = ['--sample_id', 'uid', '--sample_weight', 'weight', '--feature_bags', 'global', '--train_data_path', 341 | 'resources/train', '--validation_data_path', 342 | 'resources/validate', '--model_output_dir', 'dummy_model_output_dir', '--metadata_file', 343 | 'resources/fe_lbfgs/metadata/tensor_metadata.json', '--feature_file', 344 | 'test/resources/fe_lbfgs/featureList/global'] 345 | with muted: pytest.raises(SystemExit, Params.__from_argv__, argv) 346 | pytest.raises(AssertionError, Params.__from_argv__, argv, error_on_unknown=False) 347 | args: Params = Params.__from_argv__(argv + ['--label', 'bluh'], error_on_unknown=False) 348 | pytest.raises(AssertionError, Params, sample_id='uid', action='no_such_action') 349 | pytest.raises(SmartArgError, args.__post_init__) # mutation not allowed after init 350 | object.__delattr__(args, '__frozen__') 351 | args.__post_init__() # mutation allowed after '__frozen__` mark removed 352 | assert args == Params(sample_id='uid', sample_weight='weight', label='bluh') 353 | assert args.__to_argv__() == ('--sample_id', 'uid', '--sample_weight', 'weight', '--label', 'bluh') 354 | 355 | @arg_suite 356 | @dataclass 357 | class NoPostInit: 358 | def mutate(self): 359 | self.frozen = False 360 | 361 | frozen: bool = True 362 | 363 | pytest.raises(SmartArgError, NoPostInit().mutate) # mutation not allowed after init 364 | 365 | 366 | def test_basic_enum(): 367 | class Color(Enum): 368 | RED = 1 369 | BLUE = 2 370 | GREEN = 3 371 | 372 | @arg_suite 373 | class MyEnumBasic(NamedTuple): 374 | a_int: int 375 | my_color_dict: Dict[int, Color] 376 | my_color_list: List[Color] 377 | my_color_tuple: Tuple[Color, int] 378 | default_color: Color = Color.RED 379 | 380 | arg_cmd = ['--a_int', '1', '--my_color_dict', '10:RED', '20:BLUE', '--my_color_list', 'GREEN', 381 | '--my_color_tuple', 'BLUE', '100', '--default_color', 'GREEN'] 382 | basic_tup = MyEnumBasic(a_int=1, my_color_dict={10: Color.RED, 20: Color.BLUE}, my_color_list=[Color.GREEN], 383 | my_color_tuple=(Color.BLUE, 100), default_color=Color.GREEN) 384 | 385 | parsed_tup: MyEnumBasic = MyEnumBasic.__from_argv__(arg_cmd) 386 | assert basic_tup == parsed_tup 387 | serialized_cmd_line = basic_tup.__to_argv__() 388 | assert set(serialized_cmd_line) == set(arg_cmd) 389 | my_parser = MyEnumBasic.__arg_suite__._parser 390 | assert my_parser._option_string_actions['--my_color_dict'].metavar == "int:{RED, BLUE, GREEN}" 391 | assert my_parser._option_string_actions['--default_color'].choices == Color 392 | 393 | arg_cmd2 = ['--a_int', '1', '--my_color_dict', '10:red', '--my_color_list', 'GREEN', '--my_color_tuple', 'BLUE', '100'] 394 | pytest.raises(SmartArgError, MyEnumBasic.__from_argv__, arg_cmd2) # capital case needed for enum `RED` 395 | --------------------------------------------------------------------------------