├── .github └── workflows │ ├── docs.yml │ ├── publish.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── docs ├── Makefile ├── README.md ├── _static │ ├── icon.svg │ └── style.css ├── api │ ├── utils.rst │ └── wrappers.rst ├── conf.py ├── index.rst └── make.bat ├── paramax ├── __init__.py ├── py.typed ├── utils.py └── wrappers.py ├── pyproject.toml └── tests ├── test_utils.py └── test_wrappers.py /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Test and publish documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | docs: 10 | name: Build and publish documentation 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-python@v5 15 | with: 16 | python-version: "3.x" 17 | 18 | - name: Install pandoc 19 | run: | 20 | sudo apt-get update 21 | sudo apt-get install -y pandoc 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install -e .[dev] 26 | - name: Test documentation 27 | run: | 28 | make -C docs doctest 29 | - name: Sphinx build 30 | run: | 31 | sphinx-build docs docs/_build 32 | - name: Deploy 33 | uses: peaceiris/actions-gh-pages@v4 34 | with: 35 | publish_branch: gh-pages 36 | github_token: ${{ secrets.GITHUB_TOKEN }} 37 | publish_dir: docs/_build/ 38 | force_orphan: true 39 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | # Triggers the workflow when a release is created or edited. 5 | release: 6 | types: [created] 7 | 8 | jobs: 9 | build: 10 | name: Build distribution 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Set up Python 15 | uses: actions/setup-python@v5 16 | with: 17 | python-version: "3.x" 18 | 19 | - name: Install pypa/build 20 | run: python3 -m pip install build --user 21 | - name: Build a binary wheel and a source tarball 22 | run: python3 -m build 23 | - name: Store the distribution packages 24 | uses: actions/upload-artifact@v4 25 | with: 26 | name: python-package-distributions 27 | path: dist/ 28 | 29 | publish-to-pypi: 30 | name: Publish to PyPI 31 | needs: 32 | - build 33 | runs-on: ubuntu-latest 34 | 35 | environment: 36 | name: pypi 37 | url: https://pypi.org/p/paramax 38 | 39 | permissions: 40 | id-token: write # Mandatory for trusted publishing 41 | 42 | steps: 43 | - name: Download all the dists 44 | uses: actions/download-artifact@v4 45 | with: 46 | name: python-package-distributions 47 | path: dist/ 48 | - name: Publish distribution 📦 to PyPI 49 | uses: pypa/gh-action-pypi-publish@release/v1 50 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | Test: 13 | name: Tests 14 | runs-on: ubuntu-latest 15 | timeout-minutes: 10 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - uses: actions/setup-python@v5 20 | with: 21 | python-version: "3.x" 22 | 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install -e .[dev] 27 | 28 | - name: Pytest 29 | run: | 30 | pytest 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | .pylintrc 3 | 4 | # C extensions 5 | *.so 6 | 7 | # Packages 8 | *.egg 9 | *.egg-info 10 | dist 11 | build 12 | eggs 13 | parts 14 | bin 15 | var 16 | sdist 17 | develop-eggs 18 | .installed.cfg 19 | lib 20 | lib64 21 | __pycache__ 22 | 23 | # Installer logs 24 | pip-log.txt 25 | 26 | # Unit test / coverage reports 27 | .coverage 28 | .tox 29 | nosetests.xml 30 | 31 | # Translations 32 | *.mo 33 | 34 | # Mr Developer 35 | .mr.developer.cfg 36 | .project 37 | .pydevproject 38 | 39 | # Personal 40 | .ipynb_checkpoints 41 | .mypy* 42 | .vscode/ 43 | 44 | timing_tests.ipynb 45 | _temp* 46 | .vscode/settings.json 47 | _build 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2022 Daniel Ward 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | Paramax 3 | ============ 4 | Parameterizations and constraints for JAX PyTrees 5 | ----------------------------------------------------------------------- 6 | 7 | Paramax allows applying custom constraints or behaviors to PyTree components, 8 | using unwrappable placeholders. This can be used for 9 | - Enforcing positivity (e.g., scale parameters) 10 | - Structured matrices (triangular, symmetric, etc.) 11 | - Applying tricks like weight normalization 12 | - Marking components as non-trainable 13 | 14 | Some benefits of the unwrappable pattern: 15 | - It allows parameterizations to be computed once for a model (e.g. at the top of the 16 | loss function). 17 | - It is flexible, e.g. allowing custom parameterizations to be applied to PyTrees 18 | from external libraries 19 | - It is concise 20 | 21 | If you found the package useful, please consider giving it a star on github, and if you 22 | create ``AbstractUnwrappable``s that may be of interest to others, a pull request would 23 | be much appreciated! 24 | 25 | ## Documentation 26 | 27 | Documentation available [here](https://danielward27.github.io/paramax/). 28 | 29 | ## Installation 30 | ```bash 31 | pip install paramax 32 | ``` 33 | 34 | ## Example 35 | ```python 36 | >>> import paramax 37 | >>> import jax.numpy as jnp 38 | >>> scale = paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3))) # Enforce positivity 39 | >>> paramax.unwrap(("abc", 1, scale)) 40 | ('abc', 1, Array([1., 1., 1.], dtype=float32)) 41 | ``` 42 | 43 | ## Alternative parameterization patterns 44 | Using properties to access parameterized model components is common but has drawbacks: 45 | - Parameterizations are tied to class definition, limiting flexibility e.g. this 46 | cannot be used on PyTrees from external libraries 47 | - It can become verbose with many parameters 48 | - It often leads to repeatedly computing the parameterization 49 | 50 | ## Related 51 | - We make use of the [Equinox](https://arxiv.org/abs/2111.00254) package, to register 52 | the PyTrees used in the package 53 | - This package spawned out of a need for a simple method to apply parameter constraints 54 | in the distributions package [flowjax](https://github.com/danielward27/flowjax) 55 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS = -W 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | ### Documentation 2 | 3 | The documentation is supported by [Sphinx](https://www.sphinx-doc.org/en/master/). To build the HTML pages locally, run 4 | 5 | ``` 6 | make -C docs html 7 | ``` 8 | from the docs directory. The documentation can then be viewed by opening `./docs/_build/html/index.html``. To test the doctest code blocks in the documentation, run from the top level directory: 9 | ``` 10 | make -C docs doctest 11 | ``` 12 | 13 | Github Actions is used for continuous integration, and the tests will fail if either the documentation does not build, or any doctest examples fail. 14 | -------------------------------------------------------------------------------- /docs/_static/icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /docs/_static/style.css: -------------------------------------------------------------------------------- 1 | @import url("theme.css"); 2 | 3 | /* Clearer distinction between each class. */ 4 | dl.py.class { 5 | border-left: 3px solid #008080; 6 | padding: 10px; 7 | } 8 | 9 | /* Remove unnecessary scroll bar https://github.com/executablebooks/sphinx-book-theme/issues/732 */ 10 | #rtd-footer-container { 11 | margin-top: 0px !important; 12 | margin-bottom: 0px !important; 13 | } -------------------------------------------------------------------------------- /docs/api/utils.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | ============================================= 3 | .. automodule:: paramax.utils 4 | :members: 5 | :undoc-members: 6 | :member-order: bysource 7 | -------------------------------------------------------------------------------- /docs/api/wrappers.rst: -------------------------------------------------------------------------------- 1 | Wrappers 2 | ============================================= 3 | .. automodule:: paramax.wrappers 4 | :members: 5 | :undoc-members: 6 | :member-order: bysource -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | """Configuration file for the Sphinx documentation builder.""" 2 | 3 | import sys 4 | import typing 5 | from pathlib import Path 6 | 7 | if "doctest" not in sys.argv: # Avoid type checking/isinstance failures. 8 | # Tag used to avoid expanding arraylike alias in docs 9 | typing.GENERATING_DOCUMENTATION = True 10 | 11 | 12 | sys.path.insert(0, Path("..").resolve()) 13 | 14 | project = "Paramax" 15 | copyright = "2022, Daniel Ward" 16 | author = "Daniel Ward" 17 | 18 | extensions = [ 19 | "sphinx.ext.viewcode", 20 | "sphinx.ext.autodoc", 21 | "sphinx.ext.doctest", 22 | "sphinx.ext.intersphinx", 23 | "sphinx_copybutton", 24 | "sphinx.ext.napoleon", 25 | "sphinx_autodoc_typehints", 26 | ] 27 | 28 | intersphinx_mapping = { 29 | "python": ("https://docs.python.org/3/", None), 30 | "jax": ("https://jax.readthedocs.io/en/latest/", None), 31 | } 32 | 33 | templates_path = ["_templates"] 34 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 35 | 36 | 37 | html_theme = "sphinx_book_theme" 38 | html_static_path = ["_static"] 39 | 40 | html_css_files = ["style.css"] 41 | 42 | html_theme_options = { 43 | "use_fullscreen_button": False, 44 | "use_download_button": False, 45 | "use_repository_button": True, 46 | "repository_url": "https://github.com/danielward27/paramax", 47 | "home_page_in_toc": True, 48 | } 49 | 50 | html_title = "Paramax" 51 | html_favicon = "_static/icon.svg" 52 | 53 | pygments_style = "xcode" 54 | 55 | copybutton_prompt_text = r">>> |\.\.\. " 56 | copybutton_prompt_is_regexp = True 57 | 58 | napolean_use_rtype = False 59 | napoleon_attr_annotations = True 60 | napoleon_use_ivar = True 61 | 62 | add_module_names = False 63 | autodoc_inherit_docstrings = False 64 | python_maximum_signature_line_length = 88 65 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Paramax 2 | =========== 3 | 4 | A small package for applying parameterizations and constraints to nodes in JAX 5 | PyTrees. 6 | 7 | 8 | Installation 9 | ------------------------ 10 | .. code-block:: bash 11 | 12 | pip install paramax 13 | 14 | 15 | How it works 16 | ------------------ 17 | - :py:class:`~paramax.wrappers.AbstractUnwrappable` objects act as placeholders in the 18 | PyTree, defining the parameterizations. 19 | - :py:func:`~paramax.wrappers.unwrap` applies the parameterizations, replacing the 20 | :py:class:`~paramax.wrappers.AbstractUnwrappable` objects. 21 | 22 | A simple example of an :py:class:`~paramax.wrappers.AbstractUnwrappable` 23 | is :py:class:`~paramax.wrappers.Parameterize`. This class takes a callable and any 24 | positional or keyword arguments, which are stored and passed to the function when 25 | unwrapping. 26 | 27 | 28 | .. doctest:: 29 | 30 | >>> import paramax 31 | >>> import jax.numpy as jnp 32 | >>> scale = jnp.ones(3) # Keep this positive 33 | >>> constrained_scale = paramax.Parameterize(jnp.exp, jnp.log(scale)) 34 | >>> model = ("abc", 1, constrained_scale) # Any PyTree 35 | >>> paramax.unwrap(model) # Unwraps any AbstractUnwrappables 36 | ('abc', 1, Array([1., 1., 1.], dtype=float32)) 37 | 38 | 39 | Many simple parameterizations can be handled with this class, for example, 40 | we can parameterize a lower triangular matrix using 41 | 42 | .. doctest:: 43 | 44 | >>> import paramax 45 | >>> import jax.numpy as jnp 46 | >>> tril = jnp.tril(jnp.ones((3,3))) 47 | >>> tril = paramax.Parameterize(jnp.tril, tril) 48 | 49 | 50 | See :doc:`/api/wrappers` for more :py:class:`~paramax.wrappers.AbstractUnwrappable` 51 | objects. 52 | 53 | When to unwrap 54 | ------------------- 55 | - Unwrap whenever necessary, typically at the top of loss functions, functions or 56 | methods requiring the parameterizations to have been applied. 57 | - Unwrapping prior to a gradient computation used for optimization is usually a mistake! 58 | 59 | 60 | .. toctree:: 61 | :caption: API 62 | :maxdepth: 1 63 | :glob: 64 | 65 | api/wrappers 66 | api/utils 67 | 68 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /paramax/__init__.py: -------------------------------------------------------------------------------- 1 | """Paramax - Paramaterizations and constraints for PyTrees.""" 2 | 3 | from importlib.metadata import version 4 | 5 | from .wrappers import ( 6 | AbstractUnwrappable, 7 | NonTrainable, 8 | Parameterize, 9 | WeightNormalization, 10 | contains_unwrappables, 11 | non_trainable, 12 | unwrap, 13 | ) 14 | 15 | __version__ = version("paramax") 16 | 17 | __all__ = [ 18 | "AbstractUnwrappable", 19 | "NonTrainable", 20 | "Parameterize", 21 | "WeightNormalization", 22 | "contains_unwrappables", 23 | "non_trainable", 24 | "unwrap", 25 | ] 26 | -------------------------------------------------------------------------------- /paramax/py.typed: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /paramax/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | 3 | import equinox as eqx 4 | import jax.numpy as jnp 5 | from jaxtyping import Array, ArrayLike 6 | 7 | 8 | def inv_softplus(x: ArrayLike) -> Array: 9 | """The inverse of the softplus function, checking for positive inputs.""" 10 | x = eqx.error_if( 11 | x, 12 | x < 0, 13 | "Expected positive inputs to inv_softplus. If you are trying to use a negative " 14 | "scale parameter, you may be able to construct with positive scales, and " 15 | "modify the scale attribute post-construction, e.g., using eqx.tree_at.", 16 | ) 17 | return jnp.log(-jnp.expm1(-x)) + x 18 | -------------------------------------------------------------------------------- /paramax/wrappers.py: -------------------------------------------------------------------------------- 1 | """:class:`AbstractUnwrappable` objects and utilities. 2 | 3 | These are placeholder values for specifying custom behaviour for nodes in a pytree, 4 | applied using :func:`unwrap`. 5 | """ 6 | 7 | from abc import abstractmethod 8 | from collections.abc import Callable 9 | from typing import Any, Generic, TypeVar 10 | 11 | import equinox as eqx 12 | import jax 13 | import jax.numpy as jnp 14 | from jax import lax 15 | from jax.nn import softplus 16 | from jax.tree_util import tree_leaves 17 | from jaxtyping import Array, PyTree 18 | 19 | from paramax.utils import inv_softplus 20 | 21 | T = TypeVar("T") 22 | 23 | 24 | class AbstractUnwrappable(eqx.Module, Generic[T]): 25 | """An abstract class representing an unwrappable object. 26 | 27 | Unwrappables replace PyTree nodes, applying custom behavior upon unwrapping. 28 | """ 29 | 30 | @abstractmethod 31 | def unwrap(self) -> T: 32 | """Returns the unwrapped pytree, assuming no wrapped subnodes exist.""" 33 | pass 34 | 35 | 36 | def unwrap(tree: PyTree): 37 | """Map across a PyTree and unwrap all :class:`AbstractUnwrappable` nodes. 38 | 39 | This leaves all other nodes unchanged. If nested, the innermost 40 | ``AbstractUnwrappable`` nodes are unwrapped first. 41 | 42 | Example: 43 | Enforcing positivity. 44 | 45 | .. doctest:: 46 | 47 | >>> import paramax 48 | >>> import jax.numpy as jnp 49 | >>> params = paramax.Parameterize(jnp.exp, jnp.zeros(3)) 50 | >>> paramax.unwrap(("abc", 1, params)) 51 | ('abc', 1, Array([1., 1., 1.], dtype=float32)) 52 | """ 53 | 54 | def _unwrap(tree, *, include_self: bool): 55 | def _map_fn(leaf): 56 | if isinstance(leaf, AbstractUnwrappable): 57 | # Unwrap subnodes, then itself 58 | return _unwrap(leaf, include_self=False).unwrap() 59 | return leaf 60 | 61 | def is_leaf(x): 62 | is_unwrappable = isinstance(x, AbstractUnwrappable) 63 | included = include_self or x is not tree 64 | return is_unwrappable and included 65 | 66 | return jax.tree_util.tree_map(f=_map_fn, tree=tree, is_leaf=is_leaf) 67 | 68 | return _unwrap(tree, include_self=True) 69 | 70 | 71 | class Parameterize(AbstractUnwrappable[T]): 72 | """Unwrap an object by calling fn with args and kwargs. 73 | 74 | All of fn, args and kwargs may contain trainable parameters. 75 | 76 | .. note:: 77 | 78 | Unwrapping typically occurs after model initialization. Therefore, if the 79 | ``Parameterize`` object may be created in a vectorized context, we recommend 80 | ensuring that ``fn`` still unwraps correctly, e.g. by supporting broadcasting. 81 | 82 | Example: 83 | .. doctest:: 84 | 85 | >>> from paramax.wrappers import Parameterize, unwrap 86 | >>> import jax.numpy as jnp 87 | >>> positive = Parameterize(jnp.exp, jnp.zeros(3)) 88 | >>> unwrap(positive) # Aplies exp on unwrapping 89 | Array([1., 1., 1.], dtype=float32) 90 | 91 | Args: 92 | fn: Callable to call with args, and kwargs. 93 | *args: Positional arguments to pass to fn. 94 | **kwargs: Keyword arguments to pass to fn. 95 | """ 96 | 97 | fn: Callable[..., T] 98 | args: tuple[Any, ...] 99 | kwargs: dict[str, Any] 100 | 101 | def __init__(self, fn: Callable, *args, **kwargs): 102 | self.fn = fn 103 | self.args = tuple(args) 104 | self.kwargs = kwargs 105 | 106 | def unwrap(self) -> T: 107 | return self.fn(*self.args, **self.kwargs) 108 | 109 | 110 | def non_trainable(tree: PyTree): 111 | """Freezes parameters by wrapping inexact array leaves with :class:`NonTrainable`. 112 | 113 | .. note:: 114 | 115 | Regularization is likely to apply before unwrapping. To avoid regularization 116 | impacting non-trainable parameters, they should be filtered out, 117 | for example using: 118 | 119 | .. code-block:: python 120 | 121 | >>> eqx.partition( 122 | ... ..., 123 | ... is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable), 124 | ... ) 125 | 126 | 127 | Wrapping the arrays in a model rather than the entire tree is often preferable, 128 | allowing easier access to attributes compared to wrapping the entire tree. 129 | 130 | Args: 131 | tree: The pytree. 132 | """ 133 | 134 | def _map_fn(leaf): 135 | return NonTrainable(leaf) if eqx.is_inexact_array(leaf) else leaf 136 | 137 | return jax.tree_util.tree_map( 138 | f=_map_fn, 139 | tree=tree, 140 | is_leaf=lambda x: isinstance(x, NonTrainable), 141 | ) 142 | 143 | 144 | class NonTrainable(AbstractUnwrappable[T]): 145 | """Applies stop gradient to all arraylike leaves before unwrapping. 146 | 147 | See also :func:`non_trainable`, which is probably a generally prefereable way to 148 | achieve similar behaviour, which wraps the arraylike leaves directly, rather than 149 | the tree. Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable. 150 | Note that the underlying parameters may still be impacted by regularization, 151 | so it is generally advised to use this as a suggestively named class 152 | for filtering parameters. 153 | """ 154 | 155 | tree: T 156 | 157 | def unwrap(self) -> T: 158 | differentiable, static = eqx.partition(self.tree, eqx.is_array_like) 159 | return eqx.combine(lax.stop_gradient(differentiable), static) 160 | 161 | 162 | class WeightNormalization(AbstractUnwrappable[Array]): 163 | """Applies weight normalization (https://arxiv.org/abs/1602.07868). 164 | 165 | Args: 166 | weight: The (possibly wrapped) weight matrix. 167 | """ 168 | 169 | weight: Array | AbstractUnwrappable[Array] 170 | scale: Array | AbstractUnwrappable[Array] 171 | 172 | def __init__(self, weight: Array | AbstractUnwrappable[Array]): 173 | self.weight = weight 174 | scale_init = 1 / jnp.linalg.norm(unwrap(weight), axis=-1, keepdims=True) 175 | self.scale = Parameterize(softplus, inv_softplus(scale_init)) 176 | 177 | def unwrap(self) -> Array: 178 | weight_norms = jnp.linalg.norm(self.weight, axis=-1, keepdims=True) 179 | return self.scale * self.weight / weight_norms 180 | 181 | 182 | def contains_unwrappables(pytree): 183 | """Check if a pytree contains unwrappables.""" 184 | 185 | def _is_unwrappable(leaf): 186 | return isinstance(leaf, AbstractUnwrappable) 187 | 188 | leaves = tree_leaves(pytree, is_leaf=_is_unwrappable) 189 | return any(_is_unwrappable(leaf) for leaf in leaves) 190 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | authors = [{ name = "Daniel Ward", email = "danielward27@outlook.com" }] 3 | classifiers = [ 4 | "Intended Audience :: Science/Research", 5 | "License :: OSI Approved :: MIT License", 6 | "Natural Language :: English", 7 | "Programming Language :: Python :: 3", 8 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 9 | "Topic :: Scientific/Engineering :: Information Analysis", 10 | "Topic :: Scientific/Engineering :: Mathematics", 11 | "Typing :: Typed", 12 | ] 13 | dependencies = ["jax", "equinox", "jaxtyping"] 14 | description = "Parameterizations and parameter constraints for JAX PyTrees." 15 | keywords = ["jax", "neural-networks", "equinox"] 16 | license = { file = "LICENSE" } 17 | name = "paramax" 18 | readme = "README.md" 19 | requires-python = ">=3.10" 20 | version = "0.0.3" 21 | 22 | [project.urls] 23 | repository = "https://github.com/danielward27/paramax" 24 | documentation = "https://danielward27.github.io/paramax/index.html" 25 | 26 | [project.optional-dependencies] 27 | dev = [ 28 | "pytest", 29 | "beartype", 30 | "ruff", 31 | "sphinx", 32 | "sphinx-book-theme", 33 | "sphinx-copybutton", 34 | "sphinx-autodoc-typehints", 35 | ] 36 | 37 | [build-system] 38 | requires = ["hatchling"] 39 | build-backend = "hatchling.build" 40 | 41 | [tool.pytest.ini_options] 42 | pythonpath = ["."] 43 | addopts = "--jaxtyping-packages=paramax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))" 44 | 45 | [tool.ruff] 46 | include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"] 47 | 48 | [tool.ruff.lint] 49 | select = ["E", "F", "B", "D", "COM", "I", "UP", "TRY004", "RET", "PT", "FBT"] 50 | ignore = ["D102", "D105", "D107", "B028", "COM812", "F722"] 51 | 52 | [tool.ruff.lint.pydocstyle] 53 | convention = "google" 54 | 55 | [tool.ruff.lint.per-file-ignores] 56 | "tests/*" = ["D"] 57 | "*.ipynb" = ["D"] 58 | "__init__.py" = ["D"] 59 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pytest 3 | from jax.nn import softplus 4 | 5 | from paramax.utils import inv_softplus 6 | 7 | 8 | def test_inv_softplus(): 9 | x = jnp.arange(3) + 1 10 | y = softplus(x) 11 | x_reconstructed = inv_softplus(y) 12 | assert pytest.approx(x) == x_reconstructed 13 | -------------------------------------------------------------------------------- /tests/test_wrappers.py: -------------------------------------------------------------------------------- 1 | from math import prod 2 | 3 | import equinox as eqx 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | import pytest 7 | from jax.tree_util import tree_map 8 | 9 | from paramax.wrappers import ( 10 | NonTrainable, 11 | Parameterize, 12 | WeightNormalization, 13 | non_trainable, 14 | unwrap, 15 | ) 16 | 17 | 18 | def test_Parameterize(): 19 | diag = Parameterize(jnp.diag, jnp.ones(3)) 20 | assert pytest.approx(jnp.eye(3)) == unwrap(diag) 21 | 22 | 23 | def test_nested_unwrap(): 24 | param = Parameterize( 25 | jnp.square, 26 | Parameterize(jnp.square, Parameterize(jnp.square, 2)), 27 | ) 28 | assert unwrap(param) == jnp.square(jnp.square(jnp.square(2))) 29 | 30 | 31 | def test_non_trainable(): 32 | 33 | model = (jnp.ones(3), 1) 34 | model = non_trainable(model) 35 | 36 | def loss(model): 37 | model = unwrap(model) 38 | return model[0].sum() 39 | 40 | grad = eqx.filter_grad(loss)(model)[0].tree 41 | assert grad.shape == (3,) 42 | assert pytest.approx(0) == grad 43 | 44 | 45 | def test_WeightNormalization(): 46 | arr = jr.normal(jr.key(0), (10, 3)) 47 | weight_norm = WeightNormalization(arr) 48 | 49 | # Unwrapped norms should match weightnorm scale parameter 50 | expected = unwrap(weight_norm.scale) 51 | assert pytest.approx(expected) == jnp.linalg.norm( 52 | unwrap(weight_norm), axis=-1, keepdims=True 53 | ) 54 | 55 | 56 | test_cases = { 57 | "NonTrainable": lambda key: NonTrainable(jr.normal(key, 10)), 58 | "Parameterize-exp": lambda key: Parameterize(jnp.exp, jr.normal(key, 10)), 59 | "WeightNormalization": lambda key: WeightNormalization(jr.normal(key, (10, 2))), 60 | } 61 | 62 | 63 | @pytest.mark.parametrize("shape", [(), (2,), (5, 2, 4)]) 64 | @pytest.mark.parametrize("wrapper_fn", test_cases.values(), ids=test_cases.keys()) 65 | def test_vectorization_invariance(wrapper_fn, shape): 66 | keys = jr.split(jr.key(0), prod(shape)) 67 | wrapper = wrapper_fn(keys[0]) # Standard init 68 | 69 | # Multiple vmap init - should have same result in zero-th index 70 | vmap_wrapper_fn = wrapper_fn 71 | for _ in shape: 72 | vmap_wrapper_fn = eqx.filter_vmap(vmap_wrapper_fn) 73 | 74 | vmap_wrapper = vmap_wrapper_fn(keys.reshape(shape)) 75 | 76 | unwrapped = unwrap(wrapper) 77 | unwrapped_vmap = unwrap(vmap_wrapper) 78 | unwrapped_vmap_zero = tree_map( 79 | lambda leaf: leaf[*([0] * len(shape)), ...], 80 | unwrapped_vmap, 81 | ) 82 | assert eqx.tree_equal(unwrapped, unwrapped_vmap_zero, atol=1e-7) 83 | --------------------------------------------------------------------------------