├── .github └── workflows │ └── tests.yml ├── .gitignore ├── LICENSE.txt ├── README.md ├── demo.py ├── demo_data_model.py ├── images ├── logo.png ├── tap.png └── tap_logo.png ├── pyproject.toml ├── src └── tap │ ├── __init__.py │ ├── py.typed │ ├── tap.py │ ├── tapify.py │ └── utils.py └── tests ├── test_actions.py ├── test_integration.py ├── test_load_config_files.py ├── test_subparser.py ├── test_tapify.py ├── test_to_tap_class.py └── test_utils.py /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: tests 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | matrix: 18 | os: [ubuntu-latest, macos-latest, windows-latest] 19 | python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] 20 | 21 | steps: 22 | - uses: actions/checkout@main 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@main 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Set temp directories on Windows 28 | if: matrix.os == 'windows-latest' 29 | run: | 30 | echo "TMPDIR=$env:USERPROFILE\AppData\Local\Temp" >> $env:GITHUB_ENV 31 | echo "TEMP=$env:USERPROFILE\AppData\Local\Temp" >> $env:GITHUB_ENV 32 | echo "TMP=$env:USERPROFILE\AppData\Local\Temp" >> $env:GITHUB_ENV 33 | - name: Install dependencies 34 | run: | 35 | git config --global user.email "you@example.com" 36 | git config --global user.name "Your Name" 37 | python -m pip install --upgrade pip 38 | python -m pip install -e ".[dev-no-pydantic]" 39 | - name: Lint with flake8 40 | run: | 41 | # stop the build if there are Python syntax errors or undefined names 42 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 43 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 44 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 45 | - name: Test without pydantic 46 | run: | 47 | pytest --cov=tap 48 | - name: Test with pydantic v1 49 | run: | 50 | python -m pip install "pydantic < 2" 51 | pytest --cov=tap --cov-append 52 | - name: Test with pydantic v2 53 | run: | 54 | python -m pip install "pydantic >= 2" 55 | pytest --cov=tap --cov-append 56 | 57 | - name: Upload coverage reports to Codecov 58 | uses: codecov/codecov-action@v4 59 | with: 60 | token: ${{ secrets.CODECOV_TOKEN }} 61 | verbose: true 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | __pycache__ 4 | .DS_Store 5 | *.json 6 | *.egg-info 7 | build 8 | .eggs 9 | .coverage 10 | dist 11 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024 Jesse Michel and Kyle Swanson 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | # Typed Argument Parser (Tap) 6 | 7 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/typed-argument-parser)](https://badge.fury.io/py/typed-argument-parser) 8 | [![PyPI version](https://badge.fury.io/py/typed-argument-parser.svg)](https://badge.fury.io/py/typed-argument-parser) 9 | [![Downloads](https://pepy.tech/badge/typed-argument-parser)](https://pepy.tech/project/typed-argument-parser) 10 | [![Build Status](https://github.com/swansonk14/typed-argument-parser/workflows/tests/badge.svg)](https://github.com/swansonk14/typed-argument-parser) 11 | [![codecov](https://codecov.io/gh/swansonk14/typed-argument-parser/branch/main/graph/badge.svg)](https://codecov.io/gh/swansonk14/typed-argument-parser) 12 | [![license](https://img.shields.io/github/license/swansonk14/typed-argument-parser.svg)](https://github.com/swansonk14/typed-argument-parser/blob/main/LICENSE.txt) 13 | 14 | Tap is a typed modernization of Python's [argparse](https://docs.python.org/3/library/argparse.html) library. 15 | 16 | Tap provides the following benefits: 17 | - Static type checking 18 | - Code completion 19 | - Source code navigation (e.g. go to definition and go to implementation) 20 | 21 | ![Tap](https://github.com/swansonk14/typed-argument-parser/raw/main/images/tap.png) 22 | 23 | See [this poster](https://docs.google.com/presentation/d/1AirN6gpiq4P1L8K003EsXmobVxP3A4AVEIR2KOEQN7Y/edit?usp=sharing), which we presented at [PyCon 2020](https://us.pycon.org/2020/), for a presentation of some of the relevant concepts we used to guide the development of Tap. 24 | 25 | As of version 1.8.0, Tap includes `tapify`, which runs functions or initializes classes with arguments parsed from the command line. We show an example below. 26 | 27 | ```python 28 | # square.py 29 | from tap import tapify 30 | 31 | def square(num: float) -> float: 32 | return num ** 2 33 | 34 | if __name__ == '__main__': 35 | print(f'The square of your number is {tapify(square)}.') 36 | ``` 37 | 38 | Running `python square.py --num 2` will print `The square of your number is 4.0.`. Please see [tapify](#tapify) for more details. 39 | 40 | ## Installation 41 | 42 | Tap requires Python 3.9+ 43 | 44 | To install Tap from PyPI run: 45 | 46 | ``` 47 | pip install typed-argument-parser 48 | ``` 49 | 50 |
51 | To install Tap from source, run the following commands: 52 | 53 | ``` 54 | git clone https://github.com/swansonk14/typed-argument-parser.git 55 | cd typed-argument-parser 56 | pip install -e . 57 | ``` 58 | 59 |
60 | 61 |
62 | To develop this package, install development requirements (in a virtual environment): 63 | 64 | ``` 65 | python -m pip install -e ".[dev]" 66 | ``` 67 | 68 | Style: 69 | - Please use [`black`](https://github.com/psf/black) formatting 70 | - Set your vertical line ruler to 121 71 | - Use [`flake8`](https://github.com/PyCQA/flake8) linting. 72 | 73 | To run tests, run: 74 | 75 | ``` 76 | pytest 77 | ``` 78 | 79 |
80 | 81 | ## Table of Contents 82 | 83 | * [Installation](#installation) 84 | * [Table of Contents](#table-of-contents) 85 | * [Tap is Python-native](#tap-is-python-native) 86 | * [Tap features](#tap-features) 87 | + [Arguments](#arguments) 88 | + [Tap help](#tap-help) 89 | + [Configuring arguments](#configuring-arguments) 90 | - [Adding special argument behavior](#adding-special-argument-behavior) 91 | - [Adding subparsers](#adding-subparsers) 92 | + [Types](#types) 93 | + [Argument processing](#argument-processing) 94 | + [Processing known args](#processing-known-args) 95 | + [Subclassing](#subclassing) 96 | + [Printing](#printing) 97 | + [Reproducibility](#reproducibility) 98 | + [Saving and loading arguments](#saving-and-loading-arguments) 99 | + [Loading from configuration files](#loading-from-configuration-files) 100 | * [tapify](#tapify) 101 | + [Examples](#examples) 102 | - [Function](#function) 103 | - [Class](#class) 104 | - [Dataclass](#dataclass) 105 | + [tapify help](#tapify-help) 106 | + [Command line vs explicit arguments](#command-line-vs-explicit-arguments) 107 | + [Known args](#known-args) 108 | * [Convert to a `Tap` class](#convert-to-a-tap-class) 109 | + [`to_tap_class` examples](#to_tap_class-examples) 110 | - [Simple](#simple) 111 | - [Complex](#complex) 112 | 113 | ## Tap is Python-native 114 | 115 | To see this, let's look at an example: 116 | 117 | ```python 118 | """main.py""" 119 | 120 | from tap import Tap 121 | 122 | class SimpleArgumentParser(Tap): 123 | name: str # Your name 124 | language: str = 'Python' # Programming language 125 | package: str = 'Tap' # Package name 126 | stars: int # Number of stars 127 | max_stars: int = 5 # Maximum stars 128 | 129 | args = SimpleArgumentParser().parse_args() 130 | 131 | print(f'My name is {args.name} and I give the {args.language} package ' 132 | f'{args.package} {args.stars}/{args.max_stars} stars!') 133 | ``` 134 | 135 | You use Tap the same way you use standard argparse. 136 | 137 | ``` 138 | >>> python main.py --name Jesse --stars 5 139 | My name is Jesse and I give the Python package Tap 5/5 stars! 140 | ``` 141 | 142 | The equivalent argparse code is: 143 | ```python 144 | """main.py""" 145 | 146 | from argparse import ArgumentParser 147 | 148 | parser = ArgumentParser() 149 | parser.add_argument('--name', type=str, required=True, 150 | help='Your name') 151 | parser.add_argument('--language', type=str, default='Python', 152 | help='Programming language') 153 | parser.add_argument('--package', type=str, default='Tap', 154 | help='Package name') 155 | parser.add_argument('--stars', type=int, required=True, 156 | help='Number of stars') 157 | parser.add_argument('--max_stars', type=int, default=5, 158 | help='Maximum stars') 159 | args = parser.parse_args() 160 | 161 | print(f'My name is {args.name} and I give the {args.language} package ' 162 | f'{args.package} {args.stars}/{args.max_stars} stars!') 163 | ``` 164 | 165 | The advantages of being Python-native include being able to: 166 | - Overwrite convenient built-in methods (e.g. `process_args` ensures consistency among arguments) 167 | - Add custom methods 168 | - Inherit from your own template classes 169 | 170 | ## Tap features 171 | 172 | Now we are going to highlight some of our favorite features and give examples of how they work in practice. 173 | 174 | ### Arguments 175 | 176 | Arguments are specified as class variables defined in a subclass of `Tap`. Variables defined as `name: type` are required arguments while variables defined as `name: type = value` are not required and default to the provided value. 177 | 178 | ```python 179 | class MyTap(Tap): 180 | required_arg: str 181 | default_arg: str = 'default value' 182 | ``` 183 | 184 | ### Tap help 185 | 186 | Single line and/or multiline comments which appear after the argument are automatically parsed into the help string provided when running `python main.py -h`. The type and default values of arguments are also provided in the help string. 187 | 188 | ```python 189 | """main.py""" 190 | 191 | from tap import Tap 192 | 193 | class MyTap(Tap): 194 | x: float # What am I? 195 | pi: float = 3.14 # I'm pi! 196 | """Pi is my favorite number!""" 197 | 198 | args = MyTap().parse_args() 199 | ``` 200 | 201 | Running `python main.py -h` results in the following: 202 | 203 | ``` 204 | >>> python main.py -h 205 | usage: demo.py --x X [--pi PI] [-h] 206 | 207 | optional arguments: 208 | --x X (float, required) What am I? 209 | --pi PI (float, default=3.14) I'm pi! Pi is my favorite number. 210 | -h, --help show this help message and exit 211 | ``` 212 | 213 | ### Configuring arguments 214 | To specify behavior beyond what can be specified using arguments as class variables, override the `configure` method. 215 | `configure` provides access to advanced argument parsing features such as `add_argument` and `add_subparser`. 216 | Since Tap is a wrapper around argparse, Tap provides all of the same functionality. 217 | We detail these two functions below. 218 | 219 | #### Adding special argument behavior 220 | In the `configure` method, call `self.add_argument` just as you would use argparse's `add_argument`. For example, 221 | 222 | ```python 223 | from tap import Tap 224 | 225 | class MyTap(Tap): 226 | positional_argument: str 227 | list_of_three_things: List[str] 228 | argument_with_really_long_name: int 229 | 230 | def configure(self): 231 | self.add_argument('positional_argument') 232 | self.add_argument('--list_of_three_things', nargs=3) 233 | self.add_argument('-arg', '--argument_with_really_long_name') 234 | ``` 235 | 236 | #### Adding subparsers 237 | To add a subparser, override the `configure` method and call `self.add_subparser`. Optionally, to specify keyword arguments (e.g., `help`) to the subparser collection, call `self.add_subparsers`. For example, 238 | 239 | ```python 240 | class SubparserA(Tap): 241 | bar: int # bar help 242 | 243 | class SubparserB(Tap): 244 | baz: Literal['X', 'Y', 'Z'] # baz help 245 | 246 | class Args(Tap): 247 | foo: bool = False # foo help 248 | 249 | def configure(self): 250 | self.add_subparsers(help='sub-command help') 251 | self.add_subparser('a', SubparserA, help='a help') 252 | self.add_subparser('b', SubparserB, help='b help') 253 | ``` 254 | 255 | ### Types 256 | 257 | Tap automatically handles all the following types: 258 | 259 | ```python 260 | str, int, float, bool 261 | Optional, Optional[str], Optional[int], Optional[float], Optional[bool] 262 | List, List[str], List[int], List[float], List[bool] 263 | Set, Set[str], Set[int], Set[float], Set[bool] 264 | Tuple, Tuple[Type1, Type2, etc.], Tuple[Type, ...] 265 | Literal 266 | ``` 267 | 268 | If you're using Python 3.9+, then you can replace `List` with `list`, `Set` with `set`, and `Tuple` with `tuple`. 269 | 270 | Tap also supports `Union`, but this requires additional specification (see [Union](#-union-) section below). 271 | 272 | Additionally, any type that can be instantiated with a string argument can be used. For example, in 273 | ```python 274 | from pathlib import Path 275 | from tap import Tap 276 | 277 | class Args(Tap): 278 | path: Path 279 | 280 | args = Args().parse_args() 281 | ``` 282 | `args.path` is a `Path` instance containing the string passed in through the command line. 283 | 284 | #### `str`, `int`, and `float` 285 | 286 | Each is automatically parsed to their respective types, just like argparse. 287 | 288 | #### `bool` 289 | 290 | If an argument `arg` is specified as `arg: bool` or `arg: bool = False`, then adding the `--arg` flag to the command line will set `arg` to `True`. If `arg` is specified as `arg: bool = True`, then adding `--arg` sets `arg` to `False`. 291 | 292 | Note that if the `Tap` instance is created with `explicit_bool=True`, then booleans can be specified on the command line as `--arg True` or `--arg False` rather than `--arg`. Additionally, booleans can be specified by prefixes of `True` and `False` with any capitalization as well as `1` or `0` (e.g. for True, `--arg tRu`, `--arg T`, `--arg 1` all suffice). 293 | 294 | #### `Optional` 295 | 296 | These arguments are parsed in exactly the same way as `str`, `int`, `float`, and `bool`. Note bools can be specified using the same rules as above and that `Optional` is equivalent to `Optional[str]`. 297 | 298 | #### `List` 299 | 300 | If an argument `arg` is a `List`, simply specify the values separated by spaces just as you would with regular argparse. For example, `--arg 1 2 3` parses to `arg = [1, 2, 3]`. 301 | 302 | #### `Set` 303 | 304 | Identical to `List` but parsed into a set rather than a list. 305 | 306 | #### `Tuple` 307 | 308 | Tuples can be used to specify a fixed number of arguments with specified types using the syntax `Tuple[Type1, Type2, etc.]` (e.g. `Tuple[str, int, bool, str]`). Tuples with a variable number of arguments are specified by `Tuple[Type, ...]` (e.g. `Tuple[int, ...]`). Note `Tuple` defaults to `Tuple[str, ...]`. 309 | 310 | #### `Literal` 311 | 312 | Literal is analagous to argparse's [choices](https://docs.python.org/3/library/argparse.html#choices), which specifies the values that an argument can take. For example, if arg can only be one of 'H', 1, False, or 1.0078 then you would specify that `arg: Literal['H', 1, False, 1.0078]`. For instance, `--arg False` assigns arg to False and `--arg True` throws error. 313 | 314 | #### `Union` 315 | 316 | Union types must include the `type` keyword argument in `add_argument` in order to specify which type to use, as in the example below. 317 | 318 | ```python 319 | def to_number(string: str) -> Union[float, int]: 320 | return float(string) if '.' in string else int(string) 321 | 322 | class MyTap(Tap): 323 | number: Union[float, int] 324 | 325 | def configure(self): 326 | self.add_argument('--number', type=to_number) 327 | ``` 328 | 329 | In Python 3.10+, `Union[Type1, Type2, etc.]` can be replaced with `Type1 | Type2 | etc.`, but the `type` keyword argument must still be provided in `add_argument`. 330 | 331 | #### Complex Types 332 | 333 | Tap can also support more complex types than the ones specified above. If the desired type is constructed with a single string as input, then the type can be specified directly without additional modifications. For example, 334 | 335 | ```python 336 | class Person: 337 | def __init__(self, name: str) -> None: 338 | self.name = name 339 | 340 | class Args(Tap): 341 | person: Person 342 | 343 | args = Args().parse_args('--person Tapper'.split()) 344 | print(args.person.name) # Tapper 345 | ``` 346 | 347 | If the desired type has a more complex constructor, then the `type` keyword argument must be provided in `add_argument`. For example, 348 | 349 | ```python 350 | class AgedPerson: 351 | def __init__(self, name: str, age: int) -> None: 352 | self.name = name 353 | self.age = age 354 | 355 | def to_aged_person(string: str) -> AgedPerson: 356 | name, age = string.split(',') 357 | return AgedPerson(name=name, age=int(age)) 358 | 359 | class Args(Tap): 360 | aged_person: AgedPerson 361 | 362 | def configure(self) -> None: 363 | self.add_argument('--aged_person', type=to_aged_person) 364 | 365 | args = Args().parse_args('--aged_person Tapper,27'.split()) 366 | print(f'{args.aged_person.name} is {args.aged_person.age}') # Tapper is 27 367 | ``` 368 | 369 | 370 | ### Argument processing 371 | 372 | With complex argument parsing, arguments often end up having interdependencies. This means that it may be necessary to disallow certain combinations of arguments or to modify some arguments based on other arguments. 373 | 374 | To handle such cases, simply override `process_args` and add the required logic. `process_args` is automatically called when `parse_args` is called. 375 | 376 | ```python 377 | class MyTap(Tap): 378 | package: str 379 | is_cool: bool 380 | stars: int 381 | 382 | def process_args(self): 383 | # Validate arguments 384 | if self.is_cool and self.stars < 4: 385 | raise ValueError('Cool packages cannot have fewer than 4 stars') 386 | 387 | # Modify arguments 388 | if self.package == 'Tap': 389 | self.is_cool = True 390 | self.stars = 5 391 | ``` 392 | 393 | ### Processing known args 394 | 395 | Similar to argparse's `parse_known_args`, Tap is capable of parsing only arguments that it is aware of without raising an error due to additional arguments. This can be done by calling `parse_args` with `known_only=True`. The remaining un-parsed arguments are then available by accessing the `extra_args` field of the Tap object. 396 | 397 | ```python 398 | class MyTap(Tap): 399 | package: str 400 | 401 | args = MyTap().parse_args(['--package', 'Tap', '--other_arg', 'value'], known_only=True) 402 | print(args.extra_args) # ['--other_arg', 'value'] 403 | ``` 404 | 405 | ### Subclassing 406 | 407 | It is sometimes useful to define a template Tap and then subclass it for different use cases. Since Tap is a native Python class, inheritance is built-in, making it easy to customize from a template Tap. 408 | 409 | In the example below, `StarsTap` and `AwardsTap` inherit the arguments (`package` and `is_cool`) and the methods (`process_args`) from `BaseTap`. 410 | 411 | ```python 412 | class BaseTap(Tap): 413 | package: str 414 | is_cool: bool 415 | 416 | def process_args(self): 417 | if self.package == 'Tap': 418 | self.is_cool = True 419 | 420 | 421 | class StarsTap(BaseTap): 422 | stars: int 423 | 424 | 425 | class AwardsTap(BaseTap): 426 | awards: List[str] 427 | ``` 428 | 429 | ### Printing 430 | 431 | Tap uses Python's [pretty printer](https://docs.python.org/3/library/pprint.html) to print out arguments in an easy-to-read format. 432 | 433 | ```python 434 | """main.py""" 435 | 436 | from tap import Tap 437 | from typing import List 438 | 439 | class MyTap(Tap): 440 | package: str 441 | is_cool: bool = True 442 | awards: List[str] = ['amazing', 'wow', 'incredible', 'awesome'] 443 | 444 | args = MyTap().parse_args() 445 | print(args) 446 | ``` 447 | 448 | Running `python main.py --package Tap` results in: 449 | 450 | ``` 451 | >>> python main.py 452 | {'awards': ['amazing', 'wow', 'incredible', 'awesome'], 453 | 'is_cool': True, 454 | 'package': 'Tap'} 455 | ``` 456 | 457 | ### Reproducibility 458 | 459 | Tap makes reproducibility easy, especially when running code in a git repo. 460 | 461 | #### Reproducibility info 462 | 463 | Specifically, Tap has a method called `get_reproducibility_info` that returns a dictionary containing all the information necessary to replicate the settings under which the code was run. This dictionary includes: 464 | - Python command 465 | - The Python command that was used to run the program 466 | - Ex. `python main.py --package Tap` 467 | - Time 468 | - The time when the command was run 469 | - Ex. `Thu Aug 15 00:09:13 2019` 470 | - Git root 471 | - The root of the git repo containing the code that was run 472 | - Ex. `/Users/swansonk14/typed-argument-parser` 473 | - Git url 474 | - The url to the git repo, specifically pointing to the current git hash (i.e. the hash of HEAD in the local repo) 475 | - Ex. [https://github.com/swansonk14/typed-argument-parser/tree/446cf046631d6bdf7cab6daec93bf7a02ac00998](https://github.com/swansonk14/typed-argument-parser/tree/446cf046631d6bdf7cab6daec93bf7a02ac00998) 476 | - Uncommitted changes 477 | - Whether there are any uncommitted changes in the git repo (i.e. whether the code is different from the code at the above git hash) 478 | - Ex. `True` or `False` 479 | 480 | ### Conversion Tap to and from dictionaries 481 | 482 | Tap has methods `as_dict` and `from_dict` that convert Tap objects to and from dictionaries. 483 | For example, 484 | 485 | ```python 486 | """main.py""" 487 | from tap import Tap 488 | 489 | class Args(Tap): 490 | package: str 491 | is_cool: bool = True 492 | stars: int = 5 493 | 494 | args = Args().parse_args(["--package", "Tap"]) 495 | 496 | args_data = args.as_dict() 497 | print(args_data) # {'package': 'Tap', 'is_cool': True, 'stars': 5} 498 | 499 | args_data['stars'] = 2000 500 | args = args.from_dict(args_data) 501 | print(args.stars) # 2000 502 | ``` 503 | 504 | Note that `as_dict` does not include attributes set directly on an instance (e.g., `arg` is not included even after setting `args.arg = "hi"` in the code above because `arg` is not an attribute of the `Args` class). 505 | Also note that `from_dict` ensures that all required arguments are set. 506 | 507 | ### Saving and loading arguments 508 | 509 | #### Save 510 | 511 | Tap has a method called `save` which saves all arguments, along with the reproducibility info, to a JSON file. 512 | 513 | ```python 514 | """main.py""" 515 | 516 | from tap import Tap 517 | 518 | class MyTap(Tap): 519 | package: str 520 | is_cool: bool = True 521 | stars: int = 5 522 | 523 | args = MyTap().parse_args() 524 | args.save('args.json') 525 | ``` 526 | 527 | After running `python main.py --package Tap`, the file `args.json` will contain: 528 | 529 | ``` 530 | { 531 | "is_cool": true, 532 | "package": "Tap", 533 | "reproducibility": { 534 | "command_line": "python main.py --package Tap", 535 | "git_has_uncommitted_changes": false, 536 | "git_root": "/Users/swansonk14/typed-argument-parser", 537 | "git_url": "https://github.com/swansonk14/typed-argument-parser/tree/446cf046631d6bdf7cab6daec93bf7a02ac00998", 538 | "time": "Thu Aug 15 00:18:31 2019" 539 | }, 540 | "stars": 5 541 | } 542 | ``` 543 | 544 | Note: More complex types will be encoded in JSON as a pickle string. 545 | 546 | #### Load 547 | > :exclamation: :warning:
548 | > Never call `args.load('args.json')` on untrusted files. Argument loading uses the `pickle` module to decode complex types automatically. Unpickling of untrusted data is a security risk and can lead to arbitrary code execution. See [the warning in the pickle docs](https://docs.python.org/3/library/pickle.html).
549 | > :exclamation: :warning: 550 | 551 | Arguments can be loaded from a JSON file rather than parsed from the command line. 552 | 553 | ```python 554 | """main.py""" 555 | 556 | from tap import Tap 557 | 558 | class MyTap(Tap): 559 | package: str 560 | is_cool: bool = True 561 | stars: int = 5 562 | 563 | args = MyTap() 564 | args.load('args.json') 565 | ``` 566 | 567 | Note: All required arguments (in this case `package`) must be present in the JSON file if not already set in the Tap object. 568 | 569 | #### Load from dict 570 | 571 | Arguments can be loaded from a Python dictionary rather than parsed from the command line. 572 | 573 | ```python 574 | """main.py""" 575 | 576 | from tap import Tap 577 | 578 | class MyTap(Tap): 579 | package: str 580 | is_cool: bool = True 581 | stars: int = 5 582 | 583 | args = MyTap() 584 | args.from_dict({ 585 | 'package': 'Tap', 586 | 'stars': 20 587 | }) 588 | ``` 589 | 590 | Note: As with `load`, all required arguments must be present in the dictionary if not already set in the Tap object. All values in the provided dictionary will overwrite values currently in the Tap object. 591 | 592 | ### Loading from configuration files 593 | Configuration files can be loaded along with arguments with the optional flag `config_files: List[str]`. Arguments passed in from the command line overwrite arguments from the configuration files. Arguments in configuration files that appear later in the list overwrite the arguments in previous configuration files. 594 | 595 | For example, if you have the config file `my_config.txt` 596 | ``` 597 | --arg1 1 598 | --arg2 two 599 | ``` 600 | then you can write 601 | ```python 602 | from tap import Tap 603 | 604 | class Args(Tap): 605 | arg1: int 606 | arg2: str 607 | 608 | args = Args(config_files=['my_config.txt']).parse_args() 609 | ``` 610 | 611 | Config files are parsed using `shlex.split` from the python standard library, which supports shell-style string quoting, as well as line-end comments starting with `#`. 612 | 613 | For example, if you have the config file `my_config_shlex.txt` 614 | ``` 615 | --arg1 21 # Important arg value 616 | 617 | # Multi-word quoted string 618 | --arg2 "two three four" 619 | ``` 620 | then you can write 621 | ```python 622 | from tap import Tap 623 | 624 | class Args(Tap): 625 | arg1: int 626 | arg2: str 627 | 628 | args = Args(config_files=['my_config_shlex.txt']).parse_args() 629 | ``` 630 | to get the resulting `args = {'arg1': 21, 'arg2': 'two three four'}` 631 | 632 | The legacy parsing behavior of using standard string split can be re-enabled by passing `legacy_config_parsing=True` to `parse_args`. 633 | 634 | ## tapify 635 | 636 | `tapify` makes it possible to run functions or initialize objects via command line arguments. This is inspired by Google's [Python Fire](https://github.com/google/python-fire), but `tapify` also automatically casts command line arguments to the appropriate types based on the type hints. Under the hood, `tapify` implicitly creates a Tap object and uses it to parse the command line arguments, which it then uses to run the function or initialize the class. We show a few examples below. 637 | 638 | ### Examples 639 | 640 | #### Function 641 | 642 | ```python 643 | # square_function.py 644 | from tap import tapify 645 | 646 | def square(num: float) -> float: 647 | """Square a number. 648 | 649 | :param num: The number to square. 650 | """ 651 | return num ** 2 652 | 653 | if __name__ == '__main__': 654 | squared = tapify(square) 655 | print(f'The square of your number is {squared}.') 656 | ``` 657 | 658 | Running `python square_function.py --num 5` prints `The square of your number is 25.0.`. 659 | 660 | #### Class 661 | 662 | ```python 663 | # square_class.py 664 | from tap import tapify 665 | 666 | class Squarer: 667 | def __init__(self, num: float) -> None: 668 | """Initialize the Squarer with a number to square. 669 | 670 | :param num: The number to square. 671 | """ 672 | self.num = num 673 | 674 | def get_square(self) -> float: 675 | """Get the square of the number.""" 676 | return self.num ** 2 677 | 678 | if __name__ == '__main__': 679 | squarer = tapify(Squarer) 680 | print(f'The square of your number is {squarer.get_square()}.') 681 | ``` 682 | 683 | Running `python square_class.py --num 2` prints `The square of your number is 4.0.`. 684 | 685 | #### Dataclass 686 | 687 | ```python 688 | # square_dataclass.py 689 | from dataclasses import dataclass 690 | 691 | from tap import tapify 692 | 693 | @dataclass 694 | class Squarer: 695 | """Squarer with a number to square. 696 | 697 | :param num: The number to square. 698 | """ 699 | num: float 700 | 701 | def get_square(self) -> float: 702 | """Get the square of the number.""" 703 | return self.num ** 2 704 | 705 | if __name__ == '__main__': 706 | squarer = tapify(Squarer) 707 | print(f'The square of your number is {squarer.get_square()}.') 708 | ``` 709 | 710 | Running `python square_dataclass.py --num -1` prints `The square of your number is 1.0.`. 711 | 712 |
713 | Argument descriptions 714 | 715 | For dataclasses, the argument's description (which is displayed in the `-h` help message) can either be specified in the 716 | class docstring or the field's description in `metadata`. If both are specified, the description from the docstring is 717 | used. In the example below, the description is provided in `metadata`. 718 | 719 | ```python 720 | # square_dataclass.py 721 | from dataclasses import dataclass, field 722 | 723 | from tap import tapify 724 | 725 | @dataclass 726 | class Squarer: 727 | """Squarer with a number to square. 728 | """ 729 | num: float = field(metadata={"description": "The number to square."}) 730 | 731 | def get_square(self) -> float: 732 | """Get the square of the number.""" 733 | return self.num ** 2 734 | 735 | if __name__ == '__main__': 736 | squarer = tapify(Squarer) 737 | print(f'The square of your number is {squarer.get_square()}.') 738 | ``` 739 | 740 |
741 | 742 | #### Pydantic 743 | 744 | Pydantic [Models](https://docs.pydantic.dev/latest/concepts/models/) and 745 | [dataclasses](https://docs.pydantic.dev/latest/concepts/dataclasses/) can be `tapify`d. 746 | 747 | ```python 748 | # square_pydantic.py 749 | from pydantic import BaseModel, Field 750 | 751 | from tap import tapify 752 | 753 | class Squarer(BaseModel): 754 | """Squarer with a number to square. 755 | """ 756 | num: float = Field(description="The number to square.") 757 | 758 | def get_square(self) -> float: 759 | """Get the square of the number.""" 760 | return self.num ** 2 761 | 762 | if __name__ == '__main__': 763 | squarer = tapify(Squarer) 764 | print(f'The square of your number is {squarer.get_square()}.') 765 | ``` 766 | 767 |
768 | Argument descriptions 769 | 770 | For Pydantic v2 models and dataclasses, the argument's description (which is displayed in the `-h` help message) can 771 | either be specified in the class docstring or the field's `description`. If both are specified, the description from the 772 | docstring is used. In the example below, the description is provided in the docstring. 773 | 774 | For Pydantic v1 models and dataclasses, the argument's description must be provided in the class docstring: 775 | 776 | ```python 777 | # square_pydantic.py 778 | from pydantic import BaseModel 779 | 780 | from tap import tapify 781 | 782 | class Squarer(BaseModel): 783 | """Squarer with a number to square. 784 | 785 | :param num: The number to square. 786 | """ 787 | num: float 788 | 789 | def get_square(self) -> float: 790 | """Get the square of the number.""" 791 | return self.num ** 2 792 | 793 | if __name__ == '__main__': 794 | squarer = tapify(Squarer) 795 | print(f'The square of your number is {squarer.get_square()}.') 796 | ``` 797 | 798 |
799 | 800 | ### tapify help 801 | 802 | The help string on the command line is set based on the docstring for the function or class. For example, running `python square_function.py -h` will print: 803 | 804 | ``` 805 | usage: square_function.py [-h] --num NUM 806 | 807 | Square a number. 808 | 809 | options: 810 | -h, --help show this help message and exit 811 | --num NUM (float, required) The number to square. 812 | ``` 813 | 814 | Note that for classes, if there is a docstring in the `__init__` method, then `tapify` sets the help string description to that docstring. Otherwise, it uses the docstring from the top of the class. 815 | 816 | ### Command line vs explicit arguments 817 | 818 | `tapify` can simultaneously use both arguments passed from the command line and arguments passed in explicitly in the `tapify` call. Arguments provided in the `tapify` call override function defaults, and arguments provided via the command line override both arguments provided in the `tapify` call and function defaults. We show an example below. 819 | 820 | ```python 821 | # add.py 822 | from tap import tapify 823 | 824 | def add(num_1: float, num_2: float = 0.0, num_3: float = 0.0) -> float: 825 | """Add numbers. 826 | 827 | :param num_1: The first number. 828 | :param num_2: The second number. 829 | :param num_3: The third number. 830 | """ 831 | return num_1 + num_2 + num_3 832 | 833 | if __name__ == '__main__': 834 | added = tapify(add, num_2=2.2, num_3=4.1) 835 | print(f'The sum of your numbers is {added}.') 836 | ``` 837 | 838 | Running `python add.py --num_1 1.0 --num_2 0.9` prints `The sum of your numbers is 6.0.`. (Note that `add` took `num_1 = 1.0` and `num_2 = 0.9` from the command line and `num_3=4.1` from the `tapify` call due to the order of precedence.) 839 | 840 | ### Known args 841 | 842 | Calling `tapify` with `known_only=True` allows `tapify` to ignore additional arguments from the command line that are not needed for the function or class. If `known_only=False` (the default), then `tapify` will raise an error when additional arguments are provided. We show an example below where `known_only=True` might be useful for running multiple `tapify` calls. 843 | 844 | ```python 845 | # person.py 846 | from tap import tapify 847 | 848 | def print_name(name: str) -> None: 849 | """Print a person's name. 850 | 851 | :param name: A person's name. 852 | """ 853 | print(f'My name is {name}.') 854 | 855 | def print_age(age: int) -> None: 856 | """Print a person's age. 857 | 858 | :param name: A person's age. 859 | """ 860 | print(f'My age is {age}.') 861 | 862 | if __name__ == '__main__': 863 | tapify(print_name, known_only=True) 864 | tapify(print_age, known_only=True) 865 | ``` 866 | 867 | Running `python person.py --name Jesse --age 1` prints `My name is Jesse.` followed by `My age is 1.`. Without `known_only=True`, the `tapify` calls would raise an error due to the extra argument. 868 | 869 | ### Explicit boolean arguments 870 | 871 | Tapify supports explicit specification of boolean arguments (see [bool](#bool) for more details). By default, `explicit_bool=False` and it can be set with `tapify(..., explicit_bool=True)`. 872 | 873 | ## Convert to a `Tap` class 874 | 875 | `to_tap_class` turns a function or class into a `Tap` class. The returned class can be [subclassed](#subclassing) to add 876 | special argument behavior. For example, you can override [`configure`](#configuring-arguments) and 877 | [`process_args`](#argument-processing). 878 | 879 | If the object can be `tapify`d, then it can be `to_tap_class`d, and vice-versa. `to_tap_class` provides full control 880 | over argument parsing. 881 | 882 | ### `to_tap_class` examples 883 | 884 | #### Simple 885 | 886 | ```python 887 | # main.py 888 | """ 889 | My script description 890 | """ 891 | 892 | from pydantic import BaseModel 893 | 894 | from tap import to_tap_class 895 | 896 | class Project(BaseModel): 897 | package: str 898 | is_cool: bool = True 899 | stars: int = 5 900 | 901 | if __name__ == "__main__": 902 | ProjectTap = to_tap_class(Project) 903 | tap = ProjectTap(description=__doc__) # from the top of this script 904 | args = tap.parse_args() 905 | project = Project(**args.as_dict()) 906 | print(f"Project instance: {project}") 907 | ``` 908 | 909 | Running `python main.py --package tap` will print `Project instance: package='tap' is_cool=True stars=5`. 910 | 911 | ### Complex 912 | 913 | The general pattern is: 914 | 915 | ```python 916 | from tap import to_tap_class 917 | 918 | class MyCustomTap(to_tap_class(my_class_or_function)): 919 | # Special argument behavior, e.g., override configure and/or process_args 920 | ``` 921 | 922 | Please see `demo_data_model.py` for an example of overriding [`configure`](#configuring-arguments) and 923 | [`process_args`](#argument-processing). 924 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from tap import Tap 3 | 4 | 5 | def add_one(num: int) -> int: 6 | return num + 1 7 | 8 | 9 | # ----- ArgumentParser ----- 10 | 11 | parser = ArgumentParser() 12 | parser.add_argument("--rnn", type=str, required=True, help="RNN type") 13 | parser.add_argument("--hidden_size", type=int, default=300, help="Hidden size") 14 | parser.add_argument("--dropout", type=float, default=0.2, help="Dropout probability") 15 | 16 | 17 | args = parser.parse_args() 18 | 19 | print(args.hidden_size) # no autocomplete, no type inference, no source code navigation 20 | 21 | add_one(args.rnn) # no static type checking 22 | 23 | 24 | # ----- Tap ----- 25 | 26 | 27 | class MyTap(Tap): 28 | rnn: str # RNN type 29 | hidden_size: int = 300 # Hidden size 30 | dropout: float = 0.2 # Dropout probability 31 | 32 | 33 | args = MyTap().parse_args() 34 | 35 | print(args.hidden_size) # autocomplete, type inference, source code navigation 36 | 37 | add_one(args.rnn) # static type checking 38 | -------------------------------------------------------------------------------- /demo_data_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Works for Pydantic v1 and v2. 3 | 4 | Example commands: 5 | 6 | python demo_data_model.py -h 7 | 8 | python demo_data_model.py \ 9 | --arg_int 1 \ 10 | --arg_list x y z \ 11 | --argument_with_really_long_name 3 12 | 13 | python demo_data_model.py \ 14 | --arg_int 1 \ 15 | --arg_list x y z \ 16 | --arg_bool \ 17 | -arg 3.14 18 | """ 19 | from typing import List, Literal, Optional, Union 20 | 21 | from pydantic import BaseModel, Field 22 | from tap import tapify, to_tap_class, Tap 23 | 24 | 25 | class Model(BaseModel): 26 | """ 27 | My Pydantic Model which contains script args. 28 | """ 29 | 30 | arg_int: int = Field(description="some integer") 31 | arg_bool: bool = Field(default=True) 32 | arg_list: Optional[List[str]] = Field(default=None, description="some list of strings") 33 | 34 | 35 | def main(model: Model) -> None: 36 | print("Parsed args into Model:") 37 | print(model) 38 | 39 | 40 | def to_number(string: str) -> Union[float, int]: 41 | return float(string) if "." in string else int(string) 42 | 43 | 44 | class ModelTap(to_tap_class(Model)): 45 | # You can supply additional arguments here 46 | argument_with_really_long_name: Union[float, int] = 3 47 | "This argument has a long name and will be aliased with a short one" 48 | 49 | def configure(self) -> None: 50 | # You can still add special argument behavior 51 | self.add_argument("-arg", "--argument_with_really_long_name", type=to_number) 52 | 53 | def process_args(self) -> None: 54 | # You can still validate and modify arguments 55 | # (You should do this in the Pydantic Model. I'm just demonstrating that this functionality is still possible) 56 | if self.argument_with_really_long_name > 4: 57 | raise ValueError("argument_with_really_long_name cannot be > 4") 58 | 59 | # No auto-complete (and other niceties) for the super class attributes b/c this is a dynamic subclass. Sorry 60 | if self.arg_bool and self.arg_list is not None: 61 | self.arg_list.append("processed") 62 | 63 | 64 | # class SubparserA(Tap): 65 | # bar: int # bar help 66 | 67 | 68 | # class SubparserB(Tap): 69 | # baz: Literal["X", "Y", "Z"] # baz help 70 | 71 | 72 | # class ModelTapWithSubparsing(to_tap_class(Model)): 73 | # foo: bool = False # foo help 74 | 75 | # def configure(self): 76 | # self.add_subparsers(help="sub-command help") 77 | # self.add_subparser("a", SubparserA, help="a help", description="Description (a)") 78 | # self.add_subparser("b", SubparserB, help="b help") 79 | 80 | 81 | if __name__ == "__main__": 82 | # You don't have to subclass tap_class_from_data_model(Model) if you just want a plain argument parser: 83 | # ModelTap = to_tap_class(Model) 84 | args = ModelTap(description="Script description").parse_args() 85 | # args = ModelTapWithSubparsing(description="Script description").parse_args() 86 | print("Parsed args:") 87 | print(args) 88 | # Run the main function 89 | model = Model(**args.as_dict()) 90 | main(model) 91 | 92 | 93 | # tapify works with Model. It immediately returns a Model instance instead of a Tap class 94 | # if __name__ == "__main__": 95 | # model = tapify(Model) 96 | # print(model) 97 | -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swansonk14/typed-argument-parser/619ee8bc2ec3798c203e6641149f93402701ca5f/images/logo.png -------------------------------------------------------------------------------- /images/tap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swansonk14/typed-argument-parser/619ee8bc2ec3798c203e6641149f93402701ca5f/images/tap.png -------------------------------------------------------------------------------- /images/tap_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swansonk14/typed-argument-parser/619ee8bc2ec3798c203e6641149f93402701ca5f/images/tap_logo.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 61.0.0", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "typed-argument-parser" 7 | dynamic = ["version"] 8 | authors = [ 9 | {name = "Jesse Michel", email = "jessem.michel@gmail.com" }, 10 | {name = "Kyle Swanson", email = "swansonk.14@gmail.com" }, 11 | ] 12 | maintainers = [ 13 | {name = "Jesse Michel", email = "jessem.michel@gmail.com" }, 14 | {name = "Kyle Swanson", email = "swansonk.14@gmail.com" }, 15 | ] 16 | description = "Typed Argument Parser" 17 | readme = "README.md" 18 | license = { file = "LICENSE.txt" } 19 | dependencies = [ 20 | "docstring-parser >= 0.15", 21 | "packaging", 22 | "typing-inspect >= 0.7.1", 23 | ] 24 | requires-python = ">=3.9" 25 | classifiers = [ 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Programming Language :: Python :: 3.12", 31 | "Programming Language :: Python :: 3.13", 32 | "License :: OSI Approved :: MIT License", 33 | "Operating System :: OS Independent", 34 | "Typing :: Typed", 35 | ] 36 | keywords = [ 37 | "typing", 38 | "argument parser", 39 | "python", 40 | ] 41 | 42 | [project.optional-dependencies] 43 | dev-no-pydantic = [ 44 | "pytest", 45 | "pytest-cov", 46 | "flake8", 47 | ] 48 | dev = [ 49 | "typed-argument-parser[dev-no-pydantic]", 50 | "pydantic >= 2.5.0", 51 | ] 52 | 53 | [tool.setuptools] 54 | package-dir = {"" = "src"} 55 | 56 | [tool.setuptools.packages.find] 57 | where = ["src"] 58 | 59 | [tool.setuptools.dynamic] 60 | version = {attr = "tap.__version__"} 61 | 62 | [tool.setuptools.package-data] 63 | tap = ["py.typed"] 64 | 65 | [project.urls] 66 | Homepage = "https://github.com/swansonk14/typed-argument-parser" 67 | Issues = "https://github.com/swansonk14/typed-argument-parser/issues" 68 | -------------------------------------------------------------------------------- /src/tap/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Typed Argument Parser 3 | """ 4 | 5 | __version__ = "1.10.1" 6 | 7 | from argparse import ArgumentError, ArgumentTypeError 8 | from tap.tap import Tap 9 | from tap.tapify import tapify, to_tap_class 10 | 11 | __all__ = [ 12 | "ArgumentError", 13 | "ArgumentTypeError", 14 | "Tap", 15 | "tapify", 16 | "to_tap_class", 17 | "__version__", 18 | ] 19 | -------------------------------------------------------------------------------- /src/tap/py.typed: -------------------------------------------------------------------------------- 1 | # Marker file for PEP 561. 2 | -------------------------------------------------------------------------------- /src/tap/tap.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import time 4 | from argparse import ArgumentParser, ArgumentTypeError 5 | from copy import deepcopy 6 | from functools import partial 7 | from pathlib import Path 8 | from pprint import pformat 9 | from shlex import quote, split 10 | from types import MethodType 11 | from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, TypeVar, Union, get_type_hints 12 | from typing_inspect import is_literal_type 13 | 14 | from tap.utils import ( 15 | get_class_variables, 16 | get_args, 17 | get_argument_name, 18 | get_dest, 19 | get_origin, 20 | GitInfo, 21 | is_option_arg, 22 | is_positional_arg, 23 | type_to_str, 24 | get_literals, 25 | boolean_type, 26 | TupleTypeEnforcer, 27 | define_python_object_encoder, 28 | as_python_object, 29 | enforce_reproducibility, 30 | PathLike, 31 | ) 32 | 33 | if sys.version_info >= (3, 10): 34 | from types import UnionType 35 | 36 | 37 | # Constants 38 | EMPTY_TYPE = get_args(List)[0] if len(get_args(List)) > 0 else tuple() 39 | BOXED_COLLECTION_TYPES = {List, list, Set, set, Tuple, tuple} 40 | UNION_TYPES = {Union} | ({UnionType} if sys.version_info >= (3, 10) else set()) 41 | OPTIONAL_TYPES = {Optional} | UNION_TYPES 42 | BOXED_TYPES = BOXED_COLLECTION_TYPES | OPTIONAL_TYPES 43 | 44 | 45 | TapType = TypeVar("TapType", bound="Tap") 46 | 47 | 48 | class Tap(ArgumentParser): 49 | """Tap is a typed argument parser that wraps Python's built-in ArgumentParser.""" 50 | 51 | def __init__( 52 | self, 53 | *args, 54 | underscores_to_dashes: bool = False, 55 | explicit_bool: bool = False, 56 | config_files: Optional[list[PathLike]] = None, 57 | **kwargs, 58 | ) -> None: 59 | """Initializes the Tap instance. 60 | 61 | :param args: Arguments passed to the super class ArgumentParser. 62 | :param underscores_to_dashes: If True, convert underscores in flags to dashes. 63 | :param explicit_bool: Booleans can be specified on the command line as "--arg True" or "--arg False" 64 | rather than "--arg". Additionally, booleans can be specified by prefixes of True and False 65 | with any capitalization as well as 1 or 0. 66 | :param config_files: A list of paths to configuration files containing the command line arguments 67 | (e.g., '--arg1 a1 --arg2 a2'). Arguments passed in from the command line 68 | overwrite arguments from the configuration files. Arguments in configuration files 69 | that appear later in the list overwrite the arguments in previous configuration files. 70 | :param kwargs: Keyword arguments passed to the super class ArgumentParser. 71 | """ 72 | # Whether the Tap object has been initialized 73 | self._initialized = False 74 | 75 | # Whether boolean flags have to be explicitly set to True or False 76 | self._explicit_bool = explicit_bool 77 | 78 | # Whether we convert underscores in the flag names to dashes 79 | self._underscores_to_dashes = underscores_to_dashes 80 | 81 | # Whether the arguments have been parsed (i.e. if parse_args has been called) 82 | self._parsed = False 83 | 84 | # Set extra arguments to empty list 85 | self.extra_args = [] 86 | 87 | # Create argument buffer 88 | self.argument_buffer = {} 89 | 90 | # Create a place to put the subparsers 91 | self._subparser_buffer: list[tuple[str, type, dict[str, Any]]] = [] 92 | 93 | # Get class variables help strings from the comments 94 | self.class_variables = self._get_class_variables() 95 | 96 | # Get annotations from self and all super classes up through tap 97 | self._annotations = self._get_annotations() 98 | 99 | # Set the default description to be the docstring 100 | kwargs.setdefault("description", self.__doc__) 101 | 102 | # Initialize the super class, i.e. ArgumentParser 103 | super(Tap, self).__init__(*args, **kwargs) 104 | 105 | # Stores the subparsers 106 | self._subparsers = None 107 | 108 | # Load in the configuration files 109 | self.args_from_configs = self._load_from_config_files(config_files) 110 | 111 | # Perform additional configuration such as adding arguments or adding subparsers 112 | self._configure() 113 | 114 | # Indicate that initialization is complete 115 | self._initialized = True 116 | 117 | def _add_argument(self, *name_or_flags, **kwargs) -> None: 118 | """Adds an argument to self (i.e. the super class ArgumentParser). 119 | 120 | Sets the following attributes of kwargs when not explicitly provided: 121 | - type: Set to the type annotation of the argument. 122 | - default: Set to the default value of the argument (if provided). 123 | - required: True if a default value of the argument is not provided, False otherwise. 124 | - action: Set to "store_true" if the argument is a required bool or a bool with default value False. 125 | Set to "store_false" if the argument is a bool with default value True. 126 | - nargs: Set to "*" if the type annotation is List[str], List[int], or List[float]. 127 | - help: Set to the argument documentation from the class docstring. 128 | 129 | :param name_or_flags: Either a name or a list of option strings, e.g. foo or -f, --foo. 130 | :param kwargs: Keyword arguments. 131 | """ 132 | # Set explicit bool 133 | explicit_bool = self._explicit_bool 134 | 135 | # Get variable name 136 | variable = get_argument_name(*name_or_flags) 137 | 138 | if self._underscores_to_dashes: 139 | variable = variable.replace("-", "_") 140 | 141 | # Get default if not specified 142 | if hasattr(self, variable): 143 | kwargs["default"] = kwargs.get("default", getattr(self, variable)) 144 | 145 | # Set required if option arg 146 | if ( 147 | is_option_arg(*name_or_flags) 148 | and variable != "help" 149 | and "default" not in kwargs 150 | and kwargs.get("action") != "version" 151 | ): 152 | kwargs["required"] = kwargs.get("required", not hasattr(self, variable)) 153 | 154 | # Set help if necessary 155 | if "help" not in kwargs: 156 | kwargs["help"] = "(" 157 | 158 | # Type 159 | if variable in self._annotations: 160 | kwargs["help"] += type_to_str(self._annotations[variable]) + ", " 161 | 162 | # Required/default 163 | if kwargs.get("required", False) or is_positional_arg(*name_or_flags): 164 | kwargs["help"] += "required" 165 | else: 166 | kwargs["help"] += f'default={kwargs.get("default", None)}' 167 | 168 | kwargs["help"] += ")" 169 | 170 | # Description 171 | if variable in self.class_variables: 172 | kwargs["help"] += " " + self.class_variables[variable]["comment"] 173 | 174 | # Set other kwargs where not provided 175 | if variable in self._annotations: 176 | # Get type annotation 177 | var_type = self._annotations[variable] 178 | 179 | # If type is not explicitly provided, set it if it's one of our supported default types 180 | if "type" not in kwargs: 181 | # Unbox Union[type] (Optional[type]) and set var_type = type 182 | if get_origin(var_type) in OPTIONAL_TYPES: 183 | var_args = get_args(var_type) 184 | 185 | # If type is Union or Optional without inner types, set type to equivalent of Optional[str] 186 | if len(var_args) == 0: 187 | var_args = (str, type(None)) 188 | 189 | # Raise error if type function is not explicitly provided for Union types (not including Optionals) 190 | if get_origin(var_type) in UNION_TYPES and not (len(var_args) == 2 and var_args[1] == type(None)): 191 | raise ArgumentTypeError( 192 | "For Union types, you must include an explicit type function in the configure method. " 193 | "For example,\n\n" 194 | "def to_number(string: str) -> Union[float, int]:\n" 195 | " return float(string) if '.' in string else int(string)\n\n" 196 | "class Args(Tap):\n" 197 | " arg: Union[float, int]\n" 198 | "\n" 199 | " def configure(self) -> None:\n" 200 | " self.add_argument('--arg', type=to_number)" 201 | ) 202 | 203 | if len(var_args) > 0: 204 | var_type = var_args[0] 205 | 206 | # If var_type is tuple as in Python 3.6, change to a typing type 207 | # (e.g., (typing.List, ) ==> typing.List[bool]) 208 | if isinstance(var_type, tuple): 209 | var_type = var_type[0][var_type[1:]] 210 | 211 | explicit_bool = True 212 | 213 | # First check whether it is a literal type or a boxed literal type 214 | if is_literal_type(var_type): 215 | var_type, kwargs["choices"] = get_literals(var_type, variable) 216 | 217 | elif ( 218 | get_origin(var_type) in (List, list, Set, set) 219 | and len(get_args(var_type)) > 0 220 | and is_literal_type(get_args(var_type)[0]) 221 | ): 222 | var_type, kwargs["choices"] = get_literals(get_args(var_type)[0], variable) 223 | if kwargs.get("action") not in {"append", "append_const"}: 224 | kwargs["nargs"] = kwargs.get("nargs", "*") 225 | 226 | # Handle Tuple type (with type args) by extracting types of Tuple elements and enforcing them 227 | elif get_origin(var_type) in (Tuple, tuple) and len(get_args(var_type)) > 0: 228 | loop = False 229 | types = list(get_args(var_type)) 230 | 231 | # Handle Tuple[type, ...] 232 | if len(types) == 2 and types[1] == Ellipsis: 233 | types = types[0:1] 234 | loop = True 235 | kwargs["nargs"] = "*" 236 | # Handle Tuple[()] 237 | elif len(types) == 1 and types[0] == tuple(): 238 | types = [str] 239 | loop = True 240 | kwargs["nargs"] = "*" 241 | else: 242 | kwargs["nargs"] = len(types) 243 | 244 | # Handle Literal types 245 | types = [get_literals(tp, variable)[0] if is_literal_type(tp) else tp for tp in types] 246 | 247 | var_type = TupleTypeEnforcer(types=types, loop=loop) 248 | 249 | if get_origin(var_type) in BOXED_TYPES: 250 | # If List or Set or Tuple type, set nargs 251 | if get_origin(var_type) in BOXED_COLLECTION_TYPES and kwargs.get("action") not in { 252 | "append", 253 | "append_const", 254 | }: 255 | kwargs["nargs"] = kwargs.get("nargs", "*") 256 | 257 | # Extract boxed type for Optional, List, Set 258 | arg_types = get_args(var_type) 259 | 260 | # Set defaults type to str for Type and Type[()] 261 | if len(arg_types) == 0 or arg_types[0] == EMPTY_TYPE: 262 | var_type = str 263 | else: 264 | var_type = arg_types[0] 265 | 266 | # Handle the cases of List[bool], Set[bool], Tuple[bool] 267 | if var_type == bool: 268 | var_type = boolean_type 269 | 270 | # If bool then set action, otherwise set type 271 | if var_type == bool: 272 | if explicit_bool: 273 | kwargs["type"] = boolean_type 274 | kwargs["choices"] = [True, False] # this makes the help message more helpful 275 | else: 276 | action_cond = "true" if kwargs.get("required", False) or not kwargs["default"] else "false" 277 | kwargs["action"] = kwargs.get("action", f"store_{action_cond}") 278 | elif kwargs.get("action") not in {"count", "append_const"}: 279 | kwargs["type"] = var_type 280 | 281 | if self._underscores_to_dashes: 282 | # Replace "_" with "-" for arguments that aren't positional 283 | name_or_flags = tuple( 284 | name_or_flag.replace("_", "-") if name_or_flag.startswith("-") else name_or_flag 285 | for name_or_flag in name_or_flags 286 | ) 287 | 288 | # Deepcopy default to prevent mutation of values 289 | if "default" in kwargs: 290 | kwargs["default"] = deepcopy(kwargs["default"]) 291 | 292 | super(Tap, self).add_argument(*name_or_flags, **kwargs) 293 | 294 | def add_argument(self, *name_or_flags, **kwargs) -> None: 295 | """Adds an argument to the argument buffer, which will later be passed to _add_argument.""" 296 | if self._initialized: 297 | raise ValueError( 298 | "add_argument cannot be called after initialization. " 299 | "Arguments must be added either as class variables or by overriding " 300 | "configure and including a self.add_argument call there." 301 | ) 302 | 303 | variable = get_argument_name(*name_or_flags).replace("-", "_") 304 | self.argument_buffer[variable] = (name_or_flags, kwargs) 305 | 306 | def _add_arguments(self) -> None: 307 | """Add arguments to self in the order they are defined as class variables (so the help string is in order).""" 308 | # Add class variables (in order) 309 | for variable in self.class_variables: 310 | if variable in self.argument_buffer: 311 | name_or_flags, kwargs = self.argument_buffer[variable] 312 | self._add_argument(*name_or_flags, **kwargs) 313 | else: 314 | self._add_argument(f"--{variable}") 315 | 316 | # Add any arguments that were added manually in configure but aren't class variables (in order) 317 | for variable, (name_or_flags, kwargs) in self.argument_buffer.items(): 318 | if variable not in self.class_variables: 319 | self._add_argument(*name_or_flags, **kwargs) 320 | 321 | def process_args(self) -> None: 322 | """Perform additional argument processing and/or validation.""" 323 | pass 324 | 325 | def add_subparser(self, flag: str, subparser_type: type, **kwargs) -> None: 326 | """Add a subparser to the collection of subparsers""" 327 | help_desc = kwargs.get("help", subparser_type.__doc__) 328 | kwargs["help"] = help_desc 329 | 330 | self._subparser_buffer.append((flag, subparser_type, kwargs)) 331 | 332 | def _add_subparsers(self) -> None: 333 | """Add each of the subparsers to the Tap object.""" 334 | # Initialize the _subparsers object if not already created 335 | if self._subparsers is None and len(self._subparser_buffer) > 0: 336 | self._subparsers = super(Tap, self).add_subparsers() 337 | 338 | # Load each subparser 339 | for flag, subparser_type, kwargs in self._subparser_buffer: 340 | self._subparsers._parser_class = partial(subparser_type, underscores_to_dashes=self._underscores_to_dashes) 341 | self._subparsers.add_parser(flag, **kwargs) 342 | 343 | def add_subparsers(self, **kwargs) -> None: 344 | self._subparsers = super().add_subparsers(**kwargs) 345 | 346 | def _configure(self) -> None: 347 | """Executes the user-defined configuration.""" 348 | # Call the user-defined configuration 349 | self.configure() 350 | 351 | # Add arguments to self 352 | self._add_arguments() 353 | 354 | # Add subparsers to self 355 | self._add_subparsers() 356 | 357 | def configure(self) -> None: 358 | """Overwrite this method to configure the parser during initialization. 359 | 360 | For example, 361 | self.add_argument('--sum', 362 | dest='accumulate', 363 | action='store_const', 364 | const=sum, 365 | default=max) 366 | self.add_subparsers(help='sub-command help') 367 | self.add_subparser('a', SubparserA, help='a help') 368 | """ 369 | pass 370 | 371 | @staticmethod 372 | def get_reproducibility_info(repo_path: Optional[PathLike] = None) -> dict[str, str]: 373 | """Gets a dictionary of reproducibility information. 374 | 375 | Reproducibility information always includes: 376 | - command_line: The command line command used to execute the code. 377 | - time: The current time. 378 | 379 | If git is installed, reproducibility information also includes: 380 | - git_root: The root of the git repo where the command is run. 381 | - git_url: The url of the current hash of the git repo where the command is run. 382 | Ex. https://github.com/swansonk14/rationale-alignment/tree/. 383 | If it is a local repo, the url is None. 384 | - git_has_uncommitted_changes: Whether the current git repo has uncommitted changes. 385 | 386 | :param repo_path: Path to the git repo to examine for reproducibility info. 387 | If None, uses the git repo of the Python file that is run. 388 | :return: A dictionary of reproducibility information. 389 | """ 390 | # Get the path to the Python file that is being run 391 | if repo_path is None: 392 | repo_path = (Path.cwd() / Path(sys.argv[0]).parent).resolve() 393 | 394 | reproducibility = { 395 | "command_line": f'python {" ".join(quote(arg) for arg in sys.argv)}', 396 | "time": time.strftime("%c"), 397 | } 398 | 399 | git_info = GitInfo(repo_path=repo_path) 400 | 401 | if git_info.has_git(): 402 | reproducibility["git_root"] = git_info.get_git_root() 403 | reproducibility["git_url"] = git_info.get_git_url(commit_hash=True) 404 | reproducibility["git_has_uncommitted_changes"] = str(git_info.has_uncommitted_changes()) 405 | 406 | return reproducibility 407 | 408 | def _log_all(self, repo_path: Optional[PathLike] = None) -> dict[str, Any]: 409 | """Gets all arguments along with reproducibility information. 410 | 411 | :param repo_path: Path to the git repo to examine for reproducibility info. 412 | If None, uses the git repo of the Python file that is run. 413 | :return: A dictionary containing all arguments along with reproducibility information. 414 | """ 415 | arg_log = self.as_dict() 416 | arg_log["reproducibility"] = self.get_reproducibility_info(repo_path=repo_path) 417 | 418 | return arg_log 419 | 420 | def parse_args( 421 | self: TapType, 422 | args: Optional[Sequence[str]] = None, 423 | known_only: bool = False, 424 | legacy_config_parsing: bool = False, 425 | ) -> TapType: 426 | """Parses arguments, sets attributes of self equal to the parsed arguments, and processes arguments. 427 | 428 | :param args: List of strings to parse. The default is taken from `sys.argv`. 429 | :param known_only: If true, ignores extra arguments and only parses known arguments. 430 | Unparsed arguments are saved to self.extra_args. 431 | :param legacy_config_parsing: If true, config files are parsed using `str.split` instead of `shlex.split`. 432 | :return: self, which is a Tap instance containing all of the parsed args. 433 | """ 434 | # Prevent double parsing 435 | if self._parsed: 436 | raise ValueError("parse_args can only be called once.") 437 | 438 | # Collect arguments from all of the configs 439 | 440 | if legacy_config_parsing: 441 | splitter = lambda arg_string: arg_string.split() 442 | else: 443 | splitter = lambda arg_string: split(arg_string, comments=True) 444 | 445 | config_args = [arg for args_from_config in self.args_from_configs for arg in splitter(args_from_config)] 446 | 447 | # Add config args at lower precedence and extract args from the command line if they are not passed explicitly 448 | args = config_args + (sys.argv[1:] if args is None else list(args)) 449 | 450 | # Parse args using super class ArgumentParser's parse_args or parse_known_args function 451 | if known_only: 452 | default_namespace, self.extra_args = super(Tap, self).parse_known_args(args) 453 | else: 454 | default_namespace = super(Tap, self).parse_args(args) 455 | 456 | # Copy parsed arguments to self 457 | for variable, value in vars(default_namespace).items(): 458 | # Conversion from list to set or tuple 459 | if variable in self._annotations: 460 | if type(value) == list: 461 | var_type = get_origin(self._annotations[variable]) 462 | 463 | # Unpack nested boxed types such as Optional[List[int]] 464 | if var_type is Union: 465 | var_type = get_origin(get_args(self._annotations[variable])[0]) 466 | 467 | # If var_type is tuple as in Python 3.6, change to a typing type 468 | # (e.g., (typing.Tuple, ) ==> typing.Tuple) 469 | if isinstance(var_type, tuple): 470 | var_type = var_type[0] 471 | 472 | if var_type in (Set, set): 473 | value = set(value) 474 | elif var_type in (Tuple, tuple): 475 | value = tuple(value) 476 | 477 | # Set variable in self 478 | setattr(self, variable, value) 479 | 480 | # Process args 481 | self.process_args() 482 | 483 | # Indicate that args have been parsed 484 | self._parsed = True 485 | 486 | return self 487 | 488 | @classmethod 489 | def _get_from_self_and_super(cls, extract_func: Callable[[type], dict]) -> Union[dict[str, Any], dict]: 490 | """Returns a dictionary mapping variable names to values. 491 | 492 | Variables and values are extracted from classes using key starting 493 | with this class and traversing up the super classes up through Tap. 494 | 495 | If super class and subclass have the same key, the subclass value is used. 496 | 497 | Super classes are traversed through breadth first search. 498 | 499 | :param extract_func: A function that extracts from a class a dictionary mapping variables to values. 500 | :return: A dictionary mapping variable names to values from the class dict. 501 | """ 502 | visited = set() 503 | super_classes = [cls] 504 | dictionary = {} 505 | 506 | while len(super_classes) > 0: 507 | super_class = super_classes.pop(0) 508 | 509 | if super_class not in visited and issubclass(super_class, Tap) and super_class is not Tap: 510 | super_dictionary = extract_func(super_class) 511 | 512 | # Update only unseen variables to avoid overriding subclass values 513 | for variable, value in super_dictionary.items(): 514 | if variable not in dictionary: 515 | dictionary[variable] = value 516 | for variable in super_dictionary.keys() - dictionary.keys(): 517 | dictionary[variable] = super_dictionary[variable] 518 | 519 | super_classes += list(super_class.__bases__) 520 | visited.add(super_class) 521 | 522 | return dictionary 523 | 524 | def _get_class_dict(self) -> dict[str, Any]: 525 | """Returns a dictionary mapping class variable names to values from the class dict.""" 526 | class_dict = self._get_from_self_and_super( 527 | extract_func=lambda super_class: dict(getattr(super_class, "__dict__", dict())) 528 | ) 529 | class_dict = { 530 | var: val 531 | for var, val in class_dict.items() 532 | if not (var.startswith("_") or callable(val) or isinstance(val, (staticmethod, classmethod, property))) 533 | } 534 | 535 | return class_dict 536 | 537 | def _get_annotations(self) -> dict[str, Any]: 538 | """Returns a dictionary mapping variable names to their type annotations.""" 539 | return self._get_from_self_and_super(extract_func=lambda super_class: dict(get_type_hints(super_class))) 540 | 541 | def _get_class_variables(self) -> dict: 542 | """Returns a dictionary mapping class variables names to their additional information.""" 543 | class_variable_names = {**self._get_annotations(), **self._get_class_dict()}.keys() 544 | 545 | try: 546 | class_variables = self._get_from_self_and_super(extract_func=get_class_variables) 547 | 548 | # Handle edge-case of source code modification while code is running 549 | variables_to_add = class_variable_names - class_variables.keys() 550 | variables_to_remove = class_variables.keys() - class_variable_names 551 | 552 | for variable in variables_to_add: 553 | class_variables[variable] = {"comment": ""} 554 | 555 | for variable in variables_to_remove: 556 | class_variables.pop(variable) 557 | # Exception if inspect.getsource fails to extract the source code 558 | except Exception: 559 | class_variables = {} 560 | for variable in class_variable_names: 561 | class_variables[variable] = {"comment": ""} 562 | 563 | return class_variables 564 | 565 | def _get_argument_names(self) -> set[str]: 566 | """Returns a list of variable names corresponding to the arguments.""" 567 | return ( 568 | {get_dest(*name_or_flags, **kwargs) for name_or_flags, kwargs in self.argument_buffer.values()} 569 | | set(self._get_class_dict().keys()) 570 | | set(self._annotations.keys()) 571 | ) - {"help"} 572 | 573 | def as_dict(self) -> dict[str, Any]: 574 | """Returns the member variables corresponding to the parsed arguments. 575 | 576 | Note: This does not include attributes set directly on an instance 577 | (e.g. arg is not included in MyTap().arg = "hi") 578 | 579 | :return: A dictionary mapping each argument's name to its value. 580 | """ 581 | if not self._parsed: 582 | raise ValueError("You should call `parse_args` before retrieving arguments.") 583 | 584 | self_dict = self.__dict__ 585 | class_dict = self._get_from_self_and_super( 586 | extract_func=lambda super_class: dict(getattr(super_class, "__dict__", dict())) 587 | ) 588 | class_dict = {key: val for key, val in class_dict.items() if key not in self_dict} 589 | stored_dict = {**self_dict, **class_dict} 590 | 591 | stored_dict = { 592 | var: getattr(self, var) 593 | for var, val in stored_dict.items() 594 | if not (var.startswith("_") or isinstance(val, MethodType) or isinstance(val, staticmethod)) 595 | } 596 | 597 | tap_class_dict_keys = Tap().__dict__.keys() | Tap.__dict__.keys() 598 | stored_dict = {key: stored_dict[key] for key in stored_dict.keys() - tap_class_dict_keys} 599 | 600 | return stored_dict 601 | 602 | def from_dict(self, args_dict: dict[str, Any], skip_unsettable: bool = False) -> TapType: 603 | """Loads arguments from a dictionary, ensuring all required arguments are set. 604 | 605 | :param args_dict: A dictionary from argument names to the values of the arguments. 606 | :param skip_unsettable: When True, skips attributes that cannot be set in the Tap object, 607 | e.g. properties without setters. 608 | :return: Returns self. 609 | """ 610 | # All of the required arguments must be provided or already set 611 | required_args = {a.dest for a in self._actions if a.required} 612 | unprovided_required_args = required_args - args_dict.keys() 613 | missing_required_args = [arg for arg in unprovided_required_args if not hasattr(self, arg)] 614 | 615 | if len(missing_required_args) > 0: 616 | raise ValueError( 617 | f'Input dictionary "{args_dict}" does not include ' 618 | f'all unset required arguments: "{missing_required_args}".' 619 | ) 620 | 621 | # Load all arguments 622 | for key, value in args_dict.items(): 623 | try: 624 | setattr(self, key, value) 625 | except AttributeError: 626 | if not skip_unsettable: 627 | raise AttributeError( 628 | f'Cannot set attribute "{key}" to "{value}". ' 629 | f"To skip arguments that cannot be set \n" 630 | f'\t"skip_unsettable = True"' 631 | ) 632 | 633 | self._parsed = True 634 | 635 | return self 636 | 637 | def save( 638 | self, 639 | path: PathLike, 640 | with_reproducibility: bool = True, 641 | skip_unpicklable: bool = False, 642 | repo_path: Optional[PathLike] = None, 643 | ) -> None: 644 | """Saves the arguments and reproducibility information in JSON format, pickling what can't be encoded. 645 | 646 | :param path: Path to the JSON file where the arguments will be saved. 647 | :param with_reproducibility: If True, adds a "reproducibility" field with information (e.g. git hash) 648 | to the JSON file. 649 | :param repo_path: Path to the git repo to examine for reproducibility info. 650 | If None, uses the git repo of the Python file that is run. 651 | :param skip_unpicklable: If True, does not save attributes whose values cannot be pickled. 652 | """ 653 | with open(path, "w") as f: 654 | args = self._log_all(repo_path=repo_path) if with_reproducibility else self.as_dict() 655 | json.dump(args, f, indent=4, sort_keys=True, cls=define_python_object_encoder(skip_unpicklable)) 656 | 657 | def load( 658 | self, 659 | path: PathLike, 660 | check_reproducibility: bool = False, 661 | skip_unsettable: bool = False, 662 | repo_path: Optional[PathLike] = None, 663 | ) -> TapType: 664 | """Loads the arguments in JSON format. Note: Due to JSON, tuples are loaded as lists. 665 | 666 | :param path: Path to the JSON file where the arguments will be loaded from. 667 | :param check_reproducibility: When True, raises an error if the loaded reproducibility 668 | information doesn't match the current reproducibility information. 669 | :param skip_unsettable: When True, skips attributes that cannot be set in the Tap object, 670 | e.g. properties without setters. 671 | :param repo_path: Path to the git repo to examine for reproducibility info. 672 | If None, uses the git repo of the Python file that is run. 673 | :return: Returns self. 674 | """ 675 | with open(path) as f: 676 | args_dict = json.load(f, object_hook=as_python_object) 677 | 678 | # Remove loaded reproducibility information since it is no longer valid 679 | saved_reproducibility_data = args_dict.pop("reproducibility", None) 680 | if check_reproducibility: 681 | current_reproducibility_data = self.get_reproducibility_info(repo_path=repo_path) 682 | enforce_reproducibility(saved_reproducibility_data, current_reproducibility_data, path) 683 | 684 | self.from_dict(args_dict, skip_unsettable=skip_unsettable) 685 | 686 | return self 687 | 688 | def _load_from_config_files(self, config_files: Optional[list[str]]) -> list[str]: 689 | """Loads arguments from a list of configuration files containing command line arguments. 690 | 691 | :param config_files: A list of paths to configuration files containing the command line arguments 692 | (e.g., '--arg1 a1 --arg2 a2'). Arguments passed in from the command line 693 | overwrite arguments from the configuration files. Arguments in configuration files 694 | that appear later in the list overwrite the arguments in previous configuration files. 695 | :return: A list of the contents of each config file in order of increasing precedence (highest last). 696 | """ 697 | args_from_config = [] 698 | 699 | if config_files is not None: 700 | # Read arguments from all configs from the lowest precedence config to the highest 701 | for file in config_files: 702 | with open(file) as f: 703 | args_from_config.append(f.read().strip()) 704 | 705 | return args_from_config 706 | 707 | def __str__(self) -> str: 708 | """Returns a string representation of self. 709 | 710 | :return: A formatted string representation of the dictionary of all arguments. 711 | """ 712 | return pformat(self.as_dict()) 713 | 714 | def __deepcopy__(self, memo: dict[int, Any] = None) -> TapType: 715 | """Deepcopy the Tap object.""" 716 | copied = type(self).__new__(type(self)) 717 | 718 | if memo is None: 719 | memo = {} 720 | 721 | memo[id(self)] = copied 722 | 723 | for k, v in self.__dict__.items(): 724 | copied.__dict__[k] = deepcopy(v, memo) 725 | 726 | return copied 727 | 728 | def __getstate__(self) -> dict[str, Any]: 729 | """Gets the state of the object for pickling.""" 730 | return self.as_dict() 731 | 732 | def __setstate__(self, d: dict[str, Any]) -> None: 733 | """ 734 | Initializes the object with the provided dictionary of arguments for unpickling. 735 | 736 | :param d: A dictionary of arguments. 737 | """ 738 | self.__init__() 739 | self.from_dict(d) 740 | -------------------------------------------------------------------------------- /src/tap/tapify.py: -------------------------------------------------------------------------------- 1 | """ 2 | `tapify`: initialize a class or run a function by parsing arguments from the command line. 3 | 4 | `to_tap_class`: convert a class or function into a `Tap` class, which can then be subclassed to add special argument 5 | handling 6 | """ 7 | 8 | import dataclasses 9 | import inspect 10 | from typing import Any, Callable, Optional, Sequence, TypeVar, Union 11 | 12 | from docstring_parser import Docstring, parse 13 | from packaging.version import Version 14 | 15 | try: 16 | import pydantic 17 | except ModuleNotFoundError: 18 | _IS_PYDANTIC_V1 = None 19 | # These are "empty" types. isinstance and issubclass will always be False 20 | BaseModel = type("BaseModel", (object,), {}) 21 | _PydanticField = type("_PydanticField", (object,), {}) 22 | _PYDANTIC_FIELD_TYPES = () 23 | else: 24 | _IS_PYDANTIC_V1 = Version(pydantic.__version__) < Version("2.0.0") 25 | from pydantic import BaseModel 26 | from pydantic.fields import FieldInfo as PydanticFieldBaseModel 27 | from pydantic.dataclasses import FieldInfo as PydanticFieldDataclass 28 | 29 | _PydanticField = Union[PydanticFieldBaseModel, PydanticFieldDataclass] 30 | # typing.get_args(_PydanticField) is an empty tuple for some reason. Just repeat 31 | _PYDANTIC_FIELD_TYPES = (PydanticFieldBaseModel, PydanticFieldDataclass) 32 | 33 | from tap import Tap 34 | 35 | OutputType = TypeVar("OutputType") 36 | 37 | _ClassOrFunction = Union[Callable[..., OutputType], type[OutputType]] 38 | 39 | 40 | @dataclasses.dataclass 41 | class _ArgData: 42 | """ 43 | Data about an argument which is sufficient to inform a Tap variable/argument. 44 | """ 45 | 46 | name: str 47 | 48 | annotation: type 49 | "The type of values this argument accepts" 50 | 51 | is_required: bool 52 | "Whether or not the argument must be passed in" 53 | 54 | default: Any 55 | "Value of the argument if the argument isn't passed in. This gets ignored if `is_required`" 56 | 57 | description: Optional[str] = "" 58 | "Human-readable description of the argument" 59 | 60 | is_positional_only: bool = False 61 | "Whether or not the argument must be provided positionally" 62 | 63 | 64 | @dataclasses.dataclass(frozen=True) 65 | class _TapData: 66 | """ 67 | Data about a class' or function's arguments which are sufficient to inform a Tap class. 68 | """ 69 | 70 | args_data: list[_ArgData] 71 | "List of data about each argument in the class or function" 72 | 73 | has_kwargs: bool 74 | "True if you can pass variable/extra kwargs to the class or function (as in **kwargs), else False" 75 | 76 | known_only: bool 77 | "If true, ignore extra arguments and only parse known arguments" 78 | 79 | 80 | def _is_pydantic_base_model(obj: Union[type[Any], Any]) -> bool: 81 | if inspect.isclass(obj): # issubclass requires that obj is a class 82 | return issubclass(obj, BaseModel) 83 | else: 84 | return isinstance(obj, BaseModel) 85 | 86 | 87 | def _is_pydantic_dataclass(obj: Union[type[Any], Any]) -> bool: 88 | if _IS_PYDANTIC_V1: 89 | # There's no public function in v1. This is a somewhat safe but linear check 90 | return dataclasses.is_dataclass(obj) and any(key.startswith("__pydantic") for key in obj.__dict__) 91 | else: 92 | return pydantic.dataclasses.is_pydantic_dataclass(obj) 93 | 94 | 95 | def _tap_data_from_data_model( 96 | data_model: Any, func_kwargs: dict[str, Any], param_to_description: dict[str, str] = None 97 | ) -> _TapData: 98 | """ 99 | Currently only works when `data_model` is a: 100 | - builtin dataclass (class or instance) 101 | - Pydantic dataclass (class or instance) 102 | - Pydantic BaseModel (class or instance). 103 | 104 | The advantage of this function over :func:`_tap_data_from_class_or_function` is that field/argument descriptions are 105 | extracted, b/c this function look at the fields of the data model. 106 | 107 | Note 108 | ---- 109 | Deletes redundant keys from `func_kwargs` 110 | """ 111 | param_to_description = param_to_description or {} 112 | 113 | def arg_data_from_dataclass(name: str, field: dataclasses.Field) -> _ArgData: 114 | def is_required(field: dataclasses.Field) -> bool: 115 | return field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING 116 | 117 | description = param_to_description.get(name, field.metadata.get("description")) 118 | return _ArgData( 119 | name, 120 | field.type, 121 | is_required(field), 122 | field.default, 123 | description, 124 | ) 125 | 126 | def arg_data_from_pydantic(name: str, field: _PydanticField, annotation: Optional[type] = None) -> _ArgData: 127 | annotation = field.annotation if annotation is None else annotation 128 | # Prefer the description from param_to_description (from the data model / class docstring) over the 129 | # field.description b/c a docstring can be modified on the fly w/o causing real issues 130 | description = param_to_description.get(name, field.description) 131 | return _ArgData(name, annotation, field.is_required(), field.default, description) 132 | 133 | # Determine what type of data model it is and extract fields accordingly 134 | if dataclasses.is_dataclass(data_model): 135 | name_to_field = {field.name: field for field in dataclasses.fields(data_model)} 136 | has_kwargs = False 137 | known_only = False 138 | elif _is_pydantic_base_model(data_model): 139 | name_to_field = data_model.model_fields 140 | # For backwards compatibility, only allow new kwargs to get assigned if the model is explicitly configured to do 141 | # so via extra="allow". See https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra 142 | is_extra_ok = data_model.model_config.get("extra", "ignore") == "allow" 143 | has_kwargs = is_extra_ok 144 | known_only = is_extra_ok 145 | else: 146 | raise TypeError( 147 | "data_model must be a builtin or Pydantic dataclass (instance or class) or " 148 | f"a Pydantic BaseModel (instance or class). Got {type(data_model)}" 149 | ) 150 | 151 | # It's possible to mix fields w/ classes, e.g., use pydantic Fields in a (builtin) dataclass, or use (builtin) 152 | # dataclass fields in a pydantic BaseModel. It's also possible to use (builtin) dataclass fields and pydantic Fields 153 | # in the same data model. Therefore, the type of the data model doesn't determine the type of each field. The 154 | # solution is to iterate through the fields and check each type. 155 | args_data: list[_ArgData] = [] 156 | for name, field in name_to_field.items(): 157 | if isinstance(field, dataclasses.Field): 158 | # Idiosyncrasy: if a pydantic Field is used in a pydantic dataclass, then field.default is a FieldInfo 159 | # object instead of the field's default value. Furthermore, field.annotation is always NoneType. Luckily, 160 | # the actual type of the field is stored in field.type 161 | if isinstance(field.default, _PYDANTIC_FIELD_TYPES): 162 | arg_data = arg_data_from_pydantic(name, field.default, annotation=field.type) 163 | else: 164 | arg_data = arg_data_from_dataclass(name, field) 165 | elif isinstance(field, _PYDANTIC_FIELD_TYPES): 166 | arg_data = arg_data_from_pydantic(name, field) 167 | else: 168 | raise TypeError(f"Each field must be a dataclass or Pydantic field. Got {type(field)}") 169 | # Handle case where func_kwargs is supplied 170 | if name in func_kwargs: 171 | arg_data.default = func_kwargs[name] 172 | arg_data.is_required = False 173 | del func_kwargs[name] 174 | args_data.append(arg_data) 175 | return _TapData(args_data, has_kwargs, known_only) 176 | 177 | 178 | def _tap_data_from_class_or_function( 179 | class_or_function: _ClassOrFunction, func_kwargs: dict[str, Any], param_to_description: dict[str, str] 180 | ) -> _TapData: 181 | """ 182 | Extract data by inspecting the signature of `class_or_function`. 183 | 184 | Note 185 | ---- 186 | Deletes redundant keys from `func_kwargs` 187 | """ 188 | args_data: list[_ArgData] = [] 189 | has_kwargs = False 190 | known_only = False 191 | 192 | sig = inspect.signature(class_or_function) 193 | 194 | for param_name, param in sig.parameters.items(): 195 | # Skip **kwargs 196 | if param.kind == inspect.Parameter.VAR_KEYWORD: 197 | has_kwargs = True 198 | known_only = True 199 | continue 200 | 201 | if param.annotation != inspect.Parameter.empty: 202 | annotation = param.annotation 203 | else: 204 | annotation = Any 205 | 206 | if param.name in func_kwargs: 207 | is_required = False 208 | default = func_kwargs[param.name] 209 | del func_kwargs[param.name] 210 | elif param.default != inspect.Parameter.empty: 211 | is_required = False 212 | default = param.default 213 | else: 214 | is_required = True 215 | default = inspect.Parameter.empty # Can be set to anything. It'll be ignored 216 | 217 | arg_data = _ArgData( 218 | name=param_name, 219 | annotation=annotation, 220 | is_required=is_required, 221 | default=default, 222 | description=param_to_description.get(param.name), 223 | is_positional_only=param.kind == inspect.Parameter.POSITIONAL_ONLY, 224 | ) 225 | args_data.append(arg_data) 226 | return _TapData(args_data, has_kwargs, known_only) 227 | 228 | 229 | def _is_data_model(obj: Union[type[Any], Any]) -> bool: 230 | return dataclasses.is_dataclass(obj) or _is_pydantic_base_model(obj) 231 | 232 | 233 | def _docstring(class_or_function) -> Docstring: 234 | is_function = not inspect.isclass(class_or_function) 235 | if is_function or _is_pydantic_base_model(class_or_function): 236 | doc = class_or_function.__doc__ 237 | else: 238 | doc = class_or_function.__init__.__doc__ or class_or_function.__doc__ 239 | return parse(doc) 240 | 241 | 242 | def _tap_data(class_or_function: _ClassOrFunction, param_to_description: dict[str, str], func_kwargs) -> _TapData: 243 | """ 244 | Controls how :class:`_TapData` is extracted from `class_or_function`. 245 | """ 246 | is_pydantic_v1_data_model = _IS_PYDANTIC_V1 and ( 247 | _is_pydantic_base_model(class_or_function) or _is_pydantic_dataclass(class_or_function) 248 | ) 249 | if _is_data_model(class_or_function) and not is_pydantic_v1_data_model: 250 | # Data models from Pydantic v1 don't lend themselves well to _tap_data_from_data_model. 251 | # _tap_data_from_data_model looks at the data model's fields. In Pydantic v1, the field.type_ attribute stores 252 | # the field's annotation/type. But (in Pydantic v1) there's a bug where field.type_ is set to the inner-most 253 | # type of a subscripted type. For example, annotating a field with list[str] causes field.type_ to be str, not 254 | # list[str]. To get around this, we'll extract _TapData by looking at the signature of the data model 255 | return _tap_data_from_data_model(class_or_function, func_kwargs, param_to_description) 256 | # TODO: allow passing func_kwargs to a Pydantic BaseModel 257 | return _tap_data_from_class_or_function(class_or_function, func_kwargs, param_to_description) 258 | 259 | 260 | def _tap_class(args_data: Sequence[_ArgData]) -> type[Tap]: 261 | """ 262 | Transfers argument data to a :class:`tap.Tap` class. Arguments will be added to the parser on initialization. 263 | """ 264 | 265 | class ArgParser(Tap): 266 | # Overwriting configure would force a user to remember to call super().configure if they want to overwrite it 267 | # Instead, overwrite _configure 268 | def _configure(self): 269 | for arg_data in args_data: 270 | variable = arg_data.name 271 | self._annotations[variable] = str if arg_data.annotation is Any else arg_data.annotation 272 | self.class_variables[variable] = {"comment": arg_data.description or ""} 273 | if arg_data.is_required: 274 | kwargs = {} 275 | else: 276 | kwargs = dict(required=False, default=arg_data.default) 277 | self.add_argument(f"--{variable}", **kwargs) 278 | 279 | super()._configure() 280 | 281 | return ArgParser 282 | 283 | 284 | def to_tap_class(class_or_function: _ClassOrFunction) -> type[Tap]: 285 | """Creates a `Tap` class from `class_or_function`. This can be subclassed to add custom argument handling and 286 | instantiated to create a typed argument parser. 287 | 288 | :param class_or_function: The class or function to run with the provided arguments. 289 | """ 290 | docstring = _docstring(class_or_function) 291 | param_to_description = {param.arg_name: param.description for param in docstring.params} 292 | # TODO: add func_kwargs 293 | tap_data = _tap_data(class_or_function, param_to_description, func_kwargs={}) 294 | return _tap_class(tap_data.args_data) 295 | 296 | 297 | def tapify( 298 | class_or_function: Union[Callable[..., OutputType], type[OutputType]], 299 | known_only: bool = False, 300 | command_line_args: Optional[list[str]] = None, 301 | explicit_bool: bool = False, 302 | underscores_to_dashes: bool = False, 303 | description: Optional[str] = None, 304 | **func_kwargs, 305 | ) -> OutputType: 306 | """Tapify initializes a class or runs a function by parsing arguments from the command line. 307 | 308 | :param class_or_function: The class or function to run with the provided arguments. 309 | :param known_only: If true, ignores extra arguments and only parses known arguments. 310 | :param command_line_args: A list of command line style arguments to parse (e.g., `['--arg', 'value']`). If None, 311 | arguments are parsed from the command line (default behavior). 312 | :param explicit_bool: Booleans can be specified on the command line as `--arg True` or `--arg False` rather than 313 | `--arg`. Additionally, booleans can be specified by prefixes of True and False with any 314 | capitalization as well as 1 or 0. 315 | :param underscores_to_dashes: If True, convert underscores in flag names to dashes. 316 | :param description: The description displayed in the help message—the same description passed in 317 | `argparse.ArgumentParser(description=...)`. By default, it's extracted from `class_or_function`'s 318 | docstring. 319 | :param func_kwargs: Additional keyword arguments for the function. These act as default values when parsing the 320 | command line arguments and overwrite the function defaults but are overwritten by the parsed 321 | command line arguments. 322 | """ 323 | # We don't directly call to_tap_class b/c we need tap_data, not just tap_class 324 | docstring = _docstring(class_or_function) 325 | param_to_description = {param.arg_name: param.description for param in docstring.params} 326 | tap_data = _tap_data(class_or_function, param_to_description, func_kwargs) 327 | tap_class = _tap_class(tap_data.args_data) 328 | # Create a Tap object 329 | if description is None: 330 | description = "\n".join(filter(None, (docstring.short_description, docstring.long_description))) 331 | tap = tap_class(description=description, explicit_bool=explicit_bool, underscores_to_dashes=underscores_to_dashes) 332 | 333 | # If any func_kwargs remain, they are not used in the function, so raise an error 334 | known_only = known_only or tap_data.known_only 335 | if func_kwargs and not known_only: 336 | raise ValueError(f"Unknown keyword arguments: {func_kwargs}") 337 | 338 | # Parse command line arguments 339 | command_line_args: Tap = tap.parse_args(args=command_line_args, known_only=known_only) 340 | 341 | # Prepare command line arguments for class_or_function, respecting positional-only args 342 | class_or_function_args: list[Any] = [] 343 | class_or_function_kwargs: dict[str, Any] = {} 344 | command_line_args_dict = command_line_args.as_dict() 345 | for arg_data in tap_data.args_data: 346 | arg_value = command_line_args_dict[arg_data.name] 347 | if arg_data.is_positional_only: 348 | class_or_function_args.append(arg_value) 349 | else: 350 | class_or_function_kwargs[arg_data.name] = arg_value 351 | 352 | # Get **kwargs from extra command line arguments 353 | if tap_data.has_kwargs: 354 | kwargs = {tap.extra_args[i].lstrip("-"): tap.extra_args[i + 1] for i in range(0, len(tap.extra_args), 2)} 355 | class_or_function_kwargs.update(kwargs) 356 | 357 | # Initialize the class or run the function with the parsed arguments 358 | return class_or_function(*class_or_function_args, **class_or_function_kwargs) 359 | -------------------------------------------------------------------------------- /src/tap/utils.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, ArgumentTypeError 2 | import ast 3 | from base64 import b64encode, b64decode 4 | import inspect 5 | from io import StringIO 6 | from json import JSONEncoder 7 | import os 8 | import pickle 9 | import re 10 | import subprocess 11 | import sys 12 | import textwrap 13 | import tokenize 14 | from typing import ( 15 | Any, 16 | Callable, 17 | Generator, 18 | Iterable, 19 | Iterator, 20 | Literal, 21 | Optional, 22 | Union, 23 | ) 24 | from typing_inspect import get_args as typing_inspect_get_args, get_origin as typing_inspect_get_origin 25 | import warnings 26 | 27 | if sys.version_info >= (3, 10): 28 | from types import UnionType 29 | 30 | NO_CHANGES_STATUS = """nothing to commit, working tree clean""" 31 | PRIMITIVES = (str, int, float, bool) 32 | PathLike = Union[str, os.PathLike] 33 | 34 | 35 | def check_output(command: list[str], suppress_stderr: bool = True, **kwargs) -> str: 36 | """Runs subprocess.check_output and returns the result as a string. 37 | 38 | :param command: A list of strings representing the command to run on the command line. 39 | :param suppress_stderr: Whether to suppress anything written to standard error. 40 | :return: The output of the command, converted from bytes to string and stripped. 41 | """ 42 | with open(os.devnull, "w") as devnull: 43 | devnull = devnull if suppress_stderr else None 44 | output = subprocess.check_output(command, stderr=devnull, **kwargs).decode("utf-8").strip() 45 | return output 46 | 47 | 48 | class GitInfo: 49 | """Class with helper methods for extracting information about a git repo.""" 50 | 51 | def __init__(self, repo_path: PathLike): 52 | self.repo_path = repo_path 53 | 54 | def has_git(self) -> bool: 55 | """Returns whether git is installed. 56 | 57 | :return: True if git is installed, False otherwise. 58 | """ 59 | try: 60 | output = check_output(["git", "rev-parse", "--is-inside-work-tree"], cwd=self.repo_path) 61 | return output == "true" 62 | except (FileNotFoundError, subprocess.CalledProcessError): 63 | return False 64 | 65 | def get_git_root(self) -> str: 66 | """Gets the root directory of the git repo where the command is run. 67 | 68 | :return: The root directory of the current git repo. 69 | """ 70 | return check_output(["git", "rev-parse", "--show-toplevel"], cwd=self.repo_path) 71 | 72 | def get_git_version(self) -> tuple: 73 | """Gets the version of git. 74 | 75 | :return: The version of git, as a tuple of strings. 76 | 77 | Example: 78 | >>> get_git_version() 79 | (2, 17, 1) # for git version 2.17.1 80 | """ 81 | raw = check_output(["git", "--version"]) 82 | number_start_index = next(i for i, c in enumerate(raw) if c.isdigit()) 83 | return tuple(int(num) for num in raw[number_start_index:].split(".") if num.isdigit()) 84 | 85 | def get_git_url(self, commit_hash: bool = True) -> str: 86 | """Gets the https url of the git repo where the command is run. 87 | 88 | :param commit_hash: If True, the url links to the latest local git commit hash. 89 | If False, the url links to the general git url. 90 | :return: The https url of the current git repo or an empty string for a local repo. 91 | """ 92 | # Get git url (either https or ssh) 93 | input_remote = ( 94 | ["git", "remote", "get-url", "origin"] 95 | if self.get_git_version() >= (2, 0) 96 | else ["git", "config", "--get", "remote.origin.url"] 97 | ) 98 | try: 99 | url = check_output(input_remote, cwd=self.repo_path) 100 | except subprocess.CalledProcessError as e: 101 | if e.returncode in {2, 128}: 102 | # https://git-scm.com/docs/git-remote#_exit_status 103 | # 2: The remote does not exist. 104 | # 128: The remote was not found. 105 | return "" 106 | raise e 107 | 108 | # Remove .git at end 109 | url = url[: -len(".git")] 110 | 111 | # Convert ssh url to https url 112 | m = re.search("git@(.+):", url) 113 | if m is not None: 114 | domain = m.group(1) 115 | path = url[m.span()[1] :] 116 | url = f"https://{domain}/{path}" 117 | 118 | if commit_hash: 119 | # Add tree and hash of current commit 120 | url = f"{url}/tree/{self.get_git_hash()}" 121 | 122 | return url 123 | 124 | def get_git_hash(self) -> str: 125 | """Gets the git hash of HEAD of the git repo where the command is run. 126 | 127 | :return: The git hash of HEAD of the current git repo. 128 | """ 129 | return check_output(["git", "rev-parse", "HEAD"], cwd=self.repo_path) 130 | 131 | def has_uncommitted_changes(self) -> bool: 132 | """Returns whether there are uncommitted changes in the git repo where the command is run. 133 | 134 | :return: True if there are uncommitted changes in the current git repo, False otherwise. 135 | """ 136 | status = check_output(["git", "status"], cwd=self.repo_path) 137 | 138 | return not status.endswith(NO_CHANGES_STATUS) 139 | 140 | 141 | def type_to_str(type_annotation: Union[type, Any]) -> str: 142 | """Gets a string representation of the provided type. 143 | 144 | :param type_annotation: A type annotation, which is either a built-in type or a typing type. 145 | :return: A string representation of the type annotation. 146 | """ 147 | # Built-in type 148 | if type(type_annotation) == type: 149 | return type_annotation.__name__ 150 | 151 | # Typing type 152 | return str(type_annotation).replace("typing.", "") 153 | 154 | 155 | def get_argument_name(*name_or_flags) -> str: 156 | """Gets the name of the argument. 157 | 158 | :param name_or_flags: Either a name or a list of option strings, e.g. foo or -f, --foo. 159 | :return: The name of the argument (extracted from name_or_flags). 160 | """ 161 | if "-h" in name_or_flags or "--help" in name_or_flags: 162 | return "help" 163 | 164 | if len(name_or_flags) > 1: 165 | name_or_flags = tuple(n_or_f for n_or_f in name_or_flags if n_or_f.startswith("--")) 166 | 167 | if len(name_or_flags) != 1: 168 | raise ValueError(f"There should only be a single canonical name for argument {name_or_flags}!") 169 | 170 | return name_or_flags[0].lstrip("-") 171 | 172 | 173 | def get_dest(*name_or_flags, **kwargs) -> str: 174 | """Gets the name of the destination of the argument. 175 | 176 | :param name_or_flags: Either a name or a list of option strings, e.g. foo or -f, --foo. 177 | :param kwargs: Keyword arguments. 178 | :return: The name of the argument (extracted from name_or_flags). 179 | """ 180 | if "-h" in name_or_flags or "--help" in name_or_flags: 181 | return "help" 182 | 183 | return ArgumentParser().add_argument(*name_or_flags, **kwargs).dest 184 | 185 | 186 | def is_option_arg(*name_or_flags) -> bool: 187 | """Returns whether the argument is an option arg (as opposed to a positional arg). 188 | 189 | :param name_or_flags: Either a name or a list of option strings, e.g. foo or -f, --foo. 190 | :return: True if the argument is an option arg, False otherwise. 191 | """ 192 | return any(name_or_flag.startswith("-") for name_or_flag in name_or_flags) 193 | 194 | 195 | def is_positional_arg(*name_or_flags) -> bool: 196 | """Returns whether the argument is a positional arg (as opposed to an optional arg). 197 | 198 | :param name_or_flags: Either a name or a list of option strings, e.g. foo or -f, --foo. 199 | :return: True if the argument is a positional arg, False otherwise. 200 | """ 201 | return not is_option_arg(*name_or_flags) 202 | 203 | 204 | def tokenize_source(source: str) -> Generator[tokenize.TokenInfo, None, None]: 205 | """Returns a generator for the tokens of the object's source code, given the source code.""" 206 | return tokenize.generate_tokens(StringIO(source).readline) 207 | 208 | 209 | def get_class_column(tokens: Iterable[tokenize.TokenInfo]) -> int: 210 | """Determines the column number for class variables in a class, given the tokens of the class.""" 211 | first_line = 1 212 | for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens: 213 | if token.strip() == "@": 214 | first_line += 1 215 | if start_line <= first_line or token.strip() == "": 216 | continue 217 | 218 | return start_column 219 | raise ValueError("Could not find any class variables in the class.") 220 | 221 | 222 | def source_line_to_tokens(tokens: Iterable[tokenize.TokenInfo]) -> dict[int, list[dict[str, Union[str, int]]]]: 223 | """Extract a map from each line number to list of mappings providing information about each token.""" 224 | line_to_tokens = {} 225 | for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens: 226 | line_to_tokens.setdefault(start_line, []).append( 227 | { 228 | "token_type": token_type, 229 | "token": token, 230 | "start_line": start_line, 231 | "start_column": start_column, 232 | "end_line": end_line, 233 | "end_column": end_column, 234 | "line": line, 235 | } 236 | ) 237 | 238 | return line_to_tokens 239 | 240 | 241 | def get_subsequent_assign_lines(source_cls: str) -> tuple[set[int], set[int]]: 242 | """For all multiline assign statements, get the line numbers after the first line in the assignment. 243 | 244 | :param source_cls: The source code of the class. 245 | :return: A set of intermediate line numbers for multiline assign statements and a set of final line numbers. 246 | """ 247 | # Parse source code using ast (with an if statement to avoid indentation errors) 248 | source = f"if True:\n{textwrap.indent(source_cls, ' ')}" 249 | body = ast.parse(source).body[0] 250 | 251 | # Set up warning message 252 | parse_warning = ( 253 | "Could not parse class source code to extract comments. Comments in the help string may be incorrect." 254 | ) 255 | 256 | # Check for correct parsing 257 | if not isinstance(body, ast.If): 258 | warnings.warn(parse_warning) 259 | return set(), set() 260 | 261 | # Extract if body 262 | if_body = body.body 263 | 264 | # Check for a single body 265 | if len(if_body) != 1: 266 | warnings.warn(parse_warning) 267 | return set(), set() 268 | 269 | # Extract class body 270 | cls_body = if_body[0] 271 | 272 | # Check for a single class definition 273 | if not isinstance(cls_body, ast.ClassDef): 274 | warnings.warn(parse_warning) 275 | return set(), set() 276 | 277 | # Get line numbers of assign statements 278 | intermediate_assign_lines = set() 279 | final_assign_lines = set() 280 | for node in cls_body.body: 281 | if isinstance(node, (ast.Assign, ast.AnnAssign)): 282 | # Check if the end line number is found 283 | if node.end_lineno is None: 284 | warnings.warn(parse_warning) 285 | continue 286 | 287 | # Only consider multiline assign statements 288 | if node.end_lineno > node.lineno: 289 | # Get intermediate line number of assign statement excluding the first line (and minus 1 for the if statement) 290 | intermediate_assign_lines |= set(range(node.lineno, node.end_lineno - 1)) 291 | 292 | # If multiline assign statement, get the line number of the last line (and minus 1 for the if statement) 293 | final_assign_lines.add(node.end_lineno - 1) 294 | 295 | return intermediate_assign_lines, final_assign_lines 296 | 297 | 298 | def get_class_variables(cls: type) -> dict[str, dict[str, str]]: 299 | """Returns a dictionary mapping class variables to their additional information (currently just comments).""" 300 | # Get the source code and tokens of the class 301 | source_cls = inspect.getsource(cls) 302 | tokens = tuple(tokenize_source(source_cls)) 303 | 304 | # Get mapping from line number to tokens 305 | line_to_tokens = source_line_to_tokens(tokens) 306 | 307 | # Get class variable column number 308 | class_variable_column = get_class_column(tokens) 309 | 310 | # For all multiline assign statements, get the line numbers after the first line of the assignment 311 | # This is used to avoid identifying comments in multiline assign statements 312 | intermediate_assign_lines, final_assign_lines = get_subsequent_assign_lines(source_cls) 313 | 314 | # Extract class variables 315 | class_variable = None 316 | variable_to_comment = {} 317 | for line, tokens in line_to_tokens.items(): 318 | # If this is the final line of a multiline assign, extract any potential comments 319 | if line in final_assign_lines: 320 | # Find the comment (if it exists) 321 | for token in tokens: 322 | if token["token_type"] == tokenize.COMMENT: 323 | # Leave out "#" and whitespace from comment 324 | variable_to_comment[class_variable]["comment"] = token["token"][1:].strip() 325 | break 326 | continue 327 | 328 | # Skip assign lines after the first line of multiline assign statements 329 | if line in intermediate_assign_lines: 330 | continue 331 | 332 | for i, token in enumerate(tokens): 333 | # Skip whitespace 334 | if token["token"].strip() == "": 335 | continue 336 | 337 | # Extract multiline comments 338 | if ( 339 | class_variable is not None 340 | and token["token_type"] == tokenize.STRING 341 | and token["token"][:1] in {'"', "'"} 342 | ): 343 | sep = " " if variable_to_comment[class_variable]["comment"] else "" 344 | 345 | # Identify the quote character (single or double) 346 | quote_char = token["token"][:1] 347 | 348 | # Identify the number of quote characters at the start of the string 349 | num_quote_chars = len(token["token"]) - len(token["token"].lstrip(quote_char)) 350 | 351 | # Remove the number of quote characters at the start of the string and the end of the string 352 | token["token"] = token["token"][num_quote_chars:-num_quote_chars] 353 | 354 | # Remove the unicode escape sequences (e.g. "\"") 355 | token["token"] = bytes(token["token"], encoding="ascii").decode("unicode-escape") 356 | 357 | # Add the token to the comment, stripping whitespace 358 | variable_to_comment[class_variable]["comment"] += sep + token["token"].strip() 359 | 360 | # Match class variable 361 | class_variable = None 362 | if ( 363 | token["token_type"] == tokenize.NAME 364 | and token["start_column"] == class_variable_column 365 | and len(tokens) > i 366 | and tokens[i + 1]["token"] in ["=", ":"] 367 | ): 368 | 369 | class_variable = token["token"] 370 | variable_to_comment[class_variable] = {"comment": ""} 371 | 372 | # Find the comment (if it exists) 373 | for j in range(i + 1, len(tokens)): 374 | if tokens[j]["token_type"] == tokenize.COMMENT: 375 | # Leave out "#" and whitespace from comment 376 | variable_to_comment[class_variable]["comment"] = tokens[j]["token"][1:].strip() 377 | break 378 | 379 | break 380 | 381 | return variable_to_comment 382 | 383 | 384 | def get_literals(literal: Literal, variable: str) -> tuple[Callable[[str], Any], list[type]]: 385 | """Extracts the values from a Literal type and ensures that the values are all primitive types.""" 386 | literals = list(get_args(literal)) 387 | 388 | if not all(isinstance(literal, PRIMITIVES) for literal in literals): 389 | raise ArgumentTypeError( 390 | f'The type for variable "{variable}" contains a literal' 391 | f"of a non-primitive type e.g. (str, int, float, bool).\n" 392 | f"Currently only primitive-typed literals are supported." 393 | ) 394 | 395 | str_to_literal = {str(literal): literal for literal in literals} 396 | 397 | if len(literals) != len(str_to_literal): 398 | raise ArgumentTypeError("All literals must have unique string representations") 399 | 400 | def var_type(arg: str) -> Any: 401 | if arg not in str_to_literal: 402 | raise ArgumentTypeError(f'Value for variable "{variable}" must be one of {literals}.') 403 | 404 | return str_to_literal[arg] 405 | 406 | return var_type, literals 407 | 408 | 409 | def boolean_type(flag_value: str) -> bool: 410 | """Convert a string to a boolean if it is a prefix of 'True' or 'False' (case insensitive) or is '1' or '0'.""" 411 | if "true".startswith(flag_value.lower()) or flag_value == "1": 412 | return True 413 | if "false".startswith(flag_value.lower()) or flag_value == "0": 414 | return False 415 | raise ArgumentTypeError('Value has to be a prefix of "True" or "False" (case insensitive) or "1" or "0".') 416 | 417 | 418 | class TupleTypeEnforcer: 419 | """The type argument to argparse for checking and applying types to Tuples.""" 420 | 421 | def __init__(self, types: list[type], loop: bool = False): 422 | self.types = [boolean_type if t == bool else t for t in types] 423 | self.loop = loop 424 | self.index = 0 425 | 426 | def __call__(self, arg: str) -> Any: 427 | arg = self.types[self.index](arg) 428 | self.index += 1 429 | 430 | if self.loop: 431 | self.index %= len(self.types) 432 | 433 | return arg 434 | 435 | 436 | class MockTuple: 437 | """Mock of a tuple needed to prevent JSON encoding tuples as lists.""" 438 | 439 | def __init__(self, _tuple: tuple) -> None: 440 | self.tuple = _tuple 441 | 442 | 443 | def _nested_replace_type(obj: Any, find_type: type, replace_type: type) -> Any: 444 | """Replaces any instance (including instances within lists, tuple, dict) of find_type with an instance of replace_type. 445 | 446 | Note: Tuples, lists, and dicts are NOT modified in place. 447 | Note: Does NOT do a nested search through objects besides tuples, lists, and dicts (e.g. sets). 448 | 449 | :param obj: The object to modify by replacing find_type instances with replace_type instances. 450 | :param find_type: The type to find in obj. 451 | :param replace_type: The type to used to replace find_type in obj. 452 | :return: A version of obj with all instances of find_type replaced by replace_type 453 | """ 454 | if isinstance(obj, tuple): 455 | obj = tuple(_nested_replace_type(item, find_type, replace_type) for item in obj) 456 | 457 | elif isinstance(obj, list): 458 | obj = [_nested_replace_type(item, find_type, replace_type) for item in obj] 459 | 460 | elif isinstance(obj, dict): 461 | obj = { 462 | _nested_replace_type(key, find_type, replace_type): _nested_replace_type(value, find_type, replace_type) 463 | for key, value in obj.items() 464 | } 465 | 466 | if isinstance(obj, find_type): 467 | obj = replace_type(obj) 468 | 469 | return obj 470 | 471 | 472 | def define_python_object_encoder(skip_unpicklable: bool = False) -> "PythonObjectEncoder": # noqa F821 473 | class PythonObjectEncoder(JSONEncoder): 474 | """Stores parameters that are not JSON serializable as pickle dumps. 475 | 476 | See: https://stackoverflow.com/a/36252257 477 | """ 478 | 479 | def iterencode(self, o: Any, _one_shot: bool = False) -> Iterator[str]: 480 | o = _nested_replace_type(o, tuple, MockTuple) 481 | return super(PythonObjectEncoder, self).iterencode(o, _one_shot) 482 | 483 | def default(self, obj: Any) -> Any: 484 | if isinstance(obj, set): 485 | return {"_type": "set", "_value": list(obj)} 486 | elif isinstance(obj, MockTuple): 487 | return {"_type": "tuple", "_value": list(obj.tuple)} 488 | 489 | try: 490 | return { 491 | "_type": f"python_object (type = {obj.__class__.__name__})", 492 | "_value": b64encode(pickle.dumps(obj)).decode("utf-8"), 493 | "_string": str(obj), 494 | } 495 | except (pickle.PicklingError, TypeError, AttributeError) as e: 496 | if not skip_unpicklable: 497 | raise ValueError( 498 | f"Could not pickle this object: Failed with exception {e}\n" 499 | f"If you would like to ignore unpicklable attributes set " 500 | f"skip_unpickleable = True in save." 501 | ) 502 | else: 503 | return {"_type": f"unpicklable_object {obj.__class__.__name__}", "_value": None} 504 | 505 | return PythonObjectEncoder 506 | 507 | 508 | class UnpicklableObject: 509 | """A class that serves as a placeholder for an object that could not be pickled.""" 510 | 511 | def __eq__(self, other): 512 | return isinstance(other, UnpicklableObject) 513 | 514 | 515 | def as_python_object(dct: Any) -> Any: 516 | """The hooks that allow a parameter that is not JSON serializable to be loaded. 517 | 518 | See: https://stackoverflow.com/a/36252257 519 | """ 520 | if "_type" in dct and "_value" in dct: 521 | _type, value = dct["_type"], dct["_value"] 522 | 523 | if _type == "tuple": 524 | return tuple(value) 525 | 526 | elif _type == "set": 527 | return set(value) 528 | 529 | elif _type.startswith("python_object"): 530 | return pickle.loads(b64decode(value.encode("utf-8"))) 531 | 532 | elif _type.startswith("unpicklable_object"): 533 | return UnpicklableObject() 534 | 535 | else: 536 | raise ArgumentTypeError(f'Special type "{_type}" not supported for JSON loading.') 537 | 538 | return dct 539 | 540 | 541 | def enforce_reproducibility( 542 | saved_reproducibility_data: Optional[dict[str, str]], current_reproducibility_data: dict[str, str], path: PathLike 543 | ) -> None: 544 | """Checks if reproducibility has failed and raises the appropriate error. 545 | 546 | :param saved_reproducibility_data: Reproducibility information loaded from a saved file. 547 | :param current_reproducibility_data: Reproducibility information from the current object. 548 | :param path: The path name of the file that is being loaded. 549 | """ 550 | no_reproducibility_message = "Reproducibility not guaranteed" 551 | 552 | if saved_reproducibility_data is None: 553 | raise ValueError( 554 | f"{no_reproducibility_message}: Could not find reproducibility " 555 | f'information in args loaded from "{path}".' 556 | ) 557 | 558 | if "git_url" not in saved_reproducibility_data: 559 | raise ValueError(f"{no_reproducibility_message}: Could not find " f'git url in args loaded from "{path}".') 560 | 561 | if "git_url" not in current_reproducibility_data: 562 | raise ValueError(f"{no_reproducibility_message}: Could not find " f"git url in current args.") 563 | 564 | if saved_reproducibility_data["git_url"] != current_reproducibility_data["git_url"]: 565 | raise ValueError( 566 | f"{no_reproducibility_message}: Differing git url/hash " 567 | f'between current args and args loaded from "{path}".' 568 | ) 569 | 570 | if saved_reproducibility_data["git_has_uncommitted_changes"]: 571 | raise ValueError(f"{no_reproducibility_message}: Uncommitted changes " f'in args loaded from "{path}".') 572 | 573 | if current_reproducibility_data["git_has_uncommitted_changes"]: 574 | raise ValueError(f"{no_reproducibility_message}: Uncommitted changes " f"in current args.") 575 | 576 | 577 | # TODO: remove this once typing_inspect.get_origin is fixed for Python 3.9 and 3.10 578 | # https://github.com/ilevkivskyi/typing_inspect/issues/64 579 | # https://github.com/ilevkivskyi/typing_inspect/issues/65 580 | def get_origin(tp: Any) -> Any: 581 | """Same as typing_inspect.get_origin but fixes unparameterized generic types like Set.""" 582 | origin = typing_inspect_get_origin(tp) 583 | 584 | if origin is None: 585 | origin = tp 586 | 587 | if sys.version_info >= (3, 10) and isinstance(origin, UnionType): 588 | origin = UnionType 589 | 590 | return origin 591 | 592 | 593 | # TODO: remove this once typing_inspect.get_args is fixed for Python 3.10 union types 594 | def get_args(tp: Any) -> tuple[type, ...]: 595 | """Same as typing_inspect.get_args but fixes Python 3.10 union types.""" 596 | if sys.version_info >= (3, 10) and isinstance(tp, UnionType): 597 | return tp.__args__ 598 | 599 | return typing_inspect_get_args(tp) 600 | -------------------------------------------------------------------------------- /tests/test_actions.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal 2 | import unittest 3 | from unittest import TestCase 4 | 5 | from tap import Tap 6 | 7 | 8 | class TestArgparseActions(TestCase): 9 | def test_actions_store_const(self): 10 | class StoreConstTap(Tap): 11 | def configure(self): 12 | self.add_argument("--sum", dest="accumulate", action="store_const", const=sum, default=max) 13 | 14 | args = StoreConstTap().parse_args([]) 15 | self.assertFalse(hasattr(args, "sum")) 16 | self.assertEqual(args.accumulate, max) 17 | self.assertEqual(args.as_dict(), {"accumulate": max}) 18 | 19 | args = StoreConstTap().parse_args(["--sum"]) 20 | self.assertFalse(hasattr(args, "sum")) 21 | self.assertEqual(args.accumulate, sum) 22 | self.assertEqual(args.as_dict(), {"accumulate": sum}) 23 | 24 | def test_actions_store_true_default_true(self): 25 | class StoreTrueDefaultTrueTap(Tap): 26 | foobar: bool = True 27 | 28 | def configure(self): 29 | self.add_argument("--foobar", action="store_true") 30 | 31 | args = StoreTrueDefaultTrueTap().parse_args([]) 32 | self.assertTrue(args.foobar) 33 | 34 | args = StoreTrueDefaultTrueTap().parse_args(["--foobar"]) 35 | self.assertTrue(args.foobar) 36 | 37 | def test_actions_store_true_default_false(self): 38 | class StoreTrueDefaultFalseTap(Tap): 39 | foobar: bool = False 40 | 41 | def configure(self): 42 | self.add_argument("--foobar", action="store_true") 43 | 44 | args = StoreTrueDefaultFalseTap().parse_args([]) 45 | self.assertFalse(args.foobar) 46 | 47 | args = StoreTrueDefaultFalseTap().parse_args(["--foobar"]) 48 | self.assertTrue(args.foobar) 49 | 50 | def test_actions_store_false_default_true(self): 51 | class StoreFalseDefaultTrueTap(Tap): 52 | foobar: bool = True 53 | 54 | def configure(self): 55 | self.add_argument("--foobar", action="store_false") 56 | 57 | args = StoreFalseDefaultTrueTap().parse_args([]) 58 | self.assertTrue(args.foobar) 59 | 60 | args = StoreFalseDefaultTrueTap().parse_args(["--foobar"]) 61 | self.assertFalse(args.foobar) 62 | 63 | def test_actions_store_false_default_false(self): 64 | class StoreFalseDefaultFalseTap(Tap): 65 | foobar: bool = False 66 | 67 | def configure(self): 68 | self.add_argument("--foobar", action="store_false") 69 | 70 | args = StoreFalseDefaultFalseTap().parse_args([]) 71 | self.assertFalse(args.foobar) 72 | 73 | args = StoreFalseDefaultFalseTap().parse_args(["--foobar"]) 74 | self.assertFalse(args.foobar) 75 | 76 | def test_actions_append_list(self): 77 | class AppendListTap(Tap): 78 | arg: List = ["what", "is"] 79 | 80 | def configure(self): 81 | self.add_argument("--arg", action="append") 82 | 83 | args = AppendListTap().parse_args([]) 84 | self.assertEqual(args.arg, ["what", "is"]) 85 | 86 | args = AppendListTap().parse_args("--arg up --arg today".split()) 87 | self.assertEqual(args.arg, "what is up today".split()) 88 | 89 | def test_actions_append_list_int(self): 90 | class AppendListIntTap(Tap): 91 | arg: List[int] = [1, 2] 92 | 93 | def configure(self): 94 | self.add_argument("--arg", action="append") 95 | 96 | args = AppendListIntTap().parse_args("--arg 3 --arg 4".split()) 97 | self.assertEqual(args.arg, [1, 2, 3, 4]) 98 | 99 | def test_actions_append_list_literal(self): 100 | class AppendListLiteralTap(Tap): 101 | arg: List[Literal["what", "is", "up", "today"]] = ["what", "is"] 102 | 103 | def configure(self): 104 | self.add_argument("--arg", action="append") 105 | 106 | args = AppendListLiteralTap().parse_args("--arg up --arg today".split()) 107 | self.assertEqual(args.arg, "what is up today".split()) 108 | 109 | def test_actions_append_untyped(self): 110 | class AppendListStrTap(Tap): 111 | arg = ["what", "is"] 112 | 113 | def configure(self): 114 | self.add_argument("--arg", action="append") 115 | 116 | args = AppendListStrTap().parse_args([]) 117 | self.assertEqual(args.arg, ["what", "is"]) 118 | 119 | args = AppendListStrTap().parse_args("--arg up --arg today".split()) 120 | self.assertEqual(args.arg, "what is up today".split()) 121 | 122 | def test_actions_append_const(self): 123 | class AppendConstTap(Tap): 124 | arg: List[int] = [1, 2, 3] 125 | 126 | def configure(self): 127 | self.add_argument("--arg", action="append_const", const=7) 128 | 129 | args = AppendConstTap().parse_args([]) 130 | self.assertEqual(args.arg, [1, 2, 3]) 131 | 132 | args = AppendConstTap().parse_args("--arg --arg".split()) 133 | self.assertEqual(args.arg, [1, 2, 3, 7, 7]) 134 | 135 | def test_actions_count(self): 136 | class CountTap(Tap): 137 | arg = 7 138 | 139 | def configure(self): 140 | self.add_argument("--arg", "-a", action="count") 141 | 142 | args = CountTap().parse_args([]) 143 | self.assertEqual(args.arg, 7) 144 | 145 | args = CountTap().parse_args("-aaa --arg".split()) 146 | self.assertEqual(args.arg, 11) 147 | 148 | def test_actions_int_count(self): 149 | class CountIntTap(Tap): 150 | arg: int = 7 151 | 152 | def configure(self): 153 | self.add_argument("--arg", "-a", action="count") 154 | 155 | args = CountIntTap().parse_args([]) 156 | self.assertEqual(args.arg, 7) 157 | 158 | args = CountIntTap().parse_args("-aaa --arg".split()) 159 | self.assertEqual(args.arg, 11) 160 | 161 | def test_actions_version(self): 162 | class VersionTap(Tap): 163 | def configure(self): 164 | self.add_argument("--version", action="version", version="2.0") 165 | 166 | # Ensure that nothing breaks without version flag 167 | VersionTap().parse_args([]) 168 | 169 | # TODO: With version flag testing fails, but manual tests work 170 | # tried redirecting stderr using unittest.mock.patch 171 | # VersionTap().parse_args(['--version']) 172 | 173 | def test_actions_extend(self): 174 | class ExtendTap(Tap): 175 | arg = [1, 2] 176 | 177 | def configure(self): 178 | self.add_argument("--arg", nargs="+", action="extend") 179 | 180 | args = ExtendTap().parse_args([]) 181 | self.assertEqual(args.arg, [1, 2]) 182 | 183 | args = ExtendTap().parse_args("--arg a b --arg a --arg c d".split()) 184 | self.assertEqual(args.arg, [1, 2] + "a b a c d".split()) 185 | 186 | def test_actions_extend_list(self): 187 | class ExtendListTap(Tap): 188 | arg: List = ["hi"] 189 | 190 | def configure(self): 191 | self.add_argument("--arg", action="extend") 192 | 193 | args = ExtendListTap().parse_args("--arg yo yo --arg yoyo --arg yo yo".split()) 194 | self.assertEqual(args.arg, "hi yo yo yoyo yo yo".split()) 195 | 196 | def test_actions_extend_list_int(self): 197 | class ExtendListIntTap(Tap): 198 | arg: List[int] = [0] 199 | 200 | def configure(self): 201 | self.add_argument("--arg", action="extend") 202 | 203 | args = ExtendListIntTap().parse_args("--arg 1 2 --arg 3 --arg 4 5".split()) 204 | self.assertEqual(args.arg, [0, 1, 2, 3, 4, 5]) 205 | 206 | def test_positional_default(self): 207 | class PositionalDefault(Tap): 208 | arg: str 209 | 210 | def configure(self): 211 | self.add_argument("arg") 212 | 213 | help_regex = r".*positional arguments:\n.*arg\s*\(str, required\).*" 214 | help_text = PositionalDefault().format_help() 215 | self.assertRegex(help_text, help_regex) 216 | 217 | 218 | if __name__ == "__main__": 219 | unittest.main() 220 | -------------------------------------------------------------------------------- /tests/test_load_config_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tempfile import TemporaryDirectory 4 | import unittest 5 | from unittest import TestCase 6 | 7 | from tap import Tap 8 | 9 | 10 | # Suppress prints from SystemExit 11 | class DevNull: 12 | def write(self, msg): 13 | pass 14 | 15 | 16 | sys.stderr = DevNull() 17 | 18 | 19 | class LoadConfigFilesTests(TestCase): 20 | def test_file_does_not_exist(self) -> None: 21 | class EmptyTap(Tap): 22 | pass 23 | 24 | with self.assertRaises(FileNotFoundError): 25 | EmptyTap(config_files=["nope"]).parse_args([]) 26 | 27 | def test_single_config(self) -> None: 28 | class SimpleTap(Tap): 29 | a: int 30 | b: str = "b" 31 | 32 | with TemporaryDirectory() as temp_dir: 33 | fname = os.path.join(temp_dir, "config.txt") 34 | 35 | with open(fname, "w") as f: 36 | f.write("--a 1") 37 | 38 | args = SimpleTap(config_files=[fname]).parse_args([]) 39 | 40 | self.assertEqual(args.a, 1) 41 | self.assertEqual(args.b, "b") 42 | 43 | def test_single_config_overwriting(self) -> None: 44 | class SimpleOverwritingTap(Tap): 45 | a: int 46 | b: str = "b" 47 | 48 | with TemporaryDirectory() as temp_dir: 49 | fname = os.path.join(temp_dir, "config.txt") 50 | 51 | with open(fname, "w") as f: 52 | f.write("--a 1 --b two") 53 | 54 | args = SimpleOverwritingTap(config_files=[fname]).parse_args("--a 2".split()) 55 | 56 | self.assertEqual(args.a, 2) 57 | self.assertEqual(args.b, "two") 58 | 59 | def test_single_config_known_only(self) -> None: 60 | class KnownOnlyTap(Tap): 61 | a: int 62 | b: str = "b" 63 | 64 | with TemporaryDirectory() as temp_dir: 65 | fname = os.path.join(temp_dir, "config.txt") 66 | 67 | with open(fname, "w") as f: 68 | f.write("--a 1 --c seeNothing") 69 | 70 | args = KnownOnlyTap(config_files=[fname]).parse_args([], known_only=True) 71 | 72 | self.assertEqual(args.a, 1) 73 | self.assertEqual(args.b, "b") 74 | self.assertEqual(args.extra_args, ["--c", "seeNothing"]) 75 | 76 | def test_single_config_required_still_required(self) -> None: 77 | class KnownOnlyTap(Tap): 78 | a: int 79 | b: str = "b" 80 | 81 | with TemporaryDirectory() as temp_dir, self.assertRaises(SystemExit): 82 | fname = os.path.join(temp_dir, "config.txt") 83 | 84 | with open(fname, "w") as f: 85 | f.write("--b fore") 86 | 87 | KnownOnlyTap(config_files=[fname]).parse_args([]) 88 | 89 | def test_multiple_configs(self) -> None: 90 | class MultipleTap(Tap): 91 | a: int 92 | b: str = "b" 93 | 94 | with TemporaryDirectory() as temp_dir: 95 | fname1, fname2 = os.path.join(temp_dir, "config1.txt"), os.path.join(temp_dir, "config2.txt") 96 | 97 | with open(fname1, "w") as f1, open(fname2, "w") as f2: 98 | f1.write("--b two") 99 | f2.write("--a 1") 100 | 101 | args = MultipleTap(config_files=[fname1, fname2]).parse_args([]) 102 | 103 | self.assertEqual(args.a, 1) 104 | self.assertEqual(args.b, "two") 105 | 106 | def test_multiple_configs_overwriting(self) -> None: 107 | class MultipleOverwritingTap(Tap): 108 | a: int 109 | b: str = "b" 110 | c: str = "c" 111 | 112 | with TemporaryDirectory() as temp_dir: 113 | fname1, fname2 = os.path.join(temp_dir, "config1.txt"), os.path.join(temp_dir, "config2.txt") 114 | 115 | with open(fname1, "w") as f1, open(fname2, "w") as f2: 116 | f1.write("--a 1 --b two") 117 | f2.write("--a 2 --c see") 118 | 119 | args = MultipleOverwritingTap(config_files=[fname1, fname2]).parse_args("--b four".split()) 120 | 121 | self.assertEqual(args.a, 2) 122 | self.assertEqual(args.b, "four") 123 | self.assertEqual(args.c, "see") 124 | 125 | def test_junk_config(self) -> None: 126 | class JunkConfigTap(Tap): 127 | a: int 128 | b: str = "b" 129 | 130 | with TemporaryDirectory() as temp_dir, self.assertRaises(SystemExit): 131 | fname = os.path.join(temp_dir, "config.txt") 132 | 133 | with open(fname, "w") as f: 134 | f.write("is not a file that can reasonably be parsed") 135 | 136 | JunkConfigTap(config_files=[fname]).parse_args([]) 137 | 138 | def test_shlex_config(self) -> None: 139 | class ShlexConfigTap(Tap): 140 | a: int 141 | b: str 142 | 143 | with TemporaryDirectory() as temp_dir: 144 | fname = os.path.join(temp_dir, "config.txt") 145 | 146 | with open(fname, "w") as f: 147 | f.write('--a 21 # Important arg value\n\n# Multi-word quoted string\n--b "two three four"') 148 | 149 | args = ShlexConfigTap(config_files=[fname]).parse_args([]) 150 | 151 | self.assertEqual(args.a, 21) 152 | self.assertEqual(args.b, "two three four") 153 | 154 | 155 | if __name__ == "__main__": 156 | unittest.main() 157 | -------------------------------------------------------------------------------- /tests/test_subparser.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentError 2 | import sys 3 | from typing import Literal, Union 4 | import unittest 5 | from unittest import TestCase 6 | 7 | from tap import Tap 8 | 9 | 10 | # Suppress prints from SystemExit 11 | class DevNull: 12 | def write(self, msg): 13 | pass 14 | 15 | 16 | sys.stderr = DevNull() 17 | 18 | 19 | class TestSubparser(TestCase): 20 | def test_subparser_documentation_example(self): 21 | class SubparserA(Tap): 22 | bar: int # bar help 23 | 24 | class SubparserB(Tap): 25 | baz: Literal["X", "Y", "Z"] # baz help 26 | 27 | class Args(Tap): 28 | foo: bool = False # foo help 29 | 30 | def configure(self): 31 | self.add_subparsers(help="sub-command help") 32 | self.add_subparser("a", SubparserA, help="a help") 33 | self.add_subparser("b", SubparserB, help="b help") 34 | 35 | args = Args().parse_args([]) 36 | self.assertFalse(args.foo) 37 | self.assertFalse(hasattr(args, "bar")) 38 | self.assertFalse(hasattr(args, "baz")) 39 | 40 | args = Args().parse_args(["--foo"]) 41 | self.assertTrue(args.foo) 42 | self.assertFalse(hasattr(args, "bar")) 43 | self.assertFalse(hasattr(args, "baz")) 44 | 45 | args = Args().parse_args("a --bar 1".split()) 46 | self.assertFalse(args.foo) 47 | self.assertEqual(args.bar, 1) 48 | self.assertFalse(hasattr(args, "baz")) 49 | 50 | args = Args().parse_args("--foo b --baz X".split()) 51 | self.assertTrue(args.foo) 52 | self.assertFalse(hasattr(args, "bar")) 53 | self.assertEqual(args.baz, "X") 54 | 55 | with self.assertRaises(SystemExit): 56 | Args().parse_args("--baz X --foo b".split()) 57 | 58 | with self.assertRaises(SystemExit): 59 | Args().parse_args("b --baz X --foo".split()) 60 | 61 | with self.assertRaises(SystemExit): 62 | Args().parse_args("--foo a --bar 1 b --baz X".split()) 63 | 64 | def test_name_collision(self): 65 | class SubparserA(Tap): 66 | a: int 67 | 68 | class Args(Tap): 69 | foo: bool = False 70 | 71 | def configure(self): 72 | self.add_subparsers(help="sub-command help") 73 | self.add_subparser("a", SubparserA, help="a help") 74 | 75 | args = Args().parse_args("a --a 1".split()) 76 | self.assertFalse(args.foo) 77 | self.assertEqual(args.a, 1) 78 | 79 | def test_name_overriding(self): 80 | class SubparserA(Tap): 81 | foo: int 82 | 83 | class Args(Tap): 84 | foo: bool = False 85 | 86 | def configure(self): 87 | self.add_subparsers(help="sub-command help") 88 | self.add_subparser("a", SubparserA) 89 | 90 | args = Args().parse_args(["--foo"]) 91 | self.assertTrue(args.foo) 92 | 93 | args = Args().parse_args("a --foo 2".split()) 94 | self.assertEqual(args.foo, 2) 95 | 96 | args = Args().parse_args("--foo a --foo 2".split()) 97 | self.assertEqual(args.foo, 2) 98 | 99 | def test_add_subparser_twice(self): 100 | class SubparserA(Tap): 101 | bar: int 102 | 103 | class SubparserB(Tap): 104 | baz: int 105 | 106 | class Args(Tap): 107 | foo: bool = False 108 | 109 | def configure(self): 110 | self.add_subparser("a", SubparserB) 111 | self.add_subparser("a", SubparserA) 112 | 113 | if sys.version_info >= (3, 11): 114 | with self.assertRaises(ArgumentError): 115 | Args().parse_args([]) 116 | else: 117 | args = Args().parse_args("a --bar 2".split()) 118 | self.assertFalse(args.foo) 119 | self.assertEqual(args.bar, 2) 120 | self.assertFalse(hasattr(args, "baz")) 121 | 122 | with self.assertRaises(SystemExit): 123 | Args().parse_args("a --baz 2".split()) 124 | 125 | def test_add_subparsers_twice(self): 126 | class SubparserA(Tap): 127 | a: int 128 | 129 | class Args(Tap): 130 | foo: bool = False 131 | 132 | def configure(self): 133 | self.add_subparser("a", SubparserA) 134 | self.add_subparsers(help="sub-command1 help") 135 | self.add_subparsers(help="sub-command2 help") 136 | 137 | if sys.version_info >= (3, 12, 5): 138 | with self.assertRaises(ArgumentError): 139 | Args().parse_args([]) 140 | else: 141 | with self.assertRaises(SystemExit): 142 | Args().parse_args([]) 143 | 144 | def test_add_subparsers_with_add_argument(self): 145 | class SubparserA(Tap): 146 | for_sure: bool = False 147 | 148 | class Args(Tap): 149 | foo: bool = False 150 | bar: int = 1 151 | 152 | def configure(self): 153 | self.add_argument("--bar", "-ib") 154 | self.add_subparser("is_terrible", SubparserA) 155 | self.add_argument("--foo", "-m") 156 | 157 | args = Args().parse_args("-ib 0 -m is_terrible --for_sure".split()) 158 | self.assertTrue(args.foo) 159 | self.assertEqual(args.bar, 0) 160 | self.assertTrue(args.for_sure) 161 | 162 | def test_add_subsubparsers(self): 163 | class SubSubparserB(Tap): 164 | baz: bool = False 165 | 166 | class SubparserA(Tap): 167 | biz: bool = False 168 | 169 | def configure(self): 170 | self.add_subparser("b", SubSubparserB) 171 | 172 | class SubparserB(Tap): 173 | blaz: bool = False 174 | 175 | class Args(Tap): 176 | foo: bool = False 177 | 178 | def configure(self): 179 | self.add_subparser("a", SubparserA) 180 | self.add_subparser("b", SubparserB) 181 | 182 | args = Args().parse_args("b --blaz".split()) 183 | self.assertFalse(args.foo) 184 | self.assertFalse(hasattr(args, "baz")) 185 | self.assertFalse(hasattr(args, "biz")) 186 | self.assertTrue(args.blaz) 187 | 188 | args = Args().parse_args("a --biz".split()) 189 | self.assertFalse(args.foo) 190 | self.assertTrue(args.biz) 191 | self.assertFalse(hasattr(args, "baz")) 192 | self.assertFalse(hasattr(args, "blaz")) 193 | 194 | args = Args().parse_args("a --biz b --baz".split()) 195 | self.assertFalse(args.foo) 196 | self.assertTrue(args.biz) 197 | self.assertFalse(hasattr(args, "blaz")) 198 | self.assertTrue(args.baz) 199 | 200 | with self.assertRaises(SystemExit): 201 | Args().parse_args("b a".split()) 202 | 203 | def test_subparser_underscores_to_dashes(self): 204 | class AddProposal(Tap): 205 | proposal_id: int 206 | 207 | class Arguments(Tap): 208 | def configure(self) -> None: 209 | self.add_subparsers(dest="subparser_name") 210 | 211 | self.add_subparser( 212 | "add-proposal", AddProposal, help="Add a new proposal", 213 | ) 214 | 215 | args_underscores: Union[Arguments, AddProposal] = Arguments(underscores_to_dashes=False).parse_args( 216 | "add-proposal --proposal_id 1".split() 217 | ) 218 | self.assertEqual(args_underscores.proposal_id, 1) 219 | 220 | args_dashes: Union[Arguments, AddProposal] = Arguments(underscores_to_dashes=True).parse_args( 221 | "add-proposal --proposal-id 1".split() 222 | ) 223 | self.assertEqual(args_dashes.proposal_id, 1) 224 | 225 | with self.assertRaises(SystemExit): 226 | Arguments(underscores_to_dashes=False).parse_args("add-proposal --proposal-id 1".split()) 227 | 228 | with self.assertRaises(SystemExit): 229 | Arguments(underscores_to_dashes=True).parse_args("add-proposal --proposal_id 1".split()) 230 | 231 | 232 | if __name__ == "__main__": 233 | unittest.main() 234 | -------------------------------------------------------------------------------- /tests/test_to_tap_class.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests `tap.to_tap_class`. 3 | """ 4 | 5 | from contextlib import redirect_stdout, redirect_stderr 6 | import dataclasses 7 | import io 8 | import re 9 | import sys 10 | from typing import Any, Callable, List, Literal, Optional, Type, Union 11 | 12 | from packaging.version import Version 13 | import pytest 14 | 15 | from tap import to_tap_class, Tap 16 | from tap.utils import type_to_str 17 | 18 | 19 | try: 20 | import pydantic 21 | except ModuleNotFoundError: 22 | _IS_PYDANTIC_V1 = None 23 | else: 24 | _IS_PYDANTIC_V1 = Version(pydantic.__version__) < Version("2.0.0") 25 | 26 | 27 | # To properly test the help message, we need to know how argparse formats it. It changed from 3.9 -> 3.10 -> 3.13 28 | _OPTIONS_TITLE = "options" if not sys.version_info < (3, 10) else "optional arguments" 29 | _ARG_LIST_DOTS = "..." 30 | _ARG_WITH_ALIAS = ( 31 | "-arg, --argument_with_really_long_name ARGUMENT_WITH_REALLY_LONG_NAME" 32 | if not sys.version_info < (3, 13) 33 | else "-arg ARGUMENT_WITH_REALLY_LONG_NAME, --argument_with_really_long_name ARGUMENT_WITH_REALLY_LONG_NAME" 34 | ) 35 | 36 | 37 | @dataclasses.dataclass 38 | class _Args: 39 | """ 40 | These are the arguments which every type of class or function must contain. 41 | """ 42 | 43 | arg_int: int = dataclasses.field(metadata=dict(description="some integer")) 44 | arg_bool: bool = True 45 | arg_list: Optional[List[str]] = dataclasses.field(default=None, metadata=dict(description="some list of strings")) 46 | 47 | 48 | def _monkeypatch_eq(cls): 49 | """ 50 | Monkey-patches `cls.__eq__` to check that the attribute values are equal to a dataclass representation of them. 51 | """ 52 | 53 | def _equality(self, other: _Args) -> bool: 54 | return _Args(self.arg_int, arg_bool=self.arg_bool, arg_list=self.arg_list) == other 55 | 56 | cls.__eq__ = _equality 57 | return cls 58 | 59 | 60 | # Define a few different classes or functions which all take the same arguments (same by name, annotation, and default 61 | # if not required) 62 | 63 | 64 | def function(arg_int: int, arg_bool: bool = True, arg_list: Optional[List[str]] = None) -> _Args: 65 | """ 66 | :param arg_int: some integer 67 | :param arg_list: some list of strings 68 | """ 69 | return _Args(arg_int, arg_bool=arg_bool, arg_list=arg_list) 70 | 71 | 72 | @_monkeypatch_eq 73 | class Class: 74 | def __init__(self, arg_int: int, arg_bool: bool = True, arg_list: Optional[List[str]] = None): 75 | """ 76 | :param arg_int: some integer 77 | :param arg_list: some list of strings 78 | """ 79 | self.arg_int = arg_int 80 | self.arg_bool = arg_bool 81 | self.arg_list = arg_list 82 | 83 | 84 | DataclassBuiltin = _Args 85 | 86 | 87 | if _IS_PYDANTIC_V1 is None: 88 | pass # will raise NameError if attempting to use DataclassPydantic or Model later 89 | elif _IS_PYDANTIC_V1: 90 | # For Pydantic v1 data models, we rely on the docstring to get descriptions 91 | 92 | @_monkeypatch_eq 93 | @pydantic.dataclasses.dataclass 94 | class DataclassPydantic: 95 | """ 96 | Dataclass (pydantic v1) 97 | 98 | :param arg_int: some integer 99 | :param arg_list: some list of strings 100 | """ 101 | 102 | arg_int: int 103 | arg_bool: bool = True 104 | arg_list: Optional[List[str]] = None 105 | 106 | @_monkeypatch_eq 107 | class Model(pydantic.BaseModel): 108 | """ 109 | Pydantic model (pydantic v1) 110 | 111 | :param arg_int: some integer 112 | :param arg_list: some list of strings 113 | """ 114 | 115 | arg_int: int 116 | arg_bool: bool = True 117 | arg_list: Optional[List[str]] = None 118 | 119 | else: 120 | # For pydantic v2 data models, we check the docstring and Field for the description 121 | 122 | @_monkeypatch_eq 123 | @pydantic.dataclasses.dataclass 124 | class DataclassPydantic: 125 | """ 126 | Dataclass (pydantic) 127 | 128 | :param arg_list: some list of strings 129 | """ 130 | 131 | # Mixing field types should be ok 132 | arg_int: int = pydantic.dataclasses.Field(description="some integer") 133 | arg_bool: bool = dataclasses.field(default=True) 134 | arg_list: Optional[List[str]] = pydantic.Field(default=None) 135 | 136 | @_monkeypatch_eq 137 | class Model(pydantic.BaseModel): 138 | """ 139 | Pydantic model 140 | 141 | :param arg_int: some integer 142 | """ 143 | 144 | # Mixing field types should be ok 145 | arg_int: int 146 | arg_bool: bool = dataclasses.field(default=True) 147 | arg_list: Optional[List[str]] = pydantic.dataclasses.Field(default=None, description="some list of strings") 148 | 149 | 150 | @pytest.fixture( 151 | scope="module", 152 | params=[ 153 | function, 154 | Class, 155 | DataclassBuiltin, 156 | DataclassBuiltin( 157 | 1, arg_bool=False, arg_list=["these", "values", "don't", "matter"] 158 | ), # to_tap_class also works on instances of data models. It ignores the attribute values 159 | ] 160 | + ([] if _IS_PYDANTIC_V1 is None else [DataclassPydantic, Model]), 161 | # NOTE: instances of DataclassPydantic and Model can be tested for pydantic v2 but not v1 162 | ) 163 | def class_or_function_(request: pytest.FixtureRequest): 164 | """ 165 | Parametrized class_or_function. 166 | """ 167 | return request.param 168 | 169 | 170 | # Define some functions which take a class or function and calls `tap.to_tap_class` on it to create a `tap.Tap` 171 | # subclass (class, not instance) 172 | 173 | 174 | def subclasser_simple(class_or_function: Any) -> Type[Tap]: 175 | """ 176 | Plain subclass, does nothing extra. 177 | """ 178 | return to_tap_class(class_or_function) 179 | 180 | 181 | def subclasser_complex(class_or_function): 182 | """ 183 | It's conceivable that someone has a data model, but they want to add more arguments or handling when running a 184 | script. 185 | """ 186 | 187 | def to_number(string: str) -> Union[float, int]: 188 | return float(string) if "." in string else int(string) 189 | 190 | class TapSubclass(to_tap_class(class_or_function)): 191 | # You can supply additional arguments here 192 | argument_with_really_long_name: Union[float, int] = 3 193 | "This argument has a long name and will be aliased with a short one" 194 | 195 | def configure(self) -> None: 196 | # You can still add special argument behavior 197 | self.add_argument("-arg", "--argument_with_really_long_name", type=to_number) 198 | 199 | def process_args(self) -> None: 200 | # You can still validate and modify arguments 201 | if self.argument_with_really_long_name > 4: 202 | raise ValueError("argument_with_really_long_name cannot be > 4") 203 | 204 | # No auto-complete (and other niceties) for the super class attributes b/c this is a dynamic subclass. Sorry 205 | if self.arg_bool and self.arg_list is not None: 206 | self.arg_list.append("processed") 207 | 208 | return TapSubclass 209 | 210 | 211 | def subclasser_subparser(class_or_function): 212 | class SubparserA(Tap): 213 | bar: int # bar help 214 | 215 | class SubparserB(Tap): 216 | baz: Literal["X", "Y", "Z"] # baz help 217 | 218 | class TapSubclass(to_tap_class(class_or_function)): 219 | foo: bool = False # foo help 220 | 221 | def configure(self): 222 | self.add_subparsers(help="sub-command help") 223 | self.add_subparser("a", SubparserA, help="a help", description="Description (a)") 224 | self.add_subparser("b", SubparserB, help="b help") 225 | 226 | return TapSubclass 227 | 228 | 229 | # Test that the subclasser parses the args correctly or raises the correct error. 230 | # The subclassers are tested separately b/c the parametrizaiton of args_string_and_arg_to_expected_value depends on the 231 | # subclasser. 232 | # First, some helper functions. 233 | 234 | 235 | def _test_raises_system_exit(tap: Tap, args_string: str) -> str: 236 | is_help = ( 237 | args_string.endswith("-h") 238 | or args_string.endswith("--help") 239 | or " -h " in args_string 240 | or " --help " in args_string 241 | ) 242 | f = io.StringIO() 243 | with redirect_stdout(f) if is_help else redirect_stderr(f): 244 | with pytest.raises(SystemExit): 245 | tap.parse_args(args_string.split()) 246 | 247 | return f.getvalue() 248 | 249 | 250 | def _test_subclasser( 251 | subclasser: Callable[[Any], Type[Tap]], 252 | class_or_function: Any, 253 | args_string_and_arg_to_expected_value: tuple[str, Union[dict[str, Any], BaseException]], 254 | test_call: bool = True, 255 | ): 256 | """ 257 | Tests that the `subclasser` converts `class_or_function` to a `Tap` class which parses the argument string 258 | correctly. 259 | 260 | Setting `test_call=True` additionally tests that calling the `class_or_function` on the parsed arguments works. 261 | """ 262 | args_string, arg_to_expected_value = args_string_and_arg_to_expected_value 263 | TapSubclass = subclasser(class_or_function) 264 | assert issubclass(TapSubclass, Tap) 265 | tap = TapSubclass(description="Script description") 266 | 267 | if isinstance(arg_to_expected_value, SystemExit): 268 | stderr = _test_raises_system_exit(tap, args_string) 269 | assert re.search(str(arg_to_expected_value), stderr) 270 | elif isinstance(arg_to_expected_value, BaseException): 271 | expected_exception = arg_to_expected_value.__class__ 272 | expected_error_message = str(arg_to_expected_value) or None 273 | with pytest.raises(expected_exception=expected_exception, match=expected_error_message): 274 | args = tap.parse_args(args_string.split()) 275 | else: 276 | # args_string is a valid argument combo 277 | # Test that parsing works correctly 278 | args = tap.parse_args(args_string.split()) 279 | assert arg_to_expected_value == args.as_dict() 280 | if test_call and callable(class_or_function): 281 | result = class_or_function(**args.as_dict()) 282 | assert result == _Args(**arg_to_expected_value) 283 | 284 | 285 | def _test_subclasser_message( 286 | subclasser: Callable[[Any], Type[Tap]], 287 | class_or_function: Any, 288 | message_expected: str, 289 | description: str = "Script description", 290 | args_string: str = "-h", 291 | ): 292 | """ 293 | Tests that:: 294 | 295 | subclasser(class_or_function)(description=description).parse_args(args_string.split()) 296 | 297 | outputs `message_expected` to stdout, ignoring differences in whitespaces/newlines/tabs. 298 | """ 299 | 300 | def replace_whitespace(string: str) -> str: 301 | return re.sub(r"\s+", " ", string).strip() # FYI this line was written by an LLM 302 | 303 | TapSubclass = subclasser(class_or_function) 304 | tap = TapSubclass(description=description) 305 | message = _test_raises_system_exit(tap, args_string) 306 | # Standardize to ignore trivial differences due to terminal settings 307 | assert replace_whitespace(message) == replace_whitespace(message_expected) 308 | 309 | 310 | # Test sublcasser_simple 311 | 312 | 313 | @pytest.mark.parametrize( 314 | "args_string_and_arg_to_expected_value", 315 | [ 316 | ( 317 | "--arg_int 1 --arg_list x y z", 318 | {"arg_int": 1, "arg_bool": True, "arg_list": ["x", "y", "z"]}, 319 | ), 320 | ( 321 | "--arg_int 1 --arg_bool", 322 | {"arg_int": 1, "arg_bool": False, "arg_list": None}, 323 | ), 324 | # The rest are invalid argument combos, as indicated by the 2nd elt being a BaseException instance 325 | ( 326 | "--arg_list x y z --arg_bool", 327 | SystemExit("error: the following arguments are required: --arg_int"), 328 | ), 329 | ( 330 | "--arg_int not_an_int --arg_list x y z --arg_bool", 331 | SystemExit("error: argument --arg_int: invalid int value: 'not_an_int'"), 332 | ), 333 | ], 334 | ) 335 | def test_subclasser_simple( 336 | class_or_function_: Any, args_string_and_arg_to_expected_value: tuple[str, Union[dict[str, Any], BaseException]] 337 | ): 338 | _test_subclasser(subclasser_simple, class_or_function_, args_string_and_arg_to_expected_value) 339 | 340 | 341 | def test_subclasser_simple_help_message(class_or_function_: Any): 342 | description = "Script description" 343 | help_message_expected = f""" 344 | usage: pytest --arg_int ARG_INT [--arg_bool] [--arg_list [ARG_LIST {_ARG_LIST_DOTS}]] [-h] 345 | 346 | {description} 347 | 348 | {_OPTIONS_TITLE}: 349 | --arg_int ARG_INT (int, required) some integer 350 | --arg_bool (bool, default=True) 351 | --arg_list [ARG_LIST {_ARG_LIST_DOTS}] 352 | ({type_to_str(Optional[List[str]])}, default=None) some list of strings 353 | -h, --help show this help message and exit 354 | """ 355 | _test_subclasser_message(subclasser_simple, class_or_function_, help_message_expected, description=description) 356 | 357 | 358 | # Test subclasser_complex 359 | 360 | 361 | @pytest.mark.parametrize( 362 | "args_string_and_arg_to_expected_value", 363 | [ 364 | ( 365 | "--arg_int 1 --arg_list x y z", 366 | { 367 | "arg_int": 1, 368 | "arg_bool": True, 369 | "arg_list": ["x", "y", "z", "processed"], 370 | "argument_with_really_long_name": 3, 371 | }, 372 | ), 373 | ( 374 | "--arg_int 1 --arg_list x y z -arg 2", 375 | { 376 | "arg_int": 1, 377 | "arg_bool": True, 378 | "arg_list": ["x", "y", "z", "processed"], 379 | "argument_with_really_long_name": 2, 380 | }, 381 | ), 382 | ( 383 | "--arg_int 1 --arg_bool --argument_with_really_long_name 2.3", 384 | { 385 | "arg_int": 1, 386 | "arg_bool": False, 387 | "arg_list": None, 388 | "argument_with_really_long_name": 2.3, 389 | }, 390 | ), 391 | # The rest are invalid argument combos, as indicated by the 2nd elt being a BaseException instance 392 | ( 393 | "--arg_list x y z --arg_bool", 394 | SystemExit("error: the following arguments are required: --arg_int"), 395 | ), 396 | ( 397 | "--arg_int 1 --arg_list x y z -arg not_a_float_or_int", 398 | SystemExit( 399 | "error: argument -arg/--argument_with_really_long_name: invalid to_number value: 'not_a_float_or_int'" 400 | ), 401 | ), 402 | ( 403 | "--arg_int 1 --arg_list x y z -arg 5", # Wrong value arg (aliases argument_with_really_long_name) 404 | ValueError("argument_with_really_long_name cannot be > 4"), 405 | ), 406 | ], 407 | ) 408 | def test_subclasser_complex( 409 | class_or_function_: Any, args_string_and_arg_to_expected_value: tuple[str, Union[dict[str, Any], BaseException]] 410 | ): 411 | # Currently setting test_call=False b/c all data models except the pydantic Model don't accept extra args 412 | _test_subclasser(subclasser_complex, class_or_function_, args_string_and_arg_to_expected_value, test_call=False) 413 | 414 | 415 | def test_subclasser_complex_help_message(class_or_function_: Any): 416 | description = "Script description" 417 | help_message_expected = f""" 418 | usage: pytest [-arg ARGUMENT_WITH_REALLY_LONG_NAME] --arg_int ARG_INT [--arg_bool] 419 | [--arg_list [ARG_LIST {_ARG_LIST_DOTS}]] [-h] 420 | 421 | {description} 422 | 423 | {_OPTIONS_TITLE}: 424 | {_ARG_WITH_ALIAS} 425 | (Union[float, int], default=3) This argument has a long name and will be aliased with a short 426 | one 427 | --arg_int ARG_INT (int, required) some integer 428 | --arg_bool (bool, default=True) 429 | --arg_list [ARG_LIST {_ARG_LIST_DOTS}] 430 | ({type_to_str(Optional[List[str]])}, default=None) some list of strings 431 | -h, --help show this help message and exit 432 | """ 433 | _test_subclasser_message(subclasser_complex, class_or_function_, help_message_expected, description=description) 434 | 435 | 436 | # Test subclasser_subparser 437 | 438 | 439 | @pytest.mark.parametrize( 440 | "args_string_and_arg_to_expected_value", 441 | [ 442 | ( 443 | "--arg_int 1", 444 | {"arg_int": 1, "arg_bool": True, "arg_list": None, "foo": False}, 445 | ), 446 | ( 447 | "--arg_int 1 a --bar 2", 448 | {"arg_int": 1, "arg_bool": True, "arg_list": None, "bar": 2, "foo": False}, 449 | ), 450 | ( 451 | "--arg_int 1 --foo a --bar 2", 452 | {"arg_int": 1, "arg_bool": True, "arg_list": None, "bar": 2, "foo": True}, 453 | ), 454 | ( 455 | "--arg_int 1 b --baz X", 456 | {"arg_int": 1, "arg_bool": True, "arg_list": None, "baz": "X", "foo": False}, 457 | ), 458 | ( 459 | "--foo --arg_bool --arg_list x y z --arg_int 1 b --baz Y", 460 | {"arg_int": 1, "arg_bool": False, "arg_list": ["x", "y", "z"], "baz": "Y", "foo": True}, 461 | ), 462 | # The rest are invalid argument combos, as indicated by the 2nd elt being a BaseException instance 463 | ( 464 | "a --bar 1", 465 | SystemExit("error: the following arguments are required: --arg_int"), 466 | ), 467 | ( 468 | "--arg_int not_an_int a --bar 1", 469 | SystemExit("error: argument --arg_int: invalid int value: 'not_an_int'"), 470 | ), 471 | ( 472 | "--arg_int 1 --baz X --foo b", 473 | SystemExit( 474 | r"error: argument \{a,b}: invalid choice: 'X' \(choose from '?a'?, '?b'?\)" 475 | ), 476 | ), 477 | ( 478 | "--arg_int 1 b --baz X --foo", 479 | SystemExit("error: unrecognized arguments: --foo"), 480 | ), 481 | ( 482 | "--arg_int 1 --foo b --baz A", 483 | SystemExit(r"""error: argument --baz: Value for variable "baz" must be one of \['X', 'Y', 'Z']."""), 484 | ), 485 | ], 486 | ) 487 | def test_subclasser_subparser( 488 | class_or_function_: Any, args_string_and_arg_to_expected_value: tuple[str, Union[dict[str, Any], BaseException]] 489 | ): 490 | # Currently setting test_call=False b/c all data models except the pydantic Model don't accept extra args 491 | _test_subclasser(subclasser_subparser, class_or_function_, args_string_and_arg_to_expected_value, test_call=False) 492 | 493 | 494 | @pytest.mark.parametrize( 495 | "args_string_and_description_and_expected_message", 496 | [ 497 | ( 498 | "-h", 499 | "Script description", 500 | f""" 501 | usage: pytest [--foo] --arg_int ARG_INT [--arg_bool] [--arg_list [ARG_LIST {_ARG_LIST_DOTS}]] [-h] 502 | {{a,b}} ... 503 | 504 | Script description 505 | 506 | positional arguments: 507 | {{a,b}} sub-command help 508 | a a help 509 | b b help 510 | 511 | {_OPTIONS_TITLE}: 512 | --foo (bool, default=False) foo help 513 | --arg_int ARG_INT (int, required) some integer 514 | --arg_bool (bool, default=True) 515 | --arg_list [ARG_LIST {_ARG_LIST_DOTS}] 516 | ({type_to_str(Optional[List[str]])}, default=None) some list of strings 517 | -h, --help show this help message and exit 518 | """, 519 | ), 520 | ( 521 | "a -h", 522 | "Description (a)", 523 | f""" 524 | usage: pytest a --bar BAR [-h] 525 | 526 | Description (a) 527 | 528 | {_OPTIONS_TITLE}: 529 | --bar BAR (int, required) bar help 530 | -h, --help show this help message and exit 531 | """, 532 | ), 533 | ( 534 | "b -h", 535 | "", # no description 536 | f""" 537 | usage: pytest b --baz {{X,Y,Z}} [-h] 538 | 539 | {_OPTIONS_TITLE}: 540 | --baz {{X,Y,Z}} (Literal['X', 'Y', 'Z'], required) baz help 541 | -h, --help show this help message and exit 542 | """, 543 | ), 544 | ], 545 | ) 546 | def test_subclasser_subparser_help_message( 547 | class_or_function_: Any, args_string_and_description_and_expected_message: tuple[str, str] 548 | ): 549 | args_string, description, expected_message = args_string_and_description_and_expected_message 550 | _test_subclasser_message( 551 | subclasser_subparser, class_or_function_, expected_message, description=description, args_string=args_string 552 | ) 553 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentTypeError 2 | import inspect 3 | import json 4 | import os 5 | import subprocess 6 | from tempfile import TemporaryDirectory 7 | from typing import Any, Callable, List, Literal, Dict, Set, Tuple, Union 8 | import unittest 9 | from unittest import TestCase 10 | 11 | from tap.utils import ( 12 | get_class_column, 13 | get_class_variables, 14 | GitInfo, 15 | tokenize_source, 16 | type_to_str, 17 | get_literals, 18 | TupleTypeEnforcer, 19 | _nested_replace_type, 20 | define_python_object_encoder, 21 | UnpicklableObject, 22 | as_python_object, 23 | enforce_reproducibility, 24 | ) 25 | 26 | 27 | class GitTests(TestCase): 28 | def setUp(self) -> None: 29 | self.temp_dir = TemporaryDirectory() 30 | self.prev_dir = os.getcwd() 31 | os.chdir(self.temp_dir.name) 32 | subprocess.check_output(["git", "init"]) 33 | self.url = "https://github.com/test_account/test_repo" 34 | subprocess.check_output(["git", "remote", "add", "origin", f"{self.url}.git"]) 35 | subprocess.check_output(["touch", "README.md"]) 36 | subprocess.check_output(["git", "add", "README.md"]) 37 | subprocess.check_output(["git", "commit", "-m", "Initial commit"]) 38 | self.git_info = GitInfo(repo_path=self.temp_dir.name) 39 | 40 | def tearDown(self) -> None: 41 | os.chdir(self.prev_dir) 42 | 43 | # Add permissions to temporary directory to enable cleanup in Windows 44 | for root, dirs, files in os.walk(self.temp_dir.name): 45 | for name in dirs + files: 46 | os.chmod(os.path.join(root, name), 0o777) 47 | 48 | self.temp_dir.cleanup() 49 | 50 | def test_has_git_true(self) -> None: 51 | self.assertTrue(self.git_info.has_git()) 52 | 53 | def test_has_git_false(self) -> None: 54 | with TemporaryDirectory() as temp_dir_no_git: 55 | os.chdir(temp_dir_no_git) 56 | self.git_info.repo_path = temp_dir_no_git 57 | self.assertFalse(self.git_info.has_git()) 58 | self.git_info.repo_path = self.temp_dir.name 59 | os.chdir(self.temp_dir.name) 60 | 61 | def test_get_git_root(self) -> None: 62 | # Ideally should be self.temp_dir.name == get_git_root() but the OS may add a prefix like /private 63 | self.assertTrue(self.git_info.get_git_root().endswith(self.temp_dir.name.replace("\\", "/"))) 64 | 65 | def test_get_git_root_subdir(self) -> None: 66 | subdir = os.path.join(self.temp_dir.name, "subdir") 67 | os.makedirs(subdir) 68 | os.chdir(subdir) 69 | 70 | # Ideally should be self.temp_dir.name == get_git_root() but the OS may add a prefix like /private 71 | self.assertTrue(self.git_info.get_git_root().endswith(self.temp_dir.name.replace("\\", "/"))) 72 | 73 | os.chdir(self.temp_dir.name) 74 | 75 | def test_get_git_url_https(self) -> None: 76 | self.assertEqual(self.git_info.get_git_url(commit_hash=False), self.url) 77 | 78 | def test_get_git_url_https_hash(self) -> None: 79 | url = f"{self.url}/tree/" 80 | self.assertEqual(self.git_info.get_git_url(commit_hash=True)[: len(url)], url) 81 | 82 | def test_get_git_url_ssh(self) -> None: 83 | subprocess.run(["git", "remote", "set-url", "origin", "git@github.com:test_account/test_repo.git"]) 84 | self.assertEqual(self.git_info.get_git_url(commit_hash=False), self.url) 85 | 86 | def test_get_git_url_ssh_hash(self) -> None: 87 | subprocess.run(["git", "remote", "set-url", "origin", "git@github.com:test_account/test_repo.git"]) 88 | url = f"{self.url}/tree/" 89 | self.assertEqual(self.git_info.get_git_url(commit_hash=True)[: len(url)], url) 90 | 91 | def test_get_git_url_https_enterprise(self) -> None: 92 | true_url = "https://github.tap.com/test_account/test_repo" 93 | subprocess.run(["git", "remote", "set-url", "origin", f"{true_url}.git"]) 94 | self.assertEqual(self.git_info.get_git_url(commit_hash=False), true_url) 95 | 96 | def test_get_git_url_https_hash_enterprise(self) -> None: 97 | true_url = "https://github.tap.com/test_account/test_repo" 98 | subprocess.run(["git", "remote", "set-url", "origin", f"{true_url}.git"]) 99 | url = f"{true_url}/tree/" 100 | self.assertEqual(self.git_info.get_git_url(commit_hash=True)[: len(url)], url) 101 | 102 | def test_get_git_url_ssh_enterprise(self) -> None: 103 | true_url = "https://github.tap.com/test_account/test_repo" 104 | subprocess.run(["git", "remote", "set-url", "origin", "git@github.tap.com:test_account/test_repo.git"]) 105 | self.assertEqual(self.git_info.get_git_url(commit_hash=False), true_url) 106 | 107 | def test_get_git_url_ssh_hash_enterprise(self) -> None: 108 | true_url = "https://github.tap.com/test_account/test_repo" 109 | subprocess.run(["git", "remote", "set-url", "origin", "git@github.tap.com:test_account/test_repo.git"]) 110 | url = f"{true_url}/tree/" 111 | self.assertEqual(self.git_info.get_git_url(commit_hash=True)[: len(url)], url) 112 | 113 | def test_get_git_url_no_remote(self) -> None: 114 | subprocess.run(["git", "remote", "remove", "origin"]) 115 | self.assertEqual(self.git_info.get_git_url(), "") 116 | 117 | def test_get_git_version(self) -> None: 118 | git_version = self.git_info.get_git_version() 119 | self.assertIsInstance(git_version, tuple) 120 | for v in git_version: 121 | self.assertIsInstance(v, int) 122 | 123 | def test_has_uncommitted_changes_false(self) -> None: 124 | self.assertFalse(self.git_info.has_uncommitted_changes()) 125 | 126 | def test_has_uncommited_changes_true(self) -> None: 127 | subprocess.run(["touch", "main.py"]) 128 | self.assertTrue(self.git_info.has_uncommitted_changes()) 129 | 130 | 131 | class TypeToStrTests(TestCase): 132 | def test_type_to_str(self) -> None: 133 | self.assertEqual(type_to_str(str), "str") 134 | self.assertEqual(type_to_str(int), "int") 135 | self.assertEqual(type_to_str(float), "float") 136 | self.assertEqual(type_to_str(bool), "bool") 137 | self.assertEqual(type_to_str(Any), "Any") 138 | self.assertEqual(type_to_str(Callable[[str], str]), "Callable[[str], str]") 139 | self.assertEqual( 140 | type_to_str(Callable[[str, int], Tuple[float, bool]]), "Callable[[str, int], Tuple[float, bool]]" 141 | ) 142 | self.assertEqual(type_to_str(List[int]), "List[int]") 143 | self.assertEqual(type_to_str(List[str]), "List[str]") 144 | self.assertEqual(type_to_str(List[float]), "List[float]") 145 | self.assertEqual(type_to_str(List[bool]), "List[bool]") 146 | self.assertEqual(type_to_str(Set[int]), "Set[int]") 147 | self.assertEqual(type_to_str(Dict[str, int]), "Dict[str, int]") 148 | self.assertEqual(type_to_str(Union[List[int], Dict[float, bool]]), "Union[List[int], Dict[float, bool]]") 149 | 150 | 151 | def class_decorator(cls): 152 | return cls 153 | 154 | 155 | class ClassColumnTests(TestCase): 156 | def test_column_simple(self): 157 | class SimpleColumn: 158 | arg = 2 159 | 160 | tokens = tokenize_source(inspect.getsource(SimpleColumn)) 161 | self.assertEqual(get_class_column(tokens), 12) 162 | 163 | def test_column_comment(self): 164 | class CommentColumn: 165 | """hello 166 | there 167 | 168 | 169 | hi 170 | """ 171 | 172 | arg = 2 173 | 174 | tokens = tokenize_source(inspect.getsource(CommentColumn)) 175 | self.assertEqual(get_class_column(tokens), 12) 176 | 177 | def test_column_space(self): 178 | class SpaceColumn: 179 | 180 | arg = 2 181 | 182 | tokens = tokenize_source(inspect.getsource(SpaceColumn)) 183 | self.assertEqual(get_class_column(tokens), 12) 184 | 185 | def test_column_method(self): 186 | class FuncColumn: 187 | def func(self): 188 | pass 189 | 190 | tokens = tokenize_source(inspect.getsource(FuncColumn)) 191 | self.assertEqual(get_class_column(tokens), 12) 192 | 193 | def test_dataclass(self): 194 | @class_decorator 195 | class DataclassColumn: 196 | arg: int = 5 197 | 198 | tokens = tokenize_source(inspect.getsource(DataclassColumn)) 199 | self.assertEqual(get_class_column(tokens), 12) 200 | 201 | def test_dataclass_method(self): 202 | def wrapper(f): 203 | pass 204 | 205 | @class_decorator 206 | class DataclassColumn: 207 | @wrapper 208 | def func(self): 209 | pass 210 | 211 | tokens = tokenize_source(inspect.getsource(DataclassColumn)) 212 | self.assertEqual(get_class_column(tokens), 12) 213 | 214 | 215 | class ClassVariableTests(TestCase): 216 | def test_no_variables(self): 217 | class NoVariables: 218 | pass 219 | 220 | self.assertEqual(get_class_variables(NoVariables), {}) 221 | 222 | def test_one_variable(self): 223 | class OneVariable: 224 | arg = 2 225 | 226 | class_variables = {"arg": {"comment": ""}} 227 | self.assertEqual(get_class_variables(OneVariable), class_variables) 228 | 229 | def test_multiple_variable(self): 230 | class MultiVariable: 231 | arg_1 = 2 232 | arg_2 = 3 233 | 234 | class_variables = {"arg_1": {"comment": ""}, "arg_2": {"comment": ""}} 235 | self.assertEqual(get_class_variables(MultiVariable), class_variables) 236 | 237 | def test_typed_variables(self): 238 | class TypedVariable: 239 | arg_1: str 240 | arg_2: int = 3 241 | 242 | class_variables = {"arg_1": {"comment": ""}, "arg_2": {"comment": ""}} 243 | self.assertEqual(get_class_variables(TypedVariable), class_variables) 244 | 245 | def test_separated_variables(self): 246 | class SeparatedVariable: 247 | """Comment""" 248 | 249 | arg_1: str 250 | 251 | # Hello 252 | def func(self): 253 | pass 254 | 255 | arg_2: int = 3 256 | """More comment""" 257 | 258 | class_variables = {"arg_1": {"comment": ""}, "arg_2": {"comment": "More comment"}} 259 | self.assertEqual(get_class_variables(SeparatedVariable), class_variables) 260 | 261 | def test_commented_variables(self): 262 | class CommentedVariable: 263 | """Comment""" 264 | 265 | arg_1: str # Arg 1 comment 266 | 267 | # Hello 268 | def func(self): 269 | pass 270 | 271 | arg_2: int = 3 # Arg 2 comment 272 | arg_3: Dict[str, int] # noqa E203,E262 Poorly formatted comment 273 | """More comment""" 274 | 275 | class_variables = { 276 | "arg_1": {"comment": "Arg 1 comment"}, 277 | "arg_2": {"comment": "Arg 2 comment"}, 278 | "arg_3": {"comment": "noqa E203,E262 Poorly formatted comment More comment"}, 279 | } 280 | self.assertEqual(get_class_variables(CommentedVariable), class_variables) 281 | 282 | def test_bad_spacing_multiline(self): 283 | class TrickyMultiline: 284 | """This is really difficult 285 | 286 | so 287 | so very difficult 288 | """ 289 | 290 | foo: str = "my" # Header line 291 | 292 | """ Footer 293 | T 294 | A 295 | P 296 | 297 | multi 298 | line!! 299 | """ 300 | 301 | class_variables = {} 302 | comment = "Header line Footer\nT\n A\n P\n\n multi\n line!!" 303 | class_variables["foo"] = {"comment": comment} 304 | self.assertEqual(get_class_variables(TrickyMultiline), class_variables) 305 | 306 | def test_triple_quote_multiline(self): 307 | class TripleQuoteMultiline: 308 | bar: int = 0 309 | """biz baz""" 310 | 311 | hi: str 312 | """Hello there""" 313 | 314 | class_variables = {"bar": {"comment": "biz baz"}, "hi": {"comment": "Hello there"}} 315 | self.assertEqual(get_class_variables(TripleQuoteMultiline), class_variables) 316 | 317 | def test_comments_with_quotes(self): 318 | class MultiquoteMultiline: 319 | bar: int = 0 320 | "''biz baz'" 321 | 322 | hi: str 323 | '"Hello there""' 324 | 325 | class_variables = {} 326 | class_variables["bar"] = {"comment": "''biz baz'"} 327 | class_variables["hi"] = {"comment": '"Hello there""'} 328 | self.assertEqual(get_class_variables(MultiquoteMultiline), class_variables) 329 | 330 | def test_multiline_argument(self): 331 | class MultilineArgument: 332 | bar: str = "This is a multiline argument" " that should not be included in the docstring" 333 | """biz baz""" 334 | 335 | class_variables = {"bar": {"comment": "biz baz"}} 336 | self.assertEqual(get_class_variables(MultilineArgument), class_variables) 337 | 338 | def test_multiline_argument_with_final_hashtag_comment(self): 339 | class MultilineArgumentWithHashTagComment: 340 | bar: str = "This is a multiline argument" " that should not be included in the docstring" # biz baz 341 | barr: str = "This is a multiline argument" " that should not be included in the docstring" # bar baz 342 | barrr: str = ( # meow 343 | "This is a multiline argument" # blah 344 | " that should not be included in the docstring" # grrrr 345 | ) # yay! 346 | 347 | class_variables = {"bar": {"comment": "biz baz"}, "barr": {"comment": "bar baz"}, "barrr": {"comment": "yay!"}} 348 | self.assertEqual(get_class_variables(MultilineArgumentWithHashTagComment), class_variables) 349 | 350 | def test_single_quote_multiline(self): 351 | class SingleQuoteMultiline: 352 | bar: int = 0 353 | "biz baz" 354 | 355 | hi: str 356 | "Hello there" 357 | 358 | class_variables = {"bar": {"comment": "biz baz"}, "hi": {"comment": "Hello there"}} 359 | self.assertEqual(get_class_variables(SingleQuoteMultiline), class_variables) 360 | 361 | def test_functions_with_docs_multiline(self): 362 | class FunctionsWithDocs: 363 | i: int = 0 364 | 365 | def f(self): 366 | """Function""" 367 | a: str = "hello" # noqa F841 368 | """with docs""" 369 | 370 | class_variables = {"i": {"comment": ""}} 371 | self.assertEqual(get_class_variables(FunctionsWithDocs), class_variables) 372 | 373 | def test_dataclass(self): 374 | @class_decorator 375 | class DataclassColumn: 376 | arg: int = 5 377 | 378 | class_variables = {"arg": {"comment": ""}} 379 | self.assertEqual(get_class_variables(DataclassColumn), class_variables) 380 | 381 | 382 | class GetLiteralsTests(TestCase): 383 | def test_get_literals_string(self) -> None: 384 | literal_f, shapes = get_literals(Literal["square", "triangle", "circle"], "shape") 385 | self.assertEqual(shapes, ["square", "triangle", "circle"]) 386 | self.assertEqual(literal_f("square"), "square") 387 | self.assertEqual(literal_f("triangle"), "triangle") 388 | self.assertEqual(literal_f("circle"), "circle") 389 | 390 | def test_get_literals_primitives(self) -> None: 391 | literals = [True, "one", 2, 3.14] 392 | literal_f, prims = get_literals(Literal[True, "one", 2, 3.14], "number") 393 | self.assertEqual(prims, literals) 394 | self.assertEqual([literal_f(str(p)) for p in prims], literals) 395 | 396 | def test_get_literals_uniqueness(self) -> None: 397 | with self.assertRaises(ArgumentTypeError): 398 | get_literals(Literal["two", 2, "2"], "number") 399 | 400 | def test_get_literals_empty(self) -> None: 401 | literal_f, prims = get_literals(Literal, "hi") 402 | self.assertEqual(prims, []) 403 | 404 | 405 | class TupleTypeEnforcerTests(TestCase): 406 | def test_tuple_type_enforcer_zero_types(self): 407 | enforcer = TupleTypeEnforcer(types=[]) 408 | with self.assertRaises(IndexError): 409 | enforcer("hi") 410 | 411 | def test_tuple_type_enforcer_one_type_str(self): 412 | enforcer = TupleTypeEnforcer(types=[str]) 413 | self.assertEqual(enforcer("hi"), "hi") 414 | 415 | def test_tuple_type_enforcer_one_type_int(self): 416 | enforcer = TupleTypeEnforcer(types=[int]) 417 | self.assertEqual(enforcer("123"), 123) 418 | 419 | def test_tuple_type_enforcer_one_type_float(self): 420 | enforcer = TupleTypeEnforcer(types=[float]) 421 | self.assertEqual(enforcer("3.14159"), 3.14159) 422 | 423 | def test_tuple_type_enforcer_one_type_bool(self): 424 | enforcer = TupleTypeEnforcer(types=[bool]) 425 | self.assertEqual(enforcer("True"), True) 426 | 427 | enforcer = TupleTypeEnforcer(types=[bool]) 428 | self.assertEqual(enforcer("true"), True) 429 | 430 | enforcer = TupleTypeEnforcer(types=[bool]) 431 | self.assertEqual(enforcer("False"), False) 432 | 433 | enforcer = TupleTypeEnforcer(types=[bool]) 434 | self.assertEqual(enforcer("false"), False) 435 | 436 | enforcer = TupleTypeEnforcer(types=[bool]) 437 | self.assertEqual(enforcer("tRu"), True) 438 | 439 | enforcer = TupleTypeEnforcer(types=[bool]) 440 | self.assertEqual(enforcer("faL"), False) 441 | 442 | enforcer = TupleTypeEnforcer(types=[bool]) 443 | self.assertEqual(enforcer("1"), True) 444 | 445 | enforcer = TupleTypeEnforcer(types=[bool]) 446 | self.assertEqual(enforcer("0"), False) 447 | 448 | def test_tuple_type_enforcer_multi_types_same(self): 449 | enforcer = TupleTypeEnforcer(types=[str, str]) 450 | args = ["hi", "bye"] 451 | output = [enforcer(arg) for arg in args] 452 | self.assertEqual(output, args) 453 | 454 | enforcer = TupleTypeEnforcer(types=[int, int, int]) 455 | args = [123, 456, -789] 456 | output = [enforcer(str(arg)) for arg in args] 457 | self.assertEqual(output, args) 458 | 459 | enforcer = TupleTypeEnforcer(types=[float, float, float, float]) 460 | args = [1.23, 4.56, -7.89, 3.14159] 461 | output = [enforcer(str(arg)) for arg in args] 462 | self.assertEqual(output, args) 463 | 464 | enforcer = TupleTypeEnforcer(types=[bool, bool, bool, bool, bool]) 465 | args = ["True", "False", "1", "0", "tru"] 466 | true_output = [True, False, True, False, True] 467 | output = [enforcer(str(arg)) for arg in args] 468 | self.assertEqual(output, true_output) 469 | 470 | def test_tuple_type_enforcer_multi_types_different(self): 471 | enforcer = TupleTypeEnforcer(types=[str, int, float, bool]) 472 | args = ["hello", 77, 0.2, "tru"] 473 | true_output = ["hello", 77, 0.2, True] 474 | output = [enforcer(str(arg)) for arg in args] 475 | self.assertEqual(output, true_output) 476 | 477 | def test_tuple_type_enforcer_infinite(self): 478 | enforcer = TupleTypeEnforcer(types=[int], loop=True) 479 | args = [1, 2, -5, 20] 480 | output = [enforcer(str(arg)) for arg in args] 481 | self.assertEqual(output, args) 482 | 483 | 484 | class NestedReplaceTypeTests(TestCase): 485 | def test_nested_replace_type_notype(self): 486 | obj = ["123", 4, 5, ("hello", 4.4)] 487 | replaced_obj = _nested_replace_type(obj, bool, int) 488 | self.assertEqual(obj, replaced_obj) 489 | 490 | def test_nested_replace_type_unnested(self): 491 | obj = ["123", 4, 5, ("hello", 4.4), True, False, "hi there"] 492 | replaced_obj = _nested_replace_type(obj, tuple, list) 493 | correct_obj = ["123", 4, 5, ["hello", 4.4], True, False, "hi there"] 494 | self.assertNotEqual(obj, replaced_obj) 495 | self.assertEqual(correct_obj, replaced_obj) 496 | 497 | def test_nested_replace_type_nested(self): 498 | obj = ["123", [4, (1, 2, (3, 4))], 5, ("hello", (4,), 4.4), {"1": [2, 3, [{"2": (3, 10)}, " hi "]]}] 499 | replaced_obj = _nested_replace_type(obj, tuple, list) 500 | correct_obj = ["123", [4, [1, 2, [3, 4]]], 5, ["hello", [4], 4.4], {"1": [2, 3, [{"2": [3, 10]}, " hi "]]}] 501 | self.assertNotEqual(obj, replaced_obj) 502 | self.assertEqual(correct_obj, replaced_obj) 503 | 504 | 505 | class Person: 506 | def __init__(self, name: str) -> None: 507 | self.name = name 508 | 509 | def __eq__(self, other: Any) -> bool: 510 | return isinstance(other, Person) and self.name == other.name 511 | 512 | 513 | class PythonObjectEncoderTests(TestCase): 514 | def test_python_object_encoder_simple_types(self): 515 | obj = [1, 2, "hi", "bye", 7.3, [1, 2, "blarg"], True, False, None] 516 | dumps = json.dumps(obj, indent=4, sort_keys=True, cls=define_python_object_encoder()) 517 | recreated_obj = json.loads(dumps, object_hook=as_python_object) 518 | self.assertEqual(recreated_obj, obj) 519 | 520 | def test_python_object_encoder_tuple(self): 521 | obj = [1, 2, "hi", "bye", 7.3, (1, 2, "blarg"), [("hi", "bye"), 2], {"hi": {"bye": (3, 4)}}, True, False, None] 522 | dumps = json.dumps(obj, indent=4, sort_keys=True, cls=define_python_object_encoder()) 523 | recreated_obj = json.loads(dumps, object_hook=as_python_object) 524 | self.assertEqual(recreated_obj, obj) 525 | 526 | def test_python_object_encoder_set(self): 527 | obj = [1, 2, "hi", "bye", 7.3, {1, 2, "blarg"}, [{"hi", "bye"}, 2], {"hi": {"bye": {3, 4}}}, True, False, None] 528 | dumps = json.dumps(obj, indent=4, sort_keys=True, cls=define_python_object_encoder()) 529 | recreated_obj = json.loads(dumps, object_hook=as_python_object) 530 | self.assertEqual(recreated_obj, obj) 531 | 532 | def test_python_object_encoder_complex(self): 533 | obj = [ 534 | 1, 535 | 2, 536 | "hi", 537 | "bye", 538 | 7.3, 539 | {1, 2, "blarg"}, 540 | [("hi", "bye"), 2], 541 | {"hi": {"bye": {3, 4}}}, 542 | True, 543 | False, 544 | None, 545 | (Person("tappy"), Person("tapper")), 546 | ] 547 | dumps = json.dumps(obj, indent=4, sort_keys=True, cls=define_python_object_encoder()) 548 | recreated_obj = json.loads(dumps, object_hook=as_python_object) 549 | self.assertEqual(recreated_obj, obj) 550 | 551 | def test_python_object_encoder_unpicklable(self): 552 | class CannotPickleThis: 553 | """Da na na na. Can't pickle this.""" 554 | 555 | def __init__(self): 556 | self.x = 1 557 | 558 | obj = [1, CannotPickleThis()] 559 | expected_obj = [1, UnpicklableObject()] 560 | with self.assertRaises(ValueError): 561 | json.dumps(obj, indent=4, sort_keys=True, cls=define_python_object_encoder()) 562 | 563 | dumps = json.dumps(obj, indent=4, sort_keys=True, cls=define_python_object_encoder(True)) 564 | recreated_obj = json.loads(dumps, object_hook=as_python_object) 565 | self.assertEqual(recreated_obj, expected_obj) 566 | 567 | 568 | class EnforceReproducibilityTests(TestCase): 569 | def test_saved_reproducibility_data_is_none(self): 570 | with self.assertRaises(ValueError): 571 | enforce_reproducibility(None, {}, "here") 572 | 573 | def test_git_url_not_in_saved_reproducibility_data(self): 574 | with self.assertRaises(ValueError): 575 | enforce_reproducibility({}, {}, "here") 576 | 577 | def test_git_url_not_in_current_reproducibility_data(self): 578 | with self.assertRaises(ValueError): 579 | enforce_reproducibility({"git_url": "none"}, {}, "here") 580 | 581 | def test_git_urls_disagree(self): 582 | with self.assertRaises(ValueError): 583 | enforce_reproducibility({"git_url": "none"}, {"git_url": "some"}, "here") 584 | 585 | def test_throw_error_for_saved_uncommitted_changes(self): 586 | with self.assertRaises(ValueError): 587 | enforce_reproducibility( 588 | {"git_url": "none", "git_has_uncommitted_changes": True}, {"git_url": "some"}, "here" 589 | ) 590 | 591 | def test_throw_error_for_uncommitted_changes(self): 592 | with self.assertRaises(ValueError): 593 | enforce_reproducibility( 594 | {"git_url": "none", "git_has_uncommitted_changes": False}, 595 | {"git_url": "some", "git_has_uncommitted_changes": True}, 596 | "here", 597 | ) 598 | 599 | 600 | if __name__ == "__main__": 601 | unittest.main() 602 | --------------------------------------------------------------------------------