├── .github
├── ISSUE_TEMPLATE
│ └── bug_report.md
└── workflows
│ └── pytest_and_autopublish.yml
├── .gitignore
├── LICENSE
├── README.md
├── docs
├── CONTRIBUTING.md
├── Makefile
├── README.md
├── conf.py
├── config_dict.rst
├── config_dict_examples.rst
├── config_flags.rst
├── config_flags_examples.rst
├── index.rst
├── requirements.txt
└── testing.md
├── ml_collections
├── AUTHORS
├── __init__.py
├── config_dict
│ ├── __init__.py
│ ├── config_dict.py
│ ├── examples
│ │ ├── config.py
│ │ ├── config_dict_advanced.py
│ │ ├── config_dict_basic.py
│ │ ├── config_dict_initialization.py
│ │ ├── config_dict_lock.py
│ │ ├── config_dict_placeholder.py
│ │ ├── examples_test.py
│ │ ├── field_reference.py
│ │ └── frozen_config_dict.py
│ └── tests
│ │ ├── config_dict_test.py
│ │ ├── field_reference_test.py
│ │ └── frozen_config_dict_test.py
├── config_flags
│ ├── __init__.py
│ ├── config_flags.py
│ ├── config_path.py
│ ├── examples
│ │ ├── config.py
│ │ ├── define_config_dataclass_basic.py
│ │ ├── define_config_dict_basic.py
│ │ ├── define_config_file_basic.py
│ │ ├── examples_test.py
│ │ └── parameterised_config.py
│ ├── tests
│ │ ├── config_overriding_test.py
│ │ ├── config_path_test.py
│ │ ├── configdict_config.py
│ │ ├── dataclass_overriding_test.py
│ │ ├── fieldreference_config.py
│ │ ├── ioerror_config.py
│ │ ├── literal_config.py
│ │ ├── mini_config.py
│ │ ├── mock_config.py
│ │ ├── parameterised_config.py
│ │ ├── spork.py
│ │ ├── tuple_parser_test.py
│ │ ├── typeerror_config.py
│ │ └── valueerror_config.py
│ └── tuple_parser.py
└── conftest.py
└── pyproject.toml
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: bug
6 | assignees: mohitreddy1996
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 |
15 | ***ConfigDict***
16 | Please consider providing a Colab link reproducing the issue.
17 |
18 | ***ConfigFlags***
19 | Please consider providing a Colab link in case of `config_flags.DEFINE_config_dict` or a simple repository with a python script which has `config_flags.DEFINE_config_flags` and the config flag. Would be best the repository contains a README file with any setups and how to execute the script.
20 |
21 | **Expected behavior**
22 | A clear and concise description of what you expected to happen.
23 |
24 | **Environment:**
25 | - OS: [e.g. MacOS]
26 | - OS Version: [e.g. 22]
27 | - Python: [e.g. Python 3.6]
28 |
29 | **Additional context**
30 | Add any other context about the problem here.
31 |
--------------------------------------------------------------------------------
/.github/workflows/pytest_and_autopublish.yml:
--------------------------------------------------------------------------------
1 | name: Unittests & Auto-publish
2 |
3 | # Allow to trigger the workflow manually (e.g. when deps changes)
4 | on: [push, workflow_dispatch]
5 |
6 | jobs:
7 | pytest-job:
8 | runs-on: ubuntu-latest
9 | strategy:
10 | matrix:
11 | python-version: ['3.10', '3.11', '3.12']
12 | timeout-minutes: 30
13 |
14 | concurrency:
15 | group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.python-version }}
16 | cancel-in-progress: true
17 |
18 | steps:
19 | - uses: actions/checkout@v4
20 |
21 | # Install deps
22 | - uses: actions/setup-python@v5
23 | with:
24 | python-version: "3.10"
25 | # Uncomment to cache of pip dependencies (if tests too slow)
26 | # cache: pip
27 | # cache-dependency-path: '**/pyproject.toml'
28 |
29 | - run: pip --version
30 | - run: pip install -e .[dev]
31 | - run: pip freeze
32 |
33 | # Run tests (in parallel)
34 | - name: Run core tests
35 | run: |
36 | # TODO(marcenacp): do not ignore tests.
37 | pytest -vv -n auto ml_collections/ \
38 | --ignore=ml_collections/config_dict/examples/examples_test.py
39 |
40 | # Auto-publish when version is increased
41 | publish-job:
42 | # Only try to publish if:
43 | # * Repo is self (prevents running from forks)
44 | # * Branch is `master`
45 | if: |
46 | github.repository == 'google/ml_collections'
47 | && github.ref == 'refs/heads/master'
48 | needs: pytest-job # Only publish after tests are successful
49 | runs-on: ubuntu-latest
50 | permissions:
51 | contents: write
52 | timeout-minutes: 30
53 |
54 | steps:
55 | # Publish the package (if local `__version__` > pip version)
56 | - uses: etils-actions/pypi-auto-publish@v1
57 | with:
58 | pypi-token: ${{ secrets.PYPI_API_TOKEN }}
59 | gh-token: ${{ secrets.GITHUB_TOKEN }}
60 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | docs/__autosummary
2 | docs/_build
3 | docs/make.bat
4 | *.rej
5 | *~
6 | \#*\#
7 |
8 | # Compiled python modules.
9 | *.pyc
10 |
11 | # Byte-compiled
12 | _pycache__/
13 | .cache/
14 |
15 | # Poetry, setuptools, PyPI distribution artifacts.
16 | /*.egg-info
17 | .eggs/
18 | build/
19 | dist/
20 | poetry.lock
21 |
22 | # Tests
23 | .pytest_cache/
24 |
25 | # Type checking
26 | .pytype/
27 |
28 | # Other
29 | *.DS_Store
30 |
31 | # PyCharm
32 | .idea
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ML Collections
2 |
3 | ML Collections is a library of Python Collections designed for ML use cases.
4 |
5 | [](https://ml-collections.readthedocs.io/en/latest/?badge=latest)
6 | [](https://badge.fury.io/py/ml-collections)
7 | [](https://github.com/google/ml_collections/actions/workflows/pytest_and_autopublish.yml)
8 |
9 | ## ConfigDict
10 |
11 | The two classes called `ConfigDict` and `FrozenConfigDict` are "dict-like" data
12 | structures with dot access to nested elements. Together, they are supposed to be
13 | used as a main way of expressing configurations of experiments and models.
14 |
15 | This document describes example usage of `ConfigDict`, `FrozenConfigDict`,
16 | `FieldReference`.
17 |
18 | ### Features
19 |
20 | * Dot-based access to fields.
21 | * Locking mechanism to prevent spelling mistakes.
22 | * Lazy computation.
23 | * FrozenConfigDict() class which is immutable and hashable.
24 | * Type safety.
25 | * "Did you mean" functionality.
26 | * Human readable printing (with valid references and cycles), using valid YAML
27 | format.
28 | * Fields can be passed as keyword arguments using the `**` operator.
29 | * There is one exception to the strong type-safety of the ConfigDict: `int`
30 | values can be passed in to fields of type `float`. In such a case, the value
31 | is type-converted to a `float` before being stored. (Back in the day of
32 | Python 2, there was a similar exception to allow both `str` and `unicode`
33 | values in string fields.)
34 |
35 | ### Basic Usage
36 |
37 | ```python
38 | from ml_collections import config_dict
39 |
40 | cfg = config_dict.ConfigDict()
41 | cfg.float_field = 12.6
42 | cfg.integer_field = 123
43 | cfg.another_integer_field = 234
44 | cfg.nested = config_dict.ConfigDict()
45 | cfg.nested.string_field = 'tom'
46 |
47 | print(cfg.integer_field) # Prints 123.
48 | print(cfg['integer_field']) # Prints 123 as well.
49 |
50 | try:
51 | cfg.integer_field = 'tom' # Raises TypeError as this field is an integer.
52 | except TypeError as e:
53 | print(e)
54 |
55 | cfg.float_field = 12 # Works: `Int` types can be assigned to `Float`.
56 | cfg.nested.string_field = u'bob' # `String` fields can store Unicode strings.
57 |
58 | print(cfg)
59 | ```
60 |
61 | ### FrozenConfigDict
62 |
63 | A `FrozenConfigDict`is an immutable, hashable type of `ConfigDict`:
64 |
65 | ```python
66 | from ml_collections import config_dict
67 |
68 | initial_dictionary = {
69 | 'int': 1,
70 | 'list': [1, 2],
71 | 'tuple': (1, 2, 3),
72 | 'set': {1, 2, 3, 4},
73 | 'dict_tuple_list': {'tuple_list': ([1, 2], 3)}
74 | }
75 |
76 | cfg = config_dict.ConfigDict(initial_dictionary)
77 | frozen_dict = config_dict.FrozenConfigDict(initial_dictionary)
78 |
79 | print(frozen_dict.tuple) # Prints tuple (1, 2, 3)
80 | print(frozen_dict.list) # Prints tuple (1, 2)
81 | print(frozen_dict.set) # Prints frozenset {1, 2, 3, 4}
82 | print(frozen_dict.dict_tuple_list.tuple_list[0]) # Prints tuple (1, 2)
83 |
84 | frozen_cfg = config_dict.FrozenConfigDict(cfg)
85 | print(frozen_cfg == frozen_dict) # True
86 | print(hash(frozen_cfg) == hash(frozen_dict)) # True
87 |
88 | try:
89 | frozen_dict.int = 2 # Raises TypeError as FrozenConfigDict is immutable.
90 | except AttributeError as e:
91 | print(e)
92 |
93 | # Converting between `FrozenConfigDict` and `ConfigDict`:
94 | thawed_frozen_cfg = config_dict.ConfigDict(frozen_dict)
95 | print(thawed_frozen_cfg == cfg) # True
96 | frozen_cfg_to_cfg = frozen_dict.as_configdict()
97 | print(frozen_cfg_to_cfg == cfg) # True
98 | ```
99 |
100 | ### FieldReferences and placeholders
101 |
102 | A `FieldReference` is useful for having multiple fields use the same value. It
103 | can also be used for [lazy computation](#lazy-computation).
104 |
105 | You can use `placeholder()` as a shortcut to create a `FieldReference` (field)
106 | with a `None` default value. This is useful if a program uses optional
107 | configuration fields.
108 |
109 | ```python
110 | from ml_collections import config_dict
111 |
112 | placeholder = config_dict.FieldReference(0)
113 | cfg = config_dict.ConfigDict()
114 | cfg.placeholder = placeholder
115 | cfg.optional = config_dict.placeholder(int)
116 | cfg.nested = config_dict.ConfigDict()
117 | cfg.nested.placeholder = placeholder
118 |
119 | try:
120 | cfg.optional = 'tom' # Raises Type error as this field is an integer.
121 | except TypeError as e:
122 | print(e)
123 |
124 | cfg.optional = 1555 # Works fine.
125 | cfg.placeholder = 1 # Changes the value of both placeholder and
126 | # nested.placeholder fields.
127 |
128 | print(cfg)
129 | ```
130 |
131 | Note that the indirection provided by `FieldReference`s will be lost if accessed
132 | through a `ConfigDict`.
133 |
134 | ```python
135 | from ml_collections import config_dict
136 |
137 | placeholder = config_dict.FieldReference(0)
138 | cfg.field1 = placeholder
139 | cfg.field2 = placeholder # This field will be tied to cfg.field1.
140 | cfg.field3 = cfg.field1 # This will just be an int field initialized to 0.
141 | ```
142 |
143 | ### Lazy computation
144 |
145 | Using a `FieldReference` in a standard operation (addition, subtraction,
146 | multiplication, etc...) will return another `FieldReference` that points to the
147 | original's value. You can use `FieldReference.get()` to execute the operations
148 | and get the reference's computed value, and `FieldReference.set()` to change the
149 | original reference's value.
150 |
151 | ```python
152 | from ml_collections import config_dict
153 |
154 | ref = config_dict.FieldReference(1)
155 | print(ref.get()) # Prints 1
156 |
157 | add_ten = ref.get() + 10 # ref.get() is an integer and so is add_ten
158 | add_ten_lazy = ref + 10 # add_ten_lazy is a FieldReference - NOT an integer
159 |
160 | print(add_ten) # Prints 11
161 | print(add_ten_lazy.get()) # Prints 11 because ref's value is 1
162 |
163 | # Addition is lazily computed for FieldReferences so changing ref will change
164 | # the value that is used to compute add_ten.
165 | ref.set(5)
166 | print(add_ten) # Prints 11
167 | print(add_ten_lazy.get()) # Prints 15 because ref's value is 5
168 | ```
169 |
170 | If a `FieldReference` has `None` as its original value, or any operation has an
171 | argument of `None`, then the lazy computation will evaluate to `None`.
172 |
173 | We can also use fields in a `ConfigDict` in lazy computation. In this case a
174 | field will only be lazily evaluated if `ConfigDict.get_ref()` is used to get it.
175 |
176 | ```python
177 | from ml_collections import config_dict
178 |
179 | config = config_dict.ConfigDict()
180 | config.reference_field = config_dict.FieldReference(1)
181 | config.integer_field = 2
182 | config.float_field = 2.5
183 |
184 | # No lazy evaluatuations because we didn't use get_ref()
185 | config.no_lazy = config.integer_field * config.float_field
186 |
187 | # This will lazily evaluate ONLY config.integer_field
188 | config.lazy_integer = config.get_ref('integer_field') * config.float_field
189 |
190 | # This will lazily evaluate ONLY config.float_field
191 | config.lazy_float = config.integer_field * config.get_ref('float_field')
192 |
193 | # This will lazily evaluate BOTH config.integer_field and config.float_Field
194 | config.lazy_both = (config.get_ref('integer_field') *
195 | config.get_ref('float_field'))
196 |
197 | config.integer_field = 3
198 | print(config.no_lazy) # Prints 5.0 - It uses integer_field's original value
199 |
200 | print(config.lazy_integer) # Prints 7.5
201 |
202 | config.float_field = 3.5
203 | print(config.lazy_float) # Prints 7.0
204 | print(config.lazy_both) # Prints 10.5
205 | ```
206 |
207 | #### Changing lazily computed values
208 |
209 | Lazily computed values in a ConfigDict can be overridden in the same way as
210 | regular values. The reference to the `FieldReference` used for the lazy
211 | computation will be lost and all computations downstream in the reference graph
212 | will use the new value.
213 |
214 | ```python
215 | from ml_collections import config_dict
216 |
217 | config = config_dict.ConfigDict()
218 | config.reference = 1
219 | config.reference_0 = config.get_ref('reference') + 10
220 | config.reference_1 = config.get_ref('reference') + 20
221 | config.reference_1_0 = config.get_ref('reference_1') + 100
222 |
223 | print(config.reference) # Prints 1.
224 | print(config.reference_0) # Prints 11.
225 | print(config.reference_1) # Prints 21.
226 | print(config.reference_1_0) # Prints 121.
227 |
228 | config.reference_1 = 30
229 |
230 | print(config.reference) # Prints 1 (unchanged).
231 | print(config.reference_0) # Prints 11 (unchanged).
232 | print(config.reference_1) # Prints 30.
233 | print(config.reference_1_0) # Prints 130.
234 | ```
235 |
236 | #### Cycles
237 |
238 | You cannot create cycles using references. Fortunately
239 | [the only way](#changing-lazily-computed-values) to create a cycle is by
240 | assigning a computed field to one that *is not* the result of computation. This
241 | is forbidden:
242 |
243 | ```python
244 | from ml_collections import config_dict
245 |
246 | config = config_dict.ConfigDict()
247 | config.integer_field = 1
248 | config.bigger_integer_field = config.get_ref('integer_field') + 10
249 |
250 | try:
251 | # Raises a MutabilityError because setting config.integer_field would
252 | # cause a cycle.
253 | config.integer_field = config.get_ref('bigger_integer_field') + 2
254 | except config_dict.MutabilityError as e:
255 | print(e)
256 | ```
257 |
258 | #### One-way references
259 |
260 | One gotcha with `get_ref` is that it creates a bi-directional dependency when no operations are performed on the value.
261 |
262 | ```python
263 | from ml_collections import config_dict
264 |
265 | config = config_dict.ConfigDict()
266 | config.reference = 1
267 | config.reference_0 = config.get_ref('reference')
268 | config.reference_0 = 2
269 | print(config.reference) # Prints 2.
270 | print(config.reference_0) # Prints 2.
271 | ```
272 |
273 | This can be avoided by using `get_oneway_ref` instead of `get_ref`.
274 |
275 | ```python
276 | from ml_collections import config_dict
277 |
278 | config = config_dict.ConfigDict()
279 | config.reference = 1
280 | config.reference_0 = config.get_oneway_ref('reference')
281 | config.reference_0 = 2
282 | print(config.reference) # Prints 1.
283 | print(config.reference_0) # Prints 2.
284 | ```
285 |
286 | ### Advanced usage
287 |
288 | Here are some more advanced examples showing lazy computation with different
289 | operators and data types.
290 |
291 | ```python
292 | from ml_collections import config_dict
293 |
294 | config = config_dict.ConfigDict()
295 | config.float_field = 12.6
296 | config.integer_field = 123
297 | config.list_field = [0, 1, 2]
298 |
299 | config.float_multiply_field = config.get_ref('float_field') * 3
300 | print(config.float_multiply_field) # Prints 37.8
301 |
302 | config.float_field = 10.0
303 | print(config.float_multiply_field) # Prints 30.0
304 |
305 | config.longer_list_field = config.get_ref('list_field') + [3, 4, 5]
306 | print(config.longer_list_field) # Prints [0, 1, 2, 3, 4, 5]
307 |
308 | config.list_field = [-1]
309 | print(config.longer_list_field) # Prints [-1, 3, 4, 5]
310 |
311 | # Both operands can be references
312 | config.ref_subtraction = (
313 | config.get_ref('float_field') - config.get_ref('integer_field'))
314 | print(config.ref_subtraction) # Prints -113.0
315 |
316 | config.integer_field = 10
317 | print(config.ref_subtraction) # Prints 0.0
318 | ```
319 |
320 | ### Equality checking
321 |
322 | You can use `==` and `.eq_as_configdict()` to check equality among `ConfigDict`
323 | and `FrozenConfigDict` objects.
324 |
325 | ```python
326 | from ml_collections import config_dict
327 |
328 | dict_1 = {'list': [1, 2]}
329 | dict_2 = {'list': (1, 2)}
330 | cfg_1 = config_dict.ConfigDict(dict_1)
331 | frozen_cfg_1 = config_dict.FrozenConfigDict(dict_1)
332 | frozen_cfg_2 = config_dict.FrozenConfigDict(dict_2)
333 |
334 | # True because FrozenConfigDict converts lists to tuples
335 | print(frozen_cfg_1.items() == frozen_cfg_2.items())
336 | # False because == distinguishes the underlying difference
337 | print(frozen_cfg_1 == frozen_cfg_2)
338 |
339 | # False because == distinguishes these types
340 | print(frozen_cfg_1 == cfg_1)
341 | # But eq_as_configdict() treats both as ConfigDict, so these are True:
342 | print(frozen_cfg_1.eq_as_configdict(cfg_1))
343 | print(cfg_1.eq_as_configdict(frozen_cfg_1))
344 | ```
345 |
346 | ### Equality checking with lazy computation
347 |
348 | Equality checks see if the computed values are the same. Equality is satisfied
349 | if two sets of computations are different as long as they result in the same
350 | value.
351 |
352 | ```python
353 | from ml_collections import config_dict
354 |
355 | cfg_1 = config_dict.ConfigDict()
356 | cfg_1.a = 1
357 | cfg_1.b = cfg_1.get_ref('a') + 2
358 |
359 | cfg_2 = config_dict.ConfigDict()
360 | cfg_2.a = 1
361 | cfg_2.b = cfg_2.get_ref('a') * 3
362 |
363 | # True because all computed values are the same
364 | print(cfg_1 == cfg_2)
365 | ```
366 |
367 | ### Locking and copying
368 |
369 | Here is an example with `lock()` and `deepcopy()`:
370 |
371 | ```python
372 | import copy
373 | from ml_collections import config_dict
374 |
375 | cfg = config_dict.ConfigDict()
376 | cfg.integer_field = 123
377 |
378 | # Locking prohibits the addition and deletion of new fields but allows
379 | # modification of existing values.
380 | cfg.lock()
381 | try:
382 | cfg.intagar_field = 124 # Modifies the wrong field
383 | except AttributeError as e: # Raises AttributeError and suggests valid field.
384 | print(e)
385 | with cfg.unlocked():
386 | cfg.intagar_field = 1555 # Works fine.
387 |
388 | # Get a copy of the config dict.
389 | new_cfg = copy.deepcopy(cfg)
390 | new_cfg.integer_field = -123 # Works fine.
391 |
392 | print(cfg)
393 | print(new_cfg)
394 | ```
395 |
396 | Output:
397 |
398 | ```
399 | 'Key "intagar_field" does not exist and cannot be added since the config is locked. Other fields present: "{\'integer_field\': 123}"\nDid you mean "integer_field" instead of "intagar_field"?'
400 | intagar_field: 1555
401 | integer_field: 123
402 |
403 | intagar_field: 1555
404 | integer_field: -123
405 | ```
406 |
407 | ### Dictionary attributes and initialization
408 |
409 | ```python
410 | from ml_collections import config_dict
411 |
412 | referenced_dict = {'inner_float': 3.14}
413 | d = {
414 | 'referenced_dict_1': referenced_dict,
415 | 'referenced_dict_2': referenced_dict,
416 | 'list_containing_dict': [{'key': 'value'}],
417 | }
418 |
419 | # We can initialize on a dictionary
420 | cfg = config_dict.ConfigDict(d)
421 |
422 | # Reference structure is preserved
423 | print(id(cfg.referenced_dict_1) == id(cfg.referenced_dict_2)) # True
424 |
425 | # And the dict attributes have been converted to ConfigDict
426 | print(type(cfg.referenced_dict_1)) # ConfigDict
427 |
428 | # However, the initialization does not look inside of lists, so dicts inside
429 | # lists are not converted to ConfigDict
430 | print(type(cfg.list_containing_dict[0])) # dict
431 | ```
432 |
433 | ### More Examples
434 |
435 | For more examples, take a look at
436 | [`ml_collections/config_dict/examples/`](https://github.com/google/ml_collections/tree/master/ml_collections/config_dict/examples)
437 |
438 | For examples and gotchas specifically about initializing a ConfigDict, see
439 | [`ml_collections/config_dict/examples/config_dict_initialization.py`](https://github.com/google/ml_collections/blob/master/ml_collections/config_dict/examples/config_dict_initialization.py).
440 |
441 | ## Config Flags
442 |
443 | This library adds flag definitions to `absl.flags` to handle config files. It
444 | does not wrap `absl.flags` so if using any standard flag definitions alongside
445 | config file flags, users must also import `absl.flags`.
446 |
447 | Currently, this module adds two new flag types, namely `DEFINE_config_file`
448 | which accepts a path to a Python file that generates a configuration, and
449 | `DEFINE_config_dict` which accepts a configuration directly. Configurations are
450 | dict-like structures (see [ConfigDict](#configdict)) whose nested elements
451 | can be overridden using special command-line flags. See the examples below
452 | for more details.
453 |
454 | ### Usage
455 |
456 | Use `ml_collections.config_flags` alongside `absl.flags`. For
457 | example:
458 |
459 | `script.py`:
460 |
461 | ```python
462 | from absl import app
463 | from absl import flags
464 |
465 | from ml_collections import config_flags
466 |
467 | _CONFIG = config_flags.DEFINE_config_file('my_config')
468 | _MY_FLAG = flags.DEFINE_integer('my_flag', None)
469 |
470 | def main(_):
471 | print(_CONFIG.value)
472 | print(_MY_FLAG.value)
473 |
474 | if __name__ == '__main__':
475 | app.run(main)
476 | ```
477 |
478 | `config.py`:
479 |
480 | ```python
481 | # Note that this is a valid Python script.
482 | # get_config() can return an arbitrary dict-like object. However, it is advised
483 | # to use ml_collections.config_dict.ConfigDict.
484 | # See ml_collections/config_dict/examples/config_dict_basic.py
485 |
486 | from ml_collections import config_dict
487 |
488 | def get_config():
489 | config = config_dict.ConfigDict()
490 | config.field1 = 1
491 | config.field2 = 'tom'
492 | config.nested = config_dict.ConfigDict()
493 | config.nested.field = 2.23
494 | config.tuple = (1, 2, 3)
495 | return config
496 | ```
497 |
498 | Warning: If you are using a pickle-based distributed programming framework such
499 | as [Launchpad](https://github.com/deepmind/launchpad#readme), be aware of
500 | limitations on the structure of this script that are [described below]
501 | (#config_files_and_pickling).
502 |
503 | Now, after running:
504 |
505 | ```bash
506 | python script.py --my_config=config.py \
507 | --my_config.field1=8 \
508 | --my_config.nested.field=2.1 \
509 | --my_config.tuple='(1, 2, (1, 2))'
510 | ```
511 |
512 | we get:
513 |
514 | ```
515 | field1: 8
516 | field2: tom
517 | nested:
518 | field: 2.1
519 | tuple: !!python/tuple
520 | - 1
521 | - 2
522 | - !!python/tuple
523 | - 1
524 | - 2
525 | ```
526 |
527 | Usage of `DEFINE_config_dict` is similar to `DEFINE_config_file`, the main
528 | difference is the configuration is defined in `script.py` instead of in a
529 | separate file.
530 |
531 | `script.py`:
532 |
533 | ```python
534 | from absl import app
535 |
536 | from ml_collections import config_dict
537 | from ml_collections import config_flags
538 |
539 | config = config_dict.ConfigDict()
540 | config.field1 = 1
541 | config.field2 = 'tom'
542 | config.nested = config_dict.ConfigDict()
543 | config.nested.field = 2.23
544 | config.tuple = (1, 2, 3)
545 |
546 | _CONFIG = config_flags.DEFINE_config_dict('my_config', config)
547 |
548 | def main(_):
549 | print(_CONFIG.value)
550 |
551 | if __name__ == '__main__':
552 | app.run()
553 | ```
554 |
555 | `config_file` flags are compatible with the command-line flag syntax. All the
556 | following options are supported for non-boolean values in configurations:
557 |
558 | * `-(-)config.field=value`
559 | * `-(-)config.field value`
560 |
561 | Options for boolean values are slightly different:
562 |
563 | * `-(-)config.boolean_field`: set boolean value to True.
564 | * `-(-)noconfig.boolean_field`: set boolean value to False.
565 | * `-(-)config.boolean_field=value`: `value` is `true`, `false`, `True` or
566 | `False`.
567 |
568 | Note that `-(-)config.boolean_field value` is not supported.
569 |
570 | ### Parameterising the get_config() function
571 |
572 | It's sometimes useful to be able to pass parameters into `get_config`, and
573 | change what is returned based on this configuration. One example is if you are
574 | grid searching over parameters which have a different hierarchical structure -
575 | the flag needs to be present in the resulting ConfigDict. It would be possible
576 | to include the union of all possible leaf values in your ConfigDict,
577 | but this produces a confusing config result as you have to remember which
578 | parameters will actually have an effect and which won't.
579 |
580 | A better system is to pass some configuration, indicating which structure of
581 | ConfigDict should be returned. An example is the following config file:
582 |
583 | ```python
584 | from ml_collections import config_dict
585 |
586 | def get_config(config_string):
587 | possible_structures = {
588 | 'linear': config_dict.ConfigDict({
589 | 'model_constructor': 'snt.Linear',
590 | 'model_config': config_dict.ConfigDict({
591 | 'output_size': 42,
592 | }),
593 | 'lstm': config_dict.ConfigDict({
594 | 'model_constructor': 'snt.LSTM',
595 | 'model_config': config_dict.ConfigDict({
596 | 'hidden_size': 108,
597 | })
598 | })
599 | }
600 |
601 | return possible_structures[config_string]
602 | ```
603 |
604 | The value of `config_string` will be anything that is to the right of the first
605 | colon in the config file path, if one exists. If no colon exists, no value is
606 | passed to `get_config` (producing a TypeError if `get_config` expects a value).
607 |
608 | The above example can be run like:
609 |
610 | ```bash
611 | python script.py -- --config=path_to_config.py:linear \
612 | --config.model_config.output_size=256
613 | ```
614 |
615 | or like:
616 |
617 | ```bash
618 | python script.py -- --config=path_to_config.py:lstm \
619 | --config.model_config.hidden_size=512
620 | ```
621 |
622 | ### Additional features
623 |
624 | * Loads any valid python script which defines `get_config()` function
625 | returning any python object.
626 | * Automatic locking of the loaded object, if the loaded object defines a
627 | callable `.lock()` method.
628 | * Supports command-line overriding of arbitrarily nested values in dict-like
629 | objects (with key/attribute based getters/setters) of the following types:
630 | * `int`
631 | * `float`
632 | * `bool`
633 | * `str`
634 | * `tuple` (but **not** `list`)
635 | * `enum.Enum`
636 | * Overriding is type safe.
637 | * Overriding of a `tuple` can be done by passing in the `tuple` value as a
638 | string (see the example in the [Usage](#usage) section).
639 | * The overriding `tuple` object can be of a different length and have
640 | different item types than the original. Nested tuples are also supported.
641 |
642 | ### Config Files and Pickling {#config_files_and_pickling}
643 |
644 | This is likely to be troublesome:
645 |
646 | ```python {.bad}
647 | @dataclasses.dataclass
648 | class MyRecord:
649 | num_balloons: int
650 | color: str
651 |
652 | def get_config():
653 | return MyRecord(num_balloons=99, color='red')
654 | ```
655 |
656 | This is not:
657 |
658 | ```python {.good}
659 | def get_config():
660 | @dataclasses.dataclass
661 | class MyRecord:
662 | num_balloons: int
663 | color: str
664 |
665 | return MyRecord(num_balloons=99, color='red')
666 | ```
667 |
668 | #### Explanation
669 |
670 | A config file is a Python module but it is not imported through Python's usual
671 | module-importing mechanism.
672 |
673 | Meanwhile, serialization libraries such as [`cloudpickle`](
674 | https://github.com/cloudpipe/cloudpickle#readme) (which is used by [Launchpad](
675 | https://github.com/deepmind/launchpad#readme)) and [Apache Beam](
676 | https://beam.apache.org/) expect to be able to pickle an object without also
677 | pickling every type to which it refers, on the assumption that types defined
678 | at module scope can later be reconstructed simply by re-importing the modules
679 | in which they are defined.
680 |
681 | That assumption does not hold for a type that is defined at module scope in a
682 | config file, because the config file can't be imported the usual way. The
683 | symptom of this will be an `ImportError` when unpickling an object.
684 |
685 | The treatment is to move types from module scope into `get_config()` so that
686 | they will be serialized along with the values that have those types.
687 |
688 | ## Authors
689 | * Sergio Gómez Colmenarejo - sergomez@google.com
690 | * Wojciech Marian Czarnecki - lejlot@google.com
691 | * Nicholas Watters
692 | * Mohit Reddy - mohitreddy@google.com
693 |
--------------------------------------------------------------------------------
/docs/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # ML Collections docs
2 |
3 | The ML Collections documentation can be found here: https://ml-collections.readthedocs.io/en/latest/
4 |
5 | # How to build the docs
6 | 1. Install the requirements in ml_collections/docs/requirements.txt.
7 | 2. Ensure `pandoc` is installed.
8 | 3. Run `make html` to locally generate documentation.
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Configuration file for the Sphinx documentation builder.
16 | #
17 | # This file only contains a selection of the most common options. For a full
18 | # list see the documentation:
19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
20 |
21 | # -- Path setup --------------------------------------------------------------
22 |
23 | # If extensions (or modules to document with autodoc) are in another directory,
24 | # add these directories to sys.path here. If the directory is relative to the
25 | # documentation root, use os.path.abspath to make it absolute, like shown here.
26 | #
27 | # import os
28 | # import sys
29 | # sys.path.insert(0, os.path.abspath('.'))
30 |
31 | import os
32 | import sys
33 | sys.path.insert(0, os.path.abspath('..'))
34 |
35 | # -- Project information -----------------------------------------------------
36 |
37 | project = 'ml_collections'
38 | copyright = '2020, The ML Collection Authors'
39 | author = 'The ML Collection Authors'
40 |
41 | # The full version, including alpha/beta/rc tags
42 | release = '0.1.0'
43 |
44 |
45 | # -- General configuration ---------------------------------------------------
46 |
47 | # Add any Sphinx extension module names here, as strings. They can be
48 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
49 | # ones.
50 | extensions = [
51 | 'sphinx.ext.autodoc',
52 | 'sphinx.ext.autosummary',
53 | 'sphinx.ext.intersphinx',
54 | 'sphinx.ext.mathjax',
55 | 'sphinx.ext.napoleon',
56 | 'sphinx.ext.viewcode',
57 | 'nbsphinx',
58 | 'recommonmark',
59 | ]
60 |
61 | # Add any paths that contain templates here, relative to this directory.
62 | templates_path = ['_templates']
63 |
64 | # List of patterns, relative to source directory, that match files and
65 | # directories to ignore when looking for source files.
66 | # This pattern also affects html_static_path and html_extra_path.
67 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
68 |
69 | autosummary_generate = True
70 |
71 | master_doc = 'index'
72 |
73 | # -- Options for HTML output -------------------------------------------------
74 |
75 | # The theme to use for HTML and HTML Help pages. See the documentation for
76 | # a list of builtin themes.
77 | #
78 | html_theme = 'sphinx_rtd_theme'
79 |
80 | # Add any paths that contain custom static files (such as style sheets) here,
81 | # relative to this directory. They are copied after the builtin static files,
82 | # so a file named "default.css" will overwrite the builtin "default.css".
83 | html_static_path = ['_static']
--------------------------------------------------------------------------------
/docs/config_dict.rst:
--------------------------------------------------------------------------------
1 | ml_collections.config_dict package
2 | ==================================
3 |
4 | .. currentmodule:: ml_collections.config_dict
5 |
6 | .. automodule:: ml_collections.config_dict
7 |
8 | ConfigDict class
9 | ----------------
10 | .. autoclass:: ConfigDict
11 | :members: __init__, is_type_safe, convert_dict, lock, is_locked, unlock,
12 | get, get_oneway_ref, items, iteritems, keys, iterkeys, values, itervalues,
13 | eq_as_configdict, to_yaml, to_json, to_json_best_effort, to_dict,
14 | copy_and_resolve_references, unlocked, ignore_type, get_type, update,
15 | update_from_flattened_dict
16 |
17 | FrozenConfigDict class
18 | ----------------------
19 | .. autoclass:: FrozenConfigDict
20 | :members: __init__
21 |
22 | FieldReference class
23 | --------------------
24 | .. autoclass:: FieldReference
25 | :members: __init__, has_cycle, set, empty, get, get_type, identity, to_int,
26 | to_float, to_str
27 |
28 | Additional Methods
29 | ------------------
30 | .. autosummary::
31 | :toctree: __autosummary
32 |
33 | create
34 | placeholder
35 | required_placeholder
36 | recursive_rename
37 | CustomJSONEncoder
38 | JSONDecodeError
39 | MutabilityError
40 | RequiredValueError
--------------------------------------------------------------------------------
/docs/config_dict_examples.rst:
--------------------------------------------------------------------------------
1 | ml_collections.config_dict examples
2 | ===================================
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | ConfigDict Basic
8 | ConfigDict Advanced
9 | ConfigDict Lock
10 | ConfigDict Placeholder
11 | Field Reference
12 | FrozenConfigDict
--------------------------------------------------------------------------------
/docs/config_flags.rst:
--------------------------------------------------------------------------------
1 | ml_collections.config_flags package
2 | ===================================
3 |
4 | .. currentmodule:: ml_collections.config_flags
5 |
6 | .. automodule:: ml_collections.config_flags
7 |
8 | DEFINE_config_dict
9 | ------------------
10 | .. automethod:: ml_collections.config_flags.DEFINE_config_dict
11 |
12 | DEFINE_config_file
13 | ------------------
14 | .. automethod:: ml_collections.config_flags.DEFINE_config_file
15 |
16 | Additional Methods
17 | ------------------
18 | .. autosummary::
19 | :toctree: __autosummary
20 |
21 | ml_collections.config_flags.config_flags.is_config_flag
22 | ml_collections.config_flags.config_flags.GetValue
23 | ml_collections.config_flags.config_flags.GetType
24 | ml_collections.config_flags.config_flags.GetTypes
25 | ml_collections.config_flags.config_flags.SetValue
--------------------------------------------------------------------------------
/docs/config_flags_examples.rst:
--------------------------------------------------------------------------------
1 | ml_collections.config_flags examples
2 | ====================================
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | DEFINE_config_dict
8 | DEFINE_config_file
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. ml_collections documentation master file, created by
2 | sphinx-quickstart on Mon Aug 24 04:09:18 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to ml_collections's documentation!
7 | ==========================================
8 |
9 | ML Collections is a library of Python Collections designed for ML use cases.
10 |
11 | .. toctree::
12 | :maxdepth: 1
13 | :caption: Quickstart
14 |
15 | README
16 |
17 | .. toctree::
18 | :maxdepth: 2
19 | :caption: Examples
20 |
21 | config_dict_examples
22 | config_flags_examples
23 |
24 | .. toctree::
25 | :maxdepth: 2
26 | :caption: API Reference
27 |
28 | config_dict
29 | config_flags
30 |
31 | .. toctree::
32 | :maxdepth: 1
33 | :caption: Additional material
34 |
35 | CONTRIBUTING
36 | testing
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx
2 | sphinx_rtd_theme
3 | nbsphinx
4 | recommonmark
--------------------------------------------------------------------------------
/docs/testing.md:
--------------------------------------------------------------------------------
1 | ## Testing
2 |
3 | TODO(mohitreddy): Add sections to install bazel and run tests.
4 |
--------------------------------------------------------------------------------
/ml_collections/AUTHORS:
--------------------------------------------------------------------------------
1 | # This is the list of ML Collections's significant contributors.
2 | #
3 | # This does not necessarily list everyone who has contributed code,
4 | # especially since many employees of one corporation may be contributing.
5 | # To see the full list of contributors, see the revision history in
6 | # source control.
7 |
8 | DeepMind Technologies Limited
9 | Google LLC
--------------------------------------------------------------------------------
/ml_collections/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """ML Collections is a library of Python collections designed for ML usecases."""
16 |
17 | from ml_collections.config_dict import ConfigDict
18 | from ml_collections.config_dict import FieldReference
19 | from ml_collections.config_dict import FrozenConfigDict
20 |
21 | __all__ = ("ConfigDict", "FieldReference", "FrozenConfigDict")
22 |
23 | # A new PyPI release will be pushed every time `__version__` is increased.
24 | __version__ = "1.1.0"
25 |
--------------------------------------------------------------------------------
/ml_collections/config_dict/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Classes for defining configurations of experiments and models."""
16 |
17 | from .config_dict import _Op
18 | from .config_dict import ConfigDict
19 | from .config_dict import create
20 | from .config_dict import CustomJSONEncoder
21 | from .config_dict import FieldReference
22 | from .config_dict import FrozenConfigDict
23 | from .config_dict import JSONDecodeError
24 | from .config_dict import MutabilityError
25 | from .config_dict import placeholder
26 | from .config_dict import recursive_rename
27 | from .config_dict import required_placeholder
28 | from .config_dict import RequiredValueError
29 |
30 | __all__ = ("_Op", "ConfigDict", "create", "CustomJSONEncoder", "FieldReference",
31 | "FrozenConfigDict", "JSONDecodeError", "MutabilityError",
32 | "placeholder", "recursive_rename", "required_placeholder",
33 | "RequiredValueError")
34 |
--------------------------------------------------------------------------------
/ml_collections/config_dict/examples/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example of a config file using ConfigDict.
16 |
17 | The idea of this configuration file is to show a typical use case of ConfigDict,
18 | as well as its limitations. This also exemplifies a self-referencing ConfigDict.
19 | """
20 |
21 | import copy
22 | from ml_collections import config_dict
23 |
24 |
25 | def _get_flat_config():
26 | """Helper to generate simple config without references."""
27 |
28 | # The suggested way to create a ConfigDict() is to call its constructor
29 | # and assign all relevant fields.
30 | config = config_dict.ConfigDict()
31 |
32 | # In order to add new attributes you can just use . notation, like with any
33 | # python object. They will be tracked by ConfigDict, and you get type checking
34 | # etc. for free.
35 | config.integer = 23
36 | config.float = 2.34
37 | config.string = 'james'
38 | config.bool = True
39 |
40 | # It is possible to assign dictionaries to ConfigDict and they will be
41 | # automatically and recursively wrapped with ConfigDict. However, make sure
42 | # that the dict you are assigning does not use internal references/cycles as
43 | # this is not supported. Instead, create the dicts explicitly as demonstrated
44 | # by get_config(). But note that this operation makes an element-by-element
45 | # copy of your original dict.
46 |
47 | # Also note that the recursive wrapping on input dictionaries with ConfigDict
48 | # does not extend through non-dictionary types (including basic Python types
49 | # and custom classes). This causes unexpected behavior most commonly if a
50 | # value is a list of dictionaries, so avoid giving ConfigDict such inputs.
51 | config.dict = {
52 | 'integer': 1,
53 | 'float': 3.14,
54 | 'string': 'mark',
55 | 'bool': False,
56 | 'dict': {
57 | 'float': 5
58 | }
59 | }
60 | return config
61 |
62 |
63 | def get_config():
64 | """Returns a ConfigDict instance describing a complex config.
65 |
66 | Returns:
67 | A ConfigDict instance with the structure:
68 |
69 | ```
70 | CONFIG-+-- integer
71 | |-- float
72 | |-- string
73 | |-- bool
74 | |-- dict +-- integer
75 | | |-- float
76 | | |-- string
77 | | |-- bool
78 | | |-- dict +-- float
79 | |
80 | |-- object +-- integer
81 | | |-- float
82 | | |-- string
83 | | |-- bool
84 | | |-- dict +-- integer
85 | | |-- float
86 | | |-- string
87 | | |-- bool
88 | | |-- dict +-- float
89 | |
90 | |-- object_copy +-- integer
91 | | |-- float
92 | | |-- string
93 | | |-- bool
94 | | |-- dict +-- integer
95 | | |-- float
96 | | |-- string
97 | | |-- bool
98 | | |-- dict +-- float
99 | |
100 | |-- object_reference [reference pointing to CONFIG-+--object]
101 | ```
102 | """
103 | config = _get_flat_config()
104 | config.object = _get_flat_config()
105 |
106 | # References work just fine, so you will be able to override both
107 | # values at the same time. The rule is the same as for python objects,
108 | # everything that is mutable is passed as a reference, thus it will not work
109 | # with assigning integers or strings, but will work just fine with
110 | # ConfigDicts.
111 | # WARNING: Each time you assign a dictionary as a value it will create a new
112 | # instance of ConfigDict in memory, thus it will be a copy of the original
113 | # dict and not a reference to the original.
114 | config.object_reference = config.object
115 |
116 | # ConfigDict supports deepcopying.
117 | config.object_copy = copy.deepcopy(config.object)
118 |
119 | return config
120 |
--------------------------------------------------------------------------------
/ml_collections/config_dict/examples/config_dict_advanced.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example of ConfigDict usage.
16 |
17 | This example includes loading a ConfigDict in FLAGS, locking it, type
18 | safety, iteration over fields, checking for a particular field, unpacking with
19 | `**`, and loading dictionary from string representation.
20 | """
21 |
22 | from absl import app
23 | from ml_collections import config_flags
24 | import yaml
25 |
26 | _CONFIG = config_flags.DEFINE_config_file(
27 | 'my_config',
28 | default='ml_collections/config_dict/examples/config.py')
29 |
30 |
31 | def hello_function(string, **unused_kwargs):
32 | return 'Hello {}'.format(string)
33 |
34 |
35 | def print_section(name):
36 | print()
37 | print()
38 | print('-' * len(name))
39 | print(name.upper())
40 | print('-' * len(name))
41 | print()
42 |
43 |
44 | def main(_):
45 | # Config is already loaded in FLAGS.my_config due to the logic hidden
46 | # in app.run().
47 | config = _CONFIG.value
48 |
49 | print_section('Printing config.')
50 | print(config)
51 |
52 | # Config is of our type ConfigDict.
53 | print('Type of the config {}'.format(type(config)))
54 |
55 | # By default it is locked, thus you cannot add new fields.
56 | # This prevents you from misspelling your attribute name.
57 | print_section('Locking.')
58 | print('config.is_locked={}'.format(config.is_locked))
59 | try:
60 | config.object.new_field = -3
61 | except AttributeError as e:
62 | print(e)
63 |
64 | # There is also "did you mean" feature!
65 | try:
66 | config.object.floet = -3.
67 | except AttributeError as e:
68 | print(e)
69 |
70 | # However if you want to modify it you can always unlock.
71 | print_section('Unlocking.')
72 | with config.unlocked():
73 | config.object.new_field = -3
74 | print('config.object.new_field={}'.format(config.object.new_field))
75 |
76 | # By default config is also type-safe, so you cannot change the type of any
77 | # field.
78 | print_section('Type safety.')
79 | try:
80 | config.float = 'jerry'
81 | except TypeError as e:
82 | print(e)
83 | config.float = -1.2
84 | print('config.float={}'.format(config.float))
85 |
86 | # NoneType is ignored by type safety and can both override and be overridden.
87 | config.float = None
88 | config.float = -1.2
89 |
90 | # You can temporarly turn type safety off.
91 | with config.ignore_type():
92 | config.float = 'tom'
93 | print('config.float={}'.format(config.float))
94 | config.float = 2.3
95 | print('config.float={}'.format(config.float))
96 |
97 | # You can use ConfigDict as a regular dict in many typical use-cases:
98 | # Iteration over fields:
99 | print_section('Iteration over fields.')
100 | for field in config:
101 | print('config has field "{}"'.format(field))
102 |
103 | # Checking if it contains a particular field using the "in" command.
104 | print_section('Checking for a particular field.')
105 | for field in ('float', 'non_existing'):
106 | if field in config:
107 | print('"{}" is in config'.format(field))
108 | else:
109 | print('"{}" is not in config'.format(field))
110 |
111 | # Using ** unrolling to pass the config to a function as named arguments.
112 | print_section('Unpacking with **')
113 | print(hello_function(**config))
114 |
115 | # You can even load a dictionary (notice it is not ConfigDict anymore) from
116 | # a yaml string representation of ConfigDict.
117 | # Note: __repr__ (not __str__) is the recommended representation, as it
118 | # preserves FieldReferences and placeholders.
119 | print_section('Loading dictionary from string representation.')
120 | dictionary = yaml.load(repr(config), yaml.UnsafeLoader)
121 | print('dict["object_reference"]["dict"]["dict"]["float"]={}'.format(
122 | dictionary['object_reference']['dict']['dict']['float']))
123 |
124 |
125 | if __name__ == '__main__':
126 | app.run(main)
127 |
--------------------------------------------------------------------------------
/ml_collections/config_dict/examples/config_dict_basic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example of basic ConfigDict usage.
16 |
17 | This example shows the most basic usage of ConfigDict, including type safety.
18 | For examples of more features, see example_advanced.
19 | """
20 |
21 | from absl import app
22 | from ml_collections import config_dict
23 |
24 |
25 | def main(_):
26 | cfg = config_dict.ConfigDict()
27 | cfg.float_field = 12.6
28 | cfg.integer_field = 123
29 | cfg.another_integer_field = 234
30 | cfg.nested = config_dict.ConfigDict()
31 | cfg.nested.string_field = 'tom'
32 |
33 | print(cfg.integer_field) # Prints 123.
34 | print(cfg['integer_field']) # Prints 123 as well.
35 |
36 | try:
37 | cfg.integer_field = 'tom' # Raises TypeError as this field is an integer.
38 | except TypeError as e:
39 | print(e)
40 |
41 | cfg.float_field = 12 # Works: `int` types can be assigned to `float`.
42 | cfg.nested.string_field = u'bob' # `String` fields can store Unicode strings.
43 |
44 | print(cfg)
45 |
46 |
47 | if __name__ == '__main__':
48 | app.run(main)
49 |
--------------------------------------------------------------------------------
/ml_collections/config_dict/examples/config_dict_initialization.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example of initialization features and gotchas in a ConfigDict.
16 | """
17 |
18 | import copy
19 |
20 | from absl import app
21 | from ml_collections import config_dict
22 |
23 |
24 | def print_section(name):
25 | print()
26 | print()
27 | print('-' * len(name))
28 | print(name.upper())
29 | print('-' * len(name))
30 | print()
31 |
32 |
33 | def main(_):
34 |
35 | inner_dict = {'list': [1, 2], 'tuple': (1, 2, [3, 4], (5, 6))}
36 | example_dict = {
37 | 'string': 'tom',
38 | 'int': 2,
39 | 'list': [1, 2],
40 | 'set': {1, 2},
41 | 'tuple': (1, 2),
42 | 'ref': config_dict.FieldReference({'int': 0}),
43 | 'inner_dict_1': inner_dict,
44 | 'inner_dict_2': inner_dict
45 | }
46 |
47 | print_section('Initializing on dictionary.')
48 | # ConfigDict can be initialized on example_dict
49 | example_cd = config_dict.ConfigDict(example_dict)
50 |
51 | # Dictionary fields are also converted to ConfigDict
52 | print(type(example_cd.inner_dict_1))
53 |
54 | # And the reference structure is preserved
55 | print(id(example_cd.inner_dict_1) == id(example_cd.inner_dict_2))
56 |
57 | print_section('Initializing on ConfigDict.')
58 |
59 | # ConfigDict can also be initialized on a ConfigDict
60 | example_cd_cd = config_dict.ConfigDict(example_cd)
61 |
62 | # Yielding the same result:
63 | print(example_cd == example_cd_cd)
64 |
65 | # Note that the memory addresses are different
66 | print(id(example_cd) == id(example_cd_cd))
67 |
68 | # The memory addresses of the attributes are not the same because of the
69 | # FieldReference, which gets removed on the second initialization
70 | list_to_ids = lambda x: [id(element) for element in x]
71 | print(
72 | set(list_to_ids(list(example_cd.values()))) == set(
73 | list_to_ids(list(example_cd_cd.values()))))
74 |
75 | print_section('Initializing on self-referencing dictionary.')
76 |
77 | # Initialization works on a self-referencing dict
78 | self_ref_dict = copy.deepcopy(example_dict)
79 | self_ref_dict['self'] = self_ref_dict
80 | self_ref_cd = config_dict.ConfigDict(self_ref_dict)
81 |
82 | # And the reference structure is replicated
83 | print(id(self_ref_cd) == id(self_ref_cd.self))
84 |
85 | print_section('Unexpected initialization behavior.')
86 |
87 | # ConfigDict initialization doesn't look inside lists, so doesn't convert a
88 | # dict in a list to ConfigDict
89 | dict_in_list_in_dict = {'list': [{'troublemaker': 0}]}
90 | dict_in_list_in_dict_cd = config_dict.ConfigDict(dict_in_list_in_dict)
91 | print(type(dict_in_list_in_dict_cd.list[0]))
92 |
93 | # This can cause the reference structure to not be replicated
94 | referred_dict = {'key': 'value'}
95 | bad_reference = {'referred_dict': referred_dict, 'list': [referred_dict]}
96 | bad_reference_cd = config_dict.ConfigDict(bad_reference)
97 | print(id(bad_reference_cd.referred_dict) == id(bad_reference_cd.list[0]))
98 |
99 |
100 | if __name__ == '__main__':
101 | app.run()
102 |
--------------------------------------------------------------------------------
/ml_collections/config_dict/examples/config_dict_lock.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example of ConfigDict usage of lock.
16 |
17 | This example shows the roles and scopes of ConfigDict's lock().
18 | """
19 |
20 | from absl import app
21 | from ml_collections import config_dict
22 |
23 |
24 | def main(_):
25 | cfg = config_dict.ConfigDict()
26 | cfg.integer_field = 123
27 |
28 | # Locking prohibits the addition and deletion of new fields but allows
29 | # modification of existing values. Locking happens automatically during
30 | # loading through flags.
31 | cfg.lock()
32 | try:
33 | cfg.intagar_field = 124 # Raises AttributeError and suggests valid field.
34 | except AttributeError as e:
35 | print(e)
36 | cfg.integer_field = -123 # Works fine.
37 |
38 | with cfg.unlocked():
39 | cfg.intagar_field = 1555 # Works fine too.
40 |
41 | print(cfg)
42 |
43 |
44 | if __name__ == '__main__':
45 | app.run()
46 |
--------------------------------------------------------------------------------
/ml_collections/config_dict/examples/config_dict_placeholder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example of placeholder fields in a ConfigDict.
16 |
17 | This example shows how ConfigDict placeholder fields work. For a more complete
18 | example of ConfigDict features, see example_advanced.
19 | """
20 |
21 | from absl import app
22 | from ml_collections import config_dict
23 |
24 |
25 | def main(_):
26 | placeholder = config_dict.FieldReference(0)
27 | cfg = config_dict.ConfigDict()
28 | cfg.placeholder = placeholder
29 | cfg.optional = config_dict.FieldReference(0, field_type=int)
30 | cfg.nested = config_dict.ConfigDict()
31 | cfg.nested.placeholder = placeholder
32 |
33 | try:
34 | cfg.optional = 'tom' # Raises Type error as this field is an integer.
35 | except TypeError as e:
36 | print(e)
37 |
38 | cfg.optional = 1555 # Works fine.
39 | cfg.placeholder = 1 # Changes the value of both placeholder and
40 | # nested.placeholder fields.
41 |
42 | # Note that the indirection provided by FieldReferences will be lost if
43 | # accessed through a ConfigDict:
44 | placeholder = config_dict.FieldReference(0)
45 | cfg.field1 = placeholder
46 | cfg.field2 = placeholder # This field will be tied to cfg.field1.
47 | cfg.field3 = cfg.field1 # This will just be an int field initialized to 0.
48 |
49 | print(cfg)
50 |
51 |
52 | if __name__ == '__main__':
53 | app.run()
54 |
--------------------------------------------------------------------------------
/ml_collections/config_dict/examples/examples_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for ConfigDict examples.
16 |
17 | Ensures that config_dict_basic, config_dict_initialization, config_dict_lock,
18 | config_dict_placeholder, field_reference, frozen_config_dict run successfully.
19 | """
20 |
21 | from absl.testing import absltest
22 | from absl.testing import parameterized
23 | from ml_collections.config_dict.examples import config_dict_advanced
24 | from ml_collections.config_dict.examples import config_dict_basic
25 | from ml_collections.config_dict.examples import config_dict_initialization
26 | from ml_collections.config_dict.examples import config_dict_lock
27 | from ml_collections.config_dict.examples import config_dict_placeholder
28 | from ml_collections.config_dict.examples import field_reference
29 | from ml_collections.config_dict.examples import frozen_config_dict
30 |
31 |
32 | class ConfigDictExamplesTest(parameterized.TestCase):
33 |
34 | @parameterized.parameters(config_dict_advanced, config_dict_basic,
35 | config_dict_initialization, config_dict_lock,
36 | config_dict_placeholder, field_reference,
37 | frozen_config_dict)
38 | def testScriptRuns(self, example_name):
39 | example_name.main(None)
40 |
41 |
42 | if __name__ == '__main__':
43 | absltest.main()
44 |
--------------------------------------------------------------------------------
/ml_collections/config_dict/examples/field_reference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example of FieldReference usage.
16 |
17 | This shows how to use FieldReferences for lazy computation.
18 | """
19 |
20 | from absl import app
21 | from ml_collections import config_dict
22 |
23 |
24 | def lazy_computation():
25 | """Simple example of lazy computation with `configdict.FieldReference`."""
26 | ref = config_dict.FieldReference(1)
27 | print(ref.get()) # Prints 1
28 |
29 | add_ten = ref.get() + 10 # ref.get() is an integer and so is add_ten
30 | add_ten_lazy = ref + 10 # add_ten_lazy is a FieldReference - NOT an integer
31 |
32 | print(add_ten) # Prints 11
33 | print(add_ten_lazy.get()) # Prints 11 because ref's value is 1
34 |
35 | # Addition is lazily computed for FieldReferences so changing ref will change
36 | # the value that is used to compute add_ten.
37 | ref.set(5)
38 | print(add_ten) # Prints 11
39 | print(add_ten_lazy.get()) # Prints 15 because ref's value is 5
40 |
41 |
42 | def change_lazy_computation():
43 | """Overriding lazily computed values."""
44 | config = config_dict.ConfigDict()
45 | config.reference = 1
46 | config.reference_0 = config.get_ref('reference') + 10
47 | config.reference_1 = config.get_ref('reference') + 20
48 | config.reference_1_0 = config.get_ref('reference_1') + 100
49 |
50 | print(config.reference) # Prints 1.
51 | print(config.reference_0) # Prints 11.
52 | print(config.reference_1) # Prints 21.
53 | print(config.reference_1_0) # Prints 121.
54 |
55 | config.reference_1 = 30
56 |
57 | print(config.reference) # Prints 1 (unchanged).
58 | print(config.reference_0) # Prints 11 (unchanged).
59 | print(config.reference_1) # Prints 30.
60 | print(config.reference_1_0) # Prints 130.
61 |
62 |
63 | def create_cycle():
64 | """Creates a cycle within a ConfigDict."""
65 | config = config_dict.ConfigDict()
66 | config.integer_field = 1
67 | config.bigger_integer_field = config.get_ref('integer_field') + 10
68 |
69 | try:
70 | # Raises a MutabilityError because setting config.integer_field would
71 | # cause a cycle.
72 | config.integer_field = config.get_ref('bigger_integer_field') + 2
73 | except config_dict.MutabilityError as e:
74 | print(e)
75 |
76 |
77 | def lazy_configdict():
78 | """Example usage of lazy computation with ConfigDict."""
79 | config = config_dict.ConfigDict()
80 | config.reference_field = config_dict.FieldReference(1)
81 | config.integer_field = 2
82 | config.float_field = 2.5
83 |
84 | # No lazy evaluatuations because we didn't use get_ref()
85 | config.no_lazy = config.integer_field * config.float_field
86 |
87 | # This will lazily evaluate ONLY config.integer_field
88 | config.lazy_integer = config.get_ref('integer_field') * config.float_field
89 |
90 | # This will lazily evaluate ONLY config.float_field
91 | config.lazy_float = config.integer_field * config.get_ref('float_field')
92 |
93 | # This will lazily evaluate BOTH config.integer_field and config.float_Field
94 | config.lazy_both = (config.get_ref('integer_field') *
95 | config.get_ref('float_field'))
96 |
97 | config.integer_field = 3
98 | print(config.no_lazy) # Prints 5.0 - It uses integer_field's original value
99 |
100 | print(config.lazy_integer) # Prints 7.5
101 |
102 | config.float_field = 3.5
103 | print(config.lazy_float) # Prints 7.0
104 | print(config.lazy_both) # Prints 10.5
105 |
106 |
107 | def lazy_configdict_advanced():
108 | """Advanced lazy computation with ConfigDict."""
109 | # FieldReferences can be used with ConfigDict as well
110 | config = config_dict.ConfigDict()
111 | config.float_field = 12.6
112 | config.integer_field = 123
113 | config.list_field = [0, 1, 2]
114 |
115 | config.float_multiply_field = config.get_ref('float_field') * 3
116 | print(config.float_multiply_field) # Prints 37.8
117 |
118 | config.float_field = 10.0
119 | print(config.float_multiply_field) # Prints 30.0
120 |
121 | config.longer_list_field = config.get_ref('list_field') + [3, 4, 5]
122 | print(config.longer_list_field) # Prints [0, 1, 2, 3, 4, 5]
123 |
124 | config.list_field = [-1]
125 | print(config.longer_list_field) # Prints [-1, 3, 4, 5]
126 |
127 | # Both operands can be references
128 | config.ref_subtraction = (
129 | config.get_ref('float_field') - config.get_ref('integer_field'))
130 | print(config.ref_subtraction) # Prints -113.0
131 |
132 | config.integer_field = 10
133 | print(config.ref_subtraction) # Prints 0.0
134 |
135 |
136 | def main(argv=()):
137 | del argv # Unused.
138 | lazy_computation()
139 | lazy_configdict()
140 | change_lazy_computation()
141 | create_cycle()
142 | lazy_configdict_advanced()
143 |
144 |
145 | if __name__ == '__main__':
146 | app.run()
147 |
--------------------------------------------------------------------------------
/ml_collections/config_dict/examples/frozen_config_dict.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example of basic FrozenConfigDict usage.
16 |
17 | This example shows the most basic usage of FrozenConfigDict, highlighting
18 | the differences between FrozenConfigDict and ConfigDict and including
19 | converting between the two.
20 | """
21 |
22 | from absl import app
23 | from ml_collections import config_dict
24 |
25 |
26 | def print_section(name):
27 | print()
28 | print()
29 | print('-' * len(name))
30 | print(name.upper())
31 | print('-' * len(name))
32 | print()
33 |
34 |
35 | def main(_):
36 | print_section('Attribute Types.')
37 | cfg = config_dict.ConfigDict()
38 | cfg.int = 1
39 | cfg.list = [1, 2, 3]
40 | cfg.tuple = (1, 2, 3)
41 | cfg.set = {1, 2, 3}
42 | cfg.frozenset = frozenset({1, 2, 3})
43 | cfg.dict = {
44 | 'nested_int': 4,
45 | 'nested_list': [4, 5, 6],
46 | 'nested_tuple': ([4], 5, 6),
47 | }
48 |
49 | print('Types of cfg fields:')
50 | print('list: ', type(cfg.list)) # List
51 | print('set: ', type(cfg.set)) # Set
52 | print('nested_list: ', type(cfg.dict.nested_list)) # List
53 | print('nested_tuple[0]: ', type(cfg.dict.nested_tuple[0])) # List
54 |
55 | frozen_cfg = config_dict.FrozenConfigDict(cfg)
56 | print('\nTypes of FrozenConfigDict(cfg) fields:')
57 | print('list: ', type(frozen_cfg.list)) # Tuple
58 | print('set: ', type(frozen_cfg.set)) # Frozenset
59 | print('nested_list: ', type(frozen_cfg.dict.nested_list)) # Tuple
60 | print('nested_tuple[0]: ', type(frozen_cfg.dict.nested_tuple[0])) # Tuple
61 |
62 | cfg_from_frozen = config_dict.ConfigDict(frozen_cfg)
63 | print('\nTypes of ConfigDict(FrozenConfigDict(cfg)) fields:')
64 | print('list: ', type(cfg_from_frozen.list)) # List
65 | print('set: ', type(cfg_from_frozen.set)) # Set
66 | print('nested_list: ', type(cfg_from_frozen.dict.nested_list)) # List
67 | print('nested_tuple[0]: ', type(cfg_from_frozen.dict.nested_tuple[0])) # List
68 |
69 | print('\nCan use FrozenConfigDict.as_configdict() to convert to ConfigDict:')
70 | print(cfg_from_frozen == frozen_cfg.as_configdict()) # True
71 |
72 | print_section('Immutability.')
73 | try:
74 | frozen_cfg.new_field = 1 # Raises AttributeError because of immutability.
75 | except AttributeError as e:
76 | print(e)
77 |
78 | print_section('"==" and eq_as_configdict().')
79 | # FrozenConfigDict.__eq__() is not type-invariant with respect to ConfigDict
80 | print(frozen_cfg == cfg) # False
81 | # FrozenConfigDict.eq_as_configdict() is type-invariant with respect to
82 | # ConfigDict
83 | print(frozen_cfg.eq_as_configdict(cfg)) # True
84 | # .eq_as_congfigdict() is also a method of ConfigDict
85 | print(cfg.eq_as_configdict(frozen_cfg)) # True
86 |
87 |
88 | if __name__ == '__main__':
89 | app.run(main)
90 |
--------------------------------------------------------------------------------
/ml_collections/config_dict/tests/field_reference_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for config_dict.FieldReference."""
16 |
17 | import operator
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | from ml_collections import config_dict
22 |
23 |
24 | class FieldReferenceTest(parameterized.TestCase):
25 |
26 | def _test_binary_operator(self,
27 | initial_value,
28 | other_value,
29 | op,
30 | true_value,
31 | new_initial_value,
32 | new_true_value,
33 | assert_fn=None):
34 | """Helper for testing binary operators.
35 |
36 | Generally speaking this checks that:
37 | 1. `op(initial_value, other_value) COMP true_value`
38 | 2. `op(new_initial_value, other_value) COMP new_true_value
39 | where `COMP` is the comparison function defined by `assert_fn`.
40 |
41 | Args:
42 | initial_value: Initial value for the `FieldReference`, this is the first
43 | argument for the binary operator.
44 | other_value: The second argument for the binary operator.
45 | op: The binary operator.
46 | true_value: The expected output of the binary operator.
47 | new_initial_value: The value that the `FieldReference` is changed to.
48 | new_true_value: The expected output of the binary operator after the
49 | `FieldReference` has changed.
50 | assert_fn: Function used to check the output values.
51 | """
52 | if assert_fn is None:
53 | assert_fn = self.assertEqual
54 |
55 | ref = config_dict.FieldReference(initial_value)
56 | new_ref = op(ref, other_value)
57 | assert_fn(new_ref.get(), true_value)
58 |
59 | config = config_dict.ConfigDict()
60 | config.a = initial_value
61 | config.b = other_value
62 | config.result = op(config.get_ref('a'), config.b)
63 | assert_fn(config.result, true_value)
64 |
65 | config.a = new_initial_value
66 | assert_fn(config.result, new_true_value)
67 |
68 | def _test_unary_operator(self,
69 | initial_value,
70 | op,
71 | true_value,
72 | new_initial_value,
73 | new_true_value,
74 | assert_fn=None):
75 | """Helper for testing unary operators.
76 |
77 | Generally speaking this checks that:
78 | 1. `op(initial_value) COMP true_value`
79 | 2. `op(new_initial_value) COMP new_true_value
80 | where `COMP` is the comparison function defined by `assert_fn`.
81 |
82 | Args:
83 | initial_value: Initial value for the `FieldReference`, this is the first
84 | argument for the unary operator.
85 | op: The unary operator.
86 | true_value: The expected output of the unary operator.
87 | new_initial_value: The value that the `FieldReference` is changed to.
88 | new_true_value: The expected output of the unary operator after the
89 | `FieldReference` has changed.
90 | assert_fn: Function used to check the output values.
91 | """
92 | if assert_fn is None:
93 | assert_fn = self.assertEqual
94 |
95 | ref = config_dict.FieldReference(initial_value)
96 | new_ref = op(ref)
97 | assert_fn(new_ref.get(), true_value)
98 |
99 | config = config_dict.ConfigDict()
100 | config.a = initial_value
101 | config.result = op(config.get_ref('a'))
102 | assert_fn(config.result, true_value)
103 |
104 | config.a = new_initial_value
105 | assert_fn(config.result, new_true_value)
106 |
107 | def testBasic(self):
108 | ref = config_dict.FieldReference(1)
109 | self.assertEqual(ref.get(), 1)
110 |
111 | def testGetRef(self):
112 | config = config_dict.ConfigDict()
113 | config.a = 1.
114 | config.b = config.get_ref('a') + 10
115 | config.c = config.get_ref('b') + 10
116 | self.assertEqual(config.c, 21.0)
117 |
118 | def testFunction(self):
119 |
120 | def fn(x):
121 | return x + 5
122 |
123 | config = config_dict.ConfigDict()
124 | config.a = 1
125 | config.b = fn(config.get_ref('a'))
126 | config.c = fn(config.get_ref('b'))
127 |
128 | self.assertEqual(config.b, 6)
129 | self.assertEqual(config.c, 11)
130 | config.a = 2
131 | self.assertEqual(config.b, 7)
132 | self.assertEqual(config.c, 12)
133 |
134 | def testCycles(self):
135 | config = config_dict.ConfigDict()
136 | config.a = 1.
137 | config.b = config.get_ref('a') + 10
138 | config.c = config.get_ref('b') + 10
139 |
140 | self.assertEqual(config.b, 11.0)
141 | self.assertEqual(config.c, 21.0)
142 |
143 | # Introduce a cycle
144 | with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'):
145 | config.a = config.get_ref('c') - 1.0
146 |
147 | # Introduce a cycle on second operand
148 | with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'):
149 | config.a = config_dict.FieldReference(5.0) + config.get_ref('c')
150 |
151 | # We can create multiple FieldReferences that all point to the same object
152 | l = [0]
153 | config = config_dict.ConfigDict()
154 | config.a = l
155 | config.b = l
156 | config.c = config.get_ref('a') + ['c']
157 | config.d = config.get_ref('b') + ['d']
158 |
159 | self.assertEqual(config.c, [0, 'c'])
160 | self.assertEqual(config.d, [0, 'd'])
161 |
162 | # Make sure nothing was mutated
163 | self.assertEqual(l, [0])
164 | self.assertEqual(config.c, [0, 'c'])
165 |
166 | config.a = [1]
167 | config.b = [2]
168 | self.assertEqual(l, [0])
169 | self.assertEqual(config.c, [1, 'c'])
170 | self.assertEqual(config.d, [2, 'd'])
171 |
172 | @parameterized.parameters(
173 | {
174 | 'initial_value': 1,
175 | 'other_value': 2,
176 | 'true_value': 3,
177 | 'new_initial_value': 10,
178 | 'new_true_value': 12
179 | }, {
180 | 'initial_value': 2.0,
181 | 'other_value': 2.5,
182 | 'true_value': 4.5,
183 | 'new_initial_value': 3.7,
184 | 'new_true_value': 6.2
185 | }, {
186 | 'initial_value': 'hello, ',
187 | 'other_value': 'world!',
188 | 'true_value': 'hello, world!',
189 | 'new_initial_value': 'foo, ',
190 | 'new_true_value': 'foo, world!'
191 | }, {
192 | 'initial_value': ['hello'],
193 | 'other_value': ['world'],
194 | 'true_value': ['hello', 'world'],
195 | 'new_initial_value': ['foo'],
196 | 'new_true_value': ['foo', 'world']
197 | }, {
198 | 'initial_value': config_dict.FieldReference(10),
199 | 'other_value': config_dict.FieldReference(5.0),
200 | 'true_value': 15.0,
201 | 'new_initial_value': 12,
202 | 'new_true_value': 17.0
203 | }, {
204 | 'initial_value': config_dict.placeholder(float),
205 | 'other_value': 7.0,
206 | 'true_value': None,
207 | 'new_initial_value': 12,
208 | 'new_true_value': 19.0
209 | }, {
210 | 'initial_value': 5.0,
211 | 'other_value': config_dict.placeholder(float),
212 | 'true_value': None,
213 | 'new_initial_value': 8.0,
214 | 'new_true_value': None
215 | }, {
216 | 'initial_value': config_dict.placeholder(str),
217 | 'other_value': 'tail',
218 | 'true_value': None,
219 | 'new_initial_value': 'head',
220 | 'new_true_value': 'headtail'
221 | })
222 | def testAdd(self, initial_value, other_value, true_value, new_initial_value,
223 | new_true_value):
224 | self._test_binary_operator(initial_value, other_value, operator.add,
225 | true_value, new_initial_value, new_true_value)
226 |
227 | @parameterized.parameters(
228 | {
229 | 'initial_value': 5,
230 | 'other_value': 3,
231 | 'true_value': 2,
232 | 'new_initial_value': -1,
233 | 'new_true_value': -4
234 | }, {
235 | 'initial_value': 2.0,
236 | 'other_value': 2.5,
237 | 'true_value': -0.5,
238 | 'new_initial_value': 12.3,
239 | 'new_true_value': 9.8
240 | }, {
241 | 'initial_value': set(['hello', 123, 4.5]),
242 | 'other_value': set([123]),
243 | 'true_value': set(['hello', 4.5]),
244 | 'new_initial_value': set([123]),
245 | 'new_true_value': set([])
246 | }, {
247 | 'initial_value': config_dict.FieldReference(10),
248 | 'other_value': config_dict.FieldReference(5.0),
249 | 'true_value': 5.0,
250 | 'new_initial_value': 12,
251 | 'new_true_value': 7.0
252 | }, {
253 | 'initial_value': config_dict.placeholder(float),
254 | 'other_value': 7.0,
255 | 'true_value': None,
256 | 'new_initial_value': 12,
257 | 'new_true_value': 5.0
258 | })
259 | def testSub(self, initial_value, other_value, true_value, new_initial_value,
260 | new_true_value):
261 | self._test_binary_operator(initial_value, other_value, operator.sub,
262 | true_value, new_initial_value, new_true_value)
263 |
264 | @parameterized.parameters(
265 | {
266 | 'initial_value': 1,
267 | 'other_value': 2,
268 | 'true_value': 2,
269 | 'new_initial_value': 3,
270 | 'new_true_value': 6
271 | }, {
272 | 'initial_value': 2.0,
273 | 'other_value': 2.5,
274 | 'true_value': 5.0,
275 | 'new_initial_value': 3.5,
276 | 'new_true_value': 8.75
277 | }, {
278 | 'initial_value': ['hello'],
279 | 'other_value': 3,
280 | 'true_value': ['hello', 'hello', 'hello'],
281 | 'new_initial_value': ['foo'],
282 | 'new_true_value': ['foo', 'foo', 'foo']
283 | }, {
284 | 'initial_value': config_dict.FieldReference(10),
285 | 'other_value': config_dict.FieldReference(5.0),
286 | 'true_value': 50.0,
287 | 'new_initial_value': 1,
288 | 'new_true_value': 5.0
289 | }, {
290 | 'initial_value': config_dict.placeholder(float),
291 | 'other_value': 7.0,
292 | 'true_value': None,
293 | 'new_initial_value': 12,
294 | 'new_true_value': 84.0
295 | })
296 | def testMul(self, initial_value, other_value, true_value, new_initial_value,
297 | new_true_value):
298 | self._test_binary_operator(initial_value, other_value, operator.mul,
299 | true_value, new_initial_value, new_true_value)
300 |
301 | @parameterized.parameters(
302 | {
303 | 'initial_value': 3,
304 | 'other_value': 2,
305 | 'true_value': 1.5,
306 | 'new_initial_value': 10,
307 | 'new_true_value': 5.0
308 | }, {
309 | 'initial_value': 2.0,
310 | 'other_value': 2.5,
311 | 'true_value': 0.8,
312 | 'new_initial_value': 6.3,
313 | 'new_true_value': 2.52
314 | }, {
315 | 'initial_value': config_dict.FieldReference(10),
316 | 'other_value': config_dict.FieldReference(5.0),
317 | 'true_value': 2.0,
318 | 'new_initial_value': 13,
319 | 'new_true_value': 2.6
320 | }, {
321 | 'initial_value': config_dict.placeholder(float),
322 | 'other_value': 7.0,
323 | 'true_value': None,
324 | 'new_initial_value': 17.5,
325 | 'new_true_value': 2.5
326 | })
327 | def testTrueDiv(self, initial_value, other_value, true_value,
328 | new_initial_value, new_true_value):
329 | self._test_binary_operator(initial_value, other_value, operator.truediv,
330 | true_value, new_initial_value, new_true_value)
331 |
332 | @parameterized.parameters(
333 | {
334 | 'initial_value': 3,
335 | 'other_value': 2,
336 | 'true_value': 1,
337 | 'new_initial_value': 7,
338 | 'new_true_value': 3
339 | }, {
340 | 'initial_value': config_dict.FieldReference(10),
341 | 'other_value': config_dict.FieldReference(5),
342 | 'true_value': 2,
343 | 'new_initial_value': 28,
344 | 'new_true_value': 5
345 | }, {
346 | 'initial_value': config_dict.placeholder(int),
347 | 'other_value': 7,
348 | 'true_value': None,
349 | 'new_initial_value': 25,
350 | 'new_true_value': 3
351 | })
352 | def testFloorDiv(self, initial_value, other_value, true_value,
353 | new_initial_value, new_true_value):
354 | self._test_binary_operator(initial_value, other_value, operator.floordiv,
355 | true_value, new_initial_value, new_true_value)
356 |
357 | @parameterized.parameters(
358 | {
359 | 'initial_value': 3,
360 | 'other_value': 2,
361 | 'true_value': 9,
362 | 'new_initial_value': 10,
363 | 'new_true_value': 100
364 | }, {
365 | 'initial_value': 2.7,
366 | 'other_value': 3.2,
367 | 'true_value': 24.0084457245,
368 | 'new_initial_value': 6.5,
369 | 'new_true_value': 399.321543621
370 | }, {
371 | 'initial_value': config_dict.FieldReference(10),
372 | 'other_value': config_dict.FieldReference(5),
373 | 'true_value': 1e5,
374 | 'new_initial_value': 2,
375 | 'new_true_value': 32
376 | }, {
377 | 'initial_value': config_dict.placeholder(float),
378 | 'other_value': 3.0,
379 | 'true_value': None,
380 | 'new_initial_value': 7.0,
381 | 'new_true_value': 343.0
382 | })
383 | def testPow(self, initial_value, other_value, true_value, new_initial_value,
384 | new_true_value):
385 | self._test_binary_operator(
386 | initial_value,
387 | other_value,
388 | operator.pow,
389 | true_value,
390 | new_initial_value,
391 | new_true_value,
392 | assert_fn=self.assertAlmostEqual)
393 |
394 | @parameterized.parameters(
395 | {
396 | 'initial_value': 3,
397 | 'other_value': 2,
398 | 'true_value': 1,
399 | 'new_initial_value': 10,
400 | 'new_true_value': 0
401 | }, {
402 | 'initial_value': 5.3,
403 | 'other_value': 3.2,
404 | 'true_value': 2.0999999999999996,
405 | 'new_initial_value': 77,
406 | 'new_true_value': 0.2
407 | }, {
408 | 'initial_value': config_dict.FieldReference(10),
409 | 'other_value': config_dict.FieldReference(5),
410 | 'true_value': 0,
411 | 'new_initial_value': 32,
412 | 'new_true_value': 2
413 | }, {
414 | 'initial_value': config_dict.placeholder(int),
415 | 'other_value': 7,
416 | 'true_value': None,
417 | 'new_initial_value': 25,
418 | 'new_true_value': 4
419 | })
420 | def testMod(self, initial_value, other_value, true_value, new_initial_value,
421 | new_true_value):
422 | self._test_binary_operator(
423 | initial_value,
424 | other_value,
425 | operator.mod,
426 | true_value,
427 | new_initial_value,
428 | new_true_value,
429 | assert_fn=self.assertAlmostEqual)
430 |
431 | @parameterized.parameters(
432 | {
433 | 'initial_value': True,
434 | 'other_value': True,
435 | 'true_value': True,
436 | 'new_initial_value': False,
437 | 'new_true_value': False
438 | }, {
439 | 'initial_value': config_dict.FieldReference(False),
440 | 'other_value': config_dict.FieldReference(False),
441 | 'true_value': False,
442 | 'new_initial_value': True,
443 | 'new_true_value': False
444 | }, {
445 | 'initial_value': config_dict.placeholder(bool),
446 | 'other_value': True,
447 | 'true_value': None,
448 | 'new_initial_value': False,
449 | 'new_true_value': False
450 | })
451 | def testAnd(self, initial_value, other_value, true_value, new_initial_value,
452 | new_true_value):
453 | self._test_binary_operator(initial_value, other_value, operator.and_,
454 | true_value, new_initial_value, new_true_value)
455 |
456 | @parameterized.parameters(
457 | {
458 | 'initial_value': False,
459 | 'other_value': False,
460 | 'true_value': False,
461 | 'new_initial_value': True,
462 | 'new_true_value': True
463 | }, {
464 | 'initial_value': config_dict.FieldReference(True),
465 | 'other_value': config_dict.FieldReference(True),
466 | 'true_value': True,
467 | 'new_initial_value': False,
468 | 'new_true_value': True
469 | }, {
470 | 'initial_value': config_dict.placeholder(bool),
471 | 'other_value': False,
472 | 'true_value': None,
473 | 'new_initial_value': True,
474 | 'new_true_value': True
475 | })
476 | def testOr(self, initial_value, other_value, true_value, new_initial_value,
477 | new_true_value):
478 | self._test_binary_operator(initial_value, other_value, operator.or_,
479 | true_value, new_initial_value, new_true_value)
480 |
481 | @parameterized.parameters(
482 | {
483 | 'initial_value': False,
484 | 'other_value': True,
485 | 'true_value': True,
486 | 'new_initial_value': True,
487 | 'new_true_value': False
488 | }, {
489 | 'initial_value': config_dict.FieldReference(True),
490 | 'other_value': config_dict.FieldReference(True),
491 | 'true_value': False,
492 | 'new_initial_value': False,
493 | 'new_true_value': True
494 | }, {
495 | 'initial_value': config_dict.placeholder(bool),
496 | 'other_value': True,
497 | 'true_value': None,
498 | 'new_initial_value': True,
499 | 'new_true_value': False
500 | })
501 | def testXor(self, initial_value, other_value, true_value, new_initial_value,
502 | new_true_value):
503 | self._test_binary_operator(initial_value, other_value, operator.xor,
504 | true_value, new_initial_value, new_true_value)
505 |
506 | @parameterized.parameters(
507 | {
508 | 'initial_value': 3,
509 | 'true_value': -3,
510 | 'new_initial_value': -22,
511 | 'new_true_value': 22
512 | }, {
513 | 'initial_value': 15.3,
514 | 'true_value': -15.3,
515 | 'new_initial_value': -0.2,
516 | 'new_true_value': 0.2
517 | }, {
518 | 'initial_value': config_dict.FieldReference(7),
519 | 'true_value': config_dict.FieldReference(-7),
520 | 'new_initial_value': 123,
521 | 'new_true_value': -123
522 | }, {
523 | 'initial_value': config_dict.placeholder(int),
524 | 'true_value': None,
525 | 'new_initial_value': -6,
526 | 'new_true_value': 6
527 | })
528 | def testNeg(self, initial_value, true_value, new_initial_value,
529 | new_true_value):
530 | self._test_unary_operator(initial_value, operator.neg, true_value,
531 | new_initial_value, new_true_value)
532 |
533 | @parameterized.parameters(
534 | {
535 | 'initial_value': config_dict.create(attribute=2),
536 | 'true_value': 2,
537 | 'new_initial_value': config_dict.create(attribute=3),
538 | 'new_true_value': 3,
539 | },
540 | {
541 | 'initial_value': config_dict.create(attribute={'a': 1}),
542 | 'true_value': config_dict.create(a=1),
543 | 'new_initial_value': config_dict.create(attribute={'b': 1}),
544 | 'new_true_value': config_dict.create(b=1),
545 | },
546 | {
547 | 'initial_value':
548 | config_dict.FieldReference(config_dict.create(attribute=2)),
549 | 'true_value':
550 | config_dict.FieldReference(2),
551 | 'new_initial_value':
552 | config_dict.create(attribute=3),
553 | 'new_true_value':
554 | 3,
555 | },
556 | {
557 | 'initial_value': config_dict.placeholder(config_dict.ConfigDict),
558 | 'true_value': None,
559 | 'new_initial_value': config_dict.create(attribute=3),
560 | 'new_true_value': 3,
561 | },
562 | )
563 | def testAttr(self, initial_value, true_value, new_initial_value,
564 | new_true_value):
565 | self._test_unary_operator(initial_value, lambda x: x.attr('attribute'),
566 | true_value, new_initial_value, new_true_value)
567 |
568 | @parameterized.parameters(
569 | {
570 | 'initial_value': 3,
571 | 'true_value': 3,
572 | 'new_initial_value': -101,
573 | 'new_true_value': 101
574 | }, {
575 | 'initial_value': -15.3,
576 | 'true_value': 15.3,
577 | 'new_initial_value': 7.3,
578 | 'new_true_value': 7.3
579 | }, {
580 | 'initial_value': config_dict.FieldReference(-7),
581 | 'true_value': config_dict.FieldReference(7),
582 | 'new_initial_value': 3,
583 | 'new_true_value': 3
584 | }, {
585 | 'initial_value': config_dict.placeholder(float),
586 | 'true_value': None,
587 | 'new_initial_value': -6.25,
588 | 'new_true_value': 6.25
589 | })
590 | def testAbs(self, initial_value, true_value, new_initial_value,
591 | new_true_value):
592 | self._test_unary_operator(initial_value, operator.abs, true_value,
593 | new_initial_value, new_true_value)
594 |
595 | def testToInt(self):
596 | self._test_unary_operator(25.3, lambda ref: ref.to_int(), 25, 27.9, 27)
597 | ref = config_dict.FieldReference(64.7)
598 | ref = ref.to_int()
599 | self.assertEqual(ref.get(), 64)
600 | self.assertEqual(ref._field_type, int)
601 |
602 | def testToFloat(self):
603 | self._test_unary_operator(12, lambda ref: ref.to_float(), 12.0, 0, 0.0)
604 |
605 | ref = config_dict.FieldReference(647)
606 | ref = ref.to_float()
607 | self.assertEqual(ref.get(), 647.0)
608 | self.assertEqual(ref._field_type, float)
609 |
610 | def testToString(self):
611 | self._test_unary_operator(12, lambda ref: ref.to_str(), '12', 0, '0')
612 |
613 | ref = config_dict.FieldReference(647)
614 | ref = ref.to_str()
615 | self.assertEqual(ref.get(), '647')
616 | self.assertEqual(ref._field_type, str)
617 |
618 | def testSetValue(self):
619 | ref = config_dict.FieldReference(1.0)
620 | other = config_dict.FieldReference(3)
621 | ref_plus_other = ref + other
622 |
623 | self.assertEqual(ref_plus_other.get(), 4.0)
624 |
625 | ref.set(2.5)
626 | self.assertEqual(ref_plus_other.get(), 5.5)
627 |
628 | other.set(110)
629 | self.assertEqual(ref_plus_other.get(), 112.5)
630 |
631 | # Type checking
632 | with self.assertRaises(TypeError):
633 | other.set('this is a string')
634 |
635 | with self.assertRaises(TypeError):
636 | other.set(config_dict.FieldReference('this is a string'))
637 |
638 | with self.assertRaises(TypeError):
639 | other.set(config_dict.FieldReference(None, field_type=str))
640 |
641 | def testSetValueSubclass(self):
642 | class A:
643 | pass
644 |
645 | class B(A):
646 | pass
647 |
648 | ref = config_dict.FieldReference(A())
649 | ref.set(B()) # Can assign subclasses
650 |
651 | def testSetResult(self):
652 | ref = config_dict.FieldReference(1.0)
653 | result = ref + 1.0
654 | second_result = result + 1.0
655 |
656 | self.assertEqual(ref.get(), 1.0)
657 | self.assertEqual(result.get(), 2.0)
658 | self.assertEqual(second_result.get(), 3.0)
659 |
660 | ref.set(2.0)
661 | self.assertEqual(ref.get(), 2.0)
662 | self.assertEqual(result.get(), 3.0)
663 | self.assertEqual(second_result.get(), 4.0)
664 |
665 | result.set(4.0)
666 | self.assertEqual(ref.get(), 2.0)
667 | self.assertEqual(result.get(), 4.0)
668 | self.assertEqual(second_result.get(), 5.0)
669 |
670 | # All references are broken at this point.
671 | ref.set(1.0)
672 | self.assertEqual(ref.get(), 1.0)
673 | self.assertEqual(result.get(), 4.0)
674 | self.assertEqual(second_result.get(), 5.0)
675 |
676 | def testTypeChecking(self):
677 | ref = config_dict.FieldReference(1)
678 | string_ref = config_dict.FieldReference('a')
679 |
680 | x = ref + string_ref
681 | with self.assertRaises(TypeError):
682 | x.get()
683 |
684 | def testNoType(self):
685 | self.assertRaisesRegex(TypeError, 'field_type should be a type.*',
686 | config_dict.FieldReference, None, 0)
687 |
688 | def testEqual(self):
689 | # Simple case
690 | ref1 = config_dict.FieldReference(1)
691 | ref2 = config_dict.FieldReference(1)
692 | ref3 = config_dict.FieldReference(2)
693 | self.assertEqual(ref1, 1)
694 | self.assertEqual(ref1, ref1)
695 | self.assertEqual(ref1, ref2)
696 | self.assertNotEqual(ref1, 2)
697 | self.assertNotEqual(ref1, ref3)
698 |
699 | # ConfigDict inside FieldReference
700 | ref1 = config_dict.FieldReference(config_dict.ConfigDict({'a': 1}))
701 | ref2 = config_dict.FieldReference(config_dict.ConfigDict({'a': 1}))
702 | ref3 = config_dict.FieldReference(config_dict.ConfigDict({'a': 2}))
703 | self.assertEqual(ref1, config_dict.ConfigDict({'a': 1}))
704 | self.assertEqual(ref1, ref1)
705 | self.assertEqual(ref1, ref2)
706 | self.assertNotEqual(ref1, config_dict.ConfigDict({'a': 2}))
707 | self.assertNotEqual(ref1, ref3)
708 |
709 | def testLessEqual(self):
710 | # Simple case
711 | ref1 = config_dict.FieldReference(1)
712 | ref2 = config_dict.FieldReference(1)
713 | ref3 = config_dict.FieldReference(2)
714 | self.assertLessEqual(ref1, 1)
715 | self.assertLessEqual(ref1, 2)
716 | self.assertLessEqual(0, ref1)
717 | self.assertLessEqual(1, ref1)
718 | self.assertGreater(ref1, 0)
719 |
720 | self.assertLessEqual(ref1, ref1)
721 | self.assertLessEqual(ref1, ref2)
722 | self.assertLessEqual(ref1, ref3)
723 | self.assertGreater(ref3, ref1)
724 |
725 | def testControlFlowError(self):
726 | ref1 = config_dict.FieldReference(True)
727 | ref2 = config_dict.FieldReference(False)
728 |
729 | with self.assertRaises(NotImplementedError):
730 | if ref1:
731 | pass
732 | with self.assertRaises(NotImplementedError):
733 | _ = ref1 and ref2
734 | with self.assertRaises(NotImplementedError):
735 | _ = ref1 or ref2
736 | with self.assertRaises(NotImplementedError):
737 | _ = not ref1
738 |
739 |
740 | if __name__ == '__main__':
741 | absltest.main()
742 |
--------------------------------------------------------------------------------
/ml_collections/config_dict/tests/frozen_config_dict_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for config_dict.FrozenConfigDict."""
16 |
17 | from collections import abc as collections_abc
18 | import copy
19 | import pickle
20 |
21 | from absl.testing import absltest
22 | from ml_collections import config_dict
23 |
24 | _TEST_DICT = {
25 | 'int': 2,
26 | 'list': [1, 2],
27 | 'nested_list': [[1, [2]]],
28 | 'set': {1, 2},
29 | 'tuple': (1, 2),
30 | 'frozenset': frozenset({1, 2}),
31 | 'dict': {
32 | 'float': -1.23,
33 | 'list': [1, 2],
34 | 'dict': {},
35 | 'tuple_containing_list': (1, 2, (3, [4, 5], (6, 7))),
36 | 'list_containing_tuple': [1, 2, [3, 4], (5, 6)],
37 | },
38 | 'ref': config_dict.FieldReference({'int': 0})
39 | }
40 |
41 |
42 | def _test_dict_deepcopy():
43 | return copy.deepcopy(_TEST_DICT)
44 |
45 |
46 | def _test_configdict():
47 | return config_dict.ConfigDict(_TEST_DICT)
48 |
49 |
50 | def _test_frozenconfigdict():
51 | return config_dict.FrozenConfigDict(_TEST_DICT)
52 |
53 |
54 | class FrozenConfigDictTest(absltest.TestCase):
55 | """Tests FrozenConfigDict in config flags library."""
56 |
57 | def assertFrozenRaisesValueError(self, input_list):
58 | """Assert initialization on all elements of input_list raise ValueError."""
59 | for initial_dictionary in input_list:
60 | with self.assertRaises(ValueError):
61 | _ = config_dict.FrozenConfigDict(initial_dictionary)
62 |
63 | def testBasicEquality(self):
64 | """Tests basic equality with different types of initialization."""
65 | fcd = _test_frozenconfigdict()
66 | fcd_cd = config_dict.FrozenConfigDict(_test_configdict())
67 | fcd_fcd = config_dict.FrozenConfigDict(fcd)
68 | self.assertEqual(fcd, fcd_cd)
69 | self.assertEqual(fcd, fcd_fcd)
70 |
71 | def testImmutability(self):
72 | """Tests immutability of frozen config."""
73 | fcd = _test_frozenconfigdict()
74 | self.assertEqual(fcd.list, tuple(_TEST_DICT['list']))
75 | self.assertEqual(fcd.tuple, _TEST_DICT['tuple'])
76 | self.assertEqual(fcd.set, frozenset(_TEST_DICT['set']))
77 | self.assertEqual(fcd.frozenset, _TEST_DICT['frozenset'])
78 | # Must manually check set to frozenset conversion, since Python == does not
79 | self.assertIsInstance(fcd.set, frozenset)
80 |
81 | self.assertEqual(fcd.dict.list, tuple(_TEST_DICT['dict']['list']))
82 | self.assertNotEqual(fcd.dict.tuple_containing_list,
83 | _TEST_DICT['dict']['tuple_containing_list'])
84 | self.assertEqual(fcd.dict.tuple_containing_list[2][1],
85 | tuple(_TEST_DICT['dict']['tuple_containing_list'][2][1]))
86 | self.assertIsInstance(fcd.dict, config_dict.FrozenConfigDict)
87 |
88 | with self.assertRaises(AttributeError):
89 | fcd.newitem = 0
90 | with self.assertRaises(AttributeError):
91 | fcd.dict.int = 0
92 | with self.assertRaises(AttributeError):
93 | fcd['newitem'] = 0
94 | with self.assertRaises(AttributeError):
95 | del fcd.int
96 | with self.assertRaises(AttributeError):
97 | del fcd['int']
98 |
99 | def testLockAndFreeze(self):
100 | """Ensures .lock() and .freeze() raise errors."""
101 | fcd = _test_frozenconfigdict()
102 |
103 | self.assertFalse(fcd.is_locked)
104 | self.assertFalse(fcd.as_configdict().is_locked)
105 |
106 | with self.assertRaises(AttributeError):
107 | fcd.lock()
108 | with self.assertRaises(AttributeError):
109 | fcd.unlock()
110 | with self.assertRaises(AttributeError):
111 | fcd.freeze()
112 | with self.assertRaises(AttributeError):
113 | fcd.unfreeze()
114 |
115 | def testInitConfigDict(self):
116 | """Tests that ConfigDict initialization handles FrozenConfigDict.
117 |
118 | Initializing a ConfigDict on a dictionary with FrozenConfigDict values
119 | should unfreeze these values.
120 | """
121 | dict_without_fcd_node = _test_dict_deepcopy()
122 | dict_without_fcd_node.pop('ref')
123 | dict_with_fcd_node = copy.deepcopy(dict_without_fcd_node)
124 | dict_with_fcd_node['dict'] = config_dict.FrozenConfigDict(
125 | dict_with_fcd_node['dict'])
126 | cd_without_fcd_node = config_dict.ConfigDict(dict_without_fcd_node)
127 | cd_with_fcd_node = config_dict.ConfigDict(dict_with_fcd_node)
128 | fcd_without_fcd_node = config_dict.FrozenConfigDict(
129 | dict_without_fcd_node)
130 | fcd_with_fcd_node = config_dict.FrozenConfigDict(dict_with_fcd_node)
131 |
132 | self.assertEqual(cd_without_fcd_node, cd_with_fcd_node)
133 | self.assertEqual(fcd_without_fcd_node, fcd_with_fcd_node)
134 |
135 | def testInitCopying(self):
136 | """Tests that initialization copies when and only when necessary.
137 |
138 | Ensures copying only occurs when converting mutable type to immutable type,
139 | regardless of whether the FrozenConfigDict is initialized by a dict or a
140 | FrozenConfigDict. Also ensures no copying occurs when converting from
141 | FrozenConfigDict back to ConfigDict.
142 | """
143 | fcd = _test_frozenconfigdict()
144 |
145 | # These should be uncopied when creating fcd
146 | fcd_unchanged_from_test_dict = [
147 | (_TEST_DICT['tuple'], fcd.tuple),
148 | (_TEST_DICT['frozenset'], fcd.frozenset),
149 | (_TEST_DICT['dict']['tuple_containing_list'][2][2],
150 | fcd.dict.tuple_containing_list[2][2]),
151 | (_TEST_DICT['dict']['list_containing_tuple'][3],
152 | fcd.dict.list_containing_tuple[3])
153 | ]
154 |
155 | # These should be copied when creating fcd
156 | fcd_different_from_test_dict = [
157 | (_TEST_DICT['list'], fcd.list),
158 | (_TEST_DICT['dict']['tuple_containing_list'][2][1],
159 | fcd.dict.tuple_containing_list[2][1])
160 | ]
161 |
162 | for (x, y) in fcd_unchanged_from_test_dict:
163 | self.assertEqual(id(x), id(y))
164 | for (x, y) in fcd_different_from_test_dict:
165 | self.assertNotEqual(id(x), id(y))
166 |
167 | # Also make sure that converting back to ConfigDict makes no copies
168 | self.assertEqual(
169 | id(_TEST_DICT['dict']['tuple_containing_list']),
170 | id(config_dict.ConfigDict(fcd).dict.tuple_containing_list))
171 |
172 | def testAsConfigDict(self):
173 | """Tests that converting FrozenConfigDict to ConfigDict works correctly.
174 |
175 | In particular, ensures that FrozenConfigDict does the inverse of ConfigDict
176 | regarding type_safe, lock, and attribute mutability.
177 | """
178 | # First ensure conversion to ConfigDict works on empty FrozenConfigDict
179 | self.assertEqual(
180 | config_dict.ConfigDict(config_dict.FrozenConfigDict()),
181 | config_dict.ConfigDict())
182 |
183 | cd = _test_configdict()
184 | cd_fcd_cd = config_dict.ConfigDict(config_dict.FrozenConfigDict(cd))
185 | self.assertEqual(cd, cd_fcd_cd)
186 |
187 | # Make sure locking is respected
188 | cd.lock()
189 | self.assertEqual(
190 | cd, config_dict.ConfigDict(config_dict.FrozenConfigDict(cd)))
191 |
192 | # Make sure type_safe is respected
193 | cd = config_dict.ConfigDict(_TEST_DICT, type_safe=False)
194 | self.assertEqual(
195 | cd, config_dict.ConfigDict(config_dict.FrozenConfigDict(cd)))
196 |
197 | def testInitSelfReferencing(self):
198 | """Ensure initialization fails on self-referencing dicts."""
199 | self_ref = {}
200 | self_ref['self'] = self_ref
201 | parent_ref = {'dict': {}}
202 | parent_ref['dict']['parent'] = parent_ref
203 | tuple_parent_ref = {'dict': {}}
204 | tuple_parent_ref['dict']['tuple'] = (1, 2, tuple_parent_ref)
205 | attribute_cycle = {'dict': copy.deepcopy(self_ref)}
206 |
207 | self.assertFrozenRaisesValueError(
208 | [self_ref, parent_ref, tuple_parent_ref, attribute_cycle])
209 |
210 | def testInitCycles(self):
211 | """Ensure initialization fails if an attribute of input is cyclic."""
212 | inner_cyclic_list = [1, 2]
213 | cyclic_list = [3, inner_cyclic_list]
214 | inner_cyclic_list.append(cyclic_list)
215 | cyclic_tuple = tuple(cyclic_list)
216 |
217 | test_dict_cyclic_list = _test_dict_deepcopy()
218 | test_dict_cyclic_tuple = _test_dict_deepcopy()
219 |
220 | test_dict_cyclic_list['cyclic_list'] = cyclic_list
221 | test_dict_cyclic_tuple['dict']['cyclic_tuple'] = cyclic_tuple
222 |
223 | self.assertFrozenRaisesValueError(
224 | [test_dict_cyclic_list, test_dict_cyclic_tuple])
225 |
226 | def testInitDictInList(self):
227 | """Ensure initialization fails on dict and ConfigDict in lists/tuples."""
228 | list_containing_dict = {'list': [1, 2, 3, {'a': 4, 'b': 5}]}
229 | tuple_containing_dict = {'tuple': (1, 2, 3, {'a': 4, 'b': 5})}
230 | list_containing_cd = {'list': [1, 2, 3, _test_configdict()]}
231 | tuple_containing_cd = {'tuple': (1, 2, 3, _test_configdict())}
232 | fr_containing_list_containing_dict = {
233 | 'fr': config_dict.FieldReference([1, {
234 | 'a': 2
235 | }])
236 | }
237 |
238 | self.assertFrozenRaisesValueError([
239 | list_containing_dict, tuple_containing_dict, list_containing_cd,
240 | tuple_containing_cd, fr_containing_list_containing_dict
241 | ])
242 |
243 | def testInitFieldReferenceInList(self):
244 | """Ensure initialization fails on FieldReferences in lists/tuples."""
245 | list_containing_fr = {'list': [1, 2, 3, config_dict.FieldReference(4)]}
246 | tuple_containing_fr = {
247 | 'tuple': (1, 2, 3, config_dict.FieldReference('a'))
248 | }
249 |
250 | self.assertFrozenRaisesValueError([list_containing_fr, tuple_containing_fr])
251 |
252 | def testInitInvalidAttributeName(self):
253 | """Ensure initialization fails on attributes with invalid names."""
254 | dot_name = {'dot.name': None}
255 | immutable_name = {'__hash__': None}
256 |
257 | with self.assertRaises(ValueError):
258 | config_dict.FrozenConfigDict(dot_name)
259 |
260 | with self.assertRaises(AttributeError):
261 | config_dict.FrozenConfigDict(immutable_name)
262 |
263 | def testFieldReferenceResolved(self):
264 | """Tests that FieldReferences are resolved."""
265 | cfg = config_dict.ConfigDict({'fr': config_dict.FieldReference(1)})
266 | frozen_cfg = config_dict.FrozenConfigDict(cfg)
267 | self.assertNotIsInstance(frozen_cfg._fields['fr'],
268 | config_dict.FieldReference)
269 | hash(frozen_cfg) # with FieldReference resolved, frozen_cfg is hashable
270 |
271 | def testFieldReferenceCycle(self):
272 | """Tests that FieldReferences may not contain reference cycles."""
273 | frozenset_fr = {'frozenset': frozenset({1, 2})}
274 | frozenset_fr['fr'] = config_dict.FieldReference(
275 | frozenset_fr['frozenset'])
276 | list_fr = {'list': [1, 2]}
277 | list_fr['fr'] = config_dict.FieldReference(list_fr['list'])
278 |
279 | cyclic_fr = {'a': 1}
280 | cyclic_fr['fr'] = config_dict.FieldReference(cyclic_fr)
281 | cyclic_fr_parent = {'dict': {}}
282 | cyclic_fr_parent['dict']['fr'] = config_dict.FieldReference(
283 | cyclic_fr_parent)
284 |
285 | # FieldReference is allowed to point to non-cyclic objects:
286 | _ = config_dict.FrozenConfigDict(frozenset_fr)
287 | _ = config_dict.FrozenConfigDict(list_fr)
288 | # But not cycles:
289 | self.assertFrozenRaisesValueError([cyclic_fr, cyclic_fr_parent])
290 |
291 | def testDeepCopy(self):
292 | """Ensure deepcopy works and does not affect equality."""
293 | fcd = _test_frozenconfigdict()
294 | fcd_deepcopy = copy.deepcopy(fcd)
295 | self.assertEqual(fcd, fcd_deepcopy)
296 |
297 | def testEquals(self):
298 | """Tests that __eq__() respects hidden mutability."""
299 | fcd = _test_frozenconfigdict()
300 |
301 | # First, ensure __eq__() returns False when comparing to other types
302 | self.assertNotEqual(fcd, (1, 2))
303 | self.assertNotEqual(fcd, fcd.as_configdict())
304 |
305 | list_to_tuple = _test_dict_deepcopy()
306 | list_to_tuple['list'] = tuple(list_to_tuple['list'])
307 | fcd_list_to_tuple = config_dict.FrozenConfigDict(list_to_tuple)
308 |
309 | set_to_frozenset = _test_dict_deepcopy()
310 | set_to_frozenset['set'] = frozenset(set_to_frozenset['set'])
311 | fcd_set_to_frozenset = config_dict.FrozenConfigDict(set_to_frozenset)
312 |
313 | self.assertNotEqual(fcd, fcd_list_to_tuple)
314 |
315 | # Because set == frozenset in Python:
316 | self.assertEqual(fcd, fcd_set_to_frozenset)
317 |
318 | # Items are not affected by hidden mutability
319 | self.assertCountEqual(fcd.items(), fcd_list_to_tuple.items())
320 | self.assertCountEqual(fcd.items(), fcd_set_to_frozenset.items())
321 |
322 | def testEqualsAsConfigDict(self):
323 | """Tests that eq_as_configdict respects hidden mutability but not type."""
324 | fcd = _test_frozenconfigdict()
325 |
326 | # First, ensure eq_as_configdict() returns True with an equal ConfigDict but
327 | # False for other types.
328 | self.assertFalse(fcd.eq_as_configdict([1, 2]))
329 | self.assertTrue(fcd.eq_as_configdict(fcd.as_configdict()))
330 | empty_fcd = config_dict.FrozenConfigDict()
331 | self.assertTrue(empty_fcd.eq_as_configdict(config_dict.ConfigDict()))
332 |
333 | # Now, ensure it has the same immutability detection as __eq__().
334 | list_to_tuple = _test_dict_deepcopy()
335 | list_to_tuple['list'] = tuple(list_to_tuple['list'])
336 | fcd_list_to_tuple = config_dict.FrozenConfigDict(list_to_tuple)
337 |
338 | set_to_frozenset = _test_dict_deepcopy()
339 | set_to_frozenset['set'] = frozenset(set_to_frozenset['set'])
340 | fcd_set_to_frozenset = config_dict.FrozenConfigDict(set_to_frozenset)
341 |
342 | self.assertFalse(fcd.eq_as_configdict(fcd_list_to_tuple))
343 | # Because set == frozenset in Python:
344 | self.assertTrue(fcd.eq_as_configdict(fcd_set_to_frozenset))
345 |
346 | def testHash(self):
347 | """Ensures __hash__() respects hidden mutability."""
348 | list_to_tuple = _test_dict_deepcopy()
349 | list_to_tuple['list'] = tuple(list_to_tuple['list'])
350 |
351 | self.assertEqual(
352 | hash(_test_frozenconfigdict()),
353 | hash(config_dict.FrozenConfigDict(_test_dict_deepcopy())))
354 | self.assertNotEqual(
355 | hash(_test_frozenconfigdict()),
356 | hash(config_dict.FrozenConfigDict(list_to_tuple)))
357 |
358 | # Ensure Python realizes FrozenConfigDict is hashable
359 | self.assertIsInstance(_test_frozenconfigdict(), collections_abc.Hashable)
360 |
361 | def testUnhashableType(self):
362 | """Ensures __hash__() fails if FrozenConfigDict has unhashable value."""
363 | unhashable_fcd = config_dict.FrozenConfigDict(
364 | {'unhashable': bytearray()})
365 | with self.assertRaises(TypeError):
366 | hash(unhashable_fcd)
367 |
368 | def testToDict(self):
369 | """Ensure to_dict() does not care about hidden mutability."""
370 | list_to_tuple = _test_dict_deepcopy()
371 | list_to_tuple['list'] = tuple(list_to_tuple['list'])
372 |
373 | self.assertEqual(_test_frozenconfigdict().to_dict(),
374 | config_dict.FrozenConfigDict(list_to_tuple).to_dict())
375 |
376 | def testPickle(self):
377 | """Make sure FrozenConfigDict can be dumped and loaded with pickle."""
378 | fcd = _test_frozenconfigdict()
379 | locked_fcd = config_dict.FrozenConfigDict(_test_configdict().lock())
380 |
381 | unpickled_fcd = pickle.loads(pickle.dumps(fcd))
382 | unpickled_locked_fcd = pickle.loads(pickle.dumps(locked_fcd))
383 |
384 | self.assertEqual(fcd, unpickled_fcd)
385 | self.assertEqual(locked_fcd, unpickled_locked_fcd)
386 |
387 |
388 | if __name__ == '__main__':
389 | absltest.main()
390 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Config flags module."""
16 |
17 | from .config_flags import DEFINE_config_dataclass
18 | from .config_flags import DEFINE_config_dict
19 | from .config_flags import DEFINE_config_file
20 | from .config_flags import get_config_filename
21 | from .config_flags import get_override_values
22 | from .config_flags import register_flag_parser
23 | from .config_flags import register_flag_parser_for_type
24 |
25 | __all__ = (
26 | "DEFINE_config_dataclass",
27 | "DEFINE_config_dict",
28 | "DEFINE_config_file",
29 | "get_config_filename",
30 | "get_override_values",
31 | "register_flag_parser",
32 | "register_flag_parser_for_type",
33 | )
34 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/config_path.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Module for spliting flag prefixes."""
16 |
17 | import ast
18 | import dataclasses as dc
19 | import functools
20 | import types
21 | import typing
22 | from typing import Any, MutableSequence, Optional, Sequence, Tuple, Union, Type
23 |
24 | from ml_collections import config_dict
25 |
26 |
27 | NoneType = type(None)
28 |
29 |
30 | _AST_SPLIT_CONFIG_PATH = {
31 | ast.Attribute: lambda n: (*_split_node(n.value), n.attr),
32 | ast.Index: lambda i: _split_node(i.value),
33 | ast.Name: lambda n: (n.id,),
34 | ast.Slice: lambda i: slice(*map(_split_node, (i.lower, i.upper, i.step))),
35 | ast.Subscript: lambda n: (*_split_node(n.value), _split_node(n.slice)),
36 | type(None): lambda n: None
37 | }
38 |
39 |
40 | def _split_node(node):
41 | return _AST_SPLIT_CONFIG_PATH.get(type(node), ast.literal_eval)(node)
42 |
43 |
44 | def split(config_path: str) -> Tuple[Any]:
45 | """Returns config_path split into a tuple of parts.
46 |
47 | Example usage:
48 | >>> assert config_path.split('a.b.cc') == ('a', 'b', 'cc')
49 | >>> assert config_path.split('a.b["cc.d"]') == ('a', 'b', 'cc.d')
50 | >>> assert config_path.split('a.b[10]') == ('a', 'b', 10)
51 | >>> assert config_path.split('a[(1, 2)]') == ('a', (1, 2))
52 | >>> assert config_path.split('a[:]') == ('a', slice(None))
53 |
54 | Args:
55 | config_path: Input path to be split - see example usage.
56 |
57 | Returns:
58 | Tuple of config_path split into parts. Parts are attributes or subscripts.
59 | Attrributes are treated as strings and subscripts are parsed using
60 | ast.literal_eval. It is up to the caller to ensure all returned types are
61 | valid.
62 |
63 | Raises:
64 | ValueError: Failed to parse config_path.
65 | """
66 | try:
67 | node = ast.parse(config_path, mode='eval')
68 | except SyntaxError as e:
69 | raise ValueError(f'Could not parse {config_path!r}: {e!r}') from None
70 | if isinstance(node, ast.Expression):
71 | result = _split_node(node.body)
72 | if isinstance(result, tuple):
73 | return result
74 | raise ValueError(config_path)
75 |
76 |
77 | def _get_item_or_attribute(config, field,
78 | field_path: Optional[str] = None):
79 | """Returns attribute of member failing that the item."""
80 | if isinstance(field, str) and hasattr(config, field):
81 | return getattr(config, field)
82 | if hasattr(config, '__getitem__'):
83 | return config[field]
84 | if isinstance(field, int):
85 | raise IndexError(
86 | f'{type(config)} does not support integer indexing [{field}]]. '
87 | f'Attempting to lookup: {field_path}')
88 | raise KeyError(
89 | f'Attribute {type(config)}.{field} does not exist '
90 | 'and the type does not support indexing. '
91 | f'Attempting to lookup: {field_path}')
92 |
93 |
94 | def _get_holder_field(config_path: str, config: Any) -> Tuple[Any, str]:
95 | """Returns the last part config_path and config to allow assignment.
96 |
97 | Example usage:
98 | >>> config = {'a': {'b', {'c', 10}}}
99 | >>> holder, lastfield = _get_holder_field('a.b.c', config)
100 | >>> assert lastfield == 'c'
101 | >>> assert holder is config['a']['b']
102 | >>> assert holder[lastfield] == 10
103 |
104 | Args:
105 | config_path: Any string that `split` can process.
106 | config: A nested datastructure that can be accessed via
107 | _get_item_or_attribute
108 |
109 | Returns:
110 | The penultimate object when walking config with config_path. And the final
111 | part of the config path.
112 |
113 | Raises:
114 | IndexError: Integer field not found in nested structure.
115 | KeyError: Non-integer field not found in nested structure.
116 | ValueError: Empty/invalid config_path after parsing.
117 | """
118 | fields = split(config_path)
119 | if not fields:
120 | raise ValueError('Path cannot be empty')
121 | get_item = functools.partial(_get_item_or_attribute, field_path=config_path)
122 | holder = functools.reduce(get_item, fields[:-1], config)
123 | return holder, fields[-1]
124 |
125 |
126 | def get_value(config_path: str, config: Any):
127 | """Gets value of a single field.
128 |
129 | Example usage:
130 | >>> config = {'a': {'b', {'c', 10}}}
131 | >>> assert config_path.get_value('a.b.c', config) == 10
132 |
133 | Args:
134 | config_path: Any string that `split` can process.
135 | config: A nested datastructure
136 |
137 | Returns:
138 | The last object when walking config with config_path.
139 |
140 | Raises:
141 | IndexError: Integer field not found in nested structure.
142 | KeyError: Non-integer field not found in nested structure.
143 | ValueError: Empty/invalid config_path after parsing.
144 | """
145 | get_item = functools.partial(_get_item_or_attribute, field_path=config_path)
146 | return functools.reduce(get_item, split(config_path), config)
147 |
148 |
149 | def initialize_missing_parent_fields(
150 | config: Any, override: str,
151 | allowed_missing: Sequence[str]):
152 | """Adds some missing nested holder fields for a particular override.
153 |
154 | For example if override is 'config.a.b.c' and config.a is None, it
155 | will default initialize config.a, and if config.a.b is None will default
156 | initialize it as well. Only overrides present in allowed_missing will
157 | be initialized.
158 |
159 | Args:
160 | config: config object (typically dataclass)
161 | override: dot joined override name.
162 | allowed_missing: list of overrides that are allowed
163 | to be set. For example, if override is 'a.b.c.d',
164 | allowed_missing could be ['a.b.c', 'a', 'foo.bar'].
165 |
166 | Raises:
167 | ValueError: if parent field is not of dataclass type.
168 | """
169 | fields = split(override)
170 | # Collect the tree levels at which we are alloed to create override
171 | allowed_levels = {len(split(x)) for x in allowed_missing if
172 | override.startswith(x + '.')}
173 | child = config
174 | for level, f in enumerate(fields[:-1], 1):
175 | parent = child
176 | child = _get_item_or_attribute(parent, f, override)
177 | if child is not None:
178 | continue
179 | # Field is not yet present, see if we should create it instead.
180 | field_type = get_type(f, parent)
181 | # Note: these two assertions below are mostly guard
182 | # rails to prevent behaviors that might be confusing/accidental.
183 | # Specifically we disallow implicit creation of parent fields,
184 | # creating non dataclass objects. They can be revisited
185 | # in the future.
186 | if not dc.is_dataclass(field_type):
187 | raise ValueError(
188 | f'Override {override} can not be applied because '
189 | f'field "{f}" is None, and its type "{field_type}" is not a '
190 | f'dataclass in the parent of type "{type(parent)}".')
191 |
192 | if level not in allowed_levels:
193 | raise ValueError(
194 | f'Flag {override} can not be applied because '
195 | f'field "{f}" is None by default and it is not explicitly '
196 | 'provided in flags (it can be default intialized by '
197 | f'providing --.{f}=build flag')
198 | try:
199 | child = field_type()
200 | except Exception as e:
201 | raise ValueError(
202 | f'Override {override} can not be applied because '
203 | f'field "{f}" of type {field_type} can not be default instantiated:'
204 | f'{e}') from e
205 | set_value(f, parent, child)
206 |
207 |
208 | def get_origin(type_spec: type) -> Optional[type]: # pylint: disable=g-bare-generic drop when 3.7 support is not needed
209 | """Call typing.get_origin, with a fallback for Python 3.7 and below."""
210 | if hasattr(typing, 'get_origin'):
211 | return typing.get_origin(type_spec)
212 | return getattr(type_spec, '__origin__', None)
213 |
214 |
215 | def get_args(type_spec: type) -> Union[NoneType, Tuple[type, ...]]: # pylint: disable=g-bare-generic drop when 3.7 support is not needed
216 | """Call typing.get_args, with fallback for Python 3.7 and below."""
217 | if hasattr(typing, 'get_args'):
218 | return typing.get_args(type_spec)
219 | return getattr(type_spec, '__args__', NoneType)
220 |
221 |
222 | def _is_union_type(type_spec: type) -> bool: # pylint: disable=g-bare-generic drop when 3.7 support is not needed
223 | """Cheeck if a type_spec is a Union type or not."""
224 | # UnionType was only introduced in python 3.10. We need getattr for
225 | # backward compatibility.
226 | return get_origin(type_spec) in [Union, getattr(types, 'UnionType', Union)]
227 |
228 |
229 | def extract_type_from_optional(type_spec: type) -> Optional[type]: # pylint: disable=g-bare-generic drop when 3.7 support is not needed
230 | """If type_spec is of type Optional[T], returns T object, otherwise None"""
231 | if not _is_union_type(type_spec):
232 | return None
233 | non_none = [t for t in get_args(type_spec) if t is not NoneType]
234 | if len(non_none) != 1:
235 | return None
236 | return non_none[0]
237 |
238 |
239 | def normalize_type(type_spec: type) -> type: # pylint: disable=g-bare-generic drop when 3.7 support is not needed
240 | """Normalizes a type object.
241 |
242 | Strips all None types from the type specification and returns the remaining
243 | single type. This is primarily useful for Optional type annotations in which
244 | case it will strip out the NoneType and return the inner type.
245 |
246 | Args:
247 | type_spec: The type to normalize.
248 |
249 | Raises:
250 | TypeError: If there is not exactly 1 non-None type in the union.
251 | Returns:
252 | The normalized type.
253 | """
254 | if _is_union_type(type_spec):
255 | subtype = extract_type_from_optional(type_spec)
256 | if subtype is None:
257 | raise TypeError(f'Unable to normalize ambiguous type: {type_spec}')
258 | return subtype
259 |
260 | return type_spec
261 |
262 |
263 | def get_type(
264 | config_path: str,
265 | config: Any,
266 | normalize=True,
267 | default_type: Optional[Type[Any]] = None,
268 | ):
269 | """Gets type of field in config described by a config_path.
270 |
271 | Example usage:
272 | >>> config = {'a': {'b', {'c', 10}}}
273 | >>> assert config_path.get_type('a.b.c', config) is int
274 |
275 | Args:
276 | config_path: Any string that `split` can process.
277 | config: A nested datastructure
278 | normalize: whether to normalize the type (in particular strip Optional
279 | annotations on dataclass fields)
280 | default_type: If the `config_path` is not found and `default_type` is set,
281 | the `default_type` is returned.
282 |
283 | Returns:
284 | The type of last object when walking config with config_path.
285 |
286 | Raises:
287 | IndexError: Integer field not found in nested structure.
288 | KeyError: Non-integer field not found in nested structure.
289 | ValueError: Empty/invalid config_path after parsing.
290 | TypeError: Ambiguous type annotation on dataclass field.
291 | """
292 | holder, field = _get_holder_field(config_path, config)
293 | # Check if config is a DM collection and hence has attribute get_type()
294 | if isinstance(holder,
295 | (config_dict.ConfigDict, config_dict.FieldReference)):
296 | if default_type is not None and field not in holder:
297 | return default_type
298 | return holder.get_type(field)
299 | # For dataclasses we can just use the type annotation.
300 | elif dc.is_dataclass(holder):
301 | matches = [f.type for f in dc.fields(holder) if f.name == field]
302 | if not matches:
303 | raise KeyError(f'Field {field} not found on dataclass {type(holder)}')
304 | return normalize_type(matches[0]) if normalize else matches[0]
305 | else:
306 | return type(_get_item_or_attribute(holder, field, config_path))
307 |
308 |
309 | def is_optional(config_path: str, config: Any) -> bool:
310 | raw_type = get_type(config_path, config, normalize=False)
311 | return extract_type_from_optional(raw_type) is not None
312 |
313 |
314 | def set_value(
315 | config_path: str,
316 | config: Any,
317 | value: Any,
318 | *,
319 | accept_new_attributes: bool = False,
320 | ):
321 | """Sets value of field described by config_path.
322 |
323 | Example usage:
324 | >>> config = {'a': {'b', {'c', 10}}}
325 | >>> config_path.set_value('a.b.c', config, 20)
326 | >>> assert config['a']['b']['c'] == 20
327 |
328 | Args:
329 | config_path: Any string that `split` can process.
330 | config: A nested datastructure
331 | value: A value to assign to final field.
332 | accept_new_attributes: If `True`, the new config attributes can be added
333 |
334 | Raises:
335 | IndexError: Integer field not found in nested structure.
336 | KeyError: Non-integer field not found in nested structure.
337 | ValueError: Empty/invalid config_path after parsing.
338 | """
339 | holder, field = _get_holder_field(config_path, config)
340 |
341 | if isinstance(field, int) and isinstance(holder, MutableSequence):
342 | holder[field] = value
343 | elif hasattr(holder, '__setitem__') and (
344 | field in holder or accept_new_attributes
345 | ):
346 | holder[field] = value
347 | elif hasattr(holder, str(field)):
348 | setattr(holder, str(field), value)
349 | else:
350 | if isinstance(field, int):
351 | raise IndexError(
352 | f'{field} is not a valid index for {type(holder)} '
353 | f'(in: {config_path})')
354 | raise KeyError(f'{field} is not a valid key or attribute of {type(holder)} '
355 | f'(in: {config_path})')
356 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/examples/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Defines a method which returns an instance of ConfigDict."""
16 |
17 | from ml_collections import config_dict
18 |
19 |
20 | def get_config():
21 | config = config_dict.ConfigDict()
22 | config.field1 = 1
23 | config.field2 = 'tom'
24 | config.nested = config_dict.ConfigDict()
25 | config.nested.field = 2.23
26 | config.tuple = (1, 2, 3)
27 | return config
28 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/examples/define_config_dataclass_basic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Example of basic DEFINE_config_dataclass usage.
16 |
17 | To run this example:
18 | python define_config_dataclass_basic.py -- --my_config.field1=8 \
19 | --my_config.nested.field=2.1 --my_config.tuple='(1, 2, (1, 2))'
20 | """
21 |
22 | import dataclasses
23 | from typing import Any, Mapping, Sequence
24 |
25 | from absl import app
26 | from ml_collections import config_flags
27 |
28 |
29 | @dataclasses.dataclass
30 | class MyConfig:
31 | field1: int
32 | field2: str
33 | nested: Mapping[str, Any]
34 | tuple: Sequence[int]
35 |
36 |
37 | config = MyConfig(
38 | field1=1,
39 | field2='tom',
40 | nested={'field': 2.23},
41 | tuple=(1, 2, 3),
42 | )
43 |
44 | _CONFIG = config_flags.DEFINE_config_dataclass('my_config', config)
45 |
46 |
47 | def main(_):
48 | print(_CONFIG.value)
49 |
50 |
51 | if __name__ == '__main__':
52 | app.run(main)
53 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/examples/define_config_dict_basic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Example of basic DEFINE_config_dict usage.
16 |
17 | To run this example:
18 | python define_config_dict_basic.py -- --my_config_dict.field1=8 \
19 | --my_config_dict.nested.field=2.1 --my_config_dict.tuple='(1, 2, (1, 2))'
20 | """
21 |
22 | from absl import app
23 |
24 | from ml_collections import config_dict
25 | from ml_collections import config_flags
26 |
27 | config = config_dict.ConfigDict()
28 | config.field1 = 1
29 | config.field2 = 'tom'
30 | config.nested = config_dict.ConfigDict()
31 | config.nested.field = 2.23
32 | config.tuple = (1, 2, 3)
33 |
34 | _CONFIG = config_flags.DEFINE_config_dict('my_config_dict', config)
35 |
36 |
37 | def main(_):
38 | print(_CONFIG.value)
39 |
40 |
41 | if __name__ == '__main__':
42 | app.run(main)
43 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/examples/define_config_file_basic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint: disable=line-too-long
16 | r"""Example of basic DEFINE_flag_dict usage.
17 |
18 | To run this example with basic config file:
19 | python define_config_dict_basic.py -- \
20 | --my_config=ml_collections/config_flags/examples/config.py
21 | \
22 | --my_config.field1=8 --my_config.nested.field=2.1 \
23 | --my_config.tuple='(1, 2, (1, 2))'
24 |
25 | To run this example with parameterised config file:
26 | python define_config_dict_basic.py -- \
27 | --my_config=ml_collections/config_flags/examples/parameterised_config.py:linear
28 | \
29 | --my_config.model_config.output_size=256'
30 | """
31 | # pylint: enable=line-too-long
32 |
33 | from absl import app
34 |
35 | from ml_collections import config_flags
36 |
37 | _CONFIG = config_flags.DEFINE_config_file('my_config')
38 |
39 |
40 | def main(_):
41 | print(_CONFIG.value)
42 |
43 |
44 | if __name__ == '__main__':
45 | app.run(main)
46 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/examples/examples_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for config_flags examples.
16 |
17 | Ensures that from define_config_dict_basic, define_config_file_basic run
18 | successfully.
19 | """
20 |
21 | from absl import flags
22 | from absl.testing import absltest
23 | from absl.testing import flagsaver
24 | from ml_collections.config_flags.examples import define_config_dict_basic
25 | from ml_collections.config_flags.examples import define_config_file_basic
26 |
27 | FLAGS = flags.FLAGS
28 |
29 |
30 | class ConfigDictExamplesTest(absltest.TestCase):
31 |
32 | def test_define_config_dict_basic(self):
33 | define_config_dict_basic.main([])
34 |
35 | @flagsaver.flagsaver
36 | def test_define_config_file_basic(self):
37 | FLAGS.my_config = 'ml_collections/config_flags/examples/config.py'
38 | define_config_file_basic.main([])
39 |
40 | @flagsaver.flagsaver
41 | def test_define_config_file_parameterised(self):
42 | FLAGS.my_config = 'ml_collections/config_flags/examples/parameterised_config.py:linear'
43 | define_config_file_basic.main([])
44 |
45 |
46 | if __name__ == '__main__':
47 | absltest.main()
48 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/examples/parameterised_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Defines a parameterized method which returns a config depending on input."""
16 |
17 | from ml_collections import config_dict
18 |
19 |
20 | def get_config(config_string):
21 | """Return an instance of ConfigDict depending on `config_string`."""
22 | possible_structures = {
23 | 'linear':
24 | config_dict.ConfigDict({
25 | 'model_constructor': 'snt.Linear',
26 | 'model_config': config_dict.ConfigDict({
27 | 'output_size': 42,
28 | })
29 | }),
30 | 'lstm':
31 | config_dict.ConfigDict({
32 | 'model_constructor': 'snt.LSTM',
33 | 'model_config': config_dict.ConfigDict({
34 | 'hidden_size': 108,
35 | })
36 | })
37 | }
38 |
39 | return possible_structures[config_string]
40 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/tests/config_path_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for config_path."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from ml_collections.config_flags import config_path
20 | from ml_collections.config_flags.tests import fieldreference_config
21 | from ml_collections.config_flags.tests import mock_config
22 |
23 |
24 | class ConfigPathTest(parameterized.TestCase):
25 |
26 | def test_list_extra_index(self):
27 | """Tries to index a non-indexable list element."""
28 |
29 | test_config = mock_config.get_config()
30 | with self.assertRaises(IndexError):
31 | config_path.get_value('dict.list[0][0]', test_config)
32 |
33 | def test_list_out_of_range_get(self):
34 | """Tries to access out-of-range value in list."""
35 |
36 | test_config = mock_config.get_config()
37 | with self.assertRaises(IndexError):
38 | config_path.get_value('dict.list[2][1]', test_config)
39 |
40 | def test_list_out_of_range_set(self):
41 | """Tries to override out-of-range value in list."""
42 |
43 | test_config = mock_config.get_config()
44 | with self.assertRaises(IndexError):
45 | config_path.set_value('dict.list[2][1]', test_config, -1)
46 |
47 | def test_reading_non_existing_key(self):
48 | """Tests reading non existing key from config."""
49 |
50 | test_config = mock_config.get_config()
51 | with self.assertRaises(KeyError):
52 | config_path.set_value('dict.not_existing_key', test_config, 1)
53 |
54 | def test_reading_setting_existing_key_in_dict(self):
55 | """Tests setting non existing key from dict inside config."""
56 |
57 | test_config = mock_config.get_config()
58 | with self.assertRaises(KeyError):
59 | config_path.set_value('dict.not_existing_key.key', test_config, 1)
60 |
61 | def test_empty_key(self):
62 | """Tests calling an empty key update."""
63 |
64 | test_config = mock_config.get_config()
65 | with self.assertRaises(ValueError):
66 | config_path.set_value('', test_config, None)
67 |
68 | def test_field_reference_types(self):
69 | """Tests whether types of FieldReference fields are valid."""
70 | test_config = fieldreference_config.get_config()
71 |
72 | paths = ['ref_nodefault', 'ref']
73 | paths_types = [int, int]
74 |
75 | config_types = [config_path.get_type(path, test_config) for path in paths]
76 | self.assertEqual(paths_types, config_types)
77 |
78 | @parameterized.parameters(
79 | ('float', float),
80 | ('integer', int),
81 | ('string', str),
82 | ('bool', bool),
83 | ('dict', dict),
84 | ('dict.float', float),
85 | ('dict.list', list),
86 | ('list', list),
87 | ('list[0]', int),
88 | ('object.float', float),
89 | ('object.integer', int),
90 | ('object.string', str),
91 | ('object.bool', bool),
92 | ('object.dict', dict),
93 | ('object.dict.float', float),
94 | ('object.dict.list', list),
95 | ('object.list', list),
96 | ('object.list[0]', int),
97 | ('object.tuple', tuple),
98 | ('object_reference.float', float),
99 | ('object_reference.integer', int),
100 | ('object_reference.string', str),
101 | ('object_reference.bool', bool),
102 | ('object_reference.dict', dict),
103 | ('object_reference.dict.float', float),
104 | ('object_copy.float', float),
105 | ('object_copy.integer', int),
106 | ('object_copy.string', str),
107 | ('object_copy.bool', bool),
108 | ('object_copy.dict', dict),
109 | ('object_copy.dict.float', float),
110 | )
111 | def test_types(self, path, path_type):
112 | """Tests whether various types of objects are valid."""
113 | test_config = mock_config.get_config()
114 | self.assertEqual(path_type, config_path.get_type(path, test_config))
115 |
116 | if __name__ == '__main__':
117 | absltest.main()
118 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/tests/configdict_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """ConfigDict config file."""
16 |
17 | from ml_collections import config_dict
18 |
19 |
20 | class UnusableConfig(object):
21 | """Test against code assuming the semantics of attributes (such as `lock`).
22 |
23 | This class is to test that the flags implementation does not assume the
24 | semantics of attributes. This is to avoid code such as:
25 |
26 | ```python
27 | if hasattr(obj, lock):
28 | obj.lock()
29 | ```
30 |
31 | which will fail if `obj` has an attribute `lock` that does not behave in the
32 | way we expect.
33 |
34 | This class only has unusable attributes. There are two
35 | exceptions for which this class behaves normally:
36 | * Python's special functions which start and end with a double underscore.
37 | * `valid_attribute`, an attribute used to test the class.
38 |
39 | For other attributes, both `hasttr(obj, attr)` and `callable(obj, attr)` will
40 | return True. Calling `obj.attr` will return a function which takes no
41 | arguments and raises an AttributeError when called. For example, the `lock`
42 | example above will raise an AttributeError. The only valid action on
43 | attributes is assignment, e.g.
44 |
45 | ```python
46 | obj = UnusableConfig()
47 | obj.attr = 1
48 | ```
49 |
50 | In which case the attribute will keep its assigned value and become usable.
51 | """
52 |
53 | def __init__(self):
54 | self._valid_attribute = 1
55 |
56 | def __getattr__(self, attribute):
57 | """Get an arbitrary attribute.
58 |
59 | Returns a function which takes no arguments and raises an AttributeError,
60 | except for Python special functions in which case an AttributeError is
61 | directly raised.
62 |
63 | Args:
64 | attribute: A string representing the attribute's name.
65 | Returns:
66 | A function which raises an AttributeError when called.
67 | Raises:
68 | AttributeError: when the attribute is a Python special function starting
69 | and ending with a double underscore.
70 | """
71 | if attribute.startswith("__") and attribute.endswith("__"):
72 | raise AttributeError("UnusableConfig does not contain entry {}.".
73 | format(attribute))
74 |
75 | def raise_attribute_error_fun():
76 | raise AttributeError(
77 | "{} is not a usable attribute of UnusableConfig".format(
78 | attribute))
79 |
80 | return raise_attribute_error_fun
81 |
82 | @property
83 | def valid_attribute(self):
84 | return self._valid_attribute
85 |
86 | @valid_attribute.setter
87 | def valid_attribute(self, value):
88 | self._valid_attribute = value
89 |
90 | def get_config():
91 | """Returns a ConfigDict. Used for tests."""
92 | cfg = config_dict.ConfigDict()
93 | cfg.integer = 1
94 | cfg.reference = config_dict.FieldReference(1)
95 | cfg.list = [1, 2, 3]
96 | cfg.nested_list = [[1, 2, 3]]
97 | cfg.nested_configdict = config_dict.ConfigDict()
98 | cfg.nested_configdict.integer = 1
99 | cfg.unusable_config = UnusableConfig()
100 |
101 | return cfg
102 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/tests/dataclass_overriding_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for config_flags used in conjunction with DEFINE_config_dataclass."""
16 |
17 | import copy
18 | import dataclasses
19 | import functools
20 | import sys
21 | from typing import Mapping, Optional, Sequence, Tuple, Union
22 | import unittest
23 |
24 | from absl import flags
25 | from absl.testing import absltest
26 | from ml_collections import config_flags
27 | from ml_collections.config_flags import config_flags as config_flag_lib
28 |
29 |
30 | #####
31 | # Simple configuration class for testing.
32 | @dataclasses.dataclass
33 | class MyModelConfig:
34 | foo: int
35 | bar: Sequence[str]
36 | baz: Optional[Mapping[str, str]] = None
37 | buz: Optional[Mapping[Tuple[int, int], str]] = None
38 | qux: Optional[int] = None
39 | bax: float = 1
40 | boj: Tuple[int, ...] = ()
41 |
42 |
43 | class ParserForCustomConfig(flags.ArgumentParser):
44 | def __init__(self, delta=1):
45 | self.delta = delta
46 |
47 | def parse(self, value):
48 | if isinstance(value, CustomParserConfig):
49 | return value
50 | return CustomParserConfig(i=int(value), j=int(value) + self.delta)
51 |
52 |
53 | @dataclasses.dataclass
54 | @config_flags.register_flag_parser(parser=ParserForCustomConfig())
55 | class CustomParserConfig():
56 | i: int
57 | j: int = 1
58 |
59 |
60 | @dataclasses.dataclass
61 | class MyConfig:
62 | my_model: MyModelConfig
63 | baseline_model: MyModelConfig
64 | custom: CustomParserConfig = dataclasses.field(
65 | default_factory=lambda: CustomParserConfig(0))
66 |
67 |
68 | @dataclasses.dataclass
69 | class SubConfig:
70 | model: Optional[MyModelConfig] = dataclasses.field(
71 | default_factory=lambda: MyModelConfig(foo=0, bar=['1']))
72 |
73 |
74 | @dataclasses.dataclass
75 | class ConfigWithOptionalNestedField:
76 | sub: Optional[SubConfig] = None
77 | non_optional: SubConfig = dataclasses.field(
78 | default_factory=SubConfig)
79 |
80 | _CONFIG = MyConfig(
81 | my_model=MyModelConfig(
82 | foo=3,
83 | bar=['a', 'b'],
84 | baz={'foo.b': 'bar'},
85 | buz={(0, 0): 'ZeroZero', (0, 1): 'ZeroOne'}
86 | ),
87 | baseline_model=MyModelConfig(
88 | foo=55,
89 | bar=['c', 'd'],
90 | ),
91 | )
92 |
93 | _TEST_FLAG = config_flags.DEFINE_config_dataclass('test_flag', _CONFIG,
94 | 'MyConfig data')
95 |
96 |
97 | def _test_flags(default, *flag_args, parse_fn=None):
98 | flag_values = flags.FlagValues()
99 | # DEFINE_config_dataclass accesses sys.argv to build flag list!
100 | old_args = list(sys.argv)
101 | sys.argv[:] = ['', *['--test_config' + f for f in flag_args]]
102 | try:
103 | result = config_flags.DEFINE_config_dataclass(
104 | 'test_config', default, flag_values=flag_values, parse_fn=parse_fn)
105 | _, *remaining = flag_values(sys.argv)
106 | if remaining:
107 | raise ValueError(f'{remaining}')
108 | # assert not remaining
109 | return result.value
110 | finally:
111 | sys.argv[:] = old_args
112 |
113 |
114 | def parse_config_flag(value):
115 | cfg = _CONFIG
116 | return dataclasses.replace(
117 | cfg,
118 | my_model=dataclasses.replace(cfg.my_model, foo=int(value)))
119 |
120 |
121 | class TypedConfigFlagsTest(absltest.TestCase):
122 |
123 | def test_types(self):
124 | self.assertIsInstance(_TEST_FLAG.value, MyConfig)
125 | self.assertEqual(_TEST_FLAG.value, _CONFIG)
126 | self.assertIsInstance(flags.FLAGS['test_flag'].value, MyConfig)
127 | self.assertIsInstance(flags.FLAGS.test_flag, MyConfig)
128 | self.assertEqual(flags.FLAGS['test_flag'].value, _CONFIG)
129 | module_name = __name__ if __name__ != '__main__' else sys.argv[0]
130 | self.assertEqual(
131 | flags.FLAGS.find_module_defining_flag('test_flag'), module_name)
132 |
133 | def test_instance(self):
134 | config = _test_flags(_CONFIG)
135 | self.assertIsInstance(config, MyConfig)
136 | self.assertEqual(config.my_model, _CONFIG.my_model)
137 | self.assertEqual(_CONFIG, config)
138 |
139 | def test_flag_config_dataclass_optional(self):
140 | result = _test_flags(_CONFIG, '.baseline_model.qux=10')
141 | self.assertEqual(result.baseline_model.qux, 10)
142 | self.assertIsInstance(result.baseline_model.qux, int)
143 | self.assertIsNone(result.my_model.qux)
144 |
145 | def test_flag_config_dataclass_repeated_arg_use_last(self):
146 | result = _test_flags(
147 | _CONFIG, '.baseline_model.qux=10', '.baseline_model.qux=12'
148 | )
149 | self.assertEqual(result.baseline_model.qux, 12)
150 | self.assertIsInstance(result.baseline_model.qux, int)
151 | self.assertIsNone(result.my_model.qux)
152 |
153 | def test_custom_flag_parsing_shared_default(self):
154 | result = _test_flags(_CONFIG, '.baseline_model.foo=324')
155 | result1 = _test_flags(_CONFIG, '.baseline_model.foo=123')
156 | # Here we verify that despite using _CONFIG as shared default for
157 | # result and result1, the final values are not in fact shared.
158 | self.assertEqual(result.baseline_model.foo, 324)
159 | self.assertEqual(result1.baseline_model.foo, 123)
160 | self.assertEqual(_CONFIG.baseline_model.foo, 55)
161 |
162 | def test_custom_flag_parsing_parser_override(self):
163 | config_flags.register_flag_parser_for_type(
164 | CustomParserConfig, ParserForCustomConfig(2))
165 | result = _test_flags(_CONFIG, '.custom=10')
166 | self.assertEqual(result.custom.i, 10)
167 | self.assertEqual(result.custom.j, 12)
168 |
169 | # Restore old parser.
170 | config_flags.register_flag_parser_for_type(
171 | CustomParserConfig, ParserForCustomConfig())
172 |
173 | @unittest.skipIf(sys.version_info[:2] < (3, 10), 'Need 3.10 to test | syntax')
174 | def test_pipe_syntax(self):
175 | @dataclasses.dataclass
176 | class PipeConfig:
177 | foo: int | None = None
178 |
179 | result = _test_flags(PipeConfig(), '.foo=32')
180 | self.assertEqual(result.foo, 32)
181 |
182 | def test_custom_flag_parsing_override_work(self):
183 | # Overrides still work.
184 | result = _test_flags(_CONFIG, '.custom.i=10')
185 | self.assertEqual(result.custom.i, 10)
186 | self.assertEqual(result.custom.j, 1)
187 |
188 | def test_optional_nested_fields(self):
189 | with self.assertRaises(ValueError):
190 | # Implicit creation not allowed.
191 | _test_flags(ConfigWithOptionalNestedField(), '.sub.model.foo=12')
192 |
193 | # Explicit creation works.
194 | result = _test_flags(
195 | ConfigWithOptionalNestedField(), '.sub=build', '.sub.model.foo=12'
196 | )
197 | self.assertEqual(result.sub.model.foo, 12)
198 |
199 | # Default initialization support.
200 | result = _test_flags(ConfigWithOptionalNestedField(), '.sub=build')
201 | self.assertEqual(result.sub.model.foo, 0)
202 |
203 | # Using default value (None).
204 | result = _test_flags(ConfigWithOptionalNestedField())
205 | self.assertIsNone(result.sub)
206 |
207 | with self.assertRaises(config_flag_lib.FlagOrderError):
208 | # Don't allow accidental overwrites.
209 | _test_flags(
210 | ConfigWithOptionalNestedField(), '.sub.model.foo=12', '.sub=build'
211 | )
212 |
213 | def test_set_to_none_dataclass_fields(self):
214 | result = _test_flags(
215 | ConfigWithOptionalNestedField(), '.sub=build', '.sub.model=none'
216 | )
217 | self.assertIsNone(result.sub.model, None)
218 |
219 | with self.assertRaises(KeyError):
220 | # Parent field is set to None (from not None default value),
221 | # so this is not a valid set of flags.
222 | _test_flags(
223 | ConfigWithOptionalNestedField(),
224 | '.sub=build',
225 | '.sub.model=none',
226 | '.sub.model.foo=12',
227 | )
228 |
229 | with self.assertRaises(KeyError):
230 | # Parent field is explicitly set to None (with None default value),
231 | # so this is not a valid set of flags.
232 | _test_flags(
233 | ConfigWithOptionalNestedField(), '.sub=none', '.sub.model.foo=12'
234 | )
235 |
236 | def test_set_none_non_optional_dataclass_fields(self):
237 | with self.assertRaises(flags.IllegalFlagValueError):
238 | # Field is not marked as optional so it can't be set to None.
239 | _test_flags(ConfigWithOptionalNestedField(), '.non_optional=None')
240 |
241 | def test_no_default_initializer(self):
242 | with self.assertRaises(flags.IllegalFlagValueError):
243 | _test_flags(ConfigWithOptionalNestedField(), '.sub=1', '.sub.model=1')
244 |
245 | def test_custom_flag_parser_invoked(self):
246 | # custom parser gets invoked
247 | result = _test_flags(_CONFIG, '.custom=10')
248 | self.assertEqual(result.custom.i, 10)
249 | self.assertEqual(result.custom.j, 11)
250 |
251 | def test_custom_flag_parser_invoked_overrides_applied(self):
252 | result = _test_flags(_CONFIG, '.custom=15', '.custom.i=11')
253 | # Override applied successfully
254 | self.assertEqual(result.custom.i, 11)
255 | self.assertEqual(result.custom.j, 16)
256 |
257 | def test_custom_flag_application_order(self):
258 | # Disallow for later value to override the earlier value.
259 | with self.assertRaises(config_flag_lib.FlagOrderError):
260 | _test_flags(_CONFIG, '.custom.i=11', '.custom=15')
261 |
262 | def test_flag_config_dataclass_type_mismatch(self):
263 | result = _test_flags(_CONFIG, '.my_model.bax=10')
264 | self.assertIsInstance(result.my_model.bax, float)
265 | # We can't do anything when the value isn't overridden.
266 | self.assertIsInstance(result.baseline_model.bax, int)
267 | self.assertRaises(
268 | flags.IllegalFlagValueError,
269 | functools.partial(_test_flags, _CONFIG, '.my_model.bax=string'),
270 | )
271 |
272 | def test_illegal_dataclass_field_type(self):
273 |
274 | @dataclasses.dataclass
275 | class Config:
276 | field: Union[int, float] = 3
277 |
278 | self.assertRaises(
279 | TypeError, functools.partial(_test_flags, Config(), '.field=1')
280 | )
281 |
282 | def test_spurious_dataclass_field(self):
283 |
284 | @dataclasses.dataclass
285 | class Config:
286 | field: int = 3
287 | cfg = Config()
288 | cfg.extra = 'test'
289 |
290 | self.assertRaises(
291 | KeyError, functools.partial(_test_flags, cfg, '.extra=hi')
292 | )
293 |
294 | def test_nested_dataclass(self):
295 |
296 | @dataclasses.dataclass
297 | class Parent:
298 | field: int = 3
299 |
300 | @dataclasses.dataclass
301 | class Child(Parent):
302 | other: int = 4
303 |
304 | self.assertEqual(_test_flags(Child(), '.field=1').field, 1)
305 |
306 | def test_flag_config_dataclass(self):
307 | result = _test_flags(_CONFIG, '.baseline_model.foo=10', '.my_model.foo=7')
308 | self.assertEqual(result.baseline_model.foo, 10)
309 | self.assertEqual(result.my_model.foo, 7)
310 |
311 | def test_flag_config_dataclass_string_dict(self):
312 | result = _test_flags(_CONFIG, '.my_model.baz["foo.b"]=rab')
313 | self.assertEqual(result.my_model.baz['foo.b'], 'rab')
314 |
315 | def test_flag_config_dataclass_tuple_dict(self):
316 | result = _test_flags(_CONFIG, '.my_model.buz[(0,1)]=hello')
317 | self.assertEqual(result.my_model.buz[(0, 1)], 'hello')
318 |
319 | def test_flag_config_dataclass_typed_tuple(self):
320 | result = _test_flags(_CONFIG, '.my_model.boj=(0, 1)')
321 | self.assertEqual(result.my_model.boj, (0, 1))
322 |
323 |
324 | class DataClassParseFnTest(absltest.TestCase):
325 |
326 | def test_parse_no_custom_value(self):
327 | result = _test_flags(
328 | _CONFIG, '.baseline_model.foo=10', parse_fn=parse_config_flag
329 | )
330 | self.assertEqual(result.my_model.foo, 3)
331 | self.assertEqual(result.baseline_model.foo, 10)
332 |
333 | def test_parse_custom_value_applied(self):
334 | result = _test_flags(
335 | _CONFIG, '=75', '.baseline_model.foo=10', parse_fn=parse_config_flag
336 | )
337 | self.assertEqual(result.my_model.foo, 75)
338 | self.assertEqual(result.baseline_model.foo, 10)
339 |
340 | def test_parse_custom_value_applied_no_explicit_parse_fn(self):
341 | result = _test_flags(CustomParserConfig(0), '=75', '.i=12')
342 | self.assertEqual(result.i, 12)
343 | self.assertEqual(result.j, 76)
344 |
345 | def test_parse_out_of_order(self):
346 | with self.assertRaises(config_flag_lib.FlagOrderError):
347 | _ = _test_flags(
348 | _CONFIG, '.baseline_model.foo=10', '=75', parse_fn=parse_config_flag
349 | )
350 | # Note: If this is ever supported, add verification that overrides are
351 | # applied correctly.
352 |
353 | def test_parse_assign_dataclass(self):
354 | flag_values = flags.FlagValues()
355 |
356 | def always_fail(v):
357 | raise ValueError()
358 |
359 | result = config_flags.DEFINE_config_dataclass(
360 | 'test_config', _CONFIG, flag_values=flag_values, parse_fn=always_fail)
361 | flag_values(['program'])
362 | flag_values['test_config'].value = parse_config_flag('12')
363 | self.assertEqual(result.value.my_model.foo, 12)
364 |
365 | def test_parse_invalid_custom_value(self):
366 | with self.assertRaises(flags.IllegalFlagValueError):
367 | _ = _test_flags(
368 | _CONFIG, '=?', '.baseline_model.foo=10', parse_fn=parse_config_flag
369 | )
370 |
371 | def test_parse_overrides_applied(self):
372 | result = _test_flags(
373 | _CONFIG, '=34', '.my_model.foo=10', parse_fn=parse_config_flag
374 | )
375 | self.assertEqual(result.my_model.foo, 10)
376 |
377 | if __name__ == '__main__':
378 | absltest.main()
379 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/tests/fieldreference_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Config file with field references."""
16 |
17 | from ml_collections import config_dict
18 |
19 |
20 | def get_config():
21 | cfg = config_dict.ConfigDict()
22 | cfg.ref = config_dict.FieldReference(123)
23 | cfg.ref_nodefault = config_dict.placeholder(int)
24 | return cfg
25 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/tests/ioerror_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Config file that raises IOError on import.
16 |
17 | The flags library tries to load configuration files in a few different ways.
18 | For this it relies on catching IOError exceptions of the type "File not
19 | found" and ignoring them to continue trying with a different loading method.
20 | But we need to ensure that other types of IOError exceptions are propagated
21 | correctly (b/63165566). This is tested in `ConfigFlagTest.testIOError` in
22 | `config_overriding_test.py`.
23 | """
24 |
25 | raise IOError('This is an IOError.')
26 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/tests/literal_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Config file with field references."""
16 |
17 | from ml_collections import config_dict
18 |
19 |
20 | def get_config():
21 | cfg = config_dict.ConfigDict()
22 | cfg.integer = config_dict.placeholder(object)
23 | cfg.string = config_dict.placeholder(object)
24 | cfg.nested = config_dict.placeholder(object)
25 | cfg.other_with_default = config_dict.placeholder(object)
26 | cfg.other_with_default = 123
27 | cfg.other_with_default_overitten = config_dict.placeholder(object)
28 | cfg.other_with_default_overitten = 123
29 | return cfg
30 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/tests/mini_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Placeholder Config file."""
16 |
17 |
18 | class MiniConfig(object):
19 | """Just a placeholder config."""
20 |
21 | def __init__(self):
22 | self.dict = {}
23 | self.field = False
24 |
25 | def __getitem__(self, key):
26 | return self.dict[key]
27 |
28 | def __contains__(self, key):
29 | return key in self.dict
30 |
31 | def __setitem__(self, key, value):
32 | self.dict[key] = value
33 |
34 |
35 | def get_config():
36 | cfg = MiniConfig()
37 | cfg['entry_with_collision'] = False
38 | cfg.entry_with_collision = True
39 | return cfg
40 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/tests/mock_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Placeholder Config file."""
16 |
17 | import copy
18 |
19 | from ml_collections.config_flags.tests import spork
20 |
21 |
22 | class TestConfig(object):
23 | """Default Test config value."""
24 |
25 | def __init__(self):
26 | self.integer = 23
27 | self.float = 2.34
28 | self.string = 'james'
29 | self.bool = True
30 | self.dict = {
31 | 'integer': 1,
32 | 'float': 3.14,
33 | 'string': 'mark',
34 | 'bool': False,
35 | 'dict': {
36 | 'float': 5.
37 | },
38 | 'list': [1, 2, [3]]
39 | }
40 | self.list = [1, 2, [3]]
41 | self.tuple = (1, 2, (3,))
42 | self.tuple_with_spaces = (1, 2, (3,))
43 | self.enum = spork.SporkType.SPOON
44 |
45 | @property
46 | def readonly_field(self):
47 | return 42
48 |
49 | def __repr__(self):
50 | return str(self.__dict__)
51 |
52 |
53 | def get_config():
54 |
55 | config = TestConfig()
56 | config.object = TestConfig()
57 | config.object_reference = config.object
58 | config.object_copy = copy.deepcopy(config.object)
59 |
60 | return config
61 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/tests/parameterised_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Config file where `get_config` takes a string argument."""
16 |
17 | from ml_collections import config_dict
18 |
19 |
20 | def get_config(config_string):
21 | """A config which takes an extra string argument."""
22 | possible_configs = {
23 | 'type_a': config_dict.ConfigDict({
24 | 'thing_a': 23,
25 | 'thing_b': 42,
26 | }),
27 | 'type_b': config_dict.ConfigDict({
28 | 'thing_a': 19,
29 | 'thing_c': 65,
30 | }),
31 | }
32 | return possible_configs[config_string]
33 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/tests/spork.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A enum to be used in a mock_config.
16 |
17 | The enum can't be defined directly in the config because a config imported
18 | through a flag is a separate instance of the config module. Therefore it would
19 | define an own instance of the enum class which won't be equal to the same enum
20 | from the config imported directly.
21 | """
22 |
23 | import enum
24 |
25 |
26 | class SporkType(enum.Enum):
27 | SPOON = 1
28 | SPORK = 2
29 | FORK = 3
30 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/tests/tuple_parser_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for ml_collection.config_flags.tuple_parser."""
16 |
17 | from ml_collections.config_flags import tuple_parser
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 |
21 |
22 | class TupleParserTest(parameterized.TestCase):
23 |
24 | @parameterized.parameters(
25 | {'argument': 1, 'expected': (1,)},
26 | {'argument': (1, 2), 'expected': (1, 2)},
27 | {'argument': ('abc', 'def'), 'expected': ('abc', 'def')},
28 | {'argument': '1', 'expected': (1,)},
29 | {'argument': '"abc"', 'expected': ('abc',)},
30 | {'argument': '"abc",', 'expected': ('abc',)},
31 | {'argument': '1, "a"', 'expected': (1, 'a')},
32 | {'argument': '(1, "a")', 'expected': (1, 'a')},
33 | {'argument': '(1, "a", (2, 3))', 'expected': (1, 'a', (2, 3))},
34 | {'argument': ('abc*', 'def*'), 'expected': ('abc*', 'def*')},
35 | {'argument': '("abc*", "def*")', 'expected': ('abc*', 'def*')},
36 | {'argument': '("/abc",)', 'expected': ('/abc',)},
37 | {'argument': '("/abc*",)', 'expected': ('/abc*',)},
38 | {'argument': '("/abc/",)', 'expected': ('/abc/',)},
39 | )
40 | def test_tuple_parser_parse(self, argument, expected):
41 | parser = tuple_parser.TupleParser()
42 | self.assertEqual(parser.parse(argument), expected)
43 |
44 | @parameterized.parameters(
45 | {'argument': '1', 'expected': (1,)},
46 | {'argument': '"abc"', 'expected': ('abc',)},
47 | {'argument': '"abc",', 'expected': ('abc',)},
48 | {'argument': 'abc', 'expected': ('abc',)},
49 | {'argument': '1, "a"', 'expected': (1, 'a')},
50 | {'argument': '(1, "a")', 'expected': (1, 'a')},
51 | {'argument': '(1, "a", (2, 3))', 'expected': (1, 'a', (2, 3))},
52 | {'argument': '("abc*", "def*")', 'expected': ('abc*', 'def*')},
53 | {'argument': '"abc*", "def*"', 'expected': ('abc*', 'def*')},
54 | {'argument': '"abc*",', 'expected': ('abc*',)},
55 | {'argument': 'abc*', 'expected': ('abc*',)},
56 | {'argument': '"/abc",', 'expected': ('/abc',)},
57 | {'argument': '/abc*', 'expected': ('/abc*',)},
58 | {'argument': '/abc/', 'expected': ('/abc/',)},
59 | )
60 | def test_convert_str_to_tuple(self, argument, expected):
61 | self.assertEqual(tuple_parser._convert_str_to_tuple(argument), expected)
62 |
63 | @parameterized.parameters(
64 | 'a b',
65 | 'a,b*',
66 | '/a,b*',
67 | '/a,b',
68 | 'a b',
69 | )
70 | def test_convert_str_to_tuple_bad_inputs(self, argument):
71 | with self.assertRaises(ValueError):
72 | tuple_parser._convert_str_to_tuple(argument)
73 |
74 |
75 | if __name__ == '__main__':
76 | absltest.main()
77 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/tests/typeerror_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Config file that raises TypeError on import.
16 |
17 | When trying loading the configuration file as a flag, the flags library catches
18 | TypeError exceptions then recasts them as a IllegalFlagTypeError and rethrows
19 | (b/63877430). The rethrow does not include the stacktrace from the original
20 | exception, so we manually add the stracktrace in configflags.parse(). This is
21 | tested in `ConfigFlagTest.testTypeError` in `config_overriding_test.py`.
22 | """
23 |
24 |
25 | def type_error_function():
26 | raise TypeError('This is a TypeError.')
27 |
28 |
29 | def get_config():
30 | return {'item': type_error_function()}
31 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/tests/valueerror_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Config file that raises ValueError on import.
16 |
17 | When trying loading the configuration file as a flag, the flags library catches
18 | ValueError exceptions then recasts them as a IllegalFlagValueError and rethrows
19 | (b/63877430). The rethrow does not include the stacktrace from the original
20 | exception, so we manually add the stracktrace in configflags.parse(). This is
21 | tested in `ConfigFlagTest.testValueError` in `config_overriding_test.py`.
22 | """
23 |
24 |
25 | def value_error_function():
26 | raise ValueError('This is a ValueError.')
27 |
28 |
29 | def get_config():
30 | return {'item': value_error_function()}
31 |
--------------------------------------------------------------------------------
/ml_collections/config_flags/tuple_parser.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Custom parser to override tuples in the config dict."""
16 | import ast
17 | import collections.abc
18 |
19 | from absl import flags
20 |
21 |
22 | class TupleParser(flags.ArgumentParser):
23 | """Parser for tuple arguments.
24 |
25 | Custom flag parser for Tuple objects that is based on the existing parsers
26 | in `absl.flags`. This parser can be used to read in and override tuple
27 | arguments. It outputs a `tuple` object from an existing `tuple` or `str`.
28 | The ony requirement is that the overriding parameter should be a `tuple`
29 | as well. The overriding parameter can have a different number of elements
30 | of different type than the original. For a detailed list of what `str`
31 | arguments are supported for overriding, look at `ast.literal_eval` from the
32 | Python Standard Library.
33 | """
34 |
35 | def parse(self, argument):
36 | """Returns a `tuple` representing the input `argument`.
37 |
38 | Args:
39 | argument: The argument to be parsed. Valid types are `tuple` and
40 | `str` or a single object. `str` arguments are parsed and converted to a
41 | tuple, a single object is converted to a tuple of length 1, and an empty
42 | `tuple` is returned for arguments `NoneType`.
43 |
44 | Returns:
45 | A `TupleType` representing the input argument as a `tuple`.
46 | """
47 | if isinstance(argument, tuple):
48 | return argument
49 | elif isinstance(argument, str):
50 | return _convert_str_to_tuple(argument)
51 | elif argument is None:
52 | return ()
53 | else:
54 | return (argument,)
55 |
56 | def flag_type(self):
57 | return 'tuple'
58 |
59 |
60 | def _convert_str_to_tuple(string):
61 | """Function to convert a Python `str` object to a `tuple`.
62 |
63 | Args:
64 | string: The `str` to be converted.
65 |
66 | Returns:
67 | A `tuple` version of the string.
68 |
69 | Raises:
70 | ValueError: If the string is not a well formed `tuple`.
71 | """
72 | # literal_eval converts strings to int, tuple, list, float and dict,
73 | # booleans and None. It can also handle nested tuples.
74 | # It does not, however, handle elements of type set.
75 | try:
76 | value = ast.literal_eval(string)
77 | except ValueError:
78 | # A ValueError is raised by literal_eval if the string is not well
79 | # formed. This is probably because string represents a single string
80 | # element, e.g. string='abc', and a tuple of strings field was overridden by
81 | # repeated use of a flag (ie `--flag a --flag b` instead of
82 | # `--flag '("a", "b")'`).
83 | value = string
84 | except SyntaxError as exc:
85 | # The only other error that may be raised is a `SyntaxError` because
86 | # `literal_eval` calls the Python in-built `compile`. This error is
87 | # caused by parsing issues.
88 | if ',' not in string and ' ' not in string:
89 | # Most likely passed a single string that contained an operator -- e.g.
90 | # '/path/to/file' or 'file_pattern*'. If a comma isn't in the string, then
91 | # it can't have been a tuple, so assume it's an unquoted string.
92 | # If passed strings containing a comma (user probably expects conversion
93 | # to a tuple) or whitespace (user might expect implicit conversion to
94 | # tuple?) raise an exception.
95 | value = string
96 | else:
97 | msg = (
98 | f'Error while parsing string: {string} as tuple. If you intended to'
99 | ' pass the argument as a single element, use quotes such as `--flag'
100 | f' {repr(repr(string))}, otherwise insert quotes around each element.'
101 | )
102 | if ' ' in string:
103 | msg += (
104 | ' Use commas instead of whitespace as the separator between'
105 | ' elements.'
106 | )
107 | raise ValueError(msg) from exc
108 |
109 | # Make sure we return a tuple.
110 | if isinstance(value, tuple):
111 | return value
112 | elif (isinstance(value, collections.abc.Iterable) and
113 | not isinstance(value, str)):
114 | return tuple(value)
115 | else:
116 | return (value,)
117 |
--------------------------------------------------------------------------------
/ml_collections/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The ML Collections Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Configuration file for pytest."""
16 |
17 | from absl import flags
18 | import pytest
19 |
20 |
21 | @pytest.fixture(scope="function", autouse=True)
22 | def mark_flags_as_parsed():
23 | flags.FLAGS.mark_as_parsed()
24 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | # Project metadata. Available keys are documented at:
3 | # https://packaging.python.org/en/latest/specifications/declaring-project-metadata
4 | name = "ml_collections"
5 | description = "ML Collections is a library of Python collections designed for ML usecases."
6 | readme = "README.md"
7 | requires-python = ">=3.10"
8 | license = {file = "LICENSE"}
9 | authors = [{name = "ML Collections Authors", email="ml-collections@google.com"}]
10 | classifiers = [ # List of https://pypi.org/classifiers/
11 | "Development Status :: 4 - Beta",
12 | "Intended Audience :: Developers",
13 | "Intended Audience :: Science/Research",
14 | "License :: OSI Approved :: Apache Software License",
15 | "Programming Language :: Python",
16 | "Topic :: Scientific/Engineering",
17 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
18 | "Topic :: Software Development :: Libraries",
19 | "Topic :: Software Development :: Libraries :: Python Modules",
20 | ]
21 | keywords = []
22 |
23 | # pip dependencies of the project
24 | # Installed locally with `pip install -e .`
25 | dependencies = [
26 | "absl-py",
27 | "PyYAML", # Could make this optional ?
28 | ]
29 |
30 | # `version` is automatically set by flit to use `ml_collections.__version__`
31 | dynamic = ["version"]
32 |
33 | [project.urls]
34 | homepage = "https://github.com/google/ml_collections"
35 | repository = "https://github.com/google/ml_collections"
36 | documentation = "https://ml-collections.readthedocs.io"
37 |
38 | [project.optional-dependencies]
39 | # Development deps (unittest, linting, formating,...)
40 | # Installed through `pip install -e .[dev]`
41 | dev = [
42 | "pytest",
43 | "pytest-xdist",
44 | "pylint>=2.6.0",
45 | "pyink",
46 | ]
47 |
48 | [tool.pyink]
49 | # Formatting configuration to follow Google style-guide
50 | line-length = 80
51 | unstable = true
52 | pyink-indentation = 2
53 | pyink-use-majority-quotes = true
54 |
55 | [build-system]
56 | # Build system specify which backend is used to build/install the project (flit,
57 | # poetry, setuptools,...). All backends are supported by `pip install`
58 | requires = ["flit_core >=3.8,<4"]
59 | build-backend = "flit_core.buildapi"
60 |
61 | [tool.flit.sdist]
62 | # Flit specific options (files to exclude from the PyPI package).
63 | # If using another build backend (setuptools, poetry), you can remove this
64 | # section.
65 | exclude = [
66 | # Do not release tests files on PyPI
67 | "**/*_test.py",
68 | ]
69 |
--------------------------------------------------------------------------------