├── .github └── workflows │ ├── release.yml │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── pyproject.toml ├── sympy2jax ├── __init__.py └── sympy_module.py └── tests └── test_symbolic_module.py /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Release 13 | uses: patrick-kidger/action_update_python_project@v6 14 | with: 15 | python-version: "3.11" 16 | test-script: | 17 | python -m pip install pytest jax jaxlib sympy equinox 18 | cp -r ${{ github.workspace }}/tests ./tests 19 | pytest 20 | pypi-token: ${{ secrets.pypi_token }} 21 | github-user: patrick-kidger 22 | github-token: ${{ github.token }} 23 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | name: Run tests 2 | 3 | on: 4 | pull_request: 5 | 6 | jobs: 7 | run-tests: 8 | strategy: 9 | matrix: 10 | python-version: [ "3.10", "3.12" ] 11 | os: [ ubuntu-latest ] 12 | fail-fast: false 13 | runs-on: ${{ matrix.os }} 14 | steps: 15 | - name: Checkout code 16 | uses: actions/checkout@v2 17 | 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | python -m pip install pytest wheel jaxlib sympy equinox 27 | 28 | - name: Checks with pre-commit 29 | uses: pre-commit/action@v2.0.3 30 | 31 | - name: Test with pytest 32 | run: | 33 | python -m pip install . 34 | python -m pytest --durations=0 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | *.egg-info 3 | build/ 4 | dist/ 5 | 6 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.2.2 4 | hooks: 5 | - id: ruff-format # formatter 6 | types_or: [ python, pyi, jupyter ] 7 | - id: ruff # linter 8 | types_or: [ python, pyi, jupyter ] 9 | args: [ --fix ] 10 | - repo: https://github.com/RobertCraigie/pyright-python 11 | rev: v1.1.350 12 | hooks: 13 | - id: pyright 14 | additional_dependencies: [equinox, jax, sympy] 15 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions (pull requests) are very welcome! Here's how to get started. 4 | 5 | --- 6 | 7 | First fork the library on GitHub. 8 | 9 | Then clone and install the library in development mode: 10 | 11 | ```bash 12 | git clone https://github.com/your-username-here/sympy2jax.git 13 | cd sympy2jax 14 | pip install -e . 15 | ``` 16 | 17 | Then install the pre-commit hook: 18 | 19 | ```bash 20 | pip install pre-commit 21 | pre-commit install 22 | ``` 23 | 24 | These hooks use Black and isort to format the code, and flake8 to lint it. 25 | 26 | Now make your changes. Make sure to include additional tests if necessary. 27 | 28 | Next verify the tests all pass: 29 | 30 | ```bash 31 | pip install pytest 32 | pytest 33 | ``` 34 | 35 | Then push your changes back to your fork of the repository: 36 | 37 | ```bash 38 | git push 39 | ``` 40 | 41 | Finally, open a pull request on GitHub! 42 | 43 | ## Contributor License Agreement 44 | 45 | Contributions to this project must be accompanied by a Contributor License 46 | Agreement (CLA). You (or your employer) retain the copyright to your 47 | contribution; this simply gives us permission to use and redistribute your 48 | contributions as part of the project. Head over to 49 | to see your current agreements on file or 50 | to sign a new one. 51 | 52 | You generally only need to submit a CLA once, so if you've already submitted one 53 | (even if it was for a different project), you probably don't need to do it 54 | again. 55 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

sympy2jax

