├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── conftest.py ├── fancyflags ├── __init__.py ├── _argument_parsers.py ├── _argument_parsers_test.py ├── _auto.py ├── _auto_test.py ├── _define_auto.py ├── _define_auto_test.py ├── _definitions.py ├── _definitions_test.py ├── _flags.py ├── _flags_test.py ├── _flagsaver_test.py ├── _metadata.py └── examples │ ├── example_module.py │ └── override_test.py ├── pytest.ini ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore byte-compiled Python code 2 | *.py[cod] 3 | 4 | # Ignore directories created during the build/installation process 5 | *.egg-info/ 6 | build/ 7 | dist/ 8 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All significant changes to this project will be documented here. 4 | 5 | ## [Unreleased] 6 | 7 | ## [1.2] 8 | 9 | Release date: 2023-07-04 10 | 11 | * Added support for DateTime flag type for `ff.DEFINE_auto`. 12 | * Added option to skip defining flags for a subset of arguments in `ff.auto`. 13 | * Added support for functions or constructors without default arguments in 14 | `ff.DEFINE_auto`. 15 | * Added a `case_sensitive` argument to `ff.EnumClass` and set to `False` by 16 | default, to match the 17 | [corresponding change](https://github.com/abseil/abseil-py/commit/eb94d9587c6f2eade9617237fb6bba1364226a3b) 18 | in `DEFINE_enum_class` 19 | * Added support for variadic tuples in `ff.DEFINE_auto`. 20 | * Added support for `--foo`/`--nofoo` syntax for passing boolean flags, made 21 | this the default way of serializing such flags. 22 | * Dropped Python 3.7 support. 23 | * Added Python 3.11 support. 24 | 25 | ## [1.1] 26 | 27 | Release date: 2021-11-27 28 | 29 | * Made help strings optional for all `Item`s and `MultiItem`s. 30 | * Added `ff.DEFINE_auto`. 31 | * Dropped Python 3.6 support. 32 | * Added Python 3.10 support. 33 | * Added/improved type hints throughout. 34 | 35 | ## [1.0] 36 | 37 | Release date: 2021-02-08 38 | 39 | * Initial release. 40 | 41 | [Unreleased]: https://github.com/deepmind/fancyflags/compare/v1.2...HEAD 42 | [1.2]: https://github.com/deepmind/fancyflags/compare/v1.1...v1.2 43 | [1.1]: https://github.com/deepmind/fancyflags/compare/v1.0...v1.1 44 | [1.0]: https://github.com/deepmind/fancyflags/releases/tag/v1.0 45 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fancyflags 2 | 3 | 4 | 5 | ![PyPI Python version](https://img.shields.io/pypi/pyversions/fancyflags) 6 | ![PyPI version](https://badge.fury.io/py/fancyflags.svg) 7 | 8 | ## Introduction 9 | 10 | `fancyflags` is a Python library that extends 11 | [`absl.flags`](https://github.com/abseil/abseil-py) with additional structured 12 | flag types. 13 | 14 | `fancyflags` provides flags corresponding to structures such as 15 | [dicts](#dict-flags), [dataclasses, and (somewhat) arbitrary callables](#auto). 16 | 17 | These flags are typed and validated like regular `absl` flags, catching errors 18 | before they propagate into your program. To override values, users can access 19 | familiar "dotted" flag names. 20 | 21 | TIP: Already a `fancyflags` user? Check out our [usage tips](#tips)! 22 | 23 | ## A short note on design philosophy: 24 | 25 | `fancyflags` promotes mixing with regular `absl` flags. In many cases a few 26 | regular `absl` flags are all you need! 27 | 28 | `fancyflags` does not require you to modify library code: it should only be used 29 | in your "main" file 30 | 31 | `fancyflags` is not a dependency injection framework, and avoids 32 | programming-language-like power features. We prefer that users write regular 33 | Python for wiring up their code, because it's explicit, simple to understand, 34 | and allows static analysis tools to identify problems. 35 | 36 | ## Installation 37 | 38 | `fancyflags` can be installed from PyPI using `pip`: 39 | 40 | ```shell 41 | pip install fancyflags 42 | ``` 43 | 44 | It can also be installed directly from our GitHub repository: 45 | 46 | ```shell 47 | pip install git+git://github.com/deepmind/fancyflags.git 48 | ``` 49 | 50 | or alternatively by checking out a local copy of our repository and running: 51 | 52 | ```shell 53 | pip install /path/to/local/fancyflags/ 54 | ``` 55 | 56 | ## Dict flags 57 | 58 | If we have a class `Replay`, with arguments `capacity`, `priority_exponent` and 59 | others, we can define a corresponding dict flag in our main script 60 | 61 | ```python 62 | import fancyflags as ff 63 | 64 | _REPLAY_FLAG = ff.DEFINE_dict( 65 | "replay", 66 | capacity=ff.Integer(int(1e6)), 67 | priority_exponent=ff.Float(0.6), 68 | importance_sampling_exponent=ff.Float(0.4), 69 | removal_strategy=ff.Enum("fifo", ["rand", "fifo", "max_value"]) 70 | ) 71 | ``` 72 | 73 | and `**unpack` the values directly into the `Replay` constructor 74 | 75 | ```python 76 | replay_lib.Replay(**_REPLAY_FLAG.value) 77 | ``` 78 | 79 | `ff.DEFINE_dict` creates a flag named `replay`, with a default value of 80 | 81 | ```python 82 | { 83 | "capacity": 1000000, 84 | "priority_exponent": 0.6, 85 | "importance_sampling_exponent": 0.4, 86 | "removal_strategy": "fifo", 87 | } 88 | ``` 89 | 90 | For each item in the dict, `ff.DEFINE_dict` also generates a dot-delimited 91 | "item" flag that can be overridden from the command line. In this example the 92 | item flags would be 93 | 94 | ``` 95 | replay.capacity 96 | replay.priority_exponent 97 | replay.importance_sampling_exponent 98 | replay.removal_strategy 99 | ``` 100 | 101 | Overriding an item flag from the command line updates the corresponding entry in 102 | the dict flag. The value of the dict flag can be accessed by the return value 103 | of `ff.DEFINE_dict` (`_REPLAY_FLAG.value` in the example above), or via the 104 | `FLAGS.replay` attribute of the `absl.flags` module. For example, the override 105 | 106 | ```shell 107 | python script_name.py -- --replay.capacity=2000 --replay.removal_strategy=max_value 108 | ``` 109 | 110 | sets `_REPLAY_FLAG.value` to 111 | 112 | ```python 113 | { 114 | "capacity": 2000, # <-- Overridden 115 | "priority_exponent": 0.6, 116 | "importance_sampling_exponent": 0.4, 117 | "removal_strategy": "max_value", # <-- Overridden 118 | } 119 | ``` 120 | 121 | ## Nested dicts 122 | 123 | fancyflags also supports nested dictionaries: 124 | 125 | ```python 126 | _NESTED_REPLAY_FLAG = ff.DEFINE_dict( 127 | "replay", 128 | capacity=ff.Integer(int(1e6)), 129 | exponents=dict( 130 | priority=ff.Float(0.6), 131 | importance_sampling=ff.Float(0.4), 132 | ) 133 | ) 134 | ``` 135 | 136 | In this example, `_NESTED_REPLAY_FLAG.value` would be 137 | 138 | ```python 139 | { 140 | "capacity": 1000000, 141 | "exponents" : { 142 | "priority": 0.6, 143 | "importance_sampling": 0.4, 144 | } 145 | } 146 | ``` 147 | 148 | and the generated flags would be 149 | 150 | ``` 151 | replay.capacity 152 | replay.exponents.priority 153 | replay.exponents.importance_sampling 154 | ``` 155 | 156 | ### Help strings 157 | 158 | fancyflags uses the item flag's name as the default help string, however this 159 | can also be set manually: 160 | 161 | ```python 162 | _NESTED_REPLAY_FLAG = ff.DEFINE_dict( 163 | "replay", 164 | capacity=ff.Integer(int(1e6), "Maximum size of replay buffer."), 165 | exponents=dict( 166 | priority=ff.Float(0.6), # Help string: "replay.exponents.priority" 167 | importance_sampling=ff.Float(0.4, "Importance sampling coefficient."), 168 | ) 169 | ) 170 | ``` 171 | 172 | ## "Auto" flags for functions and other structures {#auto} 173 | 174 | `fancyflags` also provides `ff.DEFINE_auto` which automatically generates a flag 175 | declaration corresponding to the signature of a given callable. The return value 176 | will also carry the correct type information. 177 | 178 | For example the callable could be a constructor 179 | 180 | ```python 181 | _REPLAY = ff.DEFINE_auto('replay', replay_lib.Replay) 182 | ``` 183 | 184 | or it could be a container type, such as a `dataclass` 185 | 186 | ```python 187 | @dataclasses.dataclass 188 | class DataSettings: 189 | dataset_name: str = 'mnist' 190 | split: str = 'train' 191 | batch_size: int = 128 192 | 193 | # In main script. 194 | # Exposes flags: --data.dataset_name --data.split and --data.batch_size. 195 | _DATA_SETTINGS = ff.DEFINE_auto('data', datasets.DataSettings) 196 | 197 | def main(argv): 198 | # del argv # Unused. 199 | dataset = datasets.load(_DATA_SETTINGS.value()) 200 | # ... 201 | ``` 202 | 203 | or any other callable that satisfies the `ff.auto` requirements. It's also 204 | possible to override keyword arguments in the call to `.value()`, e.g. 205 | 206 | ```python 207 | test_settings = _DATA_SETTINGS.value(split='test') 208 | ``` 209 | 210 | ## Defining a dict flag from a function or constructor. 211 | 212 | The function `ff.auto` returns a dictionary of `ff.Items` given a function or 213 | constructor. This is used to build `ff.DEFINE_dict` and is also exposed in the 214 | top-level API. 215 | 216 | `ff.auto` can be used with `ff.DEFINE_dict` as follows: 217 | 218 | ```python 219 | _WRITER_KWARGS = ff.DEFINE_dict('writer', **ff.auto(logging.Writer)) 220 | ``` 221 | `ff.auto` may be useful for creating kwarg dictionaries in situations where 222 | `ff.DEFINE_auto` is not suitable, for example to pass kwargs into nested 223 | function calls. 224 | 225 | ## Auto requirements 226 | 227 | `ff.DEFINE_auto` and `ff.auto` will work if: 228 | 229 | 1. The function or class constructor has type annotations. 230 | 1. Each argument has a default value. 231 | 1. The types of the arguments are relatively simple types (`int`, `str`, 232 | `bool`, `float`, or sequences thereof). 233 | 234 | ## Notes on using `flagsaver` 235 | 236 | abseil-py's [flagsaver](https://github.com/abseil/abseil-py/blob/master/absl/testing/flagsaver.py) 237 | module is useful for safely overriding flag values in test code. Here's how to 238 | make it work well with fancyflags. 239 | 240 | ### Making dotted names work with `flagsaver` keyword arguments 241 | 242 | Since `flagsaver` relies on keyword arguments, overriding a flag with a dot in 243 | its name will result in a `SyntaxError`: 244 | 245 | ```python 246 | # Invalid Python syntax. 247 | flagsaver.flagsaver(replay.capacity=100, replay.priority_exponent=0.5) 248 | ``` 249 | 250 | To work around this, first create a dictionary and then `**` unpack it: 251 | 252 | ```python 253 | # Valid Python syntax. 254 | flagsaver.flagsaver(**{'replay.capacity': 100, 'replay.priority_exponent': 0.5}) 255 | ``` 256 | 257 | ### Be careful when setting flag values inside a `flagsaver` context 258 | 259 | If possible we recommend that you avoid setting the flag values inside the 260 | context altogether, and instead pass the override values directly to the 261 | `flagsaver` function as shown above. However, if you _do_ need to set values 262 | inside the context, be aware of this gotcha: 263 | 264 | This syntax does not work properly: 265 | 266 | ```python 267 | with flagsaver.flagsaver(): 268 | FLAGS.replay["capacity"] = 100 269 | # The original value will not be restored correctly. 270 | ``` 271 | 272 | This syntax _does_ work properly: 273 | 274 | ```python 275 | with flagsaver.flagsaver(): 276 | FLAGS["replay.capacity"].value = 100 277 | # The original value *will* be restored correctly. 278 | ``` 279 | 280 | ## fancyflags in more detail 281 | 282 | ### What is an `ff.Float` or `ff.Integer`? 283 | 284 | `ff.Float` and `ff.Integer` are both `ff.Item`s. An `ff.Item` is essentially a 285 | mapping from a default value and a help string, to a specific type of flag. 286 | 287 | The `ff.DEFINE_dict` function traverses its keyword arguments (and any nested 288 | dicts) to determine the name of each flag. It calls the `.define()` method of 289 | each `ff.Item`, passing it the name information, and the `ff.Item` then defines 290 | the appropriate dot-delimited flag. 291 | 292 | ### What `ff.Item`s are available? 293 | 294 | ff.Item | Corresponding Flag 295 | :------------------ | :------------------------------ 296 | `ff.Boolean` | `flags.DEFINE_boolean` 297 | `ff.Integer` | `flags.DEFINE_integer` 298 | `ff.Enum` | `flags.DEFINE_enum` 299 | `ff.EnumClass` | `flags.DEFINE_enum_class` 300 | `ff.Float` | `flags.DEFINE_float` 301 | `ff.Sequence` | `ff.DEFINE_sequence` 302 | `ff.String` | `flags.DEFINE_string` 303 | `ff.StringList` | `flags.DEFINE_list` 304 | `ff.MultiEnum` | `ff.DEFINE_multi_enum` 305 | `ff.MultiEnumClass` | `flags.DEFINE_multi_enum_class` 306 | `ff.MultiString` | `flags.DEFINE_multi_string` 307 | `ff.DateTime` | - 308 | 309 | ### Defining a new `ff.Item` 310 | 311 | Given a `flags.ArgumentParser`, we can define an `ff.Item` in a few lines of 312 | code. 313 | 314 | For example, if we wanted to define an `ff.Item` that corresponded to 315 | `flags.DEFINE_spaceseplist`, we would look for the parser that this definition 316 | uses, and write: 317 | 318 | ```python 319 | class SpaceSepList(ff.Item): 320 | 321 | def __init__(self, default, help_string) 322 | parser = flags.WhitespaceSeparatedListParser() 323 | super(SpaceSepList, self).__init__(default, help_string, parser) 324 | 325 | ``` 326 | 327 | Note that custom `ff.Item` definitions do not _need_ to be added to the 328 | fancyflags library to work. 329 | 330 | ### Defining `Item` flags only 331 | 332 | We also expose a `define_flags` function, which defines flags from a flat or 333 | nested dictionary that maps names to `ff.Item`s. This function is used as part 334 | of `ff.DEFINE_dict` and `ff.DEFINE_auto`, and may be useful for writing 335 | extensions on top of `fancyflags`. 336 | 337 | ```python 338 | _writer_items = dict( 339 | path=ff.String('/path/to/logdir', "Output directory."), 340 | log_every_n=ff.Integer(100, "Number of calls between writes to disk."), 341 | ) 342 | 343 | _WRITER_KWARGS = ff.define_flags("writer", _writer_items) 344 | ``` 345 | 346 | This example defines the flags `replay.capacity` and `replay.priority_exponent` 347 | only: does _not_ define a dict-flag. The return value (`REPLAY`) is a 348 | dictionary that contains the default values. Any overrides to the individual 349 | flags will also update the corresponding item in this dictionary. 350 | 351 | ### Tips 352 | 353 | Any direct access, e.g. `_DICT_FLAG.value['item']` is an indication that you 354 | may want to change your flag structure: 355 | 356 | * Try to align dict flags with constructors or functions, so that you always 357 | `**unpack` the items into their corresponding constructor or function. 358 | * If you need to access an item in a dict directly, e.g. because its value is 359 | used in multiple places, consider moving that item to its own plain flag. 360 | * Check to see if you should have `**unpacked` somewhere up the call-stack, 361 | and convert function "config" args to individual items if needed. 362 | * Don't group things under a dict flag just because they're thematically 363 | related, and don't have one catch-all dict flag. Instead, define individual 364 | dict flags to match the constructor or function calls as needed. 365 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """pytest configuration for fancyflags.""" 16 | 17 | from absl import flags 18 | 19 | collect_ignore = [ 20 | 'conftest.py', 21 | 'setup.py', 22 | ] 23 | 24 | 25 | def pytest_configure(config): 26 | del config # Unused. 27 | # We need to skip flag parsing when executing under pytest. 28 | flags.FLAGS.mark_as_parsed() 29 | -------------------------------------------------------------------------------- /fancyflags/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """An extended flags library. The main component is a nested dict flag.""" 16 | 17 | # pylint: disable=g-bad-import-order,g-import-not-at-top 18 | 19 | # Add current module to disclaimed module ids. 20 | from absl import flags 21 | 22 | flags.disclaim_key_flags() 23 | 24 | from fancyflags._metadata import __version__ 25 | 26 | # Define flags based on a dictionary or sequence. 27 | from fancyflags._definitions import DEFINE_dict 28 | from fancyflags._definitions import DEFINE_sequence 29 | from fancyflags._definitions import define_flags 30 | 31 | # Automatically build fancyflags defs from a callable signature. 32 | from fancyflags._auto import auto 33 | from fancyflags._define_auto import DEFINE_auto 34 | 35 | # `Item` definitions for supported types. 36 | from fancyflags._definitions import Boolean 37 | from fancyflags._definitions import DateTime 38 | from fancyflags._definitions import Enum 39 | from fancyflags._definitions import EnumClass 40 | from fancyflags._definitions import Float 41 | from fancyflags._definitions import Integer 42 | from fancyflags._definitions import MultiEnum 43 | from fancyflags._definitions import MultiEnumClass 44 | from fancyflags._definitions import MultiString 45 | from fancyflags._definitions import Sequence 46 | from fancyflags._definitions import String 47 | from fancyflags._definitions import StringList 48 | 49 | # Class for adding new flag types. 50 | from fancyflags._definitions import Item 51 | 52 | # usage_logging: import 53 | 54 | # experimental: import 55 | -------------------------------------------------------------------------------- /fancyflags/_argument_parsers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Argument parsers.""" 16 | 17 | import ast 18 | import datetime 19 | import enum 20 | 21 | from absl import flags 22 | 23 | BASIC_SEQUENCE_TYPES = (list, tuple) 24 | 25 | # We assume a sequence contains only these types. Python has no primitive types. 26 | SIMPLE_TYPES = (bool, float, int, str) 27 | 28 | NOT_A_SIMPLE_TYPE_MESSAGE = """ 29 | Input list contains unsupported type {{}}, however each element in a sequence 30 | must be a {} or {}. 31 | """.format( 32 | ", ".join(type_.__name__ for type_ in SIMPLE_TYPES[:-1]), 33 | SIMPLE_TYPES[-1].__name__, 34 | ) 35 | 36 | _EMPTY_STRING_ERROR_MESSAGE = """ 37 | Empty sequences should be given explicitly as [] or () and not as an empty 38 | string""" 39 | 40 | 41 | class SequenceParser(flags.ArgumentParser): 42 | """Parser of simple sequences containing simple Python values.""" 43 | 44 | def parse(self, argument): 45 | """Parses the argument as a string-formatted sequence (list or tuple). 46 | 47 | Essentially reverses the result of `"{}".format(a_sequence)` 48 | 49 | Args: 50 | argument: The flag value as a string, list, tuple or None. Examples of 51 | valid input strings are `"(1,2,3)"` and `[0.2, 0.3, 1.0]`. 52 | 53 | Returns: 54 | The parsed sequence. 55 | 56 | Raises: 57 | TypeError: If the input type is not supported, or if the input is not a 58 | flat sequence that only contains simple Python values. 59 | ValueError: If the input is an empty string. 60 | """ 61 | if argument is None: 62 | return [] 63 | elif isinstance(argument, BASIC_SEQUENCE_TYPES): 64 | result = argument[:] 65 | elif isinstance(argument, str): 66 | if not argument: 67 | raise ValueError(_EMPTY_STRING_ERROR_MESSAGE) 68 | try: 69 | result = ast.literal_eval(argument) 70 | except (ValueError, SyntaxError) as e: 71 | raise ValueError( 72 | f'Failed to parse "{argument}" as a python literal.' 73 | ) from e 74 | 75 | if not isinstance(result, BASIC_SEQUENCE_TYPES): 76 | raise TypeError( 77 | "Input string should represent a list or tuple, however it " 78 | "evaluated as a {}.".format(type(result).__name__) 79 | ) 80 | else: 81 | raise TypeError("Unsupported type {}.".format(type(argument).__name__)) 82 | 83 | # Make sure the result is a flat sequence of simple types. 84 | for value in result: 85 | if not isinstance(value, SIMPLE_TYPES): 86 | raise TypeError(NOT_A_SIMPLE_TYPE_MESSAGE.format(type(value).__name__)) 87 | 88 | return result 89 | 90 | def flag_type(self): 91 | """See base class.""" 92 | return "sequence" 93 | 94 | 95 | class MultiEnumParser(flags.ArgumentParser): 96 | """Parser of multiple enum values. 97 | 98 | This parser allows the flag values to be sequences of any type, unlike 99 | flags.DEFINE_multi_enum which only allows strings. 100 | """ 101 | 102 | def __init__(self, enum_values): 103 | if not enum_values: 104 | raise ValueError("enum_values cannot be empty") 105 | if any(not value for value in enum_values): 106 | raise ValueError("No element of enum_values can be empty") 107 | 108 | super().__init__() 109 | self.enum_values = enum_values 110 | 111 | def parse(self, arguments): 112 | """Determines validity of arguments. 113 | 114 | Args: 115 | arguments: list, tuple, or enum of flag values. Each value may be any type 116 | 117 | Returns: 118 | The input list, tuple or enum if valid. 119 | 120 | Raises: 121 | TypeError: If the input type is not supported. 122 | ValueError: Raised if an argument element didn't match anything in enum. 123 | """ 124 | if arguments is None: 125 | return [] 126 | elif isinstance(arguments, BASIC_SEQUENCE_TYPES): 127 | result = arguments[:] 128 | elif isinstance(arguments, enum.EnumMeta): 129 | result = arguments 130 | elif isinstance(arguments, str): 131 | result = ast.literal_eval(arguments) 132 | 133 | if not isinstance(result, BASIC_SEQUENCE_TYPES): 134 | raise TypeError( 135 | "Input string should represent a list or tuple, however it " 136 | "evaluated as a {}.".format(type(result).__name__) 137 | ) 138 | else: 139 | raise TypeError("Unsupported type {}.".format(type(arguments).__name__)) 140 | 141 | if not all(arg in self.enum_values for arg in result): 142 | raise ValueError( 143 | "Argument values should be one of <{}>".format( 144 | "|".join(str(value) for value in self.enum_values) 145 | ) 146 | ) 147 | else: 148 | return result 149 | 150 | def flag_type(self): 151 | return "multi enum" 152 | 153 | 154 | class PossiblyNaiveDatetimeParser(flags.ArgumentParser): 155 | """Parses an ISO format datetime string into a datetime.datetime.""" 156 | 157 | def parse(self, value) -> datetime.datetime: 158 | if isinstance(value, datetime.datetime): 159 | return value 160 | 161 | # Handle ambiguous cases such as 2000-01-01+01:00, where the part after the 162 | # '+' sign looks like timezone info but is actually just the time. 163 | if value[10:11] in ("+", "-"): 164 | # plus/minus as separator between date and time (can be any character) 165 | raise ValueError( 166 | f"datetime value {value!r} uses {value[10]!r} as separator " 167 | "between date and time (excluded to avoid confusion between " 168 | "time and offset). Use any other character instead, e.g. " 169 | f"{value[:10] + 'T' + value[11:]!r}" 170 | ) 171 | 172 | try: 173 | return datetime.datetime.fromisoformat(value) 174 | except ValueError as e: 175 | raise ValueError(f"invalid datetime value {value!r}: {e}") from None 176 | 177 | def flag_type(self) -> str: 178 | return "datetime.datetime" 179 | -------------------------------------------------------------------------------- /fancyflags/_argument_parsers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for argument_parsers.""" 16 | 17 | import datetime 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from fancyflags import _argument_parsers 22 | 23 | 24 | UTC = datetime.timezone.utc 25 | TZINFO_4HRS = datetime.timezone(datetime.timedelta(hours=4)) 26 | 27 | 28 | class SequenceParserTest(parameterized.TestCase): 29 | 30 | def setUp(self): 31 | super().setUp() 32 | self.parser = _argument_parsers.SequenceParser() 33 | 34 | @parameterized.parameters( 35 | ([1, 2, 3],), 36 | ([],), 37 | ((),), 38 | (["hello", "world"],), 39 | ((3.14, 2.718),), 40 | ((1, -1.0),), 41 | ) 42 | def test_parse_input_sequence(self, input_sequence): 43 | result = self.parser.parse(input_sequence) 44 | self.assertEqual(result, input_sequence) 45 | 46 | @parameterized.parameters( 47 | ("[1, 2, 3]", [1, 2, 3]), 48 | ("[]", []), 49 | ("()", ()), 50 | ("['hello', 'world']", ["hello", "world"]), 51 | ("(3.14, 2.718)", (3.14, 2.718)), 52 | ( 53 | "(1, -1.0)", 54 | (1, -1.0), 55 | ), 56 | ) 57 | def test_parse_input_string(self, input_string, expected): 58 | result = self.parser.parse(input_string) 59 | self.assertEqual(result, expected) 60 | 61 | # Check that the string-formatted expected result also matches the input 62 | self.assertEqual("{}".format(expected), input_string) 63 | 64 | @parameterized.parameters( 65 | ('["hello", u"world"]', ["hello", "world"]), 66 | ("(1,2,3)", (1, 2, 3)), 67 | ) 68 | def test_parse_input_string_different_format(self, input_string, expected): 69 | # The parser/ast result should also work for slightly different formatting. 70 | result = self.parser.parse(input_string) 71 | self.assertEqual(result, expected) 72 | 73 | def test_parse_none(self): 74 | result = self.parser.parse(None) 75 | self.assertEqual(result, []) 76 | 77 | @parameterized.parameters( 78 | ({1, 2, 3},), 79 | (100,), 80 | ) 81 | def test_parse_invalid_input_type(self, input_item): 82 | with self.assertRaisesRegex(TypeError, "Unsupported type"): 83 | self.parser.parse(input_item) 84 | 85 | @parameterized.parameters( 86 | ("'hello world'",), 87 | ("{1: 3}",), 88 | ) 89 | def test_parse_invalid_evaluated_type(self, input_string): 90 | with self.assertRaisesRegex(TypeError, "evaluated as"): 91 | self.parser.parse(input_string) 92 | 93 | @parameterized.parameters( 94 | # No nested anything. 95 | ([1, [2, 3]],), 96 | ([1, (2, 3)],), 97 | ((1, [2, 3]),), 98 | ((1, (2, 3)),), 99 | # Nothing outside the listed primitive types. 100 | ([1, set()],), 101 | ) 102 | def test_parse_invalid_entries(self, input_item): 103 | with self.assertRaisesRegex(TypeError, "contains unsupported"): 104 | self.parser.parse(input_item) 105 | 106 | def test_empty_string(self): 107 | with self.assertRaisesWithLiteralMatch( 108 | ValueError, _argument_parsers._EMPTY_STRING_ERROR_MESSAGE 109 | ): 110 | self.parser.parse("") 111 | 112 | @parameterized.parameters( 113 | # ValueError from ast.literal_eval 114 | "[foo, bar]", 115 | # SyntaxError from ast.literal_eval 116 | "['foo', 'bar'", 117 | "[1 2]", 118 | ) 119 | def test_parse_string_literal_error(self, input_string): 120 | with self.assertRaisesRegex(ValueError, ".*as a python literal.*"): 121 | self.parser.parse(input_string) 122 | 123 | 124 | class MultiEnumParserTest(parameterized.TestCase): 125 | 126 | def setUp(self): 127 | super().setUp() 128 | self.parser = _argument_parsers.MultiEnumParser( 129 | ["a", "a", ["a"], "b", "c", 1, [2], {"a": "d"}] 130 | ) 131 | 132 | @parameterized.parameters( 133 | ('["a"]', ["a"]), 134 | ('[["a"], "a"]', [["a"], "a"]), 135 | ('[1, "a", {"a": "d"}]', [1, "a", {"a": "d"}]), 136 | ) 137 | def test_parse_input(self, inputs, target): 138 | self.assertEqual(self.parser.parse(inputs), target) 139 | 140 | @parameterized.parameters( 141 | ("'a'", "evaluated as a"), 142 | (1, "Unsupported type"), 143 | ({"a"}, "Unsupported type"), 144 | ("''", "evaluated as a"), 145 | ) 146 | def test_invalid_input_type(self, input_item, regex): 147 | with self.assertRaisesRegex(TypeError, regex): 148 | self.parser.parse(input_item) 149 | 150 | @parameterized.parameters("[1, 2]", '["a", ["b"]]') 151 | def test_out_of_enum_values(self, inputs): 152 | with self.assertRaisesRegex(ValueError, "Argument values should be one of"): 153 | self.parser.parse(inputs) 154 | 155 | with self.assertRaisesRegex(ValueError, "Argument values should be one of"): 156 | self.parser.parse(inputs) 157 | 158 | 159 | class PossiblyNaiveDatetimeFlagTest(parameterized.TestCase): 160 | 161 | def test_parser_flag_type(self): 162 | parser = _argument_parsers.PossiblyNaiveDatetimeParser() 163 | self.assertEqual("datetime.datetime", parser.flag_type()) 164 | 165 | @parameterized.named_parameters( 166 | dict( 167 | testcase_name="date_string", 168 | value="2011-11-04", 169 | expected=datetime.datetime(2011, 11, 4, 0, 0), 170 | ), 171 | dict( 172 | testcase_name="second_string", 173 | value="2011-11-04T00:05:23", 174 | expected=datetime.datetime(2011, 11, 4, 0, 5, 23), 175 | ), 176 | dict( 177 | testcase_name="fractions_string", 178 | value="2011-11-04 00:05:23.283", 179 | expected=datetime.datetime(2011, 11, 4, 0, 5, 23, 283000), 180 | ), 181 | dict( 182 | testcase_name="utc_string", 183 | value="2011-11-04 00:05:23.283+00:00", 184 | expected=datetime.datetime(2011, 11, 4, 0, 5, 23, 283000, tzinfo=UTC), 185 | ), 186 | dict( 187 | testcase_name="offset_string", 188 | value="2011-11-04T00:05:23+04:00", 189 | expected=datetime.datetime(2011, 11, 4, 0, 5, 23, tzinfo=TZINFO_4HRS), 190 | ), 191 | dict( 192 | testcase_name="datetime", 193 | value=datetime.datetime(2011, 11, 4, 0, 0), 194 | expected=datetime.datetime(2011, 11, 4, 0, 0), 195 | ), 196 | ) 197 | def test_parse(self, value, expected): 198 | parser = _argument_parsers.PossiblyNaiveDatetimeParser() 199 | result = parser.parse(value) 200 | 201 | self.assertIsInstance(result, datetime.datetime) 202 | self.assertEqual(expected, result) 203 | 204 | def test_parse_separator_plus_or_minus_raises(self): 205 | parser = _argument_parsers.PossiblyNaiveDatetimeParser() 206 | with self.assertRaisesRegex(ValueError, r"separator between date and time"): 207 | # Avoid confusion of 1970-01-01T08:00:00 vs. 1970-01-01T00:00:00-08:00 208 | parser.parse("1970-01-01-08:00") 209 | 210 | 211 | if __name__ == "__main__": 212 | absltest.main() 213 | -------------------------------------------------------------------------------- /fancyflags/_auto.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Automatically builds flags from a callable signature.""" 16 | 17 | import datetime 18 | import enum 19 | import functools 20 | import inspect 21 | import sys 22 | import typing 23 | from typing import Any, Callable, Collection, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple 24 | import warnings 25 | 26 | from fancyflags import _definitions 27 | 28 | # TODO(b/178129474): Improve support for typing.Sequence subtypes. 29 | _TYPE_MAP = { 30 | List[bool]: _definitions.Sequence, # pylint: disable=unhashable-member 31 | List[float]: _definitions.Sequence, # pylint: disable=unhashable-member 32 | List[int]: _definitions.Sequence, # pylint: disable=unhashable-member 33 | List[str]: _definitions.Sequence, # pylint: disable=unhashable-member 34 | Sequence[bool]: _definitions.Sequence, 35 | Sequence[float]: _definitions.Sequence, 36 | Sequence[int]: _definitions.Sequence, 37 | Sequence[str]: _definitions.Sequence, 38 | Tuple[bool, ...]: _definitions.Sequence, 39 | Tuple[bool]: _definitions.Sequence, 40 | Tuple[float, ...]: _definitions.Sequence, 41 | Tuple[float]: _definitions.Sequence, 42 | Tuple[int, ...]: _definitions.Sequence, 43 | Tuple[int]: _definitions.Sequence, 44 | Tuple[str, ...]: _definitions.Sequence, 45 | Tuple[str]: _definitions.Sequence, 46 | bool: _definitions.Boolean, 47 | datetime.datetime: _definitions.DateTime, 48 | float: _definitions.Float, 49 | int: _definitions.Integer, 50 | str: _definitions.String, 51 | } 52 | if sys.version_info >= (3, 9): 53 | # Support PEP 585 type hints. 54 | _TYPE_MAP.update( 55 | { 56 | list[bool]: _definitions.Sequence, 57 | list[float]: _definitions.Sequence, 58 | list[int]: _definitions.Sequence, 59 | list[str]: _definitions.Sequence, 60 | tuple[bool, ...]: _definitions.Sequence, 61 | tuple[bool]: _definitions.Sequence, 62 | tuple[float, ...]: _definitions.Sequence, 63 | tuple[float]: _definitions.Sequence, 64 | tuple[int, ...]: _definitions.Sequence, 65 | tuple[int]: _definitions.Sequence, 66 | tuple[str, ...]: _definitions.Sequence, 67 | tuple[str]: _definitions.Sequence, 68 | } 69 | ) 70 | 71 | # Add optional versions of all types as well 72 | _TYPE_MAP.update({Optional[tp]: parser for tp, parser in _TYPE_MAP.items()}) 73 | 74 | _MISSING_TYPE_ANNOTATION = "Missing type annotation for argument {name!r}" 75 | _UNSUPPORTED_ARGUMENT_TYPE = ( 76 | "No matching flag type for argument {{name!r}} with type annotation: " 77 | "{{annotation}}\n" 78 | "Supported types:\n{}".format("\n".join(str(t) for t in _TYPE_MAP)) 79 | ) 80 | _MISSING_DEFAULT_VALUE = "Missing default value for argument {name!r}" 81 | _is_enum = lambda type_: inspect.isclass(type_) and issubclass(type_, enum.Enum) 82 | _is_unsupported_type = lambda type_: not (type_ in _TYPE_MAP or _is_enum(type_)) 83 | 84 | 85 | def get_typed_signature(fn: Callable[..., Any]) -> inspect.Signature: 86 | """Returns the signature of a callable with type annotations resolved. 87 | 88 | If postponed evaluation of type annotations (PEP 563) is enabled (e.g. via 89 | `from __future__ import annotations` in Python >= 3.7) then we will need to 90 | resolve the annotations from their string forms in order to access the real 91 | types within the signature. 92 | https://www.python.org/dev/peps/pep-0563/#resolving-type-hints-at-runtime 93 | 94 | Args: 95 | fn: A callable to get the signature of. 96 | 97 | Returns: 98 | An instance of `inspect.Signature`. 99 | """ 100 | type_hints = typing.get_type_hints(fn) or {} 101 | orig_signature = inspect.signature(fn) 102 | new_params = [] 103 | for key, orig_param in orig_signature.parameters.items(): 104 | new_params.append( 105 | inspect.Parameter( 106 | name=key, 107 | default=orig_param.default, 108 | annotation=type_hints.get(key, orig_param.annotation), 109 | kind=orig_param.kind, 110 | ) 111 | ) 112 | return orig_signature.replace(parameters=new_params) 113 | 114 | 115 | def auto( 116 | callable_fn: Callable[..., Any], 117 | *, 118 | strict: bool = True, 119 | skip_params: Collection[str] = (), 120 | ) -> Mapping[str, _definitions.Item]: 121 | """Automatically builds fancyflag definitions from a callable's signature. 122 | 123 | Example usage: 124 | ```python 125 | # Function 126 | ff.DEFINE_dict('my_function_settings', **ff.auto(my_module.my_function)) 127 | 128 | # Class constructor 129 | ff.DEFINE_dict('my_class_settings', **ff.auto(my_module.MyClass)) 130 | ``` 131 | 132 | Args: 133 | callable_fn: Generates flag definitions from this callable's signature. All 134 | arguments must have type annotations and default values. The following 135 | argument types are supported: * `bool`, `float`, `int`, or `str` scalars 136 | * Homogeneous sequences of these types * Optional scalars or sequences of 137 | these types 138 | strict: A bool, whether invalid input types and defaults should trigger an 139 | error (the default) or be silently ignored. Setting strict=False might 140 | silence real errors, but will allow decorated functions to contain 141 | non-default values, or values with defaults that can not be easily turned 142 | into a flag or overriden on the CLI. 143 | skip_params: Optional parameter names to skip defining flags for. 144 | 145 | Returns: 146 | Mapping from parameter names to fancyflags `Item`s, to be splatted into 147 | `ff.DEFINE_dict`. 148 | 149 | Raises: 150 | ValueError: If any of the arguments to `callable_fn` lacks a default value. 151 | TypeError: If any of the arguments to `callable_fn` lacks a type annotation. 152 | TypeError: If any of the arguments to `callable_fn` has an unsupported type. 153 | TypeError: If `callable_fn` is not callable. 154 | """ 155 | if not callable(callable_fn): 156 | raise TypeError(f"Not a callable: {callable_fn}.") 157 | 158 | # Work around issue with metaclass-wrapped classes, such as Sonnet v2 modules. 159 | if isinstance(callable_fn, type): 160 | signature = get_typed_signature(callable_fn.__init__) 161 | # Remove `self` from start of __init__ signature. 162 | unused_self, *parameters = signature.parameters.values() 163 | else: 164 | signature = get_typed_signature(callable_fn) 165 | parameters = signature.parameters.values() 166 | 167 | items: MutableMapping[str, _definitions.Item] = {} 168 | parameters: Iterable[inspect.Parameter] 169 | for param in parameters: 170 | if param.name in skip_params: 171 | continue 172 | 173 | # Check for potential errors. 174 | if param.annotation is inspect.Signature.empty: 175 | exception = TypeError(_MISSING_TYPE_ANNOTATION.format(name=param.name)) 176 | elif _is_unsupported_type(param.annotation): 177 | exception = TypeError( 178 | _UNSUPPORTED_ARGUMENT_TYPE.format( 179 | name=param.name, annotation=param.annotation 180 | ) 181 | ) 182 | else: 183 | exception = None 184 | 185 | # If we saw an error, decide whether to raise or skip based on strictness. 186 | if exception: 187 | if strict: 188 | raise exception 189 | else: 190 | warnings.warn( 191 | f"Caught an exception ({exception}) when defining flags for " 192 | f"parameter {param}; skipping because strict=False..." 193 | ) 194 | continue 195 | 196 | # Look up the corresponding Item to create. 197 | if _is_enum(param.annotation): 198 | item_constructor = functools.partial( 199 | _definitions.EnumClass, enum_class=param.annotation 200 | ) 201 | else: 202 | item_constructor = _TYPE_MAP[param.annotation] 203 | 204 | # If there is no default argument for this parameter, we set the 205 | # corresponding `Flag` as `required`. 206 | if param.default is inspect.Signature.empty: 207 | default = None 208 | required = True 209 | else: 210 | default = param.default 211 | required = False 212 | 213 | # TODO(b/177673667): Parse the help string from docstring. 214 | items[param.name] = item_constructor( 215 | default, 216 | help_string=param.name, 217 | required=required, 218 | ) 219 | 220 | return items 221 | -------------------------------------------------------------------------------- /fancyflags/_auto_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for fancyflags.auto.""" 16 | 17 | # Test that `auto` can still correctly infer parameter types when postponed 18 | # evaluation of type annotations (PEP 563) is enabled. 19 | from __future__ import annotations 20 | 21 | import abc 22 | import enum 23 | import sys 24 | from typing import List, Optional, Sequence, Tuple 25 | 26 | from absl import flags 27 | from absl.testing import absltest 28 | import fancyflags as ff 29 | from fancyflags import _auto 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | 34 | class MyEnum(enum.Enum): 35 | ZERO = 0 36 | ONE = 1 37 | 38 | 39 | class AutoTest(absltest.TestCase): 40 | 41 | def test_works_fn(self): 42 | # pylint: disable=unused-argument 43 | def my_function( 44 | str_: str = 'foo', 45 | int_: int = 10, 46 | float_: float = 1.0, 47 | bool_: bool = False, 48 | list_int: List[int] = [1, 2, 3], 49 | tuple_str: Tuple[str] = ('foo',), 50 | variadic_tuple_str: Tuple[str, ...] = ('foo', 'bar'), 51 | sequence_bool: Sequence[bool] = [True, False], 52 | optional_int: Optional[int] = None, 53 | optional_float: Optional[float] = None, 54 | optional_list_int: Optional[List[int]] = None, 55 | ): # pylint: disable=dangerous-default-value 56 | pass 57 | 58 | # pylint: enable=unused-argument 59 | expected_settings = { 60 | 'str_': 'foo', 61 | 'int_': 10, 62 | 'float_': 1.0, 63 | 'bool_': False, 64 | 'list_int': [1, 2, 3], 65 | 'tuple_str': ('foo',), 66 | 'variadic_tuple_str': ('foo', 'bar'), 67 | 'sequence_bool': [True, False], 68 | 'optional_int': None, 69 | 'optional_float': None, 70 | 'optional_list_int': None, 71 | } 72 | ff_dict = ff.auto(my_function) 73 | self.assertEqual(expected_settings.keys(), ff_dict.keys()) 74 | flag_values = flags.FlagValues() 75 | flag_holder = ff.DEFINE_dict( 76 | 'my_function_settings', 77 | flag_values, 78 | **ff_dict, 79 | ) 80 | flag_values(('./program', '')) 81 | self.assertEqual(flag_holder.value, expected_settings) 82 | 83 | @absltest.skipIf( 84 | condition=sys.version_info < (3, 9), 85 | reason='Generics syntax for standard collections requires Python >= 3.9', 86 | ) 87 | def test_works_fn_pep585(self): 88 | def my_function( 89 | list_int: list[int] = [1, 2, 3], 90 | tuple_str: tuple[str] = ('foo',), 91 | variadic_tuple_str: tuple[str, ...] = ('foo', 'bar'), 92 | optional_list_int: Optional[list[int]] = None, 93 | ): # pylint: disable=dangerous-default-value 94 | del list_int, tuple_str, variadic_tuple_str, optional_list_int # Unused. 95 | 96 | expected_settings = { 97 | 'list_int': [1, 2, 3], 98 | 'tuple_str': ('foo',), 99 | 'variadic_tuple_str': ('foo', 'bar'), 100 | 'optional_list_int': None, 101 | } 102 | ff_dict = ff.auto(my_function) 103 | self.assertEqual(expected_settings.keys(), ff_dict.keys()) 104 | flag_values = flags.FlagValues() 105 | flag_holder = ff.DEFINE_dict( 106 | 'my_function_settings', 107 | flag_values, 108 | **ff_dict, 109 | ) 110 | flag_values(('./program', '')) 111 | self.assertEqual(flag_holder.value, expected_settings) 112 | 113 | def test_works_enum_fn(self): 114 | # pylint: disable=unused-argument 115 | def my_function( 116 | str_: str = 'foo', int_: int = 10, enum_: MyEnum = MyEnum.ZERO 117 | ): 118 | pass 119 | 120 | # pylint: enable=unused-argument 121 | expected_settings = { 122 | 'str_': 'foo', 123 | 'int_': 10, 124 | 'enum_': MyEnum.ZERO, 125 | } 126 | ff_dict = ff.auto(my_function) 127 | self.assertCountEqual(expected_settings, ff_dict) 128 | 129 | def test_works_class(self): 130 | class MyClass: 131 | 132 | # pylint: disable=unused-argument 133 | def __init__( 134 | self, 135 | str_: str = 'foo', 136 | int_: int = 10, 137 | float_: float = 1.0, 138 | bool_: bool = False, 139 | list_int: List[int] = [1, 2, 3], 140 | tuple_str: Tuple[str] = ('foo',), 141 | variadic_tuple_str: Tuple[str, ...] = ('foo', 'bar'), 142 | sequence_bool: Sequence[bool] = [True, False], 143 | optional_int: Optional[int] = None, 144 | optional_float: Optional[float] = None, 145 | optional_list_int: Optional[List[int]] = None, 146 | ): # pylint: disable=dangerous-default-value 147 | pass 148 | 149 | # pylint: enable=unused-argument 150 | 151 | expected_settings = { 152 | 'str_': 'foo', 153 | 'int_': 10, 154 | 'float_': 1.0, 155 | 'bool_': False, 156 | 'list_int': [1, 2, 3], 157 | 'tuple_str': ('foo',), 158 | 'variadic_tuple_str': ('foo', 'bar'), 159 | 'sequence_bool': [True, False], 160 | 'optional_int': None, 161 | 'optional_float': None, 162 | 'optional_list_int': None, 163 | } 164 | ff_dict = ff.auto(MyClass) 165 | self.assertEqual(expected_settings.keys(), ff_dict.keys()) 166 | flag_values = flags.FlagValues() 167 | flag_holder = ff.DEFINE_dict( 168 | 'my_class_settings', 169 | flag_values, 170 | **ff_dict, 171 | ) 172 | flag_values(('./program', '')) 173 | self.assertEqual(flag_holder.value, expected_settings) 174 | 175 | @absltest.skipIf( 176 | condition=sys.version_info < (3, 9), 177 | reason='Generics syntax for standard collections requires Python >= 3.9', 178 | ) 179 | def test_works_class_pep585(self): 180 | class MyClass: 181 | 182 | def __init__( 183 | self, 184 | list_int: list[int] = [1, 2, 3], 185 | tuple_str: tuple[str] = ('foo',), 186 | variadic_tuple_str: tuple[str, ...] = ('foo', 'bar'), 187 | optional_list_int: Optional[list[int]] = None, 188 | ): # pylint: disable=dangerous-default-value 189 | # Unused. 190 | del list_int, tuple_str, variadic_tuple_str, optional_list_int 191 | 192 | expected_settings = { 193 | 'list_int': [1, 2, 3], 194 | 'tuple_str': ('foo',), 195 | 'variadic_tuple_str': ('foo', 'bar'), 196 | 'optional_list_int': None, 197 | } 198 | ff_dict = ff.auto(MyClass) 199 | self.assertEqual(expected_settings.keys(), ff_dict.keys()) 200 | flag_values = flags.FlagValues() 201 | flag_holder = ff.DEFINE_dict( 202 | 'my_class_settings', 203 | flag_values, 204 | **ff_dict, 205 | ) 206 | flag_values(('./program', '')) 207 | self.assertEqual(flag_holder.value, expected_settings) 208 | 209 | def test_works_metaclass(self): 210 | # This replicates an issue with Sonnet v2 modules, where the constructor 211 | # arguments are hidden by the metaclass. 212 | class MyMetaclass(abc.ABCMeta): 213 | 214 | def __call__(cls, *args, **kwargs): 215 | del args, kwargs 216 | 217 | class MyClass(metaclass=MyMetaclass): 218 | 219 | def __init__( 220 | self, a: int = 10, b: float = 1.0, c: Sequence[int] = (1, 2, 3) 221 | ): 222 | del a, b, c 223 | 224 | expected = {'a': 10, 'b': 1.0, 'c': (1, 2, 3)} 225 | ff_dict = ff.auto(MyClass) 226 | self.assertEqual(ff_dict.keys(), expected.keys()) 227 | 228 | flag_values = flags.FlagValues() 229 | flag_holder = ff.DEFINE_dict( 230 | 'my_meta_class_settings', 231 | flag_values, 232 | **ff_dict, 233 | ) 234 | flag_values(('./program', '')) 235 | self.assertEqual(flag_holder.value, expected) 236 | 237 | def test_required_item_with_no_default(self): 238 | def my_function(a: int, b: float = 1.0, c: Sequence[int] = (1, 2, 3)): 239 | del a, b, c 240 | 241 | items = ff.auto(my_function) 242 | self.assertTrue(items['a'].required) 243 | self.assertFalse(items['b'].required) 244 | self.assertFalse(items['c'].required) 245 | 246 | def test_error_if_missing_type_annotation(self): 247 | def my_function(a: int = 10, b=1.0, c: Sequence[int] = (1, 2, 3)): 248 | del a, b, c 249 | 250 | with self.assertRaisesWithLiteralMatch( 251 | TypeError, _auto._MISSING_TYPE_ANNOTATION.format(name='b') 252 | ): 253 | ff.auto(my_function) 254 | 255 | def test_error_if_unsupported_type(self): 256 | 257 | def my_function( 258 | a: int = 10, b: float = 1.0, c: Sequence[object] = (1, 2, 3) 259 | ): 260 | del a, b, c 261 | 262 | with self.assertRaisesWithLiteralMatch( 263 | TypeError, 264 | _auto._UNSUPPORTED_ARGUMENT_TYPE.format( 265 | name='c', annotation=Sequence[object] 266 | ), 267 | ): 268 | ff.auto(my_function) 269 | 270 | def test_no_error_if_nonstrict_unsupported_type(self): 271 | 272 | def my_function( 273 | a: int = 10, b: float = 1.0, c: Sequence[object] = (1, 2, 3) 274 | ): 275 | del a, b, c 276 | 277 | items = ff.auto(my_function, strict=False) 278 | self.assertSetEqual(set(items.keys()), {'a', 'b'}) 279 | 280 | def test_no_error_if_nonstrict_no_type_annotation(self): 281 | def my_function(a, b: int = 3): 282 | del a, b 283 | 284 | items = ff.auto(my_function, strict=False) 285 | self.assertSetEqual(set(items.keys()), {'b'}) 286 | 287 | def test_error_if_not_callable(self): 288 | with self.assertRaises(TypeError): 289 | ff.auto(3) # pytype: disable=wrong-arg-types 290 | 291 | # TODO(b/178129474): Improve support for typing.Sequence subtypes. 292 | @absltest.expectedFailure 293 | def test_supports_tuples_with_more_than_one_element(self): 294 | 295 | def my_function( 296 | three_ints: Tuple[int, int, int] = (1, 2, 3), 297 | zero_or_more_strings: Tuple[str, ...] = ('foo', 'bar'), 298 | ): 299 | del three_ints, zero_or_more_strings 300 | 301 | expected = { 302 | 'three_ints': (1, 2, 3), 303 | 'zero_or_more_strings': ('foo', 'bar'), 304 | } 305 | ff_dict = ff.auto(my_function) 306 | self.assertEqual(expected.keys(), ff_dict.keys()) 307 | flag_values = flags.FlagValues() 308 | flag_holder = ff.DEFINE_dict( 309 | 'my_function_settings', 310 | flag_values, 311 | **ff_dict, 312 | ) 313 | flag_values(('./program', '')) 314 | self.assertEqual(flag_holder.value, expected) 315 | 316 | def test_skip_params(self): 317 | def my_function(a: int, b: str = 'hi'): 318 | del a, b 319 | 320 | items = ff.auto(my_function, skip_params={'b'}) 321 | self.assertSetEqual(set(items.keys()), {'a'}) 322 | 323 | 324 | if __name__ == '__main__': 325 | absltest.main() 326 | -------------------------------------------------------------------------------- /fancyflags/_define_auto.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Automatic flags via ff.auto-compatible callables.""" 16 | 17 | from typing import Callable, Collection, Optional, TypeVar 18 | 19 | from absl import flags 20 | from fancyflags import _auto 21 | from fancyflags import _definitions 22 | from fancyflags import _flags 23 | 24 | _F = TypeVar('_F', bound=Callable) 25 | 26 | # Add current module to disclaimed module ids. 27 | flags.disclaim_key_flags() 28 | 29 | 30 | def DEFINE_auto( # pylint: disable=invalid-name 31 | name: str, 32 | fn: _F, 33 | help_string: Optional[str] = None, 34 | flag_values: flags.FlagValues = flags.FLAGS, 35 | *, 36 | strict: bool = True, 37 | skip_params: Collection[str] = (), 38 | ) -> flags.FlagHolder[_F]: 39 | """Defines a flag for an `ff.auto`-compatible constructor or callable. 40 | 41 | Automatically defines a set of dotted `ff.Item` flags corresponding to the 42 | constructor arguments and their default values. 43 | 44 | Overriding the value of a dotted flag will update the arguments used to invoke 45 | `fn`. This flag's value returns a callable `fn` with these values as bound 46 | arguments, 47 | 48 | Example usage: 49 | 50 | ```python 51 | # Defined in, e.g., datasets library. 52 | 53 | @dataclasses.dataclass 54 | class DataSettings: 55 | dataset_name: str = 'mnist' 56 | split: str = 'train' 57 | batch_size: int = 128 58 | 59 | # In main script. 60 | # Exposes flags: --data.dataset_name --data.split and --data.batch_size. 61 | DATA_SETTINGS = ff.DEFINE_auto('data', datasets.DataSettings, 'Data config') 62 | 63 | def main(argv): 64 | # del argv # Unused. 65 | dataset = datasets.load(DATA_SETTINGS.value()) 66 | # ... 67 | ``` 68 | 69 | Args: 70 | name: The name for the top-level flag. 71 | fn: An `ff.auto`-compatible `Callable`. 72 | help_string: Optional help string for this flag. If not provided, this will 73 | default to '{fn's module}.{fn's name}'. 74 | flag_values: An optional `flags.FlagValues` instance. 75 | strict: Whether to skip flag definitions for arguments without type hints, 76 | or for arguments with unknown types. 77 | skip_params: Optional parameter names to skip defining flags for. 78 | 79 | Returns: 80 | A `flags.FlagHolder`. 81 | """ 82 | arguments = _auto.auto(fn, strict=strict, skip_params=skip_params) 83 | # Define the individual flags. 84 | defaults = _definitions.define_flags(name, arguments, flag_values=flag_values) 85 | help_string = help_string or f'{fn.__module__}.{fn.__name__}' 86 | # Define a holder flag. 87 | return flags.DEFINE_flag( 88 | flag=_flags.AutoFlag( 89 | fn, 90 | defaults, 91 | name=name, 92 | default=None, 93 | parser=flags.ArgumentParser(), 94 | serializer=None, 95 | help_string=help_string, 96 | ), 97 | flag_values=flag_values, 98 | ) 99 | -------------------------------------------------------------------------------- /fancyflags/_define_auto_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for fancyflags._define_auto.""" 16 | 17 | import copy 18 | import dataclasses 19 | from typing import Sequence 20 | 21 | from absl import flags 22 | from absl.testing import absltest 23 | from fancyflags import _define_auto 24 | from fancyflags import _flags 25 | 26 | 27 | @dataclasses.dataclass 28 | class Point: 29 | x: float = 0.0 30 | y: float = 0.0 31 | label: str = '' 32 | enable: bool = False 33 | 34 | 35 | def greet(greeting: str = 'Hello', targets: Sequence[str] = ('world',)) -> str: 36 | return greeting + ' ' + ', '.join(targets) # pytype: disable=unsupported-operands 37 | 38 | 39 | class DefineAutoTest(absltest.TestCase): 40 | 41 | def test_dataclass(self): 42 | flag_values = flags.FlagValues() 43 | flag_holder = _define_auto.DEFINE_auto( 44 | 'point', Point, flag_values=flag_values 45 | ) 46 | flag_values(( 47 | './program', 48 | '--point.x=2.0', 49 | '--point.y=-1.5', 50 | '--point.label=p', 51 | '--nopoint.enable', 52 | )) 53 | expected = Point(2.0, -1.5, 'p', False) 54 | self.assertEqual(expected, flag_holder.value()) 55 | 56 | def test_dataclass_nodefaults(self): 57 | # Given a class constructor with non-default (required) argument(s)... 58 | 59 | @dataclasses.dataclass 60 | class MySettings: 61 | foo: str 62 | bar: int = 3 63 | 64 | # If we define auto flags for it... 65 | flag_values = flags.FlagValues() 66 | flag_holder = _define_auto.DEFINE_auto( 67 | 'thing', MySettings, flag_values=flag_values 68 | ) 69 | 70 | # Then the corresponding flag is required: not passing it should error. 71 | with self.assertRaisesRegex(flags.IllegalFlagValueError, 'thing.foo'): 72 | flag_values(('./program', '')) 73 | 74 | # Passing the required flag should work as normal. 75 | flag_values(('./program', '--thing.foo=hello')) 76 | expected = MySettings('hello', 3) 77 | self.assertEqual(expected, flag_holder.value()) 78 | 79 | def test_function(self): 80 | flag_values = flags.FlagValues() 81 | flag_holder = _define_auto.DEFINE_auto( 82 | 'greet', greet, flag_values=flag_values 83 | ) 84 | flag_values(( 85 | './program', 86 | '--greet.greeting=Hi there', 87 | "--greet.targets=('Alice', 'Bob')", 88 | )) 89 | expected = 'Hi there Alice, Bob' 90 | self.assertEqual(expected, flag_holder.value()) 91 | 92 | def test_override_kwargs(self): 93 | flag_values = flags.FlagValues() 94 | flag_holder = _define_auto.DEFINE_auto( 95 | 'point', Point, flag_values=flag_values 96 | ) 97 | flag_values(( 98 | './program', 99 | '--point.x=2.0', 100 | '--point.y=-1.5', 101 | '--point.label=p', 102 | '--point.enable', 103 | )) 104 | expected = Point(3.0, -1.5, 'p', True) 105 | # Here we override one of the arguments. 106 | self.assertEqual(expected, flag_holder.value(x=3.0)) 107 | 108 | def test_overriding_top_level_auto_flag_fails(self): 109 | flag_values = flags.FlagValues() 110 | _define_auto.DEFINE_auto('point', Point, flag_values=flag_values) 111 | with self.assertRaisesRegex( 112 | flags.IllegalFlagValueError, "Can't override an auto flag directly" 113 | ): 114 | flag_values(('./program', '--point=2.0')) 115 | 116 | def test_basic_serialization(self): 117 | flag_values = flags.FlagValues() 118 | _define_auto.DEFINE_auto('point', Point, flag_values=flag_values) 119 | 120 | # Accessing flag_holder.value would raise an error here, since flags haven't 121 | # been parsed yet. For consistency we access the value via flag_values 122 | # throughout the test, rather than through a returned `FlagHolder`. 123 | initial_point_value = copy.deepcopy(flag_values['point'].value()) 124 | 125 | # Parse flags, then serialize. 126 | flag_values(( 127 | './program', 128 | '--point.x=1.2', 129 | '--point.y=3.5', 130 | '--point.label=p', 131 | '--point.enable=True', 132 | )) 133 | 134 | self.assertEqual(flag_values['point'].serialize(), _flags._EMPTY) 135 | self.assertEqual(flag_values['point.x'].serialize(), '--point.x=1.2') 136 | self.assertEqual(flag_values['point.label'].serialize(), '--point.label=p') 137 | self.assertEqual(flag_values['point.enable'].serialize(), '--point.enable') 138 | 139 | parsed_point_value = copy.deepcopy(flag_values['point'].value()) 140 | 141 | self.assertEqual( 142 | parsed_point_value, Point(x=1.2, y=3.5, label='p', enable=True) 143 | ) 144 | self.assertNotEqual(parsed_point_value, initial_point_value) 145 | 146 | # Test a round trip. 147 | serialized_args = [ 148 | flag_values[name].serialize() 149 | for name in flag_values 150 | if name.startswith('point.') 151 | ] 152 | 153 | flag_values.unparse_flags() # Reset to defaults 154 | self.assertEqual(flag_values['point'].value(), initial_point_value) 155 | 156 | flag_values(['./program'] + serialized_args) 157 | self.assertEqual(flag_values['point'].value(), parsed_point_value) 158 | 159 | def test_disclaimed_module(self): 160 | flag_values = flags.FlagValues() 161 | _ = _define_auto.DEFINE_auto( 162 | 'greet', greet, 'help string', flag_values=flag_values 163 | ) 164 | defining_module = flag_values.find_module_defining_flag('greet') 165 | 166 | # The defining module should be the calling module, not the module where 167 | # the flag is defined. Otherwise the help for a module's flags will not be 168 | # printed unless the user uses --helpfull. 169 | self.assertIn('_define_auto_test', defining_module) 170 | 171 | def test_help_strings(self): 172 | flag_values = flags.FlagValues() 173 | 174 | # Should default to module.name, since the `greet` docstring is empty. 175 | _define_auto.DEFINE_auto('greet', greet, flag_values=flag_values) 176 | # Should use the custom help string. 177 | _define_auto.DEFINE_auto( 178 | 'point', Point, help_string='custom', flag_values=flag_values 179 | ) 180 | 181 | self.assertEqual(flag_values['greet'].help, f'{greet.__module__}.greet') 182 | self.assertEqual(flag_values['point'].help, 'custom') 183 | 184 | def test_manual_nostrict_overrides_no_default(self): 185 | # Given a function without type hints... 186 | def my_function(a): 187 | return a + 1 # pytype: disable=unsupported-operands 188 | 189 | # If we define an auto flag using this function in non-strict mode... 190 | flag_values = flags.FlagValues() 191 | flag_holder = _define_auto.DEFINE_auto( 192 | 'foo', my_function, flag_values=flag_values, strict=False 193 | ) 194 | 195 | # Calling the function without arguments should error. 196 | flag_values(('./program', '')) 197 | with self.assertRaises(TypeError): 198 | flag_holder.value() # pytype: disable=missing-parameter 199 | 200 | # Calling with arguments should work fine. 201 | self.assertEqual(flag_holder.value(a=2), 3) # pytype: disable=wrong-arg-types 202 | 203 | def test_skip_params(self): 204 | flag_values = flags.FlagValues() 205 | _define_auto.DEFINE_auto( 206 | 'greet', greet, flag_values=flag_values, skip_params=('targets',) 207 | ) 208 | self.assertNotIn('greet.targets', flag_values) 209 | 210 | 211 | if __name__ == '__main__': 212 | absltest.main() 213 | -------------------------------------------------------------------------------- /fancyflags/_definitions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Functionality for defining `Item`s and dict flags.""" 16 | 17 | import collections 18 | import enum 19 | from typing import Any, Generic, Iterable, Mapping, Optional, Type, TypeVar, Union 20 | 21 | from absl import flags 22 | from fancyflags import _argument_parsers 23 | from fancyflags import _flags 24 | 25 | _T = TypeVar("_T") 26 | _EnumT = TypeVar("_EnumT", bound=enum.Enum) 27 | _MappingT = TypeVar("_MappingT", bound=Mapping[str, Any]) 28 | 29 | SEPARATOR = "." 30 | 31 | _NOT_A_DICT_OR_ITEM = """ 32 | DEFINE_dict only supports flat or nested dictionaries, and these must contain 33 | `ff.Item`s or `ff.MultiItems. Found type {} in this definition. 34 | """ 35 | 36 | # Add this module to absl's exclusion set for determining the calling modules. 37 | flags.disclaim_key_flags() 38 | 39 | 40 | def DEFINE_dict(*args, **kwargs): # pylint: disable=invalid-name 41 | """Defines a flat or nested dictionary flag. 42 | 43 | Usage example: 44 | 45 | ```python 46 | import fancyflags as ff 47 | 48 | ff.DEFINE_dict( 49 | "image_settings", 50 | mode=ff.String("pad"), 51 | sizes=dict( 52 | width=ff.Integer(5), 53 | height=ff.Integer(7), 54 | scale=ff.Float(0.5), 55 | ) 56 | ) 57 | 58 | This creates a flag `FLAGS.image_settings`, with a default value of 59 | 60 | ```python 61 | { 62 | "mode": "pad", 63 | "sizes": { 64 | "width": 5, 65 | "height": 7, 66 | "scale": 0.5, 67 | } 68 | } 69 | ``` 70 | 71 | Each item in the definition (e.g. ff.Integer(...)) corresponds to a flag that 72 | can be overridden from the command line using "dot" notation. For example, the 73 | following command overrides the `height` item in the nested dictionary defined 74 | above: 75 | 76 | ``` 77 | python script_name.py -- --image_settings.sizes.height=10 78 | ``` 79 | 80 | Args: 81 | *args: One or two positional arguments are expected: 1. A string containing 82 | the root name for this flag. This must be set. 2. Optionally, a 83 | `flags.FlagValues` object that will hold the Flags. If not set, the usual 84 | global `flags.FLAGS` object will be used. 85 | **kwargs: One or more keyword arguments, where the value is either an 86 | `ff.Item` such as `ff.String(...)` or `ff.Integer(...)` or a dict with the 87 | same constraints. 88 | 89 | Returns: 90 | A `FlagHolder` instance. 91 | """ 92 | if not args: 93 | raise ValueError( 94 | "Please supply one positional argument containing the " 95 | "top-level flag name for the dict." 96 | ) 97 | 98 | if not kwargs: 99 | raise ValueError( 100 | "Please supply at least one keyword argument defining a flag." 101 | ) 102 | if len(args) > 2: 103 | raise ValueError( 104 | "Please supply at most two positional arguments, the " 105 | "first containing the top-level flag name for the dict " 106 | "and, optionally and unusually, a second positional " 107 | "argument to override the flags.FlagValues instance to " 108 | "use." 109 | ) 110 | 111 | if not isinstance(args[0], str): 112 | raise ValueError( 113 | "The first positional argument must be a string " 114 | "containing top-level flag name for the dict. Got a {}.".format( 115 | type(args[0]).__name__ 116 | ) 117 | ) 118 | 119 | if len(args) == 2: 120 | if not isinstance(args[1], flags.FlagValues): 121 | raise ValueError( 122 | "If supplying a second positional argument, this must " 123 | "be a flags.FlagValues instance. Got a {}. If you meant " 124 | "to define a flag, note these must be supplied as " 125 | "keyword arguments. ".format(type(args[1]).__name__) 126 | ) 127 | flag_values = args[1] 128 | else: 129 | flag_values = flags.FLAGS 130 | 131 | flag_name = args[0] 132 | 133 | shared_dict = define_flags(flag_name, kwargs, flag_values=flag_values) 134 | 135 | # TODO(b/177672282): Can we persuade pytype to correctly infer the type of the 136 | # flagholder's .value attribute? 137 | # We register a dummy flag that returns `shared_dict` as a value. 138 | return flags.DEFINE_flag( 139 | _flags.DictFlag( 140 | shared_dict, 141 | name=flag_name, 142 | default=shared_dict, 143 | parser=flags.ArgumentParser(), 144 | serializer=None, 145 | help_string="Unused help string.", 146 | ), 147 | flag_values=flag_values, 148 | ) 149 | 150 | 151 | def define_flags( 152 | name: str, 153 | name_to_item: _MappingT, 154 | flag_values: flags.FlagValues = flags.FLAGS, 155 | ) -> _MappingT: 156 | """Defines dot-delimited flags from a flat or nested dict of `ff.Item`s. 157 | 158 | Args: 159 | name: The top-level name to prepend to each flag. 160 | name_to_item: A flat or nested dictionary, where each final value is an 161 | `ff.Item` such as `ff.String(...)` or `ff.Integer(...)`. 162 | flag_values: The `flags.FlagValues` instance to use. By default this is 163 | `flags.FLAGS`. Most users will not need to override this. 164 | 165 | Returns: 166 | A flat or nested dictionary containing the default values in `name_to_item`. 167 | Overriding any of the flags defined by this function will also update the 168 | corresponding entry in the returned dictionary. 169 | """ 170 | # Each flag that we will define holds a reference to `shared_dict`, which is 171 | # a flat or nested dictionary containing the default values. 172 | 173 | shared_dict = _extract_defaults(name_to_item) 174 | 175 | # We create flags for each leaf item (e.g. ff.Integer(...)). 176 | 177 | # These are the flags that users will actually interact with when overriding 178 | # flags from the command line, however they will not access directly in their 179 | # scripts. It is also the job of these flags to update the corresponding 180 | # values in `shared_dict`, whenever their own values change. 181 | 182 | def recursively_define_flags(namespace, maybe_item): 183 | if isinstance(maybe_item, collections.abc.Mapping): 184 | for key, value in maybe_item.items(): 185 | recursively_define_flags(namespace + (key,), value) 186 | else: 187 | assert isinstance(maybe_item, (Item, MultiItem)) 188 | maybe_item.define(namespace, {name: shared_dict}, flag_values) 189 | 190 | for key, value in name_to_item.items(): 191 | recursively_define_flags(namespace=(name, key), maybe_item=value) 192 | 193 | return shared_dict 194 | 195 | 196 | def _extract_defaults(name_to_item): 197 | """Converts a flat or nested dict into a flat or nested dict of defaults.""" 198 | 199 | result = {} 200 | for key, value in name_to_item.items(): 201 | if isinstance(value, (Item, MultiItem)): 202 | result[key] = value.default 203 | elif isinstance(value, dict): 204 | result[key] = _extract_defaults(value) 205 | else: 206 | type_name = type(value).__name__ 207 | raise TypeError(_NOT_A_DICT_OR_ITEM.format(type_name)) 208 | return result 209 | 210 | 211 | class Item(Generic[_T]): 212 | """Defines a flag for leaf items in the dictionary.""" 213 | 214 | def __init__( 215 | self, 216 | default: Optional[_T], 217 | help_string: str, 218 | parser: flags.ArgumentParser, 219 | serializer: Optional[flags.ArgumentSerializer] = None, 220 | *, 221 | required: bool = False, 222 | ): 223 | """Initializes a new `Item`. 224 | 225 | Args: 226 | default: Default value of the flag that this instance will create. 227 | help_string: Help string for the flag that this instance will create. If 228 | `None`, then the dotted flag name will be used as the help string. 229 | parser: A `flags.ArgumentParser` used to parse command line input. 230 | serializer: An optional custom `flags.ArgumentSerializer`. By default, the 231 | flag defined by this class will use an instance of the base 232 | `flags.ArgumentSerializer`. 233 | required: Whether or not this item is required. If True, the corresponding 234 | abseil flag will be marked as required. 235 | """ 236 | # Flags run the following lines of parsing code during initialization. 237 | # See Flag._set_default in absl/flags/_flag.py 238 | 239 | # It's useful to repeat it here so that users will see any errors when the 240 | # Item is initialized, rather than when define() is called later. 241 | 242 | # The only minor difference is that Flag._set_default calls Flag._parse, 243 | # which also catches and modifies the exception type. 244 | if default is None: 245 | self.default = default 246 | else: 247 | if required: 248 | # Mirror the strict behavior of abseil flags. 249 | raise ValueError( 250 | "If marking an Item as required, the default must be None." 251 | ) 252 | self.default = parser.parse(default) # pytype: disable=wrong-arg-types 253 | 254 | self.required = required 255 | self._help_string = help_string 256 | self._parser = parser 257 | 258 | if serializer is None: 259 | self._serializer = flags.ArgumentSerializer() 260 | else: 261 | self._serializer = serializer 262 | 263 | def define( 264 | self, 265 | namespace: str, 266 | shared_dict, 267 | flag_values: flags.FlagValues, 268 | ) -> flags.FlagHolder[_T]: 269 | """Defines a flag that when parsed will update a shared dictionary. 270 | 271 | Args: 272 | namespace: A sequence of strings that define the name of this flag. For 273 | example, `("foo", "bar")` will correspond to a flag named `foo.bar`. 274 | shared_dict: A dictionary that is shared by the top level dict flag. When 275 | the individual flag created by this method is parsed, it will also write 276 | the parsed value into `shared_dict`. The `namespace` determines the flat 277 | or nested key when storing the parsed value. 278 | flag_values: The `flags.FlagValues` instance to use. 279 | 280 | Returns: 281 | A new flags.FlagHolder instance. 282 | """ 283 | name = SEPARATOR.join(namespace) 284 | help_string = name if self._help_string is None else self._help_string 285 | return flags.DEFINE_flag( 286 | _flags.ItemFlag( 287 | shared_dict, 288 | namespace, 289 | parser=self._parser, 290 | serializer=self._serializer, 291 | name=name, 292 | default=self.default, 293 | help_string=help_string, 294 | ), 295 | flag_values=flag_values, 296 | required=self.required, 297 | ) 298 | 299 | 300 | class Boolean(Item[bool]): 301 | """Matches behaviour of flags.DEFINE_boolean.""" 302 | 303 | def __init__( 304 | self, 305 | default: Optional[bool], 306 | help_string: Optional[str] = None, 307 | *, 308 | required: bool = False, 309 | ): 310 | super().__init__( 311 | default, 312 | help_string, 313 | flags.BooleanParser(), 314 | required=required, 315 | ) 316 | 317 | 318 | # TODO(b/177673597) Better document the different enum class options and 319 | # possibly recommend some over others. 320 | 321 | 322 | class Enum(Item[str]): 323 | """Matches behaviour of flags.DEFINE_enum.""" 324 | 325 | def __init__( 326 | self, 327 | default: Optional[str], 328 | enum_values: Iterable[str], 329 | help_string: Optional[str] = None, 330 | *, 331 | case_sensitive: bool = True, 332 | required: bool = False, 333 | ): 334 | parser = flags.EnumParser(tuple(enum_values), case_sensitive) 335 | super().__init__(default, help_string, parser, required=required) 336 | 337 | 338 | class EnumClass(Item[_EnumT]): 339 | """Matches behaviour of flags.DEFINE_enum_class.""" 340 | 341 | def __init__( 342 | self, 343 | default: Optional[_EnumT], 344 | enum_class: Type[_EnumT], 345 | help_string: Optional[str] = None, 346 | *, 347 | case_sensitive: bool = False, 348 | required: bool = False, 349 | ): 350 | parser = flags.EnumClassParser(enum_class, case_sensitive=case_sensitive) 351 | super().__init__( 352 | default, 353 | help_string, 354 | parser, 355 | flags.EnumClassSerializer(lowercase=False), 356 | required=required, 357 | ) 358 | 359 | 360 | class Float(Item[float]): 361 | """Matches behaviour of flags.DEFINE_float.""" 362 | 363 | def __init__( 364 | self, 365 | default: Optional[float], 366 | help_string: Optional[str] = None, 367 | *, 368 | required: bool = False, 369 | ): 370 | super().__init__( 371 | default, help_string, flags.FloatParser(), required=required 372 | ) 373 | 374 | 375 | class Integer(Item[int]): 376 | """Matches behaviour of flags.DEFINE_integer.""" 377 | 378 | def __init__( 379 | self, 380 | default: Optional[int], 381 | help_string: Optional[str] = None, 382 | required: bool = False, 383 | ): 384 | super().__init__( 385 | default, help_string, flags.IntegerParser(), required=required 386 | ) 387 | 388 | 389 | class Sequence(Item, Generic[_T]): 390 | r"""Defines a flag for a list or tuple of simple numeric types or strings. 391 | 392 | Here is an example of overriding a Sequence flag within a dict-flag named 393 | "settings" from the command line, with a list of values. 394 | 395 | ``` 396 | --settings.sequence=[1,2,3] 397 | ``` 398 | 399 | To include spaces, either quote the entire literal, or escape spaces as: 400 | 401 | ``` 402 | --settings.sequence="[1, 2, 3]" 403 | --settings.sequence=[1,\ 2,\ 3] 404 | ``` 405 | """ 406 | 407 | def __init__( 408 | self, 409 | default: Optional[Iterable[_T]], 410 | help_string: Optional[str] = None, 411 | required: bool = False, 412 | ): 413 | super().__init__( 414 | default, 415 | help_string, 416 | _argument_parsers.SequenceParser(), 417 | required=required, 418 | ) 419 | 420 | 421 | class String(Item[str]): 422 | """Matches behaviour of flags.DEFINE_string.""" 423 | 424 | def __init__( 425 | self, 426 | default: Optional[str], 427 | help_string: Optional[str] = None, 428 | *, 429 | required: bool = False, 430 | ): 431 | super().__init__( 432 | default, 433 | help_string, 434 | flags.ArgumentParser(), 435 | required=required, 436 | ) 437 | 438 | 439 | class DateTime(Item): 440 | 441 | def __init__( 442 | self, 443 | default: Optional[str], 444 | help_string: Optional[str] = None, 445 | *, 446 | required: bool = False, 447 | ): 448 | super(DateTime, self).__init__( 449 | default, 450 | help_string, 451 | _argument_parsers.PossiblyNaiveDatetimeParser(), 452 | required=required, 453 | ) 454 | 455 | 456 | class StringList(Item[Iterable[str]]): 457 | """A flag that implements the same behavior as absl.flags.DEFINE_list. 458 | 459 | Can be overwritten as --my_flag="a,list,of,commaseparated,strings" 460 | """ 461 | 462 | def __init__( 463 | self, 464 | default: Optional[Iterable[str]], 465 | help_string: Optional[str] = None, 466 | ): 467 | serializer = flags.CsvListSerializer(",") 468 | super().__init__(default, help_string, flags.ListParser(), serializer) 469 | 470 | 471 | # MultiFlag-related functionality. 472 | 473 | 474 | class MultiItem(Generic[_T]): 475 | """Class for items that can appear multiple times on the command line. 476 | 477 | See Item class for more details on methods and usage. 478 | """ 479 | 480 | def __init__( 481 | self, 482 | default: Union[None, _T, Iterable[_T]], 483 | help_string: str, 484 | parser: flags.ArgumentParser, 485 | serializer: Optional[flags.ArgumentSerializer] = None, 486 | ): 487 | if default is None: 488 | self.default = default 489 | else: 490 | if isinstance(default, collections.abc.Iterable) and not isinstance( 491 | default, (str, bytes) 492 | ): 493 | # Convert all non-string iterables to lists. 494 | default = list(default) 495 | 496 | if not isinstance(default, list): 497 | # Turn single items into single-value lists. 498 | default = [default] 499 | 500 | # Ensure each individual value is well-formed. 501 | self.default = [parser.parse(item) for item in default] 502 | 503 | self._help_string = help_string 504 | self._parser = parser 505 | 506 | if serializer is None: 507 | self._serializer = flags.ArgumentSerializer() 508 | else: 509 | self._serializer = serializer 510 | 511 | def define( 512 | self, 513 | namespace: str, 514 | shared_dict, 515 | flag_values, 516 | ) -> flags.FlagHolder[Iterable[_T]]: 517 | name = SEPARATOR.join(namespace) 518 | help_string = name if self._help_string is None else self._help_string 519 | return flags.DEFINE_flag( 520 | _flags.MultiItemFlag( 521 | shared_dict, 522 | namespace, 523 | parser=self._parser, 524 | serializer=self._serializer, 525 | name=name, 526 | default=self.default, 527 | help_string=help_string, 528 | ), 529 | flag_values=flag_values, 530 | ) 531 | 532 | 533 | class MultiEnum(Item[_T]): 534 | """Defines a flag for lists of values of any type, matched to enum_values.""" 535 | 536 | def __init__( 537 | self, 538 | default: Union[None, _T, Iterable[_T]], 539 | enum_values: Iterable[_T], 540 | help_string: Optional[str] = None, 541 | ): 542 | parser = _argument_parsers.MultiEnumParser(enum_values) 543 | serializer = flags.ArgumentSerializer() 544 | _ = parser.parse(enum_values) 545 | super().__init__(default, help_string, parser, serializer) 546 | 547 | 548 | class MultiEnumClass(MultiItem): 549 | """Matches behaviour of flags.DEFINE_multi_enum_class.""" 550 | 551 | def __init__( 552 | self, 553 | default: Union[None, _EnumT, Iterable[_EnumT]], 554 | enum_class: Type[_EnumT], 555 | help_string: Optional[str] = None, 556 | ): 557 | parser = flags.EnumClassParser(enum_class) 558 | serializer = flags.EnumClassListSerializer(",", lowercase=False) 559 | super().__init__(default, help_string, parser, serializer) 560 | 561 | 562 | class MultiString(MultiItem): 563 | """Matches behaviour of flags.DEFINE_multi_string.""" 564 | 565 | def __init__(self, default, help_string=None): 566 | parser = flags.ArgumentParser() 567 | serializer = flags.ArgumentSerializer() 568 | super().__init__(default, help_string, parser, serializer) 569 | 570 | 571 | # Misc DEFINE_*s. 572 | 573 | 574 | def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin 575 | name: str, 576 | default: Optional[Iterable[_T]], 577 | enum_values: Iterable[_T], 578 | help: str, 579 | flag_values=flags.FLAGS, 580 | **args, 581 | ) -> flags.FlagHolder[_T]: 582 | """Defines flag for MultiEnum.""" 583 | parser = _argument_parsers.MultiEnumParser(enum_values) 584 | serializer = flags.ArgumentSerializer() 585 | return flags.DEFINE( 586 | parser, 587 | name, 588 | default, 589 | help, 590 | flag_values, 591 | serializer, 592 | **args, 593 | ) 594 | 595 | 596 | def DEFINE_sequence( # pylint: disable=invalid-name,redefined-builtin 597 | name: str, 598 | default: Optional[Iterable[_T]], 599 | help: str, 600 | flag_values=flags.FLAGS, 601 | **args, 602 | ) -> flags.FlagHolder[Iterable[_T]]: 603 | """Defines a flag for a list or tuple of simple types. See `Sequence` docs.""" 604 | parser = _argument_parsers.SequenceParser() 605 | serializer = flags.ArgumentSerializer() 606 | return flags.DEFINE( 607 | parser, 608 | name, 609 | default, 610 | help, 611 | flag_values, 612 | serializer, 613 | **args, 614 | ) 615 | -------------------------------------------------------------------------------- /fancyflags/_definitions_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for definitions.""" 16 | 17 | import copy 18 | import datetime 19 | import enum 20 | from typing import Any, Callable 21 | 22 | from absl import flags 23 | from absl.testing import absltest 24 | from absl.testing import parameterized 25 | # definitions almost exactly corresponds to the public API, so aliasing the 26 | # import here for better illustrative tests. 27 | from fancyflags import _definitions as ff 28 | from fancyflags import _flags 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | class MyEnum(enum.Enum): 34 | A = 1 35 | B = 2 36 | 37 | 38 | class DifferentEnum(enum.Enum): 39 | C = 1 40 | D = 2 41 | 42 | 43 | class FancyflagsTest(absltest.TestCase): 44 | 45 | def test_define_with_global_flagvalues(self): 46 | # Since ff.DEFINE_dict uses an optional positional argument to specify a 47 | # custom FlagValues instance, we run nearly the same test as below to make 48 | # sure both global (default) and custom FlagValues work. 49 | unused_flagholder = ff.DEFINE_dict( 50 | "flat_dict", 51 | integer_field=ff.Integer(1, "integer field"), 52 | string_field=ff.String(""), 53 | string_list_field=ff.StringList(["a", "b", "c"], "string list field"), 54 | ) 55 | 56 | expected = { 57 | "integer_field": 1, 58 | "string_field": "", 59 | "string_list_field": ["a", "b", "c"], 60 | } 61 | self.assertEqual(FLAGS.flat_dict, expected) 62 | 63 | # These flags should also exist, although we won't access them in practice. 64 | self.assertEqual(FLAGS["flat_dict.integer_field"].value, 1) 65 | self.assertEqual(FLAGS["flat_dict.string_field"].value, "") 66 | 67 | # Custom help string. 68 | self.assertEqual(FLAGS["flat_dict.integer_field"].help, "integer field") 69 | # Default help string. 70 | self.assertEqual( 71 | FLAGS["flat_dict.string_field"].help, "flat_dict.string_field" 72 | ) 73 | 74 | def test_define_with_custom_flagvalues(self): 75 | # Since ff.DEFINE_dict uses an optional positional argument to specify a 76 | # custom FlagValues instance, we run nearly the same test as above to make 77 | # sure both global (default) and custom FlagValues work. 78 | flag_values = flags.FlagValues() 79 | unused_flagholder = ff.DEFINE_dict( 80 | "flat_dict", 81 | flag_values, 82 | integer_field=ff.Integer(1, "integer field"), 83 | string_field=ff.String(""), 84 | string_list_field=ff.StringList(["a", "b", "c"], "string list field"), 85 | ) 86 | 87 | expected = { 88 | "integer_field": 1, 89 | "string_field": "", 90 | "string_list_field": ["a", "b", "c"], 91 | } 92 | flag_values(("./program", "")) 93 | self.assertEqual(flag_values.flat_dict, expected) 94 | 95 | # These flags should also exist, although we won't access them in practice. 96 | self.assertEqual(flag_values["flat_dict.integer_field"].value, 1) 97 | self.assertEqual(flag_values["flat_dict.string_field"].value, "") 98 | 99 | # Custom help string. 100 | self.assertEqual( 101 | flag_values["flat_dict.integer_field"].help, "integer field" 102 | ) 103 | # Default help string. 104 | self.assertEqual( 105 | flag_values["flat_dict.string_field"].help, "flat_dict.string_field" 106 | ) 107 | 108 | def test_define_flat(self): 109 | flag_values = flags.FlagValues() 110 | flag_holder = ff.DEFINE_dict( 111 | "flat_dict", 112 | flag_values, 113 | integer_field=ff.Integer(1, "integer field"), 114 | string_field=ff.String(""), 115 | string_list_field=ff.StringList(["a", "b", "c"], "string list field"), 116 | ) 117 | 118 | # This should return a single dict with the default values specified above. 119 | expected = { 120 | "integer_field": 1, 121 | "string_field": "", 122 | "string_list_field": ["a", "b", "c"], 123 | } 124 | flag_values(("./program", "")) 125 | self.assertEqual(flag_values.flat_dict, expected) 126 | self.assertEqual(flag_holder.value, expected) 127 | 128 | def test_define_nested(self): 129 | flag_values = flags.FlagValues() 130 | flag_holder = ff.DEFINE_dict( 131 | "nested_dict", 132 | flag_values, 133 | integer_field=ff.Integer(1, "integer field"), 134 | sub_dict=dict(string_field=ff.String("", "string field")), 135 | ) 136 | 137 | # This should return a single dict with the default values specified above. 138 | expected = {"integer_field": 1, "sub_dict": {"string_field": ""}} 139 | 140 | flag_values(("./program", "")) 141 | self.assertEqual(flag_values.nested_dict, expected) 142 | self.assertEqual(flag_holder.value, expected) 143 | 144 | # These flags should also exist, although we won't access them in practice. 145 | self.assertEqual(flag_values["nested_dict.integer_field"].value, 1) 146 | self.assertEqual(flag_values["nested_dict.sub_dict.string_field"].value, "") 147 | 148 | def test_no_name_error(self): 149 | with self.assertRaisesRegex(ValueError, "one positional argument"): 150 | ff.DEFINE_dict( 151 | integer_field=ff.Integer(1, "integer field"), 152 | ) 153 | 154 | def test_no_kwargs_error(self): 155 | with self.assertRaisesRegex(ValueError, "one keyword argument"): 156 | ff.DEFINE_dict("no_kwargs") 157 | 158 | def test_too_many_positional_args_error(self): 159 | with self.assertRaisesRegex(ValueError, "at most two positional"): 160 | ff.DEFINE_dict( 161 | "name", 162 | ff.String("foo", "string"), 163 | ff.String("bar", "string"), 164 | integer_field=ff.Integer(1, "integer field"), 165 | ) 166 | 167 | def test_flag_name_error(self): 168 | with self.assertRaisesRegex(ValueError, "must be a string"): 169 | ff.DEFINE_dict( 170 | ff.String("name", "string flag"), 171 | ff.String("stringflag", "string"), 172 | integer_field=ff.Integer(1, "integer field"), 173 | ) 174 | 175 | def test_flag_values_error(self): 176 | with self.assertRaisesRegex(ValueError, "FlagValues instance"): 177 | ff.DEFINE_dict( 178 | "name", 179 | ff.String("stringflag", "string"), 180 | integer_field=ff.Integer(1, "integer field"), 181 | ) 182 | 183 | def test_define_valid_enum(self): 184 | flag_values = flags.FlagValues() 185 | flag_holder = ff.DEFINE_dict( 186 | "valid_enum", 187 | flag_values, 188 | padding=ff.Enum("same", ["same", "valid"], "enum field"), 189 | ) 190 | 191 | flag_values(("./program", "")) 192 | self.assertEqual(flag_holder.value, {"padding": "same"}) 193 | 194 | def test_define_valid_case_insensitive_enum(self): 195 | flag_values = flags.FlagValues() 196 | flag_holder = ff.DEFINE_dict( 197 | "valid_case_sensitive", 198 | flag_values, 199 | padding=ff.Enum( 200 | "Same", ["same", "valid"], "enum field", case_sensitive=False 201 | ), 202 | ) 203 | flag_values(("./program", "")) 204 | self.assertEqual(flag_holder.value, {"padding": "same"}) 205 | 206 | def test_define_invalid_enum(self): 207 | with self.assertRaises(ValueError): 208 | ff.Enum("invalid", ["same", "valid"], "enum field") 209 | 210 | def test_define_invalid_case_sensitive_enum(self): 211 | with self.assertRaises(ValueError): 212 | ff.Enum("Same", ["same", "valid"], "enum field") 213 | 214 | def test_define_valid_enum_class(self): 215 | flag_values = flags.FlagValues() 216 | flag_holder = ff.DEFINE_dict( 217 | "valid_enum_class", 218 | flag_values, 219 | my_enum=ff.EnumClass(MyEnum.A, MyEnum, "enum class field"), 220 | ) 221 | flag_values(("./program", "")) 222 | self.assertEqual(flag_holder.value, {"my_enum": MyEnum.A}) 223 | 224 | def test_define_invalid_enum_class(self): 225 | with self.assertRaises(ValueError): 226 | ff.EnumClass(DifferentEnum.C, MyEnum) 227 | 228 | 229 | class ExtractDefaultsTest(absltest.TestCase): 230 | 231 | def test_valid_flat(self): 232 | result = ff._extract_defaults({ 233 | "integer_field": ff.Integer(10, "Integer field"), 234 | "string_field": ff.String("default", "String field"), 235 | }) 236 | expected = {"integer_field": 10, "string_field": "default"} 237 | self.assertEqual(result, expected) 238 | 239 | def test_valid_nested(self): 240 | result = ff._extract_defaults({ 241 | "integer_field": ff.Integer(10, "Integer field"), 242 | "string_field": ff.String("default", "String field"), 243 | "nested": { 244 | "float_field": ff.Float(3.1, "Float field"), 245 | }, 246 | }) 247 | expected = { 248 | "integer_field": 10, 249 | "string_field": "default", 250 | "nested": {"float_field": 3.1}, 251 | } 252 | self.assertEqual(result, expected) 253 | 254 | def test_invalid_container(self): 255 | expected_message = ff._NOT_A_DICT_OR_ITEM.format("list") 256 | with self.assertRaisesWithLiteralMatch(TypeError, expected_message): 257 | ff._extract_defaults({ 258 | "integer_field": ff.Integer(10, "Integer field"), 259 | "string_field": ff.String("default", "String field"), 260 | "nested": [ff.Float(3.1, "Float field")], 261 | }) 262 | 263 | def test_invalid_flat_leaf(self): 264 | expected_message = ff._NOT_A_DICT_OR_ITEM.format("int") 265 | with self.assertRaisesWithLiteralMatch(TypeError, expected_message): 266 | ff._extract_defaults({ 267 | "string_field": ff.String("default", "String field"), 268 | "naughty_field": 100, 269 | }) 270 | 271 | def test_invalid_nested_leaf(self): 272 | expected_message = ff._NOT_A_DICT_OR_ITEM.format("bool") 273 | with self.assertRaisesWithLiteralMatch(TypeError, expected_message): 274 | ff._extract_defaults({ 275 | "string_field": ff.String("default", "String field"), 276 | "nested": { 277 | "naughty_field": True, 278 | }, 279 | }) 280 | 281 | def test_overriding_top_level_dict_flag_fails(self): 282 | flag_values = flags.FlagValues() 283 | ff.DEFINE_dict( 284 | "top_level_dict", 285 | flag_values, 286 | integer_field=ff.Integer(1, "integer field"), 287 | ) 288 | # The error type and message get converted in the process. 289 | with self.assertRaisesRegex( 290 | flags.IllegalFlagValueError, "Can't override a dict flag directly" 291 | ): 292 | flag_values(("./program", "--top_level_dict=3")) 293 | 294 | 295 | class DateTimeTest(parameterized.TestCase): 296 | 297 | @parameterized.named_parameters( 298 | dict( 299 | testcase_name="default_str", 300 | default="2001-01-01", 301 | expected=datetime.datetime(2001, 1, 1), 302 | ), 303 | dict( 304 | testcase_name="default_datetime", 305 | default=datetime.datetime(2001, 1, 1), 306 | expected=datetime.datetime(2001, 1, 1), 307 | ), 308 | dict(testcase_name="no_default", default=None, expected=None), 309 | ) 310 | def test_define_datetime_default(self, default, expected): 311 | flag_values = flags.FlagValues() 312 | flag_holder = ff.DEFINE_dict( 313 | "dict_with_datetime", 314 | flag_values, 315 | my_datetime=ff.DateTime(default, "datetime field"), 316 | ) 317 | flag_values(("./program", "")) 318 | self.assertEqual(flag_holder.value, {"my_datetime": expected}) 319 | 320 | def test_define_datetime_invalid_default_raises(self): 321 | with self.assertRaisesRegex(ValueError, r"invalid"): 322 | ff.DEFINE_dict( 323 | "dict_with_datetime", 324 | my_datetime=ff.DateTime("42", "datetime field"), 325 | ) 326 | 327 | def test_define_and_parse_invalid_value_raises(self): 328 | flag_name = "dict_with_datetime" 329 | flag_values = flags.FlagValues() 330 | ff.DEFINE_dict( 331 | flag_name, 332 | flag_values, 333 | my_datetime=ff.DateTime(None, "datetime field"), 334 | ) 335 | 336 | with self.assertRaisesRegex(flags.IllegalFlagValueError, r"invalid"): 337 | flag_values["dict_with_datetime.my_datetime"].parse("2001") 338 | 339 | 340 | class SequenceTest(absltest.TestCase): 341 | 342 | def test_sequence_defaults(self): 343 | flag_values = flags.FlagValues() 344 | flag_holder = ff.DEFINE_dict( 345 | "dict_with_sequences", 346 | flag_values, 347 | int_sequence=ff.Sequence([1, 2, 3], "integer field"), 348 | float_sequence=ff.Sequence([3.14, 2.718], "float field"), 349 | mixed_sequence=ff.Sequence([100, "hello", "world"], "mixed field"), 350 | ) 351 | 352 | flag_values(("./program", "")) 353 | self.assertEqual( 354 | flag_holder.value, 355 | { 356 | "int_sequence": [1, 2, 3], 357 | "float_sequence": [3.14, 2.718], 358 | "mixed_sequence": [100, "hello", "world"], 359 | }, 360 | ) 361 | 362 | 363 | class MultiEnumTest(parameterized.TestCase): 364 | 365 | def test_defaults_parsing(self): 366 | flag_values = flags.FlagValues() 367 | enum_values = [1, 2, 3, 3.14, 2.718, 100, "hello", ["world"], {"planets"}] 368 | ff.DEFINE_dict( 369 | "dict_with_multienums", 370 | flag_values, 371 | int_sequence=ff.MultiEnum([1, 2, 3], enum_values, "integer field"), 372 | float_sequence=ff.MultiEnum([3.14, 2.718], enum_values, "float field"), 373 | mixed_sequence=ff.MultiEnum( 374 | [100, "hello", ["world"], {"planets"}], enum_values, "mixed field" 375 | ), 376 | enum_sequence=ff.MultiEnum([MyEnum.A], MyEnum, "enum field"), 377 | ) 378 | 379 | expected = { 380 | "int_sequence": [1, 2, 3], 381 | "float_sequence": [3.14, 2.718], 382 | "mixed_sequence": [100, "hello", ["world"], {"planets"}], 383 | "enum_sequence": [MyEnum.A], 384 | } 385 | flag_values(("./program", "")) 386 | self.assertEqual(flag_values.dict_with_multienums, expected) 387 | 388 | 389 | class DefineSequenceTest(absltest.TestCase): 390 | 391 | # Follows test code in absl/flags/tests/flags_test.py 392 | 393 | def test_definition(self): 394 | flag_values = flags.FlagValues() 395 | flag_holder = ff.DEFINE_sequence( 396 | name="sequence", 397 | default=[1, 2, 3], 398 | help="sequence flag", 399 | flag_values=flag_values, 400 | ) 401 | 402 | flag_values(("./program", "")) 403 | self.assertEqual(flag_holder.value, [1, 2, 3]) 404 | self.assertEqual(flag_values.flag_values_dict()["sequence"], [1, 2, 3]) 405 | self.assertEqual(flag_values["sequence"].default_as_str, "'[1, 2, 3]'") # pytype: disable=attribute-error 406 | 407 | def test_end_to_end_with_default(self): 408 | # There are more extensive tests for the parser in argument_parser_test.py. 409 | # Here we just include a couple of end-to-end examples. 410 | flag_values = flags.FlagValues() 411 | 412 | flag_holder = ff.DEFINE_sequence( 413 | "sequence", 414 | [1, 2, 3], 415 | "sequence flag", 416 | flag_values=flag_values, 417 | ) 418 | flag_values(("./program", "--sequence=[4,5]")) 419 | self.assertEqual(flag_holder.value, [4, 5]) 420 | 421 | def test_end_to_end_without_default(self): 422 | flag_values = flags.FlagValues() 423 | flag_holder = ff.DEFINE_sequence( 424 | "sequence", 425 | None, 426 | "sequence flag", 427 | flag_values=flag_values, 428 | ) 429 | flag_values(("./program", "--sequence=(4, 5)")) 430 | self.assertEqual(flag_holder.value, (4, 5)) 431 | 432 | 433 | class DefineMultiEnumTest(absltest.TestCase): 434 | 435 | # Follows test code in absl/flags/tests/flags_test.py 436 | 437 | def test_definition(self): 438 | flag_values = flags.FlagValues() 439 | flag_holder = ff.DEFINE_multi_enum( 440 | "multienum", 441 | [1, 2, 3], 442 | [1, 2, 3], 443 | "multienum flag", 444 | flag_values=flag_values, 445 | ) 446 | 447 | flag_values(("./program", "")) 448 | self.assertEqual(flag_holder.value, [1, 2, 3]) 449 | self.assertEqual(flag_values.multienum, [1, 2, 3]) 450 | self.assertEqual(flag_values.flag_values_dict()["multienum"], [1, 2, 3]) 451 | self.assertEqual(flag_values["multienum"].default_as_str, "'[1, 2, 3]'") # pytype: disable=attribute-error 452 | 453 | def test_end_to_end_with_default(self): 454 | # There are more extensive tests for the parser in argument_parser_test.py. 455 | # Here we just include a couple of end-to-end examples. 456 | flag_values = flags.FlagValues() 457 | flag_holder = ff.DEFINE_multi_enum( 458 | "multienum0", 459 | [1, 2, 3], 460 | [1, 2, 3, 4, 5], 461 | "multienum flag", 462 | flag_values=flag_values, 463 | ) 464 | flag_values(("./program", "--multienum0=[4,5]")) 465 | self.assertEqual(flag_holder.value, [4, 5]) 466 | 467 | def test_end_to_end_without_default(self): 468 | flag_values = flags.FlagValues() 469 | flag_holder = ff.DEFINE_multi_enum( 470 | "multienum1", 471 | None, 472 | [1, 2, 3, 4, 5], 473 | "multienum flag", 474 | flag_values=flag_values, 475 | ) 476 | flag_values(("./program", "--multienum1=(4, 5)")) 477 | self.assertEqual(flag_holder.value, (4, 5)) 478 | 479 | 480 | class MultiEnumClassTest(parameterized.TestCase): 481 | 482 | def test_multi_enum_class(self): 483 | flag_values = flags.FlagValues() 484 | flag_holder = ff.DEFINE_dict( 485 | "dict_with_multi_enum_class", 486 | flag_values, 487 | item=ff.MultiEnumClass( 488 | [MyEnum.A], 489 | MyEnum, 490 | "multi enum", 491 | ), 492 | ) 493 | flag_values(( 494 | "./program", 495 | "--dict_with_multi_enum_class.item=A", 496 | "--dict_with_multi_enum_class.item=B", 497 | "--dict_with_multi_enum_class.item=A", 498 | )) 499 | expected = {"item": [MyEnum.A, MyEnum.B, MyEnum.A]} 500 | self.assertEqual(flag_holder.value, expected) 501 | 502 | 503 | class MultiStringTest(parameterized.TestCase): 504 | 505 | def test_defaults_parsing(self): 506 | flag_values = flags.FlagValues() 507 | flag_holder = ff.DEFINE_dict( 508 | "dict_with_multistrings", 509 | flag_values, 510 | no_default=ff.MultiString(None, "no default"), 511 | single_entry=ff.MultiString("a", "single entry"), 512 | single_entry_list=ff.MultiString(["a"], "single entry list"), 513 | multiple_entry_list=ff.MultiString(["a", "b"], "multiple entry list"), 514 | ) 515 | flag_values(("./program", "")) 516 | expected = { 517 | "no_default": None, 518 | "single_entry": ["a"], 519 | "single_entry_list": ["a"], 520 | "multiple_entry_list": ["a", "b"], 521 | } 522 | self.assertEqual(flag_holder.value, expected) 523 | 524 | 525 | class SerializationTest(absltest.TestCase): 526 | 527 | def test_basic_serialization(self): 528 | flag_values = flags.FlagValues() 529 | ff.DEFINE_dict( 530 | "to_serialize", 531 | flag_values, 532 | integer_field=ff.Integer(1, "integer field"), 533 | boolean_field=ff.Boolean(False, "boolean field"), 534 | string_list_field=ff.StringList(["a", "b", "c"], "string list field"), 535 | enum_class_field=ff.EnumClass(MyEnum.A, MyEnum, "my enum field"), 536 | ) 537 | 538 | initial_dict_value = copy.deepcopy(flag_values["to_serialize"].value) 539 | 540 | # Parse flags, then serialize. 541 | flag_values([ 542 | "./program", 543 | "--to_serialize.boolean_field=True", 544 | "--to_serialize.integer_field", 545 | "1337", 546 | "--to_serialize.string_list_field=d,e,f", 547 | "--to_serialize.enum_class_field=B", 548 | ]) 549 | self.assertEqual(flag_values["to_serialize"].serialize(), _flags._EMPTY) 550 | self.assertEqual( 551 | flag_values["to_serialize.boolean_field"].serialize(), 552 | "--to_serialize.boolean_field", 553 | ) 554 | self.assertEqual( 555 | flag_values["to_serialize.string_list_field"].serialize(), 556 | "--to_serialize.string_list_field=d,e,f", 557 | ) 558 | 559 | parsed_dict_value = copy.deepcopy(flag_values["to_serialize"].value) 560 | 561 | self.assertDictEqual( 562 | parsed_dict_value, 563 | { 564 | "boolean_field": True, 565 | "integer_field": 1337, 566 | "string_list_field": ["d", "e", "f"], 567 | "enum_class_field": MyEnum.B, 568 | }, 569 | ) 570 | self.assertNotEqual(flag_values["to_serialize"].value, initial_dict_value) 571 | 572 | # test a round trip 573 | serialized_args = [ 574 | flag_values[name].serialize() 575 | for name in flag_values 576 | if name.startswith("to_serialize.") 577 | ] 578 | 579 | flag_values.unparse_flags() # Reset to defaults 580 | self.assertDictEqual(flag_values["to_serialize"].value, initial_dict_value) 581 | 582 | flag_values(["./program"] + serialized_args) 583 | self.assertDictEqual(flag_values["to_serialize"].value, parsed_dict_value) 584 | 585 | 586 | NAMES_ITEMS_AND_FLAGS = ( 587 | dict( 588 | testcase_name="boolean", 589 | define_function=flags.DEFINE_boolean, 590 | item_constructor=ff.Boolean, 591 | default=True, 592 | override="false", 593 | ), 594 | dict( 595 | testcase_name="integer", 596 | define_function=flags.DEFINE_integer, 597 | item_constructor=ff.Integer, 598 | default=1, 599 | override="2", 600 | ), 601 | dict( 602 | testcase_name="float", 603 | define_function=flags.DEFINE_float, 604 | item_constructor=ff.Float, 605 | default=1.0, 606 | override="2.0", 607 | ), 608 | dict( 609 | testcase_name="sequence", 610 | define_function=ff.DEFINE_sequence, 611 | item_constructor=ff.Sequence, 612 | default=(1, "x"), 613 | override=(2.0, "y"), 614 | ), 615 | dict( 616 | testcase_name="string", 617 | define_function=flags.DEFINE_string, 618 | item_constructor=ff.String, 619 | default="one", 620 | override="two", 621 | ), 622 | dict( 623 | testcase_name="stringlist", 624 | define_function=flags.DEFINE_list, 625 | item_constructor=ff.StringList, 626 | default=["a", "b"], 627 | override="['c', 'd']", 628 | ), 629 | ) 630 | 631 | 632 | class FlagAndItemEquivalence(parameterized.TestCase): 633 | 634 | @parameterized.named_parameters(*NAMES_ITEMS_AND_FLAGS) 635 | def test_equivalence( 636 | self, 637 | define_function: Callable[..., flags.FlagHolder], 638 | item_constructor: type(ff.Item), 639 | default: Any, 640 | override: str, 641 | ): 642 | flag_values = flags.FlagValues() 643 | flag_holder = define_function( 644 | "name.item", 645 | default, 646 | "help string", 647 | flag_values=flag_values, 648 | ) 649 | 650 | ff_flagvalues = flags.FlagValues() 651 | shared_values = ff.define_flags( 652 | "name", 653 | {"item": item_constructor(default, "help string")}, 654 | flag_values=ff_flagvalues, 655 | ) 656 | 657 | with self.subTest("Check serialisation equivalence before parsing"): 658 | self.assertEqual( 659 | flag_values["name.item"].serialize(), 660 | ff_flagvalues["name.item"].serialize(), 661 | ) 662 | self.assertEqual( 663 | flag_values.flags_into_string(), ff_flagvalues.flags_into_string() 664 | ) 665 | 666 | with self.subTest("Apply overrides and check equivalence after parsing"): 667 | # The flag holder gets updated at this point: 668 | flag_values(("./program", f"--name.item={override}")) 669 | # The shared_values dict gets updated at this point: 670 | ff_flagvalues(("./program", f"--name.item={override}")) 671 | self.assertNotEqual(flag_holder.value, default) 672 | self.assertEqual(flag_holder.value, shared_values["item"]) 673 | 674 | 675 | if __name__ == "__main__": 676 | absltest.main() 677 | -------------------------------------------------------------------------------- /fancyflags/_flags.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Flag classes for defining dict, Item, MultiItem and Auto flags.""" 16 | 17 | import copy 18 | import functools 19 | 20 | from absl import flags 21 | 22 | _EMPTY = "" 23 | 24 | 25 | class DictFlag(flags.Flag): 26 | """Implements the shared dict mechanism. See also `ItemFlag`.""" 27 | 28 | def __init__(self, shared_dict, *args, **kwargs): 29 | self._shared_dict = shared_dict 30 | super().__init__(*args, **kwargs) 31 | 32 | def _parse(self, value): 33 | # A `DictFlag` should not be overridable from the command line; only the 34 | # dotted `Item` flags should be. However, the _parse() method will still be 35 | # called in two situations: 36 | 37 | # 1. Via the base `Flag`'s constructor, which calls `_parse()` to process 38 | # the default value, which will be the shared dict. 39 | # 2. When processing command line overrides. We don't want to allow this 40 | # normally, however some libraries will serialize and deserialize all 41 | # flags, e.g. to pass values between processes, so we accept a dummy 42 | # empty serialized value for these cases. It's unlikely users will try to 43 | # set the dict flag to an empty string from the command line. 44 | if value is self._shared_dict or value == _EMPTY: 45 | return self._shared_dict 46 | raise flags.IllegalFlagValueError( 47 | "Can't override a dict flag directly. Did you mean to override one of " 48 | "its `Item`s instead?" 49 | ) 50 | 51 | def serialize(self): 52 | # When serializing flags, we return a sentinel value that the `DictFlag` 53 | # will ignore when parsing. The value of this flag is determined by the 54 | # corresponding `Item` flags for serialization and deserialization. 55 | return _EMPTY 56 | 57 | def flag_type(self): 58 | return "dict" 59 | 60 | 61 | # TODO(b/170423907): Pytype doesn't correctly infer that these have type 62 | # `property`. 63 | _flag_value_property = flags.Flag.value # type: property # pytype: disable=annotation-type-mismatch,unbound-type-param 64 | _multi_flag_value_property = flags.MultiFlag.value # type: property # pytype: disable=annotation-type-mismatch 65 | 66 | 67 | class ItemFlag(flags.Flag): 68 | """Updates a shared dict whenever its own value changes. 69 | 70 | See also the `DictFlag` and `ff.Item` classes for usage. 71 | """ 72 | 73 | def __init__(self, shared_dict, namespace, parser, *args, **kwargs): 74 | self._shared_dict = shared_dict 75 | self._namespace = namespace 76 | super().__init__( 77 | *args, 78 | parser=parser, 79 | # absl treats boolean flags as a special case in order to support the 80 | # alternative `--foo`/`--nofoo` syntax. 81 | boolean=isinstance(parser, flags.BooleanParser), 82 | **kwargs 83 | ) 84 | 85 | # `super().value = value` doesn't work, see https://bugs.python.org/issue14965 86 | @_flag_value_property.setter 87 | def value(self, value): 88 | _flag_value_property.fset(self, value) 89 | self._update_shared_dict() 90 | 91 | def parse(self, argument): 92 | super().parse(argument) 93 | self._update_shared_dict() 94 | 95 | def _update_shared_dict(self): 96 | d = self._shared_dict 97 | for name in self._namespace[:-1]: 98 | d = d[name] 99 | d[self._namespace[-1]] = self.value 100 | 101 | 102 | class MultiItemFlag(flags.MultiFlag): 103 | """Updates a shared dict whenever its own value changes. 104 | 105 | Used for flags that can appear multiple times on the command line. 106 | See also the `DictFlag` and `ff.Item` classes for usage. 107 | """ 108 | 109 | def __init__(self, shared_dict, namespace, *args, **kwargs): 110 | self._shared_dict = shared_dict 111 | self._namespace = namespace 112 | super().__init__(*args, **kwargs) 113 | 114 | # `super().value = value` doesn't work, see https://bugs.python.org/issue14965 115 | @_multi_flag_value_property.setter 116 | def value(self, value): 117 | _multi_flag_value_property.fset(self, value) 118 | self._update_shared_dict() 119 | 120 | def parse(self, argument): 121 | super().parse(argument) 122 | self._update_shared_dict() 123 | 124 | def _update_shared_dict(self): 125 | d = self._shared_dict 126 | for name in self._namespace[:-1]: 127 | d = d[name] 128 | d[self._namespace[-1]] = self.value 129 | 130 | 131 | class AutoFlag(flags.Flag): 132 | """Implements the shared dict mechanism.""" 133 | 134 | def __init__(self, fn, fn_kwargs, *args, **kwargs): 135 | self._fn = fn 136 | self._fn_kwargs = fn_kwargs 137 | super().__init__(*args, **kwargs) 138 | 139 | @property 140 | def value(self): 141 | kwargs = copy.deepcopy(self._fn_kwargs) 142 | return functools.partial(self._fn, **kwargs) 143 | 144 | @value.setter 145 | def value(self, value): 146 | # The flags `.value` gets set as part of the `flags.FLAG` constructor to a 147 | # default value. However the default value should be given by the initial 148 | # `fn_kwargs` instead, so a) the semantics of setting the value are unclear 149 | # and b) we may not be able to call `self._fn` at this point in execution. 150 | del value 151 | 152 | def _parse(self, value): 153 | # An `AutoFlag` should not be overridable from the command line; only the 154 | # dotted `Item` flags should be. However, the `_parse()` method will still 155 | # be called in two situations: 156 | 157 | # 1. In the base `Flag`'s constructor, which calls `_parse()` to process the 158 | # default value, which will be None (as set in `DEFINE_auto`). 159 | # 2. When processing command line overrides. We don't want to allow this 160 | # normally, however some libraries will serialize and deserialize all 161 | # flags, e.g. to pass values between processes, so we accept a dummy 162 | # empty serialized value for these cases. It's unlikely users will try to 163 | # set the auto flag to an empty string from the command line. 164 | if value is None or value == _EMPTY: 165 | return None 166 | raise flags.IllegalFlagValueError( 167 | "Can't override an auto flag directly. Did you mean to override one of " 168 | "its `Item`s instead?" 169 | ) 170 | 171 | def serialize(self): 172 | # When serializing a `FlagHolder` container, we must return *some* value for 173 | # this flag. We return an empty value that the `AutoFlag` will ignore when 174 | # parsing. The value of this flag is instead determined by the 175 | # corresponding `Item` flags for serialization and deserialization. 176 | return _EMPTY 177 | 178 | def flag_type(self): 179 | return "auto" 180 | -------------------------------------------------------------------------------- /fancyflags/_flags_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for fancyflags._flags.""" 16 | 17 | from absl import flags 18 | from absl.testing import absltest 19 | from fancyflags import _flags 20 | 21 | 22 | class FlagsTest(absltest.TestCase): 23 | 24 | def test_update_shared_dict(self): 25 | # Tests that the shared dict is updated when the flag value is updated. 26 | shared_dict = {'a': {'b': 'value'}} 27 | namespace = ('a', 'b') 28 | flag_values = flags.FlagValues() 29 | 30 | flags.DEFINE_flag( 31 | _flags.ItemFlag( 32 | shared_dict, 33 | namespace, 34 | parser=flags.ArgumentParser(), 35 | serializer=flags.ArgumentSerializer(), 36 | name='a.b', 37 | default='bar', 38 | help_string='help string', 39 | ), 40 | flag_values=flag_values, 41 | ) 42 | 43 | flag_values['a.b'].value = 'new_value' 44 | with self.subTest(name='setter'): 45 | self.assertEqual(shared_dict, {'a': {'b': 'new_value'}}) 46 | 47 | flag_values(('./program', '--a.b=override')) 48 | with self.subTest(name='override_parse'): 49 | self.assertEqual(shared_dict, {'a': {'b': 'override'}}) 50 | 51 | def test_update_shared_dict_multi(self): 52 | # Tests that the shared dict is updated when the flag value is updated. 53 | shared_dict = {'a': {'b': ['value']}} 54 | namespace = ('a', 'b') 55 | flag_values = flags.FlagValues() 56 | 57 | flags.DEFINE_flag( 58 | _flags.MultiItemFlag( 59 | shared_dict, 60 | namespace, 61 | parser=flags.ArgumentParser(), 62 | serializer=flags.ArgumentSerializer(), 63 | name='a.b', 64 | default=['foo', 'bar'], 65 | help_string='help string', 66 | ), 67 | flag_values=flag_values, 68 | ) 69 | 70 | flag_values['a.b'].value = ['new', 'value'] 71 | with self.subTest(name='setter'): 72 | self.assertEqual(shared_dict, {'a': {'b': ['new', 'value']}}) 73 | 74 | flag_values(('./program', '--a.b=override1', '--a.b=override2')) 75 | with self.subTest(name='override_parse'): 76 | self.assertEqual(shared_dict, {'a': {'b': ['override1', 'override2']}}) 77 | 78 | 79 | if __name__ == '__main__': 80 | absltest.main() 81 | -------------------------------------------------------------------------------- /fancyflags/_flagsaver_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for compatibility with absl.testing.flagsaver.""" 16 | 17 | from absl import flags 18 | from absl.testing import absltest 19 | from absl.testing import flagsaver 20 | import fancyflags as ff 21 | 22 | flags.DEFINE_string("string_flag", "unchanged", "flag to test with") 23 | 24 | ff.DEFINE_dict( 25 | "test_dict_flag", 26 | dict=dict( 27 | nested=ff.Float(1.0, "nested flag"), 28 | ), 29 | unnested=ff.Integer(4, "unnested flag"), 30 | ) 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | 35 | class FlagSaverTest(absltest.TestCase): 36 | 37 | def test_flagsaver_with_context_overrides(self): 38 | with flagsaver.flagsaver( 39 | **{ 40 | "string_flag": "new value", 41 | "test_dict_flag.dict.nested": -1.0, 42 | } 43 | ): 44 | self.assertEqual("new value", FLAGS.string_flag) 45 | self.assertEqual(-1.0, FLAGS.test_dict_flag["dict"]["nested"]) 46 | self.assertEqual(4, FLAGS.test_dict_flag["unnested"]) 47 | FLAGS.string_flag = "another value" 48 | 49 | self.assertEqual("unchanged", FLAGS.string_flag) 50 | self.assertEqual(1.0, FLAGS.test_dict_flag["dict"]["nested"]) 51 | 52 | def test_flagsaver_with_decorator_overrides(self): 53 | # Modeled after test_decorator_with_overrides in 54 | # https://github.com/abseil/abseil-py/blob/master/absl/testing/tests/flagsaver_test.py # pylint: disable=line-too-long 55 | 56 | @flagsaver.flagsaver( 57 | **{ 58 | "string_flag": "new value", 59 | "test_dict_flag.dict.nested": -1.0, 60 | } 61 | ) 62 | def mutate_flags(): 63 | return FLAGS.string_flag, FLAGS.test_dict_flag["dict"]["nested"] 64 | 65 | # Values should be overridden in the function. 66 | self.assertEqual(("new value", -1.0), mutate_flags()) 67 | 68 | # But unchanged here. 69 | self.assertEqual("unchanged", FLAGS.string_flag) 70 | self.assertEqual(1.0, FLAGS.test_dict_flag["dict"]["nested"]) 71 | 72 | def test_flagsaver_with_context_overrides_twice(self): 73 | # Checking that the flat -> dict flag sync works again after restoration. 74 | # This might fail if the underlying absl functions copied the dict as part 75 | # of restoration. 76 | 77 | with flagsaver.flagsaver( 78 | **{ 79 | "string_flag": "new value", 80 | "test_dict_flag.dict.nested": -1.0, 81 | } 82 | ): 83 | self.assertEqual("new value", FLAGS.string_flag) 84 | self.assertEqual(-1.0, FLAGS.test_dict_flag["dict"]["nested"]) 85 | self.assertEqual(4, FLAGS.test_dict_flag["unnested"]) 86 | FLAGS.string_flag = "another value" 87 | 88 | self.assertEqual("unchanged", FLAGS.string_flag) 89 | self.assertEqual(1.0, FLAGS.test_dict_flag["dict"]["nested"]) 90 | 91 | # Same again! 92 | 93 | with flagsaver.flagsaver( 94 | **{ 95 | "string_flag": "new value", 96 | "test_dict_flag.dict.nested": -1.0, 97 | } 98 | ): 99 | self.assertEqual("new value", FLAGS.string_flag) 100 | self.assertEqual(-1.0, FLAGS.test_dict_flag["dict"]["nested"]) 101 | self.assertEqual(4, FLAGS.test_dict_flag["unnested"]) 102 | FLAGS.string_flag = "another value" 103 | 104 | self.assertEqual("unchanged", FLAGS.string_flag) 105 | self.assertEqual(1.0, FLAGS.test_dict_flag["dict"]["nested"]) 106 | 107 | @absltest.skip("This fails because flagsaver does not do deep copies") 108 | def test_flagsaver_with_changes_within_context(self): 109 | """Overrides within a flagsaver context should be correctly restored.""" 110 | with flagsaver.flagsaver(): 111 | FLAGS.string_flag = "new_value" 112 | FLAGS["test_dict_flag.dict.nested"].value = -1.0 113 | FLAGS.test_dict_flag["unnested"] = -1.0 114 | self.assertEqual("unchanged", FLAGS.string_flag) # Works. 115 | self.assertEqual(1.0, FLAGS.test_dict_flag["dict"]["nested"]) # Works. 116 | # TODO(b/177927157) Fix this behaviour. 117 | self.assertEqual(4, FLAGS.test_dict_flag["unnested"]) # Broken. 118 | 119 | 120 | if __name__ == "__main__": 121 | absltest.main() 122 | -------------------------------------------------------------------------------- /fancyflags/_metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Package metadata. 16 | 17 | This is kept in a separate module so that it can be imported from setup.py, at 18 | a time when the package dependencies may not have been installed yet. 19 | """ 20 | 21 | __version__ = '1.2' # https://www.python.org/dev/peps/pep-0440/ 22 | -------------------------------------------------------------------------------- /fancyflags/examples/example_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """A module that defines a dict flag, for testing purposes.""" 16 | 17 | import fancyflags as ff 18 | 19 | ff.DEFINE_dict( 20 | "settings", 21 | integer_field=ff.Integer(1, "integer field"), 22 | string_field=ff.String(None, "string field"), 23 | ) 24 | -------------------------------------------------------------------------------- /fancyflags/examples/override_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Test for flag overrides in fancyflags..""" 16 | 17 | from absl import flags 18 | from absl.testing import absltest 19 | import fancyflags as ff 20 | 21 | 22 | SETTINGS = ff.DEFINE_dict( 23 | "settings", 24 | integer_field=ff.Integer(1, "integer field"), 25 | string_field=ff.String(None, "string field"), 26 | nested=dict(float_field=ff.Float(0.0, "float field")), 27 | sequence_field=ff.Sequence([1, 2, 3], "sequence field"), 28 | another_sequence_field=ff.Sequence((3.14, 2.718), "another sequence field"), 29 | string_list_field=ff.StringList(["a"], "string list flag."), 30 | ) 31 | 32 | 33 | class OverrideTest(absltest.TestCase): 34 | 35 | def test_give_me_a_name(self): 36 | expected = dict( 37 | integer_field=5, 38 | string_field=None, # Not overridden in BUILD args. 39 | nested=dict( 40 | float_field=3.2, 41 | ), 42 | sequence_field=[4, 5, 6], 43 | another_sequence_field=[3.0, 2.0], 44 | string_list_field=["a", "bunch", "of", "overrides"], 45 | ) 46 | self.assertEqual(flags.FLAGS.settings, expected) 47 | self.assertEqual(SETTINGS.value, expected) 48 | 49 | 50 | if __name__ == "__main__": 51 | absltest.main() 52 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | norecursedirs = build dist examples 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | pytest==6.2.5 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Install script for fancyflags.""" 16 | 17 | from importlib import util 18 | import setuptools 19 | 20 | 21 | def _get_version(): 22 | spec = util.spec_from_file_location('_metadata', 'fancyflags/_metadata.py') 23 | mod = util.module_from_spec(spec) 24 | spec.loader.exec_module(mod) 25 | return mod.__version__ 26 | 27 | 28 | setuptools.setup( 29 | name='fancyflags', 30 | version=_get_version(), 31 | description='A Python library for defining dictionary-valued flags.', 32 | author='DeepMind', 33 | license='Apache License, Version 2.0', 34 | packages=setuptools.find_packages(), 35 | install_requires=['absl-py'], 36 | tests_require=['pytest'], 37 | classifiers=[ 38 | 'Programming Language :: Python', 39 | 'Programming Language :: Python :: 3', 40 | 'Programming Language :: Python :: 3.8', 41 | 'Programming Language :: Python :: 3.9', 42 | 'Programming Language :: Python :: 3.10', 43 | 'Programming Language :: Python :: 3.11', 44 | 'Intended Audience :: Developers', 45 | 'Topic :: Software Development :: Libraries :: Python Modules', 46 | 'License :: OSI Approved :: Apache Software License', 47 | 'Operating System :: OS Independent', 48 | ], 49 | ) 50 | --------------------------------------------------------------------------------