├── .github
└── workflows
│ └── tests.yml
├── .gitignore
├── LICENSE.txt
├── README.md
├── demo.py
├── demo_data_model.py
├── images
├── logo.png
├── tap.png
└── tap_logo.png
├── pyproject.toml
├── src
└── tap
│ ├── __init__.py
│ ├── py.typed
│ ├── tap.py
│ ├── tapify.py
│ └── utils.py
└── tests
├── test_actions.py
├── test_integration.py
├── test_load_config_files.py
├── test_subparser.py
├── test_tapify.py
├── test_to_tap_class.py
└── test_utils.py
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: tests
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ${{ matrix.os }}
16 | strategy:
17 | matrix:
18 | os: [ubuntu-latest, macos-latest, windows-latest]
19 | python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
20 |
21 | steps:
22 | - uses: actions/checkout@main
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@main
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 | - name: Set temp directories on Windows
28 | if: matrix.os == 'windows-latest'
29 | run: |
30 | echo "TMPDIR=$env:USERPROFILE\AppData\Local\Temp" >> $env:GITHUB_ENV
31 | echo "TEMP=$env:USERPROFILE\AppData\Local\Temp" >> $env:GITHUB_ENV
32 | echo "TMP=$env:USERPROFILE\AppData\Local\Temp" >> $env:GITHUB_ENV
33 | - name: Install dependencies
34 | run: |
35 | git config --global user.email "you@example.com"
36 | git config --global user.name "Your Name"
37 | python -m pip install --upgrade pip
38 | python -m pip install -e ".[dev-no-pydantic]"
39 | - name: Lint with flake8
40 | run: |
41 | # stop the build if there are Python syntax errors or undefined names
42 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
43 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
44 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
45 | - name: Test without pydantic
46 | run: |
47 | pytest --cov=tap
48 | - name: Test with pydantic v1
49 | run: |
50 | python -m pip install "pydantic < 2"
51 | pytest --cov=tap --cov-append
52 | - name: Test with pydantic v2
53 | run: |
54 | python -m pip install "pydantic >= 2"
55 | pytest --cov=tap --cov-append
56 |
57 | - name: Upload coverage reports to Codecov
58 | uses: codecov/codecov-action@v4
59 | with:
60 | token: ${{ secrets.CODECOV_TOKEN }}
61 | verbose: true
62 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .vscode
3 | __pycache__
4 | .DS_Store
5 | *.json
6 | *.egg-info
7 | build
8 | .eggs
9 | .coverage
10 | dist
11 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2024 Jesse Michel and Kyle Swanson
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # Typed Argument Parser (Tap)
6 |
7 | [](https://badge.fury.io/py/typed-argument-parser)
8 | [](https://badge.fury.io/py/typed-argument-parser)
9 | [](https://pepy.tech/project/typed-argument-parser)
10 | [](https://github.com/swansonk14/typed-argument-parser)
11 | [](https://codecov.io/gh/swansonk14/typed-argument-parser)
12 | [](https://github.com/swansonk14/typed-argument-parser/blob/main/LICENSE.txt)
13 |
14 | Tap is a typed modernization of Python's [argparse](https://docs.python.org/3/library/argparse.html) library.
15 |
16 | Tap provides the following benefits:
17 | - Static type checking
18 | - Code completion
19 | - Source code navigation (e.g. go to definition and go to implementation)
20 |
21 | 
22 |
23 | See [this poster](https://docs.google.com/presentation/d/1AirN6gpiq4P1L8K003EsXmobVxP3A4AVEIR2KOEQN7Y/edit?usp=sharing), which we presented at [PyCon 2020](https://us.pycon.org/2020/), for a presentation of some of the relevant concepts we used to guide the development of Tap.
24 |
25 | As of version 1.8.0, Tap includes `tapify`, which runs functions or initializes classes with arguments parsed from the command line. We show an example below.
26 |
27 | ```python
28 | # square.py
29 | from tap import tapify
30 |
31 | def square(num: float) -> float:
32 | return num ** 2
33 |
34 | if __name__ == '__main__':
35 | print(f'The square of your number is {tapify(square)}.')
36 | ```
37 |
38 | Running `python square.py --num 2` will print `The square of your number is 4.0.`. Please see [tapify](#tapify) for more details.
39 |
40 | ## Installation
41 |
42 | Tap requires Python 3.9+
43 |
44 | To install Tap from PyPI run:
45 |
46 | ```
47 | pip install typed-argument-parser
48 | ```
49 |
50 |
51 | To install Tap from source, run the following commands:
52 |
53 | ```
54 | git clone https://github.com/swansonk14/typed-argument-parser.git
55 | cd typed-argument-parser
56 | pip install -e .
57 | ```
58 |
59 |
60 |
61 |
62 | To develop this package, install development requirements (in a virtual environment):
63 |
64 | ```
65 | python -m pip install -e ".[dev]"
66 | ```
67 |
68 | Style:
69 | - Please use [`black`](https://github.com/psf/black) formatting
70 | - Set your vertical line ruler to 121
71 | - Use [`flake8`](https://github.com/PyCQA/flake8) linting.
72 |
73 | To run tests, run:
74 |
75 | ```
76 | pytest
77 | ```
78 |
79 |
80 |
81 | ## Table of Contents
82 |
83 | * [Installation](#installation)
84 | * [Table of Contents](#table-of-contents)
85 | * [Tap is Python-native](#tap-is-python-native)
86 | * [Tap features](#tap-features)
87 | + [Arguments](#arguments)
88 | + [Tap help](#tap-help)
89 | + [Configuring arguments](#configuring-arguments)
90 | - [Adding special argument behavior](#adding-special-argument-behavior)
91 | - [Adding subparsers](#adding-subparsers)
92 | + [Types](#types)
93 | + [Argument processing](#argument-processing)
94 | + [Processing known args](#processing-known-args)
95 | + [Subclassing](#subclassing)
96 | + [Printing](#printing)
97 | + [Reproducibility](#reproducibility)
98 | + [Saving and loading arguments](#saving-and-loading-arguments)
99 | + [Loading from configuration files](#loading-from-configuration-files)
100 | * [tapify](#tapify)
101 | + [Examples](#examples)
102 | - [Function](#function)
103 | - [Class](#class)
104 | - [Dataclass](#dataclass)
105 | + [tapify help](#tapify-help)
106 | + [Command line vs explicit arguments](#command-line-vs-explicit-arguments)
107 | + [Known args](#known-args)
108 | * [Convert to a `Tap` class](#convert-to-a-tap-class)
109 | + [`to_tap_class` examples](#to_tap_class-examples)
110 | - [Simple](#simple)
111 | - [Complex](#complex)
112 |
113 | ## Tap is Python-native
114 |
115 | To see this, let's look at an example:
116 |
117 | ```python
118 | """main.py"""
119 |
120 | from tap import Tap
121 |
122 | class SimpleArgumentParser(Tap):
123 | name: str # Your name
124 | language: str = 'Python' # Programming language
125 | package: str = 'Tap' # Package name
126 | stars: int # Number of stars
127 | max_stars: int = 5 # Maximum stars
128 |
129 | args = SimpleArgumentParser().parse_args()
130 |
131 | print(f'My name is {args.name} and I give the {args.language} package '
132 | f'{args.package} {args.stars}/{args.max_stars} stars!')
133 | ```
134 |
135 | You use Tap the same way you use standard argparse.
136 |
137 | ```
138 | >>> python main.py --name Jesse --stars 5
139 | My name is Jesse and I give the Python package Tap 5/5 stars!
140 | ```
141 |
142 | The equivalent argparse code is:
143 | ```python
144 | """main.py"""
145 |
146 | from argparse import ArgumentParser
147 |
148 | parser = ArgumentParser()
149 | parser.add_argument('--name', type=str, required=True,
150 | help='Your name')
151 | parser.add_argument('--language', type=str, default='Python',
152 | help='Programming language')
153 | parser.add_argument('--package', type=str, default='Tap',
154 | help='Package name')
155 | parser.add_argument('--stars', type=int, required=True,
156 | help='Number of stars')
157 | parser.add_argument('--max_stars', type=int, default=5,
158 | help='Maximum stars')
159 | args = parser.parse_args()
160 |
161 | print(f'My name is {args.name} and I give the {args.language} package '
162 | f'{args.package} {args.stars}/{args.max_stars} stars!')
163 | ```
164 |
165 | The advantages of being Python-native include being able to:
166 | - Overwrite convenient built-in methods (e.g. `process_args` ensures consistency among arguments)
167 | - Add custom methods
168 | - Inherit from your own template classes
169 |
170 | ## Tap features
171 |
172 | Now we are going to highlight some of our favorite features and give examples of how they work in practice.
173 |
174 | ### Arguments
175 |
176 | Arguments are specified as class variables defined in a subclass of `Tap`. Variables defined as `name: type` are required arguments while variables defined as `name: type = value` are not required and default to the provided value.
177 |
178 | ```python
179 | class MyTap(Tap):
180 | required_arg: str
181 | default_arg: str = 'default value'
182 | ```
183 |
184 | ### Tap help
185 |
186 | Single line and/or multiline comments which appear after the argument are automatically parsed into the help string provided when running `python main.py -h`. The type and default values of arguments are also provided in the help string.
187 |
188 | ```python
189 | """main.py"""
190 |
191 | from tap import Tap
192 |
193 | class MyTap(Tap):
194 | x: float # What am I?
195 | pi: float = 3.14 # I'm pi!
196 | """Pi is my favorite number!"""
197 |
198 | args = MyTap().parse_args()
199 | ```
200 |
201 | Running `python main.py -h` results in the following:
202 |
203 | ```
204 | >>> python main.py -h
205 | usage: demo.py --x X [--pi PI] [-h]
206 |
207 | optional arguments:
208 | --x X (float, required) What am I?
209 | --pi PI (float, default=3.14) I'm pi! Pi is my favorite number.
210 | -h, --help show this help message and exit
211 | ```
212 |
213 | ### Configuring arguments
214 | To specify behavior beyond what can be specified using arguments as class variables, override the `configure` method.
215 | `configure` provides access to advanced argument parsing features such as `add_argument` and `add_subparser`.
216 | Since Tap is a wrapper around argparse, Tap provides all of the same functionality.
217 | We detail these two functions below.
218 |
219 | #### Adding special argument behavior
220 | In the `configure` method, call `self.add_argument` just as you would use argparse's `add_argument`. For example,
221 |
222 | ```python
223 | from tap import Tap
224 |
225 | class MyTap(Tap):
226 | positional_argument: str
227 | list_of_three_things: List[str]
228 | argument_with_really_long_name: int
229 |
230 | def configure(self):
231 | self.add_argument('positional_argument')
232 | self.add_argument('--list_of_three_things', nargs=3)
233 | self.add_argument('-arg', '--argument_with_really_long_name')
234 | ```
235 |
236 | #### Adding subparsers
237 | To add a subparser, override the `configure` method and call `self.add_subparser`. Optionally, to specify keyword arguments (e.g., `help`) to the subparser collection, call `self.add_subparsers`. For example,
238 |
239 | ```python
240 | class SubparserA(Tap):
241 | bar: int # bar help
242 |
243 | class SubparserB(Tap):
244 | baz: Literal['X', 'Y', 'Z'] # baz help
245 |
246 | class Args(Tap):
247 | foo: bool = False # foo help
248 |
249 | def configure(self):
250 | self.add_subparsers(help='sub-command help')
251 | self.add_subparser('a', SubparserA, help='a help')
252 | self.add_subparser('b', SubparserB, help='b help')
253 | ```
254 |
255 | ### Types
256 |
257 | Tap automatically handles all the following types:
258 |
259 | ```python
260 | str, int, float, bool
261 | Optional, Optional[str], Optional[int], Optional[float], Optional[bool]
262 | List, List[str], List[int], List[float], List[bool]
263 | Set, Set[str], Set[int], Set[float], Set[bool]
264 | Tuple, Tuple[Type1, Type2, etc.], Tuple[Type, ...]
265 | Literal
266 | ```
267 |
268 | If you're using Python 3.9+, then you can replace `List` with `list`, `Set` with `set`, and `Tuple` with `tuple`.
269 |
270 | Tap also supports `Union`, but this requires additional specification (see [Union](#-union-) section below).
271 |
272 | Additionally, any type that can be instantiated with a string argument can be used. For example, in
273 | ```python
274 | from pathlib import Path
275 | from tap import Tap
276 |
277 | class Args(Tap):
278 | path: Path
279 |
280 | args = Args().parse_args()
281 | ```
282 | `args.path` is a `Path` instance containing the string passed in through the command line.
283 |
284 | #### `str`, `int`, and `float`
285 |
286 | Each is automatically parsed to their respective types, just like argparse.
287 |
288 | #### `bool`
289 |
290 | If an argument `arg` is specified as `arg: bool` or `arg: bool = False`, then adding the `--arg` flag to the command line will set `arg` to `True`. If `arg` is specified as `arg: bool = True`, then adding `--arg` sets `arg` to `False`.
291 |
292 | Note that if the `Tap` instance is created with `explicit_bool=True`, then booleans can be specified on the command line as `--arg True` or `--arg False` rather than `--arg`. Additionally, booleans can be specified by prefixes of `True` and `False` with any capitalization as well as `1` or `0` (e.g. for True, `--arg tRu`, `--arg T`, `--arg 1` all suffice).
293 |
294 | #### `Optional`
295 |
296 | These arguments are parsed in exactly the same way as `str`, `int`, `float`, and `bool`. Note bools can be specified using the same rules as above and that `Optional` is equivalent to `Optional[str]`.
297 |
298 | #### `List`
299 |
300 | If an argument `arg` is a `List`, simply specify the values separated by spaces just as you would with regular argparse. For example, `--arg 1 2 3` parses to `arg = [1, 2, 3]`.
301 |
302 | #### `Set`
303 |
304 | Identical to `List` but parsed into a set rather than a list.
305 |
306 | #### `Tuple`
307 |
308 | Tuples can be used to specify a fixed number of arguments with specified types using the syntax `Tuple[Type1, Type2, etc.]` (e.g. `Tuple[str, int, bool, str]`). Tuples with a variable number of arguments are specified by `Tuple[Type, ...]` (e.g. `Tuple[int, ...]`). Note `Tuple` defaults to `Tuple[str, ...]`.
309 |
310 | #### `Literal`
311 |
312 | Literal is analagous to argparse's [choices](https://docs.python.org/3/library/argparse.html#choices), which specifies the values that an argument can take. For example, if arg can only be one of 'H', 1, False, or 1.0078 then you would specify that `arg: Literal['H', 1, False, 1.0078]`. For instance, `--arg False` assigns arg to False and `--arg True` throws error.
313 |
314 | #### `Union`
315 |
316 | Union types must include the `type` keyword argument in `add_argument` in order to specify which type to use, as in the example below.
317 |
318 | ```python
319 | def to_number(string: str) -> Union[float, int]:
320 | return float(string) if '.' in string else int(string)
321 |
322 | class MyTap(Tap):
323 | number: Union[float, int]
324 |
325 | def configure(self):
326 | self.add_argument('--number', type=to_number)
327 | ```
328 |
329 | In Python 3.10+, `Union[Type1, Type2, etc.]` can be replaced with `Type1 | Type2 | etc.`, but the `type` keyword argument must still be provided in `add_argument`.
330 |
331 | #### Complex Types
332 |
333 | Tap can also support more complex types than the ones specified above. If the desired type is constructed with a single string as input, then the type can be specified directly without additional modifications. For example,
334 |
335 | ```python
336 | class Person:
337 | def __init__(self, name: str) -> None:
338 | self.name = name
339 |
340 | class Args(Tap):
341 | person: Person
342 |
343 | args = Args().parse_args('--person Tapper'.split())
344 | print(args.person.name) # Tapper
345 | ```
346 |
347 | If the desired type has a more complex constructor, then the `type` keyword argument must be provided in `add_argument`. For example,
348 |
349 | ```python
350 | class AgedPerson:
351 | def __init__(self, name: str, age: int) -> None:
352 | self.name = name
353 | self.age = age
354 |
355 | def to_aged_person(string: str) -> AgedPerson:
356 | name, age = string.split(',')
357 | return AgedPerson(name=name, age=int(age))
358 |
359 | class Args(Tap):
360 | aged_person: AgedPerson
361 |
362 | def configure(self) -> None:
363 | self.add_argument('--aged_person', type=to_aged_person)
364 |
365 | args = Args().parse_args('--aged_person Tapper,27'.split())
366 | print(f'{args.aged_person.name} is {args.aged_person.age}') # Tapper is 27
367 | ```
368 |
369 |
370 | ### Argument processing
371 |
372 | With complex argument parsing, arguments often end up having interdependencies. This means that it may be necessary to disallow certain combinations of arguments or to modify some arguments based on other arguments.
373 |
374 | To handle such cases, simply override `process_args` and add the required logic. `process_args` is automatically called when `parse_args` is called.
375 |
376 | ```python
377 | class MyTap(Tap):
378 | package: str
379 | is_cool: bool
380 | stars: int
381 |
382 | def process_args(self):
383 | # Validate arguments
384 | if self.is_cool and self.stars < 4:
385 | raise ValueError('Cool packages cannot have fewer than 4 stars')
386 |
387 | # Modify arguments
388 | if self.package == 'Tap':
389 | self.is_cool = True
390 | self.stars = 5
391 | ```
392 |
393 | ### Processing known args
394 |
395 | Similar to argparse's `parse_known_args`, Tap is capable of parsing only arguments that it is aware of without raising an error due to additional arguments. This can be done by calling `parse_args` with `known_only=True`. The remaining un-parsed arguments are then available by accessing the `extra_args` field of the Tap object.
396 |
397 | ```python
398 | class MyTap(Tap):
399 | package: str
400 |
401 | args = MyTap().parse_args(['--package', 'Tap', '--other_arg', 'value'], known_only=True)
402 | print(args.extra_args) # ['--other_arg', 'value']
403 | ```
404 |
405 | ### Subclassing
406 |
407 | It is sometimes useful to define a template Tap and then subclass it for different use cases. Since Tap is a native Python class, inheritance is built-in, making it easy to customize from a template Tap.
408 |
409 | In the example below, `StarsTap` and `AwardsTap` inherit the arguments (`package` and `is_cool`) and the methods (`process_args`) from `BaseTap`.
410 |
411 | ```python
412 | class BaseTap(Tap):
413 | package: str
414 | is_cool: bool
415 |
416 | def process_args(self):
417 | if self.package == 'Tap':
418 | self.is_cool = True
419 |
420 |
421 | class StarsTap(BaseTap):
422 | stars: int
423 |
424 |
425 | class AwardsTap(BaseTap):
426 | awards: List[str]
427 | ```
428 |
429 | ### Printing
430 |
431 | Tap uses Python's [pretty printer](https://docs.python.org/3/library/pprint.html) to print out arguments in an easy-to-read format.
432 |
433 | ```python
434 | """main.py"""
435 |
436 | from tap import Tap
437 | from typing import List
438 |
439 | class MyTap(Tap):
440 | package: str
441 | is_cool: bool = True
442 | awards: List[str] = ['amazing', 'wow', 'incredible', 'awesome']
443 |
444 | args = MyTap().parse_args()
445 | print(args)
446 | ```
447 |
448 | Running `python main.py --package Tap` results in:
449 |
450 | ```
451 | >>> python main.py
452 | {'awards': ['amazing', 'wow', 'incredible', 'awesome'],
453 | 'is_cool': True,
454 | 'package': 'Tap'}
455 | ```
456 |
457 | ### Reproducibility
458 |
459 | Tap makes reproducibility easy, especially when running code in a git repo.
460 |
461 | #### Reproducibility info
462 |
463 | Specifically, Tap has a method called `get_reproducibility_info` that returns a dictionary containing all the information necessary to replicate the settings under which the code was run. This dictionary includes:
464 | - Python command
465 | - The Python command that was used to run the program
466 | - Ex. `python main.py --package Tap`
467 | - Time
468 | - The time when the command was run
469 | - Ex. `Thu Aug 15 00:09:13 2019`
470 | - Git root
471 | - The root of the git repo containing the code that was run
472 | - Ex. `/Users/swansonk14/typed-argument-parser`
473 | - Git url
474 | - The url to the git repo, specifically pointing to the current git hash (i.e. the hash of HEAD in the local repo)
475 | - Ex. [https://github.com/swansonk14/typed-argument-parser/tree/446cf046631d6bdf7cab6daec93bf7a02ac00998](https://github.com/swansonk14/typed-argument-parser/tree/446cf046631d6bdf7cab6daec93bf7a02ac00998)
476 | - Uncommitted changes
477 | - Whether there are any uncommitted changes in the git repo (i.e. whether the code is different from the code at the above git hash)
478 | - Ex. `True` or `False`
479 |
480 | ### Conversion Tap to and from dictionaries
481 |
482 | Tap has methods `as_dict` and `from_dict` that convert Tap objects to and from dictionaries.
483 | For example,
484 |
485 | ```python
486 | """main.py"""
487 | from tap import Tap
488 |
489 | class Args(Tap):
490 | package: str
491 | is_cool: bool = True
492 | stars: int = 5
493 |
494 | args = Args().parse_args(["--package", "Tap"])
495 |
496 | args_data = args.as_dict()
497 | print(args_data) # {'package': 'Tap', 'is_cool': True, 'stars': 5}
498 |
499 | args_data['stars'] = 2000
500 | args = args.from_dict(args_data)
501 | print(args.stars) # 2000
502 | ```
503 |
504 | Note that `as_dict` does not include attributes set directly on an instance (e.g., `arg` is not included even after setting `args.arg = "hi"` in the code above because `arg` is not an attribute of the `Args` class).
505 | Also note that `from_dict` ensures that all required arguments are set.
506 |
507 | ### Saving and loading arguments
508 |
509 | #### Save
510 |
511 | Tap has a method called `save` which saves all arguments, along with the reproducibility info, to a JSON file.
512 |
513 | ```python
514 | """main.py"""
515 |
516 | from tap import Tap
517 |
518 | class MyTap(Tap):
519 | package: str
520 | is_cool: bool = True
521 | stars: int = 5
522 |
523 | args = MyTap().parse_args()
524 | args.save('args.json')
525 | ```
526 |
527 | After running `python main.py --package Tap`, the file `args.json` will contain:
528 |
529 | ```
530 | {
531 | "is_cool": true,
532 | "package": "Tap",
533 | "reproducibility": {
534 | "command_line": "python main.py --package Tap",
535 | "git_has_uncommitted_changes": false,
536 | "git_root": "/Users/swansonk14/typed-argument-parser",
537 | "git_url": "https://github.com/swansonk14/typed-argument-parser/tree/446cf046631d6bdf7cab6daec93bf7a02ac00998",
538 | "time": "Thu Aug 15 00:18:31 2019"
539 | },
540 | "stars": 5
541 | }
542 | ```
543 |
544 | Note: More complex types will be encoded in JSON as a pickle string.
545 |
546 | #### Load
547 | > :exclamation: :warning:
548 | > Never call `args.load('args.json')` on untrusted files. Argument loading uses the `pickle` module to decode complex types automatically. Unpickling of untrusted data is a security risk and can lead to arbitrary code execution. See [the warning in the pickle docs](https://docs.python.org/3/library/pickle.html).
549 | > :exclamation: :warning:
550 |
551 | Arguments can be loaded from a JSON file rather than parsed from the command line.
552 |
553 | ```python
554 | """main.py"""
555 |
556 | from tap import Tap
557 |
558 | class MyTap(Tap):
559 | package: str
560 | is_cool: bool = True
561 | stars: int = 5
562 |
563 | args = MyTap()
564 | args.load('args.json')
565 | ```
566 |
567 | Note: All required arguments (in this case `package`) must be present in the JSON file if not already set in the Tap object.
568 |
569 | #### Load from dict
570 |
571 | Arguments can be loaded from a Python dictionary rather than parsed from the command line.
572 |
573 | ```python
574 | """main.py"""
575 |
576 | from tap import Tap
577 |
578 | class MyTap(Tap):
579 | package: str
580 | is_cool: bool = True
581 | stars: int = 5
582 |
583 | args = MyTap()
584 | args.from_dict({
585 | 'package': 'Tap',
586 | 'stars': 20
587 | })
588 | ```
589 |
590 | Note: As with `load`, all required arguments must be present in the dictionary if not already set in the Tap object. All values in the provided dictionary will overwrite values currently in the Tap object.
591 |
592 | ### Loading from configuration files
593 | Configuration files can be loaded along with arguments with the optional flag `config_files: List[str]`. Arguments passed in from the command line overwrite arguments from the configuration files. Arguments in configuration files that appear later in the list overwrite the arguments in previous configuration files.
594 |
595 | For example, if you have the config file `my_config.txt`
596 | ```
597 | --arg1 1
598 | --arg2 two
599 | ```
600 | then you can write
601 | ```python
602 | from tap import Tap
603 |
604 | class Args(Tap):
605 | arg1: int
606 | arg2: str
607 |
608 | args = Args(config_files=['my_config.txt']).parse_args()
609 | ```
610 |
611 | Config files are parsed using `shlex.split` from the python standard library, which supports shell-style string quoting, as well as line-end comments starting with `#`.
612 |
613 | For example, if you have the config file `my_config_shlex.txt`
614 | ```
615 | --arg1 21 # Important arg value
616 |
617 | # Multi-word quoted string
618 | --arg2 "two three four"
619 | ```
620 | then you can write
621 | ```python
622 | from tap import Tap
623 |
624 | class Args(Tap):
625 | arg1: int
626 | arg2: str
627 |
628 | args = Args(config_files=['my_config_shlex.txt']).parse_args()
629 | ```
630 | to get the resulting `args = {'arg1': 21, 'arg2': 'two three four'}`
631 |
632 | The legacy parsing behavior of using standard string split can be re-enabled by passing `legacy_config_parsing=True` to `parse_args`.
633 |
634 | ## tapify
635 |
636 | `tapify` makes it possible to run functions or initialize objects via command line arguments. This is inspired by Google's [Python Fire](https://github.com/google/python-fire), but `tapify` also automatically casts command line arguments to the appropriate types based on the type hints. Under the hood, `tapify` implicitly creates a Tap object and uses it to parse the command line arguments, which it then uses to run the function or initialize the class. We show a few examples below.
637 |
638 | ### Examples
639 |
640 | #### Function
641 |
642 | ```python
643 | # square_function.py
644 | from tap import tapify
645 |
646 | def square(num: float) -> float:
647 | """Square a number.
648 |
649 | :param num: The number to square.
650 | """
651 | return num ** 2
652 |
653 | if __name__ == '__main__':
654 | squared = tapify(square)
655 | print(f'The square of your number is {squared}.')
656 | ```
657 |
658 | Running `python square_function.py --num 5` prints `The square of your number is 25.0.`.
659 |
660 | #### Class
661 |
662 | ```python
663 | # square_class.py
664 | from tap import tapify
665 |
666 | class Squarer:
667 | def __init__(self, num: float) -> None:
668 | """Initialize the Squarer with a number to square.
669 |
670 | :param num: The number to square.
671 | """
672 | self.num = num
673 |
674 | def get_square(self) -> float:
675 | """Get the square of the number."""
676 | return self.num ** 2
677 |
678 | if __name__ == '__main__':
679 | squarer = tapify(Squarer)
680 | print(f'The square of your number is {squarer.get_square()}.')
681 | ```
682 |
683 | Running `python square_class.py --num 2` prints `The square of your number is 4.0.`.
684 |
685 | #### Dataclass
686 |
687 | ```python
688 | # square_dataclass.py
689 | from dataclasses import dataclass
690 |
691 | from tap import tapify
692 |
693 | @dataclass
694 | class Squarer:
695 | """Squarer with a number to square.
696 |
697 | :param num: The number to square.
698 | """
699 | num: float
700 |
701 | def get_square(self) -> float:
702 | """Get the square of the number."""
703 | return self.num ** 2
704 |
705 | if __name__ == '__main__':
706 | squarer = tapify(Squarer)
707 | print(f'The square of your number is {squarer.get_square()}.')
708 | ```
709 |
710 | Running `python square_dataclass.py --num -1` prints `The square of your number is 1.0.`.
711 |
712 |
713 | Argument descriptions
714 |
715 | For dataclasses, the argument's description (which is displayed in the `-h` help message) can either be specified in the
716 | class docstring or the field's description in `metadata`. If both are specified, the description from the docstring is
717 | used. In the example below, the description is provided in `metadata`.
718 |
719 | ```python
720 | # square_dataclass.py
721 | from dataclasses import dataclass, field
722 |
723 | from tap import tapify
724 |
725 | @dataclass
726 | class Squarer:
727 | """Squarer with a number to square.
728 | """
729 | num: float = field(metadata={"description": "The number to square."})
730 |
731 | def get_square(self) -> float:
732 | """Get the square of the number."""
733 | return self.num ** 2
734 |
735 | if __name__ == '__main__':
736 | squarer = tapify(Squarer)
737 | print(f'The square of your number is {squarer.get_square()}.')
738 | ```
739 |
740 |
741 |
742 | #### Pydantic
743 |
744 | Pydantic [Models](https://docs.pydantic.dev/latest/concepts/models/) and
745 | [dataclasses](https://docs.pydantic.dev/latest/concepts/dataclasses/) can be `tapify`d.
746 |
747 | ```python
748 | # square_pydantic.py
749 | from pydantic import BaseModel, Field
750 |
751 | from tap import tapify
752 |
753 | class Squarer(BaseModel):
754 | """Squarer with a number to square.
755 | """
756 | num: float = Field(description="The number to square.")
757 |
758 | def get_square(self) -> float:
759 | """Get the square of the number."""
760 | return self.num ** 2
761 |
762 | if __name__ == '__main__':
763 | squarer = tapify(Squarer)
764 | print(f'The square of your number is {squarer.get_square()}.')
765 | ```
766 |
767 |
768 | Argument descriptions
769 |
770 | For Pydantic v2 models and dataclasses, the argument's description (which is displayed in the `-h` help message) can
771 | either be specified in the class docstring or the field's `description`. If both are specified, the description from the
772 | docstring is used. In the example below, the description is provided in the docstring.
773 |
774 | For Pydantic v1 models and dataclasses, the argument's description must be provided in the class docstring:
775 |
776 | ```python
777 | # square_pydantic.py
778 | from pydantic import BaseModel
779 |
780 | from tap import tapify
781 |
782 | class Squarer(BaseModel):
783 | """Squarer with a number to square.
784 |
785 | :param num: The number to square.
786 | """
787 | num: float
788 |
789 | def get_square(self) -> float:
790 | """Get the square of the number."""
791 | return self.num ** 2
792 |
793 | if __name__ == '__main__':
794 | squarer = tapify(Squarer)
795 | print(f'The square of your number is {squarer.get_square()}.')
796 | ```
797 |
798 |
799 |
800 | ### tapify help
801 |
802 | The help string on the command line is set based on the docstring for the function or class. For example, running `python square_function.py -h` will print:
803 |
804 | ```
805 | usage: square_function.py [-h] --num NUM
806 |
807 | Square a number.
808 |
809 | options:
810 | -h, --help show this help message and exit
811 | --num NUM (float, required) The number to square.
812 | ```
813 |
814 | Note that for classes, if there is a docstring in the `__init__` method, then `tapify` sets the help string description to that docstring. Otherwise, it uses the docstring from the top of the class.
815 |
816 | ### Command line vs explicit arguments
817 |
818 | `tapify` can simultaneously use both arguments passed from the command line and arguments passed in explicitly in the `tapify` call. Arguments provided in the `tapify` call override function defaults, and arguments provided via the command line override both arguments provided in the `tapify` call and function defaults. We show an example below.
819 |
820 | ```python
821 | # add.py
822 | from tap import tapify
823 |
824 | def add(num_1: float, num_2: float = 0.0, num_3: float = 0.0) -> float:
825 | """Add numbers.
826 |
827 | :param num_1: The first number.
828 | :param num_2: The second number.
829 | :param num_3: The third number.
830 | """
831 | return num_1 + num_2 + num_3
832 |
833 | if __name__ == '__main__':
834 | added = tapify(add, num_2=2.2, num_3=4.1)
835 | print(f'The sum of your numbers is {added}.')
836 | ```
837 |
838 | Running `python add.py --num_1 1.0 --num_2 0.9` prints `The sum of your numbers is 6.0.`. (Note that `add` took `num_1 = 1.0` and `num_2 = 0.9` from the command line and `num_3=4.1` from the `tapify` call due to the order of precedence.)
839 |
840 | ### Known args
841 |
842 | Calling `tapify` with `known_only=True` allows `tapify` to ignore additional arguments from the command line that are not needed for the function or class. If `known_only=False` (the default), then `tapify` will raise an error when additional arguments are provided. We show an example below where `known_only=True` might be useful for running multiple `tapify` calls.
843 |
844 | ```python
845 | # person.py
846 | from tap import tapify
847 |
848 | def print_name(name: str) -> None:
849 | """Print a person's name.
850 |
851 | :param name: A person's name.
852 | """
853 | print(f'My name is {name}.')
854 |
855 | def print_age(age: int) -> None:
856 | """Print a person's age.
857 |
858 | :param name: A person's age.
859 | """
860 | print(f'My age is {age}.')
861 |
862 | if __name__ == '__main__':
863 | tapify(print_name, known_only=True)
864 | tapify(print_age, known_only=True)
865 | ```
866 |
867 | Running `python person.py --name Jesse --age 1` prints `My name is Jesse.` followed by `My age is 1.`. Without `known_only=True`, the `tapify` calls would raise an error due to the extra argument.
868 |
869 | ### Explicit boolean arguments
870 |
871 | Tapify supports explicit specification of boolean arguments (see [bool](#bool) for more details). By default, `explicit_bool=False` and it can be set with `tapify(..., explicit_bool=True)`.
872 |
873 | ## Convert to a `Tap` class
874 |
875 | `to_tap_class` turns a function or class into a `Tap` class. The returned class can be [subclassed](#subclassing) to add
876 | special argument behavior. For example, you can override [`configure`](#configuring-arguments) and
877 | [`process_args`](#argument-processing).
878 |
879 | If the object can be `tapify`d, then it can be `to_tap_class`d, and vice-versa. `to_tap_class` provides full control
880 | over argument parsing.
881 |
882 | ### `to_tap_class` examples
883 |
884 | #### Simple
885 |
886 | ```python
887 | # main.py
888 | """
889 | My script description
890 | """
891 |
892 | from pydantic import BaseModel
893 |
894 | from tap import to_tap_class
895 |
896 | class Project(BaseModel):
897 | package: str
898 | is_cool: bool = True
899 | stars: int = 5
900 |
901 | if __name__ == "__main__":
902 | ProjectTap = to_tap_class(Project)
903 | tap = ProjectTap(description=__doc__) # from the top of this script
904 | args = tap.parse_args()
905 | project = Project(**args.as_dict())
906 | print(f"Project instance: {project}")
907 | ```
908 |
909 | Running `python main.py --package tap` will print `Project instance: package='tap' is_cool=True stars=5`.
910 |
911 | ### Complex
912 |
913 | The general pattern is:
914 |
915 | ```python
916 | from tap import to_tap_class
917 |
918 | class MyCustomTap(to_tap_class(my_class_or_function)):
919 | # Special argument behavior, e.g., override configure and/or process_args
920 | ```
921 |
922 | Please see `demo_data_model.py` for an example of overriding [`configure`](#configuring-arguments) and
923 | [`process_args`](#argument-processing).
924 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from tap import Tap
3 |
4 |
5 | def add_one(num: int) -> int:
6 | return num + 1
7 |
8 |
9 | # ----- ArgumentParser -----
10 |
11 | parser = ArgumentParser()
12 | parser.add_argument("--rnn", type=str, required=True, help="RNN type")
13 | parser.add_argument("--hidden_size", type=int, default=300, help="Hidden size")
14 | parser.add_argument("--dropout", type=float, default=0.2, help="Dropout probability")
15 |
16 |
17 | args = parser.parse_args()
18 |
19 | print(args.hidden_size) # no autocomplete, no type inference, no source code navigation
20 |
21 | add_one(args.rnn) # no static type checking
22 |
23 |
24 | # ----- Tap -----
25 |
26 |
27 | class MyTap(Tap):
28 | rnn: str # RNN type
29 | hidden_size: int = 300 # Hidden size
30 | dropout: float = 0.2 # Dropout probability
31 |
32 |
33 | args = MyTap().parse_args()
34 |
35 | print(args.hidden_size) # autocomplete, type inference, source code navigation
36 |
37 | add_one(args.rnn) # static type checking
38 |
--------------------------------------------------------------------------------
/demo_data_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Works for Pydantic v1 and v2.
3 |
4 | Example commands:
5 |
6 | python demo_data_model.py -h
7 |
8 | python demo_data_model.py \
9 | --arg_int 1 \
10 | --arg_list x y z \
11 | --argument_with_really_long_name 3
12 |
13 | python demo_data_model.py \
14 | --arg_int 1 \
15 | --arg_list x y z \
16 | --arg_bool \
17 | -arg 3.14
18 | """
19 | from typing import List, Literal, Optional, Union
20 |
21 | from pydantic import BaseModel, Field
22 | from tap import tapify, to_tap_class, Tap
23 |
24 |
25 | class Model(BaseModel):
26 | """
27 | My Pydantic Model which contains script args.
28 | """
29 |
30 | arg_int: int = Field(description="some integer")
31 | arg_bool: bool = Field(default=True)
32 | arg_list: Optional[List[str]] = Field(default=None, description="some list of strings")
33 |
34 |
35 | def main(model: Model) -> None:
36 | print("Parsed args into Model:")
37 | print(model)
38 |
39 |
40 | def to_number(string: str) -> Union[float, int]:
41 | return float(string) if "." in string else int(string)
42 |
43 |
44 | class ModelTap(to_tap_class(Model)):
45 | # You can supply additional arguments here
46 | argument_with_really_long_name: Union[float, int] = 3
47 | "This argument has a long name and will be aliased with a short one"
48 |
49 | def configure(self) -> None:
50 | # You can still add special argument behavior
51 | self.add_argument("-arg", "--argument_with_really_long_name", type=to_number)
52 |
53 | def process_args(self) -> None:
54 | # You can still validate and modify arguments
55 | # (You should do this in the Pydantic Model. I'm just demonstrating that this functionality is still possible)
56 | if self.argument_with_really_long_name > 4:
57 | raise ValueError("argument_with_really_long_name cannot be > 4")
58 |
59 | # No auto-complete (and other niceties) for the super class attributes b/c this is a dynamic subclass. Sorry
60 | if self.arg_bool and self.arg_list is not None:
61 | self.arg_list.append("processed")
62 |
63 |
64 | # class SubparserA(Tap):
65 | # bar: int # bar help
66 |
67 |
68 | # class SubparserB(Tap):
69 | # baz: Literal["X", "Y", "Z"] # baz help
70 |
71 |
72 | # class ModelTapWithSubparsing(to_tap_class(Model)):
73 | # foo: bool = False # foo help
74 |
75 | # def configure(self):
76 | # self.add_subparsers(help="sub-command help")
77 | # self.add_subparser("a", SubparserA, help="a help", description="Description (a)")
78 | # self.add_subparser("b", SubparserB, help="b help")
79 |
80 |
81 | if __name__ == "__main__":
82 | # You don't have to subclass tap_class_from_data_model(Model) if you just want a plain argument parser:
83 | # ModelTap = to_tap_class(Model)
84 | args = ModelTap(description="Script description").parse_args()
85 | # args = ModelTapWithSubparsing(description="Script description").parse_args()
86 | print("Parsed args:")
87 | print(args)
88 | # Run the main function
89 | model = Model(**args.as_dict())
90 | main(model)
91 |
92 |
93 | # tapify works with Model. It immediately returns a Model instance instead of a Tap class
94 | # if __name__ == "__main__":
95 | # model = tapify(Model)
96 | # print(model)
97 |
--------------------------------------------------------------------------------
/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swansonk14/typed-argument-parser/619ee8bc2ec3798c203e6641149f93402701ca5f/images/logo.png
--------------------------------------------------------------------------------
/images/tap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swansonk14/typed-argument-parser/619ee8bc2ec3798c203e6641149f93402701ca5f/images/tap.png
--------------------------------------------------------------------------------
/images/tap_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swansonk14/typed-argument-parser/619ee8bc2ec3798c203e6641149f93402701ca5f/images/tap_logo.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools >= 61.0.0", "setuptools-scm"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "typed-argument-parser"
7 | dynamic = ["version"]
8 | authors = [
9 | {name = "Jesse Michel", email = "jessem.michel@gmail.com" },
10 | {name = "Kyle Swanson", email = "swansonk.14@gmail.com" },
11 | ]
12 | maintainers = [
13 | {name = "Jesse Michel", email = "jessem.michel@gmail.com" },
14 | {name = "Kyle Swanson", email = "swansonk.14@gmail.com" },
15 | ]
16 | description = "Typed Argument Parser"
17 | readme = "README.md"
18 | license = { file = "LICENSE.txt" }
19 | dependencies = [
20 | "docstring-parser >= 0.15",
21 | "packaging",
22 | "typing-inspect >= 0.7.1",
23 | ]
24 | requires-python = ">=3.9"
25 | classifiers = [
26 | "Programming Language :: Python :: 3",
27 | "Programming Language :: Python :: 3.9",
28 | "Programming Language :: Python :: 3.10",
29 | "Programming Language :: Python :: 3.11",
30 | "Programming Language :: Python :: 3.12",
31 | "Programming Language :: Python :: 3.13",
32 | "License :: OSI Approved :: MIT License",
33 | "Operating System :: OS Independent",
34 | "Typing :: Typed",
35 | ]
36 | keywords = [
37 | "typing",
38 | "argument parser",
39 | "python",
40 | ]
41 |
42 | [project.optional-dependencies]
43 | dev-no-pydantic = [
44 | "pytest",
45 | "pytest-cov",
46 | "flake8",
47 | ]
48 | dev = [
49 | "typed-argument-parser[dev-no-pydantic]",
50 | "pydantic >= 2.5.0",
51 | ]
52 |
53 | [tool.setuptools]
54 | package-dir = {"" = "src"}
55 |
56 | [tool.setuptools.packages.find]
57 | where = ["src"]
58 |
59 | [tool.setuptools.dynamic]
60 | version = {attr = "tap.__version__"}
61 |
62 | [tool.setuptools.package-data]
63 | tap = ["py.typed"]
64 |
65 | [project.urls]
66 | Homepage = "https://github.com/swansonk14/typed-argument-parser"
67 | Issues = "https://github.com/swansonk14/typed-argument-parser/issues"
68 |
--------------------------------------------------------------------------------
/src/tap/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Typed Argument Parser
3 | """
4 |
5 | __version__ = "1.10.1"
6 |
7 | from argparse import ArgumentError, ArgumentTypeError
8 | from tap.tap import Tap
9 | from tap.tapify import tapify, to_tap_class
10 |
11 | __all__ = [
12 | "ArgumentError",
13 | "ArgumentTypeError",
14 | "Tap",
15 | "tapify",
16 | "to_tap_class",
17 | "__version__",
18 | ]
19 |
--------------------------------------------------------------------------------
/src/tap/py.typed:
--------------------------------------------------------------------------------
1 | # Marker file for PEP 561.
2 |
--------------------------------------------------------------------------------
/src/tap/tap.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | import time
4 | from argparse import ArgumentParser, ArgumentTypeError
5 | from copy import deepcopy
6 | from functools import partial
7 | from pathlib import Path
8 | from pprint import pformat
9 | from shlex import quote, split
10 | from types import MethodType
11 | from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, TypeVar, Union, get_type_hints
12 | from typing_inspect import is_literal_type
13 |
14 | from tap.utils import (
15 | get_class_variables,
16 | get_args,
17 | get_argument_name,
18 | get_dest,
19 | get_origin,
20 | GitInfo,
21 | is_option_arg,
22 | is_positional_arg,
23 | type_to_str,
24 | get_literals,
25 | boolean_type,
26 | TupleTypeEnforcer,
27 | define_python_object_encoder,
28 | as_python_object,
29 | enforce_reproducibility,
30 | PathLike,
31 | )
32 |
33 | if sys.version_info >= (3, 10):
34 | from types import UnionType
35 |
36 |
37 | # Constants
38 | EMPTY_TYPE = get_args(List)[0] if len(get_args(List)) > 0 else tuple()
39 | BOXED_COLLECTION_TYPES = {List, list, Set, set, Tuple, tuple}
40 | UNION_TYPES = {Union} | ({UnionType} if sys.version_info >= (3, 10) else set())
41 | OPTIONAL_TYPES = {Optional} | UNION_TYPES
42 | BOXED_TYPES = BOXED_COLLECTION_TYPES | OPTIONAL_TYPES
43 |
44 |
45 | TapType = TypeVar("TapType", bound="Tap")
46 |
47 |
48 | class Tap(ArgumentParser):
49 | """Tap is a typed argument parser that wraps Python's built-in ArgumentParser."""
50 |
51 | def __init__(
52 | self,
53 | *args,
54 | underscores_to_dashes: bool = False,
55 | explicit_bool: bool = False,
56 | config_files: Optional[list[PathLike]] = None,
57 | **kwargs,
58 | ) -> None:
59 | """Initializes the Tap instance.
60 |
61 | :param args: Arguments passed to the super class ArgumentParser.
62 | :param underscores_to_dashes: If True, convert underscores in flags to dashes.
63 | :param explicit_bool: Booleans can be specified on the command line as "--arg True" or "--arg False"
64 | rather than "--arg". Additionally, booleans can be specified by prefixes of True and False
65 | with any capitalization as well as 1 or 0.
66 | :param config_files: A list of paths to configuration files containing the command line arguments
67 | (e.g., '--arg1 a1 --arg2 a2'). Arguments passed in from the command line
68 | overwrite arguments from the configuration files. Arguments in configuration files
69 | that appear later in the list overwrite the arguments in previous configuration files.
70 | :param kwargs: Keyword arguments passed to the super class ArgumentParser.
71 | """
72 | # Whether the Tap object has been initialized
73 | self._initialized = False
74 |
75 | # Whether boolean flags have to be explicitly set to True or False
76 | self._explicit_bool = explicit_bool
77 |
78 | # Whether we convert underscores in the flag names to dashes
79 | self._underscores_to_dashes = underscores_to_dashes
80 |
81 | # Whether the arguments have been parsed (i.e. if parse_args has been called)
82 | self._parsed = False
83 |
84 | # Set extra arguments to empty list
85 | self.extra_args = []
86 |
87 | # Create argument buffer
88 | self.argument_buffer = {}
89 |
90 | # Create a place to put the subparsers
91 | self._subparser_buffer: list[tuple[str, type, dict[str, Any]]] = []
92 |
93 | # Get class variables help strings from the comments
94 | self.class_variables = self._get_class_variables()
95 |
96 | # Get annotations from self and all super classes up through tap
97 | self._annotations = self._get_annotations()
98 |
99 | # Set the default description to be the docstring
100 | kwargs.setdefault("description", self.__doc__)
101 |
102 | # Initialize the super class, i.e. ArgumentParser
103 | super(Tap, self).__init__(*args, **kwargs)
104 |
105 | # Stores the subparsers
106 | self._subparsers = None
107 |
108 | # Load in the configuration files
109 | self.args_from_configs = self._load_from_config_files(config_files)
110 |
111 | # Perform additional configuration such as adding arguments or adding subparsers
112 | self._configure()
113 |
114 | # Indicate that initialization is complete
115 | self._initialized = True
116 |
117 | def _add_argument(self, *name_or_flags, **kwargs) -> None:
118 | """Adds an argument to self (i.e. the super class ArgumentParser).
119 |
120 | Sets the following attributes of kwargs when not explicitly provided:
121 | - type: Set to the type annotation of the argument.
122 | - default: Set to the default value of the argument (if provided).
123 | - required: True if a default value of the argument is not provided, False otherwise.
124 | - action: Set to "store_true" if the argument is a required bool or a bool with default value False.
125 | Set to "store_false" if the argument is a bool with default value True.
126 | - nargs: Set to "*" if the type annotation is List[str], List[int], or List[float].
127 | - help: Set to the argument documentation from the class docstring.
128 |
129 | :param name_or_flags: Either a name or a list of option strings, e.g. foo or -f, --foo.
130 | :param kwargs: Keyword arguments.
131 | """
132 | # Set explicit bool
133 | explicit_bool = self._explicit_bool
134 |
135 | # Get variable name
136 | variable = get_argument_name(*name_or_flags)
137 |
138 | if self._underscores_to_dashes:
139 | variable = variable.replace("-", "_")
140 |
141 | # Get default if not specified
142 | if hasattr(self, variable):
143 | kwargs["default"] = kwargs.get("default", getattr(self, variable))
144 |
145 | # Set required if option arg
146 | if (
147 | is_option_arg(*name_or_flags)
148 | and variable != "help"
149 | and "default" not in kwargs
150 | and kwargs.get("action") != "version"
151 | ):
152 | kwargs["required"] = kwargs.get("required", not hasattr(self, variable))
153 |
154 | # Set help if necessary
155 | if "help" not in kwargs:
156 | kwargs["help"] = "("
157 |
158 | # Type
159 | if variable in self._annotations:
160 | kwargs["help"] += type_to_str(self._annotations[variable]) + ", "
161 |
162 | # Required/default
163 | if kwargs.get("required", False) or is_positional_arg(*name_or_flags):
164 | kwargs["help"] += "required"
165 | else:
166 | kwargs["help"] += f'default={kwargs.get("default", None)}'
167 |
168 | kwargs["help"] += ")"
169 |
170 | # Description
171 | if variable in self.class_variables:
172 | kwargs["help"] += " " + self.class_variables[variable]["comment"]
173 |
174 | # Set other kwargs where not provided
175 | if variable in self._annotations:
176 | # Get type annotation
177 | var_type = self._annotations[variable]
178 |
179 | # If type is not explicitly provided, set it if it's one of our supported default types
180 | if "type" not in kwargs:
181 | # Unbox Union[type] (Optional[type]) and set var_type = type
182 | if get_origin(var_type) in OPTIONAL_TYPES:
183 | var_args = get_args(var_type)
184 |
185 | # If type is Union or Optional without inner types, set type to equivalent of Optional[str]
186 | if len(var_args) == 0:
187 | var_args = (str, type(None))
188 |
189 | # Raise error if type function is not explicitly provided for Union types (not including Optionals)
190 | if get_origin(var_type) in UNION_TYPES and not (len(var_args) == 2 and var_args[1] == type(None)):
191 | raise ArgumentTypeError(
192 | "For Union types, you must include an explicit type function in the configure method. "
193 | "For example,\n\n"
194 | "def to_number(string: str) -> Union[float, int]:\n"
195 | " return float(string) if '.' in string else int(string)\n\n"
196 | "class Args(Tap):\n"
197 | " arg: Union[float, int]\n"
198 | "\n"
199 | " def configure(self) -> None:\n"
200 | " self.add_argument('--arg', type=to_number)"
201 | )
202 |
203 | if len(var_args) > 0:
204 | var_type = var_args[0]
205 |
206 | # If var_type is tuple as in Python 3.6, change to a typing type
207 | # (e.g., (typing.List, ) ==> typing.List[bool])
208 | if isinstance(var_type, tuple):
209 | var_type = var_type[0][var_type[1:]]
210 |
211 | explicit_bool = True
212 |
213 | # First check whether it is a literal type or a boxed literal type
214 | if is_literal_type(var_type):
215 | var_type, kwargs["choices"] = get_literals(var_type, variable)
216 |
217 | elif (
218 | get_origin(var_type) in (List, list, Set, set)
219 | and len(get_args(var_type)) > 0
220 | and is_literal_type(get_args(var_type)[0])
221 | ):
222 | var_type, kwargs["choices"] = get_literals(get_args(var_type)[0], variable)
223 | if kwargs.get("action") not in {"append", "append_const"}:
224 | kwargs["nargs"] = kwargs.get("nargs", "*")
225 |
226 | # Handle Tuple type (with type args) by extracting types of Tuple elements and enforcing them
227 | elif get_origin(var_type) in (Tuple, tuple) and len(get_args(var_type)) > 0:
228 | loop = False
229 | types = list(get_args(var_type))
230 |
231 | # Handle Tuple[type, ...]
232 | if len(types) == 2 and types[1] == Ellipsis:
233 | types = types[0:1]
234 | loop = True
235 | kwargs["nargs"] = "*"
236 | # Handle Tuple[()]
237 | elif len(types) == 1 and types[0] == tuple():
238 | types = [str]
239 | loop = True
240 | kwargs["nargs"] = "*"
241 | else:
242 | kwargs["nargs"] = len(types)
243 |
244 | # Handle Literal types
245 | types = [get_literals(tp, variable)[0] if is_literal_type(tp) else tp for tp in types]
246 |
247 | var_type = TupleTypeEnforcer(types=types, loop=loop)
248 |
249 | if get_origin(var_type) in BOXED_TYPES:
250 | # If List or Set or Tuple type, set nargs
251 | if get_origin(var_type) in BOXED_COLLECTION_TYPES and kwargs.get("action") not in {
252 | "append",
253 | "append_const",
254 | }:
255 | kwargs["nargs"] = kwargs.get("nargs", "*")
256 |
257 | # Extract boxed type for Optional, List, Set
258 | arg_types = get_args(var_type)
259 |
260 | # Set defaults type to str for Type and Type[()]
261 | if len(arg_types) == 0 or arg_types[0] == EMPTY_TYPE:
262 | var_type = str
263 | else:
264 | var_type = arg_types[0]
265 |
266 | # Handle the cases of List[bool], Set[bool], Tuple[bool]
267 | if var_type == bool:
268 | var_type = boolean_type
269 |
270 | # If bool then set action, otherwise set type
271 | if var_type == bool:
272 | if explicit_bool:
273 | kwargs["type"] = boolean_type
274 | kwargs["choices"] = [True, False] # this makes the help message more helpful
275 | else:
276 | action_cond = "true" if kwargs.get("required", False) or not kwargs["default"] else "false"
277 | kwargs["action"] = kwargs.get("action", f"store_{action_cond}")
278 | elif kwargs.get("action") not in {"count", "append_const"}:
279 | kwargs["type"] = var_type
280 |
281 | if self._underscores_to_dashes:
282 | # Replace "_" with "-" for arguments that aren't positional
283 | name_or_flags = tuple(
284 | name_or_flag.replace("_", "-") if name_or_flag.startswith("-") else name_or_flag
285 | for name_or_flag in name_or_flags
286 | )
287 |
288 | # Deepcopy default to prevent mutation of values
289 | if "default" in kwargs:
290 | kwargs["default"] = deepcopy(kwargs["default"])
291 |
292 | super(Tap, self).add_argument(*name_or_flags, **kwargs)
293 |
294 | def add_argument(self, *name_or_flags, **kwargs) -> None:
295 | """Adds an argument to the argument buffer, which will later be passed to _add_argument."""
296 | if self._initialized:
297 | raise ValueError(
298 | "add_argument cannot be called after initialization. "
299 | "Arguments must be added either as class variables or by overriding "
300 | "configure and including a self.add_argument call there."
301 | )
302 |
303 | variable = get_argument_name(*name_or_flags).replace("-", "_")
304 | self.argument_buffer[variable] = (name_or_flags, kwargs)
305 |
306 | def _add_arguments(self) -> None:
307 | """Add arguments to self in the order they are defined as class variables (so the help string is in order)."""
308 | # Add class variables (in order)
309 | for variable in self.class_variables:
310 | if variable in self.argument_buffer:
311 | name_or_flags, kwargs = self.argument_buffer[variable]
312 | self._add_argument(*name_or_flags, **kwargs)
313 | else:
314 | self._add_argument(f"--{variable}")
315 |
316 | # Add any arguments that were added manually in configure but aren't class variables (in order)
317 | for variable, (name_or_flags, kwargs) in self.argument_buffer.items():
318 | if variable not in self.class_variables:
319 | self._add_argument(*name_or_flags, **kwargs)
320 |
321 | def process_args(self) -> None:
322 | """Perform additional argument processing and/or validation."""
323 | pass
324 |
325 | def add_subparser(self, flag: str, subparser_type: type, **kwargs) -> None:
326 | """Add a subparser to the collection of subparsers"""
327 | help_desc = kwargs.get("help", subparser_type.__doc__)
328 | kwargs["help"] = help_desc
329 |
330 | self._subparser_buffer.append((flag, subparser_type, kwargs))
331 |
332 | def _add_subparsers(self) -> None:
333 | """Add each of the subparsers to the Tap object."""
334 | # Initialize the _subparsers object if not already created
335 | if self._subparsers is None and len(self._subparser_buffer) > 0:
336 | self._subparsers = super(Tap, self).add_subparsers()
337 |
338 | # Load each subparser
339 | for flag, subparser_type, kwargs in self._subparser_buffer:
340 | self._subparsers._parser_class = partial(subparser_type, underscores_to_dashes=self._underscores_to_dashes)
341 | self._subparsers.add_parser(flag, **kwargs)
342 |
343 | def add_subparsers(self, **kwargs) -> None:
344 | self._subparsers = super().add_subparsers(**kwargs)
345 |
346 | def _configure(self) -> None:
347 | """Executes the user-defined configuration."""
348 | # Call the user-defined configuration
349 | self.configure()
350 |
351 | # Add arguments to self
352 | self._add_arguments()
353 |
354 | # Add subparsers to self
355 | self._add_subparsers()
356 |
357 | def configure(self) -> None:
358 | """Overwrite this method to configure the parser during initialization.
359 |
360 | For example,
361 | self.add_argument('--sum',
362 | dest='accumulate',
363 | action='store_const',
364 | const=sum,
365 | default=max)
366 | self.add_subparsers(help='sub-command help')
367 | self.add_subparser('a', SubparserA, help='a help')
368 | """
369 | pass
370 |
371 | @staticmethod
372 | def get_reproducibility_info(repo_path: Optional[PathLike] = None) -> dict[str, str]:
373 | """Gets a dictionary of reproducibility information.
374 |
375 | Reproducibility information always includes:
376 | - command_line: The command line command used to execute the code.
377 | - time: The current time.
378 |
379 | If git is installed, reproducibility information also includes:
380 | - git_root: The root of the git repo where the command is run.
381 | - git_url: The url of the current hash of the git repo where the command is run.
382 | Ex. https://github.com/swansonk14/rationale-alignment/tree/.
383 | If it is a local repo, the url is None.
384 | - git_has_uncommitted_changes: Whether the current git repo has uncommitted changes.
385 |
386 | :param repo_path: Path to the git repo to examine for reproducibility info.
387 | If None, uses the git repo of the Python file that is run.
388 | :return: A dictionary of reproducibility information.
389 | """
390 | # Get the path to the Python file that is being run
391 | if repo_path is None:
392 | repo_path = (Path.cwd() / Path(sys.argv[0]).parent).resolve()
393 |
394 | reproducibility = {
395 | "command_line": f'python {" ".join(quote(arg) for arg in sys.argv)}',
396 | "time": time.strftime("%c"),
397 | }
398 |
399 | git_info = GitInfo(repo_path=repo_path)
400 |
401 | if git_info.has_git():
402 | reproducibility["git_root"] = git_info.get_git_root()
403 | reproducibility["git_url"] = git_info.get_git_url(commit_hash=True)
404 | reproducibility["git_has_uncommitted_changes"] = str(git_info.has_uncommitted_changes())
405 |
406 | return reproducibility
407 |
408 | def _log_all(self, repo_path: Optional[PathLike] = None) -> dict[str, Any]:
409 | """Gets all arguments along with reproducibility information.
410 |
411 | :param repo_path: Path to the git repo to examine for reproducibility info.
412 | If None, uses the git repo of the Python file that is run.
413 | :return: A dictionary containing all arguments along with reproducibility information.
414 | """
415 | arg_log = self.as_dict()
416 | arg_log["reproducibility"] = self.get_reproducibility_info(repo_path=repo_path)
417 |
418 | return arg_log
419 |
420 | def parse_args(
421 | self: TapType,
422 | args: Optional[Sequence[str]] = None,
423 | known_only: bool = False,
424 | legacy_config_parsing: bool = False,
425 | ) -> TapType:
426 | """Parses arguments, sets attributes of self equal to the parsed arguments, and processes arguments.
427 |
428 | :param args: List of strings to parse. The default is taken from `sys.argv`.
429 | :param known_only: If true, ignores extra arguments and only parses known arguments.
430 | Unparsed arguments are saved to self.extra_args.
431 | :param legacy_config_parsing: If true, config files are parsed using `str.split` instead of `shlex.split`.
432 | :return: self, which is a Tap instance containing all of the parsed args.
433 | """
434 | # Prevent double parsing
435 | if self._parsed:
436 | raise ValueError("parse_args can only be called once.")
437 |
438 | # Collect arguments from all of the configs
439 |
440 | if legacy_config_parsing:
441 | splitter = lambda arg_string: arg_string.split()
442 | else:
443 | splitter = lambda arg_string: split(arg_string, comments=True)
444 |
445 | config_args = [arg for args_from_config in self.args_from_configs for arg in splitter(args_from_config)]
446 |
447 | # Add config args at lower precedence and extract args from the command line if they are not passed explicitly
448 | args = config_args + (sys.argv[1:] if args is None else list(args))
449 |
450 | # Parse args using super class ArgumentParser's parse_args or parse_known_args function
451 | if known_only:
452 | default_namespace, self.extra_args = super(Tap, self).parse_known_args(args)
453 | else:
454 | default_namespace = super(Tap, self).parse_args(args)
455 |
456 | # Copy parsed arguments to self
457 | for variable, value in vars(default_namespace).items():
458 | # Conversion from list to set or tuple
459 | if variable in self._annotations:
460 | if type(value) == list:
461 | var_type = get_origin(self._annotations[variable])
462 |
463 | # Unpack nested boxed types such as Optional[List[int]]
464 | if var_type is Union:
465 | var_type = get_origin(get_args(self._annotations[variable])[0])
466 |
467 | # If var_type is tuple as in Python 3.6, change to a typing type
468 | # (e.g., (typing.Tuple, ) ==> typing.Tuple)
469 | if isinstance(var_type, tuple):
470 | var_type = var_type[0]
471 |
472 | if var_type in (Set, set):
473 | value = set(value)
474 | elif var_type in (Tuple, tuple):
475 | value = tuple(value)
476 |
477 | # Set variable in self
478 | setattr(self, variable, value)
479 |
480 | # Process args
481 | self.process_args()
482 |
483 | # Indicate that args have been parsed
484 | self._parsed = True
485 |
486 | return self
487 |
488 | @classmethod
489 | def _get_from_self_and_super(cls, extract_func: Callable[[type], dict]) -> Union[dict[str, Any], dict]:
490 | """Returns a dictionary mapping variable names to values.
491 |
492 | Variables and values are extracted from classes using key starting
493 | with this class and traversing up the super classes up through Tap.
494 |
495 | If super class and subclass have the same key, the subclass value is used.
496 |
497 | Super classes are traversed through breadth first search.
498 |
499 | :param extract_func: A function that extracts from a class a dictionary mapping variables to values.
500 | :return: A dictionary mapping variable names to values from the class dict.
501 | """
502 | visited = set()
503 | super_classes = [cls]
504 | dictionary = {}
505 |
506 | while len(super_classes) > 0:
507 | super_class = super_classes.pop(0)
508 |
509 | if super_class not in visited and issubclass(super_class, Tap) and super_class is not Tap:
510 | super_dictionary = extract_func(super_class)
511 |
512 | # Update only unseen variables to avoid overriding subclass values
513 | for variable, value in super_dictionary.items():
514 | if variable not in dictionary:
515 | dictionary[variable] = value
516 | for variable in super_dictionary.keys() - dictionary.keys():
517 | dictionary[variable] = super_dictionary[variable]
518 |
519 | super_classes += list(super_class.__bases__)
520 | visited.add(super_class)
521 |
522 | return dictionary
523 |
524 | def _get_class_dict(self) -> dict[str, Any]:
525 | """Returns a dictionary mapping class variable names to values from the class dict."""
526 | class_dict = self._get_from_self_and_super(
527 | extract_func=lambda super_class: dict(getattr(super_class, "__dict__", dict()))
528 | )
529 | class_dict = {
530 | var: val
531 | for var, val in class_dict.items()
532 | if not (var.startswith("_") or callable(val) or isinstance(val, (staticmethod, classmethod, property)))
533 | }
534 |
535 | return class_dict
536 |
537 | def _get_annotations(self) -> dict[str, Any]:
538 | """Returns a dictionary mapping variable names to their type annotations."""
539 | return self._get_from_self_and_super(extract_func=lambda super_class: dict(get_type_hints(super_class)))
540 |
541 | def _get_class_variables(self) -> dict:
542 | """Returns a dictionary mapping class variables names to their additional information."""
543 | class_variable_names = {**self._get_annotations(), **self._get_class_dict()}.keys()
544 |
545 | try:
546 | class_variables = self._get_from_self_and_super(extract_func=get_class_variables)
547 |
548 | # Handle edge-case of source code modification while code is running
549 | variables_to_add = class_variable_names - class_variables.keys()
550 | variables_to_remove = class_variables.keys() - class_variable_names
551 |
552 | for variable in variables_to_add:
553 | class_variables[variable] = {"comment": ""}
554 |
555 | for variable in variables_to_remove:
556 | class_variables.pop(variable)
557 | # Exception if inspect.getsource fails to extract the source code
558 | except Exception:
559 | class_variables = {}
560 | for variable in class_variable_names:
561 | class_variables[variable] = {"comment": ""}
562 |
563 | return class_variables
564 |
565 | def _get_argument_names(self) -> set[str]:
566 | """Returns a list of variable names corresponding to the arguments."""
567 | return (
568 | {get_dest(*name_or_flags, **kwargs) for name_or_flags, kwargs in self.argument_buffer.values()}
569 | | set(self._get_class_dict().keys())
570 | | set(self._annotations.keys())
571 | ) - {"help"}
572 |
573 | def as_dict(self) -> dict[str, Any]:
574 | """Returns the member variables corresponding to the parsed arguments.
575 |
576 | Note: This does not include attributes set directly on an instance
577 | (e.g. arg is not included in MyTap().arg = "hi")
578 |
579 | :return: A dictionary mapping each argument's name to its value.
580 | """
581 | if not self._parsed:
582 | raise ValueError("You should call `parse_args` before retrieving arguments.")
583 |
584 | self_dict = self.__dict__
585 | class_dict = self._get_from_self_and_super(
586 | extract_func=lambda super_class: dict(getattr(super_class, "__dict__", dict()))
587 | )
588 | class_dict = {key: val for key, val in class_dict.items() if key not in self_dict}
589 | stored_dict = {**self_dict, **class_dict}
590 |
591 | stored_dict = {
592 | var: getattr(self, var)
593 | for var, val in stored_dict.items()
594 | if not (var.startswith("_") or isinstance(val, MethodType) or isinstance(val, staticmethod))
595 | }
596 |
597 | tap_class_dict_keys = Tap().__dict__.keys() | Tap.__dict__.keys()
598 | stored_dict = {key: stored_dict[key] for key in stored_dict.keys() - tap_class_dict_keys}
599 |
600 | return stored_dict
601 |
602 | def from_dict(self, args_dict: dict[str, Any], skip_unsettable: bool = False) -> TapType:
603 | """Loads arguments from a dictionary, ensuring all required arguments are set.
604 |
605 | :param args_dict: A dictionary from argument names to the values of the arguments.
606 | :param skip_unsettable: When True, skips attributes that cannot be set in the Tap object,
607 | e.g. properties without setters.
608 | :return: Returns self.
609 | """
610 | # All of the required arguments must be provided or already set
611 | required_args = {a.dest for a in self._actions if a.required}
612 | unprovided_required_args = required_args - args_dict.keys()
613 | missing_required_args = [arg for arg in unprovided_required_args if not hasattr(self, arg)]
614 |
615 | if len(missing_required_args) > 0:
616 | raise ValueError(
617 | f'Input dictionary "{args_dict}" does not include '
618 | f'all unset required arguments: "{missing_required_args}".'
619 | )
620 |
621 | # Load all arguments
622 | for key, value in args_dict.items():
623 | try:
624 | setattr(self, key, value)
625 | except AttributeError:
626 | if not skip_unsettable:
627 | raise AttributeError(
628 | f'Cannot set attribute "{key}" to "{value}". '
629 | f"To skip arguments that cannot be set \n"
630 | f'\t"skip_unsettable = True"'
631 | )
632 |
633 | self._parsed = True
634 |
635 | return self
636 |
637 | def save(
638 | self,
639 | path: PathLike,
640 | with_reproducibility: bool = True,
641 | skip_unpicklable: bool = False,
642 | repo_path: Optional[PathLike] = None,
643 | ) -> None:
644 | """Saves the arguments and reproducibility information in JSON format, pickling what can't be encoded.
645 |
646 | :param path: Path to the JSON file where the arguments will be saved.
647 | :param with_reproducibility: If True, adds a "reproducibility" field with information (e.g. git hash)
648 | to the JSON file.
649 | :param repo_path: Path to the git repo to examine for reproducibility info.
650 | If None, uses the git repo of the Python file that is run.
651 | :param skip_unpicklable: If True, does not save attributes whose values cannot be pickled.
652 | """
653 | with open(path, "w") as f:
654 | args = self._log_all(repo_path=repo_path) if with_reproducibility else self.as_dict()
655 | json.dump(args, f, indent=4, sort_keys=True, cls=define_python_object_encoder(skip_unpicklable))
656 |
657 | def load(
658 | self,
659 | path: PathLike,
660 | check_reproducibility: bool = False,
661 | skip_unsettable: bool = False,
662 | repo_path: Optional[PathLike] = None,
663 | ) -> TapType:
664 | """Loads the arguments in JSON format. Note: Due to JSON, tuples are loaded as lists.
665 |
666 | :param path: Path to the JSON file where the arguments will be loaded from.
667 | :param check_reproducibility: When True, raises an error if the loaded reproducibility
668 | information doesn't match the current reproducibility information.
669 | :param skip_unsettable: When True, skips attributes that cannot be set in the Tap object,
670 | e.g. properties without setters.
671 | :param repo_path: Path to the git repo to examine for reproducibility info.
672 | If None, uses the git repo of the Python file that is run.
673 | :return: Returns self.
674 | """
675 | with open(path) as f:
676 | args_dict = json.load(f, object_hook=as_python_object)
677 |
678 | # Remove loaded reproducibility information since it is no longer valid
679 | saved_reproducibility_data = args_dict.pop("reproducibility", None)
680 | if check_reproducibility:
681 | current_reproducibility_data = self.get_reproducibility_info(repo_path=repo_path)
682 | enforce_reproducibility(saved_reproducibility_data, current_reproducibility_data, path)
683 |
684 | self.from_dict(args_dict, skip_unsettable=skip_unsettable)
685 |
686 | return self
687 |
688 | def _load_from_config_files(self, config_files: Optional[list[str]]) -> list[str]:
689 | """Loads arguments from a list of configuration files containing command line arguments.
690 |
691 | :param config_files: A list of paths to configuration files containing the command line arguments
692 | (e.g., '--arg1 a1 --arg2 a2'). Arguments passed in from the command line
693 | overwrite arguments from the configuration files. Arguments in configuration files
694 | that appear later in the list overwrite the arguments in previous configuration files.
695 | :return: A list of the contents of each config file in order of increasing precedence (highest last).
696 | """
697 | args_from_config = []
698 |
699 | if config_files is not None:
700 | # Read arguments from all configs from the lowest precedence config to the highest
701 | for file in config_files:
702 | with open(file) as f:
703 | args_from_config.append(f.read().strip())
704 |
705 | return args_from_config
706 |
707 | def __str__(self) -> str:
708 | """Returns a string representation of self.
709 |
710 | :return: A formatted string representation of the dictionary of all arguments.
711 | """
712 | return pformat(self.as_dict())
713 |
714 | def __deepcopy__(self, memo: dict[int, Any] = None) -> TapType:
715 | """Deepcopy the Tap object."""
716 | copied = type(self).__new__(type(self))
717 |
718 | if memo is None:
719 | memo = {}
720 |
721 | memo[id(self)] = copied
722 |
723 | for k, v in self.__dict__.items():
724 | copied.__dict__[k] = deepcopy(v, memo)
725 |
726 | return copied
727 |
728 | def __getstate__(self) -> dict[str, Any]:
729 | """Gets the state of the object for pickling."""
730 | return self.as_dict()
731 |
732 | def __setstate__(self, d: dict[str, Any]) -> None:
733 | """
734 | Initializes the object with the provided dictionary of arguments for unpickling.
735 |
736 | :param d: A dictionary of arguments.
737 | """
738 | self.__init__()
739 | self.from_dict(d)
740 |
--------------------------------------------------------------------------------
/src/tap/tapify.py:
--------------------------------------------------------------------------------
1 | """
2 | `tapify`: initialize a class or run a function by parsing arguments from the command line.
3 |
4 | `to_tap_class`: convert a class or function into a `Tap` class, which can then be subclassed to add special argument
5 | handling
6 | """
7 |
8 | import dataclasses
9 | import inspect
10 | from typing import Any, Callable, Optional, Sequence, TypeVar, Union
11 |
12 | from docstring_parser import Docstring, parse
13 | from packaging.version import Version
14 |
15 | try:
16 | import pydantic
17 | except ModuleNotFoundError:
18 | _IS_PYDANTIC_V1 = None
19 | # These are "empty" types. isinstance and issubclass will always be False
20 | BaseModel = type("BaseModel", (object,), {})
21 | _PydanticField = type("_PydanticField", (object,), {})
22 | _PYDANTIC_FIELD_TYPES = ()
23 | else:
24 | _IS_PYDANTIC_V1 = Version(pydantic.__version__) < Version("2.0.0")
25 | from pydantic import BaseModel
26 | from pydantic.fields import FieldInfo as PydanticFieldBaseModel
27 | from pydantic.dataclasses import FieldInfo as PydanticFieldDataclass
28 |
29 | _PydanticField = Union[PydanticFieldBaseModel, PydanticFieldDataclass]
30 | # typing.get_args(_PydanticField) is an empty tuple for some reason. Just repeat
31 | _PYDANTIC_FIELD_TYPES = (PydanticFieldBaseModel, PydanticFieldDataclass)
32 |
33 | from tap import Tap
34 |
35 | OutputType = TypeVar("OutputType")
36 |
37 | _ClassOrFunction = Union[Callable[..., OutputType], type[OutputType]]
38 |
39 |
40 | @dataclasses.dataclass
41 | class _ArgData:
42 | """
43 | Data about an argument which is sufficient to inform a Tap variable/argument.
44 | """
45 |
46 | name: str
47 |
48 | annotation: type
49 | "The type of values this argument accepts"
50 |
51 | is_required: bool
52 | "Whether or not the argument must be passed in"
53 |
54 | default: Any
55 | "Value of the argument if the argument isn't passed in. This gets ignored if `is_required`"
56 |
57 | description: Optional[str] = ""
58 | "Human-readable description of the argument"
59 |
60 | is_positional_only: bool = False
61 | "Whether or not the argument must be provided positionally"
62 |
63 |
64 | @dataclasses.dataclass(frozen=True)
65 | class _TapData:
66 | """
67 | Data about a class' or function's arguments which are sufficient to inform a Tap class.
68 | """
69 |
70 | args_data: list[_ArgData]
71 | "List of data about each argument in the class or function"
72 |
73 | has_kwargs: bool
74 | "True if you can pass variable/extra kwargs to the class or function (as in **kwargs), else False"
75 |
76 | known_only: bool
77 | "If true, ignore extra arguments and only parse known arguments"
78 |
79 |
80 | def _is_pydantic_base_model(obj: Union[type[Any], Any]) -> bool:
81 | if inspect.isclass(obj): # issubclass requires that obj is a class
82 | return issubclass(obj, BaseModel)
83 | else:
84 | return isinstance(obj, BaseModel)
85 |
86 |
87 | def _is_pydantic_dataclass(obj: Union[type[Any], Any]) -> bool:
88 | if _IS_PYDANTIC_V1:
89 | # There's no public function in v1. This is a somewhat safe but linear check
90 | return dataclasses.is_dataclass(obj) and any(key.startswith("__pydantic") for key in obj.__dict__)
91 | else:
92 | return pydantic.dataclasses.is_pydantic_dataclass(obj)
93 |
94 |
95 | def _tap_data_from_data_model(
96 | data_model: Any, func_kwargs: dict[str, Any], param_to_description: dict[str, str] = None
97 | ) -> _TapData:
98 | """
99 | Currently only works when `data_model` is a:
100 | - builtin dataclass (class or instance)
101 | - Pydantic dataclass (class or instance)
102 | - Pydantic BaseModel (class or instance).
103 |
104 | The advantage of this function over :func:`_tap_data_from_class_or_function` is that field/argument descriptions are
105 | extracted, b/c this function look at the fields of the data model.
106 |
107 | Note
108 | ----
109 | Deletes redundant keys from `func_kwargs`
110 | """
111 | param_to_description = param_to_description or {}
112 |
113 | def arg_data_from_dataclass(name: str, field: dataclasses.Field) -> _ArgData:
114 | def is_required(field: dataclasses.Field) -> bool:
115 | return field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING
116 |
117 | description = param_to_description.get(name, field.metadata.get("description"))
118 | return _ArgData(
119 | name,
120 | field.type,
121 | is_required(field),
122 | field.default,
123 | description,
124 | )
125 |
126 | def arg_data_from_pydantic(name: str, field: _PydanticField, annotation: Optional[type] = None) -> _ArgData:
127 | annotation = field.annotation if annotation is None else annotation
128 | # Prefer the description from param_to_description (from the data model / class docstring) over the
129 | # field.description b/c a docstring can be modified on the fly w/o causing real issues
130 | description = param_to_description.get(name, field.description)
131 | return _ArgData(name, annotation, field.is_required(), field.default, description)
132 |
133 | # Determine what type of data model it is and extract fields accordingly
134 | if dataclasses.is_dataclass(data_model):
135 | name_to_field = {field.name: field for field in dataclasses.fields(data_model)}
136 | has_kwargs = False
137 | known_only = False
138 | elif _is_pydantic_base_model(data_model):
139 | name_to_field = data_model.model_fields
140 | # For backwards compatibility, only allow new kwargs to get assigned if the model is explicitly configured to do
141 | # so via extra="allow". See https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra
142 | is_extra_ok = data_model.model_config.get("extra", "ignore") == "allow"
143 | has_kwargs = is_extra_ok
144 | known_only = is_extra_ok
145 | else:
146 | raise TypeError(
147 | "data_model must be a builtin or Pydantic dataclass (instance or class) or "
148 | f"a Pydantic BaseModel (instance or class). Got {type(data_model)}"
149 | )
150 |
151 | # It's possible to mix fields w/ classes, e.g., use pydantic Fields in a (builtin) dataclass, or use (builtin)
152 | # dataclass fields in a pydantic BaseModel. It's also possible to use (builtin) dataclass fields and pydantic Fields
153 | # in the same data model. Therefore, the type of the data model doesn't determine the type of each field. The
154 | # solution is to iterate through the fields and check each type.
155 | args_data: list[_ArgData] = []
156 | for name, field in name_to_field.items():
157 | if isinstance(field, dataclasses.Field):
158 | # Idiosyncrasy: if a pydantic Field is used in a pydantic dataclass, then field.default is a FieldInfo
159 | # object instead of the field's default value. Furthermore, field.annotation is always NoneType. Luckily,
160 | # the actual type of the field is stored in field.type
161 | if isinstance(field.default, _PYDANTIC_FIELD_TYPES):
162 | arg_data = arg_data_from_pydantic(name, field.default, annotation=field.type)
163 | else:
164 | arg_data = arg_data_from_dataclass(name, field)
165 | elif isinstance(field, _PYDANTIC_FIELD_TYPES):
166 | arg_data = arg_data_from_pydantic(name, field)
167 | else:
168 | raise TypeError(f"Each field must be a dataclass or Pydantic field. Got {type(field)}")
169 | # Handle case where func_kwargs is supplied
170 | if name in func_kwargs:
171 | arg_data.default = func_kwargs[name]
172 | arg_data.is_required = False
173 | del func_kwargs[name]
174 | args_data.append(arg_data)
175 | return _TapData(args_data, has_kwargs, known_only)
176 |
177 |
178 | def _tap_data_from_class_or_function(
179 | class_or_function: _ClassOrFunction, func_kwargs: dict[str, Any], param_to_description: dict[str, str]
180 | ) -> _TapData:
181 | """
182 | Extract data by inspecting the signature of `class_or_function`.
183 |
184 | Note
185 | ----
186 | Deletes redundant keys from `func_kwargs`
187 | """
188 | args_data: list[_ArgData] = []
189 | has_kwargs = False
190 | known_only = False
191 |
192 | sig = inspect.signature(class_or_function)
193 |
194 | for param_name, param in sig.parameters.items():
195 | # Skip **kwargs
196 | if param.kind == inspect.Parameter.VAR_KEYWORD:
197 | has_kwargs = True
198 | known_only = True
199 | continue
200 |
201 | if param.annotation != inspect.Parameter.empty:
202 | annotation = param.annotation
203 | else:
204 | annotation = Any
205 |
206 | if param.name in func_kwargs:
207 | is_required = False
208 | default = func_kwargs[param.name]
209 | del func_kwargs[param.name]
210 | elif param.default != inspect.Parameter.empty:
211 | is_required = False
212 | default = param.default
213 | else:
214 | is_required = True
215 | default = inspect.Parameter.empty # Can be set to anything. It'll be ignored
216 |
217 | arg_data = _ArgData(
218 | name=param_name,
219 | annotation=annotation,
220 | is_required=is_required,
221 | default=default,
222 | description=param_to_description.get(param.name),
223 | is_positional_only=param.kind == inspect.Parameter.POSITIONAL_ONLY,
224 | )
225 | args_data.append(arg_data)
226 | return _TapData(args_data, has_kwargs, known_only)
227 |
228 |
229 | def _is_data_model(obj: Union[type[Any], Any]) -> bool:
230 | return dataclasses.is_dataclass(obj) or _is_pydantic_base_model(obj)
231 |
232 |
233 | def _docstring(class_or_function) -> Docstring:
234 | is_function = not inspect.isclass(class_or_function)
235 | if is_function or _is_pydantic_base_model(class_or_function):
236 | doc = class_or_function.__doc__
237 | else:
238 | doc = class_or_function.__init__.__doc__ or class_or_function.__doc__
239 | return parse(doc)
240 |
241 |
242 | def _tap_data(class_or_function: _ClassOrFunction, param_to_description: dict[str, str], func_kwargs) -> _TapData:
243 | """
244 | Controls how :class:`_TapData` is extracted from `class_or_function`.
245 | """
246 | is_pydantic_v1_data_model = _IS_PYDANTIC_V1 and (
247 | _is_pydantic_base_model(class_or_function) or _is_pydantic_dataclass(class_or_function)
248 | )
249 | if _is_data_model(class_or_function) and not is_pydantic_v1_data_model:
250 | # Data models from Pydantic v1 don't lend themselves well to _tap_data_from_data_model.
251 | # _tap_data_from_data_model looks at the data model's fields. In Pydantic v1, the field.type_ attribute stores
252 | # the field's annotation/type. But (in Pydantic v1) there's a bug where field.type_ is set to the inner-most
253 | # type of a subscripted type. For example, annotating a field with list[str] causes field.type_ to be str, not
254 | # list[str]. To get around this, we'll extract _TapData by looking at the signature of the data model
255 | return _tap_data_from_data_model(class_or_function, func_kwargs, param_to_description)
256 | # TODO: allow passing func_kwargs to a Pydantic BaseModel
257 | return _tap_data_from_class_or_function(class_or_function, func_kwargs, param_to_description)
258 |
259 |
260 | def _tap_class(args_data: Sequence[_ArgData]) -> type[Tap]:
261 | """
262 | Transfers argument data to a :class:`tap.Tap` class. Arguments will be added to the parser on initialization.
263 | """
264 |
265 | class ArgParser(Tap):
266 | # Overwriting configure would force a user to remember to call super().configure if they want to overwrite it
267 | # Instead, overwrite _configure
268 | def _configure(self):
269 | for arg_data in args_data:
270 | variable = arg_data.name
271 | self._annotations[variable] = str if arg_data.annotation is Any else arg_data.annotation
272 | self.class_variables[variable] = {"comment": arg_data.description or ""}
273 | if arg_data.is_required:
274 | kwargs = {}
275 | else:
276 | kwargs = dict(required=False, default=arg_data.default)
277 | self.add_argument(f"--{variable}", **kwargs)
278 |
279 | super()._configure()
280 |
281 | return ArgParser
282 |
283 |
284 | def to_tap_class(class_or_function: _ClassOrFunction) -> type[Tap]:
285 | """Creates a `Tap` class from `class_or_function`. This can be subclassed to add custom argument handling and
286 | instantiated to create a typed argument parser.
287 |
288 | :param class_or_function: The class or function to run with the provided arguments.
289 | """
290 | docstring = _docstring(class_or_function)
291 | param_to_description = {param.arg_name: param.description for param in docstring.params}
292 | # TODO: add func_kwargs
293 | tap_data = _tap_data(class_or_function, param_to_description, func_kwargs={})
294 | return _tap_class(tap_data.args_data)
295 |
296 |
297 | def tapify(
298 | class_or_function: Union[Callable[..., OutputType], type[OutputType]],
299 | known_only: bool = False,
300 | command_line_args: Optional[list[str]] = None,
301 | explicit_bool: bool = False,
302 | underscores_to_dashes: bool = False,
303 | description: Optional[str] = None,
304 | **func_kwargs,
305 | ) -> OutputType:
306 | """Tapify initializes a class or runs a function by parsing arguments from the command line.
307 |
308 | :param class_or_function: The class or function to run with the provided arguments.
309 | :param known_only: If true, ignores extra arguments and only parses known arguments.
310 | :param command_line_args: A list of command line style arguments to parse (e.g., `['--arg', 'value']`). If None,
311 | arguments are parsed from the command line (default behavior).
312 | :param explicit_bool: Booleans can be specified on the command line as `--arg True` or `--arg False` rather than
313 | `--arg`. Additionally, booleans can be specified by prefixes of True and False with any
314 | capitalization as well as 1 or 0.
315 | :param underscores_to_dashes: If True, convert underscores in flag names to dashes.
316 | :param description: The description displayed in the help message—the same description passed in
317 | `argparse.ArgumentParser(description=...)`. By default, it's extracted from `class_or_function`'s
318 | docstring.
319 | :param func_kwargs: Additional keyword arguments for the function. These act as default values when parsing the
320 | command line arguments and overwrite the function defaults but are overwritten by the parsed
321 | command line arguments.
322 | """
323 | # We don't directly call to_tap_class b/c we need tap_data, not just tap_class
324 | docstring = _docstring(class_or_function)
325 | param_to_description = {param.arg_name: param.description for param in docstring.params}
326 | tap_data = _tap_data(class_or_function, param_to_description, func_kwargs)
327 | tap_class = _tap_class(tap_data.args_data)
328 | # Create a Tap object
329 | if description is None:
330 | description = "\n".join(filter(None, (docstring.short_description, docstring.long_description)))
331 | tap = tap_class(description=description, explicit_bool=explicit_bool, underscores_to_dashes=underscores_to_dashes)
332 |
333 | # If any func_kwargs remain, they are not used in the function, so raise an error
334 | known_only = known_only or tap_data.known_only
335 | if func_kwargs and not known_only:
336 | raise ValueError(f"Unknown keyword arguments: {func_kwargs}")
337 |
338 | # Parse command line arguments
339 | command_line_args: Tap = tap.parse_args(args=command_line_args, known_only=known_only)
340 |
341 | # Prepare command line arguments for class_or_function, respecting positional-only args
342 | class_or_function_args: list[Any] = []
343 | class_or_function_kwargs: dict[str, Any] = {}
344 | command_line_args_dict = command_line_args.as_dict()
345 | for arg_data in tap_data.args_data:
346 | arg_value = command_line_args_dict[arg_data.name]
347 | if arg_data.is_positional_only:
348 | class_or_function_args.append(arg_value)
349 | else:
350 | class_or_function_kwargs[arg_data.name] = arg_value
351 |
352 | # Get **kwargs from extra command line arguments
353 | if tap_data.has_kwargs:
354 | kwargs = {tap.extra_args[i].lstrip("-"): tap.extra_args[i + 1] for i in range(0, len(tap.extra_args), 2)}
355 | class_or_function_kwargs.update(kwargs)
356 |
357 | # Initialize the class or run the function with the parsed arguments
358 | return class_or_function(*class_or_function_args, **class_or_function_kwargs)
359 |
--------------------------------------------------------------------------------
/src/tap/utils.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser, ArgumentTypeError
2 | import ast
3 | from base64 import b64encode, b64decode
4 | import inspect
5 | from io import StringIO
6 | from json import JSONEncoder
7 | import os
8 | import pickle
9 | import re
10 | import subprocess
11 | import sys
12 | import textwrap
13 | import tokenize
14 | from typing import (
15 | Any,
16 | Callable,
17 | Generator,
18 | Iterable,
19 | Iterator,
20 | Literal,
21 | Optional,
22 | Union,
23 | )
24 | from typing_inspect import get_args as typing_inspect_get_args, get_origin as typing_inspect_get_origin
25 | import warnings
26 |
27 | if sys.version_info >= (3, 10):
28 | from types import UnionType
29 |
30 | NO_CHANGES_STATUS = """nothing to commit, working tree clean"""
31 | PRIMITIVES = (str, int, float, bool)
32 | PathLike = Union[str, os.PathLike]
33 |
34 |
35 | def check_output(command: list[str], suppress_stderr: bool = True, **kwargs) -> str:
36 | """Runs subprocess.check_output and returns the result as a string.
37 |
38 | :param command: A list of strings representing the command to run on the command line.
39 | :param suppress_stderr: Whether to suppress anything written to standard error.
40 | :return: The output of the command, converted from bytes to string and stripped.
41 | """
42 | with open(os.devnull, "w") as devnull:
43 | devnull = devnull if suppress_stderr else None
44 | output = subprocess.check_output(command, stderr=devnull, **kwargs).decode("utf-8").strip()
45 | return output
46 |
47 |
48 | class GitInfo:
49 | """Class with helper methods for extracting information about a git repo."""
50 |
51 | def __init__(self, repo_path: PathLike):
52 | self.repo_path = repo_path
53 |
54 | def has_git(self) -> bool:
55 | """Returns whether git is installed.
56 |
57 | :return: True if git is installed, False otherwise.
58 | """
59 | try:
60 | output = check_output(["git", "rev-parse", "--is-inside-work-tree"], cwd=self.repo_path)
61 | return output == "true"
62 | except (FileNotFoundError, subprocess.CalledProcessError):
63 | return False
64 |
65 | def get_git_root(self) -> str:
66 | """Gets the root directory of the git repo where the command is run.
67 |
68 | :return: The root directory of the current git repo.
69 | """
70 | return check_output(["git", "rev-parse", "--show-toplevel"], cwd=self.repo_path)
71 |
72 | def get_git_version(self) -> tuple:
73 | """Gets the version of git.
74 |
75 | :return: The version of git, as a tuple of strings.
76 |
77 | Example:
78 | >>> get_git_version()
79 | (2, 17, 1) # for git version 2.17.1
80 | """
81 | raw = check_output(["git", "--version"])
82 | number_start_index = next(i for i, c in enumerate(raw) if c.isdigit())
83 | return tuple(int(num) for num in raw[number_start_index:].split(".") if num.isdigit())
84 |
85 | def get_git_url(self, commit_hash: bool = True) -> str:
86 | """Gets the https url of the git repo where the command is run.
87 |
88 | :param commit_hash: If True, the url links to the latest local git commit hash.
89 | If False, the url links to the general git url.
90 | :return: The https url of the current git repo or an empty string for a local repo.
91 | """
92 | # Get git url (either https or ssh)
93 | input_remote = (
94 | ["git", "remote", "get-url", "origin"]
95 | if self.get_git_version() >= (2, 0)
96 | else ["git", "config", "--get", "remote.origin.url"]
97 | )
98 | try:
99 | url = check_output(input_remote, cwd=self.repo_path)
100 | except subprocess.CalledProcessError as e:
101 | if e.returncode in {2, 128}:
102 | # https://git-scm.com/docs/git-remote#_exit_status
103 | # 2: The remote does not exist.
104 | # 128: The remote was not found.
105 | return ""
106 | raise e
107 |
108 | # Remove .git at end
109 | url = url[: -len(".git")]
110 |
111 | # Convert ssh url to https url
112 | m = re.search("git@(.+):", url)
113 | if m is not None:
114 | domain = m.group(1)
115 | path = url[m.span()[1] :]
116 | url = f"https://{domain}/{path}"
117 |
118 | if commit_hash:
119 | # Add tree and hash of current commit
120 | url = f"{url}/tree/{self.get_git_hash()}"
121 |
122 | return url
123 |
124 | def get_git_hash(self) -> str:
125 | """Gets the git hash of HEAD of the git repo where the command is run.
126 |
127 | :return: The git hash of HEAD of the current git repo.
128 | """
129 | return check_output(["git", "rev-parse", "HEAD"], cwd=self.repo_path)
130 |
131 | def has_uncommitted_changes(self) -> bool:
132 | """Returns whether there are uncommitted changes in the git repo where the command is run.
133 |
134 | :return: True if there are uncommitted changes in the current git repo, False otherwise.
135 | """
136 | status = check_output(["git", "status"], cwd=self.repo_path)
137 |
138 | return not status.endswith(NO_CHANGES_STATUS)
139 |
140 |
141 | def type_to_str(type_annotation: Union[type, Any]) -> str:
142 | """Gets a string representation of the provided type.
143 |
144 | :param type_annotation: A type annotation, which is either a built-in type or a typing type.
145 | :return: A string representation of the type annotation.
146 | """
147 | # Built-in type
148 | if type(type_annotation) == type:
149 | return type_annotation.__name__
150 |
151 | # Typing type
152 | return str(type_annotation).replace("typing.", "")
153 |
154 |
155 | def get_argument_name(*name_or_flags) -> str:
156 | """Gets the name of the argument.
157 |
158 | :param name_or_flags: Either a name or a list of option strings, e.g. foo or -f, --foo.
159 | :return: The name of the argument (extracted from name_or_flags).
160 | """
161 | if "-h" in name_or_flags or "--help" in name_or_flags:
162 | return "help"
163 |
164 | if len(name_or_flags) > 1:
165 | name_or_flags = tuple(n_or_f for n_or_f in name_or_flags if n_or_f.startswith("--"))
166 |
167 | if len(name_or_flags) != 1:
168 | raise ValueError(f"There should only be a single canonical name for argument {name_or_flags}!")
169 |
170 | return name_or_flags[0].lstrip("-")
171 |
172 |
173 | def get_dest(*name_or_flags, **kwargs) -> str:
174 | """Gets the name of the destination of the argument.
175 |
176 | :param name_or_flags: Either a name or a list of option strings, e.g. foo or -f, --foo.
177 | :param kwargs: Keyword arguments.
178 | :return: The name of the argument (extracted from name_or_flags).
179 | """
180 | if "-h" in name_or_flags or "--help" in name_or_flags:
181 | return "help"
182 |
183 | return ArgumentParser().add_argument(*name_or_flags, **kwargs).dest
184 |
185 |
186 | def is_option_arg(*name_or_flags) -> bool:
187 | """Returns whether the argument is an option arg (as opposed to a positional arg).
188 |
189 | :param name_or_flags: Either a name or a list of option strings, e.g. foo or -f, --foo.
190 | :return: True if the argument is an option arg, False otherwise.
191 | """
192 | return any(name_or_flag.startswith("-") for name_or_flag in name_or_flags)
193 |
194 |
195 | def is_positional_arg(*name_or_flags) -> bool:
196 | """Returns whether the argument is a positional arg (as opposed to an optional arg).
197 |
198 | :param name_or_flags: Either a name or a list of option strings, e.g. foo or -f, --foo.
199 | :return: True if the argument is a positional arg, False otherwise.
200 | """
201 | return not is_option_arg(*name_or_flags)
202 |
203 |
204 | def tokenize_source(source: str) -> Generator[tokenize.TokenInfo, None, None]:
205 | """Returns a generator for the tokens of the object's source code, given the source code."""
206 | return tokenize.generate_tokens(StringIO(source).readline)
207 |
208 |
209 | def get_class_column(tokens: Iterable[tokenize.TokenInfo]) -> int:
210 | """Determines the column number for class variables in a class, given the tokens of the class."""
211 | first_line = 1
212 | for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens:
213 | if token.strip() == "@":
214 | first_line += 1
215 | if start_line <= first_line or token.strip() == "":
216 | continue
217 |
218 | return start_column
219 | raise ValueError("Could not find any class variables in the class.")
220 |
221 |
222 | def source_line_to_tokens(tokens: Iterable[tokenize.TokenInfo]) -> dict[int, list[dict[str, Union[str, int]]]]:
223 | """Extract a map from each line number to list of mappings providing information about each token."""
224 | line_to_tokens = {}
225 | for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens:
226 | line_to_tokens.setdefault(start_line, []).append(
227 | {
228 | "token_type": token_type,
229 | "token": token,
230 | "start_line": start_line,
231 | "start_column": start_column,
232 | "end_line": end_line,
233 | "end_column": end_column,
234 | "line": line,
235 | }
236 | )
237 |
238 | return line_to_tokens
239 |
240 |
241 | def get_subsequent_assign_lines(source_cls: str) -> tuple[set[int], set[int]]:
242 | """For all multiline assign statements, get the line numbers after the first line in the assignment.
243 |
244 | :param source_cls: The source code of the class.
245 | :return: A set of intermediate line numbers for multiline assign statements and a set of final line numbers.
246 | """
247 | # Parse source code using ast (with an if statement to avoid indentation errors)
248 | source = f"if True:\n{textwrap.indent(source_cls, ' ')}"
249 | body = ast.parse(source).body[0]
250 |
251 | # Set up warning message
252 | parse_warning = (
253 | "Could not parse class source code to extract comments. Comments in the help string may be incorrect."
254 | )
255 |
256 | # Check for correct parsing
257 | if not isinstance(body, ast.If):
258 | warnings.warn(parse_warning)
259 | return set(), set()
260 |
261 | # Extract if body
262 | if_body = body.body
263 |
264 | # Check for a single body
265 | if len(if_body) != 1:
266 | warnings.warn(parse_warning)
267 | return set(), set()
268 |
269 | # Extract class body
270 | cls_body = if_body[0]
271 |
272 | # Check for a single class definition
273 | if not isinstance(cls_body, ast.ClassDef):
274 | warnings.warn(parse_warning)
275 | return set(), set()
276 |
277 | # Get line numbers of assign statements
278 | intermediate_assign_lines = set()
279 | final_assign_lines = set()
280 | for node in cls_body.body:
281 | if isinstance(node, (ast.Assign, ast.AnnAssign)):
282 | # Check if the end line number is found
283 | if node.end_lineno is None:
284 | warnings.warn(parse_warning)
285 | continue
286 |
287 | # Only consider multiline assign statements
288 | if node.end_lineno > node.lineno:
289 | # Get intermediate line number of assign statement excluding the first line (and minus 1 for the if statement)
290 | intermediate_assign_lines |= set(range(node.lineno, node.end_lineno - 1))
291 |
292 | # If multiline assign statement, get the line number of the last line (and minus 1 for the if statement)
293 | final_assign_lines.add(node.end_lineno - 1)
294 |
295 | return intermediate_assign_lines, final_assign_lines
296 |
297 |
298 | def get_class_variables(cls: type) -> dict[str, dict[str, str]]:
299 | """Returns a dictionary mapping class variables to their additional information (currently just comments)."""
300 | # Get the source code and tokens of the class
301 | source_cls = inspect.getsource(cls)
302 | tokens = tuple(tokenize_source(source_cls))
303 |
304 | # Get mapping from line number to tokens
305 | line_to_tokens = source_line_to_tokens(tokens)
306 |
307 | # Get class variable column number
308 | class_variable_column = get_class_column(tokens)
309 |
310 | # For all multiline assign statements, get the line numbers after the first line of the assignment
311 | # This is used to avoid identifying comments in multiline assign statements
312 | intermediate_assign_lines, final_assign_lines = get_subsequent_assign_lines(source_cls)
313 |
314 | # Extract class variables
315 | class_variable = None
316 | variable_to_comment = {}
317 | for line, tokens in line_to_tokens.items():
318 | # If this is the final line of a multiline assign, extract any potential comments
319 | if line in final_assign_lines:
320 | # Find the comment (if it exists)
321 | for token in tokens:
322 | if token["token_type"] == tokenize.COMMENT:
323 | # Leave out "#" and whitespace from comment
324 | variable_to_comment[class_variable]["comment"] = token["token"][1:].strip()
325 | break
326 | continue
327 |
328 | # Skip assign lines after the first line of multiline assign statements
329 | if line in intermediate_assign_lines:
330 | continue
331 |
332 | for i, token in enumerate(tokens):
333 | # Skip whitespace
334 | if token["token"].strip() == "":
335 | continue
336 |
337 | # Extract multiline comments
338 | if (
339 | class_variable is not None
340 | and token["token_type"] == tokenize.STRING
341 | and token["token"][:1] in {'"', "'"}
342 | ):
343 | sep = " " if variable_to_comment[class_variable]["comment"] else ""
344 |
345 | # Identify the quote character (single or double)
346 | quote_char = token["token"][:1]
347 |
348 | # Identify the number of quote characters at the start of the string
349 | num_quote_chars = len(token["token"]) - len(token["token"].lstrip(quote_char))
350 |
351 | # Remove the number of quote characters at the start of the string and the end of the string
352 | token["token"] = token["token"][num_quote_chars:-num_quote_chars]
353 |
354 | # Remove the unicode escape sequences (e.g. "\"")
355 | token["token"] = bytes(token["token"], encoding="ascii").decode("unicode-escape")
356 |
357 | # Add the token to the comment, stripping whitespace
358 | variable_to_comment[class_variable]["comment"] += sep + token["token"].strip()
359 |
360 | # Match class variable
361 | class_variable = None
362 | if (
363 | token["token_type"] == tokenize.NAME
364 | and token["start_column"] == class_variable_column
365 | and len(tokens) > i
366 | and tokens[i + 1]["token"] in ["=", ":"]
367 | ):
368 |
369 | class_variable = token["token"]
370 | variable_to_comment[class_variable] = {"comment": ""}
371 |
372 | # Find the comment (if it exists)
373 | for j in range(i + 1, len(tokens)):
374 | if tokens[j]["token_type"] == tokenize.COMMENT:
375 | # Leave out "#" and whitespace from comment
376 | variable_to_comment[class_variable]["comment"] = tokens[j]["token"][1:].strip()
377 | break
378 |
379 | break
380 |
381 | return variable_to_comment
382 |
383 |
384 | def get_literals(literal: Literal, variable: str) -> tuple[Callable[[str], Any], list[type]]:
385 | """Extracts the values from a Literal type and ensures that the values are all primitive types."""
386 | literals = list(get_args(literal))
387 |
388 | if not all(isinstance(literal, PRIMITIVES) for literal in literals):
389 | raise ArgumentTypeError(
390 | f'The type for variable "{variable}" contains a literal'
391 | f"of a non-primitive type e.g. (str, int, float, bool).\n"
392 | f"Currently only primitive-typed literals are supported."
393 | )
394 |
395 | str_to_literal = {str(literal): literal for literal in literals}
396 |
397 | if len(literals) != len(str_to_literal):
398 | raise ArgumentTypeError("All literals must have unique string representations")
399 |
400 | def var_type(arg: str) -> Any:
401 | if arg not in str_to_literal:
402 | raise ArgumentTypeError(f'Value for variable "{variable}" must be one of {literals}.')
403 |
404 | return str_to_literal[arg]
405 |
406 | return var_type, literals
407 |
408 |
409 | def boolean_type(flag_value: str) -> bool:
410 | """Convert a string to a boolean if it is a prefix of 'True' or 'False' (case insensitive) or is '1' or '0'."""
411 | if "true".startswith(flag_value.lower()) or flag_value == "1":
412 | return True
413 | if "false".startswith(flag_value.lower()) or flag_value == "0":
414 | return False
415 | raise ArgumentTypeError('Value has to be a prefix of "True" or "False" (case insensitive) or "1" or "0".')
416 |
417 |
418 | class TupleTypeEnforcer:
419 | """The type argument to argparse for checking and applying types to Tuples."""
420 |
421 | def __init__(self, types: list[type], loop: bool = False):
422 | self.types = [boolean_type if t == bool else t for t in types]
423 | self.loop = loop
424 | self.index = 0
425 |
426 | def __call__(self, arg: str) -> Any:
427 | arg = self.types[self.index](arg)
428 | self.index += 1
429 |
430 | if self.loop:
431 | self.index %= len(self.types)
432 |
433 | return arg
434 |
435 |
436 | class MockTuple:
437 | """Mock of a tuple needed to prevent JSON encoding tuples as lists."""
438 |
439 | def __init__(self, _tuple: tuple) -> None:
440 | self.tuple = _tuple
441 |
442 |
443 | def _nested_replace_type(obj: Any, find_type: type, replace_type: type) -> Any:
444 | """Replaces any instance (including instances within lists, tuple, dict) of find_type with an instance of replace_type.
445 |
446 | Note: Tuples, lists, and dicts are NOT modified in place.
447 | Note: Does NOT do a nested search through objects besides tuples, lists, and dicts (e.g. sets).
448 |
449 | :param obj: The object to modify by replacing find_type instances with replace_type instances.
450 | :param find_type: The type to find in obj.
451 | :param replace_type: The type to used to replace find_type in obj.
452 | :return: A version of obj with all instances of find_type replaced by replace_type
453 | """
454 | if isinstance(obj, tuple):
455 | obj = tuple(_nested_replace_type(item, find_type, replace_type) for item in obj)
456 |
457 | elif isinstance(obj, list):
458 | obj = [_nested_replace_type(item, find_type, replace_type) for item in obj]
459 |
460 | elif isinstance(obj, dict):
461 | obj = {
462 | _nested_replace_type(key, find_type, replace_type): _nested_replace_type(value, find_type, replace_type)
463 | for key, value in obj.items()
464 | }
465 |
466 | if isinstance(obj, find_type):
467 | obj = replace_type(obj)
468 |
469 | return obj
470 |
471 |
472 | def define_python_object_encoder(skip_unpicklable: bool = False) -> "PythonObjectEncoder": # noqa F821
473 | class PythonObjectEncoder(JSONEncoder):
474 | """Stores parameters that are not JSON serializable as pickle dumps.
475 |
476 | See: https://stackoverflow.com/a/36252257
477 | """
478 |
479 | def iterencode(self, o: Any, _one_shot: bool = False) -> Iterator[str]:
480 | o = _nested_replace_type(o, tuple, MockTuple)
481 | return super(PythonObjectEncoder, self).iterencode(o, _one_shot)
482 |
483 | def default(self, obj: Any) -> Any:
484 | if isinstance(obj, set):
485 | return {"_type": "set", "_value": list(obj)}
486 | elif isinstance(obj, MockTuple):
487 | return {"_type": "tuple", "_value": list(obj.tuple)}
488 |
489 | try:
490 | return {
491 | "_type": f"python_object (type = {obj.__class__.__name__})",
492 | "_value": b64encode(pickle.dumps(obj)).decode("utf-8"),
493 | "_string": str(obj),
494 | }
495 | except (pickle.PicklingError, TypeError, AttributeError) as e:
496 | if not skip_unpicklable:
497 | raise ValueError(
498 | f"Could not pickle this object: Failed with exception {e}\n"
499 | f"If you would like to ignore unpicklable attributes set "
500 | f"skip_unpickleable = True in save."
501 | )
502 | else:
503 | return {"_type": f"unpicklable_object {obj.__class__.__name__}", "_value": None}
504 |
505 | return PythonObjectEncoder
506 |
507 |
508 | class UnpicklableObject:
509 | """A class that serves as a placeholder for an object that could not be pickled."""
510 |
511 | def __eq__(self, other):
512 | return isinstance(other, UnpicklableObject)
513 |
514 |
515 | def as_python_object(dct: Any) -> Any:
516 | """The hooks that allow a parameter that is not JSON serializable to be loaded.
517 |
518 | See: https://stackoverflow.com/a/36252257
519 | """
520 | if "_type" in dct and "_value" in dct:
521 | _type, value = dct["_type"], dct["_value"]
522 |
523 | if _type == "tuple":
524 | return tuple(value)
525 |
526 | elif _type == "set":
527 | return set(value)
528 |
529 | elif _type.startswith("python_object"):
530 | return pickle.loads(b64decode(value.encode("utf-8")))
531 |
532 | elif _type.startswith("unpicklable_object"):
533 | return UnpicklableObject()
534 |
535 | else:
536 | raise ArgumentTypeError(f'Special type "{_type}" not supported for JSON loading.')
537 |
538 | return dct
539 |
540 |
541 | def enforce_reproducibility(
542 | saved_reproducibility_data: Optional[dict[str, str]], current_reproducibility_data: dict[str, str], path: PathLike
543 | ) -> None:
544 | """Checks if reproducibility has failed and raises the appropriate error.
545 |
546 | :param saved_reproducibility_data: Reproducibility information loaded from a saved file.
547 | :param current_reproducibility_data: Reproducibility information from the current object.
548 | :param path: The path name of the file that is being loaded.
549 | """
550 | no_reproducibility_message = "Reproducibility not guaranteed"
551 |
552 | if saved_reproducibility_data is None:
553 | raise ValueError(
554 | f"{no_reproducibility_message}: Could not find reproducibility "
555 | f'information in args loaded from "{path}".'
556 | )
557 |
558 | if "git_url" not in saved_reproducibility_data:
559 | raise ValueError(f"{no_reproducibility_message}: Could not find " f'git url in args loaded from "{path}".')
560 |
561 | if "git_url" not in current_reproducibility_data:
562 | raise ValueError(f"{no_reproducibility_message}: Could not find " f"git url in current args.")
563 |
564 | if saved_reproducibility_data["git_url"] != current_reproducibility_data["git_url"]:
565 | raise ValueError(
566 | f"{no_reproducibility_message}: Differing git url/hash "
567 | f'between current args and args loaded from "{path}".'
568 | )
569 |
570 | if saved_reproducibility_data["git_has_uncommitted_changes"]:
571 | raise ValueError(f"{no_reproducibility_message}: Uncommitted changes " f'in args loaded from "{path}".')
572 |
573 | if current_reproducibility_data["git_has_uncommitted_changes"]:
574 | raise ValueError(f"{no_reproducibility_message}: Uncommitted changes " f"in current args.")
575 |
576 |
577 | # TODO: remove this once typing_inspect.get_origin is fixed for Python 3.9 and 3.10
578 | # https://github.com/ilevkivskyi/typing_inspect/issues/64
579 | # https://github.com/ilevkivskyi/typing_inspect/issues/65
580 | def get_origin(tp: Any) -> Any:
581 | """Same as typing_inspect.get_origin but fixes unparameterized generic types like Set."""
582 | origin = typing_inspect_get_origin(tp)
583 |
584 | if origin is None:
585 | origin = tp
586 |
587 | if sys.version_info >= (3, 10) and isinstance(origin, UnionType):
588 | origin = UnionType
589 |
590 | return origin
591 |
592 |
593 | # TODO: remove this once typing_inspect.get_args is fixed for Python 3.10 union types
594 | def get_args(tp: Any) -> tuple[type, ...]:
595 | """Same as typing_inspect.get_args but fixes Python 3.10 union types."""
596 | if sys.version_info >= (3, 10) and isinstance(tp, UnionType):
597 | return tp.__args__
598 |
599 | return typing_inspect_get_args(tp)
600 |
--------------------------------------------------------------------------------
/tests/test_actions.py:
--------------------------------------------------------------------------------
1 | from typing import List, Literal
2 | import unittest
3 | from unittest import TestCase
4 |
5 | from tap import Tap
6 |
7 |
8 | class TestArgparseActions(TestCase):
9 | def test_actions_store_const(self):
10 | class StoreConstTap(Tap):
11 | def configure(self):
12 | self.add_argument("--sum", dest="accumulate", action="store_const", const=sum, default=max)
13 |
14 | args = StoreConstTap().parse_args([])
15 | self.assertFalse(hasattr(args, "sum"))
16 | self.assertEqual(args.accumulate, max)
17 | self.assertEqual(args.as_dict(), {"accumulate": max})
18 |
19 | args = StoreConstTap().parse_args(["--sum"])
20 | self.assertFalse(hasattr(args, "sum"))
21 | self.assertEqual(args.accumulate, sum)
22 | self.assertEqual(args.as_dict(), {"accumulate": sum})
23 |
24 | def test_actions_store_true_default_true(self):
25 | class StoreTrueDefaultTrueTap(Tap):
26 | foobar: bool = True
27 |
28 | def configure(self):
29 | self.add_argument("--foobar", action="store_true")
30 |
31 | args = StoreTrueDefaultTrueTap().parse_args([])
32 | self.assertTrue(args.foobar)
33 |
34 | args = StoreTrueDefaultTrueTap().parse_args(["--foobar"])
35 | self.assertTrue(args.foobar)
36 |
37 | def test_actions_store_true_default_false(self):
38 | class StoreTrueDefaultFalseTap(Tap):
39 | foobar: bool = False
40 |
41 | def configure(self):
42 | self.add_argument("--foobar", action="store_true")
43 |
44 | args = StoreTrueDefaultFalseTap().parse_args([])
45 | self.assertFalse(args.foobar)
46 |
47 | args = StoreTrueDefaultFalseTap().parse_args(["--foobar"])
48 | self.assertTrue(args.foobar)
49 |
50 | def test_actions_store_false_default_true(self):
51 | class StoreFalseDefaultTrueTap(Tap):
52 | foobar: bool = True
53 |
54 | def configure(self):
55 | self.add_argument("--foobar", action="store_false")
56 |
57 | args = StoreFalseDefaultTrueTap().parse_args([])
58 | self.assertTrue(args.foobar)
59 |
60 | args = StoreFalseDefaultTrueTap().parse_args(["--foobar"])
61 | self.assertFalse(args.foobar)
62 |
63 | def test_actions_store_false_default_false(self):
64 | class StoreFalseDefaultFalseTap(Tap):
65 | foobar: bool = False
66 |
67 | def configure(self):
68 | self.add_argument("--foobar", action="store_false")
69 |
70 | args = StoreFalseDefaultFalseTap().parse_args([])
71 | self.assertFalse(args.foobar)
72 |
73 | args = StoreFalseDefaultFalseTap().parse_args(["--foobar"])
74 | self.assertFalse(args.foobar)
75 |
76 | def test_actions_append_list(self):
77 | class AppendListTap(Tap):
78 | arg: List = ["what", "is"]
79 |
80 | def configure(self):
81 | self.add_argument("--arg", action="append")
82 |
83 | args = AppendListTap().parse_args([])
84 | self.assertEqual(args.arg, ["what", "is"])
85 |
86 | args = AppendListTap().parse_args("--arg up --arg today".split())
87 | self.assertEqual(args.arg, "what is up today".split())
88 |
89 | def test_actions_append_list_int(self):
90 | class AppendListIntTap(Tap):
91 | arg: List[int] = [1, 2]
92 |
93 | def configure(self):
94 | self.add_argument("--arg", action="append")
95 |
96 | args = AppendListIntTap().parse_args("--arg 3 --arg 4".split())
97 | self.assertEqual(args.arg, [1, 2, 3, 4])
98 |
99 | def test_actions_append_list_literal(self):
100 | class AppendListLiteralTap(Tap):
101 | arg: List[Literal["what", "is", "up", "today"]] = ["what", "is"]
102 |
103 | def configure(self):
104 | self.add_argument("--arg", action="append")
105 |
106 | args = AppendListLiteralTap().parse_args("--arg up --arg today".split())
107 | self.assertEqual(args.arg, "what is up today".split())
108 |
109 | def test_actions_append_untyped(self):
110 | class AppendListStrTap(Tap):
111 | arg = ["what", "is"]
112 |
113 | def configure(self):
114 | self.add_argument("--arg", action="append")
115 |
116 | args = AppendListStrTap().parse_args([])
117 | self.assertEqual(args.arg, ["what", "is"])
118 |
119 | args = AppendListStrTap().parse_args("--arg up --arg today".split())
120 | self.assertEqual(args.arg, "what is up today".split())
121 |
122 | def test_actions_append_const(self):
123 | class AppendConstTap(Tap):
124 | arg: List[int] = [1, 2, 3]
125 |
126 | def configure(self):
127 | self.add_argument("--arg", action="append_const", const=7)
128 |
129 | args = AppendConstTap().parse_args([])
130 | self.assertEqual(args.arg, [1, 2, 3])
131 |
132 | args = AppendConstTap().parse_args("--arg --arg".split())
133 | self.assertEqual(args.arg, [1, 2, 3, 7, 7])
134 |
135 | def test_actions_count(self):
136 | class CountTap(Tap):
137 | arg = 7
138 |
139 | def configure(self):
140 | self.add_argument("--arg", "-a", action="count")
141 |
142 | args = CountTap().parse_args([])
143 | self.assertEqual(args.arg, 7)
144 |
145 | args = CountTap().parse_args("-aaa --arg".split())
146 | self.assertEqual(args.arg, 11)
147 |
148 | def test_actions_int_count(self):
149 | class CountIntTap(Tap):
150 | arg: int = 7
151 |
152 | def configure(self):
153 | self.add_argument("--arg", "-a", action="count")
154 |
155 | args = CountIntTap().parse_args([])
156 | self.assertEqual(args.arg, 7)
157 |
158 | args = CountIntTap().parse_args("-aaa --arg".split())
159 | self.assertEqual(args.arg, 11)
160 |
161 | def test_actions_version(self):
162 | class VersionTap(Tap):
163 | def configure(self):
164 | self.add_argument("--version", action="version", version="2.0")
165 |
166 | # Ensure that nothing breaks without version flag
167 | VersionTap().parse_args([])
168 |
169 | # TODO: With version flag testing fails, but manual tests work
170 | # tried redirecting stderr using unittest.mock.patch
171 | # VersionTap().parse_args(['--version'])
172 |
173 | def test_actions_extend(self):
174 | class ExtendTap(Tap):
175 | arg = [1, 2]
176 |
177 | def configure(self):
178 | self.add_argument("--arg", nargs="+", action="extend")
179 |
180 | args = ExtendTap().parse_args([])
181 | self.assertEqual(args.arg, [1, 2])
182 |
183 | args = ExtendTap().parse_args("--arg a b --arg a --arg c d".split())
184 | self.assertEqual(args.arg, [1, 2] + "a b a c d".split())
185 |
186 | def test_actions_extend_list(self):
187 | class ExtendListTap(Tap):
188 | arg: List = ["hi"]
189 |
190 | def configure(self):
191 | self.add_argument("--arg", action="extend")
192 |
193 | args = ExtendListTap().parse_args("--arg yo yo --arg yoyo --arg yo yo".split())
194 | self.assertEqual(args.arg, "hi yo yo yoyo yo yo".split())
195 |
196 | def test_actions_extend_list_int(self):
197 | class ExtendListIntTap(Tap):
198 | arg: List[int] = [0]
199 |
200 | def configure(self):
201 | self.add_argument("--arg", action="extend")
202 |
203 | args = ExtendListIntTap().parse_args("--arg 1 2 --arg 3 --arg 4 5".split())
204 | self.assertEqual(args.arg, [0, 1, 2, 3, 4, 5])
205 |
206 | def test_positional_default(self):
207 | class PositionalDefault(Tap):
208 | arg: str
209 |
210 | def configure(self):
211 | self.add_argument("arg")
212 |
213 | help_regex = r".*positional arguments:\n.*arg\s*\(str, required\).*"
214 | help_text = PositionalDefault().format_help()
215 | self.assertRegex(help_text, help_regex)
216 |
217 |
218 | if __name__ == "__main__":
219 | unittest.main()
220 |
--------------------------------------------------------------------------------
/tests/test_load_config_files.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from tempfile import TemporaryDirectory
4 | import unittest
5 | from unittest import TestCase
6 |
7 | from tap import Tap
8 |
9 |
10 | # Suppress prints from SystemExit
11 | class DevNull:
12 | def write(self, msg):
13 | pass
14 |
15 |
16 | sys.stderr = DevNull()
17 |
18 |
19 | class LoadConfigFilesTests(TestCase):
20 | def test_file_does_not_exist(self) -> None:
21 | class EmptyTap(Tap):
22 | pass
23 |
24 | with self.assertRaises(FileNotFoundError):
25 | EmptyTap(config_files=["nope"]).parse_args([])
26 |
27 | def test_single_config(self) -> None:
28 | class SimpleTap(Tap):
29 | a: int
30 | b: str = "b"
31 |
32 | with TemporaryDirectory() as temp_dir:
33 | fname = os.path.join(temp_dir, "config.txt")
34 |
35 | with open(fname, "w") as f:
36 | f.write("--a 1")
37 |
38 | args = SimpleTap(config_files=[fname]).parse_args([])
39 |
40 | self.assertEqual(args.a, 1)
41 | self.assertEqual(args.b, "b")
42 |
43 | def test_single_config_overwriting(self) -> None:
44 | class SimpleOverwritingTap(Tap):
45 | a: int
46 | b: str = "b"
47 |
48 | with TemporaryDirectory() as temp_dir:
49 | fname = os.path.join(temp_dir, "config.txt")
50 |
51 | with open(fname, "w") as f:
52 | f.write("--a 1 --b two")
53 |
54 | args = SimpleOverwritingTap(config_files=[fname]).parse_args("--a 2".split())
55 |
56 | self.assertEqual(args.a, 2)
57 | self.assertEqual(args.b, "two")
58 |
59 | def test_single_config_known_only(self) -> None:
60 | class KnownOnlyTap(Tap):
61 | a: int
62 | b: str = "b"
63 |
64 | with TemporaryDirectory() as temp_dir:
65 | fname = os.path.join(temp_dir, "config.txt")
66 |
67 | with open(fname, "w") as f:
68 | f.write("--a 1 --c seeNothing")
69 |
70 | args = KnownOnlyTap(config_files=[fname]).parse_args([], known_only=True)
71 |
72 | self.assertEqual(args.a, 1)
73 | self.assertEqual(args.b, "b")
74 | self.assertEqual(args.extra_args, ["--c", "seeNothing"])
75 |
76 | def test_single_config_required_still_required(self) -> None:
77 | class KnownOnlyTap(Tap):
78 | a: int
79 | b: str = "b"
80 |
81 | with TemporaryDirectory() as temp_dir, self.assertRaises(SystemExit):
82 | fname = os.path.join(temp_dir, "config.txt")
83 |
84 | with open(fname, "w") as f:
85 | f.write("--b fore")
86 |
87 | KnownOnlyTap(config_files=[fname]).parse_args([])
88 |
89 | def test_multiple_configs(self) -> None:
90 | class MultipleTap(Tap):
91 | a: int
92 | b: str = "b"
93 |
94 | with TemporaryDirectory() as temp_dir:
95 | fname1, fname2 = os.path.join(temp_dir, "config1.txt"), os.path.join(temp_dir, "config2.txt")
96 |
97 | with open(fname1, "w") as f1, open(fname2, "w") as f2:
98 | f1.write("--b two")
99 | f2.write("--a 1")
100 |
101 | args = MultipleTap(config_files=[fname1, fname2]).parse_args([])
102 |
103 | self.assertEqual(args.a, 1)
104 | self.assertEqual(args.b, "two")
105 |
106 | def test_multiple_configs_overwriting(self) -> None:
107 | class MultipleOverwritingTap(Tap):
108 | a: int
109 | b: str = "b"
110 | c: str = "c"
111 |
112 | with TemporaryDirectory() as temp_dir:
113 | fname1, fname2 = os.path.join(temp_dir, "config1.txt"), os.path.join(temp_dir, "config2.txt")
114 |
115 | with open(fname1, "w") as f1, open(fname2, "w") as f2:
116 | f1.write("--a 1 --b two")
117 | f2.write("--a 2 --c see")
118 |
119 | args = MultipleOverwritingTap(config_files=[fname1, fname2]).parse_args("--b four".split())
120 |
121 | self.assertEqual(args.a, 2)
122 | self.assertEqual(args.b, "four")
123 | self.assertEqual(args.c, "see")
124 |
125 | def test_junk_config(self) -> None:
126 | class JunkConfigTap(Tap):
127 | a: int
128 | b: str = "b"
129 |
130 | with TemporaryDirectory() as temp_dir, self.assertRaises(SystemExit):
131 | fname = os.path.join(temp_dir, "config.txt")
132 |
133 | with open(fname, "w") as f:
134 | f.write("is not a file that can reasonably be parsed")
135 |
136 | JunkConfigTap(config_files=[fname]).parse_args([])
137 |
138 | def test_shlex_config(self) -> None:
139 | class ShlexConfigTap(Tap):
140 | a: int
141 | b: str
142 |
143 | with TemporaryDirectory() as temp_dir:
144 | fname = os.path.join(temp_dir, "config.txt")
145 |
146 | with open(fname, "w") as f:
147 | f.write('--a 21 # Important arg value\n\n# Multi-word quoted string\n--b "two three four"')
148 |
149 | args = ShlexConfigTap(config_files=[fname]).parse_args([])
150 |
151 | self.assertEqual(args.a, 21)
152 | self.assertEqual(args.b, "two three four")
153 |
154 |
155 | if __name__ == "__main__":
156 | unittest.main()
157 |
--------------------------------------------------------------------------------
/tests/test_subparser.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentError
2 | import sys
3 | from typing import Literal, Union
4 | import unittest
5 | from unittest import TestCase
6 |
7 | from tap import Tap
8 |
9 |
10 | # Suppress prints from SystemExit
11 | class DevNull:
12 | def write(self, msg):
13 | pass
14 |
15 |
16 | sys.stderr = DevNull()
17 |
18 |
19 | class TestSubparser(TestCase):
20 | def test_subparser_documentation_example(self):
21 | class SubparserA(Tap):
22 | bar: int # bar help
23 |
24 | class SubparserB(Tap):
25 | baz: Literal["X", "Y", "Z"] # baz help
26 |
27 | class Args(Tap):
28 | foo: bool = False # foo help
29 |
30 | def configure(self):
31 | self.add_subparsers(help="sub-command help")
32 | self.add_subparser("a", SubparserA, help="a help")
33 | self.add_subparser("b", SubparserB, help="b help")
34 |
35 | args = Args().parse_args([])
36 | self.assertFalse(args.foo)
37 | self.assertFalse(hasattr(args, "bar"))
38 | self.assertFalse(hasattr(args, "baz"))
39 |
40 | args = Args().parse_args(["--foo"])
41 | self.assertTrue(args.foo)
42 | self.assertFalse(hasattr(args, "bar"))
43 | self.assertFalse(hasattr(args, "baz"))
44 |
45 | args = Args().parse_args("a --bar 1".split())
46 | self.assertFalse(args.foo)
47 | self.assertEqual(args.bar, 1)
48 | self.assertFalse(hasattr(args, "baz"))
49 |
50 | args = Args().parse_args("--foo b --baz X".split())
51 | self.assertTrue(args.foo)
52 | self.assertFalse(hasattr(args, "bar"))
53 | self.assertEqual(args.baz, "X")
54 |
55 | with self.assertRaises(SystemExit):
56 | Args().parse_args("--baz X --foo b".split())
57 |
58 | with self.assertRaises(SystemExit):
59 | Args().parse_args("b --baz X --foo".split())
60 |
61 | with self.assertRaises(SystemExit):
62 | Args().parse_args("--foo a --bar 1 b --baz X".split())
63 |
64 | def test_name_collision(self):
65 | class SubparserA(Tap):
66 | a: int
67 |
68 | class Args(Tap):
69 | foo: bool = False
70 |
71 | def configure(self):
72 | self.add_subparsers(help="sub-command help")
73 | self.add_subparser("a", SubparserA, help="a help")
74 |
75 | args = Args().parse_args("a --a 1".split())
76 | self.assertFalse(args.foo)
77 | self.assertEqual(args.a, 1)
78 |
79 | def test_name_overriding(self):
80 | class SubparserA(Tap):
81 | foo: int
82 |
83 | class Args(Tap):
84 | foo: bool = False
85 |
86 | def configure(self):
87 | self.add_subparsers(help="sub-command help")
88 | self.add_subparser("a", SubparserA)
89 |
90 | args = Args().parse_args(["--foo"])
91 | self.assertTrue(args.foo)
92 |
93 | args = Args().parse_args("a --foo 2".split())
94 | self.assertEqual(args.foo, 2)
95 |
96 | args = Args().parse_args("--foo a --foo 2".split())
97 | self.assertEqual(args.foo, 2)
98 |
99 | def test_add_subparser_twice(self):
100 | class SubparserA(Tap):
101 | bar: int
102 |
103 | class SubparserB(Tap):
104 | baz: int
105 |
106 | class Args(Tap):
107 | foo: bool = False
108 |
109 | def configure(self):
110 | self.add_subparser("a", SubparserB)
111 | self.add_subparser("a", SubparserA)
112 |
113 | if sys.version_info >= (3, 11):
114 | with self.assertRaises(ArgumentError):
115 | Args().parse_args([])
116 | else:
117 | args = Args().parse_args("a --bar 2".split())
118 | self.assertFalse(args.foo)
119 | self.assertEqual(args.bar, 2)
120 | self.assertFalse(hasattr(args, "baz"))
121 |
122 | with self.assertRaises(SystemExit):
123 | Args().parse_args("a --baz 2".split())
124 |
125 | def test_add_subparsers_twice(self):
126 | class SubparserA(Tap):
127 | a: int
128 |
129 | class Args(Tap):
130 | foo: bool = False
131 |
132 | def configure(self):
133 | self.add_subparser("a", SubparserA)
134 | self.add_subparsers(help="sub-command1 help")
135 | self.add_subparsers(help="sub-command2 help")
136 |
137 | if sys.version_info >= (3, 12, 5):
138 | with self.assertRaises(ArgumentError):
139 | Args().parse_args([])
140 | else:
141 | with self.assertRaises(SystemExit):
142 | Args().parse_args([])
143 |
144 | def test_add_subparsers_with_add_argument(self):
145 | class SubparserA(Tap):
146 | for_sure: bool = False
147 |
148 | class Args(Tap):
149 | foo: bool = False
150 | bar: int = 1
151 |
152 | def configure(self):
153 | self.add_argument("--bar", "-ib")
154 | self.add_subparser("is_terrible", SubparserA)
155 | self.add_argument("--foo", "-m")
156 |
157 | args = Args().parse_args("-ib 0 -m is_terrible --for_sure".split())
158 | self.assertTrue(args.foo)
159 | self.assertEqual(args.bar, 0)
160 | self.assertTrue(args.for_sure)
161 |
162 | def test_add_subsubparsers(self):
163 | class SubSubparserB(Tap):
164 | baz: bool = False
165 |
166 | class SubparserA(Tap):
167 | biz: bool = False
168 |
169 | def configure(self):
170 | self.add_subparser("b", SubSubparserB)
171 |
172 | class SubparserB(Tap):
173 | blaz: bool = False
174 |
175 | class Args(Tap):
176 | foo: bool = False
177 |
178 | def configure(self):
179 | self.add_subparser("a", SubparserA)
180 | self.add_subparser("b", SubparserB)
181 |
182 | args = Args().parse_args("b --blaz".split())
183 | self.assertFalse(args.foo)
184 | self.assertFalse(hasattr(args, "baz"))
185 | self.assertFalse(hasattr(args, "biz"))
186 | self.assertTrue(args.blaz)
187 |
188 | args = Args().parse_args("a --biz".split())
189 | self.assertFalse(args.foo)
190 | self.assertTrue(args.biz)
191 | self.assertFalse(hasattr(args, "baz"))
192 | self.assertFalse(hasattr(args, "blaz"))
193 |
194 | args = Args().parse_args("a --biz b --baz".split())
195 | self.assertFalse(args.foo)
196 | self.assertTrue(args.biz)
197 | self.assertFalse(hasattr(args, "blaz"))
198 | self.assertTrue(args.baz)
199 |
200 | with self.assertRaises(SystemExit):
201 | Args().parse_args("b a".split())
202 |
203 | def test_subparser_underscores_to_dashes(self):
204 | class AddProposal(Tap):
205 | proposal_id: int
206 |
207 | class Arguments(Tap):
208 | def configure(self) -> None:
209 | self.add_subparsers(dest="subparser_name")
210 |
211 | self.add_subparser(
212 | "add-proposal", AddProposal, help="Add a new proposal",
213 | )
214 |
215 | args_underscores: Union[Arguments, AddProposal] = Arguments(underscores_to_dashes=False).parse_args(
216 | "add-proposal --proposal_id 1".split()
217 | )
218 | self.assertEqual(args_underscores.proposal_id, 1)
219 |
220 | args_dashes: Union[Arguments, AddProposal] = Arguments(underscores_to_dashes=True).parse_args(
221 | "add-proposal --proposal-id 1".split()
222 | )
223 | self.assertEqual(args_dashes.proposal_id, 1)
224 |
225 | with self.assertRaises(SystemExit):
226 | Arguments(underscores_to_dashes=False).parse_args("add-proposal --proposal-id 1".split())
227 |
228 | with self.assertRaises(SystemExit):
229 | Arguments(underscores_to_dashes=True).parse_args("add-proposal --proposal_id 1".split())
230 |
231 |
232 | if __name__ == "__main__":
233 | unittest.main()
234 |
--------------------------------------------------------------------------------
/tests/test_to_tap_class.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests `tap.to_tap_class`.
3 | """
4 |
5 | from contextlib import redirect_stdout, redirect_stderr
6 | import dataclasses
7 | import io
8 | import re
9 | import sys
10 | from typing import Any, Callable, List, Literal, Optional, Type, Union
11 |
12 | from packaging.version import Version
13 | import pytest
14 |
15 | from tap import to_tap_class, Tap
16 | from tap.utils import type_to_str
17 |
18 |
19 | try:
20 | import pydantic
21 | except ModuleNotFoundError:
22 | _IS_PYDANTIC_V1 = None
23 | else:
24 | _IS_PYDANTIC_V1 = Version(pydantic.__version__) < Version("2.0.0")
25 |
26 |
27 | # To properly test the help message, we need to know how argparse formats it. It changed from 3.9 -> 3.10 -> 3.13
28 | _OPTIONS_TITLE = "options" if not sys.version_info < (3, 10) else "optional arguments"
29 | _ARG_LIST_DOTS = "..."
30 | _ARG_WITH_ALIAS = (
31 | "-arg, --argument_with_really_long_name ARGUMENT_WITH_REALLY_LONG_NAME"
32 | if not sys.version_info < (3, 13)
33 | else "-arg ARGUMENT_WITH_REALLY_LONG_NAME, --argument_with_really_long_name ARGUMENT_WITH_REALLY_LONG_NAME"
34 | )
35 |
36 |
37 | @dataclasses.dataclass
38 | class _Args:
39 | """
40 | These are the arguments which every type of class or function must contain.
41 | """
42 |
43 | arg_int: int = dataclasses.field(metadata=dict(description="some integer"))
44 | arg_bool: bool = True
45 | arg_list: Optional[List[str]] = dataclasses.field(default=None, metadata=dict(description="some list of strings"))
46 |
47 |
48 | def _monkeypatch_eq(cls):
49 | """
50 | Monkey-patches `cls.__eq__` to check that the attribute values are equal to a dataclass representation of them.
51 | """
52 |
53 | def _equality(self, other: _Args) -> bool:
54 | return _Args(self.arg_int, arg_bool=self.arg_bool, arg_list=self.arg_list) == other
55 |
56 | cls.__eq__ = _equality
57 | return cls
58 |
59 |
60 | # Define a few different classes or functions which all take the same arguments (same by name, annotation, and default
61 | # if not required)
62 |
63 |
64 | def function(arg_int: int, arg_bool: bool = True, arg_list: Optional[List[str]] = None) -> _Args:
65 | """
66 | :param arg_int: some integer
67 | :param arg_list: some list of strings
68 | """
69 | return _Args(arg_int, arg_bool=arg_bool, arg_list=arg_list)
70 |
71 |
72 | @_monkeypatch_eq
73 | class Class:
74 | def __init__(self, arg_int: int, arg_bool: bool = True, arg_list: Optional[List[str]] = None):
75 | """
76 | :param arg_int: some integer
77 | :param arg_list: some list of strings
78 | """
79 | self.arg_int = arg_int
80 | self.arg_bool = arg_bool
81 | self.arg_list = arg_list
82 |
83 |
84 | DataclassBuiltin = _Args
85 |
86 |
87 | if _IS_PYDANTIC_V1 is None:
88 | pass # will raise NameError if attempting to use DataclassPydantic or Model later
89 | elif _IS_PYDANTIC_V1:
90 | # For Pydantic v1 data models, we rely on the docstring to get descriptions
91 |
92 | @_monkeypatch_eq
93 | @pydantic.dataclasses.dataclass
94 | class DataclassPydantic:
95 | """
96 | Dataclass (pydantic v1)
97 |
98 | :param arg_int: some integer
99 | :param arg_list: some list of strings
100 | """
101 |
102 | arg_int: int
103 | arg_bool: bool = True
104 | arg_list: Optional[List[str]] = None
105 |
106 | @_monkeypatch_eq
107 | class Model(pydantic.BaseModel):
108 | """
109 | Pydantic model (pydantic v1)
110 |
111 | :param arg_int: some integer
112 | :param arg_list: some list of strings
113 | """
114 |
115 | arg_int: int
116 | arg_bool: bool = True
117 | arg_list: Optional[List[str]] = None
118 |
119 | else:
120 | # For pydantic v2 data models, we check the docstring and Field for the description
121 |
122 | @_monkeypatch_eq
123 | @pydantic.dataclasses.dataclass
124 | class DataclassPydantic:
125 | """
126 | Dataclass (pydantic)
127 |
128 | :param arg_list: some list of strings
129 | """
130 |
131 | # Mixing field types should be ok
132 | arg_int: int = pydantic.dataclasses.Field(description="some integer")
133 | arg_bool: bool = dataclasses.field(default=True)
134 | arg_list: Optional[List[str]] = pydantic.Field(default=None)
135 |
136 | @_monkeypatch_eq
137 | class Model(pydantic.BaseModel):
138 | """
139 | Pydantic model
140 |
141 | :param arg_int: some integer
142 | """
143 |
144 | # Mixing field types should be ok
145 | arg_int: int
146 | arg_bool: bool = dataclasses.field(default=True)
147 | arg_list: Optional[List[str]] = pydantic.dataclasses.Field(default=None, description="some list of strings")
148 |
149 |
150 | @pytest.fixture(
151 | scope="module",
152 | params=[
153 | function,
154 | Class,
155 | DataclassBuiltin,
156 | DataclassBuiltin(
157 | 1, arg_bool=False, arg_list=["these", "values", "don't", "matter"]
158 | ), # to_tap_class also works on instances of data models. It ignores the attribute values
159 | ]
160 | + ([] if _IS_PYDANTIC_V1 is None else [DataclassPydantic, Model]),
161 | # NOTE: instances of DataclassPydantic and Model can be tested for pydantic v2 but not v1
162 | )
163 | def class_or_function_(request: pytest.FixtureRequest):
164 | """
165 | Parametrized class_or_function.
166 | """
167 | return request.param
168 |
169 |
170 | # Define some functions which take a class or function and calls `tap.to_tap_class` on it to create a `tap.Tap`
171 | # subclass (class, not instance)
172 |
173 |
174 | def subclasser_simple(class_or_function: Any) -> Type[Tap]:
175 | """
176 | Plain subclass, does nothing extra.
177 | """
178 | return to_tap_class(class_or_function)
179 |
180 |
181 | def subclasser_complex(class_or_function):
182 | """
183 | It's conceivable that someone has a data model, but they want to add more arguments or handling when running a
184 | script.
185 | """
186 |
187 | def to_number(string: str) -> Union[float, int]:
188 | return float(string) if "." in string else int(string)
189 |
190 | class TapSubclass(to_tap_class(class_or_function)):
191 | # You can supply additional arguments here
192 | argument_with_really_long_name: Union[float, int] = 3
193 | "This argument has a long name and will be aliased with a short one"
194 |
195 | def configure(self) -> None:
196 | # You can still add special argument behavior
197 | self.add_argument("-arg", "--argument_with_really_long_name", type=to_number)
198 |
199 | def process_args(self) -> None:
200 | # You can still validate and modify arguments
201 | if self.argument_with_really_long_name > 4:
202 | raise ValueError("argument_with_really_long_name cannot be > 4")
203 |
204 | # No auto-complete (and other niceties) for the super class attributes b/c this is a dynamic subclass. Sorry
205 | if self.arg_bool and self.arg_list is not None:
206 | self.arg_list.append("processed")
207 |
208 | return TapSubclass
209 |
210 |
211 | def subclasser_subparser(class_or_function):
212 | class SubparserA(Tap):
213 | bar: int # bar help
214 |
215 | class SubparserB(Tap):
216 | baz: Literal["X", "Y", "Z"] # baz help
217 |
218 | class TapSubclass(to_tap_class(class_or_function)):
219 | foo: bool = False # foo help
220 |
221 | def configure(self):
222 | self.add_subparsers(help="sub-command help")
223 | self.add_subparser("a", SubparserA, help="a help", description="Description (a)")
224 | self.add_subparser("b", SubparserB, help="b help")
225 |
226 | return TapSubclass
227 |
228 |
229 | # Test that the subclasser parses the args correctly or raises the correct error.
230 | # The subclassers are tested separately b/c the parametrizaiton of args_string_and_arg_to_expected_value depends on the
231 | # subclasser.
232 | # First, some helper functions.
233 |
234 |
235 | def _test_raises_system_exit(tap: Tap, args_string: str) -> str:
236 | is_help = (
237 | args_string.endswith("-h")
238 | or args_string.endswith("--help")
239 | or " -h " in args_string
240 | or " --help " in args_string
241 | )
242 | f = io.StringIO()
243 | with redirect_stdout(f) if is_help else redirect_stderr(f):
244 | with pytest.raises(SystemExit):
245 | tap.parse_args(args_string.split())
246 |
247 | return f.getvalue()
248 |
249 |
250 | def _test_subclasser(
251 | subclasser: Callable[[Any], Type[Tap]],
252 | class_or_function: Any,
253 | args_string_and_arg_to_expected_value: tuple[str, Union[dict[str, Any], BaseException]],
254 | test_call: bool = True,
255 | ):
256 | """
257 | Tests that the `subclasser` converts `class_or_function` to a `Tap` class which parses the argument string
258 | correctly.
259 |
260 | Setting `test_call=True` additionally tests that calling the `class_or_function` on the parsed arguments works.
261 | """
262 | args_string, arg_to_expected_value = args_string_and_arg_to_expected_value
263 | TapSubclass = subclasser(class_or_function)
264 | assert issubclass(TapSubclass, Tap)
265 | tap = TapSubclass(description="Script description")
266 |
267 | if isinstance(arg_to_expected_value, SystemExit):
268 | stderr = _test_raises_system_exit(tap, args_string)
269 | assert re.search(str(arg_to_expected_value), stderr)
270 | elif isinstance(arg_to_expected_value, BaseException):
271 | expected_exception = arg_to_expected_value.__class__
272 | expected_error_message = str(arg_to_expected_value) or None
273 | with pytest.raises(expected_exception=expected_exception, match=expected_error_message):
274 | args = tap.parse_args(args_string.split())
275 | else:
276 | # args_string is a valid argument combo
277 | # Test that parsing works correctly
278 | args = tap.parse_args(args_string.split())
279 | assert arg_to_expected_value == args.as_dict()
280 | if test_call and callable(class_or_function):
281 | result = class_or_function(**args.as_dict())
282 | assert result == _Args(**arg_to_expected_value)
283 |
284 |
285 | def _test_subclasser_message(
286 | subclasser: Callable[[Any], Type[Tap]],
287 | class_or_function: Any,
288 | message_expected: str,
289 | description: str = "Script description",
290 | args_string: str = "-h",
291 | ):
292 | """
293 | Tests that::
294 |
295 | subclasser(class_or_function)(description=description).parse_args(args_string.split())
296 |
297 | outputs `message_expected` to stdout, ignoring differences in whitespaces/newlines/tabs.
298 | """
299 |
300 | def replace_whitespace(string: str) -> str:
301 | return re.sub(r"\s+", " ", string).strip() # FYI this line was written by an LLM
302 |
303 | TapSubclass = subclasser(class_or_function)
304 | tap = TapSubclass(description=description)
305 | message = _test_raises_system_exit(tap, args_string)
306 | # Standardize to ignore trivial differences due to terminal settings
307 | assert replace_whitespace(message) == replace_whitespace(message_expected)
308 |
309 |
310 | # Test sublcasser_simple
311 |
312 |
313 | @pytest.mark.parametrize(
314 | "args_string_and_arg_to_expected_value",
315 | [
316 | (
317 | "--arg_int 1 --arg_list x y z",
318 | {"arg_int": 1, "arg_bool": True, "arg_list": ["x", "y", "z"]},
319 | ),
320 | (
321 | "--arg_int 1 --arg_bool",
322 | {"arg_int": 1, "arg_bool": False, "arg_list": None},
323 | ),
324 | # The rest are invalid argument combos, as indicated by the 2nd elt being a BaseException instance
325 | (
326 | "--arg_list x y z --arg_bool",
327 | SystemExit("error: the following arguments are required: --arg_int"),
328 | ),
329 | (
330 | "--arg_int not_an_int --arg_list x y z --arg_bool",
331 | SystemExit("error: argument --arg_int: invalid int value: 'not_an_int'"),
332 | ),
333 | ],
334 | )
335 | def test_subclasser_simple(
336 | class_or_function_: Any, args_string_and_arg_to_expected_value: tuple[str, Union[dict[str, Any], BaseException]]
337 | ):
338 | _test_subclasser(subclasser_simple, class_or_function_, args_string_and_arg_to_expected_value)
339 |
340 |
341 | def test_subclasser_simple_help_message(class_or_function_: Any):
342 | description = "Script description"
343 | help_message_expected = f"""
344 | usage: pytest --arg_int ARG_INT [--arg_bool] [--arg_list [ARG_LIST {_ARG_LIST_DOTS}]] [-h]
345 |
346 | {description}
347 |
348 | {_OPTIONS_TITLE}:
349 | --arg_int ARG_INT (int, required) some integer
350 | --arg_bool (bool, default=True)
351 | --arg_list [ARG_LIST {_ARG_LIST_DOTS}]
352 | ({type_to_str(Optional[List[str]])}, default=None) some list of strings
353 | -h, --help show this help message and exit
354 | """
355 | _test_subclasser_message(subclasser_simple, class_or_function_, help_message_expected, description=description)
356 |
357 |
358 | # Test subclasser_complex
359 |
360 |
361 | @pytest.mark.parametrize(
362 | "args_string_and_arg_to_expected_value",
363 | [
364 | (
365 | "--arg_int 1 --arg_list x y z",
366 | {
367 | "arg_int": 1,
368 | "arg_bool": True,
369 | "arg_list": ["x", "y", "z", "processed"],
370 | "argument_with_really_long_name": 3,
371 | },
372 | ),
373 | (
374 | "--arg_int 1 --arg_list x y z -arg 2",
375 | {
376 | "arg_int": 1,
377 | "arg_bool": True,
378 | "arg_list": ["x", "y", "z", "processed"],
379 | "argument_with_really_long_name": 2,
380 | },
381 | ),
382 | (
383 | "--arg_int 1 --arg_bool --argument_with_really_long_name 2.3",
384 | {
385 | "arg_int": 1,
386 | "arg_bool": False,
387 | "arg_list": None,
388 | "argument_with_really_long_name": 2.3,
389 | },
390 | ),
391 | # The rest are invalid argument combos, as indicated by the 2nd elt being a BaseException instance
392 | (
393 | "--arg_list x y z --arg_bool",
394 | SystemExit("error: the following arguments are required: --arg_int"),
395 | ),
396 | (
397 | "--arg_int 1 --arg_list x y z -arg not_a_float_or_int",
398 | SystemExit(
399 | "error: argument -arg/--argument_with_really_long_name: invalid to_number value: 'not_a_float_or_int'"
400 | ),
401 | ),
402 | (
403 | "--arg_int 1 --arg_list x y z -arg 5", # Wrong value arg (aliases argument_with_really_long_name)
404 | ValueError("argument_with_really_long_name cannot be > 4"),
405 | ),
406 | ],
407 | )
408 | def test_subclasser_complex(
409 | class_or_function_: Any, args_string_and_arg_to_expected_value: tuple[str, Union[dict[str, Any], BaseException]]
410 | ):
411 | # Currently setting test_call=False b/c all data models except the pydantic Model don't accept extra args
412 | _test_subclasser(subclasser_complex, class_or_function_, args_string_and_arg_to_expected_value, test_call=False)
413 |
414 |
415 | def test_subclasser_complex_help_message(class_or_function_: Any):
416 | description = "Script description"
417 | help_message_expected = f"""
418 | usage: pytest [-arg ARGUMENT_WITH_REALLY_LONG_NAME] --arg_int ARG_INT [--arg_bool]
419 | [--arg_list [ARG_LIST {_ARG_LIST_DOTS}]] [-h]
420 |
421 | {description}
422 |
423 | {_OPTIONS_TITLE}:
424 | {_ARG_WITH_ALIAS}
425 | (Union[float, int], default=3) This argument has a long name and will be aliased with a short
426 | one
427 | --arg_int ARG_INT (int, required) some integer
428 | --arg_bool (bool, default=True)
429 | --arg_list [ARG_LIST {_ARG_LIST_DOTS}]
430 | ({type_to_str(Optional[List[str]])}, default=None) some list of strings
431 | -h, --help show this help message and exit
432 | """
433 | _test_subclasser_message(subclasser_complex, class_or_function_, help_message_expected, description=description)
434 |
435 |
436 | # Test subclasser_subparser
437 |
438 |
439 | @pytest.mark.parametrize(
440 | "args_string_and_arg_to_expected_value",
441 | [
442 | (
443 | "--arg_int 1",
444 | {"arg_int": 1, "arg_bool": True, "arg_list": None, "foo": False},
445 | ),
446 | (
447 | "--arg_int 1 a --bar 2",
448 | {"arg_int": 1, "arg_bool": True, "arg_list": None, "bar": 2, "foo": False},
449 | ),
450 | (
451 | "--arg_int 1 --foo a --bar 2",
452 | {"arg_int": 1, "arg_bool": True, "arg_list": None, "bar": 2, "foo": True},
453 | ),
454 | (
455 | "--arg_int 1 b --baz X",
456 | {"arg_int": 1, "arg_bool": True, "arg_list": None, "baz": "X", "foo": False},
457 | ),
458 | (
459 | "--foo --arg_bool --arg_list x y z --arg_int 1 b --baz Y",
460 | {"arg_int": 1, "arg_bool": False, "arg_list": ["x", "y", "z"], "baz": "Y", "foo": True},
461 | ),
462 | # The rest are invalid argument combos, as indicated by the 2nd elt being a BaseException instance
463 | (
464 | "a --bar 1",
465 | SystemExit("error: the following arguments are required: --arg_int"),
466 | ),
467 | (
468 | "--arg_int not_an_int a --bar 1",
469 | SystemExit("error: argument --arg_int: invalid int value: 'not_an_int'"),
470 | ),
471 | (
472 | "--arg_int 1 --baz X --foo b",
473 | SystemExit(
474 | r"error: argument \{a,b}: invalid choice: 'X' \(choose from '?a'?, '?b'?\)"
475 | ),
476 | ),
477 | (
478 | "--arg_int 1 b --baz X --foo",
479 | SystemExit("error: unrecognized arguments: --foo"),
480 | ),
481 | (
482 | "--arg_int 1 --foo b --baz A",
483 | SystemExit(r"""error: argument --baz: Value for variable "baz" must be one of \['X', 'Y', 'Z']."""),
484 | ),
485 | ],
486 | )
487 | def test_subclasser_subparser(
488 | class_or_function_: Any, args_string_and_arg_to_expected_value: tuple[str, Union[dict[str, Any], BaseException]]
489 | ):
490 | # Currently setting test_call=False b/c all data models except the pydantic Model don't accept extra args
491 | _test_subclasser(subclasser_subparser, class_or_function_, args_string_and_arg_to_expected_value, test_call=False)
492 |
493 |
494 | @pytest.mark.parametrize(
495 | "args_string_and_description_and_expected_message",
496 | [
497 | (
498 | "-h",
499 | "Script description",
500 | f"""
501 | usage: pytest [--foo] --arg_int ARG_INT [--arg_bool] [--arg_list [ARG_LIST {_ARG_LIST_DOTS}]] [-h]
502 | {{a,b}} ...
503 |
504 | Script description
505 |
506 | positional arguments:
507 | {{a,b}} sub-command help
508 | a a help
509 | b b help
510 |
511 | {_OPTIONS_TITLE}:
512 | --foo (bool, default=False) foo help
513 | --arg_int ARG_INT (int, required) some integer
514 | --arg_bool (bool, default=True)
515 | --arg_list [ARG_LIST {_ARG_LIST_DOTS}]
516 | ({type_to_str(Optional[List[str]])}, default=None) some list of strings
517 | -h, --help show this help message and exit
518 | """,
519 | ),
520 | (
521 | "a -h",
522 | "Description (a)",
523 | f"""
524 | usage: pytest a --bar BAR [-h]
525 |
526 | Description (a)
527 |
528 | {_OPTIONS_TITLE}:
529 | --bar BAR (int, required) bar help
530 | -h, --help show this help message and exit
531 | """,
532 | ),
533 | (
534 | "b -h",
535 | "", # no description
536 | f"""
537 | usage: pytest b --baz {{X,Y,Z}} [-h]
538 |
539 | {_OPTIONS_TITLE}:
540 | --baz {{X,Y,Z}} (Literal['X', 'Y', 'Z'], required) baz help
541 | -h, --help show this help message and exit
542 | """,
543 | ),
544 | ],
545 | )
546 | def test_subclasser_subparser_help_message(
547 | class_or_function_: Any, args_string_and_description_and_expected_message: tuple[str, str]
548 | ):
549 | args_string, description, expected_message = args_string_and_description_and_expected_message
550 | _test_subclasser_message(
551 | subclasser_subparser, class_or_function_, expected_message, description=description, args_string=args_string
552 | )
553 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentTypeError
2 | import inspect
3 | import json
4 | import os
5 | import subprocess
6 | from tempfile import TemporaryDirectory
7 | from typing import Any, Callable, List, Literal, Dict, Set, Tuple, Union
8 | import unittest
9 | from unittest import TestCase
10 |
11 | from tap.utils import (
12 | get_class_column,
13 | get_class_variables,
14 | GitInfo,
15 | tokenize_source,
16 | type_to_str,
17 | get_literals,
18 | TupleTypeEnforcer,
19 | _nested_replace_type,
20 | define_python_object_encoder,
21 | UnpicklableObject,
22 | as_python_object,
23 | enforce_reproducibility,
24 | )
25 |
26 |
27 | class GitTests(TestCase):
28 | def setUp(self) -> None:
29 | self.temp_dir = TemporaryDirectory()
30 | self.prev_dir = os.getcwd()
31 | os.chdir(self.temp_dir.name)
32 | subprocess.check_output(["git", "init"])
33 | self.url = "https://github.com/test_account/test_repo"
34 | subprocess.check_output(["git", "remote", "add", "origin", f"{self.url}.git"])
35 | subprocess.check_output(["touch", "README.md"])
36 | subprocess.check_output(["git", "add", "README.md"])
37 | subprocess.check_output(["git", "commit", "-m", "Initial commit"])
38 | self.git_info = GitInfo(repo_path=self.temp_dir.name)
39 |
40 | def tearDown(self) -> None:
41 | os.chdir(self.prev_dir)
42 |
43 | # Add permissions to temporary directory to enable cleanup in Windows
44 | for root, dirs, files in os.walk(self.temp_dir.name):
45 | for name in dirs + files:
46 | os.chmod(os.path.join(root, name), 0o777)
47 |
48 | self.temp_dir.cleanup()
49 |
50 | def test_has_git_true(self) -> None:
51 | self.assertTrue(self.git_info.has_git())
52 |
53 | def test_has_git_false(self) -> None:
54 | with TemporaryDirectory() as temp_dir_no_git:
55 | os.chdir(temp_dir_no_git)
56 | self.git_info.repo_path = temp_dir_no_git
57 | self.assertFalse(self.git_info.has_git())
58 | self.git_info.repo_path = self.temp_dir.name
59 | os.chdir(self.temp_dir.name)
60 |
61 | def test_get_git_root(self) -> None:
62 | # Ideally should be self.temp_dir.name == get_git_root() but the OS may add a prefix like /private
63 | self.assertTrue(self.git_info.get_git_root().endswith(self.temp_dir.name.replace("\\", "/")))
64 |
65 | def test_get_git_root_subdir(self) -> None:
66 | subdir = os.path.join(self.temp_dir.name, "subdir")
67 | os.makedirs(subdir)
68 | os.chdir(subdir)
69 |
70 | # Ideally should be self.temp_dir.name == get_git_root() but the OS may add a prefix like /private
71 | self.assertTrue(self.git_info.get_git_root().endswith(self.temp_dir.name.replace("\\", "/")))
72 |
73 | os.chdir(self.temp_dir.name)
74 |
75 | def test_get_git_url_https(self) -> None:
76 | self.assertEqual(self.git_info.get_git_url(commit_hash=False), self.url)
77 |
78 | def test_get_git_url_https_hash(self) -> None:
79 | url = f"{self.url}/tree/"
80 | self.assertEqual(self.git_info.get_git_url(commit_hash=True)[: len(url)], url)
81 |
82 | def test_get_git_url_ssh(self) -> None:
83 | subprocess.run(["git", "remote", "set-url", "origin", "git@github.com:test_account/test_repo.git"])
84 | self.assertEqual(self.git_info.get_git_url(commit_hash=False), self.url)
85 |
86 | def test_get_git_url_ssh_hash(self) -> None:
87 | subprocess.run(["git", "remote", "set-url", "origin", "git@github.com:test_account/test_repo.git"])
88 | url = f"{self.url}/tree/"
89 | self.assertEqual(self.git_info.get_git_url(commit_hash=True)[: len(url)], url)
90 |
91 | def test_get_git_url_https_enterprise(self) -> None:
92 | true_url = "https://github.tap.com/test_account/test_repo"
93 | subprocess.run(["git", "remote", "set-url", "origin", f"{true_url}.git"])
94 | self.assertEqual(self.git_info.get_git_url(commit_hash=False), true_url)
95 |
96 | def test_get_git_url_https_hash_enterprise(self) -> None:
97 | true_url = "https://github.tap.com/test_account/test_repo"
98 | subprocess.run(["git", "remote", "set-url", "origin", f"{true_url}.git"])
99 | url = f"{true_url}/tree/"
100 | self.assertEqual(self.git_info.get_git_url(commit_hash=True)[: len(url)], url)
101 |
102 | def test_get_git_url_ssh_enterprise(self) -> None:
103 | true_url = "https://github.tap.com/test_account/test_repo"
104 | subprocess.run(["git", "remote", "set-url", "origin", "git@github.tap.com:test_account/test_repo.git"])
105 | self.assertEqual(self.git_info.get_git_url(commit_hash=False), true_url)
106 |
107 | def test_get_git_url_ssh_hash_enterprise(self) -> None:
108 | true_url = "https://github.tap.com/test_account/test_repo"
109 | subprocess.run(["git", "remote", "set-url", "origin", "git@github.tap.com:test_account/test_repo.git"])
110 | url = f"{true_url}/tree/"
111 | self.assertEqual(self.git_info.get_git_url(commit_hash=True)[: len(url)], url)
112 |
113 | def test_get_git_url_no_remote(self) -> None:
114 | subprocess.run(["git", "remote", "remove", "origin"])
115 | self.assertEqual(self.git_info.get_git_url(), "")
116 |
117 | def test_get_git_version(self) -> None:
118 | git_version = self.git_info.get_git_version()
119 | self.assertIsInstance(git_version, tuple)
120 | for v in git_version:
121 | self.assertIsInstance(v, int)
122 |
123 | def test_has_uncommitted_changes_false(self) -> None:
124 | self.assertFalse(self.git_info.has_uncommitted_changes())
125 |
126 | def test_has_uncommited_changes_true(self) -> None:
127 | subprocess.run(["touch", "main.py"])
128 | self.assertTrue(self.git_info.has_uncommitted_changes())
129 |
130 |
131 | class TypeToStrTests(TestCase):
132 | def test_type_to_str(self) -> None:
133 | self.assertEqual(type_to_str(str), "str")
134 | self.assertEqual(type_to_str(int), "int")
135 | self.assertEqual(type_to_str(float), "float")
136 | self.assertEqual(type_to_str(bool), "bool")
137 | self.assertEqual(type_to_str(Any), "Any")
138 | self.assertEqual(type_to_str(Callable[[str], str]), "Callable[[str], str]")
139 | self.assertEqual(
140 | type_to_str(Callable[[str, int], Tuple[float, bool]]), "Callable[[str, int], Tuple[float, bool]]"
141 | )
142 | self.assertEqual(type_to_str(List[int]), "List[int]")
143 | self.assertEqual(type_to_str(List[str]), "List[str]")
144 | self.assertEqual(type_to_str(List[float]), "List[float]")
145 | self.assertEqual(type_to_str(List[bool]), "List[bool]")
146 | self.assertEqual(type_to_str(Set[int]), "Set[int]")
147 | self.assertEqual(type_to_str(Dict[str, int]), "Dict[str, int]")
148 | self.assertEqual(type_to_str(Union[List[int], Dict[float, bool]]), "Union[List[int], Dict[float, bool]]")
149 |
150 |
151 | def class_decorator(cls):
152 | return cls
153 |
154 |
155 | class ClassColumnTests(TestCase):
156 | def test_column_simple(self):
157 | class SimpleColumn:
158 | arg = 2
159 |
160 | tokens = tokenize_source(inspect.getsource(SimpleColumn))
161 | self.assertEqual(get_class_column(tokens), 12)
162 |
163 | def test_column_comment(self):
164 | class CommentColumn:
165 | """hello
166 | there
167 |
168 |
169 | hi
170 | """
171 |
172 | arg = 2
173 |
174 | tokens = tokenize_source(inspect.getsource(CommentColumn))
175 | self.assertEqual(get_class_column(tokens), 12)
176 |
177 | def test_column_space(self):
178 | class SpaceColumn:
179 |
180 | arg = 2
181 |
182 | tokens = tokenize_source(inspect.getsource(SpaceColumn))
183 | self.assertEqual(get_class_column(tokens), 12)
184 |
185 | def test_column_method(self):
186 | class FuncColumn:
187 | def func(self):
188 | pass
189 |
190 | tokens = tokenize_source(inspect.getsource(FuncColumn))
191 | self.assertEqual(get_class_column(tokens), 12)
192 |
193 | def test_dataclass(self):
194 | @class_decorator
195 | class DataclassColumn:
196 | arg: int = 5
197 |
198 | tokens = tokenize_source(inspect.getsource(DataclassColumn))
199 | self.assertEqual(get_class_column(tokens), 12)
200 |
201 | def test_dataclass_method(self):
202 | def wrapper(f):
203 | pass
204 |
205 | @class_decorator
206 | class DataclassColumn:
207 | @wrapper
208 | def func(self):
209 | pass
210 |
211 | tokens = tokenize_source(inspect.getsource(DataclassColumn))
212 | self.assertEqual(get_class_column(tokens), 12)
213 |
214 |
215 | class ClassVariableTests(TestCase):
216 | def test_no_variables(self):
217 | class NoVariables:
218 | pass
219 |
220 | self.assertEqual(get_class_variables(NoVariables), {})
221 |
222 | def test_one_variable(self):
223 | class OneVariable:
224 | arg = 2
225 |
226 | class_variables = {"arg": {"comment": ""}}
227 | self.assertEqual(get_class_variables(OneVariable), class_variables)
228 |
229 | def test_multiple_variable(self):
230 | class MultiVariable:
231 | arg_1 = 2
232 | arg_2 = 3
233 |
234 | class_variables = {"arg_1": {"comment": ""}, "arg_2": {"comment": ""}}
235 | self.assertEqual(get_class_variables(MultiVariable), class_variables)
236 |
237 | def test_typed_variables(self):
238 | class TypedVariable:
239 | arg_1: str
240 | arg_2: int = 3
241 |
242 | class_variables = {"arg_1": {"comment": ""}, "arg_2": {"comment": ""}}
243 | self.assertEqual(get_class_variables(TypedVariable), class_variables)
244 |
245 | def test_separated_variables(self):
246 | class SeparatedVariable:
247 | """Comment"""
248 |
249 | arg_1: str
250 |
251 | # Hello
252 | def func(self):
253 | pass
254 |
255 | arg_2: int = 3
256 | """More comment"""
257 |
258 | class_variables = {"arg_1": {"comment": ""}, "arg_2": {"comment": "More comment"}}
259 | self.assertEqual(get_class_variables(SeparatedVariable), class_variables)
260 |
261 | def test_commented_variables(self):
262 | class CommentedVariable:
263 | """Comment"""
264 |
265 | arg_1: str # Arg 1 comment
266 |
267 | # Hello
268 | def func(self):
269 | pass
270 |
271 | arg_2: int = 3 # Arg 2 comment
272 | arg_3: Dict[str, int] # noqa E203,E262 Poorly formatted comment
273 | """More comment"""
274 |
275 | class_variables = {
276 | "arg_1": {"comment": "Arg 1 comment"},
277 | "arg_2": {"comment": "Arg 2 comment"},
278 | "arg_3": {"comment": "noqa E203,E262 Poorly formatted comment More comment"},
279 | }
280 | self.assertEqual(get_class_variables(CommentedVariable), class_variables)
281 |
282 | def test_bad_spacing_multiline(self):
283 | class TrickyMultiline:
284 | """This is really difficult
285 |
286 | so
287 | so very difficult
288 | """
289 |
290 | foo: str = "my" # Header line
291 |
292 | """ Footer
293 | T
294 | A
295 | P
296 |
297 | multi
298 | line!!
299 | """
300 |
301 | class_variables = {}
302 | comment = "Header line Footer\nT\n A\n P\n\n multi\n line!!"
303 | class_variables["foo"] = {"comment": comment}
304 | self.assertEqual(get_class_variables(TrickyMultiline), class_variables)
305 |
306 | def test_triple_quote_multiline(self):
307 | class TripleQuoteMultiline:
308 | bar: int = 0
309 | """biz baz"""
310 |
311 | hi: str
312 | """Hello there"""
313 |
314 | class_variables = {"bar": {"comment": "biz baz"}, "hi": {"comment": "Hello there"}}
315 | self.assertEqual(get_class_variables(TripleQuoteMultiline), class_variables)
316 |
317 | def test_comments_with_quotes(self):
318 | class MultiquoteMultiline:
319 | bar: int = 0
320 | "''biz baz'"
321 |
322 | hi: str
323 | '"Hello there""'
324 |
325 | class_variables = {}
326 | class_variables["bar"] = {"comment": "''biz baz'"}
327 | class_variables["hi"] = {"comment": '"Hello there""'}
328 | self.assertEqual(get_class_variables(MultiquoteMultiline), class_variables)
329 |
330 | def test_multiline_argument(self):
331 | class MultilineArgument:
332 | bar: str = "This is a multiline argument" " that should not be included in the docstring"
333 | """biz baz"""
334 |
335 | class_variables = {"bar": {"comment": "biz baz"}}
336 | self.assertEqual(get_class_variables(MultilineArgument), class_variables)
337 |
338 | def test_multiline_argument_with_final_hashtag_comment(self):
339 | class MultilineArgumentWithHashTagComment:
340 | bar: str = "This is a multiline argument" " that should not be included in the docstring" # biz baz
341 | barr: str = "This is a multiline argument" " that should not be included in the docstring" # bar baz
342 | barrr: str = ( # meow
343 | "This is a multiline argument" # blah
344 | " that should not be included in the docstring" # grrrr
345 | ) # yay!
346 |
347 | class_variables = {"bar": {"comment": "biz baz"}, "barr": {"comment": "bar baz"}, "barrr": {"comment": "yay!"}}
348 | self.assertEqual(get_class_variables(MultilineArgumentWithHashTagComment), class_variables)
349 |
350 | def test_single_quote_multiline(self):
351 | class SingleQuoteMultiline:
352 | bar: int = 0
353 | "biz baz"
354 |
355 | hi: str
356 | "Hello there"
357 |
358 | class_variables = {"bar": {"comment": "biz baz"}, "hi": {"comment": "Hello there"}}
359 | self.assertEqual(get_class_variables(SingleQuoteMultiline), class_variables)
360 |
361 | def test_functions_with_docs_multiline(self):
362 | class FunctionsWithDocs:
363 | i: int = 0
364 |
365 | def f(self):
366 | """Function"""
367 | a: str = "hello" # noqa F841
368 | """with docs"""
369 |
370 | class_variables = {"i": {"comment": ""}}
371 | self.assertEqual(get_class_variables(FunctionsWithDocs), class_variables)
372 |
373 | def test_dataclass(self):
374 | @class_decorator
375 | class DataclassColumn:
376 | arg: int = 5
377 |
378 | class_variables = {"arg": {"comment": ""}}
379 | self.assertEqual(get_class_variables(DataclassColumn), class_variables)
380 |
381 |
382 | class GetLiteralsTests(TestCase):
383 | def test_get_literals_string(self) -> None:
384 | literal_f, shapes = get_literals(Literal["square", "triangle", "circle"], "shape")
385 | self.assertEqual(shapes, ["square", "triangle", "circle"])
386 | self.assertEqual(literal_f("square"), "square")
387 | self.assertEqual(literal_f("triangle"), "triangle")
388 | self.assertEqual(literal_f("circle"), "circle")
389 |
390 | def test_get_literals_primitives(self) -> None:
391 | literals = [True, "one", 2, 3.14]
392 | literal_f, prims = get_literals(Literal[True, "one", 2, 3.14], "number")
393 | self.assertEqual(prims, literals)
394 | self.assertEqual([literal_f(str(p)) for p in prims], literals)
395 |
396 | def test_get_literals_uniqueness(self) -> None:
397 | with self.assertRaises(ArgumentTypeError):
398 | get_literals(Literal["two", 2, "2"], "number")
399 |
400 | def test_get_literals_empty(self) -> None:
401 | literal_f, prims = get_literals(Literal, "hi")
402 | self.assertEqual(prims, [])
403 |
404 |
405 | class TupleTypeEnforcerTests(TestCase):
406 | def test_tuple_type_enforcer_zero_types(self):
407 | enforcer = TupleTypeEnforcer(types=[])
408 | with self.assertRaises(IndexError):
409 | enforcer("hi")
410 |
411 | def test_tuple_type_enforcer_one_type_str(self):
412 | enforcer = TupleTypeEnforcer(types=[str])
413 | self.assertEqual(enforcer("hi"), "hi")
414 |
415 | def test_tuple_type_enforcer_one_type_int(self):
416 | enforcer = TupleTypeEnforcer(types=[int])
417 | self.assertEqual(enforcer("123"), 123)
418 |
419 | def test_tuple_type_enforcer_one_type_float(self):
420 | enforcer = TupleTypeEnforcer(types=[float])
421 | self.assertEqual(enforcer("3.14159"), 3.14159)
422 |
423 | def test_tuple_type_enforcer_one_type_bool(self):
424 | enforcer = TupleTypeEnforcer(types=[bool])
425 | self.assertEqual(enforcer("True"), True)
426 |
427 | enforcer = TupleTypeEnforcer(types=[bool])
428 | self.assertEqual(enforcer("true"), True)
429 |
430 | enforcer = TupleTypeEnforcer(types=[bool])
431 | self.assertEqual(enforcer("False"), False)
432 |
433 | enforcer = TupleTypeEnforcer(types=[bool])
434 | self.assertEqual(enforcer("false"), False)
435 |
436 | enforcer = TupleTypeEnforcer(types=[bool])
437 | self.assertEqual(enforcer("tRu"), True)
438 |
439 | enforcer = TupleTypeEnforcer(types=[bool])
440 | self.assertEqual(enforcer("faL"), False)
441 |
442 | enforcer = TupleTypeEnforcer(types=[bool])
443 | self.assertEqual(enforcer("1"), True)
444 |
445 | enforcer = TupleTypeEnforcer(types=[bool])
446 | self.assertEqual(enforcer("0"), False)
447 |
448 | def test_tuple_type_enforcer_multi_types_same(self):
449 | enforcer = TupleTypeEnforcer(types=[str, str])
450 | args = ["hi", "bye"]
451 | output = [enforcer(arg) for arg in args]
452 | self.assertEqual(output, args)
453 |
454 | enforcer = TupleTypeEnforcer(types=[int, int, int])
455 | args = [123, 456, -789]
456 | output = [enforcer(str(arg)) for arg in args]
457 | self.assertEqual(output, args)
458 |
459 | enforcer = TupleTypeEnforcer(types=[float, float, float, float])
460 | args = [1.23, 4.56, -7.89, 3.14159]
461 | output = [enforcer(str(arg)) for arg in args]
462 | self.assertEqual(output, args)
463 |
464 | enforcer = TupleTypeEnforcer(types=[bool, bool, bool, bool, bool])
465 | args = ["True", "False", "1", "0", "tru"]
466 | true_output = [True, False, True, False, True]
467 | output = [enforcer(str(arg)) for arg in args]
468 | self.assertEqual(output, true_output)
469 |
470 | def test_tuple_type_enforcer_multi_types_different(self):
471 | enforcer = TupleTypeEnforcer(types=[str, int, float, bool])
472 | args = ["hello", 77, 0.2, "tru"]
473 | true_output = ["hello", 77, 0.2, True]
474 | output = [enforcer(str(arg)) for arg in args]
475 | self.assertEqual(output, true_output)
476 |
477 | def test_tuple_type_enforcer_infinite(self):
478 | enforcer = TupleTypeEnforcer(types=[int], loop=True)
479 | args = [1, 2, -5, 20]
480 | output = [enforcer(str(arg)) for arg in args]
481 | self.assertEqual(output, args)
482 |
483 |
484 | class NestedReplaceTypeTests(TestCase):
485 | def test_nested_replace_type_notype(self):
486 | obj = ["123", 4, 5, ("hello", 4.4)]
487 | replaced_obj = _nested_replace_type(obj, bool, int)
488 | self.assertEqual(obj, replaced_obj)
489 |
490 | def test_nested_replace_type_unnested(self):
491 | obj = ["123", 4, 5, ("hello", 4.4), True, False, "hi there"]
492 | replaced_obj = _nested_replace_type(obj, tuple, list)
493 | correct_obj = ["123", 4, 5, ["hello", 4.4], True, False, "hi there"]
494 | self.assertNotEqual(obj, replaced_obj)
495 | self.assertEqual(correct_obj, replaced_obj)
496 |
497 | def test_nested_replace_type_nested(self):
498 | obj = ["123", [4, (1, 2, (3, 4))], 5, ("hello", (4,), 4.4), {"1": [2, 3, [{"2": (3, 10)}, " hi "]]}]
499 | replaced_obj = _nested_replace_type(obj, tuple, list)
500 | correct_obj = ["123", [4, [1, 2, [3, 4]]], 5, ["hello", [4], 4.4], {"1": [2, 3, [{"2": [3, 10]}, " hi "]]}]
501 | self.assertNotEqual(obj, replaced_obj)
502 | self.assertEqual(correct_obj, replaced_obj)
503 |
504 |
505 | class Person:
506 | def __init__(self, name: str) -> None:
507 | self.name = name
508 |
509 | def __eq__(self, other: Any) -> bool:
510 | return isinstance(other, Person) and self.name == other.name
511 |
512 |
513 | class PythonObjectEncoderTests(TestCase):
514 | def test_python_object_encoder_simple_types(self):
515 | obj = [1, 2, "hi", "bye", 7.3, [1, 2, "blarg"], True, False, None]
516 | dumps = json.dumps(obj, indent=4, sort_keys=True, cls=define_python_object_encoder())
517 | recreated_obj = json.loads(dumps, object_hook=as_python_object)
518 | self.assertEqual(recreated_obj, obj)
519 |
520 | def test_python_object_encoder_tuple(self):
521 | obj = [1, 2, "hi", "bye", 7.3, (1, 2, "blarg"), [("hi", "bye"), 2], {"hi": {"bye": (3, 4)}}, True, False, None]
522 | dumps = json.dumps(obj, indent=4, sort_keys=True, cls=define_python_object_encoder())
523 | recreated_obj = json.loads(dumps, object_hook=as_python_object)
524 | self.assertEqual(recreated_obj, obj)
525 |
526 | def test_python_object_encoder_set(self):
527 | obj = [1, 2, "hi", "bye", 7.3, {1, 2, "blarg"}, [{"hi", "bye"}, 2], {"hi": {"bye": {3, 4}}}, True, False, None]
528 | dumps = json.dumps(obj, indent=4, sort_keys=True, cls=define_python_object_encoder())
529 | recreated_obj = json.loads(dumps, object_hook=as_python_object)
530 | self.assertEqual(recreated_obj, obj)
531 |
532 | def test_python_object_encoder_complex(self):
533 | obj = [
534 | 1,
535 | 2,
536 | "hi",
537 | "bye",
538 | 7.3,
539 | {1, 2, "blarg"},
540 | [("hi", "bye"), 2],
541 | {"hi": {"bye": {3, 4}}},
542 | True,
543 | False,
544 | None,
545 | (Person("tappy"), Person("tapper")),
546 | ]
547 | dumps = json.dumps(obj, indent=4, sort_keys=True, cls=define_python_object_encoder())
548 | recreated_obj = json.loads(dumps, object_hook=as_python_object)
549 | self.assertEqual(recreated_obj, obj)
550 |
551 | def test_python_object_encoder_unpicklable(self):
552 | class CannotPickleThis:
553 | """Da na na na. Can't pickle this."""
554 |
555 | def __init__(self):
556 | self.x = 1
557 |
558 | obj = [1, CannotPickleThis()]
559 | expected_obj = [1, UnpicklableObject()]
560 | with self.assertRaises(ValueError):
561 | json.dumps(obj, indent=4, sort_keys=True, cls=define_python_object_encoder())
562 |
563 | dumps = json.dumps(obj, indent=4, sort_keys=True, cls=define_python_object_encoder(True))
564 | recreated_obj = json.loads(dumps, object_hook=as_python_object)
565 | self.assertEqual(recreated_obj, expected_obj)
566 |
567 |
568 | class EnforceReproducibilityTests(TestCase):
569 | def test_saved_reproducibility_data_is_none(self):
570 | with self.assertRaises(ValueError):
571 | enforce_reproducibility(None, {}, "here")
572 |
573 | def test_git_url_not_in_saved_reproducibility_data(self):
574 | with self.assertRaises(ValueError):
575 | enforce_reproducibility({}, {}, "here")
576 |
577 | def test_git_url_not_in_current_reproducibility_data(self):
578 | with self.assertRaises(ValueError):
579 | enforce_reproducibility({"git_url": "none"}, {}, "here")
580 |
581 | def test_git_urls_disagree(self):
582 | with self.assertRaises(ValueError):
583 | enforce_reproducibility({"git_url": "none"}, {"git_url": "some"}, "here")
584 |
585 | def test_throw_error_for_saved_uncommitted_changes(self):
586 | with self.assertRaises(ValueError):
587 | enforce_reproducibility(
588 | {"git_url": "none", "git_has_uncommitted_changes": True}, {"git_url": "some"}, "here"
589 | )
590 |
591 | def test_throw_error_for_uncommitted_changes(self):
592 | with self.assertRaises(ValueError):
593 | enforce_reproducibility(
594 | {"git_url": "none", "git_has_uncommitted_changes": False},
595 | {"git_url": "some", "git_has_uncommitted_changes": True},
596 | "here",
597 | )
598 |
599 |
600 | if __name__ == "__main__":
601 | unittest.main()
602 |
--------------------------------------------------------------------------------