2 | 3 | Turn SymPy expressions into trainable JAX expressions. The output will be an [Equinox](https://github.com/patrick-kidger/equinox) module with all SymPy floats (integers, rationals, ...) as leaves. SymPy symbols will be inputs. 4 | 5 | Optimise your symbolic expressions via gradient descent! 6 | 7 | ## Installation 8 | 9 | ```bash 10 | pip install sympy2jax 11 | ``` 12 | 13 | Requires: 14 | Python 3.7+ 15 | JAX 0.3.4+ 16 | Equinox 0.5.3+ 17 | SymPy 1.7.1+. 18 | 19 | ## Example 20 | 21 | ```python 22 | import jax 23 | import sympy 24 | import sympy2jax 25 | 26 | x_sym = sympy.symbols("x_sym") 27 | cosx = 1.0 * sympy.cos(x_sym) 28 | sinx = 2.0 * sympy.sin(x_sym) 29 | mod = sympy2jax.SymbolicModule([cosx, sinx]) # PyTree of input expressions 30 | 31 | x = jax.numpy.zeros(3) 32 | out = mod(x_sym=x) # PyTree of results. 33 | params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters. 34 | # (Which may be trained in the usual way for Equinox.) 35 | ``` 36 | 37 | ## Documentation 38 | 39 | ```python 40 | sympy2jax.SymbolicModule(expressions, extra_funcs=None, make_array=True) 41 | ``` 42 | 43 | Where: 44 | - `expressions` is a PyTree of SymPy expressions. 45 | - `extra_funcs` is an optional dictionary from SymPy functions to JAX operations, to extend the built-in translation rules. 46 | - `make_array` is whether integers/floats/rationals should be stored as Python integers/etc., or as JAX arrays. 47 | 48 | Instances can be called with key-value pairs of symbol-value, as in the above example. 49 | 50 | Instances have a `.sympy()` method that translates the module back into a PyTree of SymPy expressions. 51 | 52 | (That's literally the entire documentation, it's super easy.) 53 | 54 | ## See also: other libraries in the JAX ecosystem 55 | 56 | **Always useful** 57 | [Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX! 58 | [jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays. 59 | 60 | **Deep learning** 61 | [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers. 62 | [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device). 63 | [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs). 64 | [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees. 65 | 66 | **Scientific computing** 67 | [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers. 68 | [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares. 69 | [Lineax](https://github.com/patrick-kidger/lineax): linear solvers. 70 | [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling. 71 | [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!) 72 | 73 | **Awesome JAX** 74 | [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects. 75 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "sympy2jax" 3 | version = "0.0.7" 4 | description = "Turn SymPy expressions into trainable JAX expressions." 5 | readme = "README.md" 6 | requires-python ="~=3.9" 7 | license = {file = "LICENSE"} 8 | authors = [ 9 | {name = "Patrick Kidger", email = "contact@kidger.site"}, 10 | ] 11 | keywords = ["jax", "sympy", "equinox"] 12 | classifiers = [ 13 | "Development Status :: 3 - Alpha", 14 | "Intended Audience :: Science/Research", 15 | "License :: OSI Approved :: Apache Software License", 16 | "Natural Language :: English", 17 | "Programming Language :: Python :: 3", 18 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 19 | "Topic :: Scientific/Engineering :: Mathematics", 20 | ] 21 | urls = {repository = "https://github.com/google/sympy2jax" } 22 | dependencies = ["equinox>=0.5.3", "jax>=0.3.4", "sympy>=1.7.1"] 23 | 24 | [build-system] 25 | requires = ["hatchling"] 26 | build-backend = "hatchling.build" 27 | 28 | [tool.hatch.build] 29 | include = ["sympy2jax/*"] 30 | 31 | [tool.pytest.ini_options] 32 | addopts = "--jaxtyping-packages=symyp2jax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))" 33 | 34 | [tool.ruff] 35 | select = ["E", "F", "I001"] 36 | ignore = ["E402", "E721", "E731", "E741", "F722"] 37 | ignore-init-module-imports = true 38 | fixable = ["I001", "F401"] 39 | 40 | [tool.ruff.isort] 41 | combine-as-imports = true 42 | lines-after-imports = 2 43 | extra-standard-library = ["typing_extensions"] 44 | order-by-type = false 45 | 46 | [tool.pyright] 47 | reportIncompatibleMethodOverride = true 48 | include = ["sympy2jax", "tests"] 49 | -------------------------------------------------------------------------------- /sympy2jax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .sympy_module import ( 16 | concatenate as concatenate, 17 | stack as stack, 18 | SymbolicModule as SymbolicModule, 19 | ) 20 | 21 | 22 | __version__ = "0.0.4" 23 | -------------------------------------------------------------------------------- /sympy2jax/sympy_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | import collections as co 17 | import functools as ft 18 | from collections.abc import Callable, Mapping 19 | from typing import Any, cast, Optional 20 | 21 | import equinox as eqx 22 | import jax 23 | import jax.numpy as jnp 24 | import jax.scipy as jsp 25 | import jax.tree_util as jtu 26 | import sympy 27 | 28 | 29 | PyTree = Any 30 | 31 | concatenate: Callable = sympy.Function("concatenate") # pyright: ignore 32 | stack: Callable = sympy.Function("stack") # pyright: ignore 33 | 34 | 35 | def _reduce(fn): 36 | def fn_(*args): 37 | return ft.reduce(fn, args) 38 | 39 | return fn_ 40 | 41 | 42 | def _single_args(fn): 43 | def fn_(*args): 44 | return fn(args) 45 | 46 | return fn_ 47 | 48 | 49 | _lookup = { 50 | concatenate: _single_args(jnp.concatenate), 51 | stack: _single_args(jnp.stack), 52 | sympy.Mul: _reduce(jnp.multiply), 53 | sympy.Add: _reduce(jnp.add), 54 | sympy.div: jnp.divide, 55 | sympy.Abs: jnp.abs, 56 | sympy.sign: jnp.sign, 57 | sympy.ceiling: jnp.ceil, 58 | sympy.floor: jnp.floor, 59 | sympy.log: jnp.log, 60 | sympy.exp: jnp.exp, 61 | sympy.sqrt: jnp.sqrt, 62 | sympy.cos: jnp.cos, 63 | sympy.acos: jnp.arccos, 64 | sympy.sin: jnp.sin, 65 | sympy.asin: jnp.arcsin, 66 | sympy.tan: jnp.tan, 67 | sympy.atan: jnp.arctan, 68 | sympy.atan2: jnp.arctan2, 69 | sympy.cosh: jnp.cosh, 70 | sympy.acosh: jnp.arccosh, 71 | sympy.sinh: jnp.sinh, 72 | sympy.asinh: jnp.arcsinh, 73 | sympy.tanh: jnp.tanh, 74 | sympy.atanh: jnp.arctanh, 75 | sympy.Pow: jnp.power, 76 | sympy.re: jnp.real, 77 | sympy.im: jnp.imag, 78 | sympy.arg: jnp.angle, 79 | sympy.erf: jsp.special.erf, 80 | sympy.Eq: jnp.equal, 81 | sympy.Ne: jnp.not_equal, 82 | sympy.StrictGreaterThan: jnp.greater, 83 | sympy.StrictLessThan: jnp.less, 84 | sympy.LessThan: jnp.less_equal, 85 | sympy.GreaterThan: jnp.greater_equal, 86 | sympy.And: jnp.logical_and, 87 | sympy.Or: jnp.logical_or, 88 | sympy.Not: jnp.logical_not, 89 | sympy.Xor: jnp.logical_xor, 90 | sympy.Max: _reduce(jnp.maximum), 91 | sympy.Min: _reduce(jnp.minimum), 92 | sympy.MatAdd: _reduce(jnp.add), 93 | sympy.Trace: jnp.trace, 94 | sympy.Determinant: jnp.linalg.det, 95 | } 96 | 97 | _constant_lookup = { 98 | sympy.E: jnp.e, 99 | sympy.pi: jnp.pi, 100 | sympy.EulerGamma: jnp.euler_gamma, 101 | sympy.I: 1j, 102 | } 103 | 104 | _reverse_lookup = {v: k for k, v in _lookup.items()} 105 | assert len(_reverse_lookup) == len(_lookup) 106 | 107 | 108 | def _item(x): 109 | if eqx.is_array(x): 110 | return x.item() 111 | else: 112 | return x 113 | 114 | 115 | class _AbstractNode(eqx.Module): 116 | @abc.abstractmethod 117 | def __call__(self, memodict: dict) -> jax.typing.ArrayLike: 118 | ... 119 | 120 | @abc.abstractmethod 121 | def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: 122 | ... 123 | 124 | # Comparisons based on identity 125 | __hash__ = object.__hash__ 126 | __eq__ = object.__eq__ # pyright: ignore 127 | 128 | 129 | class _Symbol(_AbstractNode): 130 | _name: str 131 | 132 | def __init__(self, expr: sympy.Expr): 133 | self._name = str(expr.name) # pyright: ignore 134 | 135 | def __call__(self, memodict: dict): 136 | try: 137 | return memodict[self._name] 138 | except KeyError as e: 139 | raise KeyError(f"Missing input for symbol {self._name}") from e 140 | 141 | def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: 142 | # memodict not needed as sympy deduplicates internally 143 | return sympy.Symbol(self._name) 144 | 145 | 146 | def _maybe_array(val, make_array): 147 | if make_array: 148 | return jnp.asarray(val) 149 | else: 150 | return val 151 | 152 | 153 | class _Integer(_AbstractNode): 154 | _value: jax.typing.ArrayLike 155 | 156 | def __init__(self, expr: sympy.Expr, make_array: bool): 157 | assert isinstance(expr, sympy.Integer) 158 | self._value = _maybe_array(int(expr), make_array) 159 | 160 | def __call__(self, memodict: dict): 161 | return self._value 162 | 163 | def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: 164 | # memodict not needed as sympy deduplicates internally 165 | return sympy.Integer(_item(self._value)) 166 | 167 | 168 | class _Float(_AbstractNode): 169 | _value: jax.typing.ArrayLike 170 | 171 | def __init__(self, expr: sympy.Expr, make_array: bool): 172 | assert isinstance(expr, sympy.Float) 173 | self._value = _maybe_array(float(expr), make_array) 174 | 175 | def __call__(self, memodict: dict): 176 | return self._value 177 | 178 | def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: 179 | # memodict not needed as sympy deduplicates internally 180 | return sympy.Float(_item(self._value)) 181 | 182 | 183 | class _Rational(_AbstractNode): 184 | _numerator: jax.typing.ArrayLike 185 | _denominator: jax.typing.ArrayLike 186 | 187 | def __init__(self, expr: sympy.Expr, make_array: bool): 188 | assert isinstance(expr, sympy.Rational) 189 | numerator = expr.numerator 190 | denominator = expr.denominator 191 | if callable(numerator): 192 | # Support SymPy < 1.10 193 | numerator = numerator() 194 | if callable(denominator): 195 | denominator = denominator() 196 | self._numerator = _maybe_array(int(numerator), make_array) 197 | self._denominator = _maybe_array(int(denominator), make_array) 198 | 199 | def __call__(self, memodict: dict): 200 | return self._numerator / self._denominator 201 | 202 | def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: 203 | # memodict not needed as sympy deduplicates internally 204 | return sympy.Integer(_item(self._numerator)) / sympy.Integer( 205 | _item(self._denominator) 206 | ) 207 | 208 | 209 | class _Constant(_AbstractNode): 210 | _value: jnp.ndarray 211 | _expr: sympy.Expr 212 | 213 | def __init__(self, expr: sympy.Expr, make_array: bool): 214 | assert expr in _constant_lookup 215 | self._value = _maybe_array(_constant_lookup[expr], make_array) 216 | self._expr = expr 217 | 218 | def __call__(self, memodict: dict): 219 | return self._value 220 | 221 | def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: 222 | return self._expr 223 | 224 | 225 | class _Func(_AbstractNode): 226 | _func: Callable 227 | _args: list 228 | 229 | def __init__( 230 | self, expr: sympy.Expr, memodict: dict, func_lookup: Mapping, make_array: bool 231 | ): 232 | try: 233 | self._func = func_lookup[expr.func] 234 | except KeyError as e: 235 | raise KeyError(f"Unsupported Sympy type {type(expr)}") from e 236 | self._args = [ 237 | _sympy_to_node(cast(sympy.Expr, arg), memodict, func_lookup, make_array) 238 | for arg in expr.args 239 | ] 240 | 241 | def __call__(self, memodict: dict): 242 | args = [] 243 | for arg in self._args: 244 | try: 245 | arg_call = memodict[arg] 246 | except KeyError: 247 | arg_call = arg(memodict) 248 | memodict[arg] = arg_call 249 | args.append(arg_call) 250 | return self._func(*args) 251 | 252 | def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: 253 | try: 254 | return memodict[self] 255 | except KeyError: 256 | func = func_lookup[self._func] 257 | args = [arg.sympy(memodict, func_lookup) for arg in self._args] 258 | out = func(*args) 259 | memodict[self] = out 260 | return out 261 | 262 | 263 | def _sympy_to_node( 264 | expr: sympy.Expr, memodict: dict, func_lookup: Mapping, make_array: bool 265 | ) -> _AbstractNode: 266 | try: 267 | return memodict[expr] 268 | except KeyError: 269 | if isinstance(expr, sympy.Symbol): 270 | out = _Symbol(expr) 271 | elif isinstance(expr, sympy.Integer): 272 | out = _Integer(expr, make_array) 273 | elif isinstance(expr, sympy.Float): 274 | out = _Float(expr, make_array) 275 | elif isinstance(expr, sympy.Rational): 276 | out = _Rational(expr, make_array) 277 | elif expr in (sympy.E, sympy.pi, sympy.EulerGamma, sympy.I): 278 | out = _Constant(expr, make_array) 279 | else: 280 | out = _Func(expr, memodict, func_lookup, make_array) 281 | memodict[expr] = out 282 | return out 283 | 284 | 285 | def _is_node(x): 286 | return isinstance(x, _AbstractNode) 287 | 288 | 289 | class SymbolicModule(eqx.Module): 290 | nodes: PyTree 291 | has_extra_funcs: bool = eqx.static_field() 292 | 293 | def __init__( 294 | self, 295 | expressions: PyTree, 296 | extra_funcs: Optional[dict] = None, 297 | make_array: bool = True, 298 | ): 299 | if extra_funcs is None: 300 | lookup = _lookup 301 | self.has_extra_funcs = False 302 | else: 303 | lookup = co.ChainMap(extra_funcs, _lookup) 304 | self.has_extra_funcs = True 305 | _convert = ft.partial( 306 | _sympy_to_node, 307 | memodict=dict(), 308 | func_lookup=lookup, 309 | make_array=make_array, 310 | ) 311 | self.nodes = jtu.tree_map(_convert, expressions) 312 | 313 | def sympy(self) -> sympy.Expr: 314 | if self.has_extra_funcs: 315 | raise NotImplementedError( 316 | "SymbolicModule cannot be converted back to SymPy if `extra_funcs` " 317 | "is passed." 318 | ) 319 | memodict = dict() 320 | return jtu.tree_map( 321 | lambda n: n.sympy(memodict, _reverse_lookup), self.nodes, is_leaf=_is_node 322 | ) 323 | 324 | def __call__(self, **symbols): 325 | memodict = symbols 326 | return jtu.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node) 327 | -------------------------------------------------------------------------------- /tests/test_symbolic_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import equinox as eqx 16 | import jax 17 | import jax.numpy as jnp 18 | import jax.random as jr 19 | import jax.tree_util as jtu 20 | import sympy 21 | 22 | import sympy2jax 23 | 24 | 25 | def assert_equal(x, y): 26 | x_leaves, x_tree = jtu.tree_flatten(x) 27 | y_leaves, y_tree = jtu.tree_flatten(y) 28 | assert x_tree == y_tree 29 | for xi, yi in zip(x_leaves, y_leaves): 30 | assert type(xi) is type(yi) 31 | if isinstance(xi, jnp.ndarray): 32 | assert xi.shape == yi.shape 33 | assert xi.dtype == yi.dtype 34 | assert jnp.all(xi == yi) 35 | else: 36 | assert xi == yi 37 | 38 | 39 | def assert_sympy_allclose(x, y): 40 | assert isinstance(x, sympy.Expr) 41 | assert isinstance(y, sympy.Expr) 42 | assert x.func is y.func 43 | if isinstance(x, sympy.Float): 44 | assert abs(float(x) - float(y)) < 1e-5 45 | elif isinstance(x, sympy.Integer): 46 | assert x == y 47 | elif isinstance(x, sympy.Rational): 48 | assert x.numerator == y.numerator # pyright: ignore 49 | assert x.denominator == y.denominator # pyright: ignore 50 | elif isinstance(x, sympy.Symbol): 51 | assert x.name == y.name # pyright: ignore 52 | else: 53 | assert len(x.args) == len(y.args) 54 | for xarg, yarg in zip(x.args, y.args): 55 | assert_sympy_allclose(xarg, yarg) 56 | 57 | 58 | def test_example(): 59 | x_sym = sympy.symbols("x_sym") 60 | cosx = 1.0 * sympy.cos(x_sym) # pyright: ignore[reportOperatorIssue] 61 | sinx = 2.0 * sympy.sin(x_sym) # pyright: ignore[reportOperatorIssue] 62 | mod = sympy2jax.SymbolicModule([cosx, sinx]) 63 | 64 | x = jax.numpy.zeros(3) 65 | out = mod(x_sym=x) 66 | params = jtu.tree_leaves(mod) 67 | 68 | assert_equal(out, [jnp.cos(x), 2 * jnp.sin(x)]) 69 | assert_equal( 70 | [x for x in params if eqx.is_array(x)], [jnp.array(1.0), jnp.array(2.0)] 71 | ) 72 | 73 | 74 | def test_grad(): 75 | x_sym = sympy.symbols("x_sym") 76 | y = 2.1 * x_sym**2 77 | mod = sympy2jax.SymbolicModule(y) 78 | x = jnp.array(1.1) 79 | 80 | grad_m = eqx.filter_grad(lambda m, z: m(x_sym=z))(mod, x) 81 | grad_z = eqx.filter_grad(lambda z, m: m(x_sym=z))(x, mod) 82 | 83 | true_grad_m = eqx.filter( 84 | sympy2jax.SymbolicModule(1.21 * x_sym**2), eqx.is_inexact_array 85 | ) 86 | true_grad_z = jnp.array(4.2 * x) 87 | 88 | assert_equal(grad_m, true_grad_m) 89 | assert_equal(grad_z, true_grad_z) 90 | 91 | mod2 = eqx.apply_updates(mod, grad_m) 92 | expr = mod2.sympy() 93 | 94 | assert_sympy_allclose(expr, 3.31 * x_sym**2) 95 | 96 | 97 | def test_reduce(): 98 | x, y, z = sympy.symbols("x y z") 99 | z = 2 * x * y * z 100 | mod = sympy2jax.SymbolicModule(expressions=z) 101 | mod(x=jnp.array(0.4), y=jnp.array(0.5), z=jnp.array(0.6)) 102 | 103 | z = 2 + x + y + z 104 | mod = sympy2jax.SymbolicModule(expressions=z) 105 | mod(x=jnp.array(0.4), y=jnp.array(0.5), z=jnp.array(0.6)) 106 | 107 | 108 | def test_special_subclasses(): 109 | x, y = sympy.symbols("x y") 110 | z = x - 1 # sympy.core.numbers.NegativeOne 111 | w = y * 0 # sympy.core.numbers.Zero 112 | v = x + 1 / 2 # sympy.core.numbers.OneHalf 113 | 114 | mod = sympy2jax.SymbolicModule([z, w, v]) 115 | assert_equal(mod(x=1, y=1), [jnp.array(0), jnp.array(0), jnp.array(1.5)]) 116 | assert mod.sympy() == [z, sympy.Integer(0), v] 117 | 118 | 119 | def test_rational(): 120 | x = sympy.symbols("x") 121 | y = x + sympy.Integer(3) / sympy.Integer(7) 122 | mod = sympy2jax.SymbolicModule(y) 123 | assert mod(x=1.0) == 1 + 3 / 7 124 | assert mod.sympy() == y 125 | 126 | 127 | def test_constants(): 128 | x = sympy.symbols("x") 129 | y = x + sympy.pi + sympy.E + sympy.EulerGamma + sympy.I 130 | mod = sympy2jax.SymbolicModule(y) 131 | assert jnp.isclose(mod(x=1.0), 1 + jnp.pi + jnp.e + jnp.euler_gamma + 1j) 132 | assert mod.sympy() == y 133 | 134 | 135 | def test_extra_funcs(): 136 | class _MLP(eqx.Module): 137 | mlp: eqx.nn.MLP 138 | 139 | def __init__(self): 140 | self.mlp = eqx.nn.MLP(1, 1, 2, 2, key=jr.PRNGKey(0)) 141 | 142 | def __call__(self, x): 143 | x = jnp.asarray(x) 144 | return self.mlp(x[None])[0] 145 | 146 | expr = sympy.parsing.sympy_parser.parse_expr("f(x) + y") 147 | mlp = _MLP() 148 | mod = sympy2jax.SymbolicModule(expr, {sympy.Function("f"): mlp}) 149 | mod(x=1.0, y=2.0) 150 | 151 | def _get_params(module): 152 | return {id(x) for x in jtu.tree_leaves(module) if eqx.is_array(x)} 153 | 154 | assert _get_params(mod).issuperset(_get_params(mlp)) 155 | 156 | 157 | def test_concatenate(): 158 | x, y, z = sympy.symbols("x y z") 159 | cat = sympy2jax.concatenate(x, y, z) 160 | mod = sympy2jax.SymbolicModule(expressions=cat) 161 | assert_equal( 162 | mod(x=jnp.array([0.4, 0.5]), y=jnp.array([0.6, 0.7]), z=jnp.array([0.8, 0.9])), 163 | jnp.array([0.4, 0.5, 0.6, 0.7, 0.8, 0.9]), 164 | ) 165 | 166 | 167 | def test_stack(): 168 | x, y, z = sympy.symbols("x y z") 169 | stack = sympy2jax.stack(x, y, z) 170 | mod = sympy2jax.SymbolicModule(expressions=stack) 171 | assert_equal( 172 | mod(x=jnp.array(0.4), y=jnp.array(0.5), z=jnp.array(0.6)), 173 | jnp.array([0.4, 0.5, 0.6]), 174 | ) 175 | 176 | 177 | def test_non_array_to_sympy(): 178 | mod = sympy2jax.SymbolicModule(expressions=[sympy.Integer(1)], make_array=False) 179 | assert mod.sympy() == [sympy.Integer(1)] 180 | --------------------------------------------------------------------------------