├── .gitattributes ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── BUG-REPORT.yml │ ├── FEATURE-REQUEST.yml │ └── config.yml ├── dependabot.yml ├── pull_request_template.md └── workflows │ ├── build_docs.yml │ ├── bump_lockfile.yml │ ├── conventional-pr-linter.yml │ ├── pre-commit-cron-updater.yml │ ├── pre_commit.yml │ ├── publish.yml │ ├── run_tests.yml │ ├── test_examples.yml │ └── test_pip_install.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── cliff.toml ├── codecov.yml ├── docs ├── Makefile ├── conf.py ├── content │ ├── api.md │ └── get-started.md ├── demo_notebooks │ └── .gitignore ├── index.md ├── make.bat └── static │ ├── custom.css │ ├── logo-dark.png │ ├── logo-light.png │ └── logo-transparent.png ├── examples ├── README.md ├── cfd │ ├── cfd-tesseract │ │ ├── tesseract_api.py │ │ ├── tesseract_config.yaml │ │ └── tesseract_requirements.txt │ ├── demo.ipynb │ ├── pl.png │ └── requirements.txt └── simple │ ├── demo.ipynb │ └── vectoradd_jax │ ├── tesseract_api.py │ ├── tesseract_config.yaml │ └── tesseract_requirements.txt ├── production.uv.lock ├── pyproject.toml ├── requirements.txt ├── ruff.toml ├── setup.py ├── tesseract_jax ├── __init__.py ├── _version.py ├── primitive.py └── tesseract_compat.py └── tests ├── conftest.py ├── nested_tesseract ├── tesseract_api.py ├── tesseract_config.yaml └── tesseract_requirements.txt ├── test_endtoend.py └── univariate_tesseract ├── tesseract_api.py ├── tesseract_config.yaml └── tesseract_requirements.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | tesseract_jax/_version.py export-subst 2 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # These owners will be the default owners for everything in 2 | # the repo. Unless a later match takes precedence, 3 | # global owners will be requested for 4 | # review when someone opens a pull request. 5 | * @dionhaefner @xalelax @apaleyes 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/BUG-REPORT.yml: -------------------------------------------------------------------------------- 1 | name: "Bug Report" 2 | description: Report a new bug encountered while using Tesseract-JAX. 3 | type: "Bug" 4 | 5 | body: 6 | - type: textarea 7 | id: description 8 | attributes: 9 | label: "Description" 10 | description: Please describe your issue in as much detail as necessary. Include relevant information that you think will help us understand the problem. 11 | placeholder: | 12 | A clear and concise description of the bug and what you expected to happen. 13 | validations: 14 | required: true 15 | 16 | - type: textarea 17 | id: reprod 18 | attributes: 19 | label: "Steps to reproduce" 20 | description: Please provide detailed steps for reproducing the issue. Include any code snippets or commands that you used when the issue occurred. 21 | placeholder: | 22 | ```python 23 | from tesseract_jax import apply_tesseract 24 | ... 25 | ``` 26 | validations: 27 | required: true 28 | 29 | - type: textarea 30 | id: logs 31 | attributes: 32 | label: "Logs" 33 | description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. 34 | render: bash 35 | validations: 36 | required: false 37 | 38 | - type: dropdown 39 | id: os 40 | attributes: 41 | label: "OS" 42 | description: What is the impacted environment? 43 | multiple: true 44 | options: 45 | - Windows 46 | - Linux 47 | - Mac 48 | validations: 49 | required: true 50 | 51 | - type: input 52 | id: tesseractVersion 53 | attributes: 54 | label: Tesseract + Tesseract-JAX version 55 | description: Paste the output of `tesseract --version` and `python -c "import tesseract_jax; print(tesseract_jax.__version__)"` here. 56 | validations: 57 | required: true 58 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/FEATURE-REQUEST.yml: -------------------------------------------------------------------------------- 1 | name: "Feature Request" 2 | description: Request a new feature. 3 | type: "Feature" 4 | 5 | body: 6 | - type: textarea 7 | id: summary 8 | attributes: 9 | label: "Summary" 10 | description: Provide a brief summary of the requested feature. 11 | placeholder: | 12 | A clear and concise description of what the feature is. 13 | validations: 14 | required: true 15 | 16 | - type: textarea 17 | id: neccesity 18 | attributes: 19 | label: "Why is this needed?" 20 | description: Provide an explanation of why this feature is needed. Who is it for? What problem does it solve? Why do none of the existing features solve this problem? 21 | placeholder: | 22 | A clear and concise description of what itch this feature scratches. 23 | validations: 24 | required: true 25 | 26 | - type: textarea 27 | id: basic_example 28 | attributes: 29 | label: "Usage example" 30 | description: Please describe how end users would interact with the proposed feature. Include any code snippets or output examples that you think will help us understand. 31 | placeholder: | 32 | ```python 33 | >>> from tesseract_jax import myfeature 34 | >>> myfeature.do_something() 35 | 36 | ``` 37 | validations: 38 | required: true 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Talk to us 4 | url: https://si-tesseract.discourse.group/ 5 | about: Support requests, inquiries, or general chatter related to Tesseracts. 6 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | updates: 4 | - package-ecosystem: github-actions 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | day: "monday" 9 | timezone: "America/New_York" 10 | open-pull-requests-limit: 10 11 | 12 | groups: 13 | actions: 14 | patterns: 15 | - "*" 16 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 5 | 6 | #### Relevant issue or PR 7 | 8 | 9 | #### Description of changes 10 | 11 | 12 | #### Testing done 13 | 14 | 15 | #### License 16 | 17 | - [ ] By submitting this pull request, I confirm that my contribution is made under the terms of the [Apache 2.0 license](https://pasteurlabs.github.io/tesseract-jax/LICENSE). 18 | - [ ] I sign the Developer Certificate of Origin below by adding my name and email address to the `Signed-off-by` line. 19 | 20 |
21 | Developer Certificate of Origin 22 | 23 | ```text 24 | Developer Certificate of Origin 25 | Version 1.1 26 | 27 | Copyright (C) 2004, 2006 The Linux Foundation and its contributors. 28 | 29 | Everyone is permitted to copy and distribute verbatim copies of this 30 | license document, but changing it is not allowed. 31 | 32 | 33 | Developer's Certificate of Origin 1.1 34 | 35 | By making a contribution to this project, I certify that: 36 | 37 | (a) The contribution was created in whole or in part by me and I 38 | have the right to submit it under the open source license 39 | indicated in the file; or 40 | 41 | (b) The contribution is based upon previous work that, to the best 42 | of my knowledge, is covered under an appropriate open source 43 | license and I have the right under that license to submit that 44 | work with modifications, whether created in whole or in part 45 | by me, under the same open source license (unless I am 46 | permitted to submit under a different license), as indicated 47 | in the file; or 48 | 49 | (c) The contribution was provided directly to me by some other 50 | person who certified (a), (b) or (c) and I have not modified 51 | it. 52 | 53 | (d) I understand and agree that this project and the contribution 54 | are public and that a record of the contribution (including all 55 | personal information I submit with it, including my sign-off) is 56 | maintained indefinitely and may be redistributed consistent with 57 | this project or the open source license(s) involved. 58 | ``` 59 | 60 |
61 | 62 | Signed-off-by: [YOUR NAME] <[YOUR EMAIL]> 63 | -------------------------------------------------------------------------------- /.github/workflows/build_docs.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | # run on PRs for validation 5 | pull_request: 6 | 7 | # this is used by deploy_pages.yml to do docs build on main 8 | workflow_call: 9 | inputs: 10 | artifact_name: 11 | description: "Name of the artifact to upload" 12 | required: false 13 | type: string 14 | 15 | jobs: 16 | test-docs: 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - name: Set up Git repository 21 | uses: actions/checkout@v4 22 | 23 | - name: Install system requirements 24 | run: | 25 | sudo apt-get update 26 | sudo apt-get install -y pandoc 27 | 28 | - name: Install uv 29 | uses: astral-sh/setup-uv@v6 30 | with: 31 | enable-cache: true 32 | 33 | - name: Set up Python 34 | uses: actions/setup-python@v5 35 | with: 36 | python-version-file: "pyproject.toml" 37 | 38 | - name: Restore UV environment 39 | run: cp production.uv.lock uv.lock 40 | 41 | - name: Install doc requirements 42 | run: | 43 | uv sync --extra docs --frozen 44 | 45 | - name: Build docs 46 | working-directory: docs 47 | run: | 48 | export SPHINXOPTS="-W" # treat warnings as errors 49 | uv run --no-sync make html 50 | 51 | - name: Upload HTML files 52 | uses: actions/upload-artifact@v4 53 | if: ${{ inputs.artifact_name }} 54 | with: 55 | name: ${{ inputs.artifact_name }} 56 | path: docs/build/html 57 | if-no-files-found: error 58 | -------------------------------------------------------------------------------- /.github/workflows/bump_lockfile.yml: -------------------------------------------------------------------------------- 1 | name: Bump UV lockfile 2 | 3 | on: 4 | workflow_dispatch: # Allows manual trigger 5 | 6 | schedule: 7 | - cron: '0 0 * * 1' # 12AM only on Mondays 8 | 9 | jobs: 10 | bump-lockfile: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Set up Git repository 15 | uses: actions/checkout@v4 16 | 17 | - name: Install uv 18 | uses: astral-sh/setup-uv@v6 19 | with: 20 | enable-cache: true 21 | 22 | - name: Set up Python 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version-file: "pyproject.toml" 26 | 27 | - name: Update lockfile 28 | run: | 29 | mv production.uv.lock uv.lock 30 | uv lock --upgrade 31 | mv uv.lock production.uv.lock 32 | 33 | - name: Generate new requirements.txt 34 | run : | 35 | pip install pre-commit 36 | pre-commit run update-requirements --all-files || true 37 | 38 | - name: Detect if changes were made 39 | id: git-diff 40 | run: | 41 | changes=false 42 | git diff --exit-code || changes=true 43 | echo "update_done=$changes" >> $GITHUB_OUTPUT 44 | 45 | - name: Create Pull Request 46 | if: steps.git-diff.outputs.update_done == 'true' 47 | uses: peter-evans/create-pull-request@v7 48 | with: 49 | token: ${{ secrets.GITHUB_TOKEN }} 50 | commit-message: Update dependencies 51 | title: "chore: 📦 Update dependencies" 52 | branch: _bot/update-deps 53 | draft: false 54 | base: main 55 | body: | 56 | This PR updates the lockfile to the latest versions of the dependencies. 57 | Please review the changes and merge when ready. 58 | 59 | To trigger CI checks, please close and reopen this PR. 60 | -------------------------------------------------------------------------------- /.github/workflows/conventional-pr-linter.yml: -------------------------------------------------------------------------------- 1 | name: "Lint PR" 2 | 3 | on: 4 | pull_request_target: 5 | types: 6 | - opened 7 | - edited 8 | - synchronize 9 | - reopened 10 | 11 | permissions: 12 | pull-requests: read 13 | 14 | jobs: 15 | main: 16 | name: Validate PR title 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: amannn/action-semantic-pull-request@v5 20 | env: 21 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 22 | with: 23 | types: | 24 | fix 25 | feat 26 | doc 27 | perf 28 | refactor 29 | test 30 | chore 31 | ci 32 | security 33 | scopes: | 34 | cli 35 | engine 36 | sdk 37 | example 38 | runtime 39 | deps 40 | requireScope: false 41 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit-cron-updater.yml: -------------------------------------------------------------------------------- 1 | name: Auto update pre-commit hooks 2 | 3 | on: 4 | workflow_dispatch: # Allows manual trigger 5 | 6 | schedule: 7 | - cron: '0 0 * * 1' # 12AM only on Mondays 8 | 9 | jobs: 10 | auto-update-hooks: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v4 15 | 16 | - uses: actions/setup-python@v5 17 | with: 18 | python-version-file: "pyproject.toml" 19 | 20 | - name: Install pre-commit 21 | run: pip install pre-commit 22 | 23 | - name: Run pre-commit autoupdate 24 | run: pre-commit autoupdate 25 | 26 | - name: Detect if changes were made 27 | id: git-diff 28 | run: | 29 | changes=false 30 | git diff --exit-code || changes=true 31 | echo "update_done=$changes" >> $GITHUB_OUTPUT 32 | 33 | - name: Run pre-commit 34 | id: pre-commit 35 | if: steps.git-diff.outputs.update_done == 'true' 36 | run: | 37 | # Run twice so we only fail if there are non-fixable issues 38 | rc=0 39 | pre-commit run --all-files || true 40 | pre-commit run --all-files > /tmp/pre-commit.log || rc=$? 41 | 42 | # Add log as step output 43 | echo "pre-commit-log< $GITHUB_OUTPUT 44 | cat /tmp/pre-commit.log >> $GITHUB_OUTPUT 45 | echo "EOF" >> $GITHUB_OUTPUT 46 | 47 | # Add linting outcome as step output 48 | if [ $rc -eq 0 ]; then 49 | echo "pre-commit-outcome=success" >> $GITHUB_OUTPUT 50 | else 51 | echo "pre-commit-outcome=failure" >> $GITHUB_OUTPUT 52 | fi 53 | 54 | # Distinguish 3 cases: 55 | # 1. No changes were made -> do nothing (steps below are skipped) 56 | # 2. Changes were made and pre-commit ran successfully -> create PR 57 | # 3. Changes were made but pre-commit failed -> create PR with draft status 58 | 59 | - name: Create Pull Request (all good) 60 | if: steps.pre-commit.outputs.pre-commit-outcome == 'success' 61 | uses: peter-evans/create-pull-request@v7 62 | with: 63 | token: ${{ secrets.GITHUB_TOKEN }} 64 | commit-message: Update pre-commit hooks 65 | title: "chore: ✅ Update pre-commit hooks" 66 | branch: _bot/update-precommit 67 | draft: false 68 | body: | 69 | Pre-commit hooks have been updated successfully without conflicts. 70 | 71 | - name: Create Pull Request (conflicts) 72 | if: steps.pre-commit.outputs.pre-commit-outcome == 'failure' 73 | uses: peter-evans/create-pull-request@v7 74 | with: 75 | token: ${{ secrets.GITHUB_TOKEN }} 76 | commit-message: Update pre-commit hooks 77 | title: "chore: ⚠️ Update pre-commit hooks [review required]" 78 | branch: _bot/update-precommit 79 | draft: true 80 | body: | 81 | Pre-commit is unable to automatically update the hooks due to unresolvable conflicts. 82 | Please review the changes and merge manually. 83 | 84 | Log: 85 | ``` 86 | ${{ steps.pre-commit.outputs.pre-commit-log }} 87 | ``` 88 | -------------------------------------------------------------------------------- /.github/workflows/pre_commit.yml: -------------------------------------------------------------------------------- 1 | name: Code linting 2 | 3 | on: 4 | pull_request: 5 | 6 | push: 7 | branches: 8 | - main 9 | 10 | jobs: 11 | pre-commit: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version-file: "pyproject.toml" 20 | 21 | - uses: pre-commit/action@v3.0.1 22 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | workflow_dispatch: # Allows manual trigger 5 | 6 | release: 7 | types: 8 | - published 9 | 10 | jobs: 11 | build: 12 | name: "Build distribution" 13 | runs-on: '${{ matrix.os }}' 14 | strategy: 15 | matrix: 16 | # TODO: Comment in additional platforms if using C extensions / platform-specific wheels 17 | os: 18 | - ubuntu-latest 19 | # - macos-latest 20 | python-version: 21 | # - "3.10" 22 | # - "3.11" 23 | # - "3.12" 24 | - "3.13" 25 | 26 | steps: 27 | - name: Checkout 28 | uses: actions/checkout@v4 29 | 30 | # make sure tags are fetched so we can get a version 31 | - name: Fetch Tags 32 | run: | 33 | git fetch --prune --unshallow --tags 34 | 35 | - name: Set up uv 36 | uses: astral-sh/setup-uv@v6 37 | with: 38 | enable-cache: true 39 | 40 | - name: Set up Python 41 | uses: actions/setup-python@v5 42 | with: 43 | python-version: ${{ matrix.python-version }} 44 | 45 | - name: Restore UV environment 46 | run: cp production.uv.lock uv.lock 47 | 48 | - name: Build Package 49 | run: | 50 | uv build 51 | 52 | - name: Store the distribution packages 53 | uses: actions/upload-artifact@v4 54 | with: 55 | name: python-package-distributions 56 | path: dist/ 57 | 58 | publish: 59 | name: "Publish distribution" 60 | if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes 61 | runs-on: ubuntu-latest 62 | needs: build 63 | 64 | environment: 65 | name: pypi 66 | url: https://pypi.org/p/tesseract-jax 67 | 68 | permissions: 69 | id-token: write 70 | 71 | steps: 72 | - name: Download all the dists 73 | uses: actions/download-artifact@v4 74 | with: 75 | name: python-package-distributions 76 | path: dist/ 77 | 78 | - name: Publish distribution to PyPI 79 | uses: pypa/gh-action-pypi-publish@release/v1 80 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | name: Run test suite 2 | 3 | on: 4 | pull_request: 5 | 6 | push: 7 | branches: 8 | - main 9 | 10 | jobs: 11 | tests: 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest] 15 | # NOTE: If bumping the minimum Python version here, also do it in 16 | # ruff.toml, setup.py and other CI files as well. 17 | 18 | # test with oldest and latest supported Python versions 19 | python-version: ["3.10", "3.13"] 20 | # test with oldest supported Python version only (for slow tests) 21 | # python-version: ["3.10"] 22 | # test with *all* supported Python versions 23 | # python-version: ["3.10", "3.11", "3.12", "3.13"] 24 | 25 | fail-fast: false 26 | 27 | runs-on: ${{ matrix.os }} 28 | 29 | steps: 30 | - name: Set up Git repository 31 | uses: actions/checkout@v4 32 | 33 | - name: Install uv 34 | uses: astral-sh/setup-uv@v6 35 | with: 36 | enable-cache: true 37 | 38 | - name: Set up Python 39 | uses: actions/setup-python@v5 40 | with: 41 | python-version: ${{ matrix.python-version }} 42 | 43 | - name: Restore UV environment 44 | run: cp production.uv.lock uv.lock 45 | 46 | - name: Install dev requirements 47 | run: | 48 | uv sync --extra dev --frozen 49 | 50 | - name: Run test suite 51 | run: | 52 | set -o pipefail 53 | uv run --no-sync pytest \ 54 | --cov-report=term-missing:skip-covered \ 55 | --cov-report=xml:coverage.xml \ 56 | --cov=tesseract_jax 57 | 58 | - name: Upload coverage reports to Codecov 59 | uses: codecov/codecov-action@v5.4.3 60 | with: 61 | token: ${{ secrets.CODECOV_TOKEN }} 62 | slug: pasteurlabs/tesseract-jax 63 | files: coverage*.xml 64 | fail_ci_if_error: true 65 | -------------------------------------------------------------------------------- /.github/workflows/test_examples.yml: -------------------------------------------------------------------------------- 1 | name: Run example notebooks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | tests: 10 | strategy: 11 | matrix: 12 | os: [ubuntu-latest] 13 | # NOTE: If bumping the minimum Python version here, also do it in 14 | # ruff.toml, setup.py and other CI files as well. 15 | 16 | # test with oldest supported Python version only (for slow tests) 17 | python-version: ["3.10"] 18 | 19 | example: 20 | - simple 21 | - cfd 22 | 23 | fail-fast: false 24 | 25 | runs-on: ${{ matrix.os }} 26 | 27 | steps: 28 | - name: Set up Git repository 29 | uses: actions/checkout@v4 30 | 31 | - name: Install uv 32 | uses: astral-sh/setup-uv@v6 33 | with: 34 | enable-cache: true 35 | 36 | - name: Set up Python 37 | uses: actions/setup-python@v5 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | 41 | - name: Restore UV environment 42 | run: cp production.uv.lock uv.lock 43 | 44 | - name: Install dev requirements 45 | run: | 46 | uv sync --extra dev --frozen 47 | 48 | - name: Run example 49 | working-directory: examples/${{matrix.example}} 50 | run: | 51 | uv pip install jupyter 52 | uv run --no-sync jupyter nbconvert --to notebook --execute demo.ipynb 53 | -------------------------------------------------------------------------------- /.github/workflows/test_pip_install.yml: -------------------------------------------------------------------------------- 1 | name: Test installation via pip 2 | 3 | on: 4 | # run on PRs for validation 5 | pull_request: 6 | 7 | 8 | jobs: 9 | test-pip-install: 10 | name: "Test pip install" 11 | runs-on: '${{ matrix.os }}' 12 | strategy: 13 | matrix: 14 | os: 15 | - ubuntu-latest 16 | - macos-latest 17 | python-version: 18 | - "3.10" 19 | - "3.13" 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v4 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | 29 | - name: Install package 30 | run: | 31 | pip install -r requirements.txt 32 | pip install --no-deps . 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # OSX stuff 132 | .DS_Store 133 | 134 | # Vim 135 | *.swp 136 | 137 | # IDEs 138 | .idea/ 139 | .vscode/ 140 | 141 | # Ignore UV dev environments 142 | uv.lock 143 | uv.lock.bak 144 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-added-large-files 6 | additional_dependencies: [--isolated] 7 | args: ["--maxkb=2000"] 8 | # Add exceptions here, as a regex 9 | exclude: "" 10 | 11 | - id: check-json 12 | additional_dependencies: [--isolated] 13 | 14 | - id: check-toml 15 | additional_dependencies: [--isolated] 16 | 17 | - id: check-yaml 18 | additional_dependencies: [--isolated] 19 | 20 | - id: detect-private-key 21 | additional_dependencies: [--isolated] 22 | 23 | - id: end-of-file-fixer 24 | additional_dependencies: [--isolated] 25 | 26 | - id: trailing-whitespace 27 | additional_dependencies: [--isolated] 28 | 29 | - repo: https://github.com/astral-sh/ruff-pre-commit 30 | # Ruff version. 31 | rev: v0.11.12 32 | hooks: 33 | # Run the linter. 34 | - id: ruff 35 | args: [--fix] 36 | types_or: [pyi, python, jupyter] 37 | # Ignore global python configuration for private registry and install hooks from public index 38 | # Add for each hook 39 | # Reference: https://github.com/pre-commit/pre-commit/issues/1454#issuecomment-1816328894 40 | additional_dependencies: [--isolated] 41 | # Run the formatter. 42 | - id: ruff-format 43 | types_or: [pyi, python, jupyter] 44 | additional_dependencies: [--isolated] 45 | 46 | - repo: local 47 | hooks: 48 | 49 | # Update production.uv.lock after pyproject.toml changes 50 | - id: update-uv-env 51 | name: update-uv-env 52 | files: ^pyproject.toml$ 53 | stages: [pre-commit] 54 | language: python 55 | entry: | 56 | bash -c ' \ 57 | if [ -z "$DOWNLOAD_TOKEN" ]; then \ 58 | DOWNLOAD_TOKEN="VssSessionToken"; \ 59 | uv tool install keyring --with artifacts-keyring; \ 60 | fi; \ 61 | cp uv.lock uv.lock.bak; \ 62 | cp production.uv.lock uv.lock; \ 63 | uv lock --index https://$DOWNLOAD_TOKEN@pkgs.dev.azure.com/pasteur-labs/d89796ea-4a5b-48aa-9930-4cebcdc9d64a/_packaging/internal/pypi/simple/ --keyring-provider subprocess; \ 64 | rc=$?; \ 65 | cp uv.lock production.uv.lock; \ 66 | mv uv.lock.bak uv.lock; \ 67 | exit $rc; 68 | ' 69 | additional_dependencies: [uv==0.6.11,--isolated] 70 | 71 | # Update requirements.txt after production.uv.lock changes 72 | - id: update-requirements 73 | name: update-requirements 74 | files: ^production.uv.lock$ 75 | stages: [pre-commit] 76 | language: python 77 | entry: | 78 | bash -c ' \ 79 | cp uv.lock uv.lock.bak; \ 80 | cp production.uv.lock uv.lock; \ 81 | uv export --frozen --color never --no-emit-project --no-hashes > requirements.txt; \ 82 | rc=$?; \ 83 | cp uv.lock production.uv.lock; \ 84 | mv uv.lock.bak uv.lock; \ 85 | exit $rc; 86 | ' 87 | additional_dependencies: [uv==0.6.11,--isolated] 88 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the OS, Python version, and other tools you might need 8 | build: 9 | os: ubuntu-24.04 10 | tools: 11 | python: "3.13" 12 | 13 | jobs: 14 | create_environment: 15 | - asdf plugin add uv 16 | - asdf install uv latest 17 | - asdf global uv latest 18 | - uv venv 19 | install: 20 | - cp production.uv.lock uv.lock 21 | - uv sync --extra dev --frozen 22 | build: 23 | html: 24 | - uv run sphinx-build -T -b html docs $READTHEDOCS_OUTPUT/html 25 | 26 | # Build documentation in the "docs/" directory with Sphinx 27 | sphinx: 28 | configuration: docs/conf.py 29 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | ## [0.2.1] - 2025-04-25 6 | 7 | ### Bug Fixes 8 | 9 | - Jvp in cfd Tesseract should only return derivative (#12) 10 | - Fixed typo in vectoradd Tesseract demo (#11) 11 | 12 | ### Refactor 13 | 14 | - Applied jax recipe to cfd tesseract (#10) 15 | 16 | ### Documentation 17 | 18 | - Add rendered demos to docs (#9) 19 | - Fix quickstart snippet (#13) 20 | - More tweaks based on beta feedback (#15) 21 | - Ensure gradients of example tesseract are well-behaved (#14) 22 | - Improve example notebook presentation (#16) 23 | 24 | 25 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in the 6 | Tesseract project and our community a harassment-free experience 7 | for everyone, regardless of age, body size, visible or invisible disability, 8 | ethnicity, sex characteristics, gender identity and expression, level of 9 | experience, education, socio-economic status, nationality, personal appearance, 10 | race, caste, color, religion, or sexual identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official email address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Attribution 60 | 61 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 62 | version 2.1, available at 63 | https://www.contributor-covenant.org/version/2/1/code_of_conduct.html. 64 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Tesseract-JAX 2 | 3 | Tesseract-JAX is an open-source project and, as such, we welcome contributions 4 | from developers, engineers, scientists, and end-users in general. Contributions 5 | are what make the open source community such an amazing place to learn, 6 | inspire, and create. Any contributions you make are greatly appreciated. 7 | 8 | 9 | ## Code of Conduct 10 | 11 | Ensure your contributions adhere to the [Code of Conduct](CODE_OF_CONDUCT.md). 12 | 13 | 14 | ## Feedback 15 | 16 | Constructive feedback is very welcome. We are interested in hearing from you! 17 | 18 | In the case things aren't working as expected, or the documentation is lacking, 19 | please [file a bug 20 | report](https://github.com/pasteurlabs/tesseract-jax/issues/new?template=BUG-REPORT.yml). 21 | 22 | In the case you want to suggest a new feature, please file a new [feature 23 | request](https://github.com/pasteurlabs/tesseract-jax/issues/new?template=FEATURE-REQUEST.yml). 24 | In particular, we recommend you open an issue before contributing code in a 25 | pull request. This allows all parties to talk things over before jumping into 26 | action, and increase the likelihood of pull requests getting merged. 27 | 28 | In case you have general questions or feedback, need support from the 29 | community, or have a cool demo to share, start a thread in our [Discourse 30 | Forum](https://si-tesseract.discourse.group/). We use GitHub Issues for bug 31 | reports and feature requests only. 32 | 33 | 34 | ## Documentation 35 | 36 | Tesseract documentation is kept under the `docs/` directory of the repository, 37 | written in Markdown and using Sphinx to generate the final HTMLs. Fixes and 38 | enhancements to the documentation should be submitted as pull requests, we 39 | treat the same as code contributions. 40 | 41 | To build the documentation locally, install the documentation dependencies in 42 | addition to the project itself, then run `make html`: 43 | 44 | ```console 45 | $ . venv/bin/activate 46 | $ pip install -e .[dev] 47 | $ pip install -r docs/requirements.txt 48 | $ cd docs 49 | $ make html 50 | ``` 51 | 52 | The resulting HTMLs are in `docs/build/html/`. 53 | 54 | Contributions in the form of tutorials, examples, demos, blog posts (including 55 | those posted elsewhere already) are best highlighted and celebrated in the 56 | [Discourse Forum](https://si-tesseract.discourse.group/). 57 | 58 | 59 | ## Code 60 | 61 | Tesseract is developed under the [Apache 2.0](LICENSE) license. By contributing 62 | to the Tesseract project you agree that your code contributions are governed by 63 | this license. 64 | 65 | 66 | ### Local development setup 67 | 68 | Make sure you have [Docker installed](https://docs.docker.com/engine/install/) 69 | on your machine and you can run `docker` commands via your user. After that, 70 | clone the repository, install the dependencies, and setup pre-commit hooks: 71 | 72 | ```console 73 | $ git clone git@github.com:pasteurlabs/tesseract-jax.git 74 | $ cd tesseract 75 | $ python -m venv venv 76 | $ . venv/bin/activate 77 | $ pip install -e .[dev] 78 | $ pre-commit install 79 | ``` 80 | 81 | ### Tests 82 | 83 | This project uses the pytest framework for all tests. New code should be 84 | covered by new or existing tests. 85 | 86 | To run the tests simply run `pytest` in the root of the project: 87 | 88 | ```console 89 | $ pytest 90 | ``` 91 | 92 | ### GitHub workflow 93 | 94 | This project uses Git for version control and follows a GitHub workflow. To 95 | contribute follow these steps: 96 | 97 | 1. Fork the project via the GitHub UI. 98 | 1. Clone your fork to your machine. 99 | 1. Add an upstream remote: `git remote add upstream git@github.com:pasteurlabs/tesseract-jax.git`. 100 | 1. Create a new branch for your code contribution: `git switch --create my_branch`. 101 | 1. Implement your changes. 102 | 1. Commit and push to your fork: `git push --set-upstream origin my_branch`. 103 | 1. [Open a Pull Request](https://github.com/pasteurlabs/tesseract-jax/pulls) with 104 | your changes. 105 | 106 | It is a good practice to rebase often on top of `main` to keep your code up to 107 | date with latest development and minimize merge conflicts: 108 | 109 | ```console 110 | $ git fetch upstream 111 | $ git switch main 112 | $ git merge upstream/main 113 | $ git switch my_branch 114 | $ git rebase main 115 | $ git push --force 116 | ``` 117 | 118 | ### Commit and pull request messages guidelines 119 | 120 | We follow the [Conventional 121 | Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification for all 122 | commits that reach the `main` branch. Each commit is crafted from a pull 123 | request that is squash-merged. The commit title and message comes from the pull 124 | request title and message, respectively. As such, they should be structured 125 | following the specfication. 126 | 127 | The title consists of a _type_, and optional _scope_, and a short 128 | _description_: `type[(scope)]: description`. The types we use are: 129 | - `chore`: for changes that affect the build system, external dependencies, or 130 | general housekeeping. 131 | - `ci`: for changes in the CI. 132 | - `doc`: for documentation only changes. 133 | - `feat`: for a new feature. 134 | - `fix`: for fixing a bug. 135 | - `perf`: for a code change that improves performance. 136 | - `refactor`: for a code change that neither adds a feature nor fixes a bug. 137 | - `security`: for a change that fixes a security issue. 138 | - `test`: for adding new tests or fixing existing ones. 139 | 140 | The scopes we use are: 141 | - `cli`: for changes that affect `tesseract` CLI. 142 | - `engine`: for changes that affect the CLI engine. 143 | - `sdk`: for changes that affect the Tesseract Python API. 144 | - `example`: for changes in the examples. 145 | - `runtime`: for changes in the Tesseract Runtime. 146 | - `deps`: for changes in the dependencies. 147 | 148 | In case there are breaking changes in your code, this should be indicated in 149 | the message either by appending an exclamation mark (`!`) after the type/scope 150 | or by adding a `BREAKING CHANGE:` trailer to the message. 151 | 152 | 153 | ## Versioning 154 | 155 | The Tesseract project follows [semantic versioning](https://semver.org). 156 | 157 | 158 | ## Changelog 159 | 160 | This project [changelog](CHANGELOG.md) is generated by 161 | [git-cliff](https://git-cliff.org). The changelog is generated by running 162 | `git cliff --output CHANGELOG.md`. To bump the semantic version of the project 163 | when generating the changelog: `git cliff --output CHANGELOG.md --bump`. 164 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025- Pasteur Labs 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Include all files here that should be part of source distributions 2 | include requirements.txt 3 | include README.md 4 | 5 | # Include Cython source files, and exclude generated C files 6 | # (required to build Cython extensions during wheel creation) 7 | recursive-include tesseract_jax *.pyx *.pxd 8 | recursive-exclude tesseract_jax *.c 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ### Tesseract-JAX 4 | 5 | Tesseract-JAX is a lightweight extension to [Tesseract Core](https://github.com/pasteurlabs/tesseract-core) that makes Tesseracts look and feel like regular [JAX](https://github.com/jax-ml/jax) primitives, and makes them jittable, differentiable, and composable. 6 | 7 | [Read the docs](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/) | 8 | [Explore the examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) | 9 | [Report an issue](https://github.com/pasteurlabs/tesseract-jax/issues) | 10 | [Talk to the community](https://si-tesseract.discourse.group/) | 11 | [Contribute](CONTRIBUTING.md) 12 | 13 | --- 14 | 15 | The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/content/api.html#tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines: 16 | 17 | ```python 18 | @jax.jit 19 | def vector_sum(x, y): 20 | res = apply_tesseract(vectoradd_tesseract, {"a": {"v": x}, "b": {"v": y}}) 21 | return res["vector_add"]["result"].sum() 22 | 23 | jax.grad(vector_sum)(x, y) # 🎉 24 | ``` 25 | 26 | ## Quick start 27 | 28 | > [!NOTE] 29 | > Before proceeding, make sure you have a [working installation of Docker](https://docs.docker.com/engine/install/) and a modern Python installation (Python 3.10+). 30 | 31 | > [!IMPORTANT] 32 | > For more detailed installation instructions, please refer to the [Tesseract Core documentation](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/introduction/installation.html). 33 | 34 | 1. Install Tesseract-JAX: 35 | 36 | ```bash 37 | $ pip install tesseract-jax 38 | ``` 39 | 40 | 2. Build an example Tesseract: 41 | 42 | ```bash 43 | $ git clone https://github.com/pasteurlabs/tesseract-jax 44 | $ tesseract build tesseract-jax/examples/simple/vectoradd_jax 45 | ``` 46 | 47 | 3. Use it as part of a JAX program via the JAX-native `apply_tesseract` function: 48 | 49 | ```python 50 | import jax 51 | import jax.numpy as jnp 52 | from tesseract_core import Tesseract 53 | from tesseract_jax import apply_tesseract 54 | 55 | # Load the Tesseract 56 | t = Tesseract.from_image("vectoradd_jax") 57 | t.serve() 58 | 59 | # Run it with JAX 60 | x = jnp.ones((1000,)) 61 | y = jnp.ones((1000,)) 62 | 63 | def vector_sum(x, y): 64 | res = apply_tesseract(t, {"a": {"v": x}, "b": {"v": y}}) 65 | return res["vector_add"]["result"].sum() 66 | 67 | vector_sum(x, y) # success! 68 | 69 | # You can also use it with JAX transformations like JIT and grad 70 | vector_sum_jit = jax.jit(vector_sum) 71 | vector_sum_jit(x, y) 72 | 73 | vector_sum_grad = jax.grad(vector_sum) 74 | vector_sum_grad(x, y) 75 | ``` 76 | 77 | > [!TIP] 78 | > Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) for more ways to use Tesseract-JAX. 79 | 80 | ## Sharp edges 81 | 82 | - **Arrays vs. array-like objects**: Tesseract-JAX is stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values. 83 | 84 | ```python 85 | from tesseract_core import Tesseract 86 | from tesseract_jax import apply_tesseract 87 | 88 | tess = Tesseract.from_image("vectoradd_jax") 89 | with Tesseract.from_image("vectoradd_jax") as tess: 90 | apply_tesseract(tess, {"a": {"v": [1.0]}, "b": {"v": [2.0]}}) # ❌ raises an error 91 | apply_tesseract(tess, {"a": {"v": jnp.array([1.0])}, "b": {"v": jnp.array([2.0])}}) # ✅ works 92 | ``` 93 | - **Additional required endpoints**: Tesseract-JAX requires the [`abstract_eval`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#abstract-eval) Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like `jax.grad` require [`vector_jacobian_product`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#vector-jacobian-product) to be defined. 94 | 95 | > [!TIP] 96 | > When creating a new Tesseract based on a JAX function, use `tesseract init --recipe jax` to define all required endpoints automatically, including `abstract_eval` and `vector_jacobian_product`. 97 | 98 | ## License 99 | 100 | Tesseract-JAX is licensed under the [Apache License 2.0](LICENSE) and is free to use, modify, and distribute (under the terms of the license). 101 | 102 | Tesseract is a registered trademark of Pasteur Labs, Inc. and may not be used without permission. 103 | -------------------------------------------------------------------------------- /cliff.toml: -------------------------------------------------------------------------------- 1 | # git-cliff ~ default configuration file 2 | # https://git-cliff.org/docs/configuration 3 | # 4 | # Lines starting with "#" are comments. 5 | # Configuration options are organized into tables and keys. 6 | # See documentation for more information on available options. 7 | 8 | [changelog] 9 | # template for the changelog header 10 | header = """ 11 | # Changelog\n 12 | All notable changes to this project will be documented in this file.\n 13 | """ 14 | # template for the changelog body 15 | # https://keats.github.io/tera/docs/#introduction 16 | body = """ 17 | {% if version %}\ 18 | ## [{{ version | trim_start_matches(pat="v") }}] - {{ timestamp | date(format="%Y-%m-%d") }} 19 | {% else %}\ 20 | ## [unreleased] 21 | {% endif %}\ 22 | {% for group, commits in commits | group_by(attribute="group") %} 23 | ### {{ group | striptags | trim | upper_first }} 24 | {% for commit in commits %} 25 | - {% if commit.scope %}*({{ commit.scope }})* {% endif %}\ 26 | {% if commit.breaking %}[**breaking**] {% endif %}\ 27 | {{ commit.message | upper_first }}\ 28 | {% endfor %} 29 | {% endfor %}\n 30 | """ 31 | # template for the changelog footer 32 | footer = """ 33 | 34 | """ 35 | # remove the leading and trailing s 36 | trim = true 37 | # postprocessors 38 | postprocessors = [ 39 | # { pattern = '', replace = "https://github.com/orhun/git-cliff" }, # replace repository URL 40 | ] 41 | # render body even when there are no releases to process 42 | # render_always = true 43 | # output file path 44 | # output = "test.md" 45 | 46 | [git] 47 | # parse the commits based on https://www.conventionalcommits.org 48 | conventional_commits = true 49 | # filter out the commits that are not conventional 50 | filter_unconventional = true 51 | # process each line of a commit as an individual commit 52 | split_commits = false 53 | # regex for preprocessing the commit messages 54 | commit_preprocessors = [ 55 | # Remove gitmoji, both actual UTF emoji and :emoji: 56 | { pattern = ' *(:\w+:|[\p{Emoji_Presentation}\p{Extended_Pictographic}](?:\u{FE0F})?\u{200D}?) *', replace = "" }, 57 | ] 58 | # regex for parsing and grouping commits 59 | commit_parsers = [ 60 | { message = "^feat", group = "Features" }, 61 | { message = "^fix", group = "Bug Fixes" }, 62 | { message = "^doc", group = "Documentation" }, 63 | { message = "^perf", group = "Performance" }, 64 | { message = "^refactor", group = "Refactor" }, 65 | { message = "^test", group = "Testing" }, 66 | { message = "^chore\\(release\\): prepare for", skip = true }, 67 | { message = "^chore\\(deps.*\\)", skip = true }, 68 | { message = "^chore\\(pr\\)", skip = true }, 69 | { message = "^chore\\(pull\\)", skip = true }, 70 | { message = "^chore|^ci", skip = true }, 71 | { body = ".*security", group = "Security" }, 72 | { message = ".*", skip = true }, 73 | ] 74 | # filter out the commits that are not matched by commit parsers 75 | filter_commits = false 76 | # sort the tags topologically 77 | topo_order = false 78 | # sort the commits inside sections by oldest/newest order 79 | sort_commits = "oldest" 80 | 81 | [bump] 82 | features_always_bump_minor = false 83 | breaking_always_bump_major = false 84 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | informational: true 6 | only_pulls: false 7 | patch: off 8 | comment: 9 | layout: "condensed_header, diff, condensed_files, condensed_footer" 10 | hide_project_coverage: false 11 | ignore: 12 | - "tesseract_jax/_version.py" 13 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -c . 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | import re 10 | import shutil 11 | from pathlib import Path 12 | 13 | from tesseract_jax import __version__ 14 | 15 | project = "Tesseract-JAX" 16 | copyright = "2025, Pasteur Labs" 17 | author = "The Tesseract-JAX Team @ Pasteur Labs + OSS contributors" 18 | 19 | # The short X.Y version 20 | parsed_version = re.match(r"(\d+\.\d+\.\d+)", __version__) 21 | if parsed_version: 22 | version = parsed_version.group(1) 23 | else: 24 | version = "0.0.0" 25 | 26 | # The full version, including alpha/beta/rc tags 27 | release = __version__ 28 | 29 | # -- General configuration --------------------------------------------------- 30 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 31 | 32 | extensions = [ 33 | "myst_nb", # This is myst-parser + jupyter notebook support 34 | "sphinx.ext.intersphinx", 35 | "sphinx.ext.autodoc", 36 | "sphinx.ext.napoleon", 37 | "sphinx.ext.viewcode", 38 | "sphinx_autodoc_typehints", 39 | # Copy button for code blocks 40 | "sphinx_copybutton", 41 | # OpenGraph metadata for social media sharing 42 | "sphinxext.opengraph", 43 | ] 44 | 45 | myst_enable_extensions = [ 46 | "dollarmath", 47 | "colon_fence", 48 | ] 49 | 50 | intersphinx_mapping = { 51 | "python": ("https://docs.python.org/3", None), 52 | "numpy": ("http://docs.scipy.org/doc/numpy/", None), 53 | "tesseract_core": ( 54 | "https://docs.pasteurlabs.ai/projects/tesseract-core/latest/", 55 | None, 56 | ), 57 | } 58 | 59 | templates_path = [] 60 | exclude_patterns = ["build", "Thumbs.db", ".DS_Store"] 61 | 62 | 63 | # -- Options for HTML output ------------------------------------------------- 64 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 65 | 66 | html_theme = "furo" 67 | html_static_path = ["static"] 68 | html_theme_options = { 69 | "light_logo": "logo-light.png", 70 | "dark_logo": "logo-dark.png", 71 | "sidebar_hide_name": True, 72 | } 73 | html_css_files = ["custom.css"] 74 | 75 | 76 | # -- Handle Jupyter notebooks ------------------------------------------------ 77 | 78 | # Do not execute notebooks during build (just take existing output) 79 | nb_execution_mode = "off" 80 | 81 | # Copy example notebooks to demo_notebooks folder on every build 82 | for example_notebook in Path("../examples").glob("*/demo.ipynb"): 83 | # Copy the example notebook to the docs folder 84 | dest = (Path("demo_notebooks") / example_notebook.parent.name).with_suffix(".ipynb") 85 | shutil.copyfile(example_notebook, dest) 86 | -------------------------------------------------------------------------------- /docs/content/api.md: -------------------------------------------------------------------------------- 1 | # API reference 2 | 3 | ```{eval-rst} 4 | .. automodule:: tesseract_jax 5 | :members: 6 | :undoc-members: 7 | ``` 8 | -------------------------------------------------------------------------------- /docs/content/get-started.md: -------------------------------------------------------------------------------- 1 | # Get started 2 | 3 | ## Quick start 4 | 5 | ```{note} 6 | Before proceeding, make sure you have a [working installation of Docker](https://docs.docker.com/engine/install/) and a modern Python installation (Python 3.10+). 7 | ``` 8 | 9 | ```{seealso} 10 | For more detailed installation instructions, please refer to the [Tesseract Core documentation](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/introduction/installation.html). 11 | ``` 12 | 13 | 1. Install Tesseract-JAX: 14 | 15 | ```bash 16 | $ pip install tesseract-jax 17 | ``` 18 | 19 | 2. Build an example Tesseract: 20 | 21 | ```bash 22 | $ git clone https://github.com/pasteurlabs/tesseract-jax 23 | $ tesseract build tesseract-jax/examples/simple/vectoradd_jax 24 | ``` 25 | 26 | 3. Use it as part of a JAX program: 27 | 28 | ```python 29 | import jax 30 | import jax.numpy as jnp 31 | from tesseract_core import Tesseract 32 | from tesseract_jax import apply_tesseract 33 | 34 | # Load the Tesseract 35 | t = Tesseract.from_image("vectoradd_jax") 36 | t.serve() 37 | 38 | # Run it with JAX 39 | x = jnp.ones((1000,)) 40 | y = jnp.ones((1000,)) 41 | 42 | def vector_sum(x, y): 43 | res = apply_tesseract(t, {"a": {"v": x}, "b": {"v": y}}) 44 | return res["vector_add"]["result"].sum() 45 | 46 | vector_sum(x, y) # success! 47 | 48 | # You can also use it with JAX transformations like JIT and grad 49 | vector_sum_jit = jax.jit(vector_sum) 50 | vector_sum_jit(x, y) 51 | 52 | vector_sum_grad = jax.grad(vector_sum) 53 | vector_sum_grad(x, y) 54 | ``` 55 | 56 | ```{tip} 57 | Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) for ways to use Tesseract-JAX. 58 | ``` 59 | 60 | ## Sharp edges 61 | 62 | - **Arrays vs. array-like objects**: Tesseract-JAX is stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values. 63 | 64 | ```python 65 | from tesseract_core import Tesseract 66 | from tesseract_jax import apply_tesseract 67 | 68 | tess = Tesseract.from_image("vectoradd_jax") 69 | with Tesseract.from_image("vectoradd_jax") as tess: 70 | apply_tesseract(tess, {"a": {"v": [1.0]}, "b": {"v": [2.0]}}) # ❌ raises an error 71 | apply_tesseract(tess, {"a": {"v": jnp.array([1.0])}, "b": {"v": jnp.array([2.0])}}) # ✅ works 72 | ``` 73 | - **Additional required endpoints**: Tesseract-JAX requires the [`abstract_eval`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#abstract-eval) Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like `jax.grad` require [`vector_jacobian_product`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#vector-jacobian-product) to be defined. 74 | 75 | ```{tip} 76 | When creating a new Tesseract based on a JAX function, use `tesseract init --recipe jax` to define all required endpoints automatically, including `abstract_eval` and `vector_jacobian_product`. 77 | ``` 78 | -------------------------------------------------------------------------------- /docs/demo_notebooks/.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Tesseract-JAX 2 | 3 | Tesseract-JAX is a lightweight extension to [Tesseract Core](https://github.com/pasteurlabs/tesseract-core) that makes Tesseracts look and feel like regular [JAX](https://github.com/jax-ml/jax) primitives, and makes them jittable, differentiable, and composable. 4 | 5 | The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines: 6 | 7 | ```python 8 | @jax.jit 9 | def vector_sum(x, y): 10 | res = apply_tesseract(vectoradd_tesseract, {"a": {"v": x}, "b": {"v": y}}) 11 | return res["vector_add"]["result"].sum() 12 | 13 | jax.grad(vector_sum)(x, y) # 🎉 14 | ``` 15 | 16 | Want to learn more? See how to [get started](content/get-started.md) with Tesseract-JAX, explore the [API reference](content/api.md), or learn by [example](demo_notebooks/simple.ipynb). 17 | 18 | ## License 19 | 20 | Tesseract JAX is licensed under the [Apache License 2.0](https://github.com/pasteurlabs/tesseract-jax/LICENSE) and is free to use, modify, and distribute (under the terms of the license). 21 | 22 | Tesseract is a registered trademark of Pasteur Labs, Inc. and may not be used without permission. 23 | 24 | 25 | ```{toctree} 26 | :caption: Usage 27 | :maxdepth: 2 28 | :hidden: 29 | 30 | content/get-started 31 | content/api 32 | ``` 33 | 34 | ```{toctree} 35 | :caption: Examples 36 | :maxdepth: 2 37 | :hidden: 38 | 39 | demo_notebooks/simple.ipynb 40 | demo_notebooks/cfd.ipynb 41 | ``` 42 | 43 | ```{toctree} 44 | :caption: See also 45 | :maxdepth: 2 46 | :hidden: 47 | 48 | Tesseract Core docs 49 | Tesseract User Forums 50 | ``` 51 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/static/custom.css: -------------------------------------------------------------------------------- 1 | .sidebar-logo { 2 | width: 180px; 3 | } 4 | 5 | .content h2 { 6 | font-size: 1.5em; 7 | } 8 | 9 | .content h3 { 10 | font-size: 1.25em; 11 | } 12 | 13 | .content h4 { 14 | font-size: 1.1em; 15 | } 16 | -------------------------------------------------------------------------------- /docs/static/logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasteurlabs/tesseract-jax/35615fab02bce0cbf80fdc2fe8c05c6f3aa1c0ce/docs/static/logo-dark.png -------------------------------------------------------------------------------- /docs/static/logo-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasteurlabs/tesseract-jax/35615fab02bce0cbf80fdc2fe8c05c6f3aa1c0ce/docs/static/logo-light.png -------------------------------------------------------------------------------- /docs/static/logo-transparent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasteurlabs/tesseract-jax/35615fab02bce0cbf80fdc2fe8c05c6f3aa1c0ce/docs/static/logo-transparent.png -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Tesseract-JAX examples 2 | 3 | This directory contains example Tesseract configurations, notebooks. and scripts demonstrating how to use Tesseract-JAX in various contexts. Each example is self-contained and can be run independently. 4 | 5 | ## Examples 6 | 7 | - [Simple](simple/demo.ipynb): A basic example of using Tesseract-JAX with a simple vector addition task. It demonstrates how to build a Tesseract and execute it within JAX. 8 | - [CFD](cfd/demo.ipynb): A more complex example demonstrating how to use Tesseract-JAX to differentiate through a computational fluid dynamics (CFD) simulation in an optimization context. 9 | -------------------------------------------------------------------------------- /examples/cfd/cfd-tesseract/tesseract_api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pasteur Labs. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from typing import Any 5 | 6 | import equinox as eqx 7 | import jax 8 | import jax.numpy as jnp 9 | import jax_cfd.base as cfd 10 | from pydantic import BaseModel, Field 11 | from tesseract_core.runtime import Array, Differentiable, Float32 12 | from tesseract_core.runtime.tree_transforms import filter_func, flatten_with_paths 13 | 14 | 15 | class InputSchema(BaseModel): 16 | v0: Differentiable[ 17 | Array[ 18 | ( 19 | None, 20 | None, 21 | None, 22 | ), 23 | Float32, 24 | ] 25 | ] = Field(description="3D Array defining the initial velocity field") 26 | density: float = Field(description="Density of the fluid") 27 | viscosity: float = Field(description="Viscosity of the fluid") 28 | inner_steps: int = Field( 29 | description="Number of solver steps for each timestep", default=25 30 | ) 31 | outer_steps: int = Field(description="Number of timesteps steps", default=10) 32 | max_velocity: float = Field(description="Maximum velocity", default=2.0) 33 | cfl_safety_factor: float = Field(description="CFL safety factor", default=0.5) 34 | domain_size_x: float = Field(description="Domain size x", default=1.0) 35 | domain_size_y: float = Field(description="Domain size y", default=1.0) 36 | 37 | 38 | class OutputSchema(BaseModel): 39 | result: Differentiable[Array[(None, None, None), Float32]] = Field( 40 | description="3D Array defining the final velocity field" 41 | ) 42 | 43 | 44 | def cfd_fwd( 45 | v0: jnp.ndarray, 46 | density: float, 47 | viscosity: float, 48 | inner_steps: int, 49 | outer_steps: int, 50 | max_velocity: float, 51 | cfl_safety_factor: float, 52 | domain_size_x: float, 53 | domain_size_y: float, 54 | ) -> tuple[jax.Array, jax.Array]: 55 | """Compute the final velocity field using the semi-implicit Navier-Stokes equations. 56 | 57 | Args: 58 | v0: Initial velocity field. 59 | density: Density of the fluid. 60 | viscosity: Viscosity of the fluid. 61 | inner_steps: Number of solver steps for each timestep. 62 | outer_steps: Number of timesteps steps. 63 | max_velocity: Maximum velocity. 64 | cfl_safety_factor: CFL safety factor. 65 | domain_size_x: Domain size in x direction. 66 | domain_size_y: Domain size in y direction. 67 | 68 | Returns: 69 | Final velocity field. 70 | """ 71 | vx0 = v0[..., 0] 72 | vy0 = v0[..., 1] 73 | bc = cfd.boundaries.HomogeneousBoundaryConditions( 74 | ( 75 | (cfd.boundaries.BCType.PERIODIC, cfd.boundaries.BCType.PERIODIC), 76 | (cfd.boundaries.BCType.PERIODIC, cfd.boundaries.BCType.PERIODIC), 77 | ) 78 | ) 79 | 80 | # reconstruct grid from input 81 | grid = cfd.grids.Grid( 82 | vx0.shape, domain=((0.0, domain_size_x), (0.0, domain_size_y)) 83 | ) 84 | 85 | vx0 = cfd.grids.GridArray(vx0, grid=grid, offset=(1.0, 0.5)) 86 | vy0 = cfd.grids.GridArray(vy0, grid=grid, offset=(0.5, 1.0)) 87 | 88 | # reconstruct GridVariable from input 89 | vx0 = cfd.grids.GridVariable(vx0, bc) 90 | vy0 = cfd.grids.GridVariable(vy0, bc) 91 | v0 = (vx0, vy0) 92 | 93 | # Choose a time step. 94 | dt = cfd.equations.stable_time_step( 95 | max_velocity, cfl_safety_factor, viscosity, grid 96 | ) 97 | 98 | # Define a step function and use it to compute a trajectory. 99 | step_fn = cfd.funcutils.repeated( 100 | cfd.equations.semi_implicit_navier_stokes( 101 | density=density, viscosity=viscosity, dt=dt, grid=grid 102 | ), 103 | steps=inner_steps, 104 | ) 105 | rollout_fn = cfd.funcutils.trajectory(step_fn, outer_steps) 106 | _, trajectory = jax.device_get(rollout_fn(v0)) 107 | vxn = trajectory[0].array.data[-1] 108 | vyn = trajectory[1].array.data[-1] 109 | return jnp.stack([vxn, vyn], axis=-1) 110 | 111 | 112 | @eqx.filter_jit 113 | def apply_jit(inputs: dict) -> dict: 114 | vn = cfd_fwd(**inputs) 115 | return dict(result=vn) 116 | 117 | 118 | def apply(inputs: InputSchema) -> OutputSchema: 119 | return apply_jit(inputs.model_dump()) 120 | 121 | 122 | def jacobian( 123 | inputs: InputSchema, 124 | jac_inputs: set[str], 125 | jac_outputs: set[str], 126 | ): 127 | return jac_jit(inputs.model_dump(), tuple(jac_inputs), tuple(jac_outputs)) 128 | 129 | 130 | def jacobian_vector_product( 131 | inputs: InputSchema, 132 | jvp_inputs: set[str], 133 | jvp_outputs: set[str], 134 | tangent_vector: dict[str, Any], 135 | ): 136 | return jvp_jit( 137 | inputs.model_dump(), 138 | tuple(jvp_inputs), 139 | tuple(jvp_outputs), 140 | tangent_vector, 141 | ) 142 | 143 | 144 | def vector_jacobian_product( 145 | inputs: InputSchema, 146 | vjp_inputs: set[str], 147 | vjp_outputs: set[str], 148 | cotangent_vector: dict[str, Any], 149 | ): 150 | return vjp_jit( 151 | inputs.model_dump(), 152 | tuple(vjp_inputs), 153 | tuple(vjp_outputs), 154 | cotangent_vector, 155 | ) 156 | 157 | 158 | def abstract_eval(abstract_inputs): 159 | """Calculate output shape of apply from the shape of its inputs.""" 160 | is_shapedtype_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"}) 161 | is_shapedtype_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct) 162 | 163 | jaxified_inputs = jax.tree.map( 164 | lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtype_dict(x) else x, 165 | abstract_inputs.model_dump(), 166 | is_leaf=is_shapedtype_dict, 167 | ) 168 | dynamic_inputs, static_inputs = eqx.partition( 169 | jaxified_inputs, filter_spec=is_shapedtype_struct 170 | ) 171 | 172 | def wrapped_apply(dynamic_inputs): 173 | inputs = eqx.combine(static_inputs, dynamic_inputs) 174 | return apply_jit(inputs) 175 | 176 | jax_shapes = jax.eval_shape(wrapped_apply, dynamic_inputs) 177 | return jax.tree.map( 178 | lambda x: ( 179 | {"shape": x.shape, "dtype": str(x.dtype)} if is_shapedtype_struct(x) else x 180 | ), 181 | jax_shapes, 182 | is_leaf=is_shapedtype_struct, 183 | ) 184 | 185 | 186 | @eqx.filter_jit 187 | def jac_jit( 188 | inputs: dict, 189 | jac_inputs: tuple[str], 190 | jac_outputs: tuple[str], 191 | ): 192 | filtered_apply = filter_func(apply_jit, inputs, jac_outputs) 193 | return jax.jacrev(filtered_apply)( 194 | flatten_with_paths(inputs, include_paths=jac_inputs) 195 | ) 196 | 197 | 198 | @eqx.filter_jit 199 | def jvp_jit( 200 | inputs: dict, jvp_inputs: tuple[str], jvp_outputs: tuple[str], tangent_vector: dict 201 | ): 202 | filtered_apply = filter_func(apply_jit, inputs, jvp_outputs) 203 | return jax.jvp( 204 | filtered_apply, 205 | [flatten_with_paths(inputs, include_paths=jvp_inputs)], 206 | [tangent_vector], 207 | )[1] 208 | 209 | 210 | @eqx.filter_jit 211 | def vjp_jit( 212 | inputs: dict, 213 | vjp_inputs: tuple[str], 214 | vjp_outputs: tuple[str], 215 | cotangent_vector: dict, 216 | ): 217 | filtered_apply = filter_func(apply_jit, inputs, vjp_outputs) 218 | _, vjp_func = jax.vjp( 219 | filtered_apply, flatten_with_paths(inputs, include_paths=vjp_inputs) 220 | ) 221 | return vjp_func(cotangent_vector)[0] 222 | -------------------------------------------------------------------------------- /examples/cfd/cfd-tesseract/tesseract_config.yaml: -------------------------------------------------------------------------------- 1 | name: jax-cfd 2 | version: "0.1.0" 3 | description: | 4 | Tesseract that runs a differentiable 2D Navier Stokes simulation on a 2D grid. 5 | 6 | build_config: 7 | package_data: [] 8 | custom_build_steps: [] 9 | -------------------------------------------------------------------------------- /examples/cfd/cfd-tesseract/tesseract_requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.26.4 2 | jax-cfd==0.2.1 3 | jax[cpu]==0.6.0 4 | equinox 5 | -------------------------------------------------------------------------------- /examples/cfd/pl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasteurlabs/tesseract-jax/35615fab02bce0cbf80fdc2fe8c05c6f3aa1c0ce/examples/cfd/pl.png -------------------------------------------------------------------------------- /examples/cfd/requirements.txt: -------------------------------------------------------------------------------- 1 | jax_cfd 2 | matplotlib 3 | scipy 4 | tqdm 5 | pillow 6 | -------------------------------------------------------------------------------- /examples/simple/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Basic example: vector addition\n", 8 | "\n", 9 | "
\n", 10 | "

