├── .github ├── dependabot.yml └── workflows │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── noxfile.py ├── pyproject.toml ├── src └── jpu │ ├── __init__.py │ ├── core.py │ ├── jax_numpy_func.py │ ├── monkey.py │ ├── numpy │ ├── __init__.py │ ├── __init__.pyi │ ├── linalg.py │ └── linalg.pyi │ ├── quantity.py │ └── registry.py └── tests ├── test_core.py ├── test_numpy.py ├── test_readme.py └── test_registry.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - "*" 9 | pull_request: 10 | workflow_dispatch: 11 | 12 | jobs: 13 | tests: 14 | strategy: 15 | fail-fast: false 16 | matrix: 17 | nox-session: 18 | - "tests" 19 | python-version: 20 | - "3.10" 21 | - "3.11" 22 | - "3.12" 23 | 24 | include: 25 | - nox-session: "doctest" 26 | python-version: "3.10" 27 | 28 | runs-on: ubuntu-latest 29 | steps: 30 | - name: "Init: checkout" 31 | uses: actions/checkout@v4 32 | with: 33 | fetch-depth: 0 34 | 35 | - name: "Init: Python" 36 | uses: actions/setup-python@v5 37 | with: 38 | python-version: ${{ matrix.python-version }} 39 | 40 | - name: "Install: dependencies" 41 | run: | 42 | python -m pip install -U pip 43 | python -m pip install -U nox 44 | 45 | - name: "Tests: run" 46 | run: python -m nox --non-interactive -s "${{ matrix.nox-session }}" 47 | 48 | build: 49 | runs-on: ubuntu-latest 50 | steps: 51 | - uses: actions/checkout@v4 52 | with: 53 | fetch-depth: 0 54 | - uses: actions/setup-python@v5 55 | name: Install Python 56 | with: 57 | python-version: "3.10" 58 | - name: Install dependencies 59 | run: | 60 | python -m pip install -U pip 61 | python -m pip install -U build twine 62 | - name: Build the distribution 63 | run: python -m build . 64 | - name: Check the distribution 65 | run: python -m twine check --strict dist/* 66 | - uses: actions/upload-artifact@v4 67 | with: 68 | path: dist/* 69 | 70 | publish: 71 | environment: 72 | name: pypi 73 | url: https://pypi.org/p/jpu 74 | permissions: 75 | id-token: write 76 | needs: [tests, build] 77 | runs-on: ubuntu-latest 78 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') 79 | steps: 80 | - uses: actions/download-artifact@v4 81 | with: 82 | name: artifact 83 | path: dist 84 | - uses: pypa/gh-action-pypi-publish@v1.12.4 85 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *_version.py 2 | .coverage 3 | .tox 4 | .coverage* 5 | /*.ipynb 6 | *.DAT 7 | docs/api/summary 8 | __pycache__ 9 | *.egg-info 10 | .*_cache 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: "v4.4.0" 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | exclude_types: [json, binary] 8 | - repo: https://github.com/psf/black 9 | rev: "23.7.0" 10 | hooks: 11 | - id: black-jupyter 12 | - repo: https://github.com/astral-sh/ruff-pre-commit 13 | rev: "v0.0.285" 14 | hooks: 15 | - id: ruff 16 | args: [--fix, --exit-non-zero-on-fix] 17 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | - Demonstrating empathy and kindness toward other people 21 | - Being respectful of differing opinions, viewpoints, and experiences 22 | - Giving and gracefully accepting constructive feedback 23 | - Accepting responsibility, and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | - Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | - The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | - Trolling, insulting or derogatory comments, and personal or political attacks 33 | - Public or private harassment 34 | - Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | - Other conduct that could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned with this Code of Conduct and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | foreman.mackey@gmail.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. 120 | 121 | Community Impact Guidelines were inspired by 122 | [Mozilla's code of conduct enforcement ladder][mozilla coc]. 123 | 124 | For answers to common questions about this code of conduct, see the FAQ at 125 | [https://www.contributor-covenant.org/faq][faq]. Translations are available 126 | at [https://www.contributor-covenant.org/translations][translations]. 127 | 128 | [homepage]: https://www.contributor-covenant.org 129 | [v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html 130 | [mozilla coc]: https://github.com/mozilla/diversity 131 | [faq]: https://www.contributor-covenant.org/faq 132 | [translations]: https://www.contributor-covenant.org/translations 133 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributor Guide 2 | 3 | Thank you for your interest in improving this project. This project is 4 | open-source under the MIT License and welcomes contributions in the form of bug 5 | reports, feature requests, and pull requests. 6 | 7 | Here is a list of important resources for contributors: 8 | 9 | - [Source Code](https://github.com/dfm/jpu) 10 | - [Documentation](https://github.com/dfm/jpu) 11 | - [Issue Tracker](https://github.com/dfm/jpu/issues) 12 | 13 | ## How to report a bug 14 | 15 | Report bugs on the [Issue Tracker](https://github.com/dfm/jpu/issues). 16 | 17 | When filing an issue, make sure to answer these questions: 18 | 19 | - Which operating system and Python version are you using? 20 | - Which version of this project are you using? 21 | - What did you do? 22 | - What did you expect to see? 23 | - What did you see instead? 24 | 25 | The best way to get your bug fixed is to provide a test case, and/or steps to 26 | reproduce the issue. In particular, please include a [Minimal, Reproducible 27 | Example](https://stackoverflow.com/help/minimal-reproducible-example). 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright 2022, 2023 Simons Foundation, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | exclude .* 2 | prune .github 3 | prune docs 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JAX + Units 2 | 3 | **Built with [JAX](https://jax.readthedocs.io) and 4 | [Pint](https://pint.readthedocs.io)!** 5 | 6 | This module provides an interface between [JAX](https://jax.readthedocs.io) and 7 | [Pint](https://pint.readthedocs.io) to allow JAX to support operations with 8 | units. The propagation of units happens at trace time, so jitted functions 9 | should see no runtime cost. This library is experimental so expect some sharp 10 | edges. 11 | 12 | For example: 13 | 14 | ```python 15 | >>> import jax 16 | >>> import jax.numpy as jnp 17 | >>> import jpu 18 | >>> 19 | >>> u = jpu.UnitRegistry() 20 | >>> 21 | >>> @jax.jit 22 | ... def add_two_lengths(a, b): 23 | ... return a + b 24 | ... 25 | >>> add_two_lengths(3 * u.m, jnp.array([4.5, 1.2, 3.9]) * u.cm) 26 | 27 | 28 | ``` 29 | 30 | ## Installation 31 | 32 | To install, use `pip`: 33 | 34 | ```bash 35 | python -m pip install jpu 36 | ``` 37 | 38 | The only dependencies are `jax` and `pint`, and these will also be installed, if 39 | not already in your environment. Take a look at [the JAX docs for more 40 | information about installing JAX on different 41 | systems](https://github.com/google/jax#installation). 42 | 43 | ## Usage 44 | 45 | Here is a slightly more complete example: 46 | 47 | ```python 48 | >>> import jax 49 | >>> import numpy as np 50 | >>> from jpu import UnitRegistry, numpy as jnpu 51 | >>> 52 | >>> u = UnitRegistry() 53 | >>> 54 | >>> @jax.jit 55 | ... def projectile_motion(v_init, theta, time, g=u.standard_gravity): 56 | ... """Compute the motion of a projectile with support for units""" 57 | ... x = v_init * time * jnpu.cos(theta) 58 | ... y = v_init * time * jnpu.sin(theta) - 0.5 * g * jnpu.square(time) 59 | ... return x.to(u.m), y.to(u.m) 60 | ... 61 | >>> x, y = projectile_motion( 62 | ... 5.0 * u.km / u.h, 60 * u.deg, np.linspace(0, 1, 50) * u.s 63 | ... ) 64 | 65 | ``` 66 | 67 | ## Technical details & limitations 68 | 69 | The most significant limitation of this library is the fact that users must use 70 | `jpu.numpy` functions when interacting with "quantities" with units instead of 71 | the `jax.numpy` interface. This is because JAX does not (yet?) provide a general 72 | interface for dispatching of ufuncs on custom array classes. I have played 73 | around with the undocumented `__jax_array__` interface, but it's not really 74 | flexible enough, and it isn't currently compatible with Pytree objects. 75 | 76 | So far, only a subset of the `numpy`/`jax.numpy` interface is implemented. Pull 77 | requests adding broader support (including submodules) would be welcome! 78 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | import nox 2 | 3 | 4 | @nox.session 5 | @nox.parametrize("x64", [True, False]) 6 | def tests(session, x64): 7 | session.install(".[test]") 8 | args = session.posargs 9 | if not args: 10 | args = ["-v"] 11 | if x64: 12 | env = {"JAX_ENABLE_X64": "1"} 13 | else: 14 | env = {} 15 | env["PYTHONWARNINGS"] = "error::DeprecationWarning" 16 | session.run("pytest", *args, env=env) 17 | 18 | 19 | @nox.session 20 | def doctest(session): 21 | session.install("jaxlib") 22 | session.install(".") 23 | session.run("python", "-m", "doctest", "-v", "README.md") 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "jpu" 3 | description = "JAX + Units" 4 | authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }] 5 | readme = "README.md" 6 | requires-python = ">=3.9" 7 | license = { text = "MIT License" } 8 | classifiers = [ 9 | "Operating System :: OS Independent", 10 | "Programming Language :: Python :: 3", 11 | "Development Status :: 5 - Production/Stable", 12 | "License :: OSI Approved :: MIT License", 13 | ] 14 | dynamic = ["version"] 15 | dependencies = ["jax", "pint"] 16 | 17 | [project.urls] 18 | "Homepage" = "https://github.com/dfm/jpu" 19 | "Source" = "https://github.com/dfm/jpu" 20 | "Bug Tracker" = "https://github.com/dfm/jpu/issues" 21 | 22 | [project.optional-dependencies] 23 | test = ["pytest", "jaxlib"] 24 | 25 | [build-system] 26 | requires = ["hatchling", "hatch-vcs"] 27 | build-backend = "hatchling.build" 28 | 29 | [tool.hatch.version] 30 | source = "vcs" 31 | 32 | [tool.hatch.build.hooks.vcs] 33 | version-file = "src/jpu/jpu_version.py" 34 | 35 | [tool.black] 36 | target-version = ["py39"] 37 | line-length = 88 38 | 39 | [tool.ruff] 40 | line-length = 88 41 | target-version = "py39" 42 | select = ["F", "I", "E", "W", "YTT", "B", "Q", "PLE", "PLR", "PLW", "UP"] 43 | ignore = ["PLR0912", "PLR0913"] 44 | exclude = [] 45 | 46 | [tool.ruff.isort] 47 | known-first-party = ["jpu"] 48 | combine-as-imports = true 49 | -------------------------------------------------------------------------------- /src/jpu/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "monkey", 3 | "numpy", 4 | "grad", 5 | "value_and_grad", 6 | "__version__", 7 | "UnitRegistry", 8 | ] 9 | 10 | from jpu import monkey, numpy 11 | from jpu.core import grad, value_and_grad 12 | from jpu.jpu_version import __version__ 13 | from jpu.registry import UnitRegistry 14 | -------------------------------------------------------------------------------- /src/jpu/core.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax._src.util import wraps 3 | from jax.tree_util import tree_map 4 | 5 | 6 | def is_quantity(obj): 7 | return hasattr(obj, "_units") and hasattr(obj, "_magnitude") 8 | 9 | 10 | def grad( 11 | fun, 12 | argnums=0, 13 | has_aux=False, 14 | holomorphic=False, 15 | allow_int=False, 16 | reduce_axes=(), 17 | ): 18 | value_and_grad_f = value_and_grad( 19 | fun, 20 | argnums, 21 | has_aux=has_aux, 22 | holomorphic=holomorphic, 23 | allow_int=allow_int, 24 | reduce_axes=reduce_axes, 25 | ) 26 | 27 | docstr = ( 28 | "Gradient of {fun} with respect to positional argument(s) " 29 | "{argnums}. Takes the same arguments as {fun} but returns the " 30 | "gradient, which has the same shape as the arguments at " 31 | "positions {argnums}." 32 | ) 33 | 34 | @wraps(fun, docstr=docstr, argnums=argnums) 35 | def grad_f(*args, **kwargs): 36 | _, g = value_and_grad_f(*args, **kwargs) 37 | return g 38 | 39 | @wraps(fun, docstr=docstr, argnums=argnums) 40 | def grad_f_aux(*args, **kwargs): 41 | (_, aux), g = value_and_grad_f(*args, **kwargs) 42 | return g, aux 43 | 44 | return grad_f_aux if has_aux else grad_f 45 | 46 | 47 | def value_and_grad( 48 | fun, 49 | argnums=0, 50 | has_aux=False, 51 | holomorphic=False, 52 | allow_int=False, 53 | reduce_axes=(), 54 | ): 55 | # inspired by: https://twitter.com/shoyer/status/1531703890512490499 56 | docstr = ( 57 | "Value and gradient of {fun} with respect to positional " 58 | "argument(s) {argnums}. Takes the same arguments as {fun} but " 59 | "returns a two-element tuple where the first element is the value " 60 | "of {fun} and the second element is the gradient, which has the " 61 | "same shape as the arguments at positions {argnums}." 62 | ) 63 | 64 | def fun_wo_units(*args, **kwargs): 65 | if has_aux: 66 | result, aux = fun(*args, **kwargs) 67 | else: 68 | result = fun(*args, **kwargs) 69 | aux = None 70 | if is_quantity(result): 71 | magnitude = result.magnitude 72 | units = result.units 73 | else: 74 | magnitude = result 75 | units = None 76 | if has_aux: 77 | return magnitude, (units, aux) 78 | else: 79 | return magnitude, (units, None) 80 | 81 | value_and_grad_fun = jax.value_and_grad( 82 | fun_wo_units, 83 | argnums=argnums, 84 | has_aux=True, 85 | holomorphic=holomorphic, 86 | allow_int=allow_int, 87 | reduce_axes=reduce_axes, 88 | ) 89 | 90 | @wraps(fun, docstr=docstr, argnums=argnums) 91 | def wrapped(*args, **kwargs): 92 | (result_wo_units, (result_units, aux)), grad = value_and_grad_fun( 93 | *args, **kwargs 94 | ) 95 | 96 | if result_units is None: 97 | result = result_wo_units 98 | grad = tree_map( 99 | lambda g: (g.magnitude * (1 / g.units) if is_quantity(g) else g), 100 | grad, 101 | is_leaf=is_quantity, 102 | ) 103 | 104 | else: 105 | result = result_wo_units * result_units 106 | grad = tree_map( 107 | lambda g: ( 108 | g.magnitude * result_units / g.units 109 | if is_quantity(g) 110 | else g * result_units 111 | ), 112 | grad, 113 | is_leaf=is_quantity, 114 | ) 115 | 116 | if has_aux: 117 | return (result, aux), grad 118 | else: 119 | return result, grad 120 | 121 | return wrapped 122 | -------------------------------------------------------------------------------- /src/jpu/jax_numpy_func.py: -------------------------------------------------------------------------------- 1 | """When imported, this submodule implements all of the functions in 2 | ``jpu.numpy`` closely following the logic implemented in 3 | ``pint.facets.numpy.numpy_func``. These implemented functions are then injected 4 | into the ``jpu.numpy`` namespace in the appropriate places. 5 | 6 | Since JAX doesn't support any sort of array dispatch protocol, you'll need to 7 | use the functions defined in ``jpu.numpy`` instead of ``jax.numpy`` to get 8 | support for units. 9 | """ 10 | 11 | from functools import wraps 12 | from inspect import signature 13 | from itertools import chain 14 | 15 | import jax.numpy as jnp 16 | from pint import DimensionalityError 17 | from pint.facets.numpy import numpy_func 18 | 19 | from jpu import numpy as jpu_numpy 20 | from jpu.numpy import linalg as linalg 21 | 22 | HANDLED_FUNCTIONS = {} 23 | 24 | 25 | def implements(numpy_func_string): 26 | def decorator(func): 27 | # Unlike in the pint implementation, we assign all functions to the jpu.numpy 28 | # module 29 | func_str_split = numpy_func_string.split(".") 30 | func_name = func_str_split[-1] 31 | 32 | # Get the appropriate jpu.numpy submodule 33 | module = jpu_numpy 34 | for func_str_piece in func_str_split[:-1]: 35 | module = getattr(module, func_str_piece) 36 | 37 | # Extract the jax.numpy function that we can fall back to 38 | jax_func = getattr(jnp, func_str_split[0]) 39 | for func_str_piece in func_str_split[1:]: 40 | jax_func = getattr(jax_func, func_str_piece) 41 | 42 | # Unlike the pint implementation, we fall back to the jnp implementation 43 | # when none of the inputs are Quantities 44 | @wraps(func) 45 | def wrapped(*args, **kwargs): 46 | # TODO(dfm): This could maybe just check args, because we typically 47 | # assume that the args will have units... Are there any cases where a 48 | # quantity in kwargs would be valid? 49 | if not any(map(numpy_func._is_quantity, chain(args, kwargs.values()))): 50 | return jax_func(*args, **kwargs) 51 | else: 52 | return func(*args, **kwargs) 53 | 54 | # Save this wrapped function to the jpu.numpy module 55 | if hasattr(module, func_name): 56 | print(f"Function {func_name} has already been implemented") 57 | setattr(module, func_name, wrapped) 58 | 59 | # The rest is the same as the pint implementation 60 | HANDLED_FUNCTIONS[numpy_func_string] = wrapped 61 | 62 | return wrapped 63 | 64 | return decorator 65 | 66 | 67 | def implement_func(func_str, input_units=None, output_unit=None): 68 | func_str_split = func_str.split(".") 69 | func = getattr(jnp, func_str_split[0]) 70 | for func_str_piece in func_str_split[1:]: 71 | func = getattr(func, func_str_piece) 72 | 73 | @implements(func_str) 74 | def implementation(*args, **kwargs): 75 | first_input_units = numpy_func._get_first_input_units(args, kwargs) 76 | 77 | if input_units == "all_consistent": 78 | stripped_args, stripped_kwargs = numpy_func.convert_to_consistent_units( 79 | *args, pre_calc_units=first_input_units, **kwargs 80 | ) 81 | else: 82 | if isinstance(input_units, str): 83 | pre_calc_units = first_input_units._REGISTRY.parse_units(input_units) 84 | else: 85 | pre_calc_units = input_units 86 | 87 | stripped_args, stripped_kwargs = numpy_func.convert_to_consistent_units( 88 | *args, pre_calc_units=pre_calc_units, **kwargs 89 | ) 90 | 91 | result_magnitude = func(*stripped_args, **stripped_kwargs) 92 | 93 | if output_unit is None: 94 | return result_magnitude 95 | elif output_unit == "match_input": 96 | result_unit = first_input_units 97 | elif output_unit in ( 98 | "sum", 99 | "mul", 100 | "delta", 101 | "delta,div", 102 | "div", 103 | "invdiv", 104 | "variance", 105 | "square", 106 | "sqrt", 107 | "cbrt", 108 | "reciprocal", 109 | "size", 110 | ): 111 | result_unit = numpy_func.get_op_output_unit( 112 | output_unit, first_input_units, tuple(chain(args, kwargs.values())) 113 | ) 114 | else: 115 | result_unit = output_unit 116 | 117 | return first_input_units._REGISTRY.Quantity(result_magnitude, result_unit) 118 | 119 | 120 | # Unlike the pint implementation, we don't explicitly distinguish between ufuncs 121 | # and functions, therefore some of the functions are commented out here because 122 | # they are also implemented below as functions 123 | function_specs = [ 124 | # ** ****** ** 125 | # ** UFUNCS ** 126 | # ** ****** ** 127 | # 128 | # strip input and output 129 | ("isnan", None, None), 130 | ("isinf", None, None), 131 | ("isfinite", None, None), 132 | ("signbit", None, None), 133 | ("sign", None, None), 134 | # bare output 135 | ("equal", "all_consistent", None), 136 | ("greater", "all_consistent", None), 137 | ("greater_equal", "all_consistent", None), 138 | ("less", "all_consistent", None), 139 | ("less_equal", "all_consistent", None), 140 | ("not_equal", "all_consistent", None), 141 | # matching input, set output 142 | ("arctan2", "all_consistent", "radian"), 143 | # set input and output 144 | # ("cumprod", "", ""), 145 | ("arccos", "", "radian"), 146 | ("arcsin", "", "radian"), 147 | ("arctan", "", "radian"), 148 | ("arccosh", "", "radian"), 149 | ("arcsinh", "", "radian"), 150 | ("arctanh", "", "radian"), 151 | ("exp", "", ""), 152 | ("expm1", "", ""), 153 | ("exp2", "", ""), 154 | ("log", "", ""), 155 | ("log10", "", ""), 156 | ("log1p", "", ""), 157 | ("log2", "", ""), 158 | ("sin", "radian", ""), 159 | ("cos", "radian", ""), 160 | ("tan", "radian", ""), 161 | ("sinh", "radian", ""), 162 | ("cosh", "radian", ""), 163 | ("tanh", "radian", ""), 164 | ("radians", "degree", "radian"), 165 | ("degrees", "radian", "degree"), 166 | ("deg2rad", "degree", "radian"), 167 | ("rad2deg", "radian", "degree"), 168 | ("logaddexp", "", ""), 169 | ("logaddexp2", "", ""), 170 | # matching input, copy output 171 | # ("compress", "all_consistent", "match_input"), 172 | ("conj", "all_consistent", "match_input"), 173 | ("conjugate", "all_consistent", "match_input"), 174 | # ("copy", "all_consistent", "match_input"), 175 | # ("diagonal", "all_consistent", "match_input"), 176 | # ("max", "all_consistent", "match_input"), 177 | # ("mean", "all_consistent", "match_input"), 178 | # ("min", "all_consistent", "match_input"), 179 | # ("ptp", "all_consistent", "match_input"), 180 | # ("ravel", "all_consistent", "match_input"), 181 | # ("repeat", "all_consistent", "match_input"), 182 | # ("reshape", "all_consistent", "match_input"), 183 | # ("round", "all_consistent", "match_input"), 184 | # ("squeeze", "all_consistent", "match_input"), 185 | # ("swapaxes", "all_consistent", "match_input"), 186 | # ("take", "all_consistent", "match_input"), 187 | ("trace", "all_consistent", "match_input"), 188 | # ("transpose", "all_consistent", "match_input"), 189 | ("ceil", "all_consistent", "match_input"), 190 | ("floor", "all_consistent", "match_input"), 191 | ("hypot", "all_consistent", "match_input"), 192 | ("rint", "all_consistent", "match_input"), 193 | ("copysign", "all_consistent", "match_input"), 194 | ("nextafter", "all_consistent", "match_input"), 195 | ("trunc", "all_consistent", "match_input"), 196 | ("absolute", "all_consistent", "match_input"), 197 | ("positive", "all_consistent", "match_input"), 198 | ("negative", "all_consistent", "match_input"), 199 | ("maximum", "all_consistent", "match_input"), 200 | ("minimum", "all_consistent", "match_input"), 201 | ("fabs", "all_consistent", "match_input"), 202 | # copy input to output 203 | ("ldexp", None, "match_input"), 204 | ("fmod", None, "match_input"), 205 | ("mod", None, "match_input"), 206 | ("remainder", None, "match_input"), 207 | # output operation on input 208 | ("var", None, "square"), 209 | ("multiply", None, "mul"), 210 | ("true_divide", None, "div"), 211 | ("divide", None, "div"), 212 | ("floor_divide", None, "div"), 213 | ("sqrt", None, "sqrt"), 214 | ("cbrt", None, "cbrt"), 215 | ("square", None, "square"), 216 | ("reciprocal", None, "reciprocal"), 217 | ("std", None, "sum"), 218 | ("sum", None, "sum"), 219 | ("cumsum", None, "sum"), 220 | ("matmul", None, "mul"), 221 | # 222 | # ** ********* ** 223 | # ** FUNCTIONS ** 224 | # ** ********* ** 225 | # 226 | # matching input, copy output 227 | ("block", "all_consistent", "match_input"), 228 | ("hstack", "all_consistent", "match_input"), 229 | ("vstack", "all_consistent", "match_input"), 230 | ("dstack", "all_consistent", "match_input"), 231 | ("column_stack", "all_consistent", "match_input"), 232 | ("broadcast_arrays", "all_consistent", "match_input"), 233 | # strip input and output 234 | ("size", None, None), 235 | ("isreal", None, None), 236 | ("iscomplex", None, None), 237 | ("shape", None, None), 238 | ("ones_like", None, None), 239 | ("zeros_like", None, None), 240 | ("empty_like", None, None), 241 | ("argsort", None, None), 242 | ("argmin", None, None), 243 | ("argmax", None, None), 244 | ("ndim", None, None), 245 | ("nanargmax", None, None), 246 | ("nanargmin", None, None), 247 | ("count_nonzero", None, None), 248 | ("nonzero", None, None), 249 | ("result_type", None, None), 250 | # output operation on input 251 | # ("std", None, "sum"), 252 | ("nanstd", None, "sum"), 253 | # ("sum", None, "sum"), 254 | ("nansum", None, "sum"), 255 | # ("cumsum", None, "sum"), 256 | ("nancumsum", None, "sum"), 257 | ("diff", None, "delta"), 258 | ("ediff1d", None, "delta"), 259 | ("gradient", None, "delta,div"), 260 | ("linalg.solve", None, "invdiv"), 261 | # ("var", None, "variance"), 262 | ("nanvar", None, "variance"), 263 | ] 264 | 265 | for func_str, input_units, output_unit in function_specs: 266 | implement_func(func_str, input_units=input_units, output_unit=output_unit) 267 | 268 | 269 | @implements("modf") 270 | def _modf(x, *args, **kwargs): 271 | (x,), output_wrap = numpy_func.unwrap_and_wrap_consistent_units(x) 272 | return tuple(output_wrap(y) for y in jnp.modf(x, *args, **kwargs)) 273 | 274 | 275 | @implements("frexp") 276 | def _frexp(x, *args, **kwargs): 277 | (x,), output_wrap = numpy_func.unwrap_and_wrap_consistent_units(x) 278 | mantissa, exponent = jnp.frexp(x, *args, **kwargs) 279 | return output_wrap(mantissa), exponent 280 | 281 | 282 | @implements("power") 283 | def _power(x1, x2): 284 | if numpy_func._is_quantity(x1): 285 | return x1**x2 286 | 287 | return x2.__rpow__(x1) 288 | 289 | 290 | @implements("add") 291 | def _add(x1, x2, *args, **kwargs): 292 | (x1, x2), output_wrap = numpy_func.unwrap_and_wrap_consistent_units(x1, x2) 293 | return output_wrap(jnp.add(x1, x2, *args, **kwargs)) # type: ignore 294 | 295 | 296 | @implements("subtract") 297 | def _subtract(x1, x2, *args, **kwargs): 298 | (x1, x2), output_wrap = numpy_func.unwrap_and_wrap_consistent_units(x1, x2) 299 | return output_wrap(jnp.subtract(x1, x2, *args, **kwargs)) # type: ignore 300 | 301 | 302 | @implements("meshgrid") 303 | def _meshgrid(*xi, **kwargs): 304 | input_units = (x.units for x in xi) 305 | res = jnp.meshgrid(*(x.m for x in xi), **kwargs) 306 | return [out * unit for out, unit in zip(res, input_units)] 307 | 308 | 309 | @implements("full_like") 310 | def _full_like(a, fill_value, **kwargs): 311 | if hasattr(fill_value, "_REGISTRY"): 312 | return fill_value._REGISTRY.Quantity( 313 | jnp.ones_like(a, **kwargs) * fill_value.m, 314 | fill_value.units, 315 | ) 316 | 317 | return jnp.ones_like(a, **kwargs) * fill_value 318 | 319 | 320 | @implements("interp") 321 | def _interp(x, xp, fp, left=None, right=None, period=None): 322 | (x, xp, period), _ = numpy_func.unwrap_and_wrap_consistent_units(x, xp, period) 323 | (fp, right, left), output_wrap = numpy_func.unwrap_and_wrap_consistent_units( 324 | fp, left, right 325 | ) 326 | res = jnp.interp(x, xp, fp, left=left, right=right, period=period) # type: ignore 327 | return output_wrap(res) 328 | 329 | 330 | @implements("concatenate") 331 | def _concatenate(sequence, *args, **kwargs): 332 | sequence, output_wrap = numpy_func.unwrap_and_wrap_consistent_units(*sequence) 333 | return output_wrap(jnp.concatenate(sequence, *args, **kwargs)) # type: ignore 334 | 335 | 336 | @implements("stack") 337 | def _stack(arrays, *args, **kwargs): 338 | arrays, output_wrap = numpy_func.unwrap_and_wrap_consistent_units(*arrays) 339 | return output_wrap(jnp.stack(arrays, *args, **kwargs)) # type: ignore 340 | 341 | 342 | @implements("unwrap") 343 | def _unwrap(p, discont=None, axis=-1): 344 | # np.unwrap only dispatches over p argument, so assume it is a Quantity 345 | discont = jnp.pi if discont is None else discont 346 | return p._REGISTRY.Quantity( 347 | jnp.unwrap(p.m_as("rad"), discont, axis=axis), "rad" 348 | ).to(p.units) 349 | 350 | 351 | @implements("einsum") 352 | def _einsum(subscripts, *operands, **kwargs): 353 | operand_magnitudes, _ = numpy_func.convert_to_consistent_units( 354 | *operands, pre_calc_units=None 355 | ) 356 | output_unit = numpy_func.get_op_output_unit( 357 | "mul", numpy_func._get_first_input_units(operands), operands 358 | ) 359 | return ( 360 | jnp.einsum( 361 | subscripts, 362 | *operand_magnitudes, # type: ignore 363 | **kwargs, 364 | ) 365 | * output_unit 366 | ) 367 | 368 | 369 | @implements("isin") 370 | def _isin(element, test_elements, assume_unique=False, invert=False): 371 | if not numpy_func._is_quantity(element): 372 | raise ValueError( 373 | "Cannot test if unit-aware elements are in not-unit-aware array" 374 | ) 375 | 376 | if numpy_func._is_quantity(test_elements): 377 | try: 378 | test_elements = test_elements.m_as(element.units) 379 | except DimensionalityError: 380 | # Incompatible unit test elements cannot be in element 381 | return jnp.full(element.shape, False) 382 | elif not element.dimensionless: 383 | # Unit do not match, so all false 384 | return jnp.full(element.shape, False) 385 | else: 386 | # Convert to units of element 387 | element._REGISTRY.Quantity(test_elements).m_as(element.units) 388 | 389 | return jnp.isin( 390 | element.m, test_elements, assume_unique=assume_unique, invert=invert 391 | ) 392 | 393 | 394 | @implements("pad") 395 | def _pad(array, pad_width, mode="constant", **kwargs): 396 | def _recursive_convert(arg, unit): 397 | if not numpy_func._is_quantity(arg): 398 | arg = unit._REGISTRY.Quantity(arg, "dimensionless") 399 | return arg.m_as(unit) 400 | 401 | # pad only dispatches on array argument, so we know it is a Quantity 402 | units = array.units 403 | 404 | # Handle flexible constant_values and end_values, converting to units if Quantity 405 | # and ignoring if not 406 | for key in ("constant_values", "end_values"): 407 | if key in kwargs: 408 | kwargs[key] = _recursive_convert(kwargs[key], units) 409 | 410 | return units._REGISTRY.Quantity( 411 | jnp.pad(array._magnitude, pad_width, mode=mode, **kwargs), units 412 | ) 413 | 414 | 415 | def _require_multiplicative(func): 416 | @wraps(func) 417 | def wrapped(a, *args, **kwargs): 418 | if numpy_func._is_quantity(a) and not a._is_multiplicative: 419 | raise ValueError("Boolean value of Quantity with offset unit is ambiguous.") 420 | 421 | return func(a, *args, **kwargs) 422 | 423 | return wrapped 424 | 425 | 426 | @implements("where") 427 | @_require_multiplicative 428 | def _where(condition, *args): 429 | condition = getattr(condition, "magnitude", condition) 430 | args, output_wrap = numpy_func.unwrap_and_wrap_consistent_units(*args) 431 | return output_wrap(jnp.where(condition, *args)) # type: ignore 432 | 433 | 434 | @implements("any") 435 | @_require_multiplicative 436 | def _any(a, *args, **kwargs): 437 | return jnp.any(a._magnitude, *args, **kwargs) 438 | 439 | 440 | @implements("all") 441 | @_require_multiplicative 442 | def _all(a, *args, **kwargs): 443 | return jnp.all(a._magnitude, *args, **kwargs) 444 | 445 | 446 | def implement_prod_func(name): 447 | func = getattr(jnp, name, None) 448 | 449 | @implements(name) 450 | def _prod(a, *args, **kwargs): 451 | arg_names = ("axis", "dtype", "out", "keepdims", "initial", "where") 452 | all_kwargs = dict(**dict(zip(arg_names, args)), **kwargs) 453 | axis = all_kwargs.get("axis", None) 454 | where = all_kwargs.get("where", None) 455 | 456 | registry = a.units._REGISTRY 457 | 458 | if axis is not None and where is not None: 459 | raise NotImplementedError 460 | elif axis is not None: 461 | units = a.units ** a.shape[axis] 462 | elif where is not None: 463 | exponent = jnp.sum(where) 464 | units = a.units**exponent 465 | else: 466 | exponent = ( 467 | jnp.sum(jnp.logical_not(jnp.isnan(a))) if name == "nanprod" else a.size 468 | ) 469 | units = a.units**exponent 470 | 471 | result = func(a._magnitude, *args, **kwargs) # type: ignore 472 | 473 | return registry.Quantity(result, units) 474 | 475 | 476 | for name in ("prod", "nanprod"): 477 | implement_prod_func(name) 478 | 479 | 480 | def implement_mul_func(func): 481 | func = getattr(jnp, func_str) 482 | 483 | @implements(func_str) 484 | def implementation(a, b, **kwargs): 485 | a = numpy_func._base_unit_if_needed(a) 486 | units = a.units 487 | if hasattr(b, "units"): 488 | b = numpy_func._base_unit_if_needed(b) 489 | units *= b.units 490 | b = b._magnitude 491 | 492 | mag = func(a._magnitude, b, **kwargs) 493 | return a.units._REGISTRY.Quantity(mag, units) 494 | 495 | 496 | for func_str in ("cross", "dot", "outer"): 497 | implement_mul_func(func_str) 498 | 499 | 500 | def implement_consistent_units_by_argument(func_str, unit_arguments, wrap_output=True): 501 | if "." not in func_str: 502 | func = getattr(jnp, func_str, None) 503 | else: 504 | parts = func_str.split(".") 505 | module = jnp 506 | for part in parts[:-1]: 507 | module = getattr(module, part, None) 508 | func = getattr(module, parts[-1], None) 509 | 510 | @implements(func_str) 511 | def implementation(*args, **kwargs): 512 | bound_args = signature(func).bind(*args, **kwargs) # type: ignore 513 | valid_unit_arguments = [ 514 | label 515 | for label in unit_arguments 516 | if label in bound_args.arguments and bound_args.arguments[label] is not None 517 | ] 518 | unwrapped_unit_args, output_wrap = numpy_func.unwrap_and_wrap_consistent_units( 519 | *(bound_args.arguments[label] for label in valid_unit_arguments) 520 | ) 521 | for i, unwrapped_unit_arg in enumerate(unwrapped_unit_args): 522 | bound_args.arguments[valid_unit_arguments[i]] = unwrapped_unit_arg 523 | ret = func(*bound_args.args, **bound_args.kwargs) # type: ignore 524 | 525 | if wrap_output: 526 | return output_wrap(ret) 527 | return ret 528 | 529 | 530 | for func_str, unit_arguments, wrap_output in ( 531 | ("expand_dims", "a", True), 532 | ("squeeze", "a", True), 533 | ("rollaxis", "a", True), 534 | ("moveaxis", "a", True), 535 | ("around", "a", True), 536 | ("diagonal", "a", True), 537 | ("mean", "a", True), 538 | ("ptp", "a", True), 539 | ("ravel", "a", True), 540 | ("repeat", "a", True), 541 | # ("round_", "a", True), 542 | ("round", "a", True), 543 | ("take", "a", True), 544 | ("sort", "a", True), 545 | ("median", "a", True), 546 | ("nanmedian", "a", True), 547 | ("transpose", "a", True), 548 | ("copy", "a", True), 549 | ("average", "a", True), 550 | ("nanmean", "a", True), 551 | ("swapaxes", "a", True), 552 | ("nanmin", "a", True), 553 | ("nanmax", "a", True), 554 | ("percentile", "a", True), 555 | ("nanpercentile", "a", True), 556 | ("quantile", "a", True), 557 | ("nanquantile", "a", True), 558 | ("flip", "m", True), 559 | ("fix", "x", True), 560 | ("trim_zeros", ["filt"], True), 561 | ("broadcast_to", ["array"], True), 562 | ("amax", ["a", "initial"], True), 563 | ("amin", ["a", "initial"], True), 564 | ("max", ["a", "initial"], True), 565 | ("min", ["a", "initial"], True), 566 | ("searchsorted", ["a", "v"], False), 567 | ("nan_to_num", ["x", "nan", "posinf", "neginf"], True), 568 | ("clip", ["a", "a_min", "a_max"], True), 569 | ("append", ["arr", "values"], True), 570 | ("compress", "a", True), 571 | ("linspace", ["start", "stop"], True), 572 | ("tile", "A", True), 573 | # ("lib.stride_tricks.sliding_window_view", "x", True), 574 | ("rot90", "m", True), 575 | ("insert", ["arr", "values"], True), 576 | ("delete", ["arr"], True), 577 | ("resize", "a", True), 578 | ("reshape", "a", True), 579 | ("intersect1d", ["ar1", "ar2"], True), 580 | ): 581 | implement_consistent_units_by_argument(func_str, unit_arguments, wrap_output) 582 | 583 | 584 | def implement_close(func_str): 585 | func = getattr(jnp, func_str) 586 | 587 | @implements(func_str) 588 | def implementation(*args, **kwargs): 589 | bound_args = signature(func).bind(*args, **kwargs) 590 | labels = ["a", "b"] 591 | arrays = {label: bound_args.arguments[label] for label in labels} 592 | if "atol" in bound_args.arguments: 593 | atol = bound_args.arguments["atol"] 594 | a = arrays["a"] 595 | if not hasattr(atol, "_REGISTRY") and hasattr(a, "_REGISTRY"): 596 | atol_ = a._REGISTRY.Quantity(atol, a.units) 597 | else: 598 | atol_ = atol 599 | arrays["atol"] = atol_ 600 | 601 | args, _ = numpy_func.unwrap_and_wrap_consistent_units(*arrays.values()) 602 | for label, value in zip(arrays.keys(), args): 603 | bound_args.arguments[label] = value 604 | 605 | return func(*bound_args.args, **bound_args.kwargs) 606 | 607 | 608 | for func_str in ("isclose", "allclose"): 609 | implement_close(func_str) 610 | 611 | 612 | def implement_atleast_nd(func_str): 613 | func = getattr(jnp, func_str) 614 | 615 | @implements(func_str) 616 | def implementation(*arrays): 617 | stripped_arrays, _ = numpy_func.convert_to_consistent_units(*arrays) 618 | arrays_magnitude = func(*stripped_arrays) 619 | if len(arrays) > 1: 620 | return [ 621 | array_magnitude 622 | if not hasattr(original, "_REGISTRY") 623 | else original._REGISTRY.Quantity(array_magnitude, original.units) 624 | for array_magnitude, original in zip(arrays_magnitude, arrays) 625 | ] 626 | else: 627 | output_unit = arrays[0].units 628 | return output_unit._REGISTRY.Quantity(arrays_magnitude, output_unit) 629 | 630 | 631 | for func_str in ("atleast_1d", "atleast_2d", "atleast_3d"): 632 | implement_atleast_nd(func_str) 633 | 634 | 635 | def implement_single_dimensionless_argument_func(func_str): 636 | func = getattr(jnp, func_str) 637 | 638 | @implements(func_str) 639 | def implementation(a, *args, **kwargs): 640 | (a_stripped,), _ = numpy_func.convert_to_consistent_units( 641 | a, pre_calc_units=a._REGISTRY.parse_units("dimensionless") 642 | ) 643 | return a._REGISTRY.Quantity(func(a_stripped, *args, **kwargs)) 644 | 645 | 646 | for func_str in ( 647 | "cumprod", 648 | # "cumproduct", # deprecated 649 | "nancumprod", 650 | ): 651 | implement_single_dimensionless_argument_func(func_str) 652 | 653 | 654 | # ** ***************************** ** 655 | # ** FUNCTIONS NOT COVERED BY PINT ** 656 | # ** ***************************** ** 657 | 658 | 659 | @implements("argpartition") 660 | def _argpartition(a, *args, **kwargs): 661 | (a,), output_wrap = numpy_func.unwrap_and_wrap_consistent_units(a) 662 | return output_wrap(jnp.argpartition(a, *args, **kwargs)) 663 | 664 | 665 | @implements("choose") 666 | def _choose(a, *args, **kwargs): 667 | (a,), output_wrap = numpy_func.unwrap_and_wrap_consistent_units(a) 668 | return output_wrap(jnp.choose(a, *args, **kwargs)) # type: ignore 669 | -------------------------------------------------------------------------------- /src/jpu/monkey.py: -------------------------------------------------------------------------------- 1 | import types 2 | from functools import singledispatch 3 | 4 | import jax.numpy as jnp 5 | 6 | import jpu.numpy as jpunp 7 | from jpu.registry import UnitRegistry 8 | 9 | 10 | def patch(): 11 | """ 12 | Replace all the supported methods in jax.numpy with the unit-aware versions 13 | """ 14 | funcs = _get_namespace_functions(jpunp) 15 | for name, func in funcs.items(): 16 | if name.startswith("_"): 17 | continue 18 | jfunc = singledispatch(getattr(jnp, name)) 19 | setattr(jnp, name, jfunc) 20 | jfunc.register(UnitRegistry.Quantity)(func) 21 | 22 | 23 | def _get_namespace_functions(module): 24 | module_fns = {} 25 | for key in dir(module): 26 | if key in ("__getattr__", "__dir__"): 27 | continue 28 | try: 29 | attr = getattr(module, key) 30 | except Exception: 31 | continue 32 | if isinstance( 33 | attr, 34 | ( 35 | types.BuiltinFunctionType, 36 | types.FunctionType, 37 | types.BuiltinMethodType, 38 | types.MethodType, 39 | ), 40 | ): 41 | module_fns[key] = attr 42 | return module_fns 43 | -------------------------------------------------------------------------------- /src/jpu/numpy/__init__.py: -------------------------------------------------------------------------------- 1 | from jpu import jax_numpy_func as jax_numpy_func 2 | -------------------------------------------------------------------------------- /src/jpu/numpy/__init__.pyi: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Sequence 4 | from typing import Any, Callable, Literal, TypeVar, overload 5 | 6 | import numpy as _np 7 | from jax._src.lax.lax import PrecisionLike 8 | from jax._src.typing import DimSize, DType, DTypeLike, DuckTypedArray, Shape 9 | 10 | from jpu.registry import Quantity 11 | 12 | ArrayLike = Quantity 13 | Array = Quantity 14 | 15 | _T = TypeVar("_T") 16 | _Axis = None | int | Sequence[int] 17 | 18 | def abs(x: ArrayLike, /) -> Array: ... 19 | def absolute(x: ArrayLike, /) -> Array: ... 20 | def add(x: ArrayLike, y: ArrayLike, /) -> Array: ... 21 | def amax( 22 | a: ArrayLike, 23 | axis: _Axis = ..., 24 | out: None = ..., 25 | keepdims: bool = ..., 26 | initial: ArrayLike | None = ..., 27 | where: ArrayLike | None = ..., 28 | ) -> Array: ... 29 | def amin( 30 | a: ArrayLike, 31 | axis: _Axis = ..., 32 | out: None = ..., 33 | keepdims: bool = ..., 34 | initial: ArrayLike | None = ..., 35 | where: ArrayLike | None = ..., 36 | ) -> Array: ... 37 | def all( 38 | a: ArrayLike, 39 | axis: _Axis = ..., 40 | out: None = ..., 41 | keepdims: bool = ..., 42 | *, 43 | where: ArrayLike | None = ..., 44 | ) -> Array: ... 45 | def allclose( 46 | a: ArrayLike, 47 | b: ArrayLike, 48 | rtol: ArrayLike = ..., 49 | atol: ArrayLike = ..., 50 | equal_nan: bool = ..., 51 | ) -> Array: ... 52 | def any( 53 | a: ArrayLike, 54 | axis: _Axis = ..., 55 | out: None = ..., 56 | keepdims: bool = ..., 57 | *, 58 | where: ArrayLike | None = ..., 59 | ) -> Array: ... 60 | def append(arr: ArrayLike, values: ArrayLike, axis: int | None = ...) -> Array: ... 61 | def arange( 62 | start: DimSize, 63 | stop: DimSize | None = ..., 64 | step: DimSize | None = ..., 65 | dtype: DTypeLike | None = ..., 66 | ) -> Array: ... 67 | def arccos(x: ArrayLike, /) -> Array: ... 68 | def arccosh(x: ArrayLike, /) -> Array: ... 69 | def arcsin(x: ArrayLike, /) -> Array: ... 70 | def arcsinh(x: ArrayLike, /) -> Array: ... 71 | def arctan(x: ArrayLike, /) -> Array: ... 72 | def arctan2(x: ArrayLike, y: ArrayLike, /) -> Array: ... 73 | def arctanh(x: ArrayLike, /) -> Array: ... 74 | def argmax( 75 | a: ArrayLike, 76 | axis: int | None = ..., 77 | out: None = ..., 78 | keepdims: bool | None = ..., 79 | ) -> Array: ... 80 | def argmin( 81 | a: ArrayLike, 82 | axis: int | None = ..., 83 | out: None = ..., 84 | keepdims: bool | None = ..., 85 | ) -> Array: ... 86 | def argpartition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ... 87 | def argsort( 88 | a: ArrayLike, 89 | axis: int | None = ..., 90 | kind: str | None = ..., 91 | order: None = ..., 92 | *, 93 | stable: bool = ..., 94 | descending: bool = ..., 95 | ) -> Array: ... 96 | 97 | around = round 98 | 99 | def asin(x: ArrayLike, /) -> Array: ... 100 | def asinh(x: ArrayLike, /) -> Array: ... 101 | def astype(a: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = ...) -> Array: ... 102 | def atan(x: ArrayLike, /) -> Array: ... 103 | def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ... 104 | def atanh(x: ArrayLike, /) -> Array: ... 105 | @overload 106 | def atleast_1d() -> list[Array]: ... 107 | @overload 108 | def atleast_1d(x: ArrayLike, /) -> Array: ... 109 | @overload 110 | def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... 111 | @overload 112 | def atleast_2d() -> list[Array]: ... 113 | @overload 114 | def atleast_2d(x: ArrayLike, /) -> Array: ... 115 | @overload 116 | def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... 117 | @overload 118 | def atleast_3d() -> list[Array]: ... 119 | @overload 120 | def atleast_3d(x: ArrayLike, /) -> Array: ... 121 | @overload 122 | def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... 123 | @overload 124 | def average( 125 | a: ArrayLike, 126 | axis: _Axis = ..., 127 | weights: ArrayLike | None = ..., 128 | returned: Literal[False] = False, 129 | keepdims: bool = False, 130 | ) -> Array: ... 131 | @overload 132 | def average( 133 | a: ArrayLike, 134 | axis: _Axis = ..., 135 | weights: ArrayLike | None = ..., 136 | *, 137 | returned: Literal[True], 138 | keepdims: bool = False, 139 | ) -> tuple[Array, Array]: ... 140 | @overload 141 | def average( 142 | a: ArrayLike, 143 | axis: _Axis = ..., 144 | weights: ArrayLike | None = ..., 145 | returned: bool = False, 146 | keepdims: bool = False, 147 | ) -> Array | tuple[Array, Array]: ... 148 | def block( 149 | arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]], 150 | ) -> Array: ... 151 | def broadcast_arrays(*args: ArrayLike) -> list[Array]: ... 152 | def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: ... 153 | def cbrt(x: ArrayLike, /) -> Array: ... 154 | def ceil(x: ArrayLike, /) -> Array: ... 155 | def choose( 156 | a: ArrayLike, choices: Sequence[ArrayLike], out: None = ..., mode: str = ... 157 | ) -> Array: ... 158 | def clip( 159 | a: ArrayLike, 160 | a_min: ArrayLike | None = ..., 161 | a_max: ArrayLike | None = ..., 162 | out: None = ..., 163 | ) -> Array: ... 164 | def column_stack(tup: _np.ndarray | Array | Sequence[ArrayLike]) -> Array: ... 165 | def compress( 166 | condition: ArrayLike, a: ArrayLike, axis: int | None = ..., out: None = ... 167 | ) -> Array: ... 168 | def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: ... 169 | def concatenate( 170 | arrays: _np.ndarray | Array | Sequence[ArrayLike], 171 | axis: int | None = ..., 172 | dtype: DTypeLike | None = ..., 173 | ) -> Array: ... 174 | def conjugate(x: ArrayLike, /) -> Array: ... 175 | 176 | conj = conjugate 177 | 178 | def convolve( 179 | a: ArrayLike, 180 | v: ArrayLike, 181 | mode: str = ..., 182 | *, 183 | precision: PrecisionLike = ..., 184 | preferred_element_type: DType | None = ..., 185 | ) -> Array: ... 186 | def copy(a: ArrayLike, order: str | None = ...) -> Array: ... 187 | def copysign(x: ArrayLike, y: ArrayLike, /) -> Array: ... 188 | def corrcoef(x: ArrayLike, y: ArrayLike | None = ..., rowvar: bool = ...) -> Array: ... 189 | def correlate( 190 | a: ArrayLike, 191 | v: ArrayLike, 192 | mode: str = ..., 193 | *, 194 | precision: PrecisionLike = ..., 195 | preferred_element_type: DType | None = ..., 196 | ) -> Array: ... 197 | def cos(x: ArrayLike, /) -> Array: ... 198 | def cosh(x: ArrayLike, /) -> Array: ... 199 | def count_nonzero(a: ArrayLike, axis: _Axis = ..., keepdims: bool = ...) -> Array: ... 200 | def cov( 201 | m: ArrayLike, 202 | y: ArrayLike | None = ..., 203 | rowvar: bool = ..., 204 | bias: bool = ..., 205 | ddof: int | None = ..., 206 | fweights: ArrayLike | None = ..., 207 | aweights: ArrayLike | None = ..., 208 | ) -> Array: ... 209 | def cross( 210 | a: ArrayLike, 211 | b: ArrayLike, 212 | axisa: int = -1, 213 | axisb: int = -1, 214 | axisc: int = -1, 215 | axis: int | None = ..., 216 | ) -> Array: ... 217 | def cumprod( 218 | a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ... 219 | ) -> Array: ... 220 | def cumsum( 221 | a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ... 222 | ) -> Array: ... 223 | def deg2rad(x: ArrayLike, /) -> Array: ... 224 | def delete( 225 | arr: ArrayLike, 226 | obj: ArrayLike | slice, 227 | axis: int | None = ..., 228 | *, 229 | assume_unique_indices: bool = ..., 230 | ) -> Array: ... 231 | def diag(v: ArrayLike, k: int = 0) -> Array: ... 232 | def diag_indices(n: int, ndim: int = ...) -> tuple[Array, ...]: ... 233 | def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: ... 234 | def diagflat(v: ArrayLike, k: int = 0) -> Array: ... 235 | def diagonal( 236 | a: ArrayLike, offset: ArrayLike = ..., axis1: int = ..., axis2: int = ... 237 | ): ... 238 | def diff( 239 | a: ArrayLike, 240 | n: int = ..., 241 | axis: int = ..., 242 | prepend: ArrayLike | None = ..., 243 | append: ArrayLike | None = ..., 244 | ) -> Array: ... 245 | def digitize(x: ArrayLike, bins: ArrayLike, right: bool = ...) -> Array: ... 246 | def divmod(x: ArrayLike, y: ArrayLike, /) -> tuple[Array, Array]: ... 247 | def dot( 248 | a: ArrayLike, 249 | b: ArrayLike, 250 | *, 251 | precision: PrecisionLike = ..., 252 | preferred_element_type: DTypeLike | None = ..., 253 | ) -> Array: ... 254 | def dsplit(ary: ArrayLike, indices_or_sections: int | ArrayLike) -> list[Array]: ... 255 | def dstack( 256 | tup: _np.ndarray | Array | Sequence[ArrayLike], 257 | dtype: DTypeLike | None = ..., 258 | ) -> Array: ... 259 | def ediff1d( 260 | ary: ArrayLike, 261 | to_end: ArrayLike | None = ..., 262 | to_begin: ArrayLike | None = ..., 263 | ) -> Array: ... 264 | @overload 265 | def einsum( 266 | subscript: str, 267 | /, 268 | *operands: ArrayLike, 269 | out: None = ..., 270 | optimize: str = "optimal", 271 | precision: PrecisionLike = ..., 272 | preferred_element_type: DTypeLike | None = ..., 273 | _use_xeinsum: bool = False, 274 | _dot_general: Callable[..., Array] = ..., 275 | ) -> Array: ... 276 | @overload 277 | def einsum( 278 | arr: ArrayLike, 279 | axes: Sequence[Any], 280 | /, 281 | *operands: ArrayLike | Sequence[Any], 282 | out: None = ..., 283 | optimize: str = "optimal", 284 | precision: PrecisionLike = ..., 285 | preferred_element_type: DTypeLike | None = ..., 286 | _use_xeinsum: bool = False, 287 | _dot_general: Callable[..., Array] = ..., 288 | ) -> Array: ... 289 | @overload 290 | def einsum( 291 | subscripts, 292 | /, 293 | *operands, 294 | out: None = ..., 295 | optimize: str = ..., 296 | precision: PrecisionLike = ..., 297 | preferred_element_type: DTypeLike | None = ..., 298 | _use_xeinsum: bool = ..., 299 | _dot_general: Callable[..., Array] = ..., 300 | ) -> Array: ... 301 | def einsum_path(subscripts, *operands, optimize=...): ... 302 | def empty(shape: Any, dtype: DTypeLike | None = ...) -> Array: ... 303 | def empty_like( 304 | prototype: ArrayLike | DuckTypedArray, 305 | dtype: DTypeLike | None = ..., 306 | shape: Any = ..., 307 | ) -> Array: ... 308 | def equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... 309 | def exp(x: ArrayLike, /) -> Array: ... 310 | def exp2(x: ArrayLike, /) -> Array: ... 311 | def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: ... 312 | def expm1(x: ArrayLike, /) -> Array: ... 313 | def extract(condition: ArrayLike, arr: ArrayLike) -> Array: ... 314 | def eye( 315 | N: DimSize, 316 | M: DimSize | None = ..., 317 | k: int = ..., 318 | dtype: DTypeLike | None = ..., 319 | ) -> Array: ... 320 | def fabs(x: ArrayLike, /) -> Array: ... 321 | def fix(x: ArrayLike, out: None = ...) -> Array: ... 322 | def flatnonzero( 323 | a: ArrayLike, 324 | *, 325 | size: int | None = ..., 326 | fill_value: None | ArrayLike | tuple[ArrayLike] = ..., 327 | ) -> Array: ... 328 | def flip(m: ArrayLike, axis: int | Sequence[int] | None = ...) -> Array: ... 329 | def floor(x: ArrayLike, /) -> Array: ... 330 | def floor_divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... 331 | def fmax(x: ArrayLike, y: ArrayLike, /) -> Array: ... 332 | def fmin(x: ArrayLike, y: ArrayLike, /) -> Array: ... 333 | def fmod(x: ArrayLike, y: ArrayLike, /) -> Array: ... 334 | def frexp(x: ArrayLike, /) -> tuple[Array, Array]: ... 335 | def full(shape: Any, fill_value: ArrayLike, dtype: DTypeLike | None = ...) -> Array: ... 336 | def full_like( 337 | a: ArrayLike | DuckTypedArray, 338 | fill_value: ArrayLike, 339 | dtype: DTypeLike | None = ..., 340 | shape: Any = ..., 341 | ) -> Array: ... 342 | def gradient( 343 | f: ArrayLike, 344 | *varargs: ArrayLike, 345 | axis: int | Sequence[int] | None = ..., 346 | edge_order: int | None = ..., 347 | ) -> Array | list[Array]: ... 348 | def greater(x: ArrayLike, y: ArrayLike, /) -> Array: ... 349 | def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... 350 | def hstack( 351 | tup: _np.ndarray | Array | Sequence[ArrayLike], 352 | dtype: DTypeLike | None = ..., 353 | ) -> Array: ... 354 | def hypot(x: ArrayLike, y: ArrayLike, /) -> Array: ... 355 | def identity(n: DimSize, dtype: DTypeLike | None = ...) -> Array: ... 356 | def imag(x: ArrayLike, /) -> Array: ... 357 | def inner( 358 | a: ArrayLike, 359 | b: ArrayLike, 360 | *, 361 | precision: PrecisionLike = ..., 362 | preferred_element_type: DTypeLike | None = ..., 363 | ) -> Array: ... 364 | def insert( 365 | arr: ArrayLike, 366 | obj: ArrayLike | slice, 367 | values: ArrayLike, 368 | axis: int | None = ..., 369 | ) -> Array: ... 370 | def interp( 371 | x: ArrayLike, 372 | xp: ArrayLike, 373 | fp: ArrayLike, 374 | left: ArrayLike | str | None = ..., 375 | right: ArrayLike | str | None = ..., 376 | period: ArrayLike | None = ..., 377 | ) -> Array: ... 378 | def intersect1d( 379 | ar1: ArrayLike, 380 | ar2: ArrayLike, 381 | assume_unique: bool = ..., 382 | return_indices: bool = ..., 383 | ) -> Array | tuple[Array, Array, Array]: ... 384 | def invert(x: ArrayLike, /) -> Array: ... 385 | def isclose( 386 | a: ArrayLike, 387 | b: ArrayLike, 388 | rtol: ArrayLike = ..., 389 | atol: ArrayLike = ..., 390 | equal_nan: bool = ..., 391 | ) -> Array: ... 392 | def iscomplex(m: ArrayLike) -> Array: ... 393 | def isfinite(x: ArrayLike, /) -> Array: ... 394 | def isin( 395 | element: ArrayLike, 396 | test_elements: ArrayLike, 397 | assume_unique: bool = ..., 398 | invert: bool = ..., 399 | ) -> Array: ... 400 | def isinf(x: ArrayLike, /) -> Array: ... 401 | def isnan(x: ArrayLike, /) -> Array: ... 402 | def isreal(m: ArrayLike) -> Array: ... 403 | def ldexp(x: ArrayLike, y: ArrayLike, /) -> Array: ... 404 | def less(x: ArrayLike, y: ArrayLike, /) -> Array: ... 405 | def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... 406 | @overload 407 | def linspace( 408 | start: ArrayLike, 409 | stop: ArrayLike, 410 | num: int = 50, 411 | endpoint: bool = True, 412 | retstep: Literal[False] = False, 413 | dtype: DTypeLike | None = ..., 414 | axis: int = 0, 415 | ) -> Array: ... 416 | @overload 417 | def linspace( 418 | start: ArrayLike, 419 | stop: ArrayLike, 420 | num: int, 421 | endpoint: bool, 422 | retstep: Literal[True], 423 | dtype: DTypeLike | None = ..., 424 | axis: int = 0, 425 | ) -> tuple[Array, Array]: ... 426 | @overload 427 | def linspace( 428 | start: ArrayLike, 429 | stop: ArrayLike, 430 | num: int = 50, 431 | endpoint: bool = True, 432 | *, 433 | retstep: Literal[True], 434 | dtype: DTypeLike | None = ..., 435 | axis: int = 0, 436 | ) -> tuple[Array, Array]: ... 437 | @overload 438 | def linspace( 439 | start: ArrayLike, 440 | stop: ArrayLike, 441 | num: int = 50, 442 | endpoint: bool = True, 443 | retstep: bool = False, 444 | dtype: DTypeLike | None = ..., 445 | axis: int = 0, 446 | ) -> Array | tuple[Array, Array]: ... 447 | def log(x: ArrayLike, /) -> Array: ... 448 | def log10(x: ArrayLike, /) -> Array: ... 449 | def log1p(x: ArrayLike, /) -> Array: ... 450 | def log2(x: ArrayLike, /) -> Array: ... 451 | def logaddexp(x: ArrayLike, y: ArrayLike, /) -> Array: ... 452 | def logaddexp2(x: ArrayLike, y: ArrayLike, /) -> Array: ... 453 | def matmul( 454 | a: ArrayLike, 455 | b: ArrayLike, 456 | *, 457 | precision: PrecisionLike = ..., 458 | preferred_element_type: DTypeLike | None = ..., 459 | ) -> Array: ... 460 | 461 | max = amax 462 | 463 | def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: ... 464 | def mean( 465 | a: ArrayLike, 466 | axis: _Axis = ..., 467 | dtype: DTypeLike = ..., 468 | out: None = ..., 469 | keepdims: bool = ..., 470 | *, 471 | where: ArrayLike | None = ..., 472 | ) -> Array: ... 473 | def median( 474 | a: ArrayLike, 475 | axis: int | tuple[int, ...] | None = ..., 476 | out: None = ..., 477 | overwrite_input: bool = ..., 478 | keepdims: bool = ..., 479 | ) -> Array: ... 480 | 481 | min = amin 482 | 483 | def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: ... 484 | def mod(x: ArrayLike, y: ArrayLike, /) -> Array: ... 485 | def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ... 486 | def moveaxis( 487 | a: ArrayLike, 488 | source: int | Sequence[int], 489 | destination: int | Sequence[int], 490 | ) -> Array: ... 491 | def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: ... 492 | def nan_to_num( 493 | x: ArrayLike, 494 | copy: bool = ..., 495 | nan: ArrayLike = ..., 496 | posinf: ArrayLike | None = ..., 497 | neginf: ArrayLike | None = ..., 498 | ) -> Array: ... 499 | def nanargmax( 500 | a: ArrayLike, 501 | axis: int | None = ..., 502 | out: None = ..., 503 | keepdims: bool | None = ..., 504 | ) -> Array: ... 505 | def nanargmin( 506 | a: ArrayLike, 507 | axis: int | None = ..., 508 | out: None = ..., 509 | keepdims: bool | None = ..., 510 | ) -> Array: ... 511 | def nancumprod( 512 | a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ... 513 | ) -> Array: ... 514 | def nancumsum( 515 | a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ... 516 | ) -> Array: ... 517 | def nanmax( 518 | a: ArrayLike, 519 | axis: _Axis = ..., 520 | out: None = ..., 521 | keepdims: bool = ..., 522 | initial: ArrayLike | None = ..., 523 | where: ArrayLike | None = ..., 524 | ) -> Array: ... 525 | def nanmean( 526 | a: ArrayLike, 527 | axis: _Axis = ..., 528 | dtype: DTypeLike = ..., 529 | out: None = ..., 530 | keepdims: bool = ..., 531 | where: ArrayLike | None = ..., 532 | ) -> Array: ... 533 | def nanmedian( 534 | a: ArrayLike, 535 | axis: int | tuple[int, ...] | None = ..., 536 | out: None = ..., 537 | overwrite_input: bool = ..., 538 | keepdims: bool = ..., 539 | ) -> Array: ... 540 | def nanmin( 541 | a: ArrayLike, 542 | axis: _Axis = ..., 543 | out: None = ..., 544 | keepdims: bool = ..., 545 | initial: ArrayLike | None = ..., 546 | where: ArrayLike | None = ..., 547 | ) -> Array: ... 548 | def nanpercentile( 549 | a: ArrayLike, 550 | q: ArrayLike, 551 | axis: int | tuple[int, ...] | None = ..., 552 | out: None = ..., 553 | overwrite_input: bool = ..., 554 | method: str = ..., 555 | keepdims: bool = ..., 556 | interpolation: None = ..., 557 | ) -> Array: ... 558 | def nanprod( 559 | a: ArrayLike, 560 | axis: _Axis = ..., 561 | dtype: DTypeLike = ..., 562 | out: None = ..., 563 | keepdims: bool = ..., 564 | initial: ArrayLike | None = ..., 565 | where: ArrayLike | None = ..., 566 | ) -> Array: ... 567 | def nanquantile( 568 | a: ArrayLike, 569 | q: ArrayLike, 570 | axis: int | tuple[int, ...] | None = ..., 571 | out: None = ..., 572 | overwrite_input: bool = ..., 573 | method: str = ..., 574 | keepdims: bool = ..., 575 | interpolation: None = ..., 576 | ) -> Array: ... 577 | def nanstd( 578 | a: ArrayLike, 579 | axis: _Axis = ..., 580 | dtype: DTypeLike = ..., 581 | out: None = ..., 582 | ddof: int = ..., 583 | keepdims: bool = ..., 584 | where: ArrayLike | None = ..., 585 | ) -> Array: ... 586 | def nansum( 587 | a: ArrayLike, 588 | axis: _Axis = ..., 589 | dtype: DTypeLike = ..., 590 | out: None = ..., 591 | keepdims: bool = ..., 592 | initial: ArrayLike | None = ..., 593 | where: ArrayLike | None = ..., 594 | ) -> Array: ... 595 | def nanvar( 596 | a: ArrayLike, 597 | axis: _Axis = ..., 598 | dtype: DTypeLike = ..., 599 | out: None = ..., 600 | ddof: int = 0, 601 | keepdims: bool = False, 602 | where: ArrayLike | None = ..., 603 | ) -> Array: ... 604 | 605 | ndim = _np.ndim 606 | 607 | def negative(x: ArrayLike, /) -> Array: ... 608 | def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ... 609 | def nonzero( 610 | a: ArrayLike, 611 | *, 612 | size: int | None = ..., 613 | fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ..., 614 | ) -> tuple[Array, ...]: ... 615 | def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... 616 | def ones(shape: Any, dtype: DTypeLike | None = ...) -> Array: ... 617 | def ones_like( 618 | a: ArrayLike | DuckTypedArray, 619 | dtype: DTypeLike | None = ..., 620 | shape: Any = ..., 621 | ) -> Array: ... 622 | def outer(a: ArrayLike, b: Array, out: None = ...) -> Array: ... 623 | 624 | PadValueLike = _T | Sequence[_T] | Sequence[Sequence[_T]] 625 | 626 | def pad( 627 | array: ArrayLike, 628 | pad_width: PadValueLike[int | Array | _np.ndarray], 629 | mode: str | Callable[..., Any] = ..., 630 | **kwargs, 631 | ) -> Array: ... 632 | def partition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ... 633 | def percentile( 634 | a: ArrayLike, 635 | q: ArrayLike, 636 | axis: int | tuple[int, ...] | None = ..., 637 | out: None = ..., 638 | overwrite_input: bool = ..., 639 | method: str = ..., 640 | keepdims: bool = ..., 641 | interpolation: None = ..., 642 | ) -> Array: ... 643 | def positive(x: ArrayLike, /) -> Array: ... 644 | def pow(x: ArrayLike, y: ArrayLike, /) -> Array: ... 645 | def power(x: ArrayLike, y: ArrayLike, /) -> Array: ... 646 | def prod( 647 | a: ArrayLike, 648 | axis: _Axis = ..., 649 | dtype: DTypeLike = ..., 650 | out: None = ..., 651 | keepdims: bool = ..., 652 | initial: ArrayLike | None = ..., 653 | where: ArrayLike | None = ..., 654 | promote_integers: bool = ..., 655 | ) -> Array: ... 656 | def ptp( 657 | a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: bool = ... 658 | ) -> Array: ... 659 | def quantile( 660 | a: ArrayLike, 661 | q: ArrayLike, 662 | axis: int | tuple[int, ...] | None = ..., 663 | out: None = ..., 664 | overwrite_input: bool = ..., 665 | method: str = ..., 666 | keepdims: bool = ..., 667 | interpolation: None = ..., 668 | ) -> Array: ... 669 | def rad2deg(x: ArrayLike, /) -> Array: ... 670 | def ravel(a: ArrayLike, order: str = ...) -> Array: ... 671 | def real(x: ArrayLike, /) -> Array: ... 672 | def reciprocal(x: ArrayLike, /) -> Array: ... 673 | def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ... 674 | def repeat( 675 | a: ArrayLike, 676 | repeats: ArrayLike, 677 | axis: int | None = ..., 678 | *, 679 | total_repeat_length: int | None = ..., 680 | ) -> Array: ... 681 | def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = ...) -> Array: ... 682 | def resize(a: ArrayLike, new_shape: Shape) -> Array: ... 683 | def result_type(*args: Any) -> DType: ... 684 | def rint(x: ArrayLike, /) -> Array: ... 685 | def roll( 686 | a: ArrayLike, 687 | shift: ArrayLike | Sequence[int], 688 | axis: int | Sequence[int] | None = ..., 689 | ) -> Array: ... 690 | def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: ... 691 | def rot90(m: ArrayLike, k: int = ..., axes: tuple[int, int] = ...) -> Array: ... 692 | def round(a: ArrayLike, decimals: int = ..., out: None = ...) -> Array: ... 693 | 694 | round_ = round 695 | 696 | def searchsorted( 697 | a: ArrayLike, 698 | v: ArrayLike, 699 | side: str = ..., 700 | sorter: None = ..., 701 | *, 702 | method: str = ..., 703 | ) -> Array: ... 704 | 705 | shape = _np.shape 706 | 707 | def sign(x: ArrayLike, /) -> Array: ... 708 | def signbit(x: ArrayLike, /) -> Array: ... 709 | def sin(x: ArrayLike, /) -> Array: ... 710 | def sinh(x: ArrayLike, /) -> Array: ... 711 | 712 | size = _np.size 713 | 714 | def sort( 715 | a: ArrayLike, 716 | axis: int | None = ..., 717 | kind: str | None = ..., 718 | order: None = ..., 719 | *, 720 | stable: bool = ..., 721 | descending: bool = ..., 722 | ) -> Array: ... 723 | def sqrt(x: ArrayLike, /) -> Array: ... 724 | def square(x: ArrayLike, /) -> Array: ... 725 | def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = ...) -> Array: ... 726 | def stack( 727 | arrays: _np.ndarray | Array | Sequence[ArrayLike], 728 | axis: int = ..., 729 | out: None = ..., 730 | dtype: DTypeLike | None = ..., 731 | ) -> Array: ... 732 | def std( 733 | a: ArrayLike, 734 | axis: _Axis = ..., 735 | dtype: DTypeLike = ..., 736 | out: None = ..., 737 | ddof: int = ..., 738 | keepdims: bool = ..., 739 | *, 740 | where: ArrayLike | None = ..., 741 | ) -> Array: ... 742 | def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: ... 743 | def sum( 744 | a: ArrayLike, 745 | axis: _Axis = ..., 746 | dtype: DTypeLike = ..., 747 | out: None = ..., 748 | keepdims: bool = ..., 749 | initial: ArrayLike | None = ..., 750 | where: ArrayLike | None = ..., 751 | promote_integers: bool = ..., 752 | ) -> Array: ... 753 | def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: ... 754 | def take( 755 | a: ArrayLike, 756 | indices: ArrayLike, 757 | axis: int | None = ..., 758 | out: None = ..., 759 | mode: str | None = ..., 760 | unique_indices: bool = ..., 761 | indices_are_sorted: bool = ..., 762 | fill_value: ArrayLike | None = ..., 763 | ) -> Array: ... 764 | def tan(x: ArrayLike, /) -> Array: ... 765 | def tanh(x: ArrayLike, /) -> Array: ... 766 | def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: ... 767 | def trace( 768 | a: ArrayLike, 769 | offset: int = ..., 770 | axis1: int = ..., 771 | axis2: int = ..., 772 | dtype: DTypeLike | None = ..., 773 | out: None = ..., 774 | ) -> Array: ... 775 | def transpose(a: ArrayLike, axes: Sequence[int] | None = ...) -> Array: ... 776 | def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ... 777 | def true_divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... 778 | def trunc(x: ArrayLike, /) -> Array: ... 779 | def unwrap( 780 | p: ArrayLike, 781 | discont: ArrayLike | None = ..., 782 | axis: int = ..., 783 | period: ArrayLike = ..., 784 | ) -> Array: ... 785 | def var( 786 | a: ArrayLike, 787 | axis: _Axis = ..., 788 | dtype: DTypeLike = ..., 789 | out: None = ..., 790 | ddof: int = ..., 791 | keepdims: bool = ..., 792 | *, 793 | where: ArrayLike | None = ..., 794 | ) -> Array: ... 795 | def vstack( 796 | tup: _np.ndarray | Array | Sequence[ArrayLike], 797 | dtype: DTypeLike | None = ..., 798 | ) -> Array: ... 799 | @overload 800 | def where( 801 | condition: ArrayLike, 802 | x: Literal[None] = ..., 803 | y: Literal[None] = ..., 804 | /, 805 | *, 806 | size: int | None = ..., 807 | fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ..., 808 | ) -> tuple[Array, ...]: ... 809 | @overload 810 | def where( 811 | condition: ArrayLike, 812 | x: ArrayLike, 813 | y: ArrayLike, 814 | /, 815 | *, 816 | size: int | None = ..., 817 | fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ..., 818 | ) -> Array: ... 819 | @overload 820 | def where( 821 | condition: ArrayLike, 822 | x: ArrayLike | None = ..., 823 | y: ArrayLike | None = ..., 824 | /, 825 | *, 826 | size: int | None = ..., 827 | fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ..., 828 | ) -> Array | tuple[Array, ...]: ... 829 | def zeros(shape: Any, dtype: DTypeLike | None = ...) -> Array: ... 830 | def zeros_like( 831 | a: ArrayLike | DuckTypedArray, 832 | dtype: DTypeLike | None = ..., 833 | shape: Any = ..., 834 | ) -> Array: ... 835 | 836 | cumproduct = cumprod 837 | degrees = rad2deg 838 | divide = true_divide 839 | radians = deg2rad 840 | -------------------------------------------------------------------------------- /src/jpu/numpy/linalg.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dfm/jpu/78ace6f097d0a56505dbe0140211370d088c628d/src/jpu/numpy/linalg.py -------------------------------------------------------------------------------- /src/jpu/numpy/linalg.pyi: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | solve = Any 4 | -------------------------------------------------------------------------------- /src/jpu/quantity.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import warnings 3 | from functools import partial 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import pint 8 | 9 | from jpu import numpy as jpu_numpy 10 | 11 | SUPPORTED_NUMPY_METHODS = [ 12 | "all", 13 | "any", 14 | "argmax", 15 | "argmin", 16 | "argpartition", 17 | "argsort", 18 | "choose", 19 | "clip", 20 | "compress", 21 | "conj", 22 | "conjugate", 23 | "copy", 24 | "cumprod", 25 | "cumsum", 26 | "delete", 27 | "diagonal", 28 | "dot", 29 | "max", 30 | "mean", 31 | "min", 32 | "nonzero", 33 | "prod", 34 | "ptp", 35 | "ravel", 36 | "repeat", 37 | "reshape", 38 | "round", 39 | "searchsorted", 40 | "sort", 41 | "squeeze", 42 | "std", 43 | "sum", 44 | "swapaxes", 45 | "take", 46 | "trace", 47 | "transpose", 48 | "var", 49 | ] 50 | SUPPORTED_PASSTHROUGH_METHODS = [ 51 | "astype", 52 | "block_until_ready", 53 | "clone", 54 | "flatten", 55 | "item", 56 | "view", 57 | ] 58 | 59 | 60 | class JpuQuantity(pint.UnitRegistry.Quantity): 61 | def __array__(self, *args, **kwargs): 62 | warnings.warn( 63 | "The unit of a Quantity is stripped when downcasted to an array.", 64 | stacklevel=2, 65 | ) 66 | return self._magnitude.__array__(*args, **kwargs) # type: ignore 67 | 68 | @property 69 | def dtype(self): 70 | return jnp.asarray(self._magnitude).dtype 71 | 72 | @property 73 | def ndim(self): 74 | return jnp.ndim(self._magnitude) # type: ignore 75 | 76 | @property 77 | def shape(self): 78 | return jnp.shape(self._magnitude) # type: ignore 79 | 80 | def _maybe_dimensionless(self, other): 81 | if isinstance(other, jax.Array): 82 | return self._REGISTRY.Quantity(other, "dimensionless") 83 | return other 84 | 85 | def __iadd__(self, other): 86 | return self._add_sub(self._maybe_dimensionless(other), operator.add) 87 | 88 | def __add__(self, other): 89 | return self._add_sub(self._maybe_dimensionless(other), operator.add) 90 | 91 | __radd__ = __add__ 92 | 93 | def __isub__(self, other): 94 | return self._add_sub(self._maybe_dimensionless(other), operator.sub) 95 | 96 | def __sub__(self, other): 97 | return self._add_sub(self._maybe_dimensionless(other), operator.sub) 98 | 99 | def __rsub__(self, other): 100 | return -self._add_sub(self._maybe_dimensionless(other), operator.sub) 101 | 102 | def __len__(self): 103 | return len(self._magnitude) # type: ignore 104 | 105 | def _wrap_passthrough_method(self, name, *args, **kwargs): 106 | return self.__class__( 107 | getattr(self._magnitude, name)(*args, **kwargs), self._units 108 | ) 109 | 110 | def __getitem__(self, key): 111 | return self.__class__(self._magnitude[key], self._units) # type: ignore 112 | 113 | def __getattr__(self, item): 114 | if item in SUPPORTED_NUMPY_METHODS: 115 | return partial(getattr(jpu_numpy, item), self) 116 | elif item in SUPPORTED_PASSTHROUGH_METHODS: 117 | return partial(self._wrap_passthrough_method, item) 118 | try: 119 | return getattr(self._magnitude, item) 120 | except AttributeError: 121 | raise AttributeError( 122 | f"Neither Quantity object nor its magnitude ({self._magnitude}) " 123 | f"has attribute '{item}'" 124 | ) from None 125 | -------------------------------------------------------------------------------- /src/jpu/registry.py: -------------------------------------------------------------------------------- 1 | import pint 2 | from jax.tree_util import register_pytree_node 3 | from pint.compat import TypeAlias 4 | 5 | from jpu.quantity import JpuQuantity 6 | 7 | 8 | class UnitRegistry(pint.registry.GenericUnitRegistry[JpuQuantity, pint.Unit]): 9 | Quantity: TypeAlias = JpuQuantity 10 | Unit: TypeAlias = pint.Unit 11 | 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | 15 | def flatten_quantity(q): 16 | return (q.magnitude,), (q.units, q._REGISTRY) 17 | 18 | def unflatten_quantity(aux_data, children): 19 | (magnitude,) = children 20 | units, registry = aux_data 21 | return registry.Quantity(magnitude, units) 22 | 23 | register_pytree_node(self.Quantity, flatten_quantity, unflatten_quantity) 24 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pytest 4 | from jax._src.public_test_util import check_close 5 | 6 | import jpu.numpy as jnpu 7 | from jpu import UnitRegistry, core 8 | 9 | ureg = UnitRegistry() 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "params", 14 | [ 15 | ( 16 | lambda x: (jnpu.sin(2 * jnp.pi * x / (10.0 * ureg.m)) * ureg.s), 17 | (), 18 | ureg.m, 19 | ureg.s / ureg.m, 20 | ), 21 | ( 22 | lambda x: (1.0 + x.magnitude), 23 | (), 24 | ureg.m, 25 | 1 / ureg.m, 26 | ), 27 | ( 28 | lambda x: jnp.sin(1.0 + x), 29 | (), 30 | None, 31 | None, 32 | ), 33 | ( 34 | lambda x: jnpu.sum(jnpu.sin(2 * jnp.pi * x / (10.0 * ureg.m)) * ureg.s), 35 | (5, 2), 36 | ureg.m, 37 | ureg.s / ureg.m, 38 | ), 39 | ], 40 | ) 41 | def test_grad(params): 42 | func, shape, in_units, grad_units = params 43 | 44 | def inp(x): 45 | return x if in_units is None else x * in_units 46 | 47 | def func_(x): 48 | y = func(inp(x)) 49 | if hasattr(y, "magnitude"): 50 | return y.magnitude 51 | else: 52 | return y 53 | 54 | computed = core.grad(func)(inp(jnp.full(shape, 5.0))) 55 | expected = jax.grad(func_)(jnp.full(shape, 5.0)) 56 | 57 | if grad_units is None: 58 | check_close(computed, expected) 59 | assert not hasattr(computed, "units") 60 | else: 61 | check_close(computed.magnitude, expected) # type: ignore 62 | assert computed.units == grad_units # type: ignore 63 | -------------------------------------------------------------------------------- /tests/test_numpy.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | import operator 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import pint 9 | import pytest 10 | from jax._src.public_test_util import check_close 11 | 12 | import jpu.numpy as jun 13 | from jpu import UnitRegistry 14 | 15 | 16 | def is_quantity(q): 17 | return hasattr(q, "magnitude") and hasattr(q, "units") 18 | 19 | 20 | def assert_quantity_allclose(a, b): 21 | if is_quantity(a): 22 | assert is_quantity(b) 23 | assert str(a.units) == str(b.units) 24 | check_close(a.magnitude, b.magnitude) 25 | else: 26 | assert not is_quantity(b) 27 | check_close(a, b) 28 | 29 | 30 | def test_type_wrapping(): 31 | u = UnitRegistry() 32 | x = jnp.array([1.4, 2.0, -5.9]) 33 | q = x * u.kpc 34 | assert q.units == u.kpc 35 | check_close(q.magnitude, x) 36 | assert type(q.magnitude) == type(x) 37 | 38 | 39 | def test_array_ops(): 40 | u = UnitRegistry() 41 | x = jnp.array([1.4, 2.0, -5.9]) 42 | q = x * u.kpc 43 | 44 | # Addition 45 | res = q + np.array(0.01) * u.Mpc 46 | assert res.units == u.kpc 47 | check_close(res.magnitude, x + 10) 48 | assert type(res.magnitude) == type(x) 49 | 50 | # Different order 51 | res = np.array(0.01) * u.Mpc + q 52 | assert res.units == u.Mpc 53 | check_close(res.magnitude, 1e-3 * (x + 10)) 54 | assert type(res.magnitude) == type(x) 55 | 56 | # Subtraction 57 | res = q - np.array(0.01) * u.Mpc 58 | assert res.units == u.kpc 59 | check_close(res.magnitude, x - 10) 60 | assert type(res.magnitude) == type(x) 61 | 62 | # Multiplication 63 | res = 2 * q 64 | assert res.units == u.kpc 65 | check_close(res.magnitude, 2 * x) 66 | assert type(res.magnitude) == type(x) 67 | 68 | # Division 69 | res = q / (2 * u.kpc) 70 | assert res.units == u.dimensionless 71 | check_close(res.magnitude, 0.5 * x) 72 | assert type(res.magnitude) == type(x) 73 | 74 | 75 | @pytest.mark.parametrize( 76 | "func,in_unit", 77 | [ 78 | ("exp", [""]), 79 | ("log", [""]), 80 | ("sin", ["degree"]), 81 | ("sin", ["radian"]), 82 | ("arctan2", ["m", "m"]), 83 | ("arctan2", ["m", "foot"]), 84 | ("argsort", ["day"]), 85 | ("std", ["day"]), 86 | ("var", ["m"]), 87 | ("dot", ["m", "s"]), 88 | ("median", ["m"]), 89 | ("cumprod", [""]), 90 | ("any", ["kpc"]), 91 | ], 92 | ) 93 | def test_unary(func, in_unit): 94 | f = (lambda x: x**2) if func == "log" else (lambda x: x) 95 | 96 | pu = pint.UnitRegistry() 97 | np_args = [] 98 | for n, iu in enumerate(in_unit): 99 | x = f(np.array([1.4, 2.0, -5.9]) - n) * pu(iu) 100 | np_args.append(x) 101 | 102 | u = UnitRegistry() 103 | jun_args = [] 104 | for n, iu in enumerate(in_unit): 105 | x = f(jnp.array([1.4, 2.0, -5.9]) - n) * u(iu) 106 | jun_args.append(x) 107 | 108 | np_func = getattr(np, func) 109 | np_res = np_func(*np_args) 110 | 111 | jun_func = getattr(jun, func) 112 | jun_res = jun_func(*jun_args) 113 | assert_quantity_allclose(jun_res, np_res) 114 | jun_res = jax.jit(jun_func)(*jun_args) 115 | assert_quantity_allclose(jun_res, np_res) 116 | 117 | np_res_no_units = np_func(*(x.magnitude for x in np_args)) 118 | jun_res_no_units = jun_func(*(x.magnitude for x in jun_args)) 119 | check_close(jun_res_no_units, np_res_no_units) 120 | 121 | 122 | @pytest.mark.parametrize( 123 | "op", [operator.add, operator.sub, operator.mul, operator.truediv] 124 | ) 125 | def test_duck_type(op): 126 | u = UnitRegistry() 127 | 128 | @jax.jit 129 | def func(x): 130 | return op(jnp.array(5.0), x) 131 | 132 | check_close(func(10.0 * u.dimensionless).magnitude, op(5.0, 10.0)) 133 | -------------------------------------------------------------------------------- /tests/test_readme.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | 4 | def test_readme(): 5 | import jax 6 | import numpy as np 7 | 8 | from jpu import UnitRegistry, numpy as jnpu 9 | 10 | u = UnitRegistry() 11 | 12 | @jax.jit 13 | def projectile_motion(v_init, theta, time, g=u.standard_gravity): 14 | """Compute the motion of a projectile with support for units""" 15 | x = v_init * time * jnpu.cos(theta) 16 | y = v_init * time * jnpu.sin(theta) - 0.5 * g * jnpu.square(time) 17 | return x.to(u.m), y.to(u.m) 18 | 19 | x, y = projectile_motion(5.0 * u.km / u.h, 60 * u.deg, np.linspace(0, 1, 50) * u.s) 20 | assert x.units == u.m 21 | assert y.units == u.m 22 | -------------------------------------------------------------------------------- /tests/test_registry.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax._src.public_test_util import check_close 6 | 7 | from jpu.registry import UnitRegistry 8 | 9 | 10 | def test_tree_flatten(): 11 | u = UnitRegistry() 12 | x = jnp.array([1.4, 2.0, -5.9]) 13 | q = x * u.m 14 | 15 | val, _ = jax.tree_util.tree_flatten(q) 16 | assert len(val) == 1 17 | check_close(val[0], x) 18 | 19 | 20 | def test_jittable(): 21 | u = UnitRegistry() 22 | x = jnp.array([1.4, 2.0, -5.9]) 23 | q = x * u.m 24 | 25 | @jax.jit 26 | def func(q): 27 | assert q.u == u.m 28 | return q + 4.5 * u.km 29 | 30 | res = func(q) 31 | assert res.units == u.m 32 | check_close(res.magnitude, x + 4500.0) 33 | 34 | 35 | def test_ducktype(): 36 | u = UnitRegistry() 37 | x = jnp.array([1.4, 2.0, -5.9]) 38 | q = x * u.m 39 | 40 | res = q.sum() 41 | assert res.units == u.m 42 | check_close(res.magnitude, x.sum()) 43 | 44 | @jax.jit 45 | def func(q): 46 | print(q) 47 | return q.sum() 48 | 49 | res = func(q) 50 | print(type(q), type(res)) 51 | assert res.units == u.m 52 | check_close(res.magnitude, x.sum()) 53 | 54 | 55 | def test_unary_ops(): 56 | u = UnitRegistry() 57 | 58 | x = jnp.array([1.4, 2.0, -5.9]) 59 | q = x * u.m 60 | 61 | for func in [ 62 | lambda q: q**2, 63 | lambda q: q.sum(), 64 | lambda q: 2 * q, 65 | ]: 66 | res = func(q) 67 | check_close(res.magnitude, func(x)) 68 | res = jax.jit(func)(q) 69 | check_close(res.magnitude, func(x)) 70 | --------------------------------------------------------------------------------