├── tests
├── __init__.py
├── test_typecheck.py
├── test_with_jax.py
└── test_tree_util.py
├── setup.py
├── docs
├── source
│ ├── index.rst
│ ├── api.rst
│ ├── _static
│ │ ├── images
│ │ │ ├── coding.svg
│ │ │ ├── books.svg
│ │ │ ├── book.svg
│ │ │ └── light-bulb.svg
│ │ └── css
│ │ │ └── custom.css
│ └── conf.py
├── rtd_environment.yml
├── Makefile
└── make.bat
├── .readthedocs.yml
├── MANIFEST.in
├── codecov.yml
├── pyproject.toml
├── .github
├── PULL_REQUEST_TEMPLATE
│ └── pull_request_template.md
├── ISSUE_TEMPLATE
│ ├── bug-report.md
│ ├── feature_request.md
│ └── enhancement.md
└── workflows
│ ├── publish-to-pypi.yml
│ └── main.yml
├── src
└── pybaum
│ ├── config.py
│ ├── equality.py
│ ├── __init__.py
│ ├── registry.py
│ ├── typecheck.py
│ ├── registry_entries.py
│ └── tree_util.py
├── environment.yml
├── CHANGES.rst
├── setup.cfg
├── LICENSE
├── tox.ini
├── .gitignore
├── README.rst
└── .pre-commit-config.yaml
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | if __name__ == "__main__":
4 | setup()
5 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | pybaum
2 | ======
3 |
4 | pybaum contains tools to work with pytrees.
5 |
6 |
7 | .. toctree::
8 | :maxdepth: 1
9 |
10 | api
11 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | image: latest
5 |
6 | python:
7 | version: 3.8
8 |
9 | conda:
10 | environment: docs/rtd_environment.yml
11 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include CITATION
2 | include LICENSE
3 | include CHANGES.rst
4 |
5 | exclude *.yaml
6 | exclude *.yml
7 | exclude tox.ini
8 |
9 | prune docs
10 | prune tests
11 |
--------------------------------------------------------------------------------
/docs/source/api.rst:
--------------------------------------------------------------------------------
1 | API Reference
2 | =============
3 |
4 | The following documents are auto-generated and not carefully edited. They provide direct
5 | access to the source code and the docstrings.
6 |
7 | .. toctree::
8 | :titlesonly:
9 |
10 | /autoapi/pybaum/index
11 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | codecov:
2 | notify:
3 | require_ci_to_pass: yes
4 |
5 | coverage:
6 | precision: 2
7 | round: down
8 | range: "50...100"
9 | status:
10 | patch:
11 | default:
12 | target: 80%
13 | project:
14 | default:
15 | target: 80%
16 |
17 | ignore:
18 | - ".tox/**/*"
19 | - "setup.py"
20 |
--------------------------------------------------------------------------------
/docs/rtd_environment.yml:
--------------------------------------------------------------------------------
1 | channels:
2 | - conda-forge
3 | - nodefaults
4 |
5 | dependencies:
6 | - python=3.9
7 | - pip
8 | - setuptools_scm
9 | - toml
10 |
11 | # Documentation
12 | - sphinx
13 | - sphinx-autoapi
14 | - sphinx-copybutton
15 | - sphinx-panels
16 | - pydata-sphinx-theme>=0.3.0
17 |
18 | - pip:
19 | - ../
20 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.nbqa.config]
2 | isort = "setup.cfg"
3 | black = "pyproject.toml"
4 |
5 | [tool.nbqa.mutate]
6 | isort = 1
7 | black = 1
8 | pyupgrade = 1
9 |
10 |
11 | [tool.nbqa.addopts]
12 | isort = ["--treat-comment-as-code", "# %%", "--profile=black"]
13 | pyupgrade = ["--py37-plus"]
14 |
15 |
16 | [build-system]
17 | requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.0"]
18 |
19 |
20 | [tool.setuptools_scm]
21 | write_to = "src/pybaum/_version.py"
22 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md:
--------------------------------------------------------------------------------
1 | ### What problem do you want to solve?
2 |
3 | Reference the issue or discussion, if there is any. Provide a description of your
4 | proposed solution.
5 |
6 | ### Todo
7 |
8 | - [ ] Target the right branch and pick an appropriate title.
9 | - [ ] Put `Closes #XXXX` in the first PR comment to auto-close the relevant issue once
10 | the PR is accepted. This is not applicable if there is no corresponding issue.
11 | - [ ] Any steps that still need to be done.
12 |
--------------------------------------------------------------------------------
/src/pybaum/config.py:
--------------------------------------------------------------------------------
1 | try:
2 | import numpy as np # noqa: F401
3 | except ImportError:
4 | IS_NUMPY_INSTALLED = False
5 | else:
6 | IS_NUMPY_INSTALLED = True
7 |
8 |
9 | try:
10 | import pandas as pd # noqa: F401
11 | except ImportError:
12 | IS_PANDAS_INSTALLED = False
13 | else:
14 | IS_PANDAS_INSTALLED = True
15 |
16 |
17 | try:
18 | import jax # noqa: F401
19 | import jaxlib # noqa: F401
20 | except ImportError:
21 | IS_JAX_INSTALLED = False
22 | else:
23 | IS_JAX_INSTALLED = True
24 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: pybaum
2 |
3 | channels:
4 | - conda-forge
5 | - nodefaults
6 |
7 | dependencies:
8 | - python >=3.9
9 | - pip
10 | - setuptools_scm
11 | - toml
12 |
13 | # Testing
14 | - pre-commit
15 | - pytest
16 | - pytest-cov
17 | - pytest-xdist
18 | - tox-conda
19 |
20 | # Documentation
21 | - sphinx >=4
22 | - sphinx-autoapi
23 | - sphinx-copybutton
24 | - sphinx-panels
25 | - pydata-sphinx-theme>=0.3.0
26 |
27 | # Development
28 | - jupyterlab
29 | - nbsphinx
30 | - pdbpp
31 | - numpy
32 | - pandas
33 | - jax
34 | - jaxlib
35 |
--------------------------------------------------------------------------------
/tests/test_typecheck.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from typing import NamedTuple
3 |
4 | from pybaum.typecheck import get_type
5 |
6 |
7 | def test_namedtuple_is_discovered():
8 | bla = namedtuple("bla", ["a", "b"])(1, 2)
9 | assert get_type(bla) == "namedtuple"
10 |
11 |
12 | def test_typed_namedtuple_is_discovered():
13 | class Blubb(NamedTuple):
14 | a: int
15 | b: int
16 |
17 | blubb = Blubb(1, 2)
18 | assert get_type(blubb) == "namedtuple"
19 |
20 |
21 | def test_standard_tuple_is_not_discovered():
22 | assert get_type((1, 2)) == tuple
23 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SPHINXPROJ = pybaum
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug-report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug Report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: bug
6 | assignees: ''
7 |
8 | ---
9 |
10 | ### Bug description
11 |
12 | A clear and concise description of what the bug is.
13 |
14 | ### To Reproduce
15 |
16 | Ideally, provide a minimal code example. If that's not possible, describe steps to reproduce the bug.
17 |
18 | ### Expected behavior
19 |
20 | A clear and concise description of what you expected to happen.
21 |
22 | ### Screenshots/Error messages
23 |
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | ### System
27 |
28 | - OS: [e.g. Ubuntu 18.04]
29 | - Version [e.g. 0.0.1]
30 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: feature-request
6 | assignees: ''
7 |
8 | ---
9 |
10 | ### Current situation
11 |
12 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]; Currently there is no way of [...]
13 |
14 | ### Desired Situation
15 |
16 | What functionality should become possible or easier?
17 |
18 | ### Proposed implementation
19 |
20 | How would you implement the new feature? Did you consider alternative implementations?
21 | You can start by describing interface changes like a new argument or a new function. There is no need to get too detailed here.
22 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/enhancement.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Enhancement
3 | about: Enhance an existing component.
4 | title: ''
5 | labels: enhancement
6 | assignees: ''
7 |
8 | ---
9 |
10 | * pybaum version used, if any:
11 | * Python version, if any:
12 | * Operating System:
13 |
14 | ### What would you like to enhance and why? Is it related to an issue/problem?
15 |
16 | A clear and concise description of the current implementation and its limitations.
17 |
18 | ### Describe the solution you'd like
19 |
20 | A clear and concise description of what you want to happen.
21 |
22 | ### Describe alternatives you've considered
23 |
24 | A clear and concise description of any alternative solutions or features you've
25 | considered and why you have discarded them.
26 |
--------------------------------------------------------------------------------
/src/pybaum/equality.py:
--------------------------------------------------------------------------------
1 | """Functions to check equality of pytree leaves."""
2 | from pybaum.config import IS_JAX_INSTALLED
3 | from pybaum.config import IS_NUMPY_INSTALLED
4 | from pybaum.config import IS_PANDAS_INSTALLED
5 |
6 |
7 | if IS_NUMPY_INSTALLED:
8 | import numpy as np
9 |
10 |
11 | if IS_PANDAS_INSTALLED:
12 | import pandas as pd
13 |
14 | EQUALITY_CHECKERS = {}
15 |
16 |
17 | if IS_NUMPY_INSTALLED:
18 | EQUALITY_CHECKERS[np.ndarray] = lambda a, b: bool((a == b).all())
19 |
20 |
21 | if IS_PANDAS_INSTALLED:
22 | EQUALITY_CHECKERS[pd.Series] = lambda a, b: a.equals(b)
23 | EQUALITY_CHECKERS[pd.DataFrame] = lambda a, b: a.equals(b)
24 |
25 |
26 | if IS_JAX_INSTALLED:
27 | EQUALITY_CHECKERS["jax.numpy.ndarray"] = lambda a, b: bool((a == b).all())
28 |
--------------------------------------------------------------------------------
/src/pybaum/__init__.py:
--------------------------------------------------------------------------------
1 | from pybaum.registry import get_registry
2 | from pybaum.tree_util import leaf_names
3 | from pybaum.tree_util import tree_equal
4 | from pybaum.tree_util import tree_flatten
5 | from pybaum.tree_util import tree_just_flatten
6 | from pybaum.tree_util import tree_just_yield
7 | from pybaum.tree_util import tree_map
8 | from pybaum.tree_util import tree_multimap
9 | from pybaum.tree_util import tree_unflatten
10 | from pybaum.tree_util import tree_update
11 | from pybaum.tree_util import tree_yield
12 |
13 |
14 | __all__ = [
15 | "tree_flatten",
16 | "tree_just_flatten",
17 | "tree_just_yield",
18 | "tree_unflatten",
19 | "tree_map",
20 | "tree_multimap",
21 | "leaf_names",
22 | "tree_equal",
23 | "tree_update",
24 | "tree_yield",
25 | "get_registry",
26 | ]
27 |
--------------------------------------------------------------------------------
/CHANGES.rst:
--------------------------------------------------------------------------------
1 | Changes
2 | =======
3 |
4 | This is a record of all past pybaum releases and what went into them in reverse
5 | chronological order. We follow `semantic versioning `_ and all
6 | releases are available on `Anaconda.org
7 | `_.
8 |
9 |
10 | 0.1.x - 2022-xx-xx
11 | ------------------
12 |
13 | - :gh:`2` replaces the pre-commit pipeline step with pre-commit.ci.
14 | - :gh:`11` implement :func:`pybaum.tree_util.tree_yield` and
15 | :func:`pybaum.tree_util.tree_just_yield`.
16 | - :gh:`10` adds default arguments to ``tree_just_flatten``.
17 | - :gh:`12` adds a section about the API to the docs.
18 | - :gh:`13` extends the readme.
19 |
20 |
21 | 0.1.0 - 2022-01-28
22 | ------------------
23 |
24 | - :gh:`1` releases the initial version of pybaum.
25 |
--------------------------------------------------------------------------------
/.github/workflows/publish-to-pypi.yml:
--------------------------------------------------------------------------------
1 | name: PyPI
2 |
3 | on: push
4 |
5 | jobs:
6 | build-n-publish:
7 | name: Build and publish Python 🐍 distributions 📦 to PyPI
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@master
11 |
12 | - name: Set up Python 3.8
13 | uses: actions/setup-python@v1
14 | with:
15 | python-version: 3.8
16 |
17 | - name: Install pypa/build
18 | run: >-
19 | python -m
20 | pip install
21 | build
22 | --user
23 | - name: Build a binary wheel and a source tarball
24 | run: >-
25 | python -m
26 | build
27 | --sdist
28 | --wheel
29 | --outdir dist/
30 | - name: Publish distribution 📦 to PyPI
31 | if: startsWith(github.ref, 'refs/tags')
32 | uses: pypa/gh-action-pypi-publish@master
33 | with:
34 | password: ${{ secrets.PYPI_API_TOKEN }}
35 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 | set SPHINXPROJ=pybaum
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
20 | echo.installed, then set the SPHINXBUILD environment variable to point
21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
22 | echo.may add the Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = pybaum
3 | description = Tools to work with pytrees.
4 | long_description = file: README.rst
5 | long_description_content_type = text/x-rst
6 | url = https://github.com/OpenSourceEconomics/pybaum
7 | author = Janos Gabler, Tobias Raabe
8 | author_email = janos.gabler@gmail.com
9 | license = MIT
10 | license_file = LICENSE
11 | platforms = unix, linux, osx, cygwin, win32
12 | classifiers =
13 | Development Status :: 3 - Alpha
14 | License :: OSI Approved :: MIT License
15 | Operating System :: MacOS :: MacOS X
16 | Operating System :: Microsoft :: Windows
17 | Operating System :: POSIX
18 | Programming Language :: Python :: 3
19 | Programming Language :: Python :: 3 :: Only
20 | Topic :: Scientific/Engineering
21 | Topic :: Utilities
22 |
23 | [options]
24 | packages = find:
25 | python_requires = >=3.7
26 | include_package_data = True
27 | package_dir =
28 | =src
29 | zip_safe = False
30 |
31 | [options.packages.find]
32 | where = src
33 |
34 | [check-manifest]
35 | ignore =
36 | src/pybaum/_version.py
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Janoś Gabler, Tobias Raabe
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this
6 | software and associated documentation files (the "Software"), to deal in the Software
7 | without restriction, including without limitation the rights to use, copy, modify,
8 | merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
9 | permit persons to whom the Software is furnished to do so, subject to the following
10 | conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all copies or
13 | substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
16 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
17 | PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT
19 | OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
20 | OTHER DEALINGS IN THE SOFTWARE.
21 |
--------------------------------------------------------------------------------
/docs/source/_static/images/coding.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
24 |
--------------------------------------------------------------------------------
/tests/test_with_jax.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pybaum.config import IS_JAX_INSTALLED
3 | from pybaum.registry import get_registry
4 | from pybaum.tree_util import leaf_names
5 | from pybaum.tree_util import tree_equal
6 | from pybaum.tree_util import tree_flatten
7 | from pybaum.tree_util import tree_just_flatten
8 |
9 | if IS_JAX_INSTALLED:
10 | import jax.numpy as jnp
11 | else:
12 | # run the tests with normal numpy instead
13 | import numpy as jnp
14 |
15 |
16 | @pytest.fixture
17 | def tree():
18 | return {"a": {"b": jnp.arange(4).reshape(2, 2)}, "c": jnp.ones(2)}
19 |
20 |
21 | @pytest.fixture
22 | def flat():
23 | return [0, 1, 2, 3, 1, 1]
24 |
25 |
26 | @pytest.fixture
27 | def registry():
28 | return get_registry(types=["jax.numpy.ndarray", "numpy.ndarray"])
29 |
30 |
31 | def test_tree_just_flatten_with_jax(tree, registry, flat):
32 | got = tree_just_flatten(tree, registry=registry)
33 | assert got == flat
34 |
35 |
36 | def test_tree_flatten_with_jax(tree, registry, flat):
37 | got_flat, got_treedef = tree_flatten(tree, registry=registry)
38 | assert got_flat == flat
39 | assert tree_equal(got_treedef, tree)
40 |
41 |
42 | def test_leaf_names_with_jax(tree, registry):
43 | got = leaf_names(tree, registry=registry)
44 | expected = ["a_b_0_0", "a_b_0_1", "a_b_1_0", "a_b_1_1", "c_0", "c_1"]
45 | assert got == expected
46 |
--------------------------------------------------------------------------------
/docs/source/_static/images/books.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/pybaum/registry.py:
--------------------------------------------------------------------------------
1 | from pybaum.registry_entries import FUNC_DICT
2 |
3 |
4 | def get_registry(types=None, include_defaults=True):
5 | """Create a pytree registry.
6 |
7 | Args:
8 | types (list): A list strings with the names of types that should be included in
9 | the registry, i.e. considered containers and not leaves by the functions
10 | that work with pytrees. Currently we support:
11 | - "tuple"
12 | - "dict"
13 | - "list"
14 | - :class:`collections.namedtuple` or :class:`typing.NamedTuple`
15 | - :obj:`None`
16 | - :class:`collections.OrderedDict`
17 | - "numpy.ndarray"
18 | - "jax.numpy.ndarray"
19 | - "pandas.Series"
20 | - "pandas.DataFrame"
21 | include_defaults (bool): Whether the default pytree containers "tuple", "dict"
22 | "list", "None", "namedtuple" and "OrderedDict" should be included even if
23 | not specified in `types`.
24 |
25 | Returns:
26 | dict: A pytree registry.
27 |
28 | """
29 | types = [] if types is None else types
30 |
31 | if include_defaults:
32 | default_types = {"list", "tuple", "dict", "None", "namedtuple", "OrderedDict"}
33 | types = list(set(types) | default_types)
34 |
35 | registry = {}
36 | for typ in types:
37 | new_entry = FUNC_DICT[typ]()
38 | registry = {**registry, **new_entry}
39 |
40 | return registry
41 |
--------------------------------------------------------------------------------
/docs/source/_static/images/book.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
19 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist = pytest, sphinx
3 | skipsdist = True
4 | skip_missing_interpreters = True
5 |
6 | [testenv]
7 | basepython = python
8 |
9 | [testenv:pytest]
10 | setenv =
11 | CONDA_DLL_SEARCH_MODIFICATION_ENABLE = 1
12 | conda_channels =
13 | conda-forge
14 | defaults
15 | conda_deps =
16 | conda-build
17 | numpy
18 | pandas
19 | pytest
20 | pytest-cov
21 | pytest-mock
22 | pytest-xdist
23 | jax
24 | jaxlib
25 | commands = pytest {posargs}
26 |
27 | [testenv:pytest-windows]
28 | setenv =
29 | CONDA_DLL_SEARCH_MODIFICATION_ENABLE = 1
30 | conda_channels =
31 | conda-forge
32 | defaults
33 | conda_deps =
34 | conda-build
35 | numpy
36 | pandas
37 | pytest
38 | pytest-cov
39 | pytest-mock
40 | pytest-xdist
41 | commands = pytest {posargs}
42 |
43 | [testenv:sphinx]
44 | changedir = docs/source
45 | conda_env = docs/rtd_environment.yml
46 | commands =
47 | sphinx-build -T -b html -d {envtmpdir}/doctrees . {envtmpdir}/html
48 | - sphinx-build -T -b linkcheck -d {envtmpdir}/doctrees . {envtmpdir}/linkcheck
49 |
50 |
51 | [doc8]
52 | ignore =
53 | D002,
54 | D004,
55 | max-line-length = 88
56 |
57 | [flake8]
58 | max-line-length = 88
59 | ignore =
60 | D ; ignores docstring style errors, enable if you are nit-picky
61 | E203 ; ignores whitespace around : which is enforced by Black
62 | W503 ; ignores linebreak before binary operator which is enforced by Black
63 | RST304 ; ignores check for valid rst roles because it is too aggressive
64 | T001 ; ignore print statements
65 | RST301 ; ignores unexpected indentations in docstrings because it was not compatible with google style docstrings
66 | RST203 ; gave false positives
67 | RST202 ; gave false positives
68 | RST201 ; gave false positives
69 | W605 ; ignores regex relevant escape sequences
70 | PT001 ; ignores brackets for fixtures.
71 | per-file-ignores =
72 | docs/source/conf.py:E501, E800
73 | warn-symbols =
74 | pytest.mark.wip = Remove 'wip' mark for tests.
75 |
76 | [pytest]
77 | addopts = --doctest-modules
78 | markers =
79 | wip: Tests that are work-in-progress.
80 | slow: Tests that take a long time to run and are skipped in continuous integration.
81 | norecursedirs =
82 | docs
83 | .tox
84 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 |
132 | src/pybaum/_version.py
133 |
--------------------------------------------------------------------------------
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: main
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 | branches:
8 | - '*'
9 |
10 | # Automatically cancel a previous run.
11 | concurrency:
12 | group: ${{ github.head_ref || github.run_id }}
13 | cancel-in-progress: true
14 |
15 | jobs:
16 |
17 | run-tests:
18 |
19 | name: Run tests for ${{ matrix.os }} on ${{ matrix.python-version }}
20 | runs-on: ${{ matrix.os }}
21 |
22 | strategy:
23 | fail-fast: false
24 | matrix:
25 | os: ['ubuntu-latest', 'macos-latest']
26 | python-version: ['3.7', '3.8', '3.9', '3.10']
27 |
28 | steps:
29 | - uses: actions/checkout@v2
30 | - uses: conda-incubator/setup-miniconda@v2
31 | with:
32 | auto-update-conda: true
33 | python-version: ${{ matrix.python-version }}
34 |
35 | - name: Install core dependencies.
36 | shell: bash -l {0}
37 | run: conda install -c conda-forge tox-conda
38 |
39 | - name: Run pytest.
40 | shell: bash -l {0}
41 | run: tox -e pytest -- -m "not slow" --cov-report=xml --cov=./
42 |
43 | - name: Upload coverage report.
44 | if: runner.os == 'Linux' && matrix.python-version == '3.9'
45 | uses: codecov/codecov-action@v1
46 | with:
47 | token: ${{ secrets.CODECOV_TOKEN }}
48 |
49 | run-tests-windows:
50 |
51 | name: Run tests for ${{ matrix.os }} on ${{ matrix.python-version }}
52 | runs-on: ${{ matrix.os }}
53 |
54 | strategy:
55 | fail-fast: false
56 | matrix:
57 | os: ['windows-latest']
58 | python-version: ['3.7', '3.8', '3.9', '3.10']
59 |
60 | steps:
61 | - uses: actions/checkout@v2
62 | - uses: conda-incubator/setup-miniconda@v2
63 | with:
64 | auto-update-conda: true
65 | python-version: ${{ matrix.python-version }}
66 |
67 | - name: Install core dependencies.
68 | shell: bash -l {0}
69 | run: conda install -c conda-forge tox-conda
70 |
71 | - name: Run pytest.
72 | shell: bash -l {0}
73 | run: tox -e pytest-windows -- -m "not slow"
74 |
75 | docs:
76 |
77 | name: Run documentation.
78 | runs-on: ubuntu-latest
79 |
80 | steps:
81 | - uses: actions/checkout@v2
82 | - uses: conda-incubator/setup-miniconda@v2
83 | with:
84 | auto-update-conda: true
85 | python-version: 3.9
86 |
87 | - name: Install core dependencies.
88 | shell: bash -l {0}
89 | run: conda install -c conda-forge tox-conda
90 |
91 | - name: Build docs
92 | shell: bash -l {0}
93 | run: tox -e sphinx
94 |
--------------------------------------------------------------------------------
/src/pybaum/typecheck.py:
--------------------------------------------------------------------------------
1 | from pybaum.config import IS_JAX_INSTALLED
2 | from pybaum.config import IS_NUMPY_INSTALLED
3 |
4 | if IS_JAX_INSTALLED:
5 | import jax.numpy as jnp
6 |
7 | if IS_NUMPY_INSTALLED:
8 | import numpy as np
9 |
10 |
11 | def get_type(obj):
12 | """Get type of candidate objects in a pytree.
13 |
14 | This function allows us to reliably identify namedtuples, NamedTuples and jax arrays
15 | for which standard ``type`` function does not work.
16 |
17 | Args:
18 | obj: The object to be checked
19 |
20 | Returns:
21 | type or str: The type of the object or a string with the type name.
22 |
23 | """
24 | if _is_namedtuple(obj):
25 | out = "namedtuple"
26 | elif _is_jax_array(obj):
27 | out = "jax.numpy.ndarray"
28 | else:
29 | out = type(obj)
30 | return out
31 |
32 |
33 | def _is_namedtuple(obj):
34 | """Check if an object is a namedtuple.
35 |
36 | As in JAX we treat collections.namedtuple and typing.NamedTuple both as
37 | namedtuple but the exact type is preserved in the unflatten function.
38 |
39 | namedtuples are discovered by being instances of tuple and having a
40 | ``_fields`` attribute as suggested by Raymond Hettinger
41 | `here `_.
42 |
43 | Moreover we check for the presence of a ``_replace`` method because we need when
44 | unflattening pytrees.
45 |
46 | This can produce false positives but in most cases would still result in desired
47 | behavior.
48 |
49 | Args:
50 | obj: The object to be checked
51 |
52 | Returns:
53 | bool
54 |
55 | """
56 | out = (
57 | isinstance(obj, tuple) and hasattr(obj, "_fields") and hasattr(obj, "_replace")
58 | )
59 | return out
60 |
61 |
62 | def _is_jax_array(obj):
63 | """Check if an object is a jax array.
64 |
65 | The exact type of jax arrays has changed over time and is an implementation detail.
66 |
67 | Instead we rely on isinstance checks which will likely be more stable in the future.
68 | However, the behavior of isinstance for jax arrays has also changed over time. For
69 | jax versions before 0.2.21, standard numpy arrays were instances of jax arrays,
70 | now they are not.
71 |
72 | Resources:
73 | ----------
74 |
75 | - https://github.com/google/jax/issues/2115
76 | - https://github.com/google/jax/issues/2014
77 | - https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0221-sept-23-2021
78 | - https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0318-sep-26-2022
79 |
80 | Args:
81 | obj: The object to be checked
82 |
83 | Returns:
84 | bool
85 |
86 | """
87 | if not IS_JAX_INSTALLED:
88 | out = False
89 | elif IS_NUMPY_INSTALLED:
90 | out = isinstance(obj, jnp.ndarray) and not isinstance(obj, np.ndarray)
91 | else:
92 | out = isinstance(obj, jnp.ndarray)
93 | return out
94 |
--------------------------------------------------------------------------------
/docs/source/_static/images/light-bulb.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
34 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | pybaum
2 | ======
3 |
4 | .. start-badges
5 |
6 | .. image:: https://img.shields.io/pypi/v/pybaum?color=blue
7 | :alt: PyPI
8 | :target: https://pypi.org/project/pybaum
9 |
10 | .. image:: https://img.shields.io/pypi/pyversions/pybaum
11 | :alt: PyPI - Python Version
12 | :target: https://pypi.org/project/pybaum
13 |
14 | .. image:: https://img.shields.io/conda/vn/conda-forge/pybaum.svg
15 | :target: https://anaconda.org/conda-forge/pybaum
16 |
17 | .. image:: https://img.shields.io/conda/pn/conda-forge/pybaum.svg
18 | :target: https://anaconda.org/conda-forge/pybaum
19 |
20 | .. image:: https://img.shields.io/pypi/l/pybaum
21 | :alt: PyPI - License
22 | :target: https://pypi.org/project/pybaum
23 |
24 | .. image:: https://readthedocs.org/projects/pybaum/badge/?version=latest
25 | :target: https://pybaum.readthedocs.io/en/latest
26 |
27 | .. image:: https://img.shields.io/github/actions/workflow/status/OpenSourceEconomics/pybaum/main.yml?branch=main
28 | :target: https://github.com/OpenSourceEconomics/pybaum/actions?query=branch%3Amain
29 |
30 | .. image:: https://codecov.io/gh/OpenSourceEconomics/pybaum/branch/main/graph/badge.svg
31 | :target: https://codecov.io/gh/OpenSourceEconomics/pybaum
32 |
33 | .. image:: https://results.pre-commit.ci/badge/github/OpenSourceEconomics/pybaum/main.svg
34 | :target: https://results.pre-commit.ci/latest/github/OpenSourceEconomics/pybaum/main
35 | :alt: pre-commit.ci status
36 |
37 | .. image:: https://img.shields.io/badge/code%20style-black-000000.svg
38 | :target: https://github.com/psf/black
39 |
40 | .. end-badges
41 |
42 | Installation
43 | ------------
44 |
45 | pybaum is available on `PyPI `_ and `Anaconda.org
46 | `_. Install it with
47 |
48 | .. code-block:: console
49 |
50 | $ pip install pybaum
51 |
52 | # or
53 |
54 | $ conda install -c conda-forge pybaum
55 |
56 |
57 | About
58 | -----
59 |
60 | pybaum provides tools to work with pytrees which is a concept borrowed from `JAX
61 | `_.
62 |
63 | What is a pytree?
64 |
65 | In pybaum, we use the term pytree to refer to a tree-like structure built out of
66 | container-like Python objects. Classes are considered container-like if they are in the
67 | pytree registry, which by default includes lists, tuples, and dicts. That is:
68 |
69 | 1. Any object whose type is not in the pytree container registry is considered a leaf
70 | pytree.
71 |
72 | 2. Any object whose type is in the pytree container registry, and which contains
73 | pytrees, is considered a pytree.
74 |
75 | For each entry in the pytree container registry, a container-like type is registered
76 | with a pair of functions that specify how to convert an instance of the container type
77 | to a (children, metadata) pair and how to convert such a pair back to an instance of the
78 | container type. Using these functions, pybaum can canonicalize any tree of registered
79 | container objects into tuples.
80 |
81 | Example pytrees:
82 |
83 | .. code-block:: python
84 |
85 | [1, "a", object()] # 3 leaves
86 |
87 | (1, (2, 3), ()) # 3 leaves
88 |
89 | [1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves
90 |
91 | pybaum can be extended to consider other container types as pytrees.
92 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.4.0
4 | hooks:
5 | - id: check-merge-conflict
6 | - id: debug-statements
7 | - id: end-of-file-fixer
8 | - repo: https://github.com/asottile/reorder_python_imports
9 | rev: v3.9.0
10 | hooks:
11 | - id: reorder-python-imports
12 | types: [python]
13 | - repo: https://github.com/pre-commit/pre-commit-hooks
14 | rev: v4.4.0
15 | hooks:
16 | - id: check-added-large-files
17 | args: ['--maxkb=100']
18 | - id: check-case-conflict
19 | - id: check-merge-conflict
20 | - id: check-vcs-permalinks
21 | - id: check-yaml
22 | - id: debug-statements
23 | - id: end-of-file-fixer
24 | - id: fix-byte-order-marker
25 | - id: mixed-line-ending
26 | - id: no-commit-to-branch
27 | args: [--branch, main]
28 | - id: trailing-whitespace
29 | - repo: https://github.com/pre-commit/pygrep-hooks
30 | rev: v1.9.0
31 | hooks:
32 | - id: python-check-blanket-noqa
33 | - id: python-check-mock-methods
34 | - id: python-no-eval
35 | - id: python-no-log-warn
36 | - id: python-use-type-annotations
37 | - id: rst-backticks
38 | - id: rst-directive-colons
39 | - id: rst-inline-touching-normal
40 | - id: text-unicode-replacement-char
41 | - repo: https://github.com/asottile/blacken-docs
42 | rev: v1.12.1
43 | hooks:
44 | - id: blacken-docs
45 | additional_dependencies: [black==22.3.0]
46 | types: [rst]
47 | - repo: https://github.com/psf/black
48 | rev: 22.12.0
49 | hooks:
50 | - id: black
51 | language_version: python3.10
52 | - repo: https://github.com/PyCQA/flake8
53 | rev: 5.0.4
54 | hooks:
55 | - id: flake8
56 | types: [python]
57 | additional_dependencies: [
58 | flake8-alfred,
59 | flake8-bugbear,
60 | flake8-builtins,
61 | flake8-comprehensions,
62 | flake8-docstrings,
63 | flake8-eradicate,
64 | flake8-print,
65 | flake8-pytest-style,
66 | flake8-todo,
67 | flake8-typing-imports,
68 | flake8-unused-arguments,
69 | pep8-naming,
70 | pydocstyle,
71 | Pygments,
72 | ]
73 | - repo: https://github.com/PyCQA/doc8
74 | rev: v1.0.0
75 | hooks:
76 | - id: doc8
77 | - repo: meta
78 | hooks:
79 | - id: check-hooks-apply
80 | - id: check-useless-excludes
81 | # - id: identity # Prints all files passed to pre-commits. Debugging.
82 | - repo: https://github.com/mgedmin/check-manifest
83 | rev: "0.49"
84 | hooks:
85 | - id: check-manifest
86 | - repo: https://github.com/PyCQA/doc8
87 | rev: v1.0.0
88 | hooks:
89 | - id: doc8
90 | - repo: https://github.com/asottile/setup-cfg-fmt
91 | rev: v2.2.0
92 | hooks:
93 | - id: setup-cfg-fmt
94 | - repo: https://github.com/econchick/interrogate
95 | rev: 1.5.0
96 | hooks:
97 | - id: interrogate
98 | args: [-v, --fail-under=20]
99 | exclude: ^(tests|docs|setup\.py)
100 | - repo: https://github.com/codespell-project/codespell
101 | rev: v2.2.2
102 | hooks:
103 | - id: codespell
104 | - repo: https://github.com/asottile/pyupgrade
105 | rev: v3.3.1
106 | hooks:
107 | - id: pyupgrade
108 | args: [--py37-plus]
109 |
--------------------------------------------------------------------------------
/docs/source/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | /* Remove execution count for notebook cells. */
2 | div.prompt {
3 | display: none;
4 | }
5 |
6 | /* Getting started index page */
7 |
8 | .intro-card {
9 | background: #fff;
10 | border-radius: 0;
11 | padding: 30px 10px 10px 10px;
12 | margin: 10px 0px;
13 | max-height: 85%;
14 | }
15 |
16 | .intro-card .card-text {
17 | margin: 20px 0px;
18 | }
19 |
20 | div#index-container {
21 | padding-bottom: 20px;
22 | }
23 |
24 | a#index-link {
25 | color: #333;
26 | text-decoration: none;
27 | }
28 |
29 | /* reference to user guide */
30 | .gs-torefguide {
31 | align-items: center;
32 | font-size: 0.9rem;
33 | }
34 |
35 | .gs-torefguide .badge {
36 | background-color: #130654;
37 | margin: 10px 10px 10px 0px;
38 | padding: 5px;
39 | }
40 |
41 | .gs-torefguide a {
42 | margin-left: 5px;
43 | color: #130654;
44 | border-bottom: 1px solid #FFCA00f3;
45 | box-shadow: 0px -10px 0px #FFCA00f3 inset;
46 | }
47 |
48 | .gs-torefguide p {
49 | margin-top: 1rem;
50 | }
51 |
52 | .gs-torefguide a:hover {
53 | margin-left: 5px;
54 | color: grey;
55 | text-decoration: none;
56 | border-bottom: 1px solid #b2ff80f3;
57 | box-shadow: 0px -10px 0px #b2ff80f3 inset;
58 | }
59 |
60 | /* selecting constraints guide */
61 | .intro-card {
62 | background:#FFF;
63 | border-radius:0;
64 | padding: 30px 10px 10px 10px;
65 | margin: 10px 0px;
66 | }
67 |
68 | .intro-card .card-text {
69 | margin:20px 0px;
70 | /*min-height: 150px; */
71 | }
72 |
73 | .intro-card .card-img-top {
74 | margin: 10px;
75 | }
76 |
77 | .install-block {
78 | padding-bottom: 30px;
79 | }
80 |
81 | .install-card .card-header {
82 | border: none;
83 | background-color:white;
84 | color: #150458;
85 | font-size: 1.1rem;
86 | font-weight: bold;
87 | padding: 1rem 1rem 0rem 1rem;
88 | }
89 |
90 | .install-card .card-footer {
91 | border: none;
92 | background-color:white;
93 | }
94 |
95 | .install-card pre {
96 | margin: 0 1em 1em 1em;
97 | }
98 |
99 | .custom-button {
100 | background-color:#DCDCDC;
101 | border: none;
102 | color: #484848;
103 | text-align: center;
104 | text-decoration: none;
105 | display: inline-block;
106 | font-size: 0.9rem;
107 | border-radius: 0.5rem;
108 | max-width: 120px;
109 | padding: 0.5rem 0rem;
110 | }
111 |
112 | .custom-button a {
113 | color: #484848;
114 | }
115 |
116 | .custom-button p {
117 | margin-top: 0;
118 | margin-bottom: 0rem;
119 | color: #484848;
120 | }
121 |
122 | /* selecting constraints guide collapsed cards */
123 |
124 | .tutorial-accordion {
125 | margin-top: 20px;
126 | margin-bottom: 20px;
127 | }
128 |
129 | .tutorial-card .card-header.card-link .btn {
130 | margin-right: 12px;
131 | }
132 |
133 | .tutorial-card .card-header.card-link .btn:after {
134 | content: "-";
135 | }
136 |
137 | .tutorial-card .card-header.card-link.collapsed .btn:after {
138 | content: "+";
139 | }
140 |
141 | .tutorial-card-header-1 {
142 | justify-content: space-between;
143 | align-items: center;
144 | }
145 |
146 | .tutorial-card-header-2 {
147 | justify-content: flex-start;
148 | align-items: center;
149 | font-size: 1.3rem;
150 | }
151 |
152 | .tutorial-card .card-header {
153 | cursor: pointer;
154 | background-color: white;
155 | }
156 |
157 | .tutorial-card .card-body {
158 | background-color: #F0F0F0;
159 | }
160 |
161 | .tutorial-card .badge:hover {
162 | background-color: grey;
163 | }
164 |
165 | /* tables in selecting constraints guide */
166 |
167 | table.rows th {
168 | background-color: #F0F0F0;
169 | border-style: solid solid solid solid;
170 | border-width: 0px 0px 0px 0px;
171 | border-color: #F0F0F0;
172 | text-align: center;
173 | }
174 |
175 | table.rows tr:nth-child(even) {
176 | background-color: #F0F0F0;
177 | text-align: right;
178 | }
179 | table.rows tr:nth-child(odd) {
180 | background-color: #FFFFFF;
181 | text-align: right;
182 | }
183 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib.metadata import version
3 |
4 |
5 | author = "Janos Gabler, Tobias Raabe"
6 |
7 | # Set variable so that todos are shown in local build
8 | on_rtd = os.environ.get("READTHEDOCS") == "True"
9 |
10 |
11 | # -- General configuration ------------------------------------------------
12 |
13 | # If your documentation needs a minimal Sphinx version, state it here.
14 |
15 | # Add any Sphinx extension module names here, as strings. They can be
16 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
17 | # ones.
18 | extensions = [
19 | "sphinx.ext.autodoc",
20 | "sphinx.ext.todo",
21 | "sphinx.ext.coverage",
22 | "sphinx.ext.extlinks",
23 | "sphinx.ext.intersphinx",
24 | "sphinx.ext.mathjax",
25 | "sphinx.ext.viewcode",
26 | "sphinx.ext.napoleon",
27 | "sphinx_panels",
28 | "autoapi.extension",
29 | ]
30 |
31 | autodoc_member_order = "bysource"
32 |
33 | autodoc_mock_imports = [
34 | "pandas",
35 | "pytest",
36 | "numpy",
37 | "jax",
38 | ]
39 |
40 | extlinks = {
41 | "ghuser": ("https://github.com/%s", "@"),
42 | "gh": ("https://github.com/OpenSourceEconomics/pybaum/pulls/%s", "#"),
43 | }
44 |
45 | intersphinx_mapping = {
46 | "numpy": ("https://numpy.org/doc/stable", None),
47 | "np": ("https://numpy.org/doc/stable", None),
48 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None),
49 | "pd": ("https://pandas.pydata.org/pandas-docs/stable", None),
50 | "python": ("https://docs.python.org/3.9", None),
51 | }
52 |
53 | # Add any paths that contain templates here, relative to this directory.
54 | templates_path = ["_templates"]
55 | html_static_path = ["_static"]
56 |
57 | # The suffix(es) of source filenames.
58 | # You can specify multiple suffix as a list of string:
59 | source_suffix = ".rst"
60 |
61 | # The master toctree document.
62 | master_doc = "index"
63 |
64 | # General information about the project.
65 | project = "pybaum"
66 | copyright = f"2022, {author}" # noqa: A001
67 |
68 | # The version info for the project you're documenting, acts as replacement for
69 | # |version| and |release|, also used in various other places throughout the
70 | # built documents.
71 |
72 | # The version, including alpha/beta/rc tags, but not commit hash and datestamps
73 | release = version("pybaum")
74 | # The short X.Y version.
75 | version = ".".join(release.split(".")[:2])
76 |
77 | # The language for content autogenerated by Sphinx. Refer to documentation
78 | # for a list of supported languages.
79 |
80 | # This is also used if you do content translation via gettext catalogs.
81 | # Usually you set "language" from the command line for these cases.
82 | language = None
83 |
84 | # List of patterns, relative to source directory, that match files and
85 | # directories to ignore when looking for source files.
86 | # This patterns also effect to html_static_path and html_extra_path
87 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"]
88 |
89 | # The name of the Pygments (syntax highlighting) style to use.
90 | pygments_style = "sphinx"
91 |
92 | # If true, `todo` and `todoList` produce output, else they produce nothing.
93 | if on_rtd:
94 | pass
95 | else:
96 | todo_include_todos = True
97 | todo_emit_warnings = True
98 |
99 | # Remove prefixed $ for bash, >>> for Python prompts, and In [1]: for IPython prompts.
100 | copybutton_prompt_text = r"\$ |>>> |In \[\d\]: "
101 | copybutton_prompt_is_regexp = True
102 |
103 | # Configuration for autoapi
104 | autoapi_type = "python"
105 | autoapi_dirs = ["../../src"]
106 | autoapi_keep_files = False
107 | autoapi_add_toctree_entry = False
108 |
109 |
110 | # -- Options for HTML output ----------------------------------------------
111 |
112 | # The theme to use for HTML and HTML Help pages. See the documentation for
113 | # a list of builtin themes.
114 | html_theme = "pydata_sphinx_theme"
115 |
116 | # html_logo = "_static/images/logo.svg"
117 |
118 | html_theme_options = {
119 | "github_url": "https://github.com/OpenSourceEconomics/pybaum",
120 | }
121 |
122 | html_css_files = ["css/custom.css"]
123 |
124 |
125 | # Add any paths that contain custom static files (such as style sheets) here,
126 | # relative to this directory. They are copied after the builtin static files,
127 | # so a file named "default.css" will overwrite the builtin "default.css".
128 | # html_static_path = ["_static"] # noqa: E800
129 |
130 | # Custom sidebar templates, must be a dictionary that maps document names
131 | # to template names.
132 |
133 | # This is required for the alabaster theme
134 | # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
135 | html_sidebars = {
136 | "**": [
137 | "relations.html", # needs 'show_related': True theme option to display
138 | "searchbox.html",
139 | ]
140 | }
141 |
--------------------------------------------------------------------------------
/src/pybaum/registry_entries.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from collections import OrderedDict
3 | from itertools import product
4 |
5 | from pybaum.config import IS_JAX_INSTALLED
6 | from pybaum.config import IS_NUMPY_INSTALLED
7 | from pybaum.config import IS_PANDAS_INSTALLED
8 |
9 | if IS_NUMPY_INSTALLED:
10 | import numpy as np
11 |
12 | if IS_PANDAS_INSTALLED:
13 | import pandas as pd
14 |
15 | if IS_JAX_INSTALLED:
16 | import jax
17 |
18 |
19 | def _none():
20 | """Create registry entry for NoneType."""
21 | entry = {
22 | type(None): {
23 | "flatten": lambda tree: ([], None), # noqa: U100
24 | "unflatten": lambda aux_data, children: None, # noqa: U100
25 | "names": lambda tree: [], # noqa: U100
26 | }
27 | }
28 | return entry
29 |
30 |
31 | def _list():
32 | """Create registry entry for list."""
33 | entry = {
34 | list: {
35 | "flatten": lambda tree: (tree, None),
36 | "unflatten": lambda aux_data, children: children, # noqa: U100
37 | "names": lambda tree: [f"{i}" for i in range(len(tree))],
38 | },
39 | }
40 | return entry
41 |
42 |
43 | def _dict():
44 | """Create registry entry for dict."""
45 | entry = {
46 | dict: {
47 | "flatten": lambda tree: (list(tree.values()), list(tree)),
48 | "unflatten": lambda aux_data, children: dict(zip(aux_data, children)),
49 | "names": lambda tree: list(map(str, list(tree))),
50 | },
51 | }
52 | return entry
53 |
54 |
55 | def _tuple():
56 | """Create registry entry for tuple."""
57 | entry = {
58 | tuple: {
59 | "flatten": lambda tree: (list(tree), None),
60 | "unflatten": lambda aux_data, children: tuple(children), # noqa: U100
61 | "names": lambda tree: [f"{i}" for i in range(len(tree))],
62 | },
63 | }
64 | return entry
65 |
66 |
67 | def _namedtuple():
68 | """Create registry entry for namedtuple and NamedTuple."""
69 | entry = {
70 | "namedtuple": {
71 | "flatten": lambda tree: (list(tree), tree),
72 | "unflatten": _unflatten_namedtuple,
73 | "names": lambda tree: list(tree._fields),
74 | },
75 | }
76 | return entry
77 |
78 |
79 | def _unflatten_namedtuple(aux_data, leaves):
80 | replacements = dict(zip(aux_data._fields, leaves))
81 | out = aux_data._replace(**replacements)
82 | return out
83 |
84 |
85 | def _ordereddict():
86 | """Create registry entry for OrderedDict."""
87 | entry = {
88 | OrderedDict: {
89 | "flatten": lambda tree: (list(tree.values()), list(tree)),
90 | "unflatten": lambda aux_data, children: OrderedDict(
91 | zip(aux_data, children)
92 | ),
93 | "names": lambda tree: list(map(str, list(tree))),
94 | },
95 | }
96 | return entry
97 |
98 |
99 | def _numpy_array():
100 | """Create registry entry for numpy.ndarray."""
101 |
102 | if IS_NUMPY_INSTALLED:
103 | entry = {
104 | np.ndarray: {
105 | "flatten": lambda arr: (arr.flatten().tolist(), arr.shape),
106 | "unflatten": lambda aux_data, leaves: np.array(leaves).reshape(
107 | aux_data
108 | ),
109 | "names": _array_element_names,
110 | },
111 | }
112 | else:
113 | entry = {}
114 | return entry
115 |
116 |
117 | def _array_element_names(arr):
118 | dim_names = [map(str, range(n)) for n in arr.shape]
119 | names = list(map("_".join, itertools.product(*dim_names)))
120 | return names
121 |
122 |
123 | def _jax_array():
124 | if IS_JAX_INSTALLED:
125 | entry = {
126 | "jax.numpy.ndarray": {
127 | "flatten": lambda arr: (arr.flatten().tolist(), arr.shape),
128 | "unflatten": lambda aux_data, leaves: jax.numpy.array(leaves).reshape(
129 | aux_data
130 | ),
131 | "names": _array_element_names,
132 | },
133 | }
134 | else:
135 | entry = {}
136 | return entry
137 |
138 |
139 | def _pandas_series():
140 | """Create registry entry for pandas.Series."""
141 | if IS_PANDAS_INSTALLED:
142 | entry = {
143 | pd.Series: {
144 | "flatten": lambda sr: (
145 | sr.tolist(),
146 | {"index": sr.index, "name": sr.name},
147 | ),
148 | "unflatten": lambda aux_data, leaves: pd.Series(leaves, **aux_data),
149 | "names": lambda sr: list(sr.index.map(_index_element_to_string)),
150 | },
151 | }
152 | else:
153 | entry = {}
154 | return entry
155 |
156 |
157 | def _pandas_dataframe():
158 | """Create registry entry for pandas.DataFrame."""
159 | if IS_PANDAS_INSTALLED:
160 | entry = {
161 | pd.DataFrame: {
162 | "flatten": _flatten_pandas_dataframe,
163 | "unflatten": _unflatten_pandas_dataframe,
164 | "names": _get_names_pandas_dataframe,
165 | }
166 | }
167 | else:
168 | entry = {}
169 | return entry
170 |
171 |
172 | def _flatten_pandas_dataframe(df):
173 | flat = df.to_numpy().flatten().tolist()
174 | aux_data = {"columns": df.columns, "index": df.index, "shape": df.shape}
175 | return flat, aux_data
176 |
177 |
178 | def _unflatten_pandas_dataframe(aux_data, leaves):
179 | out = pd.DataFrame(
180 | data=np.array(leaves).reshape(aux_data["shape"]),
181 | columns=aux_data["columns"],
182 | index=aux_data["index"],
183 | )
184 | return out
185 |
186 |
187 | def _get_names_pandas_dataframe(df):
188 | index_strings = list(df.index.map(_index_element_to_string))
189 | out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)]
190 | return out
191 |
192 |
193 | def _index_element_to_string(element):
194 | if isinstance(element, (tuple, list)):
195 | as_strings = [str(entry) for entry in element]
196 | res_string = "_".join(as_strings)
197 | else:
198 | res_string = str(element)
199 |
200 | return res_string
201 |
202 |
203 | FUNC_DICT = {
204 | "list": _list,
205 | "tuple": _tuple,
206 | "dict": _dict,
207 | "numpy.ndarray": _numpy_array,
208 | "jax.numpy.ndarray": _jax_array,
209 | "pandas.Series": _pandas_series,
210 | "pandas.DataFrame": _pandas_dataframe,
211 | "None": _none,
212 | "namedtuple": _namedtuple,
213 | "OrderedDict": _ordereddict,
214 | }
215 |
--------------------------------------------------------------------------------
/tests/test_tree_util.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from collections import namedtuple
3 | from collections import OrderedDict
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import pytest
8 | from numpy.testing import assert_array_almost_equal as aaae
9 | from pybaum.registry import get_registry
10 | from pybaum.tree_util import leaf_names
11 | from pybaum.tree_util import tree_equal
12 | from pybaum.tree_util import tree_flatten
13 | from pybaum.tree_util import tree_map
14 | from pybaum.tree_util import tree_multimap
15 | from pybaum.tree_util import tree_unflatten
16 | from pybaum.tree_util import tree_update
17 | from pybaum.tree_util import tree_yield
18 |
19 |
20 | @pytest.fixture
21 | def example_tree():
22 | return (
23 | [0, np.array([1, 2]), {"a": pd.Series([3, 4], index=["c", "d"]), "b": 5}],
24 | 6,
25 | )
26 |
27 |
28 | @pytest.fixture
29 | def example_flat():
30 | return [0, np.array([1, 2]), pd.Series([3, 4], index=["c", "d"]), 5, 6]
31 |
32 |
33 | @pytest.fixture
34 | def example_treedef(example_tree):
35 | return example_tree
36 |
37 |
38 | @pytest.fixture
39 | def extended_treedef(example_tree):
40 | return example_tree
41 |
42 |
43 | @pytest.fixture
44 | def extended_registry():
45 | types = ["pandas.DataFrame", "pandas.Series", "numpy.ndarray"]
46 | return get_registry(types=types)
47 |
48 |
49 | def test_tree_flatten(example_tree, example_flat, example_treedef):
50 | flat, treedef = tree_flatten(example_tree)
51 | assert treedef == example_treedef
52 | _assert_list_with_arrays_is_equal(flat, example_flat)
53 |
54 |
55 | def test_extended_tree_flatten(example_tree, extended_treedef, extended_registry):
56 | flat, treedef = tree_flatten(example_tree, registry=extended_registry)
57 | assert flat == list(range(7))
58 | assert tree_equal(treedef, extended_treedef)
59 |
60 |
61 | def test_tree_flatten_with_is_leave(example_tree, extended_registry):
62 | flat, _ = tree_flatten(
63 | example_tree,
64 | is_leaf=lambda tree: isinstance(tree, np.ndarray),
65 | registry=extended_registry,
66 | )
67 | expected_flat = [0, np.array([1, 2]), 3, 4, 5, 6]
68 | _assert_list_with_arrays_is_equal(flat, expected_flat)
69 |
70 |
71 | def test_tree_unflatten(example_flat, example_treedef, example_tree):
72 | unflat = tree_unflatten(example_treedef, example_flat)
73 |
74 | assert tree_equal(unflat, example_tree)
75 |
76 |
77 | def test_extended_tree_unflatten(example_tree, extended_treedef, extended_registry):
78 | unflat = tree_unflatten(
79 | extended_treedef, list(range(7)), registry=extended_registry
80 | )
81 | assert tree_equal(unflat, example_tree)
82 |
83 |
84 | def test_tree_unflatten_with_is_leaf(example_tree, extended_registry):
85 | unflat = tree_unflatten(
86 | example_tree,
87 | ([0, np.array([1, 2]), 3, 4, 5, 6]),
88 | is_leaf=lambda tree: isinstance(tree, np.ndarray),
89 | registry=extended_registry,
90 | )
91 | assert tree_equal(unflat, example_tree)
92 |
93 |
94 | def test_tree_map():
95 | tree = [{"a": 1, "b": 2, "c": {"d": 3, "e": 4}}]
96 | calculated = tree_map(lambda x: x * 2, tree)
97 | expected = [{"a": 2, "b": 4, "c": {"d": 6, "e": 8}}]
98 | assert calculated == expected
99 |
100 |
101 | def test_tree_multimap():
102 | tree = [{"a": 1, "b": 2, "c": {"d": 3, "e": 4}}]
103 | mapped = tree_map(lambda x: x**2, tree)
104 | multimapped = tree_multimap(lambda x, y: x * y, tree, tree)
105 | assert mapped == multimapped
106 |
107 |
108 | def test_leaf_names(example_tree):
109 | names = leaf_names(example_tree, separator="*")
110 |
111 | expected_names = ["0*0", "0*1", "0*2*a", "0*2*b", "1"]
112 | assert names == expected_names
113 |
114 |
115 | def test_extended_leaf_names(example_tree, extended_registry):
116 | names = leaf_names(example_tree, registry=extended_registry)
117 | expected_names = ["0_0", "0_1_0", "0_1_1", "0_2_a_c", "0_2_a_d", "0_2_b", "1"]
118 | assert names == expected_names
119 |
120 |
121 | def test_leaf_names_with_is_leaf(example_tree, extended_registry):
122 | names = leaf_names(
123 | example_tree,
124 | is_leaf=lambda tree: isinstance(tree, np.ndarray),
125 | registry=extended_registry,
126 | )
127 | expected_names = ["0_0", "0_1", "0_2_a_c", "0_2_a_d", "0_2_b", "1"]
128 | assert names == expected_names
129 |
130 |
131 | def test_iterative_flatten_and_one_step_flatten_and_unflatten(
132 | example_tree, extended_registry
133 | ):
134 | first_step_flat, first_step_treedef = tree_flatten(example_tree)
135 | second_step_flat, second_step_treedef = tree_flatten(
136 | first_step_flat, registry=extended_registry
137 | )
138 | one_step_flat, one_step_treedef = tree_flatten(
139 | example_tree, registry=extended_registry
140 | )
141 |
142 | assert second_step_flat == one_step_flat
143 |
144 | one_step_unflat = tree_unflatten(
145 | one_step_treedef, one_step_flat, registry=extended_registry
146 | )
147 | first_step_unflat = tree_unflatten(
148 | second_step_treedef, second_step_flat, registry=extended_registry
149 | )
150 | second_step_unflat = tree_unflatten(first_step_treedef, first_step_unflat)
151 |
152 | assert tree_equal(one_step_unflat, example_tree)
153 | assert tree_equal(second_step_unflat, example_tree)
154 |
155 |
156 | def test_tree_update(example_tree):
157 | other = ([7, np.array([8, 9]), {"b": 10}], 11)
158 | updated = tree_update(example_tree, other)
159 |
160 | expected = (
161 | [7, np.array([8, 9]), {"a": pd.Series([3, 4], index=["c", "d"]), "b": 10}],
162 | 11,
163 | )
164 | assert tree_equal(updated, expected)
165 |
166 |
167 | def _assert_list_with_arrays_is_equal(list1, list2):
168 | for first, second in zip(list1, list2):
169 | if isinstance(first, np.ndarray):
170 | aaae(first, second)
171 | elif isinstance(first, (pd.DataFrame, pd.Series)):
172 | assert first.equals(second)
173 | else:
174 | assert first == second
175 |
176 |
177 | def test_flatten_df_all_columns():
178 | registry = get_registry(types=["pandas.DataFrame"])
179 | df = pd.DataFrame(index=["a", "b", "c"])
180 | df["value"] = [1, 2, 3]
181 | df["bla"] = [4, 5, 6]
182 |
183 | flat, _ = tree_flatten(df, registry=registry)
184 |
185 | assert flat == [1, 4, 2, 5, 3, 6]
186 |
187 |
188 | def test_tree_yield(example_tree, example_treedef, example_flat):
189 | generator, treedef = tree_yield(example_tree)
190 |
191 | assert tree_equal(treedef, example_treedef)
192 | assert inspect.isgenerator(generator)
193 | for a, b in zip(generator, example_flat):
194 | if isinstance(a, (np.ndarray, pd.Series)):
195 | aaae(a, b)
196 | else:
197 | assert a == b
198 |
199 |
200 | def test_flatten_with_none():
201 | flat, treedef = tree_flatten(None)
202 | assert flat == []
203 | assert treedef is None
204 |
205 |
206 | def test_leaf_names_with_none():
207 | names = leaf_names(None)
208 | assert names == []
209 |
210 |
211 | def test_flatten_with_namedtuple():
212 | bla = namedtuple("bla", ["a", "b"])(1, 2)
213 | flat, _ = tree_flatten(bla)
214 | assert flat == [1, 2]
215 |
216 |
217 | def test_names_with_namedtuple():
218 | bla = namedtuple("bla", ["a", "b"])(1, 2)
219 | names = leaf_names(bla)
220 | assert names == ["a", "b"]
221 |
222 |
223 | def test_flatten_with_ordered_dict():
224 | d = OrderedDict({"a": 1, "b": 2})
225 | flat, _ = tree_flatten(d)
226 | assert flat == [1, 2]
227 |
228 |
229 | def test_names_with_ordered_dict():
230 | d = OrderedDict({"a": 1, "b": 2})
231 | names = leaf_names(d)
232 | assert names == ["a", "b"]
233 |
--------------------------------------------------------------------------------
/src/pybaum/tree_util.py:
--------------------------------------------------------------------------------
1 | """Implement functionality similar to jax.tree_util in pure Python.
2 |
3 | The functions are not completely identical to jax. The most notable differences are:
4 |
5 | - Instead of a global registry of pytree nodes, most functions have a registry argument.
6 | - The treedef containing information to unflatten pytrees is implemented differently.
7 |
8 | """
9 | from pybaum.equality import EQUALITY_CHECKERS
10 | from pybaum.registry import get_registry
11 | from pybaum.typecheck import get_type
12 |
13 |
14 | def tree_flatten(tree, is_leaf=None, registry=None):
15 | """Flatten a pytree and create a treedef.
16 |
17 | Args:
18 | tree: a pytree to flatten.
19 | is_leaf (callable or None): An optionally specified function that will be called
20 | at each flattening step. It should return a boolean, which indicates whether
21 | the flattening should traverse the current object, or if it should be
22 | stopped immediately, with the whole subtree being treated as a leaf.
23 | registry (dict or None): A pytree container registry that determines
24 | which types are considered container objects that should be flattened.
25 | ``is_leaf`` can override this in the sense that types that are in the
26 | registry are still considered a leaf but it cannot declare something a
27 | container that is not in the registry. None means that the default registry
28 | is used, i.e. that dicts, tuples and lists are considered containers.
29 | "extended" means that in addition numpy arrays and params DataFrames are
30 | considered containers. Passing a dictionary where the keys are types and the
31 | values are dicts with the entries "flatten", "unflatten" and "names" allows
32 | to completely override the default registries.
33 |
34 | Returns:
35 | A pair where the first element is a list of leaf values and the second
36 | element is a treedef representing the structure of the flattened tree.
37 |
38 | """
39 | registry = _process_pytree_registry(registry)
40 | is_leaf = _process_is_leaf(is_leaf)
41 |
42 | flat = _tree_flatten(tree, is_leaf=is_leaf, registry=registry)
43 | # unflatten the flat tree to make a copy
44 | treedef = tree_unflatten(tree, flat, is_leaf=is_leaf, registry=registry)
45 | return flat, treedef
46 |
47 |
48 | def tree_just_flatten(tree, is_leaf=None, registry=None):
49 | """Flatten a pytree without creating a treedef.
50 |
51 | Args:
52 | tree: a pytree to flatten.
53 | is_leaf (callable or None): An optionally specified function that will be called
54 | at each flattening step. It should return a boolean, which indicates whether
55 | the flattening should traverse the current object, or if it should be
56 | stopped immediately, with the whole subtree being treated as a leaf.
57 | registry (dict or None): A pytree container registry that determines
58 | which types are considered container objects that should be flattened.
59 | ``is_leaf`` can override this in the sense that types that are in the
60 | registry are still considered a leaf but it cannot declare something a
61 | container that is not in the registry. None means that the default registry
62 | is used, i.e. that dicts, tuples and lists are considered containers.
63 | "extended" means that in addition numpy arrays and params DataFrames are
64 | considered containers. Passing a dictionary where the keys are types and the
65 | values are dicts with the entries "flatten", "unflatten" and "names" allows
66 | to completely override the default registries.
67 |
68 | Returns:
69 | A pair where the first element is a list of leaf values and the second
70 | element is a treedef representing the structure of the flattened tree.
71 |
72 | """
73 | registry = _process_pytree_registry(registry)
74 | is_leaf = _process_is_leaf(is_leaf)
75 |
76 | flat = _tree_flatten(tree, is_leaf=is_leaf, registry=registry)
77 | return flat
78 |
79 |
80 | def _tree_flatten(tree, is_leaf, registry):
81 | out = []
82 | tree_type = get_type(tree)
83 |
84 | if tree_type not in registry or is_leaf(tree):
85 | out.append(tree)
86 | else:
87 | subtrees, _ = registry[tree_type]["flatten"](tree)
88 | for subtree in subtrees:
89 | if get_type(subtree) in registry:
90 | out += _tree_flatten(subtree, is_leaf, registry)
91 | else:
92 | out.append(subtree)
93 | return out
94 |
95 |
96 | def tree_yield(tree, is_leaf=None, registry=None):
97 | """Yield leafs from a pytree and create the tree definition.
98 |
99 | Args:
100 | tree: a pytree.
101 | is_leaf (callable or None): An optionally specified function that will be called
102 | at each yield step. It should return a boolean, which indicates whether
103 | the generator should traverse the current object, or if it should be
104 | stopped immediately, with the whole subtree being treated as a leaf.
105 | registry (dict or None): A pytree container registry that determines
106 | which types are considered container objects that should be yielded.
107 | ``is_leaf`` can override this in the sense that types that are in the
108 | registry are still considered a leaf but it cannot declare something a
109 | container that is not in the registry. None means that the default registry
110 | is used, i.e. that dicts, tuples and lists are considered containers.
111 | "extended" means that in addition numpy arrays and params DataFrames are
112 | considered containers. Passing a dictionary where the keys are types and the
113 | values are dicts with the entries "flatten", "unflatten" and "names" allows
114 | to completely override the default registries.
115 |
116 | Returns:
117 | A pair where the first element is a generator of leaf values and the second
118 | element is a treedef representing the structure of the flattened tree.
119 |
120 | """
121 | registry = _process_pytree_registry(registry)
122 | is_leaf = _process_is_leaf(is_leaf)
123 |
124 | flat = _tree_yield(tree, is_leaf=is_leaf, registry=registry)
125 | return flat, tree
126 |
127 |
128 | def tree_just_yield(tree, is_leaf=None, registry=None):
129 | """Yield leafs from a pytree without creating a treedef.
130 |
131 | Args:
132 | tree: a pytree.
133 | is_leaf (callable or None): An optionally specified function that will be called
134 | at each yield step. It should return a boolean, which indicates whether
135 | the generator should traverse the current object, or if it should be
136 | stopped immediately, with the whole subtree being treated as a leaf.
137 | registry (dict or None): A pytree container registry that determines
138 | which types are considered container objects that should be yielded.
139 | ``is_leaf`` can override this in the sense that types that are in the
140 | registry are still considered a leaf but it cannot declare something a
141 | container that is not in the registry. None means that the default registry
142 | is used, i.e. that dicts, tuples and lists are considered containers.
143 | "extended" means that in addition numpy arrays and params DataFrames are
144 | considered containers. Passing a dictionary where the keys are types and the
145 | values are dicts with the entries "flatten", "unflatten" and "names" allows
146 | to completely override the default registries.
147 |
148 | Returns:
149 | A generator of leaf values.
150 |
151 | """
152 | registry = _process_pytree_registry(registry)
153 | is_leaf = _process_is_leaf(is_leaf)
154 |
155 | flat = _tree_yield(tree, is_leaf=is_leaf, registry=registry)
156 | return flat
157 |
158 |
159 | def _tree_yield(tree, is_leaf, registry):
160 | out = []
161 | tree_type = get_type(tree)
162 |
163 | if tree_type not in registry or is_leaf(tree):
164 | yield tree
165 | else:
166 | subtrees, _ = registry[tree_type]["flatten"](tree)
167 | for subtree in subtrees:
168 | if get_type(subtree) in registry:
169 | yield from _tree_yield(subtree, is_leaf, registry)
170 | else:
171 | yield subtree
172 | return out
173 |
174 |
175 | def tree_unflatten(treedef, leaves, is_leaf=None, registry=None):
176 | """Reconstruct a pytree from the treedef and a list of leaves.
177 |
178 | The inverse of :func:`tree_flatten`.
179 |
180 | Args:
181 | treedef: the treedef to with information needed for reconstruction.
182 | leaves (list): the list of leaves to use for reconstruction. The list must match
183 | the leaves of the treedef.
184 | is_leaf (callable or None): An optionally specified function that will be called
185 | at each flattening step. It should return a boolean, which indicates whether
186 | the flattening should traverse the current object, or if it should be
187 | stopped immediately, with the whole subtree being treated as a leaf.
188 | registry (dict or None): A pytree container registry that determines
189 | which types are considered container objects that should be flattened.
190 | `is_leaf` can override this in the sense that types that are in the
191 | registry are still considered a leaf but it cannot declare something a
192 | container that is not in the registry. None means that the default registry
193 | is used, i.e. that dicts, tuples and lists are considered containers.
194 | "extended" means that in addition numpy arrays and params DataFrames are
195 | considered containers. Passing a dictionary where the keys are types and the
196 | values are dicts with the entries "flatten", "unflatten" and "names" allows
197 | to completely override the default registries.
198 |
199 | Returns:
200 | The reconstructed pytree, containing the ``leaves`` placed in the structure
201 | described by ``treedef``.
202 |
203 | """
204 | registry = _process_pytree_registry(registry)
205 | is_leaf = _process_is_leaf(is_leaf)
206 | return _tree_unflatten(treedef, leaves, is_leaf=is_leaf, registry=registry)
207 |
208 |
209 | def _tree_unflatten(treedef, leaves, is_leaf, registry):
210 | leaves = iter(leaves)
211 | tree_type = get_type(treedef)
212 |
213 | if tree_type not in registry or is_leaf(treedef):
214 | return next(leaves)
215 | else:
216 | items, info = registry[tree_type]["flatten"](treedef)
217 | unflattened_items = []
218 | for item in items:
219 | if get_type(item) in registry:
220 | unflattened_items.append(
221 | _tree_unflatten(item, leaves, is_leaf=is_leaf, registry=registry)
222 | )
223 | else:
224 | unflattened_items.append(next(leaves))
225 | return registry[tree_type]["unflatten"](info, unflattened_items)
226 |
227 |
228 | def tree_map(func, tree, is_leaf=None, registry=None):
229 | """Apply func to all leaves in tree.
230 |
231 | Args:
232 | func (callable): Function applied to each leaf in the tree.
233 | tree: A pytree.
234 | is_leaf (callable or None): An optionally specified function that will be called
235 | at each flattening step. It should return a boolean, which indicates whether
236 | the flattening should traverse the current object, or if it should be
237 | stopped immediately, with the whole subtree being treated as a leaf.
238 | registry (dict or None): A pytree container registry that determines
239 | which types are considered container objects that should be flattened.
240 | `is_leaf` can override this in the sense that types that are in the
241 | registry are still considered a leaf but it cannot declare something a
242 | container that is not in the registry. None means that the default registry
243 | is used, i.e. that dicts, tuples and lists are considered containers.
244 | "extended" means that in addition numpy arrays and params DataFrames are
245 | considered containers. Passing a dictionary where the keys are types and the
246 | values are dicts with the entries "flatten", "unflatten" and "names" allows
247 | to completely override the default registries.
248 | Returns:
249 | modified copy of tree.
250 |
251 | """
252 | flat, treedef = tree_flatten(tree, is_leaf=is_leaf, registry=registry)
253 | modified = [func(i) for i in flat]
254 | new_tree = tree_unflatten(treedef, modified, is_leaf=is_leaf, registry=registry)
255 | return new_tree
256 |
257 |
258 | def tree_multimap(func, *trees, is_leaf=None, registry=None):
259 | """Apply func to leaves of multiple pytrees.
260 |
261 | Args:
262 | func (callable): Function applied to each leaf corresponding leaves of
263 | multiple py trees.
264 | trees: An arbitrary number of pytrees. All trees need to have the same
265 | structure.
266 | is_leaf (callable or None): An optionally specified function that will be called
267 | at each flattening step. It should return a boolean, which indicates whether
268 | the flattening should traverse the current object, or if it should be
269 | stopped immediately, with the whole subtree being treated as a leaf.
270 | registry (dict or None): A pytree container registry that determines
271 | which types are considered container objects that should be flattened.
272 | `is_leaf` can override this in the sense that types that are in the
273 | registry are still considered a leaf but it cannot declare something a
274 | container that is not in the registry. None means that the default registry
275 | is used, i.e. that dicts, tuples and lists are considered containers.
276 | "extended" means that in addition numpy arrays and params DataFrames are
277 | considered containers. Passing a dictionary where the keys are types and the
278 | values are dicts with the entries "flatten", "unflatten" and "names" allows
279 | to completely override the default registries.
280 | Returns:
281 | tree with the same structure as the elements in trees.
282 |
283 | """
284 | flat_trees, treedefs = [], []
285 | for tree in trees:
286 | flat, treedef = tree_flatten(tree, is_leaf=is_leaf, registry=registry)
287 | flat_trees.append(flat)
288 | treedefs.append(treedef)
289 |
290 | for treedef in treedefs:
291 | if treedef != treedefs[0]:
292 | raise ValueError("All trees must have the same structure.")
293 |
294 | modified = [func(*item) for item in zip(*flat_trees)]
295 |
296 | new_trees = tree_unflatten(
297 | treedefs[0], modified, is_leaf=is_leaf, registry=registry
298 | )
299 | return new_trees
300 |
301 |
302 | def leaf_names(tree, is_leaf=None, registry=None, separator="_"):
303 | """Construct names for leaves in a pytree.
304 |
305 | Args:
306 | tree: a pytree to flatten.
307 | is_leaf (callable or None): An optionally specified function that will be called
308 | at each flattening step. It should return a boolean, which indicates whether
309 | the flattening should traverse the current object, or if it should be
310 | stopped immediately, with the whole subtree being treated as a leaf.
311 | registry (dict or None): A pytree container registry that determines
312 | which types are considered container objects that should be flattened.
313 | `is_leaf` can override this in the sense that types that are in the
314 | registry are still considered a leaf but it cannot declare something a
315 | container that is not in the registry. None means that the default registry
316 | is used, i.e. that dicts, tuples and lists are considered containers.
317 | "extended" means that in addition numpy arrays and params DataFrames are
318 | considered containers. Passing a dictionary where the keys are types and the
319 | values are dicts with the entries "flatten", "unflatten" and "names" allows
320 | to completely override the default registries.
321 | separator (str): String that separates the building blocks of the leaf name.
322 | Returns:
323 | list: List of strings with names for pytree leaves.
324 |
325 | """
326 | registry = _process_pytree_registry(registry)
327 | is_leaf = _process_is_leaf(is_leaf)
328 | leaf_names = _leaf_names(
329 | tree, is_leaf=is_leaf, registry=registry, separator=separator
330 | )
331 | return leaf_names
332 |
333 |
334 | def _leaf_names(tree, is_leaf, registry, separator, prefix=None):
335 | out = []
336 | tree_type = get_type(tree)
337 |
338 | if tree_type not in registry or is_leaf(tree):
339 | out.append(prefix)
340 | else:
341 | subtrees, info = registry[tree_type]["flatten"](tree)
342 | names = registry[tree_type]["names"](tree)
343 | for name, subtree in zip(names, subtrees):
344 | if get_type(subtree) in registry:
345 | out += _leaf_names(
346 | subtree,
347 | is_leaf=is_leaf,
348 | registry=registry,
349 | separator=separator,
350 | prefix=_add_prefix(prefix, name, separator),
351 | )
352 | else:
353 | out.append(_add_prefix(prefix, name, separator))
354 | return out
355 |
356 |
357 | def _add_prefix(prefix, string, separator):
358 | if prefix not in (None, ""):
359 | out = separator.join([prefix, string])
360 | else:
361 | out = string
362 | return out
363 |
364 |
365 | def _process_pytree_registry(registry):
366 | registry = registry if registry is not None else get_registry()
367 | return registry
368 |
369 |
370 | def _process_is_leaf(is_leaf):
371 | if is_leaf is None:
372 | return lambda tree: False # noqa: U100
373 | else:
374 | return is_leaf
375 |
376 |
377 | def tree_equal(tree, other, is_leaf=None, registry=None, equality_checkers=None):
378 | """Determine if two pytrees are equal.
379 |
380 | Two pytrees are considered equal if their leaves are equal and the names of their
381 | leaves are equal. While this definition of equality might not always make sense
382 | it makes sense in most cases and can be implemented relatively easily.
383 |
384 | Args:
385 | tree: A pytree.
386 | other: Another pytree.
387 | is_leaf (callable or None): An optionally specified function that will be called
388 | at each flattening step. It should return a boolean, which indicates whether
389 | the flattening should traverse the current object, or if it should be
390 | stopped immediately, with the whole subtree being treated as a leaf.
391 | registry (dict or None): A pytree container registry that determines
392 | which types are considered container objects that should be flattened.
393 | `is_leaf` can override this in the sense that types that are in the
394 | registry are still considered a leaf but it cannot declare something a
395 | container that is not in the registry. None means that the default registry
396 | is used, i.e. that dicts, tuples and lists are considered containers.
397 | "extended" means that in addition numpy arrays and params DataFrames are
398 | considered containers. Passing a dictionary where the keys are types and the
399 | values are dicts with the entries "flatten", "unflatten" and "names" allows
400 | to completely override the default registries.
401 | equality_checkers (dict, None): A dictionary where keys are types and values are
402 | functions which assess equality for the type of object.
403 |
404 | Returns:
405 | bool
406 |
407 | """
408 | equality_checkers = (
409 | EQUALITY_CHECKERS
410 | if equality_checkers is None
411 | else {**EQUALITY_CHECKERS, **equality_checkers}
412 | )
413 |
414 | first_flat = tree_just_flatten(tree, is_leaf=is_leaf, registry=registry)
415 | second_flat = tree_just_flatten(other, is_leaf=is_leaf, registry=registry)
416 |
417 | first_names = leaf_names(tree, is_leaf=is_leaf, registry=registry)
418 | second_names = leaf_names(tree, is_leaf=is_leaf, registry=registry)
419 |
420 | equal = first_names == second_names
421 |
422 | if equal:
423 | for first, second in zip(first_flat, second_flat):
424 | check_func = equality_checkers.get(get_type(first), lambda a, b: a == b)
425 | equal = equal and check_func(first, second)
426 | if not equal:
427 | break
428 |
429 | return equal
430 |
431 |
432 | def tree_update(tree, other, is_leaf=None, registry=None):
433 | """Update leaves in a pytree with leaves from another pytree.
434 |
435 | The second pytree must be compatible with the first one but can be smaller. For
436 | example, lists can be shorter, dictionaries can contain subsets of entries, etc.
437 |
438 | Args:
439 | tree: A pytree.
440 | other: Another pytree.
441 | is_leaf (callable or None): An optionally specified function that will be called
442 | at each flattening step. It should return a boolean, which indicates whether
443 | the flattening should traverse the current object, or if it should be
444 | stopped immediately, with the whole subtree being treated as a leaf.
445 | registry (dict or None): A pytree container registry that determines
446 | which types are considered container objects that should be flattened.
447 | `is_leaf` can override this in the sense that types that are in the
448 | registry are still considered a leaf but it cannot declare something a
449 | container that is not in the registry. None means that the default registry
450 | is used, i.e. that dicts, tuples and lists are considered containers.
451 | "extended" means that in addition numpy arrays and params DataFrames are
452 | considered containers. Passing a dictionary where the keys are types and the
453 | values are dicts with the entries "flatten", "unflatten" and "names" allows
454 | to completely override the default registries.
455 | Returns:
456 | Updated pytree.
457 |
458 | """
459 | first_flat, first_treedef = tree_flatten(tree, is_leaf=is_leaf, registry=registry)
460 | first_names = leaf_names(tree, is_leaf=is_leaf, registry=registry)
461 | first_dict = dict(zip(first_names, first_flat))
462 |
463 | other_flat, _ = tree_flatten(other, is_leaf=is_leaf, registry=registry)
464 | other_names = leaf_names(other, is_leaf=is_leaf, registry=registry)
465 | other_dict = dict(zip(other_names, other_flat))
466 |
467 | combined = list({**first_dict, **other_dict}.values())
468 |
469 | out = tree_unflatten(first_treedef, combined, is_leaf=is_leaf, registry=registry)
470 | return out
471 |
--------------------------------------------------------------------------------