Note

\n", 11 | "\n", 12 | "All examples are expected to run from the `examples/` directory of the [Tesseract-JAX repository](https://github.com/pasteurlabs/tesseract-jax).\n", 13 | "
\n", 14 | "\n", 15 | "Tesseract-JAX is a lightweight extension to [Tesseract Core](https://github.com/pasteurlabs/tesseract-core) that makes Tesseracts look and feel like regular [JAX](https://github.com/jax-ml/jax) primitives, and makes them jittable, differentiable, and composable.\n", 16 | "\n", 17 | "In this example, you will learn how to:\n", 18 | "1. Build a Tesseract that performs vector addition.\n", 19 | "1. Access its endpoints via Tesseract-JAX's `apply_tesseract()` function.\n", 20 | "1. Compose Tesseracts into more complex functions, blending multiple Tesseract applications with local operations.\n", 21 | "2. Apply `jax.jit` to the resulting pipeline to perform JIT compilation, and / or autodifferentiate the function (via `jax.grad`, `jax.jvp`, `jax.vjp`, ...)." 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## Step 1: Build + serve example Tesseract\n", 29 | "\n", 30 | "In this example, we build and use a Tesseract that performs vector addition. The example Tesseract takes two vectors and scalars as input and return some statistics as output. Here is the functionality that's implemented in the Tesseract (see `vectoradd_jax/tesseract_api.py`):\n", 31 | "\n", 32 | "```python\n", 33 | "def apply_jit(inputs: dict) -> dict:\n", 34 | " a_scaled = inputs[\"a\"][\"s\"] * inputs[\"a\"][\"v\"]\n", 35 | " b_scaled = inputs[\"b\"][\"s\"] * inputs[\"b\"][\"v\"]\n", 36 | " add_result = a_scaled + b_scaled\n", 37 | " min_result = a_scaled - b_scaled\n", 38 | "\n", 39 | " def safe_norm(x, ord):\n", 40 | " # Compute the norm of a vector, adding a small epsilon to ensure\n", 41 | " # differentiability and avoid division by zero\n", 42 | " return jnp.power(jnp.power(jnp.abs(x), ord).sum() + 1e-8, 1 / ord)\n", 43 | "\n", 44 | " return {\n", 45 | " \"vector_add\": {\n", 46 | " \"result\": add_result,\n", 47 | " \"normed_result\": add_result / safe_norm(add_result, ord=inputs[\"norm_ord\"]),\n", 48 | " },\n", 49 | " \"vector_min\": {\n", 50 | " \"result\": min_result,\n", 51 | " \"normed_result\": min_result / safe_norm(min_result, ord=inputs[\"norm_ord\"]),\n", 52 | " },\n", 53 | " }\n", 54 | "```\n", 55 | "\n", 56 | "You may build the example Tesseract either via the command line, or running the cell below (you can skip running this if already built)." 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 1, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "name": "stderr", 66 | "output_type": "stream", 67 | "text": [ 68 | "\u001b[2K \u001b[1;2m[\u001b[0m\u001b[34mi\u001b[0m\u001b[1;2m]\u001b[0m Building image \u001b[33m...\u001b[0m\n", 69 | "\u001b[2K\u001b[37m⠙\u001b[0m \u001b[37mProcessing\u001b[0m\n", 70 | "\u001b[1A\u001b[2K \u001b[1;2m[\u001b[0m\u001b[34mi\u001b[0m\u001b[1;2m]\u001b[0m Built image sh\u001b[1;92ma256:7ae8\u001b[0m5ba85970, \u001b[1m[\u001b[0m\u001b[32m'vectoradd_jax:latest'\u001b[0m\u001b[1m]\u001b[0m\n" 71 | ] 72 | }, 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "[\"vectoradd_jax:latest\"]\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "%%bash\n", 83 | "# Build vectoradd_jax Tesseract so we can use it below\n", 84 | "tesseract build vectoradd_jax/" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "To interact with the Tesseract, we use the Python SDK from `tesseract_core` to load the built image and start a server container." 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 2, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "from tesseract_core import Tesseract\n", 101 | "\n", 102 | "vectoradd = Tesseract.from_image(\"vectoradd_jax\")\n", 103 | "vectoradd.serve()" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "## Step 2: Invoke the Tesseract via Tesseract-JAX\n", 111 | "\n", 112 | "Using the `vectoradd_jax` Tesseract image we built earlier, let's add two vectors together, representing the following operation:\n", 113 | "\n", 114 | "$$\\begin{pmatrix} 1 \\\\ 2 \\\\ 3 \\end{pmatrix} + 2 \\cdot \\begin{pmatrix} 4 \\\\ 5 \\\\ 6 \\end{pmatrix} = \\begin{pmatrix} 9 \\\\ 12 \\\\ 15 \\end{pmatrix}$$\n", 115 | "\n", 116 | "We can perform this calculation using the function `tesseract_jax.apply_tesseract()`, by passing the `Tesseract` object and the input data as a PyTree (nested dictionary) of JAX arrays as inputs." 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 3, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "{'vector_add': {'normed_result': Array([0.42426407, 0.56568545, 0.70710677], dtype=float32),\n", 129 | " 'result': Array([ 9., 12., 15.], dtype=float32)},\n", 130 | " 'vector_min': {'normed_result': Array([-0.5025707 , -0.5743665 , -0.64616233], dtype=float32),\n", 131 | " 'result': Array([-7., -8., -9.], dtype=float32)}}\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "from pprint import pprint\n", 137 | "\n", 138 | "import jax\n", 139 | "import jax.numpy as jnp\n", 140 | "\n", 141 | "from tesseract_jax import apply_tesseract\n", 142 | "\n", 143 | "a = {\"v\": jnp.array([1.0, 2.0, 3.0], dtype=\"float32\")}\n", 144 | "b = {\n", 145 | " \"v\": jnp.array([4.0, 5.0, 6.0], dtype=\"float32\"),\n", 146 | " \"s\": jnp.array(2.0, dtype=\"float32\"),\n", 147 | "}\n", 148 | "\n", 149 | "outputs = apply_tesseract(vectoradd, inputs={\"a\": a, \"b\": b})\n", 150 | "pprint(outputs)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "As expected, `outputs['vector_add']` gives a value of $(9, 12, 15)$." 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "## Step 3: Function composition via Tesseracts\n", 165 | "\n", 166 | "Tesseract-JAX enables you to compose chains of Tesseract evaluations, blended with local operations, while retaining all the benefits of JAX.\n", 167 | "\n", 168 | "The function below applies `vectoradd` twice, *ie.* $(\\mathbf{a} + \\mathbf{b}) + \\mathbf{a}$, then performs local arithmetic on the outputs, applies `vectoradd` once more, and finally returns a single element of the result. The resulting function is still a valid JAX function, and is fully jittable and auto-differentiable." 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 4, 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "data": { 178 | "text/plain": [ 179 | "Array(16.135319, dtype=float32)" 180 | ] 181 | }, 182 | "execution_count": 4, 183 | "metadata": {}, 184 | "output_type": "execute_result" 185 | } 186 | ], 187 | "source": [ 188 | "def fancy_operation(a: jax.Array, b: jax.Array) -> jnp.float32:\n", 189 | " \"\"\"Fancy operation.\"\"\"\n", 190 | " result = apply_tesseract(vectoradd, inputs={\"a\": a, \"b\": b})\n", 191 | " result = apply_tesseract(\n", 192 | " vectoradd, inputs={\"a\": {\"v\": result[\"vector_add\"][\"result\"]}, \"b\": b}\n", 193 | " )\n", 194 | " # We can mix and match with local JAX operations\n", 195 | " result = 2.0 * result[\"vector_add\"][\"normed_result\"] + b[\"v\"]\n", 196 | " result = apply_tesseract(vectoradd, inputs={\"a\": {\"v\": result}, \"b\": b})\n", 197 | " return result[\"vector_add\"][\"result\"][1]\n", 198 | "\n", 199 | "\n", 200 | "fancy_operation(a, b)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "This is compatible with `jax.jit()`:" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 5, 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "data": { 217 | "text/plain": [ 218 | "Array(16.135319, dtype=float32)" 219 | ] 220 | }, 221 | "execution_count": 5, 222 | "metadata": {}, 223 | "output_type": "execute_result" 224 | } 225 | ], 226 | "source": [ 227 | "jitted_op = jax.jit(fancy_operation)\n", 228 | "jitted_op(a, b)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "Autodifferentiation is automatically dispatched to the underlying Tesseract's `jacobian_vector_product` and `vector_jacobian_product` endpoints, and works as expected:" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 6, 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "name": "stdout", 245 | "output_type": "stream", 246 | "text": [ 247 | "jax.grad result:\n", 248 | "({'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)},\n", 249 | " {'s': Array(5.002062, dtype=float32),\n", 250 | " 'v': Array([-0.05139923, 3.139905 , -0.08163408], dtype=float32)})\n", 251 | "\n", 252 | "jax.jvp result:\n", 253 | "Array(25.004124, dtype=float32)\n", 254 | "\n", 255 | "jax.vjp result:\n", 256 | "({'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)},\n", 257 | " {'s': Array(5.002062, dtype=float32),\n", 258 | " 'v': Array([-0.05139923, 3.139905 , -0.08163408], dtype=float32)})\n" 259 | ] 260 | } 261 | ], 262 | "source": [ 263 | "# jax.grad for reverse-mode autodiff (scalar outputs only)\n", 264 | "grad_res = jax.grad(fancy_operation, argnums=[0, 1])(a, b)\n", 265 | "print(\"jax.grad result:\")\n", 266 | "pprint(grad_res)\n", 267 | "\n", 268 | "# jax.jvp for general forward-mode autodiff\n", 269 | "_, jvp = jax.jvp(fancy_operation, (a, b), (a, b))\n", 270 | "print(\"\\njax.jvp result:\")\n", 271 | "pprint(jvp)\n", 272 | "\n", 273 | "# jax.vjp for general reverse-mode autodiff\n", 274 | "_, vjp_fn = jax.vjp(fancy_operation, a, b)\n", 275 | "vjp = vjp_fn(1.0)\n", 276 | "print(\"\\njax.vjp result:\")\n", 277 | "pprint(vjp)" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "All the above also works when combining with `jit`:" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 7, 290 | "metadata": {}, 291 | "outputs": [ 292 | { 293 | "name": "stdout", 294 | "output_type": "stream", 295 | "text": [ 296 | "jax.grad result:\n", 297 | "({'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)},\n", 298 | " {'s': Array(5.002062, dtype=float32),\n", 299 | " 'v': Array([-0.05139923, 3.139905 , -0.08163408], dtype=float32)})\n" 300 | ] 301 | } 302 | ], 303 | "source": [ 304 | "# jax.grad for reverse-mode autodiff (scalar output)\n", 305 | "grad_res = jax.jit(jax.grad(fancy_operation, argnums=[0, 1]))(a, b)\n", 306 | "print(\"jax.grad result:\")\n", 307 | "pprint(grad_res)" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "## Step N+1: Clean-up and conclusions" 315 | ] 316 | }, 317 | { 318 | "cell_type": "markdown", 319 | "metadata": {}, 320 | "source": [ 321 | "Since we kept the Tesseract alive using `.serve()`, we need to manually stop it using `.teardown()` to avoid leaking resources. \n", 322 | "\n", 323 | "This is not necessary when using `Tesseract` in a `with` statement, as it will automatically clean up when the context is exited." 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 8, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "vectoradd.teardown()" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": {}, 338 | "source": [ 339 | "And that's it! \n", 340 | "You've learned how to build up differentiable pipelines with Tesseracts that blend seamlessly with JAX's APIs and transformations." 341 | ] 342 | } 343 | ], 344 | "metadata": { 345 | "kernelspec": { 346 | "display_name": "science", 347 | "language": "python", 348 | "name": "python3" 349 | }, 350 | "language_info": { 351 | "codemirror_mode": { 352 | "name": "ipython", 353 | "version": 3 354 | }, 355 | "file_extension": ".py", 356 | "mimetype": "text/x-python", 357 | "name": "python", 358 | "nbconvert_exporter": "python", 359 | "pygments_lexer": "ipython3", 360 | "version": "3.12.7" 361 | } 362 | }, 363 | "nbformat": 4, 364 | "nbformat_minor": 4 365 | } 366 | -------------------------------------------------------------------------------- /examples/simple/vectoradd_jax/tesseract_api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pasteur Labs. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from typing import Any 5 | 6 | import equinox as eqx 7 | import jax 8 | import jax.numpy as jnp 9 | from pydantic import BaseModel, Field, model_validator 10 | from tesseract_core.runtime import Array, Differentiable, Float32 11 | from tesseract_core.runtime.tree_transforms import filter_func, flatten_with_paths 12 | from typing_extensions import Self 13 | 14 | 15 | class Vector_and_Scalar(BaseModel): 16 | v: Differentiable[Array[(None,), Float32]] = Field( 17 | description="An arbitrary vector" 18 | ) 19 | s: Differentiable[Float32] = Field(description="A scalar", default=1.0) 20 | 21 | def scale(self) -> Differentiable[Array[(None,), Float32]]: 22 | return self.s * self.v 23 | 24 | 25 | class InputSchema(BaseModel): 26 | a: Vector_and_Scalar = Field( 27 | description="An arbitrary vector and a scalar to multiply it by" 28 | ) 29 | b: Vector_and_Scalar = Field( 30 | description="An arbitrary vector and a scalar to multiply it by " 31 | "must be of same shape as b" 32 | ) 33 | norm_ord: int = Field( 34 | description="Order of norm (see numpy.linalg.norm)", 35 | default=2, 36 | ) 37 | 38 | @model_validator(mode="after") 39 | def validate_shape_inputs(self) -> Self: 40 | if self.a.v.shape != self.b.v.shape: 41 | raise ValueError( 42 | f"a.v and b.v must have the same shape. " 43 | f"Got {self.a.v.shape} and {self.b.v.shape} instead." 44 | ) 45 | return self 46 | 47 | 48 | class Result_and_Norm(BaseModel): 49 | result: Differentiable[Array[(None,), Float32]] = Field( 50 | description="Vector s_a·a + s_b·b" 51 | ) 52 | normed_result: Differentiable[Array[(None,), Float32]] = Field( 53 | description="Normalized Vector s_a·a + s_b·b/|s_a·a + s_b·b|" 54 | ) 55 | 56 | 57 | class OutputSchema(BaseModel): 58 | vector_add: Result_and_Norm 59 | vector_min: Result_and_Norm 60 | 61 | 62 | @eqx.filter_jit 63 | def apply_jit(inputs: dict) -> dict: 64 | a_scaled = inputs["a"]["s"] * inputs["a"]["v"] 65 | b_scaled = inputs["b"]["s"] * inputs["b"]["v"] 66 | add_result = a_scaled + b_scaled 67 | min_result = a_scaled - b_scaled 68 | 69 | def safe_norm(x, ord): 70 | # Compute the norm of a vector, adding a small epsilon to ensure 71 | # differentiability and avoid division by zero 72 | return jnp.power(jnp.power(jnp.abs(x), ord).sum() + 1e-8, 1 / ord) 73 | 74 | return { 75 | "vector_add": { 76 | "result": add_result, 77 | "normed_result": add_result / safe_norm(add_result, ord=inputs["norm_ord"]), 78 | }, 79 | "vector_min": { 80 | "result": min_result, 81 | "normed_result": min_result / safe_norm(min_result, ord=inputs["norm_ord"]), 82 | }, 83 | } 84 | 85 | 86 | def apply(inputs: InputSchema) -> OutputSchema: 87 | """Multiplies a vector `a` by `s`, and sums the result to `b`.""" 88 | return apply_jit(inputs.model_dump()) 89 | 90 | 91 | def abstract_eval(abstract_inputs): 92 | """Calculate output shape of apply from the shape of its inputs.""" 93 | is_shapedtype_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"}) 94 | is_shapedtype_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct) 95 | 96 | jaxified_inputs = jax.tree.map( 97 | lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtype_dict(x) else x, 98 | abstract_inputs.model_dump(), 99 | is_leaf=is_shapedtype_dict, 100 | ) 101 | dynamic_inputs, static_inputs = eqx.partition( 102 | jaxified_inputs, filter_spec=is_shapedtype_struct 103 | ) 104 | 105 | def wrapped_apply(dynamic_inputs): 106 | inputs = eqx.combine(static_inputs, dynamic_inputs) 107 | return apply_jit(inputs) 108 | 109 | jax_shapes = jax.eval_shape(wrapped_apply, dynamic_inputs) 110 | return jax.tree.map( 111 | lambda x: {"shape": x.shape, "dtype": str(x.dtype)} 112 | if is_shapedtype_struct(x) 113 | else x, 114 | jax_shapes, 115 | is_leaf=is_shapedtype_struct, 116 | ) 117 | 118 | 119 | def jacobian_vector_product( 120 | inputs: InputSchema, 121 | jvp_inputs: set[str], 122 | jvp_outputs: set[str], 123 | tangent_vector: dict[str, Any], 124 | ): 125 | return jvp_jit( 126 | inputs.model_dump(), 127 | tuple(jvp_inputs), 128 | tuple(jvp_outputs), 129 | tangent_vector, 130 | ) 131 | 132 | 133 | def vector_jacobian_product( 134 | inputs: InputSchema, 135 | vjp_inputs: set[str], 136 | vjp_outputs: set[str], 137 | cotangent_vector: dict[str, Any], 138 | ): 139 | return vjp_jit( 140 | inputs.model_dump(), 141 | tuple(vjp_inputs), 142 | tuple(vjp_outputs), 143 | cotangent_vector, 144 | ) 145 | 146 | 147 | def jacobian( 148 | inputs: InputSchema, 149 | jac_inputs: set[str], 150 | jac_outputs: set[str], 151 | ): 152 | return jac_jit(inputs.model_dump(), tuple(jac_inputs), tuple(jac_outputs)) 153 | 154 | 155 | @eqx.filter_jit 156 | def jvp_jit( 157 | inputs: dict, jvp_inputs: tuple[str], jvp_outputs: tuple[str], tangent_vector: dict 158 | ): 159 | filtered_apply = filter_func(apply_jit, inputs, jvp_outputs) 160 | return jax.jvp( 161 | filtered_apply, 162 | [flatten_with_paths(inputs, include_paths=jvp_inputs)], 163 | [tangent_vector], 164 | )[1] 165 | 166 | 167 | @eqx.filter_jit 168 | def vjp_jit( 169 | inputs: dict, 170 | vjp_inputs: tuple[str], 171 | vjp_outputs: tuple[str], 172 | cotangent_vector: dict, 173 | ): 174 | filtered_apply = filter_func(apply_jit, inputs, vjp_outputs) 175 | _, vjp_func = jax.vjp( 176 | filtered_apply, flatten_with_paths(inputs, include_paths=vjp_inputs) 177 | ) 178 | return vjp_func(cotangent_vector)[0] 179 | 180 | 181 | @eqx.filter_jit 182 | def jac_jit( 183 | inputs: dict, 184 | jac_inputs: tuple[str], 185 | jac_outputs: tuple[str], 186 | ): 187 | filtered_apply = filter_func(apply_jit, inputs, jac_outputs) 188 | return jax.jacrev(filtered_apply)( 189 | flatten_with_paths(inputs, include_paths=jac_inputs) 190 | ) 191 | -------------------------------------------------------------------------------- /examples/simple/vectoradd_jax/tesseract_config.yaml: -------------------------------------------------------------------------------- 1 | name: vectoradd_jax 2 | version: "2025/02/05" 3 | description: | 4 | Tesseract that adds/subtracts two vectors. Uses jax internally. 5 | 6 | build_config: 7 | target_platform: "native" 8 | # package_data: [] 9 | # custom_build_steps: [] 10 | -------------------------------------------------------------------------------- /examples/simple/vectoradd_jax/tesseract_requirements.txt: -------------------------------------------------------------------------------- 1 | jax[cpu] 2 | equinox 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "tesseract-jax" 3 | description = "Tesseract JAX executes Tesseracts as part of JAX programs, with full support for function transformations like JIT, `grad`, and more." 4 | readme = "README.md" 5 | authors = [ 6 | {name = "The tesseract-jax team @ Pasteur Labs", email = "info@simulation.science"}, 7 | ] 8 | requires-python = ">=3.10" 9 | # TODO: add your dependencies here *and* in requirements.txt 10 | dependencies = [ 11 | "jax", 12 | "tesseract-core", 13 | ] 14 | dynamic = ["version"] 15 | 16 | [project.urls] 17 | Homepage = "https://github.com/pasteurlabs/tesseract-jax" 18 | 19 | [project.optional-dependencies] 20 | docs = [ 21 | "sphinx", 22 | "sphinx_autodoc_typehints", 23 | "furo", 24 | "myst-nb", 25 | "sphinx_copybutton", 26 | "sphinxext_opengraph", 27 | ] 28 | # TODO: add dev dependencies here *and* in requirements-dev.txt 29 | dev = [ 30 | "pre-commit", 31 | "pytest", 32 | "pytest-cov", 33 | "typeguard", 34 | "requests", 35 | "tesseract-core[runtime]", 36 | "tesseract-jax[docs]", 37 | ] 38 | 39 | [build-system] 40 | requires = ["setuptools", "versioneer[toml]==0.29"] 41 | build-backend = "setuptools.build_meta" 42 | 43 | [tool.setuptools.packages.find] 44 | include = ["tesseract_jax", "tesseract_jax.*"] 45 | 46 | [tool.versioneer] 47 | VCS = "git" 48 | style = "pep440" 49 | versionfile_source = "tesseract_jax/_version.py" 50 | versionfile_build = "tesseract_jax/_version.py" 51 | tag_prefix = "v" 52 | parentdir_prefix = "tesseract_jax-" 53 | 54 | [tool.pytest.ini_options] 55 | addopts = ["--typeguard-packages=tesseract_jax"] 56 | testpaths = ["tests"] 57 | filterwarnings = [ 58 | "error", 59 | # ignored by default 60 | "ignore::DeprecationWarning", 61 | "ignore::PendingDeprecationWarning", 62 | "ignore::ImportWarning", 63 | # raised by Cython, usually harmless 64 | "ignore:numpy.dtype size changed", 65 | "ignore:numpy.ufunc size changed", 66 | # sometimes, dependencies leak resources 67 | "ignore:.*socket\\.socket.*:pytest.PytestUnraisableExceptionWarning", 68 | ] 69 | 70 | [tool.coverage.run] 71 | branch = true 72 | source = ["tesseract_jax"] 73 | 74 | [tool.coverage.report] 75 | exclude_lines = [ 76 | "pragma: no cover", 77 | "raise NotImplementedError", 78 | "if __name__ == .__main__.:", 79 | "pass", 80 | ] 81 | ignore_errors = true 82 | omit = [ 83 | "tesseract_jax/_version.py", 84 | ] 85 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv export --frozen --color never --no-emit-project --no-hashes 3 | annotated-types==0.7.0 4 | # via pydantic 5 | certifi==2025.4.26 6 | # via requests 7 | charset-normalizer==3.4.2 8 | # via requests 9 | click==8.1.8 10 | # via typer 11 | colorama==0.4.6 ; sys_platform == 'win32' 12 | # via click 13 | idna==3.10 14 | # via requests 15 | jax==0.6.1 16 | # via tesseract-jax 17 | jaxlib==0.6.1 18 | # via jax 19 | jinja2==3.1.6 20 | # via tesseract-core 21 | markdown-it-py==3.0.0 22 | # via rich 23 | markupsafe==3.0.2 24 | # via jinja2 25 | mdurl==0.1.2 26 | # via markdown-it-py 27 | ml-dtypes==0.5.1 28 | # via 29 | # jax 30 | # jaxlib 31 | numpy==2.2.4 32 | # via 33 | # jax 34 | # jaxlib 35 | # ml-dtypes 36 | # scipy 37 | # tesseract-core 38 | opt-einsum==3.4.0 39 | # via jax 40 | pip==25.1.1 41 | # via tesseract-core 42 | pydantic==2.11.0 43 | # via tesseract-core 44 | pydantic-core==2.33.0 45 | # via pydantic 46 | pygments==2.19.1 47 | # via rich 48 | pyyaml==6.0.2 49 | # via tesseract-core 50 | requests==2.32.3 51 | # via tesseract-core 52 | rich==14.0.0 53 | # via 54 | # tesseract-core 55 | # typer 56 | scipy==1.15.3 57 | # via 58 | # jax 59 | # jaxlib 60 | shellingham==1.5.4 61 | # via typer 62 | tesseract-core==0.9.0 63 | # via tesseract-jax 64 | typer==0.15.2 65 | # via tesseract-core 66 | typing-extensions==4.13.2 67 | # via 68 | # pydantic 69 | # pydantic-core 70 | # rich 71 | # typer 72 | # typing-inspection 73 | typing-inspection==0.4.1 74 | # via pydantic 75 | urllib3==2.4.0 76 | # via requests 77 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | # Set to the lowest supported Python version. 2 | target-version = "py310" 3 | 4 | # Set the target line length for formatting. 5 | line-length = 88 6 | 7 | # Exclude a variety of commonly ignored directories. 8 | extend-exclude = [ 9 | ".venv", 10 | "_version.py", 11 | ] 12 | 13 | # Also lint/format Jupyter Notebooks. 14 | extend-include = [ "*.ipynb" ] 15 | 16 | [lint] 17 | # Select and/or ignore rules for linting. 18 | # Full list of available rules: https://docs.astral.sh/ruff/rules/ 19 | extend-select = [ 20 | "ANN", # Type annotations 21 | "B", # Flake8 bugbear 22 | "D", # Pydocstyle 23 | "E", # Pycodestyle errors 24 | "F", # Pyflakes 25 | "I", # Isort 26 | "NPY", # Numpy 27 | "RUF", # Ruff-specific rules 28 | "UP", # Pyupgrade 29 | "W", # Pycodestyle warnings 30 | ] 31 | ignore = [ 32 | "E731", # Do not assign a lambda expression, use a def 33 | "D100", # Pydocstyle: missing docstring in public module 34 | "D104", # Pydocstyle: missing docstring in public package 35 | "D105", # Pydocstyle: missing docstring in magic method 36 | "D107", # Pydocstyle: missing docstring in __init__ 37 | "D203", # Pydocstyle: one blank line before class' docstring. Conflicts with D211 38 | "D213", # Pydocstyle: multiline docstring summary start on 2nd line. Conflicts with D212 39 | "ANN202", # Type annotations: missing return type for private functions 40 | "ANN401", # Type annotations: Any 41 | "F722", # Pyflakes: syntax error in type annotations 42 | ] 43 | 44 | [lint.extend-per-file-ignores] 45 | # Ignore missing docstrings and type annotations for selected directories 46 | "tests/*" = ["D101", "D102", "D103", "ANN"] 47 | "examples/*" = ["D101", "D102", "D103", "ANN"] 48 | 49 | [lint.pydocstyle] 50 | convention = "google" 51 | 52 | [lint.pycodestyle] 53 | max-line-length = 120 # Allow some flexibility in line lengths: up to 120 cols 54 | max-doc-length = 120 55 | 56 | [format] 57 | # Enable reformatting of code snippets in docstrings. 58 | docstring-code-format = true 59 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import versioneer 4 | from setuptools import setup 5 | 6 | setup( 7 | version=versioneer.get_version(), 8 | cmdclass=versioneer.get_cmdclass(), 9 | ) 10 | -------------------------------------------------------------------------------- /tesseract_jax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pasteur Labs. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from . import _version 5 | 6 | __version__ = _version.get_versions()["version"] 7 | 8 | # import public API of the package 9 | # SIDE EFFECT: Register Tesseract as a pytree node 10 | import jax 11 | from tesseract_core import Tesseract 12 | 13 | from tesseract_jax.primitive import apply_tesseract 14 | 15 | jax.tree_util.register_pytree_node( 16 | Tesseract, 17 | lambda x: ((), x), 18 | lambda x, _: x, 19 | ) 20 | del jax 21 | del Tesseract 22 | 23 | # add public API as strings here, for example __all__ = ["obj"] 24 | __all__ = [ 25 | "apply_tesseract", 26 | ] 27 | -------------------------------------------------------------------------------- /tesseract_jax/_version.py: -------------------------------------------------------------------------------- 1 | 2 | # This file helps to compute a version number in source trees obtained from 3 | # git-archive tarball (such as those provided by githubs download-from-tag 4 | # feature). Distribution tarballs (built by setup.py sdist) and build 5 | # directories (produced by setup.py build) will contain a much shorter file 6 | # that just contains the computed version number. 7 | 8 | # This file is released into the public domain. 9 | # Generated by versioneer-0.29 10 | # https://github.com/python-versioneer/python-versioneer 11 | 12 | """Git implementation of _version.py.""" 13 | 14 | import errno 15 | import os 16 | import re 17 | import subprocess 18 | import sys 19 | from typing import Any, Callable, Dict, List, Optional, Tuple 20 | import functools 21 | 22 | 23 | def get_keywords() -> Dict[str, str]: 24 | """Get the keywords needed to look up the version information.""" 25 | # these strings will be replaced by git during git-archive. 26 | # setup.py/versioneer.py will grep for the variable names, so they must 27 | # each be defined on a line of their own. _version.py will just call 28 | # get_keywords(). 29 | git_refnames = " (HEAD -> main)" 30 | git_full = "35615fab02bce0cbf80fdc2fe8c05c6f3aa1c0ce" 31 | git_date = "2025-06-02 10:37:02 +0200" 32 | keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} 33 | return keywords 34 | 35 | 36 | class VersioneerConfig: 37 | """Container for Versioneer configuration parameters.""" 38 | 39 | VCS: str 40 | style: str 41 | tag_prefix: str 42 | parentdir_prefix: str 43 | versionfile_source: str 44 | verbose: bool 45 | 46 | 47 | def get_config() -> VersioneerConfig: 48 | """Create, populate and return the VersioneerConfig() object.""" 49 | # these strings are filled in when 'setup.py versioneer' creates 50 | # _version.py 51 | cfg = VersioneerConfig() 52 | cfg.VCS = "git" 53 | cfg.style = "pep440" 54 | cfg.tag_prefix = "v" 55 | cfg.parentdir_prefix = "tesseract_jax-" 56 | cfg.versionfile_source = "tesseract_jax/_version.py" 57 | cfg.verbose = False 58 | return cfg 59 | 60 | 61 | class NotThisMethod(Exception): 62 | """Exception raised if a method is not valid for the current scenario.""" 63 | 64 | 65 | LONG_VERSION_PY: Dict[str, str] = {} 66 | HANDLERS: Dict[str, Dict[str, Callable]] = {} 67 | 68 | 69 | def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator 70 | """Create decorator to mark a method as the handler of a VCS.""" 71 | def decorate(f: Callable) -> Callable: 72 | """Store f in HANDLERS[vcs][method].""" 73 | if vcs not in HANDLERS: 74 | HANDLERS[vcs] = {} 75 | HANDLERS[vcs][method] = f 76 | return f 77 | return decorate 78 | 79 | 80 | def run_command( 81 | commands: List[str], 82 | args: List[str], 83 | cwd: Optional[str] = None, 84 | verbose: bool = False, 85 | hide_stderr: bool = False, 86 | env: Optional[Dict[str, str]] = None, 87 | ) -> Tuple[Optional[str], Optional[int]]: 88 | """Call the given command(s).""" 89 | assert isinstance(commands, list) 90 | process = None 91 | 92 | popen_kwargs: Dict[str, Any] = {} 93 | if sys.platform == "win32": 94 | # This hides the console window if pythonw.exe is used 95 | startupinfo = subprocess.STARTUPINFO() 96 | startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW 97 | popen_kwargs["startupinfo"] = startupinfo 98 | 99 | for command in commands: 100 | try: 101 | dispcmd = str([command] + args) 102 | # remember shell=False, so use git.cmd on windows, not just git 103 | process = subprocess.Popen([command] + args, cwd=cwd, env=env, 104 | stdout=subprocess.PIPE, 105 | stderr=(subprocess.PIPE if hide_stderr 106 | else None), **popen_kwargs) 107 | break 108 | except OSError as e: 109 | if e.errno == errno.ENOENT: 110 | continue 111 | if verbose: 112 | print("unable to run %s" % dispcmd) 113 | print(e) 114 | return None, None 115 | else: 116 | if verbose: 117 | print("unable to find command, tried %s" % (commands,)) 118 | return None, None 119 | stdout = process.communicate()[0].strip().decode() 120 | if process.returncode != 0: 121 | if verbose: 122 | print("unable to run %s (error)" % dispcmd) 123 | print("stdout was %s" % stdout) 124 | return None, process.returncode 125 | return stdout, process.returncode 126 | 127 | 128 | def versions_from_parentdir( 129 | parentdir_prefix: str, 130 | root: str, 131 | verbose: bool, 132 | ) -> Dict[str, Any]: 133 | """Try to determine the version from the parent directory name. 134 | 135 | Source tarballs conventionally unpack into a directory that includes both 136 | the project name and a version string. We will also support searching up 137 | two directory levels for an appropriately named parent directory 138 | """ 139 | rootdirs = [] 140 | 141 | for _ in range(3): 142 | dirname = os.path.basename(root) 143 | if dirname.startswith(parentdir_prefix): 144 | return {"version": dirname[len(parentdir_prefix):], 145 | "full-revisionid": None, 146 | "dirty": False, "error": None, "date": None} 147 | rootdirs.append(root) 148 | root = os.path.dirname(root) # up a level 149 | 150 | if verbose: 151 | print("Tried directories %s but none started with prefix %s" % 152 | (str(rootdirs), parentdir_prefix)) 153 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix") 154 | 155 | 156 | @register_vcs_handler("git", "get_keywords") 157 | def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: 158 | """Extract version information from the given file.""" 159 | # the code embedded in _version.py can just fetch the value of these 160 | # keywords. When used from setup.py, we don't want to import _version.py, 161 | # so we do it with a regexp instead. This function is not used from 162 | # _version.py. 163 | keywords: Dict[str, str] = {} 164 | try: 165 | with open(versionfile_abs, "r") as fobj: 166 | for line in fobj: 167 | if line.strip().startswith("git_refnames ="): 168 | mo = re.search(r'=\s*"(.*)"', line) 169 | if mo: 170 | keywords["refnames"] = mo.group(1) 171 | if line.strip().startswith("git_full ="): 172 | mo = re.search(r'=\s*"(.*)"', line) 173 | if mo: 174 | keywords["full"] = mo.group(1) 175 | if line.strip().startswith("git_date ="): 176 | mo = re.search(r'=\s*"(.*)"', line) 177 | if mo: 178 | keywords["date"] = mo.group(1) 179 | except OSError: 180 | pass 181 | return keywords 182 | 183 | 184 | @register_vcs_handler("git", "keywords") 185 | def git_versions_from_keywords( 186 | keywords: Dict[str, str], 187 | tag_prefix: str, 188 | verbose: bool, 189 | ) -> Dict[str, Any]: 190 | """Get version information from git keywords.""" 191 | if "refnames" not in keywords: 192 | raise NotThisMethod("Short version file found") 193 | date = keywords.get("date") 194 | if date is not None: 195 | # Use only the last line. Previous lines may contain GPG signature 196 | # information. 197 | date = date.splitlines()[-1] 198 | 199 | # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant 200 | # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 201 | # -like" string, which we must then edit to make compliant), because 202 | # it's been around since git-1.5.3, and it's too difficult to 203 | # discover which version we're using, or to work around using an 204 | # older one. 205 | date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) 206 | refnames = keywords["refnames"].strip() 207 | if refnames.startswith("$Format"): 208 | if verbose: 209 | print("keywords are unexpanded, not using") 210 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball") 211 | refs = {r.strip() for r in refnames.strip("()").split(",")} 212 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of 213 | # just "foo-1.0". If we see a "tag: " prefix, prefer those. 214 | TAG = "tag: " 215 | tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} 216 | if not tags: 217 | # Either we're using git < 1.8.3, or there really are no tags. We use 218 | # a heuristic: assume all version tags have a digit. The old git %d 219 | # expansion behaves like git log --decorate=short and strips out the 220 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish 221 | # between branches and tags. By ignoring refnames without digits, we 222 | # filter out many common branch names like "release" and 223 | # "stabilization", as well as "HEAD" and "master". 224 | tags = {r for r in refs if re.search(r'\d', r)} 225 | if verbose: 226 | print("discarding '%s', no digits" % ",".join(refs - tags)) 227 | if verbose: 228 | print("likely tags: %s" % ",".join(sorted(tags))) 229 | for ref in sorted(tags): 230 | # sorting will prefer e.g. "2.0" over "2.0rc1" 231 | if ref.startswith(tag_prefix): 232 | r = ref[len(tag_prefix):] 233 | # Filter out refs that exactly match prefix or that don't start 234 | # with a number once the prefix is stripped (mostly a concern 235 | # when prefix is '') 236 | if not re.match(r'\d', r): 237 | continue 238 | if verbose: 239 | print("picking %s" % r) 240 | return {"version": r, 241 | "full-revisionid": keywords["full"].strip(), 242 | "dirty": False, "error": None, 243 | "date": date} 244 | # no suitable tags, so version is "0+unknown", but full hex is still there 245 | if verbose: 246 | print("no suitable tags, using unknown + full revision id") 247 | return {"version": "0+unknown", 248 | "full-revisionid": keywords["full"].strip(), 249 | "dirty": False, "error": "no suitable tags", "date": None} 250 | 251 | 252 | @register_vcs_handler("git", "pieces_from_vcs") 253 | def git_pieces_from_vcs( 254 | tag_prefix: str, 255 | root: str, 256 | verbose: bool, 257 | runner: Callable = run_command 258 | ) -> Dict[str, Any]: 259 | """Get version from 'git describe' in the root of the source tree. 260 | 261 | This only gets called if the git-archive 'subst' keywords were *not* 262 | expanded, and _version.py hasn't already been rewritten with a short 263 | version string, meaning we're inside a checked out source tree. 264 | """ 265 | GITS = ["git"] 266 | if sys.platform == "win32": 267 | GITS = ["git.cmd", "git.exe"] 268 | 269 | # GIT_DIR can interfere with correct operation of Versioneer. 270 | # It may be intended to be passed to the Versioneer-versioned project, 271 | # but that should not change where we get our version from. 272 | env = os.environ.copy() 273 | env.pop("GIT_DIR", None) 274 | runner = functools.partial(runner, env=env) 275 | 276 | _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, 277 | hide_stderr=not verbose) 278 | if rc != 0: 279 | if verbose: 280 | print("Directory %s not under git control" % root) 281 | raise NotThisMethod("'git rev-parse --git-dir' returned error") 282 | 283 | # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] 284 | # if there isn't one, this yields HEX[-dirty] (no NUM) 285 | describe_out, rc = runner(GITS, [ 286 | "describe", "--tags", "--dirty", "--always", "--long", 287 | "--match", f"{tag_prefix}[[:digit:]]*" 288 | ], cwd=root) 289 | # --long was added in git-1.5.5 290 | if describe_out is None: 291 | raise NotThisMethod("'git describe' failed") 292 | describe_out = describe_out.strip() 293 | full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) 294 | if full_out is None: 295 | raise NotThisMethod("'git rev-parse' failed") 296 | full_out = full_out.strip() 297 | 298 | pieces: Dict[str, Any] = {} 299 | pieces["long"] = full_out 300 | pieces["short"] = full_out[:7] # maybe improved later 301 | pieces["error"] = None 302 | 303 | branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], 304 | cwd=root) 305 | # --abbrev-ref was added in git-1.6.3 306 | if rc != 0 or branch_name is None: 307 | raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") 308 | branch_name = branch_name.strip() 309 | 310 | if branch_name == "HEAD": 311 | # If we aren't exactly on a branch, pick a branch which represents 312 | # the current commit. If all else fails, we are on a branchless 313 | # commit. 314 | branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) 315 | # --contains was added in git-1.5.4 316 | if rc != 0 or branches is None: 317 | raise NotThisMethod("'git branch --contains' returned error") 318 | branches = branches.split("\n") 319 | 320 | # Remove the first line if we're running detached 321 | if "(" in branches[0]: 322 | branches.pop(0) 323 | 324 | # Strip off the leading "* " from the list of branches. 325 | branches = [branch[2:] for branch in branches] 326 | if "master" in branches: 327 | branch_name = "master" 328 | elif not branches: 329 | branch_name = None 330 | else: 331 | # Pick the first branch that is returned. Good or bad. 332 | branch_name = branches[0] 333 | 334 | pieces["branch"] = branch_name 335 | 336 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] 337 | # TAG might have hyphens. 338 | git_describe = describe_out 339 | 340 | # look for -dirty suffix 341 | dirty = git_describe.endswith("-dirty") 342 | pieces["dirty"] = dirty 343 | if dirty: 344 | git_describe = git_describe[:git_describe.rindex("-dirty")] 345 | 346 | # now we have TAG-NUM-gHEX or HEX 347 | 348 | if "-" in git_describe: 349 | # TAG-NUM-gHEX 350 | mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) 351 | if not mo: 352 | # unparsable. Maybe git-describe is misbehaving? 353 | pieces["error"] = ("unable to parse git-describe output: '%s'" 354 | % describe_out) 355 | return pieces 356 | 357 | # tag 358 | full_tag = mo.group(1) 359 | if not full_tag.startswith(tag_prefix): 360 | if verbose: 361 | fmt = "tag '%s' doesn't start with prefix '%s'" 362 | print(fmt % (full_tag, tag_prefix)) 363 | pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" 364 | % (full_tag, tag_prefix)) 365 | return pieces 366 | pieces["closest-tag"] = full_tag[len(tag_prefix):] 367 | 368 | # distance: number of commits since tag 369 | pieces["distance"] = int(mo.group(2)) 370 | 371 | # commit: short hex revision ID 372 | pieces["short"] = mo.group(3) 373 | 374 | else: 375 | # HEX: no tags 376 | pieces["closest-tag"] = None 377 | out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) 378 | pieces["distance"] = len(out.split()) # total number of commits 379 | 380 | # commit date: see ISO-8601 comment in git_versions_from_keywords() 381 | date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() 382 | # Use only the last line. Previous lines may contain GPG signature 383 | # information. 384 | date = date.splitlines()[-1] 385 | pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) 386 | 387 | return pieces 388 | 389 | 390 | def plus_or_dot(pieces: Dict[str, Any]) -> str: 391 | """Return a + if we don't already have one, else return a .""" 392 | if "+" in pieces.get("closest-tag", ""): 393 | return "." 394 | return "+" 395 | 396 | 397 | def render_pep440(pieces: Dict[str, Any]) -> str: 398 | """Build up version string, with post-release "local version identifier". 399 | 400 | Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you 401 | get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty 402 | 403 | Exceptions: 404 | 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] 405 | """ 406 | if pieces["closest-tag"]: 407 | rendered = pieces["closest-tag"] 408 | if pieces["distance"] or pieces["dirty"]: 409 | rendered += plus_or_dot(pieces) 410 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) 411 | if pieces["dirty"]: 412 | rendered += ".dirty" 413 | else: 414 | # exception #1 415 | rendered = "0+untagged.%d.g%s" % (pieces["distance"], 416 | pieces["short"]) 417 | if pieces["dirty"]: 418 | rendered += ".dirty" 419 | return rendered 420 | 421 | 422 | def render_pep440_branch(pieces: Dict[str, Any]) -> str: 423 | """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . 424 | 425 | The ".dev0" means not master branch. Note that .dev0 sorts backwards 426 | (a feature branch will appear "older" than the master branch). 427 | 428 | Exceptions: 429 | 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] 430 | """ 431 | if pieces["closest-tag"]: 432 | rendered = pieces["closest-tag"] 433 | if pieces["distance"] or pieces["dirty"]: 434 | if pieces["branch"] != "master": 435 | rendered += ".dev0" 436 | rendered += plus_or_dot(pieces) 437 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) 438 | if pieces["dirty"]: 439 | rendered += ".dirty" 440 | else: 441 | # exception #1 442 | rendered = "0" 443 | if pieces["branch"] != "master": 444 | rendered += ".dev0" 445 | rendered += "+untagged.%d.g%s" % (pieces["distance"], 446 | pieces["short"]) 447 | if pieces["dirty"]: 448 | rendered += ".dirty" 449 | return rendered 450 | 451 | 452 | def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: 453 | """Split pep440 version string at the post-release segment. 454 | 455 | Returns the release segments before the post-release and the 456 | post-release version number (or -1 if no post-release segment is present). 457 | """ 458 | vc = str.split(ver, ".post") 459 | return vc[0], int(vc[1] or 0) if len(vc) == 2 else None 460 | 461 | 462 | def render_pep440_pre(pieces: Dict[str, Any]) -> str: 463 | """TAG[.postN.devDISTANCE] -- No -dirty. 464 | 465 | Exceptions: 466 | 1: no tags. 0.post0.devDISTANCE 467 | """ 468 | if pieces["closest-tag"]: 469 | if pieces["distance"]: 470 | # update the post release segment 471 | tag_version, post_version = pep440_split_post(pieces["closest-tag"]) 472 | rendered = tag_version 473 | if post_version is not None: 474 | rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) 475 | else: 476 | rendered += ".post0.dev%d" % (pieces["distance"]) 477 | else: 478 | # no commits, use the tag as the version 479 | rendered = pieces["closest-tag"] 480 | else: 481 | # exception #1 482 | rendered = "0.post0.dev%d" % pieces["distance"] 483 | return rendered 484 | 485 | 486 | def render_pep440_post(pieces: Dict[str, Any]) -> str: 487 | """TAG[.postDISTANCE[.dev0]+gHEX] . 488 | 489 | The ".dev0" means dirty. Note that .dev0 sorts backwards 490 | (a dirty tree will appear "older" than the corresponding clean one), 491 | but you shouldn't be releasing software with -dirty anyways. 492 | 493 | Exceptions: 494 | 1: no tags. 0.postDISTANCE[.dev0] 495 | """ 496 | if pieces["closest-tag"]: 497 | rendered = pieces["closest-tag"] 498 | if pieces["distance"] or pieces["dirty"]: 499 | rendered += ".post%d" % pieces["distance"] 500 | if pieces["dirty"]: 501 | rendered += ".dev0" 502 | rendered += plus_or_dot(pieces) 503 | rendered += "g%s" % pieces["short"] 504 | else: 505 | # exception #1 506 | rendered = "0.post%d" % pieces["distance"] 507 | if pieces["dirty"]: 508 | rendered += ".dev0" 509 | rendered += "+g%s" % pieces["short"] 510 | return rendered 511 | 512 | 513 | def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: 514 | """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . 515 | 516 | The ".dev0" means not master branch. 517 | 518 | Exceptions: 519 | 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] 520 | """ 521 | if pieces["closest-tag"]: 522 | rendered = pieces["closest-tag"] 523 | if pieces["distance"] or pieces["dirty"]: 524 | rendered += ".post%d" % pieces["distance"] 525 | if pieces["branch"] != "master": 526 | rendered += ".dev0" 527 | rendered += plus_or_dot(pieces) 528 | rendered += "g%s" % pieces["short"] 529 | if pieces["dirty"]: 530 | rendered += ".dirty" 531 | else: 532 | # exception #1 533 | rendered = "0.post%d" % pieces["distance"] 534 | if pieces["branch"] != "master": 535 | rendered += ".dev0" 536 | rendered += "+g%s" % pieces["short"] 537 | if pieces["dirty"]: 538 | rendered += ".dirty" 539 | return rendered 540 | 541 | 542 | def render_pep440_old(pieces: Dict[str, Any]) -> str: 543 | """TAG[.postDISTANCE[.dev0]] . 544 | 545 | The ".dev0" means dirty. 546 | 547 | Exceptions: 548 | 1: no tags. 0.postDISTANCE[.dev0] 549 | """ 550 | if pieces["closest-tag"]: 551 | rendered = pieces["closest-tag"] 552 | if pieces["distance"] or pieces["dirty"]: 553 | rendered += ".post%d" % pieces["distance"] 554 | if pieces["dirty"]: 555 | rendered += ".dev0" 556 | else: 557 | # exception #1 558 | rendered = "0.post%d" % pieces["distance"] 559 | if pieces["dirty"]: 560 | rendered += ".dev0" 561 | return rendered 562 | 563 | 564 | def render_git_describe(pieces: Dict[str, Any]) -> str: 565 | """TAG[-DISTANCE-gHEX][-dirty]. 566 | 567 | Like 'git describe --tags --dirty --always'. 568 | 569 | Exceptions: 570 | 1: no tags. HEX[-dirty] (note: no 'g' prefix) 571 | """ 572 | if pieces["closest-tag"]: 573 | rendered = pieces["closest-tag"] 574 | if pieces["distance"]: 575 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 576 | else: 577 | # exception #1 578 | rendered = pieces["short"] 579 | if pieces["dirty"]: 580 | rendered += "-dirty" 581 | return rendered 582 | 583 | 584 | def render_git_describe_long(pieces: Dict[str, Any]) -> str: 585 | """TAG-DISTANCE-gHEX[-dirty]. 586 | 587 | Like 'git describe --tags --dirty --always -long'. 588 | The distance/hash is unconditional. 589 | 590 | Exceptions: 591 | 1: no tags. HEX[-dirty] (note: no 'g' prefix) 592 | """ 593 | if pieces["closest-tag"]: 594 | rendered = pieces["closest-tag"] 595 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 596 | else: 597 | # exception #1 598 | rendered = pieces["short"] 599 | if pieces["dirty"]: 600 | rendered += "-dirty" 601 | return rendered 602 | 603 | 604 | def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: 605 | """Render the given version pieces into the requested style.""" 606 | if pieces["error"]: 607 | return {"version": "unknown", 608 | "full-revisionid": pieces.get("long"), 609 | "dirty": None, 610 | "error": pieces["error"], 611 | "date": None} 612 | 613 | if not style or style == "default": 614 | style = "pep440" # the default 615 | 616 | if style == "pep440": 617 | rendered = render_pep440(pieces) 618 | elif style == "pep440-branch": 619 | rendered = render_pep440_branch(pieces) 620 | elif style == "pep440-pre": 621 | rendered = render_pep440_pre(pieces) 622 | elif style == "pep440-post": 623 | rendered = render_pep440_post(pieces) 624 | elif style == "pep440-post-branch": 625 | rendered = render_pep440_post_branch(pieces) 626 | elif style == "pep440-old": 627 | rendered = render_pep440_old(pieces) 628 | elif style == "git-describe": 629 | rendered = render_git_describe(pieces) 630 | elif style == "git-describe-long": 631 | rendered = render_git_describe_long(pieces) 632 | else: 633 | raise ValueError("unknown style '%s'" % style) 634 | 635 | return {"version": rendered, "full-revisionid": pieces["long"], 636 | "dirty": pieces["dirty"], "error": None, 637 | "date": pieces.get("date")} 638 | 639 | 640 | def get_versions() -> Dict[str, Any]: 641 | """Get version information or return default if unable to do so.""" 642 | # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have 643 | # __file__, we can work backwards from there to the root. Some 644 | # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which 645 | # case we can only use expanded keywords. 646 | 647 | cfg = get_config() 648 | verbose = cfg.verbose 649 | 650 | try: 651 | return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, 652 | verbose) 653 | except NotThisMethod: 654 | pass 655 | 656 | try: 657 | root = os.path.realpath(__file__) 658 | # versionfile_source is the relative path from the top of the source 659 | # tree (where the .git directory might live) to this file. Invert 660 | # this to find the root from __file__. 661 | for _ in cfg.versionfile_source.split('/'): 662 | root = os.path.dirname(root) 663 | except NameError: 664 | return {"version": "0+unknown", "full-revisionid": None, 665 | "dirty": None, 666 | "error": "unable to find root of source tree", 667 | "date": None} 668 | 669 | try: 670 | pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) 671 | return render(pieces, cfg.style) 672 | except NotThisMethod: 673 | pass 674 | 675 | try: 676 | if cfg.parentdir_prefix: 677 | return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) 678 | except NotThisMethod: 679 | pass 680 | 681 | return {"version": "0+unknown", "full-revisionid": None, 682 | "dirty": None, 683 | "error": "unable to compute version", "date": None} 684 | -------------------------------------------------------------------------------- /tesseract_jax/primitive.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pasteur Labs. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import functools 5 | from collections.abc import Sequence 6 | from typing import Any 7 | 8 | import jax.tree 9 | import numpy as np 10 | from jax import ShapeDtypeStruct, dtypes, extend 11 | from jax.core import ShapedArray 12 | from jax.interpreters import ad, mlir, xla 13 | from jax.tree_util import PyTreeDef 14 | from jax.typing import ArrayLike 15 | from tesseract_core import Tesseract 16 | 17 | from tesseract_jax.tesseract_compat import Jaxeract 18 | 19 | tesseract_dispatch_p = extend.core.Primitive("tesseract_dispatch") 20 | tesseract_dispatch_p.multiple_results = True 21 | tesseract_dispatch_p.def_impl( 22 | functools.partial(xla.apply_primitive, tesseract_dispatch_p) 23 | ) 24 | 25 | 26 | class _Hashable: 27 | def __init__(self, obj: Any) -> None: 28 | self.wrapped = obj 29 | 30 | def __hash__(self) -> int: 31 | try: 32 | return hash(self.wrapped) 33 | except TypeError: 34 | return id(self.wrapped) 35 | 36 | 37 | def split_args( 38 | flat_args: Sequence[Any], is_static_mask: Sequence[bool] 39 | ) -> tuple[tuple[ArrayLike, ...], tuple[_Hashable, ...]]: 40 | """Split a flat argument list into a tuple (array_args, static_args).""" 41 | static_args = tuple( 42 | _make_hashable(arg) 43 | for arg, is_static in zip(flat_args, is_static_mask, strict=True) 44 | if is_static 45 | ) 46 | array_args = tuple( 47 | arg 48 | for arg, is_static in zip(flat_args, is_static_mask, strict=True) # fmt: skip 49 | if not is_static 50 | ) 51 | 52 | return array_args, static_args 53 | 54 | 55 | @tesseract_dispatch_p.def_abstract_eval 56 | def tesseract_dispatch_abstract_eval( 57 | *array_args: ArrayLike | ShapedArray, 58 | static_args: tuple[_Hashable, ...], 59 | input_pytreedef: PyTreeDef, 60 | output_pytreedef: PyTreeDef, 61 | output_avals: tuple[ShapeDtypeStruct, ...], 62 | is_static_mask: tuple[bool, ...], 63 | client: Jaxeract, 64 | eval_func: str, 65 | ) -> tuple: 66 | """Define how to dispatch evals and pipe arguments.""" 67 | if eval_func not in ( 68 | "apply", 69 | "jacobian_vector_product", 70 | "vector_jacobian_product", 71 | ): 72 | raise NotImplementedError(eval_func) 73 | 74 | n_primals = len(is_static_mask) - sum(is_static_mask) 75 | 76 | if eval_func == "vector_jacobian_product": 77 | # We mustn't run forward evaluation of shapes, as out 78 | # of vjp has the same shapes as the primals; thus we can return early 79 | return tuple(array_args[:n_primals]) 80 | 81 | # Those have the same shape as the outputs 82 | assert eval_func in ("apply", "jacobian_vector_product") 83 | return tuple(jax.core.ShapedArray(aval.shape, aval.dtype) for aval in output_avals) 84 | 85 | 86 | def tesseract_dispatch_jvp_rule( 87 | in_args: tuple[ArrayLike, ...], 88 | tan_args: tuple[ArrayLike, ...], 89 | static_args: tuple[_Hashable, ...], 90 | input_pytreedef: PyTreeDef, 91 | output_pytreedef: PyTreeDef, 92 | output_avals: tuple[ShapeDtypeStruct, ...], 93 | is_static_mask: tuple[bool, ...], 94 | client: Jaxeract, 95 | eval_func: str, 96 | ) -> tuple[tuple[ArrayLike, ...], tuple[ArrayLike, ...]]: 97 | """Defines how to dispatch jvp operation.""" 98 | if eval_func != "apply": 99 | raise RuntimeError("Cannot take higher-order derivatives") 100 | 101 | # https://github.com/jax-ml/jax/issues/16303#issuecomment-1585295819 102 | # mattjj: taking a narrow pigeon-holed view, anywhere you see a symbolic 103 | # zero `Zero(AbstractToken)`, i.e. in a JVP or transpose rule 104 | # (not in ad.py's backward_pass), you probably want to instantiate 105 | # it so that it's no longer symbolic 106 | 107 | # TODO: create a mask for Zero (essentially, jvp_in)? or maybe substitute it 108 | # with something that jax still likes, while not wasting memory and time? 109 | 110 | tan_args_ = tuple( 111 | ( 112 | jax.numpy.zeros_like(arg.aval) 113 | if isinstance(arg, jax._src.ad_util.Zero) 114 | else arg 115 | ) 116 | for arg in tan_args 117 | ) 118 | 119 | jvp = tesseract_dispatch_p.bind( 120 | *in_args, 121 | *tan_args_, 122 | static_args=static_args, 123 | input_pytreedef=input_pytreedef, 124 | output_pytreedef=output_pytreedef, 125 | output_avals=output_avals, 126 | is_static_mask=is_static_mask, 127 | client=client, 128 | eval_func="jacobian_vector_product", 129 | ) 130 | 131 | res = tesseract_dispatch_p.bind( 132 | *in_args, 133 | static_args=static_args, 134 | input_pytreedef=input_pytreedef, 135 | output_pytreedef=output_pytreedef, 136 | output_avals=output_avals, 137 | is_static_mask=is_static_mask, 138 | client=client, 139 | eval_func="apply", 140 | ) 141 | 142 | return tuple(res), tuple(jvp) 143 | 144 | 145 | ad.primitive_jvps[tesseract_dispatch_p] = tesseract_dispatch_jvp_rule 146 | 147 | 148 | def tesseract_dispatch_transpose_rule( 149 | cotangent: Sequence[ArrayLike], 150 | *args: ArrayLike, 151 | static_args: tuple[_Hashable, ...], 152 | input_pytreedef: PyTreeDef, 153 | output_pytreedef: PyTreeDef, 154 | output_avals: tuple[ShapeDtypeStruct, ...], 155 | is_static_mask: tuple[bool, ...], 156 | client: Jaxeract, 157 | eval_func: str, 158 | ) -> tuple[ArrayLike | None, ...]: 159 | """Defines how to dispatch vjp operation.""" 160 | assert eval_func in ("jacobian_vector_product",) 161 | 162 | n_primals = len(is_static_mask) - sum(is_static_mask) 163 | args = args[:n_primals] 164 | 165 | cotan_args_ = tuple( 166 | ( 167 | jax.numpy.zeros_like(arg.aval) 168 | if isinstance(arg, jax._src.ad_util.Zero) 169 | else arg 170 | ) 171 | for arg in cotangent 172 | ) 173 | 174 | vjp = tesseract_dispatch_p.bind( 175 | *args, 176 | *cotan_args_, 177 | static_args=static_args, 178 | input_pytreedef=input_pytreedef, 179 | output_pytreedef=output_pytreedef, 180 | output_avals=output_avals, 181 | is_static_mask=is_static_mask, 182 | client=client, 183 | eval_func="vector_jacobian_product", 184 | ) 185 | # TODO: I'm not sure this makes sense given these docs: 186 | # https://jax.readthedocs.io/en/latest/jax-primitives.html#transposition 187 | # "A tuple with the cotangent of the inputs, with the value None corresponding to the constant arguments" 188 | # ...but if I provide only cotangent, jax complains, and if I investigate its internals, 189 | # I see it chokes on map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out), 190 | # where eqn.invars ends up being longer than cts_out. 191 | 192 | return tuple([None] * len(args) + vjp) 193 | 194 | 195 | ad.primitive_transposes[tesseract_dispatch_p] = tesseract_dispatch_transpose_rule 196 | 197 | 198 | def tesseract_dispatch_lowering( 199 | ctx: Any, 200 | *array_args: ArrayLike | ShapedArray | Any, 201 | static_args: tuple[_Hashable, ...], 202 | input_pytreedef: PyTreeDef, 203 | output_pytreedef: PyTreeDef, 204 | output_avals: tuple[ShapeDtypeStruct, ...], 205 | is_static_mask: tuple[bool, ...], 206 | client: Jaxeract, 207 | eval_func: str, 208 | ) -> Any: 209 | """Defines how to dispatch lowering the computation.""" 210 | 211 | def _dispatch(*args: ArrayLike) -> Any: 212 | static_args_ = tuple(_unpack_hashable(arg) for arg in static_args) 213 | out = getattr(client, eval_func)( 214 | args, 215 | static_args_, 216 | input_pytreedef, 217 | output_pytreedef, 218 | output_avals, 219 | is_static_mask, 220 | ) 221 | if not isinstance(out, tuple): 222 | out = (out,) 223 | return out 224 | 225 | result, _, keepalive = mlir.emit_python_callback( 226 | ctx, 227 | _dispatch, 228 | None, 229 | array_args, 230 | ctx.avals_in, 231 | ctx.avals_out, 232 | has_side_effect=False, 233 | ) 234 | ctx.module_context.add_keepalive(keepalive) 235 | return result 236 | 237 | 238 | mlir.register_lowering(tesseract_dispatch_p, tesseract_dispatch_lowering) 239 | 240 | 241 | def _check_dtype(dtype: Any) -> None: 242 | dt = np.dtype(dtype) 243 | if dtypes.canonicalize_dtype(dt) != dt: 244 | raise ValueError( 245 | "Cannot return 64-bit values when `jax_enable_x64` is disabled. " 246 | "Try enabling it with `jax.config.update('jax_enable_x64', True)`." 247 | ) 248 | 249 | 250 | def _is_static(x: Any) -> bool: 251 | if isinstance(x, jax.core.Tracer): 252 | return False 253 | return True 254 | 255 | 256 | def _make_hashable(obj: Any) -> _Hashable: 257 | return _Hashable(obj) 258 | 259 | 260 | def _unpack_hashable(obj: _Hashable) -> Any: 261 | return obj.wrapped 262 | 263 | 264 | def apply_tesseract( 265 | tesseract_client: Tesseract, 266 | inputs: Any, 267 | ) -> Any: 268 | """Applies the given Tesseract object to the inputs. 269 | 270 | This function is fully traceable and can be used in JAX transformations like 271 | jit, grad, etc. It will automatically dispatch to the appropriate Tesseract 272 | endpoint based on the requested operation. 273 | 274 | Example: 275 | >>> from tesseract_core import Tesseract 276 | >>> from tesseract_jax import apply_tesseract 277 | >>> 278 | >>> # Create a Tesseract object and some inputs 279 | >>> tesseract_client = Tesseract.from_image("univariate") 280 | >>> tesseract_client.serve() 281 | >>> inputs = {"x": jax.numpy.array(1.0), "y": jax.numpy.array(2.0)} 282 | >>> 283 | >>> # Apply the Tesseract object to the inputs 284 | >>> # (this calls tesseract_client.apply under the hood) 285 | >>> apply_tesseract(tesseract_client, inputs) 286 | {'result': Array(100., dtype=float64)} 287 | >>> 288 | >>> # Compute the gradient of the outputs with respect to the inputs 289 | >>> # (this calls tesseract_client.vector_jacobian_product under the hood) 290 | >>> def apply_fn(x): 291 | ... res = apply_tesseract(tesseract_client, x) 292 | ... return res["result"].sum() 293 | >>> grad_fn = jax.grad(apply_fn) 294 | >>> grad_fn(inputs) 295 | {'x': Array(-400., dtype=float64, weak_type=True), 'y': Array(200., dtype=float64, weak_type=True)} 296 | 297 | Args: 298 | tesseract_client: The Tesseract object to apply. 299 | inputs: The inputs to apply to the Tesseract object. 300 | 301 | Returns: 302 | The outputs of the Tesseract object after applying the inputs. 303 | """ 304 | if not isinstance(tesseract_client, Tesseract): 305 | raise TypeError( 306 | "The first argument must be a Tesseract object. " 307 | f"Got {type(tesseract_client)} instead." 308 | ) 309 | 310 | if "abstract_eval" not in tesseract_client.available_endpoints: 311 | raise ValueError( 312 | "Given Tesseract object does not support abstract_eval, " 313 | "which is required for compatibility with JAX." 314 | ) 315 | 316 | client = Jaxeract(tesseract_client) 317 | 318 | flat_args, input_pytreedef = jax.tree.flatten(inputs) 319 | is_static_mask = tuple(_is_static(arg) for arg in flat_args) 320 | array_args, static_args = split_args(flat_args, is_static_mask) 321 | 322 | # Get abstract values for outputs, so we can unflatten them later 323 | output_pytreedef, avals = None, None 324 | avals = client.abstract_eval( 325 | array_args, 326 | static_args, 327 | input_pytreedef, 328 | output_pytreedef, 329 | avals, 330 | is_static_mask, 331 | ) 332 | 333 | is_aval = lambda x: isinstance(x, dict) and "dtype" in x and "shape" in x 334 | flat_avals, output_pytreedef = jax.tree.flatten(avals, is_leaf=is_aval) 335 | for aval in flat_avals: 336 | if not is_aval(aval): 337 | continue 338 | _check_dtype(aval["dtype"]) 339 | 340 | flat_avals = tuple( 341 | jax.ShapeDtypeStruct(shape=tuple(aval["shape"]), dtype=aval["dtype"]) 342 | for aval in flat_avals 343 | ) 344 | 345 | # Apply the primitive 346 | out = tesseract_dispatch_p.bind( 347 | *array_args, 348 | static_args=static_args, 349 | input_pytreedef=input_pytreedef, 350 | output_pytreedef=output_pytreedef, 351 | output_avals=flat_avals, 352 | is_static_mask=is_static_mask, 353 | client=client, 354 | eval_func="apply", 355 | ) 356 | 357 | # Unflatten the output 358 | return jax.tree.unflatten(output_pytreedef, out) 359 | -------------------------------------------------------------------------------- /tesseract_jax/tesseract_compat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pasteur Labs. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from typing import Any, TypeAlias 5 | 6 | import jax.tree 7 | from jax import ShapeDtypeStruct 8 | from jax.tree_util import PyTreeDef 9 | from jax.typing import ArrayLike 10 | from tesseract_core import Tesseract 11 | 12 | PyTree: TypeAlias = Any 13 | 14 | 15 | def unflatten_args( 16 | array_args: tuple[ArrayLike, ...], 17 | static_args: tuple[Any, ...], 18 | input_pytreedef: PyTreeDef, 19 | is_static_mask: tuple[bool, ...], 20 | remove_static_args: bool = False, 21 | ) -> PyTree: 22 | """Unflatten lists of arguments (static and not) into a pytree.""" 23 | combined_args = [] 24 | static_iter = iter(static_args) 25 | array_iter = iter(array_args) 26 | 27 | for is_static in is_static_mask: 28 | if is_static: 29 | elem = next(static_iter) 30 | elem = elem.wrapped if hasattr(elem, "wrapped") else elem 31 | 32 | if remove_static_args: 33 | combined_args.append(None) 34 | else: 35 | combined_args.append(elem) 36 | 37 | else: 38 | combined_args.append(next(array_iter)) 39 | 40 | result = jax.tree.unflatten(input_pytreedef, combined_args) 41 | 42 | if remove_static_args: 43 | result = _prune_nones(result) 44 | 45 | return result 46 | 47 | 48 | def _prune_nones(tree: PyTree) -> PyTree: 49 | if isinstance(tree, dict): 50 | return {k: _prune_nones(v) for k, v in tree.items() if v is not None} 51 | elif isinstance(tree, tuple | list): 52 | return type(tree)(_prune_nones(v) for v in tree if v is not None) 53 | else: 54 | return tree 55 | 56 | 57 | def _pytree_to_tesseract_flat(pytree: PyTree) -> list[tuple]: 58 | leaves = jax.tree_util.tree_flatten_with_path(pytree)[0] 59 | 60 | flat_list = [] 61 | for jax_path, val in leaves: 62 | tesseract_path = "" 63 | first_elem = True 64 | for elem in jax_path: 65 | if hasattr(elem, "key"): 66 | if not first_elem: 67 | tesseract_path += "." 68 | tesseract_path += elem.key 69 | elif hasattr(elem, "idx"): 70 | tesseract_path += f"[{elem.idx}]" 71 | first_elem = False 72 | flat_list.append((tesseract_path, val)) 73 | 74 | return flat_list 75 | 76 | 77 | class Jaxeract: 78 | """A wrapper around a Tesseract client to make its signature compatible with JAX primitives.""" 79 | 80 | def __init__(self, tesseract_client: Tesseract) -> None: 81 | """Initialize the Tesseract client.""" 82 | self.client = tesseract_client 83 | 84 | self.tesseract_input_args = tuple( 85 | arg 86 | for arg in self.client.openapi_schema["components"]["schemas"][ 87 | "Apply_InputSchema" 88 | ]["properties"] 89 | ) 90 | # We need this to adhere to jax convention on tree flattening (sort keys alphabetically) 91 | # Only outermost level should be sufficient. 92 | self.tesseract_input_args = tuple(sorted(self.tesseract_input_args)) 93 | 94 | self.tesseract_output_args = tuple( 95 | arg 96 | for arg in self.client.openapi_schema["components"]["schemas"][ 97 | "Apply_OutputSchema" 98 | ]["properties"] 99 | ) 100 | 101 | self.differentiable_input_paths = self.client.input_schema[ 102 | "differentiable_arrays" 103 | ] 104 | 105 | self.differentiable_output_paths = self.client.output_schema[ 106 | "differentiable_arrays" 107 | ] 108 | 109 | self.available_methods = self.client.available_endpoints 110 | 111 | def abstract_eval( 112 | self, 113 | array_args: tuple[ArrayLike, ...], 114 | static_args: tuple[Any, ...], 115 | input_pytreedef: PyTreeDef, 116 | output_pytreedef: PyTreeDef | None, 117 | output_avals: tuple[ShapeDtypeStruct, ...] | None, 118 | is_static_mask: tuple[bool, ...], 119 | ) -> PyTree: 120 | """Run an abstract evaluation on a Tesseract. 121 | 122 | This used in order to get output shapes given input shapes. 123 | """ 124 | avals = unflatten_args(array_args, static_args, input_pytreedef, is_static_mask) 125 | 126 | abstract_inputs = jax.tree.map( 127 | lambda x: ( 128 | {"shape": x.shape, "dtype": x.dtype.name} if hasattr(x, "shape") else x 129 | ), 130 | avals, 131 | ) 132 | 133 | out_data = self.client.abstract_eval(abstract_inputs) 134 | return out_data 135 | 136 | def apply( 137 | self, 138 | array_args: tuple[ArrayLike, ...], 139 | static_args: tuple[Any, ...], 140 | input_pytreedef: PyTreeDef, 141 | output_pytreedef: PyTreeDef, 142 | output_avals: tuple[ShapeDtypeStruct, ...], 143 | is_static_mask: tuple[bool, ...], 144 | ) -> PyTree: 145 | """Call the Tesseract's apply endpoint with the given arguments.""" 146 | inputs = unflatten_args( 147 | array_args, static_args, input_pytreedef, is_static_mask 148 | ) 149 | 150 | out_data = self.client.apply(inputs) 151 | 152 | out_data = tuple(jax.tree.flatten(out_data)[0]) 153 | return out_data 154 | 155 | def jacobian_vector_product( 156 | self, 157 | array_args: tuple[ArrayLike, ...], 158 | static_args: tuple[Any, ...], 159 | input_pytreedef: PyTreeDef, 160 | output_pytreedef: PyTreeDef, 161 | output_avals: tuple[ShapeDtypeStruct, ...], 162 | is_static_mask: tuple[bool, ...], 163 | ) -> PyTree: 164 | """Call the Tesseract's jvp endpoint with the given arguments.""" 165 | n_primals = len(is_static_mask) - sum(is_static_mask) 166 | primals = array_args[:n_primals] 167 | tangents = array_args[n_primals:] 168 | 169 | primal_inputs = unflatten_args( 170 | primals, static_args, input_pytreedef, is_static_mask 171 | ) 172 | tangent_inputs = unflatten_args( 173 | tangents, 174 | static_args, 175 | input_pytreedef, 176 | is_static_mask, 177 | remove_static_args=True, 178 | ) 179 | 180 | flat_tangents = dict(_pytree_to_tesseract_flat(tangent_inputs)) 181 | 182 | jvp_inputs = list(flat_tangents.keys()) 183 | jvp_outputs = list(self.differentiable_output_paths.keys()) 184 | 185 | out_data = self.client.jacobian_vector_product( 186 | inputs=primal_inputs, 187 | jvp_inputs=jvp_inputs, 188 | jvp_outputs=jvp_outputs, 189 | tangent_vector=flat_tangents, 190 | ) 191 | 192 | paths = [ 193 | p 194 | for p, _ in _pytree_to_tesseract_flat( 195 | jax.tree.unflatten(output_pytreedef, range(len(output_avals))) 196 | ) 197 | ] 198 | 199 | out = [] 200 | for path, aval in zip(paths, output_avals, strict=False): 201 | if path in out_data: 202 | out.append(out_data[path]) 203 | else: 204 | out.append(jax.numpy.full_like(aval, jax.numpy.nan)) 205 | 206 | return tuple(out) 207 | 208 | def vector_jacobian_product( 209 | self, 210 | array_args: tuple[ArrayLike, ...], 211 | static_args: tuple[Any, ...], 212 | input_pytreedef: PyTreeDef, 213 | output_pytreedef: PyTreeDef, 214 | output_avals: tuple[ShapeDtypeStruct, ...], 215 | is_static_mask: tuple[bool, ...], 216 | ) -> PyTree: 217 | """Call the Tesseract's vjp endpoint with the given arguments.""" 218 | n_primals = len(is_static_mask) - sum(is_static_mask) 219 | primals = array_args[:n_primals] 220 | cotangents = array_args[n_primals:] 221 | 222 | primal_inputs = unflatten_args( 223 | primals, static_args, input_pytreedef, is_static_mask 224 | ) 225 | 226 | in_keys = [k for k, _ in _pytree_to_tesseract_flat(primal_inputs)] 227 | vjp_inputs = [o for o, m in zip(in_keys, is_static_mask, strict=True) if not m] 228 | vjp_outputs = list(self.differentiable_output_paths.keys()) 229 | 230 | paths = [ 231 | p 232 | for p, _ in _pytree_to_tesseract_flat( 233 | jax.tree.unflatten(output_pytreedef, range(len(output_avals))) 234 | ) 235 | ] 236 | 237 | cotangents_dict = {} 238 | 239 | for i, p in enumerate(paths): 240 | if p in vjp_outputs: 241 | cotangents_dict[p] = cotangents[i] 242 | 243 | out_data = self.client.vector_jacobian_product( 244 | inputs=primal_inputs, 245 | vjp_inputs=vjp_inputs, 246 | vjp_outputs=vjp_outputs, 247 | cotangent_vector=cotangents_dict, 248 | ) 249 | 250 | out_data = tuple(jax.tree.flatten(out_data)[0]) 251 | return out_data 252 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pasteur Labs. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import socket 6 | import subprocess 7 | import time 8 | from pathlib import Path 9 | 10 | import jax 11 | import pytest 12 | import requests 13 | 14 | here = Path(__file__).parent 15 | 16 | jax.config.update("jax_enable_x64", True) 17 | 18 | 19 | def get_tesseract_folders(): 20 | tesseract_folders = [ 21 | "univariate_tesseract", 22 | "nested_tesseract", 23 | # Add more as needed 24 | ] 25 | return tesseract_folders 26 | 27 | 28 | def find_free_port(): 29 | """Find a free port to use for the test server.""" 30 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 31 | s.bind(("localhost", 0)) 32 | return s.getsockname()[1] 33 | 34 | 35 | def make_tesseract_fixture(folder_name): 36 | """Factory function to create tesseract fixtures for different folders. 37 | 38 | This fixture serves a Tesseract via `tesseract-runtime` for the specified folder. 39 | This way, we can test with real Tesseracts without a running Docker daemon. 40 | """ 41 | 42 | @pytest.fixture(scope="session") 43 | def served_tesseract(): 44 | port = find_free_port() 45 | timeout = 10 46 | 47 | env = os.environ.copy() 48 | env["TESSERACT_API_PATH"] = str(here / folder_name / "tesseract_api.py") 49 | 50 | # Start the server as a subprocess 51 | process = subprocess.Popen( 52 | [ 53 | "tesseract-runtime", 54 | "serve", 55 | "--host", 56 | "localhost", 57 | "--port", 58 | str(port), 59 | ], 60 | env=env, 61 | stdout=subprocess.PIPE, 62 | stderr=subprocess.PIPE, 63 | ) 64 | 65 | try: 66 | start_time = time.time() 67 | while True: 68 | try: 69 | requests.get(f"http://localhost:{port}/health") 70 | break 71 | except requests.exceptions.ConnectionError as exc: 72 | if time.time() - start_time > timeout: 73 | raise TimeoutError( 74 | f"Tesseract for {folder_name} did not start in time" 75 | ) from exc 76 | time.sleep(0.1) 77 | 78 | yield f"http://localhost:{port}" 79 | finally: 80 | process.terminate() 81 | process.communicate() 82 | 83 | return served_tesseract 84 | 85 | 86 | served_univariate_tesseract_raw = make_tesseract_fixture("univariate_tesseract") 87 | served_nested_tesseract_raw = make_tesseract_fixture("nested_tesseract") 88 | -------------------------------------------------------------------------------- /tests/nested_tesseract/tesseract_api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pasteur Labs. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from pydantic import BaseModel, Field 5 | from tesseract_core.runtime import Array, Differentiable, Float32 6 | 7 | 8 | class Scalars(BaseModel): 9 | a: Differentiable[Float32] = Field(description="Scalar value a.", default=0.0) 10 | b: Float32 = Field(description="Scalar value b.") 11 | 12 | 13 | class Vectors(BaseModel): 14 | v: Differentiable[Array[(None,), Float32]] = Field(description="Vector value v.") 15 | w: Array[(3,), Float32] = Field( 16 | description="Vector value w.", default=[0.0, 1.0, 2.0] 17 | ) 18 | 19 | 20 | class OtherStuff(BaseModel): 21 | s: str = Field(description="String value s.") 22 | i: int = Field(description="Integer value i.") 23 | f: float = Field(description="Float value f.") 24 | 25 | 26 | class InputSchema(BaseModel): 27 | scalars: Scalars 28 | vectors: Vectors 29 | other_stuff: OtherStuff 30 | 31 | 32 | class OutputSchema(BaseModel): 33 | scalars: Scalars 34 | vectors: Vectors 35 | 36 | 37 | def apply(inputs: InputSchema) -> OutputSchema: 38 | a = inputs.scalars.a 39 | b = inputs.scalars.b 40 | v = inputs.vectors.v 41 | w = inputs.vectors.w 42 | 43 | new_a = a * 10 + b 44 | new_v = v * 10 + w 45 | 46 | scalars = Scalars(a=new_a, b=b) 47 | vectors = Vectors(v=new_v, w=w) 48 | return OutputSchema(scalars=scalars, vectors=vectors) 49 | 50 | 51 | # 52 | # Optional endpoints 53 | # 54 | 55 | 56 | def jacobian_vector_product( 57 | inputs: InputSchema, 58 | jvp_inputs: set[str], 59 | jvp_outputs: set[str], 60 | tangent_vector, 61 | ): 62 | out = {dy: 0.0 for dy in jvp_outputs} 63 | if "scalars.a" in jvp_inputs and "scalars.a" in jvp_outputs: 64 | out["scalars.a"] = 10.0 * tangent_vector["scalars.a"] 65 | if "vectors.v" in jvp_inputs and "vectors.v" in jvp_outputs: 66 | out["vectors.v"] = 10.0 * tangent_vector["vectors.v"] 67 | return out 68 | 69 | 70 | def vector_jacobian_product( 71 | inputs: InputSchema, 72 | vjp_inputs: set[str], 73 | vjp_outputs: set[str], 74 | cotangent_vector, 75 | ): 76 | out = {dx: 0.0 for dx in vjp_inputs} 77 | if "scalars.a" in vjp_inputs and "scalars.a" in vjp_outputs: 78 | out["scalars.a"] = 10.0 * cotangent_vector["scalars.a"] 79 | if "vectors.v" in vjp_inputs and "vectors.v" in vjp_outputs: 80 | out["vectors.v"] = 10.0 * cotangent_vector["vectors.v"] 81 | return out 82 | 83 | 84 | def abstract_eval(abstract_inputs): 85 | """Calculate output shape of apply from the shape of its inputs.""" 86 | return { 87 | "scalars": { 88 | "a": abstract_inputs.scalars.a, 89 | "b": abstract_inputs.scalars.b, 90 | }, 91 | "vectors": { 92 | "v": abstract_inputs.vectors.v, 93 | "w": abstract_inputs.vectors.w, 94 | }, 95 | } 96 | -------------------------------------------------------------------------------- /tests/nested_tesseract/tesseract_config.yaml: -------------------------------------------------------------------------------- 1 | name: "nested" 2 | version: "0.1.0" 3 | description: | 4 | Unit tesseract that evaluates some functions of (nested) inputs and returns them as (nested) outputs 5 | -------------------------------------------------------------------------------- /tests/nested_tesseract/tesseract_requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasteurlabs/tesseract-jax/35615fab02bce0cbf80fdc2fe8c05c6f3aa1c0ce/tests/nested_tesseract/tesseract_requirements.txt -------------------------------------------------------------------------------- /tests/test_endtoend.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pasteur Labs. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import jax 5 | import numpy as np 6 | import pytest 7 | from jax.typing import ArrayLike 8 | from tesseract_core import Tesseract 9 | 10 | from tesseract_jax import apply_tesseract 11 | 12 | 13 | def _assert_pytree_isequal(a, b, rtol=None, atol=None): 14 | """Check if two PyTrees are equal.""" 15 | a_flat, a_structure = jax.tree.flatten_with_path(a) 16 | b_flat, b_structure = jax.tree.flatten_with_path(b) 17 | 18 | if a_structure != b_structure: 19 | raise AssertionError( 20 | f"PyTree structures are different:\n{a_structure}\n{b_structure}" 21 | ) 22 | 23 | if rtol is not None or atol is not None: 24 | array_compare = lambda x, y: np.testing.assert_allclose( 25 | x, y, rtol=rtol, atol=atol 26 | ) 27 | else: 28 | array_compare = lambda x, y: np.testing.assert_array_equal(x, y) 29 | 30 | failures = [] 31 | for (a_path, a_elem), (b_path, b_elem) in zip(a_flat, b_flat, strict=True): 32 | assert a_path == b_path, f"Unexpected path mismatch: {a_path} != {b_path}" 33 | try: 34 | if isinstance(a_elem, ArrayLike) or isinstance(b_elem, ArrayLike): 35 | array_compare(a_elem, b_elem) 36 | else: 37 | assert a_elem == b_elem, f"Values are different: {a_elem} != {b_elem}" 38 | except AssertionError as e: 39 | failures.append(a_path, str(e)) 40 | 41 | if failures: 42 | msg = "\n".join(f"Path: {path}, Error: {error}" for path, error in failures) 43 | raise AssertionError(f"PyTree elements are different:\n{msg}") 44 | 45 | 46 | def rosenbrock_impl(x, y, a=1.0, b=100.0): 47 | """JAX-traceable version of the Rosenbrock function used by univariate_tesseract.""" 48 | return (a - x) ** 2 + b * (y - x**2) ** 2 49 | 50 | 51 | @pytest.mark.parametrize("use_jit", [True, False]) 52 | def test_univariate_tesseract_apply(served_univariate_tesseract_raw, use_jit): 53 | rosenbrock_tess = Tesseract(served_univariate_tesseract_raw) 54 | x, y = np.array(0.0), np.array(0.0) 55 | 56 | def f(x, y): 57 | return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y)) 58 | 59 | rosenbrock_raw = rosenbrock_impl 60 | if use_jit: 61 | f = jax.jit(f) 62 | rosenbrock_raw = jax.jit(rosenbrock_raw) 63 | 64 | # Test against Tesseract client 65 | result = f(x, y) 66 | result_ref = rosenbrock_tess.apply(dict(x=x, y=y)) 67 | _assert_pytree_isequal(result, result_ref) 68 | 69 | # Test against direct implementation 70 | result_raw = rosenbrock_raw(x, y) 71 | np.testing.assert_array_equal(result["result"], result_raw) 72 | 73 | 74 | @pytest.mark.parametrize("use_jit", [True, False]) 75 | def test_univariate_tesseract_jvp(served_univariate_tesseract_raw, use_jit): 76 | rosenbrock_tess = Tesseract(served_univariate_tesseract_raw) 77 | 78 | # make things callable without keyword args 79 | def f(x, y): 80 | return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y)) 81 | 82 | rosenbrock_raw = rosenbrock_impl 83 | if use_jit: 84 | f = jax.jit(f) 85 | rosenbrock_raw = jax.jit(rosenbrock_raw) 86 | 87 | x, y = np.array(0.0), np.array(0.0) 88 | dx, dy = np.array(1.0), np.array(0.0) 89 | (primal, jvp) = jax.jvp(f, (x, y), (dx, dy)) 90 | 91 | # Test against Tesseract client 92 | primal_ref = rosenbrock_tess.apply(dict(x=x, y=y)) 93 | _assert_pytree_isequal(primal, primal_ref) 94 | 95 | jvp_ref = rosenbrock_tess.jacobian_vector_product( 96 | inputs=dict(x=x, y=y), 97 | jvp_inputs=["x", "y"], 98 | jvp_outputs=["result"], 99 | tangent_vector=dict(x=dx, y=dy), 100 | ) 101 | _assert_pytree_isequal(jvp, jvp_ref) 102 | 103 | # Test against direct implementation 104 | _, jvp_raw = jax.jvp(rosenbrock_raw, (x, y), (dx, dy)) 105 | np.testing.assert_array_equal(jvp["result"], jvp_raw) 106 | 107 | 108 | @pytest.mark.parametrize("use_jit", [True, False]) 109 | def test_univariate_tesseract_vjp(served_univariate_tesseract_raw, use_jit): 110 | rosenbrock_tess = Tesseract(served_univariate_tesseract_raw) 111 | 112 | def f(x, y): 113 | return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y)) 114 | 115 | rosenbrock_raw = rosenbrock_impl 116 | if use_jit: 117 | f = jax.jit(f) 118 | rosenbrock_raw = jax.jit(rosenbrock_raw) 119 | 120 | x, y = np.array(0.0), np.array(0.0) 121 | (primal, f_vjp) = jax.vjp(f, x, y) 122 | 123 | if use_jit: 124 | f_vjp = jax.jit(f_vjp) 125 | 126 | vjp = f_vjp(primal) 127 | 128 | # Test against Tesseract client 129 | primal_ref = rosenbrock_tess.apply(dict(x=x, y=y)) 130 | _assert_pytree_isequal(primal, primal_ref) 131 | 132 | vjp_ref = rosenbrock_tess.vector_jacobian_product( 133 | inputs=dict(x=x, y=y), 134 | vjp_inputs=["x", "y"], 135 | vjp_outputs=["result"], 136 | cotangent_vector=primal_ref, 137 | ) 138 | # JAX vjp returns a flat tuple, so unflatten it to match the Tesseract output (dict with keys vjp_inputs) 139 | vjp = {"x": vjp[0], "y": vjp[1]} 140 | _assert_pytree_isequal(vjp, vjp_ref) 141 | 142 | # Test against direct implementation 143 | primal_raw, f_vjp_raw = jax.vjp(rosenbrock_raw, x, y) 144 | if use_jit: 145 | f_vjp_raw = jax.jit(f_vjp_raw) 146 | vjp_raw = f_vjp_raw(primal_raw) 147 | vjp_raw = {"x": vjp_raw[0], "y": vjp_raw[1]} 148 | _assert_pytree_isequal(vjp, vjp_raw) 149 | 150 | 151 | @pytest.mark.parametrize("use_jit", [True, False]) 152 | def test_nested_tesseract_apply(served_nested_tesseract_raw, use_jit): 153 | nested_tess = Tesseract(served_nested_tesseract_raw) 154 | a, b = np.array(1.0, dtype="float32"), np.array(2.0, dtype="float32") 155 | v, w = ( 156 | np.array([1.0, 2.0, 3.0], dtype="float32"), 157 | np.array([5.0, 7.0, 9.0], dtype="float32"), 158 | ) 159 | 160 | def f(a, v, s, i): 161 | return apply_tesseract( 162 | nested_tess, 163 | inputs={ 164 | "scalars": {"a": a, "b": b}, 165 | "vectors": {"v": v, "w": w}, 166 | "other_stuff": {"s": s, "i": i, "f": 2.718}, 167 | }, 168 | ) 169 | 170 | if use_jit: 171 | f = jax.jit(f, static_argnames=["s", "i"]) 172 | 173 | result = f(a, v, "hello", 3) 174 | result_ref = nested_tess.apply( 175 | inputs={ 176 | "scalars": {"a": a, "b": b}, 177 | "vectors": {"v": v, "w": w}, 178 | "other_stuff": {"s": "hello", "i": 3, "f": 2.718}, 179 | } 180 | ) 181 | _assert_pytree_isequal(result, result_ref) 182 | 183 | 184 | @pytest.mark.parametrize("use_jit", [True, False]) 185 | def test_nested_tesseract_jvp(served_nested_tesseract_raw, use_jit): 186 | nested_tess = Tesseract(served_nested_tesseract_raw) 187 | a, b = np.array(1.0, dtype="float32"), np.array(2.0, dtype="float32") 188 | v, w = ( 189 | np.array([1.0, 2.0, 3.0], dtype="float32"), 190 | np.array([5.0, 7.0, 9.0], dtype="float32"), 191 | ) 192 | 193 | def f(a, v): 194 | return apply_tesseract( 195 | nested_tess, 196 | inputs=dict( 197 | scalars={"a": a, "b": b}, 198 | vectors={"v": v, "w": w}, 199 | other_stuff={"s": "hey!", "i": 1234, "f": 2.718}, 200 | ), 201 | ) 202 | 203 | if use_jit: 204 | f = jax.jit(f) 205 | 206 | (primal, jvp) = jax.jvp(f, (a, v), (a, v)) 207 | 208 | primal_ref = nested_tess.apply( 209 | inputs=dict( 210 | scalars={"a": a, "b": b}, 211 | vectors={"v": v, "w": w}, 212 | other_stuff={"s": "hey!", "i": 1234, "f": 2.718}, 213 | ) 214 | ) 215 | _assert_pytree_isequal(primal, primal_ref) 216 | 217 | jvp_ref = nested_tess.jacobian_vector_product( 218 | inputs=dict( 219 | scalars={"a": a, "b": b}, 220 | vectors={"v": v, "w": w}, 221 | other_stuff={"s": "hey!", "i": 1234, "f": 2.718}, 222 | ), 223 | jvp_inputs=["scalars.a", "vectors.v"], 224 | jvp_outputs=["scalars.a", "vectors.v"], 225 | tangent_vector={"scalars.a": a, "vectors.v": v}, 226 | ) 227 | # JAX returns a nested dict, so we need to flatten it to match the Tesseract output (dict with keys jvp_outputs) 228 | jvp = {"scalars.a": jvp["scalars"]["a"], "vectors.v": jvp["vectors"]["v"]} 229 | _assert_pytree_isequal(jvp, jvp_ref) 230 | 231 | 232 | @pytest.mark.parametrize("use_jit", [True, False]) 233 | def test_nested_tesseract_vjp(served_nested_tesseract_raw, use_jit): 234 | nested_tess = Tesseract(served_nested_tesseract_raw) 235 | 236 | a, b = np.array(1.0, dtype="float32"), np.array(2.0, dtype="float32") 237 | v, w = ( 238 | np.array([1.0, 2.0, 3.0], dtype="float32"), 239 | np.array([5.0, 7.0, 9.0], dtype="float32"), 240 | ) 241 | 242 | def f(a, v): 243 | return apply_tesseract( 244 | nested_tess, 245 | inputs=dict( 246 | scalars={"a": a, "b": b}, 247 | vectors={"v": v, "w": w}, 248 | other_stuff={"s": "hey!", "i": 1234, "f": 2.718}, 249 | ), 250 | ) 251 | 252 | if use_jit: 253 | f = jax.jit(f) 254 | 255 | (primal, f_vjp) = jax.vjp(f, a, v) 256 | 257 | if use_jit: 258 | f_vjp = jax.jit(f_vjp) 259 | 260 | vjp = f_vjp(primal) 261 | 262 | primal_ref = nested_tess.apply( 263 | inputs=dict( 264 | scalars={"a": a, "b": b}, 265 | vectors={"v": v, "w": w}, 266 | other_stuff={"s": "hey!", "i": 1234, "f": 2.718}, 267 | ) 268 | ) 269 | _assert_pytree_isequal(primal, primal_ref) 270 | 271 | vjp_ref = nested_tess.vector_jacobian_product( 272 | inputs=dict( 273 | scalars={"a": a, "b": b}, 274 | vectors={"v": v, "w": w}, 275 | other_stuff={"s": "hey!", "i": 1234, "f": 2.718}, 276 | ), 277 | vjp_inputs=["scalars.a", "vectors.v"], 278 | vjp_outputs=["scalars.a", "vectors.v"], 279 | cotangent_vector={ 280 | "scalars.a": primal_ref["scalars"]["a"], 281 | "vectors.v": primal_ref["vectors"]["v"], 282 | }, 283 | ) 284 | # JAX vjp returns a flat tuple, so unflatten it to match the Tesseract output (dict with keys vjp_inputs) 285 | vjp = {"scalars.a": vjp[0], "vectors.v": vjp[1]} 286 | _assert_pytree_isequal(vjp, vjp_ref) 287 | 288 | 289 | @pytest.mark.parametrize("use_jit", [True, False]) 290 | def test_partial_differentiation(served_univariate_tesseract_raw, use_jit): 291 | """Test that differentiation works correctly in cases where some inputs are constants.""" 292 | rosenbrock_tess = Tesseract(served_univariate_tesseract_raw) 293 | x, y = np.array(0.0), np.array(0.0) 294 | 295 | def f(y): 296 | return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y))["result"] 297 | 298 | if use_jit: 299 | f = jax.jit(f) 300 | 301 | # Test forward application 302 | result = f(y) 303 | result_ref = rosenbrock_tess.apply(dict(x=x, y=y))["result"] 304 | _assert_pytree_isequal(result, result_ref) 305 | 306 | # Test gradient 307 | grad = jax.grad(f)(y) 308 | grad_ref = rosenbrock_tess.vector_jacobian_product( 309 | inputs=dict(x=x, y=y), 310 | vjp_inputs=["y"], 311 | vjp_outputs=["result"], 312 | cotangent_vector=dict(result=1.0), 313 | )["y"] 314 | _assert_pytree_isequal(grad, grad_ref) 315 | 316 | 317 | def test_tesseract_as_jax_pytree(served_univariate_tesseract_raw): 318 | """Test that Tesseract can be used as a JAX PyTree.""" 319 | tess = Tesseract(served_univariate_tesseract_raw) 320 | 321 | @jax.jit 322 | def f(x, y, tess): 323 | return apply_tesseract(tess, inputs=dict(x=x, y=y))["result"] 324 | 325 | x, y = np.array(0.0), np.array(0.0) 326 | result = f(x, y, tess) 327 | result_ref = rosenbrock_impl(x, y) 328 | _assert_pytree_isequal(result, result_ref) 329 | -------------------------------------------------------------------------------- /tests/univariate_tesseract/tesseract_api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pasteur Labs. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import jax 5 | from pydantic import BaseModel, Field 6 | from tesseract_core.runtime import Differentiable, Float64, ShapeDType 7 | 8 | jax.config.update("jax_enable_x64", True) 9 | 10 | 11 | def rosenbrock(x: float, y: float, a: float = 1.0, b: float = 100.0) -> float: 12 | return (a - x) ** 2 + b * (y - x**2) ** 2 13 | 14 | 15 | # 16 | # Schemas 17 | # 18 | 19 | 20 | class InputSchema(BaseModel): 21 | x: Differentiable[Float64] = Field(description="Scalar value x.", default=0.0) 22 | y: Differentiable[Float64] = Field(description="Scalar value y.", default=0.0) 23 | a: Float64 = Field(description="Scalar parameter a.", default=1.0) 24 | b: Float64 = Field(description="Scalar parameter b.", default=100.0) 25 | 26 | 27 | class OutputSchema(BaseModel): 28 | result: Differentiable[Float64] = Field( 29 | description="Result of Rosenbrock function evaluation." 30 | ) 31 | 32 | 33 | # 34 | # Required endpoints 35 | # 36 | 37 | 38 | def apply(inputs: InputSchema) -> OutputSchema: 39 | """Evaluates the Rosenbrock function given input values and parameters.""" 40 | result = rosenbrock(inputs.x, inputs.y, a=inputs.a, b=inputs.b) 41 | return OutputSchema(result=result) 42 | 43 | 44 | # 45 | # Optional endpoints 46 | # 47 | 48 | 49 | def jacobian( 50 | inputs: InputSchema, 51 | jac_inputs: set[str], 52 | jac_outputs: set[str], 53 | ): 54 | rosenbrock_signature = ["x", "y", "a", "b"] 55 | 56 | jac_result = {dy: {} for dy in jac_outputs} 57 | for dx in jac_inputs: 58 | grad_func = jax.jacrev(rosenbrock, argnums=rosenbrock_signature.index(dx)) 59 | for dy in jac_outputs: 60 | jac_result[dy][dx] = grad_func(inputs.x, inputs.y, inputs.a, inputs.b) 61 | 62 | return jac_result 63 | 64 | 65 | def jacobian_vector_product( 66 | inputs: InputSchema, 67 | jvp_inputs: set[str], 68 | jvp_outputs: set[str], 69 | tangent_vector, 70 | ): 71 | # NOTE: This is a naive implementation of JVP, which is not efficient. 72 | jac = jacobian(inputs, jvp_inputs, jvp_outputs) 73 | out = {} 74 | for dy in jvp_outputs: 75 | out[dy] = sum(jac[dy][dx] * tangent_vector[dx] for dx in jvp_inputs) 76 | return out 77 | 78 | 79 | def vector_jacobian_product( 80 | inputs: InputSchema, 81 | vjp_inputs: set[str], 82 | vjp_outputs: set[str], 83 | cotangent_vector, 84 | ): 85 | # NOTE: This is a naive implementation of VJP, which is not efficient. 86 | jac = jacobian(inputs, vjp_inputs, vjp_outputs) 87 | out = {} 88 | for dx in vjp_inputs: 89 | out[dx] = sum(jac[dy][dx] * cotangent_vector[dy] for dy in vjp_outputs) 90 | return out 91 | 92 | 93 | def abstract_eval(abstract_inputs): 94 | """Calculate output shape of apply from the shape of its inputs.""" 95 | return {"result": ShapeDType(shape=(), dtype="float64")} 96 | -------------------------------------------------------------------------------- /tests/univariate_tesseract/tesseract_config.yaml: -------------------------------------------------------------------------------- 1 | name: "univariate" 2 | version: "0.1.0" 3 | description: | 4 | Unit tesseract that evaluates the rosenbrock function and derivatives. 5 | -------------------------------------------------------------------------------- /tests/univariate_tesseract/tesseract_requirements.txt: -------------------------------------------------------------------------------- 1 | jax[cpu]==0.4.28 2 | --------------------------------------------------------------------------------