├── pytorch2jax ├── __init__.py └── pytorch2jax.py ├── setup.py ├── LICENSE ├── .gitignore └── README.md /pytorch2jax/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch2jax import convert_to_pyt,convert_to_jax,convert_pytnn_to_jax,convert_pytnn_to_flax,py_to_jax_wrapper -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setup( 7 | name="pytorch2jax", 8 | version="0.1.0", 9 | packages=find_packages(), 10 | long_description=long_description, 11 | long_description_content_type="text/markdown", 12 | python_requires=">=3.6, <4", 13 | install_requires=["torch", "jax", "jaxlib"], 14 | classifiers=[ 15 | "Programming Language :: Python :: 3.6", 16 | "Programming Language :: Python :: 3.7", 17 | "Programming Language :: Python :: 3.8", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | ], 21 | author="Subhojeet Pramanik", 22 | description="Convert PyTorch models to Jax functions and Flax models", 23 | url="https://github.com/subho406/Pytorch2Jax", 24 | ) 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Subhojeet Pramanik 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch2Jax 2 | 3 | [![PyPI version](https://badge.fury.io/py/pytorch2jax.svg)](https://badge.fury.io/py/pytorch2jax) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 5 | 6 | Pytorch2Jax is a small Python library that provides functions to wrap PyTorch models into Jax functions and Flax modules. It uses `dlpack` to convert between Pytorch and Jax tensors in-memory and executes Pytorch backend inside Jax wrapped functions. The wrapped functions are compaitible with Jax backward-mode autodiff (`jax.grad` and `jax.vjp`) via `functorch.vjp` and could use be used in any `dlpack` compatible hardware. 7 | 8 | ## Installation 9 | 10 | You can install the Pytorch2Jax package from PyPI via pip: 11 | ``` 12 | pip install pytorch2jax 13 | ``` 14 | 15 | ## Usage 16 | ### Example 1: Wrap a Pytorch function to a function that accepts Jax tensors 17 | 18 | ```python 19 | import torch 20 | import jax.numpy as jnp 21 | from pytorch2jax import py_to_jax_wrapper 22 | 23 | # Define a PyTorch function that multiples an input tensor with another tensor 24 | # and wrap it with the py_to_jax_wrapper decorator 25 | @py_to_jax_wrapper 26 | def fn(x): 27 | return torch.rand((10,10))*x 28 | 29 | 30 | # Call the wrapped function on a JAX array 31 | x = jnp.ones((10,10)) 32 | output = fn(x) 33 | 34 | # Print the output 35 | print(output) 36 | 37 | ``` 38 | 39 | ### Example 2: Convert a PyTorch model to a JAX function and differentiate with grad 40 | 41 | The converted Jax function can be used seamlessly with Jax's `grad` function to compute gradients. 42 | ```python 43 | import jax.numpy as jnp 44 | import jax 45 | 46 | import torch.nn as pnn 47 | 48 | from pytorch2jax import convert_pytnn_to_jax 49 | 50 | # Create a PyTorch model 51 | pyt_model = pnn.Linear(10, 10) 52 | 53 | # Convert PyTorch model to a JAX function 54 | jax_fn, params = convert_pytnn_to_jax(pyt_model) 55 | 56 | # Define a function that uses the JAX function and returns the sum of its output 57 | def fx(x): 58 | return jax_fn(params, x).sum() 59 | 60 | # Compute the gradient of the function `fx` with respect to `x` 61 | grad_fx = jax.grad(fx) 62 | x = jnp.ones((10,)) 63 | print(grad_fx(x)) # Prints the gradient of fx at x 64 | 65 | ``` 66 | 67 | ### Example 3: Convert a PyTorch model to a Flax model class and do forward pass inside another Flax module 68 | 69 | ```python 70 | import jax.numpy as jnp 71 | import jax 72 | import torch.nn as pnn 73 | import flax.linen as jnn 74 | 75 | from pytorch2jax import convert_pytnn_to_flax 76 | from typing import Any 77 | 78 | # Convert the PyTorch model to a Flax model using the 'convert_pytnn_to_flax' function 79 | # flax_module is the converted Flax model and params are the parameters of the converted Flax model 80 | pyt_model = pnn.Linear(10, 10) 81 | flax_module, params = convert_pytnn_to_flax(pyt_model) 82 | 83 | # Define a new Flax module and define the flax_module attribute as the converted Flax model 84 | # The __call__ method of this module will call the __call__ method of the flax_module attribute 85 | class SampleFlaxModule(jnn.Module): 86 | flax_module: Any 87 | 88 | @jnn.compact 89 | def __call__(self, x): 90 | return self.flax_module()(x) 91 | 92 | # Create an instance of the new Flax module 93 | flax_model = SampleFlaxModule(flax_module) 94 | 95 | params = flax_model.init(jax.random.PRNGKey(0), jnp.ones((10, 10))) 96 | 97 | # Apply the Flax model to the input to get the output 98 | flax_model.apply(params, jnp.ones((10, 10))) 99 | ``` 100 | 101 | # Contributing 102 | 103 | If you encounter any bugs or issues while using pytorch2jax, or if you have any suggestions for improvements or new features, please open an issue on the GitHub repository at https://github.com/subho406/Pytorch2Jax. 104 | 105 | # License 106 | 107 | Pytorch2Jax is released under the MIT License. See LICENSE for more information. 108 | -------------------------------------------------------------------------------- /pytorch2jax/pytorch2jax.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import functorch 3 | import torch 4 | import jax 5 | from flax import linen as jnn 6 | from torch import nn as pnn 7 | from torch.utils.dlpack import from_dlpack as pyt_from_dlpack 8 | from torch.utils.dlpack import to_dlpack as pyt_to_dlpack 9 | from jax.dlpack import to_dlpack as jax_to_dlpack 10 | from jax.dlpack import from_dlpack as jax_from_dlpack 11 | from jax import custom_vjp 12 | from typing import Any, Callable, Iterable, Optional, Tuple, Union 13 | 14 | 15 | #Convert a Pytorch model to a Flax model 16 | # Define functions for converting data between PyTorch and JAX representations 17 | def convert_to_pyt(x): 18 | # If x is a JAX ndarray, convert it to a DLPack and then to a PyTorch tensor 19 | if isinstance(x,jnp.ndarray): 20 | x=jax_to_dlpack(x) 21 | x=pyt_from_dlpack(x) 22 | return x 23 | 24 | def convert_to_jax(x): 25 | # If x is a PyTorch tensor, convert it to a DLPack and then to a JAX ndarray 26 | if isinstance(x,torch.Tensor): 27 | x=pyt_to_dlpack(x) 28 | x=jax_from_dlpack(x) 29 | return x 30 | 31 | 32 | def convert_pytnn_to_jax(model): 33 | # Convert the PyTorch model to a functional representation and extract the model function and parameters 34 | model_fn,model_params=functorch.make_functional(model) 35 | 36 | 37 | 38 | # Convert the model parameters from PyTorch to JAX representations 39 | model_params=jax.tree_map(convert_to_jax,model_params) 40 | 41 | # Define the apply function using a custom VJP 42 | @custom_vjp 43 | def apply(params,*args,**kwargs): 44 | # Convert the input data from PyTorch to JAX representations 45 | params=jax.tree_map(convert_to_pyt,params) 46 | args=jax.tree_map(convert_to_pyt,args) 47 | kwargs=jax.tree_map(convert_to_pyt,kwargs) 48 | # Apply the model function to the input data 49 | out=model_fn(params,*args,**kwargs) 50 | # Convert the output data from JAX to PyTorch representations 51 | out=jax.tree_map(convert_to_jax,out) 52 | return out 53 | 54 | # Define the forward and backward passes for the VJP 55 | def apply_fwd(params,*args,**kwargs): 56 | return apply(params,*args,**kwargs),(params,args,kwargs) 57 | 58 | def apply_bwd(res,grads): 59 | params,args,kwargs=res 60 | # Convert the input data and gradients from PyTorch to JAX representations 61 | params=jax.tree_map(convert_to_pyt,params) 62 | args=jax.tree_map(convert_to_pyt,args) 63 | kwargs=jax.tree_map(convert_to_pyt,kwargs) 64 | grads=jax.tree_map(convert_to_pyt,grads) 65 | # Compute the gradients using the model function and convert them from JAX to PyTorch representations 66 | grads=functorch.vjp(model_fn,params,*args,**kwargs)[1](grads) 67 | grads=jax.tree_map(convert_to_jax,grads) 68 | return grads 69 | apply.defvjp(apply_fwd,apply_bwd) 70 | 71 | # Return the apply function and the converted model parameters 72 | return apply,model_params 73 | 74 | 75 | def convert_pytnn_to_flax(model): 76 | # Define a Flax module that wraps the JAX-converted PyTorch model 77 | jax_fn,params=convert_pytnn_to_jax(model) 78 | class FlaxModule(jnn.Module): 79 | # Convert the PyTorch model to a JAX-converted version and set it up as a Flax parameter 80 | def setup(self): 81 | self.jax_param=self.param('jax_params',lambda x:params) 82 | 83 | # Define the __call__ method to apply the JAX-converted model to the input data 84 | def __call__(self,x): 85 | return jax_fn(self.jax_param,x) 86 | return FlaxModule,params 87 | 88 | 89 | def py_to_jax_wrapper(fun): 90 | def wrapper(*args,**kwargs): 91 | # Convert the input data from PyTorch to JAX representations 92 | args=jax.tree_map(convert_to_pyt,args) 93 | kwargs=jax.tree_map(convert_to_pyt,kwargs) 94 | # Apply the function to the input data 95 | out=fun(*args,**kwargs) 96 | # Convert the output data from JAX to PyTorch representations 97 | out=jax.tree_map(convert_to_jax,out) 98 | return out 99 | return wrapper --------------------------------------------------------------------------------