├── tests ├── __init__.py ├── test_typecheck.py ├── test_with_jax.py └── test_tree_util.py ├── setup.py ├── docs ├── source │ ├── index.rst │ ├── api.rst │ ├── _static │ │ ├── images │ │ │ ├── coding.svg │ │ │ ├── books.svg │ │ │ ├── book.svg │ │ │ └── light-bulb.svg │ │ └── css │ │ │ └── custom.css │ └── conf.py ├── rtd_environment.yml ├── Makefile └── make.bat ├── .readthedocs.yml ├── MANIFEST.in ├── codecov.yml ├── pyproject.toml ├── .github ├── PULL_REQUEST_TEMPLATE │ └── pull_request_template.md ├── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── feature_request.md │ └── enhancement.md └── workflows │ ├── publish-to-pypi.yml │ └── main.yml ├── src └── pybaum │ ├── config.py │ ├── equality.py │ ├── __init__.py │ ├── registry.py │ ├── typecheck.py │ ├── registry_entries.py │ └── tree_util.py ├── environment.yml ├── CHANGES.rst ├── setup.cfg ├── LICENSE ├── tox.ini ├── .gitignore ├── README.rst └── .pre-commit-config.yaml /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup() 5 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | pybaum 2 | ====== 3 | 4 | pybaum contains tools to work with pytrees. 5 | 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | api 11 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | image: latest 5 | 6 | python: 7 | version: 3.8 8 | 9 | conda: 10 | environment: docs/rtd_environment.yml 11 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include CITATION 2 | include LICENSE 3 | include CHANGES.rst 4 | 5 | exclude *.yaml 6 | exclude *.yml 7 | exclude tox.ini 8 | 9 | prune docs 10 | prune tests 11 | -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | The following documents are auto-generated and not carefully edited. They provide direct 5 | access to the source code and the docstrings. 6 | 7 | .. toctree:: 8 | :titlesonly: 9 | 10 | /autoapi/pybaum/index 11 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | notify: 3 | require_ci_to_pass: yes 4 | 5 | coverage: 6 | precision: 2 7 | round: down 8 | range: "50...100" 9 | status: 10 | patch: 11 | default: 12 | target: 80% 13 | project: 14 | default: 15 | target: 80% 16 | 17 | ignore: 18 | - ".tox/**/*" 19 | - "setup.py" 20 | -------------------------------------------------------------------------------- /docs/rtd_environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | - nodefaults 4 | 5 | dependencies: 6 | - python=3.9 7 | - pip 8 | - setuptools_scm 9 | - toml 10 | 11 | # Documentation 12 | - sphinx 13 | - sphinx-autoapi 14 | - sphinx-copybutton 15 | - sphinx-panels 16 | - pydata-sphinx-theme>=0.3.0 17 | 18 | - pip: 19 | - ../ 20 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.nbqa.config] 2 | isort = "setup.cfg" 3 | black = "pyproject.toml" 4 | 5 | [tool.nbqa.mutate] 6 | isort = 1 7 | black = 1 8 | pyupgrade = 1 9 | 10 | 11 | [tool.nbqa.addopts] 12 | isort = ["--treat-comment-as-code", "# %%", "--profile=black"] 13 | pyupgrade = ["--py37-plus"] 14 | 15 | 16 | [build-system] 17 | requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.0"] 18 | 19 | 20 | [tool.setuptools_scm] 21 | write_to = "src/pybaum/_version.py" 22 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ### What problem do you want to solve? 2 | 3 | Reference the issue or discussion, if there is any. Provide a description of your 4 | proposed solution. 5 | 6 | ### Todo 7 | 8 | - [ ] Target the right branch and pick an appropriate title. 9 | - [ ] Put `Closes #XXXX` in the first PR comment to auto-close the relevant issue once 10 | the PR is accepted. This is not applicable if there is no corresponding issue. 11 | - [ ] Any steps that still need to be done. 12 | -------------------------------------------------------------------------------- /src/pybaum/config.py: -------------------------------------------------------------------------------- 1 | try: 2 | import numpy as np # noqa: F401 3 | except ImportError: 4 | IS_NUMPY_INSTALLED = False 5 | else: 6 | IS_NUMPY_INSTALLED = True 7 | 8 | 9 | try: 10 | import pandas as pd # noqa: F401 11 | except ImportError: 12 | IS_PANDAS_INSTALLED = False 13 | else: 14 | IS_PANDAS_INSTALLED = True 15 | 16 | 17 | try: 18 | import jax # noqa: F401 19 | import jaxlib # noqa: F401 20 | except ImportError: 21 | IS_JAX_INSTALLED = False 22 | else: 23 | IS_JAX_INSTALLED = True 24 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pybaum 2 | 3 | channels: 4 | - conda-forge 5 | - nodefaults 6 | 7 | dependencies: 8 | - python >=3.9 9 | - pip 10 | - setuptools_scm 11 | - toml 12 | 13 | # Testing 14 | - pre-commit 15 | - pytest 16 | - pytest-cov 17 | - pytest-xdist 18 | - tox-conda 19 | 20 | # Documentation 21 | - sphinx >=4 22 | - sphinx-autoapi 23 | - sphinx-copybutton 24 | - sphinx-panels 25 | - pydata-sphinx-theme>=0.3.0 26 | 27 | # Development 28 | - jupyterlab 29 | - nbsphinx 30 | - pdbpp 31 | - numpy 32 | - pandas 33 | - jax 34 | - jaxlib 35 | -------------------------------------------------------------------------------- /tests/test_typecheck.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import NamedTuple 3 | 4 | from pybaum.typecheck import get_type 5 | 6 | 7 | def test_namedtuple_is_discovered(): 8 | bla = namedtuple("bla", ["a", "b"])(1, 2) 9 | assert get_type(bla) == "namedtuple" 10 | 11 | 12 | def test_typed_namedtuple_is_discovered(): 13 | class Blubb(NamedTuple): 14 | a: int 15 | b: int 16 | 17 | blubb = Blubb(1, 2) 18 | assert get_type(blubb) == "namedtuple" 19 | 20 | 21 | def test_standard_tuple_is_not_discovered(): 22 | assert get_type((1, 2)) == tuple 23 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = pybaum 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Bug description 11 | 12 | A clear and concise description of what the bug is. 13 | 14 | ### To Reproduce 15 | 16 | Ideally, provide a minimal code example. If that's not possible, describe steps to reproduce the bug. 17 | 18 | ### Expected behavior 19 | 20 | A clear and concise description of what you expected to happen. 21 | 22 | ### Screenshots/Error messages 23 | 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | ### System 27 | 28 | - OS: [e.g. Ubuntu 18.04] 29 | - Version [e.g. 0.0.1] 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: feature-request 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Current situation 11 | 12 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]; Currently there is no way of [...] 13 | 14 | ### Desired Situation 15 | 16 | What functionality should become possible or easier? 17 | 18 | ### Proposed implementation 19 | 20 | How would you implement the new feature? Did you consider alternative implementations? 21 | You can start by describing interface changes like a new argument or a new function. There is no need to get too detailed here. 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/enhancement.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Enhancement 3 | about: Enhance an existing component. 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | * pybaum version used, if any: 11 | * Python version, if any: 12 | * Operating System: 13 | 14 | ### What would you like to enhance and why? Is it related to an issue/problem? 15 | 16 | A clear and concise description of the current implementation and its limitations. 17 | 18 | ### Describe the solution you'd like 19 | 20 | A clear and concise description of what you want to happen. 21 | 22 | ### Describe alternatives you've considered 23 | 24 | A clear and concise description of any alternative solutions or features you've 25 | considered and why you have discarded them. 26 | -------------------------------------------------------------------------------- /src/pybaum/equality.py: -------------------------------------------------------------------------------- 1 | """Functions to check equality of pytree leaves.""" 2 | from pybaum.config import IS_JAX_INSTALLED 3 | from pybaum.config import IS_NUMPY_INSTALLED 4 | from pybaum.config import IS_PANDAS_INSTALLED 5 | 6 | 7 | if IS_NUMPY_INSTALLED: 8 | import numpy as np 9 | 10 | 11 | if IS_PANDAS_INSTALLED: 12 | import pandas as pd 13 | 14 | EQUALITY_CHECKERS = {} 15 | 16 | 17 | if IS_NUMPY_INSTALLED: 18 | EQUALITY_CHECKERS[np.ndarray] = lambda a, b: bool((a == b).all()) 19 | 20 | 21 | if IS_PANDAS_INSTALLED: 22 | EQUALITY_CHECKERS[pd.Series] = lambda a, b: a.equals(b) 23 | EQUALITY_CHECKERS[pd.DataFrame] = lambda a, b: a.equals(b) 24 | 25 | 26 | if IS_JAX_INSTALLED: 27 | EQUALITY_CHECKERS["jax.numpy.ndarray"] = lambda a, b: bool((a == b).all()) 28 | -------------------------------------------------------------------------------- /src/pybaum/__init__.py: -------------------------------------------------------------------------------- 1 | from pybaum.registry import get_registry 2 | from pybaum.tree_util import leaf_names 3 | from pybaum.tree_util import tree_equal 4 | from pybaum.tree_util import tree_flatten 5 | from pybaum.tree_util import tree_just_flatten 6 | from pybaum.tree_util import tree_just_yield 7 | from pybaum.tree_util import tree_map 8 | from pybaum.tree_util import tree_multimap 9 | from pybaum.tree_util import tree_unflatten 10 | from pybaum.tree_util import tree_update 11 | from pybaum.tree_util import tree_yield 12 | 13 | 14 | __all__ = [ 15 | "tree_flatten", 16 | "tree_just_flatten", 17 | "tree_just_yield", 18 | "tree_unflatten", 19 | "tree_map", 20 | "tree_multimap", 21 | "leaf_names", 22 | "tree_equal", 23 | "tree_update", 24 | "tree_yield", 25 | "get_registry", 26 | ] 27 | -------------------------------------------------------------------------------- /CHANGES.rst: -------------------------------------------------------------------------------- 1 | Changes 2 | ======= 3 | 4 | This is a record of all past pybaum releases and what went into them in reverse 5 | chronological order. We follow `semantic versioning `_ and all 6 | releases are available on `Anaconda.org 7 | `_. 8 | 9 | 10 | 0.1.x - 2022-xx-xx 11 | ------------------ 12 | 13 | - :gh:`2` replaces the pre-commit pipeline step with pre-commit.ci. 14 | - :gh:`11` implement :func:`pybaum.tree_util.tree_yield` and 15 | :func:`pybaum.tree_util.tree_just_yield`. 16 | - :gh:`10` adds default arguments to ``tree_just_flatten``. 17 | - :gh:`12` adds a section about the API to the docs. 18 | - :gh:`13` extends the readme. 19 | 20 | 21 | 0.1.0 - 2022-01-28 22 | ------------------ 23 | 24 | - :gh:`1` releases the initial version of pybaum. 25 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: PyPI 2 | 3 | on: push 4 | 5 | jobs: 6 | build-n-publish: 7 | name: Build and publish Python 🐍 distributions 📦 to PyPI 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@master 11 | 12 | - name: Set up Python 3.8 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: 3.8 16 | 17 | - name: Install pypa/build 18 | run: >- 19 | python -m 20 | pip install 21 | build 22 | --user 23 | - name: Build a binary wheel and a source tarball 24 | run: >- 25 | python -m 26 | build 27 | --sdist 28 | --wheel 29 | --outdir dist/ 30 | - name: Publish distribution 📦 to PyPI 31 | if: startsWith(github.ref, 'refs/tags') 32 | uses: pypa/gh-action-pypi-publish@master 33 | with: 34 | password: ${{ secrets.PYPI_API_TOKEN }} 35 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=pybaum 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = pybaum 3 | description = Tools to work with pytrees. 4 | long_description = file: README.rst 5 | long_description_content_type = text/x-rst 6 | url = https://github.com/OpenSourceEconomics/pybaum 7 | author = Janos Gabler, Tobias Raabe 8 | author_email = janos.gabler@gmail.com 9 | license = MIT 10 | license_file = LICENSE 11 | platforms = unix, linux, osx, cygwin, win32 12 | classifiers = 13 | Development Status :: 3 - Alpha 14 | License :: OSI Approved :: MIT License 15 | Operating System :: MacOS :: MacOS X 16 | Operating System :: Microsoft :: Windows 17 | Operating System :: POSIX 18 | Programming Language :: Python :: 3 19 | Programming Language :: Python :: 3 :: Only 20 | Topic :: Scientific/Engineering 21 | Topic :: Utilities 22 | 23 | [options] 24 | packages = find: 25 | python_requires = >=3.7 26 | include_package_data = True 27 | package_dir = 28 | =src 29 | zip_safe = False 30 | 31 | [options.packages.find] 32 | where = src 33 | 34 | [check-manifest] 35 | ignore = 36 | src/pybaum/_version.py 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Janoś Gabler, Tobias Raabe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this 6 | software and associated documentation files (the "Software"), to deal in the Software 7 | without restriction, including without limitation the rights to use, copy, modify, 8 | merge, publish, distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to the following 10 | conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all copies or 13 | substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 16 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 17 | PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT 19 | OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 20 | OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /docs/source/_static/images/coding.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | 9 | 10 | 11 | 12 | 13 | 15 | 16 | 17 | 18 | 19 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /tests/test_with_jax.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pybaum.config import IS_JAX_INSTALLED 3 | from pybaum.registry import get_registry 4 | from pybaum.tree_util import leaf_names 5 | from pybaum.tree_util import tree_equal 6 | from pybaum.tree_util import tree_flatten 7 | from pybaum.tree_util import tree_just_flatten 8 | 9 | if IS_JAX_INSTALLED: 10 | import jax.numpy as jnp 11 | else: 12 | # run the tests with normal numpy instead 13 | import numpy as jnp 14 | 15 | 16 | @pytest.fixture 17 | def tree(): 18 | return {"a": {"b": jnp.arange(4).reshape(2, 2)}, "c": jnp.ones(2)} 19 | 20 | 21 | @pytest.fixture 22 | def flat(): 23 | return [0, 1, 2, 3, 1, 1] 24 | 25 | 26 | @pytest.fixture 27 | def registry(): 28 | return get_registry(types=["jax.numpy.ndarray", "numpy.ndarray"]) 29 | 30 | 31 | def test_tree_just_flatten_with_jax(tree, registry, flat): 32 | got = tree_just_flatten(tree, registry=registry) 33 | assert got == flat 34 | 35 | 36 | def test_tree_flatten_with_jax(tree, registry, flat): 37 | got_flat, got_treedef = tree_flatten(tree, registry=registry) 38 | assert got_flat == flat 39 | assert tree_equal(got_treedef, tree) 40 | 41 | 42 | def test_leaf_names_with_jax(tree, registry): 43 | got = leaf_names(tree, registry=registry) 44 | expected = ["a_b_0_0", "a_b_0_1", "a_b_1_0", "a_b_1_1", "c_0", "c_1"] 45 | assert got == expected 46 | -------------------------------------------------------------------------------- /docs/source/_static/images/books.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/pybaum/registry.py: -------------------------------------------------------------------------------- 1 | from pybaum.registry_entries import FUNC_DICT 2 | 3 | 4 | def get_registry(types=None, include_defaults=True): 5 | """Create a pytree registry. 6 | 7 | Args: 8 | types (list): A list strings with the names of types that should be included in 9 | the registry, i.e. considered containers and not leaves by the functions 10 | that work with pytrees. Currently we support: 11 | - "tuple" 12 | - "dict" 13 | - "list" 14 | - :class:`collections.namedtuple` or :class:`typing.NamedTuple` 15 | - :obj:`None` 16 | - :class:`collections.OrderedDict` 17 | - "numpy.ndarray" 18 | - "jax.numpy.ndarray" 19 | - "pandas.Series" 20 | - "pandas.DataFrame" 21 | include_defaults (bool): Whether the default pytree containers "tuple", "dict" 22 | "list", "None", "namedtuple" and "OrderedDict" should be included even if 23 | not specified in `types`. 24 | 25 | Returns: 26 | dict: A pytree registry. 27 | 28 | """ 29 | types = [] if types is None else types 30 | 31 | if include_defaults: 32 | default_types = {"list", "tuple", "dict", "None", "namedtuple", "OrderedDict"} 33 | types = list(set(types) | default_types) 34 | 35 | registry = {} 36 | for typ in types: 37 | new_entry = FUNC_DICT[typ]() 38 | registry = {**registry, **new_entry} 39 | 40 | return registry 41 | -------------------------------------------------------------------------------- /docs/source/_static/images/book.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = pytest, sphinx 3 | skipsdist = True 4 | skip_missing_interpreters = True 5 | 6 | [testenv] 7 | basepython = python 8 | 9 | [testenv:pytest] 10 | setenv = 11 | CONDA_DLL_SEARCH_MODIFICATION_ENABLE = 1 12 | conda_channels = 13 | conda-forge 14 | defaults 15 | conda_deps = 16 | conda-build 17 | numpy 18 | pandas 19 | pytest 20 | pytest-cov 21 | pytest-mock 22 | pytest-xdist 23 | jax 24 | jaxlib 25 | commands = pytest {posargs} 26 | 27 | [testenv:pytest-windows] 28 | setenv = 29 | CONDA_DLL_SEARCH_MODIFICATION_ENABLE = 1 30 | conda_channels = 31 | conda-forge 32 | defaults 33 | conda_deps = 34 | conda-build 35 | numpy 36 | pandas 37 | pytest 38 | pytest-cov 39 | pytest-mock 40 | pytest-xdist 41 | commands = pytest {posargs} 42 | 43 | [testenv:sphinx] 44 | changedir = docs/source 45 | conda_env = docs/rtd_environment.yml 46 | commands = 47 | sphinx-build -T -b html -d {envtmpdir}/doctrees . {envtmpdir}/html 48 | - sphinx-build -T -b linkcheck -d {envtmpdir}/doctrees . {envtmpdir}/linkcheck 49 | 50 | 51 | [doc8] 52 | ignore = 53 | D002, 54 | D004, 55 | max-line-length = 88 56 | 57 | [flake8] 58 | max-line-length = 88 59 | ignore = 60 | D ; ignores docstring style errors, enable if you are nit-picky 61 | E203 ; ignores whitespace around : which is enforced by Black 62 | W503 ; ignores linebreak before binary operator which is enforced by Black 63 | RST304 ; ignores check for valid rst roles because it is too aggressive 64 | T001 ; ignore print statements 65 | RST301 ; ignores unexpected indentations in docstrings because it was not compatible with google style docstrings 66 | RST203 ; gave false positives 67 | RST202 ; gave false positives 68 | RST201 ; gave false positives 69 | W605 ; ignores regex relevant escape sequences 70 | PT001 ; ignores brackets for fixtures. 71 | per-file-ignores = 72 | docs/source/conf.py:E501, E800 73 | warn-symbols = 74 | pytest.mark.wip = Remove 'wip' mark for tests. 75 | 76 | [pytest] 77 | addopts = --doctest-modules 78 | markers = 79 | wip: Tests that are work-in-progress. 80 | slow: Tests that take a long time to run and are skipped in continuous integration. 81 | norecursedirs = 82 | docs 83 | .tox 84 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | src/pybaum/_version.py 133 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: main 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - '*' 9 | 10 | # Automatically cancel a previous run. 11 | concurrency: 12 | group: ${{ github.head_ref || github.run_id }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | 17 | run-tests: 18 | 19 | name: Run tests for ${{ matrix.os }} on ${{ matrix.python-version }} 20 | runs-on: ${{ matrix.os }} 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | os: ['ubuntu-latest', 'macos-latest'] 26 | python-version: ['3.7', '3.8', '3.9', '3.10'] 27 | 28 | steps: 29 | - uses: actions/checkout@v2 30 | - uses: conda-incubator/setup-miniconda@v2 31 | with: 32 | auto-update-conda: true 33 | python-version: ${{ matrix.python-version }} 34 | 35 | - name: Install core dependencies. 36 | shell: bash -l {0} 37 | run: conda install -c conda-forge tox-conda 38 | 39 | - name: Run pytest. 40 | shell: bash -l {0} 41 | run: tox -e pytest -- -m "not slow" --cov-report=xml --cov=./ 42 | 43 | - name: Upload coverage report. 44 | if: runner.os == 'Linux' && matrix.python-version == '3.9' 45 | uses: codecov/codecov-action@v1 46 | with: 47 | token: ${{ secrets.CODECOV_TOKEN }} 48 | 49 | run-tests-windows: 50 | 51 | name: Run tests for ${{ matrix.os }} on ${{ matrix.python-version }} 52 | runs-on: ${{ matrix.os }} 53 | 54 | strategy: 55 | fail-fast: false 56 | matrix: 57 | os: ['windows-latest'] 58 | python-version: ['3.7', '3.8', '3.9', '3.10'] 59 | 60 | steps: 61 | - uses: actions/checkout@v2 62 | - uses: conda-incubator/setup-miniconda@v2 63 | with: 64 | auto-update-conda: true 65 | python-version: ${{ matrix.python-version }} 66 | 67 | - name: Install core dependencies. 68 | shell: bash -l {0} 69 | run: conda install -c conda-forge tox-conda 70 | 71 | - name: Run pytest. 72 | shell: bash -l {0} 73 | run: tox -e pytest-windows -- -m "not slow" 74 | 75 | docs: 76 | 77 | name: Run documentation. 78 | runs-on: ubuntu-latest 79 | 80 | steps: 81 | - uses: actions/checkout@v2 82 | - uses: conda-incubator/setup-miniconda@v2 83 | with: 84 | auto-update-conda: true 85 | python-version: 3.9 86 | 87 | - name: Install core dependencies. 88 | shell: bash -l {0} 89 | run: conda install -c conda-forge tox-conda 90 | 91 | - name: Build docs 92 | shell: bash -l {0} 93 | run: tox -e sphinx 94 | -------------------------------------------------------------------------------- /src/pybaum/typecheck.py: -------------------------------------------------------------------------------- 1 | from pybaum.config import IS_JAX_INSTALLED 2 | from pybaum.config import IS_NUMPY_INSTALLED 3 | 4 | if IS_JAX_INSTALLED: 5 | import jax.numpy as jnp 6 | 7 | if IS_NUMPY_INSTALLED: 8 | import numpy as np 9 | 10 | 11 | def get_type(obj): 12 | """Get type of candidate objects in a pytree. 13 | 14 | This function allows us to reliably identify namedtuples, NamedTuples and jax arrays 15 | for which standard ``type`` function does not work. 16 | 17 | Args: 18 | obj: The object to be checked 19 | 20 | Returns: 21 | type or str: The type of the object or a string with the type name. 22 | 23 | """ 24 | if _is_namedtuple(obj): 25 | out = "namedtuple" 26 | elif _is_jax_array(obj): 27 | out = "jax.numpy.ndarray" 28 | else: 29 | out = type(obj) 30 | return out 31 | 32 | 33 | def _is_namedtuple(obj): 34 | """Check if an object is a namedtuple. 35 | 36 | As in JAX we treat collections.namedtuple and typing.NamedTuple both as 37 | namedtuple but the exact type is preserved in the unflatten function. 38 | 39 | namedtuples are discovered by being instances of tuple and having a 40 | ``_fields`` attribute as suggested by Raymond Hettinger 41 | `here `_. 42 | 43 | Moreover we check for the presence of a ``_replace`` method because we need when 44 | unflattening pytrees. 45 | 46 | This can produce false positives but in most cases would still result in desired 47 | behavior. 48 | 49 | Args: 50 | obj: The object to be checked 51 | 52 | Returns: 53 | bool 54 | 55 | """ 56 | out = ( 57 | isinstance(obj, tuple) and hasattr(obj, "_fields") and hasattr(obj, "_replace") 58 | ) 59 | return out 60 | 61 | 62 | def _is_jax_array(obj): 63 | """Check if an object is a jax array. 64 | 65 | The exact type of jax arrays has changed over time and is an implementation detail. 66 | 67 | Instead we rely on isinstance checks which will likely be more stable in the future. 68 | However, the behavior of isinstance for jax arrays has also changed over time. For 69 | jax versions before 0.2.21, standard numpy arrays were instances of jax arrays, 70 | now they are not. 71 | 72 | Resources: 73 | ---------- 74 | 75 | - https://github.com/google/jax/issues/2115 76 | - https://github.com/google/jax/issues/2014 77 | - https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0221-sept-23-2021 78 | - https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0318-sep-26-2022 79 | 80 | Args: 81 | obj: The object to be checked 82 | 83 | Returns: 84 | bool 85 | 86 | """ 87 | if not IS_JAX_INSTALLED: 88 | out = False 89 | elif IS_NUMPY_INSTALLED: 90 | out = isinstance(obj, jnp.ndarray) and not isinstance(obj, np.ndarray) 91 | else: 92 | out = isinstance(obj, jnp.ndarray) 93 | return out 94 | -------------------------------------------------------------------------------- /docs/source/_static/images/light-bulb.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 11 | 13 | 15 | 17 | 19 | 21 | 23 | 25 | 27 | 29 | 30 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | pybaum 2 | ====== 3 | 4 | .. start-badges 5 | 6 | .. image:: https://img.shields.io/pypi/v/pybaum?color=blue 7 | :alt: PyPI 8 | :target: https://pypi.org/project/pybaum 9 | 10 | .. image:: https://img.shields.io/pypi/pyversions/pybaum 11 | :alt: PyPI - Python Version 12 | :target: https://pypi.org/project/pybaum 13 | 14 | .. image:: https://img.shields.io/conda/vn/conda-forge/pybaum.svg 15 | :target: https://anaconda.org/conda-forge/pybaum 16 | 17 | .. image:: https://img.shields.io/conda/pn/conda-forge/pybaum.svg 18 | :target: https://anaconda.org/conda-forge/pybaum 19 | 20 | .. image:: https://img.shields.io/pypi/l/pybaum 21 | :alt: PyPI - License 22 | :target: https://pypi.org/project/pybaum 23 | 24 | .. image:: https://readthedocs.org/projects/pybaum/badge/?version=latest 25 | :target: https://pybaum.readthedocs.io/en/latest 26 | 27 | .. image:: https://img.shields.io/github/actions/workflow/status/OpenSourceEconomics/pybaum/main.yml?branch=main 28 | :target: https://github.com/OpenSourceEconomics/pybaum/actions?query=branch%3Amain 29 | 30 | .. image:: https://codecov.io/gh/OpenSourceEconomics/pybaum/branch/main/graph/badge.svg 31 | :target: https://codecov.io/gh/OpenSourceEconomics/pybaum 32 | 33 | .. image:: https://results.pre-commit.ci/badge/github/OpenSourceEconomics/pybaum/main.svg 34 | :target: https://results.pre-commit.ci/latest/github/OpenSourceEconomics/pybaum/main 35 | :alt: pre-commit.ci status 36 | 37 | .. image:: https://img.shields.io/badge/code%20style-black-000000.svg 38 | :target: https://github.com/psf/black 39 | 40 | .. end-badges 41 | 42 | Installation 43 | ------------ 44 | 45 | pybaum is available on `PyPI `_ and `Anaconda.org 46 | `_. Install it with 47 | 48 | .. code-block:: console 49 | 50 | $ pip install pybaum 51 | 52 | # or 53 | 54 | $ conda install -c conda-forge pybaum 55 | 56 | 57 | About 58 | ----- 59 | 60 | pybaum provides tools to work with pytrees which is a concept borrowed from `JAX 61 | `_. 62 | 63 | What is a pytree? 64 | 65 | In pybaum, we use the term pytree to refer to a tree-like structure built out of 66 | container-like Python objects. Classes are considered container-like if they are in the 67 | pytree registry, which by default includes lists, tuples, and dicts. That is: 68 | 69 | 1. Any object whose type is not in the pytree container registry is considered a leaf 70 | pytree. 71 | 72 | 2. Any object whose type is in the pytree container registry, and which contains 73 | pytrees, is considered a pytree. 74 | 75 | For each entry in the pytree container registry, a container-like type is registered 76 | with a pair of functions that specify how to convert an instance of the container type 77 | to a (children, metadata) pair and how to convert such a pair back to an instance of the 78 | container type. Using these functions, pybaum can canonicalize any tree of registered 79 | container objects into tuples. 80 | 81 | Example pytrees: 82 | 83 | .. code-block:: python 84 | 85 | [1, "a", object()] # 3 leaves 86 | 87 | (1, (2, 3), ()) # 3 leaves 88 | 89 | [1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves 90 | 91 | pybaum can be extended to consider other container types as pytrees. 92 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: check-merge-conflict 6 | - id: debug-statements 7 | - id: end-of-file-fixer 8 | - repo: https://github.com/asottile/reorder_python_imports 9 | rev: v3.9.0 10 | hooks: 11 | - id: reorder-python-imports 12 | types: [python] 13 | - repo: https://github.com/pre-commit/pre-commit-hooks 14 | rev: v4.4.0 15 | hooks: 16 | - id: check-added-large-files 17 | args: ['--maxkb=100'] 18 | - id: check-case-conflict 19 | - id: check-merge-conflict 20 | - id: check-vcs-permalinks 21 | - id: check-yaml 22 | - id: debug-statements 23 | - id: end-of-file-fixer 24 | - id: fix-byte-order-marker 25 | - id: mixed-line-ending 26 | - id: no-commit-to-branch 27 | args: [--branch, main] 28 | - id: trailing-whitespace 29 | - repo: https://github.com/pre-commit/pygrep-hooks 30 | rev: v1.9.0 31 | hooks: 32 | - id: python-check-blanket-noqa 33 | - id: python-check-mock-methods 34 | - id: python-no-eval 35 | - id: python-no-log-warn 36 | - id: python-use-type-annotations 37 | - id: rst-backticks 38 | - id: rst-directive-colons 39 | - id: rst-inline-touching-normal 40 | - id: text-unicode-replacement-char 41 | - repo: https://github.com/asottile/blacken-docs 42 | rev: v1.12.1 43 | hooks: 44 | - id: blacken-docs 45 | additional_dependencies: [black==22.3.0] 46 | types: [rst] 47 | - repo: https://github.com/psf/black 48 | rev: 22.12.0 49 | hooks: 50 | - id: black 51 | language_version: python3.10 52 | - repo: https://github.com/PyCQA/flake8 53 | rev: 5.0.4 54 | hooks: 55 | - id: flake8 56 | types: [python] 57 | additional_dependencies: [ 58 | flake8-alfred, 59 | flake8-bugbear, 60 | flake8-builtins, 61 | flake8-comprehensions, 62 | flake8-docstrings, 63 | flake8-eradicate, 64 | flake8-print, 65 | flake8-pytest-style, 66 | flake8-todo, 67 | flake8-typing-imports, 68 | flake8-unused-arguments, 69 | pep8-naming, 70 | pydocstyle, 71 | Pygments, 72 | ] 73 | - repo: https://github.com/PyCQA/doc8 74 | rev: v1.0.0 75 | hooks: 76 | - id: doc8 77 | - repo: meta 78 | hooks: 79 | - id: check-hooks-apply 80 | - id: check-useless-excludes 81 | # - id: identity # Prints all files passed to pre-commits. Debugging. 82 | - repo: https://github.com/mgedmin/check-manifest 83 | rev: "0.49" 84 | hooks: 85 | - id: check-manifest 86 | - repo: https://github.com/PyCQA/doc8 87 | rev: v1.0.0 88 | hooks: 89 | - id: doc8 90 | - repo: https://github.com/asottile/setup-cfg-fmt 91 | rev: v2.2.0 92 | hooks: 93 | - id: setup-cfg-fmt 94 | - repo: https://github.com/econchick/interrogate 95 | rev: 1.5.0 96 | hooks: 97 | - id: interrogate 98 | args: [-v, --fail-under=20] 99 | exclude: ^(tests|docs|setup\.py) 100 | - repo: https://github.com/codespell-project/codespell 101 | rev: v2.2.2 102 | hooks: 103 | - id: codespell 104 | - repo: https://github.com/asottile/pyupgrade 105 | rev: v3.3.1 106 | hooks: 107 | - id: pyupgrade 108 | args: [--py37-plus] 109 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Remove execution count for notebook cells. */ 2 | div.prompt { 3 | display: none; 4 | } 5 | 6 | /* Getting started index page */ 7 | 8 | .intro-card { 9 | background: #fff; 10 | border-radius: 0; 11 | padding: 30px 10px 10px 10px; 12 | margin: 10px 0px; 13 | max-height: 85%; 14 | } 15 | 16 | .intro-card .card-text { 17 | margin: 20px 0px; 18 | } 19 | 20 | div#index-container { 21 | padding-bottom: 20px; 22 | } 23 | 24 | a#index-link { 25 | color: #333; 26 | text-decoration: none; 27 | } 28 | 29 | /* reference to user guide */ 30 | .gs-torefguide { 31 | align-items: center; 32 | font-size: 0.9rem; 33 | } 34 | 35 | .gs-torefguide .badge { 36 | background-color: #130654; 37 | margin: 10px 10px 10px 0px; 38 | padding: 5px; 39 | } 40 | 41 | .gs-torefguide a { 42 | margin-left: 5px; 43 | color: #130654; 44 | border-bottom: 1px solid #FFCA00f3; 45 | box-shadow: 0px -10px 0px #FFCA00f3 inset; 46 | } 47 | 48 | .gs-torefguide p { 49 | margin-top: 1rem; 50 | } 51 | 52 | .gs-torefguide a:hover { 53 | margin-left: 5px; 54 | color: grey; 55 | text-decoration: none; 56 | border-bottom: 1px solid #b2ff80f3; 57 | box-shadow: 0px -10px 0px #b2ff80f3 inset; 58 | } 59 | 60 | /* selecting constraints guide */ 61 | .intro-card { 62 | background:#FFF; 63 | border-radius:0; 64 | padding: 30px 10px 10px 10px; 65 | margin: 10px 0px; 66 | } 67 | 68 | .intro-card .card-text { 69 | margin:20px 0px; 70 | /*min-height: 150px; */ 71 | } 72 | 73 | .intro-card .card-img-top { 74 | margin: 10px; 75 | } 76 | 77 | .install-block { 78 | padding-bottom: 30px; 79 | } 80 | 81 | .install-card .card-header { 82 | border: none; 83 | background-color:white; 84 | color: #150458; 85 | font-size: 1.1rem; 86 | font-weight: bold; 87 | padding: 1rem 1rem 0rem 1rem; 88 | } 89 | 90 | .install-card .card-footer { 91 | border: none; 92 | background-color:white; 93 | } 94 | 95 | .install-card pre { 96 | margin: 0 1em 1em 1em; 97 | } 98 | 99 | .custom-button { 100 | background-color:#DCDCDC; 101 | border: none; 102 | color: #484848; 103 | text-align: center; 104 | text-decoration: none; 105 | display: inline-block; 106 | font-size: 0.9rem; 107 | border-radius: 0.5rem; 108 | max-width: 120px; 109 | padding: 0.5rem 0rem; 110 | } 111 | 112 | .custom-button a { 113 | color: #484848; 114 | } 115 | 116 | .custom-button p { 117 | margin-top: 0; 118 | margin-bottom: 0rem; 119 | color: #484848; 120 | } 121 | 122 | /* selecting constraints guide collapsed cards */ 123 | 124 | .tutorial-accordion { 125 | margin-top: 20px; 126 | margin-bottom: 20px; 127 | } 128 | 129 | .tutorial-card .card-header.card-link .btn { 130 | margin-right: 12px; 131 | } 132 | 133 | .tutorial-card .card-header.card-link .btn:after { 134 | content: "-"; 135 | } 136 | 137 | .tutorial-card .card-header.card-link.collapsed .btn:after { 138 | content: "+"; 139 | } 140 | 141 | .tutorial-card-header-1 { 142 | justify-content: space-between; 143 | align-items: center; 144 | } 145 | 146 | .tutorial-card-header-2 { 147 | justify-content: flex-start; 148 | align-items: center; 149 | font-size: 1.3rem; 150 | } 151 | 152 | .tutorial-card .card-header { 153 | cursor: pointer; 154 | background-color: white; 155 | } 156 | 157 | .tutorial-card .card-body { 158 | background-color: #F0F0F0; 159 | } 160 | 161 | .tutorial-card .badge:hover { 162 | background-color: grey; 163 | } 164 | 165 | /* tables in selecting constraints guide */ 166 | 167 | table.rows th { 168 | background-color: #F0F0F0; 169 | border-style: solid solid solid solid; 170 | border-width: 0px 0px 0px 0px; 171 | border-color: #F0F0F0; 172 | text-align: center; 173 | } 174 | 175 | table.rows tr:nth-child(even) { 176 | background-color: #F0F0F0; 177 | text-align: right; 178 | } 179 | table.rows tr:nth-child(odd) { 180 | background-color: #FFFFFF; 181 | text-align: right; 182 | } 183 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib.metadata import version 3 | 4 | 5 | author = "Janos Gabler, Tobias Raabe" 6 | 7 | # Set variable so that todos are shown in local build 8 | on_rtd = os.environ.get("READTHEDOCS") == "True" 9 | 10 | 11 | # -- General configuration ------------------------------------------------ 12 | 13 | # If your documentation needs a minimal Sphinx version, state it here. 14 | 15 | # Add any Sphinx extension module names here, as strings. They can be 16 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 17 | # ones. 18 | extensions = [ 19 | "sphinx.ext.autodoc", 20 | "sphinx.ext.todo", 21 | "sphinx.ext.coverage", 22 | "sphinx.ext.extlinks", 23 | "sphinx.ext.intersphinx", 24 | "sphinx.ext.mathjax", 25 | "sphinx.ext.viewcode", 26 | "sphinx.ext.napoleon", 27 | "sphinx_panels", 28 | "autoapi.extension", 29 | ] 30 | 31 | autodoc_member_order = "bysource" 32 | 33 | autodoc_mock_imports = [ 34 | "pandas", 35 | "pytest", 36 | "numpy", 37 | "jax", 38 | ] 39 | 40 | extlinks = { 41 | "ghuser": ("https://github.com/%s", "@"), 42 | "gh": ("https://github.com/OpenSourceEconomics/pybaum/pulls/%s", "#"), 43 | } 44 | 45 | intersphinx_mapping = { 46 | "numpy": ("https://numpy.org/doc/stable", None), 47 | "np": ("https://numpy.org/doc/stable", None), 48 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), 49 | "pd": ("https://pandas.pydata.org/pandas-docs/stable", None), 50 | "python": ("https://docs.python.org/3.9", None), 51 | } 52 | 53 | # Add any paths that contain templates here, relative to this directory. 54 | templates_path = ["_templates"] 55 | html_static_path = ["_static"] 56 | 57 | # The suffix(es) of source filenames. 58 | # You can specify multiple suffix as a list of string: 59 | source_suffix = ".rst" 60 | 61 | # The master toctree document. 62 | master_doc = "index" 63 | 64 | # General information about the project. 65 | project = "pybaum" 66 | copyright = f"2022, {author}" # noqa: A001 67 | 68 | # The version info for the project you're documenting, acts as replacement for 69 | # |version| and |release|, also used in various other places throughout the 70 | # built documents. 71 | 72 | # The version, including alpha/beta/rc tags, but not commit hash and datestamps 73 | release = version("pybaum") 74 | # The short X.Y version. 75 | version = ".".join(release.split(".")[:2]) 76 | 77 | # The language for content autogenerated by Sphinx. Refer to documentation 78 | # for a list of supported languages. 79 | 80 | # This is also used if you do content translation via gettext catalogs. 81 | # Usually you set "language" from the command line for these cases. 82 | language = None 83 | 84 | # List of patterns, relative to source directory, that match files and 85 | # directories to ignore when looking for source files. 86 | # This patterns also effect to html_static_path and html_extra_path 87 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] 88 | 89 | # The name of the Pygments (syntax highlighting) style to use. 90 | pygments_style = "sphinx" 91 | 92 | # If true, `todo` and `todoList` produce output, else they produce nothing. 93 | if on_rtd: 94 | pass 95 | else: 96 | todo_include_todos = True 97 | todo_emit_warnings = True 98 | 99 | # Remove prefixed $ for bash, >>> for Python prompts, and In [1]: for IPython prompts. 100 | copybutton_prompt_text = r"\$ |>>> |In \[\d\]: " 101 | copybutton_prompt_is_regexp = True 102 | 103 | # Configuration for autoapi 104 | autoapi_type = "python" 105 | autoapi_dirs = ["../../src"] 106 | autoapi_keep_files = False 107 | autoapi_add_toctree_entry = False 108 | 109 | 110 | # -- Options for HTML output ---------------------------------------------- 111 | 112 | # The theme to use for HTML and HTML Help pages. See the documentation for 113 | # a list of builtin themes. 114 | html_theme = "pydata_sphinx_theme" 115 | 116 | # html_logo = "_static/images/logo.svg" 117 | 118 | html_theme_options = { 119 | "github_url": "https://github.com/OpenSourceEconomics/pybaum", 120 | } 121 | 122 | html_css_files = ["css/custom.css"] 123 | 124 | 125 | # Add any paths that contain custom static files (such as style sheets) here, 126 | # relative to this directory. They are copied after the builtin static files, 127 | # so a file named "default.css" will overwrite the builtin "default.css". 128 | # html_static_path = ["_static"] # noqa: E800 129 | 130 | # Custom sidebar templates, must be a dictionary that maps document names 131 | # to template names. 132 | 133 | # This is required for the alabaster theme 134 | # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars 135 | html_sidebars = { 136 | "**": [ 137 | "relations.html", # needs 'show_related': True theme option to display 138 | "searchbox.html", 139 | ] 140 | } 141 | -------------------------------------------------------------------------------- /src/pybaum/registry_entries.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections import OrderedDict 3 | from itertools import product 4 | 5 | from pybaum.config import IS_JAX_INSTALLED 6 | from pybaum.config import IS_NUMPY_INSTALLED 7 | from pybaum.config import IS_PANDAS_INSTALLED 8 | 9 | if IS_NUMPY_INSTALLED: 10 | import numpy as np 11 | 12 | if IS_PANDAS_INSTALLED: 13 | import pandas as pd 14 | 15 | if IS_JAX_INSTALLED: 16 | import jax 17 | 18 | 19 | def _none(): 20 | """Create registry entry for NoneType.""" 21 | entry = { 22 | type(None): { 23 | "flatten": lambda tree: ([], None), # noqa: U100 24 | "unflatten": lambda aux_data, children: None, # noqa: U100 25 | "names": lambda tree: [], # noqa: U100 26 | } 27 | } 28 | return entry 29 | 30 | 31 | def _list(): 32 | """Create registry entry for list.""" 33 | entry = { 34 | list: { 35 | "flatten": lambda tree: (tree, None), 36 | "unflatten": lambda aux_data, children: children, # noqa: U100 37 | "names": lambda tree: [f"{i}" for i in range(len(tree))], 38 | }, 39 | } 40 | return entry 41 | 42 | 43 | def _dict(): 44 | """Create registry entry for dict.""" 45 | entry = { 46 | dict: { 47 | "flatten": lambda tree: (list(tree.values()), list(tree)), 48 | "unflatten": lambda aux_data, children: dict(zip(aux_data, children)), 49 | "names": lambda tree: list(map(str, list(tree))), 50 | }, 51 | } 52 | return entry 53 | 54 | 55 | def _tuple(): 56 | """Create registry entry for tuple.""" 57 | entry = { 58 | tuple: { 59 | "flatten": lambda tree: (list(tree), None), 60 | "unflatten": lambda aux_data, children: tuple(children), # noqa: U100 61 | "names": lambda tree: [f"{i}" for i in range(len(tree))], 62 | }, 63 | } 64 | return entry 65 | 66 | 67 | def _namedtuple(): 68 | """Create registry entry for namedtuple and NamedTuple.""" 69 | entry = { 70 | "namedtuple": { 71 | "flatten": lambda tree: (list(tree), tree), 72 | "unflatten": _unflatten_namedtuple, 73 | "names": lambda tree: list(tree._fields), 74 | }, 75 | } 76 | return entry 77 | 78 | 79 | def _unflatten_namedtuple(aux_data, leaves): 80 | replacements = dict(zip(aux_data._fields, leaves)) 81 | out = aux_data._replace(**replacements) 82 | return out 83 | 84 | 85 | def _ordereddict(): 86 | """Create registry entry for OrderedDict.""" 87 | entry = { 88 | OrderedDict: { 89 | "flatten": lambda tree: (list(tree.values()), list(tree)), 90 | "unflatten": lambda aux_data, children: OrderedDict( 91 | zip(aux_data, children) 92 | ), 93 | "names": lambda tree: list(map(str, list(tree))), 94 | }, 95 | } 96 | return entry 97 | 98 | 99 | def _numpy_array(): 100 | """Create registry entry for numpy.ndarray.""" 101 | 102 | if IS_NUMPY_INSTALLED: 103 | entry = { 104 | np.ndarray: { 105 | "flatten": lambda arr: (arr.flatten().tolist(), arr.shape), 106 | "unflatten": lambda aux_data, leaves: np.array(leaves).reshape( 107 | aux_data 108 | ), 109 | "names": _array_element_names, 110 | }, 111 | } 112 | else: 113 | entry = {} 114 | return entry 115 | 116 | 117 | def _array_element_names(arr): 118 | dim_names = [map(str, range(n)) for n in arr.shape] 119 | names = list(map("_".join, itertools.product(*dim_names))) 120 | return names 121 | 122 | 123 | def _jax_array(): 124 | if IS_JAX_INSTALLED: 125 | entry = { 126 | "jax.numpy.ndarray": { 127 | "flatten": lambda arr: (arr.flatten().tolist(), arr.shape), 128 | "unflatten": lambda aux_data, leaves: jax.numpy.array(leaves).reshape( 129 | aux_data 130 | ), 131 | "names": _array_element_names, 132 | }, 133 | } 134 | else: 135 | entry = {} 136 | return entry 137 | 138 | 139 | def _pandas_series(): 140 | """Create registry entry for pandas.Series.""" 141 | if IS_PANDAS_INSTALLED: 142 | entry = { 143 | pd.Series: { 144 | "flatten": lambda sr: ( 145 | sr.tolist(), 146 | {"index": sr.index, "name": sr.name}, 147 | ), 148 | "unflatten": lambda aux_data, leaves: pd.Series(leaves, **aux_data), 149 | "names": lambda sr: list(sr.index.map(_index_element_to_string)), 150 | }, 151 | } 152 | else: 153 | entry = {} 154 | return entry 155 | 156 | 157 | def _pandas_dataframe(): 158 | """Create registry entry for pandas.DataFrame.""" 159 | if IS_PANDAS_INSTALLED: 160 | entry = { 161 | pd.DataFrame: { 162 | "flatten": _flatten_pandas_dataframe, 163 | "unflatten": _unflatten_pandas_dataframe, 164 | "names": _get_names_pandas_dataframe, 165 | } 166 | } 167 | else: 168 | entry = {} 169 | return entry 170 | 171 | 172 | def _flatten_pandas_dataframe(df): 173 | flat = df.to_numpy().flatten().tolist() 174 | aux_data = {"columns": df.columns, "index": df.index, "shape": df.shape} 175 | return flat, aux_data 176 | 177 | 178 | def _unflatten_pandas_dataframe(aux_data, leaves): 179 | out = pd.DataFrame( 180 | data=np.array(leaves).reshape(aux_data["shape"]), 181 | columns=aux_data["columns"], 182 | index=aux_data["index"], 183 | ) 184 | return out 185 | 186 | 187 | def _get_names_pandas_dataframe(df): 188 | index_strings = list(df.index.map(_index_element_to_string)) 189 | out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)] 190 | return out 191 | 192 | 193 | def _index_element_to_string(element): 194 | if isinstance(element, (tuple, list)): 195 | as_strings = [str(entry) for entry in element] 196 | res_string = "_".join(as_strings) 197 | else: 198 | res_string = str(element) 199 | 200 | return res_string 201 | 202 | 203 | FUNC_DICT = { 204 | "list": _list, 205 | "tuple": _tuple, 206 | "dict": _dict, 207 | "numpy.ndarray": _numpy_array, 208 | "jax.numpy.ndarray": _jax_array, 209 | "pandas.Series": _pandas_series, 210 | "pandas.DataFrame": _pandas_dataframe, 211 | "None": _none, 212 | "namedtuple": _namedtuple, 213 | "OrderedDict": _ordereddict, 214 | } 215 | -------------------------------------------------------------------------------- /tests/test_tree_util.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from collections import namedtuple 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import pytest 8 | from numpy.testing import assert_array_almost_equal as aaae 9 | from pybaum.registry import get_registry 10 | from pybaum.tree_util import leaf_names 11 | from pybaum.tree_util import tree_equal 12 | from pybaum.tree_util import tree_flatten 13 | from pybaum.tree_util import tree_map 14 | from pybaum.tree_util import tree_multimap 15 | from pybaum.tree_util import tree_unflatten 16 | from pybaum.tree_util import tree_update 17 | from pybaum.tree_util import tree_yield 18 | 19 | 20 | @pytest.fixture 21 | def example_tree(): 22 | return ( 23 | [0, np.array([1, 2]), {"a": pd.Series([3, 4], index=["c", "d"]), "b": 5}], 24 | 6, 25 | ) 26 | 27 | 28 | @pytest.fixture 29 | def example_flat(): 30 | return [0, np.array([1, 2]), pd.Series([3, 4], index=["c", "d"]), 5, 6] 31 | 32 | 33 | @pytest.fixture 34 | def example_treedef(example_tree): 35 | return example_tree 36 | 37 | 38 | @pytest.fixture 39 | def extended_treedef(example_tree): 40 | return example_tree 41 | 42 | 43 | @pytest.fixture 44 | def extended_registry(): 45 | types = ["pandas.DataFrame", "pandas.Series", "numpy.ndarray"] 46 | return get_registry(types=types) 47 | 48 | 49 | def test_tree_flatten(example_tree, example_flat, example_treedef): 50 | flat, treedef = tree_flatten(example_tree) 51 | assert treedef == example_treedef 52 | _assert_list_with_arrays_is_equal(flat, example_flat) 53 | 54 | 55 | def test_extended_tree_flatten(example_tree, extended_treedef, extended_registry): 56 | flat, treedef = tree_flatten(example_tree, registry=extended_registry) 57 | assert flat == list(range(7)) 58 | assert tree_equal(treedef, extended_treedef) 59 | 60 | 61 | def test_tree_flatten_with_is_leave(example_tree, extended_registry): 62 | flat, _ = tree_flatten( 63 | example_tree, 64 | is_leaf=lambda tree: isinstance(tree, np.ndarray), 65 | registry=extended_registry, 66 | ) 67 | expected_flat = [0, np.array([1, 2]), 3, 4, 5, 6] 68 | _assert_list_with_arrays_is_equal(flat, expected_flat) 69 | 70 | 71 | def test_tree_unflatten(example_flat, example_treedef, example_tree): 72 | unflat = tree_unflatten(example_treedef, example_flat) 73 | 74 | assert tree_equal(unflat, example_tree) 75 | 76 | 77 | def test_extended_tree_unflatten(example_tree, extended_treedef, extended_registry): 78 | unflat = tree_unflatten( 79 | extended_treedef, list(range(7)), registry=extended_registry 80 | ) 81 | assert tree_equal(unflat, example_tree) 82 | 83 | 84 | def test_tree_unflatten_with_is_leaf(example_tree, extended_registry): 85 | unflat = tree_unflatten( 86 | example_tree, 87 | ([0, np.array([1, 2]), 3, 4, 5, 6]), 88 | is_leaf=lambda tree: isinstance(tree, np.ndarray), 89 | registry=extended_registry, 90 | ) 91 | assert tree_equal(unflat, example_tree) 92 | 93 | 94 | def test_tree_map(): 95 | tree = [{"a": 1, "b": 2, "c": {"d": 3, "e": 4}}] 96 | calculated = tree_map(lambda x: x * 2, tree) 97 | expected = [{"a": 2, "b": 4, "c": {"d": 6, "e": 8}}] 98 | assert calculated == expected 99 | 100 | 101 | def test_tree_multimap(): 102 | tree = [{"a": 1, "b": 2, "c": {"d": 3, "e": 4}}] 103 | mapped = tree_map(lambda x: x**2, tree) 104 | multimapped = tree_multimap(lambda x, y: x * y, tree, tree) 105 | assert mapped == multimapped 106 | 107 | 108 | def test_leaf_names(example_tree): 109 | names = leaf_names(example_tree, separator="*") 110 | 111 | expected_names = ["0*0", "0*1", "0*2*a", "0*2*b", "1"] 112 | assert names == expected_names 113 | 114 | 115 | def test_extended_leaf_names(example_tree, extended_registry): 116 | names = leaf_names(example_tree, registry=extended_registry) 117 | expected_names = ["0_0", "0_1_0", "0_1_1", "0_2_a_c", "0_2_a_d", "0_2_b", "1"] 118 | assert names == expected_names 119 | 120 | 121 | def test_leaf_names_with_is_leaf(example_tree, extended_registry): 122 | names = leaf_names( 123 | example_tree, 124 | is_leaf=lambda tree: isinstance(tree, np.ndarray), 125 | registry=extended_registry, 126 | ) 127 | expected_names = ["0_0", "0_1", "0_2_a_c", "0_2_a_d", "0_2_b", "1"] 128 | assert names == expected_names 129 | 130 | 131 | def test_iterative_flatten_and_one_step_flatten_and_unflatten( 132 | example_tree, extended_registry 133 | ): 134 | first_step_flat, first_step_treedef = tree_flatten(example_tree) 135 | second_step_flat, second_step_treedef = tree_flatten( 136 | first_step_flat, registry=extended_registry 137 | ) 138 | one_step_flat, one_step_treedef = tree_flatten( 139 | example_tree, registry=extended_registry 140 | ) 141 | 142 | assert second_step_flat == one_step_flat 143 | 144 | one_step_unflat = tree_unflatten( 145 | one_step_treedef, one_step_flat, registry=extended_registry 146 | ) 147 | first_step_unflat = tree_unflatten( 148 | second_step_treedef, second_step_flat, registry=extended_registry 149 | ) 150 | second_step_unflat = tree_unflatten(first_step_treedef, first_step_unflat) 151 | 152 | assert tree_equal(one_step_unflat, example_tree) 153 | assert tree_equal(second_step_unflat, example_tree) 154 | 155 | 156 | def test_tree_update(example_tree): 157 | other = ([7, np.array([8, 9]), {"b": 10}], 11) 158 | updated = tree_update(example_tree, other) 159 | 160 | expected = ( 161 | [7, np.array([8, 9]), {"a": pd.Series([3, 4], index=["c", "d"]), "b": 10}], 162 | 11, 163 | ) 164 | assert tree_equal(updated, expected) 165 | 166 | 167 | def _assert_list_with_arrays_is_equal(list1, list2): 168 | for first, second in zip(list1, list2): 169 | if isinstance(first, np.ndarray): 170 | aaae(first, second) 171 | elif isinstance(first, (pd.DataFrame, pd.Series)): 172 | assert first.equals(second) 173 | else: 174 | assert first == second 175 | 176 | 177 | def test_flatten_df_all_columns(): 178 | registry = get_registry(types=["pandas.DataFrame"]) 179 | df = pd.DataFrame(index=["a", "b", "c"]) 180 | df["value"] = [1, 2, 3] 181 | df["bla"] = [4, 5, 6] 182 | 183 | flat, _ = tree_flatten(df, registry=registry) 184 | 185 | assert flat == [1, 4, 2, 5, 3, 6] 186 | 187 | 188 | def test_tree_yield(example_tree, example_treedef, example_flat): 189 | generator, treedef = tree_yield(example_tree) 190 | 191 | assert tree_equal(treedef, example_treedef) 192 | assert inspect.isgenerator(generator) 193 | for a, b in zip(generator, example_flat): 194 | if isinstance(a, (np.ndarray, pd.Series)): 195 | aaae(a, b) 196 | else: 197 | assert a == b 198 | 199 | 200 | def test_flatten_with_none(): 201 | flat, treedef = tree_flatten(None) 202 | assert flat == [] 203 | assert treedef is None 204 | 205 | 206 | def test_leaf_names_with_none(): 207 | names = leaf_names(None) 208 | assert names == [] 209 | 210 | 211 | def test_flatten_with_namedtuple(): 212 | bla = namedtuple("bla", ["a", "b"])(1, 2) 213 | flat, _ = tree_flatten(bla) 214 | assert flat == [1, 2] 215 | 216 | 217 | def test_names_with_namedtuple(): 218 | bla = namedtuple("bla", ["a", "b"])(1, 2) 219 | names = leaf_names(bla) 220 | assert names == ["a", "b"] 221 | 222 | 223 | def test_flatten_with_ordered_dict(): 224 | d = OrderedDict({"a": 1, "b": 2}) 225 | flat, _ = tree_flatten(d) 226 | assert flat == [1, 2] 227 | 228 | 229 | def test_names_with_ordered_dict(): 230 | d = OrderedDict({"a": 1, "b": 2}) 231 | names = leaf_names(d) 232 | assert names == ["a", "b"] 233 | -------------------------------------------------------------------------------- /src/pybaum/tree_util.py: -------------------------------------------------------------------------------- 1 | """Implement functionality similar to jax.tree_util in pure Python. 2 | 3 | The functions are not completely identical to jax. The most notable differences are: 4 | 5 | - Instead of a global registry of pytree nodes, most functions have a registry argument. 6 | - The treedef containing information to unflatten pytrees is implemented differently. 7 | 8 | """ 9 | from pybaum.equality import EQUALITY_CHECKERS 10 | from pybaum.registry import get_registry 11 | from pybaum.typecheck import get_type 12 | 13 | 14 | def tree_flatten(tree, is_leaf=None, registry=None): 15 | """Flatten a pytree and create a treedef. 16 | 17 | Args: 18 | tree: a pytree to flatten. 19 | is_leaf (callable or None): An optionally specified function that will be called 20 | at each flattening step. It should return a boolean, which indicates whether 21 | the flattening should traverse the current object, or if it should be 22 | stopped immediately, with the whole subtree being treated as a leaf. 23 | registry (dict or None): A pytree container registry that determines 24 | which types are considered container objects that should be flattened. 25 | ``is_leaf`` can override this in the sense that types that are in the 26 | registry are still considered a leaf but it cannot declare something a 27 | container that is not in the registry. None means that the default registry 28 | is used, i.e. that dicts, tuples and lists are considered containers. 29 | "extended" means that in addition numpy arrays and params DataFrames are 30 | considered containers. Passing a dictionary where the keys are types and the 31 | values are dicts with the entries "flatten", "unflatten" and "names" allows 32 | to completely override the default registries. 33 | 34 | Returns: 35 | A pair where the first element is a list of leaf values and the second 36 | element is a treedef representing the structure of the flattened tree. 37 | 38 | """ 39 | registry = _process_pytree_registry(registry) 40 | is_leaf = _process_is_leaf(is_leaf) 41 | 42 | flat = _tree_flatten(tree, is_leaf=is_leaf, registry=registry) 43 | # unflatten the flat tree to make a copy 44 | treedef = tree_unflatten(tree, flat, is_leaf=is_leaf, registry=registry) 45 | return flat, treedef 46 | 47 | 48 | def tree_just_flatten(tree, is_leaf=None, registry=None): 49 | """Flatten a pytree without creating a treedef. 50 | 51 | Args: 52 | tree: a pytree to flatten. 53 | is_leaf (callable or None): An optionally specified function that will be called 54 | at each flattening step. It should return a boolean, which indicates whether 55 | the flattening should traverse the current object, or if it should be 56 | stopped immediately, with the whole subtree being treated as a leaf. 57 | registry (dict or None): A pytree container registry that determines 58 | which types are considered container objects that should be flattened. 59 | ``is_leaf`` can override this in the sense that types that are in the 60 | registry are still considered a leaf but it cannot declare something a 61 | container that is not in the registry. None means that the default registry 62 | is used, i.e. that dicts, tuples and lists are considered containers. 63 | "extended" means that in addition numpy arrays and params DataFrames are 64 | considered containers. Passing a dictionary where the keys are types and the 65 | values are dicts with the entries "flatten", "unflatten" and "names" allows 66 | to completely override the default registries. 67 | 68 | Returns: 69 | A pair where the first element is a list of leaf values and the second 70 | element is a treedef representing the structure of the flattened tree. 71 | 72 | """ 73 | registry = _process_pytree_registry(registry) 74 | is_leaf = _process_is_leaf(is_leaf) 75 | 76 | flat = _tree_flatten(tree, is_leaf=is_leaf, registry=registry) 77 | return flat 78 | 79 | 80 | def _tree_flatten(tree, is_leaf, registry): 81 | out = [] 82 | tree_type = get_type(tree) 83 | 84 | if tree_type not in registry or is_leaf(tree): 85 | out.append(tree) 86 | else: 87 | subtrees, _ = registry[tree_type]["flatten"](tree) 88 | for subtree in subtrees: 89 | if get_type(subtree) in registry: 90 | out += _tree_flatten(subtree, is_leaf, registry) 91 | else: 92 | out.append(subtree) 93 | return out 94 | 95 | 96 | def tree_yield(tree, is_leaf=None, registry=None): 97 | """Yield leafs from a pytree and create the tree definition. 98 | 99 | Args: 100 | tree: a pytree. 101 | is_leaf (callable or None): An optionally specified function that will be called 102 | at each yield step. It should return a boolean, which indicates whether 103 | the generator should traverse the current object, or if it should be 104 | stopped immediately, with the whole subtree being treated as a leaf. 105 | registry (dict or None): A pytree container registry that determines 106 | which types are considered container objects that should be yielded. 107 | ``is_leaf`` can override this in the sense that types that are in the 108 | registry are still considered a leaf but it cannot declare something a 109 | container that is not in the registry. None means that the default registry 110 | is used, i.e. that dicts, tuples and lists are considered containers. 111 | "extended" means that in addition numpy arrays and params DataFrames are 112 | considered containers. Passing a dictionary where the keys are types and the 113 | values are dicts with the entries "flatten", "unflatten" and "names" allows 114 | to completely override the default registries. 115 | 116 | Returns: 117 | A pair where the first element is a generator of leaf values and the second 118 | element is a treedef representing the structure of the flattened tree. 119 | 120 | """ 121 | registry = _process_pytree_registry(registry) 122 | is_leaf = _process_is_leaf(is_leaf) 123 | 124 | flat = _tree_yield(tree, is_leaf=is_leaf, registry=registry) 125 | return flat, tree 126 | 127 | 128 | def tree_just_yield(tree, is_leaf=None, registry=None): 129 | """Yield leafs from a pytree without creating a treedef. 130 | 131 | Args: 132 | tree: a pytree. 133 | is_leaf (callable or None): An optionally specified function that will be called 134 | at each yield step. It should return a boolean, which indicates whether 135 | the generator should traverse the current object, or if it should be 136 | stopped immediately, with the whole subtree being treated as a leaf. 137 | registry (dict or None): A pytree container registry that determines 138 | which types are considered container objects that should be yielded. 139 | ``is_leaf`` can override this in the sense that types that are in the 140 | registry are still considered a leaf but it cannot declare something a 141 | container that is not in the registry. None means that the default registry 142 | is used, i.e. that dicts, tuples and lists are considered containers. 143 | "extended" means that in addition numpy arrays and params DataFrames are 144 | considered containers. Passing a dictionary where the keys are types and the 145 | values are dicts with the entries "flatten", "unflatten" and "names" allows 146 | to completely override the default registries. 147 | 148 | Returns: 149 | A generator of leaf values. 150 | 151 | """ 152 | registry = _process_pytree_registry(registry) 153 | is_leaf = _process_is_leaf(is_leaf) 154 | 155 | flat = _tree_yield(tree, is_leaf=is_leaf, registry=registry) 156 | return flat 157 | 158 | 159 | def _tree_yield(tree, is_leaf, registry): 160 | out = [] 161 | tree_type = get_type(tree) 162 | 163 | if tree_type not in registry or is_leaf(tree): 164 | yield tree 165 | else: 166 | subtrees, _ = registry[tree_type]["flatten"](tree) 167 | for subtree in subtrees: 168 | if get_type(subtree) in registry: 169 | yield from _tree_yield(subtree, is_leaf, registry) 170 | else: 171 | yield subtree 172 | return out 173 | 174 | 175 | def tree_unflatten(treedef, leaves, is_leaf=None, registry=None): 176 | """Reconstruct a pytree from the treedef and a list of leaves. 177 | 178 | The inverse of :func:`tree_flatten`. 179 | 180 | Args: 181 | treedef: the treedef to with information needed for reconstruction. 182 | leaves (list): the list of leaves to use for reconstruction. The list must match 183 | the leaves of the treedef. 184 | is_leaf (callable or None): An optionally specified function that will be called 185 | at each flattening step. It should return a boolean, which indicates whether 186 | the flattening should traverse the current object, or if it should be 187 | stopped immediately, with the whole subtree being treated as a leaf. 188 | registry (dict or None): A pytree container registry that determines 189 | which types are considered container objects that should be flattened. 190 | `is_leaf` can override this in the sense that types that are in the 191 | registry are still considered a leaf but it cannot declare something a 192 | container that is not in the registry. None means that the default registry 193 | is used, i.e. that dicts, tuples and lists are considered containers. 194 | "extended" means that in addition numpy arrays and params DataFrames are 195 | considered containers. Passing a dictionary where the keys are types and the 196 | values are dicts with the entries "flatten", "unflatten" and "names" allows 197 | to completely override the default registries. 198 | 199 | Returns: 200 | The reconstructed pytree, containing the ``leaves`` placed in the structure 201 | described by ``treedef``. 202 | 203 | """ 204 | registry = _process_pytree_registry(registry) 205 | is_leaf = _process_is_leaf(is_leaf) 206 | return _tree_unflatten(treedef, leaves, is_leaf=is_leaf, registry=registry) 207 | 208 | 209 | def _tree_unflatten(treedef, leaves, is_leaf, registry): 210 | leaves = iter(leaves) 211 | tree_type = get_type(treedef) 212 | 213 | if tree_type not in registry or is_leaf(treedef): 214 | return next(leaves) 215 | else: 216 | items, info = registry[tree_type]["flatten"](treedef) 217 | unflattened_items = [] 218 | for item in items: 219 | if get_type(item) in registry: 220 | unflattened_items.append( 221 | _tree_unflatten(item, leaves, is_leaf=is_leaf, registry=registry) 222 | ) 223 | else: 224 | unflattened_items.append(next(leaves)) 225 | return registry[tree_type]["unflatten"](info, unflattened_items) 226 | 227 | 228 | def tree_map(func, tree, is_leaf=None, registry=None): 229 | """Apply func to all leaves in tree. 230 | 231 | Args: 232 | func (callable): Function applied to each leaf in the tree. 233 | tree: A pytree. 234 | is_leaf (callable or None): An optionally specified function that will be called 235 | at each flattening step. It should return a boolean, which indicates whether 236 | the flattening should traverse the current object, or if it should be 237 | stopped immediately, with the whole subtree being treated as a leaf. 238 | registry (dict or None): A pytree container registry that determines 239 | which types are considered container objects that should be flattened. 240 | `is_leaf` can override this in the sense that types that are in the 241 | registry are still considered a leaf but it cannot declare something a 242 | container that is not in the registry. None means that the default registry 243 | is used, i.e. that dicts, tuples and lists are considered containers. 244 | "extended" means that in addition numpy arrays and params DataFrames are 245 | considered containers. Passing a dictionary where the keys are types and the 246 | values are dicts with the entries "flatten", "unflatten" and "names" allows 247 | to completely override the default registries. 248 | Returns: 249 | modified copy of tree. 250 | 251 | """ 252 | flat, treedef = tree_flatten(tree, is_leaf=is_leaf, registry=registry) 253 | modified = [func(i) for i in flat] 254 | new_tree = tree_unflatten(treedef, modified, is_leaf=is_leaf, registry=registry) 255 | return new_tree 256 | 257 | 258 | def tree_multimap(func, *trees, is_leaf=None, registry=None): 259 | """Apply func to leaves of multiple pytrees. 260 | 261 | Args: 262 | func (callable): Function applied to each leaf corresponding leaves of 263 | multiple py trees. 264 | trees: An arbitrary number of pytrees. All trees need to have the same 265 | structure. 266 | is_leaf (callable or None): An optionally specified function that will be called 267 | at each flattening step. It should return a boolean, which indicates whether 268 | the flattening should traverse the current object, or if it should be 269 | stopped immediately, with the whole subtree being treated as a leaf. 270 | registry (dict or None): A pytree container registry that determines 271 | which types are considered container objects that should be flattened. 272 | `is_leaf` can override this in the sense that types that are in the 273 | registry are still considered a leaf but it cannot declare something a 274 | container that is not in the registry. None means that the default registry 275 | is used, i.e. that dicts, tuples and lists are considered containers. 276 | "extended" means that in addition numpy arrays and params DataFrames are 277 | considered containers. Passing a dictionary where the keys are types and the 278 | values are dicts with the entries "flatten", "unflatten" and "names" allows 279 | to completely override the default registries. 280 | Returns: 281 | tree with the same structure as the elements in trees. 282 | 283 | """ 284 | flat_trees, treedefs = [], [] 285 | for tree in trees: 286 | flat, treedef = tree_flatten(tree, is_leaf=is_leaf, registry=registry) 287 | flat_trees.append(flat) 288 | treedefs.append(treedef) 289 | 290 | for treedef in treedefs: 291 | if treedef != treedefs[0]: 292 | raise ValueError("All trees must have the same structure.") 293 | 294 | modified = [func(*item) for item in zip(*flat_trees)] 295 | 296 | new_trees = tree_unflatten( 297 | treedefs[0], modified, is_leaf=is_leaf, registry=registry 298 | ) 299 | return new_trees 300 | 301 | 302 | def leaf_names(tree, is_leaf=None, registry=None, separator="_"): 303 | """Construct names for leaves in a pytree. 304 | 305 | Args: 306 | tree: a pytree to flatten. 307 | is_leaf (callable or None): An optionally specified function that will be called 308 | at each flattening step. It should return a boolean, which indicates whether 309 | the flattening should traverse the current object, or if it should be 310 | stopped immediately, with the whole subtree being treated as a leaf. 311 | registry (dict or None): A pytree container registry that determines 312 | which types are considered container objects that should be flattened. 313 | `is_leaf` can override this in the sense that types that are in the 314 | registry are still considered a leaf but it cannot declare something a 315 | container that is not in the registry. None means that the default registry 316 | is used, i.e. that dicts, tuples and lists are considered containers. 317 | "extended" means that in addition numpy arrays and params DataFrames are 318 | considered containers. Passing a dictionary where the keys are types and the 319 | values are dicts with the entries "flatten", "unflatten" and "names" allows 320 | to completely override the default registries. 321 | separator (str): String that separates the building blocks of the leaf name. 322 | Returns: 323 | list: List of strings with names for pytree leaves. 324 | 325 | """ 326 | registry = _process_pytree_registry(registry) 327 | is_leaf = _process_is_leaf(is_leaf) 328 | leaf_names = _leaf_names( 329 | tree, is_leaf=is_leaf, registry=registry, separator=separator 330 | ) 331 | return leaf_names 332 | 333 | 334 | def _leaf_names(tree, is_leaf, registry, separator, prefix=None): 335 | out = [] 336 | tree_type = get_type(tree) 337 | 338 | if tree_type not in registry or is_leaf(tree): 339 | out.append(prefix) 340 | else: 341 | subtrees, info = registry[tree_type]["flatten"](tree) 342 | names = registry[tree_type]["names"](tree) 343 | for name, subtree in zip(names, subtrees): 344 | if get_type(subtree) in registry: 345 | out += _leaf_names( 346 | subtree, 347 | is_leaf=is_leaf, 348 | registry=registry, 349 | separator=separator, 350 | prefix=_add_prefix(prefix, name, separator), 351 | ) 352 | else: 353 | out.append(_add_prefix(prefix, name, separator)) 354 | return out 355 | 356 | 357 | def _add_prefix(prefix, string, separator): 358 | if prefix not in (None, ""): 359 | out = separator.join([prefix, string]) 360 | else: 361 | out = string 362 | return out 363 | 364 | 365 | def _process_pytree_registry(registry): 366 | registry = registry if registry is not None else get_registry() 367 | return registry 368 | 369 | 370 | def _process_is_leaf(is_leaf): 371 | if is_leaf is None: 372 | return lambda tree: False # noqa: U100 373 | else: 374 | return is_leaf 375 | 376 | 377 | def tree_equal(tree, other, is_leaf=None, registry=None, equality_checkers=None): 378 | """Determine if two pytrees are equal. 379 | 380 | Two pytrees are considered equal if their leaves are equal and the names of their 381 | leaves are equal. While this definition of equality might not always make sense 382 | it makes sense in most cases and can be implemented relatively easily. 383 | 384 | Args: 385 | tree: A pytree. 386 | other: Another pytree. 387 | is_leaf (callable or None): An optionally specified function that will be called 388 | at each flattening step. It should return a boolean, which indicates whether 389 | the flattening should traverse the current object, or if it should be 390 | stopped immediately, with the whole subtree being treated as a leaf. 391 | registry (dict or None): A pytree container registry that determines 392 | which types are considered container objects that should be flattened. 393 | `is_leaf` can override this in the sense that types that are in the 394 | registry are still considered a leaf but it cannot declare something a 395 | container that is not in the registry. None means that the default registry 396 | is used, i.e. that dicts, tuples and lists are considered containers. 397 | "extended" means that in addition numpy arrays and params DataFrames are 398 | considered containers. Passing a dictionary where the keys are types and the 399 | values are dicts with the entries "flatten", "unflatten" and "names" allows 400 | to completely override the default registries. 401 | equality_checkers (dict, None): A dictionary where keys are types and values are 402 | functions which assess equality for the type of object. 403 | 404 | Returns: 405 | bool 406 | 407 | """ 408 | equality_checkers = ( 409 | EQUALITY_CHECKERS 410 | if equality_checkers is None 411 | else {**EQUALITY_CHECKERS, **equality_checkers} 412 | ) 413 | 414 | first_flat = tree_just_flatten(tree, is_leaf=is_leaf, registry=registry) 415 | second_flat = tree_just_flatten(other, is_leaf=is_leaf, registry=registry) 416 | 417 | first_names = leaf_names(tree, is_leaf=is_leaf, registry=registry) 418 | second_names = leaf_names(tree, is_leaf=is_leaf, registry=registry) 419 | 420 | equal = first_names == second_names 421 | 422 | if equal: 423 | for first, second in zip(first_flat, second_flat): 424 | check_func = equality_checkers.get(get_type(first), lambda a, b: a == b) 425 | equal = equal and check_func(first, second) 426 | if not equal: 427 | break 428 | 429 | return equal 430 | 431 | 432 | def tree_update(tree, other, is_leaf=None, registry=None): 433 | """Update leaves in a pytree with leaves from another pytree. 434 | 435 | The second pytree must be compatible with the first one but can be smaller. For 436 | example, lists can be shorter, dictionaries can contain subsets of entries, etc. 437 | 438 | Args: 439 | tree: A pytree. 440 | other: Another pytree. 441 | is_leaf (callable or None): An optionally specified function that will be called 442 | at each flattening step. It should return a boolean, which indicates whether 443 | the flattening should traverse the current object, or if it should be 444 | stopped immediately, with the whole subtree being treated as a leaf. 445 | registry (dict or None): A pytree container registry that determines 446 | which types are considered container objects that should be flattened. 447 | `is_leaf` can override this in the sense that types that are in the 448 | registry are still considered a leaf but it cannot declare something a 449 | container that is not in the registry. None means that the default registry 450 | is used, i.e. that dicts, tuples and lists are considered containers. 451 | "extended" means that in addition numpy arrays and params DataFrames are 452 | considered containers. Passing a dictionary where the keys are types and the 453 | values are dicts with the entries "flatten", "unflatten" and "names" allows 454 | to completely override the default registries. 455 | Returns: 456 | Updated pytree. 457 | 458 | """ 459 | first_flat, first_treedef = tree_flatten(tree, is_leaf=is_leaf, registry=registry) 460 | first_names = leaf_names(tree, is_leaf=is_leaf, registry=registry) 461 | first_dict = dict(zip(first_names, first_flat)) 462 | 463 | other_flat, _ = tree_flatten(other, is_leaf=is_leaf, registry=registry) 464 | other_names = leaf_names(other, is_leaf=is_leaf, registry=registry) 465 | other_dict = dict(zip(other_names, other_flat)) 466 | 467 | combined = list({**first_dict, **other_dict}.values()) 468 | 469 | out = tree_unflatten(first_treedef, combined, is_leaf=is_leaf, registry=registry) 470 | return out 471 | --------------------------------------------------------------------------------