├── .copier-answers.yml ├── .github └── workflows │ ├── build.yaml │ ├── docs.yaml │ └── publish.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .python-version ├── LICENSE ├── README.md ├── docs ├── index.md └── reference │ └── torch_jax_interop.md ├── mkdocs.yaml ├── pyproject.toml ├── torch_jax_interop ├── __init__.py ├── conftest.py ├── docs_test.py ├── py.typed ├── to_jax.py ├── to_jax_module.py ├── to_jax_module_test.py ├── to_jax_test.py ├── to_torch.py ├── to_torch_module.py ├── to_torch_module_test.py ├── to_torch_test.py ├── types.py └── utils.py └── uv.lock /.copier-answers.yml: -------------------------------------------------------------------------------- 1 | # Changes here will be overwritten by Copier 2 | _commit: v0.0.3-2-ge44104f 3 | _src_path: gh:lebrice/tool_template 4 | project_description: Simple tools to mix and match PyTorch and Jax - Get the best 5 | of both worlds! 6 | python_version: '3.12' 7 | tool_name: torch_jax_interop 8 | your_email: fabrice.normandin@gmail.com 9 | your_name: Fabrice Normandin 10 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: 9 | - master 10 | pull_request: 11 | 12 | permissions: 13 | contents: read 14 | 15 | # https://stackoverflow.com/a/72408109/6388696 16 | # https://docs.github.com/en/actions/using-jobs/using-concurrency#example-using-concurrency-to-cancel-any-in-progress-job-or-run 17 | concurrency: 18 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 19 | cancel-in-progress: true 20 | 21 | jobs: 22 | linting: 23 | name: Run linting/pre-commit checks 24 | runs-on: ubuntu-latest 25 | timeout-minutes: 5 26 | steps: 27 | - uses: actions/checkout@v4 28 | - uses: actions/setup-python@v4 29 | with: 30 | python-version: "3.12" 31 | - run: pip install 'pre-commit<4.0.0' 32 | - run: pre-commit --version 33 | - run: pre-commit install 34 | - run: pre-commit run --all-files --show-diff-on-failure 35 | 36 | check_docs: 37 | needs: [linting] 38 | runs-on: ubuntu-latest 39 | steps: 40 | - uses: actions/checkout@v4 41 | - name: Install the latest version of uv 42 | uses: astral-sh/setup-uv@v3 43 | with: 44 | version: "latest" 45 | enable-cache: true 46 | # https://github.com/astral-sh/setup-uv?tab=readme-ov-file#github-authentication-token 47 | github-token: ${{ secrets.GITHUB_TOKEN }} 48 | cache-suffix: "3.12" 49 | - name: Pin python-version 3.12 50 | run: uv python pin 3.12 51 | - name: Install dependencies 52 | run: uv sync --frozen 53 | - name: Build the documentation (strict mode) 54 | run: uv run mkdocs build --strict 55 | 56 | unit_tests: 57 | needs: [linting] 58 | runs-on: ${{ matrix.platform }} 59 | strategy: 60 | max-parallel: 4 61 | matrix: 62 | platform: ["ubuntu-latest", "macos-latest"] 63 | python-version: ["3.12"] 64 | steps: 65 | - uses: actions/checkout@v4 66 | - name: Install the latest version of uv 67 | uses: astral-sh/setup-uv@v3 68 | with: 69 | version: "latest" 70 | enable-cache: true 71 | # https://github.com/astral-sh/setup-uv?tab=readme-ov-file#github-authentication-token 72 | github-token: ${{ secrets.GITHUB_TOKEN }} 73 | cache-suffix: ${{ matrix.python-version }} 74 | - name: Pin python-version ${{ matrix.python-version }} 75 | run: uv python pin ${{ matrix.python-version }} 76 | - name: Install dependencies 77 | run: uv sync --frozen 78 | - name: Test with pytest 79 | run: uv run pytest -v --cov=torch_jax_interop --cov-report=xml --cov-append --gen-missing 80 | - name: Store coverage report as an artifact 81 | uses: actions/upload-artifact@v4 82 | with: 83 | name: coverage-reports-unit-tests-${{ matrix.platform }}-${{ matrix.python-version }} 84 | path: ./coverage.xml 85 | 86 | # https://about.codecov.io/blog/uploading-code-coverage-in-a-separate-job-on-github-actions/ 87 | upload-coverage-codecov: 88 | needs: [unit_tests] 89 | runs-on: ubuntu-latest 90 | name: Upload coverage reports to Codecov 91 | timeout-minutes: 5 92 | steps: 93 | - name: Checkout 94 | uses: actions/checkout@v4 95 | - name: Download artifacts 96 | uses: actions/download-artifact@v4 97 | with: 98 | pattern: coverage-reports-* 99 | merge-multiple: false 100 | # download all the artifacts in this directory (each .coverage.xml will be in a subdirectory) 101 | # Next step if this doesn't work would be to give the coverage files a unique name and use merge-multiple: true 102 | path: coverage_reports 103 | - name: Upload coverage reports to Codecov 104 | uses: codecov/codecov-action@v4 105 | with: 106 | token: ${{ secrets.CODECOV_TOKEN }} 107 | directory: coverage_reports 108 | fail_ci_if_error: true 109 | -------------------------------------------------------------------------------- /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: Publish docs via GitHub Pages 2 | on: 3 | push: 4 | branches: 5 | - master 6 | 7 | jobs: 8 | build: 9 | name: Deploy docs 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout 13 | uses: actions/checkout@v4 14 | - name: Install the latest version of uv 15 | uses: astral-sh/setup-uv@v3 16 | with: 17 | version: "latest" 18 | enable-cache: true # no need, uses the local uv cache. 19 | # https://github.com/astral-sh/setup-uv?tab=readme-ov-file#github-authentication-token 20 | github-token: ${{ secrets.GITHUB_TOKEN }} 21 | cache-suffix: "3.12" 22 | 23 | - name: Pin python-version 24 | run: uv python pin 3.12 25 | 26 | - name: Install dependencies 27 | run: uv sync --extra docs --frozen 28 | 29 | - name: Deploy docs 30 | run: uv run mkdocs gh-deploy --force 31 | # note: Checking if we really need the one below: 32 | # uses: mhausenblas/mkdocs-deploy-gh-pages@1.9 33 | # # Or use mhausenblas/mkdocs-deploy-gh-pages@nomaterial to build without the mkdocs-material theme 34 | # env: 35 | # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 36 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Poetry when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | on: 9 | release: 10 | types: [published] 11 | workflow_dispatch: 12 | jobs: 13 | # https://docs.pypi.org/trusted-publishers/using-a-publisher/ 14 | publish: 15 | strategy: 16 | matrix: 17 | python-version: [3.12] 18 | os: [ubuntu-latest] 19 | runs-on: ${{ matrix.os }} 20 | environment: release 21 | permissions: 22 | # IMPORTANT: this permission is mandatory for trusted publishing 23 | id-token: write 24 | steps: 25 | - uses: actions/checkout@v4 26 | 27 | - name: Install the latest version of uv 28 | uses: astral-sh/setup-uv@v3 29 | with: 30 | version: "latest" 31 | enable-cache: true 32 | # https://github.com/astral-sh/setup-uv?tab=readme-ov-file#github-authentication-token 33 | github-token: ${{ secrets.GITHUB_TOKEN }} 34 | cache-suffix: ${{ matrix.python-version }} 35 | - name: Pin python-version ${{ matrix.python-version }} 36 | run: uv python pin ${{ matrix.python-version }} 37 | - name: Install dependencies 38 | run: uv sync --frozen 39 | 40 | - name: Build package 41 | run: | 42 | uv build 43 | 44 | - name: Publish package distributions to PyPI 45 | uses: pypa/gh-action-pypi-publish@release/v1 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # Benchmarks created by `pytest-benchmark`. 165 | .benchmarks 166 | .vscode 167 | 168 | # Ignore tensor regression files. 169 | *.npz 170 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v5.0.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | require_serial: true 11 | - id: end-of-file-fixer 12 | require_serial: true 13 | # - id: check-docstring-first 14 | - id: check-added-large-files 15 | require_serial: true 16 | exclude: uv.lock 17 | - id: check-ast 18 | require_serial: true 19 | - id: check-yaml 20 | require_serial: true 21 | exclude: "mkdocs.yml" 22 | - id: debug-statements 23 | require_serial: true 24 | - id: detect-private-key 25 | require_serial: true 26 | - id: check-executables-have-shebangs 27 | require_serial: true 28 | - id: check-toml 29 | require_serial: true 30 | - id: check-case-conflict 31 | require_serial: true 32 | 33 | - repo: https://github.com/charliermarsh/ruff-pre-commit 34 | # Ruff version. 35 | rev: "v0.8.4" 36 | hooks: 37 | - id: ruff 38 | args: ["--fix"] 39 | require_serial: true 40 | 41 | # python docstring formatting 42 | - repo: https://github.com/myint/docformatter 43 | rev: v1.7.5 44 | hooks: 45 | - id: docformatter 46 | language: python 47 | args: [--in-place] 48 | require_serial: true 49 | 50 | # NOTE: Disabling this, since I'm having the glib-c2.29 weird bug. 51 | # # yaml formatting 52 | # - repo: https://github.com/pre-commit/mirrors-prettier 53 | # rev: v2.7.1 54 | # hooks: 55 | # - id: prettier 56 | # types: [yaml] 57 | 58 | # jupyter notebook cell output clearing 59 | - repo: https://github.com/kynan/nbstripout 60 | rev: 0.8.1 61 | hooks: 62 | - id: nbstripout 63 | require_serial: true 64 | 65 | # md formatting 66 | # - repo: https://github.com/executablebooks/mdformat 67 | # rev: 0.7.20 68 | # hooks: 69 | # - id: mdformat 70 | # exclude: "docs/" # terrible, I know, but it's messing up everything with mkdocs fences! 71 | # args: ["--number"] 72 | # additional_dependencies: 73 | # - mdformat-gfm 74 | # - mdformat-tables 75 | # - mdformat_frontmatter 76 | # - mdformat-toc 77 | # - mdformat-config 78 | # - mdformat-black 79 | # # see https://github.com/KyleKing/mdformat-mkdocs 80 | # # Doesn't seem to work! 81 | # - mdformat-mkdocs[recommended]>=2.1.0 82 | # require_serial: true 83 | 84 | # word spelling linter 85 | - repo: https://github.com/codespell-project/codespell 86 | rev: v2.3.0 87 | hooks: 88 | - id: codespell 89 | args: 90 | - --skip=logs/**,data/** 91 | # - --ignore-words-list=abc,def 92 | require_serial: true 93 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Fabrice Normandin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Torch \<-> Jax Interop Utilities 2 | 3 | Hey, you there! 4 | 5 | - Do you use PyTorch, but are curious about Jax (or vice-versa)? Would you prefer to start adding some (Jax/PyTorch) progressively into your projects rather than to start from scratch? 6 | - Want to avoid the pain of rewriting a model from an existing PyTorch codebase in Jax (or vice-versa)? 7 | - Do you like the performance benefits of Jax, but aren't prepared to sacrifice your nice PyTorch software frameworks (e.g. [Lightning](https://lightning.ai/docs/pytorch/stable/))? 8 | 9 | **Well I have some good news for you!** 10 | You can have it all: Sweet, sweet jit-ed functions and automatic differentiation from Jax, as well as mature, widely-used frameworks from the PyTorch software ecosystem. 11 | 12 | ## What this does 13 | 14 | This package contains a few utility functions to simplify interoperability between jax and torch: `torch_to_jax`, `jax_to_torch`, `WrappedJaxFunction`, `torch_module_to_jax`. 15 | 16 | This repository contains utilities for converting PyTorch Tensors to JAX arrays and vice versa. 17 | This conversion happens thanks the `dlpack` format, which is a common format for exchanging tensors between different deep learning frameworks. Crucially, this format allows for zero-copy * tensor sharing between PyTorch and JAX. 18 | 19 | > \* Note: For some torch tensors with specific memory layouts, for example channels-first image tensors, Jax will refuse to read the array from the dlpack, so we flatten and unflatten the data when converting, which might involve a copy.This is displayed as a warning at the moment on the command-line. 20 | 21 | ## Installation 22 | 23 | We would **highly** recommend you use [uv](https://docs.astral.sh/uv/) to manage your project dependencies. This greatly helps avoid cuda dependency conflicts between PyTorch and Jax. 24 | 25 | ```bash 26 | uv add torch-jax-interop 27 | ``` 28 | 29 | Otherwise, if you don't use `uv`: 30 | 31 | ```bash 32 | pip install torch-jax-interop 33 | ``` 34 | 35 | > This will package only depends on the base (cpu) version of Jax by default. 36 | > If you want to also install the GPU version of jax, use `uv add torch-jax-interop[gpu]` or `uv add jax[cuda12]` directly (or the pip equivalents). 37 | 38 | ## Comparable projects 39 | 40 | - https://github.com/lucidrains/jax2torch: Seems to be the first minimal prototype for something like this. Supports jax2torch for functions, but not the other way around. 41 | - https://github.com/subho406/pytorch2jax: Very similar. The way we convert `torch.nn.Module`s to `jax.custom_vjp` is actually based on their implementation, with some additions (support for jitting, along with more flexible input/output signatures). 42 | - https://github.com/samuela/torch2jax: Takes a different approach: using a `torch.Tensor` subclass and `__torch_fuction__`. 43 | - https://github.com/rdyro/torch2jax: Just found this, seems to have very good support for the torch to jax conversion, but not the other way around. Has additional features like specifying the depth (levels of derivatives). 44 | 45 | ## Usage 46 | 47 | ```python 48 | import torch 49 | import jax.numpy as jnp 50 | from torch_jax_interop import jax_to_torch, torch_to_jax 51 | ``` 52 | 53 | Converting `torch.Tensor`s into `jax.Array`s: 54 | 55 | ```python 56 | import jax 57 | import torch 58 | 59 | tensors = { 60 | "x": torch.randn(5), 61 | "y": torch.arange(5), 62 | } 63 | 64 | jax_arrays = jax.tree.map(torch_to_jax, tensors) 65 | torch_tensors = jax.tree.map(jax_to_torch, jax_arrays) 66 | ``` 67 | 68 | Passing torch.Tensors to a Jax function: 69 | 70 | ```python 71 | @jax_to_torch 72 | def some_jax_function(x: jnp.ndarray) -> jnp.ndarray: 73 | return x + jnp.ones_like(x) 74 | 75 | 76 | torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 77 | some_torch_tensor = torch.arange(5, device=device) 78 | 79 | torch_output = some_jax_function(some_torch_tensor) 80 | 81 | 82 | some_jax_array = jnp.arange(5) 83 | 84 | 85 | @torch_to_jax 86 | def some_torch_function(x: torch.Tensor) -> torch.Tensor: 87 | return x + torch.ones_like(x) 88 | 89 | 90 | print(some_torch_function(some_jax_array)) 91 | ``` 92 | 93 | ## Examples 94 | 95 | ### Jax to Torch nn.Module 96 | 97 | Suppose we have some jax function we'd like to use in a PyTorch model: 98 | 99 | ```python 100 | import jax 101 | import jax.numpy as jnp 102 | 103 | 104 | def some_jax_function(params: jax.Array, x: jax.Array): 105 | """Some toy function that takes in some parameters and an input vector.""" 106 | return jnp.dot(x, params) 107 | ``` 108 | 109 | By importing this: 110 | 111 | ```python 112 | from torch_jax_interop import WrappedJaxFunction 113 | ``` 114 | 115 | We can then wrap this jax function into a torch.nn.Module with learnable parameters: 116 | 117 | ```python 118 | import torch 119 | import torch.nn 120 | 121 | module = WrappedJaxFunction(some_jax_function, jax.random.normal(jax.random.key(0), (2, 1))) 122 | module = module.to("cpu") # jax arrays are on GPU by default, moving them to CPU for this example. 123 | ``` 124 | 125 | The parameters are now learnable parameters of the module parameters: 126 | 127 | ```python 128 | dict(module.state_dict()) 129 | {"params.0": tensor([[-0.7848], [0.8564]])} 130 | ``` 131 | 132 | You can use this just like any other torch.nn.Module: 133 | 134 | ```python 135 | x, y = torch.randn(2), torch.rand(1) 136 | output = module(x) 137 | loss = torch.nn.functional.mse_loss(output, y) 138 | loss.backward() 139 | 140 | model = torch.nn.Sequential( 141 | torch.nn.Linear(123, 2), 142 | module, 143 | ) 144 | ``` 145 | 146 | Same goes for `flax.linen.Module`s, you can now use them in your torch forward / backward pass: 147 | 148 | ```python 149 | import flax.linen 150 | 151 | 152 | class Classifier(flax.linen.Module): 153 | num_classes: int = 10 154 | 155 | @flax.linen.compact 156 | def __call__(self, x: jax.Array): 157 | x = x.reshape((x.shape[0], -1)) # flatten 158 | x = flax.linen.Dense(features=256)(x) 159 | x = flax.linen.relu(x) 160 | x = flax.linen.Dense(features=self.num_classes)(x) 161 | return x 162 | 163 | 164 | jax_module = Classifier(num_classes=10) 165 | jax_params = jax_module.init(jax.random.key(0), x) 166 | 167 | from torch_jax_interop import WrappedJaxFunction 168 | 169 | torch_module = WrappedJaxFunction(jax.jit(jax_module.apply), jax_params) 170 | ``` 171 | 172 | ### Torch nn.Module to jax function 173 | 174 | ```python 175 | >>> import torch 176 | >>> import jax 177 | 178 | >>> model = torch.nn.Linear(3, 2, device="cuda") 179 | >>> apply_fn, params = torch_module_to_jax(model) 180 | 181 | 182 | >>> def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array: 183 | ... y_pred = apply_fn(params, x) 184 | ... return jax.numpy.mean((y - y_pred) ** 2) 185 | 186 | 187 | >>> x = jax.random.uniform(key=jax.random.key(0), shape=(1, 3)) 188 | >>> y = jax.random.uniform(key=jax.random.key(1), shape=(1, 1)) 189 | 190 | >>> loss, grad = jax.value_and_grad(loss_function)(params, x, y) 191 | >>> loss 192 | Array(0.3944674, dtype=float32) 193 | >>> grad 194 | (Array([[-0.46541408, -0.15171866, -0.30520514], 195 | [-0.7201077 , -0.23474531, -0.47222584]], dtype=float32), Array([-0.4821338, -0.7459771], dtype=float32)) 196 | ``` 197 | 198 | To use `jax.jit` on the model, you need to pass an example of an output so we can 199 | tell the JIT compiler the output shapes and dtypes to expect: 200 | 201 | ```python 202 | >>> # here we reuse the same model as before: 203 | >>> apply, params = torch_module_to_jax(model, example_output=torch.zeros(1, 2, device="cuda")) 204 | >>> def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array: 205 | ... y_pred = apply(params, x) 206 | ... return jax.numpy.mean((y - y_pred) ** 2) 207 | >>> loss, grad = jax.jit(jax.value_and_grad(loss_function))(params, x, y) 208 | >>> loss 209 | Array(0.3944674, dtype=float32) 210 | >>> grad 211 | (Array([[-0.46541408, -0.15171866, -0.30520514], 212 | [-0.7201077 , -0.23474531, -0.47222584]], dtype=float32), Array([-0.4821338, -0.7459771], dtype=float32)) 213 | ``` 214 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # torch_jax_interop 2 | 3 | ## Installation 4 | 5 | 1. (optional) Install UV: https://docs.astral.sh/uv/getting-started/installation/ 6 | 7 | 2. Install this package: 8 | 9 | ```console 10 | uv add torch_jax_interop 11 | ``` 12 | 13 | ## Usage 14 | 15 | ::: torch_jax_interop 16 | options: 17 | show_submodules: false 18 | members: [] 19 | -------------------------------------------------------------------------------- /docs/reference/torch_jax_interop.md: -------------------------------------------------------------------------------- 1 | ::: torch_jax_interop 2 | -------------------------------------------------------------------------------- /mkdocs.yaml: -------------------------------------------------------------------------------- 1 | site_name: torch_jax_interop Documentation 2 | site_url: https://lebrice.github.io/torch_jax_interop 3 | theme: 4 | name: material 5 | 6 | markdown_extensions: 7 | - pymdownx.highlight: # https://squidfunk.github.io/mkdocs-material/reference/code-blocks/#configuration 8 | anchor_linenums: true 9 | line_spans: __span 10 | pygments_lang_class: true 11 | - pymdownx.inlinehilite 12 | - pymdownx.snippets 13 | - pymdownx.superfences 14 | - pymdownx.magiclink 15 | 16 | plugins: 17 | - search 18 | - mkdocstrings: 19 | handlers: 20 | python: 21 | import: 22 | - https://docs.python.org/3/objects.inv 23 | - https://docs.pytest.org/en/stable/objects.inv 24 | - https://flax.readthedocs.io/en/latest/objects.inv 25 | - https://pytorch.org/docs/stable/objects.inv 26 | - https://jax.readthedocs.io/en/latest/objects.inv 27 | options: 28 | docstring_style: numpy 29 | members_order: source 30 | annotations_path: brief 31 | show_docstring_attributes: true 32 | modernize_annotations: true 33 | show_source: false 34 | show_submodules: false 35 | separate_signature: true 36 | signature_crossrefs: true 37 | show_signature_annotations: true 38 | allow_inspection: true 39 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "torch_jax_interop" 3 | description = "Simple tools to mix and match PyTorch and Jax - Get the best of both worlds!" 4 | readme = "README.md" 5 | authors = [ 6 | { name = "Fabrice Normandin", email = "fabrice.normandin@gmail.com" }, 7 | ] 8 | requires-python = ">=3.10" 9 | dependencies = ["jax>=0.4.28", "torch>=2.0.0"] 10 | dynamic = ["version"] 11 | 12 | [project.optional-dependencies] 13 | gpu = ["jax[cuda12]>=0.4.28; sys_platform == 'linux'"] 14 | 15 | 16 | [dependency-groups] 17 | dev = [ 18 | "pytest>=8.3.3", 19 | "pytest-cov>=5.0.0", 20 | "uv-dynamic-versioning>=0.2.0", 21 | "mkdocs-material>=9.5.44", 22 | "flax>=0.10.2", 23 | "tensor-regression>=0.0.8", 24 | "mktestdocs>=0.2.4", 25 | "pytest-benchmark>=5.1.0", 26 | "pytest-env>=1.1.5", 27 | "mkdocstrings[python]>=0.27.0", 28 | "black>=24.10.0", 29 | ] 30 | 31 | [tool.pytest.ini_options] 32 | testpaths = ["torch_jax_interop"] 33 | norecursedirs = [".venv"] 34 | addopts = ["--doctest-modules"] 35 | 36 | [tool.pytest_env] 37 | CUBLAS_WORKSPACE_CONFIG = ":4096:8" 38 | 39 | [tool.ruff] 40 | line-length = 99 41 | 42 | [tool.docformatter] 43 | wrap-summaries = 99 44 | wrap-descriptions = 99 45 | 46 | [tool.uv] 47 | managed = true 48 | 49 | [tool.uv-dynamic-versioning] 50 | vcs = "git" 51 | style = "semver" 52 | 53 | [build-system] 54 | requires = ["hatchling", "uv-dynamic-versioning"] 55 | build-backend = "hatchling.build" 56 | 57 | [tool.hatch.build.targets.wheel] 58 | packages = ["torch_jax_interop"] 59 | 60 | [tool.hatch.version] 61 | source = "uv-dynamic-versioning" 62 | -------------------------------------------------------------------------------- /torch_jax_interop/__init__.py: -------------------------------------------------------------------------------- 1 | """Tools to help interoperability between PyTorch and Jax code. 2 | 3 | ## Examples 4 | 5 | ### Converting [torch.Tensor][]s into [jax.Array][]s and vice-versa: 6 | 7 | ```python 8 | import jax 9 | import torch 10 | from torch_jax_interop import torch_to_jax, jax_to_torch 11 | tensors = { 12 | "x": torch.randn(5), 13 | "y": torch.arange(5), 14 | } 15 | 16 | jax_arrays = jax.tree.map(torch_to_jax, tensors) 17 | print(jax_arrays) 18 | # {'x': Array([-0.11146712, 0.12036294, -0.3696345 , -0.24041797, -1.1969243 ], dtype=float32), 19 | # 'y': Array([0, 1, 2, 3, 4], dtype=int32)} 20 | 21 | torch_tensors = jax.tree.map(jax_to_torch, jax_arrays) 22 | print(torch_tensors) 23 | # {'x': tensor([-0.1115, 0.1204, -0.3696, -0.2404, -1.1969]), 24 | # 'y': tensor([0, 1, 2, 3, 4], dtype=torch.int32)} 25 | ``` 26 | 27 | ### Using a Jax function from PyTorch: 28 | 29 | ```python 30 | @jax_to_torch 31 | def some_wrapped_jax_function(x: jax.Array) -> jax.Array: 32 | return x + jax.numpy.ones_like(x) 33 | 34 | torch_input = torch.arange(5) 35 | torch_output = some_wrapped_jax_function(torch_input) 36 | print(torch_output) 37 | # tensor([1, 2, 3, 4, 5], dtype=torch.int32) 38 | ``` 39 | 40 | ### Using a Torch function from Jax: 41 | 42 | ```python 43 | @torch_to_jax 44 | def some_wrapped_torch_function(x: torch.Tensor) -> torch.Tensor: 45 | return x + torch.ones_like(x) 46 | 47 | jax_input = jax.numpy.arange(5) 48 | jax_output = some_wrapped_torch_function(jax_input) 49 | print(jax_output) 50 | # Array([1, 2, 3, 4, 5], dtype=int32) 51 | ``` 52 | 53 | ### Differentiating through a Jax function in PyTorch: 54 | 55 | ```python 56 | def some_jax_function(params: jax.Array, x: jax.Array): 57 | '''Some toy function that takes in some parameters and an input vector.''' 58 | return jax.numpy.dot(x, params) 59 | ``` 60 | 61 | By importing this: 62 | 63 | ```python 64 | from torch_jax_interop import WrappedJaxFunction 65 | ``` 66 | 67 | We can then wrap this jax function into a torch.nn.Module with learnable parameters: 68 | 69 | ```python 70 | module = WrappedJaxFunction(some_jax_function, jax_params=jax.random.normal(jax.random.key(0), (2, 1))) 71 | module = module.to("cpu") # jax arrays are on GPU by default, moving them to CPU for this example. 72 | ``` 73 | 74 | The parameters are now learnable parameters of the module parameters: 75 | 76 | ```python 77 | print(dict(module.state_dict())) 78 | # {'params.0': tensor([[-0.7848], 79 | # [ 0.8564]])} 80 | ``` 81 | 82 | You can use this just like any other torch.nn.Module: 83 | 84 | ```python 85 | x, y = torch.randn(2), torch.rand(1) 86 | output = module(x) 87 | loss = torch.nn.functional.mse_loss(output, y) 88 | loss.backward() 89 | ``` 90 | 91 | This also works the same way for `flax.linen.Module`s: 92 | 93 | ```python 94 | import flax 95 | class JaxModule(flax.linen.Module): 96 | output_dims: int 97 | @flax.linen.compact 98 | def __call__(self, x: jax.Array): 99 | x = x.reshape((x.shape[0], -1)) # flatten 100 | x = flax.linen.Dense(features=256)(x) 101 | x = flax.linen.relu(x) 102 | x = flax.linen.Dense(features=self.output_dims)(x) 103 | return x 104 | 105 | 106 | x = jax.random.uniform(key=jax.random.key(0), shape=(16, 28, 28, 1)) 107 | jax_module = JaxModule(output_dims=10) 108 | jax_params = jax_module.init(jax.random.key(0), x) 109 | ``` 110 | 111 | You can still of course jit your Jax code: 112 | 113 | ```python 114 | wrapped_jax_module = WrappedJaxFunction(jax.jit(jax_module.apply), jax_params=jax_params) 115 | ``` 116 | 117 | And you can then use this jax module in PyTorch: 118 | 119 | ```python 120 | x = jax_to_torch(x) 121 | y = torch.randint(0, 10, (16,), device=x.device) 122 | logits = wrapped_jax_module(x) 123 | loss = torch.nn.functional.cross_entropy(logits, y, reduction="mean") 124 | loss.backward() 125 | print({name: p.grad.shape for name, p in wrapped_jax_module.named_parameters()}) 126 | # {'params.0': torch.Size([256]), 'params.1': torch.Size([784, 256]), 'params.2': torch.Size([10]), 'params.3': torch.Size([256, 10])} 127 | ``` 128 | """ 129 | 130 | from .to_jax import torch_to_jax 131 | from .to_jax_module import torch_module_to_jax 132 | from .to_torch import jax_to_torch 133 | from .to_torch_module import WrappedJaxFunction, WrappedJaxScalarFunction 134 | 135 | __all__ = [ 136 | "jax_to_torch", 137 | "torch_to_jax", 138 | "WrappedJaxFunction", 139 | "WrappedJaxScalarFunction", 140 | "torch_module_to_jax", 141 | ] 142 | -------------------------------------------------------------------------------- /torch_jax_interop/conftest.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import flax.linen 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import pytest 8 | import torch 9 | from tensor_regression.stats import get_simple_attributes 10 | 11 | from torch_jax_interop import torch_to_jax 12 | from torch_jax_interop.to_torch import jax_to_torch, jax_to_torch_device 13 | from torch_jax_interop.utils import to_channels_last 14 | 15 | 16 | # Add support for Jax arrays in the tensor regression fixture. 17 | @get_simple_attributes.register(jax.Array) 18 | def jax_array_simple_attributes(array: jnp.ndarray, precision: int | None) -> dict: 19 | return get_simple_attributes(jax_to_torch(array), precision=precision) 20 | 21 | 22 | DEFAULT_SEED = 123 23 | 24 | 25 | @pytest.fixture(autouse=True) 26 | def seed(request: pytest.FixtureRequest): 27 | """Fixture that seeds everything for reproducibility and yields the random seed used.""" 28 | random_seed = getattr(request, "param", DEFAULT_SEED) 29 | assert isinstance(random_seed, int) or random_seed is None 30 | 31 | random_state = random.getstate() 32 | np_random_state = np.random.get_state() 33 | with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))): 34 | random.seed(random_seed) 35 | np.random.seed(random_seed) 36 | torch.manual_seed(random_seed) 37 | yield random_seed 38 | 39 | random.setstate(random_state) 40 | np.random.set_state(np_random_state) 41 | 42 | 43 | @pytest.fixture(scope="session", params=["cpu", "cuda", "rocm", "tpu"], ids="backend={}".format) 44 | def jax_device(request: pytest.FixtureRequest) -> jax.Device: 45 | backend_str = request.param 46 | try: 47 | devices = jax.devices(backend=request.param) 48 | except RuntimeError: 49 | devices = None 50 | if not devices: 51 | pytest.skip(f"No devices found for backend {backend_str}.") 52 | return devices[0] 53 | 54 | 55 | @pytest.fixture(scope="session") 56 | def torch_device(request: pytest.FixtureRequest, jax_device: jax.Device) -> torch.device: 57 | param = getattr(request, "param", None) 58 | # in case of an indirect parametrization, use the specified device: 59 | if param is not None: 60 | assert isinstance(param, str | torch.device) 61 | return torch.device(param) if isinstance(param, str) else param 62 | return jax_to_torch_device(jax_device) 63 | 64 | 65 | @pytest.fixture( 66 | scope="session", 67 | params=[ 68 | pytest.param((torch.float32, jnp.float32), id="float32"), 69 | pytest.param((torch.float64, jnp.float32), id="float64"), # important! 70 | pytest.param((torch.int32, jnp.int32), id="int32"), 71 | pytest.param((torch.int64, jnp.int32), id="int64"), # important! 72 | ], 73 | ) 74 | def torch_jax_dtypes(request: pytest.FixtureRequest): 75 | return request.param 76 | 77 | 78 | @pytest.fixture(scope="session") 79 | def torch_dtype(torch_jax_dtypes: tuple[torch.dtype, jnp.dtype]) -> torch.dtype: 80 | return torch_jax_dtypes[0] 81 | 82 | 83 | @pytest.fixture(scope="session") 84 | def jax_dtype(torch_jax_dtypes: tuple[torch.dtype, jnp.dtype]) -> jnp.dtype: 85 | return torch_jax_dtypes[1] 86 | 87 | 88 | @pytest.fixture 89 | def jax_input(torch_input: torch.Tensor): 90 | return torch_to_jax(torch_input) 91 | 92 | 93 | @pytest.fixture 94 | def torch_input(torch_device: torch.device, seed: int): 95 | input_shape: tuple[int, ...] = (1, 3, 32, 32) 96 | torch_input = torch.randn( 97 | input_shape, 98 | generator=torch.Generator(device=torch_device).manual_seed(seed), 99 | device=torch_device, 100 | ) 101 | return torch_input 102 | 103 | 104 | class JaxCNN(flax.linen.Module): 105 | """A simple CNN model. 106 | 107 | Taken from 108 | https://flax.readthedocs.io/en/latest/quick_start.html#define-network 109 | """ 110 | 111 | num_classes: int = 10 112 | 113 | @flax.linen.compact 114 | def __call__(self, x: jax.Array): 115 | x = to_channels_last(x) 116 | x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x) 117 | x = flax.linen.relu(x) 118 | x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) 119 | x = flax.linen.Conv(features=64, kernel_size=(3, 3))(x) 120 | x = flax.linen.relu(x) 121 | x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) 122 | 123 | x = flatten(x) 124 | x = flax.linen.Dense(features=256)(x) 125 | x = flax.linen.relu(x) 126 | x = flax.linen.Dense(features=self.num_classes)(x) 127 | return x 128 | 129 | 130 | class TorchCNN(torch.nn.Sequential): 131 | def __init__(self, num_classes: int = 10): 132 | super().__init__( 133 | torch.nn.LazyConv2d(out_channels=32, kernel_size=3), 134 | torch.nn.ReLU(), 135 | torch.nn.AvgPool2d(kernel_size=2, stride=2), 136 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3), 137 | torch.nn.ReLU(), 138 | torch.nn.AvgPool2d(kernel_size=2, stride=2), 139 | torch.nn.Flatten(), 140 | torch.nn.LazyLinear(out_features=256), 141 | torch.nn.ReLU(), 142 | torch.nn.Linear(in_features=256, out_features=num_classes), 143 | ) 144 | 145 | 146 | def flatten(x: jax.Array) -> jax.Array: 147 | return x.reshape((x.shape[0], -1)) 148 | 149 | 150 | class JaxFcNet(flax.linen.Module): 151 | num_classes: int = 10 152 | 153 | @flax.linen.compact 154 | def __call__(self, x: jax.Array): 155 | x = flatten(x) 156 | x = flax.linen.Dense(features=256)(x) 157 | x = flax.linen.relu(x) 158 | x = flax.linen.Dense(features=self.num_classes)(x) 159 | return x 160 | 161 | 162 | class TorchFcNet(torch.nn.Sequential): 163 | def __init__(self, num_classes: int = 10): 164 | super().__init__( 165 | torch.nn.Flatten(), 166 | torch.nn.LazyLinear(out_features=256), 167 | torch.nn.ReLU(), 168 | torch.nn.Linear(in_features=256, out_features=num_classes), 169 | ) 170 | 171 | 172 | @pytest.fixture 173 | def num_classes(): 174 | return 10 175 | 176 | 177 | @pytest.fixture(autouse=True, scope="session") 178 | def make_torch_deterministic(): 179 | mode = torch.get_deterministic_debug_mode() 180 | 181 | torch.set_deterministic_debug_mode("error") 182 | yield 183 | torch.set_deterministic_debug_mode(mode) 184 | 185 | 186 | @pytest.fixture(params=[TorchCNN, TorchFcNet]) 187 | def torch_network( 188 | request: pytest.FixtureRequest, 189 | seed: int, 190 | torch_input: torch.Tensor, 191 | num_classes: int, 192 | torch_device: torch.device, 193 | ): 194 | torch_network_type: type[torch.nn.Module] = request.param 195 | with ( 196 | torch_device, 197 | torch.random.fork_rng([torch_device] if torch_device.type == "cuda" else []), 198 | ): 199 | torch_network = torch_network_type(num_classes=num_classes) 200 | # initialize any un-initialized parameters in the network by doing a forward pass 201 | # with a dummy input. 202 | torch_network(torch_input) 203 | return torch_network 204 | 205 | 206 | @pytest.fixture(params=[JaxCNN, JaxFcNet]) 207 | def jax_network_and_params( 208 | request: pytest.FixtureRequest, 209 | seed: int, 210 | jax_input: jax.Array, 211 | num_classes: int, 212 | jax_device: jax.Device, 213 | ): 214 | jax_network_type: type[flax.linen.Module] 215 | jax_network_type = request.param 216 | with jax.default_device(jax_device): 217 | # todo: fix channels_last vs channels_first issues automatically in torch_to_jax? 218 | jax_network = jax_network_type(num_classes=num_classes) 219 | jax_params = jax_network.init(jax.random.key(seed), jax_input) 220 | return jax_network, jax_params 221 | -------------------------------------------------------------------------------- /torch_jax_interop/docs_test.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | 3 | import pytest 4 | from mktestdocs.__main__ import check_codeblock, check_raw_string 5 | 6 | import torch_jax_interop 7 | from torch_jax_interop.to_jax_module import torch_module_to_jax 8 | from torch_jax_interop.to_torch_module import ( 9 | WrappedJaxFunction, 10 | ) 11 | 12 | 13 | def patched_grab_code_blocks(docstring: str, lang="python"): 14 | """Given a docstring, grab all the markdown codeblocks found in docstring. 15 | 16 | Patch for a bug in mkdoctest with indenting: 17 | - https://github.com/koaning/mktestdocs/issues/19 18 | 19 | Arguments: 20 | docstring: the docstring to analyse 21 | lang: if not None, the language that is assigned to the codeblock 22 | """ 23 | docstring_lines = docstring.splitlines() 24 | docstring = ( 25 | docstring_lines[0] + "\n" + textwrap.dedent("\n".join(docstring_lines[1:])) 26 | ) 27 | in_block = False 28 | block = "" 29 | codeblocks = [] 30 | for idx, line in enumerate(docstring.split("\n")): 31 | if line.strip().startswith("```"): 32 | if in_block: 33 | codeblocks.append(check_codeblock(block, lang=lang)) 34 | block = "" 35 | in_block = not in_block 36 | if in_block: 37 | block += line + "\n" 38 | return [c for c in codeblocks if c != ""] 39 | 40 | 41 | @pytest.mark.parametrize( 42 | "obj", 43 | [WrappedJaxFunction, torch_jax_interop, torch_module_to_jax], 44 | ids=lambda d: getattr(d, "__qualname__", d), 45 | ) 46 | def test_member(obj): 47 | all_code = "".join(patched_grab_code_blocks(obj.__doc__, lang="python")) 48 | assert all_code, (obj, obj.__doc__) 49 | check_raw_string(all_code, lang="python") 50 | 51 | # assert False, mktestdocs.grab_code_blocks(WrappedJaxFunction.__doc__) 52 | -------------------------------------------------------------------------------- /torch_jax_interop/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lebrice/torch_jax_interop/d4e113282c3ce5b9af4357481c1348a3f419f718/torch_jax_interop/py.typed -------------------------------------------------------------------------------- /torch_jax_interop/to_jax.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import collections.abc 4 | import contextlib 5 | import dataclasses 6 | import functools 7 | import logging 8 | import warnings 9 | from logging import getLogger as get_logger 10 | from typing import Any, Callable, overload 11 | 12 | import jax 13 | import jax.core 14 | import jaxlib 15 | import jaxlib.xla_extension 16 | import torch 17 | import torch.func 18 | import torch.utils._pytree 19 | from jax.dlpack import from_dlpack as jax_from_dlpack # type: ignore 20 | from torch.utils.dlpack import to_dlpack as torch_to_dlpack # type: ignore 21 | 22 | from .types import ( 23 | Dataclass, 24 | DataclassType, 25 | K, 26 | NestedDict, 27 | NestedMapping, 28 | ) 29 | from .utils import ( 30 | log_once, 31 | ) 32 | 33 | logger = get_logger(__name__) 34 | 35 | 36 | @overload 37 | def torch_to_jax(value: torch.Tensor, /) -> jax.Array: 38 | ... 39 | 40 | 41 | @overload 42 | def torch_to_jax(value: torch.device, /) -> jax.Device: 43 | ... 44 | 45 | 46 | @overload 47 | def torch_to_jax(value: tuple[torch.Tensor, ...], /) -> tuple[jax.Array, ...]: 48 | ... 49 | 50 | 51 | @overload 52 | def torch_to_jax(value: list[torch.Tensor], /) -> list[jax.Array]: 53 | ... 54 | 55 | 56 | @overload 57 | def torch_to_jax(value: NestedDict[K, torch.Tensor], /) -> NestedDict[K, jax.Array]: 58 | ... 59 | 60 | 61 | @overload 62 | def torch_to_jax(value: Any, /) -> Any: 63 | ... 64 | 65 | 66 | def torch_to_jax(value: Any, /) -> Any: 67 | """Converts PyTorch tensors to JAX arrays. 68 | 69 | Converts the tensors "in-place", without the need for copies or moving data to the CPU. 70 | 71 | Args: 72 | value: a torch tensor 73 | 74 | Returns: 75 | a JAX array 76 | """ 77 | log_once( 78 | logger, 79 | message=f"No registered handler for values of type {type(value)}, returning it as-is.", 80 | level=logging.DEBUG, 81 | ) 82 | return value 83 | 84 | 85 | torch_to_jax = functools.singledispatch(torch_to_jax) # type: ignore 86 | 87 | 88 | @torch_to_jax.register(type(None)) 89 | @torch_to_jax.register(int) 90 | @torch_to_jax.register(float) 91 | @torch_to_jax.register(str) 92 | @torch_to_jax.register(bool) 93 | @torch_to_jax.register(bytes) 94 | def no_op(v: Any) -> Any: 95 | return v 96 | 97 | 98 | def _direct_conversion(v: torch.Tensor) -> jax.Array: 99 | return jax_from_dlpack(v, copy=False) 100 | 101 | 102 | def _to_from_dlpack( 103 | v: torch.Tensor, ignore_deprecation_warning: bool = True 104 | ) -> jax.Array: 105 | with warnings.catch_warnings() if ignore_deprecation_warning else contextlib.nullcontext(): 106 | # Only way to get this to work for CPU seems to be with to/from dlpack... so we have to use this deprecated 107 | # conversion method for now. 108 | # todo: Should we let it though though? 109 | if ignore_deprecation_warning: 110 | warnings.filterwarnings("ignore", category=DeprecationWarning) 111 | return jax_from_dlpack(torch_to_dlpack(v), copy=False) 112 | 113 | 114 | def torch_to_jax_tensor(value: torch.Tensor) -> jax.Array: 115 | """Converts a PyTorch Tensor into a jax.Array. 116 | 117 | NOTE: seems like torch.float64 tensors are implicitly converted to jax.float32 tensors? 118 | TODO: 119 | - If the tensor is on the GPU, then we can use the direct conversion with jax from_dlpack. 120 | Otherwise we might have to convert to/from dlpack, which is apparently being deprecated. 121 | - ALSO: this seems to happen when jitted code is calling a pure callback. Not sure if it happens in other cases too 122 | (e.g. just calling this with a jax tensor in non-jit mode). 123 | """ 124 | value = value.detach() 125 | 126 | if value.device.type == "cpu": 127 | try: 128 | # todo: Calling jax_from_dlpack with a cpu tensor causes issues in jax pure callbacks **later**, 129 | # when they are run by jax somehow. This causes issues when using a nn.Module in jax graph. 130 | # return _direct_conversion(value) 131 | return _to_from_dlpack(value, ignore_deprecation_warning=True) 132 | 133 | except jaxlib.xla_extension.XlaRuntimeError as err: 134 | log_once( 135 | logger, 136 | message=( 137 | f"Unable to view tensor of shape {tuple(value.shape)} as a jax.Array in-place:\n" 138 | f"'{err}'\n" 139 | f"Tensors of this shape will be flattened and unflattened (which may or " 140 | f"may not involve making a copy of the tensor's data)." 141 | ), 142 | level=logging.WARNING, 143 | ) 144 | return _direct_conversion(value.flatten()).reshape(value.shape) 145 | 146 | try: 147 | return _direct_conversion(value) 148 | except jaxlib.xla_extension.XlaRuntimeError as err: 149 | log_once( 150 | logger, 151 | message=( 152 | f"Unable to view tensor of shape {tuple(value.shape)} as a jax.Array in-place:\n" 153 | f"'{err}'\n" 154 | f"Tensors of this shape will be flattened and unflattened (which may or " 155 | f"may not involve making a copy of the tensor's data)." 156 | ), 157 | level=logging.WARNING, 158 | ) 159 | return _direct_conversion(value.flatten()).reshape(value.shape) 160 | 161 | # NOTE: This may or may not involve making a copy of the tensor. 162 | # See https://pytorch.org/docs/stable/generated/torch.flatten.html#torch.flatten 163 | return torch_to_jax_tensor(value.flatten()).reshape(value.shape) 164 | 165 | 166 | # Register it like this so the type hints are preserved on the functions (which are also called 167 | # directly in some places). 168 | torch_to_jax.register(torch.Tensor, torch_to_jax_tensor) 169 | 170 | 171 | @torch_to_jax.register(tuple) 172 | def torch_to_jax_tuple(value: tuple) -> tuple: 173 | return type(value)(*[torch_to_jax(v) for v in value]) # type: ignore 174 | 175 | 176 | @torch_to_jax.register(list) 177 | def torch_to_jax_list(value: list) -> list: 178 | return list(torch_to_jax(v) for v in value) 179 | 180 | 181 | @torch_to_jax.register(collections.abc.Mapping) 182 | def torch_to_jax_dict( 183 | value: NestedMapping[K, torch.Tensor], 184 | ) -> NestedMapping[K, jax.Array]: 185 | """Converts a dict of PyTorch tensors into a dict of jax.Arrays.""" 186 | return type(value)(**{k: torch_to_jax(v) for k, v in value.items()}) # type: ignore 187 | 188 | 189 | @torch_to_jax.register(Dataclass) 190 | def torch_to_jax_dataclass(value: DataclassType) -> DataclassType: 191 | """Converts any torch Tensors in the dataclass fields to jax arrays.""" 192 | return type(value)(**torch_to_jax(dataclasses.asdict(value))) 193 | 194 | 195 | @torch_to_jax.register(torch.device) 196 | def torch_to_jax_device(torch_device: torch.device) -> jax.Device: 197 | if torch_device.type == "cuda": 198 | backend = "gpu" 199 | elif jax.default_backend() == "tpu": 200 | backend = "tpu" 201 | else: 202 | backend = "cpu" 203 | devices = jax.devices(backend=backend) 204 | if torch_device.type == "cuda": 205 | return devices[torch_device.index] 206 | else: 207 | torch_device.index 208 | return devices[0] 209 | 210 | 211 | @torch_to_jax.register(collections.abc.Callable) 212 | def torch_to_jax_callable(torch_callable: Callable) -> Callable: 213 | """Wraps a torch function so that it can be used from jax. 214 | 215 | NOTE: You shouldn't expect jax.jit or jax.grad to work through this torch function (at least 216 | for now). 217 | """ 218 | from .to_torch import jax_to_torch 219 | 220 | @functools.wraps(torch_callable) 221 | def _wrapped(*jax_args, **jax_kwargs): 222 | torch_args = [jax_to_torch(arg) for arg in jax_args] 223 | torch_kwargs = {k: jax_to_torch(v) for k, v in jax_kwargs.items()} 224 | torch_outputs = torch_callable(*torch_args, **torch_kwargs) 225 | return torch_to_jax(torch_outputs) 226 | 227 | return _wrapped 228 | -------------------------------------------------------------------------------- /torch_jax_interop/to_jax_module.py: -------------------------------------------------------------------------------- 1 | """Utility to wrap a nn.Module into a jax function with differentiation. 2 | 3 | TODO: Maybe convert a torch.nn.Module into a flax.linen.Module? 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | import copy 9 | import functools 10 | from logging import getLogger as get_logger 11 | from typing import Callable, Concatenate, Iterable 12 | 13 | import jax 14 | import jax.core 15 | import torch 16 | import torch.func 17 | import torch.utils._pytree 18 | 19 | from torch_jax_interop.to_jax import torch_to_jax 20 | from torch_jax_interop.types import Module 21 | 22 | from .types import JaxPyTree, Out_cov, P, TorchPyTree 23 | 24 | logger = get_logger(__name__) 25 | 26 | 27 | def make_functional( 28 | module_with_state: Module[P, Out_cov], disable_autograd_tracking=False 29 | ) -> tuple[ 30 | Callable[Concatenate[Iterable[torch.Tensor], P], Out_cov], tuple[torch.Tensor, ...] 31 | ]: 32 | """Backward compatibility equivalent for `functorch.make_functional` in the new torch.func API. 33 | 34 | Adapted from https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf as suggested by 35 | this: https://pytorch.org/docs/master/func.migrating.html#functorch-make-functional 36 | """ 37 | params_dict = dict(module_with_state.named_parameters()) 38 | param_names = params_dict.keys() 39 | params_values = tuple(params_dict.values()) 40 | if disable_autograd_tracking: 41 | params_values = tuple(map(torch.Tensor.detach, params_values)) 42 | 43 | stateless_module = copy.deepcopy(module_with_state) 44 | stateless_module.to(device="meta") 45 | 46 | def fmodel(parameters: Iterable[torch.Tensor], *args: P.args, **kwargs: P.kwargs): 47 | parameters = tuple(parameters) 48 | if len(parameters) != len(param_names): 49 | raise RuntimeError( 50 | f"The wrapped PyTorch model {stateless_module} expected " 51 | f"{len(param_names)} parameters in its inputs, but only received " 52 | f"{len(parameters)}." 53 | ) 54 | params_dict = dict(zip(param_names, parameters)) 55 | return torch.func.functional_call(stateless_module, params_dict, args, kwargs) # type: ignore 56 | 57 | return fmodel, params_values 58 | 59 | 60 | def torch_module_to_jax( 61 | model: Module[..., torch.Tensor], example_output: torch.Tensor | None = None 62 | ) -> tuple[jax.custom_vjp[jax.Array], tuple[jax.Array, ...]]: 63 | """Wrap a pytorch model to be used in a jax computation. 64 | 65 | Copied and adapted from https://github.com/subho406/pytorch2jax/blob/main/pytorch2jax/pytorch2jax.py#L32 66 | 67 | Example 68 | ------- 69 | 70 | ```python 71 | import torch 72 | import jax 73 | torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 74 | torch.manual_seed(0) # doctest:+ELLIPSIS 75 | # 76 | model = torch.nn.Linear(3, 2, device=torch_device) 77 | wrapped_model, params = torch_module_to_jax(model) 78 | def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array: 79 | y_pred = wrapped_model(params, x) 80 | return jax.numpy.mean((y - y_pred) ** 2) 81 | x = jax.random.uniform(key=jax.random.key(0), shape=(1, 3)) 82 | y = jax.random.uniform(key=jax.random.key(1), shape=(1, 1)) 83 | loss, grad = jax.value_and_grad(loss_function)(params, x, y) 84 | loss # doctest: +SKIP 85 | # Array(0.5914371, dtype=float32) 86 | grad # doctest: +SKIP 87 | # (Array([[-0.02565618, -0.00836356, -0.01682458], 88 | # [ 1.0495702 , 0.34214562, 0.68827784]], dtype=float32), Array([-0.02657786, 1.0872754 ], dtype=float32)) 89 | ``` 90 | 91 | To use `jax.jit` on the model, you need to pass an example of an output so we can 92 | tell the JIT compiler the output shapes and dtypes to expect: 93 | 94 | ```python 95 | # here we reuse the same model as before: 96 | wrapped_model, params = torch_module_to_jax(model, example_output=torch.zeros(1, 2, device=torch_device)) 97 | def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array: 98 | y_pred = wrapped_model(params, x) 99 | return jax.numpy.mean((y - y_pred) ** 2) 100 | loss, grad = jax.jit(jax.value_and_grad(loss_function))(params, x, y) 101 | loss # doctest: +SKIP 102 | # Array(0.5914371, dtype=float32) 103 | grad # doctest: +SKIP 104 | # (Array([[-0.02565618, -0.00836356, -0.01682458], 105 | # [ 1.0495702 , 0.34214562, 0.68827784]], dtype=float32), Array([-0.02657786, 1.0872754 ], dtype=float32)) 106 | ``` 107 | 108 | 109 | Parameters 110 | ---------- 111 | model : torch.nn.Module 112 | A Torch module. 113 | example_output : torch.Tensor | None, optional 114 | Example of an output of the model, used to specify the expected shapes and 115 | dtypes so that this computation can be jitted. 116 | 117 | Returns 118 | ------- 119 | the functional model and the model parameters (converted to jax arrays). 120 | """ 121 | 122 | if example_output is not None: 123 | example_output = jax.tree.map(torch_to_jax, example_output) 124 | 125 | from .to_torch import jax_to_torch 126 | 127 | def j2t(v: JaxPyTree) -> TorchPyTree: 128 | if any(isinstance(v_i, jax.core.Tracer) for v_i in jax.tree.leaves(v)): 129 | # running inside JIT. 130 | return jax.pure_callback( 131 | functools.partial(jax.tree.map, jax_to_torch), v, v, vectorized=True 132 | ) 133 | return jax.tree.map(jax_to_torch, v) 134 | 135 | def t2j(v: TorchPyTree) -> JaxPyTree: 136 | if any(isinstance(v_i, jax.core.Tracer) for v_i in jax.tree.leaves(v)): 137 | # running inside JIT. 138 | return jax.pure_callback( 139 | functools.partial(jax.tree.map, torch_to_jax), v, v, vectorized=True 140 | ) 141 | return jax.tree.map(torch_to_jax, v) 142 | 143 | # Convert the PyTorch model to a functional representation and extract the model function and parameters 144 | model_fn, model_params = make_functional(model) # type: ignore 145 | 146 | # Convert the model parameters from PyTorch to JAX representations 147 | jax_model_params: tuple[jax.Array, ...] = tuple(map(torch_to_jax, model_params)) 148 | 149 | # Define the apply function using a custom VJP 150 | @jax.custom_vjp 151 | def apply(params, *args, **kwargs): 152 | # Convert the input data from PyTorch to JAX representations 153 | # Apply the model function to the input data. 154 | if example_output is None: 155 | if any( 156 | isinstance(v, jax.core.Tracer) 157 | for v in jax.tree.leaves((params, args, kwargs)) 158 | ): 159 | raise RuntimeError( 160 | "You need to pass `example_output` in order to JIT the torch function!" 161 | ) 162 | params = j2t(params) 163 | args = j2t(args) 164 | kwargs = j2t(kwargs) 165 | out = model_fn(params, *args, **kwargs) 166 | # Convert the output data from JAX to PyTorch 167 | out = t2j(out) 168 | return out 169 | 170 | result_shape_dtypes = t2j(example_output) 171 | # idea: use `torch.compile` as the equivalent of jax's `.jit`? 172 | jitted_model_fn = torch.compile(model_fn) 173 | 174 | def pytorch_model_callback(params, *args, **kwargs): 175 | params = jax.tree.map(jax_to_torch, params) 176 | args = jax.tree.map(jax_to_torch, args) 177 | kwargs = jax.tree.map(jax_to_torch, kwargs) 178 | out = jitted_model_fn(params, *args, **kwargs) 179 | return jax.tree.map(torch_to_jax, out) 180 | 181 | # Pass the jax params to the model function in this case, because 182 | # jax.pure_callback tries to extract the dtypes of the args. 183 | out = jax.pure_callback( 184 | pytorch_model_callback, 185 | result_shape_dtypes, 186 | params, 187 | *args, 188 | **kwargs, 189 | vectorized=True, 190 | ) 191 | # Convert the output data from JAX to PyTorch representations 192 | out = t2j(out) 193 | return out 194 | 195 | # Define the forward and backward passes for the VJP 196 | def apply_fwd(params, *args, **kwargs): 197 | return apply(params, *args, **kwargs), (params, args, kwargs) 198 | 199 | def apply_bwd(res, grads: jax.Array): 200 | params, args, kwargs = res 201 | # Convert the input data and gradients from PyTorch to JAX 202 | 203 | if isinstance(grads, jax.core.Tracer): 204 | jitted_model_fn = torch.compile(model_fn) 205 | 206 | # Compute the gradients using the model function and convert them from JAX to PyTorch 207 | def _pytorch_model_backward_callback(params, grads, *args, **kwargs): 208 | torch_params = jax.tree.map(jax_to_torch, params) 209 | torch_args = jax.tree.map(jax_to_torch, args) 210 | torch_kwargs = jax.tree.map(jax_to_torch, kwargs) 211 | torch_grads = jax.tree.map(jax_to_torch, grads) 212 | _torch_out, torch_jvp_fn = torch.func.vjp( 213 | jitted_model_fn, torch_params, *torch_args, **torch_kwargs 214 | ) 215 | torch_in_grads = torch_jvp_fn(torch_grads) 216 | return torch_in_grads 217 | 218 | # todo: this seems to depend on the model_fn used. Need to 219 | result_shape_dtypes = (params, args[0]) 220 | in_grads = jax.pure_callback( 221 | _pytorch_model_backward_callback, 222 | result_shape_dtypes, 223 | params, 224 | grads, 225 | *args, 226 | **kwargs, 227 | vectorized=True, 228 | ) 229 | in_grads = t2j(in_grads) 230 | return in_grads 231 | # not JITed 232 | torch_params = jax.tree.map(jax_to_torch, params) 233 | torch_args = jax.tree.map(jax_to_torch, args) 234 | torch_kwargs = jax.tree.map(jax_to_torch, kwargs) 235 | torch_grads = jax.tree.map(jax_to_torch, grads) 236 | _torch_out, torch_jvp_fn = torch.func.vjp( 237 | model_fn, torch_params, *torch_args, **torch_kwargs 238 | ) 239 | torch_in_grads = torch_jvp_fn(torch_grads) 240 | in_grads = jax.tree.map(torch_to_jax, torch_in_grads) 241 | return in_grads 242 | 243 | apply.defvjp(apply_fwd, apply_bwd) 244 | 245 | # Return the apply function and the converted model parameters 246 | return apply, jax_model_params 247 | 248 | 249 | torch_to_jax.register(torch.nn.Module, torch_module_to_jax) 250 | -------------------------------------------------------------------------------- /torch_jax_interop/to_jax_module_test.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import optax 4 | import pytest 5 | import torch 6 | from tensor_regression import TensorRegressionFixture 7 | 8 | from torch_jax_interop import jax_to_torch, torch_to_jax 9 | from torch_jax_interop.to_jax_module import torch_module_to_jax 10 | from torch_jax_interop.to_torch import jax_to_torch_device 11 | from torch_jax_interop.types import jit, value_and_grad 12 | from torch_jax_interop.utils import to_channels_first 13 | 14 | 15 | def test_torch_to_jax_nn_module(torch_device: torch.device): 16 | with torch_device: 17 | torch_net = torch.nn.Sequential( 18 | torch.nn.Linear(10, 10), 19 | torch.nn.ReLU(), 20 | torch.nn.Linear(10, 1), 21 | ) 22 | torch_params = dict(torch_net.named_parameters()) 23 | torch_input = torch.randn(1, 10, requires_grad=True) 24 | 25 | jax_net_fn, jax_net_params = torch_module_to_jax(torch_net) 26 | 27 | for jax_param, torch_param in zip(jax_net_params, torch_params.values()): 28 | torch.testing.assert_close(jax_to_torch(jax_param), torch_param) 29 | 30 | expected_torch_output = torch_net(torch_input) 31 | assert isinstance(expected_torch_output, torch.Tensor) 32 | assert expected_torch_output.requires_grad 33 | assert expected_torch_output.device == torch_device 34 | 35 | def _loss(output): 36 | return (output**2).mean() 37 | 38 | loss = _loss(expected_torch_output) 39 | loss.backward() 40 | # expected_torch_output.backward(gradient=torch.ones_like(expected_torch_output)) 41 | # Make a copy of the gradients so we can compare them later. 42 | expected_torch_grads = { 43 | k: v.grad.detach().clone() for k, v in torch_params.items() if v.grad is not None 44 | } 45 | torch_net.zero_grad(set_to_none=True) 46 | 47 | jax_input = torch_to_jax(torch_input) 48 | jax_output = jax_net_fn(jax_net_params, jax_input) 49 | 50 | torch_output = jax_to_torch(jax_output) 51 | torch.testing.assert_close(torch_output, expected_torch_output) 52 | 53 | def loss_fn(params, input): 54 | return _loss(jax_net_fn(params, input)) 55 | 56 | grad_fn = jax.grad(loss_fn, argnums=0) 57 | grads = grad_fn(jax_net_params, jax_input) 58 | jax_grads = jax.tree.map(jax_to_torch, grads) 59 | assert isinstance(jax_grads, tuple) and len(jax_grads) == len(jax_net_params) 60 | assert len(jax_grads) == len(expected_torch_grads) 61 | for jax_grad, (name, torch_grad) in zip(jax_grads, expected_torch_grads.items()): 62 | torch.testing.assert_close(jax_grad, torch_grad) 63 | 64 | 65 | @pytest.mark.parametrize("with_jit", [False, True]) 66 | @pytest.mark.parametrize("input_needs_grad", [False, True]) 67 | def test_use_torch_module_in_jax_graph( 68 | torch_network: torch.nn.Module, 69 | jax_input: jax.Array, 70 | tensor_regression: TensorRegressionFixture, 71 | num_classes: int, 72 | seed: int, 73 | with_jit: bool, 74 | torch_device: torch.device, 75 | input_needs_grad: bool, 76 | ): 77 | torch_parameters = {name: p for name, p in torch_network.named_parameters()} 78 | # todo: check that only trainable parameters have a gradient? 79 | # _is_trainable = {name: p.requires_grad for name, p in torch_parameters.items()} 80 | # _num_parameters = len(torch_parameters) 81 | # _total_num_parameters = sum( 82 | # map( 83 | # operator.methodcaller("numel"), 84 | # filter(operator.attrgetter("requires_grad"), torch_parameters.values()), 85 | # ) 86 | # ) 87 | 88 | with torch.random.fork_rng([torch_device] if torch_device.type == "cuda" else []): 89 | # Pass the example output so the fn can be jitted! 90 | example_out = torch_network(jax_to_torch(jax_input)) 91 | 92 | flat_torch_params, params_treedef = jax.tree.flatten(torch_parameters) 93 | wrapped_torch_network_fn, jax_params = torch_module_to_jax( 94 | torch_network, example_output=example_out 95 | ) 96 | 97 | assert callable(wrapped_torch_network_fn) 98 | assert isinstance(jax_params, tuple) and all(isinstance(p, jax.Array) for p in jax_params) 99 | assert len(jax_params) == len(flat_torch_params) 100 | # TODO: Why would the ordering change?! 101 | jax_param_shapes = sorted([p.shape for p in jax_params]) 102 | torch_param_shapes = sorted([p.shape for p in flat_torch_params]) 103 | assert jax_param_shapes == torch_param_shapes 104 | 105 | # BUG: values are different? Is it only due to the dtype? 106 | # assert all( 107 | # numpy.testing.assert_allclose(jax_p, torch_to_jax(torch_p)) 108 | # for jax_p, torch_p in zip( 109 | # sorted(jax_params, key=operator.attrgetter("shape")), 110 | # sorted(flat_torch_params, key=operator.attrgetter("shape")), 111 | # ) 112 | # ) 113 | 114 | batch_size = jax_input.shape[0] 115 | labels = jax.random.randint( 116 | key=jax.random.key(seed), 117 | minval=0, 118 | maxval=num_classes, 119 | shape=(batch_size,), 120 | ) 121 | 122 | def loss_fn( 123 | params: tuple[jax.Array, ...], x: jax.Array, y: jax.Array 124 | ) -> tuple[jax.Array, jax.Array]: 125 | x = to_channels_first(x) 126 | logits = wrapped_torch_network_fn(params, x) 127 | one_hot = jax.nn.one_hot(y, logits.shape[-1]) 128 | loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) 129 | return loss, logits 130 | 131 | if input_needs_grad: 132 | jax_input = jax_input.astype(jnp.float32) 133 | grad_fn = value_and_grad(loss_fn, argnums=[0, 1], has_aux=True) 134 | if with_jit: 135 | grad_fn = jit(grad_fn) 136 | (loss, logits), (param_grads, input_grads) = grad_fn(jax_params, jax_input, labels) 137 | assert len(param_grads) == len(jax_params) 138 | assert isinstance(input_grads, jax.Array) 139 | assert input_grads.shape == jax_input.shape 140 | else: 141 | grad_fn = value_and_grad(loss_fn, argnums=0, has_aux=True) 142 | if with_jit: 143 | grad_fn = jit(grad_fn) 144 | (loss, logits), param_grads = grad_fn(jax_params, jax_input, labels) 145 | input_grads = None 146 | assert len(param_grads) == len(jax_params) 147 | 148 | def _get_device(v: jax.Array) -> torch.device: 149 | assert len(v.devices()) == 1 150 | jax_device = v.devices().pop() 151 | return jax_to_torch_device(jax_device) 152 | 153 | if input_needs_grad: 154 | assert input_grads is not None 155 | assert _get_device(input_grads) == torch_device 156 | assert _get_device(loss) == torch_device 157 | assert _get_device(logits) == torch_device 158 | assert len(param_grads) == len(jax_params) 159 | for param, grad in zip(jax_params, param_grads): 160 | assert param.shape == grad.shape 161 | assert param.dtype == grad.dtype 162 | assert _get_device(param) == torch_device 163 | assert _get_device(grad) == torch_device 164 | 165 | grads_dict = jax.tree.unflatten(params_treedef, param_grads) 166 | 167 | tensor_regression.check( 168 | { 169 | "input": jax_input, 170 | "output": logits, 171 | "loss": loss, 172 | } 173 | | {name: p for name, p in grads_dict.items()}, 174 | include_gpu_name_in_stats=False, 175 | ) 176 | -------------------------------------------------------------------------------- /torch_jax_interop/to_jax_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import shutil 3 | from typing import Any 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy 8 | import pytest 9 | import torch 10 | from pytest_benchmark.fixture import BenchmarkFixture 11 | from tensor_regression import TensorRegressionFixture 12 | 13 | from torch_jax_interop import jax_to_torch, torch_to_jax 14 | from torch_jax_interop.to_jax_module import torch_module_to_jax 15 | from torch_jax_interop.utils import log_once 16 | 17 | 18 | def test_jax_can_use_the_GPU(): 19 | """Test that Jax can use the GPU if it we have one.""" 20 | # NOTE: Super interesting: Seems like running just an 21 | # `import jax.numpy; print(jax.numpy.zeros(1).devices())` in a new terminal FAILS, but if you 22 | # do `import torch` before that, then it works! 23 | import jax.numpy 24 | 25 | device = jax.numpy.zeros(1).devices().pop() 26 | if shutil.which("nvidia-smi"): 27 | assert str(device) == "cuda:0" 28 | else: 29 | assert "cpu" in str(device).lower() 30 | 31 | 32 | def test_torch_can_use_the_GPU(): 33 | """Test that torch can use the GPU if it we have one.""" 34 | 35 | assert torch.cuda.is_available() == bool(shutil.which("nvidia-smi")) 36 | 37 | 38 | @pytest.mark.parametrize( 39 | ("shape", "might_warn"), 40 | [ 41 | ((1,), False), 42 | ((10, 10), False), 43 | ((100, 100, 100), False), 44 | (tuple(range(1, 6)), True), 45 | ((1, 3, 32, 32), True), 46 | ], 47 | ids=str, 48 | ) 49 | def test_torch_to_jax_tensor( 50 | torch_device: torch.device, 51 | shape: tuple[int, ...], 52 | might_warn: bool, 53 | torch_dtype: torch.dtype, 54 | jax_dtype: jax.numpy.dtype, 55 | seed: int, 56 | benchmark: BenchmarkFixture, 57 | caplog: pytest.LogCaptureFixture, 58 | ): 59 | if numpy.prod(shape) >= 1_000_000 and torch_device.type == "cpu": 60 | pytest.skip("Skipping test with large tensor on CPU.") 61 | 62 | gen = torch.Generator(device=torch_device).manual_seed(seed) 63 | if torch_dtype.is_floating_point: 64 | torch_value = torch.rand(shape, device=torch_device, generator=gen, dtype=torch_dtype) 65 | else: 66 | torch_value = torch.randint( 67 | low=0, 68 | high=100, 69 | size=shape, 70 | device=torch_device, 71 | generator=gen, 72 | dtype=torch_dtype, 73 | ) 74 | assert torch_value.shape == shape 75 | 76 | log_once.cache_clear() 77 | with caplog.at_level(logging.WARNING): 78 | jax_value = benchmark(torch_to_jax, torch_value) 79 | if not might_warn: 80 | assert not caplog.records 81 | assert isinstance(jax_value, jax.Array) 82 | assert jax_value.dtype == jax_dtype 83 | 84 | jax_expected_device = torch_to_jax(torch_value.device) 85 | assert jax_value.devices() == {jax_expected_device} 86 | 87 | torch_numpy_value = torch_value.cpu().numpy() 88 | jax_numpy_value = numpy.asarray(jax_value) 89 | numpy.testing.assert_allclose(torch_numpy_value, jax_numpy_value) 90 | 91 | # round-trip: 92 | torch_round_trip = jax_to_torch(jax_value) 93 | assert isinstance(torch_round_trip, torch.Tensor) 94 | 95 | if torch_dtype == torch.float64: 96 | assert jax_dtype == jnp.float32 97 | assert torch_round_trip.dtype == torch.float32 98 | torch.testing.assert_close(torch_round_trip, torch_value.to(torch_round_trip.dtype)) 99 | elif torch_dtype == torch.int64: 100 | assert jax_dtype == jnp.int32 101 | assert torch_round_trip.dtype == torch.int32 102 | torch.testing.assert_close(torch_round_trip, torch_value.to(torch_round_trip.dtype)) 103 | else: 104 | torch.testing.assert_close(torch_value, torch_round_trip) 105 | 106 | 107 | def some_torch_function(x: torch.Tensor) -> torch.Tensor: 108 | return x + torch.ones_like(x) 109 | 110 | 111 | def test_torch_to_jax_function( 112 | torch_device: torch.device, 113 | benchmark: BenchmarkFixture, 114 | ): 115 | torch_input = torch.arange(5, dtype=torch.int32, device=torch_device) 116 | torch_function = some_torch_function 117 | expected_torch_output = torch_function(torch_input) 118 | 119 | jax_input = torch_to_jax(torch_input) 120 | jax_function = torch_to_jax(torch_function) 121 | jax_output = benchmark(jax_function, jax_input) 122 | 123 | torch_output = jax_to_torch(jax_output) 124 | # todo: dtypes might be mismatched for int64 and float64 125 | torch.testing.assert_close(torch_output, expected_torch_output) 126 | 127 | # todo: Should it return torch Tensor when given a torch Tensor? 128 | # ? = jax_function(torch_input) 129 | 130 | 131 | class FooBar: 132 | pass 133 | 134 | 135 | @pytest.mark.parametrize("unsupported_value", [FooBar()]) 136 | def test_log_once_on_unsupported_value(unsupported_value: Any, caplog: pytest.LogCaptureFixture): 137 | with caplog.at_level(logging.DEBUG): 138 | assert torch_to_jax(unsupported_value) is unsupported_value 139 | assert len(caplog.records) == 1 140 | assert "No registered handler for values of type" in caplog.records[0].getMessage() 141 | 142 | caplog.clear() 143 | with caplog.at_level(logging.DEBUG): 144 | assert torch_to_jax(unsupported_value) is unsupported_value 145 | assert len(caplog.records) == 0 146 | 147 | 148 | def test_torch_params_dont_change( 149 | torch_network: torch.nn.Module, tensor_regression: TensorRegressionFixture 150 | ): 151 | tensor_regression.check( 152 | dict(torch_network.named_parameters()), 153 | include_gpu_name_in_stats=False, 154 | ) 155 | 156 | 157 | def test_benchmark_forward_pass( 158 | torch_network: torch.nn.Module, 159 | torch_input: torch.Tensor, 160 | benchmark: BenchmarkFixture, 161 | tensor_regression: TensorRegressionFixture, 162 | ): 163 | output = torch_network(torch_input) 164 | 165 | jax_fn, params = torch_module_to_jax(torch_network, example_output=output) 166 | output = benchmark(jax_fn, params, torch_to_jax(torch_input)) 167 | tensor_regression.check( 168 | {"output": output}, 169 | include_gpu_name_in_stats=False, 170 | ) 171 | -------------------------------------------------------------------------------- /torch_jax_interop/to_torch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import collections.abc 4 | import dataclasses 5 | import functools 6 | import logging 7 | from logging import getLogger as get_logger 8 | from typing import Any, Callable, overload 9 | 10 | import jax 11 | import torch 12 | from jax.dlpack import to_dlpack as jax_to_dlpack # type: ignore (not exported there?) 13 | from torch.utils import dlpack as torch_dlpack 14 | 15 | from .types import Dataclass, DataclassType, K, NestedDict, NestedMapping 16 | from .utils import log_once 17 | 18 | logger = get_logger(__name__) 19 | 20 | 21 | @overload 22 | def jax_to_torch(value: jax.Array, /) -> torch.Tensor: 23 | ... 24 | 25 | 26 | @overload 27 | def jax_to_torch(value: jax.Device, /) -> torch.device: 28 | ... 29 | 30 | 31 | @overload 32 | def jax_to_torch(value: tuple[jax.Array, ...], /) -> tuple[torch.Tensor, ...]: 33 | ... 34 | 35 | 36 | @overload 37 | def jax_to_torch(value: list[jax.Array], /) -> list[torch.Tensor]: 38 | ... 39 | 40 | 41 | @overload 42 | def jax_to_torch(value: NestedDict[K, jax.Array], /) -> NestedDict[K, torch.Tensor]: 43 | ... 44 | 45 | 46 | @overload 47 | def jax_to_torch(value: Any, /) -> Any: 48 | ... 49 | 50 | 51 | def jax_to_torch(value: Any, /) -> Any: 52 | """Converts JAX arrays to PyTorch Tensors. 53 | 54 | Converts the tensors "in-place", without the need for copies or moving data to the CPU. 55 | 56 | Args: 57 | value: jax array 58 | 59 | Returns: 60 | a PyTorch tensor 61 | """ 62 | log_once( 63 | logger, 64 | message=f"No registered handler for values of type {type(value)}, returning it as-is.", 65 | level=logging.DEBUG, 66 | ) 67 | return value 68 | 69 | 70 | # Make it a singledispatch here instead of above, so the overloads are presented as 71 | # options for code completion. 72 | jax_to_torch = functools.singledispatch(jax_to_torch) # type: ignore 73 | 74 | 75 | # Keep `None`s the same. 76 | @jax_to_torch.register(type(None)) 77 | @jax_to_torch.register(int) 78 | @jax_to_torch.register(float) 79 | @jax_to_torch.register(str) 80 | @jax_to_torch.register(bool) 81 | @jax_to_torch.register(bytes) 82 | def no_op(v: Any) -> Any: 83 | return v 84 | 85 | 86 | def jax_to_torch_tensor(value: jax.Array, /) -> torch.Tensor: 87 | """Converts a Jax array into a torch.Tensor.""" 88 | try: 89 | return torch_dlpack.from_dlpack(value) 90 | except Exception: 91 | return torch_dlpack.from_dlpack(jax_to_dlpack(value)) 92 | 93 | 94 | # Register it like this so the type hints are preserved on the functions (which are also called 95 | # directly in some places). 96 | jax_to_torch.register(jax.Array, jax_to_torch_tensor) 97 | 98 | 99 | @jax_to_torch.register(tuple) 100 | def jax_to_torch_tuple(value: tuple) -> tuple: 101 | return type(value)(*[jax_to_torch(v) for v in value]) 102 | 103 | 104 | @jax_to_torch.register(list) 105 | def jax_to_torch_list(value: list) -> list: 106 | return list(jax_to_torch(v) for v in value) 107 | 108 | 109 | @jax_to_torch.register(collections.abc.Mapping) 110 | def jax_to_torch_mapping( 111 | value: NestedMapping[str, jax.Array | Any], 112 | ) -> NestedMapping[str, torch.Tensor | Any]: 113 | """Converts a dict of Jax arrays into a dict of PyTorch tensors .""" 114 | return type(value)(**{k: jax_to_torch(v) for k, v in value.items()}) # type: ignore 115 | 116 | 117 | @jax_to_torch.register(Dataclass) 118 | def jax_to_torch_dataclass(value: DataclassType) -> DataclassType: 119 | """Converts any jax.Arrays in the dataclass fields to torch Tensors.""" 120 | return type(value)(**jax_to_torch(dataclasses.asdict(value))) 121 | 122 | 123 | @jax_to_torch.register(jax.Device) 124 | def jax_to_torch_device(jax_device: jax.Device) -> torch.device: 125 | jax_device_str = str(jax_device) 126 | if jax_device_str.startswith("cuda"): 127 | device_type, _, index = jax_device_str.partition(":") 128 | assert index.isdigit() 129 | return torch.device(device_type, int(index)) 130 | return torch.device("cpu") 131 | 132 | 133 | @jax_to_torch.register(collections.abc.Callable) 134 | def jax_to_torch_callable(jax_callable: Callable) -> Callable: 135 | """Wraps a jax function so that it can be used from pytorch. 136 | 137 | NOTE: You shouldn't the backward pass through this jax function to work (at least for now). 138 | 139 | TODO: Create a custom autograd Function that computes the gradient using jax.grad. 140 | """ 141 | from .to_jax import torch_to_jax 142 | 143 | @functools.wraps(jax_callable) 144 | def _wrapped(*torch_args, **torch_kwargs): 145 | jax_args = [torch_to_jax(arg) for arg in torch_args] 146 | jax_kwargs = {k: torch_to_jax(v) for k, v in torch_kwargs.items()} 147 | jax_outputs = jax_callable(*jax_args, **jax_kwargs) 148 | return jax_to_torch(jax_outputs) 149 | 150 | return _wrapped 151 | -------------------------------------------------------------------------------- /torch_jax_interop/to_torch_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import operator 5 | import typing 6 | from logging import getLogger as get_logger 7 | from typing import Any, Callable, Generic, Literal, overload 8 | 9 | import chex 10 | import jax 11 | import torch 12 | from chex import PyTreeDef 13 | from typing_extensions import Unpack 14 | 15 | from torch_jax_interop.types import ( 16 | In, 17 | is_sequence_of, 18 | jit, 19 | value_and_grad, 20 | ) 21 | 22 | from .to_torch import jax_to_torch 23 | from .types import Aux, JaxPyTree, Params, TorchPyTree 24 | from .utils import log_once 25 | 26 | logger = get_logger(__name__) 27 | # todo: make it possible to have different out type than just a single tensor/array. 28 | Out = jax.Array 29 | 30 | 31 | class WrappedJaxFunction(torch.nn.Module): 32 | """Wraps a jax function that returns vectors or matrices into a `torch.nn.Module`. 33 | 34 | This function should accept parameters as a first argument, followed by some inputs 35 | (jax.Arrays) and should return a single output (jax.Array). 36 | 37 | TODOs: 38 | 39 | - [ ] Test and add support for different combinations of .requires_grad in inputs. 40 | - [ ] Add support for multiple outputs instead of a single tensor. 41 | - [ ] Somehow support pytrees as inputs instead of just jax Arrays, maybe with a 42 | classmethod that flattens / unflattens stuff? 43 | 44 | ## Examples 45 | 46 | Suppose we have some jax function we'd like to use in a PyTorch model: 47 | 48 | ```python 49 | import jax 50 | import jax.numpy as jnp 51 | def some_jax_function(params: jax.Array, x: jax.Array): 52 | '''Some toy function that takes in some parameters and an input vector.''' 53 | return jnp.dot(x, params) 54 | ``` 55 | 56 | By importing this: 57 | 58 | ```python 59 | from torch_jax_interop import WrappedJaxFunction 60 | ``` 61 | 62 | We can then wrap this jax function into a torch.nn.Module with learnable parameters: 63 | 64 | ```python 65 | import torch 66 | import torch.nn 67 | module = WrappedJaxFunction(some_jax_function, jax.random.normal(jax.random.key(0), (2, 1))) 68 | module = module.to("cpu") # jax arrays are on GPU by default, moving them to CPU for this example. 69 | ``` 70 | 71 | The parameters are now learnable parameters of the module parameters: 72 | 73 | ```python 74 | dict(module.state_dict()) 75 | {'params.0': tensor([[-0.7848], 76 | [ 0.8564]])} 77 | ``` 78 | 79 | You can use this just like any other torch.nn.Module: 80 | 81 | ```python 82 | x, y = torch.randn(2), torch.rand(1) 83 | output = module(x) 84 | loss = torch.nn.functional.mse_loss(output, y) 85 | loss.backward() 86 | 87 | model = torch.nn.Sequential( 88 | torch.nn.Linear(123, 2), 89 | module, 90 | ) 91 | ``` 92 | """ 93 | 94 | @overload 95 | def __init__( 96 | self, 97 | jax_function: Callable[[Params, *tuple[jax.Array, ...]], jax.Array], 98 | jax_params: Params, 99 | has_aux: Literal[False] = False, 100 | clone_params: bool = False, 101 | ): 102 | ... 103 | 104 | @overload 105 | def __init__( 106 | self, 107 | jax_function: Callable[ 108 | [Params, *tuple[jax.Array, ...]], tuple[jax.Array, JaxPyTree] 109 | ], 110 | jax_params: Params, 111 | has_aux: Literal[True] = True, 112 | clone_params: bool = False, 113 | ): 114 | ... 115 | 116 | @overload 117 | def __init__( 118 | self, 119 | jax_function: Callable[[Params, *tuple[jax.Array, ...]], jax.Array] 120 | | Callable[[Params, *tuple[jax.Array, ...]], tuple[jax.Array, JaxPyTree]], 121 | jax_params: Params, 122 | has_aux: bool = ..., 123 | clone_params: bool = False, 124 | ): 125 | ... 126 | 127 | def __init__( 128 | self, 129 | jax_function: Callable[[Params, *tuple[jax.Array, ...]], jax.Array] 130 | | Callable[[Params, *tuple[jax.Array, ...]], tuple[jax.Array, JaxPyTree]], 131 | jax_params: Params, 132 | has_aux: bool = False, 133 | clone_params: bool = False, 134 | ): 135 | """Wraps the given jax function into a torch.nn.Module. 136 | 137 | Parameters 138 | ---------- 139 | jax_function: Function to wrap. 140 | jax_params: Initial value for the parameters (PyTree of jax arrays). 141 | has_aux: Whether the jax function returns an additional output (auxiliary data). 142 | clone_params: Whether the torch tensors should be copies of the jax parameters \ 143 | instead of sharing the same memory. Set this to `True` when you plan to do \ 144 | distributed training, otherwise you could run into 'invalid device \ 145 | pointer' errors. 146 | """ 147 | super().__init__() 148 | self.jax_function = jax_function 149 | self.has_aux = has_aux 150 | 151 | # Flatten the jax parameters so we can store them in a nn.ParameterList. 152 | flat_params, self.params_treedef = jax.tree.flatten(jax_params) 153 | # Register the parameters. 154 | # Need to call .clone() when doing distributed training, otherwise we get a RuntimeError: 155 | # Invalid device pointer when trying to share the CUDA memory. 156 | flat_params = map(jax_to_torch, flat_params) 157 | if clone_params: 158 | flat_params = map(operator.methodcaller("clone"), flat_params) 159 | self.params = torch.nn.ParameterList(flat_params) 160 | 161 | def forward( 162 | self, *args: torch.Tensor 163 | ) -> torch.Tensor | tuple[torch.Tensor, TorchPyTree]: 164 | flat_inputs, inputs_treedef = jax.tree.flatten(args) 165 | # Flatten everything out before passing it to autograd. 166 | # todo: this should be `flat_outputs` and be unflattened before being returned. 167 | outputs = _JaxFunction.apply( 168 | self.jax_function, 169 | inputs_treedef, 170 | self.params_treedef, 171 | self.has_aux, 172 | *flat_inputs, 173 | *self.params, 174 | ) 175 | assert isinstance(outputs, tuple) and len(outputs) == 3 176 | output, _jvp_fn, aux = outputs 177 | if self.has_aux: 178 | return output, aux 179 | return output 180 | 181 | if typing.TYPE_CHECKING: 182 | __call__ = forward 183 | 184 | 185 | Output = jax.Array 186 | 187 | 188 | class WrappedJaxScalarFunction(WrappedJaxFunction): 189 | """Wraps a jax function that returns scalars into a `torch.nn.Module`. 190 | 191 | Compared to `WrappedJaxFunction`, this has the advantage of using jax.value_and_grad 192 | for the combined forward and backward pass. 193 | 194 | This function should accept parameters as a first argument, followed by some inputs 195 | (jax.Arrays) and should return a tuple with an output and some additional data (aux) 196 | """ 197 | 198 | @overload 199 | def __init__( 200 | self, 201 | jax_function: Callable[ 202 | [Params, *tuple[jax.Array, ...]], 203 | jax.Array, 204 | ], 205 | jax_params: Params, 206 | has_aux: Literal[False] = False, 207 | clone_params: bool = False, 208 | ): 209 | ... 210 | 211 | @overload 212 | def __init__( 213 | self, 214 | jax_function: Callable[ 215 | [Params, *tuple[jax.Array, ...]], 216 | tuple[jax.Array, JaxPyTree], 217 | ], 218 | jax_params: Params, 219 | has_aux: Literal[True] = True, 220 | clone_params: bool = False, 221 | ): 222 | ... 223 | 224 | def __init__( 225 | self, 226 | jax_function: Callable[ 227 | [Params, *tuple[jax.Array, ...]], 228 | jax.Array, 229 | ] 230 | | Callable[ 231 | [Params, *tuple[jax.Array, ...]], 232 | tuple[jax.Array, JaxPyTree], 233 | ], 234 | jax_params: Params, 235 | has_aux: bool = True, 236 | clone_params: bool = False, 237 | ): 238 | super().__init__( 239 | jax_function=jax_function, 240 | jax_params=jax_params, 241 | has_aux=has_aux, 242 | clone_params=clone_params, 243 | ) 244 | self.jax_function: Callable[ 245 | [Params, *tuple[jax.Array, ...]], tuple[jax.Array, JaxPyTree] 246 | ] 247 | self.jax_value_and_grad_function_wrt_only_params = jit( 248 | value_and_grad(jax_function, argnums=0, has_aux=has_aux) 249 | ) 250 | self._value_and_grad_fns: dict[ 251 | tuple[bool, ...], # the `.requires_grad` of the (*inputs, *params). 252 | Callable[ 253 | [Params, *tuple[jax.Array, ...]], # same signature as the fn 254 | tuple[ 255 | jax.Array | tuple[jax.Array, JaxPyTree], # returns the output value 256 | # and gradients of either just params or params and inputs: 257 | Params | tuple[Params, Unpack[tuple[jax.Array, ...]]], 258 | ], 259 | ], 260 | ] = {} 261 | 262 | def forward(self, *inputs: torch.Tensor) -> tuple[torch.Tensor, TorchPyTree]: 263 | flat_inputs: list[torch.Tensor] 264 | flat_inputs, inputs_treedef = jax.tree.flatten(inputs) 265 | 266 | inputs_need_grad = tuple(input.requires_grad for input in flat_inputs) 267 | params_need_grad = tuple(param.requires_grad for param in self.parameters()) 268 | # IDEA: Reuse or create the `value_and_grad` fn that we need 269 | # depending on which param / input requires gradients 270 | n_inputs = inputs_treedef.num_leaves 271 | n_params = self.params_treedef.num_leaves 272 | 273 | if not self._value_and_grad_fns: 274 | # When all params need a grad and no input needs one, use the function we 275 | # already have (argnums=0): 276 | self._value_and_grad_fns[ 277 | tuple([False] * n_inputs + [True] * n_params) 278 | ] = self.jax_value_and_grad_function_wrt_only_params 279 | # When neither the params nor the inputs require a gradient, the 280 | # value_and_grad function won't be used, so we can just set the same 281 | # function. 282 | self._value_and_grad_fns[ 283 | tuple([False] * n_inputs + [False] * n_params) 284 | ] = self.jax_value_and_grad_function_wrt_only_params 285 | 286 | key = inputs_need_grad + params_need_grad 287 | # Note: when we don't need the value_and_grad function won't be used anyway. 288 | if key in self._value_and_grad_fns: 289 | # We already have the function to be used to compute the value and desired gradients 290 | value_and_grad_fn = self._value_and_grad_fns[key] 291 | else: 292 | logger.info( 293 | f"Compiling the `value_and_grad` function needed for {inputs_need_grad=} and {params_need_grad=}" 294 | ) 295 | assert all(params_need_grad) # assuming all parameters need a grad for now 296 | # NOTE: Since all parameters are passed via the `param` is the first positional 297 | # argument, currently either no parameters need a grad, or every parameter needs a grad. 298 | argnums = ( 299 | 0, 300 | *tuple( 301 | i + 1 for i, input in enumerate(flat_inputs) if input.requires_grad 302 | ), 303 | ) 304 | logger.debug(f"argnums({argnums=})") 305 | value_and_grad_fn = jit( 306 | value_and_grad( 307 | self.jax_function, 308 | argnums=argnums, 309 | has_aux=True, 310 | ) 311 | ) 312 | self._value_and_grad_fns[key] = value_and_grad_fn 313 | 314 | output = _JaxScalarFunction.apply( 315 | self.jax_function, 316 | value_and_grad_fn, 317 | inputs_treedef, 318 | self.params_treedef, 319 | *inputs, 320 | *self.params, 321 | ) 322 | assert isinstance(output, tuple) and len(output) == 2 323 | out, aux = output 324 | assert isinstance(out, torch.Tensor) 325 | return out, aux 326 | 327 | if typing.TYPE_CHECKING: 328 | __call__ = forward 329 | 330 | 331 | class _JaxFunction(torch.autograd.Function, Generic[Params]): 332 | """Wrapper for a jax function, making it usable in PyTorch's autograd system. 333 | 334 | TODOs: make this more flexible in terms of input/output signature: 335 | - [ ] Currently assumes that has_aux is False. 336 | - [ ] Currently assumes that the function returns a single array. 337 | - [ ] Currently assumes that the function accepts only params and one input... 338 | """ 339 | 340 | @staticmethod 341 | def forward( 342 | jax_function: Callable[[Params, In], Out] 343 | | Callable[[Params, In], tuple[Out, Aux]], 344 | inputs_treedef: PyTreeDef, 345 | params_treedef: PyTreeDef, 346 | has_aux: bool, 347 | # need to flatten the params for autograd to understand that they need a gradient. 348 | *flat_inputs_and_params: torch.Tensor, 349 | ): 350 | from .to_jax import torch_to_jax 351 | 352 | n_inputs = inputs_treedef.num_leaves 353 | flat_inputs, flat_params = ( 354 | flat_inputs_and_params[:n_inputs], 355 | flat_inputs_and_params[n_inputs:], 356 | ) 357 | jax_inputs = jax.tree.unflatten(inputs_treedef, map(torch_to_jax, flat_inputs)) 358 | jax_params = jax.tree.unflatten(params_treedef, map(torch_to_jax, flat_params)) 359 | # todo: support multiple outputs and/or `has_aux=True`. 360 | if has_aux: 361 | jax_function_with_aux = typing.cast( 362 | Callable[[Params, In], tuple[Out, Aux]], jax_function 363 | ) 364 | output, jvp_function, aux = jax.vjp( 365 | jax_function_with_aux, jax_params, *jax_inputs, has_aux=has_aux 366 | ) 367 | output = jax.tree.map(jax_to_torch, output) 368 | aux = jax.tree.map(jax_to_torch, aux) 369 | return output, jvp_function, aux 370 | else: 371 | output, jvp_function = jax.vjp( 372 | jax_function, jax_params, *jax_inputs, has_aux=has_aux 373 | ) 374 | output = jax.tree.map(jax_to_torch, output) 375 | # flat_outputs, = jax.tree.leaves(output) 376 | return output, jvp_function, None 377 | 378 | if typing.TYPE_CHECKING: 379 | apply = forward # type: ignore 380 | 381 | # setup_context is responsible for calling methods and/or assigning to 382 | # the ctx object. Please do not do additional compute (e.g. add 383 | # Tensors together) in setup_context. 384 | @staticmethod 385 | def setup_context( 386 | ctx: torch.autograd.function.BackwardCFunction, inputs: tuple, output: tuple 387 | ): 388 | ( 389 | jax_function, 390 | inputs_treedef, 391 | params_treedef, 392 | has_aux, 393 | *inputs_and_params, 394 | ) = inputs 395 | output, jvp_function, aux = output 396 | # Save the function to use to compute the backward pass. 397 | ctx.jvp_function = jvp_function # type: ignore 398 | ctx.inputs_treedef = inputs_treedef # type: ignore 399 | ctx.params_treedef = params_treedef # type: ignore 400 | ctx.has_aux = has_aux # type: ignore 401 | ctx.aux = aux # type: ignore 402 | 403 | @torch.autograd.function.once_differentiable 404 | @staticmethod 405 | def backward( 406 | ctx: torch.autograd.function.NestedIOFunction, 407 | *output_grads: Unpack[tuple[torch.Tensor, Unpack[tuple[None, ...]]]], 408 | ): 409 | from .to_jax import torch_to_jax 410 | 411 | grad_output, *_unused_output_grads = output_grads 412 | assert grad_output is not None 413 | assert all(unused_grad is None for unused_grad in _unused_output_grads) 414 | needs_input_grad = tuple(ctx.needs_input_grad) 415 | 416 | assert ( 417 | is_sequence_of(needs_input_grad, bool) 418 | and isinstance(needs_input_grad, tuple) 419 | and len(needs_input_grad) >= 5 420 | ) 421 | _, _, _, _, *inputs_and_params_need_grad = needs_input_grad 422 | 423 | jvp_function = ctx.jvp_function # type: ignore 424 | inputs_treedef: PyTreeDef = ctx.inputs_treedef # type: ignore 425 | params_treedef: PyTreeDef = ctx.params_treedef # type: ignore 426 | # fn_had_aux: bool = ctx.has_aux # type: ignore 427 | 428 | n_inputs = inputs_treedef.num_leaves 429 | n_params = params_treedef.num_leaves 430 | 431 | inputs_need_grad, params_need_grad = ( 432 | inputs_and_params_need_grad[:n_inputs], 433 | inputs_and_params_need_grad[n_inputs:], 434 | ) 435 | 436 | _jax_grad_output = torch_to_jax(grad_output) 437 | _jax_grad_params, *_jax_input_grads = jvp_function(_jax_grad_output) 438 | 439 | flat_param_grads = jax.tree.leaves(jax.tree.map(jax_to_torch, _jax_grad_params)) 440 | flat_input_grads = jax.tree.leaves(jax.tree.map(jax_to_torch, _jax_input_grads)) 441 | 442 | # Only give out the gradients if they were requested. 443 | assert len(flat_param_grads) == n_params 444 | assert len(flat_input_grads) == n_inputs 445 | flat_param_grads = tuple( 446 | flat_param_grad if params_need_grad[i] else None 447 | for i, flat_param_grad in enumerate(flat_param_grads) 448 | ) 449 | # We have gradients for inputs that don't require them. 450 | assert len(flat_input_grads) == len(inputs_need_grad) 451 | flat_input_grads = tuple( 452 | flat_input_grad if inputs_need_grad[i] else None 453 | for i, flat_input_grad in enumerate(flat_input_grads) 454 | ) 455 | 456 | return None, None, None, None, *flat_input_grads, *flat_param_grads 457 | 458 | @staticmethod 459 | def jvp( 460 | ctx, 461 | jax_function: Callable[[Params, jax.Array], jax.Array], 462 | params_treedef: PyTreeDef, 463 | input_grad: torch.Tensor, 464 | *params_grads: torch.Tensor, # need to flatten the params for autograd to understand that they need a gradient. 465 | ): 466 | # todo: debug and test this further. 467 | # https://pytorch.org/docs/stable/notes/extending.html#forward-mode-ad 468 | # Called after `forward` 469 | # Should return as many tensors as there were outputs. 470 | from .to_jax import torch_to_jax 471 | 472 | log_once( 473 | logger, 474 | message="This is untested! Use at your own risk!", 475 | level=logging.WARNING, 476 | ) 477 | jax_params = jax.tree.unflatten(params_treedef, map(torch_to_jax, params_grads)) 478 | primals_out, tangents_out = jax.jvp(jax_function, jax_params, input_grad) 479 | output_grads = jax.tree.map(jax_to_torch, tangents_out) 480 | return output_grads, None 481 | 482 | @staticmethod 483 | def vmap( 484 | info, 485 | in_dims: tuple[int | None, ...], 486 | jax_function: Callable[[Params, jax.Array], jax.Array], 487 | params_treedef: PyTreeDef, 488 | input: torch.Tensor, 489 | *params: torch.Tensor, 490 | ): 491 | log_once( 492 | logger, 493 | message="This is untested! Use at your own risk!", 494 | level=logging.WARNING, 495 | ) 496 | # todo: debug and test this further. 497 | _, _, input_vmap_dim, *params_vmap_dims = in_dims 498 | 499 | params_vmap_dims_dict = jax.tree.unflatten(params_treedef, params_vmap_dims) 500 | # todo: use something like functools.cache so we can jit this? 501 | vmapped_jax_function = jax.vmap( 502 | jax_function, in_axes=(params_vmap_dims_dict, input_vmap_dim) 503 | ) 504 | from .to_jax import torch_to_jax 505 | 506 | jax_params = jax.tree.unflatten(params_treedef, map(torch_to_jax, params)) 507 | jax_input = torch_to_jax(input) 508 | vmapped_result = vmapped_jax_function(jax_params, jax_input) 509 | return vmapped_result 510 | 511 | 512 | class _JaxScalarFunction(torch.autograd.Function, Generic[Params]): 513 | """Wrapper for a jax scalar-valued function, making it usable in PyTorch's autograd system. 514 | 515 | This has potentially an advantage compared to `JaxFunction` (which is more general): 516 | It gets to use (and jit) the `jax.value_and_grad` of the function. 517 | 518 | TODO: Assumes that the function "has an aux": that it returns a tuple of (val, aux). 519 | """ 520 | 521 | # todo: If we used setup_context we could maybe have **kwargs in the forward? 522 | 523 | @staticmethod 524 | def forward( 525 | ctx: torch.autograd.function.BackwardCFunction, 526 | jax_function: Callable[ 527 | [Params, *tuple[jax.Array, ...]], tuple[chex.Scalar, Aux] 528 | ], 529 | jax_value_and_grad_function: Callable[ 530 | [Params, Unpack[tuple[jax.Array, ...]]], 531 | tuple[ 532 | tuple[chex.Scalar, Aux], # outputs 533 | Params # grads of params only 534 | | tuple[ 535 | Params, Unpack[tuple[jax.Array, ...]] 536 | ], # grads of params and inputs 537 | ], 538 | ], 539 | inputs_treedef: PyTreeDef, 540 | params_treedef: PyTreeDef, 541 | # need to flatten the inputs and params for autograd to understand that they need a gradient. 542 | *flatened_inputs_and_params: torch.Tensor, 543 | ): 544 | # TODO: Keep the aux the same? Or convert to torch? 545 | 546 | from .to_jax import torch_to_jax 547 | 548 | flat_inputs, flat_params = ( 549 | flatened_inputs_and_params[: inputs_treedef.num_leaves], 550 | flatened_inputs_and_params[inputs_treedef.num_leaves :], 551 | ) 552 | jax_inputs = jax.tree.unflatten(inputs_treedef, map(torch_to_jax, flat_inputs)) 553 | jax_params = jax.tree.unflatten(params_treedef, map(torch_to_jax, flat_params)) 554 | 555 | _, _, _, _, *inputs_and_params_need_grad = ctx.needs_input_grad # type: ignore 556 | 557 | inputs_need_grad, params_need_grad = ( 558 | inputs_and_params_need_grad[: inputs_treedef.num_leaves], 559 | inputs_and_params_need_grad[inputs_treedef.num_leaves :], 560 | ) 561 | # Save these for the backward pass. 562 | ctx.inputs_treedef = inputs_treedef # type: ignore 563 | ctx.params_treedef = params_treedef # type: ignore 564 | 565 | if not any(inputs_need_grad) and not any(params_need_grad): 566 | # Do only the forward pass. 567 | assert not any(inputs_need_grad) and not any(params_need_grad) 568 | jax_output, jax_aux = jax_function(jax_params, *jax_inputs) 569 | output = jax.tree.map(jax_to_torch, jax_output) 570 | # TODO: Keep the aux the same? Or convert to torch? 571 | aux = jax.tree.map(jax_to_torch, jax_aux) 572 | return output, aux 573 | 574 | # We only calculate the gradients we care about by changing the argnums 575 | # passed to value_and_grad. 576 | ( 577 | (jax_output, jax_aux), 578 | jax_grads_depending_on_argnums, 579 | ) = jax_value_and_grad_function(jax_params, *jax_inputs) 580 | 581 | output = jax.tree.map(jax_to_torch, jax_output) 582 | aux = jax.tree.map(jax_to_torch, jax_aux) 583 | 584 | if any(params_need_grad) and not any(inputs_need_grad): 585 | # The `value_and_grad_function` is used to calculate the gradients w.r.t. 586 | # only the parameters. 587 | flat_jax_grads = jax.tree.leaves(jax_grads_depending_on_argnums) 588 | param_grads = tuple(map(jax_to_torch, flat_jax_grads)) 589 | ctx.save_for_backward(*param_grads) 590 | return output, aux 591 | 592 | assert any(inputs_need_grad) 593 | assert all(params_need_grad) # assuming that all params need a gradient. 594 | 595 | # The `value_and_grad_function` calculated the gradients w.r.t. the 596 | # parameters and (some?) inputs. 597 | assert ( 598 | isinstance(jax_grads_depending_on_argnums, tuple) 599 | and len(jax_grads_depending_on_argnums) >= 2 600 | ) 601 | jax_param_grads, *jax_input_grads = jax_grads_depending_on_argnums 602 | param_grads = jax.tree.leaves(jax.tree.map(jax_to_torch, jax_param_grads)) 603 | # Some inputs might need gradients. 604 | assert len(jax_input_grads) <= inputs_treedef.num_leaves 605 | input_grads = jax.tree.leaves(jax.tree.map(jax_to_torch, jax_input_grads)) 606 | ctx.save_for_backward(*input_grads, *param_grads) 607 | 608 | return output, aux 609 | 610 | @torch.autograd.function.once_differentiable 611 | @staticmethod 612 | def backward( 613 | ctx: torch.autograd.function.NestedIOFunction, 614 | grad_output: torch.Tensor, 615 | _grad_aux: Any, 616 | ): 617 | assert not grad_output.shape 618 | assert (grad_output == torch.ones_like(grad_output)).all() 619 | 620 | # The gradients have already been computed with `value_and_grad`, so here we 621 | # just return them from ctx.saved_tensors depending on `ctx.needs_input_grad`. 622 | 623 | inputs_treedef: PyTreeDef = ctx.inputs_treedef # type: ignore 624 | params_treedef: PyTreeDef = ctx.params_treedef # type: ignore 625 | _, _, _, _, *flat_inputs_and_params_need_grad = ctx.needs_input_grad # type: ignore 626 | 627 | n_inputs = inputs_treedef.num_leaves 628 | n_params = params_treedef.num_leaves 629 | 630 | inputs_need_grad, params_need_grad = ( 631 | flat_inputs_and_params_need_grad[:n_inputs], 632 | flat_inputs_and_params_need_grad[n_inputs:], 633 | ) 634 | 635 | if not any(inputs_need_grad) and len(ctx.saved_tensors) == n_params: 636 | # We saved the gradients of the parameters (but not of the inputs). 637 | input_grads = [None] * inputs_treedef.num_leaves 638 | # Remove the grads of parameters that didn't require them. 639 | param_grads = [ 640 | param_grad if needed_grad else None 641 | for param_grad, needed_grad in zip(ctx.saved_tensors, params_need_grad) 642 | ] 643 | return None, None, None, None, *input_grads, *param_grads 644 | 645 | assert all(params_need_grad) 646 | assert any(inputs_need_grad) 647 | 648 | saved_tensors: tuple[torch.Tensor, ...] = ctx.saved_tensors # type: ignore 649 | 650 | # Here we do it slightly differently, because we might not have `n_inputs` input 651 | # gradients saved: it depends on the `argnums` of the `value_and_grad` function 652 | # that was used in forward. 653 | 654 | n_inputs_that_needed_grad = sum(inputs_need_grad) 655 | n_params_that_needed_grad = sum(params_need_grad) 656 | # We should have as many saved tensors as there are inputs that needed a grad. 657 | assert ( 658 | len(ctx.saved_tensors) 659 | == n_inputs_that_needed_grad + n_params_that_needed_grad 660 | ) 661 | 662 | input_grads, param_grads = ( 663 | saved_tensors[:n_inputs_that_needed_grad], 664 | saved_tensors[n_inputs_that_needed_grad:], 665 | ) 666 | # Consume the saved input grads, assigning them to the right index: 667 | saved_input_grads = list(input_grads) 668 | input_grads = [ 669 | saved_input_grads.pop(0) if needed_grad else None 670 | for needed_grad in inputs_need_grad 671 | ] 672 | # Only give out a grad for parameters that required them. 673 | # note: un-needed atm, since above we assume all(params_need_grad!) 674 | param_grads = [ 675 | param_grad if needed_grad else None 676 | for param_grad, needed_grad in zip(param_grads, params_need_grad) 677 | ] 678 | return None, None, None, None, *input_grads, *param_grads 679 | -------------------------------------------------------------------------------- /torch_jax_interop/to_torch_module_test.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import flax.linen 4 | import jax 5 | import jax.test_util 6 | import optax 7 | import pytest 8 | import torch 9 | from flax.typing import VariableDict 10 | from tensor_regression import TensorRegressionFixture 11 | 12 | from torch_jax_interop import jax_to_torch 13 | from torch_jax_interop.conftest import JaxCNN, TorchCNN 14 | from torch_jax_interop.to_torch_module import ( 15 | JaxPyTree, 16 | WrappedJaxFunction, 17 | WrappedJaxScalarFunction, 18 | ) 19 | from torch_jax_interop.types import jit 20 | 21 | # TODO: The regression check in this test occasionally fails? Unable to precisely 22 | # replicate it yet. 23 | # This test case seems to fail occasionally: 24 | # - `input_grad` tensor differs in this case: [backend=cuda-JaxFcNet-input_requires_grad=True-aux=True-jit=False-clone_params=False] 25 | 26 | 27 | @pytest.mark.parametrize("clone_params", [False, True], ids="clone_params={}".format) 28 | @pytest.mark.parametrize("use_jit", [False, True], ids="jit={}".format) 29 | @pytest.mark.parametrize("has_aux", [False, True], ids="aux={}".format) 30 | @pytest.mark.parametrize( 31 | "input_requires_grad", [False, True], ids="input_requires_grad={}".format 32 | ) 33 | @pytest.mark.parametrize( 34 | "do_regression_check", 35 | [ 36 | False, 37 | True, 38 | ], 39 | ) 40 | def test_use_jax_module_in_torch_graph( 41 | jax_network_and_params: tuple[flax.linen.Module, VariableDict], 42 | torch_input: torch.Tensor, 43 | tensor_regression: TensorRegressionFixture, 44 | num_classes: int, 45 | seed: int, 46 | has_aux: bool, 47 | use_jit: bool, 48 | clone_params: bool, 49 | input_requires_grad: bool, 50 | torch_device: torch.device, 51 | do_regression_check: bool, 52 | ): 53 | jax_network, jax_params = jax_network_and_params 54 | 55 | batch_size = torch_input.shape[0] 56 | 57 | input = torch_input.clone().detach().requires_grad_(input_requires_grad) 58 | labels = torch.randint( 59 | 0, 60 | num_classes, 61 | (batch_size,), 62 | device=input.device, 63 | generator=torch.Generator(device=input.device).manual_seed(seed), 64 | ) 65 | 66 | if not has_aux: 67 | jax_function: Callable[ 68 | [JaxPyTree, *tuple[jax.Array, ...]], jax.Array 69 | ] = jax_network.apply # type: ignore 70 | 71 | if use_jit: 72 | jax_function = jit(jax_function) 73 | 74 | wrapped_jax_module = WrappedJaxFunction( 75 | jax_function, jax_params, has_aux=has_aux, clone_params=clone_params 76 | ) 77 | 78 | logits = wrapped_jax_module(input) 79 | 80 | loss = torch.nn.functional.cross_entropy(logits, labels, reduction="mean") 81 | loss.backward() 82 | 83 | else: 84 | 85 | def jax_function_with_aux( 86 | params: JaxPyTree, *inputs: jax.Array 87 | ) -> tuple[jax.Array, JaxPyTree]: 88 | out = jax_network.apply(params, *inputs) 89 | assert isinstance(out, jax.Array) 90 | aux = {"mean": out.mean(), "max": out.max()} 91 | return out, aux 92 | 93 | if use_jit: 94 | jax_function_with_aux = jit(jax_function_with_aux) 95 | 96 | wrapped_jax_module = WrappedJaxFunction( 97 | jax_function_with_aux, 98 | jax_params, 99 | has_aux=has_aux, 100 | clone_params=clone_params, 101 | ) 102 | 103 | logits, stats_dict = wrapped_jax_module(input) 104 | loss = torch.nn.functional.cross_entropy(logits, labels, reduction="mean") 105 | loss.backward() 106 | 107 | # Check that the stats dict has the same structure but contains pytorch tensors 108 | # instead of jax arrays. 109 | 110 | assert isinstance(stats_dict, dict) 111 | mean = stats_dict["mean"] 112 | assert isinstance(mean, torch.Tensor) 113 | assert mean.device == torch_device 114 | torch.testing.assert_close(mean, logits.mean()) 115 | assert not mean.requires_grad 116 | max = stats_dict["max"] 117 | assert isinstance(max, torch.Tensor) 118 | assert max.device == torch_device 119 | torch.testing.assert_close(max, logits.max()) 120 | assert not max.requires_grad 121 | 122 | assert len(list(wrapped_jax_module.parameters())) == len( 123 | jax.tree.leaves(jax_params) 124 | ) 125 | assert all(p.requires_grad for p in wrapped_jax_module.parameters()) 126 | assert isinstance(logits, torch.Tensor) and logits.requires_grad 127 | assert all( 128 | p.requires_grad and p.grad is not None for p in wrapped_jax_module.parameters() 129 | ) 130 | if input_requires_grad: 131 | assert input.grad is not None 132 | else: 133 | assert input.grad is None 134 | 135 | if do_regression_check: 136 | tensor_regression.check( 137 | { 138 | "input": input, 139 | "output": logits, 140 | "loss": loss, 141 | "input_grad": input.grad, 142 | } 143 | | {name: p for name, p in wrapped_jax_module.named_parameters()}, 144 | include_gpu_name_in_stats=False, 145 | ) 146 | 147 | 148 | @pytest.mark.parametrize("input_requires_grad", [False, True]) 149 | def test_use_jax_scalar_function_in_torch_graph( 150 | jax_network_and_params: tuple[flax.linen.Module, VariableDict], 151 | torch_input: torch.Tensor, 152 | tensor_regression: TensorRegressionFixture, 153 | num_classes: int, 154 | seed: int, 155 | input_requires_grad: bool, 156 | ): 157 | """Same idea, but now its the entire loss function that is in jax, not just the module.""" 158 | jax_network, jax_params = jax_network_and_params 159 | 160 | batch_size = torch_input.shape[0] 161 | 162 | @jit 163 | def loss_fn( 164 | params: VariableDict, x: jax.Array, y: jax.Array 165 | ) -> tuple[jax.Array, jax.Array]: 166 | logits = jax_network.apply(params, x) 167 | assert isinstance(logits, jax.Array) 168 | one_hot = jax.nn.one_hot(y, logits.shape[-1]) 169 | loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean() 170 | return loss, logits 171 | 172 | # todo: add a test case where the input is floating point and requires gradients. 173 | if not input_requires_grad: 174 | # note: the input can't require grad because it's an int tensor. 175 | input = torch_input 176 | else: 177 | input = torch_input.float().clone().detach().requires_grad_(True) 178 | 179 | labels = torch.randint( 180 | 0, 181 | num_classes, 182 | (batch_size,), 183 | device=input.device, 184 | generator=torch.Generator(device=input.device).manual_seed(seed), 185 | ) 186 | 187 | wrapped_jax_module = WrappedJaxScalarFunction(loss_fn, jax_params) 188 | 189 | assert len(list(wrapped_jax_module.parameters())) == len( 190 | jax.tree.leaves(jax_params) 191 | ) 192 | assert all(p.requires_grad for p in wrapped_jax_module.parameters()) 193 | if not input_requires_grad: 194 | assert not input.requires_grad 195 | else: 196 | assert input.requires_grad 197 | assert not labels.requires_grad 198 | loss, logits = wrapped_jax_module(input, labels) 199 | assert isinstance(loss, torch.Tensor) and loss.requires_grad 200 | assert isinstance(logits, torch.Tensor) and logits.requires_grad 201 | loss.backward() 202 | 203 | assert all( 204 | p.requires_grad and p.grad is not None for p in wrapped_jax_module.parameters() 205 | ) 206 | if input_requires_grad: 207 | assert input.grad is not None 208 | else: 209 | assert input.grad is None 210 | 211 | tensor_regression.check( 212 | { 213 | "input": input, 214 | "output": logits, 215 | "loss": loss, 216 | "input_grad": input.grad, 217 | } 218 | | {name: p for name, p in wrapped_jax_module.named_parameters()}, 219 | include_gpu_name_in_stats=False, 220 | ) 221 | 222 | 223 | @pytest.fixture 224 | def torch_and_jax_networks_with_same_params( 225 | torch_network: torch.nn.Module, 226 | jax_network_and_params: tuple[flax.linen.Module, VariableDict], 227 | ): 228 | jax_network, jax_params = jax_network_and_params 229 | if isinstance(torch_network, TorchCNN) or isinstance(jax_network, JaxCNN): 230 | pytest.skip(reason="Params dont even lign up, its too hard to do atm.") 231 | 232 | flattened_jax_params, jax_params_treedef = jax.tree.flatten(jax_params) 233 | torch_params = list(torch_network.parameters()) 234 | assert len(flattened_jax_params) == len(torch_params) 235 | 236 | flattened_jax_params = sorted(flattened_jax_params, key=lambda p: tuple(p.shape)) 237 | torch_params = sorted(torch_params, key=lambda p: tuple(p.shape)) 238 | 239 | jax_param_shapes = [p.shape for p in flattened_jax_params] 240 | torch_param_shapes = [p.shape for p in torch_params] 241 | assert jax_param_shapes == torch_param_shapes 242 | 243 | # todo: find the equivalence between params, the ordering doesn't appear to be the same. 244 | with torch.no_grad(): 245 | for jax_param, torch_param in zip(flattened_jax_params, torch_params): 246 | assert jax_param.shape == torch_param.shape 247 | # initialize both networks with the same parameters. 248 | torch_param.data[:] = jax_to_torch(jax_param)[:] 249 | 250 | return jax_network, jax_params, torch_network 251 | 252 | 253 | @pytest.mark.xfail(reason="Params dont even lign up, its too hard to do atm.") 254 | def test_jax_and_torch_modules_have_same_forward_pass( 255 | torch_and_jax_networks_with_same_params: tuple[ 256 | flax.linen.Module, VariableDict, torch.nn.Module 257 | ], 258 | torch_input: torch.Tensor, 259 | jax_input: jax.Array, 260 | ): 261 | jax_network, jax_params, torch_network = torch_and_jax_networks_with_same_params 262 | 263 | jax_output = jax_network.apply(jax_params, jax_input) 264 | torch_output = torch_network(torch_input) 265 | 266 | torch.testing.assert_close(jax_output, torch_output) 267 | -------------------------------------------------------------------------------- /torch_jax_interop/to_torch_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import jax.test_util 7 | import numpy 8 | import pytest 9 | import torch 10 | from pytest_benchmark.fixture import BenchmarkFixture 11 | 12 | from torch_jax_interop import jax_to_torch, torch_to_jax 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "shape", 17 | [ 18 | (1,), 19 | (10, 10), 20 | (10, 10, 10), 21 | pytest.param( 22 | tuple(range(1, 6)), 23 | ), 24 | ], 25 | ids=repr, 26 | ) 27 | def test_jax_to_torch_tensor( 28 | shape: tuple[int, ...], 29 | jax_device: jax.Device, 30 | torch_dtype: torch.dtype, 31 | jax_dtype: jnp.dtype, 32 | seed: int, 33 | benchmark: BenchmarkFixture, 34 | ): 35 | if numpy.prod(shape) >= 1_000_000 and jax_device.platform == "cpu": 36 | pytest.skip("Skipping test with large tensor on CPU.") 37 | 38 | key = jax.random.key(seed) 39 | # todo: don't know what the equivalent is on a np/jax dtype for checking if the dtype is 40 | # floating-point. 41 | if torch_dtype.is_floating_point: 42 | jax_value = jax.random.uniform(key=key, shape=shape, dtype=jax_dtype) 43 | else: 44 | jax_value = jax.random.randint( 45 | key=key, shape=shape, minval=0, maxval=100, dtype=jax_dtype 46 | ) 47 | jax_value = jax.device_put(jax_value, device=jax_device) 48 | 49 | torch_expected_device = jax_to_torch(jax_device) 50 | assert isinstance(torch_expected_device, torch.device) 51 | 52 | torch_value = benchmark(jax_to_torch, jax_value) 53 | assert isinstance(torch_value, torch.Tensor) 54 | assert torch_value.device == torch_expected_device 55 | 56 | # Convert the torch Tensor to a numpy array so we can compare the contents. 57 | torch_numpy_value = torch_value.cpu().numpy() 58 | numpy.testing.assert_allclose(jax_value, torch_numpy_value) 59 | 60 | # round-trip: 61 | jax_round_trip = torch_to_jax(torch_value) 62 | numpy.testing.assert_allclose(jax_round_trip, jax_value) 63 | 64 | 65 | def some_jax_function(x: jnp.ndarray) -> jnp.ndarray: 66 | return x + jnp.ones_like(x) 67 | 68 | 69 | def test_jax_to_torch_function(jax_device: torch.device, benchmark: BenchmarkFixture): 70 | jax_input: jax.Array = jax.device_put(jnp.arange(5), device=jax_device) 71 | jax_function = some_jax_function 72 | expected_jax_output = jax_function(jax_input) 73 | 74 | torch_input = jax_to_torch(jax_input) 75 | torch_function = jax_to_torch(jax_function) 76 | torch_output = benchmark(torch_function, torch_input) 77 | 78 | jax_output = torch_to_jax(torch_output) 79 | # todo: dtypes might be mismatched for int64 and float64 80 | numpy.testing.assert_allclose(jax_output, expected_jax_output) 81 | 82 | # todo: Should it return a jax.Array when given a jax.Array as input? 83 | # ? = torch_function(jax_input) 84 | 85 | 86 | class FooBar: 87 | pass 88 | 89 | 90 | @pytest.mark.parametrize("unsupported_value", [FooBar()]) 91 | def test_log_once_on_unsupported_value( 92 | unsupported_value: Any, caplog: pytest.LogCaptureFixture 93 | ): 94 | with caplog.at_level(logging.DEBUG): 95 | assert jax_to_torch(unsupported_value) is unsupported_value 96 | assert len(caplog.records) == 1 97 | assert "No registered handler for values of type" in caplog.records[0].getMessage() 98 | 99 | caplog.clear() 100 | with caplog.at_level(logging.DEBUG): 101 | assert jax_to_torch(unsupported_value) is unsupported_value 102 | assert len(caplog.records) == 0 103 | -------------------------------------------------------------------------------- /torch_jax_interop/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | import functools 5 | import typing 6 | from typing import ( 7 | Any, 8 | Callable, 9 | ClassVar, 10 | FrozenSet, 11 | Literal, 12 | Mapping, 13 | ParamSpec, 14 | Protocol, 15 | Sequence, 16 | TypeGuard, 17 | TypeVar, 18 | overload, 19 | runtime_checkable, 20 | ) 21 | 22 | import chex 23 | import jax 24 | import jax.experimental 25 | import jax.experimental.checkify 26 | import torch 27 | from typing_extensions import TypeVarTuple, Unpack 28 | 29 | K = TypeVar("K") 30 | V = TypeVar("V") 31 | C = TypeVar("C", bound=Callable) 32 | Out = TypeVar("Out") 33 | P = ParamSpec("P") 34 | Aux = TypeVar("Aux") 35 | 36 | NestedDict = dict[K, V | "NestedDict[K, V]"] 37 | NestedMapping = Mapping[K, V | "NestedMapping[K, V]"] 38 | 39 | T = TypeVar("T") 40 | PyTree = T | tuple["PyTree[T]", ...] | list["PyTree[T]"] | dict[Any, "PyTree[T]"] 41 | 42 | Scalar = float | int | bool 43 | JaxPyTree = ( 44 | Scalar 45 | | jax.Array 46 | | tuple["JaxPyTree", ...] 47 | | list["JaxPyTree"] 48 | | Mapping[Any, "JaxPyTree"] 49 | ) 50 | TorchPyTree = ( 51 | Scalar 52 | | torch.Tensor 53 | | tuple["TorchPyTree", ...] 54 | | list["TorchPyTree"] 55 | | Mapping[Any, "TorchPyTree"] 56 | ) 57 | Params = TypeVar("Params", bound=JaxPyTree) 58 | 59 | 60 | T = TypeVar("T", jax.Array, torch.Tensor) 61 | 62 | 63 | P = ParamSpec("P") 64 | Out_cov = TypeVar("Out_cov", covariant=True) 65 | 66 | 67 | @runtime_checkable 68 | class Module(Protocol[P, Out_cov]): 69 | """Protocol for a torch.nn.Module that gives better type hints for the `__call__` method.""" 70 | 71 | def forward(self, *args: P.args, **kwargs: P.kwargs) -> Out_cov: 72 | raise NotImplementedError 73 | 74 | if typing.TYPE_CHECKING: 75 | # note: Only define this for typing purposes so that we don't actually override anything. 76 | def __call__(self, *args: P.args, **kwagrs: P.kwargs) -> Out_cov: 77 | ... 78 | 79 | modules = torch.nn.Module.modules 80 | named_modules = torch.nn.Module.named_modules 81 | state_dict = torch.nn.Module.state_dict 82 | zero_grad = torch.nn.Module.zero_grad 83 | parameters = torch.nn.Module.parameters 84 | named_parameters = torch.nn.Module.named_parameters 85 | cuda = torch.nn.Module.cuda 86 | cpu = torch.nn.Module.cpu 87 | # note: the overloads on nn.Module.to cause a bug with missing `self`. 88 | # This shouldn't be a problem. 89 | to = torch.nn.Module().to 90 | 91 | 92 | # NOTE: Not using a `runtime_checkable` version of the `Dataclass` protocol here, because it 93 | # doesn't work correctly in the case of `isinstance(SomeDataclassType, Dataclass)`, which returns 94 | # `True` when it should be `False` (since it's a dataclass type, not a dataclass instance), and the 95 | # runtime_checkable decorator doesn't check the type of the attribute (ClassVar vs instance 96 | # attribute). 97 | 98 | 99 | class _DataclassMeta(type): 100 | def __subclasscheck__(self, subclass: type) -> bool: 101 | return dataclasses.is_dataclass(subclass) and not dataclasses.is_dataclass( 102 | type(subclass) 103 | ) 104 | 105 | def __instancecheck__(self, instance: Any) -> bool: 106 | return dataclasses.is_dataclass(instance) and dataclasses.is_dataclass( 107 | type(instance) 108 | ) 109 | 110 | 111 | class Dataclass(metaclass=_DataclassMeta): 112 | """A class which is used to check if a given object is a dataclass. 113 | 114 | This plays nicely with @functools.singledispatch, allowing us to register functions to be used 115 | for dataclass inputs. 116 | """ 117 | 118 | 119 | class DataclassInstance(Protocol): 120 | # Copy of the type stub from dataclasses. 121 | __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] 122 | 123 | 124 | DataclassType = TypeVar("DataclassType", bound=DataclassInstance) 125 | 126 | 127 | def is_sequence_of( 128 | object: Any, item_type: type[V] | tuple[type[V], ...] 129 | ) -> TypeGuard[Sequence[V]]: 130 | """Used to check (and tell the type checker) that `object` is a sequence of items of type 131 | `V`.""" 132 | try: 133 | return all(isinstance(value, item_type) for value in object) 134 | except TypeError: 135 | return False 136 | 137 | 138 | def is_list_of( 139 | object: Any, item_type: type[V] | tuple[type[V], ...] 140 | ) -> TypeGuard[list[V]]: 141 | """Used to check (and tell the type checker) that `object` is a list of items of this type.""" 142 | return isinstance(object, list) and is_sequence_of(object, item_type) 143 | 144 | 145 | def jit(fn: Callable[P, Out]) -> Callable[P, Out]: 146 | """Small type hint fix for jax's `jit` (preserves the signature of the callable).""" 147 | return jax.jit(fn) # type: ignore 148 | 149 | 150 | In = TypeVar("In") 151 | Aux = TypeVar("Aux") 152 | In2 = TypeVar("In2") 153 | In3 = TypeVar("In3") 154 | Ts = TypeVarTuple("Ts") 155 | 156 | 157 | # argnums = 0 158 | @overload 159 | def value_and_grad( 160 | fn: Callable[[In, *Ts], Out], 161 | argnums: Literal[0] = 0, 162 | has_aux: bool = ..., 163 | ) -> Callable[[In, *Ts], tuple[Out, In]]: 164 | ... 165 | 166 | 167 | @overload 168 | def value_and_grad( 169 | fn: Callable[[In, In2, *Ts], Out], 170 | argnums: tuple[Literal[0], Literal[1]], 171 | has_aux: bool = ..., 172 | ) -> Callable[[In, *Ts], tuple[Out, tuple[In, In2]]]: 173 | ... 174 | 175 | 176 | @overload 177 | def value_and_grad( 178 | fn: Callable[[In, In2, In3, *Ts], Out], 179 | argnums: tuple[Literal[0], Literal[1], Literal[2]], 180 | has_aux: bool = ..., 181 | ) -> Callable[[In, *Ts], tuple[Out, tuple[In, In2, In3]]]: 182 | ... 183 | 184 | 185 | @overload 186 | def value_and_grad( 187 | fn: Callable[[In, *Ts], Out], 188 | argnums: tuple[Literal[0], Unpack[tuple[int, ...]]], 189 | has_aux: bool = ..., 190 | ) -> Callable[[In, Unpack[Ts]], tuple[Out, tuple[In, Unpack[Ts]]]]: 191 | ... 192 | 193 | 194 | @overload 195 | def value_and_grad( 196 | fn: Callable[[Unpack[Ts]], Out], 197 | argnums: Sequence[int], 198 | has_aux: bool = ..., 199 | ) -> Callable[[*Ts], tuple[Unpack[Ts]]]: 200 | ... 201 | 202 | 203 | def value_and_grad( # type: ignore 204 | fn: Callable[..., Out], 205 | argnums: int | Sequence[int] = 0, 206 | has_aux: bool = False, 207 | ): 208 | """Small type hint fix for jax's `value_and_grad` (preserves the signature of the callable).""" 209 | return jax.value_and_grad(fn, argnums=argnums, has_aux=has_aux) 210 | 211 | 212 | def chexify( 213 | fn: Callable[P, Out], 214 | async_check: bool = True, 215 | errors: FrozenSet[ 216 | jax.experimental.checkify.ErrorCategory 217 | ] = chex.ChexifyChecks.user, 218 | ) -> Callable[P, Out]: 219 | # Fix `chex.chexify` so it preserves the function's signature. 220 | return functools.wraps(fn)(chex.chexify(fn, async_check=async_check, errors=errors)) # type: ignore 221 | -------------------------------------------------------------------------------- /torch_jax_interop/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | 4 | import jax 5 | import torch 6 | 7 | from .types import T 8 | 9 | 10 | def log_once(logger: logging.Logger, message: str, level: int): 11 | logger.log(level=level, msg=message, stacklevel=2) 12 | 13 | 14 | log_once = functools.cache(log_once) 15 | 16 | 17 | def is_channels_first(shape: torch.Size | tuple[int, ...]) -> bool: 18 | if len(shape) == 4: 19 | return is_channels_first(shape[1:]) 20 | if len(shape) != 3: 21 | return False 22 | return ( 23 | shape[0] in (1, 3) and shape[1] not in {1, 3} and shape[2] not in {1, 3} 24 | ) or (shape[0] < min(shape[1], shape[2])) 25 | 26 | 27 | def is_channels_last(shape: torch.Size | tuple[int, ...]) -> bool: 28 | if len(shape) == 4: 29 | return is_channels_last(shape[1:]) 30 | if len(shape) != 3: 31 | return False 32 | return ( 33 | shape[2] in (1, 3) and shape[0] not in {1, 3} and shape[1] not in {1, 3} 34 | ) or (shape[2] < min(shape[0], shape[1])) 35 | 36 | 37 | def to_channels_last(tensor: T) -> T: 38 | shape = tuple(tensor.shape) 39 | assert len(shape) == 3 or len(shape) == 4 40 | if not is_channels_first(shape): 41 | return tensor 42 | if isinstance(tensor, jax.Array): 43 | if len(shape) == 3: 44 | return tensor.transpose(1, 2, 0) 45 | return tensor.transpose(0, 2, 3, 1) 46 | else: 47 | if len(shape) == 3: 48 | return tensor.permute(1, 2, 0) 49 | return tensor.permute(0, 2, 3, 1) 50 | 51 | 52 | def to_channels_first(tensor: T) -> T: 53 | shape = tuple(tensor.shape) 54 | assert len(shape) == 3 or len(shape) == 4 55 | if is_channels_first(shape): 56 | return tensor 57 | if not is_channels_last(shape): 58 | return tensor 59 | if isinstance(tensor, jax.Array): 60 | if len(shape) == 3: 61 | # [H, W, C] -> [C, H, W] 62 | return tensor.transpose(2, 0, 1) 63 | # [B, H, W, C] -> [B, C, H, W] 64 | return tensor.transpose(0, 3, 1, 2) 65 | else: 66 | if len(shape) == 3: 67 | # [H, W, C] -> [C, H, W] 68 | return tensor.permute(2, 0, 1) 69 | return tensor.permute(0, 3, 1, 2) 70 | --------------------------------------------------------------------------------