├── .github ├── ISSUE_TEMPLATE │ └── bug_report.md └── workflows │ └── pytest_and_autopublish.yml ├── .gitignore ├── LICENSE ├── README.md ├── docs ├── CONTRIBUTING.md ├── Makefile ├── README.md ├── conf.py ├── config_dict.rst ├── config_dict_examples.rst ├── config_flags.rst ├── config_flags_examples.rst ├── index.rst ├── requirements.txt └── testing.md ├── ml_collections ├── AUTHORS ├── __init__.py ├── config_dict │ ├── __init__.py │ ├── config_dict.py │ ├── examples │ │ ├── config.py │ │ ├── config_dict_advanced.py │ │ ├── config_dict_basic.py │ │ ├── config_dict_initialization.py │ │ ├── config_dict_lock.py │ │ ├── config_dict_placeholder.py │ │ ├── examples_test.py │ │ ├── field_reference.py │ │ └── frozen_config_dict.py │ └── tests │ │ ├── config_dict_test.py │ │ ├── field_reference_test.py │ │ └── frozen_config_dict_test.py ├── config_flags │ ├── __init__.py │ ├── config_flags.py │ ├── config_path.py │ ├── examples │ │ ├── config.py │ │ ├── define_config_dataclass_basic.py │ │ ├── define_config_dict_basic.py │ │ ├── define_config_file_basic.py │ │ ├── examples_test.py │ │ └── parameterised_config.py │ ├── tests │ │ ├── config_overriding_test.py │ │ ├── config_path_test.py │ │ ├── configdict_config.py │ │ ├── dataclass_overriding_test.py │ │ ├── fieldreference_config.py │ │ ├── ioerror_config.py │ │ ├── literal_config.py │ │ ├── mini_config.py │ │ ├── mock_config.py │ │ ├── parameterised_config.py │ │ ├── spork.py │ │ ├── tuple_parser_test.py │ │ ├── typeerror_config.py │ │ └── valueerror_config.py │ └── tuple_parser.py └── conftest.py └── pyproject.toml /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: mohitreddy1996 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | 15 | ***ConfigDict*** 16 | Please consider providing a Colab link reproducing the issue. 17 | 18 | ***ConfigFlags*** 19 | Please consider providing a Colab link in case of `config_flags.DEFINE_config_dict` or a simple repository with a python script which has `config_flags.DEFINE_config_flags` and the config flag. Would be best the repository contains a README file with any setups and how to execute the script. 20 | 21 | **Expected behavior** 22 | A clear and concise description of what you expected to happen. 23 | 24 | **Environment:** 25 | - OS: [e.g. MacOS] 26 | - OS Version: [e.g. 22] 27 | - Python: [e.g. Python 3.6] 28 | 29 | **Additional context** 30 | Add any other context about the problem here. 31 | -------------------------------------------------------------------------------- /.github/workflows/pytest_and_autopublish.yml: -------------------------------------------------------------------------------- 1 | name: Unittests & Auto-publish 2 | 3 | # Allow to trigger the workflow manually (e.g. when deps changes) 4 | on: [push, workflow_dispatch] 5 | 6 | jobs: 7 | pytest-job: 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ['3.10', '3.11', '3.12'] 12 | timeout-minutes: 30 13 | 14 | concurrency: 15 | group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.python-version }} 16 | cancel-in-progress: true 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | 21 | # Install deps 22 | - uses: actions/setup-python@v5 23 | with: 24 | python-version: "3.10" 25 | # Uncomment to cache of pip dependencies (if tests too slow) 26 | # cache: pip 27 | # cache-dependency-path: '**/pyproject.toml' 28 | 29 | - run: pip --version 30 | - run: pip install -e .[dev] 31 | - run: pip freeze 32 | 33 | # Run tests (in parallel) 34 | - name: Run core tests 35 | run: | 36 | # TODO(marcenacp): do not ignore tests. 37 | pytest -vv -n auto ml_collections/ \ 38 | --ignore=ml_collections/config_dict/examples/examples_test.py 39 | 40 | # Auto-publish when version is increased 41 | publish-job: 42 | # Only try to publish if: 43 | # * Repo is self (prevents running from forks) 44 | # * Branch is `master` 45 | if: | 46 | github.repository == 'google/ml_collections' 47 | && github.ref == 'refs/heads/master' 48 | needs: pytest-job # Only publish after tests are successful 49 | runs-on: ubuntu-latest 50 | permissions: 51 | contents: write 52 | timeout-minutes: 30 53 | 54 | steps: 55 | # Publish the package (if local `__version__` > pip version) 56 | - uses: etils-actions/pypi-auto-publish@v1 57 | with: 58 | pypi-token: ${{ secrets.PYPI_API_TOKEN }} 59 | gh-token: ${{ secrets.GITHUB_TOKEN }} 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | docs/__autosummary 2 | docs/_build 3 | docs/make.bat 4 | *.rej 5 | *~ 6 | \#*\# 7 | 8 | # Compiled python modules. 9 | *.pyc 10 | 11 | # Byte-compiled 12 | _pycache__/ 13 | .cache/ 14 | 15 | # Poetry, setuptools, PyPI distribution artifacts. 16 | /*.egg-info 17 | .eggs/ 18 | build/ 19 | dist/ 20 | poetry.lock 21 | 22 | # Tests 23 | .pytest_cache/ 24 | 25 | # Type checking 26 | .pytype/ 27 | 28 | # Other 29 | *.DS_Store 30 | 31 | # PyCharm 32 | .idea 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ML Collections 2 | 3 | ML Collections is a library of Python Collections designed for ML use cases. 4 | 5 | [![Documentation Status](https://readthedocs.org/projects/ml-collections/badge/?version=latest)](https://ml-collections.readthedocs.io/en/latest/?badge=latest) 6 | [![PyPI version](https://badge.fury.io/py/ml-collections.svg)](https://badge.fury.io/py/ml-collections) 7 | [![Build Status](https://github.com/google/ml_collections/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/google/ml_collections/actions/workflows/pytest_and_autopublish.yml) 8 | 9 | ## ConfigDict 10 | 11 | The two classes called `ConfigDict` and `FrozenConfigDict` are "dict-like" data 12 | structures with dot access to nested elements. Together, they are supposed to be 13 | used as a main way of expressing configurations of experiments and models. 14 | 15 | This document describes example usage of `ConfigDict`, `FrozenConfigDict`, 16 | `FieldReference`. 17 | 18 | ### Features 19 | 20 | * Dot-based access to fields. 21 | * Locking mechanism to prevent spelling mistakes. 22 | * Lazy computation. 23 | * FrozenConfigDict() class which is immutable and hashable. 24 | * Type safety. 25 | * "Did you mean" functionality. 26 | * Human readable printing (with valid references and cycles), using valid YAML 27 | format. 28 | * Fields can be passed as keyword arguments using the `**` operator. 29 | * There is one exception to the strong type-safety of the ConfigDict: `int` 30 | values can be passed in to fields of type `float`. In such a case, the value 31 | is type-converted to a `float` before being stored. (Back in the day of 32 | Python 2, there was a similar exception to allow both `str` and `unicode` 33 | values in string fields.) 34 | 35 | ### Basic Usage 36 | 37 | ```python 38 | from ml_collections import config_dict 39 | 40 | cfg = config_dict.ConfigDict() 41 | cfg.float_field = 12.6 42 | cfg.integer_field = 123 43 | cfg.another_integer_field = 234 44 | cfg.nested = config_dict.ConfigDict() 45 | cfg.nested.string_field = 'tom' 46 | 47 | print(cfg.integer_field) # Prints 123. 48 | print(cfg['integer_field']) # Prints 123 as well. 49 | 50 | try: 51 | cfg.integer_field = 'tom' # Raises TypeError as this field is an integer. 52 | except TypeError as e: 53 | print(e) 54 | 55 | cfg.float_field = 12 # Works: `Int` types can be assigned to `Float`. 56 | cfg.nested.string_field = u'bob' # `String` fields can store Unicode strings. 57 | 58 | print(cfg) 59 | ``` 60 | 61 | ### FrozenConfigDict 62 | 63 | A `FrozenConfigDict`is an immutable, hashable type of `ConfigDict`: 64 | 65 | ```python 66 | from ml_collections import config_dict 67 | 68 | initial_dictionary = { 69 | 'int': 1, 70 | 'list': [1, 2], 71 | 'tuple': (1, 2, 3), 72 | 'set': {1, 2, 3, 4}, 73 | 'dict_tuple_list': {'tuple_list': ([1, 2], 3)} 74 | } 75 | 76 | cfg = config_dict.ConfigDict(initial_dictionary) 77 | frozen_dict = config_dict.FrozenConfigDict(initial_dictionary) 78 | 79 | print(frozen_dict.tuple) # Prints tuple (1, 2, 3) 80 | print(frozen_dict.list) # Prints tuple (1, 2) 81 | print(frozen_dict.set) # Prints frozenset {1, 2, 3, 4} 82 | print(frozen_dict.dict_tuple_list.tuple_list[0]) # Prints tuple (1, 2) 83 | 84 | frozen_cfg = config_dict.FrozenConfigDict(cfg) 85 | print(frozen_cfg == frozen_dict) # True 86 | print(hash(frozen_cfg) == hash(frozen_dict)) # True 87 | 88 | try: 89 | frozen_dict.int = 2 # Raises TypeError as FrozenConfigDict is immutable. 90 | except AttributeError as e: 91 | print(e) 92 | 93 | # Converting between `FrozenConfigDict` and `ConfigDict`: 94 | thawed_frozen_cfg = config_dict.ConfigDict(frozen_dict) 95 | print(thawed_frozen_cfg == cfg) # True 96 | frozen_cfg_to_cfg = frozen_dict.as_configdict() 97 | print(frozen_cfg_to_cfg == cfg) # True 98 | ``` 99 | 100 | ### FieldReferences and placeholders 101 | 102 | A `FieldReference` is useful for having multiple fields use the same value. It 103 | can also be used for [lazy computation](#lazy-computation). 104 | 105 | You can use `placeholder()` as a shortcut to create a `FieldReference` (field) 106 | with a `None` default value. This is useful if a program uses optional 107 | configuration fields. 108 | 109 | ```python 110 | from ml_collections import config_dict 111 | 112 | placeholder = config_dict.FieldReference(0) 113 | cfg = config_dict.ConfigDict() 114 | cfg.placeholder = placeholder 115 | cfg.optional = config_dict.placeholder(int) 116 | cfg.nested = config_dict.ConfigDict() 117 | cfg.nested.placeholder = placeholder 118 | 119 | try: 120 | cfg.optional = 'tom' # Raises Type error as this field is an integer. 121 | except TypeError as e: 122 | print(e) 123 | 124 | cfg.optional = 1555 # Works fine. 125 | cfg.placeholder = 1 # Changes the value of both placeholder and 126 | # nested.placeholder fields. 127 | 128 | print(cfg) 129 | ``` 130 | 131 | Note that the indirection provided by `FieldReference`s will be lost if accessed 132 | through a `ConfigDict`. 133 | 134 | ```python 135 | from ml_collections import config_dict 136 | 137 | placeholder = config_dict.FieldReference(0) 138 | cfg.field1 = placeholder 139 | cfg.field2 = placeholder # This field will be tied to cfg.field1. 140 | cfg.field3 = cfg.field1 # This will just be an int field initialized to 0. 141 | ``` 142 | 143 | ### Lazy computation 144 | 145 | Using a `FieldReference` in a standard operation (addition, subtraction, 146 | multiplication, etc...) will return another `FieldReference` that points to the 147 | original's value. You can use `FieldReference.get()` to execute the operations 148 | and get the reference's computed value, and `FieldReference.set()` to change the 149 | original reference's value. 150 | 151 | ```python 152 | from ml_collections import config_dict 153 | 154 | ref = config_dict.FieldReference(1) 155 | print(ref.get()) # Prints 1 156 | 157 | add_ten = ref.get() + 10 # ref.get() is an integer and so is add_ten 158 | add_ten_lazy = ref + 10 # add_ten_lazy is a FieldReference - NOT an integer 159 | 160 | print(add_ten) # Prints 11 161 | print(add_ten_lazy.get()) # Prints 11 because ref's value is 1 162 | 163 | # Addition is lazily computed for FieldReferences so changing ref will change 164 | # the value that is used to compute add_ten. 165 | ref.set(5) 166 | print(add_ten) # Prints 11 167 | print(add_ten_lazy.get()) # Prints 15 because ref's value is 5 168 | ``` 169 | 170 | If a `FieldReference` has `None` as its original value, or any operation has an 171 | argument of `None`, then the lazy computation will evaluate to `None`. 172 | 173 | We can also use fields in a `ConfigDict` in lazy computation. In this case a 174 | field will only be lazily evaluated if `ConfigDict.get_ref()` is used to get it. 175 | 176 | ```python 177 | from ml_collections import config_dict 178 | 179 | config = config_dict.ConfigDict() 180 | config.reference_field = config_dict.FieldReference(1) 181 | config.integer_field = 2 182 | config.float_field = 2.5 183 | 184 | # No lazy evaluatuations because we didn't use get_ref() 185 | config.no_lazy = config.integer_field * config.float_field 186 | 187 | # This will lazily evaluate ONLY config.integer_field 188 | config.lazy_integer = config.get_ref('integer_field') * config.float_field 189 | 190 | # This will lazily evaluate ONLY config.float_field 191 | config.lazy_float = config.integer_field * config.get_ref('float_field') 192 | 193 | # This will lazily evaluate BOTH config.integer_field and config.float_Field 194 | config.lazy_both = (config.get_ref('integer_field') * 195 | config.get_ref('float_field')) 196 | 197 | config.integer_field = 3 198 | print(config.no_lazy) # Prints 5.0 - It uses integer_field's original value 199 | 200 | print(config.lazy_integer) # Prints 7.5 201 | 202 | config.float_field = 3.5 203 | print(config.lazy_float) # Prints 7.0 204 | print(config.lazy_both) # Prints 10.5 205 | ``` 206 | 207 | #### Changing lazily computed values 208 | 209 | Lazily computed values in a ConfigDict can be overridden in the same way as 210 | regular values. The reference to the `FieldReference` used for the lazy 211 | computation will be lost and all computations downstream in the reference graph 212 | will use the new value. 213 | 214 | ```python 215 | from ml_collections import config_dict 216 | 217 | config = config_dict.ConfigDict() 218 | config.reference = 1 219 | config.reference_0 = config.get_ref('reference') + 10 220 | config.reference_1 = config.get_ref('reference') + 20 221 | config.reference_1_0 = config.get_ref('reference_1') + 100 222 | 223 | print(config.reference) # Prints 1. 224 | print(config.reference_0) # Prints 11. 225 | print(config.reference_1) # Prints 21. 226 | print(config.reference_1_0) # Prints 121. 227 | 228 | config.reference_1 = 30 229 | 230 | print(config.reference) # Prints 1 (unchanged). 231 | print(config.reference_0) # Prints 11 (unchanged). 232 | print(config.reference_1) # Prints 30. 233 | print(config.reference_1_0) # Prints 130. 234 | ``` 235 | 236 | #### Cycles 237 | 238 | You cannot create cycles using references. Fortunately 239 | [the only way](#changing-lazily-computed-values) to create a cycle is by 240 | assigning a computed field to one that *is not* the result of computation. This 241 | is forbidden: 242 | 243 | ```python 244 | from ml_collections import config_dict 245 | 246 | config = config_dict.ConfigDict() 247 | config.integer_field = 1 248 | config.bigger_integer_field = config.get_ref('integer_field') + 10 249 | 250 | try: 251 | # Raises a MutabilityError because setting config.integer_field would 252 | # cause a cycle. 253 | config.integer_field = config.get_ref('bigger_integer_field') + 2 254 | except config_dict.MutabilityError as e: 255 | print(e) 256 | ``` 257 | 258 | #### One-way references 259 | 260 | One gotcha with `get_ref` is that it creates a bi-directional dependency when no operations are performed on the value. 261 | 262 | ```python 263 | from ml_collections import config_dict 264 | 265 | config = config_dict.ConfigDict() 266 | config.reference = 1 267 | config.reference_0 = config.get_ref('reference') 268 | config.reference_0 = 2 269 | print(config.reference) # Prints 2. 270 | print(config.reference_0) # Prints 2. 271 | ``` 272 | 273 | This can be avoided by using `get_oneway_ref` instead of `get_ref`. 274 | 275 | ```python 276 | from ml_collections import config_dict 277 | 278 | config = config_dict.ConfigDict() 279 | config.reference = 1 280 | config.reference_0 = config.get_oneway_ref('reference') 281 | config.reference_0 = 2 282 | print(config.reference) # Prints 1. 283 | print(config.reference_0) # Prints 2. 284 | ``` 285 | 286 | ### Advanced usage 287 | 288 | Here are some more advanced examples showing lazy computation with different 289 | operators and data types. 290 | 291 | ```python 292 | from ml_collections import config_dict 293 | 294 | config = config_dict.ConfigDict() 295 | config.float_field = 12.6 296 | config.integer_field = 123 297 | config.list_field = [0, 1, 2] 298 | 299 | config.float_multiply_field = config.get_ref('float_field') * 3 300 | print(config.float_multiply_field) # Prints 37.8 301 | 302 | config.float_field = 10.0 303 | print(config.float_multiply_field) # Prints 30.0 304 | 305 | config.longer_list_field = config.get_ref('list_field') + [3, 4, 5] 306 | print(config.longer_list_field) # Prints [0, 1, 2, 3, 4, 5] 307 | 308 | config.list_field = [-1] 309 | print(config.longer_list_field) # Prints [-1, 3, 4, 5] 310 | 311 | # Both operands can be references 312 | config.ref_subtraction = ( 313 | config.get_ref('float_field') - config.get_ref('integer_field')) 314 | print(config.ref_subtraction) # Prints -113.0 315 | 316 | config.integer_field = 10 317 | print(config.ref_subtraction) # Prints 0.0 318 | ``` 319 | 320 | ### Equality checking 321 | 322 | You can use `==` and `.eq_as_configdict()` to check equality among `ConfigDict` 323 | and `FrozenConfigDict` objects. 324 | 325 | ```python 326 | from ml_collections import config_dict 327 | 328 | dict_1 = {'list': [1, 2]} 329 | dict_2 = {'list': (1, 2)} 330 | cfg_1 = config_dict.ConfigDict(dict_1) 331 | frozen_cfg_1 = config_dict.FrozenConfigDict(dict_1) 332 | frozen_cfg_2 = config_dict.FrozenConfigDict(dict_2) 333 | 334 | # True because FrozenConfigDict converts lists to tuples 335 | print(frozen_cfg_1.items() == frozen_cfg_2.items()) 336 | # False because == distinguishes the underlying difference 337 | print(frozen_cfg_1 == frozen_cfg_2) 338 | 339 | # False because == distinguishes these types 340 | print(frozen_cfg_1 == cfg_1) 341 | # But eq_as_configdict() treats both as ConfigDict, so these are True: 342 | print(frozen_cfg_1.eq_as_configdict(cfg_1)) 343 | print(cfg_1.eq_as_configdict(frozen_cfg_1)) 344 | ``` 345 | 346 | ### Equality checking with lazy computation 347 | 348 | Equality checks see if the computed values are the same. Equality is satisfied 349 | if two sets of computations are different as long as they result in the same 350 | value. 351 | 352 | ```python 353 | from ml_collections import config_dict 354 | 355 | cfg_1 = config_dict.ConfigDict() 356 | cfg_1.a = 1 357 | cfg_1.b = cfg_1.get_ref('a') + 2 358 | 359 | cfg_2 = config_dict.ConfigDict() 360 | cfg_2.a = 1 361 | cfg_2.b = cfg_2.get_ref('a') * 3 362 | 363 | # True because all computed values are the same 364 | print(cfg_1 == cfg_2) 365 | ``` 366 | 367 | ### Locking and copying 368 | 369 | Here is an example with `lock()` and `deepcopy()`: 370 | 371 | ```python 372 | import copy 373 | from ml_collections import config_dict 374 | 375 | cfg = config_dict.ConfigDict() 376 | cfg.integer_field = 123 377 | 378 | # Locking prohibits the addition and deletion of new fields but allows 379 | # modification of existing values. 380 | cfg.lock() 381 | try: 382 | cfg.intagar_field = 124 # Modifies the wrong field 383 | except AttributeError as e: # Raises AttributeError and suggests valid field. 384 | print(e) 385 | with cfg.unlocked(): 386 | cfg.intagar_field = 1555 # Works fine. 387 | 388 | # Get a copy of the config dict. 389 | new_cfg = copy.deepcopy(cfg) 390 | new_cfg.integer_field = -123 # Works fine. 391 | 392 | print(cfg) 393 | print(new_cfg) 394 | ``` 395 | 396 | Output: 397 | 398 | ``` 399 | 'Key "intagar_field" does not exist and cannot be added since the config is locked. Other fields present: "{\'integer_field\': 123}"\nDid you mean "integer_field" instead of "intagar_field"?' 400 | intagar_field: 1555 401 | integer_field: 123 402 | 403 | intagar_field: 1555 404 | integer_field: -123 405 | ``` 406 | 407 | ### Dictionary attributes and initialization 408 | 409 | ```python 410 | from ml_collections import config_dict 411 | 412 | referenced_dict = {'inner_float': 3.14} 413 | d = { 414 | 'referenced_dict_1': referenced_dict, 415 | 'referenced_dict_2': referenced_dict, 416 | 'list_containing_dict': [{'key': 'value'}], 417 | } 418 | 419 | # We can initialize on a dictionary 420 | cfg = config_dict.ConfigDict(d) 421 | 422 | # Reference structure is preserved 423 | print(id(cfg.referenced_dict_1) == id(cfg.referenced_dict_2)) # True 424 | 425 | # And the dict attributes have been converted to ConfigDict 426 | print(type(cfg.referenced_dict_1)) # ConfigDict 427 | 428 | # However, the initialization does not look inside of lists, so dicts inside 429 | # lists are not converted to ConfigDict 430 | print(type(cfg.list_containing_dict[0])) # dict 431 | ``` 432 | 433 | ### More Examples 434 | 435 | For more examples, take a look at 436 | [`ml_collections/config_dict/examples/`](https://github.com/google/ml_collections/tree/master/ml_collections/config_dict/examples) 437 | 438 | For examples and gotchas specifically about initializing a ConfigDict, see 439 | [`ml_collections/config_dict/examples/config_dict_initialization.py`](https://github.com/google/ml_collections/blob/master/ml_collections/config_dict/examples/config_dict_initialization.py). 440 | 441 | ## Config Flags 442 | 443 | This library adds flag definitions to `absl.flags` to handle config files. It 444 | does not wrap `absl.flags` so if using any standard flag definitions alongside 445 | config file flags, users must also import `absl.flags`. 446 | 447 | Currently, this module adds two new flag types, namely `DEFINE_config_file` 448 | which accepts a path to a Python file that generates a configuration, and 449 | `DEFINE_config_dict` which accepts a configuration directly. Configurations are 450 | dict-like structures (see [ConfigDict](#configdict)) whose nested elements 451 | can be overridden using special command-line flags. See the examples below 452 | for more details. 453 | 454 | ### Usage 455 | 456 | Use `ml_collections.config_flags` alongside `absl.flags`. For 457 | example: 458 | 459 | `script.py`: 460 | 461 | ```python 462 | from absl import app 463 | from absl import flags 464 | 465 | from ml_collections import config_flags 466 | 467 | _CONFIG = config_flags.DEFINE_config_file('my_config') 468 | _MY_FLAG = flags.DEFINE_integer('my_flag', None) 469 | 470 | def main(_): 471 | print(_CONFIG.value) 472 | print(_MY_FLAG.value) 473 | 474 | if __name__ == '__main__': 475 | app.run(main) 476 | ``` 477 | 478 | `config.py`: 479 | 480 | ```python 481 | # Note that this is a valid Python script. 482 | # get_config() can return an arbitrary dict-like object. However, it is advised 483 | # to use ml_collections.config_dict.ConfigDict. 484 | # See ml_collections/config_dict/examples/config_dict_basic.py 485 | 486 | from ml_collections import config_dict 487 | 488 | def get_config(): 489 | config = config_dict.ConfigDict() 490 | config.field1 = 1 491 | config.field2 = 'tom' 492 | config.nested = config_dict.ConfigDict() 493 | config.nested.field = 2.23 494 | config.tuple = (1, 2, 3) 495 | return config 496 | ``` 497 | 498 | Warning: If you are using a pickle-based distributed programming framework such 499 | as [Launchpad](https://github.com/deepmind/launchpad#readme), be aware of 500 | limitations on the structure of this script that are [described below] 501 | (#config_files_and_pickling). 502 | 503 | Now, after running: 504 | 505 | ```bash 506 | python script.py --my_config=config.py \ 507 | --my_config.field1=8 \ 508 | --my_config.nested.field=2.1 \ 509 | --my_config.tuple='(1, 2, (1, 2))' 510 | ``` 511 | 512 | we get: 513 | 514 | ``` 515 | field1: 8 516 | field2: tom 517 | nested: 518 | field: 2.1 519 | tuple: !!python/tuple 520 | - 1 521 | - 2 522 | - !!python/tuple 523 | - 1 524 | - 2 525 | ``` 526 | 527 | Usage of `DEFINE_config_dict` is similar to `DEFINE_config_file`, the main 528 | difference is the configuration is defined in `script.py` instead of in a 529 | separate file. 530 | 531 | `script.py`: 532 | 533 | ```python 534 | from absl import app 535 | 536 | from ml_collections import config_dict 537 | from ml_collections import config_flags 538 | 539 | config = config_dict.ConfigDict() 540 | config.field1 = 1 541 | config.field2 = 'tom' 542 | config.nested = config_dict.ConfigDict() 543 | config.nested.field = 2.23 544 | config.tuple = (1, 2, 3) 545 | 546 | _CONFIG = config_flags.DEFINE_config_dict('my_config', config) 547 | 548 | def main(_): 549 | print(_CONFIG.value) 550 | 551 | if __name__ == '__main__': 552 | app.run() 553 | ``` 554 | 555 | `config_file` flags are compatible with the command-line flag syntax. All the 556 | following options are supported for non-boolean values in configurations: 557 | 558 | * `-(-)config.field=value` 559 | * `-(-)config.field value` 560 | 561 | Options for boolean values are slightly different: 562 | 563 | * `-(-)config.boolean_field`: set boolean value to True. 564 | * `-(-)noconfig.boolean_field`: set boolean value to False. 565 | * `-(-)config.boolean_field=value`: `value` is `true`, `false`, `True` or 566 | `False`. 567 | 568 | Note that `-(-)config.boolean_field value` is not supported. 569 | 570 | ### Parameterising the get_config() function 571 | 572 | It's sometimes useful to be able to pass parameters into `get_config`, and 573 | change what is returned based on this configuration. One example is if you are 574 | grid searching over parameters which have a different hierarchical structure - 575 | the flag needs to be present in the resulting ConfigDict. It would be possible 576 | to include the union of all possible leaf values in your ConfigDict, 577 | but this produces a confusing config result as you have to remember which 578 | parameters will actually have an effect and which won't. 579 | 580 | A better system is to pass some configuration, indicating which structure of 581 | ConfigDict should be returned. An example is the following config file: 582 | 583 | ```python 584 | from ml_collections import config_dict 585 | 586 | def get_config(config_string): 587 | possible_structures = { 588 | 'linear': config_dict.ConfigDict({ 589 | 'model_constructor': 'snt.Linear', 590 | 'model_config': config_dict.ConfigDict({ 591 | 'output_size': 42, 592 | }), 593 | 'lstm': config_dict.ConfigDict({ 594 | 'model_constructor': 'snt.LSTM', 595 | 'model_config': config_dict.ConfigDict({ 596 | 'hidden_size': 108, 597 | }) 598 | }) 599 | } 600 | 601 | return possible_structures[config_string] 602 | ``` 603 | 604 | The value of `config_string` will be anything that is to the right of the first 605 | colon in the config file path, if one exists. If no colon exists, no value is 606 | passed to `get_config` (producing a TypeError if `get_config` expects a value). 607 | 608 | The above example can be run like: 609 | 610 | ```bash 611 | python script.py -- --config=path_to_config.py:linear \ 612 | --config.model_config.output_size=256 613 | ``` 614 | 615 | or like: 616 | 617 | ```bash 618 | python script.py -- --config=path_to_config.py:lstm \ 619 | --config.model_config.hidden_size=512 620 | ``` 621 | 622 | ### Additional features 623 | 624 | * Loads any valid python script which defines `get_config()` function 625 | returning any python object. 626 | * Automatic locking of the loaded object, if the loaded object defines a 627 | callable `.lock()` method. 628 | * Supports command-line overriding of arbitrarily nested values in dict-like 629 | objects (with key/attribute based getters/setters) of the following types: 630 | * `int` 631 | * `float` 632 | * `bool` 633 | * `str` 634 | * `tuple` (but **not** `list`) 635 | * `enum.Enum` 636 | * Overriding is type safe. 637 | * Overriding of a `tuple` can be done by passing in the `tuple` value as a 638 | string (see the example in the [Usage](#usage) section). 639 | * The overriding `tuple` object can be of a different length and have 640 | different item types than the original. Nested tuples are also supported. 641 | 642 | ### Config Files and Pickling {#config_files_and_pickling} 643 | 644 | This is likely to be troublesome: 645 | 646 | ```python {.bad} 647 | @dataclasses.dataclass 648 | class MyRecord: 649 | num_balloons: int 650 | color: str 651 | 652 | def get_config(): 653 | return MyRecord(num_balloons=99, color='red') 654 | ``` 655 | 656 | This is not: 657 | 658 | ```python {.good} 659 | def get_config(): 660 | @dataclasses.dataclass 661 | class MyRecord: 662 | num_balloons: int 663 | color: str 664 | 665 | return MyRecord(num_balloons=99, color='red') 666 | ``` 667 | 668 | #### Explanation 669 | 670 | A config file is a Python module but it is not imported through Python's usual 671 | module-importing mechanism. 672 | 673 | Meanwhile, serialization libraries such as [`cloudpickle`]( 674 | https://github.com/cloudpipe/cloudpickle#readme) (which is used by [Launchpad]( 675 | https://github.com/deepmind/launchpad#readme)) and [Apache Beam]( 676 | https://beam.apache.org/) expect to be able to pickle an object without also 677 | pickling every type to which it refers, on the assumption that types defined 678 | at module scope can later be reconstructed simply by re-importing the modules 679 | in which they are defined. 680 | 681 | That assumption does not hold for a type that is defined at module scope in a 682 | config file, because the config file can't be imported the usual way. The 683 | symptom of this will be an `ImportError` when unpickling an object. 684 | 685 | The treatment is to move types from module scope into `get_config()` so that 686 | they will be serialized along with the values that have those types. 687 | 688 | ## Authors 689 | * Sergio Gómez Colmenarejo - sergomez@google.com 690 | * Wojciech Marian Czarnecki - lejlot@google.com 691 | * Nicholas Watters 692 | * Mohit Reddy - mohitreddy@google.com 693 | -------------------------------------------------------------------------------- /docs/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # ML Collections docs 2 | 3 | The ML Collections documentation can be found here: https://ml-collections.readthedocs.io/en/latest/ 4 | 5 | # How to build the docs 6 | 1. Install the requirements in ml_collections/docs/requirements.txt. 7 | 2. Ensure `pandoc` is installed. 8 | 3. Run `make html` to locally generate documentation. -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Configuration file for the Sphinx documentation builder. 16 | # 17 | # This file only contains a selection of the most common options. For a full 18 | # list see the documentation: 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 20 | 21 | # -- Path setup -------------------------------------------------------------- 22 | 23 | # If extensions (or modules to document with autodoc) are in another directory, 24 | # add these directories to sys.path here. If the directory is relative to the 25 | # documentation root, use os.path.abspath to make it absolute, like shown here. 26 | # 27 | # import os 28 | # import sys 29 | # sys.path.insert(0, os.path.abspath('.')) 30 | 31 | import os 32 | import sys 33 | sys.path.insert(0, os.path.abspath('..')) 34 | 35 | # -- Project information ----------------------------------------------------- 36 | 37 | project = 'ml_collections' 38 | copyright = '2020, The ML Collection Authors' 39 | author = 'The ML Collection Authors' 40 | 41 | # The full version, including alpha/beta/rc tags 42 | release = '0.1.0' 43 | 44 | 45 | # -- General configuration --------------------------------------------------- 46 | 47 | # Add any Sphinx extension module names here, as strings. They can be 48 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 49 | # ones. 50 | extensions = [ 51 | 'sphinx.ext.autodoc', 52 | 'sphinx.ext.autosummary', 53 | 'sphinx.ext.intersphinx', 54 | 'sphinx.ext.mathjax', 55 | 'sphinx.ext.napoleon', 56 | 'sphinx.ext.viewcode', 57 | 'nbsphinx', 58 | 'recommonmark', 59 | ] 60 | 61 | # Add any paths that contain templates here, relative to this directory. 62 | templates_path = ['_templates'] 63 | 64 | # List of patterns, relative to source directory, that match files and 65 | # directories to ignore when looking for source files. 66 | # This pattern also affects html_static_path and html_extra_path. 67 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 68 | 69 | autosummary_generate = True 70 | 71 | master_doc = 'index' 72 | 73 | # -- Options for HTML output ------------------------------------------------- 74 | 75 | # The theme to use for HTML and HTML Help pages. See the documentation for 76 | # a list of builtin themes. 77 | # 78 | html_theme = 'sphinx_rtd_theme' 79 | 80 | # Add any paths that contain custom static files (such as style sheets) here, 81 | # relative to this directory. They are copied after the builtin static files, 82 | # so a file named "default.css" will overwrite the builtin "default.css". 83 | html_static_path = ['_static'] -------------------------------------------------------------------------------- /docs/config_dict.rst: -------------------------------------------------------------------------------- 1 | ml_collections.config_dict package 2 | ================================== 3 | 4 | .. currentmodule:: ml_collections.config_dict 5 | 6 | .. automodule:: ml_collections.config_dict 7 | 8 | ConfigDict class 9 | ---------------- 10 | .. autoclass:: ConfigDict 11 | :members: __init__, is_type_safe, convert_dict, lock, is_locked, unlock, 12 | get, get_oneway_ref, items, iteritems, keys, iterkeys, values, itervalues, 13 | eq_as_configdict, to_yaml, to_json, to_json_best_effort, to_dict, 14 | copy_and_resolve_references, unlocked, ignore_type, get_type, update, 15 | update_from_flattened_dict 16 | 17 | FrozenConfigDict class 18 | ---------------------- 19 | .. autoclass:: FrozenConfigDict 20 | :members: __init__ 21 | 22 | FieldReference class 23 | -------------------- 24 | .. autoclass:: FieldReference 25 | :members: __init__, has_cycle, set, empty, get, get_type, identity, to_int, 26 | to_float, to_str 27 | 28 | Additional Methods 29 | ------------------ 30 | .. autosummary:: 31 | :toctree: __autosummary 32 | 33 | create 34 | placeholder 35 | required_placeholder 36 | recursive_rename 37 | CustomJSONEncoder 38 | JSONDecodeError 39 | MutabilityError 40 | RequiredValueError -------------------------------------------------------------------------------- /docs/config_dict_examples.rst: -------------------------------------------------------------------------------- 1 | ml_collections.config_dict examples 2 | =================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | ConfigDict Basic 8 | ConfigDict Advanced 9 | ConfigDict Lock 10 | ConfigDict Placeholder 11 | Field Reference 12 | FrozenConfigDict -------------------------------------------------------------------------------- /docs/config_flags.rst: -------------------------------------------------------------------------------- 1 | ml_collections.config_flags package 2 | =================================== 3 | 4 | .. currentmodule:: ml_collections.config_flags 5 | 6 | .. automodule:: ml_collections.config_flags 7 | 8 | DEFINE_config_dict 9 | ------------------ 10 | .. automethod:: ml_collections.config_flags.DEFINE_config_dict 11 | 12 | DEFINE_config_file 13 | ------------------ 14 | .. automethod:: ml_collections.config_flags.DEFINE_config_file 15 | 16 | Additional Methods 17 | ------------------ 18 | .. autosummary:: 19 | :toctree: __autosummary 20 | 21 | ml_collections.config_flags.config_flags.is_config_flag 22 | ml_collections.config_flags.config_flags.GetValue 23 | ml_collections.config_flags.config_flags.GetType 24 | ml_collections.config_flags.config_flags.GetTypes 25 | ml_collections.config_flags.config_flags.SetValue -------------------------------------------------------------------------------- /docs/config_flags_examples.rst: -------------------------------------------------------------------------------- 1 | ml_collections.config_flags examples 2 | ==================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | DEFINE_config_dict 8 | DEFINE_config_file -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. ml_collections documentation master file, created by 2 | sphinx-quickstart on Mon Aug 24 04:09:18 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to ml_collections's documentation! 7 | ========================================== 8 | 9 | ML Collections is a library of Python Collections designed for ML use cases. 10 | 11 | .. toctree:: 12 | :maxdepth: 1 13 | :caption: Quickstart 14 | 15 | README 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | :caption: Examples 20 | 21 | config_dict_examples 22 | config_flags_examples 23 | 24 | .. toctree:: 25 | :maxdepth: 2 26 | :caption: API Reference 27 | 28 | config_dict 29 | config_flags 30 | 31 | .. toctree:: 32 | :maxdepth: 1 33 | :caption: Additional material 34 | 35 | CONTRIBUTING 36 | testing -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx_rtd_theme 3 | nbsphinx 4 | recommonmark -------------------------------------------------------------------------------- /docs/testing.md: -------------------------------------------------------------------------------- 1 | ## Testing 2 | 3 | TODO(mohitreddy): Add sections to install bazel and run tests. 4 | -------------------------------------------------------------------------------- /ml_collections/AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of ML Collections's significant contributors. 2 | # 3 | # This does not necessarily list everyone who has contributed code, 4 | # especially since many employees of one corporation may be contributing. 5 | # To see the full list of contributors, see the revision history in 6 | # source control. 7 | 8 | DeepMind Technologies Limited 9 | Google LLC -------------------------------------------------------------------------------- /ml_collections/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ML Collections is a library of Python collections designed for ML usecases.""" 16 | 17 | from ml_collections.config_dict import ConfigDict 18 | from ml_collections.config_dict import FieldReference 19 | from ml_collections.config_dict import FrozenConfigDict 20 | 21 | __all__ = ("ConfigDict", "FieldReference", "FrozenConfigDict") 22 | 23 | # A new PyPI release will be pushed every time `__version__` is increased. 24 | __version__ = "1.1.0" 25 | -------------------------------------------------------------------------------- /ml_collections/config_dict/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Classes for defining configurations of experiments and models.""" 16 | 17 | from .config_dict import _Op 18 | from .config_dict import ConfigDict 19 | from .config_dict import create 20 | from .config_dict import CustomJSONEncoder 21 | from .config_dict import FieldReference 22 | from .config_dict import FrozenConfigDict 23 | from .config_dict import JSONDecodeError 24 | from .config_dict import MutabilityError 25 | from .config_dict import placeholder 26 | from .config_dict import recursive_rename 27 | from .config_dict import required_placeholder 28 | from .config_dict import RequiredValueError 29 | 30 | __all__ = ("_Op", "ConfigDict", "create", "CustomJSONEncoder", "FieldReference", 31 | "FrozenConfigDict", "JSONDecodeError", "MutabilityError", 32 | "placeholder", "recursive_rename", "required_placeholder", 33 | "RequiredValueError") 34 | -------------------------------------------------------------------------------- /ml_collections/config_dict/examples/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example of a config file using ConfigDict. 16 | 17 | The idea of this configuration file is to show a typical use case of ConfigDict, 18 | as well as its limitations. This also exemplifies a self-referencing ConfigDict. 19 | """ 20 | 21 | import copy 22 | from ml_collections import config_dict 23 | 24 | 25 | def _get_flat_config(): 26 | """Helper to generate simple config without references.""" 27 | 28 | # The suggested way to create a ConfigDict() is to call its constructor 29 | # and assign all relevant fields. 30 | config = config_dict.ConfigDict() 31 | 32 | # In order to add new attributes you can just use . notation, like with any 33 | # python object. They will be tracked by ConfigDict, and you get type checking 34 | # etc. for free. 35 | config.integer = 23 36 | config.float = 2.34 37 | config.string = 'james' 38 | config.bool = True 39 | 40 | # It is possible to assign dictionaries to ConfigDict and they will be 41 | # automatically and recursively wrapped with ConfigDict. However, make sure 42 | # that the dict you are assigning does not use internal references/cycles as 43 | # this is not supported. Instead, create the dicts explicitly as demonstrated 44 | # by get_config(). But note that this operation makes an element-by-element 45 | # copy of your original dict. 46 | 47 | # Also note that the recursive wrapping on input dictionaries with ConfigDict 48 | # does not extend through non-dictionary types (including basic Python types 49 | # and custom classes). This causes unexpected behavior most commonly if a 50 | # value is a list of dictionaries, so avoid giving ConfigDict such inputs. 51 | config.dict = { 52 | 'integer': 1, 53 | 'float': 3.14, 54 | 'string': 'mark', 55 | 'bool': False, 56 | 'dict': { 57 | 'float': 5 58 | } 59 | } 60 | return config 61 | 62 | 63 | def get_config(): 64 | """Returns a ConfigDict instance describing a complex config. 65 | 66 | Returns: 67 | A ConfigDict instance with the structure: 68 | 69 | ``` 70 | CONFIG-+-- integer 71 | |-- float 72 | |-- string 73 | |-- bool 74 | |-- dict +-- integer 75 | | |-- float 76 | | |-- string 77 | | |-- bool 78 | | |-- dict +-- float 79 | | 80 | |-- object +-- integer 81 | | |-- float 82 | | |-- string 83 | | |-- bool 84 | | |-- dict +-- integer 85 | | |-- float 86 | | |-- string 87 | | |-- bool 88 | | |-- dict +-- float 89 | | 90 | |-- object_copy +-- integer 91 | | |-- float 92 | | |-- string 93 | | |-- bool 94 | | |-- dict +-- integer 95 | | |-- float 96 | | |-- string 97 | | |-- bool 98 | | |-- dict +-- float 99 | | 100 | |-- object_reference [reference pointing to CONFIG-+--object] 101 | ``` 102 | """ 103 | config = _get_flat_config() 104 | config.object = _get_flat_config() 105 | 106 | # References work just fine, so you will be able to override both 107 | # values at the same time. The rule is the same as for python objects, 108 | # everything that is mutable is passed as a reference, thus it will not work 109 | # with assigning integers or strings, but will work just fine with 110 | # ConfigDicts. 111 | # WARNING: Each time you assign a dictionary as a value it will create a new 112 | # instance of ConfigDict in memory, thus it will be a copy of the original 113 | # dict and not a reference to the original. 114 | config.object_reference = config.object 115 | 116 | # ConfigDict supports deepcopying. 117 | config.object_copy = copy.deepcopy(config.object) 118 | 119 | return config 120 | -------------------------------------------------------------------------------- /ml_collections/config_dict/examples/config_dict_advanced.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example of ConfigDict usage. 16 | 17 | This example includes loading a ConfigDict in FLAGS, locking it, type 18 | safety, iteration over fields, checking for a particular field, unpacking with 19 | `**`, and loading dictionary from string representation. 20 | """ 21 | 22 | from absl import app 23 | from ml_collections import config_flags 24 | import yaml 25 | 26 | _CONFIG = config_flags.DEFINE_config_file( 27 | 'my_config', 28 | default='ml_collections/config_dict/examples/config.py') 29 | 30 | 31 | def hello_function(string, **unused_kwargs): 32 | return 'Hello {}'.format(string) 33 | 34 | 35 | def print_section(name): 36 | print() 37 | print() 38 | print('-' * len(name)) 39 | print(name.upper()) 40 | print('-' * len(name)) 41 | print() 42 | 43 | 44 | def main(_): 45 | # Config is already loaded in FLAGS.my_config due to the logic hidden 46 | # in app.run(). 47 | config = _CONFIG.value 48 | 49 | print_section('Printing config.') 50 | print(config) 51 | 52 | # Config is of our type ConfigDict. 53 | print('Type of the config {}'.format(type(config))) 54 | 55 | # By default it is locked, thus you cannot add new fields. 56 | # This prevents you from misspelling your attribute name. 57 | print_section('Locking.') 58 | print('config.is_locked={}'.format(config.is_locked)) 59 | try: 60 | config.object.new_field = -3 61 | except AttributeError as e: 62 | print(e) 63 | 64 | # There is also "did you mean" feature! 65 | try: 66 | config.object.floet = -3. 67 | except AttributeError as e: 68 | print(e) 69 | 70 | # However if you want to modify it you can always unlock. 71 | print_section('Unlocking.') 72 | with config.unlocked(): 73 | config.object.new_field = -3 74 | print('config.object.new_field={}'.format(config.object.new_field)) 75 | 76 | # By default config is also type-safe, so you cannot change the type of any 77 | # field. 78 | print_section('Type safety.') 79 | try: 80 | config.float = 'jerry' 81 | except TypeError as e: 82 | print(e) 83 | config.float = -1.2 84 | print('config.float={}'.format(config.float)) 85 | 86 | # NoneType is ignored by type safety and can both override and be overridden. 87 | config.float = None 88 | config.float = -1.2 89 | 90 | # You can temporarly turn type safety off. 91 | with config.ignore_type(): 92 | config.float = 'tom' 93 | print('config.float={}'.format(config.float)) 94 | config.float = 2.3 95 | print('config.float={}'.format(config.float)) 96 | 97 | # You can use ConfigDict as a regular dict in many typical use-cases: 98 | # Iteration over fields: 99 | print_section('Iteration over fields.') 100 | for field in config: 101 | print('config has field "{}"'.format(field)) 102 | 103 | # Checking if it contains a particular field using the "in" command. 104 | print_section('Checking for a particular field.') 105 | for field in ('float', 'non_existing'): 106 | if field in config: 107 | print('"{}" is in config'.format(field)) 108 | else: 109 | print('"{}" is not in config'.format(field)) 110 | 111 | # Using ** unrolling to pass the config to a function as named arguments. 112 | print_section('Unpacking with **') 113 | print(hello_function(**config)) 114 | 115 | # You can even load a dictionary (notice it is not ConfigDict anymore) from 116 | # a yaml string representation of ConfigDict. 117 | # Note: __repr__ (not __str__) is the recommended representation, as it 118 | # preserves FieldReferences and placeholders. 119 | print_section('Loading dictionary from string representation.') 120 | dictionary = yaml.load(repr(config), yaml.UnsafeLoader) 121 | print('dict["object_reference"]["dict"]["dict"]["float"]={}'.format( 122 | dictionary['object_reference']['dict']['dict']['float'])) 123 | 124 | 125 | if __name__ == '__main__': 126 | app.run(main) 127 | -------------------------------------------------------------------------------- /ml_collections/config_dict/examples/config_dict_basic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example of basic ConfigDict usage. 16 | 17 | This example shows the most basic usage of ConfigDict, including type safety. 18 | For examples of more features, see example_advanced. 19 | """ 20 | 21 | from absl import app 22 | from ml_collections import config_dict 23 | 24 | 25 | def main(_): 26 | cfg = config_dict.ConfigDict() 27 | cfg.float_field = 12.6 28 | cfg.integer_field = 123 29 | cfg.another_integer_field = 234 30 | cfg.nested = config_dict.ConfigDict() 31 | cfg.nested.string_field = 'tom' 32 | 33 | print(cfg.integer_field) # Prints 123. 34 | print(cfg['integer_field']) # Prints 123 as well. 35 | 36 | try: 37 | cfg.integer_field = 'tom' # Raises TypeError as this field is an integer. 38 | except TypeError as e: 39 | print(e) 40 | 41 | cfg.float_field = 12 # Works: `int` types can be assigned to `float`. 42 | cfg.nested.string_field = u'bob' # `String` fields can store Unicode strings. 43 | 44 | print(cfg) 45 | 46 | 47 | if __name__ == '__main__': 48 | app.run(main) 49 | -------------------------------------------------------------------------------- /ml_collections/config_dict/examples/config_dict_initialization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example of initialization features and gotchas in a ConfigDict. 16 | """ 17 | 18 | import copy 19 | 20 | from absl import app 21 | from ml_collections import config_dict 22 | 23 | 24 | def print_section(name): 25 | print() 26 | print() 27 | print('-' * len(name)) 28 | print(name.upper()) 29 | print('-' * len(name)) 30 | print() 31 | 32 | 33 | def main(_): 34 | 35 | inner_dict = {'list': [1, 2], 'tuple': (1, 2, [3, 4], (5, 6))} 36 | example_dict = { 37 | 'string': 'tom', 38 | 'int': 2, 39 | 'list': [1, 2], 40 | 'set': {1, 2}, 41 | 'tuple': (1, 2), 42 | 'ref': config_dict.FieldReference({'int': 0}), 43 | 'inner_dict_1': inner_dict, 44 | 'inner_dict_2': inner_dict 45 | } 46 | 47 | print_section('Initializing on dictionary.') 48 | # ConfigDict can be initialized on example_dict 49 | example_cd = config_dict.ConfigDict(example_dict) 50 | 51 | # Dictionary fields are also converted to ConfigDict 52 | print(type(example_cd.inner_dict_1)) 53 | 54 | # And the reference structure is preserved 55 | print(id(example_cd.inner_dict_1) == id(example_cd.inner_dict_2)) 56 | 57 | print_section('Initializing on ConfigDict.') 58 | 59 | # ConfigDict can also be initialized on a ConfigDict 60 | example_cd_cd = config_dict.ConfigDict(example_cd) 61 | 62 | # Yielding the same result: 63 | print(example_cd == example_cd_cd) 64 | 65 | # Note that the memory addresses are different 66 | print(id(example_cd) == id(example_cd_cd)) 67 | 68 | # The memory addresses of the attributes are not the same because of the 69 | # FieldReference, which gets removed on the second initialization 70 | list_to_ids = lambda x: [id(element) for element in x] 71 | print( 72 | set(list_to_ids(list(example_cd.values()))) == set( 73 | list_to_ids(list(example_cd_cd.values())))) 74 | 75 | print_section('Initializing on self-referencing dictionary.') 76 | 77 | # Initialization works on a self-referencing dict 78 | self_ref_dict = copy.deepcopy(example_dict) 79 | self_ref_dict['self'] = self_ref_dict 80 | self_ref_cd = config_dict.ConfigDict(self_ref_dict) 81 | 82 | # And the reference structure is replicated 83 | print(id(self_ref_cd) == id(self_ref_cd.self)) 84 | 85 | print_section('Unexpected initialization behavior.') 86 | 87 | # ConfigDict initialization doesn't look inside lists, so doesn't convert a 88 | # dict in a list to ConfigDict 89 | dict_in_list_in_dict = {'list': [{'troublemaker': 0}]} 90 | dict_in_list_in_dict_cd = config_dict.ConfigDict(dict_in_list_in_dict) 91 | print(type(dict_in_list_in_dict_cd.list[0])) 92 | 93 | # This can cause the reference structure to not be replicated 94 | referred_dict = {'key': 'value'} 95 | bad_reference = {'referred_dict': referred_dict, 'list': [referred_dict]} 96 | bad_reference_cd = config_dict.ConfigDict(bad_reference) 97 | print(id(bad_reference_cd.referred_dict) == id(bad_reference_cd.list[0])) 98 | 99 | 100 | if __name__ == '__main__': 101 | app.run() 102 | -------------------------------------------------------------------------------- /ml_collections/config_dict/examples/config_dict_lock.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example of ConfigDict usage of lock. 16 | 17 | This example shows the roles and scopes of ConfigDict's lock(). 18 | """ 19 | 20 | from absl import app 21 | from ml_collections import config_dict 22 | 23 | 24 | def main(_): 25 | cfg = config_dict.ConfigDict() 26 | cfg.integer_field = 123 27 | 28 | # Locking prohibits the addition and deletion of new fields but allows 29 | # modification of existing values. Locking happens automatically during 30 | # loading through flags. 31 | cfg.lock() 32 | try: 33 | cfg.intagar_field = 124 # Raises AttributeError and suggests valid field. 34 | except AttributeError as e: 35 | print(e) 36 | cfg.integer_field = -123 # Works fine. 37 | 38 | with cfg.unlocked(): 39 | cfg.intagar_field = 1555 # Works fine too. 40 | 41 | print(cfg) 42 | 43 | 44 | if __name__ == '__main__': 45 | app.run() 46 | -------------------------------------------------------------------------------- /ml_collections/config_dict/examples/config_dict_placeholder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example of placeholder fields in a ConfigDict. 16 | 17 | This example shows how ConfigDict placeholder fields work. For a more complete 18 | example of ConfigDict features, see example_advanced. 19 | """ 20 | 21 | from absl import app 22 | from ml_collections import config_dict 23 | 24 | 25 | def main(_): 26 | placeholder = config_dict.FieldReference(0) 27 | cfg = config_dict.ConfigDict() 28 | cfg.placeholder = placeholder 29 | cfg.optional = config_dict.FieldReference(0, field_type=int) 30 | cfg.nested = config_dict.ConfigDict() 31 | cfg.nested.placeholder = placeholder 32 | 33 | try: 34 | cfg.optional = 'tom' # Raises Type error as this field is an integer. 35 | except TypeError as e: 36 | print(e) 37 | 38 | cfg.optional = 1555 # Works fine. 39 | cfg.placeholder = 1 # Changes the value of both placeholder and 40 | # nested.placeholder fields. 41 | 42 | # Note that the indirection provided by FieldReferences will be lost if 43 | # accessed through a ConfigDict: 44 | placeholder = config_dict.FieldReference(0) 45 | cfg.field1 = placeholder 46 | cfg.field2 = placeholder # This field will be tied to cfg.field1. 47 | cfg.field3 = cfg.field1 # This will just be an int field initialized to 0. 48 | 49 | print(cfg) 50 | 51 | 52 | if __name__ == '__main__': 53 | app.run() 54 | -------------------------------------------------------------------------------- /ml_collections/config_dict/examples/examples_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ConfigDict examples. 16 | 17 | Ensures that config_dict_basic, config_dict_initialization, config_dict_lock, 18 | config_dict_placeholder, field_reference, frozen_config_dict run successfully. 19 | """ 20 | 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | from ml_collections.config_dict.examples import config_dict_advanced 24 | from ml_collections.config_dict.examples import config_dict_basic 25 | from ml_collections.config_dict.examples import config_dict_initialization 26 | from ml_collections.config_dict.examples import config_dict_lock 27 | from ml_collections.config_dict.examples import config_dict_placeholder 28 | from ml_collections.config_dict.examples import field_reference 29 | from ml_collections.config_dict.examples import frozen_config_dict 30 | 31 | 32 | class ConfigDictExamplesTest(parameterized.TestCase): 33 | 34 | @parameterized.parameters(config_dict_advanced, config_dict_basic, 35 | config_dict_initialization, config_dict_lock, 36 | config_dict_placeholder, field_reference, 37 | frozen_config_dict) 38 | def testScriptRuns(self, example_name): 39 | example_name.main(None) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /ml_collections/config_dict/examples/field_reference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example of FieldReference usage. 16 | 17 | This shows how to use FieldReferences for lazy computation. 18 | """ 19 | 20 | from absl import app 21 | from ml_collections import config_dict 22 | 23 | 24 | def lazy_computation(): 25 | """Simple example of lazy computation with `configdict.FieldReference`.""" 26 | ref = config_dict.FieldReference(1) 27 | print(ref.get()) # Prints 1 28 | 29 | add_ten = ref.get() + 10 # ref.get() is an integer and so is add_ten 30 | add_ten_lazy = ref + 10 # add_ten_lazy is a FieldReference - NOT an integer 31 | 32 | print(add_ten) # Prints 11 33 | print(add_ten_lazy.get()) # Prints 11 because ref's value is 1 34 | 35 | # Addition is lazily computed for FieldReferences so changing ref will change 36 | # the value that is used to compute add_ten. 37 | ref.set(5) 38 | print(add_ten) # Prints 11 39 | print(add_ten_lazy.get()) # Prints 15 because ref's value is 5 40 | 41 | 42 | def change_lazy_computation(): 43 | """Overriding lazily computed values.""" 44 | config = config_dict.ConfigDict() 45 | config.reference = 1 46 | config.reference_0 = config.get_ref('reference') + 10 47 | config.reference_1 = config.get_ref('reference') + 20 48 | config.reference_1_0 = config.get_ref('reference_1') + 100 49 | 50 | print(config.reference) # Prints 1. 51 | print(config.reference_0) # Prints 11. 52 | print(config.reference_1) # Prints 21. 53 | print(config.reference_1_0) # Prints 121. 54 | 55 | config.reference_1 = 30 56 | 57 | print(config.reference) # Prints 1 (unchanged). 58 | print(config.reference_0) # Prints 11 (unchanged). 59 | print(config.reference_1) # Prints 30. 60 | print(config.reference_1_0) # Prints 130. 61 | 62 | 63 | def create_cycle(): 64 | """Creates a cycle within a ConfigDict.""" 65 | config = config_dict.ConfigDict() 66 | config.integer_field = 1 67 | config.bigger_integer_field = config.get_ref('integer_field') + 10 68 | 69 | try: 70 | # Raises a MutabilityError because setting config.integer_field would 71 | # cause a cycle. 72 | config.integer_field = config.get_ref('bigger_integer_field') + 2 73 | except config_dict.MutabilityError as e: 74 | print(e) 75 | 76 | 77 | def lazy_configdict(): 78 | """Example usage of lazy computation with ConfigDict.""" 79 | config = config_dict.ConfigDict() 80 | config.reference_field = config_dict.FieldReference(1) 81 | config.integer_field = 2 82 | config.float_field = 2.5 83 | 84 | # No lazy evaluatuations because we didn't use get_ref() 85 | config.no_lazy = config.integer_field * config.float_field 86 | 87 | # This will lazily evaluate ONLY config.integer_field 88 | config.lazy_integer = config.get_ref('integer_field') * config.float_field 89 | 90 | # This will lazily evaluate ONLY config.float_field 91 | config.lazy_float = config.integer_field * config.get_ref('float_field') 92 | 93 | # This will lazily evaluate BOTH config.integer_field and config.float_Field 94 | config.lazy_both = (config.get_ref('integer_field') * 95 | config.get_ref('float_field')) 96 | 97 | config.integer_field = 3 98 | print(config.no_lazy) # Prints 5.0 - It uses integer_field's original value 99 | 100 | print(config.lazy_integer) # Prints 7.5 101 | 102 | config.float_field = 3.5 103 | print(config.lazy_float) # Prints 7.0 104 | print(config.lazy_both) # Prints 10.5 105 | 106 | 107 | def lazy_configdict_advanced(): 108 | """Advanced lazy computation with ConfigDict.""" 109 | # FieldReferences can be used with ConfigDict as well 110 | config = config_dict.ConfigDict() 111 | config.float_field = 12.6 112 | config.integer_field = 123 113 | config.list_field = [0, 1, 2] 114 | 115 | config.float_multiply_field = config.get_ref('float_field') * 3 116 | print(config.float_multiply_field) # Prints 37.8 117 | 118 | config.float_field = 10.0 119 | print(config.float_multiply_field) # Prints 30.0 120 | 121 | config.longer_list_field = config.get_ref('list_field') + [3, 4, 5] 122 | print(config.longer_list_field) # Prints [0, 1, 2, 3, 4, 5] 123 | 124 | config.list_field = [-1] 125 | print(config.longer_list_field) # Prints [-1, 3, 4, 5] 126 | 127 | # Both operands can be references 128 | config.ref_subtraction = ( 129 | config.get_ref('float_field') - config.get_ref('integer_field')) 130 | print(config.ref_subtraction) # Prints -113.0 131 | 132 | config.integer_field = 10 133 | print(config.ref_subtraction) # Prints 0.0 134 | 135 | 136 | def main(argv=()): 137 | del argv # Unused. 138 | lazy_computation() 139 | lazy_configdict() 140 | change_lazy_computation() 141 | create_cycle() 142 | lazy_configdict_advanced() 143 | 144 | 145 | if __name__ == '__main__': 146 | app.run() 147 | -------------------------------------------------------------------------------- /ml_collections/config_dict/examples/frozen_config_dict.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example of basic FrozenConfigDict usage. 16 | 17 | This example shows the most basic usage of FrozenConfigDict, highlighting 18 | the differences between FrozenConfigDict and ConfigDict and including 19 | converting between the two. 20 | """ 21 | 22 | from absl import app 23 | from ml_collections import config_dict 24 | 25 | 26 | def print_section(name): 27 | print() 28 | print() 29 | print('-' * len(name)) 30 | print(name.upper()) 31 | print('-' * len(name)) 32 | print() 33 | 34 | 35 | def main(_): 36 | print_section('Attribute Types.') 37 | cfg = config_dict.ConfigDict() 38 | cfg.int = 1 39 | cfg.list = [1, 2, 3] 40 | cfg.tuple = (1, 2, 3) 41 | cfg.set = {1, 2, 3} 42 | cfg.frozenset = frozenset({1, 2, 3}) 43 | cfg.dict = { 44 | 'nested_int': 4, 45 | 'nested_list': [4, 5, 6], 46 | 'nested_tuple': ([4], 5, 6), 47 | } 48 | 49 | print('Types of cfg fields:') 50 | print('list: ', type(cfg.list)) # List 51 | print('set: ', type(cfg.set)) # Set 52 | print('nested_list: ', type(cfg.dict.nested_list)) # List 53 | print('nested_tuple[0]: ', type(cfg.dict.nested_tuple[0])) # List 54 | 55 | frozen_cfg = config_dict.FrozenConfigDict(cfg) 56 | print('\nTypes of FrozenConfigDict(cfg) fields:') 57 | print('list: ', type(frozen_cfg.list)) # Tuple 58 | print('set: ', type(frozen_cfg.set)) # Frozenset 59 | print('nested_list: ', type(frozen_cfg.dict.nested_list)) # Tuple 60 | print('nested_tuple[0]: ', type(frozen_cfg.dict.nested_tuple[0])) # Tuple 61 | 62 | cfg_from_frozen = config_dict.ConfigDict(frozen_cfg) 63 | print('\nTypes of ConfigDict(FrozenConfigDict(cfg)) fields:') 64 | print('list: ', type(cfg_from_frozen.list)) # List 65 | print('set: ', type(cfg_from_frozen.set)) # Set 66 | print('nested_list: ', type(cfg_from_frozen.dict.nested_list)) # List 67 | print('nested_tuple[0]: ', type(cfg_from_frozen.dict.nested_tuple[0])) # List 68 | 69 | print('\nCan use FrozenConfigDict.as_configdict() to convert to ConfigDict:') 70 | print(cfg_from_frozen == frozen_cfg.as_configdict()) # True 71 | 72 | print_section('Immutability.') 73 | try: 74 | frozen_cfg.new_field = 1 # Raises AttributeError because of immutability. 75 | except AttributeError as e: 76 | print(e) 77 | 78 | print_section('"==" and eq_as_configdict().') 79 | # FrozenConfigDict.__eq__() is not type-invariant with respect to ConfigDict 80 | print(frozen_cfg == cfg) # False 81 | # FrozenConfigDict.eq_as_configdict() is type-invariant with respect to 82 | # ConfigDict 83 | print(frozen_cfg.eq_as_configdict(cfg)) # True 84 | # .eq_as_congfigdict() is also a method of ConfigDict 85 | print(cfg.eq_as_configdict(frozen_cfg)) # True 86 | 87 | 88 | if __name__ == '__main__': 89 | app.run(main) 90 | -------------------------------------------------------------------------------- /ml_collections/config_dict/tests/field_reference_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for config_dict.FieldReference.""" 16 | 17 | import operator 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from ml_collections import config_dict 22 | 23 | 24 | class FieldReferenceTest(parameterized.TestCase): 25 | 26 | def _test_binary_operator(self, 27 | initial_value, 28 | other_value, 29 | op, 30 | true_value, 31 | new_initial_value, 32 | new_true_value, 33 | assert_fn=None): 34 | """Helper for testing binary operators. 35 | 36 | Generally speaking this checks that: 37 | 1. `op(initial_value, other_value) COMP true_value` 38 | 2. `op(new_initial_value, other_value) COMP new_true_value 39 | where `COMP` is the comparison function defined by `assert_fn`. 40 | 41 | Args: 42 | initial_value: Initial value for the `FieldReference`, this is the first 43 | argument for the binary operator. 44 | other_value: The second argument for the binary operator. 45 | op: The binary operator. 46 | true_value: The expected output of the binary operator. 47 | new_initial_value: The value that the `FieldReference` is changed to. 48 | new_true_value: The expected output of the binary operator after the 49 | `FieldReference` has changed. 50 | assert_fn: Function used to check the output values. 51 | """ 52 | if assert_fn is None: 53 | assert_fn = self.assertEqual 54 | 55 | ref = config_dict.FieldReference(initial_value) 56 | new_ref = op(ref, other_value) 57 | assert_fn(new_ref.get(), true_value) 58 | 59 | config = config_dict.ConfigDict() 60 | config.a = initial_value 61 | config.b = other_value 62 | config.result = op(config.get_ref('a'), config.b) 63 | assert_fn(config.result, true_value) 64 | 65 | config.a = new_initial_value 66 | assert_fn(config.result, new_true_value) 67 | 68 | def _test_unary_operator(self, 69 | initial_value, 70 | op, 71 | true_value, 72 | new_initial_value, 73 | new_true_value, 74 | assert_fn=None): 75 | """Helper for testing unary operators. 76 | 77 | Generally speaking this checks that: 78 | 1. `op(initial_value) COMP true_value` 79 | 2. `op(new_initial_value) COMP new_true_value 80 | where `COMP` is the comparison function defined by `assert_fn`. 81 | 82 | Args: 83 | initial_value: Initial value for the `FieldReference`, this is the first 84 | argument for the unary operator. 85 | op: The unary operator. 86 | true_value: The expected output of the unary operator. 87 | new_initial_value: The value that the `FieldReference` is changed to. 88 | new_true_value: The expected output of the unary operator after the 89 | `FieldReference` has changed. 90 | assert_fn: Function used to check the output values. 91 | """ 92 | if assert_fn is None: 93 | assert_fn = self.assertEqual 94 | 95 | ref = config_dict.FieldReference(initial_value) 96 | new_ref = op(ref) 97 | assert_fn(new_ref.get(), true_value) 98 | 99 | config = config_dict.ConfigDict() 100 | config.a = initial_value 101 | config.result = op(config.get_ref('a')) 102 | assert_fn(config.result, true_value) 103 | 104 | config.a = new_initial_value 105 | assert_fn(config.result, new_true_value) 106 | 107 | def testBasic(self): 108 | ref = config_dict.FieldReference(1) 109 | self.assertEqual(ref.get(), 1) 110 | 111 | def testGetRef(self): 112 | config = config_dict.ConfigDict() 113 | config.a = 1. 114 | config.b = config.get_ref('a') + 10 115 | config.c = config.get_ref('b') + 10 116 | self.assertEqual(config.c, 21.0) 117 | 118 | def testFunction(self): 119 | 120 | def fn(x): 121 | return x + 5 122 | 123 | config = config_dict.ConfigDict() 124 | config.a = 1 125 | config.b = fn(config.get_ref('a')) 126 | config.c = fn(config.get_ref('b')) 127 | 128 | self.assertEqual(config.b, 6) 129 | self.assertEqual(config.c, 11) 130 | config.a = 2 131 | self.assertEqual(config.b, 7) 132 | self.assertEqual(config.c, 12) 133 | 134 | def testCycles(self): 135 | config = config_dict.ConfigDict() 136 | config.a = 1. 137 | config.b = config.get_ref('a') + 10 138 | config.c = config.get_ref('b') + 10 139 | 140 | self.assertEqual(config.b, 11.0) 141 | self.assertEqual(config.c, 21.0) 142 | 143 | # Introduce a cycle 144 | with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'): 145 | config.a = config.get_ref('c') - 1.0 146 | 147 | # Introduce a cycle on second operand 148 | with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'): 149 | config.a = config_dict.FieldReference(5.0) + config.get_ref('c') 150 | 151 | # We can create multiple FieldReferences that all point to the same object 152 | l = [0] 153 | config = config_dict.ConfigDict() 154 | config.a = l 155 | config.b = l 156 | config.c = config.get_ref('a') + ['c'] 157 | config.d = config.get_ref('b') + ['d'] 158 | 159 | self.assertEqual(config.c, [0, 'c']) 160 | self.assertEqual(config.d, [0, 'd']) 161 | 162 | # Make sure nothing was mutated 163 | self.assertEqual(l, [0]) 164 | self.assertEqual(config.c, [0, 'c']) 165 | 166 | config.a = [1] 167 | config.b = [2] 168 | self.assertEqual(l, [0]) 169 | self.assertEqual(config.c, [1, 'c']) 170 | self.assertEqual(config.d, [2, 'd']) 171 | 172 | @parameterized.parameters( 173 | { 174 | 'initial_value': 1, 175 | 'other_value': 2, 176 | 'true_value': 3, 177 | 'new_initial_value': 10, 178 | 'new_true_value': 12 179 | }, { 180 | 'initial_value': 2.0, 181 | 'other_value': 2.5, 182 | 'true_value': 4.5, 183 | 'new_initial_value': 3.7, 184 | 'new_true_value': 6.2 185 | }, { 186 | 'initial_value': 'hello, ', 187 | 'other_value': 'world!', 188 | 'true_value': 'hello, world!', 189 | 'new_initial_value': 'foo, ', 190 | 'new_true_value': 'foo, world!' 191 | }, { 192 | 'initial_value': ['hello'], 193 | 'other_value': ['world'], 194 | 'true_value': ['hello', 'world'], 195 | 'new_initial_value': ['foo'], 196 | 'new_true_value': ['foo', 'world'] 197 | }, { 198 | 'initial_value': config_dict.FieldReference(10), 199 | 'other_value': config_dict.FieldReference(5.0), 200 | 'true_value': 15.0, 201 | 'new_initial_value': 12, 202 | 'new_true_value': 17.0 203 | }, { 204 | 'initial_value': config_dict.placeholder(float), 205 | 'other_value': 7.0, 206 | 'true_value': None, 207 | 'new_initial_value': 12, 208 | 'new_true_value': 19.0 209 | }, { 210 | 'initial_value': 5.0, 211 | 'other_value': config_dict.placeholder(float), 212 | 'true_value': None, 213 | 'new_initial_value': 8.0, 214 | 'new_true_value': None 215 | }, { 216 | 'initial_value': config_dict.placeholder(str), 217 | 'other_value': 'tail', 218 | 'true_value': None, 219 | 'new_initial_value': 'head', 220 | 'new_true_value': 'headtail' 221 | }) 222 | def testAdd(self, initial_value, other_value, true_value, new_initial_value, 223 | new_true_value): 224 | self._test_binary_operator(initial_value, other_value, operator.add, 225 | true_value, new_initial_value, new_true_value) 226 | 227 | @parameterized.parameters( 228 | { 229 | 'initial_value': 5, 230 | 'other_value': 3, 231 | 'true_value': 2, 232 | 'new_initial_value': -1, 233 | 'new_true_value': -4 234 | }, { 235 | 'initial_value': 2.0, 236 | 'other_value': 2.5, 237 | 'true_value': -0.5, 238 | 'new_initial_value': 12.3, 239 | 'new_true_value': 9.8 240 | }, { 241 | 'initial_value': set(['hello', 123, 4.5]), 242 | 'other_value': set([123]), 243 | 'true_value': set(['hello', 4.5]), 244 | 'new_initial_value': set([123]), 245 | 'new_true_value': set([]) 246 | }, { 247 | 'initial_value': config_dict.FieldReference(10), 248 | 'other_value': config_dict.FieldReference(5.0), 249 | 'true_value': 5.0, 250 | 'new_initial_value': 12, 251 | 'new_true_value': 7.0 252 | }, { 253 | 'initial_value': config_dict.placeholder(float), 254 | 'other_value': 7.0, 255 | 'true_value': None, 256 | 'new_initial_value': 12, 257 | 'new_true_value': 5.0 258 | }) 259 | def testSub(self, initial_value, other_value, true_value, new_initial_value, 260 | new_true_value): 261 | self._test_binary_operator(initial_value, other_value, operator.sub, 262 | true_value, new_initial_value, new_true_value) 263 | 264 | @parameterized.parameters( 265 | { 266 | 'initial_value': 1, 267 | 'other_value': 2, 268 | 'true_value': 2, 269 | 'new_initial_value': 3, 270 | 'new_true_value': 6 271 | }, { 272 | 'initial_value': 2.0, 273 | 'other_value': 2.5, 274 | 'true_value': 5.0, 275 | 'new_initial_value': 3.5, 276 | 'new_true_value': 8.75 277 | }, { 278 | 'initial_value': ['hello'], 279 | 'other_value': 3, 280 | 'true_value': ['hello', 'hello', 'hello'], 281 | 'new_initial_value': ['foo'], 282 | 'new_true_value': ['foo', 'foo', 'foo'] 283 | }, { 284 | 'initial_value': config_dict.FieldReference(10), 285 | 'other_value': config_dict.FieldReference(5.0), 286 | 'true_value': 50.0, 287 | 'new_initial_value': 1, 288 | 'new_true_value': 5.0 289 | }, { 290 | 'initial_value': config_dict.placeholder(float), 291 | 'other_value': 7.0, 292 | 'true_value': None, 293 | 'new_initial_value': 12, 294 | 'new_true_value': 84.0 295 | }) 296 | def testMul(self, initial_value, other_value, true_value, new_initial_value, 297 | new_true_value): 298 | self._test_binary_operator(initial_value, other_value, operator.mul, 299 | true_value, new_initial_value, new_true_value) 300 | 301 | @parameterized.parameters( 302 | { 303 | 'initial_value': 3, 304 | 'other_value': 2, 305 | 'true_value': 1.5, 306 | 'new_initial_value': 10, 307 | 'new_true_value': 5.0 308 | }, { 309 | 'initial_value': 2.0, 310 | 'other_value': 2.5, 311 | 'true_value': 0.8, 312 | 'new_initial_value': 6.3, 313 | 'new_true_value': 2.52 314 | }, { 315 | 'initial_value': config_dict.FieldReference(10), 316 | 'other_value': config_dict.FieldReference(5.0), 317 | 'true_value': 2.0, 318 | 'new_initial_value': 13, 319 | 'new_true_value': 2.6 320 | }, { 321 | 'initial_value': config_dict.placeholder(float), 322 | 'other_value': 7.0, 323 | 'true_value': None, 324 | 'new_initial_value': 17.5, 325 | 'new_true_value': 2.5 326 | }) 327 | def testTrueDiv(self, initial_value, other_value, true_value, 328 | new_initial_value, new_true_value): 329 | self._test_binary_operator(initial_value, other_value, operator.truediv, 330 | true_value, new_initial_value, new_true_value) 331 | 332 | @parameterized.parameters( 333 | { 334 | 'initial_value': 3, 335 | 'other_value': 2, 336 | 'true_value': 1, 337 | 'new_initial_value': 7, 338 | 'new_true_value': 3 339 | }, { 340 | 'initial_value': config_dict.FieldReference(10), 341 | 'other_value': config_dict.FieldReference(5), 342 | 'true_value': 2, 343 | 'new_initial_value': 28, 344 | 'new_true_value': 5 345 | }, { 346 | 'initial_value': config_dict.placeholder(int), 347 | 'other_value': 7, 348 | 'true_value': None, 349 | 'new_initial_value': 25, 350 | 'new_true_value': 3 351 | }) 352 | def testFloorDiv(self, initial_value, other_value, true_value, 353 | new_initial_value, new_true_value): 354 | self._test_binary_operator(initial_value, other_value, operator.floordiv, 355 | true_value, new_initial_value, new_true_value) 356 | 357 | @parameterized.parameters( 358 | { 359 | 'initial_value': 3, 360 | 'other_value': 2, 361 | 'true_value': 9, 362 | 'new_initial_value': 10, 363 | 'new_true_value': 100 364 | }, { 365 | 'initial_value': 2.7, 366 | 'other_value': 3.2, 367 | 'true_value': 24.0084457245, 368 | 'new_initial_value': 6.5, 369 | 'new_true_value': 399.321543621 370 | }, { 371 | 'initial_value': config_dict.FieldReference(10), 372 | 'other_value': config_dict.FieldReference(5), 373 | 'true_value': 1e5, 374 | 'new_initial_value': 2, 375 | 'new_true_value': 32 376 | }, { 377 | 'initial_value': config_dict.placeholder(float), 378 | 'other_value': 3.0, 379 | 'true_value': None, 380 | 'new_initial_value': 7.0, 381 | 'new_true_value': 343.0 382 | }) 383 | def testPow(self, initial_value, other_value, true_value, new_initial_value, 384 | new_true_value): 385 | self._test_binary_operator( 386 | initial_value, 387 | other_value, 388 | operator.pow, 389 | true_value, 390 | new_initial_value, 391 | new_true_value, 392 | assert_fn=self.assertAlmostEqual) 393 | 394 | @parameterized.parameters( 395 | { 396 | 'initial_value': 3, 397 | 'other_value': 2, 398 | 'true_value': 1, 399 | 'new_initial_value': 10, 400 | 'new_true_value': 0 401 | }, { 402 | 'initial_value': 5.3, 403 | 'other_value': 3.2, 404 | 'true_value': 2.0999999999999996, 405 | 'new_initial_value': 77, 406 | 'new_true_value': 0.2 407 | }, { 408 | 'initial_value': config_dict.FieldReference(10), 409 | 'other_value': config_dict.FieldReference(5), 410 | 'true_value': 0, 411 | 'new_initial_value': 32, 412 | 'new_true_value': 2 413 | }, { 414 | 'initial_value': config_dict.placeholder(int), 415 | 'other_value': 7, 416 | 'true_value': None, 417 | 'new_initial_value': 25, 418 | 'new_true_value': 4 419 | }) 420 | def testMod(self, initial_value, other_value, true_value, new_initial_value, 421 | new_true_value): 422 | self._test_binary_operator( 423 | initial_value, 424 | other_value, 425 | operator.mod, 426 | true_value, 427 | new_initial_value, 428 | new_true_value, 429 | assert_fn=self.assertAlmostEqual) 430 | 431 | @parameterized.parameters( 432 | { 433 | 'initial_value': True, 434 | 'other_value': True, 435 | 'true_value': True, 436 | 'new_initial_value': False, 437 | 'new_true_value': False 438 | }, { 439 | 'initial_value': config_dict.FieldReference(False), 440 | 'other_value': config_dict.FieldReference(False), 441 | 'true_value': False, 442 | 'new_initial_value': True, 443 | 'new_true_value': False 444 | }, { 445 | 'initial_value': config_dict.placeholder(bool), 446 | 'other_value': True, 447 | 'true_value': None, 448 | 'new_initial_value': False, 449 | 'new_true_value': False 450 | }) 451 | def testAnd(self, initial_value, other_value, true_value, new_initial_value, 452 | new_true_value): 453 | self._test_binary_operator(initial_value, other_value, operator.and_, 454 | true_value, new_initial_value, new_true_value) 455 | 456 | @parameterized.parameters( 457 | { 458 | 'initial_value': False, 459 | 'other_value': False, 460 | 'true_value': False, 461 | 'new_initial_value': True, 462 | 'new_true_value': True 463 | }, { 464 | 'initial_value': config_dict.FieldReference(True), 465 | 'other_value': config_dict.FieldReference(True), 466 | 'true_value': True, 467 | 'new_initial_value': False, 468 | 'new_true_value': True 469 | }, { 470 | 'initial_value': config_dict.placeholder(bool), 471 | 'other_value': False, 472 | 'true_value': None, 473 | 'new_initial_value': True, 474 | 'new_true_value': True 475 | }) 476 | def testOr(self, initial_value, other_value, true_value, new_initial_value, 477 | new_true_value): 478 | self._test_binary_operator(initial_value, other_value, operator.or_, 479 | true_value, new_initial_value, new_true_value) 480 | 481 | @parameterized.parameters( 482 | { 483 | 'initial_value': False, 484 | 'other_value': True, 485 | 'true_value': True, 486 | 'new_initial_value': True, 487 | 'new_true_value': False 488 | }, { 489 | 'initial_value': config_dict.FieldReference(True), 490 | 'other_value': config_dict.FieldReference(True), 491 | 'true_value': False, 492 | 'new_initial_value': False, 493 | 'new_true_value': True 494 | }, { 495 | 'initial_value': config_dict.placeholder(bool), 496 | 'other_value': True, 497 | 'true_value': None, 498 | 'new_initial_value': True, 499 | 'new_true_value': False 500 | }) 501 | def testXor(self, initial_value, other_value, true_value, new_initial_value, 502 | new_true_value): 503 | self._test_binary_operator(initial_value, other_value, operator.xor, 504 | true_value, new_initial_value, new_true_value) 505 | 506 | @parameterized.parameters( 507 | { 508 | 'initial_value': 3, 509 | 'true_value': -3, 510 | 'new_initial_value': -22, 511 | 'new_true_value': 22 512 | }, { 513 | 'initial_value': 15.3, 514 | 'true_value': -15.3, 515 | 'new_initial_value': -0.2, 516 | 'new_true_value': 0.2 517 | }, { 518 | 'initial_value': config_dict.FieldReference(7), 519 | 'true_value': config_dict.FieldReference(-7), 520 | 'new_initial_value': 123, 521 | 'new_true_value': -123 522 | }, { 523 | 'initial_value': config_dict.placeholder(int), 524 | 'true_value': None, 525 | 'new_initial_value': -6, 526 | 'new_true_value': 6 527 | }) 528 | def testNeg(self, initial_value, true_value, new_initial_value, 529 | new_true_value): 530 | self._test_unary_operator(initial_value, operator.neg, true_value, 531 | new_initial_value, new_true_value) 532 | 533 | @parameterized.parameters( 534 | { 535 | 'initial_value': config_dict.create(attribute=2), 536 | 'true_value': 2, 537 | 'new_initial_value': config_dict.create(attribute=3), 538 | 'new_true_value': 3, 539 | }, 540 | { 541 | 'initial_value': config_dict.create(attribute={'a': 1}), 542 | 'true_value': config_dict.create(a=1), 543 | 'new_initial_value': config_dict.create(attribute={'b': 1}), 544 | 'new_true_value': config_dict.create(b=1), 545 | }, 546 | { 547 | 'initial_value': 548 | config_dict.FieldReference(config_dict.create(attribute=2)), 549 | 'true_value': 550 | config_dict.FieldReference(2), 551 | 'new_initial_value': 552 | config_dict.create(attribute=3), 553 | 'new_true_value': 554 | 3, 555 | }, 556 | { 557 | 'initial_value': config_dict.placeholder(config_dict.ConfigDict), 558 | 'true_value': None, 559 | 'new_initial_value': config_dict.create(attribute=3), 560 | 'new_true_value': 3, 561 | }, 562 | ) 563 | def testAttr(self, initial_value, true_value, new_initial_value, 564 | new_true_value): 565 | self._test_unary_operator(initial_value, lambda x: x.attr('attribute'), 566 | true_value, new_initial_value, new_true_value) 567 | 568 | @parameterized.parameters( 569 | { 570 | 'initial_value': 3, 571 | 'true_value': 3, 572 | 'new_initial_value': -101, 573 | 'new_true_value': 101 574 | }, { 575 | 'initial_value': -15.3, 576 | 'true_value': 15.3, 577 | 'new_initial_value': 7.3, 578 | 'new_true_value': 7.3 579 | }, { 580 | 'initial_value': config_dict.FieldReference(-7), 581 | 'true_value': config_dict.FieldReference(7), 582 | 'new_initial_value': 3, 583 | 'new_true_value': 3 584 | }, { 585 | 'initial_value': config_dict.placeholder(float), 586 | 'true_value': None, 587 | 'new_initial_value': -6.25, 588 | 'new_true_value': 6.25 589 | }) 590 | def testAbs(self, initial_value, true_value, new_initial_value, 591 | new_true_value): 592 | self._test_unary_operator(initial_value, operator.abs, true_value, 593 | new_initial_value, new_true_value) 594 | 595 | def testToInt(self): 596 | self._test_unary_operator(25.3, lambda ref: ref.to_int(), 25, 27.9, 27) 597 | ref = config_dict.FieldReference(64.7) 598 | ref = ref.to_int() 599 | self.assertEqual(ref.get(), 64) 600 | self.assertEqual(ref._field_type, int) 601 | 602 | def testToFloat(self): 603 | self._test_unary_operator(12, lambda ref: ref.to_float(), 12.0, 0, 0.0) 604 | 605 | ref = config_dict.FieldReference(647) 606 | ref = ref.to_float() 607 | self.assertEqual(ref.get(), 647.0) 608 | self.assertEqual(ref._field_type, float) 609 | 610 | def testToString(self): 611 | self._test_unary_operator(12, lambda ref: ref.to_str(), '12', 0, '0') 612 | 613 | ref = config_dict.FieldReference(647) 614 | ref = ref.to_str() 615 | self.assertEqual(ref.get(), '647') 616 | self.assertEqual(ref._field_type, str) 617 | 618 | def testSetValue(self): 619 | ref = config_dict.FieldReference(1.0) 620 | other = config_dict.FieldReference(3) 621 | ref_plus_other = ref + other 622 | 623 | self.assertEqual(ref_plus_other.get(), 4.0) 624 | 625 | ref.set(2.5) 626 | self.assertEqual(ref_plus_other.get(), 5.5) 627 | 628 | other.set(110) 629 | self.assertEqual(ref_plus_other.get(), 112.5) 630 | 631 | # Type checking 632 | with self.assertRaises(TypeError): 633 | other.set('this is a string') 634 | 635 | with self.assertRaises(TypeError): 636 | other.set(config_dict.FieldReference('this is a string')) 637 | 638 | with self.assertRaises(TypeError): 639 | other.set(config_dict.FieldReference(None, field_type=str)) 640 | 641 | def testSetValueSubclass(self): 642 | class A: 643 | pass 644 | 645 | class B(A): 646 | pass 647 | 648 | ref = config_dict.FieldReference(A()) 649 | ref.set(B()) # Can assign subclasses 650 | 651 | def testSetResult(self): 652 | ref = config_dict.FieldReference(1.0) 653 | result = ref + 1.0 654 | second_result = result + 1.0 655 | 656 | self.assertEqual(ref.get(), 1.0) 657 | self.assertEqual(result.get(), 2.0) 658 | self.assertEqual(second_result.get(), 3.0) 659 | 660 | ref.set(2.0) 661 | self.assertEqual(ref.get(), 2.0) 662 | self.assertEqual(result.get(), 3.0) 663 | self.assertEqual(second_result.get(), 4.0) 664 | 665 | result.set(4.0) 666 | self.assertEqual(ref.get(), 2.0) 667 | self.assertEqual(result.get(), 4.0) 668 | self.assertEqual(second_result.get(), 5.0) 669 | 670 | # All references are broken at this point. 671 | ref.set(1.0) 672 | self.assertEqual(ref.get(), 1.0) 673 | self.assertEqual(result.get(), 4.0) 674 | self.assertEqual(second_result.get(), 5.0) 675 | 676 | def testTypeChecking(self): 677 | ref = config_dict.FieldReference(1) 678 | string_ref = config_dict.FieldReference('a') 679 | 680 | x = ref + string_ref 681 | with self.assertRaises(TypeError): 682 | x.get() 683 | 684 | def testNoType(self): 685 | self.assertRaisesRegex(TypeError, 'field_type should be a type.*', 686 | config_dict.FieldReference, None, 0) 687 | 688 | def testEqual(self): 689 | # Simple case 690 | ref1 = config_dict.FieldReference(1) 691 | ref2 = config_dict.FieldReference(1) 692 | ref3 = config_dict.FieldReference(2) 693 | self.assertEqual(ref1, 1) 694 | self.assertEqual(ref1, ref1) 695 | self.assertEqual(ref1, ref2) 696 | self.assertNotEqual(ref1, 2) 697 | self.assertNotEqual(ref1, ref3) 698 | 699 | # ConfigDict inside FieldReference 700 | ref1 = config_dict.FieldReference(config_dict.ConfigDict({'a': 1})) 701 | ref2 = config_dict.FieldReference(config_dict.ConfigDict({'a': 1})) 702 | ref3 = config_dict.FieldReference(config_dict.ConfigDict({'a': 2})) 703 | self.assertEqual(ref1, config_dict.ConfigDict({'a': 1})) 704 | self.assertEqual(ref1, ref1) 705 | self.assertEqual(ref1, ref2) 706 | self.assertNotEqual(ref1, config_dict.ConfigDict({'a': 2})) 707 | self.assertNotEqual(ref1, ref3) 708 | 709 | def testLessEqual(self): 710 | # Simple case 711 | ref1 = config_dict.FieldReference(1) 712 | ref2 = config_dict.FieldReference(1) 713 | ref3 = config_dict.FieldReference(2) 714 | self.assertLessEqual(ref1, 1) 715 | self.assertLessEqual(ref1, 2) 716 | self.assertLessEqual(0, ref1) 717 | self.assertLessEqual(1, ref1) 718 | self.assertGreater(ref1, 0) 719 | 720 | self.assertLessEqual(ref1, ref1) 721 | self.assertLessEqual(ref1, ref2) 722 | self.assertLessEqual(ref1, ref3) 723 | self.assertGreater(ref3, ref1) 724 | 725 | def testControlFlowError(self): 726 | ref1 = config_dict.FieldReference(True) 727 | ref2 = config_dict.FieldReference(False) 728 | 729 | with self.assertRaises(NotImplementedError): 730 | if ref1: 731 | pass 732 | with self.assertRaises(NotImplementedError): 733 | _ = ref1 and ref2 734 | with self.assertRaises(NotImplementedError): 735 | _ = ref1 or ref2 736 | with self.assertRaises(NotImplementedError): 737 | _ = not ref1 738 | 739 | 740 | if __name__ == '__main__': 741 | absltest.main() 742 | -------------------------------------------------------------------------------- /ml_collections/config_dict/tests/frozen_config_dict_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for config_dict.FrozenConfigDict.""" 16 | 17 | from collections import abc as collections_abc 18 | import copy 19 | import pickle 20 | 21 | from absl.testing import absltest 22 | from ml_collections import config_dict 23 | 24 | _TEST_DICT = { 25 | 'int': 2, 26 | 'list': [1, 2], 27 | 'nested_list': [[1, [2]]], 28 | 'set': {1, 2}, 29 | 'tuple': (1, 2), 30 | 'frozenset': frozenset({1, 2}), 31 | 'dict': { 32 | 'float': -1.23, 33 | 'list': [1, 2], 34 | 'dict': {}, 35 | 'tuple_containing_list': (1, 2, (3, [4, 5], (6, 7))), 36 | 'list_containing_tuple': [1, 2, [3, 4], (5, 6)], 37 | }, 38 | 'ref': config_dict.FieldReference({'int': 0}) 39 | } 40 | 41 | 42 | def _test_dict_deepcopy(): 43 | return copy.deepcopy(_TEST_DICT) 44 | 45 | 46 | def _test_configdict(): 47 | return config_dict.ConfigDict(_TEST_DICT) 48 | 49 | 50 | def _test_frozenconfigdict(): 51 | return config_dict.FrozenConfigDict(_TEST_DICT) 52 | 53 | 54 | class FrozenConfigDictTest(absltest.TestCase): 55 | """Tests FrozenConfigDict in config flags library.""" 56 | 57 | def assertFrozenRaisesValueError(self, input_list): 58 | """Assert initialization on all elements of input_list raise ValueError.""" 59 | for initial_dictionary in input_list: 60 | with self.assertRaises(ValueError): 61 | _ = config_dict.FrozenConfigDict(initial_dictionary) 62 | 63 | def testBasicEquality(self): 64 | """Tests basic equality with different types of initialization.""" 65 | fcd = _test_frozenconfigdict() 66 | fcd_cd = config_dict.FrozenConfigDict(_test_configdict()) 67 | fcd_fcd = config_dict.FrozenConfigDict(fcd) 68 | self.assertEqual(fcd, fcd_cd) 69 | self.assertEqual(fcd, fcd_fcd) 70 | 71 | def testImmutability(self): 72 | """Tests immutability of frozen config.""" 73 | fcd = _test_frozenconfigdict() 74 | self.assertEqual(fcd.list, tuple(_TEST_DICT['list'])) 75 | self.assertEqual(fcd.tuple, _TEST_DICT['tuple']) 76 | self.assertEqual(fcd.set, frozenset(_TEST_DICT['set'])) 77 | self.assertEqual(fcd.frozenset, _TEST_DICT['frozenset']) 78 | # Must manually check set to frozenset conversion, since Python == does not 79 | self.assertIsInstance(fcd.set, frozenset) 80 | 81 | self.assertEqual(fcd.dict.list, tuple(_TEST_DICT['dict']['list'])) 82 | self.assertNotEqual(fcd.dict.tuple_containing_list, 83 | _TEST_DICT['dict']['tuple_containing_list']) 84 | self.assertEqual(fcd.dict.tuple_containing_list[2][1], 85 | tuple(_TEST_DICT['dict']['tuple_containing_list'][2][1])) 86 | self.assertIsInstance(fcd.dict, config_dict.FrozenConfigDict) 87 | 88 | with self.assertRaises(AttributeError): 89 | fcd.newitem = 0 90 | with self.assertRaises(AttributeError): 91 | fcd.dict.int = 0 92 | with self.assertRaises(AttributeError): 93 | fcd['newitem'] = 0 94 | with self.assertRaises(AttributeError): 95 | del fcd.int 96 | with self.assertRaises(AttributeError): 97 | del fcd['int'] 98 | 99 | def testLockAndFreeze(self): 100 | """Ensures .lock() and .freeze() raise errors.""" 101 | fcd = _test_frozenconfigdict() 102 | 103 | self.assertFalse(fcd.is_locked) 104 | self.assertFalse(fcd.as_configdict().is_locked) 105 | 106 | with self.assertRaises(AttributeError): 107 | fcd.lock() 108 | with self.assertRaises(AttributeError): 109 | fcd.unlock() 110 | with self.assertRaises(AttributeError): 111 | fcd.freeze() 112 | with self.assertRaises(AttributeError): 113 | fcd.unfreeze() 114 | 115 | def testInitConfigDict(self): 116 | """Tests that ConfigDict initialization handles FrozenConfigDict. 117 | 118 | Initializing a ConfigDict on a dictionary with FrozenConfigDict values 119 | should unfreeze these values. 120 | """ 121 | dict_without_fcd_node = _test_dict_deepcopy() 122 | dict_without_fcd_node.pop('ref') 123 | dict_with_fcd_node = copy.deepcopy(dict_without_fcd_node) 124 | dict_with_fcd_node['dict'] = config_dict.FrozenConfigDict( 125 | dict_with_fcd_node['dict']) 126 | cd_without_fcd_node = config_dict.ConfigDict(dict_without_fcd_node) 127 | cd_with_fcd_node = config_dict.ConfigDict(dict_with_fcd_node) 128 | fcd_without_fcd_node = config_dict.FrozenConfigDict( 129 | dict_without_fcd_node) 130 | fcd_with_fcd_node = config_dict.FrozenConfigDict(dict_with_fcd_node) 131 | 132 | self.assertEqual(cd_without_fcd_node, cd_with_fcd_node) 133 | self.assertEqual(fcd_without_fcd_node, fcd_with_fcd_node) 134 | 135 | def testInitCopying(self): 136 | """Tests that initialization copies when and only when necessary. 137 | 138 | Ensures copying only occurs when converting mutable type to immutable type, 139 | regardless of whether the FrozenConfigDict is initialized by a dict or a 140 | FrozenConfigDict. Also ensures no copying occurs when converting from 141 | FrozenConfigDict back to ConfigDict. 142 | """ 143 | fcd = _test_frozenconfigdict() 144 | 145 | # These should be uncopied when creating fcd 146 | fcd_unchanged_from_test_dict = [ 147 | (_TEST_DICT['tuple'], fcd.tuple), 148 | (_TEST_DICT['frozenset'], fcd.frozenset), 149 | (_TEST_DICT['dict']['tuple_containing_list'][2][2], 150 | fcd.dict.tuple_containing_list[2][2]), 151 | (_TEST_DICT['dict']['list_containing_tuple'][3], 152 | fcd.dict.list_containing_tuple[3]) 153 | ] 154 | 155 | # These should be copied when creating fcd 156 | fcd_different_from_test_dict = [ 157 | (_TEST_DICT['list'], fcd.list), 158 | (_TEST_DICT['dict']['tuple_containing_list'][2][1], 159 | fcd.dict.tuple_containing_list[2][1]) 160 | ] 161 | 162 | for (x, y) in fcd_unchanged_from_test_dict: 163 | self.assertEqual(id(x), id(y)) 164 | for (x, y) in fcd_different_from_test_dict: 165 | self.assertNotEqual(id(x), id(y)) 166 | 167 | # Also make sure that converting back to ConfigDict makes no copies 168 | self.assertEqual( 169 | id(_TEST_DICT['dict']['tuple_containing_list']), 170 | id(config_dict.ConfigDict(fcd).dict.tuple_containing_list)) 171 | 172 | def testAsConfigDict(self): 173 | """Tests that converting FrozenConfigDict to ConfigDict works correctly. 174 | 175 | In particular, ensures that FrozenConfigDict does the inverse of ConfigDict 176 | regarding type_safe, lock, and attribute mutability. 177 | """ 178 | # First ensure conversion to ConfigDict works on empty FrozenConfigDict 179 | self.assertEqual( 180 | config_dict.ConfigDict(config_dict.FrozenConfigDict()), 181 | config_dict.ConfigDict()) 182 | 183 | cd = _test_configdict() 184 | cd_fcd_cd = config_dict.ConfigDict(config_dict.FrozenConfigDict(cd)) 185 | self.assertEqual(cd, cd_fcd_cd) 186 | 187 | # Make sure locking is respected 188 | cd.lock() 189 | self.assertEqual( 190 | cd, config_dict.ConfigDict(config_dict.FrozenConfigDict(cd))) 191 | 192 | # Make sure type_safe is respected 193 | cd = config_dict.ConfigDict(_TEST_DICT, type_safe=False) 194 | self.assertEqual( 195 | cd, config_dict.ConfigDict(config_dict.FrozenConfigDict(cd))) 196 | 197 | def testInitSelfReferencing(self): 198 | """Ensure initialization fails on self-referencing dicts.""" 199 | self_ref = {} 200 | self_ref['self'] = self_ref 201 | parent_ref = {'dict': {}} 202 | parent_ref['dict']['parent'] = parent_ref 203 | tuple_parent_ref = {'dict': {}} 204 | tuple_parent_ref['dict']['tuple'] = (1, 2, tuple_parent_ref) 205 | attribute_cycle = {'dict': copy.deepcopy(self_ref)} 206 | 207 | self.assertFrozenRaisesValueError( 208 | [self_ref, parent_ref, tuple_parent_ref, attribute_cycle]) 209 | 210 | def testInitCycles(self): 211 | """Ensure initialization fails if an attribute of input is cyclic.""" 212 | inner_cyclic_list = [1, 2] 213 | cyclic_list = [3, inner_cyclic_list] 214 | inner_cyclic_list.append(cyclic_list) 215 | cyclic_tuple = tuple(cyclic_list) 216 | 217 | test_dict_cyclic_list = _test_dict_deepcopy() 218 | test_dict_cyclic_tuple = _test_dict_deepcopy() 219 | 220 | test_dict_cyclic_list['cyclic_list'] = cyclic_list 221 | test_dict_cyclic_tuple['dict']['cyclic_tuple'] = cyclic_tuple 222 | 223 | self.assertFrozenRaisesValueError( 224 | [test_dict_cyclic_list, test_dict_cyclic_tuple]) 225 | 226 | def testInitDictInList(self): 227 | """Ensure initialization fails on dict and ConfigDict in lists/tuples.""" 228 | list_containing_dict = {'list': [1, 2, 3, {'a': 4, 'b': 5}]} 229 | tuple_containing_dict = {'tuple': (1, 2, 3, {'a': 4, 'b': 5})} 230 | list_containing_cd = {'list': [1, 2, 3, _test_configdict()]} 231 | tuple_containing_cd = {'tuple': (1, 2, 3, _test_configdict())} 232 | fr_containing_list_containing_dict = { 233 | 'fr': config_dict.FieldReference([1, { 234 | 'a': 2 235 | }]) 236 | } 237 | 238 | self.assertFrozenRaisesValueError([ 239 | list_containing_dict, tuple_containing_dict, list_containing_cd, 240 | tuple_containing_cd, fr_containing_list_containing_dict 241 | ]) 242 | 243 | def testInitFieldReferenceInList(self): 244 | """Ensure initialization fails on FieldReferences in lists/tuples.""" 245 | list_containing_fr = {'list': [1, 2, 3, config_dict.FieldReference(4)]} 246 | tuple_containing_fr = { 247 | 'tuple': (1, 2, 3, config_dict.FieldReference('a')) 248 | } 249 | 250 | self.assertFrozenRaisesValueError([list_containing_fr, tuple_containing_fr]) 251 | 252 | def testInitInvalidAttributeName(self): 253 | """Ensure initialization fails on attributes with invalid names.""" 254 | dot_name = {'dot.name': None} 255 | immutable_name = {'__hash__': None} 256 | 257 | with self.assertRaises(ValueError): 258 | config_dict.FrozenConfigDict(dot_name) 259 | 260 | with self.assertRaises(AttributeError): 261 | config_dict.FrozenConfigDict(immutable_name) 262 | 263 | def testFieldReferenceResolved(self): 264 | """Tests that FieldReferences are resolved.""" 265 | cfg = config_dict.ConfigDict({'fr': config_dict.FieldReference(1)}) 266 | frozen_cfg = config_dict.FrozenConfigDict(cfg) 267 | self.assertNotIsInstance(frozen_cfg._fields['fr'], 268 | config_dict.FieldReference) 269 | hash(frozen_cfg) # with FieldReference resolved, frozen_cfg is hashable 270 | 271 | def testFieldReferenceCycle(self): 272 | """Tests that FieldReferences may not contain reference cycles.""" 273 | frozenset_fr = {'frozenset': frozenset({1, 2})} 274 | frozenset_fr['fr'] = config_dict.FieldReference( 275 | frozenset_fr['frozenset']) 276 | list_fr = {'list': [1, 2]} 277 | list_fr['fr'] = config_dict.FieldReference(list_fr['list']) 278 | 279 | cyclic_fr = {'a': 1} 280 | cyclic_fr['fr'] = config_dict.FieldReference(cyclic_fr) 281 | cyclic_fr_parent = {'dict': {}} 282 | cyclic_fr_parent['dict']['fr'] = config_dict.FieldReference( 283 | cyclic_fr_parent) 284 | 285 | # FieldReference is allowed to point to non-cyclic objects: 286 | _ = config_dict.FrozenConfigDict(frozenset_fr) 287 | _ = config_dict.FrozenConfigDict(list_fr) 288 | # But not cycles: 289 | self.assertFrozenRaisesValueError([cyclic_fr, cyclic_fr_parent]) 290 | 291 | def testDeepCopy(self): 292 | """Ensure deepcopy works and does not affect equality.""" 293 | fcd = _test_frozenconfigdict() 294 | fcd_deepcopy = copy.deepcopy(fcd) 295 | self.assertEqual(fcd, fcd_deepcopy) 296 | 297 | def testEquals(self): 298 | """Tests that __eq__() respects hidden mutability.""" 299 | fcd = _test_frozenconfigdict() 300 | 301 | # First, ensure __eq__() returns False when comparing to other types 302 | self.assertNotEqual(fcd, (1, 2)) 303 | self.assertNotEqual(fcd, fcd.as_configdict()) 304 | 305 | list_to_tuple = _test_dict_deepcopy() 306 | list_to_tuple['list'] = tuple(list_to_tuple['list']) 307 | fcd_list_to_tuple = config_dict.FrozenConfigDict(list_to_tuple) 308 | 309 | set_to_frozenset = _test_dict_deepcopy() 310 | set_to_frozenset['set'] = frozenset(set_to_frozenset['set']) 311 | fcd_set_to_frozenset = config_dict.FrozenConfigDict(set_to_frozenset) 312 | 313 | self.assertNotEqual(fcd, fcd_list_to_tuple) 314 | 315 | # Because set == frozenset in Python: 316 | self.assertEqual(fcd, fcd_set_to_frozenset) 317 | 318 | # Items are not affected by hidden mutability 319 | self.assertCountEqual(fcd.items(), fcd_list_to_tuple.items()) 320 | self.assertCountEqual(fcd.items(), fcd_set_to_frozenset.items()) 321 | 322 | def testEqualsAsConfigDict(self): 323 | """Tests that eq_as_configdict respects hidden mutability but not type.""" 324 | fcd = _test_frozenconfigdict() 325 | 326 | # First, ensure eq_as_configdict() returns True with an equal ConfigDict but 327 | # False for other types. 328 | self.assertFalse(fcd.eq_as_configdict([1, 2])) 329 | self.assertTrue(fcd.eq_as_configdict(fcd.as_configdict())) 330 | empty_fcd = config_dict.FrozenConfigDict() 331 | self.assertTrue(empty_fcd.eq_as_configdict(config_dict.ConfigDict())) 332 | 333 | # Now, ensure it has the same immutability detection as __eq__(). 334 | list_to_tuple = _test_dict_deepcopy() 335 | list_to_tuple['list'] = tuple(list_to_tuple['list']) 336 | fcd_list_to_tuple = config_dict.FrozenConfigDict(list_to_tuple) 337 | 338 | set_to_frozenset = _test_dict_deepcopy() 339 | set_to_frozenset['set'] = frozenset(set_to_frozenset['set']) 340 | fcd_set_to_frozenset = config_dict.FrozenConfigDict(set_to_frozenset) 341 | 342 | self.assertFalse(fcd.eq_as_configdict(fcd_list_to_tuple)) 343 | # Because set == frozenset in Python: 344 | self.assertTrue(fcd.eq_as_configdict(fcd_set_to_frozenset)) 345 | 346 | def testHash(self): 347 | """Ensures __hash__() respects hidden mutability.""" 348 | list_to_tuple = _test_dict_deepcopy() 349 | list_to_tuple['list'] = tuple(list_to_tuple['list']) 350 | 351 | self.assertEqual( 352 | hash(_test_frozenconfigdict()), 353 | hash(config_dict.FrozenConfigDict(_test_dict_deepcopy()))) 354 | self.assertNotEqual( 355 | hash(_test_frozenconfigdict()), 356 | hash(config_dict.FrozenConfigDict(list_to_tuple))) 357 | 358 | # Ensure Python realizes FrozenConfigDict is hashable 359 | self.assertIsInstance(_test_frozenconfigdict(), collections_abc.Hashable) 360 | 361 | def testUnhashableType(self): 362 | """Ensures __hash__() fails if FrozenConfigDict has unhashable value.""" 363 | unhashable_fcd = config_dict.FrozenConfigDict( 364 | {'unhashable': bytearray()}) 365 | with self.assertRaises(TypeError): 366 | hash(unhashable_fcd) 367 | 368 | def testToDict(self): 369 | """Ensure to_dict() does not care about hidden mutability.""" 370 | list_to_tuple = _test_dict_deepcopy() 371 | list_to_tuple['list'] = tuple(list_to_tuple['list']) 372 | 373 | self.assertEqual(_test_frozenconfigdict().to_dict(), 374 | config_dict.FrozenConfigDict(list_to_tuple).to_dict()) 375 | 376 | def testPickle(self): 377 | """Make sure FrozenConfigDict can be dumped and loaded with pickle.""" 378 | fcd = _test_frozenconfigdict() 379 | locked_fcd = config_dict.FrozenConfigDict(_test_configdict().lock()) 380 | 381 | unpickled_fcd = pickle.loads(pickle.dumps(fcd)) 382 | unpickled_locked_fcd = pickle.loads(pickle.dumps(locked_fcd)) 383 | 384 | self.assertEqual(fcd, unpickled_fcd) 385 | self.assertEqual(locked_fcd, unpickled_locked_fcd) 386 | 387 | 388 | if __name__ == '__main__': 389 | absltest.main() 390 | -------------------------------------------------------------------------------- /ml_collections/config_flags/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config flags module.""" 16 | 17 | from .config_flags import DEFINE_config_dataclass 18 | from .config_flags import DEFINE_config_dict 19 | from .config_flags import DEFINE_config_file 20 | from .config_flags import get_config_filename 21 | from .config_flags import get_override_values 22 | from .config_flags import register_flag_parser 23 | from .config_flags import register_flag_parser_for_type 24 | 25 | __all__ = ( 26 | "DEFINE_config_dataclass", 27 | "DEFINE_config_dict", 28 | "DEFINE_config_file", 29 | "get_config_filename", 30 | "get_override_values", 31 | "register_flag_parser", 32 | "register_flag_parser_for_type", 33 | ) 34 | -------------------------------------------------------------------------------- /ml_collections/config_flags/config_path.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module for spliting flag prefixes.""" 16 | 17 | import ast 18 | import dataclasses as dc 19 | import functools 20 | import types 21 | import typing 22 | from typing import Any, MutableSequence, Optional, Sequence, Tuple, Union, Type 23 | 24 | from ml_collections import config_dict 25 | 26 | 27 | NoneType = type(None) 28 | 29 | 30 | _AST_SPLIT_CONFIG_PATH = { 31 | ast.Attribute: lambda n: (*_split_node(n.value), n.attr), 32 | ast.Index: lambda i: _split_node(i.value), 33 | ast.Name: lambda n: (n.id,), 34 | ast.Slice: lambda i: slice(*map(_split_node, (i.lower, i.upper, i.step))), 35 | ast.Subscript: lambda n: (*_split_node(n.value), _split_node(n.slice)), 36 | type(None): lambda n: None 37 | } 38 | 39 | 40 | def _split_node(node): 41 | return _AST_SPLIT_CONFIG_PATH.get(type(node), ast.literal_eval)(node) 42 | 43 | 44 | def split(config_path: str) -> Tuple[Any]: 45 | """Returns config_path split into a tuple of parts. 46 | 47 | Example usage: 48 | >>> assert config_path.split('a.b.cc') == ('a', 'b', 'cc') 49 | >>> assert config_path.split('a.b["cc.d"]') == ('a', 'b', 'cc.d') 50 | >>> assert config_path.split('a.b[10]') == ('a', 'b', 10) 51 | >>> assert config_path.split('a[(1, 2)]') == ('a', (1, 2)) 52 | >>> assert config_path.split('a[:]') == ('a', slice(None)) 53 | 54 | Args: 55 | config_path: Input path to be split - see example usage. 56 | 57 | Returns: 58 | Tuple of config_path split into parts. Parts are attributes or subscripts. 59 | Attrributes are treated as strings and subscripts are parsed using 60 | ast.literal_eval. It is up to the caller to ensure all returned types are 61 | valid. 62 | 63 | Raises: 64 | ValueError: Failed to parse config_path. 65 | """ 66 | try: 67 | node = ast.parse(config_path, mode='eval') 68 | except SyntaxError as e: 69 | raise ValueError(f'Could not parse {config_path!r}: {e!r}') from None 70 | if isinstance(node, ast.Expression): 71 | result = _split_node(node.body) 72 | if isinstance(result, tuple): 73 | return result 74 | raise ValueError(config_path) 75 | 76 | 77 | def _get_item_or_attribute(config, field, 78 | field_path: Optional[str] = None): 79 | """Returns attribute of member failing that the item.""" 80 | if isinstance(field, str) and hasattr(config, field): 81 | return getattr(config, field) 82 | if hasattr(config, '__getitem__'): 83 | return config[field] 84 | if isinstance(field, int): 85 | raise IndexError( 86 | f'{type(config)} does not support integer indexing [{field}]]. ' 87 | f'Attempting to lookup: {field_path}') 88 | raise KeyError( 89 | f'Attribute {type(config)}.{field} does not exist ' 90 | 'and the type does not support indexing. ' 91 | f'Attempting to lookup: {field_path}') 92 | 93 | 94 | def _get_holder_field(config_path: str, config: Any) -> Tuple[Any, str]: 95 | """Returns the last part config_path and config to allow assignment. 96 | 97 | Example usage: 98 | >>> config = {'a': {'b', {'c', 10}}} 99 | >>> holder, lastfield = _get_holder_field('a.b.c', config) 100 | >>> assert lastfield == 'c' 101 | >>> assert holder is config['a']['b'] 102 | >>> assert holder[lastfield] == 10 103 | 104 | Args: 105 | config_path: Any string that `split` can process. 106 | config: A nested datastructure that can be accessed via 107 | _get_item_or_attribute 108 | 109 | Returns: 110 | The penultimate object when walking config with config_path. And the final 111 | part of the config path. 112 | 113 | Raises: 114 | IndexError: Integer field not found in nested structure. 115 | KeyError: Non-integer field not found in nested structure. 116 | ValueError: Empty/invalid config_path after parsing. 117 | """ 118 | fields = split(config_path) 119 | if not fields: 120 | raise ValueError('Path cannot be empty') 121 | get_item = functools.partial(_get_item_or_attribute, field_path=config_path) 122 | holder = functools.reduce(get_item, fields[:-1], config) 123 | return holder, fields[-1] 124 | 125 | 126 | def get_value(config_path: str, config: Any): 127 | """Gets value of a single field. 128 | 129 | Example usage: 130 | >>> config = {'a': {'b', {'c', 10}}} 131 | >>> assert config_path.get_value('a.b.c', config) == 10 132 | 133 | Args: 134 | config_path: Any string that `split` can process. 135 | config: A nested datastructure 136 | 137 | Returns: 138 | The last object when walking config with config_path. 139 | 140 | Raises: 141 | IndexError: Integer field not found in nested structure. 142 | KeyError: Non-integer field not found in nested structure. 143 | ValueError: Empty/invalid config_path after parsing. 144 | """ 145 | get_item = functools.partial(_get_item_or_attribute, field_path=config_path) 146 | return functools.reduce(get_item, split(config_path), config) 147 | 148 | 149 | def initialize_missing_parent_fields( 150 | config: Any, override: str, 151 | allowed_missing: Sequence[str]): 152 | """Adds some missing nested holder fields for a particular override. 153 | 154 | For example if override is 'config.a.b.c' and config.a is None, it 155 | will default initialize config.a, and if config.a.b is None will default 156 | initialize it as well. Only overrides present in allowed_missing will 157 | be initialized. 158 | 159 | Args: 160 | config: config object (typically dataclass) 161 | override: dot joined override name. 162 | allowed_missing: list of overrides that are allowed 163 | to be set. For example, if override is 'a.b.c.d', 164 | allowed_missing could be ['a.b.c', 'a', 'foo.bar']. 165 | 166 | Raises: 167 | ValueError: if parent field is not of dataclass type. 168 | """ 169 | fields = split(override) 170 | # Collect the tree levels at which we are alloed to create override 171 | allowed_levels = {len(split(x)) for x in allowed_missing if 172 | override.startswith(x + '.')} 173 | child = config 174 | for level, f in enumerate(fields[:-1], 1): 175 | parent = child 176 | child = _get_item_or_attribute(parent, f, override) 177 | if child is not None: 178 | continue 179 | # Field is not yet present, see if we should create it instead. 180 | field_type = get_type(f, parent) 181 | # Note: these two assertions below are mostly guard 182 | # rails to prevent behaviors that might be confusing/accidental. 183 | # Specifically we disallow implicit creation of parent fields, 184 | # creating non dataclass objects. They can be revisited 185 | # in the future. 186 | if not dc.is_dataclass(field_type): 187 | raise ValueError( 188 | f'Override {override} can not be applied because ' 189 | f'field "{f}" is None, and its type "{field_type}" is not a ' 190 | f'dataclass in the parent of type "{type(parent)}".') 191 | 192 | if level not in allowed_levels: 193 | raise ValueError( 194 | f'Flag {override} can not be applied because ' 195 | f'field "{f}" is None by default and it is not explicitly ' 196 | 'provided in flags (it can be default intialized by ' 197 | f'providing --.{f}=build flag') 198 | try: 199 | child = field_type() 200 | except Exception as e: 201 | raise ValueError( 202 | f'Override {override} can not be applied because ' 203 | f'field "{f}" of type {field_type} can not be default instantiated:' 204 | f'{e}') from e 205 | set_value(f, parent, child) 206 | 207 | 208 | def get_origin(type_spec: type) -> Optional[type]: # pylint: disable=g-bare-generic drop when 3.7 support is not needed 209 | """Call typing.get_origin, with a fallback for Python 3.7 and below.""" 210 | if hasattr(typing, 'get_origin'): 211 | return typing.get_origin(type_spec) 212 | return getattr(type_spec, '__origin__', None) 213 | 214 | 215 | def get_args(type_spec: type) -> Union[NoneType, Tuple[type, ...]]: # pylint: disable=g-bare-generic drop when 3.7 support is not needed 216 | """Call typing.get_args, with fallback for Python 3.7 and below.""" 217 | if hasattr(typing, 'get_args'): 218 | return typing.get_args(type_spec) 219 | return getattr(type_spec, '__args__', NoneType) 220 | 221 | 222 | def _is_union_type(type_spec: type) -> bool: # pylint: disable=g-bare-generic drop when 3.7 support is not needed 223 | """Cheeck if a type_spec is a Union type or not.""" 224 | # UnionType was only introduced in python 3.10. We need getattr for 225 | # backward compatibility. 226 | return get_origin(type_spec) in [Union, getattr(types, 'UnionType', Union)] 227 | 228 | 229 | def extract_type_from_optional(type_spec: type) -> Optional[type]: # pylint: disable=g-bare-generic drop when 3.7 support is not needed 230 | """If type_spec is of type Optional[T], returns T object, otherwise None""" 231 | if not _is_union_type(type_spec): 232 | return None 233 | non_none = [t for t in get_args(type_spec) if t is not NoneType] 234 | if len(non_none) != 1: 235 | return None 236 | return non_none[0] 237 | 238 | 239 | def normalize_type(type_spec: type) -> type: # pylint: disable=g-bare-generic drop when 3.7 support is not needed 240 | """Normalizes a type object. 241 | 242 | Strips all None types from the type specification and returns the remaining 243 | single type. This is primarily useful for Optional type annotations in which 244 | case it will strip out the NoneType and return the inner type. 245 | 246 | Args: 247 | type_spec: The type to normalize. 248 | 249 | Raises: 250 | TypeError: If there is not exactly 1 non-None type in the union. 251 | Returns: 252 | The normalized type. 253 | """ 254 | if _is_union_type(type_spec): 255 | subtype = extract_type_from_optional(type_spec) 256 | if subtype is None: 257 | raise TypeError(f'Unable to normalize ambiguous type: {type_spec}') 258 | return subtype 259 | 260 | return type_spec 261 | 262 | 263 | def get_type( 264 | config_path: str, 265 | config: Any, 266 | normalize=True, 267 | default_type: Optional[Type[Any]] = None, 268 | ): 269 | """Gets type of field in config described by a config_path. 270 | 271 | Example usage: 272 | >>> config = {'a': {'b', {'c', 10}}} 273 | >>> assert config_path.get_type('a.b.c', config) is int 274 | 275 | Args: 276 | config_path: Any string that `split` can process. 277 | config: A nested datastructure 278 | normalize: whether to normalize the type (in particular strip Optional 279 | annotations on dataclass fields) 280 | default_type: If the `config_path` is not found and `default_type` is set, 281 | the `default_type` is returned. 282 | 283 | Returns: 284 | The type of last object when walking config with config_path. 285 | 286 | Raises: 287 | IndexError: Integer field not found in nested structure. 288 | KeyError: Non-integer field not found in nested structure. 289 | ValueError: Empty/invalid config_path after parsing. 290 | TypeError: Ambiguous type annotation on dataclass field. 291 | """ 292 | holder, field = _get_holder_field(config_path, config) 293 | # Check if config is a DM collection and hence has attribute get_type() 294 | if isinstance(holder, 295 | (config_dict.ConfigDict, config_dict.FieldReference)): 296 | if default_type is not None and field not in holder: 297 | return default_type 298 | return holder.get_type(field) 299 | # For dataclasses we can just use the type annotation. 300 | elif dc.is_dataclass(holder): 301 | matches = [f.type for f in dc.fields(holder) if f.name == field] 302 | if not matches: 303 | raise KeyError(f'Field {field} not found on dataclass {type(holder)}') 304 | return normalize_type(matches[0]) if normalize else matches[0] 305 | else: 306 | return type(_get_item_or_attribute(holder, field, config_path)) 307 | 308 | 309 | def is_optional(config_path: str, config: Any) -> bool: 310 | raw_type = get_type(config_path, config, normalize=False) 311 | return extract_type_from_optional(raw_type) is not None 312 | 313 | 314 | def set_value( 315 | config_path: str, 316 | config: Any, 317 | value: Any, 318 | *, 319 | accept_new_attributes: bool = False, 320 | ): 321 | """Sets value of field described by config_path. 322 | 323 | Example usage: 324 | >>> config = {'a': {'b', {'c', 10}}} 325 | >>> config_path.set_value('a.b.c', config, 20) 326 | >>> assert config['a']['b']['c'] == 20 327 | 328 | Args: 329 | config_path: Any string that `split` can process. 330 | config: A nested datastructure 331 | value: A value to assign to final field. 332 | accept_new_attributes: If `True`, the new config attributes can be added 333 | 334 | Raises: 335 | IndexError: Integer field not found in nested structure. 336 | KeyError: Non-integer field not found in nested structure. 337 | ValueError: Empty/invalid config_path after parsing. 338 | """ 339 | holder, field = _get_holder_field(config_path, config) 340 | 341 | if isinstance(field, int) and isinstance(holder, MutableSequence): 342 | holder[field] = value 343 | elif hasattr(holder, '__setitem__') and ( 344 | field in holder or accept_new_attributes 345 | ): 346 | holder[field] = value 347 | elif hasattr(holder, str(field)): 348 | setattr(holder, str(field), value) 349 | else: 350 | if isinstance(field, int): 351 | raise IndexError( 352 | f'{field} is not a valid index for {type(holder)} ' 353 | f'(in: {config_path})') 354 | raise KeyError(f'{field} is not a valid key or attribute of {type(holder)} ' 355 | f'(in: {config_path})') 356 | -------------------------------------------------------------------------------- /ml_collections/config_flags/examples/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Defines a method which returns an instance of ConfigDict.""" 16 | 17 | from ml_collections import config_dict 18 | 19 | 20 | def get_config(): 21 | config = config_dict.ConfigDict() 22 | config.field1 = 1 23 | config.field2 = 'tom' 24 | config.nested = config_dict.ConfigDict() 25 | config.nested.field = 2.23 26 | config.tuple = (1, 2, 3) 27 | return config 28 | -------------------------------------------------------------------------------- /ml_collections/config_flags/examples/define_config_dataclass_basic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Example of basic DEFINE_config_dataclass usage. 16 | 17 | To run this example: 18 | python define_config_dataclass_basic.py -- --my_config.field1=8 \ 19 | --my_config.nested.field=2.1 --my_config.tuple='(1, 2, (1, 2))' 20 | """ 21 | 22 | import dataclasses 23 | from typing import Any, Mapping, Sequence 24 | 25 | from absl import app 26 | from ml_collections import config_flags 27 | 28 | 29 | @dataclasses.dataclass 30 | class MyConfig: 31 | field1: int 32 | field2: str 33 | nested: Mapping[str, Any] 34 | tuple: Sequence[int] 35 | 36 | 37 | config = MyConfig( 38 | field1=1, 39 | field2='tom', 40 | nested={'field': 2.23}, 41 | tuple=(1, 2, 3), 42 | ) 43 | 44 | _CONFIG = config_flags.DEFINE_config_dataclass('my_config', config) 45 | 46 | 47 | def main(_): 48 | print(_CONFIG.value) 49 | 50 | 51 | if __name__ == '__main__': 52 | app.run(main) 53 | -------------------------------------------------------------------------------- /ml_collections/config_flags/examples/define_config_dict_basic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Example of basic DEFINE_config_dict usage. 16 | 17 | To run this example: 18 | python define_config_dict_basic.py -- --my_config_dict.field1=8 \ 19 | --my_config_dict.nested.field=2.1 --my_config_dict.tuple='(1, 2, (1, 2))' 20 | """ 21 | 22 | from absl import app 23 | 24 | from ml_collections import config_dict 25 | from ml_collections import config_flags 26 | 27 | config = config_dict.ConfigDict() 28 | config.field1 = 1 29 | config.field2 = 'tom' 30 | config.nested = config_dict.ConfigDict() 31 | config.nested.field = 2.23 32 | config.tuple = (1, 2, 3) 33 | 34 | _CONFIG = config_flags.DEFINE_config_dict('my_config_dict', config) 35 | 36 | 37 | def main(_): 38 | print(_CONFIG.value) 39 | 40 | 41 | if __name__ == '__main__': 42 | app.run(main) 43 | -------------------------------------------------------------------------------- /ml_collections/config_flags/examples/define_config_file_basic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=line-too-long 16 | r"""Example of basic DEFINE_flag_dict usage. 17 | 18 | To run this example with basic config file: 19 | python define_config_dict_basic.py -- \ 20 | --my_config=ml_collections/config_flags/examples/config.py 21 | \ 22 | --my_config.field1=8 --my_config.nested.field=2.1 \ 23 | --my_config.tuple='(1, 2, (1, 2))' 24 | 25 | To run this example with parameterised config file: 26 | python define_config_dict_basic.py -- \ 27 | --my_config=ml_collections/config_flags/examples/parameterised_config.py:linear 28 | \ 29 | --my_config.model_config.output_size=256' 30 | """ 31 | # pylint: enable=line-too-long 32 | 33 | from absl import app 34 | 35 | from ml_collections import config_flags 36 | 37 | _CONFIG = config_flags.DEFINE_config_file('my_config') 38 | 39 | 40 | def main(_): 41 | print(_CONFIG.value) 42 | 43 | 44 | if __name__ == '__main__': 45 | app.run(main) 46 | -------------------------------------------------------------------------------- /ml_collections/config_flags/examples/examples_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for config_flags examples. 16 | 17 | Ensures that from define_config_dict_basic, define_config_file_basic run 18 | successfully. 19 | """ 20 | 21 | from absl import flags 22 | from absl.testing import absltest 23 | from absl.testing import flagsaver 24 | from ml_collections.config_flags.examples import define_config_dict_basic 25 | from ml_collections.config_flags.examples import define_config_file_basic 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | 30 | class ConfigDictExamplesTest(absltest.TestCase): 31 | 32 | def test_define_config_dict_basic(self): 33 | define_config_dict_basic.main([]) 34 | 35 | @flagsaver.flagsaver 36 | def test_define_config_file_basic(self): 37 | FLAGS.my_config = 'ml_collections/config_flags/examples/config.py' 38 | define_config_file_basic.main([]) 39 | 40 | @flagsaver.flagsaver 41 | def test_define_config_file_parameterised(self): 42 | FLAGS.my_config = 'ml_collections/config_flags/examples/parameterised_config.py:linear' 43 | define_config_file_basic.main([]) 44 | 45 | 46 | if __name__ == '__main__': 47 | absltest.main() 48 | -------------------------------------------------------------------------------- /ml_collections/config_flags/examples/parameterised_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Defines a parameterized method which returns a config depending on input.""" 16 | 17 | from ml_collections import config_dict 18 | 19 | 20 | def get_config(config_string): 21 | """Return an instance of ConfigDict depending on `config_string`.""" 22 | possible_structures = { 23 | 'linear': 24 | config_dict.ConfigDict({ 25 | 'model_constructor': 'snt.Linear', 26 | 'model_config': config_dict.ConfigDict({ 27 | 'output_size': 42, 28 | }) 29 | }), 30 | 'lstm': 31 | config_dict.ConfigDict({ 32 | 'model_constructor': 'snt.LSTM', 33 | 'model_config': config_dict.ConfigDict({ 34 | 'hidden_size': 108, 35 | }) 36 | }) 37 | } 38 | 39 | return possible_structures[config_string] 40 | -------------------------------------------------------------------------------- /ml_collections/config_flags/tests/config_path_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for config_path.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from ml_collections.config_flags import config_path 20 | from ml_collections.config_flags.tests import fieldreference_config 21 | from ml_collections.config_flags.tests import mock_config 22 | 23 | 24 | class ConfigPathTest(parameterized.TestCase): 25 | 26 | def test_list_extra_index(self): 27 | """Tries to index a non-indexable list element.""" 28 | 29 | test_config = mock_config.get_config() 30 | with self.assertRaises(IndexError): 31 | config_path.get_value('dict.list[0][0]', test_config) 32 | 33 | def test_list_out_of_range_get(self): 34 | """Tries to access out-of-range value in list.""" 35 | 36 | test_config = mock_config.get_config() 37 | with self.assertRaises(IndexError): 38 | config_path.get_value('dict.list[2][1]', test_config) 39 | 40 | def test_list_out_of_range_set(self): 41 | """Tries to override out-of-range value in list.""" 42 | 43 | test_config = mock_config.get_config() 44 | with self.assertRaises(IndexError): 45 | config_path.set_value('dict.list[2][1]', test_config, -1) 46 | 47 | def test_reading_non_existing_key(self): 48 | """Tests reading non existing key from config.""" 49 | 50 | test_config = mock_config.get_config() 51 | with self.assertRaises(KeyError): 52 | config_path.set_value('dict.not_existing_key', test_config, 1) 53 | 54 | def test_reading_setting_existing_key_in_dict(self): 55 | """Tests setting non existing key from dict inside config.""" 56 | 57 | test_config = mock_config.get_config() 58 | with self.assertRaises(KeyError): 59 | config_path.set_value('dict.not_existing_key.key', test_config, 1) 60 | 61 | def test_empty_key(self): 62 | """Tests calling an empty key update.""" 63 | 64 | test_config = mock_config.get_config() 65 | with self.assertRaises(ValueError): 66 | config_path.set_value('', test_config, None) 67 | 68 | def test_field_reference_types(self): 69 | """Tests whether types of FieldReference fields are valid.""" 70 | test_config = fieldreference_config.get_config() 71 | 72 | paths = ['ref_nodefault', 'ref'] 73 | paths_types = [int, int] 74 | 75 | config_types = [config_path.get_type(path, test_config) for path in paths] 76 | self.assertEqual(paths_types, config_types) 77 | 78 | @parameterized.parameters( 79 | ('float', float), 80 | ('integer', int), 81 | ('string', str), 82 | ('bool', bool), 83 | ('dict', dict), 84 | ('dict.float', float), 85 | ('dict.list', list), 86 | ('list', list), 87 | ('list[0]', int), 88 | ('object.float', float), 89 | ('object.integer', int), 90 | ('object.string', str), 91 | ('object.bool', bool), 92 | ('object.dict', dict), 93 | ('object.dict.float', float), 94 | ('object.dict.list', list), 95 | ('object.list', list), 96 | ('object.list[0]', int), 97 | ('object.tuple', tuple), 98 | ('object_reference.float', float), 99 | ('object_reference.integer', int), 100 | ('object_reference.string', str), 101 | ('object_reference.bool', bool), 102 | ('object_reference.dict', dict), 103 | ('object_reference.dict.float', float), 104 | ('object_copy.float', float), 105 | ('object_copy.integer', int), 106 | ('object_copy.string', str), 107 | ('object_copy.bool', bool), 108 | ('object_copy.dict', dict), 109 | ('object_copy.dict.float', float), 110 | ) 111 | def test_types(self, path, path_type): 112 | """Tests whether various types of objects are valid.""" 113 | test_config = mock_config.get_config() 114 | self.assertEqual(path_type, config_path.get_type(path, test_config)) 115 | 116 | if __name__ == '__main__': 117 | absltest.main() 118 | -------------------------------------------------------------------------------- /ml_collections/config_flags/tests/configdict_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ConfigDict config file.""" 16 | 17 | from ml_collections import config_dict 18 | 19 | 20 | class UnusableConfig(object): 21 | """Test against code assuming the semantics of attributes (such as `lock`). 22 | 23 | This class is to test that the flags implementation does not assume the 24 | semantics of attributes. This is to avoid code such as: 25 | 26 | ```python 27 | if hasattr(obj, lock): 28 | obj.lock() 29 | ``` 30 | 31 | which will fail if `obj` has an attribute `lock` that does not behave in the 32 | way we expect. 33 | 34 | This class only has unusable attributes. There are two 35 | exceptions for which this class behaves normally: 36 | * Python's special functions which start and end with a double underscore. 37 | * `valid_attribute`, an attribute used to test the class. 38 | 39 | For other attributes, both `hasttr(obj, attr)` and `callable(obj, attr)` will 40 | return True. Calling `obj.attr` will return a function which takes no 41 | arguments and raises an AttributeError when called. For example, the `lock` 42 | example above will raise an AttributeError. The only valid action on 43 | attributes is assignment, e.g. 44 | 45 | ```python 46 | obj = UnusableConfig() 47 | obj.attr = 1 48 | ``` 49 | 50 | In which case the attribute will keep its assigned value and become usable. 51 | """ 52 | 53 | def __init__(self): 54 | self._valid_attribute = 1 55 | 56 | def __getattr__(self, attribute): 57 | """Get an arbitrary attribute. 58 | 59 | Returns a function which takes no arguments and raises an AttributeError, 60 | except for Python special functions in which case an AttributeError is 61 | directly raised. 62 | 63 | Args: 64 | attribute: A string representing the attribute's name. 65 | Returns: 66 | A function which raises an AttributeError when called. 67 | Raises: 68 | AttributeError: when the attribute is a Python special function starting 69 | and ending with a double underscore. 70 | """ 71 | if attribute.startswith("__") and attribute.endswith("__"): 72 | raise AttributeError("UnusableConfig does not contain entry {}.". 73 | format(attribute)) 74 | 75 | def raise_attribute_error_fun(): 76 | raise AttributeError( 77 | "{} is not a usable attribute of UnusableConfig".format( 78 | attribute)) 79 | 80 | return raise_attribute_error_fun 81 | 82 | @property 83 | def valid_attribute(self): 84 | return self._valid_attribute 85 | 86 | @valid_attribute.setter 87 | def valid_attribute(self, value): 88 | self._valid_attribute = value 89 | 90 | def get_config(): 91 | """Returns a ConfigDict. Used for tests.""" 92 | cfg = config_dict.ConfigDict() 93 | cfg.integer = 1 94 | cfg.reference = config_dict.FieldReference(1) 95 | cfg.list = [1, 2, 3] 96 | cfg.nested_list = [[1, 2, 3]] 97 | cfg.nested_configdict = config_dict.ConfigDict() 98 | cfg.nested_configdict.integer = 1 99 | cfg.unusable_config = UnusableConfig() 100 | 101 | return cfg 102 | -------------------------------------------------------------------------------- /ml_collections/config_flags/tests/dataclass_overriding_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for config_flags used in conjunction with DEFINE_config_dataclass.""" 16 | 17 | import copy 18 | import dataclasses 19 | import functools 20 | import sys 21 | from typing import Mapping, Optional, Sequence, Tuple, Union 22 | import unittest 23 | 24 | from absl import flags 25 | from absl.testing import absltest 26 | from ml_collections import config_flags 27 | from ml_collections.config_flags import config_flags as config_flag_lib 28 | 29 | 30 | ##### 31 | # Simple configuration class for testing. 32 | @dataclasses.dataclass 33 | class MyModelConfig: 34 | foo: int 35 | bar: Sequence[str] 36 | baz: Optional[Mapping[str, str]] = None 37 | buz: Optional[Mapping[Tuple[int, int], str]] = None 38 | qux: Optional[int] = None 39 | bax: float = 1 40 | boj: Tuple[int, ...] = () 41 | 42 | 43 | class ParserForCustomConfig(flags.ArgumentParser): 44 | def __init__(self, delta=1): 45 | self.delta = delta 46 | 47 | def parse(self, value): 48 | if isinstance(value, CustomParserConfig): 49 | return value 50 | return CustomParserConfig(i=int(value), j=int(value) + self.delta) 51 | 52 | 53 | @dataclasses.dataclass 54 | @config_flags.register_flag_parser(parser=ParserForCustomConfig()) 55 | class CustomParserConfig(): 56 | i: int 57 | j: int = 1 58 | 59 | 60 | @dataclasses.dataclass 61 | class MyConfig: 62 | my_model: MyModelConfig 63 | baseline_model: MyModelConfig 64 | custom: CustomParserConfig = dataclasses.field( 65 | default_factory=lambda: CustomParserConfig(0)) 66 | 67 | 68 | @dataclasses.dataclass 69 | class SubConfig: 70 | model: Optional[MyModelConfig] = dataclasses.field( 71 | default_factory=lambda: MyModelConfig(foo=0, bar=['1'])) 72 | 73 | 74 | @dataclasses.dataclass 75 | class ConfigWithOptionalNestedField: 76 | sub: Optional[SubConfig] = None 77 | non_optional: SubConfig = dataclasses.field( 78 | default_factory=SubConfig) 79 | 80 | _CONFIG = MyConfig( 81 | my_model=MyModelConfig( 82 | foo=3, 83 | bar=['a', 'b'], 84 | baz={'foo.b': 'bar'}, 85 | buz={(0, 0): 'ZeroZero', (0, 1): 'ZeroOne'} 86 | ), 87 | baseline_model=MyModelConfig( 88 | foo=55, 89 | bar=['c', 'd'], 90 | ), 91 | ) 92 | 93 | _TEST_FLAG = config_flags.DEFINE_config_dataclass('test_flag', _CONFIG, 94 | 'MyConfig data') 95 | 96 | 97 | def _test_flags(default, *flag_args, parse_fn=None): 98 | flag_values = flags.FlagValues() 99 | # DEFINE_config_dataclass accesses sys.argv to build flag list! 100 | old_args = list(sys.argv) 101 | sys.argv[:] = ['', *['--test_config' + f for f in flag_args]] 102 | try: 103 | result = config_flags.DEFINE_config_dataclass( 104 | 'test_config', default, flag_values=flag_values, parse_fn=parse_fn) 105 | _, *remaining = flag_values(sys.argv) 106 | if remaining: 107 | raise ValueError(f'{remaining}') 108 | # assert not remaining 109 | return result.value 110 | finally: 111 | sys.argv[:] = old_args 112 | 113 | 114 | def parse_config_flag(value): 115 | cfg = _CONFIG 116 | return dataclasses.replace( 117 | cfg, 118 | my_model=dataclasses.replace(cfg.my_model, foo=int(value))) 119 | 120 | 121 | class TypedConfigFlagsTest(absltest.TestCase): 122 | 123 | def test_types(self): 124 | self.assertIsInstance(_TEST_FLAG.value, MyConfig) 125 | self.assertEqual(_TEST_FLAG.value, _CONFIG) 126 | self.assertIsInstance(flags.FLAGS['test_flag'].value, MyConfig) 127 | self.assertIsInstance(flags.FLAGS.test_flag, MyConfig) 128 | self.assertEqual(flags.FLAGS['test_flag'].value, _CONFIG) 129 | module_name = __name__ if __name__ != '__main__' else sys.argv[0] 130 | self.assertEqual( 131 | flags.FLAGS.find_module_defining_flag('test_flag'), module_name) 132 | 133 | def test_instance(self): 134 | config = _test_flags(_CONFIG) 135 | self.assertIsInstance(config, MyConfig) 136 | self.assertEqual(config.my_model, _CONFIG.my_model) 137 | self.assertEqual(_CONFIG, config) 138 | 139 | def test_flag_config_dataclass_optional(self): 140 | result = _test_flags(_CONFIG, '.baseline_model.qux=10') 141 | self.assertEqual(result.baseline_model.qux, 10) 142 | self.assertIsInstance(result.baseline_model.qux, int) 143 | self.assertIsNone(result.my_model.qux) 144 | 145 | def test_flag_config_dataclass_repeated_arg_use_last(self): 146 | result = _test_flags( 147 | _CONFIG, '.baseline_model.qux=10', '.baseline_model.qux=12' 148 | ) 149 | self.assertEqual(result.baseline_model.qux, 12) 150 | self.assertIsInstance(result.baseline_model.qux, int) 151 | self.assertIsNone(result.my_model.qux) 152 | 153 | def test_custom_flag_parsing_shared_default(self): 154 | result = _test_flags(_CONFIG, '.baseline_model.foo=324') 155 | result1 = _test_flags(_CONFIG, '.baseline_model.foo=123') 156 | # Here we verify that despite using _CONFIG as shared default for 157 | # result and result1, the final values are not in fact shared. 158 | self.assertEqual(result.baseline_model.foo, 324) 159 | self.assertEqual(result1.baseline_model.foo, 123) 160 | self.assertEqual(_CONFIG.baseline_model.foo, 55) 161 | 162 | def test_custom_flag_parsing_parser_override(self): 163 | config_flags.register_flag_parser_for_type( 164 | CustomParserConfig, ParserForCustomConfig(2)) 165 | result = _test_flags(_CONFIG, '.custom=10') 166 | self.assertEqual(result.custom.i, 10) 167 | self.assertEqual(result.custom.j, 12) 168 | 169 | # Restore old parser. 170 | config_flags.register_flag_parser_for_type( 171 | CustomParserConfig, ParserForCustomConfig()) 172 | 173 | @unittest.skipIf(sys.version_info[:2] < (3, 10), 'Need 3.10 to test | syntax') 174 | def test_pipe_syntax(self): 175 | @dataclasses.dataclass 176 | class PipeConfig: 177 | foo: int | None = None 178 | 179 | result = _test_flags(PipeConfig(), '.foo=32') 180 | self.assertEqual(result.foo, 32) 181 | 182 | def test_custom_flag_parsing_override_work(self): 183 | # Overrides still work. 184 | result = _test_flags(_CONFIG, '.custom.i=10') 185 | self.assertEqual(result.custom.i, 10) 186 | self.assertEqual(result.custom.j, 1) 187 | 188 | def test_optional_nested_fields(self): 189 | with self.assertRaises(ValueError): 190 | # Implicit creation not allowed. 191 | _test_flags(ConfigWithOptionalNestedField(), '.sub.model.foo=12') 192 | 193 | # Explicit creation works. 194 | result = _test_flags( 195 | ConfigWithOptionalNestedField(), '.sub=build', '.sub.model.foo=12' 196 | ) 197 | self.assertEqual(result.sub.model.foo, 12) 198 | 199 | # Default initialization support. 200 | result = _test_flags(ConfigWithOptionalNestedField(), '.sub=build') 201 | self.assertEqual(result.sub.model.foo, 0) 202 | 203 | # Using default value (None). 204 | result = _test_flags(ConfigWithOptionalNestedField()) 205 | self.assertIsNone(result.sub) 206 | 207 | with self.assertRaises(config_flag_lib.FlagOrderError): 208 | # Don't allow accidental overwrites. 209 | _test_flags( 210 | ConfigWithOptionalNestedField(), '.sub.model.foo=12', '.sub=build' 211 | ) 212 | 213 | def test_set_to_none_dataclass_fields(self): 214 | result = _test_flags( 215 | ConfigWithOptionalNestedField(), '.sub=build', '.sub.model=none' 216 | ) 217 | self.assertIsNone(result.sub.model, None) 218 | 219 | with self.assertRaises(KeyError): 220 | # Parent field is set to None (from not None default value), 221 | # so this is not a valid set of flags. 222 | _test_flags( 223 | ConfigWithOptionalNestedField(), 224 | '.sub=build', 225 | '.sub.model=none', 226 | '.sub.model.foo=12', 227 | ) 228 | 229 | with self.assertRaises(KeyError): 230 | # Parent field is explicitly set to None (with None default value), 231 | # so this is not a valid set of flags. 232 | _test_flags( 233 | ConfigWithOptionalNestedField(), '.sub=none', '.sub.model.foo=12' 234 | ) 235 | 236 | def test_set_none_non_optional_dataclass_fields(self): 237 | with self.assertRaises(flags.IllegalFlagValueError): 238 | # Field is not marked as optional so it can't be set to None. 239 | _test_flags(ConfigWithOptionalNestedField(), '.non_optional=None') 240 | 241 | def test_no_default_initializer(self): 242 | with self.assertRaises(flags.IllegalFlagValueError): 243 | _test_flags(ConfigWithOptionalNestedField(), '.sub=1', '.sub.model=1') 244 | 245 | def test_custom_flag_parser_invoked(self): 246 | # custom parser gets invoked 247 | result = _test_flags(_CONFIG, '.custom=10') 248 | self.assertEqual(result.custom.i, 10) 249 | self.assertEqual(result.custom.j, 11) 250 | 251 | def test_custom_flag_parser_invoked_overrides_applied(self): 252 | result = _test_flags(_CONFIG, '.custom=15', '.custom.i=11') 253 | # Override applied successfully 254 | self.assertEqual(result.custom.i, 11) 255 | self.assertEqual(result.custom.j, 16) 256 | 257 | def test_custom_flag_application_order(self): 258 | # Disallow for later value to override the earlier value. 259 | with self.assertRaises(config_flag_lib.FlagOrderError): 260 | _test_flags(_CONFIG, '.custom.i=11', '.custom=15') 261 | 262 | def test_flag_config_dataclass_type_mismatch(self): 263 | result = _test_flags(_CONFIG, '.my_model.bax=10') 264 | self.assertIsInstance(result.my_model.bax, float) 265 | # We can't do anything when the value isn't overridden. 266 | self.assertIsInstance(result.baseline_model.bax, int) 267 | self.assertRaises( 268 | flags.IllegalFlagValueError, 269 | functools.partial(_test_flags, _CONFIG, '.my_model.bax=string'), 270 | ) 271 | 272 | def test_illegal_dataclass_field_type(self): 273 | 274 | @dataclasses.dataclass 275 | class Config: 276 | field: Union[int, float] = 3 277 | 278 | self.assertRaises( 279 | TypeError, functools.partial(_test_flags, Config(), '.field=1') 280 | ) 281 | 282 | def test_spurious_dataclass_field(self): 283 | 284 | @dataclasses.dataclass 285 | class Config: 286 | field: int = 3 287 | cfg = Config() 288 | cfg.extra = 'test' 289 | 290 | self.assertRaises( 291 | KeyError, functools.partial(_test_flags, cfg, '.extra=hi') 292 | ) 293 | 294 | def test_nested_dataclass(self): 295 | 296 | @dataclasses.dataclass 297 | class Parent: 298 | field: int = 3 299 | 300 | @dataclasses.dataclass 301 | class Child(Parent): 302 | other: int = 4 303 | 304 | self.assertEqual(_test_flags(Child(), '.field=1').field, 1) 305 | 306 | def test_flag_config_dataclass(self): 307 | result = _test_flags(_CONFIG, '.baseline_model.foo=10', '.my_model.foo=7') 308 | self.assertEqual(result.baseline_model.foo, 10) 309 | self.assertEqual(result.my_model.foo, 7) 310 | 311 | def test_flag_config_dataclass_string_dict(self): 312 | result = _test_flags(_CONFIG, '.my_model.baz["foo.b"]=rab') 313 | self.assertEqual(result.my_model.baz['foo.b'], 'rab') 314 | 315 | def test_flag_config_dataclass_tuple_dict(self): 316 | result = _test_flags(_CONFIG, '.my_model.buz[(0,1)]=hello') 317 | self.assertEqual(result.my_model.buz[(0, 1)], 'hello') 318 | 319 | def test_flag_config_dataclass_typed_tuple(self): 320 | result = _test_flags(_CONFIG, '.my_model.boj=(0, 1)') 321 | self.assertEqual(result.my_model.boj, (0, 1)) 322 | 323 | 324 | class DataClassParseFnTest(absltest.TestCase): 325 | 326 | def test_parse_no_custom_value(self): 327 | result = _test_flags( 328 | _CONFIG, '.baseline_model.foo=10', parse_fn=parse_config_flag 329 | ) 330 | self.assertEqual(result.my_model.foo, 3) 331 | self.assertEqual(result.baseline_model.foo, 10) 332 | 333 | def test_parse_custom_value_applied(self): 334 | result = _test_flags( 335 | _CONFIG, '=75', '.baseline_model.foo=10', parse_fn=parse_config_flag 336 | ) 337 | self.assertEqual(result.my_model.foo, 75) 338 | self.assertEqual(result.baseline_model.foo, 10) 339 | 340 | def test_parse_custom_value_applied_no_explicit_parse_fn(self): 341 | result = _test_flags(CustomParserConfig(0), '=75', '.i=12') 342 | self.assertEqual(result.i, 12) 343 | self.assertEqual(result.j, 76) 344 | 345 | def test_parse_out_of_order(self): 346 | with self.assertRaises(config_flag_lib.FlagOrderError): 347 | _ = _test_flags( 348 | _CONFIG, '.baseline_model.foo=10', '=75', parse_fn=parse_config_flag 349 | ) 350 | # Note: If this is ever supported, add verification that overrides are 351 | # applied correctly. 352 | 353 | def test_parse_assign_dataclass(self): 354 | flag_values = flags.FlagValues() 355 | 356 | def always_fail(v): 357 | raise ValueError() 358 | 359 | result = config_flags.DEFINE_config_dataclass( 360 | 'test_config', _CONFIG, flag_values=flag_values, parse_fn=always_fail) 361 | flag_values(['program']) 362 | flag_values['test_config'].value = parse_config_flag('12') 363 | self.assertEqual(result.value.my_model.foo, 12) 364 | 365 | def test_parse_invalid_custom_value(self): 366 | with self.assertRaises(flags.IllegalFlagValueError): 367 | _ = _test_flags( 368 | _CONFIG, '=?', '.baseline_model.foo=10', parse_fn=parse_config_flag 369 | ) 370 | 371 | def test_parse_overrides_applied(self): 372 | result = _test_flags( 373 | _CONFIG, '=34', '.my_model.foo=10', parse_fn=parse_config_flag 374 | ) 375 | self.assertEqual(result.my_model.foo, 10) 376 | 377 | if __name__ == '__main__': 378 | absltest.main() 379 | -------------------------------------------------------------------------------- /ml_collections/config_flags/tests/fieldreference_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config file with field references.""" 16 | 17 | from ml_collections import config_dict 18 | 19 | 20 | def get_config(): 21 | cfg = config_dict.ConfigDict() 22 | cfg.ref = config_dict.FieldReference(123) 23 | cfg.ref_nodefault = config_dict.placeholder(int) 24 | return cfg 25 | -------------------------------------------------------------------------------- /ml_collections/config_flags/tests/ioerror_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config file that raises IOError on import. 16 | 17 | The flags library tries to load configuration files in a few different ways. 18 | For this it relies on catching IOError exceptions of the type "File not 19 | found" and ignoring them to continue trying with a different loading method. 20 | But we need to ensure that other types of IOError exceptions are propagated 21 | correctly (b/63165566). This is tested in `ConfigFlagTest.testIOError` in 22 | `config_overriding_test.py`. 23 | """ 24 | 25 | raise IOError('This is an IOError.') 26 | -------------------------------------------------------------------------------- /ml_collections/config_flags/tests/literal_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config file with field references.""" 16 | 17 | from ml_collections import config_dict 18 | 19 | 20 | def get_config(): 21 | cfg = config_dict.ConfigDict() 22 | cfg.integer = config_dict.placeholder(object) 23 | cfg.string = config_dict.placeholder(object) 24 | cfg.nested = config_dict.placeholder(object) 25 | cfg.other_with_default = config_dict.placeholder(object) 26 | cfg.other_with_default = 123 27 | cfg.other_with_default_overitten = config_dict.placeholder(object) 28 | cfg.other_with_default_overitten = 123 29 | return cfg 30 | -------------------------------------------------------------------------------- /ml_collections/config_flags/tests/mini_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Placeholder Config file.""" 16 | 17 | 18 | class MiniConfig(object): 19 | """Just a placeholder config.""" 20 | 21 | def __init__(self): 22 | self.dict = {} 23 | self.field = False 24 | 25 | def __getitem__(self, key): 26 | return self.dict[key] 27 | 28 | def __contains__(self, key): 29 | return key in self.dict 30 | 31 | def __setitem__(self, key, value): 32 | self.dict[key] = value 33 | 34 | 35 | def get_config(): 36 | cfg = MiniConfig() 37 | cfg['entry_with_collision'] = False 38 | cfg.entry_with_collision = True 39 | return cfg 40 | -------------------------------------------------------------------------------- /ml_collections/config_flags/tests/mock_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Placeholder Config file.""" 16 | 17 | import copy 18 | 19 | from ml_collections.config_flags.tests import spork 20 | 21 | 22 | class TestConfig(object): 23 | """Default Test config value.""" 24 | 25 | def __init__(self): 26 | self.integer = 23 27 | self.float = 2.34 28 | self.string = 'james' 29 | self.bool = True 30 | self.dict = { 31 | 'integer': 1, 32 | 'float': 3.14, 33 | 'string': 'mark', 34 | 'bool': False, 35 | 'dict': { 36 | 'float': 5. 37 | }, 38 | 'list': [1, 2, [3]] 39 | } 40 | self.list = [1, 2, [3]] 41 | self.tuple = (1, 2, (3,)) 42 | self.tuple_with_spaces = (1, 2, (3,)) 43 | self.enum = spork.SporkType.SPOON 44 | 45 | @property 46 | def readonly_field(self): 47 | return 42 48 | 49 | def __repr__(self): 50 | return str(self.__dict__) 51 | 52 | 53 | def get_config(): 54 | 55 | config = TestConfig() 56 | config.object = TestConfig() 57 | config.object_reference = config.object 58 | config.object_copy = copy.deepcopy(config.object) 59 | 60 | return config 61 | -------------------------------------------------------------------------------- /ml_collections/config_flags/tests/parameterised_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config file where `get_config` takes a string argument.""" 16 | 17 | from ml_collections import config_dict 18 | 19 | 20 | def get_config(config_string): 21 | """A config which takes an extra string argument.""" 22 | possible_configs = { 23 | 'type_a': config_dict.ConfigDict({ 24 | 'thing_a': 23, 25 | 'thing_b': 42, 26 | }), 27 | 'type_b': config_dict.ConfigDict({ 28 | 'thing_a': 19, 29 | 'thing_c': 65, 30 | }), 31 | } 32 | return possible_configs[config_string] 33 | -------------------------------------------------------------------------------- /ml_collections/config_flags/tests/spork.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A enum to be used in a mock_config. 16 | 17 | The enum can't be defined directly in the config because a config imported 18 | through a flag is a separate instance of the config module. Therefore it would 19 | define an own instance of the enum class which won't be equal to the same enum 20 | from the config imported directly. 21 | """ 22 | 23 | import enum 24 | 25 | 26 | class SporkType(enum.Enum): 27 | SPOON = 1 28 | SPORK = 2 29 | FORK = 3 30 | -------------------------------------------------------------------------------- /ml_collections/config_flags/tests/tuple_parser_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ml_collection.config_flags.tuple_parser.""" 16 | 17 | from ml_collections.config_flags import tuple_parser 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | 22 | class TupleParserTest(parameterized.TestCase): 23 | 24 | @parameterized.parameters( 25 | {'argument': 1, 'expected': (1,)}, 26 | {'argument': (1, 2), 'expected': (1, 2)}, 27 | {'argument': ('abc', 'def'), 'expected': ('abc', 'def')}, 28 | {'argument': '1', 'expected': (1,)}, 29 | {'argument': '"abc"', 'expected': ('abc',)}, 30 | {'argument': '"abc",', 'expected': ('abc',)}, 31 | {'argument': '1, "a"', 'expected': (1, 'a')}, 32 | {'argument': '(1, "a")', 'expected': (1, 'a')}, 33 | {'argument': '(1, "a", (2, 3))', 'expected': (1, 'a', (2, 3))}, 34 | {'argument': ('abc*', 'def*'), 'expected': ('abc*', 'def*')}, 35 | {'argument': '("abc*", "def*")', 'expected': ('abc*', 'def*')}, 36 | {'argument': '("/abc",)', 'expected': ('/abc',)}, 37 | {'argument': '("/abc*",)', 'expected': ('/abc*',)}, 38 | {'argument': '("/abc/",)', 'expected': ('/abc/',)}, 39 | ) 40 | def test_tuple_parser_parse(self, argument, expected): 41 | parser = tuple_parser.TupleParser() 42 | self.assertEqual(parser.parse(argument), expected) 43 | 44 | @parameterized.parameters( 45 | {'argument': '1', 'expected': (1,)}, 46 | {'argument': '"abc"', 'expected': ('abc',)}, 47 | {'argument': '"abc",', 'expected': ('abc',)}, 48 | {'argument': 'abc', 'expected': ('abc',)}, 49 | {'argument': '1, "a"', 'expected': (1, 'a')}, 50 | {'argument': '(1, "a")', 'expected': (1, 'a')}, 51 | {'argument': '(1, "a", (2, 3))', 'expected': (1, 'a', (2, 3))}, 52 | {'argument': '("abc*", "def*")', 'expected': ('abc*', 'def*')}, 53 | {'argument': '"abc*", "def*"', 'expected': ('abc*', 'def*')}, 54 | {'argument': '"abc*",', 'expected': ('abc*',)}, 55 | {'argument': 'abc*', 'expected': ('abc*',)}, 56 | {'argument': '"/abc",', 'expected': ('/abc',)}, 57 | {'argument': '/abc*', 'expected': ('/abc*',)}, 58 | {'argument': '/abc/', 'expected': ('/abc/',)}, 59 | ) 60 | def test_convert_str_to_tuple(self, argument, expected): 61 | self.assertEqual(tuple_parser._convert_str_to_tuple(argument), expected) 62 | 63 | @parameterized.parameters( 64 | 'a b', 65 | 'a,b*', 66 | '/a,b*', 67 | '/a,b', 68 | 'a b', 69 | ) 70 | def test_convert_str_to_tuple_bad_inputs(self, argument): 71 | with self.assertRaises(ValueError): 72 | tuple_parser._convert_str_to_tuple(argument) 73 | 74 | 75 | if __name__ == '__main__': 76 | absltest.main() 77 | -------------------------------------------------------------------------------- /ml_collections/config_flags/tests/typeerror_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config file that raises TypeError on import. 16 | 17 | When trying loading the configuration file as a flag, the flags library catches 18 | TypeError exceptions then recasts them as a IllegalFlagTypeError and rethrows 19 | (b/63877430). The rethrow does not include the stacktrace from the original 20 | exception, so we manually add the stracktrace in configflags.parse(). This is 21 | tested in `ConfigFlagTest.testTypeError` in `config_overriding_test.py`. 22 | """ 23 | 24 | 25 | def type_error_function(): 26 | raise TypeError('This is a TypeError.') 27 | 28 | 29 | def get_config(): 30 | return {'item': type_error_function()} 31 | -------------------------------------------------------------------------------- /ml_collections/config_flags/tests/valueerror_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config file that raises ValueError on import. 16 | 17 | When trying loading the configuration file as a flag, the flags library catches 18 | ValueError exceptions then recasts them as a IllegalFlagValueError and rethrows 19 | (b/63877430). The rethrow does not include the stacktrace from the original 20 | exception, so we manually add the stracktrace in configflags.parse(). This is 21 | tested in `ConfigFlagTest.testValueError` in `config_overriding_test.py`. 22 | """ 23 | 24 | 25 | def value_error_function(): 26 | raise ValueError('This is a ValueError.') 27 | 28 | 29 | def get_config(): 30 | return {'item': value_error_function()} 31 | -------------------------------------------------------------------------------- /ml_collections/config_flags/tuple_parser.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Custom parser to override tuples in the config dict.""" 16 | import ast 17 | import collections.abc 18 | 19 | from absl import flags 20 | 21 | 22 | class TupleParser(flags.ArgumentParser): 23 | """Parser for tuple arguments. 24 | 25 | Custom flag parser for Tuple objects that is based on the existing parsers 26 | in `absl.flags`. This parser can be used to read in and override tuple 27 | arguments. It outputs a `tuple` object from an existing `tuple` or `str`. 28 | The ony requirement is that the overriding parameter should be a `tuple` 29 | as well. The overriding parameter can have a different number of elements 30 | of different type than the original. For a detailed list of what `str` 31 | arguments are supported for overriding, look at `ast.literal_eval` from the 32 | Python Standard Library. 33 | """ 34 | 35 | def parse(self, argument): 36 | """Returns a `tuple` representing the input `argument`. 37 | 38 | Args: 39 | argument: The argument to be parsed. Valid types are `tuple` and 40 | `str` or a single object. `str` arguments are parsed and converted to a 41 | tuple, a single object is converted to a tuple of length 1, and an empty 42 | `tuple` is returned for arguments `NoneType`. 43 | 44 | Returns: 45 | A `TupleType` representing the input argument as a `tuple`. 46 | """ 47 | if isinstance(argument, tuple): 48 | return argument 49 | elif isinstance(argument, str): 50 | return _convert_str_to_tuple(argument) 51 | elif argument is None: 52 | return () 53 | else: 54 | return (argument,) 55 | 56 | def flag_type(self): 57 | return 'tuple' 58 | 59 | 60 | def _convert_str_to_tuple(string): 61 | """Function to convert a Python `str` object to a `tuple`. 62 | 63 | Args: 64 | string: The `str` to be converted. 65 | 66 | Returns: 67 | A `tuple` version of the string. 68 | 69 | Raises: 70 | ValueError: If the string is not a well formed `tuple`. 71 | """ 72 | # literal_eval converts strings to int, tuple, list, float and dict, 73 | # booleans and None. It can also handle nested tuples. 74 | # It does not, however, handle elements of type set. 75 | try: 76 | value = ast.literal_eval(string) 77 | except ValueError: 78 | # A ValueError is raised by literal_eval if the string is not well 79 | # formed. This is probably because string represents a single string 80 | # element, e.g. string='abc', and a tuple of strings field was overridden by 81 | # repeated use of a flag (ie `--flag a --flag b` instead of 82 | # `--flag '("a", "b")'`). 83 | value = string 84 | except SyntaxError as exc: 85 | # The only other error that may be raised is a `SyntaxError` because 86 | # `literal_eval` calls the Python in-built `compile`. This error is 87 | # caused by parsing issues. 88 | if ',' not in string and ' ' not in string: 89 | # Most likely passed a single string that contained an operator -- e.g. 90 | # '/path/to/file' or 'file_pattern*'. If a comma isn't in the string, then 91 | # it can't have been a tuple, so assume it's an unquoted string. 92 | # If passed strings containing a comma (user probably expects conversion 93 | # to a tuple) or whitespace (user might expect implicit conversion to 94 | # tuple?) raise an exception. 95 | value = string 96 | else: 97 | msg = ( 98 | f'Error while parsing string: {string} as tuple. If you intended to' 99 | ' pass the argument as a single element, use quotes such as `--flag' 100 | f' {repr(repr(string))}, otherwise insert quotes around each element.' 101 | ) 102 | if ' ' in string: 103 | msg += ( 104 | ' Use commas instead of whitespace as the separator between' 105 | ' elements.' 106 | ) 107 | raise ValueError(msg) from exc 108 | 109 | # Make sure we return a tuple. 110 | if isinstance(value, tuple): 111 | return value 112 | elif (isinstance(value, collections.abc.Iterable) and 113 | not isinstance(value, str)): 114 | return tuple(value) 115 | else: 116 | return (value,) 117 | -------------------------------------------------------------------------------- /ml_collections/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The ML Collections Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration file for pytest.""" 16 | 17 | from absl import flags 18 | import pytest 19 | 20 | 21 | @pytest.fixture(scope="function", autouse=True) 22 | def mark_flags_as_parsed(): 23 | flags.FLAGS.mark_as_parsed() 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | # Project metadata. Available keys are documented at: 3 | # https://packaging.python.org/en/latest/specifications/declaring-project-metadata 4 | name = "ml_collections" 5 | description = "ML Collections is a library of Python collections designed for ML usecases." 6 | readme = "README.md" 7 | requires-python = ">=3.10" 8 | license = {file = "LICENSE"} 9 | authors = [{name = "ML Collections Authors", email="ml-collections@google.com"}] 10 | classifiers = [ # List of https://pypi.org/classifiers/ 11 | "Development Status :: 4 - Beta", 12 | "Intended Audience :: Developers", 13 | "Intended Audience :: Science/Research", 14 | "License :: OSI Approved :: Apache Software License", 15 | "Programming Language :: Python", 16 | "Topic :: Scientific/Engineering", 17 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 18 | "Topic :: Software Development :: Libraries", 19 | "Topic :: Software Development :: Libraries :: Python Modules", 20 | ] 21 | keywords = [] 22 | 23 | # pip dependencies of the project 24 | # Installed locally with `pip install -e .` 25 | dependencies = [ 26 | "absl-py", 27 | "PyYAML", # Could make this optional ? 28 | ] 29 | 30 | # `version` is automatically set by flit to use `ml_collections.__version__` 31 | dynamic = ["version"] 32 | 33 | [project.urls] 34 | homepage = "https://github.com/google/ml_collections" 35 | repository = "https://github.com/google/ml_collections" 36 | documentation = "https://ml-collections.readthedocs.io" 37 | 38 | [project.optional-dependencies] 39 | # Development deps (unittest, linting, formating,...) 40 | # Installed through `pip install -e .[dev]` 41 | dev = [ 42 | "pytest", 43 | "pytest-xdist", 44 | "pylint>=2.6.0", 45 | "pyink", 46 | ] 47 | 48 | [tool.pyink] 49 | # Formatting configuration to follow Google style-guide 50 | line-length = 80 51 | unstable = true 52 | pyink-indentation = 2 53 | pyink-use-majority-quotes = true 54 | 55 | [build-system] 56 | # Build system specify which backend is used to build/install the project (flit, 57 | # poetry, setuptools,...). All backends are supported by `pip install` 58 | requires = ["flit_core >=3.8,<4"] 59 | build-backend = "flit_core.buildapi" 60 | 61 | [tool.flit.sdist] 62 | # Flit specific options (files to exclude from the PyPI package). 63 | # If using another build backend (setuptools, poetry), you can remove this 64 | # section. 65 | exclude = [ 66 | # Do not release tests files on PyPI 67 | "**/*_test.py", 68 | ] 69 | --------------------------------------------------------------------------------