├── .github
└── workflows
│ ├── release.yml
│ └── run_tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── pyproject.toml
├── sympy2jax
├── __init__.py
└── sympy_module.py
└── tests
└── test_symbolic_module.py
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | build:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: Release
13 | uses: patrick-kidger/action_update_python_project@v6
14 | with:
15 | python-version: "3.11"
16 | test-script: |
17 | python -m pip install pytest jax jaxlib sympy equinox
18 | cp -r ${{ github.workspace }}/tests ./tests
19 | pytest
20 | pypi-token: ${{ secrets.pypi_token }}
21 | github-user: patrick-kidger
22 | github-token: ${{ github.token }}
23 |
--------------------------------------------------------------------------------
/.github/workflows/run_tests.yml:
--------------------------------------------------------------------------------
1 | name: Run tests
2 |
3 | on:
4 | pull_request:
5 |
6 | jobs:
7 | run-tests:
8 | strategy:
9 | matrix:
10 | python-version: [ "3.10", "3.12" ]
11 | os: [ ubuntu-latest ]
12 | fail-fast: false
13 | runs-on: ${{ matrix.os }}
14 | steps:
15 | - name: Checkout code
16 | uses: actions/checkout@v2
17 |
18 | - name: Set up Python ${{ matrix.python-version }}
19 | uses: actions/setup-python@v2
20 | with:
21 | python-version: ${{ matrix.python-version }}
22 |
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | python -m pip install pytest wheel jaxlib sympy equinox
27 |
28 | - name: Checks with pre-commit
29 | uses: pre-commit/action@v2.0.3
30 |
31 | - name: Test with pytest
32 | run: |
33 | python -m pip install .
34 | python -m pytest --durations=0
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__
2 | *.egg-info
3 | build/
4 | dist/
5 |
6 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/astral-sh/ruff-pre-commit
3 | rev: v0.2.2
4 | hooks:
5 | - id: ruff-format # formatter
6 | types_or: [ python, pyi, jupyter ]
7 | - id: ruff # linter
8 | types_or: [ python, pyi, jupyter ]
9 | args: [ --fix ]
10 | - repo: https://github.com/RobertCraigie/pyright-python
11 | rev: v1.1.350
12 | hooks:
13 | - id: pyright
14 | additional_dependencies: [equinox, jax, sympy]
15 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | Contributions (pull requests) are very welcome! Here's how to get started.
4 |
5 | ---
6 |
7 | First fork the library on GitHub.
8 |
9 | Then clone and install the library in development mode:
10 |
11 | ```bash
12 | git clone https://github.com/your-username-here/sympy2jax.git
13 | cd sympy2jax
14 | pip install -e .
15 | ```
16 |
17 | Then install the pre-commit hook:
18 |
19 | ```bash
20 | pip install pre-commit
21 | pre-commit install
22 | ```
23 |
24 | These hooks use Black and isort to format the code, and flake8 to lint it.
25 |
26 | Now make your changes. Make sure to include additional tests if necessary.
27 |
28 | Next verify the tests all pass:
29 |
30 | ```bash
31 | pip install pytest
32 | pytest
33 | ```
34 |
35 | Then push your changes back to your fork of the repository:
36 |
37 | ```bash
38 | git push
39 | ```
40 |
41 | Finally, open a pull request on GitHub!
42 |
43 | ## Contributor License Agreement
44 |
45 | Contributions to this project must be accompanied by a Contributor License
46 | Agreement (CLA). You (or your employer) retain the copyright to your
47 | contribution; this simply gives us permission to use and redistribute your
48 | contributions as part of the project. Head over to
49 | to see your current agreements on file or
50 | to sign a new one.
51 |
52 | You generally only need to submit a CLA once, so if you've already submitted one
53 | (even if it was for a different project), you probably don't need to do it
54 | again.
55 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
sympy2jax
2 |
3 | Turn SymPy expressions into trainable JAX expressions. The output will be an [Equinox](https://github.com/patrick-kidger/equinox) module with all SymPy floats (integers, rationals, ...) as leaves. SymPy symbols will be inputs.
4 |
5 | Optimise your symbolic expressions via gradient descent!
6 |
7 | ## Installation
8 |
9 | ```bash
10 | pip install sympy2jax
11 | ```
12 |
13 | Requires:
14 | Python 3.7+
15 | JAX 0.3.4+
16 | Equinox 0.5.3+
17 | SymPy 1.7.1+.
18 |
19 | ## Example
20 |
21 | ```python
22 | import jax
23 | import sympy
24 | import sympy2jax
25 |
26 | x_sym = sympy.symbols("x_sym")
27 | cosx = 1.0 * sympy.cos(x_sym)
28 | sinx = 2.0 * sympy.sin(x_sym)
29 | mod = sympy2jax.SymbolicModule([cosx, sinx]) # PyTree of input expressions
30 |
31 | x = jax.numpy.zeros(3)
32 | out = mod(x_sym=x) # PyTree of results.
33 | params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
34 | # (Which may be trained in the usual way for Equinox.)
35 | ```
36 |
37 | ## Documentation
38 |
39 | ```python
40 | sympy2jax.SymbolicModule(expressions, extra_funcs=None, make_array=True)
41 | ```
42 |
43 | Where:
44 | - `expressions` is a PyTree of SymPy expressions.
45 | - `extra_funcs` is an optional dictionary from SymPy functions to JAX operations, to extend the built-in translation rules.
46 | - `make_array` is whether integers/floats/rationals should be stored as Python integers/etc., or as JAX arrays.
47 |
48 | Instances can be called with key-value pairs of symbol-value, as in the above example.
49 |
50 | Instances have a `.sympy()` method that translates the module back into a PyTree of SymPy expressions.
51 |
52 | (That's literally the entire documentation, it's super easy.)
53 |
54 | ## See also: other libraries in the JAX ecosystem
55 |
56 | **Always useful**
57 | [Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
58 | [jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.
59 |
60 | **Deep learning**
61 | [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
62 | [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
63 | [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
64 | [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
65 |
66 | **Scientific computing**
67 | [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
68 | [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
69 | [Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
70 | [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
71 | [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
72 |
73 | **Awesome JAX**
74 | [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.
75 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "sympy2jax"
3 | version = "0.0.7"
4 | description = "Turn SymPy expressions into trainable JAX expressions."
5 | readme = "README.md"
6 | requires-python ="~=3.9"
7 | license = {file = "LICENSE"}
8 | authors = [
9 | {name = "Patrick Kidger", email = "contact@kidger.site"},
10 | ]
11 | keywords = ["jax", "sympy", "equinox"]
12 | classifiers = [
13 | "Development Status :: 3 - Alpha",
14 | "Intended Audience :: Science/Research",
15 | "License :: OSI Approved :: Apache Software License",
16 | "Natural Language :: English",
17 | "Programming Language :: Python :: 3",
18 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
19 | "Topic :: Scientific/Engineering :: Mathematics",
20 | ]
21 | urls = {repository = "https://github.com/google/sympy2jax" }
22 | dependencies = ["equinox>=0.5.3", "jax>=0.3.4", "sympy>=1.7.1"]
23 |
24 | [build-system]
25 | requires = ["hatchling"]
26 | build-backend = "hatchling.build"
27 |
28 | [tool.hatch.build]
29 | include = ["sympy2jax/*"]
30 |
31 | [tool.pytest.ini_options]
32 | addopts = "--jaxtyping-packages=symyp2jax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))"
33 |
34 | [tool.ruff]
35 | select = ["E", "F", "I001"]
36 | ignore = ["E402", "E721", "E731", "E741", "F722"]
37 | ignore-init-module-imports = true
38 | fixable = ["I001", "F401"]
39 |
40 | [tool.ruff.isort]
41 | combine-as-imports = true
42 | lines-after-imports = 2
43 | extra-standard-library = ["typing_extensions"]
44 | order-by-type = false
45 |
46 | [tool.pyright]
47 | reportIncompatibleMethodOverride = true
48 | include = ["sympy2jax", "tests"]
49 |
--------------------------------------------------------------------------------
/sympy2jax/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .sympy_module import (
16 | concatenate as concatenate,
17 | stack as stack,
18 | SymbolicModule as SymbolicModule,
19 | )
20 |
21 |
22 | __version__ = "0.0.4"
23 |
--------------------------------------------------------------------------------
/sympy2jax/sympy_module.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import abc
16 | import collections as co
17 | import functools as ft
18 | from collections.abc import Callable, Mapping
19 | from typing import Any, cast, Optional
20 |
21 | import equinox as eqx
22 | import jax
23 | import jax.numpy as jnp
24 | import jax.scipy as jsp
25 | import jax.tree_util as jtu
26 | import sympy
27 |
28 |
29 | PyTree = Any
30 |
31 | concatenate: Callable = sympy.Function("concatenate") # pyright: ignore
32 | stack: Callable = sympy.Function("stack") # pyright: ignore
33 |
34 |
35 | def _reduce(fn):
36 | def fn_(*args):
37 | return ft.reduce(fn, args)
38 |
39 | return fn_
40 |
41 |
42 | def _single_args(fn):
43 | def fn_(*args):
44 | return fn(args)
45 |
46 | return fn_
47 |
48 |
49 | _lookup = {
50 | concatenate: _single_args(jnp.concatenate),
51 | stack: _single_args(jnp.stack),
52 | sympy.Mul: _reduce(jnp.multiply),
53 | sympy.Add: _reduce(jnp.add),
54 | sympy.div: jnp.divide,
55 | sympy.Abs: jnp.abs,
56 | sympy.sign: jnp.sign,
57 | sympy.ceiling: jnp.ceil,
58 | sympy.floor: jnp.floor,
59 | sympy.log: jnp.log,
60 | sympy.exp: jnp.exp,
61 | sympy.sqrt: jnp.sqrt,
62 | sympy.cos: jnp.cos,
63 | sympy.acos: jnp.arccos,
64 | sympy.sin: jnp.sin,
65 | sympy.asin: jnp.arcsin,
66 | sympy.tan: jnp.tan,
67 | sympy.atan: jnp.arctan,
68 | sympy.atan2: jnp.arctan2,
69 | sympy.cosh: jnp.cosh,
70 | sympy.acosh: jnp.arccosh,
71 | sympy.sinh: jnp.sinh,
72 | sympy.asinh: jnp.arcsinh,
73 | sympy.tanh: jnp.tanh,
74 | sympy.atanh: jnp.arctanh,
75 | sympy.Pow: jnp.power,
76 | sympy.re: jnp.real,
77 | sympy.im: jnp.imag,
78 | sympy.arg: jnp.angle,
79 | sympy.erf: jsp.special.erf,
80 | sympy.Eq: jnp.equal,
81 | sympy.Ne: jnp.not_equal,
82 | sympy.StrictGreaterThan: jnp.greater,
83 | sympy.StrictLessThan: jnp.less,
84 | sympy.LessThan: jnp.less_equal,
85 | sympy.GreaterThan: jnp.greater_equal,
86 | sympy.And: jnp.logical_and,
87 | sympy.Or: jnp.logical_or,
88 | sympy.Not: jnp.logical_not,
89 | sympy.Xor: jnp.logical_xor,
90 | sympy.Max: _reduce(jnp.maximum),
91 | sympy.Min: _reduce(jnp.minimum),
92 | sympy.MatAdd: _reduce(jnp.add),
93 | sympy.Trace: jnp.trace,
94 | sympy.Determinant: jnp.linalg.det,
95 | }
96 |
97 | _constant_lookup = {
98 | sympy.E: jnp.e,
99 | sympy.pi: jnp.pi,
100 | sympy.EulerGamma: jnp.euler_gamma,
101 | sympy.I: 1j,
102 | }
103 |
104 | _reverse_lookup = {v: k for k, v in _lookup.items()}
105 | assert len(_reverse_lookup) == len(_lookup)
106 |
107 |
108 | def _item(x):
109 | if eqx.is_array(x):
110 | return x.item()
111 | else:
112 | return x
113 |
114 |
115 | class _AbstractNode(eqx.Module):
116 | @abc.abstractmethod
117 | def __call__(self, memodict: dict) -> jax.typing.ArrayLike:
118 | ...
119 |
120 | @abc.abstractmethod
121 | def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
122 | ...
123 |
124 | # Comparisons based on identity
125 | __hash__ = object.__hash__
126 | __eq__ = object.__eq__ # pyright: ignore
127 |
128 |
129 | class _Symbol(_AbstractNode):
130 | _name: str
131 |
132 | def __init__(self, expr: sympy.Expr):
133 | self._name = str(expr.name) # pyright: ignore
134 |
135 | def __call__(self, memodict: dict):
136 | try:
137 | return memodict[self._name]
138 | except KeyError as e:
139 | raise KeyError(f"Missing input for symbol {self._name}") from e
140 |
141 | def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
142 | # memodict not needed as sympy deduplicates internally
143 | return sympy.Symbol(self._name)
144 |
145 |
146 | def _maybe_array(val, make_array):
147 | if make_array:
148 | return jnp.asarray(val)
149 | else:
150 | return val
151 |
152 |
153 | class _Integer(_AbstractNode):
154 | _value: jax.typing.ArrayLike
155 |
156 | def __init__(self, expr: sympy.Expr, make_array: bool):
157 | assert isinstance(expr, sympy.Integer)
158 | self._value = _maybe_array(int(expr), make_array)
159 |
160 | def __call__(self, memodict: dict):
161 | return self._value
162 |
163 | def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
164 | # memodict not needed as sympy deduplicates internally
165 | return sympy.Integer(_item(self._value))
166 |
167 |
168 | class _Float(_AbstractNode):
169 | _value: jax.typing.ArrayLike
170 |
171 | def __init__(self, expr: sympy.Expr, make_array: bool):
172 | assert isinstance(expr, sympy.Float)
173 | self._value = _maybe_array(float(expr), make_array)
174 |
175 | def __call__(self, memodict: dict):
176 | return self._value
177 |
178 | def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
179 | # memodict not needed as sympy deduplicates internally
180 | return sympy.Float(_item(self._value))
181 |
182 |
183 | class _Rational(_AbstractNode):
184 | _numerator: jax.typing.ArrayLike
185 | _denominator: jax.typing.ArrayLike
186 |
187 | def __init__(self, expr: sympy.Expr, make_array: bool):
188 | assert isinstance(expr, sympy.Rational)
189 | numerator = expr.numerator
190 | denominator = expr.denominator
191 | if callable(numerator):
192 | # Support SymPy < 1.10
193 | numerator = numerator()
194 | if callable(denominator):
195 | denominator = denominator()
196 | self._numerator = _maybe_array(int(numerator), make_array)
197 | self._denominator = _maybe_array(int(denominator), make_array)
198 |
199 | def __call__(self, memodict: dict):
200 | return self._numerator / self._denominator
201 |
202 | def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
203 | # memodict not needed as sympy deduplicates internally
204 | return sympy.Integer(_item(self._numerator)) / sympy.Integer(
205 | _item(self._denominator)
206 | )
207 |
208 |
209 | class _Constant(_AbstractNode):
210 | _value: jnp.ndarray
211 | _expr: sympy.Expr
212 |
213 | def __init__(self, expr: sympy.Expr, make_array: bool):
214 | assert expr in _constant_lookup
215 | self._value = _maybe_array(_constant_lookup[expr], make_array)
216 | self._expr = expr
217 |
218 | def __call__(self, memodict: dict):
219 | return self._value
220 |
221 | def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
222 | return self._expr
223 |
224 |
225 | class _Func(_AbstractNode):
226 | _func: Callable
227 | _args: list
228 |
229 | def __init__(
230 | self, expr: sympy.Expr, memodict: dict, func_lookup: Mapping, make_array: bool
231 | ):
232 | try:
233 | self._func = func_lookup[expr.func]
234 | except KeyError as e:
235 | raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
236 | self._args = [
237 | _sympy_to_node(cast(sympy.Expr, arg), memodict, func_lookup, make_array)
238 | for arg in expr.args
239 | ]
240 |
241 | def __call__(self, memodict: dict):
242 | args = []
243 | for arg in self._args:
244 | try:
245 | arg_call = memodict[arg]
246 | except KeyError:
247 | arg_call = arg(memodict)
248 | memodict[arg] = arg_call
249 | args.append(arg_call)
250 | return self._func(*args)
251 |
252 | def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
253 | try:
254 | return memodict[self]
255 | except KeyError:
256 | func = func_lookup[self._func]
257 | args = [arg.sympy(memodict, func_lookup) for arg in self._args]
258 | out = func(*args)
259 | memodict[self] = out
260 | return out
261 |
262 |
263 | def _sympy_to_node(
264 | expr: sympy.Expr, memodict: dict, func_lookup: Mapping, make_array: bool
265 | ) -> _AbstractNode:
266 | try:
267 | return memodict[expr]
268 | except KeyError:
269 | if isinstance(expr, sympy.Symbol):
270 | out = _Symbol(expr)
271 | elif isinstance(expr, sympy.Integer):
272 | out = _Integer(expr, make_array)
273 | elif isinstance(expr, sympy.Float):
274 | out = _Float(expr, make_array)
275 | elif isinstance(expr, sympy.Rational):
276 | out = _Rational(expr, make_array)
277 | elif expr in (sympy.E, sympy.pi, sympy.EulerGamma, sympy.I):
278 | out = _Constant(expr, make_array)
279 | else:
280 | out = _Func(expr, memodict, func_lookup, make_array)
281 | memodict[expr] = out
282 | return out
283 |
284 |
285 | def _is_node(x):
286 | return isinstance(x, _AbstractNode)
287 |
288 |
289 | class SymbolicModule(eqx.Module):
290 | nodes: PyTree
291 | has_extra_funcs: bool = eqx.static_field()
292 |
293 | def __init__(
294 | self,
295 | expressions: PyTree,
296 | extra_funcs: Optional[dict] = None,
297 | make_array: bool = True,
298 | ):
299 | if extra_funcs is None:
300 | lookup = _lookup
301 | self.has_extra_funcs = False
302 | else:
303 | lookup = co.ChainMap(extra_funcs, _lookup)
304 | self.has_extra_funcs = True
305 | _convert = ft.partial(
306 | _sympy_to_node,
307 | memodict=dict(),
308 | func_lookup=lookup,
309 | make_array=make_array,
310 | )
311 | self.nodes = jtu.tree_map(_convert, expressions)
312 |
313 | def sympy(self) -> sympy.Expr:
314 | if self.has_extra_funcs:
315 | raise NotImplementedError(
316 | "SymbolicModule cannot be converted back to SymPy if `extra_funcs` "
317 | "is passed."
318 | )
319 | memodict = dict()
320 | return jtu.tree_map(
321 | lambda n: n.sympy(memodict, _reverse_lookup), self.nodes, is_leaf=_is_node
322 | )
323 |
324 | def __call__(self, **symbols):
325 | memodict = symbols
326 | return jtu.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)
327 |
--------------------------------------------------------------------------------
/tests/test_symbolic_module.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import equinox as eqx
16 | import jax
17 | import jax.numpy as jnp
18 | import jax.random as jr
19 | import jax.tree_util as jtu
20 | import sympy
21 |
22 | import sympy2jax
23 |
24 |
25 | def assert_equal(x, y):
26 | x_leaves, x_tree = jtu.tree_flatten(x)
27 | y_leaves, y_tree = jtu.tree_flatten(y)
28 | assert x_tree == y_tree
29 | for xi, yi in zip(x_leaves, y_leaves):
30 | assert type(xi) is type(yi)
31 | if isinstance(xi, jnp.ndarray):
32 | assert xi.shape == yi.shape
33 | assert xi.dtype == yi.dtype
34 | assert jnp.all(xi == yi)
35 | else:
36 | assert xi == yi
37 |
38 |
39 | def assert_sympy_allclose(x, y):
40 | assert isinstance(x, sympy.Expr)
41 | assert isinstance(y, sympy.Expr)
42 | assert x.func is y.func
43 | if isinstance(x, sympy.Float):
44 | assert abs(float(x) - float(y)) < 1e-5
45 | elif isinstance(x, sympy.Integer):
46 | assert x == y
47 | elif isinstance(x, sympy.Rational):
48 | assert x.numerator == y.numerator # pyright: ignore
49 | assert x.denominator == y.denominator # pyright: ignore
50 | elif isinstance(x, sympy.Symbol):
51 | assert x.name == y.name # pyright: ignore
52 | else:
53 | assert len(x.args) == len(y.args)
54 | for xarg, yarg in zip(x.args, y.args):
55 | assert_sympy_allclose(xarg, yarg)
56 |
57 |
58 | def test_example():
59 | x_sym = sympy.symbols("x_sym")
60 | cosx = 1.0 * sympy.cos(x_sym) # pyright: ignore[reportOperatorIssue]
61 | sinx = 2.0 * sympy.sin(x_sym) # pyright: ignore[reportOperatorIssue]
62 | mod = sympy2jax.SymbolicModule([cosx, sinx])
63 |
64 | x = jax.numpy.zeros(3)
65 | out = mod(x_sym=x)
66 | params = jtu.tree_leaves(mod)
67 |
68 | assert_equal(out, [jnp.cos(x), 2 * jnp.sin(x)])
69 | assert_equal(
70 | [x for x in params if eqx.is_array(x)], [jnp.array(1.0), jnp.array(2.0)]
71 | )
72 |
73 |
74 | def test_grad():
75 | x_sym = sympy.symbols("x_sym")
76 | y = 2.1 * x_sym**2
77 | mod = sympy2jax.SymbolicModule(y)
78 | x = jnp.array(1.1)
79 |
80 | grad_m = eqx.filter_grad(lambda m, z: m(x_sym=z))(mod, x)
81 | grad_z = eqx.filter_grad(lambda z, m: m(x_sym=z))(x, mod)
82 |
83 | true_grad_m = eqx.filter(
84 | sympy2jax.SymbolicModule(1.21 * x_sym**2), eqx.is_inexact_array
85 | )
86 | true_grad_z = jnp.array(4.2 * x)
87 |
88 | assert_equal(grad_m, true_grad_m)
89 | assert_equal(grad_z, true_grad_z)
90 |
91 | mod2 = eqx.apply_updates(mod, grad_m)
92 | expr = mod2.sympy()
93 |
94 | assert_sympy_allclose(expr, 3.31 * x_sym**2)
95 |
96 |
97 | def test_reduce():
98 | x, y, z = sympy.symbols("x y z")
99 | z = 2 * x * y * z
100 | mod = sympy2jax.SymbolicModule(expressions=z)
101 | mod(x=jnp.array(0.4), y=jnp.array(0.5), z=jnp.array(0.6))
102 |
103 | z = 2 + x + y + z
104 | mod = sympy2jax.SymbolicModule(expressions=z)
105 | mod(x=jnp.array(0.4), y=jnp.array(0.5), z=jnp.array(0.6))
106 |
107 |
108 | def test_special_subclasses():
109 | x, y = sympy.symbols("x y")
110 | z = x - 1 # sympy.core.numbers.NegativeOne
111 | w = y * 0 # sympy.core.numbers.Zero
112 | v = x + 1 / 2 # sympy.core.numbers.OneHalf
113 |
114 | mod = sympy2jax.SymbolicModule([z, w, v])
115 | assert_equal(mod(x=1, y=1), [jnp.array(0), jnp.array(0), jnp.array(1.5)])
116 | assert mod.sympy() == [z, sympy.Integer(0), v]
117 |
118 |
119 | def test_rational():
120 | x = sympy.symbols("x")
121 | y = x + sympy.Integer(3) / sympy.Integer(7)
122 | mod = sympy2jax.SymbolicModule(y)
123 | assert mod(x=1.0) == 1 + 3 / 7
124 | assert mod.sympy() == y
125 |
126 |
127 | def test_constants():
128 | x = sympy.symbols("x")
129 | y = x + sympy.pi + sympy.E + sympy.EulerGamma + sympy.I
130 | mod = sympy2jax.SymbolicModule(y)
131 | assert jnp.isclose(mod(x=1.0), 1 + jnp.pi + jnp.e + jnp.euler_gamma + 1j)
132 | assert mod.sympy() == y
133 |
134 |
135 | def test_extra_funcs():
136 | class _MLP(eqx.Module):
137 | mlp: eqx.nn.MLP
138 |
139 | def __init__(self):
140 | self.mlp = eqx.nn.MLP(1, 1, 2, 2, key=jr.PRNGKey(0))
141 |
142 | def __call__(self, x):
143 | x = jnp.asarray(x)
144 | return self.mlp(x[None])[0]
145 |
146 | expr = sympy.parsing.sympy_parser.parse_expr("f(x) + y")
147 | mlp = _MLP()
148 | mod = sympy2jax.SymbolicModule(expr, {sympy.Function("f"): mlp})
149 | mod(x=1.0, y=2.0)
150 |
151 | def _get_params(module):
152 | return {id(x) for x in jtu.tree_leaves(module) if eqx.is_array(x)}
153 |
154 | assert _get_params(mod).issuperset(_get_params(mlp))
155 |
156 |
157 | def test_concatenate():
158 | x, y, z = sympy.symbols("x y z")
159 | cat = sympy2jax.concatenate(x, y, z)
160 | mod = sympy2jax.SymbolicModule(expressions=cat)
161 | assert_equal(
162 | mod(x=jnp.array([0.4, 0.5]), y=jnp.array([0.6, 0.7]), z=jnp.array([0.8, 0.9])),
163 | jnp.array([0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
164 | )
165 |
166 |
167 | def test_stack():
168 | x, y, z = sympy.symbols("x y z")
169 | stack = sympy2jax.stack(x, y, z)
170 | mod = sympy2jax.SymbolicModule(expressions=stack)
171 | assert_equal(
172 | mod(x=jnp.array(0.4), y=jnp.array(0.5), z=jnp.array(0.6)),
173 | jnp.array([0.4, 0.5, 0.6]),
174 | )
175 |
176 |
177 | def test_non_array_to_sympy():
178 | mod = sympy2jax.SymbolicModule(expressions=[sympy.Integer(1)], make_array=False)
179 | assert mod.sympy() == [sympy.Integer(1)]
180 |
--------------------------------------------------------------------------------