├── .github
├── PULL_REQUEST_TEMPLATE.md
└── workflows
│ └── python-check-publish.yml
├── .gitignore
├── .readthedocs.yml
├── CONTRIBUTING.md
├── LICENSE.md
├── NOTICE
├── README.md
├── docs
├── CONTRIBUTING.md
├── LICENSE.md
├── README.md
├── advanced.md
├── api.rst
├── conf.py
├── index.rst
├── requirements.txt
└── smart-arg-demo.gif
├── publish.sh
├── setup.cfg
├── setup.py
├── smart-arg-demo.gif
├── smart_arg.py
└── test
├── smart_arg_demo.py
└── test_smart_arg.py
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # Description
2 |
3 | Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
4 |
5 | Fixes # (issue)
6 |
7 | ## Type of change
8 |
9 | Please delete options that are not relevant.
10 |
11 | - [ ] Bug fix (non-breaking change which fixes an issue)
12 | - [ ] New feature (non-breaking change which adds functionality)
13 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
14 |
15 | ## List all changes
16 | Please list all changes in the commit.
17 | * change1
18 | * change2
19 |
20 | # Testing
21 | Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration
22 |
23 |
24 | **Test Configuration**:
25 | * Firmware version:
26 | * Hardware:
27 | * Toolchain:
28 | * SDK:
29 |
30 | # Checklist
31 |
32 | - [ ] My code follows the style guidelines of this project
33 | - [ ] I have performed a self-review of my own code
34 | - [ ] I have commented my code, particularly in hard-to-understand areas
35 | - [ ] I have made corresponding changes to the documentation
36 | - [ ] My changes generate no new warnings
37 | - [ ] I have added tests that prove my fix is effective or that my feature works
38 | - [ ] New and existing unit tests pass locally with my changes
39 | - [ ] Any dependent changes have been merged and published in downstream modules
40 |
--------------------------------------------------------------------------------
/.github/workflows/python-check-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python check and publish
5 |
6 | on:
7 | pull_request:
8 | push:
9 | branches:
10 | - master
11 |
12 | jobs:
13 | check:
14 | strategy:
15 | matrix:
16 | python-version: [3.6, 3.7, 3.8, 3.9]
17 | runs-on: ubuntu-latest
18 |
19 | steps:
20 | - uses: actions/checkout@v2
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v2
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install dependencies
26 | run: |
27 | python -m pip install --upgrade pip setuptools
28 | python setup.py develop
29 | - name: flake8
30 | run: |
31 | pip install flake8
32 | # stop the build if there are Python syntax errors or undefined names
33 | flake8 smart_arg.py
34 | - name: mypy
35 | run: |
36 | pip install types-pkg_resources
37 | [ ${{ matrix.python-version }} == 3.6 ] && pip install dataclasses types-dataclasses
38 | pip install mypy
39 | mypy
40 | - name: Test with pytest
41 | run: |
42 | pip install pytest
43 | pytest
44 |
45 | publish:
46 | runs-on: ubuntu-latest
47 | needs: check
48 | if: >
49 | github.ref == 'refs/heads/master' && github.event_name == 'push' && github.repository_owner == 'linkedin'
50 | && !contains(github.event.head_commit.message, 'NO_PUBLISH')
51 |
52 | steps:
53 | - uses: actions/checkout@v2
54 | with:
55 | fetch-depth: '0' # To fetch tags too
56 | - name: Set up Python
57 | uses: actions/setup-python@v2
58 | with:
59 | python-version: '3.7'
60 | - name: Install dependencies
61 | run: |
62 | python -m pip install --upgrade pip setuptools twine
63 | python setup.py develop
64 | - name: Build, publish and tag
65 | env:
66 | TWINE_USERNAME: __token__
67 | TWINE_PASSWORD: ${{ secrets.pypi_token }}
68 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
69 | run: |
70 | bash publish.sh $(python -c "import smart_arg; print(smart_arg._base_version)")
71 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | *.egg
3 | *.egg-info/
4 | *.iml
5 | *.ipr
6 | *.iws
7 | *.pyc
8 | *.pyo
9 | *.sublime-*
10 | .*.swo
11 | .*.swp
12 | .cache/
13 | .coverage
14 | .direnv/
15 | .env
16 | .envrc
17 | .gradle/
18 | .idea/
19 | .tox*
20 | .venv*
21 | .vscode/
22 | /*/*pinned.txt
23 | /*/.mypy_cache/
24 | /*/MANIFEST
25 | /*/activate
26 | /*/build/
27 | /*/config
28 | /*/coverage.xml
29 | /*/dist/
30 | /*/htmlcov/
31 | /*/product-spec.json
32 | /build/
33 | /config/external/
34 | /dist/
35 | /ligradle/
36 | TEST-*.xml
37 | __pycache__/
38 |
39 |
40 |
41 | # Added by mp-maker
42 | build
43 | .shelf
44 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | python:
4 | version: 3.7
5 | install:
6 | - requirements: docs/requirements.txt
7 | - method: pip
8 | path: .
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # CONTRIBUTION
2 |
3 | ## Contribution Agreement
4 |
5 | As a contributor, you represent that the code you submit is your original work or
6 | that of your employer (in which case you represent you have the right to bind your
7 | employer). By submitting code, you (and, if applicable, your employer) are
8 | licensing the submitted code to LinkedIn and the open source community subject
9 | to the BSD 2-Clause license.
10 |
11 | ## Responsible Disclosure of Security Vulnerabilities
12 |
13 | **Do not file an issue on Github for security issues.** Please review
14 | the [guidelines for disclosure][disclosure_guidelines]. Reports should
15 | be encrypted using PGP ([public key][pubkey]) and sent to
16 | [security@linkedin.com][disclosure_email] preferably with the title
17 | "Vulnerability in Github LinkedIn/smart-arg - <short summary>".
18 |
19 | ## Setup for Development
20 |
21 | ```shell-session
22 | # # Uncomment the next a few lines to set up a virtual environment and install the packages as needed
23 | # python3 -m venv .venv
24 | # . .venv/bin/activate
25 | # pip install -U setuptools pytest flake8 mypy
26 |
27 | python3 setup.py develop
28 |
29 | # Happily make changes
30 | # ...
31 |
32 | mypy
33 | flake8 smart_arg.py
34 | pytest
35 | ```
36 |
37 |
38 | ## Tips for Getting Your Pull Request Accepted
39 |
40 | 1. Make sure all new features are tested and the tests pass.
41 | 2. Bug fixes must include a test case demonstrating the error that it fixes.
42 | 3. Open an issue first and seek advice for your change before submitting
43 | a pull request. Large features which have never been discussed are
44 | unlikely to be accepted. **You have been warned.**
45 |
46 | [disclosure_guidelines]: https://www.linkedin.com/help/linkedin/answer/62924
47 | [pubkey]: https://www.linkedin.com/help/linkedin/answer/79676
48 | [disclosure_email]: mailto:security@linkedin.com?subject=Vulnerability%20in%20Github%20LinkedIn/smart-arg%20-%20%3Csummary%3E
49 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | BSD 2-CLAUSE LICENSE
2 | ====
3 | BSD 2-CLAUSE LICENSE
4 | Copyright 2020 LinkedIn Corporation
5 | All Rights Reserved.
6 |
7 | Redistribution and use in source and binary forms, with or
8 | without modification, are permitted provided that the following
9 | conditions are met:
10 |
11 | 1. Redistributions of source code must retain the above copyright
12 | notice, this list of conditions and the following disclaimer.
13 | 2. Redistributions in binary form must reproduce the above
14 | copyright notice, this list of conditions and the following
15 | disclaimer in the documentation and/or other materials provided
16 | with the distribution.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright 2020 LinkedIn Corporation
2 | All Rights Reserved.
3 |
4 | Licensed under the BSD 2-Clause License (the "License"). See License in the project root for license information.
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Smart Argument Suite (`smart-arg`)
2 |
3 | [](https://GitHub.com/linkedin/smart-arg/tags/)
4 | [](https://pypi.python.org/pypi/smart-arg/)
5 |
6 | Smart Argument Suite (`smart-arg`) is a slim and handy Python library that helps one work safely and conveniently
7 | with the arguments that are represented by an immutable argument container class' fields
8 | ([`NamedTuple`](https://docs.python.org/3.7/library/typing.html?highlight=namedtuple#typing.NamedTuple) or
9 | [`dataclass`](https://docs.python.org/3.7/library/dataclasses.html#dataclasses.dataclass) out-of-box),
10 | and passed through command-line interfaces.
11 |
12 | `smart-arg` promotes arguments type-safety, enables IDEs' code autocompletion and type hints
13 | functionalities, and helps one produce correct code.
14 |
15 | 
16 |
17 | ## Quick start
18 |
19 | The [`smart-arg`](https://pypi.org/project/smart-arg/) package is available through `pip`.
20 | ```shell
21 | pip3 install smart-arg
22 | ```
23 |
24 | Users can bring or define, if not already, their argument container class -- a `NamedTuple` or `dataclass`,
25 | and then annotate it with `smart-arg` decorator `@arg_suite` in their Python scripts.
26 |
27 | Now an argument container class instance, e.g. `my_arg` of `MyArg` class, once created, is ready to be serialized by the `smart-arg` API --
28 | `my_arg.__to_argv__()` to a sequence of strings, passed through the command-line interface
29 | and then deserialized back to an instance again by `my_arg = MyArg.__from_argv__(sys.argv[1:])`.
30 |
31 | ```python
32 | import sys
33 | from typing import NamedTuple, List, Tuple, Dict, Optional
34 | from smart_arg import arg_suite
35 |
36 |
37 | # Define the argument container class
38 | @arg_suite
39 | class MyArg(NamedTuple):
40 | """
41 | MyArg is smart! (docstring goes to description)
42 | """
43 | nn: List[int] # Comments go to argparse help
44 | a_tuple: Tuple[str, int] # a random tuple argument
45 | encoder: str # Text encoder type
46 | h_param: Dict[str, int] # Hyperparameters
47 | batch_size: Optional[int] = None
48 | adp: bool = True # bool is a bit tricky
49 | embedding_dim: int = 100 # Size of embedding vector
50 | lr: float = 1e-3 # Learning rate
51 |
52 |
53 | def cli_interfaced_job_scheduler():
54 | """
55 | This is to be called by the job scheduler to set up the job launching command,
56 | i.e., producer side of the Python job arguments
57 | """
58 | # Create the argument container instance
59 | my_arg = MyArg(nn=[3], a_tuple=("str", 1), encoder='lstm', h_param={}, adp=False) # The patched argument container class requires keyword arguments to instantiate the class
60 |
61 | # Serialize the argument to command-line representation
62 | argv = my_arg.__to_argv__()
63 | cli = 'my_job.py ' + ' '.join(argv)
64 | # Schedule the job with command line `cli`
65 | print(f"Executing job:\n{cli}")
66 | # Executing job:
67 | # my_job.py --nn 3 --a_tuple str 1 --encoder lstm --h_param --batch_size None --adp False --embedding_dim 100 --lr 0.001
68 |
69 |
70 | def my_job(my_arg: MyArg):
71 | """
72 | This is the actual job defined by the input argument my_arg,
73 | i.e., consumer side of the Python job arguments
74 | """
75 | print(my_arg)
76 | # MyArg(nn=[3], a_tuple=('str', 1), encoder='lstm', h_param={}, batch_size=None, adp=False, embedding_dim=100, lr=0.001)
77 |
78 | # `my_arg` can be used in later script with a typed manner, which help of IDEs (type hints and auto completion)
79 | # ...
80 | print(f"My network has {len(my_arg.nn)} layers with sizes of {my_arg.nn}.")
81 | # My network has 1 layers with sizes of [3].
82 |
83 |
84 | # my_job.py
85 | if __name__ == '__main__':
86 | # Deserialize the command-line representation of the argument back to a container instance
87 | arg_deserialized: MyArg = MyArg.__from_argv__(sys.argv[1:]) # Equivalent to `MyArg(None)`, one positional arg required to indicate the arg is a command-line representation.
88 | my_job(arg_deserialized)
89 | ```
90 |
91 | ```shell-session
92 | > python my_job.py -h
93 | usage: my_job.py [-h] --nn [int [int ...]] --a_tuple str int --encoder str
94 | --h_param [str:int [str:int ...]] [--batch_size int]
95 | [--adp {True,False}] [--embedding_dim int] [--lr float]
96 |
97 | MyArg is smart! (docstring goes to description)
98 |
99 | optional arguments:
100 | -h, --help show this help message and exit
101 | --nn [int [int ...]] (List[int], required) Comments go to argparse help
102 | --a_tuple str int (Tuple[str, int], required) a random tuple argument
103 | --encoder str (str, required) Text encoder type
104 | --h_param [str:int [str:int ...]]
105 | (Dict[str, int], required) Hyperparameters
106 | --batch_size int (Optional[int], default: None)
107 | --adp {True,False} (bool, default: True) bool is a bit tricky
108 | --embedding_dim int (int, default: 100) Size of embedding vector
109 | --lr float (float, default: 0.001) Learning rate
110 |
111 | ```
112 | ## Promoted practices
113 | * Focus on defining the arguments diligently, and let the `smart-arg`
114 | (backed by [argparse.ArgumentParser](https://docs.python.org/3/library/argparse.html#argumentparser-objects))
115 | work its magic around command-line interface.
116 | * Always work directly with argument container class instances when possible, even if you only need to generate the command-line representation.
117 | * Stick to the default behavior and the basic features, think twice before using any of the [advanced features](https://smart-arg.readthedocs.io/en/latest/advanced.html#advanced-usages).
118 |
119 |
120 | ## More detail
121 | For more features and implementation detail, please refer to the [documentation](https://smart-arg.readthedocs.io/).
122 |
123 | ## Contributing
124 |
125 | Please read [CONTRIBUTING.md](CONTRIBUTING.md) for details on our code of conduct, and the process for submitting pull requests to us.
126 |
127 | ## License
128 |
129 | This project is licensed under the BSD 2-CLAUSE LICENSE - see the [LICENSE.md](LICENSE.md) file for details
130 |
--------------------------------------------------------------------------------
/docs/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ../CONTRIBUTING.md
--------------------------------------------------------------------------------
/docs/LICENSE.md:
--------------------------------------------------------------------------------
1 | ../LICENSE.md
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | ../README.md
--------------------------------------------------------------------------------
/docs/advanced.md:
--------------------------------------------------------------------------------
1 | ## More Features
2 | ### Post processing & validation
3 |
4 | A user can define a method in the argument container class to do post processing after the argument is created bt
5 | before returned to the caller.
6 | For example, when a field's default value depends on some other input fields, one could use a default
7 | placeholder `LateInit`, and a `__post_init__` function to define the actual value.
8 |
9 | The `__post_init__` function can also be used to validate the argument after instantiation:
10 |
11 | ```python
12 | from typing import NamedTuple
13 |
14 | from smart_arg import arg_suite, LateInit
15 |
16 |
17 | @arg_suite
18 | class MyArg(NamedTuple):
19 | network: str
20 | _network = {'choices': ['cnn', 'mlp']} # Use enum instead if you can
21 | num_layers: int = LateInit
22 |
23 | def __post_init__(self) -> 'MyArg':
24 | assert self.network in self._network['choices'], f'Invalid network {self.network}'
25 | if self.num_layers is LateInit:
26 | if self.network == 'cnn':
27 | num_layers = 3
28 | elif self.network == 'mlp':
29 | num_layers = 5
30 | else:
31 | raise RuntimeError("Should not be reachable!")
32 | self.num_layers=num_layers
33 | else:
34 | assert self.num_layers >= 0, f"number_layers: {self.num_layers} can not be negative"
35 |
36 | ```
37 | Notes:
38 | * If any fields are assigned a default placeholder `LateInit`, a `__post_init__` is expected
39 | to be defined, replace any `LateInit` with actual values, or it will fail the internal validation and trigger a `SmartArgError`.
40 | * Field mutations are only allowed in the `__post_init__` as part of the class instance construction. No mutations are
41 | allowed after construction, including manually calling `__post_init__`.
42 | * `__post_init__` of a `NamedTuple` works similar to that of a `dataclass`.
43 |
44 | ## Supported Argument Container Classes
45 | ### [`NameTuple`](https://docs.python.org/3.7/library/typing.html?highlight=namedtuple#typing.NamedTuple)
46 | * Strong immutability
47 | * Part of the Python distribution
48 | * No inheritance
49 |
50 | ### [`dataclass`](https://docs.python.org/3.7/library/dataclasses.html)
51 | * Weak immutability with native `@dataclass(frozen=True)` or `smart-arg` patched (when not `frozen`)
52 | * `pip install dataclasses` is needed for Python 3.6
53 | * Inheritance support
54 | * Native `__post_init__` support
55 | ## Advanced Usages
56 |
57 | By default, `smart-arg` supports the following types as fields of an argument container class:
58 | * primitives: `int`, `float`, `bool`, `str`, `enum.Enum`
59 | * `Tuple`: elements of the tuple are expected to be primitives
60 | * `Sequence`/`Set`: `Sequence[int]`, `Sequence[float]`, `Sequence[bool]`, `Sequence[str]`, `Sequence[enum.Enum]`, `Set[int]`, `Set[float]`, `Set[bool]`, `Set[str]`, `Set[enum.Enum]`
61 | * `Dict`: `Dict[int, int]`, `Dict[int, float]`, `Dict[int, bool]`, `Dict[int, str]`, `Dict[float, int]`, `Dict[float, float]`,
62 | `Dict[float, bool]`, `Dict[float, str]`, `Dict[bool, int]`, `Dict[bool, float]`, `Dict[bool, bool]`, `Dict[bool, str]`,
63 | `Dict[str, int]`, `Dict[str, float]`, `Dict[str, bool]`, `Dict[str, str]`, `Dict[enum.Enum, int/float/bool/str]`, `Dict[int/float/bool/str, enum.Enum]`
64 | * `Optional[AnyOtherSupportedType]`: Beware that any optional field is required to **default to `None`**.
65 |
66 | ### override argument Ser/De
67 | A user can change the parsing behavior of certain field of an argument container class.
68 | One can only do this when the field's type is already supported by `smart-arg`.
69 |
70 |
71 | This is done by defining a private companion field starts with "``__``" (double underscores) to overwrite the keyed arguments
72 | to [ArgumentParser.add_argument](https://docs.python.org/3/library/argparse.html#the-add-argument-method) with a dictionary.
73 | The key '_serialization' defines an [`iterator/generator`](https://wiki.python.org/moin/Generators) and all other keys go to `argparse`
74 | for deserialization/parsing.
75 |
76 | ALERT: this can lead to **inconsistent behaviors** when one also generates the command-line
77 | representation of an argument container class instance, since it can only modify the deserialization
78 | behavior from the command-line representation.
79 | ```python
80 | from typing import NamedTuple, Sequence
81 |
82 | from smart_arg import arg_suite
83 |
84 |
85 | @arg_suite
86 | class MyTup(NamedTuple):
87 | a_list: Sequence[int]
88 | __a_list = {'choices': [200, 300], 'nargs': '+'}
89 | ```
90 |
91 | ### override or extend the support of primitive and other types
92 | A user can use this provided functionality to change serialization and deserialization behavior for supported types and add support for additional types.
93 | * User can overwrite the primitive types handling by defining additional `PrimitiveHandlerAddon`. The basic primitive handler
94 | is defined in source code `PrimitiveHandlerAddon`. A user can pass in the customized handlers to the decorator.
95 | * Same to type handler by providing additional `TypeHandler` and pass in the decorator argument. `TypeHandler` is to deal with complex types
96 | other than primitive ones such as `Sequence`, `Set`, `Dict`, `Tuple`, etc.
97 |
98 | ```python
99 | from math import sqrt
100 | from typing import NamedTuple, Any, Type
101 |
102 | from smart_arg import PrimitiveHandlerAddon, TypeHandler, custom_arg_suite
103 |
104 |
105 | # overwrite int primitive type handling by squaring it
106 | class IntHandlerAddon(PrimitiveHandlerAddon):
107 | @staticmethod
108 | def build_type(arg_type) -> Any:
109 | return lambda s: int(s) ** 2
110 |
111 | @staticmethod
112 | def build_str(arg) -> str:
113 | return str(int(sqrt(arg)))
114 |
115 | @staticmethod
116 | def handles(t: Type) -> bool:
117 | return t == int
118 |
119 |
120 | class IntTypeHandler(TypeHandler):
121 | def _build_other(self, kwargs, arg_type) -> None:
122 | kwargs.type = self.primitive_addons[0].build_type(arg_type)
123 |
124 | def handles(self, t: Type) -> bool:
125 | return t == int
126 |
127 |
128 | my_suite = custom_arg_suite(primitive_handler_addons=[IntHandlerAddon], type_handlers=[IntTypeHandler])
129 |
130 |
131 | @my_suite
132 | class MyTuple(NamedTuple):
133 | a_int: int
134 |
135 | ```
--------------------------------------------------------------------------------
/docs/api.rst:
--------------------------------------------------------------------------------
1 | API Reference
2 | -------------
3 | .. automodule:: smart_arg
4 | :members:
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import datetime
3 | import smart_arg
4 |
5 | # Add any Sphinx extension module names here, as strings. They can be extensions
6 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
7 | extensions = [
8 | 'sphinx.ext.autodoc',
9 | 'sphinx_autodoc_typehints',
10 | 'sphinx.ext.intersphinx',
11 | 'recommonmark',
12 | ]
13 |
14 | version = release = smart_arg.__version__
15 | # Add any paths that contain templates here, relative to this directory.
16 | templates_path = ['_templates']
17 |
18 | # The suffix of source filenames.
19 | source_suffix = '.rst'
20 |
21 | # The master toctree document.
22 | master_doc = 'index'
23 |
24 | project = u'Smart Argument Suite'
25 |
26 | # General information about the project.
27 | copyright = f'2020-{datetime.datetime.today().year}, LinkedIn'
28 |
29 | # List of patterns, relative to source directory, that match files and
30 | # directories to ignore when looking for source files.
31 | exclude_patterns = ['_build']
32 |
33 | # The name of the Pygments (syntax highlighting) style to use.
34 | pygments_style = 'sphinx'
35 |
36 | # -- Options for HTML output ---------------------------------------------------
37 |
38 | html_theme = 'default'
39 | # html_static_path = ['_static']
40 |
41 | # Example configuration for intersphinx: refer to the Python standard library.
42 | intersphinx_mapping = {
43 | 'python': ('http://docs.python.org/', None),
44 | }
45 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to ``smart-arg``'s documentation!
2 | ============================================
3 | Contents:
4 |
5 | .. toctree::
6 | :maxdepth: 2
7 | :glob:
8 |
9 | README
10 | advanced
11 | api
12 | CONTRIBUTING
13 | LICENSE
14 |
15 |
16 | Indices and tables
17 | ==================
18 |
19 | * :ref:`genindex`
20 | * :ref:`modindex`
21 | * :ref:`search`
22 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx_autodoc_typehints
--------------------------------------------------------------------------------
/docs/smart-arg-demo.gif:
--------------------------------------------------------------------------------
1 | ../smart-arg-demo.gif
--------------------------------------------------------------------------------
/publish.sh:
--------------------------------------------------------------------------------
1 | set -e
2 |
3 | function git_tag() {
4 | # get repo name from git
5 | repo=$(git config --get remote.origin.url | sed 's|^.*[/:]\(.*/.*\)\(\.git\)\?$|\1|')
6 | commit=$(git rev-parse HEAD)
7 | new_tag=$1
8 | git tag "$new_tag"
9 |
10 | echo >&2 POST the new tag ref "$new_tag" to "$repo" via Github API
11 | HTTP_STATUS=$(
12 | curl -w "%{http_code}" -o >(cat >&2) -si -X POST "https://api.github.com/repos/$repo/git/refs" \
13 | -H "Authorization: token $GITHUB_TOKEN" \
14 | -d @- </dev/null); then
32 | echo >&2 last_matched_tag is "$last_matched_tag" for "$star_version"
33 | rest=${last_matched_tag#v$star_version}
34 | increment_version=$((${rest%%.*} + 1))
35 | [ "$(git rev-parse HEAD)" == "$(git rev-parse "$last_matched_tag")" ] && echo HEAD already at "$last_matched_tag" >&2 && return 126
36 | fi
37 | resolved_version=$star_version_prefix$((increment_version))
38 | echo "$star_version is resolved to version '$resolved_version'." >&2
39 | echo "$resolved_version"
40 | else
41 | echo "Using the input star version '$star_version' as a fixed version." >&2
42 | echo "$star_version"
43 | fi
44 | else
45 | echo -e "Unsupported star version: '$star_version'.\nOnly supports star minor or patch version." >&2
46 | return 127
47 | fi
48 | }
49 |
50 | # Set star_version to the first argument or $STAR_VERSION or $($STAR_VERSION_CMD) in this order
51 | star_version=${1-${STAR_VERSION-$($STAR_VERSION_CMD)}}
52 | new_version=$(resolve_version "$star_version")
53 | git_tag v"$new_version"
54 |
55 | rm -rf dist
56 | python setup.py sdist
57 | twine upload dist/* --verbose
58 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = E731, W503
3 | max-line-length = 160
4 |
5 | [tool:pytest]
6 | testpaths = test
7 |
8 | [coverage:report]
9 | fail_under = 90
10 | show_missing = true
11 |
12 | [coverage:run]
13 | branch = true
14 |
15 | [mypy]
16 | files = smart_arg.py
17 | ignore_missing_imports = true
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Argument class <=> Human friendly cli"""
2 |
3 | from setuptools import setup
4 |
5 | doc = 'https://smart-arg.readthedocs.io'
6 | with open('README.md', encoding='utf-8') as f:
7 | readme = f.read()
8 |
9 | setup(
10 | name='smart-arg',
11 | use_scm_version={
12 | 'root': '.',
13 | 'relative_to': __file__
14 | },
15 | setup_requires=['setuptools_scm'],
16 | description=__doc__,
17 | long_description=readme,
18 | long_description_content_type='text/markdown',
19 | license='BSD-2-CLAUSE',
20 | python_requires='>=3.6',
21 | url=doc,
22 | download_url='https://pypi.python.org/pypi/smart-arg',
23 | project_urls={
24 | 'Documentation': doc,
25 | 'Source': 'https://github.com/linkedin/smart-arg.git',
26 | 'Tracker': 'https://github.com/linkedin/smart-arg/issues',
27 | },
28 | py_modules=['smart_arg'],
29 | install_requires=[],
30 | tests_require=['pytest', 'mypy'],
31 | classifiers=[
32 | 'Programming Language :: Python',
33 | 'Programming Language :: Python :: 3',
34 | 'Operating System :: OS Independent',
35 | 'License :: OSI Approved',
36 | 'Typing :: Typed'
37 | ],
38 | keywords=[
39 | 'typing',
40 | 'argument parser',
41 | 'reverse argument parser',
42 | 'human friendly',
43 | 'configuration (de)serialization',
44 | 'python',
45 | 'cli'
46 | ]
47 | )
48 |
--------------------------------------------------------------------------------
/smart-arg-demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linkedin/smart-arg/cc762a7f8b0bb2abf98d380492ebc7e40fae927a/smart-arg-demo.gif
--------------------------------------------------------------------------------
/smart_arg.py:
--------------------------------------------------------------------------------
1 | """Smart Argument Suite
2 |
3 | This module is an argument serialization and deserialization library that:
4 |
5 | - handles two-way conversions: a typed, preferably immutable argument container object
6 | (a `NamedTuple` or `dataclass` instance) <==> a command-line argv
7 | - enables IDE type hints and code auto-completion by using `NamedTuple` or `dataclass`
8 | - promotes type-safety of cli-ready arguments
9 |
10 |
11 | The following is a simple usage example::
12 |
13 | # Define the argument container class:
14 | @arg_suite
15 | class MyArg(NamedTuple):
16 | '''MyArg is smart and the docstring goes to description'''
17 |
18 | nn: List[int] # Comments go to ArgumentParser argument help
19 |
20 | a_tuple: Tuple[str, int] # Arguments without defaults are treated as "required = True"
21 | h_param: Dict[str, int] # Also supports List, Set
22 |
23 | ###### arguments without defaults go first as required by NamedTuple ######
24 |
25 | l_r: float = 0.01 # Arguments with defaults are treated as "required = False"
26 | n: Optional[int] = None # Optional can only default to `None`
27 |
28 |
29 | # Create the corresponding argument container class instance from the command line argument list:
30 | # by using
31 | arg: ArgClass = ArgClass.__from_argv__(sys.argv[1:]) # with factory method, need the manual type hint `:ArgClass` to help IDE
32 | # or the monkey-patched constructor of the annotated argument container class
33 | arg = ArgClass(sys.argv[1:]) # with monkey-patched constructor with one positional argument, no manual hint needed
34 |
35 | # Create a NamedTuple argument instance and generate its command line counterpart:
36 | # the monkey-patched constructor only take named arguments for directly creating the NamedTuple
37 | arg = ArgClass(nn=[200, 300], a_tuple=('t', 0), h_param={'batch_size': 1000})
38 | # generate the corresponding command line argument list
39 | arg.__to_argv__()
40 |
41 |
42 | The module contains the following public classes/functions:
43 |
44 | - `arg_suite` -- The main entry point for Smart Argument Suite. As the
45 | example above shows, this decorator will attach an `ArgSuite` instance to
46 | the argument container class `NamedTuple` or `dataclass` and patch it.
47 |
48 | - `PrimitiveHandlerAddon` -- The base class to deal some basic operation on primitive types,
49 | and users can implement their owns to change the behavior.
50 |
51 | - `TypeHandler` -- The base class to handle types, users can extend to expand the
52 | support or change existing behaviors.
53 |
54 | All other classes and methods in this module are considered implementation details."""
55 |
56 | import logging
57 | import os
58 | import sys
59 | from argparse import Action, ArgumentParser
60 | from collections import abc
61 | from enum import EnumMeta, Enum
62 | from itertools import chain
63 | from types import SimpleNamespace
64 | from typing import Any, Callable, Dict, Generic, Iterable, List, NamedTuple, Optional, Sequence, Tuple, Type, TypeVar, Union
65 | from warnings import warn
66 |
67 | import pkg_resources
68 |
69 | _base_version = '1.1.*' # star version should be auto-resolved to concrete number when run setup.py
70 | __version__ = pkg_resources.get_distribution(__name__).version
71 | __all__ = (
72 | 'arg_suite',
73 | 'custom_arg_suite',
74 | 'frozenlist',
75 | 'LateInit',
76 | 'SmartArgError',
77 | 'TypeHandler',
78 | 'PrimitiveHandlerAddon',
79 | )
80 |
81 | ArgType = TypeVar('ArgType', bound=NamedTuple) # NamedTuple is not a real class bound, but this makes mypy happier
82 | NoneType = None.__class__
83 | FieldMeta = NamedTuple('FieldMeta', (('comment', str), ('default', Any), ('type', Type), ('optional', bool)))
84 | KwargsType = SimpleNamespace
85 |
86 | logger = logging.getLogger(__name__)
87 | SMART_ARG_LOG_LEVEL = 'SMART_ARG_LOG_LEVEL'
88 | if SMART_ARG_LOG_LEVEL in os.environ:
89 | logger.addHandler(logging.StreamHandler())
90 | log_level = os.environ[SMART_ARG_LOG_LEVEL].upper()
91 | logger.setLevel(log_level)
92 | logger.info(f"Detected environment var 'SMART_ARG_LOG_LEVEL', set log level to '{log_level}' and log to stream.")
93 |
94 | if sys.version_info >= (3, 8):
95 | from typing import get_origin, get_args
96 | elif sys.version_info >= (3, 7):
97 | # Python == 3.7.x. Defining the back-ported get_origin, get_args
98 | # 3.7 `List.__origin__ == list`
99 | get_origin, get_args = lambda tp: getattr(tp, '__origin__', None), lambda tp: getattr(tp, '__args__', [])
100 | elif sys.version_info >= (3, 6):
101 | # Python == 3.6.x. Defining the back-ported get_origin, get_args
102 | # 3.6 `List.__origin__ == List`, `Optional` does not have `__dict__`
103 | get_origin, get_args = lambda tp: getattr(tp, '__extra__', ()) or getattr(tp, '__origin__', None), lambda tp: getattr(tp, '__args__', ()) or []
104 | else:
105 | try:
106 | warn(f"Unsupported and untested python version {sys.version_info} < 3.6. "
107 | f"The package may or may not work properly. "
108 | f"Try 'from typing_inspect import get_origin, get_args' now.")
109 | from typing_inspect import get_args, get_origin
110 | except ImportError:
111 | warn(f"`from typing_inspect import get_origin, get_args` failed for "
112 | f"unsupported python version {sys.version_info} < 3.6. It might work with 'pip install typing_inspect'.")
113 | raise
114 |
115 | frozenlist = tuple # tuple to simulate an immutable list. [https://www.python.org/dev/peps/pep-0603/#rationale]
116 | _black_hole = lambda *_, **__: None
117 | _mro_all = lambda arg_class, get_dict: {k: v for b in (*arg_class.__mro__[-1:0:-1], arg_class) if b.__module__ != 'builtins' for k, v in get_dict(b).items()}
118 | # note: Fields without type annotation won't be regarded as a property/entry _fields.
119 | _annotations = lambda arg_class: _mro_all(arg_class, lambda b: getattr(b, '__annotations__', {}))
120 | try:
121 | from dataclasses import MISSING, asdict, is_dataclass
122 | except ImportError:
123 | logger.warning("Importing dataclasses failed. You might need to 'pip install dataclasses' on python 3.6.")
124 | class MISSING: pass # type: ignore # noqa: E701
125 | is_dataclass = _black_hole # type: ignore # Always return None
126 |
127 |
128 | class LateInit:
129 | """special singleton/class to mark late initialized fields"""
130 |
131 |
132 | class SmartArgError(Exception): # TODO Extend to better represent different types of errors.
133 | """Base exception for smart-arg."""
134 |
135 |
136 | def _raise_if(message: str, condition: Any = True):
137 | if condition:
138 | raise SmartArgError(message)
139 |
140 |
141 | def _first_handles(with_handles, arg_type, default: Optional[bool] = None):
142 | handle = next(filter(lambda h: h.handles(arg_type), with_handles), default)
143 | return handle if handle is not None else _raise_if(f"{arg_type!r} is not supported.")
144 |
145 |
146 | class PrimitiveHandlerAddon:
147 | """Primitive handler addon defines some basic operations on primitive types. Only `staticmethod` is expected.
148 | Users can extend/modify the primitive handling by inheriting this class."""
149 | @staticmethod
150 | def build_type(arg_type: Type) -> Union[Type, Callable[[str], Any]]:
151 | return (lambda s: True if s == 'True' else False if s == 'False' else _raise_if(f"Invalid bool: {s!r}")) if arg_type is bool else \
152 | (lambda s: getattr(arg_type, s, None) or _raise_if(f"Invalid enum {s!r} for {arg_type!r}")) if type(arg_type) is EnumMeta else \
153 | arg_type
154 |
155 | @staticmethod
156 | def build_str(arg: Any) -> str:
157 | """Define to serialize the `arg` to a string.
158 |
159 | :param arg: The argument
160 | :type arg: Any type that is supported by this class
161 | :return: The string serialization of `arg`"""
162 | return str(arg.name) if isinstance(arg, Enum) else str(arg)
163 |
164 | @staticmethod
165 | def build_metavar(arg_type: Type) -> str:
166 | """Define the hint string in argument help message for `arg_type`."""
167 | return '{True, False}' if arg_type is bool else \
168 | f"{{{', '.join(map(str, arg_type._member_names_))}}}" if type(arg_type) is EnumMeta else \
169 | arg_type.__name__
170 |
171 | @staticmethod
172 | def build_choices(arg_type) -> Optional[Iterable[Any]]:
173 | """Enumerate `arg_type` if possible, or return `None`."""
174 | return (True, False) if arg_type is bool else \
175 | arg_type if type(arg_type) is EnumMeta else \
176 | None
177 |
178 | @staticmethod
179 | def handles(t: Type) -> bool:
180 | return t in (str, int, float, bool) or type(t) is EnumMeta
181 |
182 |
183 | class TypeHandler:
184 | """Base type handler. A subclass typically implements `gen_cli_arg` for serialization and `_build_other` for deserialization."""
185 | def __init__(self, primitive_addons: Sequence[Type[PrimitiveHandlerAddon]]):
186 | self.primitive_addons = primitive_addons
187 |
188 | def _build_common(self, kwargs: KwargsType, field_meta: FieldMeta, parent_required: bool) -> None:
189 | """Build `help`, `default` and `required` for keyword arguments for ArgumentParser.add_argument
190 |
191 | :param parent_required: is parent a required field or not
192 | :param kwargs: the keyword argument KwargsType object
193 | :param field_meta: the meta information extracted from the NamedTuple class"""
194 | # Build help message
195 | arg_type = field_meta.type
196 | help_builder = ['(', 'Optional[' if field_meta.optional else '', self._type_to_str(arg_type), ']' if field_meta.optional else '']
197 | # Get default if specified and set required if no default
198 | kwargs.required = parent_required and field_meta.default is MISSING
199 | message_required = 'required' if kwargs.required else \
200 | 'to be post-processed' if field_meta.default is LateInit else \
201 | 'required if its container is being specified or marked' if field_meta.default is MISSING else \
202 | f'default: {field_meta.default}' # informational only. The default is set when creating the argument container instance.
203 | help_builder.extend(('; ', message_required, ') ', field_meta.comment)) # Add from source code comment
204 | kwargs.help = ''.join(help_builder)
205 |
206 | def _build_other(self, kwargs: KwargsType, arg_type: Type) -> None:
207 | """Build `nargs`, `type` and `metavar` for the keyword argument KwargsType object
208 |
209 | :param kwargs: the keyword argument KwargsType object
210 | :param arg_type: the type of the argument extracted from NamedTuple (primitive types)"""
211 |
212 | def gen_kwargs(self, field_meta: FieldMeta, parent_required: bool) -> KwargsType:
213 | """Build keyword argument object KwargsType
214 |
215 | :param parent_required: if the parent container is required
216 | :param field_meta: argument metadata information
217 | :return: keyword argument object"""
218 | kwargs = KwargsType()
219 | self._build_common(kwargs, field_meta, parent_required)
220 | self._build_other(kwargs, field_meta.type)
221 | return kwargs
222 |
223 | def gen_cli_arg(self, arg: Any) -> Iterable[str]:
224 | """Generate command line for argument. This defines the serialization process.
225 |
226 | :param arg: value of the argument
227 | :return: iterable command line str"""
228 | args = (arg,) if isinstance(arg, str) or not isinstance(arg, Iterable) else arg
229 | yield from map(lambda a: _first_handles(self.primitive_addons, type(a)).build_str(a), args)
230 |
231 | def _type_to_str(self, t: Union[type, Type]) -> str:
232 | """Convert type to string for ArgumentParser help message
233 |
234 | :param t: type of the argument, i.e. float, Dict[str, int], Set[int], List[str] etc.
235 | :return: string representation of the argument type"""
236 | return f'{getattr(t, "_name", "") or t.__name__}[{", ".join(map(lambda a: a.__name__, get_args(t)))}]'
237 |
238 | def handles(self, t: Type) -> bool:
239 | raise NotImplementedError
240 |
241 |
242 | class PrimitiveHandler(TypeHandler):
243 | def handles(self, t: Type) -> bool:
244 | return _first_handles(self.primitive_addons, t, False)
245 |
246 | def _build_other(self, kwargs: KwargsType, arg_type: Type) -> None:
247 | addon = _first_handles(self.primitive_addons, arg_type)
248 | kwargs.type = addon.build_type(arg_type)
249 | kwargs.metavar = addon.build_metavar(arg_type)
250 | kwargs.choices = addon.build_choices(arg_type)
251 |
252 | def _type_to_str(self, t: Union[type, Type]) -> str:
253 | return t.__name__
254 |
255 |
256 | class TupleHandler(TypeHandler):
257 | class __BuildType:
258 | def __init__(self, types, p_addons):
259 | self.__name__ = f'Tuple[{", ".join(t.__name__ for t in types)}]'
260 | self.counter, self.types, self.p_addons = 0, types, p_addons
261 |
262 | def __call__(self, s):
263 | if self.counter == len(self.types):
264 | self.counter = 0
265 | t = self.types[self.counter]
266 | self.counter += 1
267 | return _first_handles(self.p_addons, t).build_type(t)(s)
268 |
269 | def _build_other(self, kwargs: KwargsType, arg_type: Type) -> None:
270 | # get the tuple element types
271 | types = get_args(arg_type)
272 | kwargs.nargs = len(types)
273 | kwargs.metavar = tuple(map(lambda t: _first_handles(self.primitive_addons, t).build_metavar(t), types))
274 | kwargs.type = TupleHandler.__BuildType(types, self.primitive_addons)
275 |
276 | def handles(self, t: Type) -> bool:
277 | return get_origin(t) is tuple and get_args(t) # type: ignore
278 |
279 |
280 | class CollectionHandler(TypeHandler):
281 | def _build_other(self, kwargs: KwargsType, arg_type: Type) -> None:
282 | kwargs.nargs = '*'
283 | unboxed_type = get_args(arg_type)[0]
284 | addon = _first_handles(self.primitive_addons, unboxed_type)
285 | kwargs.metavar = addon.build_metavar(unboxed_type)
286 | kwargs.type = addon.build_type(unboxed_type)
287 |
288 | def handles(self, t: Type) -> bool:
289 | args = get_args(t)
290 | return len(args) == 1 and get_origin(t) in (list, set, frozenset, abc.Sequence) and _first_handles(self.primitive_addons, args[0], False)
291 |
292 |
293 | class DictHandler(TypeHandler):
294 | def _build_other(self, kwargs, arg_type) -> None:
295 | addon_method = lambda t, method: getattr(_first_handles(self.primitive_addons, t), method)(t) # Find the addon for a type and a method
296 | kv_apply = lambda method, arg_types=get_args(arg_type): (addon_method(arg_types[0], method), addon_method(arg_types[1], method)) # Apply on k/v pair
297 | k_type, v_type = kv_apply('build_type')
298 | kwargs.nargs = '*'
299 |
300 | def dict_type(s: str):
301 | k, v = s.split(":")
302 | return k_type(k), v_type(v)
303 | kwargs.type = dict_type
304 | k, v = kv_apply('build_metavar')
305 | kwargs.metavar = f'{k}:{v}'
306 |
307 | def gen_cli_arg(self, arg):
308 | arg_to_str = lambda arg_v: _first_handles(self.primitive_addons, type(arg_v)).build_str(arg_v)
309 | yield from map(lambda kv: f'{arg_to_str(kv[0])}:{arg_to_str(kv[1])}', arg.items())
310 |
311 | def handles(self, t: Type) -> bool:
312 | args, addons = get_args(t), self.primitive_addons
313 | return len(args) == 2 and get_origin(t) is dict and _first_handles(addons, args[0], False) and _first_handles(addons, args[1], False)
314 |
315 |
316 | class _namedtuple: # TODO expand lambdas to static methods or use a better holder representation
317 | """A NamedTuple proxy, helper function holder for NamedTuple support"""
318 | @staticmethod
319 | def new_instance(arg_class, kwargs):
320 | """:return A new instance of `arg_class`: call original __new__ -> __post_init__ -> post_validation"""
321 | new_instance = arg_class.__original_new__(arg_class, **kwargs)
322 | post_init = getattr(arg_class, '__post_init__', None)
323 | if post_init:
324 | fake_namedtuple = SimpleNamespace(**new_instance._asdict())
325 | post_init(fake_namedtuple) # make the faked NamedTuple mutable in post_init only while initialization
326 | new_instance = arg_class.__original_new__(arg_class, **vars(fake_namedtuple))
327 | arg_class.__arg_suite__.post_validation(new_instance)
328 | return new_instance
329 |
330 | @staticmethod
331 | def proxy(t: Type):
332 | """:return This proxy class if `t` is a `NamedTuple` or `None`"""
333 | b, f, f_t = getattr(t, '__bases__', []), getattr(t, '_fields', []), getattr(t, '__annotations__', {})
334 | return _namedtuple if (len(b) == 1 and b[0] is tuple and isinstance(f, tuple) and isinstance(f_t, dict)
335 | and all(map(lambda n: type(n) is str, chain(f, f_t.keys())))) else None
336 | asdict = lambda args: args._asdict()
337 | field_default = lambda arg_class, raw_arg_name: arg_class._field_defaults.get(raw_arg_name, MISSING)
338 | patch = _black_hole # No need to patch
339 |
340 |
341 | class _dataclasses:
342 | """dataclass proxy"""
343 | @staticmethod
344 | def patch(cls):
345 | """Patch the argument dataclass so that `post_validation` is called, and it's immutable if not `frozen` after initialization"""
346 | def raise_if_frozen(self, fun, name, *args, **kwargs):
347 | _raise_if(f"cannot assign to/delete field {name!r}", getattr(self, '__frozen__', False))
348 | getattr(object, fun)(self, name, *args, **kwargs)
349 |
350 | def init(self, *args, **kwargs):
351 | if args and hasattr(self, '__frozen__'):
352 | logger.debug(f"Assuming {self} is from the patched __new__ with __from_argv__, already initialized, skipping init.")
353 | return
354 | self.__original_init__(*args, **kwargs)
355 | object.__setattr__(self, '__frozen__', True)
356 | self.__class__.__arg_suite__.post_validation(self)
357 | cls.__init__, cls.__original_init__ = init, cls.__init__
358 | cls.__setattr__ = lambda self, name, value: raise_if_frozen(self, '__setattr__', name, value)
359 | cls.__delattr__ = lambda self, name, : raise_if_frozen(self, '__delattr__', name)
360 |
361 | @staticmethod
362 | def field_default(arg_class, raw_arg_name):
363 | f = arg_class.__dataclass_fields__[raw_arg_name]
364 | return f.default_factory() if f.default_factory is not MISSING else f.default # Assuming the default_factory has no side effects.
365 | proxy = lambda t: _dataclasses if is_dataclass(t) else None
366 | asdict = lambda args: asdict(args)
367 | new_instance = lambda arg_class, _: arg_class.__original_new__(arg_class)
368 |
369 |
370 | _type_proxies = [_namedtuple, _dataclasses] # Supported container types. Users can extend it if they know what they are doing. Not officially supported yet.
371 |
372 |
373 | def _get_type_proxy(arg_class):
374 | return next(filter(lambda p: p.proxy(arg_class), _type_proxies), False)
375 |
376 |
377 | def _unwrap_optional(arg_type):
378 | type_origin, type_args, optional = get_origin(arg_type), get_args(arg_type), False
379 | if type_origin is Union and len(type_args) == 2 and type_args[1] is NoneType: # `Optional` support
380 | arg_type, optional = type_args[0], True # Unwrap `Optional` and validate
381 | return arg_type, optional
382 |
383 |
384 | class ArgSuite(Generic[ArgType]):
385 | """Generates the corresponding `ArgumentParser` and handles the two-way conversions."""
386 | @staticmethod
387 | def new_arg(arg_class, *args, **kwargs):
388 | """Monkey-Patched argument container class __new__.
389 | If any positional arguments exist, it would assume that the user is trying to parse a sequence of strings.
390 | It would also assume there is only one positional argument, and raise an `SmartArgError` otherwise.
391 | If no positional arguments exist, it would call the argument container class instance creator with all keyword arguments.
392 |
393 | :param arg_class: Decorated class
394 | :param args: Optional positional argument, to be parsed to the arg_class type.
395 | `args[0]`: an optional marker to mark the sub-sequence of `argv` to be parsed by the parser. ``None`` will
396 | be interpreted as ``sys.argv[1:]``
397 | `args[1]`: default to `None`, indicating using the default separator for the argument container class
398 |
399 | :type `(Optional[Sequence[str]], Optional[str])`
400 | :param kwargs: Optional keyword arguments, to be passed to the argument container class specific instance creator."""
401 | logger.info(f"Patched __new__ for {arg_class} is called with {args} and {kwargs}.")
402 | if args:
403 | warn(f"Calling the patched constructor of {arg_class} with argv is deprecated, please use {arg_class}.__from_argv__ instead.")
404 | _raise_if(f"Calling '{arg_class}(positional {args}, keyword {kwargs})' is not allowed:\n"
405 | f"Only accept positional arguments to parse to the '{arg_class}'\nkeyword arguments can only be used to create an instance directly.",
406 | kwargs or len(args) > 2 or len(args) == 2 and args[1].__class__ not in (NoneType, str)
407 | or not (args[0] is None or isinstance(args[0], Sequence) and all(map(lambda a: a.__class__ is str, args[0]))))
408 |
409 | return arg_class.__from_argv__(args[0])
410 | else:
411 | return _get_type_proxy(arg_class).new_instance(arg_class, kwargs)
412 |
413 | def __init__(self, type_handlers: Sequence[TypeHandler], arg_class):
414 | type_proxy = _get_type_proxy(arg_class)
415 | _raise_if(f"Unsupported argument container class {arg_class}.", not type_proxy)
416 | self.handlers = type_handlers
417 | self.handler_actions: Dict[str, Tuple[Union[TypeHandler, Type], Action]] = {}
418 | type_proxy.patch(arg_class)
419 | # A big assumption here is that the argument container classes never override __new__
420 | if not hasattr(arg_class, '__original_new__'):
421 | arg_class.__original_new__ = arg_class.__new__
422 | arg_class.__new__ = ArgSuite.new_arg
423 | arg_class.__to_argv__ = lambda arg_self, separator='': self.to_argv(arg_self, separator) # arg_class instance level method
424 | arg_class.__from_argv__ = self.parse_to_arg
425 | arg_class.__arg_suite__ = self
426 | self._arg_class = arg_class
427 | self._parser = ArgumentParser(description=self._arg_class.__doc__, argument_default=MISSING, # type: ignore
428 | fromfile_prefix_chars='@', allow_abbrev=False)
429 | self._parser.convert_arg_line_to_args = lambda arg_line: arg_line.split() # type: ignore
430 | self._gen_arguments_from_class(self._arg_class, '', True, [], type_proxy)
431 |
432 | @staticmethod
433 | def _validate_fields(arg_class: Type) -> None:
434 | """Validate fields in `arg_class`.
435 |
436 | :raise: SmartArgError if the decorated argument container class has non-typed field with defaults and such field
437 | does not startswith "_" to overwrite the existing argument field property."""
438 | arg_fields = _annotations(arg_class).keys()
439 | invalid_fields = tuple(filter(lambda s: s.endswith('_'), arg_fields))
440 | _raise_if(f"'{arg_class}': found invalid (ending with '_') fields : {invalid_fields}.", invalid_fields)
441 | private_prefix = f'_{arg_class.__name__}__'
442 | l_prefix = len(private_prefix)
443 | # skip callable methods and typed fields
444 | for f in filter(lambda f: not callable(getattr(arg_class, f)) and f not in arg_fields, vars(arg_class).keys()):
445 | is_private = f.startswith(private_prefix)
446 | _raise_if(f"'{arg_class}': there is no field '{f[l_prefix:]}' for '{f}' to override.", is_private and f[l_prefix:] not in arg_fields)
447 | _raise_if(f"'{arg_class}': found invalid (untyped) field '{f}'.", not (is_private or f.startswith('_')))
448 |
449 | def _gen_arguments_from_class(self, arg_class, prefix: str, parent_required, arg_classes: List, type_proxy) -> None:
450 | """Add argument to the self._parser for each field in the self._arg_class
451 | :raise: SmartArgError if a corresponding handler for the argument type is not found."""
452 | _raise_if(f"Recursively nested argument container class '{arg_class}' is not supported.", arg_class in arg_classes)
453 | suite = getattr(arg_class, '__arg_suite__', None)
454 | _raise_if(f"Nested argument container class '{arg_class}' with '__post_init__' expected to be decorated.",
455 | not (suite and suite._arg_class is arg_class) and hasattr(arg_class, '__post_init__'))
456 | self._validate_fields(arg_class)
457 | comments = _mro_all(arg_class, self.get_comments_from_source)
458 | for raw_arg_name, arg_type in _annotations(arg_class).items():
459 | arg_name = f'{prefix}{raw_arg_name}'
460 | try:
461 | default = type_proxy.field_default(arg_class, raw_arg_name)
462 | arg_type, optional = _unwrap_optional(arg_type)
463 | _raise_if(f"Optional field: {arg_name!r}={default!r} must default to `None` or `LateInit`", optional and default not in (None, LateInit))
464 | sub_type_proxy = _get_type_proxy(arg_type)
465 | if sub_type_proxy:
466 | required = parent_required and default is MISSING
467 | arg_classes.append(arg_class)
468 | self._gen_arguments_from_class(arg_type, f'{arg_name}.', required, arg_classes, sub_type_proxy)
469 | arg_classes.pop()
470 | kwargs = KwargsType(nargs='?',
471 | required=required,
472 | metavar='Does NOT accept any value',
473 | # capture the value of arg_name in the lambda default argument since arg_name is not immutable
474 | type=lambda _, arg=arg_name: _raise_if(f"Nested argument container marker {arg!r} does not accept any value."),
475 | help=f"nested argument container marker{'' if default in (LateInit, MISSING) else f'; default: {default}'}")
476 | else:
477 | handler = _first_handles(self.handlers, arg_type)
478 | user_override: dict = getattr(arg_class, f'_{arg_class.__name__}__{raw_arg_name}', {})
479 | comment = user_override.pop('_comment_to_help', comments.get(raw_arg_name, ''))
480 | field_meta = FieldMeta(comment=comment, default=default, type=arg_type, optional=optional)
481 | kwargs = handler.gen_kwargs(field_meta, parent_required)
482 | if user_override.get('choices', None):
483 | logger.info(f"Instead of defining `choices`, please consider using Enum for {arg_name}")
484 | kwargs.__dict__.update(user_override) # apply user override to the keyword argument object
485 | kwargs.__dict__.pop('_serialization', None)
486 | logger.debug(f"Adding kwargs {kwargs} for --{arg_name}")
487 | self.handler_actions[arg_name] = (sub_type_proxy or handler), self._parser.add_argument(f'--{arg_name}', **vars(kwargs))
488 | except BaseException as b_e:
489 | logger.critical(f"Failed creating argument parser for {arg_name!r}:{arg_type!r} with exception {b_e}.")
490 | raise
491 |
492 | @staticmethod
493 | def strip_argv(separator: str, argv: Optional[Sequence[str]]) -> Sequence[str]:
494 | """Strip any elements outside `{separator}+` and `{separator}-` of `argv`.
495 | :param separator: A string marker prefix to mark the boundaries of the belonging arguments in argv
496 | :param argv: Input argument list, treated as `sys.argv[1:]` if `None`
497 | :return: Stripped `argv`"""
498 | if argv is None:
499 | argv = sys.argv[1:]
500 | l_s, r_s = separator + '+', separator + '-'
501 | lc, rc = argv.count(l_s), argv.count(r_s)
502 | if lc == rc:
503 | if lc == 0:
504 | return argv
505 | elif lc == 1:
506 | b, e = argv.index(l_s), argv.index(r_s)
507 | if e > b:
508 | return argv[b + 1: e]
509 | raise SmartArgError(f"Expecting up to 1 pair of separator markers'{l_s}' and '{r_s}' in {argv}")
510 |
511 | def parse_to_arg(self, argv: Optional[Sequence[str]] = None, separator: Optional[str] = '', *, error_on_unknown: bool = True) -> ArgType:
512 | """Parse the command line to decorated ArgType
513 |
514 | :param separator: Optional marker to mark the sub-sequence of `argv` ['{separator}+' to '{separator}-'] to parse
515 | :param error_on_unknown: When `True`, raise if there is any unknown argument in the marked sub-sequence.
516 | :param argv: the command line list
517 | :return: parsed decorated object from command line"""
518 | def to_arg(arg_class: Type[ArgType], prefix) -> ArgType:
519 | nest_arg = {}
520 | for raw_arg_name, arg_type in _annotations(arg_class).items():
521 | arg_name = f'{prefix}{raw_arg_name}'
522 | handler_or_proxy, _ = self.handler_actions[arg_name]
523 | if isinstance(handler_or_proxy, TypeHandler):
524 | value = arg_dict.get(arg_name, MISSING)
525 | type_origin = get_origin(arg_type)
526 | type_args = get_args(arg_type)
527 | type_to_new = get_origin(type_args[0]) if type_origin is Union and len(type_args) == 2 and type_args[1] is NoneType else type_origin
528 | # argparse reading variable length arguments are all lists, need to apply the origin type for the conversion to correct type.
529 | value = type_to_new(value) if value is not None and isinstance(value, List) else value # type: ignore
530 | else: # deal with nested container, consider defined only if there's subfield
531 | is_nested_items_defined = any(filter(lambda name: name.startswith(arg_name), arg_dict.keys())) # type: ignore
532 | value = to_arg(_unwrap_optional(arg_type)[0], f'{arg_name}.') if is_nested_items_defined else MISSING
533 | if value is not MISSING:
534 | nest_arg[raw_arg_name] = value
535 | return arg_class(**nest_arg)
536 | argv = self.strip_argv(separator or self._arg_class.__name__, argv)
537 | ns, unknown = self._parser.parse_known_args(argv)
538 | error_on_unknown and unknown and self._parser.error(f"unrecognized arguments: {' '.join(unknown)}")
539 | logger.info(f"{argv} is parsed to {ns}")
540 | arg_dict: Dict[str, Any] = {name: value for name, value in vars(ns).items() if value is not MISSING} # Filter out defaults added by the parser.
541 | return to_arg(self._arg_class, '')
542 |
543 | def post_validation(self, arg: ArgType, prefix: str = '') -> None:
544 | """This is called after __post_init__ to validate the fields."""
545 | arg_class = arg.__class__
546 | for name, t in _annotations(arg_class).items():
547 | attr = getattr(arg, name)
548 | _raise_if(f"Field '{name}' is still not initialized after post processing for {arg_class}", attr is LateInit)
549 | attr_class = attr.__class__
550 | type_origin, type_args = get_origin(t), get_args(t)
551 | if _get_type_proxy(t) or type_origin is Union and len(type_args) == 2 and type_args[1] is NoneType and _get_type_proxy(type_args[0]):
552 | _raise_if(f"Field {name}' value of {attr}:{attr_class} is not of the expected type '{t}' for {arg_class}", attr_class not in (t, *type_args))
553 | self.post_validation(attr, f'{prefix}{name}.')
554 | else:
555 | arg_type = get_origin(t) or t
556 | if arg_type is get_origin(Optional[Any]):
557 | arg_type = get_args(t)
558 | if arg_type is list and isinstance(attr, tuple) or arg_type is set and isinstance(attr, frozenset): # allow for immutable replacements
559 | continue
560 | try:
561 | conforming = isinstance(attr, arg_type) # e.g. list and List
562 | except TypeError: # best effort to check the instance type
563 | logger.warning(f"Unable to check if {attr!r} is of type {t} for field {name!r} of argument container class {arg_class}")
564 | conforming = True
565 | _raise_if(f"Field {name!r} has value of {attr!r} of type {attr_class} which is not of the expected type '{t}' for {arg_class}", not conforming)
566 |
567 | def _gen_cmd_argv(self, args: ArgType, prefix) -> Iterable[str]:
568 | arg_class = args.__class__
569 | proxy = _get_type_proxy(arg_class)
570 | for name, arg in proxy.asdict(args).items():
571 | default = proxy.field_default(arg_class, name)
572 | if arg != default:
573 | handler_or_proxy, action = self.handler_actions[f'{prefix}{name}']
574 | serialization_user_override = getattr(arg_class, f'_{arg_class.__name__}__{name}', {}).get('_serialization', None)
575 | yield action.option_strings[0]
576 | yield from serialization_user_override(arg) if serialization_user_override else \
577 | handler_or_proxy.gen_cli_arg(arg) if isinstance(handler_or_proxy, TypeHandler) else \
578 | self._gen_cmd_argv(arg, f'{prefix}{name}.')
579 |
580 | @staticmethod
581 | def get_comments_from_source(arg_cls: Type[ArgType]) -> Dict[str, str]:
582 | """Get in-line comments for the input class fields. Only single line of trailing comment is supported.
583 |
584 | :param arg_cls: the input class
585 | :return: a dictionary with key of class field name and value of in-line comment"""
586 | comments = {}
587 | indent = 0
588 | field = None
589 | import inspect, tokenize # noqa: E401
590 | try:
591 | for token in tokenize.generate_tokens(iter(inspect.getsourcelines(arg_cls)[0]).__next__):
592 | if token.type == tokenize.NEWLINE:
593 | field = None
594 | elif token.type == tokenize.INDENT:
595 | indent += 1
596 | elif token.type == tokenize.DEDENT:
597 | indent -= 1
598 | elif token.type == tokenize.NAME and indent == 1 and not field:
599 | field = token
600 | elif token.type == tokenize.COMMENT and field:
601 | # TODO nicer way to deal with with long comments or support multiple lines
602 | comments[field.string] = (token.string + ' ')[1:token.string.lower().find('# noqa:')] # TODO consider move processing out
603 | except Exception as e:
604 | logger.error(f'Failed to parse comments from source of class {arg_cls}, continue without them.', exc_info=e)
605 | return comments
606 |
607 | def to_argv(self, arg: ArgType, separator: Optional[str] = '') -> Sequence[str]:
608 | """Generate the command line arguments
609 |
610 | :param separator: separator marker, empty str to disable, None to default to class name
611 | :param arg: the annotated argument container class object
612 | :return: command line sequence"""
613 | argv = self._gen_cmd_argv(arg, '')
614 | if separator is None:
615 | separator = self._arg_class.__name__
616 | return (separator + '+', *argv, separator + '-') if separator else tuple(argv)
617 |
618 |
619 | class ArgSuiteDecorator:
620 | """Generate a decorator to easily convert back and forth from command-line to `NamedTuple` or `dataclass`.
621 |
622 | The decorator monkey patches the constructor, so that the IDE would infer the type of
623 | the deserialized `arg` for code auto-completion and type check::
624 |
625 | arg: ArgClass = ArgClass.__from_argv__(my_argv) # Factory method, need the manual type hint `:ArgClass`
626 | arg = ArgClass(my_argv) # with monkey-patched constructor, no manual hint needed
627 |
628 | For the ArgClass fields without types but with default values: only private fields starts with "__" to overwrite the existed
629 | argument parameters are allowed; others will throw SmartArgError.
630 |
631 | For the handlers/addons, first one in the sequence that claims to handle a type takes the precedence.
632 |
633 | Usage::
634 |
635 | @arg_suite # `arg_suite` is a shorthand for `custom_arg_suite()`
636 | class MyArg(NamedTuple):
637 | field_one: str
638 | _field_one = {"choices": ["one", "two"]} # advanced usage: overwrite the `field_one` parameter
639 | field_two: List[int]
640 | ...
641 |
642 | :param primitive_handler_addons: the primitive types handling in addition to the provided primitive type basic operations
643 | :param type_handlers: the types handling in addition to the provided types handling.
644 | :return: the argument container class decorator"""
645 | def __init__(self, type_handlers: Sequence[Type[TypeHandler]] = (), primitive_handler_addons: Sequence[Type[PrimitiveHandlerAddon]] = ()) -> None:
646 | addons = (*primitive_handler_addons, PrimitiveHandlerAddon)
647 | self.handlers = tuple(map(lambda handler: handler(addons), chain(type_handlers, (PrimitiveHandler, CollectionHandler, DictHandler, TupleHandler))))
648 |
649 | def __call__(self, cls):
650 | ArgSuite(self.handlers, cls)
651 | return cls
652 |
653 |
654 | custom_arg_suite = ArgSuiteDecorator # Snake case decorator alias for the class name in camel case
655 | arg_suite = custom_arg_suite() # Default argument container class decorator to expose smart arg functionalities.
656 |
--------------------------------------------------------------------------------
/test/smart_arg_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from enum import Enum, auto
3 | from typing import List, NamedTuple, Tuple, Optional, Dict, Sequence
4 |
5 | from smart_arg import arg_suite, LateInit, frozenlist
6 |
7 |
8 | class Encoder(Enum):
9 | FASTTEXT = auto()
10 | WORD2VEC = auto()
11 |
12 |
13 | class NestedArg(NamedTuple):
14 | class NNested(NamedTuple):
15 | n_name: str
16 | n_nested: NNested = NNested(n_name='nested name')
17 |
18 |
19 | @arg_suite
20 | class MyModelConfig(NamedTuple):
21 | """MyModelConfig docstring goes to description"""
22 | adp: bool
23 | nn: List[int] # Comments go to argparse help
24 | a_tuple: Tuple[str, int]
25 | h_param: Dict[str, int] = {} # Hyperparameters
26 | encoder: Encoder = Encoder.FASTTEXT # Word encoder type
27 | nested: Optional[NestedArg] = None # nested args
28 | n: Optional[int] = LateInit # An argument that can be auto-set in post processing
29 | immutable_list: Sequence[int] = frozenlist((1, 2)) # A frozenlist is an alias to tuple
30 | embedding_dim: int = 100 # Size of embedding vector
31 | lr: float = 1e-3 # Learning rate
32 |
33 | def __post_init__(self):
34 | assert self.nn, "Expect nn to be non-empty." # validation
35 | if self.n is LateInit: # post processing
36 | self.n = self.nn[0]
37 |
38 |
39 | if __name__ == '__main__':
40 | my_config = MyModelConfig(nn=[200, 300], a_tuple=("s", 5), adp=True, h_param={'n': 0, 'y': 1}, nested=NestedArg())
41 |
42 | my_config_argv = my_config.__to_argv__()
43 | print(f"Serialized reference to the expected argument container object:\n{my_config_argv!r}")
44 |
45 | print(f"Deserializing from the command line string list: {sys.argv[1:]!r}.")
46 | deserialized = MyModelConfig.__from_argv__()
47 |
48 | re_deserialized = MyModelConfig.__from_argv__(my_config_argv)
49 | if deserialized == my_config == re_deserialized:
50 | print(f"All matched, argument container object: '{deserialized!r}'")
51 | else:
52 | err_code = 168
53 | print(f"Error({err_code}):\n"
54 | f"Expected argument container object:\n {my_config!r}\n"
55 | f"Re-deserialized:\n {re_deserialized!r}\n"
56 | f"Deserialized:\n {deserialized!r}")
57 | exit(err_code)
58 |
--------------------------------------------------------------------------------
/test/test_smart_arg.py:
--------------------------------------------------------------------------------
1 | """Tests for smart-arg-suite."""
2 | import os
3 | import subprocess
4 | import sys
5 | from contextlib import redirect_stderr
6 | from dataclasses import dataclass, replace, field
7 | from enum import Enum
8 | from math import sqrt
9 | from types import SimpleNamespace
10 | from typing import List, NamedTuple, Tuple, Optional, Dict, Set, Type, Any, Sequence, FrozenSet
11 |
12 | import pytest
13 |
14 | from smart_arg import arg_suite, custom_arg_suite, LateInit, SmartArgError, TypeHandler, _first_handles, PrimitiveHandlerAddon, frozenlist
15 |
16 |
17 | @arg_suite
18 | @dataclass(frozen=True)
19 | class MyTupBasic:
20 | """
21 | MyTup docstring goes to description
22 | """
23 | def _serialization(a_str):
24 | if a_str not in MyTupBasic.__a_str['choices']:
25 | raise ValueError
26 | yield a_str
27 |
28 | a_int: int # a is int
29 | a_float: float # a is float
30 | a_bool: bool
31 | a_str: str # comment will NOT go to a_str's help
32 | _comment_in_help = """
33 | multiline override
34 | trailing comment in help"""
35 | __a_str = {'choices': ['hello', 'bonjour', 'hola'], '_serialization': _serialization, '_comment_to_help': _comment_in_help} # will overwrite the a_str argument choices, serialization and comment
36 | b_list_int: List[int]
37 | b_set_str: Set[str]
38 | d_tuple_3: Tuple[int, float, bool]
39 | d_tuple_2: Tuple[str, int]
40 | e_dict_str_int: Dict[str, int]
41 | e_dict_int_bool: Dict[int, bool]
42 | c_optional_float: Optional[float] = None
43 | immutable_list: List[int] = frozenlist((1, 2))
44 | immutable_seq: Sequence[int] = frozenlist((1, 2))
45 | immutable_set: Set[int] = frozenset((1, 2))
46 | immutable_f_set: FrozenSet[int] = frozenset((1, 2))
47 | with_default_factory: List[float] = field(default_factory=list)
48 | with_default: int = field(default=10)
49 |
50 |
51 | my_tup_basic = MyTupBasic(
52 | a_int=32,
53 | a_float=0.3,
54 | a_bool=True,
55 | a_str='hello',
56 | b_list_int=[1, 2, 3],
57 | b_set_str={'set1', 'set2'},
58 | c_optional_float=None,
59 | d_tuple_2=('tuple', 12),
60 | d_tuple_3=(10, 0.5, False),
61 | e_dict_str_int={'size': 32, 'area': 90},
62 | e_dict_int_bool={10: True, 20: False, 30: True})
63 |
64 |
65 | def test_basic_parse_to_arg():
66 | with pytest.raises(TypeError, match="missing 10"):
67 | MyTupBasic()
68 | arg_cmd = 'MyTupBasic+ --a_int 32 --a_float 0.3 --a_bool True --a_str hello --b_list_int 1 2 3 --b_set_str set1 set2 ' + \
69 | '--d_tuple_3 10 0.5 False --d_tuple_2 tuple 12 --e_dict_str_int size:32 area:90 --e_dict_int_bool 10:True 20:False 30:True MyTupBasic-'
70 | parsed_arg_from_factory: MyTupBasic = MyTupBasic.__from_argv__(arg_cmd.split())
71 | assert my_tup_basic == parsed_arg_from_factory
72 | serialized_cmd_line = my_tup_basic.__to_argv__(separator=None)
73 | assert set(serialized_cmd_line) == set(arg_cmd.split())
74 | my_parser = MyTupBasic.__arg_suite__._parser
75 | assert my_parser._option_string_actions['--c_optional_float'].help == '(Optional[float]; default: None) '
76 | assert my_parser._option_string_actions['--a_int'].help == '(int; required) a is int'
77 | assert my_parser._option_string_actions['--a_float'].help == '(float; required) a is float'
78 | assert my_parser._option_string_actions['--a_str'].choices == ['hello', 'bonjour', 'hola']
79 | assert my_parser._option_string_actions['--a_str'].help == f"(str; required) {MyTupBasic._comment_in_help}"
80 |
81 | parsed_arg = MyTupBasic(arg_cmd.split()) # Patched constructor
82 | assert parsed_arg == my_tup_basic
83 |
84 | pytest.raises(ValueError, replace(parsed_arg, a_str='').__to_argv__, []) # serialization is overridden
85 |
86 |
87 | muted = redirect_stderr(SimpleNamespace(write=lambda *_: None))
88 |
89 |
90 | def test_parse_error():
91 | with muted:
92 | pytest.raises(SystemExit, MyTupBasic.__from_argv__, [])
93 |
94 |
95 | def test_optional():
96 | @arg_suite
97 | class MyTup(NamedTuple):
98 | ints: Optional[List[int]] = None
99 |
100 | with muted:
101 | pytest.raises(SystemExit, MyTup.__from_argv__, ['--ints', 'None'])
102 | assert MyTup.__arg_suite__._parser._option_string_actions['--ints'].help == '(Optional[List[int]]; default: None) '
103 | assert MyTup.__from_argv__([]).ints is None
104 | assert MyTup.__from_argv__(['--ints', '1', '2']).ints == [1, 2]
105 | assert MyTup.__from_argv__(['--ints']).ints == []
106 |
107 | class InvalidOptional(NamedTuple):
108 | no: Optional[int]
109 |
110 | pytest.raises(SmartArgError, arg_suite, InvalidOptional)
111 |
112 |
113 | def test_post_process():
114 | @arg_suite
115 | class MyTup(NamedTuple):
116 | a_int: Optional[int] = LateInit # if a_int is not in the argument, post_process will initialize it
117 |
118 | pytest.raises(SmartArgError, MyTup)
119 | pytest.raises(SmartArgError, MyTup, a_int='not a int')
120 | pytest.raises(SmartArgError, MyTup.__from_argv__, [])
121 |
122 | @arg_suite
123 | class MyTup(NamedTuple):
124 | a_int: Optional[int] = LateInit # if a_int is not in the argument, post_process will initialize it
125 |
126 | def __post_init__(self):
127 | self.a_int = 10 if self.a_int is LateInit else self.a_int
128 |
129 | assert MyTup.__from_argv__([]).a_int == 10
130 | assert MyTup().a_int == 10
131 | assert MyTup(a_int=0).a_int == 0
132 |
133 |
134 | def test_validate():
135 | @arg_suite
136 | class MyTup(NamedTuple):
137 | a_int: int = LateInit # if a_int is not in the argument, __post_init__ needs to initialize it
138 |
139 | validated = False
140 |
141 | def validate(s):
142 | nonlocal validated
143 | validated = True
144 | raise AttributeError()
145 |
146 | MyTup.__post_init__ = validate
147 | pytest.raises(AttributeError, MyTup.__from_argv__, ['--a_int', '1'])
148 | assert validated, "`validate` might not be executed."
149 |
150 | @arg_suite
151 | class MyTuple(NamedTuple):
152 | abc: str
153 |
154 | def __post_init__(self):
155 | if self.abc != 'abc':
156 | raise AttributeError()
157 | return self
158 |
159 | @staticmethod
160 | def format():
161 | return "hello"
162 |
163 | @classmethod
164 | def format2(cls):
165 | return "hello2"
166 |
167 | format3 = lambda: "hello3"
168 |
169 | # no exception for callable methods
170 | tup = MyTuple.__from_argv__(['--abc', 'abc'])
171 | assert tup.abc == 'abc'
172 | assert tup.format() == 'hello'
173 | assert MyTuple.format2() == 'hello2'
174 | assert MyTuple.format3() == 'hello3'
175 |
176 |
177 | def test_validate_fields():
178 | class DanglingParamOverride(NamedTuple):
179 | __a_str = "abc"
180 |
181 | class MyNonType(NamedTuple):
182 | a_str = "abc"
183 |
184 | class MyNonTypeTuple(NamedTuple):
185 | a_tuple: Tuple
186 |
187 | class TrailingUnderscore(NamedTuple):
188 | a_: int
189 |
190 | pytest.raises(SmartArgError, arg_suite, DanglingParamOverride)
191 | pytest.raises(SmartArgError, arg_suite, MyNonType)
192 | pytest.raises(SmartArgError, arg_suite, MyNonTypeTuple)
193 | pytest.raises(SmartArgError, arg_suite, TrailingUnderscore)
194 |
195 |
196 | def test_primitive_addon():
197 | class IntHandlerAddon(PrimitiveHandlerAddon):
198 | @staticmethod
199 | def build_type(arg_type) -> Any:
200 | return lambda s: int(s) ** 2
201 |
202 | @staticmethod
203 | def build_str(arg) -> str:
204 | return str(int(sqrt(arg)))
205 |
206 | @staticmethod
207 | def handles(t: Type) -> bool:
208 | return t == int
209 |
210 | class IntTypeHandler(TypeHandler):
211 | def _build_common(self, kwargs, field_meta, parent_required) -> None:
212 | super()._build_common(kwargs, field_meta, parent_required)
213 | kwargs.help = '(int, squared)'
214 |
215 | def _build_other(self, kwargs, arg_type) -> None:
216 | kwargs.type = _first_handles(self.primitive_addons, arg_type).build_type(arg_type)
217 |
218 | def handles(self, t: Type) -> bool:
219 | return t == int
220 |
221 | @custom_arg_suite(primitive_handler_addons=[IntHandlerAddon], type_handlers=[IntTypeHandler])
222 | class MyTuple(NamedTuple):
223 | a_int: int
224 |
225 | argv = ('--a_int', '3')
226 | tup = MyTuple.__from_argv__(argv)
227 | assert tup.a_int == 9
228 | assert tup.__to_argv__() == argv
229 | my_parser = MyTuple.__arg_suite__._parser
230 | assert my_parser._option_string_actions['--a_int'].help == '(int, squared)'
231 |
232 |
233 | def test_unsupported_types():
234 | class MyTuple(NamedTuple):
235 | a: List[List[int]]
236 |
237 | pytest.raises(SmartArgError, arg_suite, MyTuple)
238 |
239 |
240 | def test_nested():
241 | @arg_suite
242 | class MyTup(NamedTuple):
243 | def __post_init__(self):
244 | self.b_int = 10 if self.b_int is LateInit else self.b_int
245 |
246 | b_int: int = LateInit # if a_int is not in the argument, __post_init__ will initialize it
247 |
248 | @arg_suite
249 | class Nested(NamedTuple):
250 | a_int: int
251 | nested: Optional[MyTupBasic] = None # Optional nested
252 | another_int: int = 0
253 | another_nested: MyTup = MyTup()
254 |
255 | nested = Nested(a_int=0)
256 | assert nested.another_nested.b_int == 10
257 | argv = nested.__to_argv__()
258 | assert Nested(argv) == nested
259 | assert Nested(argv[0:2]) == nested
260 | pytest.raises(SmartArgError, Nested, ['--a_int', '0', '--nested', 'Not Allowed'])
261 | pytest.raises(SmartArgError, Nested, a_int=0, nested='Not nested MyTupBasic')
262 |
263 | class NotDecoratedWithPost(NamedTuple):
264 | def __post_init__(self): pass
265 |
266 | class Nested(NamedTuple):
267 | nested: NotDecoratedWithPost
268 |
269 | # Nested class is not allowed to have __post_init__ if not decorated
270 | pytest.raises(SmartArgError, arg_suite, Nested)
271 | del NotDecoratedWithPost.__post_init__
272 | arg_suite(Nested) # should not raise
273 |
274 |
275 | def test_cli_execution():
276 | cmd_line = f'{sys.executable if sys.executable else "python"} {os.path.join(os.path.dirname(__file__), "smart_arg_demo.py")}'
277 | args = ' --nn 200 300 --a_tuple s 5 --encoder FASTTEXT --h_param y:1 n:0 --nested --embedding_dim 100 --lr 0.001 --adp True'
278 | cmd_line += args
279 | kwargs = {'stdout': subprocess.PIPE, 'stderr': subprocess.PIPE, 'shell': True}
280 | completed_process = subprocess.run(cmd_line, **kwargs)
281 | assert completed_process.stderr.decode('utf-8') == ''
282 | assert completed_process.returncode == 0
283 |
284 | completed_process = subprocess.run(f'{cmd_line} --adp False', **kwargs)
285 | assert completed_process.returncode == 168, "deserialization succeeded, but not as expected"
286 |
287 | completed_process = subprocess.run(f'{cmd_line} --nested "OH NO!"', **kwargs)
288 | assert completed_process.returncode == 1, "nested field should not have any value defined directly"
289 |
290 | separator = 'MyModelConfig'
291 | completed_process = subprocess.run(f'{cmd_line} {separator}+ {args} {separator}- --nested "OH NO!"', **kwargs)
292 | assert completed_process.stderr.decode('utf-8') == ''
293 | assert completed_process.returncode == 0, "Argument outside of separators should be ignored."
294 |
295 |
296 | def test_dataclass():
297 | @dataclass
298 | class GdmixParams:
299 | ACTIONS = ("action_inference", "action_train")
300 | action: str = ACTIONS[1] # Train or inference.
301 | __action = {"choices": ACTIONS}
302 | STAGES = ("fixed_effect", "random_effect")
303 | stage: str = STAGES[0] # Fixed or random effect.
304 | __stage = {"choices": STAGES}
305 | MODEL_TYPES = ("logistic_regression", "detext")
306 | model_type: str = MODEL_TYPES[0] # The model type to train, e.g, logistic regression, detext, etc.
307 | __model_type = {"choices": MODEL_TYPES}
308 |
309 | # Input / output files or directories
310 | training_output_dir: Optional[str] = None # Training output directory.
311 | validation_output_dir: Optional[str] = None # Validation output directory.
312 |
313 | # Driver arguments for random effect training
314 | partition_list_file: Optional[str] = None # File containing a list of all the partition ids, for random effect only
315 |
316 | def __post_init__(self):
317 | assert self.action in self.ACTIONS, "Action must be either train or inference"
318 | assert self.stage in self.STAGES, "Stage must be either fixed_effect or random_effect"
319 | assert self.model_type in self.MODEL_TYPES, "Model type must be either logistic_regression or detext"
320 |
321 | @dataclass
322 | class SchemaParams:
323 | # Schema names
324 | sample_id: str # Sample id column name in the input file.
325 | sample_weight: Optional[str] = None # Sample weight column name in the input file.
326 | label: Optional[str] = None # Label column name in the train/validation file.
327 | prediction_score: Optional[str] = None # Prediction score column name in the generated result file.
328 | prediction_score_per_coordinate: str = "predictionScorePerCoordinate" # ColumnName of the prediction score without the offset.
329 |
330 | @arg_suite
331 | @dataclass
332 | class Params(GdmixParams, SchemaParams):
333 | """GDMix Driver"""
334 |
335 | def __post_init__(self):
336 | super().__post_init__()
337 | assert (self.action == self.ACTIONS[1] and self.label) or (self.action == self.ACTIONS[0] and self.prediction_score)
338 | self.prediction_score = self.prediction_score
339 |
340 | argv = ['--sample_id', 'uid', '--sample_weight', 'weight', '--feature_bags', 'global', '--train_data_path',
341 | 'resources/train', '--validation_data_path',
342 | 'resources/validate', '--model_output_dir', 'dummy_model_output_dir', '--metadata_file',
343 | 'resources/fe_lbfgs/metadata/tensor_metadata.json', '--feature_file',
344 | 'test/resources/fe_lbfgs/featureList/global']
345 | with muted: pytest.raises(SystemExit, Params.__from_argv__, argv)
346 | pytest.raises(AssertionError, Params.__from_argv__, argv, error_on_unknown=False)
347 | args: Params = Params.__from_argv__(argv + ['--label', 'bluh'], error_on_unknown=False)
348 | pytest.raises(AssertionError, Params, sample_id='uid', action='no_such_action')
349 | pytest.raises(SmartArgError, args.__post_init__) # mutation not allowed after init
350 | object.__delattr__(args, '__frozen__')
351 | args.__post_init__() # mutation allowed after '__frozen__` mark removed
352 | assert args == Params(sample_id='uid', sample_weight='weight', label='bluh')
353 | assert args.__to_argv__() == ('--sample_id', 'uid', '--sample_weight', 'weight', '--label', 'bluh')
354 |
355 | @arg_suite
356 | @dataclass
357 | class NoPostInit:
358 | def mutate(self):
359 | self.frozen = False
360 |
361 | frozen: bool = True
362 |
363 | pytest.raises(SmartArgError, NoPostInit().mutate) # mutation not allowed after init
364 |
365 |
366 | def test_basic_enum():
367 | class Color(Enum):
368 | RED = 1
369 | BLUE = 2
370 | GREEN = 3
371 |
372 | @arg_suite
373 | class MyEnumBasic(NamedTuple):
374 | a_int: int
375 | my_color_dict: Dict[int, Color]
376 | my_color_list: List[Color]
377 | my_color_tuple: Tuple[Color, int]
378 | default_color: Color = Color.RED
379 |
380 | arg_cmd = ['--a_int', '1', '--my_color_dict', '10:RED', '20:BLUE', '--my_color_list', 'GREEN',
381 | '--my_color_tuple', 'BLUE', '100', '--default_color', 'GREEN']
382 | basic_tup = MyEnumBasic(a_int=1, my_color_dict={10: Color.RED, 20: Color.BLUE}, my_color_list=[Color.GREEN],
383 | my_color_tuple=(Color.BLUE, 100), default_color=Color.GREEN)
384 |
385 | parsed_tup: MyEnumBasic = MyEnumBasic.__from_argv__(arg_cmd)
386 | assert basic_tup == parsed_tup
387 | serialized_cmd_line = basic_tup.__to_argv__()
388 | assert set(serialized_cmd_line) == set(arg_cmd)
389 | my_parser = MyEnumBasic.__arg_suite__._parser
390 | assert my_parser._option_string_actions['--my_color_dict'].metavar == "int:{RED, BLUE, GREEN}"
391 | assert my_parser._option_string_actions['--default_color'].choices == Color
392 |
393 | arg_cmd2 = ['--a_int', '1', '--my_color_dict', '10:red', '--my_color_list', 'GREEN', '--my_color_tuple', 'BLUE', '100']
394 | pytest.raises(SmartArgError, MyEnumBasic.__from_argv__, arg_cmd2) # capital case needed for enum `RED`
395 |
--------------------------------------------------------------------------------