├── .github ├── ISSUE_TEMPLATE │ └── feature_request.yml └── workflows │ ├── build.yml │ └── pythonpublish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── contributing.md ├── docs ├── _static │ └── css │ │ └── jaxonnxruntime_theme.css ├── adding_a_new_op.rst ├── conf.py ├── examples │ └── jaxonnxruntime-python-api.py ├── index.rst └── requirements.txt ├── jaxonnxruntime ├── __init__.py ├── backend.py ├── core │ ├── __init__.py │ ├── call_onnx.py │ ├── call_onnx_test.py │ ├── config_class.py │ ├── handler.py │ ├── handler_test.py │ ├── onnx_graph.py │ ├── onnx_graph_test.py │ ├── onnx_node.py │ ├── onnx_node_test.py │ ├── onnx_utils.py │ └── onnx_utils_test.py ├── experimental │ ├── call_torch │ │ ├── Notes.md │ │ ├── __init__.py │ │ ├── _call_torch.py │ │ ├── call_torch_test.py │ │ ├── call_torch_test_util.py │ │ ├── call_torch_xla.py │ │ ├── call_torch_xla_test.py │ │ └── test_data │ │ │ └── d2l_torch.py │ ├── custom_ops │ │ ├── __init__.py │ │ ├── zeros_like.py │ │ └── zeros_like_test.py │ └── export │ │ ├── __init__.py │ │ ├── exportable.py │ │ ├── exportable_test.py │ │ ├── exportable_test_utils.py │ │ ├── exportable_utils.py │ │ ├── jax_exported_test.py │ │ ├── tensorflow_exportable.py │ │ ├── tensorflow_exportable_test.py │ │ ├── torch_exportable.py │ │ └── torch_exportable_test.py ├── onnx_ops │ ├── __init__.py │ ├── abs.py │ ├── acos.py │ ├── acosh.py │ ├── add.py │ ├── and_op.py │ ├── argmax.py │ ├── argmin.py │ ├── asin.py │ ├── asinh.py │ ├── atan.py │ ├── atanh.py │ ├── averagepool.py │ ├── batchnormalization.py │ ├── bitshift.py │ ├── cast.py │ ├── castlike.py │ ├── ceil.py │ ├── clip.py │ ├── concat.py │ ├── constant.py │ ├── constantofshape.py │ ├── conv.py │ ├── cos.py │ ├── cosh.py │ ├── dequantizelinear.py │ ├── div.py │ ├── dropout.py │ ├── dropout_test.py │ ├── einsum.py │ ├── equal.py │ ├── erf.py │ ├── exp.py │ ├── expand.py │ ├── flatten.py │ ├── gather.py │ ├── gatherelements.py │ ├── gemm.py │ ├── globalaveragepool.py │ ├── greater.py │ ├── greaterorequal.py │ ├── identity.py │ ├── if_op.py │ ├── leakyrelu.py │ ├── less.py │ ├── lessorequal.py │ ├── log.py │ ├── logsoftmax.py │ ├── lrn.py │ ├── matmul.py │ ├── max.py │ ├── maxpool.py │ ├── min.py │ ├── mul.py │ ├── neg.py │ ├── nonzero.py │ ├── onehot.py │ ├── onehot_test.py │ ├── onnx_not.py │ ├── onnx_ops_utils.py │ ├── or_op.py │ ├── pad.py │ ├── pow.py │ ├── prelu.py │ ├── quantizelinear.py │ ├── range.py │ ├── reciprocal.py │ ├── reducemax.py │ ├── reducemean.py │ ├── reducesum.py │ ├── relu.py │ ├── reshape.py │ ├── scatterelements.py │ ├── scatternd.py │ ├── selu.py │ ├── shape.py │ ├── sigmoid.py │ ├── sin.py │ ├── sinh.py │ ├── slice.py │ ├── softmax.py │ ├── softplus.py │ ├── split.py │ ├── sqrt.py │ ├── squeeze.py │ ├── sub.py │ ├── sum.py │ ├── tanh.py │ ├── tile.py │ ├── topk.py │ ├── transpose.py │ ├── trilu.py │ ├── unsqueeze.py │ └── where.py ├── runner.py └── version.py ├── pylintrc ├── pyproject.toml ├── tests ├── onnx_models_test.py ├── onnx_ops_test.py └── run_all_tests.sh └── tools ├── analyze_model.py └── op_code_generator.py /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | name: Feature request 4 | description: Create a feature request for a functionality that does not currently exist. 5 | title: "[Feature request] " 6 | labels: ["enhancement"] 7 | body: 8 | - type: markdown 9 | attributes: 10 | value: Thanks for taking the time to create a feature request! 11 | - type: textarea 12 | id: system-info 13 | attributes: 14 | label: System information 15 | description: "jaxonnxruntime version (you are using):" 16 | validations: 17 | required: false 18 | - type: textarea 19 | id: solves-problem 20 | attributes: 21 | label: What is the problem that this feature solves? 22 | description: Please detail the discrepancy with our current functionality. 23 | validations: 24 | required: false 25 | - type: textarea 26 | id: alternatives 27 | attributes: 28 | label: Alternatives considered 29 | description: Describe the alternatives you have considered 30 | placeholder: A clear and concise description of any alternative solutions or features you've considered. 31 | validations: 32 | required: false 33 | - type: textarea 34 | id: feature 35 | attributes: 36 | label: Describe the feature 37 | description: Why is this feature necessary? What does it accomplish? 38 | validations: 39 | required: false 40 | - type: textarea 41 | id: api-impact 42 | attributes: 43 | label: Will this influence the current api (Y/N)? 44 | placeholder: If yes, how? 45 | validations: 46 | required: false 47 | - type: textarea 48 | id: feature-area 49 | attributes: 50 | label: Feature Area 51 | description: Which area does this impact? e.g., model usage, backend, best practices, shape_inference, training, test, operators, data preprocessing, CI pipelines. 52 | validations: 53 | required: false 54 | - type: dropdown 55 | id: contribute 56 | attributes: 57 | label: "Are you willing to contribute it (Y/N)" 58 | options: 59 | - "Yes" 60 | - "No" 61 | validations: 62 | required: false 63 | - type: textarea 64 | id: notes 65 | attributes: 66 | label: Notes 67 | description: Any additional information 68 | validations: 69 | required: false 70 | -------------------------------------------------------------------------------- /.github/workflows/pythonpublish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v1 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools build wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: __token__ 28 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 29 | run: | 30 | python -m build 31 | twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | \#*\# 3 | *.pyc 4 | .tfds 5 | .DS_Store 6 | docs/**/_autosummary 7 | docs/_build 8 | dist/ 9 | build/ 10 | *.egg-info 11 | *.rej 12 | .pytype 13 | .vscode/* 14 | /.devcontainer 15 | docs/**/tmp 16 | .coverage 17 | coverage.xml 18 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Install the pre-commit hooks below with 2 | # 'pre-commit install' 3 | 4 | # Auto-update the version of the hooks with 5 | # 'pre-commit autoupdate' 6 | 7 | # Run the hooks on all files with 8 | # 'pre-commit run --all' 9 | 10 | repos: 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v4.4.0 13 | hooks: 14 | - id: check-toml 15 | - id: trailing-whitespace 16 | exclude: ^docs/.*\.md$ 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JAX ONNX Runtime 2 | 3 | JAX ONNX Runtime is a robust and user-friendly tool chain that enables the seamless execution of ONNX models using JAX as the backend. 4 | 5 | More specifically, this tool chain has the abilities: 6 | 7 | - ONNX Model Conversion: Converts ONNX models into JAX format modules. Tested on popular large language models including GPT-2, BERT, and LLaMA. 8 | 9 | - Hardware Acceleration: Enable the jit mode of the converted JAX modules, which accelerates execution on GPU and/or TPU. 10 | 11 | - Compatibility with JAX ecosystem: E.g., export models by Orbax, and serve the saved models by Tensorflow Serving system. 12 | 13 | ## Get Started 14 | 15 | - We follow most of the interface definitions by `onnx.backend` [here](https://onnx.ai/onnx/api/backend.html). 16 | 17 | - Please check a brief example on model conversion and forward calling in [`examples/imagenet/imagenet_main.py`](https://github.com/google/jaxonnxruntime/blob/main/examples/imagenet/imagenet_main.py). 18 | 19 | ## Contributions and Discussions 20 | 21 | We believe that collaboration is the key to building remarkable software, and we wholeheartedly welcome contributions from developers like you. 22 | You can make a real impact and help shape the future of our project with contributions such as 23 | [implementing new operators](https://github.com/google/jaxonnxruntime/blob/main/docs/adding_a_new_op.rst) and increasing support for more ML models. 24 | 25 | Our contributors will have a chance to earn [Google Open Source Peer Bonus](https://opensource.google/documentation/reference/growing/peer-bonus), so that your valuable contributions won't go unnoticed. 26 | Your hard work will be rewarded both by the community and by Google. 27 | Together, let's create an amazing library and foster a supportive environment for open-source enthusiasts. 28 | 29 | Thank you for taking the time to contribute! Please see [the contribution guidelines](https://github.com/google/jaxonnxruntime/blob/main/contributing.md). 30 | 31 | ## License 32 | 33 | This project is licensed under the [Apache License](https://github.com/google/jaxonnxruntime/blob/main/LICENSE). 34 | -------------------------------------------------------------------------------- /contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing to jaxonnxruntime 2 | 3 | 🎉🎉 First off, thank you for taking the time to contribute! 🎉🎉 4 | 5 | The following is a set of guidelines, but not rules, for contributing to jaxonnxruntime. 6 | Use your best judgment, and feel free to propose changes to this document in a pull request. 7 | 8 | We follow most of the best practices listed in the [contributing guidelines](https://github.com/google/flax/blob/main/docs/contributing.md) of the `google/flax` project . 9 | Here we only list the difference. 10 | 11 | 21 | 22 | ## How to Contribute? 23 | 24 | ### Adding Support for New Operators 25 | 26 | When running through a new onnx model, you may find that some operators are not implemented in our repository. 27 | This can also be done by running the command: 28 | ```shell 29 | $ python tools/analyze_model.py 30 | ``` 31 | All source code for JAX backend of ONNX operators is located in [jaxonnxruntime/onnx_ops](https://github.com/google/jaxonnxruntime/tree/main/jaxonnxruntime/onnx_ops). 32 | You can generate new template implementation of ONNX operators by running the following command, and then fill in the blanks marked by ```TODO```. 33 | Please do use JAX methods to do the implementation! 34 | ```shell 35 | $ python tools/op_code_generator.py 36 | ``` 37 | 38 | After finishing the implementation, please add unit test for the new operator by adding a line 39 | ```python 40 | include_patterns.append('test_{op_name_lower}_') 41 | ``` 42 | to [```tests/onnx_ops_test.py```](https://github.com/google/jaxonnxruntime/blob/main/tests/onnx_ops_test.py). 43 | Then run the unit tests with the following command to make sure it is compatible with both ONNX and JAX. 44 | Make sure all tests are passed before submitting a pull request. 45 | ```shell 46 | $ python tests/onnx_ops_test.py 47 | ``` 48 | -------------------------------------------------------------------------------- /docs/_static/css/jaxonnxruntime_theme.css: -------------------------------------------------------------------------------- 1 | @import url("theme.css"); 2 | 3 | .wy-nav-content { 4 | max-width: 1290px; 5 | } 6 | 7 | .rst-content table.docutils { 8 | width: 100%; 9 | } 10 | 11 | .rst-content table.docutils td { 12 | vertical-align: top; 13 | padding: 0; 14 | } 15 | 16 | .rst-content table.docutils td p { 17 | padding: 8px; 18 | } 19 | 20 | .rst-content div[class^=highlight] { 21 | border: 0; 22 | margin: 0; 23 | } 24 | -------------------------------------------------------------------------------- /docs/examples/jaxonnxruntime-python-api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Jaxonnxruntime python API example.""" 29 | from jaxonnxruntime import backend as jax_backend 30 | import numpy as np 31 | import torch 32 | 33 | import onnx 34 | 35 | MODEL_FILE = ".model.onnx" 36 | 37 | 38 | def model(): 39 | """A simple torch model to calculate addition of two tensors.""" 40 | 41 | class Model(torch.nn.Module): 42 | 43 | def forward(self, x, y): 44 | return x.add(y) 45 | 46 | return Model() 47 | 48 | 49 | def create_model(dtype: torch.dtype = torch.float32): 50 | """Create an instance of the model and export it to ONNX graph format, with dynamic size for the data.""" 51 | sample_x = torch.ones(3, dtype=dtype) 52 | sample_y = torch.zeros(3, dtype=dtype) 53 | 54 | torch.onnx.export( 55 | model(), 56 | (sample_x, sample_y), 57 | MODEL_FILE, 58 | input_names=["x", "y"], 59 | output_names=["z"], 60 | dynamic_axes={"x": {0: "array_length_x"}, "y": {0: "array_length_y"}}, 61 | ) 62 | 63 | 64 | def main(): 65 | """main function.""" 66 | create_model() 67 | onnx_model = onnx.load(MODEL_FILE) 68 | backend_rep = jax_backend.BackendRep(onnx_model) 69 | 70 | # Run the model on CPU consuming and producing numpy arrays 71 | def run(x: np.array, y: np.array) -> np.array: 72 | z = backend_rep.run({"x": x, "y": y}) 73 | return z[0] 74 | 75 | print(run(x=np.float32([1.0, 2.0, 3.0]), y=np.float32([4.0, 5.0, 6.0]))) 76 | # [array([5., 7., 9.], dtype=float32)] 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Jaxonnxruntime documentation main file, created by 2 | sphinx-quickstart on Mon Feb 17 11:41:38 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | ****************************** 7 | Jaxonnxruntime 8 | ****************************** 9 | 10 | 11 | .. div:: sd-text-left sd-font-italic 12 | 13 | JAX based ONNX backend 14 | ------------------------------ 15 | 16 | `Jaxonnxruntime` is focused on creating a JAX-based backend for the ONNX format. The benefits of using ONNX include interoperability and ease of hardware access, while JAX provides a similar API to Numpy and allows for performance speed-ups through jit compilation. 17 | 18 | `Jaxonnxruntime` implements the backend by re-writing the ONNX operator implementations in the "JAX programming way" and interpreting all data structures as PyTree. The user will be able to run the jit function on the run_model function for performance speed-up and apply other Jax transformations. 19 | 20 | .. toctree:: 21 | :maxdepth: 2 22 | :caption: Getting Started 23 | 24 | adding_a_new_op -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx>=3.3.1 2 | sphinx-book-theme 3 | Pygments>=2.6.1 4 | ipykernel 5 | myst_nb 6 | recommonmark 7 | ipython_genutils 8 | sphinx-design 9 | jupytext==1.13.8 10 | torchvision 11 | 12 | # Need to pin docutils to 0.16 to make bulleted lists appear correctly on 13 | # ReadTheDocs: https://stackoverflow.com/a/68008428 14 | docutils==0.16 15 | 16 | # The next packages are for notebooks. 17 | matplotlib 18 | scikit-learn 19 | # Must install itself for notebook execution and autodocs to work. 20 | . 21 | -------------------------------------------------------------------------------- /jaxonnxruntime/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """core util functions for onnx.""" 16 | 17 | from jaxonnxruntime.core import config_class 18 | from jaxonnxruntime.core import handler 19 | from jaxonnxruntime.core import onnx_graph 20 | from jaxonnxruntime.core import onnx_node 21 | from jaxonnxruntime.core import onnx_utils 22 | -------------------------------------------------------------------------------- /jaxonnxruntime/core/call_onnx_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | import jax 17 | from jaxonnxruntime.core import call_onnx 18 | from jaxonnxruntime.core import config_class 19 | from jaxonnxruntime.onnx_ops import abs 20 | import numpy as np 21 | 22 | import onnx 23 | 24 | 25 | def create_test_model(x: np.ndarray) -> onnx.ModelProto: 26 | input_tensor = onnx.ValueInfoProto( 27 | name='input', 28 | type=onnx.TypeProto( 29 | tensor_type=onnx.TypeProto.Tensor( 30 | elem_type=onnx.TensorProto.FLOAT, 31 | shape=onnx.TensorShapeProto( 32 | dim=[ 33 | onnx.TensorShapeProto.Dimension(dim_value=d) 34 | for d in x.shape 35 | ] 36 | ), 37 | ) 38 | ), 39 | ) 40 | output_tensor = onnx.ValueInfoProto( 41 | name='output', 42 | type=onnx.TypeProto( 43 | tensor_type=onnx.TypeProto.Tensor( 44 | elem_type=onnx.TensorProto.FLOAT, 45 | shape=onnx.TensorShapeProto( 46 | dim=[ 47 | onnx.TensorShapeProto.Dimension(dim_value=d) 48 | for d in x.shape 49 | ] 50 | ), 51 | ) 52 | ), 53 | ) 54 | node_abs = onnx.NodeProto(op_type='Abs', input=['input'], output=['output']) 55 | graph_def = onnx.GraphProto( 56 | node=[node_abs], 57 | name='abs_graph', 58 | input=[input_tensor], 59 | output=[output_tensor], 60 | ) 61 | model_proto = onnx.ModelProto(graph=graph_def, producer_name='onnx-example') 62 | return model_proto 63 | 64 | 65 | class TestCallOnnx(absltest.TestCase): 66 | 67 | def test_basic(self): 68 | x = np.array([-2.0, 1.0, 3.0], dtype=np.float32) 69 | model_proto = create_test_model(x) 70 | jax_func, model_params = call_onnx.call_onnx_model(model_proto, [x]) 71 | results = jax_func(model_params, [x]) 72 | expect = [np.array([2.0, 1.0, 3.0], dtype=np.float32)] 73 | np.testing.assert_array_equal(results, expect) 74 | 75 | with config_class.jaxort_experimental_support_abstract_shape(True): 76 | x = np.array([-2.0, -8.0, 3.0], dtype=np.float32) 77 | jax_func, model_params = call_onnx.call_onnx_model( 78 | model_proto, [jax.ShapeDtypeStruct(x.shape, x.dtype)] 79 | ) 80 | results = jax_func(model_params, [x]) 81 | expect = [np.array([2.0, 8.0, 3.0], dtype=np.float32)] 82 | np.testing.assert_array_equal(results, expect) 83 | 84 | 85 | if __name__ == '__main__': 86 | absltest.main() 87 | -------------------------------------------------------------------------------- /jaxonnxruntime/core/handler_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from jaxonnxruntime.core import handler 17 | from onnx import defs 18 | 19 | 20 | class TestHandler(absltest.TestCase): 21 | 22 | def test_get_since_version(self): 23 | class MyOpHandler(handler.Handler): 24 | pass 25 | 26 | MyOpHandler.DOMAIN = "" 27 | MyOpHandler.OP_TYPE = "Add" 28 | version = 11 29 | print("get_all_schemas", defs.get_all_schemas()) 30 | schema = defs.get_schema( 31 | MyOpHandler.OP_TYPE, 32 | max_inclusive_version=version, 33 | domain=MyOpHandler.DOMAIN, 34 | ) 35 | since_version = MyOpHandler.get_since_version(version) 36 | self.assertEqual(since_version, schema.since_version) 37 | 38 | def test_register_op(self): 39 | @handler.register_op("my_op", domain="ai.onnx") 40 | class MyOpHandler(handler.Handler): 41 | pass 42 | 43 | self.assertEqual(MyOpHandler.OP_TYPE, "my_op") 44 | self.assertEqual(MyOpHandler.DOMAIN, "ai.onnx") 45 | 46 | def test_castlike_model_version_14(self): 47 | class MyOpHandler(handler.Handler): 48 | pass 49 | 50 | MyOpHandler.DOMAIN = "" 51 | MyOpHandler.OP_TYPE = "CastLike" 52 | version = 14 53 | since_version = MyOpHandler.get_since_version(version) 54 | 55 | # CastLike was added in version 15. So there is no CastLike in model 56 | # version 14. 57 | self.assertEqual(since_version, -1) 58 | 59 | 60 | if __name__ == "__main__": 61 | absltest.main() 62 | -------------------------------------------------------------------------------- /jaxonnxruntime/core/onnx_graph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from jaxonnxruntime.core import onnx_graph 17 | 18 | import onnx 19 | 20 | OnnxGraph = onnx_graph.OnnxGraph 21 | 22 | 23 | class TestOnnxGraph(absltest.TestCase): 24 | 25 | def setUp(self): 26 | # create a simple ONNX graph proto with an Add and Relu node 27 | super().setUp() 28 | graph_proto = onnx.GraphProto( 29 | input=[onnx.ValueInfoProto(name="x")], 30 | output=[onnx.ValueInfoProto(name="y")], 31 | node=[ 32 | onnx.NodeProto( 33 | name="node_0", 34 | op_type="Add", 35 | input=["x", "x"], 36 | output=["add_out"], 37 | ), 38 | onnx.NodeProto( 39 | name="node_1", 40 | op_type="Conv", 41 | input=["add_out", "weight"], 42 | output=["conv_out"], 43 | ), 44 | onnx.NodeProto( 45 | name="node_2", 46 | op_type="Relu", 47 | input=["conv_out"], 48 | output=["y"], 49 | ), 50 | ], 51 | ) 52 | self.graph = OnnxGraph(graph_proto) 53 | 54 | def test_get_real_input(self): 55 | real_input = self.graph.get_real_input() 56 | self.assertEqual(real_input, ["x", "weight"]) 57 | 58 | def test_get_parent_nodes_name(self): 59 | parent_nodes = self.graph.get_parent_nodes_name("node_1") 60 | self.assertEqual(parent_nodes, ["node_0"]) 61 | 62 | def test_get_child_nodes_name(self): 63 | child_nodes = self.graph.get_child_nodes_name("node_1") 64 | self.assertEqual(child_nodes, ["node_2"]) 65 | 66 | def test_topological_sort(self): 67 | node_order = self.graph.topological_sort() 68 | self.assertLen(node_order, 3) 69 | self.assertEqual(node_order[0].op_type, "Add") 70 | self.assertEqual(node_order[1].op_type, "Conv") 71 | self.assertEqual(node_order[2].op_type, "Relu") 72 | 73 | 74 | if __name__ == "__main__": 75 | absltest.main() 76 | -------------------------------------------------------------------------------- /jaxonnxruntime/core/onnx_node_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from jaxonnxruntime.core import onnx_node 17 | 18 | import onnx 19 | 20 | 21 | OnnxNode = onnx_node.OnnxNode 22 | convert_onnx = onnx_node.convert_onnx 23 | 24 | 25 | class TestOnnxNode(absltest.TestCase): 26 | 27 | def test_onnx_node_init(self): 28 | # Create a dummy NodeProto object 29 | node_proto = onnx.NodeProto() 30 | node_proto.name = "test_node" 31 | node_proto.op_type = "Add" 32 | node_proto.domain = "test_domain" 33 | node_proto.attribute.add(name="test_attr", i=42) 34 | node_proto.input.extend(["input1", "input2"]) 35 | node_proto.output.extend(["output1", "output2"]) 36 | 37 | # Create an OnnxNode object 38 | node = OnnxNode(node_proto) 39 | 40 | # Test that the attributes were correctly set 41 | self.assertEqual(node.name, "test_node") 42 | self.assertEqual(node.op_type, "Add") 43 | self.assertEqual(node.domain, "test_domain") 44 | self.assertEqual(node.attrs["test_attr"], 42) 45 | self.assertEqual(node.inputs, ["input1", "input2"]) 46 | self.assertEqual(node.outputs, ["output1", "output2"]) 47 | self.assertEqual(node.node_proto, node_proto) 48 | self.assertIsNone(node.context_graph) 49 | 50 | def test_convert_onnx(self): 51 | # Test converting a few different types of attributes 52 | attr_proto = onnx.AttributeProto() 53 | attr_proto.f = 3.14 54 | self.assertLess(abs(float(convert_onnx(attr_proto)) - 3.14), 0.001) 55 | 56 | attr_proto = onnx.AttributeProto() 57 | attr_proto.i = 42 58 | self.assertEqual(convert_onnx(attr_proto), 42) 59 | 60 | attr_proto = onnx.AttributeProto() 61 | attr_proto.s = b"test_string" 62 | self.assertEqual(convert_onnx(attr_proto), "test_string") 63 | 64 | attr_proto = onnx.AttributeProto() 65 | tensor_proto = onnx.TensorProto() 66 | tensor_proto.dims.extend([2, 3]) 67 | tensor_proto.float_data.extend([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 68 | attr_proto.t.CopyFrom(tensor_proto) 69 | self.assertEqual(convert_onnx(attr_proto), tensor_proto) 70 | 71 | 72 | if __name__ == "__main__": 73 | absltest.main() 74 | -------------------------------------------------------------------------------- /jaxonnxruntime/core/onnx_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from jaxonnxruntime.core import onnx_utils 17 | 18 | import onnx 19 | 20 | 21 | class TestOnnxUtils(absltest.TestCase): 22 | 23 | def test_sanitize_tensor_names_in_graph(self): 24 | else_branch = onnx.GraphProto( 25 | input=[onnx.ValueInfoProto(name="else_in")], 26 | output=[onnx.ValueInfoProto(name="else_out")], 27 | node=[ 28 | onnx.NodeProto( 29 | name="node_0", 30 | op_type="Identity", 31 | input=["else_in"], 32 | output=["else_out"], 33 | ), 34 | ], 35 | ) 36 | else_attr = onnx.AttributeProto(g=else_branch) 37 | graph = onnx.GraphProto( 38 | input=[onnx.ValueInfoProto(name="x")], 39 | initializer=[ 40 | onnx.TensorProto(name="else_in"), 41 | ], 42 | output=[onnx.ValueInfoProto(name="y")], 43 | node=[ 44 | onnx.NodeProto( 45 | name="node_1", 46 | op_type="If", 47 | input=["x"], 48 | output=["y"], 49 | attribute=[else_attr], # Omit then branch for simplicity 50 | ), 51 | ], 52 | ) 53 | onnx_utils.sanitize_tensor_names_in_graph(graph) 54 | # Tensor names change: 55 | # x -> tensor_0, y -> tensor_1, else_in -> tensor_2, else_out -> tensor_3 56 | # Graph inputs & initializers & outputs 57 | self.assertEqual(graph.input[0].name, "tensor_0") 58 | self.assertEqual(graph.initializer[0].name, "tensor_2") 59 | self.assertEqual(graph.output[0].name, "tensor_1") 60 | # Node inputs & outputs 61 | self.assertEqual(graph.node[0].input[0], "tensor_0") 62 | self.assertEqual(graph.node[0].output[0], "tensor_1") 63 | # Subgraph inputs & outputs & nodes 64 | subgraph = graph.node[0].attribute[0].g 65 | self.assertEqual(subgraph.input[0].name, "tensor_2") 66 | self.assertEqual(subgraph.output[0].name, "tensor_3") 67 | self.assertEqual(subgraph.node[0].input[0], "tensor_2") 68 | self.assertEqual(subgraph.node[0].output[0], "tensor_3") 69 | 70 | 71 | if __name__ == "__main__": 72 | absltest.main() 73 | -------------------------------------------------------------------------------- /jaxonnxruntime/experimental/call_torch/Notes.md: -------------------------------------------------------------------------------- 1 | This is only for OSS pytorch env. The internal ONNX and PyTorch is out-of-date. 2 | -------------------------------------------------------------------------------- /jaxonnxruntime/experimental/call_torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """call_torch module to support convert pytorch module to jax function.""" 16 | 17 | from ._call_torch import call_torch 18 | from ._call_torch import torch_tensor_to_jax_array 19 | from .call_torch_test_util import CallTorchTestCase 20 | -------------------------------------------------------------------------------- /jaxonnxruntime/experimental/call_torch/_call_torch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Convert PyTorch function to Jax funtion.""" 16 | 17 | import io 18 | import os 19 | from typing import Any, Callable, Tuple, Union 20 | 21 | from absl import logging 22 | import jax 23 | from jaxonnxruntime.core import call_onnx 24 | import torch 25 | 26 | import onnx 27 | 28 | 29 | class TorchONNXExportError(Exception): 30 | pass 31 | 32 | 33 | def torch_tensor_to_jax_array( 34 | tensor: torch.Tensor, inplace: bool = False 35 | ) -> jax.Array: 36 | """Convert a torch tensor to a jax array.""" 37 | if not inplace: 38 | tensor = tensor.clone().detach() 39 | return jax.dlpack.from_dlpack(tensor) 40 | 41 | 42 | def call_torch( 43 | model: Union[ 44 | torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction 45 | ], 46 | args: Union[Tuple[Any, ...], torch.Tensor], 47 | onnx_dump_prefix: str | None = None, 48 | verbose: bool = False, 49 | ) -> Tuple[Callable[..., Any], Any]: 50 | """Give a pytorch model and return its equivilent jax function. 51 | 52 | Its API interface should be consistent with 53 | [`torch.onnx.export`](https://pytorch.org/docs/stable/onnx.html#torch.onnx.export) 54 | 55 | Args: 56 | model: the torch model to be exported. 57 | args: (tuple or torch.Tensor), model inputs args for torch.onnx.export. 58 | onnx_dump_prefix: The onnx_model debug directory. 59 | verbose: (bool, default False) if True, prints more debugging info. 60 | 61 | Returns: 62 | A JAX jittable function can be invoked with JAX pytree arguments. 63 | """ 64 | file_obj = io.BytesIO() 65 | try: 66 | torch.onnx.export( 67 | model=model, 68 | args=args, 69 | f=file_obj, 70 | export_params=True, 71 | verbose=verbose, 72 | dynamic_axes=None, 73 | keep_initializers_as_inputs=False, 74 | ) 75 | except Exception as e: 76 | raise TorchONNXExportError( 77 | "torch.onnx.export fails. Please debug torch.onnx.export manually" 78 | " first." 79 | ) from e 80 | if onnx_dump_prefix: 81 | if not os.path.exists(onnx_dump_prefix): 82 | os.makedirs(onnx_dump_prefix) 83 | onnx_model_file = os.path.join(onnx_dump_prefix, "model.onnx") 84 | with open(onnx_model_file, "wb") as f: 85 | f.write(file_obj.getvalue()) 86 | logging.info("Saving debug model.onnx to %s", onnx_model_file) 87 | file_obj.seek(0) 88 | onnx_model = onnx.load(file_obj) 89 | jax_args = jax.tree_util.tree_leaves( 90 | jax.tree.map(torch_tensor_to_jax_array, args) 91 | ) 92 | jax_fn, jax_model_params = call_onnx.call_onnx_model(onnx_model, jax_args) 93 | return jax_fn, jax_model_params 94 | -------------------------------------------------------------------------------- /jaxonnxruntime/experimental/custom_ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX ops.""" 16 | 17 | 18 | # PEP 484: import as is required for names to be exported. 19 | from jaxonnxruntime.experimental.custom_ops import zeros_like as zeros_like 20 | -------------------------------------------------------------------------------- /jaxonnxruntime/experimental/custom_ops/zeros_like.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define Custom ONNX ZerosLike operator. 29 | 30 | Here we demo how to reuse Tensorflow ops impelmentation. 31 | """ 32 | 33 | 34 | from collections.abc import Callable, Sequence 35 | import functools 36 | import inspect 37 | from typing import Any 38 | 39 | import jax 40 | from jax.experimental import jax2tf 41 | from jaxonnxruntime.core import handler 42 | from jaxonnxruntime.core import onnx_node 43 | from jaxonnxruntime.core import onnx_utils 44 | import tensorflow as tf 45 | 46 | 47 | @handler.register_op("ZerosLike", domain="jaxonnxruntime") 48 | class ZerosLike(handler.Handler): 49 | """Implementation of the ONNX ZerosLike custom operator.""" 50 | 51 | @classmethod 52 | def _prepare( 53 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 54 | ): 55 | sig = inspect.signature(onnx_jax_impl) 56 | kwparams = [ 57 | param.name 58 | for param in sig.parameters.values() 59 | if param.kind == inspect.Parameter.KEYWORD_ONLY 60 | ] 61 | for name in kwparams: 62 | node.attrs_dict[name] = node.attrs.get(name) 63 | 64 | @classmethod 65 | def version_1( 66 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 67 | ) -> Callable[..., Any]: 68 | """ONNX version_1 Identity op.""" 69 | cls._prepare(node, inputs, tf_zeros_like) 70 | return tf_zeros_like 71 | 72 | 73 | @functools.partial(jax.jit, static_argnames="dtype") 74 | def tf_zeros_like(x: jax.Array, *, dtype: int): 75 | """https://www.tensorflow.org/api_docs/python/tf/zeros_like for more details.""" 76 | jax_dtype = onnx_utils.tensor_dtype_to_jnp_dtype(dtype) # pytype: disable=wrong-arg-types 77 | 78 | def tf_func(input0): 79 | return tf.zeros_like(input0, dtype=jax_dtype) 80 | 81 | jax_func = jax2tf.call_tf(tf_func) 82 | return jax_func(x) 83 | -------------------------------------------------------------------------------- /jaxonnxruntime/experimental/custom_ops/zeros_like_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from jaxonnxruntime.core import call_onnx 17 | from jaxonnxruntime.experimental import custom_ops # pylint: disable=unused-import 18 | import numpy as np 19 | 20 | import onnx 21 | from onnx import helper as onnx_helper 22 | 23 | 24 | class ZerosLikeTest(absltest.TestCase): 25 | 26 | def test_basic(self): 27 | # Create input and output tensors. 28 | input_tensor = onnx_helper.make_tensor_value_info( 29 | "X", onnx.TensorProto.FLOAT, [1, 3, 224, 224] 30 | ) 31 | output_tensor = onnx_helper.make_tensor_value_info( 32 | "Y", onnx.TensorProto.FLOAT, [1, 3, 224, 224] 33 | ) 34 | 35 | # Create the ZerosLike node 36 | node = onnx_helper.make_node( 37 | "ZerosLike", 38 | inputs=["X"], 39 | outputs=["Y"], 40 | dtype=onnx.TensorProto.FLOAT, 41 | domain="jaxonnxruntime", 42 | ) 43 | 44 | # Create the graph with the node 45 | graph_def = onnx_helper.make_graph( 46 | [node], 47 | "ZerosLike_Model", 48 | [input_tensor], 49 | [output_tensor], 50 | ) 51 | 52 | # Create the model 53 | onnx_model = onnx_helper.make_model( 54 | graph_def, 55 | producer_name="JAX-ONNX", 56 | opset_imports=[ 57 | onnx_helper.make_opsetid( 58 | onnx.defs.ONNX_DOMAIN, onnx.defs.onnx_opset_version() 59 | ), 60 | onnx_helper.make_opsetid("jaxonnxruntime", 1), 61 | ], 62 | ) 63 | x = np.random.randn(3, 4, 5).astype(np.float32) 64 | y = np.zeros_like(x).astype(np.float32) 65 | inputs = [x] 66 | jax_model_func, jax_model_params = call_onnx.call_onnx_model( 67 | onnx_model, inputs 68 | ) 69 | outputs = jax_model_func(jax_model_params, inputs) 70 | expect_outputs = [y] 71 | self.assertLen(outputs, 1, f"output is {outputs}") 72 | np.testing.assert_allclose(outputs[0], expect_outputs[0], atol=1e-6) 73 | 74 | 75 | if __name__ == "__main__": 76 | absltest.main() 77 | -------------------------------------------------------------------------------- /jaxonnxruntime/experimental/export/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Exportable modules.""" 16 | 17 | from .exportable import Exportable as JaxExportable # pylint: disable=g-importing-member 18 | from .tensorflow_exportable import TensorflowExportable # pylint: disable=g-importing-member 19 | from .torch_exportable import TorchExportable # pylint: disable=g-importing-member 20 | -------------------------------------------------------------------------------- /jaxonnxruntime/experimental/export/exportable_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for jax exportable.""" 16 | 17 | import os 18 | from typing import Any 19 | from absl.testing import absltest 20 | import chex 21 | import jax 22 | from jax import numpy as jnp 23 | from jaxonnxruntime.experimental.export import exportable 24 | from jaxonnxruntime.experimental.export import exportable_test_utils 25 | import numpy as np 26 | 27 | 28 | class ExportableTest(exportable_test_utils.ExportableTestCase): 29 | 30 | def test_basic(self): 31 | def jax_func(x): 32 | return jnp.sum(jnp.sin(x)) 33 | 34 | x = jnp.arange(32, dtype=np.float32).reshape((8, 4)) 35 | exported_inputs = (x,) 36 | 37 | exportable_obj = exportable.Exportable( 38 | jax_func, exported_inputs, {}, ['cpu', 'cuda', 'rocm', 'tpu'] 39 | ) 40 | exported = exportable_obj.export() 41 | loaded_exported = self._save_and_load_exported(exported) 42 | self.assertClassAttributeType(exported, loaded_exported) 43 | 44 | result = exported.call(*exported_inputs) 45 | result2 = loaded_exported.call(*exported_inputs) 46 | chex.assert_trees_all_close(result, result2) 47 | 48 | 49 | if __name__ == '__main__': 50 | jax.config.parse_flags_with_absl() 51 | os.environ['XLA_FLAGS'] = ( 52 | '--xla_force_host_platform_device_count=8' # Use 8 CPU devices 53 | ) 54 | absltest.main() 55 | -------------------------------------------------------------------------------- /jaxonnxruntime/experimental/export/exportable_test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for tf2export.""" 16 | 17 | import inspect 18 | from typing import Any 19 | from absl import logging 20 | from absl.testing import parameterized 21 | import jax 22 | from jax import export as jax_export 23 | from jaxonnxruntime.experimental.export import exportable_utils 24 | 25 | 26 | class ExportableTestCase(parameterized.TestCase): 27 | """Base class for Exportable tests.""" 28 | 29 | def _save_and_load_exported( 30 | self, exported: jax_export.Exported 31 | ) -> jax_export.Exported: 32 | model_path = self.create_tempdir().full_path 33 | exportable_utils.save_exported(exported, model_path) 34 | loaded_exported = exportable_utils.load_exported(model_path) 35 | return loaded_exported 36 | 37 | def check_exported_call(self, exported: jax_export.Exported, *args, **kwargs): 38 | logging.info('exported.__dict__: %s', exported.__dict__) 39 | f = exported.call 40 | f = jax.jit(f) 41 | lowered = f.lower(*args, **kwargs) 42 | lowering = lowered._lowering # pylint: disable=protected-access 43 | compile_args = lowering.compile_args 44 | mlir_module_str = lowering.as_text() 45 | logging.info('compile_args: %s', compile_args) 46 | logging.info('mlir_module_str: %s', mlir_module_str) 47 | 48 | def assertClassAttributeType(self, obj: Any, other_obj: Any): # pylint: disable=invalid-name 49 | def get_attributes_and_types_inspect(obj): 50 | type_dict = {} 51 | for name, value in inspect.getmembers(obj): 52 | if not name.startswith('__'): 53 | type_dict[name] = type(value) 54 | return type_dict 55 | 56 | obj_type_dict = get_attributes_and_types_inspect(obj) 57 | other_obj_type_dict = get_attributes_and_types_inspect(other_obj) 58 | for k in obj_type_dict: 59 | self.assertEqual( 60 | obj_type_dict[k], 61 | other_obj_type_dict[k], 62 | f'Exported {k} does not match loaded exported' 63 | f' {other_obj_type_dict[k]}', 64 | ) 65 | -------------------------------------------------------------------------------- /jaxonnxruntime/experimental/export/torch_exportable_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for torch_exportable.""" 16 | 17 | from absl import logging 18 | from absl.testing import absltest 19 | import chex 20 | import jax 21 | from jaxonnxruntime.experimental.export import exportable_test_utils 22 | from jaxonnxruntime.experimental.export import exportable_utils 23 | from jaxonnxruntime.experimental.export import torch_exportable 24 | import torch 25 | 26 | 27 | class TorchExportableObjTest(exportable_test_utils.ExportableTestCase): 28 | 29 | def setUp(self): 30 | super().setUp() 31 | 32 | args = (torch.tensor([1.0, 2.0, 3.0]), torch.tensor([1.0, 2.0, 3.0])) 33 | kwargs = {} 34 | 35 | def f(x, y): 36 | return x + y 37 | 38 | torch_module = f 39 | 40 | self.exportable = torch_exportable.TorchExportable( 41 | torch_module, args, kwargs, ['cpu'] 42 | ) 43 | self.args = args 44 | self.kwargs = kwargs 45 | self.torch_module = torch_module 46 | 47 | def test_exportable(self): 48 | exported = self.exportable.export() 49 | logging.info('exported: %s', exported) 50 | loaded_exported = self._save_and_load_exported(exported) 51 | self.assertClassAttributeType(exported, loaded_exported) 52 | args = jax.tree_util.tree_map( 53 | exportable_utils.torch_tensor_to_jax_array, self.args 54 | ) 55 | kwargs = jax.tree_util.tree_map( 56 | exportable_utils.torch_tensor_to_jax_array, self.kwargs 57 | ) 58 | result = exported.call(*args, **kwargs) 59 | result2 = loaded_exported.call(*args, **kwargs) 60 | chex.assert_trees_all_close(result, result2) 61 | 62 | 63 | if __name__ == '__main__': 64 | jax.config.parse_flags_with_absl() 65 | jax.config.update('jax_traceback_filtering', 'off') 66 | absltest.main() 67 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/abs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Abs operator.""" 16 | 17 | from collections.abc import Callable, Sequence 18 | import functools 19 | from typing import Any 20 | 21 | import jax 22 | from jax import numpy as jnp 23 | from jaxonnxruntime.core import handler 24 | from jaxonnxruntime.core import onnx_node 25 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Abs") 30 | class Abs(handler.Handler): 31 | """Implementation of the ONNX Abs operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_1( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_1 Abs op.""" 44 | cls._prepare(node, inputs, onnx_abs) 45 | return onnx_abs 46 | 47 | @classmethod 48 | def version_6( 49 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 50 | ) -> Callable[..., Any]: 51 | """ONNX version_6 Abs op.""" 52 | cls._prepare(node, inputs, onnx_abs) 53 | return onnx_abs 54 | 55 | @classmethod 56 | def version_13( 57 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 58 | ) -> Callable[..., Any]: 59 | """ONNX version_13 Abs op.""" 60 | cls._prepare(node, inputs, onnx_abs) 61 | return onnx_abs 62 | 63 | 64 | @functools.partial(jax.jit, static_argnames=()) 65 | def onnx_abs(*input_args): 66 | """The internal jax impl for onnx Abs op.""" 67 | assert len(input_args) == 1 68 | (x,) = input_args 69 | return jnp.abs(x) 70 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/acos.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Acos operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Acos") 30 | class Acos(handler.Handler): 31 | """Implementation of the ONNX Acos operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_7( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_7 Acos op.""" 44 | cls._prepare(node, inputs, onnx_acos) 45 | return onnx_acos 46 | 47 | 48 | @functools.partial(jax.jit, static_argnames=()) 49 | def onnx_acos(*input_args): 50 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Acos for more details.""" 51 | assert len(input_args) == 1 52 | data = input_args[0] 53 | return jnp.arccos(data) 54 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/acosh.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Acosh operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Acosh") 30 | class Acosh(handler.Handler): 31 | """Implementation of the ONNX Acosh operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_9( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_9 Acosh op.""" 44 | cls._prepare(node, inputs, onnx_acosh) 45 | return onnx_acosh 46 | 47 | 48 | @functools.partial(jax.jit, static_argnames=()) 49 | def onnx_acosh(*input_args): 50 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Acosh for more details.""" 51 | assert len(input_args) == 1 52 | data = input_args[0] 53 | return jnp.arccosh(data) 54 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/add.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Add operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 38 | 39 | 40 | @handler.register_op("Add") 41 | class Add(handler.Handler): 42 | """Implementation of the ONNX Add operator.""" 43 | 44 | @classmethod 45 | def _prepare( 46 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 47 | ): 48 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 49 | 50 | @classmethod 51 | def version_1( 52 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 53 | ) -> Callable[..., Any]: 54 | """ONNX version_1 Add op.""" 55 | cls._prepare(node, inputs, onnx_add) 56 | return onnx_add 57 | 58 | @classmethod 59 | def version_6( 60 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 61 | ) -> Callable[..., Any]: 62 | """ONNX version_6 Add op.""" 63 | cls._prepare(node, inputs, onnx_add) 64 | return onnx_add 65 | 66 | @classmethod 67 | def version_7( 68 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 69 | ) -> Callable[..., Any]: 70 | """ONNX version_7 Add op.""" 71 | cls._prepare(node, inputs, onnx_add) 72 | return onnx_add 73 | 74 | @classmethod 75 | def version_13( 76 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 77 | ) -> Callable[..., Any]: 78 | """ONNX version_13 Add op.""" 79 | cls._prepare(node, inputs, onnx_add) 80 | return onnx_add 81 | 82 | @classmethod 83 | def version_14( 84 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 85 | ) -> Callable[..., Any]: 86 | """ONNX version_14 Add op.""" 87 | cls._prepare(node, inputs, onnx_add) 88 | return onnx_add 89 | 90 | 91 | @functools.partial(jax.jit, static_argnames=()) 92 | def onnx_add(*input_args): 93 | """The internal jax impl for onnx Add op.""" 94 | assert len(input_args) == 2 95 | a, b = input_args 96 | return jnp.add(a, b) 97 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/and_op.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX And operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("And") 30 | class And(handler.Handler): 31 | """Implementation of the ONNX And operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_7( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_7 And op.""" 44 | cls._prepare(node, inputs, onnx_and) 45 | return onnx_and 46 | 47 | 48 | @functools.partial(jax.jit, static_argnames=()) 49 | def onnx_and(*input_args): 50 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#And for more details.""" 51 | assert len(input_args) == 2 52 | x, y = input_args 53 | return jnp.logical_and(x, y) 54 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/argmax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX ArgMax operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | 27 | 28 | @handler.register_op('ArgMax') 29 | class ArgMax(handler.Handler): 30 | """Implementation of the ONNX ArgMax operator.""" 31 | 32 | @classmethod 33 | def _prepare( 34 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 35 | ): 36 | node.attrs_dict['axis'] = node.attrs.get('axis', 0) 37 | node.attrs_dict['keepdims'] = node.attrs.get('keepdims', 1) 38 | node.attrs_dict['select_last_index'] = node.attrs.get( 39 | 'select_last_index', 0 40 | ) 41 | 42 | @classmethod 43 | def version_13( 44 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 45 | ) -> Callable[..., Any]: 46 | """ONNX version_13 ArgMax op.""" 47 | cls._prepare(node, inputs, onnx_argmax) 48 | return onnx_argmax 49 | 50 | 51 | @functools.partial( 52 | jax.jit, static_argnames=('axis', 'keepdims', 'select_last_index') 53 | ) 54 | def onnx_argmax(data, *, axis, keepdims, select_last_index): 55 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#ArgMax for more details.""" 56 | keepdims = False if keepdims == 0 else True 57 | if select_last_index == 0: 58 | return jnp.argmax(data, axis=axis, keepdims=keepdims) 59 | data = jnp.flip(data, axis) 60 | result = jnp.argmax(data, axis=axis) 61 | result = data.shape[axis] - result - 1 62 | if keepdims: 63 | result = jnp.expand_dims(result, axis) 64 | return result 65 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/argmin.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Callable, Sequence 16 | import functools 17 | from typing import Any 18 | 19 | import jax 20 | from jax import numpy as jnp 21 | from jaxonnxruntime.core import handler 22 | from jaxonnxruntime.core import onnx_node 23 | 24 | 25 | @handler.register_op('ArgMin') 26 | class ArgMin(handler.Handler): 27 | """Implementation of the ONNX ArgMin operator.""" 28 | 29 | @classmethod 30 | def _prepare( 31 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 32 | ): 33 | node.attrs_dict['axis'] = node.attrs.get('axis', 0) 34 | node.attrs_dict['keepdims'] = node.attrs.get('keepdims', 1) 35 | node.attrs_dict['select_last_index'] = node.attrs.get( 36 | 'select_last_index', 0 37 | ) 38 | 39 | @classmethod 40 | def version_13( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_13 ArgMin op.""" 44 | cls._prepare(node, inputs, onnx_argmin) 45 | return onnx_argmin 46 | 47 | 48 | @functools.partial( 49 | jax.jit, static_argnames=('axis', 'keepdims', 'select_last_index') 50 | ) 51 | def onnx_argmin(data, *, axis, keepdims, select_last_index): 52 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#ArgMin for more details.""" 53 | keepdims = False if keepdims == 0 else True 54 | if select_last_index == 0: 55 | return jnp.argmin(data, axis=axis, keepdims=keepdims) 56 | data = jnp.flip(data, axis) 57 | result = jnp.argmin(data, axis=axis) 58 | result = data.shape[axis] - result - 1 59 | if keepdims: 60 | result = jnp.expand_dims(result, axis) 61 | return result 62 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/asin.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Asin operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Asin") 30 | class Asin(handler.Handler): 31 | """Implementation of the ONNX Asin operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_7( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_7 Asin op.""" 44 | cls._prepare(node, inputs, onnx_asin) 45 | return onnx_asin 46 | 47 | 48 | @functools.partial(jax.jit, static_argnames=()) 49 | def onnx_asin(*input_args): 50 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Asin for more details.""" 51 | assert len(input_args) == 1 52 | data = input_args[0] 53 | return jnp.arcsin(data) 54 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/asinh.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Asinh operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Asinh") 30 | class Asinh(handler.Handler): 31 | """Implementation of the ONNX Asinh operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_9( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_9 Asinh op.""" 44 | cls._prepare(node, inputs, onnx_asinh) 45 | return onnx_asinh 46 | 47 | 48 | @functools.partial(jax.jit, static_argnames=()) 49 | def onnx_asinh(*input_args): 50 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Asinh for more details.""" 51 | assert len(input_args) == 1 52 | data = input_args[0] 53 | return jnp.arcsinh(data) 54 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/atan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Atan operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Atan") 30 | class Atan(handler.Handler): 31 | """Implementation of the ONNX Atan operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_7( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_7 Atan op.""" 44 | cls._prepare(node, inputs, onnx_atan) 45 | return onnx_atan 46 | 47 | 48 | @functools.partial(jax.jit, static_argnames=()) 49 | def onnx_atan(*input_args): 50 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Atan for more details.""" 51 | assert len(input_args) == 1 52 | data = input_args[0] 53 | return jnp.arctan(data) 54 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/atanh.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Atanh operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Atanh") 30 | class Atanh(handler.Handler): 31 | """Implementation of the ONNX Atanh operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_9( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_9 Atanh op.""" 44 | cls._prepare(node, inputs, onnx_atanh) 45 | return onnx_atanh 46 | 47 | 48 | @functools.partial(jax.jit, static_argnames=()) 49 | def onnx_atanh(*input_args): 50 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Atanh for more details.""" 51 | assert len(input_args) == 1 52 | data = input_args[0] 53 | return jnp.arctanh(data) 54 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/bitshift.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX BitShift operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | 27 | 28 | @handler.register_op("BitShift") 29 | class BitShift(handler.Handler): 30 | """Implementation of the ONNX BitShift operator.""" 31 | 32 | @classmethod 33 | def _prepare( 34 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 35 | ): 36 | node.attrs_dict["direction"] = node.attrs.get("direction") 37 | if node.attrs_dict["direction"] is None: 38 | raise ValueError("Operator BitShift requires attribute 'direction'!") 39 | if ( 40 | node.attrs_dict["direction"] != "LEFT" 41 | and node.attrs_dict["direction"] != "RIGHT" 42 | ): 43 | raise ValueError( 44 | "Operator BitShift only supports LEFT and RIGHT directions!" 45 | ) 46 | 47 | @classmethod 48 | def version_11( 49 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 50 | ) -> Callable[..., Any]: 51 | """ONNX version_11 BitShift op.""" 52 | cls._prepare(node, inputs, onnx_bitshift) 53 | return onnx_bitshift 54 | 55 | 56 | @functools.partial(jax.jit, static_argnames="direction") 57 | def onnx_bitshift(x, y, *, direction): 58 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#BitShift for more details.""" 59 | if direction == "LEFT": 60 | return jnp.left_shift(x, y) 61 | return jnp.right_shift(x, y) 62 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/castlike.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX CastLike operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jaxonnxruntime.core import config_class 24 | 25 | config = config_class.config 26 | from jaxonnxruntime.core import handler 27 | from jaxonnxruntime.core import onnx_node 28 | from jaxonnxruntime.core import onnx_utils 29 | 30 | 31 | @handler.register_op("CastLike") 32 | class CastLike(handler.Handler): 33 | """Implementation of the ONNX CastLike operator.""" 34 | 35 | @classmethod 36 | def _prepare( 37 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 38 | ): 39 | from_type = cls._get_type(node, node.inputs[0], inputs[0]) 40 | node.attrs_dict["from_type"] = from_type 41 | 42 | @classmethod 43 | def _get_type( 44 | cls, node: onnx_node.OnnxNode, input_name: str, input_value: Any 45 | ): 46 | if node.context_graph.value_info_dict.get(input_name) is not None: 47 | tensor_proto = node.context_graph.value_info_dict.get(input_name) 48 | res_type = onnx_utils.tensor_dtype_to_jnp_dtype( 49 | tensor_proto.type.tensor_type.elem_type 50 | ) 51 | elif config.jaxort_only_allow_initializers_as_static_args: 52 | if input_name in node.context_graph.get_constant_dict(): 53 | tensor = node.context_graph.get_constant_dict().get(input_name) 54 | res_type = tensor.dtype 55 | else: 56 | raise ValueError( 57 | "`config.jaxort_only_allow_initializers_as_static_args = True but " 58 | f"{input_name} tensor is not constant. We can not use it" 59 | "a static argument of the `Cast` operator. " 60 | ) 61 | else: 62 | res_type = input_value.dtype 63 | return res_type 64 | 65 | @classmethod 66 | def version_15( 67 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 68 | ) -> Callable[..., Any]: 69 | """ONNX version_15 CastLike op.""" 70 | cls._prepare(node, inputs, onnx_castlike) 71 | return onnx_castlike 72 | 73 | @classmethod 74 | def version_19( 75 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 76 | ) -> Callable[..., Any]: 77 | """ONNX version_19 CastLike op.""" 78 | cls._prepare(node, inputs, onnx_castlike) 79 | return onnx_castlike 80 | 81 | 82 | @functools.partial(jax.jit, static_argnames=("from_type",)) 83 | def onnx_castlike(*input_args, from_type): 84 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#CastLike for more details.""" 85 | assert len(input_args) == 2 86 | inp, target = input_args 87 | return inp.view(from_type).astype(target.dtype) 88 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/ceil.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Ceil operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Ceil") 30 | class Ceil(handler.Handler): 31 | """Implementation of the ONNX Ceil operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_6( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_6 Ceil op.""" 44 | cls._prepare(node, inputs, onnx_ceil) 45 | return onnx_ceil 46 | 47 | @classmethod 48 | def version_13( 49 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 50 | ) -> Callable[..., Any]: 51 | """ONNX version_13 Ceil op.""" 52 | cls._prepare(node, inputs, onnx_ceil) 53 | return onnx_ceil 54 | 55 | 56 | @functools.partial(jax.jit, static_argnames=()) 57 | def onnx_ceil(*input_args): 58 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Ceil for more details.""" 59 | assert len(input_args) == 1 60 | return jnp.ceil(input_args[0]) 61 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/clip.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Clip operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | 27 | 28 | @handler.register_op('Clip') 29 | class Clip(handler.Handler): 30 | """Implementation of the ONNX Clip operator.""" 31 | 32 | @classmethod 33 | def _prepare_6( 34 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 35 | ): 36 | node.attrs_dict['amin'] = node.attrs.get('min') 37 | node.attrs_dict['amax'] = node.attrs.get('max') 38 | 39 | @classmethod 40 | def _prepare_13( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 42 | ): 43 | pass 44 | 45 | @classmethod 46 | def version_6( 47 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 48 | ) -> Callable[..., Any]: 49 | """ONNX version_6 Clip op.""" 50 | cls._prepare_6(node, inputs, onnx_clip) 51 | return onnx_clip 52 | 53 | @classmethod 54 | def version_13( 55 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 56 | ) -> Callable[..., Any]: 57 | """ONNX version_13 Clip op.""" 58 | cls._prepare_13(node, inputs, onnx_clip) 59 | return onnx_clip 60 | 61 | 62 | @functools.partial(jax.jit, static_argnames=()) 63 | def onnx_clip(data, amin=None, amax=None): 64 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Clip for more details.""" 65 | if amin is None and amax is None: 66 | return data 67 | return jnp.clip(data, min=amin, max=amax) 68 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/concat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Concat operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 38 | 39 | 40 | @handler.register_op("Concat") 41 | class Concat(handler.Handler): 42 | """Implementation of the ONNX Concat operator.""" 43 | 44 | @classmethod 45 | def _prepare( 46 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 47 | ): 48 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 49 | 50 | @classmethod 51 | def version_4( 52 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 53 | ) -> Callable[..., Any]: 54 | """ONNX version_4 Concat op.""" 55 | cls._prepare(node, inputs, onnx_concat) 56 | return onnx_concat 57 | 58 | @classmethod 59 | def version_11( 60 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 61 | ) -> Callable[..., Any]: 62 | """ONNX version_11 Concat op.""" 63 | cls._prepare(node, inputs, onnx_concat) 64 | return onnx_concat 65 | 66 | @classmethod 67 | def version_13( 68 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 69 | ) -> Callable[..., Any]: 70 | """ONNX version_13 Concat op.""" 71 | cls._prepare(node, inputs, onnx_concat) 72 | return onnx_concat 73 | 74 | 75 | @functools.partial(jax.jit, static_argnames="axis") 76 | def onnx_concat(*input_args, axis=0): 77 | """The internal jax impl for onnx Concat op.""" 78 | return jnp.concatenate(input_args, axis=axis) 79 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/constantofshape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX ConstantOfShape operator.""" 29 | # pylint: disable=unused-argument 30 | from collections.abc import Callable, Sequence 31 | import functools 32 | from typing import Any 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 38 | import onnx 39 | 40 | 41 | @handler.register_op('ConstantOfShape') 42 | class ConstantOfShape(handler.Handler): 43 | """Implementation of the ONNX ConstantOfShape operator.""" 44 | 45 | @classmethod 46 | def _prepare( 47 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 48 | ): 49 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 50 | assert len(inputs) == 1 51 | node.attrs_dict['shape'] = tuple(inputs[0].tolist()) 52 | if 'value' in node.attrs_dict: 53 | np_value = onnx.numpy_helper.to_array(node.attrs_dict['value']) 54 | if len(np_value.tolist()) != 1: 55 | raise ValueError( 56 | 'ONNX ConstantOfShape op `value` attr should contain only 1 value' 57 | f' but got {np_value} on node {node.node_proto}' 58 | ) 59 | node.attrs_dict['value'] = np_value.tolist()[0] 60 | node.attrs_dict['dtype'] = np_value.dtype 61 | 62 | @classmethod 63 | def version_9( 64 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 65 | ) -> Callable[..., Any]: 66 | """ONNX version_9 ConstantOfShape op.""" 67 | cls._prepare(node, inputs, onnx_constantofshape) 68 | return onnx_constantofshape 69 | 70 | 71 | @functools.partial(jax.jit, static_argnames=('value', 'shape', 'dtype')) 72 | def onnx_constantofshape(*input_args, value=0, shape=None, dtype=jnp.float32): 73 | """The internal jax impl for onnx ConstantOfShape op.""" 74 | return jnp.full(fill_value=value, shape=shape, dtype=dtype) 75 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/cos.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Cos operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Cos") 30 | class Cos(handler.Handler): 31 | """Implementation of the ONNX Cos operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_7( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_7 Cos op.""" 44 | cls._prepare(node, inputs, onnx_cos) 45 | return onnx_cos 46 | 47 | 48 | @functools.partial(jax.jit, static_argnames=()) 49 | def onnx_cos(*input_args): 50 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Cos for more details.""" 51 | assert len(input_args) == 1 52 | data = input_args[0] 53 | return jnp.cos(data) 54 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/cosh.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Cosh operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Cosh") 30 | class Cosh(handler.Handler): 31 | """Implementation of the ONNX Cosh operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_9( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_9 Cosh op.""" 44 | cls._prepare(node, inputs, onnx_cosh) 45 | return onnx_cosh 46 | 47 | 48 | @functools.partial(jax.jit, static_argnames=()) 49 | def onnx_cosh(*input_args): 50 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Cosh for more details.""" 51 | assert len(input_args) == 1 52 | data = input_args[0] 53 | return jnp.cosh(data) 54 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/div.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Div operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 38 | 39 | 40 | @handler.register_op("Div") 41 | class Div(handler.Handler): 42 | """Implementation of the ONNX Div operator.""" 43 | 44 | @classmethod 45 | def _prepare( 46 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 47 | ): 48 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 49 | 50 | @classmethod 51 | def version_6( 52 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 53 | ) -> Callable[..., Any]: 54 | """ONNX version_7 Div op.""" 55 | cls._prepare(node, inputs, onnx_div) 56 | return onnx_div 57 | 58 | @classmethod 59 | def version_7( 60 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 61 | ) -> Callable[..., Any]: 62 | """ONNX version_7 Div op.""" 63 | cls._prepare(node, inputs, onnx_div) 64 | return onnx_div 65 | 66 | @classmethod 67 | def version_13( 68 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 69 | ) -> Callable[..., Any]: 70 | """ONNX version_13 Div op.""" 71 | cls._prepare(node, inputs, onnx_div) 72 | return onnx_div 73 | 74 | @classmethod 75 | def version_14( 76 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 77 | ) -> Callable[..., Any]: 78 | """ONNX version_14 Div op.""" 79 | cls._prepare(node, inputs, onnx_div) 80 | return onnx_div 81 | 82 | 83 | @functools.partial(jax.jit, static_argnames=()) 84 | def onnx_div(*input_args): 85 | """The internal jax impl for onnx Div op.""" 86 | assert len(input_args) == 2 87 | x, y = input_args 88 | if jnp.issubdtype(x.dtype, jnp.integer) and jnp.issubdtype( 89 | y.dtype, jnp.integer 90 | ): 91 | return jnp.floor_divide(x, y) 92 | else: 93 | return jnp.true_divide(x, y) 94 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/einsum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Einsum operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Einsum") 30 | class Einsum(handler.Handler): 31 | """Implementation of the ONNX Einsum operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_12( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_12 Einsum op.""" 44 | cls._prepare(node, inputs, onnx_einsum) 45 | return onnx_einsum 46 | 47 | 48 | @functools.partial(jax.jit, static_argnames=("equation",)) 49 | def onnx_einsum(*input_args, equation): 50 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Einsum for more details.""" 51 | return jnp.einsum(equation, *input_args) 52 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/equal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Equal operator.""" 29 | # pylint: disable=unused-argument 30 | # pylint: disable=g-explicit-length-test 31 | from collections.abc import Callable, Sequence 32 | import functools 33 | from typing import Any 34 | 35 | import jax 36 | from jax import numpy as jnp 37 | from jaxonnxruntime.core import handler 38 | from jaxonnxruntime.core import onnx_node 39 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 40 | 41 | 42 | @handler.register_op("Equal") 43 | class Equal(handler.Handler): 44 | """Implementation of the ONNX Equal operator.""" 45 | 46 | @classmethod 47 | def _prepare( 48 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 49 | ): 50 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 51 | 52 | @classmethod 53 | def version_11( 54 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 55 | ) -> Callable[..., Any]: 56 | """ONNX version_11 Equal op.""" 57 | cls._prepare(node, inputs, onnx_equal) 58 | return onnx_equal 59 | 60 | @classmethod 61 | def version_13( 62 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 63 | ) -> Callable[..., Any]: 64 | """ONNX version_13 Equal op.""" 65 | cls._prepare(node, inputs, onnx_equal) 66 | return onnx_equal 67 | 68 | @classmethod 69 | def version_19( 70 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 71 | ) -> Callable[..., Any]: 72 | """ONNX version_19 Equal op.""" 73 | cls._prepare(node, inputs, onnx_equal) 74 | return onnx_equal 75 | 76 | 77 | @functools.partial(jax.jit, static_argnames=()) 78 | def onnx_equal(*input_args): 79 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Equal for more details.""" 80 | assert len(input_args) == 2 81 | return jnp.equal(input_args[0], input_args[1]) 82 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/erf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Erf operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | import jax 22 | from jaxonnxruntime.core import handler 23 | from jaxonnxruntime.core import onnx_node 24 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 25 | 26 | 27 | @handler.register_op("Erf") 28 | class Erf(handler.Handler): 29 | """Implementation of the ONNX Erf operator.""" 30 | 31 | @classmethod 32 | def _prepare( 33 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 34 | ): 35 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 36 | 37 | @classmethod 38 | def version_13( 39 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 40 | ) -> Callable[..., Any]: 41 | """ONNX version_13 Erf op.""" 42 | cls._prepare(node, inputs, onnx_erf) 43 | return onnx_erf 44 | 45 | 46 | @functools.partial(jax.jit, static_argnames=()) 47 | def onnx_erf(*input_args): 48 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Erf for more details.""" 49 | assert len(input_args) == 1 50 | data = input_args[0] 51 | return jax.lax.erf(data) 52 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/exp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Exp operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 38 | 39 | 40 | @handler.register_op("Exp") 41 | class Exp(handler.Handler): 42 | """Implementation of the ONNX Exp operator.""" 43 | 44 | @classmethod 45 | def _prepare( 46 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 47 | ): 48 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 49 | 50 | @classmethod 51 | def version_1( 52 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 53 | ) -> Callable[..., Any]: 54 | """ONNX version_1 Exp op.""" 55 | cls._prepare(node, inputs, onnx_exp) 56 | return onnx_exp 57 | 58 | @classmethod 59 | def version_6( 60 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 61 | ) -> Callable[..., Any]: 62 | """ONNX version_6 Exp op.""" 63 | cls._prepare(node, inputs, onnx_exp) 64 | return onnx_exp 65 | 66 | @classmethod 67 | def version_13( 68 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 69 | ) -> Callable[..., Any]: 70 | """ONNX version_13 Exp op.""" 71 | cls._prepare(node, inputs, onnx_exp) 72 | return onnx_exp 73 | 74 | 75 | @functools.partial(jax.jit, static_argnames=()) 76 | def onnx_exp(*input_args): 77 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Exp.""" 78 | assert len(input_args) == 1 79 | x = input_args[0] 80 | return jnp.exp(x) 81 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/flatten.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Flatten operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | import jax 22 | from jax import numpy as jnp 23 | from jaxonnxruntime.core import handler 24 | from jaxonnxruntime.core import onnx_node 25 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 26 | import numpy as np 27 | 28 | 29 | @handler.register_op("Flatten") 30 | class Flatten(handler.Handler): 31 | """Implementation of the ONNX Flatten operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_1( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_1 Flatten op.""" 44 | cls._prepare(node, inputs, onnx_flatten) 45 | return onnx_flatten 46 | 47 | @classmethod 48 | def version_11( 49 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 50 | ) -> Callable[..., Any]: 51 | """ONNX version_11 Flatten op.""" 52 | cls._prepare(node, inputs, onnx_flatten) 53 | return onnx_flatten 54 | 55 | @classmethod 56 | def version_13( 57 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 58 | ) -> Callable[..., Any]: 59 | """ONNX version_13 Flatten op.""" 60 | cls._prepare(node, inputs, onnx_flatten) 61 | return onnx_flatten 62 | 63 | 64 | @functools.partial(jax.jit, static_argnames="axis") 65 | def onnx_flatten(*input_args, axis): 66 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Flatten for more details.""" 67 | axis = 1 if axis is None else axis 68 | assert len(input_args) == 1 69 | x = input_args[0] 70 | dim = len(x.shape) 71 | assert axis <= dim and axis >= -dim, f"axis should with [{-dim}, {dim}]" 72 | new_shape = ( 73 | (1, -1) if axis == 0 else (-1, np.prod(x.shape[axis:]).astype(int)) 74 | ) 75 | return jnp.reshape(x, new_shape) 76 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/gather.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Gather operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | 38 | 39 | @handler.register_op("Gather") 40 | class Gather(handler.Handler): 41 | """Implementation of the ONNX Gather operator.""" 42 | 43 | @classmethod 44 | def _prepare( 45 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 46 | ): 47 | node.attrs_dict["axis"] = node.attrs.get("axis", 0) 48 | 49 | @classmethod 50 | def version_1( 51 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 52 | ) -> Callable[..., Any]: 53 | """ONNX version_1 Gather op.""" 54 | cls._prepare(node, inputs, onnx_gather) 55 | return onnx_gather 56 | 57 | @classmethod 58 | def version_11( 59 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 60 | ) -> Callable[..., Any]: 61 | """ONNX version_11 Gather op.""" 62 | cls._prepare(node, inputs, onnx_gather) 63 | return onnx_gather 64 | 65 | @classmethod 66 | def version_13( 67 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 68 | ) -> Callable[..., Any]: 69 | """ONNX version_13 Gather op.""" 70 | cls._prepare(node, inputs, onnx_gather) 71 | return onnx_gather 72 | 73 | 74 | @functools.partial(jax.jit, static_argnames="axis") 75 | def onnx_gather(*input_args, axis=0): 76 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Gather.""" 77 | assert len(input_args) == 2 78 | data, indices = input_args 79 | indices = indices.astype(jnp.int64) 80 | return jnp.take(data, indices, axis=axis) 81 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/gatherelements.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX GatherElements operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | 27 | 28 | @handler.register_op('GatherElements') 29 | class GatherElements(handler.Handler): 30 | """Implementation of the ONNX GatherElements operator.""" 31 | 32 | @classmethod 33 | def _prepare( 34 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 35 | ): 36 | node.attrs_dict['axis'] = node.attrs.get('axis', 0) 37 | 38 | @classmethod 39 | def version_13( 40 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 41 | ) -> Callable[..., Any]: 42 | """ONNX version_13 GatherElements op.""" 43 | cls._prepare(node, inputs, onnx_gatherelements) 44 | return onnx_gatherelements 45 | 46 | 47 | @functools.partial(jax.jit, static_argnames=('axis',)) 48 | def onnx_gatherelements(*input_args, axis): 49 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#GatherElements for more details.""" 50 | data, index = input_args 51 | data_swaped = jnp.swapaxes(data, 0, axis) 52 | index_swaped = jnp.swapaxes(index, 0, axis).astype(int) 53 | gathered = jnp.choose(index_swaped, data_swaped, mode='wrap') 54 | return jnp.swapaxes(gathered, 0, axis) 55 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/globalaveragepool.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX GlobalAveragePool operator.""" 29 | # pylint: disable=unused-argument 30 | # pylint: disable=g-explicit-length-test 31 | from collections.abc import Callable, Sequence 32 | import functools 33 | from typing import Any 34 | 35 | import jax 36 | from jax import numpy as jnp 37 | from jaxonnxruntime.core import handler 38 | from jaxonnxruntime.core import onnx_node 39 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 40 | 41 | 42 | @handler.register_op("GlobalAveragePool") 43 | class GlobalAveragePool(handler.Handler): 44 | """Implementation of the ONNX GlobalAveragePool operator.""" 45 | 46 | @classmethod 47 | def _prepare( 48 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 49 | ): 50 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 51 | 52 | @classmethod 53 | def version_1( 54 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 55 | ) -> Callable[..., Any]: 56 | """ONNX version_1 GlobalAveragePool op.""" 57 | cls._prepare(node, inputs, onnx_globalaveragepool) 58 | return onnx_globalaveragepool 59 | 60 | 61 | @functools.partial(jax.jit, static_argnames=()) 62 | def onnx_globalaveragepool(*input_args): 63 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#GlobalAveragePool for more details.""" 64 | assert len(input_args) == 1 65 | x = input_args[0] 66 | y = jnp.mean(x, axis=tuple(range(2, jnp.ndim(x))), keepdims=True) 67 | return y 68 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/greater.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2024 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | 29 | """Define ONNX Greater operator.""" 30 | 31 | from collections.abc import Callable, Sequence 32 | import functools 33 | from typing import Any 34 | 35 | import jax 36 | from jax import numpy as jnp 37 | from jaxonnxruntime.core import handler 38 | from jaxonnxruntime.core import onnx_node 39 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 40 | 41 | 42 | @handler.register_op("Greater") 43 | class Greater(handler.Handler): 44 | """Implementation of the ONNX Greater operator.""" 45 | 46 | @classmethod 47 | def _prepare( 48 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 49 | ): 50 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 51 | 52 | @classmethod 53 | def version_7( 54 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 55 | ) -> Callable[..., Any]: 56 | """ONNX version_7 Greater op.""" 57 | cls._prepare(node, inputs, onnx_greater) 58 | return onnx_greater 59 | 60 | @classmethod 61 | def version_9( 62 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 63 | ) -> Callable[..., Any]: 64 | """ONNX version_9 Greater op.""" 65 | cls._prepare(node, inputs, onnx_greater) 66 | return onnx_greater 67 | 68 | @classmethod 69 | def version_13( 70 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 71 | ) -> Callable[..., Any]: 72 | """ONNX version_13 Greater op.""" 73 | cls._prepare(node, inputs, onnx_greater) 74 | return onnx_greater 75 | 76 | 77 | @functools.partial(jax.jit, static_argnames=()) 78 | def onnx_greater(*input_args): 79 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Greater for more details.""" 80 | assert len(input_args) == 2 81 | return jnp.greater(input_args[0], input_args[1]) 82 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/greaterorequal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2024 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | 29 | """Define ONNX GreaterOrEqual operator.""" 30 | 31 | from collections.abc import Callable, Sequence 32 | import functools 33 | from typing import Any 34 | 35 | import jax 36 | from jax import numpy as jnp 37 | from jaxonnxruntime.core import handler 38 | from jaxonnxruntime.core import onnx_node 39 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 40 | 41 | 42 | @handler.register_op("GreaterOrEqual") 43 | class GreaterOrEqual(handler.Handler): 44 | """Implementation of the ONNX GreaterOrEqual operator.""" 45 | 46 | @classmethod 47 | def _prepare( 48 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 49 | ): 50 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 51 | 52 | @classmethod 53 | def version_12( 54 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 55 | ) -> Callable[..., Any]: 56 | """ONNX version_12 GreaterOrEqual op.""" 57 | cls._prepare(node, inputs, onnx_greaterorequal) 58 | return onnx_greaterorequal 59 | 60 | @classmethod 61 | def version_16( 62 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 63 | ) -> Callable[..., Any]: 64 | """ONNX version_16 GreaterOrEqual op.""" 65 | cls._prepare(node, inputs, onnx_greaterorequal) 66 | return onnx_greaterorequal 67 | 68 | 69 | @functools.partial(jax.jit, static_argnames=()) 70 | def onnx_greaterorequal(*input_args): 71 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#GreaterOrEqual for more details.""" 72 | assert len(input_args) == 2 73 | return jnp.greater_equal(input_args[0], input_args[1]) 74 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/leakyrelu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX LeakyRelu operator.""" 29 | # pylint: disable=unused-argument 30 | # pylint: disable=g-explicit-length-test 31 | from collections.abc import Callable, Sequence 32 | import functools 33 | from typing import Any 34 | 35 | import jax 36 | from jaxonnxruntime.core import handler 37 | from jaxonnxruntime.core import onnx_node 38 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 39 | 40 | 41 | @handler.register_op("LeakyRelu") 42 | class LeakyRelu(handler.Handler): 43 | """Implementation of the ONNX LeakyRelu operator.""" 44 | 45 | @classmethod 46 | def _prepare( 47 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 48 | ): 49 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 50 | node.attrs_dict["alpha"] = node.attrs.get("alpha", 0.01) 51 | 52 | @classmethod 53 | def version_6( 54 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 55 | ) -> Callable[..., Any]: 56 | """ONNX version_6 LeakyRelu op.""" 57 | cls._prepare(node, inputs, onnx_leakyrelu) 58 | return onnx_leakyrelu 59 | 60 | @classmethod 61 | def version_16( 62 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 63 | ) -> Callable[..., Any]: 64 | """ONNX version_16 LeakyRelu op.""" 65 | cls._prepare(node, inputs, onnx_leakyrelu) 66 | return onnx_leakyrelu 67 | 68 | 69 | @functools.partial(jax.jit, static_argnames=("alpha",)) 70 | def onnx_leakyrelu(*input_args, alpha): 71 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#LeakyRelu for more details.""" 72 | assert len(input_args) == 1 73 | (x,) = input_args 74 | return jax.nn.leaky_relu(x, negative_slope=alpha) 75 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/less.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Less operator.""" 29 | # pylint: disable=unused-argument 30 | # pylint: disable=g-explicit-length-test 31 | from collections.abc import Callable, Sequence 32 | import functools 33 | from typing import Any 34 | 35 | import jax 36 | from jax import numpy as jnp 37 | from jaxonnxruntime.core import handler 38 | from jaxonnxruntime.core import onnx_node 39 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 40 | 41 | 42 | @handler.register_op("Less") 43 | class Less(handler.Handler): 44 | """Implementation of the ONNX Less operator.""" 45 | 46 | @classmethod 47 | def _prepare( 48 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 49 | ): 50 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 51 | 52 | @classmethod 53 | def version_7( 54 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 55 | ) -> Callable[..., Any]: 56 | """ONNX version_7 Less op.""" 57 | cls._prepare(node, inputs, onnx_less) 58 | return onnx_less 59 | 60 | @classmethod 61 | def version_9( 62 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 63 | ) -> Callable[..., Any]: 64 | """ONNX version_9 Less op.""" 65 | cls._prepare(node, inputs, onnx_less) 66 | return onnx_less 67 | 68 | @classmethod 69 | def version_13( 70 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 71 | ) -> Callable[..., Any]: 72 | """ONNX version_13 Less op.""" 73 | cls._prepare(node, inputs, onnx_less) 74 | return onnx_less 75 | 76 | 77 | @functools.partial(jax.jit, static_argnames=()) 78 | def onnx_less(*input_args): 79 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Less for more details.""" 80 | assert len(input_args) == 2 81 | return jnp.less(input_args[0], input_args[1]) 82 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/lessorequal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX LessOrEqual operator.""" 29 | # pylint: disable=unused-argument 30 | # pylint: disable=g-explicit-length-test 31 | from collections.abc import Callable, Sequence 32 | import functools 33 | from typing import Any 34 | 35 | import jax 36 | from jax import numpy as jnp 37 | from jaxonnxruntime.core import handler 38 | from jaxonnxruntime.core import onnx_node 39 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 40 | 41 | 42 | @handler.register_op("LessOrEqual") 43 | class LessOrEqual(handler.Handler): 44 | """Implementation of the ONNX LessOrEqual operator.""" 45 | 46 | @classmethod 47 | def _prepare( 48 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 49 | ): 50 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 51 | 52 | @classmethod 53 | def version_12( 54 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 55 | ) -> Callable[..., Any]: 56 | """ONNX version_12 LessOrEqual op.""" 57 | cls._prepare(node, inputs, onnx_lessorequal) 58 | return onnx_lessorequal 59 | 60 | @classmethod 61 | def version_16( 62 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 63 | ) -> Callable[..., Any]: 64 | """ONNX version_16 LessOrEqual op.""" 65 | cls._prepare(node, inputs, onnx_lessorequal) 66 | return onnx_lessorequal 67 | 68 | 69 | @functools.partial(jax.jit, static_argnames=()) 70 | def onnx_lessorequal(*input_args): 71 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#LessOrEqual for more details.""" 72 | assert len(input_args) == 2 73 | return jnp.less_equal(input_args[0], input_args[1]) 74 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/log.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Log operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Log") 30 | class Log(handler.Handler): 31 | """Implementation of the ONNX Log operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_1( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_1 Log op.""" 44 | cls._prepare(node, inputs, onnx_log) 45 | return onnx_log 46 | 47 | @classmethod 48 | def version_6( 49 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 50 | ) -> Callable[..., Any]: 51 | """ONNX version_6 Log op.""" 52 | cls._prepare(node, inputs, onnx_log) 53 | return onnx_log 54 | 55 | @classmethod 56 | def version_13( 57 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 58 | ) -> Callable[..., Any]: 59 | """ONNX version_13 Log op.""" 60 | cls._prepare(node, inputs, onnx_log) 61 | return onnx_log 62 | 63 | 64 | @functools.partial(jax.jit, static_argnames=()) 65 | def onnx_log(*input_args): 66 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Log for more details.""" 67 | assert len(input_args) == 1 68 | data = input_args[0] 69 | return jnp.log(data) 70 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/logsoftmax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX LogSoftmax operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | 27 | 28 | @handler.register_op('LogSoftmax') 29 | class LogSoftmax(handler.Handler): 30 | """Implementation of the ONNX LogSoftmax operator.""" 31 | 32 | @classmethod 33 | def _prepare_1( 34 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 35 | ): 36 | node.attrs_dict['axis'] = node.attrs.get('axis', 1) 37 | 38 | @classmethod 39 | def _prepare_11( 40 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 41 | ): 42 | cls._prepare_1(node, inputs, onnx_jax_impl) 43 | 44 | @classmethod 45 | def _prepare_13( 46 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 47 | ): 48 | node.attrs_dict['axis'] = node.attrs.get('axis', -1) 49 | 50 | @classmethod 51 | def version_1( 52 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 53 | ) -> Callable[..., Any]: 54 | """ONNX version_1 LogSoftmax op.""" 55 | cls._prepare_1(node, inputs, onnx_logsoftmax) 56 | return onnx_logsoftmax 57 | 58 | @classmethod 59 | def version_11( 60 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 61 | ) -> Callable[..., Any]: 62 | """ONNX version_11 LogSoftmax op.""" 63 | cls._prepare_11(node, inputs, onnx_logsoftmax) 64 | return onnx_logsoftmax 65 | 66 | @classmethod 67 | def version_13( 68 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 69 | ) -> Callable[..., Any]: 70 | """ONNX version_13 LogSoftmax op.""" 71 | cls._prepare_13(node, inputs, onnx_logsoftmax) 72 | return onnx_logsoftmax 73 | 74 | 75 | @functools.partial(jax.jit, static_argnames=('axis',)) 76 | def onnx_logsoftmax(*input_args, axis=-1): 77 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#LogSoftmax for more details.""" 78 | assert len(input_args) == 1 79 | data = input_args[0] 80 | res = jax.nn.softmax(data, axis=axis) 81 | return jnp.log(res) 82 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/matmul.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX MatMul operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 38 | 39 | 40 | @handler.register_op("MatMul") 41 | class MatMul(handler.Handler): 42 | """Implementation of the ONNX MatMul operator.""" 43 | 44 | @classmethod 45 | def _prepare( 46 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 47 | ): 48 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 49 | 50 | @classmethod 51 | def version_1( 52 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 53 | ) -> Callable[..., Any]: 54 | """ONNX version_1 MatMul op.""" 55 | cls._prepare(node, inputs, onnx_matmul) 56 | return onnx_matmul 57 | 58 | @classmethod 59 | def version_9( 60 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 61 | ) -> Callable[..., Any]: 62 | """ONNX version_9 MatMul op.""" 63 | cls._prepare(node, inputs, onnx_matmul) 64 | return onnx_matmul 65 | 66 | @classmethod 67 | def version_13( 68 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 69 | ) -> Callable[..., Any]: 70 | """ONNX version_13 MatMul op.""" 71 | cls._prepare(node, inputs, onnx_matmul) 72 | return onnx_matmul 73 | 74 | 75 | @functools.partial(jax.jit, static_argnames=()) 76 | def onnx_matmul(*input_args): 77 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#MatMul.""" 78 | assert len(input_args) == 2 79 | a, b = input_args 80 | return jnp.matmul(a, b) 81 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/min.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Min operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Min") 30 | class Min(handler.Handler): 31 | """Implementation of the ONNX Min operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | node.attrs_dict["arg_num"] = len(node.inputs) 39 | 40 | @classmethod 41 | def version_6( 42 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 43 | ) -> Callable[..., Any]: 44 | """ONNX version_6 Min op.""" 45 | cls._prepare(node, inputs, onnx_min) 46 | return onnx_min 47 | 48 | @classmethod 49 | def version_8( 50 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 51 | ) -> Callable[..., Any]: 52 | """ONNX version_8 Min op.""" 53 | cls._prepare(node, inputs, onnx_min) 54 | return onnx_min 55 | 56 | @classmethod 57 | def version_12( 58 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 59 | ) -> Callable[..., Any]: 60 | """ONNX version_12 Min op.""" 61 | cls._prepare(node, inputs, onnx_min) 62 | return onnx_min 63 | 64 | @classmethod 65 | def version_13( 66 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 67 | ) -> Callable[..., Any]: 68 | """ONNX version_13 Min op.""" 69 | cls._prepare(node, inputs, onnx_min) 70 | return onnx_min 71 | 72 | 73 | @functools.partial(jax.jit, static_argnames=("arg_num",)) 74 | def onnx_min(*input_args, arg_num): 75 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Min for more details.""" 76 | assert len(input_args) == arg_num 77 | res = input_args[0] 78 | for i in range(arg_num): 79 | res = jnp.minimum(res, input_args[i]) 80 | return res 81 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/mul.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Mul operator.""" 16 | from collections.abc import Callable, Sequence 17 | import functools 18 | from typing import Any 19 | 20 | import jax 21 | from jax import numpy as jnp 22 | from jaxonnxruntime.core import handler 23 | from jaxonnxruntime.core import onnx_node 24 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 25 | 26 | 27 | @handler.register_op("Mul") 28 | class Mul(handler.Handler): 29 | """Implementation of the ONNX Mul operator.""" 30 | 31 | @classmethod 32 | def _prepare( 33 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 34 | ): 35 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 36 | 37 | @classmethod 38 | def version_6( 39 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 40 | ) -> Callable[..., Any]: 41 | """ONNX version_6 Mul op.""" 42 | cls._prepare(node, inputs, onnx_mul) 43 | return onnx_mul 44 | 45 | @classmethod 46 | def version_7( 47 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 48 | ) -> Callable[..., Any]: 49 | """ONNX version_7 Mul op.""" 50 | cls._prepare(node, inputs, onnx_mul) 51 | return onnx_mul 52 | 53 | @classmethod 54 | def version_13( 55 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 56 | ) -> Callable[..., Any]: 57 | """ONNX version_13 Mul op.""" 58 | cls._prepare(node, inputs, onnx_mul) 59 | return onnx_mul 60 | 61 | @classmethod 62 | def version_14( 63 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 64 | ) -> Callable[..., Any]: 65 | """ONNX version_14 Mul op.""" 66 | cls._prepare(node, inputs, onnx_mul) 67 | return onnx_mul 68 | 69 | 70 | @functools.partial(jax.jit, static_argnames=()) 71 | def onnx_mul(*input_args): 72 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Mul.""" 73 | assert len(input_args) == 2 74 | a, b = input_args 75 | return jnp.multiply(a, b) 76 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/neg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Neg operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Neg") 30 | class Neg(handler.Handler): 31 | """Implementation of the ONNX Neg operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_6( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_6 Neg op.""" 44 | cls._prepare(node, inputs, onnx_neg) 45 | return onnx_neg 46 | 47 | @classmethod 48 | def version_13( 49 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 50 | ) -> Callable[..., Any]: 51 | """ONNX version_13 Neg op.""" 52 | cls._prepare(node, inputs, onnx_neg) 53 | return onnx_neg 54 | 55 | 56 | @functools.partial(jax.jit, static_argnames=()) 57 | def onnx_neg(*input_args): 58 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Neg for more details.""" 59 | assert len(input_args) == 1 60 | return jnp.negative(input_args[0]) 61 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/nonzero.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX NonZero operator.""" 16 | from collections.abc import Callable, Sequence 17 | import functools 18 | import logging 19 | from typing import Any 20 | 21 | import jax 22 | from jax import numpy as jnp 23 | from jaxonnxruntime.core import config_class 24 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 25 | 26 | config = config_class.config 27 | from jaxonnxruntime.core import handler 28 | from jaxonnxruntime.core import onnx_node 29 | 30 | 31 | @handler.register_op("NonZero") 32 | class NonZero(handler.Handler): 33 | """Implementation of the ONNX NonZero operator.""" 34 | 35 | @classmethod 36 | def _prepare( 37 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 38 | ): 39 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 40 | 41 | assert len(inputs) == 1 42 | if config.jaxort_nonzero_use_fully_padding: 43 | node.attrs_dict["size"] = inputs[0].size 44 | if node.attrs_dict["size"] is None: 45 | raise ValueError( 46 | "NonZero Jax implementation must have static size attribute but not." 47 | ) 48 | 49 | @classmethod 50 | def version_9( 51 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 52 | ) -> Callable[..., Any]: 53 | """ONNX version_9 NonZero op.""" 54 | cls._prepare(node, inputs, onnx_nonzero) 55 | return onnx_nonzero 56 | 57 | @classmethod 58 | def version_13( 59 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 60 | ) -> Callable[..., Any]: 61 | """ONNX version_13 NonZero op.""" 62 | cls._prepare(node, inputs, onnx_nonzero) 63 | return onnx_nonzero 64 | 65 | 66 | @functools.partial(jax.jit, static_argnames="size") 67 | def onnx_nonzero(*input_args, size): 68 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#NonZero.""" 69 | assert len(input_args) == 1 70 | logging.warning("onnx_nonzero cannot support jax.jit mode.") 71 | (x,) = input_args 72 | return jnp.stack(jnp.nonzero(x, size=size)) 73 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/onehot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX OneHot operator.""" 29 | # pylint: disable=unused-argument 30 | # pylint: disable=g-explicit-length-test 31 | 32 | from collections.abc import Callable, Sequence 33 | import functools 34 | from typing import Any 35 | 36 | import jax 37 | from jax import numpy as jnp 38 | from jaxonnxruntime.core import handler 39 | from jaxonnxruntime.core import onnx_node 40 | 41 | 42 | @handler.register_op("OneHot") 43 | class OneHot(handler.Handler): 44 | """Implementation of the ONNX OneHot operator.""" 45 | 46 | @classmethod 47 | def _prepare( 48 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 49 | ): 50 | node.attrs_dict["axis"] = node.attrs.get("axis", -1) 51 | node.attrs_dict["depth"] = int(inputs[1].item()) 52 | 53 | @classmethod 54 | def version_11( 55 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 56 | ) -> Callable[..., Any]: 57 | """ONNX version_11 OneHot op.""" 58 | cls._prepare(node, inputs, onnx_onehot) 59 | return onnx_onehot 60 | 61 | 62 | @functools.partial(jax.jit, static_argnames=("depth", "axis")) 63 | def onnx_onehot(*input_args, depth, axis): 64 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#OneHot for more details.""" 65 | assert len(input_args) == 3 66 | indices, _, values = input_args 67 | indices = jnp.mod(indices, depth) 68 | encode = jax.nn.one_hot(indices, depth, axis=axis) 69 | encode = encode * (values[1] - values[0]) + values[0] 70 | return encode.astype(values.dtype) 71 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/onehot_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from jaxonnxruntime.core import onnx_node 17 | from jaxonnxruntime.onnx_ops import onehot 18 | import numpy as np 19 | 20 | import onnx 21 | 22 | NodeProto = onnx.NodeProto 23 | AttributeProto = onnx.AttributeProto 24 | OneHot = onehot.OneHot 25 | OnnxNode = onnx_node.OnnxNode 26 | 27 | 28 | class OneHotTest(absltest.TestCase): 29 | 30 | def setUp(self): 31 | super().setUp() 32 | node_proto = NodeProto(op_type='OneHot', input=['input'], output=['output']) 33 | self.node_onehot = OnnxNode(node_proto) 34 | 35 | def test_onehot(self): 36 | indices = np.array([0, -7, -8], dtype=np.int64) 37 | depth = np.float32(10) 38 | off_value, on_value = 1, 3 39 | values = np.array([off_value, on_value], dtype=np.float32) 40 | inputs = [indices, depth, values] 41 | 42 | onehot_func = OneHot.version_11(self.node_onehot, inputs) 43 | 44 | outputs = onehot_func(*inputs, **self.node_onehot.attrs_dict) 45 | 46 | expect = np.array([ 47 | [3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 48 | [1.0, 1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 49 | [1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 50 | ]) 51 | np.testing.assert_array_equal(outputs, expect) 52 | 53 | def test_onehot_static_depth(self): 54 | indices = np.array([0, -7, -8], dtype=np.int64) 55 | depth = np.float32(10) 56 | off_value, on_value = 1, 3 57 | values = np.array([off_value, on_value], dtype=np.float32) 58 | inputs = [indices, depth, values] 59 | 60 | onehot_func = OneHot.version_11(self.node_onehot, inputs) 61 | 62 | outputs_depth_10 = onehot_func(*inputs, **self.node_onehot.attrs_dict) 63 | 64 | depth = np.float32(8) 65 | inputs = [indices, depth, values] 66 | outputs_depth_8 = onehot_func(*inputs, **self.node_onehot.attrs_dict) 67 | 68 | np.testing.assert_array_equal(outputs_depth_10, outputs_depth_8) 69 | 70 | expect_outputs_depth_8 = np.array([ 71 | [3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 72 | [1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 73 | [3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 74 | ]) 75 | np.testing.assert_raises( 76 | AssertionError, 77 | np.testing.assert_array_equal, 78 | outputs_depth_8, 79 | expect_outputs_depth_8, 80 | ) 81 | 82 | 83 | if __name__ == '__main__': 84 | absltest.main() 85 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/onnx_not.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Not operator.""" 16 | from collections.abc import Callable, Sequence 17 | import functools 18 | from typing import Any 19 | 20 | import jax 21 | from jax import numpy as jnp 22 | from jaxonnxruntime.core import handler 23 | from jaxonnxruntime.core import onnx_node 24 | 25 | 26 | @handler.register_op("Not") 27 | class Not(handler.Handler): 28 | """Implementation of the ONNX Not operator.""" 29 | 30 | @classmethod 31 | def version_1( 32 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 33 | ) -> Callable[..., Any]: 34 | """ONNX version_1 Not op.""" 35 | return onnx_not 36 | 37 | 38 | @functools.partial(jax.jit, static_argnames=()) 39 | def onnx_not(*input_args): 40 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Not.""" 41 | assert len(input_args) == 1 42 | x = input_args[0] 43 | return (jnp.logical_not(x),) 44 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/onnx_ops_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for creating onnx ops.""" 16 | 17 | from collections.abc import Callable 18 | import inspect 19 | from typing import Any 20 | 21 | from jaxonnxruntime.core import onnx_node 22 | 23 | 24 | def update_node_attrs_dict( 25 | node: onnx_node.OnnxNode, onnx_jax_impl: Callable[..., Any] 26 | ): 27 | """Updates the node's attrs_dict with the values from the node's attrs.""" 28 | sig = inspect.signature(onnx_jax_impl) 29 | kwparams = [ 30 | param.name 31 | for param in sig.parameters.values() 32 | if param.kind == inspect.Parameter.KEYWORD_ONLY 33 | ] 34 | for name in kwparams: 35 | node.attrs_dict[name] = node.attrs.get(name, None) 36 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/or_op.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Or operator.""" 29 | # pylint: disable=unused-argument 30 | # pylint: disable=g-explicit-length-test 31 | from collections.abc import Callable, Sequence 32 | import functools 33 | from typing import Any 34 | 35 | import jax 36 | from jax import numpy as jnp 37 | from jaxonnxruntime.core import handler 38 | from jaxonnxruntime.core import onnx_node 39 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 40 | 41 | 42 | @handler.register_op("Or") 43 | class Or(handler.Handler): 44 | """Implementation of the ONNX Or operator.""" 45 | 46 | @classmethod 47 | def _prepare( 48 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 49 | ): 50 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 51 | 52 | @classmethod 53 | def version_7( 54 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 55 | ) -> Callable[..., Any]: 56 | """ONNX version_7 Or op.""" 57 | cls._prepare(node, inputs, onnx_or) 58 | return onnx_or 59 | 60 | 61 | @functools.partial(jax.jit, static_argnames=()) 62 | def onnx_or(*input_args): 63 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Or for more details.""" 64 | assert len(input_args) == 2 65 | return jnp.logical_or(input_args[0], input_args[1]) 66 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/pow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Pow operator.""" 16 | from collections.abc import Callable, Sequence 17 | import functools 18 | from typing import Any 19 | 20 | import jax 21 | from jax import numpy as jnp 22 | from jaxonnxruntime.core import handler 23 | from jaxonnxruntime.core import onnx_node 24 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 25 | 26 | 27 | @handler.register_op("Pow") 28 | class Pow(handler.Handler): 29 | """Implementation of the ONNX Pow operator.""" 30 | 31 | @classmethod 32 | def _prepare( 33 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 34 | ): 35 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 36 | 37 | @classmethod 38 | def version_1( 39 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 40 | ) -> Callable[..., Any]: 41 | """ONNX version_1 Pow op.""" 42 | cls._prepare(node, inputs, onnx_pow) 43 | return onnx_pow 44 | 45 | @classmethod 46 | def version_7( 47 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 48 | ) -> Callable[..., Any]: 49 | """ONNX version_7 Pow op.""" 50 | cls._prepare(node, inputs, onnx_pow) 51 | return onnx_pow 52 | 53 | @classmethod 54 | def version_12( 55 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 56 | ) -> Callable[..., Any]: 57 | """ONNX version_12 Pow op.""" 58 | cls._prepare(node, inputs, onnx_pow) 59 | return onnx_pow 60 | 61 | @classmethod 62 | def version_13( 63 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 64 | ) -> Callable[..., Any]: 65 | """ONNX version_13 Pow op.""" 66 | cls._prepare(node, inputs, onnx_pow) 67 | return onnx_pow 68 | 69 | @classmethod 70 | def version_15( 71 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 72 | ) -> Callable[..., Any]: 73 | """ONNX version_15 Pow op.""" 74 | cls._prepare(node, inputs, onnx_pow) 75 | return onnx_pow 76 | 77 | 78 | @functools.partial(jax.jit, static_argnames=()) 79 | def onnx_pow(a, b): 80 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Pow.""" 81 | return jnp.power(a, b).astype(a.dtype) 82 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/prelu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX PRelu operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("PRelu") 30 | class PRelu(handler.Handler): 31 | """Implementation of the ONNX PRelu operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_6( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_6 PRelu op.""" 44 | cls._prepare(node, inputs, onnx_prelu) 45 | return onnx_prelu 46 | 47 | @classmethod 48 | def version_16( 49 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 50 | ) -> Callable[..., Any]: 51 | """ONNX version_16 PRelu op.""" 52 | cls._prepare(node, inputs, onnx_prelu) 53 | return onnx_prelu 54 | 55 | 56 | @functools.partial(jax.jit, static_argnames=()) 57 | def onnx_prelu(*input_args): 58 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#PRelu for more details.""" 59 | assert len(input_args) == 2 60 | data, slope = input_args 61 | return jnp.where(data >= 0, data, slope * data) 62 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/reciprocal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Reciprocal operator.""" 29 | # pylint: disable=unused-argument 30 | # pylint: disable=g-explicit-length-test 31 | from collections.abc import Callable, Sequence 32 | import functools 33 | from typing import Any 34 | 35 | import jax 36 | from jax import numpy as jnp 37 | from jaxonnxruntime.core import handler 38 | from jaxonnxruntime.core import onnx_node 39 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 40 | 41 | 42 | @handler.register_op("Reciprocal") 43 | class Reciprocal(handler.Handler): 44 | """Implementation of the ONNX Reciprocal operator.""" 45 | 46 | @classmethod 47 | def _prepare( 48 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 49 | ): 50 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 51 | 52 | @classmethod 53 | def version_6( 54 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 55 | ) -> Callable[..., Any]: 56 | """ONNX version_6 Reciprocal op.""" 57 | cls._prepare(node, inputs, onnx_reciprocal) 58 | return onnx_reciprocal 59 | 60 | @classmethod 61 | def version_13( 62 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 63 | ) -> Callable[..., Any]: 64 | """ONNX version_13 Reciprocal op.""" 65 | cls._prepare(node, inputs, onnx_reciprocal) 66 | return onnx_reciprocal 67 | 68 | 69 | @functools.partial(jax.jit, static_argnames=()) 70 | def onnx_reciprocal(*input_args): 71 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Reciprocal for more details.""" 72 | assert len(input_args) == 1 73 | return jnp.reciprocal(input_args[0]) 74 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/reducemax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX ReduceMax operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | 38 | 39 | @handler.register_op('ReduceMax') 40 | class ReduceMax(handler.Handler): 41 | """Implementation of the ONNX ReduceMax operator.""" 42 | 43 | @classmethod 44 | def _prepare( 45 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 46 | ): 47 | node.attrs_dict['axes'] = node.attrs.get('axes') 48 | if len(inputs) >= 2: 49 | node.attrs_dict['axes'] = tuple(inputs[1].tolist()) 50 | node.attrs_dict['axes'] = ( 51 | None if len(node.attrs_dict['axes']) == 0 else node.attrs_dict['axes'] 52 | ) 53 | node.attrs_dict['keepdims'] = node.attrs.get('keepdims', 1) 54 | node.attrs_dict['noop_with_empty_axes'] = node.attrs.get( 55 | 'noop_with_empty_axes', 0 56 | ) 57 | 58 | @classmethod 59 | def version_13( 60 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 61 | ) -> Callable[..., Any]: 62 | """ONNX version_13 ReduceMax op.""" 63 | cls._prepare(node, inputs, onnx_reducemax) 64 | return onnx_reducemax 65 | 66 | @classmethod 67 | def version_18( 68 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 69 | ) -> Callable[..., Any]: 70 | """ONNX version_18 ReduceMax op.""" 71 | cls._prepare(node, inputs, onnx_reducemax) 72 | return onnx_reducemax 73 | 74 | 75 | @functools.partial( 76 | jax.jit, static_argnames=('axes', 'keepdims', 'noop_with_empty_axes') 77 | ) 78 | def onnx_reducemax( 79 | *input_args, 80 | axes=None, 81 | keepdims=1, 82 | noop_with_empty_axes=0, 83 | ): 84 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#ReduceSum.""" 85 | assert len(input_args) == 1 or len(input_args) == 2 86 | data = input_args[0] 87 | if axes is None and noop_with_empty_axes > 0: 88 | return data 89 | return jnp.max(data, axis=axes, keepdims=keepdims > 0) 90 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/relu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Relu operator.""" 29 | # pylint: disable=unused-argument 30 | # pylint: disable=g-explicit-length-test 31 | from collections.abc import Callable, Sequence 32 | import functools 33 | from typing import Any 34 | 35 | import jax 36 | from jaxonnxruntime.core import handler 37 | from jaxonnxruntime.core import onnx_node 38 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 39 | 40 | 41 | @handler.register_op("Relu") 42 | class Relu(handler.Handler): 43 | """Implementation of the ONNX Relu operator.""" 44 | 45 | @classmethod 46 | def _prepare( 47 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 48 | ): 49 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 50 | 51 | @classmethod 52 | def version_6( 53 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 54 | ) -> Callable[..., Any]: 55 | """ONNX version_6 Relu op.""" 56 | cls._prepare(node, inputs, onnx_relu) 57 | return onnx_relu 58 | 59 | @classmethod 60 | def version_14( 61 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 62 | ) -> Callable[..., Any]: 63 | """ONNX version_14 Relu op.""" 64 | cls._prepare(node, inputs, onnx_relu) 65 | return onnx_relu 66 | 67 | 68 | @functools.partial(jax.jit, static_argnames=()) 69 | def onnx_relu(*input_args): 70 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Relu for more details.""" 71 | assert len(input_args) == 1 72 | x = input_args[0] 73 | return jax.nn.relu(x) 74 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/selu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Selu operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jaxonnxruntime.core import handler 24 | from jaxonnxruntime.core import onnx_node 25 | 26 | 27 | @handler.register_op('Selu') 28 | class Selu(handler.Handler): 29 | """Implementation of the ONNX Selu operator.""" 30 | 31 | @classmethod 32 | def _prepare( 33 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 34 | ): 35 | node.attrs_dict['alpha'] = node.attrs.get( 36 | 'alpha', 1.67326319217681884765625 37 | ) 38 | node.attrs_dict['gamma'] = node.attrs.get( 39 | 'gamma', 1.05070102214813232421875 40 | ) 41 | 42 | @classmethod 43 | def version_6( 44 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 45 | ) -> Callable[..., Any]: 46 | """ONNX version_6 Selu op.""" 47 | cls._prepare(node, inputs, onnx_selu) 48 | return onnx_selu 49 | 50 | 51 | @functools.partial(jax.jit, static_argnames=('alpha', 'gamma')) 52 | def onnx_selu(*input_args, alpha, gamma): 53 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Selu for more details.""" 54 | assert len(input_args) == 1 55 | data = input_args[0] 56 | return gamma * jax.nn.elu(data, alpha) 57 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/shape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Shape operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 38 | 39 | 40 | @handler.register_op('Shape') 41 | class Shape(handler.Handler): 42 | """Implementation of the ONNX Shape operator.""" 43 | 44 | @classmethod 45 | def _prepare( 46 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 47 | ): 48 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 49 | 50 | @classmethod 51 | def version_1( 52 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 53 | ) -> Callable[..., Any]: 54 | """ONNX version_1 Shape op.""" 55 | cls._prepare(node, inputs, onnx_shape) 56 | return onnx_shape 57 | 58 | @classmethod 59 | def version_13( 60 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 61 | ) -> Callable[..., Any]: 62 | """ONNX version_13 Shape op.""" 63 | cls._prepare(node, inputs, onnx_shape) 64 | return onnx_shape 65 | 66 | @classmethod 67 | def version_15( 68 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 69 | ) -> Callable[..., Any]: 70 | """ONNX version_15 Shape op.""" 71 | cls._prepare(node, inputs, onnx_shape) 72 | return onnx_shape 73 | 74 | @classmethod 75 | def version_19( 76 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 77 | ) -> Callable[..., Any]: 78 | """ONNX version_19 Shape op.""" 79 | cls._prepare(node, inputs, onnx_shape) 80 | return onnx_shape 81 | 82 | 83 | @functools.partial(jax.jit, static_argnames=('start', 'end')) 84 | def onnx_shape(*input_args, start=None, end=None): 85 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Shape.""" 86 | assert len(input_args) == 1 87 | x = input_args[0] 88 | dims = x.shape[start:end] 89 | return jnp.asarray(dims) 90 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/sigmoid.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Sigmoid operator.""" 29 | # pylint: disable=unused-argument 30 | # pylint: disable=g-explicit-length-test 31 | from collections.abc import Callable, Sequence 32 | import functools 33 | from typing import Any 34 | 35 | import jax 36 | from jaxonnxruntime.core import handler 37 | from jaxonnxruntime.core import onnx_node 38 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 39 | 40 | 41 | @handler.register_op("Sigmoid") 42 | class Sigmoid(handler.Handler): 43 | """Implementation of the ONNX Sigmoid operator.""" 44 | 45 | @classmethod 46 | def _prepare( 47 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 48 | ): 49 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 50 | 51 | @classmethod 52 | def version_13( 53 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 54 | ) -> Callable[..., Any]: 55 | """ONNX version_13 Sigmoid op.""" 56 | cls._prepare(node, inputs, onnx_sigmoid) 57 | return onnx_sigmoid 58 | 59 | @classmethod 60 | def version_6( 61 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 62 | ) -> Callable[..., Any]: 63 | """ONNX version_6 Sigmoid op.""" 64 | cls._prepare(node, inputs, onnx_sigmoid) 65 | return onnx_sigmoid 66 | 67 | 68 | @functools.partial(jax.jit, static_argnames=()) 69 | def onnx_sigmoid(*input_args): 70 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Sigmoid for more details.""" 71 | assert len(input_args) == 1 72 | return jax.nn.sigmoid(input_args[0]) 73 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/sin.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Sin operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | @handler.register_op("Sin") 29 | class Sin(handler.Handler): 30 | """Implementation of the ONNX Sin operator.""" 31 | 32 | @classmethod 33 | def _prepare( 34 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 35 | ): 36 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 37 | 38 | @classmethod 39 | def version_7( 40 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 41 | ) -> Callable[..., Any]: 42 | """ONNX version_7 Sin op.""" 43 | cls._prepare(node, inputs, onnx_sin) 44 | return onnx_sin 45 | 46 | 47 | @functools.partial(jax.jit, static_argnames=()) 48 | def onnx_sin(*input_args): 49 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Sin for more details.""" 50 | assert len(input_args) == 1 51 | data = input_args[0] 52 | return jnp.sin(data) 53 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/sinh.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Sinh operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import handler 25 | from jaxonnxruntime.core import onnx_node 26 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 27 | 28 | 29 | @handler.register_op("Sinh") 30 | class Sinh(handler.Handler): 31 | """Implementation of the ONNX Sinh operator.""" 32 | 33 | @classmethod 34 | def _prepare( 35 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 36 | ): 37 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 38 | 39 | @classmethod 40 | def version_9( 41 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 42 | ) -> Callable[..., Any]: 43 | """ONNX version_9 Sinh op.""" 44 | cls._prepare(node, inputs, onnx_sinh) 45 | return onnx_sinh 46 | 47 | 48 | @functools.partial(jax.jit, static_argnames=()) 49 | def onnx_sinh(*input_args): 50 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Sinh for more details.""" 51 | assert len(input_args) == 1 52 | data = input_args[0] 53 | return jnp.sinh(data) 54 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/softmax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Softmax operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | import jax 33 | from jaxonnxruntime.core import handler 34 | from jaxonnxruntime.core import onnx_node 35 | 36 | 37 | @handler.register_op('Softmax') 38 | class Softmax(handler.Handler): 39 | """Implementation of the ONNX Softmax operator.""" 40 | 41 | @classmethod 42 | def _prepare_11( 43 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 44 | ): 45 | node.attrs_dict['axis'] = node.attrs.get('axis', 1) 46 | 47 | @classmethod 48 | def _prepare_13( 49 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 50 | ): 51 | node.attrs_dict['axis'] = node.attrs.get('axis', -1) 52 | 53 | @classmethod 54 | def version_1( 55 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 56 | ) -> Callable[..., Any]: 57 | """ONNX version_1 Softmax op.""" 58 | cls._prepare_11(node, inputs, onnx_softmax) 59 | return onnx_softmax 60 | 61 | @classmethod 62 | def version_11( 63 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 64 | ) -> Callable[..., Any]: 65 | """ONNX version_11 Softmax op.""" 66 | cls._prepare_11(node, inputs, onnx_softmax) 67 | return onnx_softmax 68 | 69 | @classmethod 70 | def version_13( 71 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 72 | ) -> Callable[..., Any]: 73 | """ONNX version_13 Softmax op.""" 74 | cls._prepare_13(node, inputs, onnx_softmax) 75 | return onnx_softmax 76 | 77 | 78 | @functools.partial(jax.jit, static_argnames=('axis',)) 79 | def onnx_softmax(*input_args, axis): 80 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Softmax.""" 81 | assert len(input_args) == 1 82 | x = input_args[0] 83 | return jax.nn.softmax(x, axis=axis) 84 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/softplus.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Softplus operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jaxonnxruntime.core import handler 24 | from jaxonnxruntime.core import onnx_node 25 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 26 | 27 | 28 | @handler.register_op("Softplus") 29 | class Softplus(handler.Handler): 30 | """Implementation of the ONNX Softplus operator.""" 31 | 32 | @classmethod 33 | def _prepare( 34 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 35 | ): 36 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 37 | 38 | @classmethod 39 | def version_1( 40 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 41 | ) -> Callable[..., Any]: 42 | """ONNX version_1 Softplus op.""" 43 | cls._prepare(node, inputs, onnx_softplus) 44 | return onnx_softplus 45 | 46 | 47 | @functools.partial(jax.jit, static_argnames=()) 48 | def onnx_softplus(*input_args): 49 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Softplus for more details.""" 50 | assert len(input_args) == 1 51 | data = input_args[0] 52 | return jax.nn.softplus(data) 53 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/sqrt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Sqrt operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 38 | 39 | 40 | @handler.register_op("Sqrt") 41 | class Sqrt(handler.Handler): 42 | """Implementation of the ONNX Sqrt operator.""" 43 | 44 | @classmethod 45 | def _prepare( 46 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 47 | ): 48 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 49 | 50 | @classmethod 51 | def version_6( 52 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 53 | ) -> Callable[..., Any]: 54 | """ONNX version_6 Sqrt op.""" 55 | cls._prepare(node, inputs, onnx_sqrt) 56 | return onnx_sqrt 57 | 58 | @classmethod 59 | def version_13( 60 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 61 | ) -> Callable[..., Any]: 62 | """ONNX version_13 Sqrt op.""" 63 | cls._prepare(node, inputs, onnx_sqrt) 64 | return onnx_sqrt 65 | 66 | 67 | @functools.partial(jax.jit, static_argnames=()) 68 | def onnx_sqrt(*input_args): 69 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Sqrt.""" 70 | assert len(input_args) == 1 71 | return jnp.sqrt(*input_args) 72 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/squeeze.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Squeeze operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | 38 | 39 | @handler.register_op('Squeeze') 40 | class Squeeze(handler.Handler): 41 | """Implementation of the ONNX Squeeze operator.""" 42 | 43 | @classmethod 44 | def _prepare( 45 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 46 | ): 47 | if len(inputs) == 1: 48 | node.attrs_dict['axis'] = None 49 | else: 50 | node.attrs_dict['axis'] = tuple(inputs[1].tolist()) 51 | 52 | @classmethod 53 | def version_1( 54 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 55 | ) -> Callable[..., Any]: 56 | """ONNX version_1 Squeeze op.""" 57 | cls._prepare(node, inputs, onnx_squeeze) 58 | return onnx_squeeze 59 | 60 | @classmethod 61 | def version_11( 62 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 63 | ) -> Callable[..., Any]: 64 | """ONNX version_11 Squeeze op.""" 65 | cls._prepare(node, inputs, onnx_squeeze) 66 | return onnx_squeeze 67 | 68 | @classmethod 69 | def version_13( 70 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 71 | ) -> Callable[..., Any]: 72 | """ONNX version_13 Squeeze op.""" 73 | cls._prepare(node, inputs, onnx_squeeze) 74 | return onnx_squeeze 75 | 76 | 77 | @functools.partial(jax.jit, static_argnames='axis') 78 | def onnx_squeeze(*input_args, axis): 79 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Squeeze.""" 80 | x = input_args[0] 81 | return jnp.squeeze(x, axis=axis) 82 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/sub.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Sub operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 38 | 39 | 40 | @handler.register_op("Sub") 41 | class Sub(handler.Handler): 42 | """Implementation of the ONNX Sub operator.""" 43 | 44 | @classmethod 45 | def _prepare( 46 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 47 | ): 48 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 49 | 50 | @classmethod 51 | def version_6( 52 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 53 | ) -> Callable[..., Any]: 54 | """ONNX version_6 Sub op.""" 55 | cls._prepare(node, inputs, onnx_sub) 56 | return onnx_sub 57 | 58 | @classmethod 59 | def version_7( 60 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 61 | ) -> Callable[..., Any]: 62 | """ONNX version_7 Sub op.""" 63 | cls._prepare(node, inputs, onnx_sub) 64 | return onnx_sub 65 | 66 | @classmethod 67 | def version_13( 68 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 69 | ) -> Callable[..., Any]: 70 | """ONNX version_13 Sub op.""" 71 | cls._prepare(node, inputs, onnx_sub) 72 | return onnx_sub 73 | 74 | @classmethod 75 | def version_14( 76 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 77 | ) -> Callable[..., Any]: 78 | """ONNX version_14 Sub op.""" 79 | cls._prepare(node, inputs, onnx_sub) 80 | return onnx_sub 81 | 82 | 83 | @functools.partial(jax.jit, static_argnames=()) 84 | def onnx_sub(*input_args): 85 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Sub.""" 86 | assert len(input_args) == 2 87 | a, b = input_args 88 | return jnp.subtract(a, b) 89 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/sum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Sum operator.""" 29 | # pylint: disable=unused-argument 30 | # pylint: disable=g-explicit-length-test 31 | from collections.abc import Callable, Sequence 32 | import functools 33 | from typing import Any 34 | 35 | import jax 36 | from jax import numpy as jnp 37 | from jaxonnxruntime.core import handler 38 | from jaxonnxruntime.core import onnx_node 39 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 40 | 41 | 42 | @handler.register_op("Sum") 43 | class Sum(handler.Handler): 44 | """Implementation of the ONNX Sum operator.""" 45 | 46 | @classmethod 47 | def _prepare( 48 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 49 | ): 50 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 51 | 52 | @classmethod 53 | def version_6( 54 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 55 | ) -> Callable[..., Any]: 56 | """ONNX version_6 Sum op.""" 57 | cls._prepare(node, inputs, onnx_sum) 58 | return onnx_sum 59 | 60 | @classmethod 61 | def version_8( 62 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 63 | ) -> Callable[..., Any]: 64 | """ONNX version_8 Sum op.""" 65 | cls._prepare(node, inputs, onnx_sum) 66 | return onnx_sum 67 | 68 | @classmethod 69 | def version_13( 70 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 71 | ) -> Callable[..., Any]: 72 | """ONNX version_13 Sum op.""" 73 | cls._prepare(node, inputs, onnx_sum) 74 | return onnx_sum 75 | 76 | 77 | @functools.partial(jax.jit, static_argnames=()) 78 | def onnx_sum(*input_args): 79 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Sum for more details.""" 80 | return jnp.sum(jnp.array(input_args), axis=0) 81 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/tanh.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Tanh operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 38 | 39 | 40 | @handler.register_op("Tanh") 41 | class Tanh(handler.Handler): 42 | """Implementation of the ONNX Tanh operator.""" 43 | 44 | @classmethod 45 | def _prepare( 46 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 47 | ): 48 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 49 | 50 | @classmethod 51 | def version_6( 52 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 53 | ) -> Callable[..., Any]: 54 | """ONNX version_6 Tanh op.""" 55 | cls._prepare(node, inputs, onnx_tanh) 56 | return onnx_tanh 57 | 58 | @classmethod 59 | def version_13( 60 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 61 | ) -> Callable[..., Any]: 62 | """ONNX version_13 Tanh op.""" 63 | cls._prepare(node, inputs, onnx_tanh) 64 | return onnx_tanh 65 | 66 | 67 | @functools.partial(jax.jit, static_argnames=()) 68 | def onnx_tanh(*input_args): 69 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Tanh.""" 70 | assert len(input_args) == 1 71 | x = input_args[0] 72 | return jnp.tanh(x) 73 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/tile.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Tile operator.""" 16 | 17 | # pylint: disable=unused-argument 18 | # pylint: disable=g-explicit-length-test 19 | from collections.abc import Callable, Sequence 20 | import functools 21 | from typing import Any 22 | 23 | import jax 24 | from jax import numpy as jnp 25 | from jaxonnxruntime.core import config_class 26 | 27 | config = config_class.config 28 | from jaxonnxruntime.core import handler 29 | from jaxonnxruntime.core import onnx_node 30 | 31 | 32 | @handler.register_op('Tile') 33 | class Tile(handler.Handler): 34 | """Implementation of the ONNX Tile operator.""" 35 | 36 | @classmethod 37 | def _prepare( 38 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 39 | ): 40 | if config.jaxort_only_allow_initializers_as_static_args: 41 | if node.inputs[1] not in node.context_graph.get_constant_dict(): 42 | raise ValueError( 43 | f'{node.inputs[1]} is not constant but used as `repeats` of Tile' 44 | ' static argument during `jax.jit`. the jitted function gives' 45 | ' wrong results if its value changes in another input.If you know' 46 | ' what you are doing, set' 47 | ' `config.update("jaxort_only_allow_initializers_as_static_args",' 48 | ' False)` to remove this contraint.' 49 | ) 50 | node.attrs_dict['repeats'] = tuple( 51 | node.context_graph.get_constant_dict()[node.inputs[1]].tolist() 52 | ) 53 | else: 54 | node.attrs_dict['repeats'] = tuple(inputs[1].tolist()) 55 | 56 | @classmethod 57 | def version_1( 58 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 59 | ) -> Callable[..., Any]: 60 | """ONNX version_1 Tile op.""" 61 | cls._prepare(node, inputs, onnx_tile) 62 | return onnx_tile 63 | 64 | @classmethod 65 | def version_6( 66 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 67 | ) -> Callable[..., Any]: 68 | """ONNX version_6 Tile op.""" 69 | cls._prepare(node, inputs, onnx_tile) 70 | return onnx_tile 71 | 72 | @classmethod 73 | def version_13( 74 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 75 | ) -> Callable[..., Any]: 76 | """ONNX version_13 Tile op.""" 77 | cls._prepare(node, inputs, onnx_tile) 78 | return onnx_tile 79 | 80 | 81 | @functools.partial(jax.jit, static_argnames='repeats') 82 | def onnx_tile(*input_args, repeats): 83 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Tile for more details.""" 84 | assert ( 85 | len(input_args) == 2 86 | ), f'Expected 2 input args but got {len(input_args)}' 87 | x = input_args[0] 88 | return jnp.tile(x, repeats) 89 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/transpose.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Transpose operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 38 | 39 | 40 | @handler.register_op("Transpose") 41 | class Transpose(handler.Handler): 42 | """Implementation of the ONNX Transpose operator.""" 43 | 44 | @classmethod 45 | def _prepare( 46 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 47 | ): 48 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 49 | 50 | @classmethod 51 | def version_1( 52 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 53 | ) -> Callable[..., Any]: 54 | """ONNX version_1 Transpose op.""" 55 | cls._prepare(node, inputs, onnx_transpose) 56 | return onnx_transpose 57 | 58 | @classmethod 59 | def version_13( 60 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 61 | ) -> Callable[..., Any]: 62 | """ONNX version_13 Transpose op.""" 63 | cls._prepare(node, inputs, onnx_transpose) 64 | return onnx_transpose 65 | 66 | 67 | @functools.partial(jax.jit, static_argnames="perm") 68 | def onnx_transpose(*input_args, perm): 69 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Transpose.""" 70 | assert len(input_args) == 1 71 | x = input_args[0] 72 | return jnp.transpose(x, perm) 73 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/trilu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Define ONNX Trilu operator.""" 16 | # pylint: disable=unused-argument 17 | # pylint: disable=g-explicit-length-test 18 | from collections.abc import Callable, Sequence 19 | import functools 20 | from typing import Any 21 | 22 | import jax 23 | from jax import numpy as jnp 24 | from jaxonnxruntime.core import config_class 25 | 26 | config = config_class.config 27 | from jaxonnxruntime.core import handler 28 | from jaxonnxruntime.core import onnx_node 29 | 30 | 31 | @handler.register_op('Trilu') 32 | class Trilu(handler.Handler): 33 | """Implementation of the ONNX Trilu operator.""" 34 | 35 | @classmethod 36 | def _prepare( 37 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 38 | ): 39 | node.attrs_dict['upper'] = node.attrs.get('upper', 1) 40 | if config.jaxort_only_allow_initializers_as_static_args: 41 | if ( 42 | len(node.inputs) == 1 43 | or node.inputs[1] not in node.context_graph.get_constant_dict() 44 | ): 45 | raise ValueError( 46 | "Trilu's `k` is not constant defined by the graph initializers but" 47 | ' used as a static argument. The function wrapped by `jax.jit` will' 48 | ' output incorrect results if its value changes in another input.' 49 | ) 50 | node.attrs_dict['k'] = int( 51 | node.context_graph.get_constant_dict()[node.inputs[1]].tolist()[0] 52 | ) 53 | else: 54 | node.attrs_dict['k'] = int(inputs[1]) if len(inputs) == 2 else 0 55 | 56 | @classmethod 57 | def version_14( 58 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 59 | ) -> Callable[..., Any]: 60 | """ONNX version_14 Trilu op.""" 61 | cls._prepare(node, inputs, onnx_trilu) 62 | return onnx_trilu 63 | 64 | 65 | @functools.partial(jax.jit, static_argnames=('upper', 'k')) 66 | def onnx_trilu(*input_args, k, upper): 67 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Trilu for more details.""" 68 | assert len(input_args) == 1 or len(input_args) == 2 69 | data = input_args[0] 70 | if upper: 71 | return jnp.triu(data, k) 72 | return jnp.tril(data, k) 73 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/unsqueeze.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Unsqueeze operator.""" 29 | from collections.abc import Callable, Sequence 30 | import functools 31 | from typing import Any 32 | 33 | import jax 34 | from jax import numpy as jnp 35 | from jaxonnxruntime.core import handler 36 | from jaxonnxruntime.core import onnx_node 37 | 38 | 39 | @handler.register_op('Unsqueeze') 40 | class Unsqueeze(handler.Handler): 41 | """Implementation of the ONNX Unsqueeze operator.""" 42 | 43 | @classmethod 44 | def _prepare( 45 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 46 | ): 47 | if 'axes' in node.attrs: 48 | node.attrs_dict['axis'] = node.attrs['axes'] 49 | if len(inputs) >= 2: 50 | node.attrs_dict['axis'] = tuple(inputs[1].tolist()) 51 | 52 | @classmethod 53 | def version_1( 54 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 55 | ) -> Callable[..., Any]: 56 | """ONNX version_1 Unsqueeze op.""" 57 | cls._prepare(node, inputs, onnx_unsqueeze) 58 | return onnx_unsqueeze 59 | 60 | @classmethod 61 | def version_11( 62 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 63 | ) -> Callable[..., Any]: 64 | """ONNX version_11 Unsqueeze op.""" 65 | cls._prepare(node, inputs, onnx_unsqueeze) 66 | return onnx_unsqueeze 67 | 68 | @classmethod 69 | def version_13( 70 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 71 | ) -> Callable[..., Any]: 72 | """ONNX version_13 Unsqueeze op.""" 73 | cls._prepare(node, inputs, onnx_unsqueeze) 74 | return onnx_unsqueeze 75 | 76 | 77 | @functools.partial(jax.jit, static_argnames='axis') 78 | def onnx_unsqueeze(*input_args, axis: list[int]): 79 | """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Unsqueeze.""" 80 | x = input_args[0] 81 | return jnp.expand_dims(x, axis) 82 | -------------------------------------------------------------------------------- /jaxonnxruntime/onnx_ops/where.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 The Jaxonnxruntime Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Define ONNX Where operator.""" 29 | # pylint: disable=unused-argument 30 | # pylint: disable=g-explicit-length-test 31 | from collections.abc import Callable, Sequence 32 | import functools 33 | from typing import Any 34 | 35 | import jax 36 | from jax import numpy as jnp 37 | from jaxonnxruntime.core import handler 38 | from jaxonnxruntime.core import onnx_node 39 | from jaxonnxruntime.onnx_ops import onnx_ops_utils 40 | 41 | 42 | @handler.register_op("Where") 43 | class Where(handler.Handler): 44 | """Implementation of the ONNX Where operator.""" 45 | 46 | @classmethod 47 | def _prepare( 48 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any 49 | ): 50 | onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) 51 | 52 | @classmethod 53 | def version_9( 54 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 55 | ) -> Callable[..., Any]: 56 | """ONNX version_9 Where op.""" 57 | cls._prepare(node, inputs, onnx_where) 58 | return onnx_where 59 | 60 | @classmethod 61 | def version_16( 62 | cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] 63 | ) -> Callable[..., Any]: 64 | """ONNX version_16 Where op.""" 65 | cls._prepare(node, inputs, onnx_where) 66 | return onnx_where 67 | 68 | 69 | @functools.partial(jax.jit, static_argnames=()) 70 | def onnx_where(*input_args): 71 | """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Where for more details.""" 72 | assert len(input_args) == 3 73 | cond, x, y = input_args 74 | return jnp.where(cond, x, y) 75 | -------------------------------------------------------------------------------- /jaxonnxruntime/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Jaxonnxruntime Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Current jaxonnxruntime version at head on Github.""" 16 | __version__ = "0.3.0" 17 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "jaxonnxruntime" 7 | description = "Jaxonnxruntime: JAX based ONNX Runtime." 8 | keywords = [] 9 | authors = [ 10 | {name = "Jaxonnxruntime team", email = "jaxonnxruntime-dev@google.com"}, 11 | ] 12 | dependencies = [ 13 | "numpy", 14 | "jax", 15 | "jaxlib", 16 | "absl-py", 17 | "jaxtyping", 18 | "chex", 19 | ] 20 | classifiers = [ 21 | "Development Status :: 3 - Alpha", 22 | "Intended Audience :: Developers", 23 | "Intended Audience :: Science/Research", 24 | "License :: OSI Approved :: Apache Software License", 25 | "Programming Language :: Python :: 3.7", 26 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 27 | ] 28 | dynamic = ["version", "readme"] 29 | 30 | [project.optional-dependencies] 31 | all = [ 32 | "onnx==1.12.0", # Please keep this as the last line since onnx request old protobuf lib. 33 | ] 34 | 35 | testing = [ 36 | "mypy", 37 | "pytest", 38 | "pytest-cov", 39 | "pytest-custom_exit_code", 40 | "pytest-xdist", 41 | ] 42 | 43 | [project.urls] 44 | homepage = "https://github.com/google/jaxonnxruntime" 45 | 46 | [tool.setuptools.dynamic] 47 | readme = {file = ["README.md"], content-type = "text/markdown"} 48 | version = {attr = "jaxonnxruntime.version.__version__"} 49 | 50 | [tool.setuptools.packages.find] 51 | include = ["jaxonnxruntime*"] 52 | 53 | [tool.setuptools.package-data] 54 | jaxonnxruntime = ["*py.typed"] 55 | 56 | [tool.yapf] 57 | based_on_style = "yapf" 58 | 59 | [tool.pytype] 60 | 61 | [tool.mypy] 62 | show_error_codes = true 63 | no_implicit_optional = true 64 | disable_error_code = "attr-defined" 65 | 66 | [[tool.mypy.overrides]] 67 | module = [ 68 | "tensorflow.*", 69 | "tensorboard.*", 70 | "absl.*", 71 | "jax.*", 72 | "rich.*", 73 | "jaxlib.cuda.*", 74 | "jaxlib.cpu.*", 75 | "msgpack", 76 | "numpy.*", 77 | "optax.*", 78 | "orbax.*", 79 | "opt_einsum.*", 80 | "scipy.*", 81 | "jaxlib.mlir.*", 82 | "yaml", 83 | ] 84 | ignore_missing_imports = true 85 | 86 | [tool.pytest.ini_options] 87 | filterwarnings = [ 88 | # By default error out on any warnings. 89 | ] 90 | 91 | [tool.coverage.report] 92 | exclude_lines = [ 93 | "@abc.abstractmethod", 94 | "raise NotImplementedError", 95 | ] 96 | --------------------------------------------------------------------------------