├── docs ├── requirements.txt ├── Makefile ├── index.rst ├── recipes.rst ├── api.rst ├── changes.rst └── conf.py ├── MANIFEST.in ├── readthedocs.yml ├── README.md ├── CONTRIBUTING.md ├── tree ├── tree_benchmark.py ├── sequence.py ├── CMakeLists.txt ├── tree.h ├── tree.cc ├── __init__.py └── tree_test.py ├── .github └── workflows │ └── build.yml ├── setup.py └── LICENSE /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx>=2.0.1 2 | sphinx_rtd_theme>=0.4.3 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # metadata 2 | include LICENSE 3 | include WORKSPACE 4 | include CONTRIBUTING.md 5 | include README.md 6 | 7 | # python package requirements 8 | include requirements*.txt 9 | 10 | # tree files 11 | recursive-include . CMakeLists.txt *.cc *.cpp *.h *.sh *.py *.cmake 12 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | version: 2 5 | 6 | sphinx: 7 | builder: html 8 | configuration: docs/conf.py 9 | fail_on_warning: false 10 | 11 | python: 12 | version: 3.7 13 | install: 14 | - requirements: docs/requirements.txt 15 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | ################## 2 | Tree Documentation 3 | ################## 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | :hidden: 8 | 9 | api 10 | changes 11 | recipes 12 | 13 | ``tree`` is a library for working with nested data structures. In a way, 14 | ``tree`` generalizes the builtin :func:`map` function which only supports 15 | flat sequences, and allows to apply a function to each "leaf" preserving 16 | the overall structure. 17 | 18 | Here's a quick example:: 19 | 20 | >>> tree.map_structure(lambda v: v**2, [[1], [[[2, 3]]], [4]]) 21 | [[1], [[[4, 9]]], [16]] 22 | 23 | .. note:: 24 | 25 | ``tree`` has originally been part of TensorFlow and is available 26 | as ``tf.nest``. 27 | 28 | Installation 29 | ============ 30 | 31 | Install ``tree`` by running:: 32 | 33 | $ pip install dm-tree 34 | 35 | Support 36 | ======= 37 | 38 | If you are having issues, please let us know by filing an issue on our 39 | `issue tracker `_. 40 | 41 | License 42 | ======= 43 | 44 | Tree is licensed under the Apache 2.0 License. 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tree 2 | 3 | `tree` is a library for working with nested data structures. In a way, `tree` 4 | generalizes the builtin `map` function which only supports flat sequences, 5 | and allows to apply a function to each "leaf" preserving the overall 6 | structure. 7 | 8 | ```python 9 | >>> import tree 10 | >>> structure = [[1], [[[2, 3]]], [4]] 11 | >>> tree.flatten(structure) 12 | [1, 2, 3, 4] 13 | >>> tree.map_structure(lambda v: v**2, structure) 14 | [[1], [[[4, 9]]], [16]] 15 | ``` 16 | 17 | `tree` is backed by an optimized C++ implementation suitable for use in 18 | demanding applications, such as machine learning models. 19 | 20 | ## Installation 21 | 22 | From PyPI: 23 | 24 | ```shell 25 | $ pip install dm-tree 26 | ``` 27 | 28 | Directly from github using pip: 29 | 30 | ```shell 31 | $ pip install git+git://github.com/deepmind/tree.git 32 | ``` 33 | 34 | Build from source: 35 | 36 | ```shell 37 | $ python setup.py install 38 | ``` 39 | 40 | ## Support 41 | 42 | If you are having issues, please let us know by filing an issue on our 43 | [issue tracker](https://github.com/deepmind/tree/issues). 44 | 45 | ## License 46 | 47 | The project is licensed under the Apache 2.0 license. 48 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /docs/recipes.rst: -------------------------------------------------------------------------------- 1 | ############ 2 | Recipes 3 | ############ 4 | 5 | 6 | Concatenate nested array structures 7 | =================================== 8 | >>> tree.map_structure(lambda *args: np.concatenate(args, axis=1), 9 | ... {'a': np.ones((2, 1))}, 10 | ... {'a': np.zeros((2, 1))}) 11 | {'a': array([[1., 0.], 12 | [1., 0.]])} 13 | 14 | >>> tree.map_structure(lambda *args: np.concatenate(args, axis=0), 15 | ... {'a': np.ones((2, 1))}, 16 | ... {'a': np.zeros((2, 1))}) 17 | {'a': array([[1.], 18 | [1.], 19 | [0.], 20 | [0.]])} 21 | 22 | 23 | Exclude "meta" keys while mapping across structures 24 | =================================================== 25 | >>> d1 = {'key_to_exclude': None, 'a': 1} 26 | >>> d2 = {'key_to_exclude': None, 'a': 2} 27 | >>> d3 = {'a': 3} 28 | >>> tree.map_structure_up_to({'a': True}, lambda x, y, z: x+y+z, d1, d2, d3) 29 | {'a': 6} 30 | 31 | 32 | Broadcast a value across a reference structure 33 | ============================================== 34 | >>> reference_tree = {'a': 1, 'b': (2, 3)} 35 | >>> value = np.inf 36 | >>> tree.map_structure(lambda _: value, reference_tree) 37 | {'a': inf, 'b': (inf, inf)} 38 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | ############# 2 | API Reference 3 | ############# 4 | 5 | All ``tree`` functions operate on nested tree-like structures. A *structure* 6 | is recursively defined as:: 7 | 8 | Structure = Union[ 9 | Any, 10 | Sequence['Structure'], 11 | Mapping[Any, 'Structure'], 12 | 'AnyNamedTuple', 13 | ] 14 | 15 | .. TODO(slebedev): Support @dataclass classes if we make @attr.s 16 | .. support public. 17 | 18 | A single (non-nested) Python object is a perfectly valid structure:: 19 | 20 | >>> tree.map_structure(lambda v: v * 2, 42) 21 | 84 22 | >>> tree.flatten(42) 23 | [42] 24 | 25 | You could check whether a structure is actually nested via 26 | :func:`~tree.is_nested`:: 27 | 28 | >>> tree.is_nested(42) 29 | False 30 | >>> tree.is_nested([42]) 31 | True 32 | 33 | Note that ``tree`` only supports acyclic structures. The behavior for 34 | structures with cycle references is undefined. 35 | 36 | .. currentmodule:: tree 37 | 38 | .. autofunction:: is_nested 39 | 40 | .. autofunction:: assert_same_structure 41 | 42 | .. autofunction:: unflatten_as 43 | 44 | .. autofunction:: flatten 45 | 46 | .. autofunction:: flatten_up_to 47 | 48 | .. autofunction:: flatten_with_path 49 | 50 | .. autofunction:: flatten_with_path_up_to 51 | 52 | .. autofunction:: map_structure 53 | 54 | .. autofunction:: map_structure_up_to 55 | 56 | .. autofunction:: map_structure_with_path 57 | 58 | .. autofunction:: map_structure_with_path_up_to 59 | 60 | .. autofunction:: traverse 61 | 62 | .. autodata:: MAP_TO_NONE 63 | -------------------------------------------------------------------------------- /docs/changes.rst: -------------------------------------------------------------------------------- 1 | ######### 2 | Changelog 3 | ######### 4 | 5 | Version 0.1.9 6 | ============= 7 | 8 | Released 2025-01-30 9 | 10 | * Dropped support for Python <3.10. 11 | 12 | Version 0.1.8 13 | ============= 14 | 15 | Released 2022-12-19 16 | 17 | * Bumped pybind11 to v2.10.1 to support Python 3.11. 18 | * Dropped support for Python 3.6. 19 | 20 | Version 0.1.7 21 | ============= 22 | 23 | Released 2022-04-10 24 | 25 | * The build is now done via CMake instead of Bazel. 26 | 27 | Version 0.1.6 28 | ============= 29 | 30 | Released 2021-04-12 31 | 32 | * Dropped support for Python 2.X. 33 | * Added a generalization of ``tree.traverse`` which keeps track of the 34 | current path during traversal. 35 | 36 | Version 0.1.5 37 | ============= 38 | 39 | Released 2020-04-30 40 | 41 | * Added a new function ``tree.traverse`` which allows to traverse a nested 42 | structure and apply a function to each subtree. 43 | 44 | Version 0.1.4 45 | ============= 46 | 47 | Released 2020-03-27 48 | 49 | * Added support for ``types.MappingProxyType`` on Python 3.X. 50 | 51 | Version 0.1.3 52 | ============= 53 | 54 | Released 2020-01-30 55 | 56 | * Fixed ``ImportError`` when ``wrapt`` was not available. 57 | 58 | Version 0.1.2 59 | ============= 60 | 61 | Released 2020-01-29 62 | 63 | * Added support for ``wrapt.ObjectWrapper`` objects. 64 | * Added ``StructureKV[K, V]`` and ``Structure = Structure[Text, V]`` types. 65 | 66 | Version 0.1.1 67 | ============= 68 | 69 | Released 2019-11-07 70 | 71 | * Ensured that the produced Linux wheels are manylinux2010-compatible. 72 | 73 | Version 0.1.0 74 | ============= 75 | 76 | Released 2019-11-05 77 | 78 | * Initial public release. 79 | -------------------------------------------------------------------------------- /tree/tree_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Benchmarks for utilities working with arbitrarily nested structures.""" 16 | 17 | import collections 18 | import timeit 19 | 20 | import tree 21 | 22 | 23 | TIME_UNITS = [ 24 | (1, "s"), 25 | (10**-3, "ms"), 26 | (10**-6, "us"), 27 | (10**-9, "ns"), 28 | ] 29 | 30 | 31 | def format_time(time): 32 | for d, unit in TIME_UNITS: 33 | if time > d: 34 | return "{:.2f}{}".format(time / d, unit) 35 | 36 | 37 | def run_benchmark(benchmark_fn, num_iters): 38 | times = timeit.repeat(benchmark_fn, repeat=2, number=num_iters) 39 | return times[-1] / num_iters # Discard the first half for "warmup". 40 | 41 | 42 | def map_to_list(func, *args): 43 | return list(map(func, *args)) 44 | 45 | 46 | def benchmark_map(map_fn, structure): 47 | def benchmark_fn(): 48 | return map_fn(lambda v: v, structure) 49 | return benchmark_fn 50 | 51 | 52 | BENCHMARKS = collections.OrderedDict([ 53 | ("tree_map_1", benchmark_map(tree.map_structure, [0])), 54 | ("tree_map_8", benchmark_map(tree.map_structure, [0] * 8)), 55 | ("tree_map_64", benchmark_map(tree.map_structure, [0] * 64)), 56 | ("builtin_map_1", benchmark_map(map_to_list, [0])), 57 | ("builtin_map_8", benchmark_map(map_to_list, [0] * 8)), 58 | ("builtin_map_64", benchmark_map(map_to_list, [0] * 64)), 59 | ]) 60 | 61 | 62 | def main(): 63 | for name, benchmark_fn in BENCHMARKS.items(): 64 | print(name, format_time(run_benchmark(benchmark_fn, num_iters=1000))) 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | release: 9 | types: [created] 10 | workflow_dispatch: 11 | 12 | jobs: 13 | sdist: 14 | name: sdist 15 | runs-on: ubuntu-24.04 16 | steps: 17 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 18 | - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 19 | with: 20 | python-version: '3.11' 21 | - name: Create sdist 22 | run: | 23 | python -m pip install --upgrade pip setuptools 24 | python setup.py sdist 25 | shell: bash 26 | - name: List output directory 27 | run: ls -lh dist/dm_tree*.tar.gz 28 | shell: bash 29 | - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 30 | if: ${{ github.event_name == 'workflow_dispatch' || (github.event_name == 'release' && github.event.action == 'created') }} 31 | with: 32 | name: dm-tree-sdist 33 | path: dist/dm_tree*.tar.gz 34 | 35 | bdist-wheel: 36 | name: Build wheels on ${{ matrix.os }} 37 | runs-on: ${{ matrix.os }} 38 | strategy: 39 | matrix: 40 | os: [ubuntu-24.04, macos-14, windows-2022] # latest 41 | 42 | steps: 43 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 44 | with: 45 | submodules: true 46 | - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 47 | with: 48 | python-version: "3.11" 49 | - name: Set up QEMU 50 | if: runner.os == 'Linux' 51 | uses: docker/setup-qemu-action@53851d14592bedcffcf25ea515637cff71ef929a # v3.3.0 52 | with: 53 | platforms: all 54 | # This should be temporary 55 | # xref https://github.com/docker/setup-qemu-action/issues/188 56 | # xref https://github.com/tonistiigi/binfmt/issues/215 57 | image: tonistiigi/binfmt:qemu-v8.1.5 58 | - name: Install cibuildwheel 59 | run: python -m pip install cibuildwheel==2.22.0 60 | - name: Build wheels 61 | run: python -m cibuildwheel --output-dir wheelhouse 62 | env: 63 | CIBW_ARCHS_LINUX: auto aarch64 64 | CIBW_ARCHS_MACOS: universal2 65 | CIBW_BUILD: "cp310-* cp311-* cp312-* cp313-* cp313t-*" 66 | CIBW_BUILD_VERBOSITY: 1 67 | CIBW_FREE_THREADED_SUPPORT: True 68 | CIBW_PRERELEASE_PYTHONS: True 69 | CIBW_SKIP: "*musllinux* *i686* *win32* *t-win*" 70 | CIBW_TEST_COMMAND: pytest --pyargs tree 71 | CIBW_TEST_REQUIRES: pytest 72 | MAKEFLAGS: "-j$(nproc)" 73 | - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 74 | if: ${{ github.event_name == 'workflow_dispatch' || (github.event_name == 'release' && github.event.action == 'created') }} 75 | with: 76 | name: dm-tree-bdist-wheel-${{ matrix.os }}-${{ strategy.job-index }} 77 | path: wheelhouse/*.whl 78 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Configuration file for the Sphinx documentation builder.""" 16 | 17 | # This file only contains a selection of the most common options. For a full 18 | # list see the documentation: 19 | # http://www.sphinx-doc.org/en/master/config 20 | 21 | # -- Path setup -------------------------------------------------------------- 22 | 23 | # If extensions (or modules to document with autodoc) are in another directory, 24 | # add these directories to sys.path here. If the directory is relative to the 25 | # documentation root, use os.path.abspath to make it absolute, like shown here. 26 | 27 | # pylint: disable=g-bad-import-order 28 | # pylint: disable=g-import-not-at-top 29 | import datetime 30 | import inspect 31 | import os 32 | import sys 33 | 34 | sys.path.insert(0, os.path.abspath('../')) 35 | 36 | import tree 37 | 38 | # -- Project information ----------------------------------------------------- 39 | 40 | project = 'Tree' 41 | copyright = f'{datetime.date.today().year}, DeepMind' # pylint: disable=redefined-builtin 42 | author = 'DeepMind' 43 | 44 | # -- General configuration --------------------------------------------------- 45 | 46 | master_doc = 'index' 47 | 48 | # Add any Sphinx extension module names here, as strings. They can be 49 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 50 | # ones. 51 | extensions = [ 52 | 'sphinx.ext.autodoc', 53 | 'sphinx.ext.autosummary', 54 | 'sphinx.ext.linkcode', 55 | 'sphinx.ext.napoleon', 56 | 'sphinx.ext.doctest' 57 | ] 58 | 59 | # Add any paths that contain templates here, relative to this directory. 60 | templates_path = ['_templates'] 61 | 62 | # List of patterns, relative to source directory, that match files and 63 | # directories to ignore when looking for source files. 64 | # This pattern also affects html_static_path and html_extra_path. 65 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 66 | 67 | # -- Options for autodoc ----------------------------------------------------- 68 | 69 | autodoc_default_options = { 70 | 'member-order': 'bysource', 71 | 'special-members': True, 72 | 'exclude-members': '__repr__, __str__, __weakref__', 73 | } 74 | 75 | # -- Options for HTML output ------------------------------------------------- 76 | 77 | # The theme to use for HTML and HTML Help pages. See the documentation for 78 | # a list of builtin themes. 79 | # 80 | html_theme = 'sphinx_rtd_theme' 81 | 82 | html_theme_options = { 83 | # 'collapse_navigation': False, 84 | # 'sticky_navigation': False, 85 | } 86 | 87 | # -- Options for doctest ----------------------------------------------------- 88 | 89 | doctest_global_setup = ''' 90 | import collections 91 | import numpy as np 92 | import tree 93 | ''' 94 | 95 | # -- Source code links ------------------------------------------------------- 96 | 97 | 98 | def linkcode_resolve(domain, info): 99 | """Resolve a GitHub URL corresponding to Python object.""" 100 | if domain != 'py': 101 | return None 102 | 103 | try: 104 | mod = sys.modules[info['module']] 105 | except ImportError: 106 | return None 107 | 108 | obj = mod 109 | try: 110 | for attr in info['fullname'].split('.'): 111 | obj = getattr(obj, attr) 112 | except AttributeError: 113 | return None 114 | else: 115 | obj = inspect.unwrap(obj) 116 | 117 | try: 118 | filename = inspect.getsourcefile(obj) 119 | except TypeError: 120 | return None 121 | 122 | try: 123 | source, lineno = inspect.getsourcelines(obj) 124 | except OSError: 125 | return None 126 | 127 | # TODO(slebedev): support tags after we release an initial version. 128 | return 'https://github.com/deepmind/tree/blob/master/tree/%s#L%d#L%d' % ( 129 | os.path.relpath(filename, start=os.path.dirname( 130 | tree.__file__)), lineno, lineno + len(source) - 1) 131 | -------------------------------------------------------------------------------- /tree/sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains _sequence_like and helpers for sequence data structures.""" 16 | import collections 17 | from collections import abc as collections_abc 18 | import types 19 | from tree import _tree 20 | 21 | # pylint: disable=g-import-not-at-top 22 | try: 23 | import wrapt 24 | ObjectProxy = wrapt.ObjectProxy 25 | except ImportError: 26 | class ObjectProxy(object): 27 | """Stub-class for `wrapt.ObjectProxy``.""" 28 | 29 | 30 | def _sorted(dictionary): 31 | """Returns a sorted list of the dict keys, with error if keys not sortable.""" 32 | try: 33 | return sorted(dictionary) 34 | except TypeError: 35 | raise TypeError("tree only supports dicts with sortable keys.") 36 | 37 | 38 | def _is_attrs(instance): 39 | return _tree.is_attrs(instance) 40 | 41 | 42 | def _is_namedtuple(instance, strict=False): 43 | """Returns True iff `instance` is a `namedtuple`. 44 | 45 | Args: 46 | instance: An instance of a Python object. 47 | strict: If True, `instance` is considered to be a `namedtuple` only if 48 | it is a "plain" namedtuple. For instance, a class inheriting 49 | from a `namedtuple` will be considered to be a `namedtuple` 50 | iff `strict=False`. 51 | 52 | Returns: 53 | True if `instance` is a `namedtuple`. 54 | """ 55 | return _tree.is_namedtuple(instance, strict) 56 | 57 | 58 | def _sequence_like(instance, args): 59 | """Converts the sequence `args` to the same type as `instance`. 60 | 61 | Args: 62 | instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, or 63 | `collections.OrderedDict`. 64 | args: elements to be converted to the `instance` type. 65 | 66 | Returns: 67 | `args` with the type of `instance`. 68 | """ 69 | if isinstance(instance, (dict, collections_abc.Mapping)): 70 | # Pack dictionaries in a deterministic order by sorting the keys. 71 | # Notice this means that we ignore the original order of `OrderedDict` 72 | # instances. This is intentional, to avoid potential bugs caused by mixing 73 | # ordered and plain dicts (e.g., flattening a dict but using a 74 | # corresponding `OrderedDict` to pack it back). 75 | result = dict(zip(_sorted(instance), args)) 76 | keys_and_values = ((key, result[key]) for key in instance) 77 | if isinstance(instance, collections.defaultdict): 78 | # `defaultdict` requires a default factory as the first argument. 79 | return type(instance)(instance.default_factory, keys_and_values) 80 | elif isinstance(instance, types.MappingProxyType): 81 | # MappingProxyType requires a dict to proxy to. 82 | return type(instance)(dict(keys_and_values)) 83 | else: 84 | return type(instance)(keys_and_values) 85 | elif isinstance(instance, collections_abc.MappingView): 86 | # We can't directly construct mapping views, so we create a list instead 87 | return list(args) 88 | elif _is_namedtuple(instance) or _is_attrs(instance): 89 | if isinstance(instance, ObjectProxy): 90 | instance_type = type(instance.__wrapped__) 91 | else: 92 | instance_type = type(instance) 93 | try: 94 | if _is_attrs(instance): 95 | return instance_type( 96 | **{ 97 | attr.name: arg 98 | for attr, arg in zip(instance_type.__attrs_attrs__, args) 99 | }) 100 | else: 101 | return instance_type(*args) 102 | except Exception as e: 103 | raise TypeError( 104 | f"Couldn't traverse {instance!r} with arguments {args}") from e 105 | elif isinstance(instance, ObjectProxy): 106 | # For object proxies, first create the underlying type and then re-wrap it 107 | # in the proxy type. 108 | return type(instance)(_sequence_like(instance.__wrapped__, args)) 109 | else: 110 | # Not a namedtuple 111 | return type(instance)(args) 112 | -------------------------------------------------------------------------------- /tree/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Version >= 3.24 required for new `FindPython` module and `FIND_PACKAGE_ARGS` 2 | # keyword of `FetchContent` module. 3 | # https://cmake.org/cmake/help/v3.24/release/3.24.html 4 | cmake_minimum_required(VERSION 3.24) 5 | 6 | cmake_policy(SET CMP0135 NEW) 7 | 8 | project (tree LANGUAGES CXX) 9 | 10 | option(USE_SYSTEM_ABSEIL "Force use of system abseil-cpp" OFF) 11 | option(USE_SYSTEM_PYBIND11 "Force use of system pybind11" OFF) 12 | 13 | # Required for Python.h and python binding. 14 | find_package(Python3 COMPONENTS Interpreter Development) 15 | include_directories(SYSTEM ${Python3_INCLUDE_DIRS}) 16 | if(Python3_VERSION VERSION_LESS "3.6.0") 17 | message(FATAL_ERROR 18 | "Python found ${Python3_VERSION} < 3.6.0") 19 | endif() 20 | 21 | # Use C++14 standard. 22 | set(CMAKE_CXX_STANDARD 14 CACHE STRING "C++ version selection") 23 | 24 | # Position-independent code is needed for Python extension modules. 25 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 26 | 27 | # Set default build type. 28 | if(NOT CMAKE_BUILD_TYPE) 29 | set(CMAKE_BUILD_TYPE RELEASE 30 | CACHE STRING "Choose the type of build: Debug Release." 31 | FORCE) 32 | endif() 33 | message("Current build type is: ${CMAKE_BUILD_TYPE}") 34 | message("PROJECT_BINARY_DIR is: ${PROJECT_BINARY_DIR}") 35 | 36 | if (NOT (WIN32 OR MSVC)) 37 | if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") 38 | # Basic build for debugging (default). 39 | # -Og enables optimizations that do not interfere with debugging. 40 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Og") 41 | endif() 42 | 43 | if(${CMAKE_BUILD_TYPE} STREQUAL "Release") 44 | # Optimized release build: turn off debug runtime checks 45 | # and turn on highest speed optimizations. 46 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNDEBUG -O3") 47 | endif() 48 | endif() 49 | 50 | if(APPLE) 51 | # On MacOS: 52 | # -undefined dynamic_lookup is necessary for pybind11 linking 53 | set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-everything -w -undefined dynamic_lookup") 54 | 55 | # On MacOS, we need this so that CMake will use the right Python if the user 56 | # has a virtual environment active 57 | set (CMAKE_FIND_FRAMEWORK LAST) 58 | endif() 59 | 60 | # Use `FetchContent` module to manage all external dependencies (i.e. 61 | # abseil-cpp and pybind11). 62 | include(FetchContent) 63 | 64 | # Needed to disable Abseil tests. 65 | set(BUILD_TESTING OFF) 66 | 67 | # Try to find abseil-cpp package system-wide first. 68 | if (USE_SYSTEM_ABSEIL) 69 | message(STATUS "Use system abseil-cpp: ${USE_SYSTEM_ABSEIL}") 70 | set(ABSEIL_FIND_PACKAGE_ARGS FIND_PACKAGE_ARGS) 71 | endif() 72 | 73 | # Include abseil-cpp. 74 | set(ABSEIL_REPO https://github.com/abseil/abseil-cpp) 75 | set(ABSEIL_CMAKE_ARGS 76 | "-DCMAKE_INSTALL_PREFIX=${CMAKE_SOURCE_DIR}/abseil-cpp" 77 | "-DCMAKE_CXX_STANDARD=${CMAKE_CXX_STANDARD}" 78 | "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" 79 | "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}" 80 | "-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}" 81 | "-DCMAKE_POSITION_INDEPENDENT_CODE=${CMAKE_POSITION_INDEPENDENT_CODE}" 82 | "-DLIBRARY_OUTPUT_PATH=${CMAKE_SOURCE_DIR}/abseil-cpp/lib" 83 | "-DABSL_PROPAGATE_CXX_STD=ON") 84 | if(DEFINED CMAKE_OSX_ARCHITECTURES) 85 | set(ABSEIL_CMAKE_ARGS 86 | ${ABSEIL_CMAKE_ARGS} 87 | "-DCMAKE_OSX_ARCHITECTURES=${CMAKE_OSX_ARCHITECTURES}") 88 | endif() 89 | 90 | FetchContent_Declare( 91 | absl 92 | URL ${ABSEIL_REPO}/archive/refs/tags/20220623.2.tar.gz 93 | URL_HASH SHA256=773652c0fc276bcd5c461668dc112d0e3b6cde499600bfe3499c5fdda4ed4a5b 94 | CMAKE_ARGS ${ABSEIL_CMAKE_ARGS} 95 | EXCLUDE_FROM_ALL 96 | ${ABSEIL_FIND_PACKAGE_ARGS}) 97 | 98 | # Try to find pybind11 package system-wide first. 99 | if (USE_SYSTEM_PYBIND11) 100 | message(STATUS "Use system pybind11: ${USE_SYSTEM_PYBIND11}") 101 | set(PYBIND11_FIND_PACKAGE_ARGS FIND_PACKAGE_ARGS) 102 | endif() 103 | 104 | FetchContent_Declare( 105 | pybind11 106 | URL https://github.com/pybind/pybind11/archive/refs/tags/v2.10.1.tar.gz 107 | URL_HASH SHA256=111014b516b625083bef701df7880f78c2243835abdb263065b6b59b960b6bad 108 | ${PYBIND11_FIND_PACKAGE_ARGS}) 109 | 110 | FetchContent_MakeAvailable(absl pybind11) 111 | 112 | # Define pybind11 tree module. 113 | pybind11_add_module(_tree tree.h tree.cc) 114 | 115 | target_link_libraries( 116 | _tree 117 | PRIVATE 118 | absl::int128 119 | absl::raw_hash_set 120 | absl::raw_logging_internal 121 | absl::strings 122 | absl::throw_delegate) 123 | 124 | # Make the module private to tree package. 125 | set_target_properties(_tree PROPERTIES OUTPUT_NAME tree/_tree) 126 | 127 | 128 | -------------------------------------------------------------------------------- /tree/tree.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TREE_H_ 17 | #define TREE_H_ 18 | 19 | #include 20 | 21 | #include 22 | 23 | namespace tree { 24 | 25 | // Returns a true if its input is a collections.Sequence (except strings). 26 | // 27 | // Args: 28 | // seq: an input sequence. 29 | // 30 | // Returns: 31 | // True if the sequence is a not a string and is a collections.Sequence or a 32 | // dict. 33 | bool IsSequence(PyObject* o); 34 | 35 | // Returns Py_True iff `instance` should be considered a `namedtuple`. 36 | // 37 | // Args: 38 | // instance: An instance of a Python object. 39 | // strict: If True, `instance` is considered to be a `namedtuple` only if 40 | // it is a "plain" namedtuple. For instance, a class inheriting 41 | // from a `namedtuple` will be considered to be a `namedtuple` 42 | // iff `strict=False`. 43 | // 44 | // Returns: 45 | // True if `instance` is a `namedtuple`. 46 | PyObject* IsNamedtuple(PyObject* o, bool strict); 47 | 48 | // Returns a true if its input is an instance of an attr.s decorated class. 49 | // 50 | // Args: 51 | // o: the input to be checked. 52 | // 53 | // Returns: 54 | // True if the object is an instance of an attr.s decorated class. 55 | bool IsAttrs(PyObject* o); 56 | 57 | // Returns Py_True iff the two namedtuples have the same name and fields. 58 | // Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have 59 | // '_fields' attribute). 60 | PyObject* SameNamedtuples(PyObject* o1, PyObject* o2); 61 | 62 | // Asserts that two structures are nested in the same way. 63 | // 64 | // Note that namedtuples with identical name and fields are always considered 65 | // to have the same shallow structure (even with `check_types=True`). 66 | // For intance, this code will print `True`: 67 | // 68 | // ```python 69 | // def nt(a, b): 70 | // return collections.namedtuple('foo', 'a b')(a, b) 71 | // print(assert_same_structure(nt(0, 1), nt(2, 3))) 72 | // ``` 73 | // 74 | // Args: 75 | // nest1: an arbitrarily nested structure. 76 | // nest2: an arbitrarily nested structure. 77 | // check_types: if `true`, types of sequences are checked as 78 | // well, including the keys of dictionaries. If set to `false`, for example 79 | // a list and a tuple of objects will look the same if they have the same 80 | // size. Note that namedtuples with identical name and fields are always 81 | // considered to have the same shallow structure. 82 | // 83 | // Raises: 84 | // ValueError: If the two structures do not have the same number of elements or 85 | // if the two structures are not nested in the same way. 86 | // TypeError: If the two structures differ in the type of sequence in any of 87 | // their substructures. Only possible if `check_types` is `True`. 88 | void AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types); 89 | 90 | // 91 | // Returns a flat list from a given nested structure. 92 | // 93 | // If `nest` is not a sequence, tuple, or dict, then returns a single-element 94 | // list: `[nest]`. 95 | // 96 | // In the case of dict instances, the sequence consists of the values, sorted by 97 | // key to ensure deterministic behavior. This is true also for `OrderedDict` 98 | // instances: their sequence order is ignored, the sorting order of keys is 99 | // used instead. The same convention is followed in `pack_sequence_as`. This 100 | // correctly repacks dicts and `OrderedDict`s after they have been flattened, 101 | // and also allows flattening an `OrderedDict` and then repacking it back using 102 | // a corresponding plain dict, or vice-versa. 103 | // Dictionaries with non-sortable keys cannot be flattened. 104 | // 105 | // Args: 106 | // nest: an arbitrarily nested structure or a scalar object. Note, numpy 107 | // arrays are considered scalars. 108 | // 109 | // Returns: 110 | // A Python list, the flattened version of the input. 111 | // On error, returns nullptr 112 | // 113 | // Raises: 114 | // TypeError: The nest is or contains a dict with non-sortable keys. 115 | PyObject* Flatten(PyObject* nested); 116 | 117 | struct DecrementsPyRefcount { 118 | void operator()(PyObject* p) const { Py_DECREF(p); } 119 | }; 120 | 121 | // ValueIterator interface 122 | class ValueIterator { 123 | public: 124 | virtual ~ValueIterator() {} 125 | virtual std::unique_ptr next() = 0; 126 | 127 | bool valid() const { return is_valid_; } 128 | 129 | protected: 130 | void invalidate() { is_valid_ = false; } 131 | 132 | private: 133 | bool is_valid_ = true; 134 | }; 135 | 136 | std::unique_ptr GetValueIterator(PyObject* nested); 137 | } // namespace tree 138 | 139 | #endif // TREE_H_ 140 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Setup for pip package.""" 16 | 17 | import os 18 | import platform 19 | import shutil 20 | import subprocess 21 | import sys 22 | import sysconfig 23 | 24 | import setuptools 25 | from setuptools.command import build_ext 26 | 27 | here = os.path.dirname(os.path.abspath(__file__)) 28 | 29 | 30 | def _get_tree_version(): 31 | """Parse the version string from tree/__init__.py.""" 32 | with open(os.path.join(here, 'tree', '__init__.py')) as f: 33 | try: 34 | version_line = next(line for line in f if line.startswith('__version__')) 35 | except StopIteration: 36 | raise ValueError('__version__ not defined in tree/__init__.py') 37 | else: 38 | ns = {} 39 | exec(version_line, ns) # pylint: disable=exec-used 40 | return ns['__version__'] 41 | 42 | 43 | class CMakeExtension(setuptools.Extension): 44 | """An extension with no sources. 45 | 46 | We do not want distutils to handle any of the compilation (instead we rely 47 | on CMake), so we always pass an empty list to the constructor. 48 | """ 49 | 50 | def __init__(self, name, source_dir=''): 51 | super().__init__(name, sources=[]) 52 | self.source_dir = os.path.abspath(source_dir) 53 | 54 | 55 | class BuildCMakeExtension(build_ext.build_ext): 56 | """Our custom build_ext command. 57 | 58 | Uses CMake to build extensions instead of a bare compiler (e.g. gcc, clang). 59 | """ 60 | 61 | def run(self): 62 | self._check_build_environment() 63 | for ext in self.extensions: 64 | self.build_extension(ext) 65 | 66 | def _check_build_environment(self): 67 | """Check for required build tools: CMake, C++ compiler, and python dev.""" 68 | try: 69 | subprocess.check_call(['cmake', '--version']) 70 | except OSError as e: 71 | ext_names = ', '.join(e.name for e in self.extensions) 72 | raise RuntimeError( 73 | f'CMake must be installed to build the following extensions: {ext_names}' 74 | ) from e 75 | print('Found CMake') 76 | 77 | def build_extension(self, ext): 78 | extension_dir = os.path.abspath( 79 | os.path.dirname(self.get_ext_fullpath(ext.name))) 80 | build_cfg = 'Debug' if self.debug else 'Release' 81 | cmake_args = [ 82 | f'-DPython3_ROOT_DIR={sys.prefix}', 83 | f'-DPython3_EXECUTABLE={sys.executable}', 84 | f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extension_dir}', 85 | f'-DCMAKE_BUILD_TYPE={build_cfg}' 86 | ] 87 | if platform.system() != 'Windows': 88 | cmake_args.extend([ 89 | f'-DPython3_LIBRARY={sysconfig.get_paths()["stdlib"]}', 90 | f'-DPython3_INCLUDE_DIR={sysconfig.get_paths()["include"]}', 91 | ]) 92 | if platform.system() == 'Darwin' and os.environ.get('ARCHFLAGS'): 93 | osx_archs = [] 94 | if '-arch x86_64' in os.environ['ARCHFLAGS']: 95 | osx_archs.append('x86_64') 96 | if '-arch arm64' in os.environ['ARCHFLAGS']: 97 | osx_archs.append('arm64') 98 | cmake_args.append(f'-DCMAKE_OSX_ARCHITECTURES={";".join(osx_archs)}') 99 | os.makedirs(self.build_temp, exist_ok=True) 100 | subprocess.check_call( 101 | ['cmake', '-S', ext.source_dir, '-B', self.build_temp] + cmake_args) 102 | num_jobs = () 103 | if self.parallel: 104 | num_jobs = (f'-j{self.parallel}',) 105 | subprocess.check_call([ 106 | 'cmake', '--build', self.build_temp, *num_jobs, '--config', build_cfg 107 | ]) 108 | 109 | # Force output to /. Amends CMake multigenerator output paths 110 | # on Windows and avoids Debug/ and Release/ subdirs, which is CMake default. 111 | tree_dir = os.path.join(extension_dir, 'tree') # pylint:disable=unreachable 112 | for cfg in ('Release', 'Debug'): 113 | cfg_dir = os.path.join(extension_dir, cfg) 114 | if os.path.isdir(cfg_dir): 115 | for f in os.listdir(cfg_dir): 116 | shutil.move(os.path.join(cfg_dir, f), tree_dir) 117 | 118 | 119 | setuptools.setup( 120 | name='dm-tree', 121 | version=_get_tree_version(), 122 | url='https://github.com/deepmind/tree', 123 | description='Tree is a library for working with nested data structures.', 124 | author='DeepMind', 125 | author_email='tree-copybara@google.com', 126 | long_description=open(os.path.join(here, 'README.md')).read(), 127 | long_description_content_type='text/markdown', 128 | # Contained modules and scripts. 129 | packages=setuptools.find_packages(), 130 | python_requires='>=3.10', 131 | install_requires=[ 132 | 'absl-py>=0.6.1', 133 | 'attrs>=18.2.0', 134 | 'numpy>=1.21', 135 | "numpy>=1.21.2; python_version>='3.10'", 136 | "numpy>=1.23.3; python_version>='3.11'", 137 | "numpy>=1.26.0; python_version>='3.12'", 138 | "numpy>=2.1.0; python_version>='3.13'", 139 | 'wrapt>=1.11.2', 140 | ], 141 | test_suite='tree', 142 | cmdclass=dict(build_ext=BuildCMakeExtension), 143 | ext_modules=[CMakeExtension('_tree', source_dir='tree')], 144 | zip_safe=False, 145 | # PyPI package information. 146 | classifiers=[ 147 | 'Development Status :: 4 - Beta', 148 | 'Intended Audience :: Developers', 149 | 'Intended Audience :: Science/Research', 150 | 'License :: OSI Approved :: Apache Software License', 151 | 'Programming Language :: Python :: 3.10', 152 | 'Programming Language :: Python :: 3.11', 153 | 'Programming Language :: Python :: 3.12', 154 | 'Programming Language :: Python :: 3.13', 155 | 'Topic :: Scientific/Engineering :: Mathematics', 156 | 'Topic :: Software Development :: Libraries', 157 | ], 158 | license='Apache 2.0', 159 | keywords='tree nest flatten', 160 | ) 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /tree/tree.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | #include "tree.h" 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | // logging 23 | #include "absl/memory/memory.h" 24 | #include "absl/strings/str_cat.h" 25 | #include "absl/strings/string_view.h" 26 | #include 27 | 28 | #ifdef LOG 29 | #define LOG_WARNING(w) LOG(WARNING) << w; 30 | #else 31 | #include 32 | #define LOG_WARNING(w) std::cerr << w << "\n"; 33 | #endif 34 | 35 | #ifndef DCHECK 36 | #define DCHECK(stmt) 37 | #endif 38 | 39 | namespace py = pybind11; 40 | 41 | namespace tree { 42 | namespace { 43 | 44 | // PyObjectPtr wraps an underlying Python object and decrements the 45 | // reference count in the destructor. 46 | // 47 | // This class does not acquire the GIL in the destructor, so the GIL must be 48 | // held when the destructor is called. 49 | using PyObjectPtr = std::unique_ptr; 50 | 51 | const int kMaxItemsInCache = 1024; 52 | 53 | bool WarnedThatSetIsNotSequence = false; 54 | 55 | bool IsString(PyObject* o) { 56 | return PyBytes_Check(o) || PyByteArray_Check(o) || PyUnicode_Check(o); 57 | } 58 | 59 | // Equivalent to Python's 'o.__class__.__name__' 60 | // Note that '__class__' attribute is set only in new-style classes. 61 | // A lot of tensorflow code uses __class__ without checks, so it seems like 62 | // we only support new-style classes. 63 | absl::string_view GetClassName(PyObject* o) { 64 | // __class__ is equivalent to type() for new style classes. 65 | // type() is equivalent to PyObject_Type() 66 | // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type) 67 | // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which 68 | // we don't need here. 69 | PyTypeObject* type = o->ob_type; 70 | 71 | // __name__ is the value of `tp_name` after the last '.' 72 | // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name) 73 | absl::string_view name(type->tp_name); 74 | size_t pos = name.rfind('.'); 75 | if (pos != absl::string_view::npos) { 76 | name.remove_prefix(pos + 1); 77 | } 78 | return name; 79 | } 80 | 81 | std::string PyObjectToString(PyObject* o) { 82 | if (o == nullptr) { 83 | return ""; 84 | } 85 | PyObject* str = PyObject_Str(o); 86 | if (str) { 87 | std::string s(PyUnicode_AsUTF8(str)); 88 | Py_DECREF(str); 89 | return absl::StrCat("type=", GetClassName(o), " str=", s); 90 | } else { 91 | return ""; 92 | } 93 | } 94 | 95 | class CachedTypeCheck { 96 | public: 97 | explicit CachedTypeCheck(std::function ternary_predicate) 98 | : ternary_predicate_(std::move(ternary_predicate)) {} 99 | 100 | ~CachedTypeCheck() { 101 | for (const auto& pair : type_to_sequence_map_) { 102 | Py_DECREF(pair.first); 103 | } 104 | } 105 | 106 | // Caches successful executions of the one-argument (PyObject*) callable 107 | // "ternary_predicate" based on the type of "o". -1 from the callable 108 | // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type 109 | // does not match the predicate, and 1 indicates that it does. Used to avoid 110 | // calling back into Python for expensive isinstance checks. 111 | int CachedLookup(PyObject* o) { 112 | // Try not to return to Python - see if the type has already been seen 113 | // before. 114 | auto* type = Py_TYPE(o); 115 | 116 | { 117 | auto it = type_to_sequence_map_.find(type); 118 | if (it != type_to_sequence_map_.end()) { 119 | return it->second; 120 | } 121 | } 122 | 123 | int check_result = ternary_predicate_(o); 124 | 125 | if (check_result == -1) { 126 | return -1; // Type check error, not cached. 127 | } 128 | 129 | // NOTE: This is never decref'd as long as the object lives, which is likely 130 | // forever, but we don't want the type to get deleted as long as it is in 131 | // the map. This should not be too much of a leak, as there should only be a 132 | // relatively small number of types in the map, and an even smaller number 133 | // that are eligible for decref. As a precaution, we limit the size of the 134 | // map to 1024. 135 | { 136 | if (type_to_sequence_map_.size() < kMaxItemsInCache) { 137 | Py_INCREF(type); 138 | type_to_sequence_map_.insert({type, check_result}); 139 | } 140 | } 141 | 142 | return check_result; 143 | } 144 | 145 | private: 146 | std::function ternary_predicate_; 147 | std::unordered_map type_to_sequence_map_; 148 | }; 149 | 150 | py::object GetCollectionsSequenceType() { 151 | static py::object type = 152 | py::module::import("collections.abc").attr("Sequence"); 153 | return type; 154 | } 155 | 156 | py::object GetCollectionsMappingType() { 157 | static py::object type = 158 | py::module::import("collections.abc").attr("Mapping"); 159 | return type; 160 | } 161 | 162 | py::object GetCollectionsMappingViewType() { 163 | static py::object type = 164 | py::module::import("collections.abc").attr("MappingView"); 165 | return type; 166 | } 167 | 168 | py::object GetWraptObjectProxyTypeUncached() { 169 | try { 170 | return py::module::import("wrapt").attr("ObjectProxy"); 171 | } catch (const py::error_already_set& e) { 172 | if (e.matches(PyExc_ImportError)) return py::none(); 173 | throw e; 174 | } 175 | } 176 | 177 | py::object GetWraptObjectProxyType() { 178 | // TODO(gregthornton): Restore caching when deadlock issue is fixed. 179 | return GetWraptObjectProxyTypeUncached(); 180 | } 181 | 182 | // Returns 1 if `o` is considered a mapping for the purposes of Flatten(). 183 | // Returns 0 otherwise. 184 | // Returns -1 if an error occurred. 185 | int IsMappingHelper(PyObject* o) { 186 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { 187 | return PyObject_IsInstance(to_check, GetCollectionsMappingType().ptr()); 188 | }); 189 | if (PyDict_Check(o)) return true; 190 | return check_cache->CachedLookup(o); 191 | } 192 | 193 | // Returns 1 if `o` is considered a mapping view for the purposes of Flatten(). 194 | // Returns 0 otherwise. 195 | // Returns -1 if an error occurred. 196 | int IsMappingViewHelper(PyObject* o) { 197 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { 198 | return PyObject_IsInstance(to_check, GetCollectionsMappingViewType().ptr()); 199 | }); 200 | return check_cache->CachedLookup(o); 201 | } 202 | 203 | // Returns 1 if `o` is considered an object proxy 204 | // Returns 0 otherwise. 205 | // Returns -1 if an error occurred. 206 | int IsObjectProxy(PyObject* o) { 207 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { 208 | auto type = GetWraptObjectProxyType(); 209 | return !type.is_none() && PyObject_IsInstance(to_check, type.ptr()) == 1; 210 | }); 211 | return check_cache->CachedLookup(o); 212 | } 213 | 214 | // Returns 1 if `o` is an instance of attrs-decorated class. 215 | // Returns 0 otherwise. 216 | int IsAttrsHelper(PyObject* o) { 217 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { 218 | PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__")); 219 | if (cls) { 220 | return PyObject_HasAttrString(cls.get(), "__attrs_attrs__"); 221 | } 222 | 223 | // PyObject_GetAttrString returns null on error 224 | PyErr_Clear(); 225 | return 0; 226 | }); 227 | return check_cache->CachedLookup(o); 228 | } 229 | 230 | // Returns 1 if `o` is considered a sequence for the purposes of Flatten(). 231 | // Returns 0 otherwise. 232 | // Returns -1 if an error occurred. 233 | int IsSequenceHelper(PyObject* o) { 234 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { 235 | int is_instance = 236 | PyObject_IsInstance(to_check, GetCollectionsSequenceType().ptr()); 237 | 238 | // Don't cache a failed is_instance check. 239 | if (is_instance == -1) return -1; 240 | 241 | return static_cast(is_instance != 0 && !IsString(to_check)); 242 | }); // We treat dicts and other mappings as special cases of sequences. 243 | if (IsMappingHelper(o)) return true; 244 | if (IsMappingViewHelper(o)) return true; 245 | if (IsAttrsHelper(o)) return true; 246 | if (PySet_Check(o) && !WarnedThatSetIsNotSequence) { 247 | LOG_WARNING( 248 | "Sets are not currently considered sequences, " 249 | "but this may change in the future, " 250 | "so consider avoiding using them."); 251 | WarnedThatSetIsNotSequence = true; 252 | } 253 | return check_cache->CachedLookup(o); 254 | } 255 | 256 | using ValueIteratorPtr = std::unique_ptr; 257 | 258 | // Iterate through dictionaries in a deterministic order by sorting the 259 | // keys. Notice this means that we ignore the original order of 260 | // `OrderedDict` instances. This is intentional, to avoid potential 261 | // bugs caused by mixing ordered and plain dicts (e.g., flattening 262 | // a dict but using a corresponding `OrderedDict` to pack it back). 263 | class DictValueIterator : public ValueIterator { 264 | public: 265 | explicit DictValueIterator(PyObject* dict) 266 | : dict_(dict), keys_(PyDict_Keys(dict)) { 267 | if (PyList_Sort(keys_.get()) == -1) { 268 | invalidate(); 269 | } else { 270 | iter_.reset(PyObject_GetIter(keys_.get())); 271 | } 272 | } 273 | 274 | PyObjectPtr next() override { 275 | PyObjectPtr result; 276 | PyObjectPtr key(PyIter_Next(iter_.get())); 277 | if (key) { 278 | // PyDict_GetItem returns a borrowed reference. 279 | PyObject* elem = PyDict_GetItem(dict_, key.get()); 280 | if (elem) { 281 | Py_INCREF(elem); 282 | result.reset(elem); 283 | } else { 284 | PyErr_SetString(PyExc_RuntimeError, 285 | "Dictionary was modified during iteration over it"); 286 | } 287 | } 288 | return result; 289 | } 290 | 291 | private: 292 | PyObject* dict_; 293 | PyObjectPtr keys_; 294 | PyObjectPtr iter_; 295 | }; 296 | 297 | // Iterate over mapping objects by sorting the keys first 298 | class MappingValueIterator : public ValueIterator { 299 | public: 300 | explicit MappingValueIterator(PyObject* mapping) 301 | : mapping_(mapping), keys_(PyMapping_Keys(mapping)) { 302 | if (!keys_ || PyList_Sort(keys_.get()) == -1) { 303 | invalidate(); 304 | } else { 305 | iter_.reset(PyObject_GetIter(keys_.get())); 306 | } 307 | } 308 | 309 | PyObjectPtr next() override { 310 | PyObjectPtr result; 311 | PyObjectPtr key(PyIter_Next(iter_.get())); 312 | if (key) { 313 | // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference. 314 | PyObject* elem = PyObject_GetItem(mapping_, key.get()); 315 | if (elem) { 316 | result.reset(elem); 317 | } else { 318 | PyErr_SetString(PyExc_RuntimeError, 319 | "Mapping was modified during iteration over it"); 320 | } 321 | } 322 | return result; 323 | } 324 | 325 | private: 326 | PyObject* mapping_; 327 | PyObjectPtr keys_; 328 | PyObjectPtr iter_; 329 | }; 330 | 331 | // Iterate over a sequence, by index. 332 | class SequenceValueIterator : public ValueIterator { 333 | public: 334 | explicit SequenceValueIterator(PyObject* iterable) 335 | : seq_(PySequence_Fast(iterable, "")), 336 | size_(seq_.get() ? PySequence_Fast_GET_SIZE(seq_.get()) : 0), 337 | index_(0) {} 338 | 339 | PyObjectPtr next() override { 340 | PyObjectPtr result; 341 | if (index_ < size_) { 342 | // PySequence_Fast_GET_ITEM returns a borrowed reference. 343 | PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_); 344 | ++index_; 345 | if (elem) { 346 | Py_INCREF(elem); 347 | result.reset(elem); 348 | } 349 | } 350 | 351 | return result; 352 | } 353 | 354 | private: 355 | PyObjectPtr seq_; 356 | const Py_ssize_t size_; 357 | Py_ssize_t index_; 358 | }; 359 | 360 | class AttrsValueIterator : public ValueIterator { 361 | public: 362 | explicit AttrsValueIterator(PyObject* nested) : nested_(nested) { 363 | Py_INCREF(nested); 364 | cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__")); 365 | if (cls_) { 366 | attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__")); 367 | if (attrs_) { 368 | iter_.reset(PyObject_GetIter(attrs_.get())); 369 | } 370 | } 371 | if (!iter_ || PyErr_Occurred()) invalidate(); 372 | } 373 | 374 | PyObjectPtr next() override { 375 | PyObjectPtr result; 376 | PyObjectPtr item(PyIter_Next(iter_.get())); 377 | if (item) { 378 | PyObjectPtr name(PyObject_GetAttrString(item.get(), "name")); 379 | result.reset(PyObject_GetAttr(nested_.get(), name.get())); 380 | } 381 | 382 | return result; 383 | } 384 | 385 | private: 386 | PyObjectPtr nested_; 387 | PyObjectPtr cls_; 388 | PyObjectPtr attrs_; 389 | PyObjectPtr iter_; 390 | }; 391 | 392 | 393 | bool FlattenHelper( 394 | PyObject* nested, PyObject* list, 395 | const std::function& is_sequence_helper, 396 | const std::function& value_iterator_getter) { 397 | // if nested is not a sequence, append itself and exit 398 | int is_seq = is_sequence_helper(nested); 399 | if (is_seq == -1) return false; 400 | if (!is_seq) { 401 | return PyList_Append(list, nested) != -1; 402 | } 403 | 404 | ValueIteratorPtr iter = value_iterator_getter(nested); 405 | if (!iter->valid()) return false; 406 | 407 | for (PyObjectPtr item = iter->next(); item; item = iter->next()) { 408 | if (Py_EnterRecursiveCall(" in flatten")) { 409 | return false; 410 | } 411 | const bool success = FlattenHelper(item.get(), list, is_sequence_helper, 412 | value_iterator_getter); 413 | Py_LeaveRecursiveCall(); 414 | if (!success) { 415 | return false; 416 | } 417 | } 418 | return true; 419 | } 420 | 421 | // Sets error using keys of 'dict1' and 'dict2'. 422 | // 'dict1' and 'dict2' are assumed to be Python dictionaries. 423 | void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, 424 | std::string* error_msg, bool* is_type_error) { 425 | PyObjectPtr k1(PyMapping_Keys(dict1)); 426 | if (PyErr_Occurred() || k1.get() == nullptr) { 427 | *error_msg = 428 | ("The two dictionaries don't have the same set of keys. Failed to " 429 | "fetch keys."); 430 | return; 431 | } 432 | PyObjectPtr k2(PyMapping_Keys(dict2)); 433 | if (PyErr_Occurred() || k2.get() == nullptr) { 434 | *error_msg = 435 | ("The two dictionaries don't have the same set of keys. Failed to " 436 | "fetch keys."); 437 | return; 438 | } 439 | *is_type_error = false; 440 | *error_msg = absl::StrCat( 441 | "The two dictionaries don't have the same set of keys. " 442 | "First structure has keys ", 443 | PyObjectToString(k1.get()), ", while second structure has keys ", 444 | PyObjectToString(k2.get())); 445 | } 446 | 447 | // Returns true iff there were no "internal" errors. In other words, 448 | // errors that has nothing to do with structure checking. 449 | // If an "internal" error occurred, the appropriate Python error will be 450 | // set and the caller can propage it directly to the user. 451 | // 452 | // Both `error_msg` and `is_type_error` must be non-null. `error_msg` must 453 | // be empty. 454 | // Leaves `error_msg` empty if structures matched. Else, fills `error_msg` 455 | // with appropriate error and sets `is_type_error` to true iff 456 | // the error to be raised should be TypeError. 457 | bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types, 458 | std::string* error_msg, bool* is_type_error) { 459 | DCHECK(error_msg); 460 | DCHECK(is_type_error); 461 | const bool is_seq1 = IsSequence(o1); 462 | const bool is_seq2 = IsSequence(o2); 463 | if (PyErr_Occurred()) return false; 464 | if (is_seq1 != is_seq2) { 465 | std::string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2); 466 | std::string non_seq_str = 467 | is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1); 468 | *is_type_error = false; 469 | *error_msg = absl::StrCat("Substructure \"", seq_str, 470 | "\" is a sequence, while substructure \"", 471 | non_seq_str, "\" is not"); 472 | return true; 473 | } 474 | 475 | // Got to scalars, so finished checking. Structures are the same. 476 | if (!is_seq1) return true; 477 | 478 | if (check_types) { 479 | // Unwrap wrapt.ObjectProxy if needed. 480 | PyObjectPtr o1_wrapped; 481 | if (IsObjectProxy(o1)) { 482 | o1_wrapped.reset(PyObject_GetAttrString(o1, "__wrapped__")); 483 | o1 = o1_wrapped.get(); 484 | } 485 | PyObjectPtr o2_wrapped; 486 | if (IsObjectProxy(o2)) { 487 | o2_wrapped.reset(PyObject_GetAttrString(o2, "__wrapped__")); 488 | o2 = o2_wrapped.get(); 489 | } 490 | 491 | const PyTypeObject* type1 = o1->ob_type; 492 | const PyTypeObject* type2 = o2->ob_type; 493 | 494 | // We treat two different namedtuples with identical name and fields 495 | // as having the same type. 496 | const PyObject* o1_tuple = IsNamedtuple(o1, true); 497 | if (o1_tuple == nullptr) return false; 498 | const PyObject* o2_tuple = IsNamedtuple(o2, true); 499 | if (o2_tuple == nullptr) { 500 | Py_DECREF(o1_tuple); 501 | return false; 502 | } 503 | bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True; 504 | Py_DECREF(o1_tuple); 505 | Py_DECREF(o2_tuple); 506 | 507 | if (both_tuples) { 508 | const PyObject* same_tuples = SameNamedtuples(o1, o2); 509 | if (same_tuples == nullptr) return false; 510 | bool not_same_tuples = same_tuples != Py_True; 511 | Py_DECREF(same_tuples); 512 | if (not_same_tuples) { 513 | *is_type_error = true; 514 | *error_msg = absl::StrCat( 515 | "The two namedtuples don't have the same sequence type. " 516 | "First structure ", 517 | PyObjectToString(o1), " has type ", type1->tp_name, 518 | ", while second structure ", PyObjectToString(o2), " has type ", 519 | type2->tp_name); 520 | return true; 521 | } 522 | } else if (type1 != type2 523 | /* If both sequences are list types, don't complain. This allows 524 | one to be a list subclass (e.g. _ListWrapper used for 525 | automatic dependency tracking.) */ 526 | && !(PyList_Check(o1) && PyList_Check(o2)) 527 | /* Two mapping types will also compare equal, making _DictWrapper 528 | and dict compare equal. */ 529 | && !(IsMappingHelper(o1) && IsMappingHelper(o2))) { 530 | *is_type_error = true; 531 | *error_msg = absl::StrCat( 532 | "The two namedtuples don't have the same sequence type. " 533 | "First structure ", 534 | PyObjectToString(o1), " has type ", type1->tp_name, 535 | ", while second structure ", PyObjectToString(o2), " has type ", 536 | type2->tp_name); 537 | return true; 538 | } 539 | 540 | if (PyDict_Check(o1) && PyDict_Check(o2)) { 541 | if (PyDict_Size(o1) != PyDict_Size(o2)) { 542 | SetDifferentKeysError(o1, o2, error_msg, is_type_error); 543 | return true; 544 | } 545 | 546 | PyObject* key; 547 | Py_ssize_t pos = 0; 548 | while (PyDict_Next(o1, &pos, &key, nullptr)) { 549 | if (PyDict_GetItem(o2, key) == nullptr) { 550 | SetDifferentKeysError(o1, o2, error_msg, is_type_error); 551 | return true; 552 | } 553 | } 554 | } else if (IsMappingHelper(o1)) { 555 | // Fallback for custom mapping types. Instead of using PyDict methods 556 | // which stay in C, we call iter(o1). 557 | if (PyMapping_Size(o1) != PyMapping_Size(o2)) { 558 | SetDifferentKeysError(o1, o2, error_msg, is_type_error); 559 | return true; 560 | } 561 | 562 | PyObjectPtr iter(PyObject_GetIter(o1)); 563 | PyObject* key; 564 | while ((key = PyIter_Next(iter.get())) != nullptr) { 565 | if (!PyMapping_HasKey(o2, key)) { 566 | SetDifferentKeysError(o1, o2, error_msg, is_type_error); 567 | Py_DECREF(key); 568 | return true; 569 | } 570 | Py_DECREF(key); 571 | } 572 | } 573 | } 574 | 575 | ValueIteratorPtr iter1 = GetValueIterator(o1); 576 | ValueIteratorPtr iter2 = GetValueIterator(o2); 577 | 578 | if (!iter1->valid() || !iter2->valid()) return false; 579 | 580 | while (true) { 581 | PyObjectPtr v1 = iter1->next(); 582 | PyObjectPtr v2 = iter2->next(); 583 | if (v1 && v2) { 584 | if (Py_EnterRecursiveCall(" in assert_same_structure")) { 585 | return false; 586 | } 587 | bool no_internal_errors = AssertSameStructureHelper( 588 | v1.get(), v2.get(), check_types, error_msg, is_type_error); 589 | Py_LeaveRecursiveCall(); 590 | if (!no_internal_errors) return false; 591 | if (!error_msg->empty()) return true; 592 | } else if (!v1 && !v2) { 593 | // Done with all recursive calls. Structure matched. 594 | return true; 595 | } else { 596 | *is_type_error = false; 597 | *error_msg = absl::StrCat( 598 | "The two structures don't have the same number of elements. ", 599 | "First structure: ", PyObjectToString(o1), 600 | ". Second structure: ", PyObjectToString(o2)); 601 | return true; 602 | } 603 | } 604 | } 605 | 606 | } // namespace 607 | 608 | bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; } 609 | bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; } 610 | 611 | PyObject* Flatten(PyObject* nested) { 612 | PyObject* list = PyList_New(0); 613 | if (FlattenHelper(nested, list, IsSequenceHelper, GetValueIterator)) { 614 | return list; 615 | } else { 616 | Py_DECREF(list); 617 | return nullptr; 618 | } 619 | } 620 | 621 | PyObject* IsNamedtuple(PyObject* o, bool strict) { 622 | // Unwrap wrapt.ObjectProxy if needed. 623 | PyObjectPtr o_wrapped; 624 | if (IsObjectProxy(o)) { 625 | o_wrapped.reset(PyObject_GetAttrString(o, "__wrapped__")); 626 | o = o_wrapped.get(); 627 | } 628 | 629 | // Must be subclass of tuple 630 | if (!PyTuple_Check(o)) { 631 | Py_RETURN_FALSE; 632 | } 633 | 634 | // If strict, o.__class__.__base__ must be tuple 635 | if (strict) { 636 | PyObject* klass = PyObject_GetAttrString(o, "__class__"); 637 | if (klass == nullptr) return nullptr; 638 | PyObject* base = PyObject_GetAttrString(klass, "__base__"); 639 | Py_DECREF(klass); 640 | if (base == nullptr) return nullptr; 641 | 642 | const PyTypeObject* base_type = reinterpret_cast(base); 643 | // built-in object types are singletons 644 | bool tuple_base = base_type == &PyTuple_Type; 645 | Py_DECREF(base); 646 | if (!tuple_base) { 647 | Py_RETURN_FALSE; 648 | } 649 | } 650 | 651 | // o must have attribute '_fields' and every element in 652 | // '_fields' must be a string. 653 | int has_fields = PyObject_HasAttrString(o, "_fields"); 654 | if (!has_fields) { 655 | Py_RETURN_FALSE; 656 | } 657 | 658 | PyObjectPtr fields(PyObject_GetAttrString(o, "_fields")); 659 | int is_instance = 660 | PyObject_IsInstance(fields.get(), GetCollectionsSequenceType().ptr()); 661 | if (is_instance == 0) { 662 | Py_RETURN_FALSE; 663 | } else if (is_instance == -1) { 664 | return nullptr; 665 | } 666 | 667 | PyObjectPtr seq(PySequence_Fast(fields.get(), "")); 668 | const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get()); 669 | for (Py_ssize_t i = 0; i < s; ++i) { 670 | // PySequence_Fast_GET_ITEM returns borrowed ref 671 | PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i); 672 | if (!IsString(elem)) { 673 | Py_RETURN_FALSE; 674 | } 675 | } 676 | 677 | Py_RETURN_TRUE; 678 | } 679 | 680 | PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) { 681 | PyObject* f1 = PyObject_GetAttrString(o1, "_fields"); 682 | PyObject* f2 = PyObject_GetAttrString(o2, "_fields"); 683 | if (f1 == nullptr || f2 == nullptr) { 684 | Py_XDECREF(f1); 685 | Py_XDECREF(f2); 686 | PyErr_SetString( 687 | PyExc_RuntimeError, 688 | "Expected namedtuple-like objects (that have _fields attr)"); 689 | return nullptr; 690 | } 691 | 692 | if (PyObject_RichCompareBool(f1, f2, Py_NE)) { 693 | Py_RETURN_FALSE; 694 | } 695 | 696 | if (GetClassName(o1).compare(GetClassName(o2)) == 0) { 697 | Py_RETURN_TRUE; 698 | } else { 699 | Py_RETURN_FALSE; 700 | } 701 | } 702 | 703 | void AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types) { 704 | std::string error_msg; 705 | bool is_type_error = false; 706 | AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error); 707 | if (PyErr_Occurred()) { 708 | // Don't hide Python exceptions while checking (e.g. errors fetching keys 709 | // from custom mappings). 710 | return; 711 | } 712 | if (!error_msg.empty()) { 713 | PyErr_SetString( 714 | is_type_error ? PyExc_TypeError : PyExc_ValueError, 715 | absl::StrCat( 716 | "The two structures don't have the same nested structure.\n\n", 717 | "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ", 718 | PyObjectToString(o2), "\n\nMore specifically: ", error_msg) 719 | .c_str()); 720 | } 721 | } 722 | 723 | ValueIteratorPtr GetValueIterator(PyObject* nested) { 724 | if (PyDict_Check(nested)) { 725 | return absl::make_unique(nested); 726 | } else if (IsMappingHelper(nested)) { 727 | return absl::make_unique(nested); 728 | } else if (IsAttrsHelper(nested)) { 729 | return absl::make_unique(nested); 730 | } else { 731 | return absl::make_unique(nested); 732 | } 733 | } 734 | 735 | namespace { 736 | 737 | inline py::object pyo_or_throw(PyObject* ptr) { 738 | if (PyErr_Occurred() || ptr == nullptr) { 739 | throw py::error_already_set(); 740 | } 741 | return py::reinterpret_steal(ptr); 742 | } 743 | 744 | PYBIND11_MODULE(_tree, m) { 745 | // Resolve `wrapt.ObjectProxy` at import time to avoid doing 746 | // imports during function calls. 747 | tree::GetWraptObjectProxyType(); 748 | 749 | m.def("assert_same_structure", 750 | [](py::handle& o1, py::handle& o2, bool check_types) { 751 | tree::AssertSameStructure(o1.ptr(), o2.ptr(), check_types); 752 | if (PyErr_Occurred()) { 753 | throw py::error_already_set(); 754 | } 755 | }); 756 | m.def("is_sequence", [](py::handle& o) { 757 | bool result = tree::IsSequence(o.ptr()); 758 | if (PyErr_Occurred()) { 759 | throw py::error_already_set(); 760 | } 761 | return result; 762 | }); 763 | m.def("is_namedtuple", [](py::handle& o, bool strict) { 764 | return pyo_or_throw(tree::IsNamedtuple(o.ptr(), strict)); 765 | }); 766 | m.def("is_attrs", [](py::handle& o) { 767 | bool result = tree::IsAttrs(o.ptr()); 768 | if (PyErr_Occurred()) { 769 | throw py::error_already_set(); 770 | } 771 | return result; 772 | }); 773 | m.def("same_namedtuples", [](py::handle& o1, py::handle& o2) { 774 | return pyo_or_throw(tree::SameNamedtuples(o1.ptr(), o2.ptr())); 775 | }); 776 | m.def("flatten", [](py::handle& nested) { 777 | return pyo_or_throw(tree::Flatten(nested.ptr())); 778 | }); 779 | } 780 | 781 | } // namespace 782 | } // namespace tree 783 | -------------------------------------------------------------------------------- /tree/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for working with nested data structures.""" 17 | 18 | from collections import abc as collections_abc 19 | import logging 20 | import sys 21 | from typing import Mapping, Sequence, TypeVar, Union 22 | 23 | from .sequence import _is_attrs 24 | from .sequence import _is_namedtuple 25 | from .sequence import _sequence_like 26 | from .sequence import _sorted 27 | 28 | # pylint: disable=g-import-not-at-top 29 | try: 30 | import wrapt 31 | ObjectProxy = wrapt.ObjectProxy 32 | except ImportError: 33 | class ObjectProxy(object): 34 | """Stub-class for `wrapt.ObjectProxy``.""" 35 | 36 | try: 37 | from tree import _tree 38 | except ImportError: 39 | if "sphinx" not in sys.modules: 40 | raise 41 | 42 | _tree = None 43 | 44 | # pylint: enable=g-import-not-at-top 45 | 46 | __all__ = [ 47 | "is_nested", 48 | "assert_same_structure", 49 | "unflatten_as", 50 | "flatten", 51 | "flatten_up_to", 52 | "flatten_with_path", 53 | "flatten_with_path_up_to", 54 | "map_structure", 55 | "map_structure_up_to", 56 | "map_structure_with_path", 57 | "map_structure_with_path_up_to", 58 | "traverse", 59 | "MAP_TO_NONE", 60 | ] 61 | 62 | __version__ = "0.1.9" 63 | 64 | # Note: this is *not* the same as `six.string_types`, which in Python3 is just 65 | # `(str,)` (i.e. it does not include byte strings). 66 | _TEXT_OR_BYTES = (str, bytes) 67 | 68 | _SHALLOW_TREE_HAS_INVALID_KEYS = ( 69 | "The shallow_tree's keys are not a subset of the input_tree's keys. The " 70 | "shallow_tree has the following keys that are not in the input_tree: {}.") 71 | 72 | _STRUCTURES_HAVE_MISMATCHING_TYPES = ( 73 | "The two structures don't have the same sequence type. Input structure has " 74 | "type {input_type}, while shallow structure has type {shallow_type}.") 75 | 76 | _STRUCTURES_HAVE_MISMATCHING_LENGTHS = ( 77 | "The two structures don't have the same sequence length. Input " 78 | "structure has length {input_length}, while shallow structure has length " 79 | "{shallow_length}." 80 | ) 81 | 82 | _INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = ( 83 | "The input_tree has fewer elements than the shallow_tree. Input structure " 84 | "has length {input_size}, while shallow structure has length " 85 | "{shallow_size}.") 86 | 87 | _IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ = ( 88 | "If shallow structure is a sequence, input must also be a sequence. " 89 | "Input has type: {}.") 90 | 91 | _IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH = ( 92 | "If shallow structure is a sequence, input must also be a sequence. " 93 | "Input at path: {path} has type: {input_type}.") 94 | 95 | K = TypeVar("K") 96 | V = TypeVar("V") 97 | 98 | # A generic monomorphic structure type, e.g. ``StructureKV[str, int]`` 99 | # is an arbitrarily nested structure where keys must be of type ``str`` 100 | # and values are integers. 101 | StructureKV = Union[ 102 | Sequence["StructureKV[K, V]"], 103 | Mapping[K, "StructureKV[K, V]"], 104 | V, 105 | ] 106 | Structure = StructureKV[str, V] 107 | 108 | 109 | def _get_attrs_items(obj): 110 | """Returns a list of (name, value) pairs from an attrs instance. 111 | 112 | The list will be sorted by name. 113 | 114 | Args: 115 | obj: an object. 116 | 117 | Returns: 118 | A list of (attr_name, attr_value) pairs. 119 | """ 120 | return [(attr.name, getattr(obj, attr.name)) 121 | for attr in obj.__class__.__attrs_attrs__] 122 | 123 | 124 | def _yield_value(iterable): 125 | for _, v in _yield_sorted_items(iterable): 126 | yield v 127 | 128 | 129 | def _yield_sorted_items(iterable): 130 | """Yield (key, value) pairs for `iterable` in a deterministic order. 131 | 132 | For Sequences, the key will be an int, the array index of a value. 133 | For Mappings, the key will be the dictionary key. 134 | For objects (e.g. namedtuples), the key will be the attribute name. 135 | 136 | In all cases, the keys will be iterated in sorted order. 137 | 138 | Args: 139 | iterable: an iterable. 140 | 141 | Yields: 142 | The iterable's (key, value) pairs, in order of sorted keys. 143 | """ 144 | if isinstance(iterable, collections_abc.Mapping): 145 | # Iterate through dictionaries in a deterministic order by sorting the 146 | # keys. Notice this means that we ignore the original order of `OrderedDict` 147 | # instances. This is intentional, to avoid potential bugs caused by mixing 148 | # ordered and plain dicts (e.g., flattening a dict but using a 149 | # corresponding `OrderedDict` to pack it back). 150 | for key in _sorted(iterable): 151 | yield key, iterable[key] 152 | elif _is_attrs(iterable): 153 | for item in _get_attrs_items(iterable): 154 | yield item 155 | elif _is_namedtuple(iterable): 156 | for field in iterable._fields: 157 | yield (field, getattr(iterable, field)) 158 | else: 159 | for item in enumerate(iterable): 160 | yield item 161 | 162 | 163 | def _num_elements(structure): 164 | if _is_attrs(structure): 165 | return len(getattr(structure.__class__, "__attrs_attrs__")) 166 | else: 167 | return len(structure) 168 | 169 | 170 | def is_nested(structure): 171 | """Checks if a given structure is nested. 172 | 173 | >>> tree.is_nested(42) 174 | False 175 | >>> tree.is_nested({"foo": 42}) 176 | True 177 | 178 | Args: 179 | structure: A structure to check. 180 | 181 | Returns: 182 | `True` if a given structure is nested, i.e. is a sequence, a mapping, 183 | or a namedtuple, and `False` otherwise. 184 | """ 185 | return _tree.is_sequence(structure) 186 | 187 | 188 | def flatten(structure): 189 | r"""Flattens a possibly nested structure into a list. 190 | 191 | >>> tree.flatten([[1, 2, 3], [4, [5], [[6]]]]) 192 | [1, 2, 3, 4, 5, 6] 193 | 194 | If `structure` is not nested, the result is a single-element list. 195 | 196 | >>> tree.flatten(None) 197 | [None] 198 | >>> tree.flatten(1) 199 | [1] 200 | 201 | In the case of dict instances, the sequence consists of the values, 202 | sorted by key to ensure deterministic behavior. This is true also for 203 | :class:`~collections.OrderedDict` instances: their sequence order is 204 | ignored, the sorting order of keys is used instead. The same convention 205 | is followed in :func:`~tree.unflatten`. This correctly unflattens dicts 206 | and ``OrderedDict``\ s after they have been flattened, and also allows 207 | flattening an ``OrderedDict`` and then unflattening it back using a 208 | corresponding plain dict, or vice-versa. 209 | 210 | Dictionaries with non-sortable keys cannot be flattened. 211 | 212 | >>> tree.flatten({100: 'world!', 6: 'Hello'}) 213 | ['Hello', 'world!'] 214 | 215 | Args: 216 | structure: An arbitrarily nested structure. 217 | 218 | Returns: 219 | A list, the flattened version of the input `structure`. 220 | 221 | Raises: 222 | TypeError: If `structure` is or contains a mapping with non-sortable keys. 223 | """ 224 | return _tree.flatten(structure) 225 | 226 | 227 | class _DotString(object): 228 | 229 | def __str__(self): 230 | return "." 231 | 232 | def __repr__(self): 233 | return "." 234 | 235 | 236 | _DOT = _DotString() 237 | 238 | 239 | def assert_same_structure(a, b, check_types=True): 240 | """Asserts that two structures are nested in the same way. 241 | 242 | >>> tree.assert_same_structure([(0, 1)], [(2, 3)]) 243 | 244 | Note that namedtuples with identical name and fields are always considered 245 | to have the same shallow structure (even with `check_types=True`). 246 | 247 | >>> Foo = collections.namedtuple('Foo', ['a', 'b']) 248 | >>> AlsoFoo = collections.namedtuple('Foo', ['a', 'b']) 249 | >>> tree.assert_same_structure(Foo(0, 1), AlsoFoo(2, 3)) 250 | 251 | Named tuples with different names are considered to have different shallow 252 | structures: 253 | 254 | >>> Bar = collections.namedtuple('Bar', ['a', 'b']) 255 | >>> tree.assert_same_structure(Foo(0, 1), Bar(2, 3)) 256 | Traceback (most recent call last): 257 | ... 258 | TypeError: The two structures don't have the same nested structure. 259 | ... 260 | 261 | Args: 262 | a: an arbitrarily nested structure. 263 | b: an arbitrarily nested structure. 264 | check_types: if `True` (default) types of sequences are checked as 265 | well, including the keys of dictionaries. If set to `False`, for example 266 | a list and a tuple of objects will look the same if they have the same 267 | size. Note that namedtuples with identical name and fields are always 268 | considered to have the same shallow structure. 269 | 270 | Raises: 271 | ValueError: If the two structures do not have the same number of elements or 272 | if the two structures are not nested in the same way. 273 | TypeError: If the two structures differ in the type of sequence in any of 274 | their substructures. Only possible if `check_types` is `True`. 275 | """ 276 | try: 277 | _tree.assert_same_structure(a, b, check_types) 278 | except (ValueError, TypeError) as e: 279 | str1 = str(map_structure(lambda _: _DOT, a)) 280 | str2 = str(map_structure(lambda _: _DOT, b)) 281 | raise type(e)("%s\n" 282 | "Entire first structure:\n%s\n" 283 | "Entire second structure:\n%s" 284 | % (e, str1, str2)) 285 | 286 | 287 | def _packed_nest_with_indices(structure, flat, index): 288 | """Helper function for ``unflatten_as``. 289 | 290 | Args: 291 | structure: Substructure (list / tuple / dict) to mimic. 292 | flat: Flattened values to output substructure for. 293 | index: Index at which to start reading from flat. 294 | 295 | Returns: 296 | The tuple (new_index, child), where: 297 | * new_index - the updated index into `flat` having processed `structure`. 298 | * packed - the subset of `flat` corresponding to `structure`, 299 | having started at `index`, and packed into the same nested 300 | format. 301 | 302 | Raises: 303 | ValueError: if `structure` contains more elements than `flat` 304 | (assuming indexing starts from `index`). 305 | """ 306 | packed = [] 307 | for s in _yield_value(structure): 308 | if is_nested(s): 309 | new_index, child = _packed_nest_with_indices(s, flat, index) 310 | packed.append(_sequence_like(s, child)) 311 | index = new_index 312 | else: 313 | packed.append(flat[index]) 314 | index += 1 315 | return index, packed 316 | 317 | 318 | def unflatten_as(structure, flat_sequence): 319 | r"""Unflattens a sequence into a given structure. 320 | 321 | >>> tree.unflatten_as([[1, 2], [[3], [4]]], [5, 6, 7, 8]) 322 | [[5, 6], [[7], [8]]] 323 | 324 | If `structure` is a scalar, `flat_sequence` must be a single-element list; 325 | in this case the return value is ``flat_sequence[0]``. 326 | 327 | >>> tree.unflatten_as(None, [1]) 328 | 1 329 | 330 | If `structure` is or contains a dict instance, the keys will be sorted to 331 | pack the flat sequence in deterministic order. This is true also for 332 | :class:`~collections.OrderedDict` instances: their sequence order is 333 | ignored, the sorting order of keys is used instead. The same convention 334 | is followed in :func:`~tree.flatten`. This correctly unflattens dicts 335 | and ``OrderedDict``\ s after they have been flattened, and also allows 336 | flattening an ``OrderedDict`` and then unflattening it back using a 337 | corresponding plain dict, or vice-versa. 338 | 339 | Dictionaries with non-sortable keys cannot be unflattened. 340 | 341 | >>> tree.unflatten_as({1: None, 2: None}, ['Hello', 'world!']) 342 | {1: 'Hello', 2: 'world!'} 343 | 344 | Args: 345 | structure: Arbitrarily nested structure. 346 | flat_sequence: Sequence to unflatten. 347 | 348 | Returns: 349 | `flat_sequence` unflattened into `structure`. 350 | 351 | Raises: 352 | ValueError: If `flat_sequence` and `structure` have different 353 | element counts. 354 | TypeError: If `structure` is or contains a mapping with non-sortable keys. 355 | """ 356 | if not is_nested(flat_sequence): 357 | raise TypeError("flat_sequence must be a sequence not a {}:\n{}".format( 358 | type(flat_sequence), flat_sequence)) 359 | 360 | if not is_nested(structure): 361 | if len(flat_sequence) != 1: 362 | raise ValueError("Structure is a scalar but len(flat_sequence) == %d > 1" 363 | % len(flat_sequence)) 364 | return flat_sequence[0] 365 | 366 | flat_structure = flatten(structure) 367 | if len(flat_structure) != len(flat_sequence): 368 | raise ValueError( 369 | "Could not pack sequence. Structure had %d elements, but flat_sequence " 370 | "had %d elements. Structure: %s, flat_sequence: %s." 371 | % (len(flat_structure), len(flat_sequence), structure, flat_sequence)) 372 | 373 | _, packed = _packed_nest_with_indices(structure, flat_sequence, 0) 374 | return _sequence_like(structure, packed) 375 | 376 | 377 | def map_structure(func, *structures, **kwargs): # pylint: disable=redefined-builtin 378 | """Maps `func` through given structures. 379 | 380 | >>> structure = [[1], [2], [3]] 381 | >>> tree.map_structure(lambda v: v**2, structure) 382 | [[1], [4], [9]] 383 | >>> tree.map_structure(lambda x, y: x * y, structure, structure) 384 | [[1], [4], [9]] 385 | >>> Foo = collections.namedtuple('Foo', ['a', 'b']) 386 | >>> structure = Foo(a=1, b=2) 387 | >>> tree.map_structure(lambda v: v * 2, structure) 388 | Foo(a=2, b=4) 389 | 390 | Args: 391 | func: A callable that accepts as many arguments as there are structures. 392 | *structures: Arbitrarily nested structures of the same layout. 393 | **kwargs: The only valid keyword argument is `check_types`. If `True` 394 | (default) the types of components within the structures have 395 | to be match, e.g. ``tree.map_structure(func, [1], (1,))`` will raise 396 | a `TypeError`, otherwise this is not enforced. Note that namedtuples 397 | with identical name and fields are considered to be the same type. 398 | 399 | Returns: 400 | A new structure with the same layout as the given ones. If the 401 | `structures` have components of varying types, the resulting structure 402 | will use the same types as ``structures[0]``. 403 | 404 | Raises: 405 | TypeError: If `func` is not callable. 406 | ValueError: If the two structures do not have the same number of elements or 407 | if the two structures are not nested in the same way. 408 | TypeError: If `check_types` is `True` and any two `structures` 409 | differ in the types of their components. 410 | ValueError: If no structures were given or if a keyword argument other 411 | than `check_types` is provided. 412 | """ 413 | if not callable(func): 414 | raise TypeError("func must be callable, got: %s" % func) 415 | 416 | if not structures: 417 | raise ValueError("Must provide at least one structure") 418 | 419 | check_types = kwargs.pop("check_types", True) 420 | if kwargs: 421 | raise ValueError( 422 | "Only valid keyword arguments are `check_types` " 423 | "not: `%s`" % ("`, `".join(kwargs.keys()))) 424 | 425 | for other in structures[1:]: 426 | assert_same_structure(structures[0], other, check_types=check_types) 427 | return unflatten_as(structures[0], 428 | [func(*args) for args in zip(*map(flatten, structures))]) 429 | 430 | 431 | def map_structure_with_path(func, *structures, **kwargs): 432 | """Maps `func` through given structures. 433 | 434 | This is a variant of :func:`~tree.map_structure` which accumulates 435 | a *path* while mapping through the structures. A path is a tuple of 436 | indices and/or keys which uniquely identifies the positions of the 437 | arguments passed to `func`. 438 | 439 | >>> tree.map_structure_with_path( 440 | ... lambda path, v: (path, v**2), 441 | ... [{"foo": 42}]) 442 | [{'foo': ((0, 'foo'), 1764)}] 443 | 444 | Args: 445 | func: A callable that accepts a path and as many arguments as there are 446 | structures. 447 | *structures: Arbitrarily nested structures of the same layout. 448 | **kwargs: The only valid keyword argument is `check_types`. If `True` 449 | (default) the types of components within the structures have to be match, 450 | e.g. ``tree.map_structure_with_path(func, [1], (1,))`` will raise a 451 | `TypeError`, otherwise this is not enforced. Note that namedtuples with 452 | identical name and fields are considered to be the same type. 453 | 454 | Returns: 455 | A new structure with the same layout as the given ones. If the 456 | `structures` have components of varying types, the resulting structure 457 | will use the same types as ``structures[0]``. 458 | 459 | Raises: 460 | TypeError: If `func` is not callable or if the `structures` do not 461 | have the same layout. 462 | TypeError: If `check_types` is `True` and any two `structures` 463 | differ in the types of their components. 464 | ValueError: If no structures were given or if a keyword argument other 465 | than `check_types` is provided. 466 | """ 467 | return map_structure_with_path_up_to(structures[0], func, *structures, 468 | **kwargs) 469 | 470 | 471 | def _yield_flat_up_to(shallow_tree, input_tree, path=()): 472 | """Yields (path, value) pairs of input_tree flattened up to shallow_tree. 473 | 474 | Args: 475 | shallow_tree: Nested structure. Traverse no further than its leaf nodes. 476 | input_tree: Nested structure. Return the paths and values from this tree. 477 | Must have the same upper structure as shallow_tree. 478 | path: Tuple. Optional argument, only used when recursing. The path from the 479 | root of the original shallow_tree, down to the root of the shallow_tree 480 | arg of this recursive call. 481 | 482 | Yields: 483 | Pairs of (path, value), where path the tuple path of a leaf node in 484 | shallow_tree, and value is the value of the corresponding node in 485 | input_tree. 486 | """ 487 | if (isinstance(shallow_tree, _TEXT_OR_BYTES) or 488 | not (isinstance(shallow_tree, (collections_abc.Mapping, 489 | collections_abc.Sequence)) or 490 | _is_namedtuple(shallow_tree) or 491 | _is_attrs(shallow_tree))): 492 | yield (path, input_tree) 493 | else: 494 | input_tree = dict(_yield_sorted_items(input_tree)) 495 | for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree): 496 | subpath = path + (shallow_key,) 497 | input_subtree = input_tree[shallow_key] 498 | for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree, 499 | input_subtree, 500 | path=subpath): 501 | yield (leaf_path, leaf_value) 502 | 503 | 504 | def _multiyield_flat_up_to(shallow_tree, *input_trees): 505 | """Same as `_yield_flat_up_to`, but takes multiple input trees.""" 506 | zipped_iterators = zip(*[_yield_flat_up_to(shallow_tree, input_tree) 507 | for input_tree in input_trees]) 508 | try: 509 | for paths_and_values in zipped_iterators: 510 | paths, values = zip(*paths_and_values) 511 | yield paths[:1] + values 512 | except KeyError as e: 513 | paths = locals().get("paths", ((),)) 514 | raise ValueError(f"Could not find key '{e.args[0]}' in some `input_trees`. " 515 | "Please ensure the structure of all `input_trees` are " 516 | "compatible with `shallow_tree`. The last valid path " 517 | f"yielded was {paths[0]}.") from e 518 | 519 | 520 | def _assert_shallow_structure(shallow_tree, 521 | input_tree, 522 | path=None, 523 | check_types=True): 524 | """Asserts that `shallow_tree` is a shallow structure of `input_tree`. 525 | 526 | That is, this function recursively tests if each key in shallow_tree has its 527 | corresponding key in input_tree. 528 | 529 | Examples: 530 | 531 | The following code will raise an exception: 532 | 533 | >>> shallow_tree = {"a": "A", "b": "B"} 534 | >>> input_tree = {"a": 1, "c": 2} 535 | >>> _assert_shallow_structure(shallow_tree, input_tree) 536 | Traceback (most recent call last): 537 | ... 538 | ValueError: The shallow_tree's keys are not a subset of the input_tree's ... 539 | 540 | The following code will raise an exception: 541 | 542 | >>> shallow_tree = ["a", "b"] 543 | >>> input_tree = ["c", ["d", "e"], "f"] 544 | >>> _assert_shallow_structure(shallow_tree, input_tree) 545 | Traceback (most recent call last): 546 | ... 547 | ValueError: The two structures don't have the same sequence length. ... 548 | 549 | By setting check_types=False, we drop the requirement that corresponding 550 | nodes in shallow_tree and input_tree have to be the same type. Sequences 551 | are treated equivalently to Mappables that map integer keys (indices) to 552 | values. The following code will therefore not raise an exception: 553 | 554 | >>> _assert_shallow_structure({0: "foo"}, ["foo"], check_types=False) 555 | 556 | Args: 557 | shallow_tree: an arbitrarily nested structure. 558 | input_tree: an arbitrarily nested structure. 559 | path: if not `None`, a tuple containing the current path in the nested 560 | structure. This is only used for more informative errror messages. 561 | check_types: if `True` (default) the sequence types of `shallow_tree` and 562 | `input_tree` have to be the same. 563 | 564 | Raises: 565 | TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 566 | TypeError: If the sequence types of `shallow_tree` are different from 567 | `input_tree`. Only raised if `check_types` is `True`. 568 | ValueError: If the sequence lengths of `shallow_tree` are different from 569 | `input_tree`. 570 | """ 571 | if is_nested(shallow_tree): 572 | if not is_nested(input_tree): 573 | if path is not None: 574 | raise TypeError( 575 | _IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format( 576 | path=list(path), input_type=type(input_tree))) 577 | else: 578 | raise TypeError( 579 | _IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format( 580 | type(input_tree))) 581 | 582 | if isinstance(shallow_tree, ObjectProxy): 583 | shallow_type = type(shallow_tree.__wrapped__) 584 | else: 585 | shallow_type = type(shallow_tree) 586 | 587 | if check_types and not isinstance(input_tree, shallow_type): 588 | # Duck-typing means that nest should be fine with two different 589 | # namedtuples with identical name and fields. 590 | shallow_is_namedtuple = _is_namedtuple(shallow_tree, False) 591 | input_is_namedtuple = _is_namedtuple(input_tree, False) 592 | if shallow_is_namedtuple and input_is_namedtuple: 593 | # pylint: disable=protected-access 594 | if not _tree.same_namedtuples(shallow_tree, input_tree): 595 | raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( 596 | input_type=type(input_tree), 597 | shallow_type=shallow_type)) 598 | # pylint: enable=protected-access 599 | elif not (isinstance(shallow_tree, collections_abc.Mapping) 600 | and isinstance(input_tree, collections_abc.Mapping)): 601 | raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( 602 | input_type=type(input_tree), 603 | shallow_type=shallow_type)) 604 | 605 | if _num_elements(input_tree) != _num_elements(shallow_tree): 606 | raise ValueError( 607 | _STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( 608 | input_length=_num_elements(input_tree), 609 | shallow_length=_num_elements(shallow_tree))) 610 | elif _num_elements(input_tree) < _num_elements(shallow_tree): 611 | raise ValueError( 612 | _INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format( 613 | input_size=_num_elements(input_tree), 614 | shallow_size=_num_elements(shallow_tree))) 615 | 616 | shallow_iter = _yield_sorted_items(shallow_tree) 617 | input_iter = _yield_sorted_items(input_tree) 618 | 619 | def get_matching_input_branch(shallow_key): 620 | for input_key, input_branch in input_iter: 621 | if input_key == shallow_key: 622 | return input_branch 623 | 624 | raise ValueError(_SHALLOW_TREE_HAS_INVALID_KEYS.format([shallow_key])) 625 | 626 | for shallow_key, shallow_branch in shallow_iter: 627 | input_branch = get_matching_input_branch(shallow_key) 628 | _assert_shallow_structure( 629 | shallow_branch, 630 | input_branch, 631 | path + (shallow_key,) if path is not None else None, 632 | check_types=check_types) 633 | 634 | 635 | def flatten_up_to(shallow_structure, input_structure, check_types=True): 636 | """Flattens `input_structure` up to `shallow_structure`. 637 | 638 | All further nested components in `input_structure` are retained as-is. 639 | 640 | >>> structure = [[1, 1], [2, 2]] 641 | >>> tree.flatten_up_to([None, None], structure) 642 | [[1, 1], [2, 2]] 643 | >>> tree.flatten_up_to([None, [None, None]], structure) 644 | [[1, 1], 2, 2] 645 | 646 | If `shallow_structure` and `input_structure` are not nested, the 647 | result is a single-element list: 648 | 649 | >>> tree.flatten_up_to(42, 1) 650 | [1] 651 | >>> tree.flatten_up_to(42, [1, 2, 3]) 652 | [[1, 2, 3]] 653 | 654 | Args: 655 | shallow_structure: A structure with the same (but possibly more shallow) 656 | layout as `input_structure`. 657 | input_structure: An arbitrarily nested structure. 658 | check_types: If `True`, check that each node in shallow_tree has the 659 | same type as the corresponding node in `input_structure`. 660 | 661 | Returns: 662 | A list, the partially flattened version of `input_structure` wrt 663 | `shallow_structure`. 664 | 665 | Raises: 666 | TypeError: If the layout of `shallow_structure` does not match that of 667 | `input_structure`. 668 | TypeError: If `check_types` is `True` and `shallow_structure` and 669 | `input_structure` differ in the types of their components. 670 | """ 671 | _assert_shallow_structure( 672 | shallow_structure, input_structure, path=None, check_types=check_types) 673 | # Discard paths returned by _yield_flat_up_to. 674 | return [v for _, v in _yield_flat_up_to(shallow_structure, input_structure)] 675 | 676 | 677 | def flatten_with_path_up_to(shallow_structure, 678 | input_structure, 679 | check_types=True): 680 | """Flattens `input_structure` up to `shallow_structure`. 681 | 682 | This is a combination of :func:`~tree.flatten_up_to` and 683 | :func:`~tree.flatten_with_path` 684 | 685 | Args: 686 | shallow_structure: A structure with the same (but possibly more shallow) 687 | layout as `input_structure`. 688 | input_structure: An arbitrarily nested structure. 689 | check_types: If `True`, check that each node in shallow_tree has the 690 | same type as the corresponding node in `input_structure`. 691 | 692 | Returns: 693 | A list of ``(path, item)`` pairs corresponding to the partially flattened 694 | version of `input_structure` wrt `shallow_structure`. 695 | 696 | Raises: 697 | TypeError: If the layout of `shallow_structure` does not match that of 698 | `input_structure`. 699 | TypeError: If `input_structure` is or contains a mapping with non-sortable 700 | keys. 701 | TypeError: If `check_types` is `True` and `shallow_structure` and 702 | `input_structure` differ in the types of their components. 703 | """ 704 | _assert_shallow_structure( 705 | shallow_structure, input_structure, path=(), check_types=check_types) 706 | return list(_yield_flat_up_to(shallow_structure, input_structure)) 707 | 708 | 709 | def map_structure_up_to(shallow_structure, func, *structures, **kwargs): 710 | """Maps `func` through given structures up to `shallow_structure`. 711 | 712 | This is a variant of :func:`~tree.map_structure` which only maps 713 | the given structures up to `shallow_structure`. All further nested 714 | components are retained as-is. 715 | 716 | >>> structure = [[1, 1], [2, 2]] 717 | >>> tree.map_structure_up_to([None, None], len, structure) 718 | [2, 2] 719 | >>> tree.map_structure_up_to([None, [None, None]], str, structure) 720 | ['[1, 1]', ['2', '2']] 721 | 722 | Args: 723 | shallow_structure: A structure with layout common to all `structures`. 724 | func: A callable that accepts as many arguments as there are structures. 725 | *structures: Arbitrarily nested structures of the same layout. 726 | **kwargs: No valid keyword arguments. 727 | Raises: 728 | ValueError: If `func` is not callable or if `structures` have different 729 | layout or if the layout of `shallow_structure` does not match that of 730 | `structures` or if no structures were given. 731 | 732 | Returns: 733 | A new structure with the same layout as `shallow_structure`. 734 | """ 735 | return map_structure_with_path_up_to( 736 | shallow_structure, 737 | lambda _, *args: func(*args), # Discards path. 738 | *structures, 739 | **kwargs) 740 | 741 | 742 | def map_structure_with_path_up_to(shallow_structure, func, *structures, 743 | **kwargs): 744 | """Maps `func` through given structures up to `shallow_structure`. 745 | 746 | This is a combination of :func:`~tree.map_structure_up_to` and 747 | :func:`~tree.map_structure_with_path` 748 | 749 | Args: 750 | shallow_structure: A structure with layout common to all `structures`. 751 | func: A callable that accepts a path and as many arguments as there are 752 | structures. 753 | *structures: Arbitrarily nested structures of the same layout. 754 | **kwargs: No valid keyword arguments. 755 | 756 | Raises: 757 | ValueError: If `func` is not callable or if `structures` have different 758 | layout or if the layout of `shallow_structure` does not match that of 759 | `structures` or if no structures were given. 760 | 761 | Returns: 762 | Result of repeatedly applying `func`. Has the same structure layout 763 | as `shallow_tree`. 764 | """ 765 | if "check_types" in kwargs: 766 | logging.warning("The use of `check_types` is deprecated and does not have " 767 | "any effect.") 768 | del kwargs 769 | results = [] 770 | for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures): 771 | results.append(func(*path_and_values)) 772 | return unflatten_as(shallow_structure, results) 773 | 774 | 775 | def flatten_with_path(structure): 776 | r"""Flattens a possibly nested structure into a list. 777 | 778 | This is a variant of :func:`~tree.flattens` which produces a list of 779 | pairs: ``(path, item)``. A path is a tuple of indices and/or keys 780 | which uniquely identifies the position of the corresponding ``item``. 781 | 782 | >>> tree.flatten_with_path([{"foo": 42}]) 783 | [((0, 'foo'), 42)] 784 | 785 | Args: 786 | structure: An arbitrarily nested structure. 787 | 788 | Returns: 789 | A list of ``(path, item)`` pairs corresponding to the flattened version 790 | of the input `structure`. 791 | 792 | Raises: 793 | TypeError: 794 | If ``structure`` is or contains a mapping with non-sortable keys. 795 | """ 796 | return list(_yield_flat_up_to(structure, structure)) 797 | 798 | 799 | #: Special value for use with :func:`traverse`. 800 | MAP_TO_NONE = object() 801 | 802 | 803 | def traverse(fn, structure, top_down=True): 804 | """Traverses the given nested structure, applying the given function. 805 | 806 | The traversal is depth-first. If ``top_down`` is True (default), parents 807 | are returned before their children (giving the option to avoid traversing 808 | into a sub-tree). 809 | 810 | >>> visited = [] 811 | >>> tree.traverse(visited.append, [(1, 2), [3], {"a": 4}], top_down=True) 812 | [(1, 2), [3], {'a': 4}] 813 | >>> visited 814 | [[(1, 2), [3], {'a': 4}], (1, 2), 1, 2, [3], 3, {'a': 4}, 4] 815 | 816 | >>> visited = [] 817 | >>> tree.traverse(visited.append, [(1, 2), [3], {"a": 4}], top_down=False) 818 | [(1, 2), [3], {'a': 4}] 819 | >>> visited 820 | [1, 2, (1, 2), 3, [3], 4, {'a': 4}, [(1, 2), [3], {'a': 4}]] 821 | 822 | Args: 823 | fn: The function to be applied to each sub-nest of the structure. 824 | 825 | When traversing top-down: 826 | If ``fn(subtree) is None`` the traversal continues into the sub-tree. 827 | If ``fn(subtree) is not None`` the traversal does not continue into 828 | the sub-tree. The sub-tree will be replaced by ``fn(subtree)`` in the 829 | returned structure (to replace the sub-tree with None, use the special 830 | value :data:`MAP_TO_NONE`). 831 | 832 | When traversing bottom-up: 833 | If ``fn(subtree) is None`` the traversed sub-tree is returned unaltered. 834 | If ``fn(subtree) is not None`` the sub-tree will be replaced by 835 | ``fn(subtree)`` in the returned structure (to replace the sub-tree 836 | with None, use the special value :data:`MAP_TO_NONE`). 837 | 838 | structure: The structure to traverse. 839 | top_down: If True, parent structures will be visited before their children. 840 | 841 | Returns: 842 | The structured output from the traversal. 843 | """ 844 | return traverse_with_path(lambda _, x: fn(x), structure, top_down=top_down) 845 | 846 | 847 | def traverse_with_path(fn, structure, top_down=True): 848 | """Traverses the given nested structure, applying the given function. 849 | 850 | The traversal is depth-first. If ``top_down`` is True (default), parents 851 | are returned before their children (giving the option to avoid traversing 852 | into a sub-tree). 853 | 854 | >>> visited = [] 855 | >>> tree.traverse_with_path( 856 | ... lambda path, subtree: visited.append((path, subtree)), 857 | ... [(1, 2), [3], {"a": 4}], 858 | ... top_down=True) 859 | [(1, 2), [3], {'a': 4}] 860 | >>> visited == [ 861 | ... ((), [(1, 2), [3], {'a': 4}]), 862 | ... ((0,), (1, 2)), 863 | ... ((0, 0), 1), 864 | ... ((0, 1), 2), 865 | ... ((1,), [3]), 866 | ... ((1, 0), 3), 867 | ... ((2,), {'a': 4}), 868 | ... ((2, 'a'), 4)] 869 | True 870 | 871 | >>> visited = [] 872 | >>> tree.traverse_with_path( 873 | ... lambda path, subtree: visited.append((path, subtree)), 874 | ... [(1, 2), [3], {"a": 4}], 875 | ... top_down=False) 876 | [(1, 2), [3], {'a': 4}] 877 | >>> visited == [ 878 | ... ((0, 0), 1), 879 | ... ((0, 1), 2), 880 | ... ((0,), (1, 2)), 881 | ... ((1, 0), 3), 882 | ... ((1,), [3]), 883 | ... ((2, 'a'), 4), 884 | ... ((2,), {'a': 4}), 885 | ... ((), [(1, 2), [3], {'a': 4}])] 886 | True 887 | 888 | Args: 889 | fn: The function to be applied to the path to each sub-nest of the structure 890 | and the sub-nest value. 891 | When traversing top-down: If ``fn(path, subtree) is None`` the traversal 892 | continues into the sub-tree. If ``fn(path, subtree) is not None`` the 893 | traversal does not continue into the sub-tree. The sub-tree will be 894 | replaced by ``fn(path, subtree)`` in the returned structure (to replace 895 | the sub-tree with None, use the special 896 | value :data:`MAP_TO_NONE`). 897 | When traversing bottom-up: If ``fn(path, subtree) is None`` the traversed 898 | sub-tree is returned unaltered. If ``fn(path, subtree) is not None`` the 899 | sub-tree will be replaced by ``fn(path, subtree)`` in the returned 900 | structure (to replace the sub-tree 901 | with None, use the special value :data:`MAP_TO_NONE`). 902 | structure: The structure to traverse. 903 | top_down: If True, parent structures will be visited before their children. 904 | 905 | Returns: 906 | The structured output from the traversal. 907 | """ 908 | 909 | def traverse_impl(path, structure): 910 | """Recursive traversal implementation.""" 911 | 912 | def subtree_fn(item): 913 | subtree_path, subtree = item 914 | return traverse_impl(path + (subtree_path,), subtree) 915 | 916 | def traverse_subtrees(): 917 | if is_nested(structure): 918 | return _sequence_like(structure, 919 | map(subtree_fn, _yield_sorted_items(structure))) 920 | else: 921 | return structure 922 | 923 | if top_down: 924 | ret = fn(path, structure) 925 | if ret is None: 926 | return traverse_subtrees() 927 | elif ret is MAP_TO_NONE: 928 | return None 929 | else: 930 | return ret 931 | else: 932 | traversed_structure = traverse_subtrees() 933 | ret = fn(path, traversed_structure) 934 | if ret is None: 935 | return traversed_structure 936 | elif ret is MAP_TO_NONE: 937 | return None 938 | else: 939 | return ret 940 | 941 | return traverse_impl((), structure) 942 | 943 | 944 | -------------------------------------------------------------------------------- /tree/tree_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for utilities working with arbitrarily nested structures.""" 16 | 17 | import collections 18 | import doctest 19 | import types 20 | from typing import Any, Iterator, Mapping 21 | import unittest 22 | 23 | from absl.testing import parameterized 24 | import attr 25 | import numpy as np 26 | import tree 27 | import wrapt 28 | 29 | STRUCTURE1 = (((1, 2), 3), 4, (5, 6)) 30 | STRUCTURE2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) 31 | STRUCTURE_DIFFERENT_NUM_ELEMENTS = ("spam", "eggs") 32 | STRUCTURE_DIFFERENT_NESTING = (((1, 2), 3), 4, 5, (6,)) 33 | 34 | 35 | class DoctestTest(parameterized.TestCase): 36 | 37 | def testDoctest(self): 38 | extraglobs = { 39 | "collections": collections, 40 | "tree": tree, 41 | } 42 | num_failed, num_attempted = doctest.testmod( 43 | tree, extraglobs=extraglobs, optionflags=doctest.ELLIPSIS) 44 | self.assertGreater(num_attempted, 0, "No doctests found.") 45 | self.assertEqual(num_failed, 0, "{} doctests failed".format(num_failed)) 46 | 47 | 48 | class NestTest(parameterized.TestCase): 49 | 50 | def assertAllEquals(self, a, b): 51 | self.assertTrue((np.asarray(a) == b).all()) 52 | 53 | def testAttrsFlattenAndUnflatten(self): 54 | 55 | class BadAttr(object): 56 | """Class that has a non-iterable __attrs_attrs__.""" 57 | __attrs_attrs__ = None 58 | 59 | @attr.s 60 | class SampleAttr(object): 61 | field1 = attr.ib() 62 | field2 = attr.ib() 63 | 64 | field_values = [1, 2] 65 | sample_attr = SampleAttr(*field_values) 66 | self.assertFalse(tree._is_attrs(field_values)) 67 | self.assertTrue(tree._is_attrs(sample_attr)) 68 | flat = tree.flatten(sample_attr) 69 | self.assertEqual(field_values, flat) 70 | restructured_from_flat = tree.unflatten_as(sample_attr, flat) 71 | self.assertIsInstance(restructured_from_flat, SampleAttr) 72 | self.assertEqual(restructured_from_flat, sample_attr) 73 | 74 | # Check that flatten fails if attributes are not iterable 75 | with self.assertRaisesRegex(TypeError, "object is not iterable"): 76 | flat = tree.flatten(BadAttr()) 77 | 78 | @parameterized.parameters([ 79 | (1, 2, 3), 80 | ({"B": 10, "A": 20}, [1, 2], 3), 81 | ((1, 2), [3, 4], 5), 82 | (collections.namedtuple("Point", ["x", "y"])(1, 2), 3, 4), 83 | wrapt.ObjectProxy( 84 | (collections.namedtuple("Point", ["x", "y"])(1, 2), 3, 4)) 85 | ]) 86 | def testAttrsMapStructure(self, *field_values): 87 | @attr.s 88 | class SampleAttr(object): 89 | field3 = attr.ib() 90 | field1 = attr.ib() 91 | field2 = attr.ib() 92 | 93 | structure = SampleAttr(*field_values) 94 | new_structure = tree.map_structure(lambda x: x, structure) 95 | self.assertEqual(structure, new_structure) 96 | 97 | def testFlattenAndUnflatten(self): 98 | structure = ((3, 4), 5, (6, 7, (9, 10), 8)) 99 | flat = ["a", "b", "c", "d", "e", "f", "g", "h"] 100 | self.assertEqual(tree.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) 101 | self.assertEqual( 102 | tree.unflatten_as(structure, flat), 103 | (("a", "b"), "c", ("d", "e", ("f", "g"), "h"))) 104 | point = collections.namedtuple("Point", ["x", "y"]) 105 | structure = (point(x=4, y=2), ((point(x=1, y=0),),)) 106 | flat = [4, 2, 1, 0] 107 | self.assertEqual(tree.flatten(structure), flat) 108 | restructured_from_flat = tree.unflatten_as(structure, flat) 109 | self.assertEqual(restructured_from_flat, structure) 110 | self.assertEqual(restructured_from_flat[0].x, 4) 111 | self.assertEqual(restructured_from_flat[0].y, 2) 112 | self.assertEqual(restructured_from_flat[1][0][0].x, 1) 113 | self.assertEqual(restructured_from_flat[1][0][0].y, 0) 114 | 115 | self.assertEqual([5], tree.flatten(5)) 116 | self.assertEqual([np.array([5])], tree.flatten(np.array([5]))) 117 | 118 | self.assertEqual("a", tree.unflatten_as(5, ["a"])) 119 | self.assertEqual( 120 | np.array([5]), tree.unflatten_as("scalar", [np.array([5])])) 121 | 122 | with self.assertRaisesRegex(ValueError, "Structure is a scalar"): 123 | tree.unflatten_as("scalar", [4, 5]) 124 | 125 | with self.assertRaisesRegex(TypeError, "flat_sequence"): 126 | tree.unflatten_as([4, 5], "bad_sequence") 127 | 128 | with self.assertRaises(ValueError): 129 | tree.unflatten_as([5, 6, [7, 8]], ["a", "b", "c"]) 130 | 131 | def testFlattenDictOrder(self): 132 | ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) 133 | plain = {"d": 3, "b": 1, "a": 0, "c": 2} 134 | ordered_flat = tree.flatten(ordered) 135 | plain_flat = tree.flatten(plain) 136 | self.assertEqual([0, 1, 2, 3], ordered_flat) 137 | self.assertEqual([0, 1, 2, 3], plain_flat) 138 | 139 | def testUnflattenDictOrder(self): 140 | ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) 141 | plain = {"d": 0, "b": 0, "a": 0, "c": 0} 142 | seq = [0, 1, 2, 3] 143 | ordered_reconstruction = tree.unflatten_as(ordered, seq) 144 | plain_reconstruction = tree.unflatten_as(plain, seq) 145 | self.assertEqual( 146 | collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), 147 | ordered_reconstruction) 148 | self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) 149 | 150 | def testFlattenAndUnflatten_withDicts(self): 151 | # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. 152 | named_tuple = collections.namedtuple("A", ("b", "c")) 153 | mess = [ 154 | "z", 155 | named_tuple(3, 4), 156 | { 157 | "c": [ 158 | 1, 159 | collections.OrderedDict([ 160 | ("b", 3), 161 | ("a", 2), 162 | ]), 163 | ], 164 | "b": 5 165 | }, 166 | 17 167 | ] 168 | 169 | flattened = tree.flatten(mess) 170 | self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17]) 171 | 172 | structure_of_mess = [ 173 | 14, 174 | named_tuple("a", True), 175 | { 176 | "c": [ 177 | 0, 178 | collections.OrderedDict([ 179 | ("b", 9), 180 | ("a", 8), 181 | ]), 182 | ], 183 | "b": 3 184 | }, 185 | "hi everybody", 186 | ] 187 | 188 | self.assertEqual(mess, tree.unflatten_as(structure_of_mess, flattened)) 189 | 190 | # Check also that the OrderedDict was created, with the correct key order. 191 | unflattened_ordered_dict = tree.unflatten_as( 192 | structure_of_mess, flattened)[2]["c"][1] 193 | self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) 194 | self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) 195 | 196 | def testFlatten_numpyIsNotFlattened(self): 197 | structure = np.array([1, 2, 3]) 198 | flattened = tree.flatten(structure) 199 | self.assertLen(flattened, 1) 200 | 201 | def testFlatten_stringIsNotFlattened(self): 202 | structure = "lots of letters" 203 | flattened = tree.flatten(structure) 204 | self.assertLen(flattened, 1) 205 | self.assertEqual(structure, tree.unflatten_as("goodbye", flattened)) 206 | 207 | def testFlatten_bytearrayIsNotFlattened(self): 208 | structure = bytearray("bytes in an array", "ascii") 209 | flattened = tree.flatten(structure) 210 | self.assertLen(flattened, 1) 211 | self.assertEqual(flattened, [structure]) 212 | self.assertEqual(structure, 213 | tree.unflatten_as(bytearray("hello", "ascii"), flattened)) 214 | 215 | def testUnflattenSequenceAs_notIterableError(self): 216 | with self.assertRaisesRegex(TypeError, "flat_sequence must be a sequence"): 217 | tree.unflatten_as("hi", "bye") 218 | 219 | def testUnflattenSequenceAs_wrongLengthsError(self): 220 | with self.assertRaisesRegex( 221 | ValueError, 222 | "Structure had 2 elements, but flat_sequence had 3 elements."): 223 | tree.unflatten_as(["hello", "world"], ["and", "goodbye", "again"]) 224 | 225 | def testUnflattenSequenceAs_defaultdict(self): 226 | structure = collections.defaultdict( 227 | list, [("a", [None]), ("b", [None, None])]) 228 | sequence = [1, 2, 3] 229 | expected = collections.defaultdict( 230 | list, [("a", [1]), ("b", [2, 3])]) 231 | self.assertEqual(expected, tree.unflatten_as(structure, sequence)) 232 | 233 | def testIsSequence(self): 234 | self.assertFalse(tree.is_nested("1234")) 235 | self.assertFalse(tree.is_nested(b"1234")) 236 | self.assertFalse(tree.is_nested(u"1234")) 237 | self.assertFalse(tree.is_nested(bytearray("1234", "ascii"))) 238 | self.assertTrue(tree.is_nested([1, 3, [4, 5]])) 239 | self.assertTrue(tree.is_nested(((7, 8), (5, 6)))) 240 | self.assertTrue(tree.is_nested([])) 241 | self.assertTrue(tree.is_nested({"a": 1, "b": 2})) 242 | self.assertFalse(tree.is_nested(set([1, 2]))) 243 | ones = np.ones([2, 3]) 244 | self.assertFalse(tree.is_nested(ones)) 245 | self.assertFalse(tree.is_nested(np.tanh(ones))) 246 | self.assertFalse(tree.is_nested(np.ones((4, 5)))) 247 | 248 | # pylint does not correctly recognize these as class names and 249 | # suggests to use variable style under_score naming. 250 | # pylint: disable=invalid-name 251 | Named0ab = collections.namedtuple("named_0", ("a", "b")) 252 | Named1ab = collections.namedtuple("named_1", ("a", "b")) 253 | SameNameab = collections.namedtuple("same_name", ("a", "b")) 254 | SameNameab2 = collections.namedtuple("same_name", ("a", "b")) 255 | SameNamexy = collections.namedtuple("same_name", ("x", "y")) 256 | SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) 257 | SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) 258 | NotSameName = collections.namedtuple("not_same_name", ("a", "b")) 259 | # pylint: enable=invalid-name 260 | 261 | class SameNamedType1(SameNameab): 262 | pass 263 | 264 | # pylint: disable=g-error-prone-assert-raises 265 | def testAssertSameStructure(self): 266 | tree.assert_same_structure(STRUCTURE1, STRUCTURE2) 267 | tree.assert_same_structure("abc", 1.0) 268 | tree.assert_same_structure(b"abc", 1.0) 269 | tree.assert_same_structure(u"abc", 1.0) 270 | tree.assert_same_structure(bytearray("abc", "ascii"), 1.0) 271 | tree.assert_same_structure("abc", np.array([0, 1])) 272 | 273 | def testAssertSameStructure_differentNumElements(self): 274 | with self.assertRaisesRegex( 275 | ValueError, 276 | ("The two structures don't have the same nested structure\\.\n\n" 277 | "First structure:.*?\n\n" 278 | "Second structure:.*\n\n" 279 | "More specifically: Substructure " 280 | r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' 281 | 'substructure "type=str str=spam" is not\n' 282 | "Entire first structure:\n" 283 | r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" 284 | "Entire second structure:\n" 285 | r"\(\., \.\)")): 286 | tree.assert_same_structure(STRUCTURE1, STRUCTURE_DIFFERENT_NUM_ELEMENTS) 287 | 288 | def testAssertSameStructure_listVsNdArray(self): 289 | with self.assertRaisesRegex( 290 | ValueError, 291 | ("The two structures don't have the same nested structure\\.\n\n" 292 | "First structure:.*?\n\n" 293 | "Second structure:.*\n\n" 294 | r'More specifically: Substructure "type=list str=\[0, 1\]" ' 295 | r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' 296 | "is not")): 297 | tree.assert_same_structure([0, 1], np.array([0, 1])) 298 | 299 | def testAssertSameStructure_intVsList(self): 300 | with self.assertRaisesRegex( 301 | ValueError, 302 | ("The two structures don't have the same nested structure\\.\n\n" 303 | "First structure:.*?\n\n" 304 | "Second structure:.*\n\n" 305 | r'More specifically: Substructure "type=list str=\[0, 1\]" ' 306 | 'is a sequence, while substructure "type=int str=0" ' 307 | "is not")): 308 | tree.assert_same_structure(0, [0, 1]) 309 | 310 | def testAssertSameStructure_tupleVsList(self): 311 | self.assertRaises( 312 | TypeError, tree.assert_same_structure, (0, 1), [0, 1]) 313 | 314 | def testAssertSameStructure_differentNesting(self): 315 | with self.assertRaisesRegex( 316 | ValueError, 317 | ("don't have the same nested structure\\.\n\n" 318 | "First structure: .*?\n\nSecond structure: ")): 319 | tree.assert_same_structure(STRUCTURE1, STRUCTURE_DIFFERENT_NESTING) 320 | 321 | def testAssertSameStructure_tupleVsNamedTuple(self): 322 | self.assertRaises(TypeError, tree.assert_same_structure, (0, 1), 323 | NestTest.Named0ab("a", "b")) 324 | 325 | def testAssertSameStructure_sameNamedTupleDifferentContents(self): 326 | tree.assert_same_structure(NestTest.Named0ab(3, 4), 327 | NestTest.Named0ab("a", "b")) 328 | 329 | def testAssertSameStructure_differentNamedTuples(self): 330 | self.assertRaises(TypeError, tree.assert_same_structure, 331 | NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) 332 | 333 | def testAssertSameStructure_sameNamedTupleDifferentStructuredContents(self): 334 | with self.assertRaisesRegex( 335 | ValueError, 336 | ("don't have the same nested structure\\.\n\n" 337 | "First structure: .*?\n\nSecond structure: ")): 338 | tree.assert_same_structure(NestTest.Named0ab(3, 4), 339 | NestTest.Named0ab([3], 4)) 340 | 341 | def testAssertSameStructure_differentlyNestedLists(self): 342 | with self.assertRaisesRegex( 343 | ValueError, 344 | ("don't have the same nested structure\\.\n\n" 345 | "First structure: .*?\n\nSecond structure: ")): 346 | tree.assert_same_structure([[3], 4], [3, [4]]) 347 | 348 | def testAssertSameStructure_listStructureWithAndWithoutTypes(self): 349 | structure1_list = [[[1, 2], 3], 4, [5, 6]] 350 | with self.assertRaisesRegex(TypeError, "don't have the same sequence type"): 351 | tree.assert_same_structure(STRUCTURE1, structure1_list) 352 | tree.assert_same_structure(STRUCTURE1, STRUCTURE2, check_types=False) 353 | tree.assert_same_structure(STRUCTURE1, structure1_list, check_types=False) 354 | 355 | def testAssertSameStructure_dictionaryDifferentKeys(self): 356 | with self.assertRaisesRegex(ValueError, "don't have the same set of keys"): 357 | tree.assert_same_structure({"a": 1}, {"b": 1}) 358 | 359 | def testAssertSameStructure_sameNameNamedTuples(self): 360 | tree.assert_same_structure(NestTest.SameNameab(0, 1), 361 | NestTest.SameNameab2(2, 3)) 362 | 363 | def testAssertSameStructure_sameNameNamedTuplesNested(self): 364 | # This assertion is expected to pass: two namedtuples with the same 365 | # name and field names are considered to be identical. 366 | tree.assert_same_structure( 367 | NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), 368 | NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) 369 | 370 | def testAssertSameStructure_sameNameNamedTuplesDifferentStructure(self): 371 | expected_message = "The two structures don't have the same.*" 372 | with self.assertRaisesRegex(ValueError, expected_message): 373 | tree.assert_same_structure( 374 | NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), 375 | NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) 376 | 377 | def testAssertSameStructure_differentNameNamedStructures(self): 378 | self.assertRaises(TypeError, tree.assert_same_structure, 379 | NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) 380 | 381 | def testAssertSameStructure_sameNameDifferentFieldNames(self): 382 | self.assertRaises(TypeError, tree.assert_same_structure, 383 | NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) 384 | 385 | def testAssertSameStructure_classWrappingNamedTuple(self): 386 | self.assertRaises(TypeError, tree.assert_same_structure, 387 | NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) 388 | # pylint: enable=g-error-prone-assert-raises 389 | 390 | def testMapStructure(self): 391 | structure2 = (((7, 8), 9), 10, (11, 12)) 392 | structure1_plus1 = tree.map_structure(lambda x: x + 1, STRUCTURE1) 393 | tree.assert_same_structure(STRUCTURE1, structure1_plus1) 394 | self.assertAllEquals( 395 | [2, 3, 4, 5, 6, 7], 396 | tree.flatten(structure1_plus1)) 397 | structure1_plus_structure2 = tree.map_structure( 398 | lambda x, y: x + y, STRUCTURE1, structure2) 399 | self.assertEqual( 400 | (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), 401 | structure1_plus_structure2) 402 | 403 | self.assertEqual(3, tree.map_structure(lambda x: x - 1, 4)) 404 | 405 | self.assertEqual(7, tree.map_structure(lambda x, y: x + y, 3, 4)) 406 | 407 | # Empty structures 408 | self.assertEqual((), tree.map_structure(lambda x: x + 1, ())) 409 | self.assertEqual([], tree.map_structure(lambda x: x + 1, [])) 410 | self.assertEqual({}, tree.map_structure(lambda x: x + 1, {})) 411 | empty_nt = collections.namedtuple("empty_nt", "") 412 | self.assertEqual(empty_nt(), tree.map_structure(lambda x: x + 1, 413 | empty_nt())) 414 | 415 | # This is checking actual equality of types, empty list != empty tuple 416 | self.assertNotEqual((), tree.map_structure(lambda x: x + 1, [])) 417 | 418 | with self.assertRaisesRegex(TypeError, "callable"): 419 | tree.map_structure("bad", structure1_plus1) 420 | 421 | with self.assertRaisesRegex(ValueError, "at least one structure"): 422 | tree.map_structure(lambda x: x) 423 | 424 | with self.assertRaisesRegex(ValueError, "same number of elements"): 425 | tree.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) 426 | 427 | with self.assertRaisesRegex(ValueError, "same nested structure"): 428 | tree.map_structure(lambda x, y: None, 3, (3,)) 429 | 430 | with self.assertRaisesRegex(TypeError, "same sequence type"): 431 | tree.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) 432 | 433 | with self.assertRaisesRegex(ValueError, "same nested structure"): 434 | tree.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) 435 | 436 | structure1_list = [[[1, 2], 3], 4, [5, 6]] 437 | with self.assertRaisesRegex(TypeError, "same sequence type"): 438 | tree.map_structure(lambda x, y: None, STRUCTURE1, structure1_list) 439 | 440 | tree.map_structure(lambda x, y: None, STRUCTURE1, structure1_list, 441 | check_types=False) 442 | 443 | with self.assertRaisesRegex(ValueError, "same nested structure"): 444 | tree.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), 445 | check_types=False) 446 | 447 | with self.assertRaisesRegex(ValueError, "Only valid keyword argument.*foo"): 448 | tree.map_structure(lambda x: None, STRUCTURE1, foo="a") 449 | 450 | with self.assertRaisesRegex(ValueError, "Only valid keyword argument.*foo"): 451 | tree.map_structure(lambda x: None, STRUCTURE1, check_types=False, foo="a") 452 | 453 | def testMapStructureWithStrings(self): 454 | ab_tuple = collections.namedtuple("ab_tuple", "a, b") 455 | inp_a = ab_tuple(a="foo", b=("bar", "baz")) 456 | inp_b = ab_tuple(a=2, b=(1, 3)) 457 | out = tree.map_structure(lambda string, repeats: string * repeats, 458 | inp_a, 459 | inp_b) 460 | self.assertEqual("foofoo", out.a) 461 | self.assertEqual("bar", out.b[0]) 462 | self.assertEqual("bazbazbaz", out.b[1]) 463 | 464 | nt = ab_tuple(a=("something", "something_else"), 465 | b="yet another thing") 466 | rev_nt = tree.map_structure(lambda x: x[::-1], nt) 467 | # Check the output is the correct structure, and all strings are reversed. 468 | tree.assert_same_structure(nt, rev_nt) 469 | self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) 470 | self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) 471 | self.assertEqual(nt.b[::-1], rev_nt.b) 472 | 473 | def testAssertShallowStructure(self): 474 | inp_ab = ["a", "b"] 475 | inp_abc = ["a", "b", "c"] 476 | with self.assertRaisesRegex( 477 | ValueError, 478 | tree._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( 479 | input_length=len(inp_ab), 480 | shallow_length=len(inp_abc))): 481 | tree._assert_shallow_structure(inp_abc, inp_ab) 482 | 483 | inp_ab1 = [(1, 1), (2, 2)] 484 | inp_ab2 = [[1, 1], [2, 2]] 485 | with self.assertRaisesWithLiteralMatch( 486 | TypeError, 487 | tree._STRUCTURES_HAVE_MISMATCHING_TYPES.format( 488 | shallow_type=type(inp_ab2[0]), 489 | input_type=type(inp_ab1[0]))): 490 | tree._assert_shallow_structure(shallow_tree=inp_ab2, input_tree=inp_ab1) 491 | 492 | tree._assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) 493 | 494 | inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} 495 | inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} 496 | 497 | with self.assertRaisesWithLiteralMatch( 498 | ValueError, 499 | tree._SHALLOW_TREE_HAS_INVALID_KEYS.format(["d"])): 500 | tree._assert_shallow_structure(inp_ab2, inp_ab1) 501 | 502 | inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) 503 | inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) 504 | tree._assert_shallow_structure(inp_ab, inp_ba) 505 | 506 | # regression test for b/130633904 507 | tree._assert_shallow_structure({0: "foo"}, ["foo"], check_types=False) 508 | 509 | def testFlattenUpTo(self): 510 | # Shallow tree ends at scalar. 511 | input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] 512 | shallow_tree = [[True, True], [False, True]] 513 | flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) 514 | flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) 515 | self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) 516 | self.assertEqual(flattened_shallow_tree, [True, True, False, True]) 517 | 518 | # Shallow tree ends at string. 519 | input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] 520 | shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] 521 | input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree, 522 | input_tree) 523 | input_tree_flattened = tree.flatten(input_tree) 524 | self.assertEqual(input_tree_flattened_as_shallow_tree, 525 | [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) 526 | self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) 527 | 528 | # Make sure dicts are correctly flattened, yielding values, not keys. 529 | input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} 530 | shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} 531 | input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree, 532 | input_tree) 533 | self.assertEqual(input_tree_flattened_as_shallow_tree, 534 | [1, {"c": 2}, 3, (4, 5)]) 535 | 536 | # Namedtuples. 537 | ab_tuple = collections.namedtuple("ab_tuple", "a, b") 538 | input_tree = ab_tuple(a=[0, 1], b=2) 539 | shallow_tree = ab_tuple(a=0, b=1) 540 | input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree, 541 | input_tree) 542 | self.assertEqual(input_tree_flattened_as_shallow_tree, 543 | [[0, 1], 2]) 544 | 545 | # Attrs. 546 | @attr.s 547 | class ABAttr(object): 548 | a = attr.ib() 549 | b = attr.ib() 550 | input_tree = ABAttr(a=[0, 1], b=2) 551 | shallow_tree = ABAttr(a=0, b=1) 552 | input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree, 553 | input_tree) 554 | self.assertEqual(input_tree_flattened_as_shallow_tree, 555 | [[0, 1], 2]) 556 | 557 | # Nested dicts, OrderedDicts and namedtuples. 558 | input_tree = collections.OrderedDict( 559 | [("a", ab_tuple(a=[0, {"b": 1}], b=2)), 560 | ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) 561 | shallow_tree = input_tree 562 | input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree, 563 | input_tree) 564 | self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) 565 | shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) 566 | input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree, 567 | input_tree) 568 | self.assertEqual(input_tree_flattened_as_shallow_tree, 569 | [ab_tuple(a=[0, {"b": 1}], b=2), 570 | 3, 571 | collections.OrderedDict([("f", 4)])]) 572 | shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) 573 | input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree, 574 | input_tree) 575 | self.assertEqual(input_tree_flattened_as_shallow_tree, 576 | [ab_tuple(a=[0, {"b": 1}], b=2), 577 | {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) 578 | 579 | ## Shallow non-list edge-case. 580 | # Using iterable elements. 581 | input_tree = ["input_tree"] 582 | shallow_tree = "shallow_tree" 583 | flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) 584 | flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) 585 | self.assertEqual(flattened_input_tree, [input_tree]) 586 | self.assertEqual(flattened_shallow_tree, [shallow_tree]) 587 | 588 | input_tree = ["input_tree_0", "input_tree_1"] 589 | shallow_tree = "shallow_tree" 590 | flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) 591 | flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) 592 | self.assertEqual(flattened_input_tree, [input_tree]) 593 | self.assertEqual(flattened_shallow_tree, [shallow_tree]) 594 | 595 | # Using non-iterable elements. 596 | input_tree = [0] 597 | shallow_tree = 9 598 | flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) 599 | flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) 600 | self.assertEqual(flattened_input_tree, [input_tree]) 601 | self.assertEqual(flattened_shallow_tree, [shallow_tree]) 602 | 603 | input_tree = [0, 1] 604 | shallow_tree = 9 605 | flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) 606 | flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) 607 | self.assertEqual(flattened_input_tree, [input_tree]) 608 | self.assertEqual(flattened_shallow_tree, [shallow_tree]) 609 | 610 | ## Both non-list edge-case. 611 | # Using iterable elements. 612 | input_tree = "input_tree" 613 | shallow_tree = "shallow_tree" 614 | flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) 615 | flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) 616 | self.assertEqual(flattened_input_tree, [input_tree]) 617 | self.assertEqual(flattened_shallow_tree, [shallow_tree]) 618 | 619 | # Using non-iterable elements. 620 | input_tree = 0 621 | shallow_tree = 0 622 | flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) 623 | flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) 624 | self.assertEqual(flattened_input_tree, [input_tree]) 625 | self.assertEqual(flattened_shallow_tree, [shallow_tree]) 626 | 627 | ## Input non-list edge-case. 628 | # Using iterable elements. 629 | input_tree = "input_tree" 630 | shallow_tree = ["shallow_tree"] 631 | with self.assertRaisesWithLiteralMatch( 632 | TypeError, 633 | tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))): 634 | flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) 635 | flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) 636 | self.assertEqual(flattened_shallow_tree, shallow_tree) 637 | 638 | input_tree = "input_tree" 639 | shallow_tree = ["shallow_tree_9", "shallow_tree_8"] 640 | with self.assertRaisesWithLiteralMatch( 641 | TypeError, 642 | tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))): 643 | flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) 644 | flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) 645 | self.assertEqual(flattened_shallow_tree, shallow_tree) 646 | 647 | # Using non-iterable elements. 648 | input_tree = 0 649 | shallow_tree = [9] 650 | with self.assertRaisesWithLiteralMatch( 651 | TypeError, 652 | tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))): 653 | flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) 654 | flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) 655 | self.assertEqual(flattened_shallow_tree, shallow_tree) 656 | 657 | input_tree = 0 658 | shallow_tree = [9, 8] 659 | with self.assertRaisesWithLiteralMatch( 660 | TypeError, 661 | tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))): 662 | flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) 663 | flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) 664 | self.assertEqual(flattened_shallow_tree, shallow_tree) 665 | 666 | def testByteStringsNotTreatedAsIterable(self): 667 | structure = [u"unicode string", b"byte string"] 668 | flattened_structure = tree.flatten_up_to(structure, structure) 669 | self.assertEqual(structure, flattened_structure) 670 | 671 | def testFlattenWithPathUpTo(self): 672 | 673 | def get_paths_and_values(shallow_tree, input_tree): 674 | path_value_pairs = tree.flatten_with_path_up_to(shallow_tree, input_tree) 675 | paths = [p for p, _ in path_value_pairs] 676 | values = [v for _, v in path_value_pairs] 677 | return paths, values 678 | 679 | # Shallow tree ends at scalar. 680 | input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] 681 | shallow_tree = [[True, True], [False, True]] 682 | (flattened_input_tree_paths, 683 | flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) 684 | (flattened_shallow_tree_paths, 685 | flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) 686 | self.assertEqual(flattened_input_tree_paths, 687 | [(0, 0), (0, 1), (1, 0), (1, 1)]) 688 | self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) 689 | self.assertEqual(flattened_shallow_tree_paths, 690 | [(0, 0), (0, 1), (1, 0), (1, 1)]) 691 | self.assertEqual(flattened_shallow_tree, [True, True, False, True]) 692 | 693 | # Shallow tree ends at string. 694 | input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] 695 | shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] 696 | (input_tree_flattened_as_shallow_tree_paths, 697 | input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, 698 | input_tree) 699 | input_tree_flattened_paths = [ 700 | p for p, _ in tree.flatten_with_path(input_tree) 701 | ] 702 | input_tree_flattened = tree.flatten(input_tree) 703 | self.assertEqual(input_tree_flattened_as_shallow_tree_paths, 704 | [(0, 0), (0, 1, 0), (0, 1, 1, 0), (0, 1, 1, 1, 0)]) 705 | self.assertEqual(input_tree_flattened_as_shallow_tree, 706 | [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) 707 | 708 | self.assertEqual(input_tree_flattened_paths, 709 | [(0, 0, 0), (0, 0, 1), 710 | (0, 1, 0, 0), (0, 1, 0, 1), 711 | (0, 1, 1, 0, 0), (0, 1, 1, 0, 1), 712 | (0, 1, 1, 1, 0, 0), (0, 1, 1, 1, 0, 1)]) 713 | self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) 714 | 715 | # Make sure dicts are correctly flattened, yielding values, not keys. 716 | input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} 717 | shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} 718 | (input_tree_flattened_as_shallow_tree_paths, 719 | input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, 720 | input_tree) 721 | self.assertEqual(input_tree_flattened_as_shallow_tree_paths, 722 | [("a",), ("b",), ("d", 0), ("d", 1)]) 723 | self.assertEqual(input_tree_flattened_as_shallow_tree, 724 | [1, {"c": 2}, 3, (4, 5)]) 725 | 726 | # Namedtuples. 727 | ab_tuple = collections.namedtuple("ab_tuple", "a, b") 728 | input_tree = ab_tuple(a=[0, 1], b=2) 729 | shallow_tree = ab_tuple(a=0, b=1) 730 | (input_tree_flattened_as_shallow_tree_paths, 731 | input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, 732 | input_tree) 733 | self.assertEqual(input_tree_flattened_as_shallow_tree_paths, 734 | [("a",), ("b",)]) 735 | self.assertEqual(input_tree_flattened_as_shallow_tree, 736 | [[0, 1], 2]) 737 | 738 | # Nested dicts, OrderedDicts and namedtuples. 739 | input_tree = collections.OrderedDict( 740 | [("a", ab_tuple(a=[0, {"b": 1}], b=2)), 741 | ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) 742 | shallow_tree = input_tree 743 | (input_tree_flattened_as_shallow_tree_paths, 744 | input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, 745 | input_tree) 746 | self.assertEqual(input_tree_flattened_as_shallow_tree_paths, 747 | [("a", "a", 0), 748 | ("a", "a", 1, "b"), 749 | ("a", "b"), 750 | ("c", "d"), 751 | ("c", "e", "f")]) 752 | self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) 753 | shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) 754 | (input_tree_flattened_as_shallow_tree_paths, 755 | input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, 756 | input_tree) 757 | self.assertEqual(input_tree_flattened_as_shallow_tree_paths, 758 | [("a",), 759 | ("c", "d"), 760 | ("c", "e")]) 761 | self.assertEqual(input_tree_flattened_as_shallow_tree, 762 | [ab_tuple(a=[0, {"b": 1}], b=2), 763 | 3, 764 | collections.OrderedDict([("f", 4)])]) 765 | shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) 766 | (input_tree_flattened_as_shallow_tree_paths, 767 | input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, 768 | input_tree) 769 | self.assertEqual(input_tree_flattened_as_shallow_tree_paths, 770 | [("a",), ("c",)]) 771 | self.assertEqual(input_tree_flattened_as_shallow_tree, 772 | [ab_tuple(a=[0, {"b": 1}], b=2), 773 | {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) 774 | 775 | ## Shallow non-list edge-case. 776 | # Using iterable elements. 777 | input_tree = ["input_tree"] 778 | shallow_tree = "shallow_tree" 779 | (flattened_input_tree_paths, 780 | flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) 781 | (flattened_shallow_tree_paths, 782 | flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) 783 | self.assertEqual(flattened_input_tree_paths, [()]) 784 | self.assertEqual(flattened_input_tree, [input_tree]) 785 | self.assertEqual(flattened_shallow_tree_paths, [()]) 786 | self.assertEqual(flattened_shallow_tree, [shallow_tree]) 787 | 788 | input_tree = ["input_tree_0", "input_tree_1"] 789 | shallow_tree = "shallow_tree" 790 | (flattened_input_tree_paths, 791 | flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) 792 | (flattened_shallow_tree_paths, 793 | flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) 794 | self.assertEqual(flattened_input_tree_paths, [()]) 795 | self.assertEqual(flattened_input_tree, [input_tree]) 796 | self.assertEqual(flattened_shallow_tree_paths, [()]) 797 | self.assertEqual(flattened_shallow_tree, [shallow_tree]) 798 | 799 | # Test case where len(shallow_tree) < len(input_tree) 800 | input_tree = {"a": "A", "b": "B", "c": "C"} 801 | shallow_tree = {"a": 1, "c": 2} 802 | 803 | # Using non-iterable elements. 804 | input_tree = [0] 805 | shallow_tree = 9 806 | (flattened_input_tree_paths, 807 | flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) 808 | (flattened_shallow_tree_paths, 809 | flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) 810 | self.assertEqual(flattened_input_tree_paths, [()]) 811 | self.assertEqual(flattened_input_tree, [input_tree]) 812 | self.assertEqual(flattened_shallow_tree_paths, [()]) 813 | self.assertEqual(flattened_shallow_tree, [shallow_tree]) 814 | 815 | input_tree = [0, 1] 816 | shallow_tree = 9 817 | (flattened_input_tree_paths, 818 | flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) 819 | (flattened_shallow_tree_paths, 820 | flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) 821 | self.assertEqual(flattened_input_tree_paths, [()]) 822 | self.assertEqual(flattened_input_tree, [input_tree]) 823 | self.assertEqual(flattened_shallow_tree_paths, [()]) 824 | self.assertEqual(flattened_shallow_tree, [shallow_tree]) 825 | 826 | ## Both non-list edge-case. 827 | # Using iterable elements. 828 | input_tree = "input_tree" 829 | shallow_tree = "shallow_tree" 830 | (flattened_input_tree_paths, 831 | flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) 832 | (flattened_shallow_tree_paths, 833 | flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) 834 | self.assertEqual(flattened_input_tree_paths, [()]) 835 | self.assertEqual(flattened_input_tree, [input_tree]) 836 | self.assertEqual(flattened_shallow_tree_paths, [()]) 837 | self.assertEqual(flattened_shallow_tree, [shallow_tree]) 838 | 839 | # Using non-iterable elements. 840 | input_tree = 0 841 | shallow_tree = 0 842 | (flattened_input_tree_paths, 843 | flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) 844 | (flattened_shallow_tree_paths, 845 | flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) 846 | self.assertEqual(flattened_input_tree_paths, [()]) 847 | self.assertEqual(flattened_input_tree, [input_tree]) 848 | self.assertEqual(flattened_shallow_tree_paths, [()]) 849 | self.assertEqual(flattened_shallow_tree, [shallow_tree]) 850 | 851 | ## Input non-list edge-case. 852 | # Using iterable elements. 853 | input_tree = "input_tree" 854 | shallow_tree = ["shallow_tree"] 855 | with self.assertRaisesWithLiteralMatch( 856 | TypeError, 857 | tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format( 858 | path=[], input_type=type(input_tree))): 859 | (flattened_input_tree_paths, 860 | flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) 861 | (flattened_shallow_tree_paths, 862 | flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) 863 | self.assertEqual(flattened_shallow_tree_paths, [(0,)]) 864 | self.assertEqual(flattened_shallow_tree, shallow_tree) 865 | 866 | input_tree = "input_tree" 867 | shallow_tree = ["shallow_tree_9", "shallow_tree_8"] 868 | with self.assertRaisesWithLiteralMatch( 869 | TypeError, 870 | tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format( 871 | path=[], input_type=type(input_tree))): 872 | (flattened_input_tree_paths, 873 | flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) 874 | (flattened_shallow_tree_paths, 875 | flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) 876 | self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)]) 877 | self.assertEqual(flattened_shallow_tree, shallow_tree) 878 | 879 | # Using non-iterable elements. 880 | input_tree = 0 881 | shallow_tree = [9] 882 | with self.assertRaisesWithLiteralMatch( 883 | TypeError, 884 | tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format( 885 | path=[], input_type=type(input_tree))): 886 | (flattened_input_tree_paths, 887 | flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) 888 | (flattened_shallow_tree_paths, 889 | flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) 890 | self.assertEqual(flattened_shallow_tree_paths, [(0,)]) 891 | self.assertEqual(flattened_shallow_tree, shallow_tree) 892 | 893 | input_tree = 0 894 | shallow_tree = [9, 8] 895 | with self.assertRaisesWithLiteralMatch( 896 | TypeError, 897 | tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format( 898 | path=[], input_type=type(input_tree))): 899 | (flattened_input_tree_paths, 900 | flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) 901 | (flattened_shallow_tree_paths, 902 | flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) 903 | self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)]) 904 | self.assertEqual(flattened_shallow_tree, shallow_tree) 905 | 906 | # Test that error messages include paths. 907 | input_tree = {"a": {"b": {0, 1}}} 908 | structure = {"a": {"b": [0, 1]}} 909 | with self.assertRaisesWithLiteralMatch( 910 | TypeError, 911 | tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format( 912 | path=["a", "b"], input_type=type(input_tree["a"]["b"]))): 913 | (flattened_input_tree_paths, 914 | flattened_input_tree) = get_paths_and_values(structure, input_tree) 915 | (flattened_tree_paths, 916 | flattened_tree) = get_paths_and_values(structure, structure) 917 | self.assertEqual(flattened_tree_paths, [("a", "b", 0,), ("a", "b", 1,)]) 918 | self.assertEqual(flattened_tree, structure["a"]["b"]) 919 | 920 | def testMapStructureUpTo(self): 921 | # Named tuples. 922 | ab_tuple = collections.namedtuple("ab_tuple", "a, b") 923 | op_tuple = collections.namedtuple("op_tuple", "add, mul") 924 | inp_val = ab_tuple(a=2, b=3) 925 | inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) 926 | out = tree.map_structure_up_to( 927 | inp_val, 928 | lambda val, ops: (val + ops.add) * ops.mul, 929 | inp_val, 930 | inp_ops, 931 | check_types=False) 932 | self.assertEqual(out.a, 6) 933 | self.assertEqual(out.b, 15) 934 | 935 | # Lists. 936 | data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] 937 | name_list = ["evens", ["odds", "primes"]] 938 | out = tree.map_structure_up_to( 939 | name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), 940 | name_list, data_list) 941 | self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) 942 | 943 | # We cannot define namedtuples within @parameterized argument lists. 944 | # pylint: disable=invalid-name 945 | Foo = collections.namedtuple("Foo", ["a", "b"]) 946 | Bar = collections.namedtuple("Bar", ["c", "d"]) 947 | # pylint: enable=invalid-name 948 | 949 | @parameterized.parameters([ 950 | dict(inputs=[], expected=[]), 951 | dict(inputs=[23, "42"], expected=[((0,), 23), ((1,), "42")]), 952 | dict(inputs=[[[[108]]]], expected=[((0, 0, 0, 0), 108)]), 953 | dict(inputs=Foo(a=3, b=Bar(c=23, d=42)), 954 | expected=[(("a",), 3), (("b", "c"), 23), (("b", "d"), 42)]), 955 | dict(inputs=Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="thing")), 956 | expected=[(("a", "c"), 23), (("a", "d"), 42), (("b", "c"), 0), 957 | (("b", "d"), "thing")]), 958 | dict(inputs=Bar(c=42, d=43), 959 | expected=[(("c",), 42), (("d",), 43)]), 960 | dict(inputs=Bar(c=[42], d=43), 961 | expected=[(("c", 0), 42), (("d",), 43)]), 962 | dict(inputs=wrapt.ObjectProxy(Bar(c=[42], d=43)), 963 | expected=[(("c", 0), 42), (("d",), 43)]), 964 | ]) 965 | def testFlattenWithPath(self, inputs, expected): 966 | self.assertEqual(tree.flatten_with_path(inputs), expected) 967 | 968 | @parameterized.named_parameters([ 969 | dict(testcase_name="Tuples", s1=(1, 2), s2=(3, 4), 970 | check_types=True, expected=(((0,), 4), ((1,), 6))), 971 | dict(testcase_name="Dicts", s1={"a": 1, "b": 2}, s2={"b": 4, "a": 3}, 972 | check_types=True, expected={"a": (("a",), 4), "b": (("b",), 6)}), 973 | dict(testcase_name="Mixed", s1=(1, 2), s2=[3, 4], 974 | check_types=False, expected=(((0,), 4), ((1,), 6))), 975 | dict(testcase_name="Nested", 976 | s1={"a": [2, 3], "b": [1, 2, 3]}, 977 | s2={"b": [5, 6, 7], "a": [8, 9]}, 978 | check_types=True, 979 | expected={"a": [(("a", 0), 10), (("a", 1), 12)], 980 | "b": [(("b", 0), 6), (("b", 1), 8), (("b", 2), 10)]}), 981 | ]) 982 | def testMapWithPathCompatibleStructures(self, s1, s2, check_types, expected): 983 | def path_and_sum(path, *values): 984 | return path, sum(values) 985 | 986 | result = tree.map_structure_with_path( 987 | path_and_sum, s1, s2, check_types=check_types) 988 | self.assertEqual(expected, result) 989 | 990 | @parameterized.named_parameters([ 991 | dict(testcase_name="Tuples", s1=(1, 2, 3), s2=(4, 5), 992 | error_type=ValueError), 993 | dict(testcase_name="Dicts", s1={"a": 1}, s2={"b": 2}, 994 | error_type=ValueError), 995 | dict(testcase_name="Nested", 996 | s1={"a": [2, 3, 4], "b": [1, 3]}, 997 | s2={"b": [5, 6], "a": [8, 9]}, 998 | error_type=ValueError) 999 | ]) 1000 | def testMapWithPathIncompatibleStructures(self, s1, s2, error_type): 1001 | with self.assertRaises(error_type): 1002 | tree.map_structure_with_path(lambda path, *s: 0, s1, s2) 1003 | 1004 | def testMappingProxyType(self): 1005 | structure = types.MappingProxyType({"a": 1, "b": (2, 3)}) 1006 | expected = types.MappingProxyType({"a": 4, "b": (5, 6)}) 1007 | self.assertEqual(tree.flatten(structure), [1, 2, 3]) 1008 | self.assertEqual(tree.unflatten_as(structure, [4, 5, 6]), expected) 1009 | self.assertEqual(tree.map_structure(lambda v: v + 3, structure), expected) 1010 | 1011 | def testTraverseListsToTuples(self): 1012 | structure = [(1, 2), [3], {"a": [4]}] 1013 | self.assertEqual( 1014 | ((1, 2), (3,), {"a": (4,)}), 1015 | tree.traverse( 1016 | lambda x: tuple(x) if isinstance(x, list) else x, 1017 | structure, 1018 | top_down=False)) 1019 | 1020 | def testTraverseEarlyTermination(self): 1021 | structure = [(1, [2]), [3, (4, 5, 6)]] 1022 | visited = [] 1023 | def visit(x): 1024 | visited.append(x) 1025 | return "X" if isinstance(x, tuple) and len(x) > 2 else None 1026 | 1027 | output = tree.traverse(visit, structure) 1028 | self.assertEqual([(1, [2]), [3, "X"]], output) 1029 | self.assertEqual( 1030 | [[(1, [2]), [3, (4, 5, 6)]], 1031 | (1, [2]), 1, [2], 2, [3, (4, 5, 6)], 3, (4, 5, 6)], 1032 | visited) 1033 | 1034 | def testMapStructureAcrossSubtreesDict(self): 1035 | shallow = {"a": 1, "b": {"c": 2}} 1036 | deep1 = {"a": 2, "b": {"c": 3, "d": 2}, "e": 4} 1037 | deep2 = {"a": 3, "b": {"c": 2, "d": 3}, "e": 1} 1038 | summed = tree.map_structure_up_to( 1039 | shallow, lambda *args: sum(args), deep1, deep2) 1040 | expected = {"a": 5, "b": {"c": 5}} 1041 | self.assertEqual(summed, expected) 1042 | concatenated = tree.map_structure_up_to( 1043 | shallow, lambda *args: args, deep1, deep2) 1044 | expected = {"a": (2, 3), "b": {"c": (3, 2)}} 1045 | self.assertEqual(concatenated, expected) 1046 | 1047 | def testMapStructureAcrossSubtreesNoneValues(self): 1048 | shallow = [1, [None]] 1049 | deep1 = [1, [2, 3]] 1050 | deep2 = [2, [3, 4]] 1051 | summed = tree.map_structure_up_to( 1052 | shallow, lambda *args: sum(args), deep1, deep2) 1053 | expected = [3, [5]] 1054 | self.assertEqual(summed, expected) 1055 | 1056 | def testMapStructureAcrossSubtreesList(self): 1057 | shallow = [1, [1]] 1058 | deep1 = [1, [2, 3]] 1059 | deep2 = [2, [3, 4]] 1060 | summed = tree.map_structure_up_to( 1061 | shallow, lambda *args: sum(args), deep1, deep2) 1062 | expected = [3, [5]] 1063 | self.assertEqual(summed, expected) 1064 | 1065 | def testMapStructureAcrossSubtreesTuple(self): 1066 | shallow = (1, (1,)) 1067 | deep1 = (1, (2, 3)) 1068 | deep2 = (2, (3, 4)) 1069 | summed = tree.map_structure_up_to( 1070 | shallow, lambda *args: sum(args), deep1, deep2) 1071 | expected = (3, (5,)) 1072 | self.assertEqual(summed, expected) 1073 | 1074 | def testMapStructureAcrossSubtreesNamedTuple(self): 1075 | Foo = collections.namedtuple("Foo", ["x", "y"]) 1076 | Bar = collections.namedtuple("Bar", ["x"]) 1077 | shallow = Bar(1) 1078 | deep1 = Foo(1, (1, 0)) 1079 | deep2 = Foo(2, (2, 0)) 1080 | summed = tree.map_structure_up_to( 1081 | shallow, lambda *args: sum(args), deep1, deep2) 1082 | expected = Bar(3) 1083 | self.assertEqual(summed, expected) 1084 | 1085 | def testMapStructureAcrossSubtreesListTuple(self): 1086 | # Tuples and lists can be used interchangeably between shallow structure 1087 | # and input structures. Output takes on type of the shallow structure 1088 | shallow = [1, (1,)] 1089 | deep1 = [1, [2, 3]] 1090 | deep2 = [2, [3, 4]] 1091 | summed = tree.map_structure_up_to(shallow, lambda *args: sum(args), deep1, 1092 | deep2) 1093 | expected = [3, (5,)] 1094 | self.assertEqual(summed, expected) 1095 | 1096 | shallow = [1, [1]] 1097 | deep1 = [1, (2, 3)] 1098 | deep2 = [2, (3, 4)] 1099 | summed = tree.map_structure_up_to(shallow, lambda *args: sum(args), deep1, 1100 | deep2) 1101 | expected = [3, [5]] 1102 | self.assertEqual(summed, expected) 1103 | 1104 | def testNoneNodeIncluded(self): 1105 | structure = ((1, None)) 1106 | self.assertEqual(tree.flatten(structure), [1, None]) 1107 | 1108 | def testCustomClassMapWithPath(self): 1109 | 1110 | class ExampleClass(Mapping[Any, Any]): 1111 | """Small example custom class.""" 1112 | 1113 | def __init__(self, *args, **kwargs): 1114 | self._mapping = dict(*args, **kwargs) 1115 | 1116 | def __getitem__(self, k: Any) -> Any: 1117 | return self._mapping[k] 1118 | 1119 | def __len__(self) -> int: 1120 | return len(self._mapping) 1121 | 1122 | def __iter__(self) -> Iterator[Any]: 1123 | return iter(self._mapping) 1124 | 1125 | def mapper(path, value): 1126 | full_path = "/".join(path) 1127 | return f"{full_path}_{value}" 1128 | 1129 | test_input = ExampleClass({"first": 1, "nested": {"second": 2, "third": 3}}) 1130 | output = tree.map_structure_with_path(mapper, test_input) 1131 | expected = ExampleClass({ 1132 | "first": "first_1", 1133 | "nested": { 1134 | "second": "nested/second_2", 1135 | "third": "nested/third_3" 1136 | } 1137 | }) 1138 | self.assertEqual(output, expected) 1139 | 1140 | 1141 | if __name__ == "__main__": 1142 | unittest.main() 1143 | --------------------------------------------------------------------------------