├── .github └── workflows │ ├── black-ruff.yml │ ├── check-urls.yml │ ├── codeql.yml │ ├── documentation.yml │ └── wheels-any.yml ├── .gitignore ├── CHANGELOGS.rst ├── CODE_OF_CONDUCT.md ├── LICENSE.txt ├── MANIFEST.in ├── README.rst ├── _doc ├── _static │ └── logo.png ├── api │ ├── array_api.rst │ ├── array_api_numpy.rst │ ├── array_api_ort.rst │ ├── docs.rst │ ├── f8.rst │ ├── graph_api.rst │ ├── index.rst │ ├── light_api.rst │ ├── npx_array_api.rst │ ├── npx_core_api.rst │ ├── npx_functions.rst │ ├── npx_jit_eager.rst │ ├── npx_numpy.rst │ ├── npx_tensors.rst │ ├── npx_types.rst │ ├── npx_var.rst │ ├── onnx_tools.rst │ ├── ort.rst │ ├── plotting.rst │ ├── profiling.rst │ ├── reference.rst │ ├── tools.rst │ └── translate_api.rst ├── command_lines.rst ├── conf.py ├── examples │ ├── README.txt │ ├── data │ │ └── small.onnx │ ├── plot_benchmark_rf.py │ ├── plot_f8.py │ ├── plot_first_example.py │ ├── plot_onnx_diff.py │ ├── plot_onnxruntime.py │ ├── plot_optimization.py │ └── plot_profiling.py ├── index.rst ├── license.rst ├── long_outputs.rst ├── run_coverage.sh ├── tech │ ├── aapi.rst │ └── index.rst └── tutorial │ ├── benchmarks.rst │ ├── graph_api.rst │ ├── index.rst │ ├── light_api.rst │ ├── numpy_api.rst │ ├── onnx_api.rst │ └── tools.rst ├── _unittests ├── onnx-numpy-skips.txt ├── onnx-ort-skips.txt ├── test_array_api.sh ├── ut_array_api │ ├── test_array_apis.py │ ├── test_hypothesis_array_api.py │ ├── test_onnx_numpy.py │ └── test_onnx_ort.py ├── ut_graph_api │ ├── data │ │ └── debug_7951-CPUep.0.onnx │ ├── test_graph_builder.py │ └── test_graph_builder_optim.py ├── ut_light_api │ ├── test_backend_export.py │ └── test_light_api.py ├── ut_npx │ ├── test_npx.py │ └── test_sklearn_array_api.py ├── ut_ort │ ├── data │ │ ├── prof_base.xlsx │ │ └── prof_opti.xlsx │ ├── test_ort_optimizer.py │ ├── test_ort_profile.py │ ├── test_ort_tensor.py │ └── test_sklearn_array_api_ort.py ├── ut_plotting │ ├── data │ │ ├── bug_Hardmax.onnx │ │ ├── onnx_text_plot_tree_cls_2.onnx │ │ ├── prof.csv │ │ └── tree_torch.onnx │ ├── test_dot_plot.py │ ├── test_graphviz.py │ ├── test_stat_plot.py │ └── test_text_plot.py ├── ut_reference │ ├── test_array_tensor.py │ ├── test_backend_extended_reference_evaluator.py │ ├── test_evaluator_yield.py │ └── test_reference_ops.py ├── ut_tools │ └── test_replace_constants.py ├── ut_translate_api │ ├── _data │ │ ├── custom_ops_type_inference_fails_0.onnx │ │ └── stft_inlined_batch_1.onnx │ ├── test_translate.py │ ├── test_translate_builder.py │ └── test_translate_classic.py ├── ut_validation │ ├── data │ │ └── small.onnx │ ├── test_diff.py │ ├── test_docs.py │ ├── test_f8.py │ └── test_tools.py ├── ut_xrun_doc │ ├── test_command_lines1.py │ ├── test_documentation_examples.py │ └── test_profiling.py └── win_test_array_api.bat ├── azure-pipelines.yml ├── onnx_array_api ├── __init__.py ├── __main__.py ├── _command_lines_parser.py ├── _helpers.py ├── annotations.py ├── array_api │ ├── __init__.py │ ├── _onnx_common.py │ ├── onnx_numpy.py │ └── onnx_ort.py ├── cache.py ├── ext_test_case.py ├── graph_api │ ├── __init__.py │ └── graph_builder.py ├── light_api │ ├── __init__.py │ ├── _op_var.py │ ├── _op_vars.py │ ├── model.py │ └── var.py ├── npx │ ├── __init__.py │ ├── npx_array_api.py │ ├── npx_constants.py │ ├── npx_core_api.py │ ├── npx_function_implementation.py │ ├── npx_functions.py │ ├── npx_functions_test.py │ ├── npx_graph_builder.py │ ├── npx_helper.py │ ├── npx_jit_eager.py │ ├── npx_numpy_tensors.py │ ├── npx_tensors.py │ ├── npx_types.py │ └── npx_var.py ├── ort │ ├── __init__.py │ ├── ort_optimizers.py │ ├── ort_profile.py │ └── ort_tensors.py ├── plotting │ ├── __init__.py │ ├── _helper.py │ ├── dot_plot.py │ ├── graphviz_helper.py │ ├── stat_plot.py │ └── text_plot.py ├── profiling.py ├── reference │ ├── __init__.py │ ├── evaluator.py │ ├── evaluator_yield.py │ └── ops │ │ ├── __init__.py │ │ ├── op_cast_like.py │ │ ├── op_concat.py │ │ ├── op_constant_of_shape.py │ │ ├── op_fused_matmul.py │ │ ├── op_memcpy_host.py │ │ ├── op_quick_gelu.py │ │ └── op_scatter_elements.py ├── tools │ ├── __init__.py │ └── replace_constants.py ├── translate_api │ ├── __init__.py │ ├── base_emitter.py │ ├── builder_emitter.py │ ├── inner_emitter.py │ ├── light_emitter.py │ ├── make_helper.py │ └── translate.py └── validation │ ├── __init__.py │ ├── diff.py │ ├── diff2html-ui-slim.min.js │ ├── diff2html.min.css │ ├── docs.py │ ├── f8.py │ └── tools.py ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt ├── setup.cfg └── setup.py /.github/workflows/black-ruff.yml: -------------------------------------------------------------------------------- 1 | name: Black + Ruff Format Checker 2 | on: [push, pull_request] 3 | jobs: 4 | black-format-check: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v2 8 | - uses: psf/black@stable 9 | with: 10 | options: "--diff --check" 11 | src: "." 12 | ruff-format-check: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | - uses: chartboost/ruff-action@v1 17 | -------------------------------------------------------------------------------- /.github/workflows/check-urls.yml: -------------------------------------------------------------------------------- 1 | name: Check URLs 2 | 3 | on: 4 | pull_request: 5 | branches: [main] 6 | schedule: 7 | # ┌───────────── minute (0 - 59) 8 | # │ ┌───────────── hour (0 - 23) 9 | # │ │ ┌───────────── day of the month (1 - 31) 10 | # │ │ │ ┌───────────── month (1 - 12 or JAN-DEC) 11 | # │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT) 12 | # │ │ │ │ │ 13 | # │ │ │ │ │ 14 | # │ │ │ │ │ 15 | # * * * * * 16 | - cron: '30 1 * * 0' 17 | 18 | jobs: 19 | build: 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v3 24 | 25 | - name: urls-checker-code 26 | uses: urlstechie/urlchecker-action@master 27 | with: 28 | subfolder: onnx_array_api 29 | file_types: .md,.py,.rst,.ipynb 30 | print_all: false 31 | timeout: 2 32 | retry_count# : 2 33 | # exclude_urls: https://dumps.wikimedia.org/other/pageviews/%Y/%Y-%m/pageviews-%Y%m%d-%H0000.gz,https://dumps.wikimedia.org/frwiki/latest/latest-all-titles-in-ns0.gz 34 | exclude_patterns: https://dumps.wikimedia.org/ 35 | # force_pass : true 36 | 37 | - name: urls-checker-docs 38 | uses: urlstechie/urlchecker-action@master 39 | with: 40 | subfolder: _doc 41 | file_types: .md,.py,.rst,.ipynb 42 | print_all: false 43 | timeout: 2 44 | retry_count# : 2 45 | exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document,https://github.com/onnx/tensorflow-onnx 46 | exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://github.com/onnx/tensorflow-onnx 47 | # force_pass : true 48 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | name: "Code Scanning - Action" 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | schedule: 9 | # ┌───────────── minute (0 - 59) 10 | # │ ┌───────────── hour (0 - 23) 11 | # │ │ ┌───────────── day of the month (1 - 31) 12 | # │ │ │ ┌───────────── month (1 - 12 or JAN-DEC) 13 | # │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT) 14 | # │ │ │ │ │ 15 | # │ │ │ │ │ 16 | # │ │ │ │ │ 17 | # * * * * * 18 | - cron: '30 1 * * 0' 19 | 20 | jobs: 21 | CodeQL-Build: 22 | # CodeQL runs on ubuntu-latest, windows-latest, and macos-latest 23 | runs-on: ubuntu-latest 24 | 25 | permissions: 26 | # required for all workflows 27 | security-events: write 28 | 29 | # only required for workflows in private repositories 30 | actions: read 31 | contents: read 32 | 33 | steps: 34 | - name: Checkout repository 35 | uses: actions/checkout@v3 36 | 37 | # Initializes the CodeQL tools for scanning. 38 | - name: Initialize CodeQL 39 | uses: github/codeql-action/init@v2 40 | # Override language selection by uncommenting this and choosing your languages 41 | # with: 42 | # languages: go, javascript, csharp, python, cpp, java, ruby 43 | 44 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). 45 | # If this step fails, then you should remove it and run the build manually (see below). 46 | - name: Autobuild 47 | uses: github/codeql-action/autobuild@v2 48 | 49 | # ℹ️ Command-line programs to run using the OS shell. 50 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 51 | 52 | # ✏️ If the Autobuild fails above, remove it and uncomment the following 53 | # three lines and modify them (or add more) to build your code if your 54 | # project uses a compiled language 55 | 56 | #- run: | 57 | # make bootstrap 58 | # make release 59 | 60 | - name: Perform CodeQL Analysis 61 | uses: github/codeql-action/analyze@v2 62 | -------------------------------------------------------------------------------- /.github/workflows/documentation.yml: -------------------------------------------------------------------------------- 1 | name: Documentation and Code Coverage 2 | 3 | on: 4 | push: 5 | pull_request: 6 | types: 7 | - closed 8 | branches: 9 | - main 10 | 11 | jobs: 12 | run: 13 | name: Build documentation on ${{ matrix.os }} 14 | runs-on: ${{ matrix.os }} 15 | strategy: 16 | matrix: 17 | os: [ubuntu-latest] 18 | 19 | steps: 20 | - uses: actions/checkout@v3 21 | 22 | - uses: actions/setup-python@v4 23 | with: 24 | python-version: '3.12' 25 | 26 | - uses: tlylt/install-graphviz@v1 27 | 28 | - name: Install pandoc 29 | run: sudo apt-get install -y pandoc 30 | 31 | - name: Install requirements 32 | run: python -m pip install -r requirements.txt 33 | 34 | - name: Install requirements dev 35 | run: python -m pip install -r requirements-dev.txt 36 | 37 | - name: Cache pip 38 | uses: actions/cache@v4 39 | with: 40 | path: ~/.cache/pip 41 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }} 42 | restore-keys: | 43 | ${{ runner.os }}-pip- 44 | ${{ runner.os }}- 45 | 46 | - name: Generate coverage report 47 | run: | 48 | pip install pytest 49 | pip install pytest-cov 50 | export PYTHONPATH=. 51 | pytest --cov=./onnx_array_api/ --cov-report=xml --durations=10 --ignore-glob=**LONG*.py --ignore-glob=**notebook*.py 52 | export PYTHONPATH= 53 | 54 | - name: Upload coverage reports to Codecov 55 | uses: codecov/codecov-action@v3 56 | env: 57 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 58 | 59 | - name: Install 60 | run: python -m pip install -e . -v 61 | 62 | - name: Copy license, changelogs 63 | run: | 64 | cp LICENSE* ./_doc 65 | cp CHANGELOGS* ./_doc 66 | 67 | - name: Documentation 68 | run: python -m sphinx ./_doc ./dist/html -n -w doc.txt 69 | 70 | - name: Summary 71 | run: cat doc.txt 72 | 73 | - name: Check for errors and warnings 74 | run: | 75 | if [[ $(grep ERROR doc.txt) ]]; then 76 | echo "Documentation produces errors." 77 | grep ERROR doc.txt 78 | exit 1 79 | fi 80 | if [[ $(grep WARNING doc.txt) ]]; then 81 | echo "Documentation produces warnings." 82 | grep WARNING doc.txt 83 | exit 1 84 | fi 85 | 86 | - uses: actions/upload-artifact@v4 87 | with: 88 | path: ./dist/html/** 89 | -------------------------------------------------------------------------------- /.github/workflows/wheels-any.yml: -------------------------------------------------------------------------------- 1 | name: Build Any Wheel 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - 'releases/**' 8 | 9 | jobs: 10 | build_wheels: 11 | name: Build wheels on ${{ matrix.os }} 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | os: [ubuntu-latest] 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | 20 | - uses: actions/setup-python@v4 21 | with: 22 | python-version: '3.12' 23 | 24 | - name: build wheel 25 | run: python -m pip wheel . 26 | 27 | - uses: actions/upload-artifact@v4 28 | with: 29 | path: ./onnx_array_api*.whl 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pyd 3 | *.dylib 4 | *.so 5 | *.whl 6 | *.xlsx 7 | coverage.html/* 8 | _cache/* 9 | .coverage 10 | dist/* 11 | build/* 12 | .eggs/* 13 | .hypothesis/* 14 | *egg-info/* 15 | onnxruntime_profile* 16 | prof 17 | test*.png 18 | _doc/sg_execution_times.rst 19 | _doc/auto_examples/* 20 | _doc/examples/_cache/* 21 | _doc/examples/onnxruntime_profile* 22 | _doc/examples/plot_*.png 23 | _doc/examples/plot_*.xlsx 24 | _doc/examples/data/*.optimized.onnx 25 | _doc/examples/*.html 26 | _doc/_static/require.js 27 | _doc/_static/viz.js 28 | _doc/LICENSE.txt 29 | _doc/CHANGELOGS.rst 30 | _unittests/ut__main/*.png 31 | _unittests/ut__main/_cache/* 32 | _unittests/ut__main/*.html 33 | _unittests/.hypothesis/* 34 | -------------------------------------------------------------------------------- /CHANGELOGS.rst: -------------------------------------------------------------------------------- 1 | Change Logs 2 | =========== 3 | 4 | 0.3.1 5 | +++++ 6 | 7 | * :pr:`100`: updates requirements, add 3.12 8 | * :pr:`96`: supports local functions in translator 9 | * :pr:`95`: improves translation to GraphBuilder 10 | 11 | 0.3.0 12 | +++++ 13 | 14 | * :pr:`93`: fixes evaluator type in ``compare_onnx_execution`` 15 | * :pr:`92`: avoids recursion errors in profiling 16 | * :pr:`87`: adds command line to replace contant by ConstantOfShape 17 | * :pr:`79`: first draft to export to GraphBuilder 18 | * :pr:`77`: supports ConcatOfShape and Slice with the light API 19 | 20 | 0.2.0 21 | +++++ 22 | 23 | * :pr:`76`, :pr:`79`: add a mode to compare models without execution 24 | * :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator 25 | * :pr:`71`: adds tools to compare two onnx graphs 26 | * :pr:`61`: adds function to plot onnx model as graphs 27 | * :pr:`60`: supports translation of local functions 28 | * :pr:`59`: add methods to update nodes in GraphAPI 29 | 30 | 0.1.3 31 | +++++ 32 | 33 | * :pr:`57`: implements GraphBuilder 34 | * :pr:`49`: adds command line to export a model into code 35 | * :pr:`48`: support for subgraph in light API 36 | * :pr:`47`: extends export onnx to code to support inner API 37 | * :pr:`46`: adds an export to convert an onnx graph into light API code 38 | * :pr:`45`: fixes light API for operators with two outputs 39 | 40 | 0.1.2 41 | +++++ 42 | 43 | * :pr:`42`: first sketch for a very simple API to create onnx graph in one or two lines 44 | * :pr:`27`: add function from_array_extended to convert 45 | an array to a TensorProto, including bfloat16 and float 8 types 46 | * :pr:`24`: add ExtendedReferenceEvaluator to support scenario 47 | for the Array API onnx does not support 48 | * :pr:`22`: support OrtValue in function *ort_profile* 49 | * :pr:`17`: implements ArrayAPI 50 | * :pr:`3`: fixes Array API with onnxruntime and scikit-learn 51 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | We are a community based on openness, as well as friendly and didactic discussions. 4 | 5 | We aspire to treat everybody equally, and value their contributions. 6 | 7 | Decisions are made based on technical merit and consensus. 8 | 9 | Code is not the only way to help the project. Reviewing pull requests, 10 | answering questions to help others on mailing lists or issues, organizing and 11 | teaching tutorials, working on the website, improving the documentation, are 12 | all priceless contributions. 13 | 14 | We abide by the principles of openness, respect, and consideration of others of 15 | the Python Software Foundation: https://www.python.org/psf/codeofconduct/ 16 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023-2025, Xavier Dupré 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune _doc 2 | prune _unittests 3 | exclude *.bat 4 | exclude *.yml 5 | exclude *.git* 6 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | 2 | .. image:: https://github.com/sdpython/onnx-array-api/raw/main/_doc/_static/logo.png 3 | :width: 120 4 | 5 | onnx-array-api: APIs to create ONNX Graphs 6 | ========================================== 7 | 8 | .. image:: https://dev.azure.com/xavierdupre3/onnx-array-api/_apis/build/status/sdpython.onnx-array-api 9 | :target: https://dev.azure.com/xavierdupre3/onnx-array-api/ 10 | 11 | .. image:: https://badge.fury.io/py/onnx-array-api.svg 12 | :target: http://badge.fury.io/py/onnx-array-api 13 | 14 | .. image:: http://img.shields.io/github/issues/sdpython/onnx-array-api.png 15 | :alt: GitHub Issues 16 | :target: https://github.com/sdpython/onnx-array-api/issues 17 | 18 | .. image:: https://img.shields.io/badge/license-MIT-blue.svg 19 | :alt: MIT License 20 | :target: https://opensource.org/license/MIT/ 21 | 22 | .. image:: https://img.shields.io/github/repo-size/sdpython/onnx-array-api 23 | :target: https://github.com/sdpython/onnx-array-api/ 24 | :alt: size 25 | 26 | .. image:: https://img.shields.io/badge/code%20style-black-000000.svg 27 | :target: https://github.com/psf/black 28 | 29 | .. image:: https://codecov.io/gh/sdpython/onnx-array-api/branch/main/graph/badge.svg?token=Wb9ZGDta8J 30 | :target: https://codecov.io/gh/sdpython/onnx-array-api 31 | 32 | **onnx-array-api** implements APIs to create custom ONNX graphs. 33 | The objective is to speed up the implementation of converter libraries. 34 | The library is released on 35 | `pypi/onnx-array-api `_ 36 | and its documentation is published at 37 | `APIs to create ONNX Graphs `_. 38 | 39 | Numpy API 40 | +++++++++ 41 | 42 | The first one matches **numpy API**. 43 | It gives the user the ability to convert functions written 44 | following the numpy API to convert that function into ONNX as 45 | well as to execute it. 46 | 47 | .. code-block:: python 48 | 49 | import numpy as np 50 | from onnx_array_api.npx import absolute, jit_onnx 51 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot 52 | 53 | def l1_loss(x, y): 54 | return absolute(x - y).sum() 55 | 56 | 57 | def l2_loss(x, y): 58 | return ((x - y) ** 2).sum() 59 | 60 | 61 | def myloss(x, y): 62 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1]) 63 | 64 | 65 | jitted_myloss = jit_onnx(myloss) 66 | 67 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 68 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32) 69 | 70 | res = jitted_myloss(x, y) 71 | print(res) 72 | 73 | print(onnx_simple_text_plot(jitted_myloss.get_onnx())) 74 | 75 | :: 76 | 77 | [0.042] 78 | opset: domain='' version=18 79 | input: name='x0' type=dtype('float32') shape=['', ''] 80 | input: name='x1' type=dtype('float32') shape=['', ''] 81 | Sub(x0, x1) -> r__0 82 | Abs(r__0) -> r__1 83 | ReduceSum(r__1, keepdims=0) -> r__2 84 | output: name='r__2' type=dtype('float32') shape=None 85 | 86 | It supports eager mode as well: 87 | 88 | .. code-block:: python 89 | 90 | import numpy as np 91 | from onnx_array_api.npx import absolute, eager_onnx 92 | 93 | 94 | def l1_loss(x, y): 95 | err = absolute(x - y).sum() 96 | print(f"l1_loss={err.numpy()}") 97 | return err 98 | 99 | 100 | def l2_loss(x, y): 101 | err = ((x - y) ** 2).sum() 102 | print(f"l2_loss={err.numpy()}") 103 | return err 104 | 105 | 106 | def myloss(x, y): 107 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1]) 108 | 109 | 110 | eager_myloss = eager_onnx(myloss) 111 | 112 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 113 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32) 114 | 115 | res = eager_myloss(x, y) 116 | print(res) 117 | 118 | :: 119 | 120 | l1_loss=[0.04] 121 | l2_loss=[0.002] 122 | [0.042] 123 | 124 | Light API 125 | +++++++++ 126 | 127 | The second API or **Light API** tends to do every thing in one line. 128 | It is inspired from the `Reverse Polish Notation 129 | `_. 130 | The euclidean distance looks like the following: 131 | 132 | .. code-block:: python 133 | 134 | import numpy as np 135 | from onnx_array_api.light_api import start 136 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot 137 | 138 | model = ( 139 | start() 140 | .vin("X") 141 | .vin("Y") 142 | .bring("X", "Y") 143 | .Sub() 144 | .rename("dxy") 145 | .cst(np.array([2], dtype=np.int64), "two") 146 | .bring("dxy", "two") 147 | .Pow() 148 | .ReduceSum() 149 | .rename("Z") 150 | .vout() 151 | .to_onnx() 152 | ) 153 | 154 | GraphBuilder API 155 | ++++++++++++++++ 156 | 157 | Almost every converting library (converting a machine learned model to ONNX) is implementing 158 | its own graph builder and customizes it for its needs. 159 | It handles some frequent tasks such as giving names to intermediate 160 | results, loading, saving onnx models. It can be used as well to extend an existing graph. 161 | 162 | .. code-block:: python 163 | 164 | import numpy as np 165 | from onnx_array_api.graph_api import GraphBuilder 166 | 167 | g = GraphBuilder() 168 | g.make_tensor_input("X", np.float32, (None, None)) 169 | g.make_tensor_input("Y", np.float32, (None, None)) 170 | r1 = g.make_node("Sub", ["X", "Y"]) # the name given to the output is given by the class, 171 | # it ensures the name is unique 172 | init = g.make_initializer(np.array([2], dtype=np.int64)) # the class automatically 173 | # converts the array to a tensor 174 | r2 = g.make_node("Pow", [r1, init]) 175 | g.make_node("ReduceSum", [r2], outputs=["Z"]) # the output name is given because 176 | # the user wants to choose the name 177 | g.make_tensor_output("Z", np.float32, (None, None)) 178 | 179 | onx = g.to_onnx() # final conversion to onnx 180 | -------------------------------------------------------------------------------- /_doc/_static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_doc/_static/logo.png -------------------------------------------------------------------------------- /_doc/api/array_api.rst: -------------------------------------------------------------------------------- 1 | onnx_array_api.array_api 2 | ======================== 3 | 4 | .. toctree:: 5 | 6 | array_api_numpy 7 | array_api_ort 8 | npx_array_api 9 | -------------------------------------------------------------------------------- /_doc/api/array_api_numpy.rst: -------------------------------------------------------------------------------- 1 | onnx_array_api.array_api.onnx_numpy 2 | ============================================= 3 | 4 | .. automodule:: onnx_array_api.array_api.onnx_numpy 5 | :members: 6 | -------------------------------------------------------------------------------- /_doc/api/array_api_ort.rst: -------------------------------------------------------------------------------- 1 | onnx_array_api.array_api.onnx_ort 2 | ================================= 3 | 4 | .. automodule:: onnx_array_api.array_api.onnx_ort 5 | :members: 6 | -------------------------------------------------------------------------------- /_doc/api/docs.rst: -------------------------------------------------------------------------------- 1 | validation.docs 2 | =============== 3 | 4 | make_euclidean 5 | ++++++++++++++ 6 | 7 | .. autofunction:: onnx_array_api.validation.docs.make_euclidean 8 | -------------------------------------------------------------------------------- /_doc/api/f8.rst: -------------------------------------------------------------------------------- 1 | Float 8 2 | ======= 3 | 4 | .. automodule:: onnx_array_api.validation.f8 5 | :members: 6 | -------------------------------------------------------------------------------- /_doc/api/graph_api.rst: -------------------------------------------------------------------------------- 1 | ======================== 2 | onnx_array_api.graph_api 3 | ======================== 4 | 5 | 6 | GraphBuilder 7 | ============ 8 | 9 | .. autoclass:: onnx_array_api.graph_api.GraphBuilder 10 | :members: 11 | 12 | NodePattern 13 | =========== 14 | 15 | .. autoclass:: onnx_array_api.graph_api.NodePattern 16 | :members: 17 | 18 | OptimizationOptions 19 | =================== 20 | 21 | .. autoclass:: onnx_array_api.graph_api.graph_builder.OptimizationOptions 22 | :members: 23 | -------------------------------------------------------------------------------- /_doc/api/index.rst: -------------------------------------------------------------------------------- 1 | 2 | === 3 | API 4 | === 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | array_api 10 | graph_api 11 | light_api 12 | translate_api 13 | npx_core_api 14 | npx_functions 15 | npx_jit_eager 16 | npx_numpy 17 | npx_tensors 18 | npx_types 19 | npx_var 20 | onnx_tools 21 | ort 22 | plotting 23 | reference 24 | tools 25 | profiling 26 | f8 27 | docs 28 | -------------------------------------------------------------------------------- /_doc/api/light_api.rst: -------------------------------------------------------------------------------- 1 | ======================== 2 | onnx_array_api.light_api 3 | ======================== 4 | 5 | 6 | Main API 7 | ======== 8 | 9 | start 10 | +++++ 11 | 12 | .. autofunction:: onnx_array_api.light_api.start 13 | 14 | g 15 | + 16 | 17 | .. autofunction:: onnx_array_api.light_api.g 18 | 19 | Classes for the Light API 20 | ========================= 21 | 22 | domain 23 | ++++++ 24 | 25 | ..autofunction:: onnx_array_api.light_api.domain 26 | 27 | BaseVar 28 | +++++++ 29 | 30 | .. autoclass:: onnx_array_api.light_api.var.BaseVar 31 | :members: 32 | 33 | OnnxGraph 34 | +++++++++ 35 | 36 | .. autoclass:: onnx_array_api.light_api.OnnxGraph 37 | :members: 38 | 39 | ProtoType 40 | +++++++++ 41 | 42 | .. autoclass:: onnx_array_api.light_api.model.ProtoType 43 | :members: 44 | 45 | SubDomain 46 | +++++++++ 47 | 48 | .. autoclass:: onnx_array_api.light_api.var.SubDomain 49 | :members: 50 | 51 | Var 52 | +++ 53 | 54 | .. autoclass:: onnx_array_api.light_api.Var 55 | :members: 56 | :inherited-members: 57 | 58 | Vars 59 | ++++ 60 | 61 | .. autoclass:: onnx_array_api.light_api.Vars 62 | :members: 63 | :inherited-members: 64 | 65 | Available operators 66 | =================== 67 | 68 | One input 69 | +++++++++ 70 | 71 | .. autoclass:: onnx_array_api.light_api._op_var.OpsVar 72 | :members: 73 | 74 | Two inputs or more 75 | ++++++++++++++++++ 76 | 77 | .. autoclass:: onnx_array_api.light_api._op_vars.OpsVars 78 | :members: 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /_doc/api/npx_array_api.rst: -------------------------------------------------------------------------------- 1 | onnx_array_api.npx.npx_array_api 2 | ================================ 3 | 4 | .. automodule:: onnx_array_api.npx.npx_array_api.BaseArrayApi 5 | :members: 6 | -------------------------------------------------------------------------------- /_doc/api/npx_core_api.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | npx_core_api 3 | ============ 4 | 5 | cst 6 | === 7 | 8 | .. autofunction:: onnx_array_api.npx.npx_core_api.cst 9 | 10 | make_tuple 11 | ========== 12 | 13 | .. autofunction:: onnx_array_api.npx.npx_core_api.make_tuple 14 | 15 | tuple_var 16 | ========= 17 | 18 | .. autofunction:: onnx_array_api.npx.npx_core_api.tuple_var 19 | 20 | npxapi_inline 21 | ============= 22 | 23 | .. autofunction:: onnx_array_api.npx.npx_core_api.npxapi_inline 24 | 25 | npxapi_function 26 | =============== 27 | 28 | .. autofunction:: onnx_array_api.npx.npx_core_api.npxapi_function 29 | 30 | var 31 | === 32 | 33 | .. autofunction:: onnx_array_api.npx.npx_core_api.var 34 | -------------------------------------------------------------------------------- /_doc/api/npx_functions.rst: -------------------------------------------------------------------------------- 1 | npx.npx_functions 2 | ================= 3 | 4 | .. autofunction:: onnx_array_api.npx.npx_functions.abs 5 | 6 | .. autofunction:: onnx_array_api.npx.npx_functions.absolute 7 | 8 | .. autofunction:: onnx_array_api.npx.npx_functions.arccos 9 | 10 | .. autofunction:: onnx_array_api.npx.npx_functions.arccosh 11 | 12 | .. autofunction:: onnx_array_api.npx.npx_functions.amax 13 | 14 | .. autofunction:: onnx_array_api.npx.npx_functions.amin 15 | 16 | .. autofunction:: onnx_array_api.npx.npx_functions.arange 17 | 18 | .. autofunction:: onnx_array_api.npx.npx_functions.argmax 19 | 20 | .. autofunction:: onnx_array_api.npx.npx_functions.argmin 21 | 22 | .. autofunction:: onnx_array_api.npx.npx_functions.arcsin 23 | 24 | .. autofunction:: onnx_array_api.npx.npx_functions.arcsinh 25 | 26 | .. autofunction:: onnx_array_api.npx.npx_functions.arctan 27 | 28 | .. autofunction:: onnx_array_api.npx.npx_functions.arctanh 29 | 30 | .. autofunction:: onnx_array_api.npx.npx_functions.cdist 31 | 32 | .. autofunction:: onnx_array_api.npx.npx_functions.ceil 33 | 34 | .. autofunction:: onnx_array_api.npx.npx_functions.clip 35 | 36 | .. autofunction:: onnx_array_api.npx.npx_functions.compress 37 | 38 | .. autofunction:: onnx_array_api.npx.npx_functions.compute 39 | 40 | .. autofunction:: onnx_array_api.npx.npx_functions.concat 41 | 42 | .. autofunction:: onnx_array_api.npx.npx_functions.cos 43 | 44 | .. autofunction:: onnx_array_api.npx.npx_functions.cosh 45 | 46 | .. autofunction:: onnx_array_api.npx.npx_functions.cumsum 47 | 48 | .. autofunction:: onnx_array_api.npx.npx_functions.det 49 | 50 | .. autofunction:: onnx_array_api.npx.npx_functions.dot 51 | 52 | .. autofunction:: onnx_array_api.npx.npx_functions.einsum 53 | 54 | .. autofunction:: onnx_array_api.npx.npx_functions.erf 55 | 56 | .. autofunction:: onnx_array_api.npx.npx_functions.exp 57 | 58 | .. autofunction:: onnx_array_api.npx.npx_functions.expand_dims 59 | 60 | .. autofunction:: onnx_array_api.npx.npx_functions.expit 61 | 62 | .. autofunction:: onnx_array_api.npx.npx_functions.floor 63 | 64 | .. autofunction:: onnx_array_api.npx.npx_functions.hstack 65 | 66 | .. autofunction:: onnx_array_api.npx.npx_functions.copy 67 | 68 | .. autofunction:: onnx_array_api.npx.npx_functions.identity 69 | 70 | .. autofunction:: onnx_array_api.npx.npx_functions.isnan 71 | 72 | .. autofunction:: onnx_array_api.npx.npx_functions.log 73 | 74 | .. autofunction:: onnx_array_api.npx.npx_functions.log1p 75 | 76 | .. autofunction:: onnx_array_api.npx.npx_functions.matmul 77 | 78 | .. autofunction:: onnx_array_api.npx.npx_functions.pad 79 | 80 | .. autofunction:: onnx_array_api.npx.npx_functions.reciprocal 81 | 82 | .. autofunction:: onnx_array_api.npx.npx_functions.relu 83 | 84 | .. autofunction:: onnx_array_api.npx.npx_functions.round 85 | 86 | .. autofunction:: onnx_array_api.npx.npx_functions.sigmoid 87 | 88 | .. autofunction:: onnx_array_api.npx.npx_functions.sign 89 | 90 | .. autofunction:: onnx_array_api.npx.npx_functions.sin 91 | 92 | .. autofunction:: onnx_array_api.npx.npx_functions.sinh 93 | 94 | .. autofunction:: onnx_array_api.npx.npx_functions.squeeze 95 | 96 | .. autofunction:: onnx_array_api.npx.npx_functions.tan 97 | 98 | .. autofunction:: onnx_array_api.npx.npx_functions.tanh 99 | 100 | .. autofunction:: onnx_array_api.npx.npx_functions.topk 101 | 102 | .. autofunction:: onnx_array_api.npx.npx_functions.transpose 103 | 104 | .. autofunction:: onnx_array_api.npx.npx_functions.unsqueeze 105 | 106 | .. autofunction:: onnx_array_api.npx.npx_functions.vstack 107 | 108 | .. autofunction:: onnx_array_api.npx.npx_functions.where 109 | -------------------------------------------------------------------------------- /_doc/api/npx_jit_eager.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | npx_jit_eager 3 | ============= 4 | 5 | eager_onnx 6 | ========== 7 | 8 | .. autofunction:: onnx_array_api.npx.npx_jit_eager.eager_onnx 9 | 10 | EagerOnnx 11 | ========= 12 | 13 | .. autoclass:: onnx_array_api.npx.npx_jit_eager.EagerOnnx 14 | :members: 15 | 16 | JitEager 17 | ======== 18 | 19 | .. autoclass:: onnx_array_api.npx.npx_jit_eager.JitEager 20 | :members: 21 | 22 | jit_onnx 23 | ======== 24 | 25 | .. autofunction:: onnx_array_api.npx.npx_jit_eager.jit_onnx 26 | 27 | JitOnnx 28 | ======= 29 | 30 | .. autoclass:: onnx_array_api.npx.npx_jit_eager.JitOnnx 31 | :members: 32 | -------------------------------------------------------------------------------- /_doc/api/npx_numpy.rst: -------------------------------------------------------------------------------- 1 | npx.npx_numpy_tensors 2 | ===================== 3 | 4 | EagerNumpyTensor 5 | ++++++++++++++++ 6 | 7 | .. autoclass:: onnx_array_api.npx.npx_numpy_tensors.EagerNumpyTensor 8 | :members: 9 | 10 | JitNumpyTensor 11 | ++++++++++++++ 12 | 13 | .. autoclass:: onnx_array_api.npx.npx_numpy_tensors.JitNumpyTensor 14 | :members: 15 | 16 | NumpyTensor 17 | +++++++++++ 18 | 19 | .. autoclass:: onnx_array_api.npx.npx_numpy_tensors.NumpyTensor 20 | :members: -------------------------------------------------------------------------------- /_doc/api/npx_tensors.rst: -------------------------------------------------------------------------------- 1 | =========== 2 | npx_tensors 3 | =========== 4 | 5 | 6 | EagerTensor 7 | =========== 8 | 9 | .. autoclass:: onnx_array_api.npx.npx_tensors.EagerTensor 10 | :members: 11 | -------------------------------------------------------------------------------- /_doc/api/npx_types.rst: -------------------------------------------------------------------------------- 1 | npx.npx_types 2 | ============= 3 | 4 | DType 5 | +++++ 6 | 7 | .. autoclass:: onnx_array_api.npx.npx_types.DType 8 | :members: 9 | 10 | ElemType 11 | ++++++++ 12 | 13 | .. autoclass:: onnx_array_api.npx.npx_types.ElemType 14 | :members: 15 | 16 | OptParType 17 | ++++++++++ 18 | 19 | .. autoclass:: onnx_array_api.npx.npx_types.OptParType 20 | :members: 21 | 22 | OptTensorType 23 | +++++++++++++ 24 | 25 | .. autoclass:: onnx_array_api.npx.npx_types.OptTensorType 26 | :members: 27 | 28 | ParType 29 | +++++++ 30 | 31 | .. autoclass:: onnx_array_api.npx.npx_types.ParType 32 | :members: 33 | 34 | Scalar 35 | ++++++ 36 | 37 | .. autoclass:: onnx_array_api.npx.npx_types.Scalar 38 | :members: 39 | 40 | SequenceType 41 | ++++++++++++ 42 | 43 | .. autoclass:: onnx_array_api.npx.npx_types.SequenceType 44 | :members: 45 | 46 | ShapeType 47 | +++++++++ 48 | 49 | .. autoclass:: onnx_array_api.npx.npx_types.ShapeType 50 | :members: 51 | 52 | TensorType 53 | ++++++++++ 54 | 55 | .. autoclass:: onnx_array_api.npx.npx_types.TensorType 56 | :members: 57 | 58 | TupleType 59 | +++++++++ 60 | 61 | .. autoclass:: onnx_array_api.npx.npx_types.TupleType 62 | :members: 63 | 64 | Shortcuts 65 | ========= 66 | 67 | .. autoclass:: onnx_array_api.npx.npx_types.Bool 68 | 69 | .. autoclass:: onnx_array_api.npx.npx_types.BFloat16 70 | 71 | .. autoclass:: onnx_array_api.npx.npx_types.Float16 72 | 73 | .. autoclass:: onnx_array_api.npx.npx_types.Float32 74 | 75 | .. autoclass:: onnx_array_api.npx.npx_types.Float64 76 | 77 | .. autoclass:: onnx_array_api.npx.npx_types.Int8 78 | 79 | .. autoclass:: onnx_array_api.npx.npx_types.Int16 80 | 81 | .. autoclass:: onnx_array_api.npx.npx_types.Int32 82 | 83 | .. autoclass:: onnx_array_api.npx.npx_types.Int64 84 | 85 | .. autoclass:: onnx_array_api.npx.npx_types.UInt8 86 | 87 | .. autoclass:: onnx_array_api.npx.npx_types.UInt16 88 | 89 | .. autoclass:: onnx_array_api.npx.npx_types.UInt32 90 | 91 | .. autoclass:: onnx_array_api.npx.npx_types.UInt64 92 | -------------------------------------------------------------------------------- /_doc/api/npx_var.rst: -------------------------------------------------------------------------------- 1 | npx.npx_var 2 | =========== 3 | 4 | Var 5 | +++ 6 | 7 | .. autoclass:: onnx_array_api.npx.npx_var.Var 8 | :members: 9 | 10 | Cst, Input 11 | ++++++++++ 12 | 13 | .. autoclass:: onnx_array_api.npx.npx_var.Cst 14 | :members: 15 | 16 | .. autoclass:: onnx_array_api.npx.npx_var.Input 17 | :members: 18 | 19 | ManyIdentity 20 | ++++++++++++ 21 | 22 | .. autoclass:: onnx_array_api.npx.npx_var.ManyIdentity 23 | :members: 24 | 25 | Par 26 | +++ 27 | 28 | .. autoclass:: onnx_array_api.npx.npx_var.Par 29 | :members: 30 | 31 | -------------------------------------------------------------------------------- /_doc/api/onnx_tools.rst: -------------------------------------------------------------------------------- 1 | onnx tools 2 | ========== 3 | 4 | Differences 5 | +++++++++++ 6 | 7 | .. autofunction:: onnx_array_api.validation.diff.html_diff 8 | 9 | .. autofunction:: onnx_array_api.validation.diff.text_diff 10 | 11 | Protos 12 | ++++++ 13 | 14 | .. autofunction:: onnx_array_api.validation.tools.randomize_proto 15 | -------------------------------------------------------------------------------- /_doc/api/ort.rst: -------------------------------------------------------------------------------- 1 | ort 2 | === 3 | 4 | ort_optimized_model 5 | +++++++++++++++++++ 6 | 7 | .. autofunction:: onnx_array_api.ort.ort_optimizers.ort_optimized_model 8 | 9 | EagerOrtTensor 10 | ++++++++++++++ 11 | 12 | .. autoclass:: onnx_array_api.ort.ort_tensors.EagerOrtTensor 13 | :members: 14 | 15 | JitOrtTensor 16 | ++++++++++++ 17 | 18 | .. autoclass:: onnx_array_api.ort.ort_tensors.JitOrtTensor 19 | :members: 20 | 21 | OrtTensor 22 | +++++++++ 23 | 24 | .. autoclass:: onnx_array_api.ort.ort_tensors.OrtTensor 25 | :members: 26 | 27 | merge_ort_profile 28 | +++++++++++++++++ 29 | 30 | .. autofunction:: onnx_array_api.ort.ort_profile.merge_ort_profile 31 | 32 | ort_profile 33 | +++++++++++ 34 | 35 | .. autofunction:: onnx_array_api.ort.ort_profile.ort_profile 36 | -------------------------------------------------------------------------------- /_doc/api/plotting.rst: -------------------------------------------------------------------------------- 1 | plotting 2 | ======== 3 | 4 | Dot 5 | +++ 6 | 7 | .. autofunction:: onnx_array_api.plotting.dot_plot.to_dot 8 | 9 | .. autofunction:: onnx_array_api.plotting.graphviz_helper.plot_dot 10 | 11 | Statistics 12 | ++++++++++ 13 | 14 | .. autofunction:: onnx_array_api.plotting.stat_plot.plot_ort_profile 15 | 16 | Text 17 | ++++ 18 | 19 | .. autofunction:: onnx_array_api.plotting.text_plot.onnx_text_plot_tree 20 | 21 | .. autofunction:: onnx_array_api.plotting.text_plot.onnx_text_plot_io 22 | 23 | .. autofunction:: onnx_array_api.plotting.text_plot.onnx_simple_text_plot 24 | -------------------------------------------------------------------------------- /_doc/api/profiling.rst: -------------------------------------------------------------------------------- 1 | profiling 2 | ========= 3 | 4 | ProfileNode 5 | +++++++++++ 6 | 7 | .. autoclass:: onnx_array_api.profiling.ProfileNode 8 | 9 | profile 10 | +++++++ 11 | 12 | .. autofunction:: onnx_array_api.profiling.profile 13 | 14 | profile2graph 15 | +++++++++++++ 16 | 17 | .. autofunction:: onnx_array_api.profiling.profile2graph 18 | -------------------------------------------------------------------------------- /_doc/api/reference.rst: -------------------------------------------------------------------------------- 1 | reference 2 | ========= 3 | 4 | ExtendedReferenceEvaluator 5 | ++++++++++++++++++++++++++ 6 | 7 | .. autoclass:: onnx_array_api.reference.ExtendedReferenceEvaluator 8 | :members: 9 | 10 | ResultType 11 | ++++++++++ 12 | 13 | .. autoclass:: onnx_array_api.reference.ResultType 14 | :members: 15 | 16 | ResultExecution 17 | +++++++++++++++ 18 | 19 | .. autoclass:: onnx_array_api.reference.ResultExecution 20 | :members: 21 | 22 | YieldEvaluator 23 | ++++++++++++++ 24 | 25 | .. autoclass:: onnx_array_api.reference.YieldEvaluator 26 | :members: 27 | 28 | DistanceExecution 29 | +++++++++++++++++ 30 | 31 | .. autoclass:: onnx_array_api.reference.DistanceExecution 32 | :members: 33 | 34 | compare_onnx_execution 35 | ++++++++++++++++++++++ 36 | 37 | .. autofunction:: onnx_array_api.reference.compare_onnx_execution 38 | -------------------------------------------------------------------------------- /_doc/api/tools.rst: -------------------------------------------------------------------------------- 1 | tools 2 | ===== 3 | 4 | Benchmark 5 | +++++++++ 6 | 7 | .. autofunction:: onnx_array_api.ext_test_case.measure_time 8 | 9 | Manipulations 10 | +++++++++++++ 11 | 12 | .. autofunction:: onnx_array_api.tools.replace_constants.replace_initializer_by_constant_of_shape 13 | 14 | Examples 15 | ++++++++ 16 | 17 | .. autofunction:: onnx_array_api.ext_test_case.example_path 18 | 19 | Unit tests 20 | ++++++++++ 21 | 22 | .. autofunction:: onnx_array_api.ext_test_case.ignore_warnings 23 | 24 | .. autoclass:: onnx_array_api.ext_test_case.ExtTestCase 25 | :members: 26 | -------------------------------------------------------------------------------- /_doc/api/translate_api.rst: -------------------------------------------------------------------------------- 1 | ============================ 2 | onnx_array_api.translate_api 3 | ============================ 4 | 5 | 6 | Main API 7 | ======== 8 | 9 | translate 10 | +++++++++ 11 | 12 | .. autofunction:: onnx_array_api.translate_api.translate 13 | 14 | make_helper 15 | +++++++++++ 16 | 17 | .. autofunction:: onnx_array_api.translate_api.make_helper.make_node_extended 18 | 19 | .. autofunction:: onnx_array_api.translate_api.make_helper.make_ref_attribute 20 | 21 | Classes for the Translater 22 | ========================== 23 | 24 | BaseEmitter 25 | +++++++++++ 26 | 27 | .. autoclass:: onnx_array_api.translate_api.base_emitter.BaseEmitter 28 | :members: 29 | 30 | EventType 31 | +++++++++ 32 | 33 | .. autoclass:: onnx_array_api.translate_api.base_emitter.EventType 34 | :members: 35 | 36 | InnerEmitter 37 | ++++++++++++ 38 | 39 | .. autoclass:: onnx_array_api.translate_api.inner_emitter.InnerEmitter 40 | :members: 41 | 42 | InnerEmitterShortInitializer 43 | ++++++++++++++++++++++++++++ 44 | 45 | .. autoclass:: onnx_array_api.translate_api.inner_emitter.InnerEmitterShortInitializer 46 | :members: 47 | 48 | LightEmitter 49 | ++++++++++++ 50 | 51 | .. autoclass:: onnx_array_api.translate_api.light_emitter.LightEmitter 52 | :members: 53 | 54 | Translater 55 | ++++++++++ 56 | 57 | .. autoclass:: onnx_array_api.translate_api.translate.Translater 58 | :members: 59 | -------------------------------------------------------------------------------- /_doc/command_lines.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | command lines 3 | ============= 4 | 5 | compare 6 | ======= 7 | 8 | The function convers an onnx file into some code. 9 | 10 | :: 11 | 12 | python -m compare -m1 model1.onnx -m2 model2.onnx -v 1 13 | 14 | Output example:: 15 | 16 | [compare_onnx_execution] got 2 inputs 17 | [compare_onnx_execution] execute first model 18 | [compare_onnx_execution] got 5 results 19 | [compare_onnx_execution] execute second model 20 | [compare_onnx_execution] got 5 results 21 | [compare_onnx_execution] compute edit distance 22 | [compare_onnx_execution] got 4 pairs 23 | [compare_onnx_execution] done 24 | = | INPUT float32 5x6 AAAA X | INPUT float32 5x6 AAAA X 25 | = | INPUT float32 5x6 AAAA Y | INPUT float32 5x6 AAAA Y 26 | = | RESULT float32 5x6 AABB Add res | RESULT float32 5x6 AABB Add res 27 | = | RESULT float32 5x6 AAAA Cos Z | RESULT float32 5x6 AAAA Cos Z 28 | 29 | .. runpython:: 30 | 31 | from onnx_array_api._command_lines_parser import get_parser_compare 32 | get_parser_compare().print_help() 33 | 34 | See function :func:`onnx_array_api.reference.compare_onnx_execution`. 35 | 36 | translate 37 | ========= 38 | 39 | The function convers an onnx file into some code. 40 | 41 | :: 42 | 43 | python -m translate ... 44 | 45 | Output example:: 46 | 47 | not yet ready 48 | 49 | .. runpython:: 50 | 51 | from onnx_array_api._command_lines_parser import get_parser_translate 52 | get_parser_translate().print_help() 53 | -------------------------------------------------------------------------------- /_doc/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from sphinx_runpython.github_link import make_linkcode_resolve 4 | from sphinx_runpython.conf_helper import has_dvipng, has_dvisvgm 5 | from onnx_array_api import __version__ 6 | 7 | extensions = [ 8 | "sphinx.ext.autodoc", 9 | "sphinx.ext.coverage", 10 | "sphinx.ext.githubpages", 11 | "sphinx.ext.ifconfig", 12 | "sphinx.ext.intersphinx", 13 | "sphinx.ext.mathjax", 14 | "sphinx.ext.viewcode", 15 | "sphinx.ext.todo", 16 | "sphinx_gallery.gen_gallery", 17 | "sphinx_issues", 18 | "matplotlib.sphinxext.plot_directive", 19 | "sphinx_runpython.epkg", 20 | "sphinx_runpython.gdot", 21 | "sphinx_runpython.runpython", 22 | ] 23 | 24 | if has_dvisvgm(): 25 | extensions.append("sphinx.ext.imgmath") 26 | imgmath_image_format = "svg" 27 | elif has_dvipng(): 28 | extensions.append("sphinx.ext.pngmath") 29 | imgmath_image_format = "png" 30 | else: 31 | extensions.append("sphinx.ext.mathjax") 32 | 33 | templates_path = ["_templates"] 34 | html_logo = "_static/logo.png" 35 | source_suffix = ".rst" 36 | master_doc = "index" 37 | project = "onnx-array-api" 38 | copyright = "2023-2024, Xavier Dupré" 39 | author = "Xavier Dupré" 40 | version = __version__ 41 | release = __version__ 42 | language = "en" 43 | exclude_patterns = [] 44 | pygments_style = "sphinx" 45 | todo_include_todos = True 46 | 47 | html_theme = "furo" 48 | html_theme_path = ["_static"] 49 | html_theme_options = {} 50 | html_static_path = ["_static"] 51 | html_sourcelink_suffix = "" 52 | 53 | issues_github_path = "sdpython/onnx-array-api" 54 | 55 | # The following is used by sphinx.ext.linkcode to provide links to github 56 | linkcode_resolve = make_linkcode_resolve( 57 | "onnx-array-api", 58 | ( 59 | "https://github.com/sdpython/onnx-array-api/" 60 | "blob/{revision}/{package}/" 61 | "{path}#L{lineno}" 62 | ), 63 | ) 64 | 65 | latex_elements = { 66 | "papersize": "a4", 67 | "pointsize": "10pt", 68 | "title": project, 69 | } 70 | 71 | intersphinx_mapping = { 72 | "matplotlib": ("https://matplotlib.org/", None), 73 | "numpy": ("https://numpy.org/doc/stable", None), 74 | "onnx": ("https://onnx.ai/onnx/", None), 75 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), 76 | "python": (f"https://docs.python.org/{sys.version_info.major}", None), 77 | "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), 78 | "sklearn": ("https://scikit-learn.org/stable/", None), 79 | "sklearn-onnx": ("https://onnx.ai/sklearn-onnx/", None), 80 | "torch": ("https://pytorch.org/docs/stable/", None), 81 | } 82 | 83 | # Check intersphinx reference targets exist 84 | nitpicky = True 85 | # See also scikit-learn/scikit-learn#26761 86 | nitpick_ignore = [ 87 | ("py:class", "False"), 88 | ("py:class", "True"), 89 | ("py:class", "pipeline.Pipeline"), 90 | ("py:class", "default=sklearn.utils.metadata_routing.UNCHANGED"), 91 | ] 92 | 93 | nitpick_ignore_regex = [ 94 | ("py:func", ".*numpy[.].*"), 95 | ("py:func", ".*scipy[.].*"), 96 | ("py:class", ".*onnxruntime[.].*"), 97 | ("py:class", ".*onnx_array_api.npx.npx_types.OptParTypeTupleType_.*"), 98 | ("py:class", ".*onnx_array_api.npx.npx_types.ParType[a-z].*"), 99 | ("py:class", ".*onnx_array_api.npx.npx_types.OptTensorType_.*"), 100 | ("py:class", ".*onnx_array_api.npx.npx_types.TensorType_.*"), 101 | ("py:class", ".*onnx_array_api.npx.npx_types.[ui].*"), 102 | ] 103 | 104 | sphinx_gallery_conf = { 105 | # path to your examples scripts 106 | "examples_dirs": os.path.join(os.path.dirname(__file__), "examples"), 107 | # path where to save gallery generated examples 108 | "gallery_dirs": "auto_examples", 109 | } 110 | 111 | epkg_dictionary = { 112 | "Array API": "https://data-apis.org/array-api/", 113 | "ArrayAPI": ( 114 | "https://data-apis.org/array-api/", 115 | ("2022.12/API_specification/generated/array_api.{0}.html", 1), 116 | ), 117 | "ast": "https://docs.python.org/3/library/ast.html", 118 | "cProfile.Profile": "https://docs.python.org/3/library/profile.html#profile.Profile", 119 | "DOT": "https://graphviz.org/doc/info/lang.html", 120 | "Graphviz": "https://graphviz.org/", 121 | "inner API": "https://onnx.ai/onnx/intro/python.html", 122 | "JIT": "https://en.wikipedia.org/wiki/Just-in-time_compilation", 123 | "onnx": "https://onnx.ai/onnx/", 124 | "onnx-graphsurgeon": "https://github.com/NVIDIA/TensorRT/tree/main/tools/onnx-graphsurgeon", 125 | "onnx.helper": "https://onnx.ai/onnx/api/helper.html", 126 | "ONNX": "https://onnx.ai/", 127 | "ONNX Operators": "https://onnx.ai/onnx/operators/", 128 | "onnxruntime": "https://onnxruntime.ai/", 129 | "onnxruntime-training": "https://onnxruntime.ai/docs/get-started/training-on-device.html", 130 | "numpy": "https://numpy.org/", 131 | "numba": "https://numba.pydata.org/", 132 | "onnx-array-api": ("https://sdpython.github.io/doc/onnx-array-api/dev/"), 133 | "onnxscript": "https://github.com/microsoft/onnxscript", 134 | "pyinstrument": "https://github.com/joerick/pyinstrument", 135 | "python": "https://www.python.org/", 136 | "pytorch": "https://pytorch.org/", 137 | "reverse Polish notation": "https://en.wikipedia.org/wiki/Reverse_Polish_notation", 138 | "scikit-learn": "https://scikit-learn.org/stable/", 139 | "scipy": "https://scipy.org/", 140 | "sklearn-onnx": "https://onnx.ai/sklearn-onnx/", 141 | "spox": "https://github.com/Quantco/spox", 142 | "sphinx-gallery": "https://github.com/sphinx-gallery/sphinx-gallery", 143 | "tensorflow": "https://www.tensorflow.org/", 144 | "tensorflow-onnx": "https://github.com/onnx/tensorflow-onnx", 145 | "torch": "https://pytorch.org/docs/stable/torch.html", 146 | "torch.onnx": "https://pytorch.org/docs/stable/onnx.html", 147 | # 148 | "C_OrtValue": ( 149 | "https://onnxruntime.ai/docs/api/csharp/api/Microsoft.ML.OnnxRuntime.OrtValue.html" 150 | ), 151 | "OrtValue": ( 152 | "https://onnxruntime.ai/docs/api/python/api_summary.html#onnxruntime.OrtValue" 153 | ), 154 | } 155 | -------------------------------------------------------------------------------- /_doc/examples/README.txt: -------------------------------------------------------------------------------- 1 | Example gallery 2 | =============== 3 | 4 | A couple of examples to illustrate different implementation 5 | of dot product (see also :epkg:`sphinx-gallery`). 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /_doc/examples/data/small.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_doc/examples/data/small.onnx -------------------------------------------------------------------------------- /_doc/examples/plot_f8.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. _l-example-float8: 3 | 4 | About float 8 5 | ============= 6 | 7 | Float 8 types were recently introduced to speed up the 8 | training of deep learning models. 9 | 10 | Possible values 11 | +++++++++++++++ 12 | 13 | First E4M3FN. 14 | """ 15 | 16 | import pprint 17 | from onnx_array_api.validation.f8 import CastFloat8 18 | 19 | pprint.pprint(CastFloat8.values_e4m3fn) 20 | 21 | 22 | ############################################ 23 | # Then E5M2. 24 | 25 | pprint.pprint(CastFloat8.values_e5m2) 26 | -------------------------------------------------------------------------------- /_doc/examples/plot_first_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | .. _l-onnx-array-first-api-example: 4 | 5 | First examples with onnx-array-api 6 | ================================== 7 | 8 | This demonstrates an easy case with :epkg:`onnx-array-api`. 9 | It shows how a function can be easily converted into 10 | ONNX. 11 | 12 | A loss function from numpy to ONNX 13 | ++++++++++++++++++++++++++++++++++ 14 | 15 | The first example takes a loss function and converts it into ONNX. 16 | """ 17 | 18 | import numpy as np 19 | 20 | from onnx_array_api.npx import absolute, jit_onnx 21 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot 22 | 23 | 24 | ################################ 25 | # The function looks like a numpy function. 26 | def l1_loss(x, y): 27 | return absolute(x - y).sum() 28 | 29 | 30 | ################################ 31 | # The function needs to be converted into ONNX with function jit_onnx. 32 | # jitted_l1_loss is a wrapper. It intercepts all calls to l1_loss. 33 | # When it happens, it checks the input types and creates the 34 | # corresponding ONNX graph. 35 | jitted_l1_loss = jit_onnx(l1_loss) 36 | 37 | ################################ 38 | # First execution and conversion to ONNX. 39 | # The wrapper caches the created onnx graph. 40 | # It reuses it if the input types and the number of dimension are the same. 41 | # It creates a new one otherwise and keep the old one. 42 | 43 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 44 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32) 45 | 46 | res = jitted_l1_loss(x, y) 47 | print(res) 48 | 49 | #################################### 50 | # The ONNX graph can be accessed the following way. 51 | print(onnx_simple_text_plot(jitted_l1_loss.get_onnx())) 52 | 53 | ################################ 54 | # We can also define a more complex loss by computing L1 loss on 55 | # the first column and L2 loss on the seconde one. 56 | 57 | 58 | def l1_loss(x, y): 59 | return absolute(x - y).sum() 60 | 61 | 62 | def l2_loss(x, y): 63 | return ((x - y) ** 2).sum() 64 | 65 | 66 | def myloss(x, y): 67 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1]) 68 | 69 | 70 | jitted_myloss = jit_onnx(myloss) 71 | 72 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 73 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32) 74 | 75 | res = jitted_myloss(x, y) 76 | print(res) 77 | 78 | print(onnx_simple_text_plot(jitted_myloss.get_onnx())) 79 | 80 | ############################ 81 | # Eager mode 82 | # ++++++++++ 83 | 84 | import numpy as np 85 | 86 | from onnx_array_api.npx import absolute, eager_onnx 87 | 88 | 89 | def l1_loss(x, y): 90 | """ 91 | err is a type inheriting from 92 | :class:`EagerTensor `. 93 | It needs to be converted to numpy first before any display. 94 | """ 95 | err = absolute(x - y).sum() 96 | print(f"l1_loss={err.numpy()}") 97 | return err 98 | 99 | 100 | def l2_loss(x, y): 101 | err = ((x - y) ** 2).sum() 102 | print(f"l2_loss={err.numpy()}") 103 | return err 104 | 105 | 106 | def myloss(x, y): 107 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1]) 108 | 109 | 110 | ################################# 111 | # Eager mode is enabled by function :func:`eager_onnx 112 | # `. 113 | # It intercepts all calls to `my_loss`. On the first call, 114 | # it replaces a numpy array by a tensor corresponding to the 115 | # selected runtime, here numpy as well through 116 | # :class:`EagerNumpyTensor 117 | # `. 118 | eager_myloss = eager_onnx(myloss) 119 | 120 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 121 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32) 122 | 123 | ################################# 124 | # First execution and conversion to ONNX. 125 | # The wrapper caches many Onnx graphs corresponding to 126 | # simple opeator, (`+`, `-`, `/`, `*`, ...), reduce functions, 127 | # any other function from the API. 128 | # It reuses it if the input types and the number of dimension are the same. 129 | # It creates a new one otherwise and keep the old ones. 130 | res = eager_myloss(x, y) 131 | print(res) 132 | 133 | ################################ 134 | # There is no ONNX graph to show. Every operation 135 | # is converted into small ONNX graphs. 136 | -------------------------------------------------------------------------------- /_doc/examples/plot_onnx_diff.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | .. _l-onnx-diff-example: 4 | 5 | Compares the conversions of the same model with different options 6 | ================================================================= 7 | 8 | The script compares two onnx models obtained with the same trained 9 | scikit-learn models but converted with different options. 10 | 11 | A model 12 | +++++++ 13 | """ 14 | 15 | from sklearn.mixture import GaussianMixture 16 | from sklearn.datasets import load_iris 17 | from sklearn.model_selection import train_test_split 18 | from skl2onnx import to_onnx 19 | from onnx_array_api.reference import compare_onnx_execution 20 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot 21 | 22 | 23 | data = load_iris() 24 | X_train, X_test = train_test_split(data.data) 25 | model = GaussianMixture() 26 | model.fit(X_train) 27 | 28 | ################################# 29 | # Conversion to onnx 30 | # ++++++++++++++++++ 31 | 32 | onx = to_onnx( 33 | model, X_train[:1], options={id(model): {"score_samples": True}}, target_opset=12 34 | ) 35 | 36 | print(onnx_simple_text_plot(onx)) 37 | 38 | ################################## 39 | # Conversion to onnx without ReduceLogSumExp 40 | # ++++++++++++++++++++++++++++++++++++++++++ 41 | 42 | onx2 = to_onnx( 43 | model, 44 | X_train[:1], 45 | options={id(model): {"score_samples": True}}, 46 | black_op={"ReduceLogSumExp"}, 47 | target_opset=12, 48 | ) 49 | 50 | print(onnx_simple_text_plot(onx2)) 51 | 52 | 53 | ############################################# 54 | # Differences 55 | # +++++++++++ 56 | # 57 | # Function :func:`onnx_array_api.reference.compare_onnx_execution` 58 | # compares the intermediate results of two onnx models. Then it finds 59 | # the best alignmet between the two models using an edit distance. 60 | 61 | res1, res2, align, dc = compare_onnx_execution(onx, onx2, verbose=1) 62 | print("------------") 63 | text = dc.to_str(res1, res2, align) 64 | print(text) 65 | 66 | ############################### 67 | # See :ref:`l-long-output-compare_onnx_execution` for a better view. 68 | # The display shows that ReduceSumSquare was replaced by Mul + ReduceSum, 69 | # and ReduceLogSumExp by ReduceMax + Sub + Exp + Log + Add. 70 | -------------------------------------------------------------------------------- /_doc/examples/plot_onnxruntime.py: -------------------------------------------------------------------------------- 1 | """ 2 | First examples with onnxruntime 3 | =============================== 4 | 5 | Example :ref:`l-onnx-array-first-api-example` defines a custom 6 | loss and then executes it with class 7 | :class:`onnx.reference.ReferenceEvaluator`. 8 | Next example replaces it with :epkg:`onnxruntime`. 9 | 10 | Example 11 | +++++++ 12 | """ 13 | 14 | import numpy as np 15 | 16 | from onnx_array_api.npx import absolute, jit_onnx 17 | from onnx_array_api.ort.ort_tensors import JitOrtTensor, OrtTensor 18 | 19 | 20 | def l1_loss(x, y): 21 | return absolute(x - y).sum() 22 | 23 | 24 | def l2_loss(x, y): 25 | return ((x - y) ** 2).sum() 26 | 27 | 28 | def myloss(x, y): 29 | l1 = l1_loss(x[:, 0], y[:, 0]) 30 | l2 = l2_loss(x[:, 1], y[:, 1]) 31 | return l1 + l2 32 | 33 | 34 | ort_myloss = jit_onnx(myloss, JitOrtTensor, target_opsets={"": 17}, ir_version=8) 35 | 36 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 37 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32) 38 | 39 | xort = OrtTensor.from_array(x) 40 | yort = OrtTensor.from_array(y) 41 | 42 | res = ort_myloss(xort, yort) 43 | print(res.numpy()) 44 | 45 | ############################### 46 | # Profiling 47 | # +++++++++ 48 | from onnx_array_api.profiling import profile, profile2graph 49 | 50 | x = np.random.randn(10000, 2).astype(np.float32) 51 | y = np.random.randn(10000, 2).astype(np.float32) 52 | xort = OrtTensor.from_array(x) 53 | yort = OrtTensor.from_array(y) 54 | 55 | 56 | def loop_ort(n): 57 | for _ in range(n): 58 | ort_myloss(xort, yort) 59 | 60 | 61 | def loop_numpy(n): 62 | for _ in range(n): 63 | myloss(x, y) 64 | 65 | 66 | def loop(n=1000): 67 | loop_numpy(n) 68 | loop_ort(n) 69 | 70 | 71 | ps = profile(loop)[0] 72 | root, nodes = profile2graph(ps, clean_text=lambda x: x.split("/")[-1]) 73 | text = root.to_text() 74 | print(text) 75 | 76 | ############################## 77 | # Benchmark 78 | # +++++++++ 79 | 80 | from pandas import DataFrame 81 | from tqdm import tqdm 82 | 83 | from onnx_array_api.ext_test_case import measure_time 84 | 85 | data = [] 86 | for n in tqdm([1, 10, 100, 1000, 10000, 100000]): 87 | x = np.random.randn(n, 2).astype(np.float32) 88 | y = np.random.randn(n, 2).astype(np.float32) 89 | 90 | obs = measure_time(lambda x=x, y=y: myloss(x, y)) 91 | obs["name"] = "numpy" 92 | obs["n"] = n 93 | data.append(obs) 94 | 95 | xort = OrtTensor.from_array(x) 96 | yort = OrtTensor.from_array(y) 97 | obs = measure_time(lambda xort=xort, yort=yort: ort_myloss(xort, yort)) 98 | obs["name"] = "ort" 99 | obs["n"] = n 100 | data.append(obs) 101 | 102 | df = DataFrame(data) 103 | piv = df.pivot(index="n", columns="name", values="average") 104 | piv 105 | 106 | ############################ 107 | # Plots 108 | # +++++ 109 | 110 | import matplotlib.pyplot as plt 111 | 112 | fig, ax = plt.subplots(1, 2, figsize=(12, 4)) 113 | piv.plot( 114 | title="Comparison between numpy and onnxruntime", logx=True, logy=True, ax=ax[0] 115 | ) 116 | piv["ort/numpy"] = piv["ort"] / piv["numpy"] 117 | piv["ort/numpy"].plot(title="Ratio ort/numpy", logx=True, ax=ax[1]) 118 | fig.savefig("plot_onnxruntime.png") 119 | -------------------------------------------------------------------------------- /_doc/examples/plot_optimization.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | .. _l-onnx-array-onnxruntime-optimization: 4 | 5 | Optimization with onnxruntime 6 | ============================= 7 | 8 | *onnxruntime* optimizes the onnx graph by default before running 9 | the inference. It modifies, fuses or add new operators. 10 | Some of them are standard onnx operators, some of them 11 | are implemented in onnxruntime (see `Supported Operators 12 | `_). 13 | This example looks into the differences of two models. 14 | 15 | Optimize a model with onnxruntime 16 | +++++++++++++++++++++++++++++++++ 17 | """ 18 | 19 | import os 20 | from pprint import pprint 21 | import numpy 22 | from pandas import DataFrame 23 | import matplotlib.pyplot as plt 24 | from onnx import load 25 | from onnx_array_api.ext_test_case import example_path 26 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot 27 | from onnx_array_api.validation.diff import text_diff, html_diff 28 | from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions 29 | from onnx_array_api.ext_test_case import measure_time 30 | from onnx_array_api.ort.ort_optimizers import ort_optimized_model 31 | 32 | 33 | filename = example_path("data/small.onnx") 34 | optimized = filename + ".optimized.onnx" 35 | 36 | if not os.path.exists(optimized): 37 | ort_optimized_model(filename, output=optimized) 38 | print(optimized) 39 | 40 | ############################# 41 | # Output comparison 42 | # +++++++++++++++++ 43 | 44 | so = SessionOptions() 45 | so.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL 46 | img = numpy.random.random((1, 3, 112, 112)).astype(numpy.float32) 47 | 48 | sess = InferenceSession(filename, so, providers=["CPUExecutionProvider"]) 49 | sess_opt = InferenceSession(optimized, so, providers=["CPUExecutionProvider"]) 50 | input_name = sess.get_inputs()[0].name 51 | out = sess.run(None, {input_name: img})[0] 52 | out_opt = sess_opt.run(None, {input_name: img})[0] 53 | if out.shape != out_opt.shape: 54 | print("ERROR shape are different {out.shape} != {out_opt.shape}") 55 | diff = numpy.abs(out - out_opt).max() 56 | print(f"Differences: {diff}") 57 | 58 | #################################### 59 | # Difference 60 | # ++++++++++ 61 | # 62 | # Unoptimized model. 63 | 64 | with open(filename, "rb") as f: 65 | model = load(f) 66 | print("first model to text...") 67 | text1 = onnx_simple_text_plot(model, indent=False) 68 | print(text1) 69 | 70 | ##################################### 71 | # Optimized model. 72 | 73 | 74 | with open(optimized, "rb") as f: 75 | model = load(f) 76 | print("second model to text...") 77 | text2 = onnx_simple_text_plot(model, indent=False) 78 | print(text2) 79 | 80 | ######################################## 81 | # Differences 82 | 83 | print("differences...") 84 | print(text_diff(text1, text2)) 85 | 86 | ##################################### 87 | # HTML version. 88 | 89 | print("html differences...") 90 | output = html_diff(text1, text2) 91 | with open("diff_html.html", "w", encoding="utf-8") as f: 92 | f.write(output) 93 | print("done.") 94 | 95 | ##################################### 96 | # Benchmark 97 | # +++++++++ 98 | 99 | img = numpy.random.random((1, 3, 112, 112)).astype(numpy.float32) 100 | 101 | t1 = measure_time(lambda: sess.run(None, {input_name: img}), repeat=25, number=25) 102 | t1["name"] = "original" 103 | print("Original model") 104 | pprint(t1) 105 | 106 | t2 = measure_time(lambda: sess_opt.run(None, {input_name: img}), repeat=25, number=25) 107 | t2["name"] = "optimized" 108 | print("Optimized") 109 | pprint(t2) 110 | 111 | 112 | ############################ 113 | # Plots 114 | # +++++ 115 | 116 | 117 | fig, ax = plt.subplots(1, 1, figsize=(12, 4)) 118 | 119 | df = DataFrame([t1, t2]).set_index("name") 120 | df 121 | 122 | ####################################### 123 | # And the graph is: 124 | 125 | ax.bar(df.index, df["average"].values, yerr=df["deviation"].values, capsize=6) 126 | ax.set_title("Measure performance of optimized model\nlower is better") 127 | plt.grid() 128 | fig.savefig("plot_optimization.png") 129 | -------------------------------------------------------------------------------- /_doc/examples/plot_profiling.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | .. _l-onnx-array-onnxruntime-profiling: 4 | 5 | Profiling with onnxruntime 6 | ========================== 7 | 8 | *onnxruntime* optimizes the onnx graph by default before running 9 | the inference. It modifies, fuses or add new operators. 10 | Some of them are standard onnx operators, some of them 11 | are implemented in onnxruntime (see `Supported Operators 12 | `_). 13 | This example profiles the two models. 14 | 15 | Optimize a model with onnxruntime 16 | +++++++++++++++++++++++++++++++++ 17 | """ 18 | 19 | import os 20 | import numpy 21 | import matplotlib.pyplot as plt 22 | from onnxruntime import get_available_providers 23 | from onnx_array_api.ext_test_case import example_path 24 | from onnx_array_api.ort.ort_optimizers import ort_optimized_model 25 | from onnx_array_api.ort.ort_profile import ort_profile, merge_ort_profile 26 | from onnx_array_api.plotting.stat_plot import plot_ort_profile 27 | 28 | 29 | suffix = "" 30 | filename = example_path(f"data/small{suffix}.onnx") 31 | optimized = filename + ".optimized.onnx" 32 | print(f"model={filename!r}") 33 | 34 | if not os.path.exists(optimized): 35 | ort_optimized_model(filename, output=optimized) 36 | print(f"optimized={optimized!r}") 37 | 38 | ############################# 39 | # .. _l-example-ort-profiling: 40 | # 41 | # Profiling 42 | # +++++++++ 43 | 44 | feeds = {"input": numpy.random.random((1, 3, 112, 112)).astype(numpy.float32)} 45 | prof_base = ort_profile( 46 | filename, 47 | feeds, 48 | repeat=6, 49 | disable_optimization=True, 50 | providers=["CPUExecutionProvider"], 51 | ) 52 | prof_base.to_excel(f"prof_base{suffix}.xlsx", index=False) 53 | prof_base 54 | 55 | ####################################### 56 | # And the optimized model. 57 | 58 | prof_opti = ort_profile( 59 | optimized, 60 | feeds, 61 | repeat=6, 62 | disable_optimization=True, 63 | providers=["CPUExecutionProvider"], 64 | ) 65 | prof_opti.to_excel(f"prof_opti{suffix}.xlsx", index=False) 66 | prof_opti 67 | 68 | ####################################### 69 | # And the graph is: 70 | 71 | unique_op = set(prof_base["args_op_name"]) 72 | fig, ax = plt.subplots(2, 2, figsize=(10, len(unique_op)), sharex="col") 73 | plot_ort_profile(prof_base, ax[0, 0], ax[0, 1], title="baseline") 74 | plot_ort_profile(prof_opti, ax[1, 0], ax[1, 1], title="optimized") 75 | fig.tight_layout() 76 | fig.savefig(f"plot_profiling{suffix}.png") 77 | 78 | ################################################## 79 | # Merging profiles 80 | # ++++++++++++++++ 81 | # 82 | # Let's try to compare both profiles assuming every iteration 83 | # process the same image and the input and output size are the 84 | # same at every iteration. 85 | 86 | merge, gr = merge_ort_profile(prof_base, prof_opti) 87 | merge.to_excel(f"plot_profiling_merged{suffix}.xlsx", index=False) 88 | merge 89 | 90 | ##################################################### 91 | # More detailed 92 | 93 | gr.to_excel(f"plot_profiling_merged_details{suffix}.xlsx", index=False) 94 | gr 95 | 96 | ################################ 97 | # Final plot 98 | # ++++++++++ 99 | 100 | # let's filter out unsignificant operator. 101 | grmax = gr["durbase"] + gr["duropti"] 102 | total = grmax.sum() 103 | grmax /= total 104 | gr = gr[grmax >= 0.01] 105 | 106 | 107 | fig, ax = plt.subplots(1, 2, figsize=(14, min(gr.shape[0], 500)), sharey=True) 108 | gr[["durbase", "duropti"]].plot.barh(ax=ax[0]) 109 | ax[0].set_title("Side by side duration") 110 | gr = gr.copy() 111 | gr[["countbase", "countopti"]].plot.barh(ax=ax[1]) 112 | ax[1].set_title("Side by side count") 113 | fig.tight_layout() 114 | fig.savefig(f"plot_profiling_side_by_side{suffix}.png") 115 | 116 | 117 | ######################################## 118 | # On CUDA 119 | # +++++++ 120 | 121 | 122 | if "CUDAExecutionProvider" in get_available_providers(): 123 | print("Profiling on CUDA") 124 | prof_base = ort_profile( 125 | filename, 126 | feeds, 127 | repeat=6, 128 | disable_optimization=True, 129 | providers=["CUDAExecutionProvider"], 130 | ) 131 | prof_base.to_excel(f"prof_cuda_base{suffix}.xlsx", index=False) 132 | 133 | prof_opti = ort_profile( 134 | optimized, 135 | feeds, 136 | repeat=6, 137 | disable_optimization=True, 138 | providers=["CUDAExecutionProvider", "CPUExecutionProvider"], 139 | ) 140 | prof_opti.to_excel(f"prof_cuda_opti{suffix}.xlsx", index=False) 141 | 142 | unique_op = set(prof_base["args_op_name"]) 143 | fig, ax = plt.subplots(2, 2, figsize=(10, len(unique_op)), sharex="col") 144 | plot_ort_profile(prof_base, ax[0, 0], ax[0, 1], title="baseline") 145 | plot_ort_profile(prof_opti, ax[1, 0], ax[1, 1], title="optimized") 146 | fig.tight_layout() 147 | fig.savefig(f"plot_profiling_cuda{suffix}.png") 148 | 149 | merge, gr = merge_ort_profile(prof_base, prof_opti) 150 | merge.to_excel(f"plot_profiling_merged{suffix}.xlsx", index=False) 151 | gr.to_excel(f"plot_profiling_merged_details{suffix}.xlsx", index=False) 152 | 153 | grmax = gr["durbase"] + gr["duropti"] 154 | total = grmax.sum() 155 | grmax /= total 156 | gr = gr[grmax >= 0.01] 157 | 158 | fig, ax = plt.subplots(1, 2, figsize=(14, min(gr.shape[0], 500)), sharey=True) 159 | gr[["durbase", "duropti"]].plot.barh(ax=ax[0]) 160 | ax[0].set_title("Side by side duration") 161 | gr = gr.copy() 162 | gr[["countbase", "countopti"]].plot.barh(ax=ax[1]) 163 | ax[1].set_title("Side by side count") 164 | fig.tight_layout() 165 | fig.savefig(f"plot_profiling_side_by_side_cuda{suffix}.png") 166 | 167 | else: 168 | print(f"CUDA not available in {get_available_providers()}.") 169 | fig, ax = None, None 170 | 171 | ax 172 | -------------------------------------------------------------------------------- /_doc/index.rst: -------------------------------------------------------------------------------- 1 | 2 | onnx-array-api: APIs to create ONNX Graphs 3 | ========================================== 4 | 5 | .. image:: https://dev.azure.com/xavierdupre3/onnx-array-api/_apis/build/status/sdpython.onnx-array-api 6 | :target: https://dev.azure.com/xavierdupre3/onnx-array-api/ 7 | 8 | .. image:: https://badge.fury.io/py/onnx-array-api.svg 9 | :target: http://badge.fury.io/py/onnx-array-api 10 | 11 | .. image:: http://img.shields.io/github/issues/sdpython/onnx-array-api.png 12 | :alt: GitHub Issues 13 | :target: https://github.com/sdpython/onnx-array-api/issues 14 | 15 | .. image:: https://img.shields.io/badge/license-MIT-blue.svg 16 | :alt: MIT License 17 | :target: https://opensource.org/license/MIT/ 18 | 19 | .. image:: https://img.shields.io/github/repo-size/sdpython/onnx-array-api 20 | :target: https://github.com/sdpython/onnx-array-api/ 21 | :alt: size 22 | 23 | .. image:: https://img.shields.io/badge/code%20style-black-000000.svg 24 | :target: https://github.com/psf/black 25 | 26 | .. image:: https://codecov.io/gh/sdpython/onnx-array-api/branch/main/graph/badge.svg?token=Wb9ZGDta8J 27 | :target: https://codecov.io/gh/sdpython/onnx-array-api 28 | 29 | **onnx-array-api** implements APIs to create custom ONNX graphs. 30 | The objective is to speed up the implementation of converter libraries. 31 | 32 | .. toctree:: 33 | :maxdepth: 1 34 | :caption: Contents 35 | 36 | tutorial/index 37 | api/index 38 | tech/index 39 | command_lines 40 | auto_examples/index 41 | 42 | .. toctree:: 43 | :maxdepth: 1 44 | :caption: More 45 | 46 | CHANGELOGS 47 | license 48 | long_outputs 49 | 50 | Sources available on 51 | `github/onnx-array-api `_. 52 | 53 | GraphBuilder API 54 | ++++++++++++++++ 55 | 56 | Almost every converting library (converting a machine learned model to ONNX) is implementing 57 | its own graph builder and customizes it for its needs. 58 | It handles some frequent tasks such as giving names to intermediate 59 | results, loading, saving onnx models. It can be used as well to extend an existing graph. 60 | See :ref:`l-graph-api`. 61 | 62 | .. runpython:: 63 | :showcode: 64 | 65 | import numpy as np 66 | from onnx_array_api.graph_api import GraphBuilder 67 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot 68 | 69 | g = GraphBuilder() 70 | g.make_tensor_input("X", np.float32, (None, None)) 71 | g.make_tensor_input("Y", np.float32, (None, None)) 72 | r1 = g.make_node("Sub", ["X", "Y"]) # the name given to the output is given by the class, 73 | # it ensures the name is unique 74 | init = g.make_initializer(np.array([2], dtype=np.int64)) # the class automatically 75 | # converts the array to a tensor 76 | r2 = g.make_node("Pow", [r1, init]) 77 | g.make_node("ReduceSum", [r2], outputs=["Z"]) # the output name is given because 78 | # the user wants to choose the name 79 | g.make_tensor_output("Z", np.float32, (None, None)) 80 | 81 | onx = g.to_onnx() # final conversion to onnx 82 | 83 | print(onnx_simple_text_plot(onx)) 84 | 85 | Light API 86 | +++++++++ 87 | 88 | The syntax is inspired from the 89 | `Reverse Polish Notation `_. 90 | This kind of API is easy to use to build new graphs, 91 | less easy to extend an existing graph. See :ref:`l-light-api`. 92 | 93 | .. runpython:: 94 | :showcode: 95 | 96 | import numpy as np 97 | from onnx_array_api.light_api import start 98 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot 99 | 100 | model = ( 101 | start() 102 | .vin("X") 103 | .vin("Y") 104 | .bring("X", "Y") 105 | .Sub() 106 | .rename("dxy") 107 | .cst(np.array([2], dtype=np.int64), "two") 108 | .bring("dxy", "two") 109 | .Pow() 110 | .ReduceSum() 111 | .rename("Z") 112 | .vout() 113 | .to_onnx() 114 | ) 115 | 116 | print(onnx_simple_text_plot(model)) 117 | 118 | Numpy API 119 | +++++++++ 120 | 121 | Writing ONNX graphs requires to know ONNX syntax unless 122 | it is possible to reuse an existing syntax such as :epkg:`numpy`. 123 | This is what this API is doing. 124 | This kind of API is easy to use to build new graphs, 125 | almost impossible to use to extend new graphs as it usually requires 126 | to know onnx for that. See :ref:`l-numpy-api-onnx`. 127 | 128 | .. runpython:: 129 | :showcode: 130 | :warningout: DeprecationWarning, FutureWarning 131 | :process: 132 | 133 | import numpy as np # A 134 | from onnx_array_api.npx import absolute, jit_onnx 135 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot 136 | 137 | def l1_loss(x, y): 138 | return absolute(x - y).sum() 139 | 140 | 141 | def l2_loss(x, y): 142 | return ((x - y) ** 2).sum() 143 | 144 | 145 | def myloss(x, y): 146 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1]) 147 | 148 | 149 | jitted_myloss = jit_onnx(myloss) 150 | 151 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 152 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32) 153 | res = jitted_myloss(x, y) 154 | print(res) 155 | 156 | print(onnx_simple_text_plot(jitted_myloss.get_onnx())) 157 | 158 | .. gdot:: 159 | :script: DOT-SECTION 160 | :process: 161 | 162 | # index 163 | import numpy as np 164 | from onnx_array_api.npx import absolute, jit_onnx 165 | from onnx_array_api.plotting.dot_plot import to_dot 166 | 167 | 168 | def l1_loss(x, y): 169 | return absolute(x - y).sum() 170 | 171 | 172 | def l2_loss(x, y): 173 | return ((x - y) ** 2).sum() 174 | 175 | 176 | def myloss(x, y): 177 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1]) 178 | 179 | 180 | jitted_myloss = jit_onnx(myloss) 181 | 182 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 183 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32) 184 | res = jitted_myloss(x, y) 185 | print(to_dot(jitted_myloss.get_onnx())) 186 | 187 | Older versions 188 | ++++++++++++++ 189 | 190 | * `0.3.0 <../v0.3.0/index.html>`_ 191 | * `0.2.0 <../v0.2.0/index.html>`_ 192 | * `0.1.3 <../v0.1.3/index.html>`_ 193 | * `0.1.2 <../v0.1.2/index.html>`_ 194 | -------------------------------------------------------------------------------- /_doc/license.rst: -------------------------------------------------------------------------------- 1 | License 2 | ======= 3 | 4 | .. literalinclude:: LICENSE.txt 5 | :language: none 6 | -------------------------------------------------------------------------------- /_doc/long_outputs.rst: -------------------------------------------------------------------------------- 1 | :hide-toc: 2 | 3 | ========================== 4 | Long outputs uneasy to see 5 | ========================== 6 | 7 | onnx 8 | ==== 9 | 10 | .. _l-long-output-compare_onnx_execution: 11 | 12 | onnx_array_api.reference.compare_onnx_execution 13 | +++++++++++++++++++++++++++++++++++++++++++++++ 14 | 15 | From example :ref:`l-onnx-diff-example` for function 16 | :func:`onnx_array_api.reference.compare_onnx_execution`. 17 | See also `raw rendering `_. 18 | 19 | :: 20 | 21 | 1 = | INITIA float64 1 HAAA Ad_Addcst | INITIA float64 1 HAAA Ad_Addcst 22 | 2 = | INITIA float64 4x4 ADZF Ge_Gemmcst | INITIA float64 4x4 ADZF Ge_Gemmcst 23 | 3 = | INITIA float64 4 USEA Ge_Gemmcst1 | INITIA float64 4 USEA Ge_Gemmcst1 24 | 4 = | INITIA float64 1 AAAA Mu_Mulcst | INITIA float64 1 AAAA Mu_Mulcst 25 | 5 = | INITIA float64 1 DAAA Ad_Addcst1 | INITIA float64 1 DAAA Ad_Addcst1 26 | 6 = | INITIA float64 1 AAAA Ad_Addcst2 | INITIA float64 1 AAAA Ad_Addcst2 27 | 7 = | INPUT float64 1x4 AAAA X | INPUT float64 1x4 AAAA X 28 | 8 = | RESULT float64 1x4 UTFC Gemm Ge_Y0 | RESULT float64 1x4 UTFC Gemm Ge_Y0 29 | 9 + | | RESULT float64 1x4 TIEG Mul Mu_C01 30 | 10 ~ | RESULT float64 1x1 NAAA ReduceSumS Re_reduced0 | RESULT float64 1x1 NAAA ReduceSum Re_reduced0 31 | 11 = | RESULT float64 1x1 NAAA Concat Co_concat_re | RESULT float64 1x1 NAAA Concat Co_concat_re 32 | 12 = | RESULT float64 1x1 UAAA Add Ad_C02 | RESULT float64 1x1 UAAA Add Ad_C02 33 | 13 = | RESULT float64 1x1 DAAA Mul Mu_C0 | RESULT float64 1x1 DAAA Mul Mu_C0 34 | 14 = | RESULT float64 1x1 GAAA Add Ad_C01 | RESULT float64 1x1 GAAA Add Ad_C01 35 | 15 = | RESULT float64 1x1 GAAA Add Ad_C0 | RESULT float64 1x1 GAAA Add Ad_C0 36 | 16 = | RESULT int64 1x1 AAAA ArgMax label | RESULT int64 1x1 AAAA ArgMax label 37 | 17 + | | RESULT float64 1x1 GAAA ReduceMax Re_reduced03 38 | 18 + | | RESULT float64 1x1 AAAA Sub Su_C01 39 | 19 + | | RESULT float64 1x1 BAAA Exp Ex_output0 40 | 20 + | | RESULT float64 1x1 BAAA ReduceSum Re_reduced02 41 | 21 + | | RESULT float64 1x1 AAAA Log Lo_output0 42 | 22 ~ | RESULT float64 1x1 GAAA ReduceLogS score_sample | RESULT float64 1x1 GAAA Add score_sample 43 | 23 = | RESULT float64 1x1 AAAA Sub Su_C0 | RESULT float64 1x1 AAAA Sub Su_C0 44 | 24 = | RESULT float64 1x1 BAAA Exp probabilitie | RESULT float64 1x1 BAAA Exp probabilitie 45 | 25 = | OUTPUT int64 1x1 AAAA label | OUTPUT int64 1x1 AAAA label 46 | 26 = | OUTPUT float64 1x1 BAAA probabilitie | OUTPUT float64 1x1 BAAA probabilitie 47 | 27 = | OUTPUT float64 1x1 GAAA score_sample | OUTPUT float64 1x1 GAAA score_sample 48 | -------------------------------------------------------------------------------- /_doc/run_coverage.sh: -------------------------------------------------------------------------------- 1 | python3 -m pytest --cov --cov-report html:_doc/_static/cov_html _unittests 2 | -------------------------------------------------------------------------------- /_doc/tech/aapi.rst: -------------------------------------------------------------------------------- 1 | .. _l-array-api-painpoint: 2 | 3 | Difficulty to implement an Array API for ONNX 4 | ============================================= 5 | 6 | Implementing the full array API is not always easy with :epkg:`onnx`. 7 | Python is not strongly typed and many different types can be used 8 | to represent a value. Argument *axis* can be an integer or a tuple 9 | (see `min from Array API 10 | `_ 12 | for example). On the other side, `ReduceMin from ONNX 13 | `_ 14 | is considered as a tensor. 15 | 16 | Performance 17 | +++++++++++ 18 | 19 | The Array API must work in eager mode and for every operation, 20 | it generates an ONNX graph and executes it with a specific 21 | backend. It can be :epkg:`numpy`, :epkg:`onnxruntime` or any other 22 | backend. The generation of every graph takes a significant amount of time. 23 | It must be avoided. These graphs are cached. But a graph can be reused 24 | only if the inputs - by ONNX semantic - change. If a parameter change, 25 | a new graph must be cached. Method :meth:`JitEager.make_key 26 | ` 27 | generates a unique key based on the input it receives, 28 | the signature of the function to call. If the key is the same, 29 | a cached onnx can be reused on the second call. 30 | 31 | However, eager mode - use a small single onnx graph for every operation - 32 | is not the most efficient one. At the same time, the design must allow 33 | to merge every needed operation into a bigger graph. 34 | Bigger graphs can be more easily optimized by the backend. 35 | 36 | Input vs parameter 37 | ++++++++++++++++++ 38 | 39 | An input is a tensor or array, a parameter is any other type. 40 | Following onnx semantic, an input is variable, a parameter is frozen 41 | cannot be changed. It is a constant. A good design would be 42 | to considered any named input (`**kwargs`) a parameter and 43 | any input (`*args`) a tensor. But the Array API does not follow that 44 | design. Function `astype 45 | _` 47 | takes two inputs. Operator `Cast 48 | _` 49 | takes one input and a frozen parameter `to`. 50 | And python allows `astype(x, dtype)` as well as `astype(x, dtype=dtype)` 51 | unless the signature enforces one call over another type. 52 | There may be ambiguities from time to time. 53 | Beside, from onnx point of view, argument dtype should be named. 54 | 55 | Tensor type 56 | +++++++++++ 57 | 58 | An :class:`EagerTensor ` 59 | must be used to represent any tensor. 60 | This class defines the backend to use as well. 61 | :class:`EagerNumpyTensor 62 | ` 63 | for :epkg:`numpy`, :class:`EagerOrtTensor 64 | ` 65 | for :epkg:`onnxruntime`. Since the Array API is new, 66 | existing packages do not fully support the API if they support it 67 | (:epkg:`scikit-learn`). Some numpy array may still be used. 68 | 69 | Inplace 70 | +++++++ 71 | 72 | ONNX has no notion of inplace computation. Therefore something 73 | like `coefs[:, 1] = 1` is not valid unless some code is written 74 | to create another tensor. The current design supports some of these 75 | by storing every call to `__setitem__`. The user sees `coefs` 76 | but the framework sees that `coefs` holds a reference to another 77 | tensor. That's the one the framework needs to use. However, since 78 | `__setitem__` is used for efficiency, it becomes less than efficient 79 | with this design and should be avoided. This assumption may be true 80 | when the backend is relying on CPU but not on GPU. 81 | A function such as `empty 82 | `_ should be avoided as it 84 | has to be followed by calls to `__setitem__`. 85 | 86 | Eager or compilation 87 | ++++++++++++++++++++ 88 | 89 | Eager mode is what the Array API implies. 90 | Every function is converted into an ONNX graph based 91 | on its inputs without any knownledge of how these inputs 92 | were obtained. This graph is then executed before going 93 | to the next call of a function from the API. 94 | The conversion of a machine learned model 95 | into ONNX implies the gathering of all these operations 96 | into a graph. It means using a mode that records all the function 97 | calls to compile every tiny onnx graph into a unique graph. 98 | 99 | Iterators and Reduction 100 | +++++++++++++++++++++++ 101 | 102 | An efficient implementation of function 103 | :func:`numpy.any` or :func:`numpy.all` returns 104 | as soon as the result is known. :func:`numpy.all` is 105 | false whenever the first false condition is met. 106 | Same goes for :func:`numpy.any` which is true 107 | whenever the first true condition is met. 108 | There is no such operator in ONNX (<= 20) because 109 | it is unlikely to appear in a machine learned model. 110 | However, it is highly used when two results are 111 | compared in unit tests. The ONNX implementation is 112 | not efficient due to that reason but it only impacts 113 | the unit tests. 114 | 115 | Types 116 | +++++ 117 | 118 | :epkg:`onnx` supports more types than :epkg:`numpy` does. 119 | It is not always easy to deal with bfloat16 or float8 types. 120 | -------------------------------------------------------------------------------- /_doc/tech/index.rst: -------------------------------------------------------------------------------- 1 | Technical Details 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | aapi -------------------------------------------------------------------------------- /_doc/tutorial/benchmarks.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | Benchmarks 3 | ========== 4 | 5 | A list of benchmark used to improve to the performance of 6 | ONNX components (onnx, onnxruntime, onnx-array-api, ...). 7 | 8 | .. toctree:: 9 | 10 | ../auto_examples/plot_benchmark_rf 11 | -------------------------------------------------------------------------------- /_doc/tutorial/graph_api.rst: -------------------------------------------------------------------------------- 1 | .. _l-graph-api: 2 | 3 | ================================= 4 | GraphBuilder: common API for ONNX 5 | ================================= 6 | 7 | This is a very common way to build ONNX graph. There are some 8 | annoying steps while building an ONNX graph. The first one is to 9 | give unique names to every intermediate result in the graph. The second 10 | is the conversion from numpy arrays to onnx tensors. A *graph builder*, 11 | here implemented by class 12 | :class:`GraphBuilder ` 13 | usually makes these two frequent tasks easier. 14 | 15 | .. runpython:: 16 | :showcode: 17 | 18 | import numpy as np 19 | from onnx_array_api.graph_api import GraphBuilder 20 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot 21 | 22 | g = GraphBuilder() 23 | g.make_tensor_input("X", np.float32, (None, None)) 24 | g.make_tensor_input("Y", np.float32, (None, None)) 25 | r1 = g.make_node("Sub", ["X", "Y"]) # the name given to the output is given by the class, 26 | # it ensures the name is unique 27 | init = g.make_initializer(np.array([2], dtype=np.int64)) # the class automatically 28 | # converts the array to a tensor 29 | r2 = g.make_node("Pow", [r1, init]) 30 | g.make_node("ReduceSum", [r2], outputs=["Z"]) # the output name is given because 31 | # the user wants to choose the name 32 | g.make_tensor_output("Z", np.float32, (None, None)) 33 | 34 | onx = g.to_onnx() # final conversion to onnx 35 | 36 | print(onnx_simple_text_plot(onx)) 37 | 38 | A more simple versions of the same code to produce the same graph. 39 | 40 | .. runpython:: 41 | :showcode: 42 | 43 | import numpy as np 44 | from onnx_array_api.graph_api import GraphBuilder 45 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot 46 | 47 | g = GraphBuilder() 48 | g.make_tensor_input("X", np.float32, (None, None)) 49 | g.make_tensor_input("Y", np.float32, (None, None)) 50 | r1 = g.op.Sub("X", "Y") # the method name indicates which operator to use, 51 | # this can be used when there is no ambiguity about the 52 | # number of outputs 53 | r2 = g.op.Pow(r1, np.array([2], dtype=np.int64)) 54 | g.op.ReduceSum(r2, outputs=["Z"]) # the still wants the user to specify the name 55 | g.make_tensor_output("Z", np.float32, (None, None)) 56 | 57 | onx = g.to_onnx() 58 | 59 | print(onnx_simple_text_plot(onx)) 60 | -------------------------------------------------------------------------------- /_doc/tutorial/index.rst: -------------------------------------------------------------------------------- 1 | 2 | ======== 3 | Tutorial 4 | ======== 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | onnx_api 10 | graph_api 11 | light_api 12 | numpy_api 13 | tools 14 | benchmarks 15 | -------------------------------------------------------------------------------- /_doc/tutorial/light_api.rst: -------------------------------------------------------------------------------- 1 | .. _l-light-api: 2 | 3 | ========================================== 4 | Light API for ONNX: everything in one line 5 | ========================================== 6 | 7 | It is inspired from the :epkg:`reverse Polish notation`. 8 | Following example implements the euclidean distance. 9 | This API tries to keep it simple and intuitive to short functions. 10 | 11 | .. runpython:: 12 | :showcode: 13 | 14 | import numpy as np 15 | from onnx_array_api.light_api import start 16 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot 17 | 18 | model = ( 19 | start() 20 | .vin("X") 21 | .vin("Y") 22 | .bring("X", "Y") 23 | .Sub() 24 | .rename("dxy") 25 | .cst(np.array([2], dtype=np.int64), "two") 26 | .bring("dxy", "two") 27 | .Pow() 28 | .ReduceSum() 29 | .rename("Z") 30 | .vout() 31 | .to_onnx() 32 | ) 33 | 34 | print(onnx_simple_text_plot(model)) 35 | 36 | There are two kinds of methods, the graph methods, playing with the graph structure, 37 | and the methods for operators starting with an upper letter. 38 | 39 | Graph methods 40 | ============= 41 | 42 | Any graph must start with function :func:`start `. 43 | It is usually following by `vin` to add an input. 44 | 45 | * bring (:meth:`Var.bring `, 46 | :meth:`Vars.bring `): 47 | assembles multiple results into a set before calling an operator taking mulitple inputs, 48 | * cst (:meth:`Var.cst `, 49 | :meth:`Vars.cst `): 50 | adds a constant tensor to the graph, 51 | * rename (:meth:`Var.rename `, 52 | :meth:`Vars.rename `): 53 | renames or give a name to a variable in order to call it later. 54 | * vin (:meth:`Var.vin `, 55 | :meth:`Vars.vin `): 56 | adds an input to the graph, 57 | * vout (:meth:`Var.vout `, 58 | :meth:`Vars.vout `): 59 | declares an existing result as an output. 60 | 61 | These methods are implemented in class :class:`onnx_array_api.light_api.var.BaseVar` 62 | 63 | Operator methods 64 | ================ 65 | 66 | They are described in :epkg:`ONNX Operators` and redefined in a stable API 67 | so that the definition should not change depending on this opset. 68 | :class:`onnx_array_api.light_api.Var` defines all operators taking only one input. 69 | :class:`onnx_array_api.light_api.Vars` defines all other operators. 70 | 71 | Numpy methods 72 | ============= 73 | 74 | Numpy users expect methods such as `reshape`, property `shape` or 75 | operator `+` to be available as well and that the case. They are 76 | defined in class :class:`Var ` or 77 | :class:`Vars ` depending on the number of 78 | inputs they require. Their name starts with a lower letter. 79 | 80 | Other domains 81 | ============= 82 | 83 | The following example uses operator *Normalizer* from domain 84 | *ai.onnx.ml*. The operator name is called with the syntax 85 | `.`. The domain may have dots in its name 86 | but it must follow the python definition of a variable. 87 | The operator *Normalizer* becomes `ai.onnx.ml.Normalizer`. 88 | 89 | .. runpython:: 90 | :showcode: 91 | 92 | import numpy as np 93 | from onnx_array_api.light_api import start 94 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot 95 | 96 | model = ( 97 | start(opset=19, opsets={"ai.onnx.ml": 3}) 98 | .vin("X") 99 | .reshape((-1, 1)) 100 | .rename("USE") 101 | .ai.onnx.ml.Normalizer(norm="MAX") 102 | .rename("Y") 103 | .vout() 104 | .to_onnx() 105 | ) 106 | 107 | print(onnx_simple_text_plot(model)) 108 | -------------------------------------------------------------------------------- /_doc/tutorial/numpy_api.rst: -------------------------------------------------------------------------------- 1 | .. _l-numpy-api-onnx: 2 | 3 | ================== 4 | Numpy API for ONNX 5 | ================== 6 | 7 | Many users have difficulties to write onnx graphs. 8 | Many packages tries to symplify it either by implementing 9 | their own api very close to onnx operators 10 | (`sklearn-onnx `_, 11 | `tf2onnx `_, 12 | `spox `_, 13 | `onnx-script `_). 14 | This contribution tries a different approach by implementing 15 | a numpy API for ONNX. It does not cover everything numpy 16 | or ONNX can do but it can easily be used to define 17 | loss functions for example without knowing too much about ONNX. 18 | 19 | .. note:: control flow 20 | 21 | The first version (onnx==1.15) does not support control flow yet (test and loops). 22 | There is no easy syntax for that yet and the main challenge is to deal with local context. 23 | 24 | You read :ref:`l-array-api-painpoint` as well. 25 | 26 | Overview 27 | ======== 28 | 29 | .. toctree:: 30 | 31 | ../auto_examples/plot_first_example 32 | ../auto_examples/plot_onnxruntime 33 | -------------------------------------------------------------------------------- /_doc/tutorial/tools.rst: -------------------------------------------------------------------------------- 1 | ===== 2 | Tools 3 | ===== 4 | 5 | Some of useful tools. 6 | 7 | Text representation 8 | =================== 9 | 10 | Plotting a graph is great but difficult to read when 11 | the graph is big and it is slow. 12 | :func:`onnx_array_api.plotting.text_plot.onnx_simple_text_plot` 13 | prints out a text representation. 14 | 15 | Differences between two models 16 | ============================== 17 | 18 | How to understand the differences between two models 19 | assuming they are producing the same outputs? 20 | Example :ref:`l-onnx-diff-example` shows how to do it. 21 | -------------------------------------------------------------------------------- /_unittests/onnx-numpy-skips.txt: -------------------------------------------------------------------------------- 1 | # API failures 2 | # see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt 3 | # uses __setitem__ 4 | array_api_tests/test_creation_functions.py::test_arange 5 | array_api_tests/test_creation_functions.py::test_asarray_arrays 6 | array_api_tests/test_creation_functions.py::test_empty 7 | array_api_tests/test_creation_functions.py::test_empty_like 8 | array_api_tests/test_creation_functions.py::test_eye 9 | array_api_tests/test_creation_functions.py::test_full 10 | array_api_tests/test_creation_functions.py::test_full_like 11 | array_api_tests/test_creation_functions.py::test_ones 12 | array_api_tests/test_creation_functions.py::test_ones_like 13 | array_api_tests/test_creation_functions.py::test_zeros 14 | array_api_tests/test_creation_functions.py::test_zeros_like 15 | # fails to precision issue 16 | array_api_tests/test_creation_functions.py::test_linspace 17 | array_api_tests/test_creation_functions.py::test_meshgrid 18 | -------------------------------------------------------------------------------- /_unittests/onnx-ort-skips.txt: -------------------------------------------------------------------------------- 1 | # Not implementated by onnxruntime 2 | array_api_tests/test_creation_functions.py::test_arange 3 | array_api_tests/test_creation_functions.py::test_asarray_scalars 4 | array_api_tests/test_creation_functions.py::test_asarray_arrays 5 | array_api_tests/test_creation_functions.py::test_empty 6 | array_api_tests/test_creation_functions.py::test_empty_like 7 | array_api_tests/test_creation_functions.py::test_eye 8 | array_api_tests/test_creation_functions.py::test_full 9 | array_api_tests/test_creation_functions.py::test_full_like 10 | array_api_tests/test_creation_functions.py::test_linspace 11 | array_api_tests/test_creation_functions.py::test_meshgrid 12 | array_api_tests/test_creation_functions.py::test_ones 13 | array_api_tests/test_creation_functions.py::test_ones_like 14 | array_api_tests/test_creation_functions.py::test_zeros 15 | array_api_tests/test_creation_functions.py::test_zeros_like 16 | -------------------------------------------------------------------------------- /_unittests/test_array_api.sh: -------------------------------------------------------------------------------- 1 | export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy 2 | pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_full_like || exit 1 3 | # pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help 4 | pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 5 | -------------------------------------------------------------------------------- /_unittests/ut_array_api/test_array_apis.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from inspect import isfunction, ismethod 3 | import numpy as np 4 | from onnx_array_api.ext_test_case import ExtTestCase 5 | from onnx_array_api.array_api import onnx_numpy as xpn 6 | from onnx_array_api.array_api import onnx_ort as xpo 7 | 8 | # from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor 9 | # from onnx_array_api.ort.ort_tensors import EagerOrtTensor 10 | 11 | 12 | class TestArraysApis(ExtTestCase): 13 | def test_zeros_numpy_1(self): 14 | c = xpn.zeros(1) 15 | d = c.numpy() 16 | self.assertEqualArray(np.array([0], dtype=np.float64), d) 17 | 18 | def test_zeros_ort_1(self): 19 | c = xpo.zeros(1) 20 | d = c.numpy() 21 | self.assertEqualArray(np.array([0], dtype=np.float64), d) 22 | 23 | def test_ffinfo(self): 24 | dt = np.float32 25 | fi1 = np.finfo(dt) 26 | fi2 = xpn.finfo(dt) 27 | fi3 = xpo.finfo(dt) 28 | dt1 = fi1.dtype 29 | dt2 = fi2.dtype 30 | dt3 = fi3.dtype 31 | self.assertEqual(dt2, dt3) 32 | self.assertNotEqual(dt1.__class__, dt2.__class__) 33 | mi1 = fi1.min 34 | mi2 = fi2.min 35 | self.assertEqual(mi1, mi2) 36 | mi1 = fi1.smallest_normal 37 | mi2 = fi2.smallest_normal 38 | self.assertEqual(mi1, mi2) 39 | for n in dir(fi1): 40 | if n.startswith("__"): 41 | continue 42 | if n in {"machar"}: 43 | continue 44 | v1 = getattr(fi1, n) 45 | with self.subTest(att=n): 46 | v2 = getattr(fi2, n) 47 | v3 = getattr(fi3, n) 48 | if isfunction(v1) or ismethod(v1): 49 | try: 50 | v1 = v1() 51 | except TypeError: 52 | continue 53 | v2 = v2() 54 | v3 = v3() 55 | if v1 != v2: 56 | raise AssertionError( 57 | f"12: info disagree on name {n!r}: {v1} != {v2}, " 58 | f"type(v1)={type(v1)}, type(v2)={type(v2)}, " 59 | f"ismethod={ismethod(v1)}." 60 | ) 61 | if v2 != v3: 62 | raise AssertionError( 63 | f"23: info disagree on name {n!r}: {v2} != {v3}, " 64 | f"type(v1)={type(v1)}, type(v2)={type(v2)}, " 65 | f"ismethod={ismethod(v1)}." 66 | ) 67 | 68 | def test_iiinfo(self): 69 | dt = np.int64 70 | fi1 = np.iinfo(dt) 71 | fi2 = xpn.iinfo(dt) 72 | fi3 = xpo.iinfo(dt) 73 | dt1 = fi1.dtype 74 | dt2 = fi2.dtype 75 | dt3 = fi3.dtype 76 | self.assertEqual(dt2, dt3) 77 | self.assertNotEqual(dt1.__class__, dt2.__class__) 78 | mi1 = fi1.min 79 | mi2 = fi2.min 80 | self.assertEqual(mi1, mi2) 81 | for n in dir(fi1): 82 | if n.startswith("__"): 83 | continue 84 | if n in {"machar"}: 85 | continue 86 | v1 = getattr(fi1, n) 87 | with self.subTest(att=n): 88 | v2 = getattr(fi2, n) 89 | v3 = getattr(fi3, n) 90 | if isfunction(v1) or ismethod(v1): 91 | try: 92 | v1 = v1() 93 | except TypeError: 94 | continue 95 | v2 = v2() 96 | v3 = v3() 97 | if v1 != v2: 98 | raise AssertionError( 99 | f"12: info disagree on name {n!r}: {v1} != {v2}, " 100 | f"type(v1)={type(v1)}, type(v2)={type(v2)}, " 101 | f"ismethod={ismethod(v1)}." 102 | ) 103 | if v2 != v3: 104 | raise AssertionError( 105 | f"23: info disagree on name {n!r}: {v2} != {v3}, " 106 | f"type(v1)={type(v1)}, type(v2)={type(v2)}, " 107 | f"ismethod={ismethod(v1)}." 108 | ) 109 | 110 | 111 | if __name__ == "__main__": 112 | unittest.main(verbosity=2) 113 | -------------------------------------------------------------------------------- /_unittests/ut_array_api/test_onnx_ort.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from onnx_array_api.ext_test_case import ExtTestCase 4 | from onnx_array_api.array_api import onnx_ort as xp 5 | from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor 6 | from onnx_array_api.ort.ort_tensors import EagerOrtTensor as EagerTensor 7 | 8 | 9 | class TestOnnxOrt(ExtTestCase): 10 | def test_abs(self): 11 | c = EagerTensor(np.array([4, 5], dtype=np.int64)) 12 | mat = xp.zeros(c, dtype=xp.int64) 13 | matnp = mat.numpy() 14 | self.assertEqual(matnp.shape, (4, 5)) 15 | self.assertNotEmpty(matnp[0, 0]) 16 | a = xp.absolute(mat) 17 | self.assertEqualArray(np.absolute(mat.numpy()), a.numpy()) 18 | 19 | def test_matmul(self): 20 | for cls in [EagerTensor, EagerNumpyTensor]: 21 | for dtype in (np.float32, np.float64): 22 | X = cls( 23 | np.array( 24 | [[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], 25 | dtype=dtype, 26 | ) 27 | ) 28 | coef = cls(np.array([[1e-13, 8]], dtype=dtype).T) 29 | self.assertEqualArray( 30 | np.array([[1e-13, 8]], dtype=dtype), coef.numpy().T 31 | ) 32 | expected = X.numpy() @ coef.numpy() 33 | got = X @ coef 34 | try: 35 | self.assertEqualArray(expected, got.numpy()) 36 | except AssertionError as e: 37 | raise AssertionError( 38 | f"Discrepancies (1) with cls={cls.__name__}, dtype={dtype}" 39 | ) from e 40 | 41 | coef = np.array([[1e-13, 8]], dtype=dtype).T 42 | expected = X.numpy() @ coef 43 | got = X @ coef 44 | try: 45 | self.assertEqualArray(expected, got.numpy()) 46 | except AssertionError as e: 47 | raise AssertionError( 48 | f"Discrepancies (2) with cls={cls.__name__}, dtype={dtype}" 49 | ) from e 50 | 51 | 52 | if __name__ == "__main__": 53 | # import logging 54 | 55 | # logging.basicConfig(level=logging.DEBUG) 56 | # TestOnnxOrt().test_matmul() 57 | unittest.main(verbosity=2) 58 | -------------------------------------------------------------------------------- /_unittests/ut_graph_api/data/debug_7951-CPUep.0.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_graph_api/data/debug_7951-CPUep.0.onnx -------------------------------------------------------------------------------- /_unittests/ut_graph_api/test_graph_builder_optim.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | import onnx 4 | from onnx.inliner import inline_local_functions 5 | from onnx_array_api.ext_test_case import ExtTestCase 6 | from onnx_array_api.graph_api.graph_builder import GraphBuilder 7 | 8 | 9 | class TestGraphBuilderOptim(ExtTestCase): 10 | def test_wcheck_afiles(self): 11 | import onnxruntime 12 | 13 | data = os.path.join(os.path.dirname(__file__), "data") 14 | filename = [f for f in os.listdir(data) if f.endswith(".onnx")] 15 | for f in filename: 16 | with self.subTest(f=f): 17 | onx = onnx.load(os.path.join(data, f)) 18 | sess = onnxruntime.InferenceSession( 19 | os.path.join(data, f), providers=["CPUExecutionProvider"] 20 | ) 21 | assert sess 22 | onxi = inline_local_functions(onx) 23 | sess = onnxruntime.InferenceSession( 24 | onxi.SerializeToString(), providers=["CPUExecutionProvider"] 25 | ) 26 | assert sess 27 | g = GraphBuilder(onxi) 28 | g.optimize(check_order=True) 29 | g.check_order() 30 | onx2 = g.to_onnx() 31 | sess2 = onnxruntime.InferenceSession( 32 | onx2.SerializeToString(), providers=["CPUExecutionProvider"] 33 | ) 34 | assert sess2 35 | 36 | 37 | if __name__ == "__main__": 38 | unittest.main(verbosity=2) 39 | -------------------------------------------------------------------------------- /_unittests/ut_npx/test_sklearn_array_api.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from packaging.version import Version 4 | from onnx.defs import onnx_opset_version 5 | from sklearn import config_context, __version__ as sklearn_version 6 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 7 | from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings 8 | from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor 9 | 10 | 11 | DEFAULT_OPSET = onnx_opset_version() 12 | 13 | 14 | class TestSklearnArrayAPI(ExtTestCase): 15 | @unittest.skipIf( 16 | Version(sklearn_version) <= Version("1.2.2"), 17 | reason="reshape ArrayAPI not followed", 18 | ) 19 | @ignore_warnings(DeprecationWarning) 20 | @unittest.skip("not maintained") 21 | def test_sklearn_array_api_linear_discriminant(self): 22 | X = np.array( 23 | [[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float64 24 | ) 25 | y = np.array([1, 1, 1, 2, 2, 2], dtype=np.int64) 26 | ana = LinearDiscriminantAnalysis() 27 | ana.fit(X, y) 28 | expected = ana.predict(X) 29 | 30 | new_x = EagerNumpyTensor(X) 31 | self.assertStartsWith("EagerNumpyTensor(array([[", repr(new_x)) 32 | with config_context(array_api_dispatch=True): 33 | # It fails if scikit-learn <= 1.2.2 because the ArrayAPI 34 | # is not strictly applied. 35 | got = ana.predict(new_x) 36 | self.assertEqualArray(expected, got.numpy()) 37 | 38 | @unittest.skipIf( 39 | Version(sklearn_version) <= Version("1.2.2"), 40 | reason="reshape ArrayAPI not followed", 41 | ) 42 | @ignore_warnings(DeprecationWarning) 43 | @unittest.skip("not maintained") 44 | def test_sklearn_array_api_linear_discriminant_float32(self): 45 | X = np.array( 46 | [[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float32 47 | ) 48 | y = np.array([1, 1, 1, 2, 2, 2], dtype=np.int64) 49 | ana = LinearDiscriminantAnalysis() 50 | ana.fit(X, y) 51 | expected = ana.predict(X) 52 | 53 | new_x = EagerNumpyTensor(X) 54 | self.assertStartsWith("EagerNumpyTensor(array([[", repr(new_x)) 55 | with config_context(array_api_dispatch=True): 56 | # It fails if scikit-learn <= 1.2.2 because the ArrayAPI 57 | # is not strictly applied. 58 | got = ana.predict(new_x) 59 | self.assertEqualArray(expected, got.numpy()) 60 | 61 | 62 | if __name__ == "__main__": 63 | # import logging 64 | 65 | # logging.basicConfig(level=logging.DEBUG) 66 | unittest.main(verbosity=2) 67 | -------------------------------------------------------------------------------- /_unittests/ut_ort/data/prof_base.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_ort/data/prof_base.xlsx -------------------------------------------------------------------------------- /_unittests/ut_ort/data/prof_opti.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_ort/data/prof_opti.xlsx -------------------------------------------------------------------------------- /_unittests/ut_ort/test_ort_optimizer.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from onnx.defs import onnx_opset_version 5 | from onnx_array_api.npx import absolute, jit_onnx 6 | from onnx_array_api.ext_test_case import ExtTestCase 7 | from onnx_array_api.ort.ort_optimizers import ort_optimized_model 8 | 9 | 10 | DEFAULT_OPSET = onnx_opset_version() 11 | 12 | 13 | class TestOrtOptimizer(ExtTestCase): 14 | def test_ort_optimizers(self): 15 | def l1_loss(x, y): 16 | return absolute(x - y).sum() 17 | 18 | def l2_loss(x, y): 19 | return ((x - y) ** 2).sum() 20 | 21 | def myloss(x, y): 22 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1]) 23 | 24 | jitted_myloss = jit_onnx(myloss) 25 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 26 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32) 27 | jitted_myloss(x, y) 28 | onx = jitted_myloss.get_onnx() 29 | self.assertRaise(lambda: ort_optimized_model(onx, "NO"), ValueError) 30 | optimized = ort_optimized_model(onx) 31 | self.assertIn('op_type: "Squeeze"', str(optimized)) 32 | self.assertIn("initializer {", str(optimized)) 33 | 34 | 35 | if __name__ == "__main__": 36 | unittest.main(verbosity=2) 37 | -------------------------------------------------------------------------------- /_unittests/ut_ort/test_ort_profile.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import numpy as np 4 | from pandas import DataFrame, read_excel 5 | from onnx_array_api.npx import absolute, jit_onnx 6 | from onnx_array_api.ext_test_case import ExtTestCase 7 | from onnx_array_api.ort.ort_optimizers import ort_optimized_model 8 | from onnx_array_api.ort.ort_profile import ort_profile, merge_ort_profile 9 | from onnxruntime.capi._pybind_state import ( 10 | OrtValue as C_OrtValue, 11 | OrtDevice as C_OrtDevice, 12 | ) 13 | 14 | 15 | class TestOrtProfile(ExtTestCase): 16 | def test_ort_profile(self): 17 | def l1_loss(x, y): 18 | return absolute(x - y).sum() 19 | 20 | def l2_loss(x, y): 21 | return ((x - y) ** 2).sum() 22 | 23 | def myloss(x, y): 24 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1]) 25 | 26 | jitted_myloss = jit_onnx(myloss) 27 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 28 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32) 29 | jitted_myloss(x, y) 30 | onx = jitted_myloss.get_onnx() 31 | feeds = {"x0": x, "x1": y} 32 | self.assertRaise(lambda: ort_optimized_model(onx, "NO"), ValueError) 33 | optimized = ort_optimized_model(onx) 34 | prof = ort_profile(optimized, feeds) 35 | self.assertIsInstance(prof, DataFrame) 36 | prof = ort_profile(optimized, feeds, as_df=False) 37 | self.assertIsInstance(prof, list) 38 | 39 | def test_ort_profile_first_it_out(self): 40 | def l1_loss(x, y): 41 | return absolute(x - y).sum() 42 | 43 | def l2_loss(x, y): 44 | return ((x - y) ** 2).sum() 45 | 46 | def myloss(x, y): 47 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1]) 48 | 49 | jitted_myloss = jit_onnx(myloss) 50 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 51 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32) 52 | jitted_myloss(x, y) 53 | onx = jitted_myloss.get_onnx() 54 | feeds = {"x0": x, "x1": y} 55 | self.assertRaise(lambda: ort_optimized_model(onx, "NO"), ValueError) 56 | optimized = ort_optimized_model(onx) 57 | prof = ort_profile(optimized, feeds) 58 | events = { 59 | "kernel_time", 60 | "SequentialExecutor::Execute", 61 | "model_run", 62 | "model_loading_array", 63 | "session_initialization", 64 | } 65 | self.assertEqual(set(prof["event_name"]), events) 66 | agg = ort_profile(optimized, feeds, first_it_out=True, agg=True) 67 | self.assertIsInstance(agg, DataFrame) 68 | self.assertLess(agg.shape[0], prof.shape[0]) 69 | self.assertEqual(set(agg.reset_index(drop=False)["event_name"]), events) 70 | agg = ort_profile( 71 | optimized, feeds, first_it_out=True, agg=True, agg_op_name=False 72 | ) 73 | self.assertIsInstance(agg, DataFrame) 74 | self.assertLess(agg.shape[0], prof.shape[0]) 75 | self.assertEqual(set(agg.reset_index(drop=False)["event_name"]), events) 76 | 77 | def test_ort_profile_ort_value(self): 78 | def to_ort_value(m): 79 | device = C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) 80 | ort_value = C_OrtValue.ortvalue_from_numpy(m, device) 81 | return ort_value 82 | 83 | def l1_loss(x, y): 84 | return absolute(x - y).sum() 85 | 86 | def l2_loss(x, y): 87 | return ((x - y) ** 2).sum() 88 | 89 | def myloss(x, y): 90 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1]) 91 | 92 | jitted_myloss = jit_onnx(myloss) 93 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 94 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32) 95 | jitted_myloss(x, y) 96 | onx = jitted_myloss.get_onnx() 97 | np_feeds = {"x0": x, "x1": y} 98 | feeds = {k: to_ort_value(v) for k, v in np_feeds.items()} 99 | 100 | self.assertRaise(lambda: ort_optimized_model(onx, "NO"), ValueError) 101 | optimized = ort_optimized_model(onx) 102 | prof = ort_profile(optimized, feeds) 103 | self.assertIsInstance(prof, DataFrame) 104 | prof = ort_profile(optimized, feeds, as_df=False) 105 | self.assertIsInstance(prof, list) 106 | 107 | def test_merge_ort_profile(self): 108 | data = os.path.join(os.path.dirname(__file__), "data") 109 | df1 = read_excel(os.path.join(data, "prof_base.xlsx")) 110 | df2 = read_excel(os.path.join(data, "prof_opti.xlsx")) 111 | merged, gr = merge_ort_profile(df1, df2) 112 | self.assertEqual(merged.shape, (23, 9)) 113 | self.assertEqual( 114 | list(merged.columns), 115 | [ 116 | "args_op_name", 117 | "args_output_type_shape", 118 | "args_input_type_shape", 119 | "args_provider", 120 | "idx", 121 | "durbase", 122 | "countbase", 123 | "duropti", 124 | "countopti", 125 | ], 126 | ) 127 | self.assertEqual(gr.shape, (19, 4)) 128 | self.assertEqual( 129 | list(gr.columns), ["durbase", "duropti", "countbase", "countopti"] 130 | ) 131 | 132 | 133 | if __name__ == "__main__": 134 | unittest.main(verbosity=2) 135 | -------------------------------------------------------------------------------- /_unittests/ut_ort/test_sklearn_array_api_ort.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from packaging.version import Version 4 | from onnx.defs import onnx_opset_version 5 | from sklearn import config_context, __version__ as sklearn_version 6 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 7 | from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows 8 | from onnx_array_api.ort.ort_tensors import EagerOrtTensor, OrtTensor 9 | 10 | 11 | DEFAULT_OPSET = onnx_opset_version() 12 | 13 | 14 | class TestSklearnArrayAPIOrt(ExtTestCase): 15 | @unittest.skipIf( 16 | Version(sklearn_version) <= Version("1.2.2"), 17 | reason="reshape ArrayAPI not followed", 18 | ) 19 | @skipif_ci_windows("Unstable on Windows.") 20 | @unittest.skip("discontinued") 21 | def test_sklearn_array_api_linear_discriminant_ort(self): 22 | X = np.array( 23 | [[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float64 24 | ) 25 | y = np.array([1, 1, 1, 2, 2, 2], dtype=np.int64) 26 | ana = LinearDiscriminantAnalysis() 27 | ana.fit(X, y) 28 | expected = ana.predict(X) 29 | 30 | new_x = EagerOrtTensor(OrtTensor.from_array(X)) 31 | self.assertEqual(new_x.device_name, "Cpu") 32 | self.assertStartsWith( 33 | "EagerOrtTensor(OrtTensor.from_array(array([[", repr(new_x) 34 | ) 35 | with config_context(array_api_dispatch=True): 36 | got = ana.predict(new_x) 37 | self.assertEqualArray(expected, got.numpy()) 38 | 39 | @unittest.skipIf( 40 | Version(sklearn_version) <= Version("1.2.2"), 41 | reason="reshape ArrayAPI not followed", 42 | ) 43 | @skipif_ci_windows("Unstable on Windows.") 44 | @unittest.skip("discontinued") 45 | def test_sklearn_array_api_linear_discriminant_ort_float32(self): 46 | X = np.array( 47 | [[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float32 48 | ) 49 | y = np.array([1, 1, 1, 2, 2, 2], dtype=np.int64) 50 | ana = LinearDiscriminantAnalysis() 51 | ana.fit(X, y) 52 | expected = ana.predict(X) 53 | 54 | new_x = EagerOrtTensor(OrtTensor.from_array(X)) 55 | self.assertEqual(new_x.device_name, "Cpu") 56 | self.assertStartsWith( 57 | "EagerOrtTensor(OrtTensor.from_array(array([[", repr(new_x) 58 | ) 59 | with config_context(array_api_dispatch=True): 60 | got = ana.predict(new_x) 61 | self.assertEqualArray(expected, got.numpy()) 62 | 63 | 64 | if __name__ == "__main__": 65 | # import logging 66 | 67 | # logging.basicConfig(level=logging.DEBUG) 68 | unittest.main(verbosity=2) 69 | -------------------------------------------------------------------------------- /_unittests/ut_plotting/data/bug_Hardmax.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_plotting/data/bug_Hardmax.onnx -------------------------------------------------------------------------------- /_unittests/ut_plotting/data/onnx_text_plot_tree_cls_2.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_plotting/data/onnx_text_plot_tree_cls_2.onnx -------------------------------------------------------------------------------- /_unittests/ut_plotting/data/tree_torch.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_plotting/data/tree_torch.onnx -------------------------------------------------------------------------------- /_unittests/ut_plotting/test_graphviz.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | import onnx.parser 4 | from onnx_array_api.ext_test_case import ( 5 | ExtTestCase, 6 | skipif_ci_windows, 7 | skipif_ci_apple, 8 | ) 9 | from onnx_array_api.plotting.dot_plot import to_dot 10 | from onnx_array_api.plotting.graphviz_helper import draw_graph_graphviz, plot_dot 11 | 12 | 13 | class TestGraphviz(ExtTestCase): 14 | @classmethod 15 | def _get_graph(cls): 16 | return onnx.parser.parse_model( 17 | """ 18 | 19 | agraph (float[N] x) => (float[N] z) { 20 | two = Constant () 21 | four = Add(two, two) 22 | z = Mul(x, x) 23 | }""" 24 | ) 25 | 26 | @skipif_ci_windows("graphviz not installed") 27 | @skipif_ci_apple("graphviz not installed") 28 | def test_draw_graph_graphviz(self): 29 | fout = "test_draw_graph_graphviz.png" 30 | dot = to_dot(self._get_graph()) 31 | draw_graph_graphviz(dot, image=fout) 32 | self.assertExists(os.path.exists(fout)) 33 | 34 | @skipif_ci_windows("graphviz not installed") 35 | @skipif_ci_apple("graphviz not installed") 36 | def test_draw_graph_graphviz_proto(self): 37 | fout = "test_draw_graph_graphviz_proto.png" 38 | dot = self._get_graph() 39 | draw_graph_graphviz(dot, image=fout) 40 | self.assertExists(os.path.exists(fout)) 41 | 42 | @skipif_ci_windows("graphviz not installed") 43 | @skipif_ci_apple("graphviz not installed") 44 | def test_plot_dot(self): 45 | dot = to_dot(self._get_graph()) 46 | ax = plot_dot(dot) 47 | ax.get_figure().savefig("test_plot_dot.png") 48 | 49 | 50 | if __name__ == "__main__": 51 | unittest.main(verbosity=2) 52 | -------------------------------------------------------------------------------- /_unittests/ut_plotting/test_stat_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | import pandas 4 | import matplotlib.pyplot as plt 5 | from onnx_array_api.ext_test_case import ExtTestCase, matplotlib_test 6 | from onnx_array_api.plotting.stat_plot import plot_ort_profile 7 | 8 | 9 | class TestStatPlot(ExtTestCase): 10 | @matplotlib_test() 11 | def test_plot_ort_profile(self): 12 | data = os.path.join(os.path.dirname(__file__), "data", "prof.csv") 13 | df = pandas.read_csv(data) 14 | _, ax = plt.subplots(2, 1) 15 | plot_ort_profile(df, ax0=ax[0], ax1=ax[1]) 16 | 17 | 18 | if __name__ == "__main__": 19 | unittest.main() 20 | -------------------------------------------------------------------------------- /_unittests/ut_reference/test_array_tensor.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from onnx import TensorProto 4 | from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info 5 | from onnx_array_api.ext_test_case import ExtTestCase 6 | from onnx_array_api.reference import ( 7 | to_array_extended, 8 | from_array_extended, 9 | ExtendedReferenceEvaluator, 10 | ) 11 | 12 | 13 | class TestArrayTensor(ExtTestCase): 14 | def test_from_array(self): 15 | for dt in (np.float32, np.float16, np.uint16, np.uint8): 16 | with self.subTest(dtype=dt): 17 | a = np.array([0, 1, 2], dtype=dt) 18 | t = from_array_extended(a, "a") 19 | b = to_array_extended(t) 20 | self.assertEqualArray(a, b) 21 | t2 = from_array_extended(b, "a") 22 | self.assertEqual(t.SerializeToString(), t2.SerializeToString()) 23 | 24 | def test_from_array_f8(self): 25 | def make_model_f8(fr, to): 26 | model = make_model( 27 | make_graph( 28 | [make_node("Cast", ["X"], ["Y"], to=to)], 29 | "cast", 30 | [make_tensor_value_info("X", fr, None)], 31 | [make_tensor_value_info("Y", to, None)], 32 | ) 33 | ) 34 | return model 35 | 36 | for dt in (np.float32, np.float16, np.uint16, np.uint8): 37 | with self.subTest(dtype=dt): 38 | a = np.array([0, 1, 2], dtype=dt) 39 | b = from_array_extended(a, "a") 40 | for to in [ 41 | TensorProto.FLOAT8E4M3FN, 42 | TensorProto.FLOAT8E4M3FNUZ, 43 | TensorProto.FLOAT8E5M2, 44 | TensorProto.FLOAT8E5M2FNUZ, 45 | TensorProto.BFLOAT16, 46 | ]: 47 | with self.subTest(fr=b.data_type, to=to): 48 | model = make_model_f8(b.data_type, to) 49 | ref = ExtendedReferenceEvaluator(model) 50 | got = ref.run(None, {"X": a})[0] 51 | back = from_array_extended(got, "a") 52 | self.assertEqual(to, back.data_type) 53 | 54 | 55 | if __name__ == "__main__": 56 | unittest.main(verbosity=2) 57 | -------------------------------------------------------------------------------- /_unittests/ut_reference/test_reference_ops.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from onnx import TensorProto 4 | from onnx.helper import ( 5 | make_graph, 6 | make_model, 7 | make_node, 8 | make_tensor_value_info, 9 | make_opsetid, 10 | ) 11 | from onnx_array_api.ext_test_case import ExtTestCase 12 | from onnx_array_api.reference import ExtendedReferenceEvaluator 13 | 14 | 15 | class TestReferenceOps(ExtTestCase): 16 | 17 | def test_fused_matmul(self): 18 | model = make_model( 19 | make_graph( 20 | [make_node("FusedMatMul", ["X", "Y"], ["Z"], domain="com.microsoft")], 21 | "name", 22 | [ 23 | make_tensor_value_info("X", TensorProto.FLOAT, None), 24 | make_tensor_value_info("Y", TensorProto.FLOAT, None), 25 | ], 26 | [make_tensor_value_info("Z", TensorProto.FLOAT, None)], 27 | ), 28 | opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)], 29 | ) 30 | ref = ExtendedReferenceEvaluator(model) 31 | a = np.arange(4).reshape(-1, 2) 32 | got = ref.run(None, {"X": a, "Y": a}) 33 | self.assertEqualArray(a @ a, got[0]) 34 | 35 | def test_fused_matmul11(self): 36 | model = make_model( 37 | make_graph( 38 | [ 39 | make_node( 40 | "FusedMatMul", 41 | ["X", "Y"], 42 | ["Z"], 43 | transA=1, 44 | transB=1, 45 | domain="com.microsoft", 46 | ) 47 | ], 48 | "name", 49 | [ 50 | make_tensor_value_info("X", TensorProto.FLOAT, None), 51 | make_tensor_value_info("Y", TensorProto.FLOAT, None), 52 | ], 53 | [make_tensor_value_info("Z", TensorProto.FLOAT, None)], 54 | ), 55 | opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)], 56 | ) 57 | ref = ExtendedReferenceEvaluator(model) 58 | a = np.arange(4).reshape(-1, 2) 59 | got = ref.run(None, {"X": a, "Y": a}) 60 | self.assertEqualArray(a.T @ a.T, got[0]) 61 | 62 | def test_memcpy(self): 63 | model = make_model( 64 | make_graph( 65 | [ 66 | make_node("MemcpyToHost", ["X"], ["Z"]), 67 | make_node("MemcpyFromHost", ["X"], ["Z"]), 68 | ], 69 | "name", 70 | [make_tensor_value_info("X", TensorProto.FLOAT, None)], 71 | [make_tensor_value_info("Z", TensorProto.FLOAT, None)], 72 | ), 73 | opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)], 74 | ir_version=9, 75 | ) 76 | a = np.arange(4).reshape(-1, 2).astype(np.float32) 77 | ref = ExtendedReferenceEvaluator(model) 78 | got = ref.run(None, {"X": a}) 79 | self.assertEqualArray(a, got[0]) 80 | 81 | def test_quick_gelu(self): 82 | from onnxruntime import InferenceSession 83 | 84 | for alpha in [0.0, 2.0]: 85 | model = make_model( 86 | make_graph( 87 | [ 88 | make_node( 89 | "QuickGelu", 90 | ["X"], 91 | ["Z"], 92 | domain="com.microsoft", 93 | alpha=alpha, 94 | ) 95 | ], 96 | "name", 97 | [make_tensor_value_info("X", TensorProto.FLOAT, None)], 98 | [make_tensor_value_info("Z", TensorProto.FLOAT, None)], 99 | ), 100 | opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)], 101 | ir_version=9, 102 | ) 103 | sess = InferenceSession( 104 | model.SerializeToString(), providers=["CPUExecutionProvider"] 105 | ) 106 | a = np.arange(4).reshape(-1, 2).astype(np.float32) 107 | expected = sess.run(None, {"X": a}) 108 | ref = ExtendedReferenceEvaluator(model) 109 | got = ref.run(None, {"X": a}) 110 | self.assertEqualArray(expected[0], got[0]) 111 | 112 | def test_scatter_elements(self): 113 | model = make_model( 114 | make_graph( 115 | [ 116 | make_node( 117 | "ScatterElements", 118 | ["data", "indices", "updates"], 119 | ["Z"], 120 | axis=3, 121 | reduction="add", 122 | ) 123 | ], 124 | "name", 125 | [ 126 | make_tensor_value_info("data", TensorProto.FLOAT, None), 127 | make_tensor_value_info("indices", TensorProto.INT64, None), 128 | make_tensor_value_info("updates", TensorProto.FLOAT, None), 129 | ], 130 | [make_tensor_value_info("Z", TensorProto.FLOAT, None)], 131 | ), 132 | opset_imports=[make_opsetid("", 18)], 133 | ) 134 | data = np.zeros(2**4, dtype=np.float32).reshape((2, 2, 2, 2)) 135 | indices = np.array([[[[0]]]], dtype=np.int64) 136 | updates = np.array([[[[1]]]], dtype=np.float32) 137 | y = np.array( 138 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32 139 | ).reshape((2, 2, 2, 2)) 140 | ref = ExtendedReferenceEvaluator(model) 141 | got = ref.run(None, {"data": data, "indices": indices, "updates": updates}) 142 | self.assertEqualArray(y, got[0]) 143 | 144 | 145 | if __name__ == "__main__": 146 | unittest.main(verbosity=2) 147 | -------------------------------------------------------------------------------- /_unittests/ut_translate_api/_data/custom_ops_type_inference_fails_0.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_translate_api/_data/custom_ops_type_inference_fails_0.onnx -------------------------------------------------------------------------------- /_unittests/ut_translate_api/_data/stft_inlined_batch_1.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_translate_api/_data/stft_inlined_batch_1.onnx -------------------------------------------------------------------------------- /_unittests/ut_validation/data/small.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_validation/data/small.onnx -------------------------------------------------------------------------------- /_unittests/ut_validation/test_diff.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from onnx import load 3 | from onnx.checker import check_model 4 | from onnx_array_api.ext_test_case import ExtTestCase 5 | from onnx_array_api.ort.ort_optimizers import ort_optimized_model 6 | from onnx_array_api.validation.diff import text_diff, html_diff 7 | 8 | 9 | class TestDiff(ExtTestCase): 10 | def test_diff_optimized(self): 11 | data = self.relative_path(__file__, "data", "small.onnx") 12 | with open(data, "rb") as f: 13 | model = load(f) 14 | optimized = ort_optimized_model(model) 15 | check_model(optimized) 16 | diff = text_diff(model, optimized) 17 | self.assertIn("^^^^^^^^^^^^^^^^", diff) 18 | ht = html_diff(model, optimized) 19 | self.assertIn("", ht) 20 | 21 | 22 | if __name__ == "__main__": 23 | unittest.main(verbosity=2) 24 | -------------------------------------------------------------------------------- /_unittests/ut_validation/test_docs.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from onnx.reference import ReferenceEvaluator 4 | from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows 5 | from onnx_array_api.validation.docs import make_euclidean, make_euclidean_skl2onnx 6 | 7 | 8 | class TestDocs(ExtTestCase): 9 | def test_make_euclidean(self): 10 | model = make_euclidean() 11 | 12 | ref = ReferenceEvaluator(model) 13 | X = np.random.rand(3, 4).astype(np.float32) 14 | Y = np.random.rand(3, 4).astype(np.float32) 15 | expected = ((X - Y) ** 2).sum(keepdims=1) 16 | got = ref.run(None, {"X": X, "Y": Y})[0] 17 | self.assertEqualArray(expected, got) 18 | 19 | def test_make_euclidean_skl2onnx(self): 20 | model = make_euclidean_skl2onnx() 21 | 22 | ref = ReferenceEvaluator(model) 23 | X = np.random.rand(3, 4).astype(np.float32) 24 | Y = np.random.rand(3, 4).astype(np.float32) 25 | expected = ((X - Y) ** 2).sum(keepdims=1) 26 | got = ref.run(None, {"X": X, "Y": Y})[0] 27 | self.assertEqualArray(expected, got) 28 | 29 | @skipif_ci_windows("Unstable on Windows.") 30 | def test_make_euclidean_np(self): 31 | from onnx_array_api.npx import jit_onnx 32 | 33 | def l2_loss(x, y): 34 | return ((x - y) ** 2).sum(keepdims=1) 35 | 36 | jitted_myloss = jit_onnx(l2_loss) 37 | dummy1 = np.array([0], dtype=np.float32) 38 | dummy2 = np.array([1], dtype=np.float32) 39 | # unstable on windows? 40 | jitted_myloss(dummy1, dummy2) 41 | model = jitted_myloss.get_onnx() 42 | 43 | ref = ReferenceEvaluator(model) 44 | X = np.random.rand(3, 4).astype(np.float32) 45 | Y = np.random.rand(3, 4).astype(np.float32) 46 | expected = ((X - Y) ** 2).sum(keepdims=1) 47 | got = ref.run(None, {"x0": X, "x1": Y})[0] 48 | self.assertEqualArray(expected, got) 49 | 50 | def test_make_euclidean_light(self): 51 | from onnx_array_api.light_api import start 52 | 53 | model = ( 54 | start() 55 | .vin("X") 56 | .vin("Y") 57 | .bring("X", "Y") 58 | .Sub() 59 | .rename("dxy") 60 | .cst(np.array([2], dtype=np.int64), "two") 61 | .bring("dxy", "two") 62 | .Pow() 63 | .ReduceSum() 64 | .rename("Z") 65 | .vout() 66 | .to_onnx() 67 | ) 68 | 69 | ref = ReferenceEvaluator(model) 70 | X = np.random.rand(3, 4).astype(np.float32) 71 | Y = np.random.rand(3, 4).astype(np.float32) 72 | expected = ((X - Y) ** 2).sum(keepdims=1) 73 | got = ref.run(None, {"X": X, "Y": Y})[0] 74 | self.assertEqualArray(expected, got) 75 | 76 | def test_ort_make_euclidean(self): 77 | from onnxruntime import InferenceSession 78 | 79 | model = make_euclidean(opset=18) 80 | 81 | ref = InferenceSession( 82 | model.SerializeToString(), providers=["CPUExecutionProvider"] 83 | ) 84 | X = np.random.rand(3, 4).astype(np.float32) 85 | Y = np.random.rand(3, 4).astype(np.float32) 86 | expected = ((X - Y) ** 2).sum(keepdims=1) 87 | got = ref.run(None, {"X": X, "Y": Y})[0] 88 | self.assertEqualArray(expected, got, atol=1e-6) 89 | 90 | 91 | if __name__ == "__main__": 92 | unittest.main(verbosity=2) 93 | -------------------------------------------------------------------------------- /_unittests/ut_validation/test_tools.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from onnx import load 3 | from onnx.checker import check_model 4 | from onnx_array_api.ext_test_case import ExtTestCase 5 | from onnx_array_api.validation.tools import randomize_proto 6 | 7 | 8 | class TestTools(ExtTestCase): 9 | def test_randomize_proto(self): 10 | data = self.relative_path(__file__, "data", "small.onnx") 11 | with open(data, "rb") as f: 12 | model = load(f) 13 | check_model(model) 14 | rnd = randomize_proto(model) 15 | self.assertEqual(len(model.SerializeToString()), len(rnd.SerializeToString())) 16 | check_model(rnd) 17 | 18 | 19 | if __name__ == "__main__": 20 | unittest.main(verbosity=2) 21 | -------------------------------------------------------------------------------- /_unittests/ut_xrun_doc/test_command_lines1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import unittest 4 | from contextlib import redirect_stdout 5 | from io import StringIO 6 | from onnx import TensorProto 7 | from onnx.helper import ( 8 | make_graph, 9 | make_model, 10 | make_node, 11 | make_opsetid, 12 | make_tensor_value_info, 13 | ) 14 | from onnx_array_api.ext_test_case import ExtTestCase 15 | from onnx_array_api._command_lines_parser import ( 16 | get_main_parser, 17 | get_parser_compare, 18 | get_parser_translate, 19 | get_parser_replace, 20 | main, 21 | ) 22 | 23 | 24 | class TestCommandLines1(ExtTestCase): 25 | def test_main_parser(self): 26 | st = StringIO() 27 | with redirect_stdout(st): 28 | get_main_parser().print_help() 29 | text = st.getvalue() 30 | self.assertIn("translate", text) 31 | 32 | def test_parser_translate(self): 33 | st = StringIO() 34 | with redirect_stdout(st): 35 | get_parser_translate().print_help() 36 | text = st.getvalue() 37 | self.assertIn("model", text) 38 | 39 | def test_parser_replace(self): 40 | st = StringIO() 41 | with redirect_stdout(st): 42 | get_parser_replace().print_help() 43 | text = st.getvalue() 44 | self.assertIn("model", text) 45 | 46 | def test_command_translate(self): 47 | X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) 48 | Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6]) 49 | Z = make_tensor_value_info("Z", TensorProto.FLOAT, [None, None]) 50 | graph = make_graph( 51 | [ 52 | make_node("Add", ["X", "Y"], ["res"]), 53 | make_node("Cos", ["res"], ["Z"]), 54 | ], 55 | "g", 56 | [X, Y], 57 | [Z], 58 | ) 59 | onnx_model = make_model(graph, opset_imports=[make_opsetid("", 18)]) 60 | 61 | with tempfile.TemporaryDirectory() as root: 62 | model_file = os.path.join(root, "model.onnx") 63 | with open(model_file, "wb") as f: 64 | f.write(onnx_model.SerializeToString()) 65 | 66 | args = ["translate", "-m", model_file] 67 | st = StringIO() 68 | with redirect_stdout(st): 69 | main(args) 70 | 71 | code = st.getvalue() 72 | self.assertIn("model = make_model(", code) 73 | 74 | args = ["translate", "-m", model_file, "-a", "light"] 75 | st = StringIO() 76 | with redirect_stdout(st): 77 | main(args) 78 | 79 | code = st.getvalue() 80 | self.assertIn("start(opset=", code) 81 | 82 | def test_parser_compare(self): 83 | st = StringIO() 84 | with redirect_stdout(st): 85 | get_parser_compare().print_help() 86 | text = st.getvalue() 87 | self.assertIn("model1", text) 88 | 89 | def test_command_compare(self): 90 | X = make_tensor_value_info("X", TensorProto.FLOAT, [5, 6]) 91 | Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6]) 92 | Z = make_tensor_value_info("Z", TensorProto.FLOAT, [5, 6]) 93 | graph = make_graph( 94 | [ 95 | make_node("Add", ["X", "Y"], ["res"]), 96 | make_node("Cos", ["res"], ["Z"]), 97 | ], 98 | "g", 99 | [X, Y], 100 | [Z], 101 | ) 102 | onnx_model = make_model(graph, opset_imports=[make_opsetid("", 18)]) 103 | 104 | with tempfile.TemporaryDirectory() as root: 105 | model_file = os.path.join(root, "model.onnx") 106 | with open(model_file, "wb") as f: 107 | f.write(onnx_model.SerializeToString()) 108 | 109 | args = ["compare", "-m1", model_file, "-m2", model_file, "-v", "1"] 110 | st = StringIO() 111 | with redirect_stdout(st): 112 | main(args) 113 | 114 | code = st.getvalue() 115 | self.assertIn("[compare_onnx_execution]", code) 116 | self.assertIn("ADFF", code) 117 | 118 | 119 | if __name__ == "__main__": 120 | unittest.main(verbosity=2) 121 | -------------------------------------------------------------------------------- /_unittests/ut_xrun_doc/test_documentation_examples.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import sys 4 | import importlib 5 | import subprocess 6 | import time 7 | from onnx_array_api import __file__ as onnx_array_api_file 8 | from onnx_array_api.ext_test_case import ExtTestCase, is_windows 9 | 10 | VERBOSE = 0 11 | ROOT = os.path.realpath(os.path.abspath(os.path.join(onnx_array_api_file, "..", ".."))) 12 | 13 | 14 | def import_source(module_file_path, module_name): 15 | if not os.path.exists(module_file_path): 16 | raise FileNotFoundError(module_file_path) 17 | module_spec = importlib.util.spec_from_file_location(module_name, module_file_path) 18 | if module_spec is None: 19 | raise FileNotFoundError( 20 | "Unable to find '{}' in '{}'.".format(module_name, module_file_path) 21 | ) 22 | module = importlib.util.module_from_spec(module_spec) 23 | return module_spec.loader.exec_module(module) 24 | 25 | 26 | class TestDocumentationExamples(ExtTestCase): 27 | def run_test(self, fold: str, name: str, verbose=0) -> int: 28 | ppath = os.environ.get("PYTHONPATH", "") 29 | if not ppath: 30 | os.environ["PYTHONPATH"] = ROOT 31 | elif ROOT not in ppath: 32 | sep = ";" if is_windows() else ":" 33 | os.environ["PYTHONPATH"] = ppath + sep + ROOT 34 | perf = time.perf_counter() 35 | try: 36 | mod = import_source(fold, os.path.splitext(name)[0]) 37 | assert mod is not None 38 | except FileNotFoundError: 39 | # try another way 40 | cmds = [sys.executable, "-u", os.path.join(fold, name)] 41 | p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 42 | res = p.communicate() 43 | out, err = res 44 | st = err.decode("ascii", errors="ignore") 45 | if st and "Traceback" in st: 46 | if '"dot" not found in path.' in st: 47 | # dot not installed, this part 48 | # is tested in onnx framework 49 | if verbose: 50 | print(f"failed: {name!r} due to missing dot.") 51 | return 0 52 | raise AssertionError( # noqa: B904 53 | "Example '{}' (cmd: {} - exec_prefix='{}') " 54 | "failed due to\n{}" 55 | "".format(name, cmds, sys.exec_prefix, st) 56 | ) 57 | dt = time.perf_counter() - perf 58 | if verbose: 59 | print(f"{dt:.3f}: run {name!r}") 60 | return 1 61 | 62 | @classmethod 63 | def add_test_methods(cls): 64 | this = os.path.abspath(os.path.dirname(__file__)) 65 | fold = os.path.normpath(os.path.join(this, "..", "..", "_doc", "examples")) 66 | found = os.listdir(fold) 67 | for name in found: 68 | if not name.startswith("plot_") or not name.endswith(".py"): 69 | continue 70 | short_name = os.path.split(os.path.splitext(name)[0])[-1] 71 | 72 | def _test_(self, name=name): 73 | res = self.run_test(fold, name, verbose=VERBOSE) 74 | self.assertTrue(res) 75 | 76 | setattr(cls, f"test_{short_name}", _test_) 77 | 78 | 79 | TestDocumentationExamples.add_test_methods() 80 | 81 | if __name__ == "__main__": 82 | unittest.main(verbosity=2) 83 | -------------------------------------------------------------------------------- /_unittests/ut_xrun_doc/test_profiling.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief test tree node (time=5s) 3 | """ 4 | 5 | import os 6 | import sys 7 | import time 8 | import unittest 9 | from io import StringIO 10 | from pstats import SortKey 11 | 12 | import pandas 13 | 14 | from onnx_array_api import __file__ as rootfile 15 | from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings 16 | from onnx_array_api.profiling import ProfileNode, profile, profile2df, profile2graph 17 | 18 | 19 | class TestProfiling(ExtTestCase): 20 | def test_profile(self): 21 | def simple(): 22 | df = pandas.DataFrame( 23 | [{"A": "x", "AA": "xx", "AAA": "xxx"}, {"AA": "xxxxxxx", "AAA": "xxx"}] 24 | ) 25 | return df.to_csv(StringIO()) 26 | 27 | rootrem = os.path.normpath( 28 | os.path.abspath(os.path.join(os.path.dirname(rootfile), "..")) 29 | ) 30 | ps, res = profile(simple, rootrem=rootrem) 31 | res = res.replace("\\", "/") 32 | self.assertIn("function calls", res) 33 | self.assertNotEmpty(ps) 34 | 35 | ps, res = profile(simple) 36 | res = res.replace("\\", "/") 37 | self.assertIn("function calls", res) 38 | self.assertNotEmpty(ps) 39 | 40 | @ignore_warnings(FutureWarning) 41 | def test_profile_df(self): 42 | def simple(): 43 | def simple2(): 44 | df = pandas.DataFrame( 45 | [ 46 | {"A": "x", "AA": "xx", "AAA": "xxx"}, 47 | {"AA": "xxxxxxx", "AAA": "xxx"}, 48 | ] 49 | ) 50 | return df.to_csv(StringIO()) 51 | 52 | return simple2() 53 | 54 | rootrem = os.path.normpath( 55 | os.path.abspath(os.path.join(os.path.dirname(rootfile), "..")) 56 | ) 57 | ps, df = profile(simple, rootrem=rootrem, as_df=True) 58 | self.assertIsInstance(df, pandas.DataFrame) 59 | self.assertEqual(df.loc[0, "namefct"].split("-")[-1], "simple") 60 | self.assertNotEmpty(ps) 61 | df = profile2df(ps, False) 62 | self.assertIsInstance(df, list) 63 | self.assertIsInstance(df[0], dict) 64 | df = profile2df(ps, True) 65 | self.assertIsInstance(df, pandas.DataFrame) 66 | 67 | def test_profile_df_verbose(self): 68 | calls = [0] 69 | 70 | def f0(t): 71 | calls[0] += 1 72 | time.sleep(t) 73 | 74 | def f1(t): 75 | calls[0] += 1 76 | time.sleep(t) 77 | 78 | def f2(): 79 | calls[0] += 1 80 | f1(0.1) 81 | f1(0.01) 82 | 83 | def f3(): 84 | calls[0] += 1 85 | f0(0.2) 86 | f1(0.5) 87 | 88 | def f4(): 89 | calls[0] += 1 90 | f2() 91 | f3() 92 | 93 | ps = profile(f4)[0] 94 | df = self.capture(lambda: profile2df(ps, verbose=True, fLOG=print))[0] 95 | dfi = df.set_index("fct") 96 | self.assertEqual(dfi.loc["f4", "ncalls1"], 1) 97 | self.assertEqual(dfi.loc["f4", "ncalls2"], 1) 98 | 99 | @unittest.skipIf(sys.version_info[:2] < (3, 7), reason="not supported") 100 | def test_profile_graph(self): 101 | calls = [0] 102 | 103 | def f0(t): 104 | calls[0] += 1 105 | time.sleep(t) 106 | 107 | def f1(t): 108 | calls[0] += 1 109 | time.sleep(t) 110 | 111 | def f2(): 112 | calls[0] += 1 113 | f1(0.1) 114 | f1(0.01) 115 | 116 | def f3(): 117 | calls[0] += 1 118 | f0(0.2) 119 | f1(0.5) 120 | 121 | def f4(): 122 | calls[0] += 1 123 | f2() 124 | f3() 125 | 126 | ps = profile(f4)[0] 127 | profile2df(ps, verbose=False, clean_text=lambda x: x.split("/")[-1]) 128 | root, nodes = profile2graph(ps, clean_text=lambda x: x.split("/")[-1]) 129 | self.assertEqual(len(nodes), 6) 130 | self.assertIsInstance(nodes, dict) 131 | self.assertIsInstance(root, ProfileNode) 132 | self.assertIn("(", str(root)) 133 | dicts = root.as_dict() 134 | self.assertEqual(10, len(dicts)) 135 | text = root.to_text() 136 | self.assertIn("1 1", text) 137 | self.assertIn(" f1", text) 138 | text = root.to_text(fct_width=20) 139 | self.assertIn("...", text) 140 | root.to_text(sort_key=SortKey.CUMULATIVE) 141 | root.to_text(sort_key=SortKey.TIME) 142 | self.assertRaise( 143 | lambda: root.to_text(sort_key=SortKey.NAME), NotImplementedError 144 | ) 145 | js = root.to_json(indent=2) 146 | self.assertIn('"details"', js) 147 | js = root.to_json(as_str=False) 148 | self.assertIsInstance(js, dict) 149 | 150 | def test_profile_graph_recursive2(self): 151 | def f0(t): 152 | if t < 0.2: 153 | time.sleep(t) 154 | else: 155 | f1(t - 0.1) 156 | 157 | def f1(t): 158 | if t < 0.1: 159 | time.sleep(t) 160 | else: 161 | f0(t) 162 | 163 | def f4(): 164 | f1(0.3) 165 | 166 | ps = profile(f4)[0] 167 | profile2df(ps, verbose=False, clean_text=lambda x: x.split("/")[-1]) 168 | root, nodes = profile2graph(ps, clean_text=lambda x: x.split("/")[-1]) 169 | self.assertEqual(len(nodes), 4) 170 | text = root.to_text() 171 | self.assertIn(" f1", text) 172 | js = root.to_json(indent=2) 173 | self.assertIn('"details"', js) 174 | 175 | def test_profile_graph_recursive1(self): 176 | def f0(t): 177 | if t < 0.1: 178 | time.sleep(t) 179 | else: 180 | f0(t - 0.1) 181 | 182 | def f4(): 183 | f0(0.15) 184 | 185 | ps = profile(f4)[0] 186 | profile2df(ps, verbose=False, clean_text=lambda x: x.split("/")[-1]) 187 | root, nodes = profile2graph(ps, clean_text=lambda x: x.split("/")[-1]) 188 | self.assertEqual(len(nodes), 3) 189 | text = root.to_text() 190 | self.assertIn(" f0", text) 191 | js = root.to_json(indent=2) 192 | self.assertIn('"details"', js) 193 | 194 | 195 | if __name__ == "__main__": 196 | unittest.main() 197 | -------------------------------------------------------------------------------- /_unittests/win_test_array_api.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | set ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy 3 | python -m pytest ../../array-api-tests/array_api_tests/test_creation_functions.py::test_arange || exit 1 4 | python -m pytest ../../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 5 | -------------------------------------------------------------------------------- /onnx_array_api/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | APIs to create ONNX Graphs. 3 | """ 4 | 5 | __version__ = "0.3.1" 6 | __author__ = "Xavier Dupré" 7 | -------------------------------------------------------------------------------- /onnx_array_api/__main__.py: -------------------------------------------------------------------------------- 1 | from ._command_lines_parser import main 2 | 3 | if __name__ == "__main__": 4 | main() 5 | -------------------------------------------------------------------------------- /onnx_array_api/_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Any 3 | from onnx import helper, TensorProto 4 | 5 | 6 | def np_dtype_to_tensor_dtype(dtype: Any): 7 | """ 8 | Improves :func:`onnx.helper.np_dtype_to_tensor_dtype`. 9 | """ 10 | try: 11 | dt = helper.np_dtype_to_tensor_dtype(dtype) 12 | except (KeyError, ValueError): 13 | if dtype == np.float32: 14 | dt = TensorProto.FLOAT 15 | elif dtype == np.float64: 16 | dt = TensorProto.DOUBLE 17 | elif dtype == np.int64: 18 | dt = TensorProto.INT64 19 | elif dtype == np.int32: 20 | dt = TensorProto.INT32 21 | elif dtype == np.int16: 22 | dt = TensorProto.INT16 23 | elif dtype == np.int8: 24 | dt = TensorProto.INT8 25 | elif dtype == np.uint64: 26 | dt = TensorProto.UINT64 27 | elif dtype == np.uint32: 28 | dt = TensorProto.UINT32 29 | elif dtype == np.uint16: 30 | dt = TensorProto.UINT16 31 | elif dtype == np.uint8: 32 | dt = TensorProto.UINT8 33 | elif dtype == np.float16: 34 | dt = TensorProto.FLOAT16 35 | elif dtype in (bool, np.bool_): 36 | dt = TensorProto.BOOL 37 | elif dtype in (str, np.str_): 38 | dt = TensorProto.STRING 39 | elif dtype is int: 40 | dt = TensorProto.INT64 41 | elif dtype is float: 42 | dt = TensorProto.DOUBLE 43 | elif dtype == np.complex64: 44 | dt = TensorProto.COMPLEX64 45 | elif dtype == np.complex128: 46 | dt = TensorProto.COMPLEX128 47 | else: 48 | raise KeyError(f"Unable to guess type for dtype={dtype}.") # noqa: B904 49 | return dt 50 | -------------------------------------------------------------------------------- /onnx_array_api/annotations.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 2 | import numpy as np 3 | from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, TensorShapeProto 4 | from onnx.helper import np_dtype_to_tensor_dtype 5 | 6 | NP_DTYPE = np.dtype 7 | ELEMENT_TYPE = Union[int, NP_DTYPE] 8 | SHAPE_TYPE = Tuple[int, ...] 9 | VAR_CONSTANT_TYPE = Union["Var", TensorProto, np.ndarray] 10 | GRAPH_PROTO = Union[FunctionProto, GraphProto, ModelProto] 11 | 12 | AI_ONNX_ML = "ai.onnx.ml" 13 | 14 | ELEMENT_TYPE_NAME = { 15 | getattr(TensorProto, k): k 16 | for k in dir(TensorProto) 17 | if isinstance(getattr(TensorProto, k), int) and "_" not in k 18 | } 19 | 20 | 21 | class SubDomain: 22 | pass 23 | 24 | 25 | def domain(domain: str, op_type: Optional[str] = None) -> Callable: 26 | """ 27 | Registers one operator into a sub domain. It should be used as a 28 | decorator. One example: 29 | 30 | .. code-block:: python 31 | 32 | @domain("ai.onnx.ml") 33 | def Normalizer(self, norm: str = "MAX"): 34 | return self.make_node("Normalizer", self, norm=norm, domain="ai.onnx.ml") 35 | """ 36 | names = [op_type] 37 | 38 | def decorate(op_method: Callable) -> Callable: 39 | if names[0] is None: 40 | names[0] = op_method.__name__ 41 | 42 | def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any: 43 | return op_method(self.parent, *args, **kwargs) 44 | 45 | wrapper.__qual__name__ = f"[{domain}]{names[0]}" 46 | wrapper.__name__ = f"[{domain}]{names[0]}" 47 | wrapper.__domain__ = domain 48 | return wrapper 49 | 50 | return decorate 51 | 52 | 53 | _type_numpy = { 54 | np.float32: TensorProto.FLOAT, 55 | np.float64: TensorProto.DOUBLE, 56 | np.float16: TensorProto.FLOAT16, 57 | np.int8: TensorProto.INT8, 58 | np.int16: TensorProto.INT16, 59 | np.int32: TensorProto.INT32, 60 | np.int64: TensorProto.INT64, 61 | np.uint8: TensorProto.UINT8, 62 | np.uint16: TensorProto.UINT16, 63 | np.uint32: TensorProto.UINT32, 64 | np.uint64: TensorProto.UINT64, 65 | np.bool_: TensorProto.BOOL, 66 | np.str_: TensorProto.STRING, 67 | np.complex64: TensorProto.COMPLEX64, 68 | np.complex128: TensorProto.COMPLEX128, 69 | } 70 | 71 | 72 | def elem_type_int(elem_type: ELEMENT_TYPE) -> int: 73 | """ 74 | Converts an element type into an onnx element type (int). 75 | 76 | :param elem_type: integer or numpy type 77 | :return: int 78 | """ 79 | if isinstance(elem_type, int): 80 | return elem_type 81 | if elem_type in _type_numpy: 82 | return _type_numpy[elem_type] 83 | return np_dtype_to_tensor_dtype(elem_type) 84 | 85 | 86 | def _pick_dim(d, empty_dim): 87 | if d.dim_value: 88 | return d.dim_value 89 | if d.dim_param: 90 | return d.dim_param 91 | return empty_dim 92 | 93 | 94 | def make_shape(shape: TensorShapeProto, empty_dim: Optional[Any] = None) -> SHAPE_TYPE: 95 | "Extracts a shape from a tensor type." 96 | if hasattr(shape, "dim"): 97 | res = [_pick_dim(d, empty_dim=empty_dim) for i, d in enumerate(shape.dim)] 98 | return tuple(res) 99 | return None 100 | -------------------------------------------------------------------------------- /onnx_array_api/array_api/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Dict 2 | import warnings 3 | import numpy as np 4 | from onnx import TensorProto 5 | from .._helpers import np_dtype_to_tensor_dtype 6 | from ..npx.npx_types import DType 7 | from ..npx import npx_functions 8 | 9 | 10 | supported_functions = [ 11 | "abs", 12 | "absolute", 13 | "all", 14 | "any", 15 | "arange", 16 | "asarray", 17 | "astype", 18 | "empty", 19 | "equal", 20 | "eye", 21 | "full", 22 | "full_like", 23 | "isdtype", 24 | "isfinite", 25 | "isinf", 26 | "isnan", 27 | "linspace", 28 | "ones", 29 | "ones_like", 30 | "reshape", 31 | "sum", 32 | "take", 33 | "zeros", 34 | "zeros_like", 35 | ] 36 | 37 | 38 | def _finfo(dtype): 39 | """ 40 | Similar to :class:`numpy.finfo`. 41 | """ 42 | dt = dtype.np_dtype if isinstance(dtype, DType) else dtype 43 | res = np.finfo(dt) 44 | d = {} 45 | for k, v in res.__dict__.items(): 46 | if k.startswith("__"): 47 | continue 48 | if isinstance(v, (np.float32, np.float64, np.float16)): 49 | d[k] = float(v) 50 | elif isinstance(v, (np.complex128, np.complex64)): 51 | d[k] = complex(v) 52 | else: 53 | d[k] = v 54 | d["dtype"] = DType(np_dtype_to_tensor_dtype(dt)) 55 | nres = type("finfo", (res.__class__,), d) 56 | setattr(nres, "smallest_normal", float(res.smallest_normal)) # noqa: B010 57 | setattr(nres, "tiny", float(res.tiny)) # noqa: B010 58 | return nres 59 | 60 | 61 | def _iinfo(dtype): 62 | """ 63 | Similar to :class:`numpy.finfo`. 64 | """ 65 | dt = dtype.np_dtype if isinstance(dtype, DType) else dtype 66 | res = np.iinfo(dt) 67 | d = {} 68 | for k, v in res.__dict__.items(): 69 | if k.startswith("__"): 70 | continue 71 | if isinstance( 72 | v, 73 | ( 74 | np.int16, 75 | np.int32, 76 | np.int64, 77 | np.uint16, 78 | np.uint32, 79 | np.uint64, 80 | np.int8, 81 | np.uint8, 82 | ), 83 | ): 84 | d[k] = int(v) 85 | else: 86 | d[k] = v 87 | d["dtype"] = DType(np_dtype_to_tensor_dtype(dt)) 88 | nres = type("iinfo", (res.__class__,), d) 89 | setattr(nres, "min", int(res.min)) # noqa: B010 90 | setattr(nres, "max", int(res.max)) # noqa: B010 91 | return nres 92 | 93 | 94 | def array_api_wrap_function(f: Callable, TEagerTensor: type) -> Callable: 95 | """ 96 | Converts an eager function takeing EagerTensor into a function 97 | available through an Array API. 98 | 99 | :param callable: function 100 | :param TEagerTensor: EagerTensor class 101 | :return: new function 102 | """ 103 | 104 | def wrap(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: 105 | new_args = [] 106 | for a in args: 107 | if isinstance(a, np.ndarray): 108 | b = TEagerTensor(a) 109 | else: 110 | b = a 111 | new_args.append(b) 112 | res = f(TEagerTensor, *new_args, **kwargs) 113 | return res 114 | 115 | wrap.__doc__ = f.__doc__ 116 | return wrap 117 | 118 | 119 | def _finalize_array_api(module, function_names, TEagerTensor): 120 | """ 121 | Adds common attributes to Array API defined in this modules 122 | such as types. 123 | """ 124 | from . import _onnx_common 125 | 126 | module.float16 = DType(TensorProto.FLOAT16) 127 | module.float32 = DType(TensorProto.FLOAT) 128 | module.float64 = DType(TensorProto.DOUBLE) 129 | module.complex64 = DType(TensorProto.COMPLEX64) 130 | module.complex128 = DType(TensorProto.COMPLEX128) 131 | module.int8 = DType(TensorProto.INT8) 132 | module.int16 = DType(TensorProto.INT16) 133 | module.int32 = DType(TensorProto.INT32) 134 | module.int64 = DType(TensorProto.INT64) 135 | module.uint8 = DType(TensorProto.UINT8) 136 | module.uint16 = DType(TensorProto.UINT16) 137 | module.uint32 = DType(TensorProto.UINT32) 138 | module.uint64 = DType(TensorProto.UINT64) 139 | module.bfloat16 = DType(TensorProto.BFLOAT16) 140 | setattr(module, "bool", DType(TensorProto.BOOL)) # noqa: B010 141 | setattr(module, "str", DType(TensorProto.STRING)) # noqa: B010 142 | setattr(module, "finfo", _finfo) # noqa: B010 143 | setattr(module, "iinfo", _iinfo) # noqa: B010 144 | 145 | if function_names is None: 146 | function_names = supported_functions 147 | 148 | for name in function_names: 149 | f = getattr(_onnx_common, name, None) 150 | if f is None: 151 | f2 = getattr(npx_functions, name, None) 152 | if f2 is None: 153 | warnings.warn( 154 | f"Function {name!r} is not available in {module!r}.", 155 | stacklevel=0, 156 | ) 157 | continue 158 | f = lambda TEagerTensor, *args, _f=f2, **kwargs: _f( # noqa: E731 159 | *args, **kwargs 160 | ) 161 | setattr(module, name, array_api_wrap_function(f, TEagerTensor)) 162 | -------------------------------------------------------------------------------- /onnx_array_api/array_api/onnx_numpy.py: -------------------------------------------------------------------------------- 1 | from ..npx.npx_numpy_tensors import EagerNumpyTensor 2 | from . import _finalize_array_api 3 | 4 | 5 | def _finalize(): 6 | """ 7 | Adds common attributes to Array API defined in this modules 8 | such as types. 9 | """ 10 | from . import onnx_numpy 11 | 12 | _finalize_array_api(onnx_numpy, None, EagerNumpyTensor) 13 | 14 | 15 | _finalize() 16 | -------------------------------------------------------------------------------- /onnx_array_api/array_api/onnx_ort.py: -------------------------------------------------------------------------------- 1 | from ..ort.ort_tensors import EagerOrtTensor 2 | from . import _finalize_array_api 3 | 4 | 5 | def _finalize(): 6 | """ 7 | Adds common attributes to Array API defined in this modules 8 | such as types. 9 | """ 10 | from . import onnx_ort 11 | 12 | _finalize_array_api(onnx_ort, None, EagerOrtTensor) 13 | 14 | 15 | _finalize() 16 | -------------------------------------------------------------------------------- /onnx_array_api/cache.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | def get_cache_file(filename: str, remove: bool = False): 5 | """ 6 | Returns a name in the cache folder `~/.onnx-array-api`. 7 | 8 | :param filename: filename 9 | :param remove: remove if exists 10 | :return: full filename 11 | """ 12 | home = Path.home() 13 | folder = home / ".onnx-array-api" 14 | if not folder.exists(): 15 | folder.mkdir() 16 | name = folder / filename 17 | if name.exists(): 18 | name.unlink() 19 | return name 20 | -------------------------------------------------------------------------------- /onnx_array_api/graph_api/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph_builder import GraphBuilder, NodePattern 2 | -------------------------------------------------------------------------------- /onnx_array_api/light_api/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | from onnx import ModelProto 3 | from ..annotations import domain 4 | from .model import OnnxGraph, ProtoType 5 | from .var import Var, Vars 6 | 7 | 8 | def start( 9 | opset: Optional[int] = None, 10 | opsets: Optional[Dict[str, int]] = None, 11 | ir_version: Optional[int] = None, 12 | ) -> OnnxGraph: 13 | """ 14 | Starts an onnx model. 15 | 16 | :param opset: main opset version 17 | :param opsets: others opsets as a dictionary 18 | :param ir_version: specify the ir_version as well 19 | :return: an instance of :class:`onnx_array_api.light_api.OnnxGraph` 20 | 21 | A very simple model: 22 | 23 | .. runpython:: 24 | :showcode: 25 | 26 | from onnx_array_api.light_api import start 27 | 28 | onx = start().vin("X").Neg().rename("Y").vout().to_onnx() 29 | print(onx) 30 | 31 | Another with operator Add: 32 | 33 | .. runpython:: 34 | :showcode: 35 | 36 | from onnx_array_api.light_api import start 37 | 38 | onx = ( 39 | start() 40 | .vin("X") 41 | .vin("Y") 42 | .bring("X", "Y") 43 | .Add() 44 | .rename("Z") 45 | .vout() 46 | .to_onnx() 47 | ) 48 | print(onx) 49 | """ 50 | return OnnxGraph(opset=opset, opsets=opsets, ir_version=ir_version) 51 | 52 | 53 | def g() -> OnnxGraph: 54 | """ 55 | Starts a subgraph. 56 | :return: an instance of :class:`onnx_array_api.light_api.OnnxGraph` 57 | """ 58 | return OnnxGraph(proto_type=ProtoType.GRAPH) 59 | -------------------------------------------------------------------------------- /onnx_array_api/npx/__init__.py: -------------------------------------------------------------------------------- 1 | from .npx_core_api import cst, make_tuple, npxapi_function, npxapi_inline, var 2 | from .npx_functions import * 3 | from .npx_jit_eager import eager_onnx, jit_onnx 4 | from .npx_types import ElemType, OptParType, ParType, SequenceType, TensorType 5 | -------------------------------------------------------------------------------- /onnx_array_api/npx/npx_constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_OPSETS = {"": 18, "ai.onnx.ml": 3} 2 | FUNCTION_DOMAIN = "FUNCTION-DOMAIN" 3 | ONNX_DOMAIN = "ONNX-DOMAIN" 4 | 5 | _OPSET_TO_IR_VERSION = { 6 | 14: 7, 7 | 15: 8, 8 | 16: 8, 9 | 17: 8, 10 | 18: 8, 11 | 19: 9, 12 | } 13 | 14 | DEFAULT_IR_VERSION = _OPSET_TO_IR_VERSION[DEFAULT_OPSETS[""]] 15 | -------------------------------------------------------------------------------- /onnx_array_api/npx/npx_function_implementation.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple 2 | 3 | from onnx import FunctionProto, ValueInfoProto 4 | from onnx.helper import make_function, make_graph, make_node, make_opsetid 5 | 6 | from .npx_constants import FUNCTION_DOMAIN 7 | 8 | 9 | def get_function_implementation( 10 | domop: Tuple[str, str], 11 | node_inputs: List[str], 12 | node_outputs: List[str], 13 | opsets: Dict[str, int], 14 | **kwargs: Any, 15 | ) -> FunctionProto: 16 | """ 17 | Returns a :class:`onnx.FunctionProto` for a specific proto. 18 | 19 | :param domop: domain, function 20 | :param node_inputs: list of input names 21 | :param node_outputs: list of output names 22 | :param opsets: available opsets 23 | :kwargs: any other parameters 24 | :return: FunctionProto 25 | """ 26 | if domop[0] != FUNCTION_DOMAIN: 27 | raise ValueError( 28 | f"This function only considers function for domain " 29 | f"{FUNCTION_DOMAIN!r} not {domop[0]!r}." 30 | ) 31 | if domop[1] == "CDist": 32 | return _get_cdist_implementation(node_inputs, node_outputs, opsets, **kwargs) 33 | raise ValueError(f"Unable to return an implementation of function {domop!r}.") 34 | 35 | 36 | def _get_cdist_implementation( 37 | node_inputs: List[str], 38 | node_outputs: List[str], 39 | opsets: Dict[str, int], 40 | **kwargs: Any, 41 | ) -> FunctionProto: 42 | """ 43 | Returns the CDist implementation as a function. 44 | """ 45 | if len(node_inputs) != 2: 46 | raise ValueError(f"cdist has two inputs not {len(node_inputs)}.") 47 | if len(node_outputs) != 1: 48 | raise ValueError(f"cdist has one outputs not {len(node_outputs)}.") 49 | if opsets is None: 50 | raise ValueError("opsets cannot be None.") 51 | if "" not in opsets: 52 | raise ValueError( 53 | "Opsets for domain '' must be specified but opsets={opsets!r}." 54 | ) 55 | if set(kwargs) != {"metric"}: 56 | raise ValueError(f"kwargs={kwargs} must contain metric and only metric.") 57 | metric = kwargs["metric"] 58 | if opsets is not None and "com.microsoft" in opsets: 59 | node = make_node( 60 | "CDist", ["xa", "xb"], ["z"], domain="com.microsoft", metric=metric 61 | ) 62 | return make_function( 63 | "npx", 64 | f"CDist_{metric}", 65 | ["xa", "xb"], 66 | ["z"], 67 | [node], 68 | [make_opsetid("com.microsoft", 1)], 69 | ) 70 | 71 | if metric in ("euclidean", "sqeuclidean"): 72 | # subgraph 73 | nodes = [ 74 | make_node("Sub", ["next", "next_in"], ["diff"]), 75 | make_node("Constant", [], ["axis"], value_ints=[1]), 76 | make_node("ReduceSumSquare", ["diff", "axis"], ["scan_out"], keepdims=0), 77 | make_node("Identity", ["next_in"], ["next_out"]), 78 | ] 79 | 80 | def make_value(name): 81 | value = ValueInfoProto() 82 | value.name = name 83 | return value 84 | 85 | graph = make_graph( 86 | nodes, 87 | "loop", 88 | [make_value("next_in"), make_value("next")], 89 | [make_value("next_out"), make_value("scan_out")], 90 | ) 91 | 92 | scan = make_node( 93 | "Scan", ["xb", "xa"], ["next_out", "zout"], num_scan_inputs=1, body=graph 94 | ) 95 | if metric == "euclidean": 96 | final = make_node("Sqrt", ["zout"], ["z"]) 97 | else: 98 | final = make_node("Identity", ["zout"], ["z"]) 99 | return make_function( 100 | "npx", 101 | f"CDist_{metric}", 102 | ["xa", "xb"], 103 | ["z"], 104 | [scan, final], 105 | [make_opsetid("", opsets[""])], 106 | ) 107 | 108 | raise RuntimeError( 109 | f"There is no implementation for cdist and metric={metric!r} yet." 110 | ) 111 | -------------------------------------------------------------------------------- /onnx_array_api/npx/npx_functions_test.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | 5 | from .npx_core_api import ( 6 | cst, 7 | make_tuple, 8 | npxapi_function, 9 | npxapi_inline, 10 | tuple_var, 11 | var, 12 | ) 13 | from .npx_types import ( 14 | ElemType, 15 | OptParType, 16 | ParType, 17 | SequenceType, 18 | TensorType, 19 | TupleType, 20 | ) 21 | 22 | 23 | @npxapi_function 24 | def _min_max( 25 | x: TensorType[ElemType.numerics, "T"], 26 | ) -> TupleType[TensorType[ElemType.numerics, "T"], TensorType[ElemType.numerics, "T"]]: 27 | return tuple_var(var(x, op="ReduceMin"), var(x, op="ReduceMax")) 28 | 29 | 30 | @npxapi_inline 31 | def _min_max_inline( 32 | x: TensorType[ElemType.numerics, "T"], 33 | ) -> TupleType[TensorType[ElemType.numerics, "T"], TensorType[ElemType.numerics, "T"]]: 34 | return tuple_var(var(x, op="ReduceMin"), var(x, op="ReduceMax")) 35 | 36 | 37 | @npxapi_function 38 | def absolute( 39 | x: TensorType[ElemType.numerics, "T"], 40 | ) -> TensorType[ElemType.numerics, "T"]: 41 | "See :func:`numpy.absolute`." 42 | return var(x, op="Abs") 43 | 44 | 45 | @npxapi_function 46 | def addition( 47 | x: TensorType[ElemType.numerics, "T"], y: TensorType[ElemType.numerics, "T"] 48 | ) -> TensorType[ElemType.numerics, "T"]: 49 | "See :func:`numpy.addition`." 50 | return var(x, y, op="Add") 51 | 52 | 53 | @npxapi_function 54 | def argmin( 55 | x: TensorType[ElemType.numerics, "T"], 56 | axis: OptParType[int] = 0, 57 | keepdims: OptParType[int] = 0, 58 | ) -> TensorType[ElemType.numerics, "T"]: 59 | """ 60 | See :func:`numpy.argmin`. 61 | """ 62 | return var(x, op="ArgMin", axis=axis, keepdims=keepdims) 63 | 64 | 65 | @npxapi_function 66 | def concat( 67 | *x: SequenceType[TensorType[ElemType.numerics, "T"]], axis: ParType[int] = 0 68 | ) -> TensorType[ElemType.numerics, "T"]: 69 | """ 70 | Operator concat, handle :func:`numpy.vstack` and 71 | :func:`numpy.hstack`. 72 | """ 73 | if len(x) <= 1: 74 | raise RuntimeError(f"N={len(x)}<=1 elements to concatenate.") 75 | return var(*x, op="Concat", axis=axis) 76 | 77 | 78 | @npxapi_function 79 | def copy(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]: 80 | "Makes a copy." 81 | return var(x, op="Identity") 82 | 83 | 84 | @npxapi_function 85 | def log1p(x: TensorType[ElemType.floats, "T"]) -> TensorType[ElemType.floats, "T"]: 86 | "See :func:`numpy.log1p`." 87 | x1 = var(x, var(cst(np.array([1], dtype=np.int64)), x, op="CastLike"), op="Add") 88 | return var(x1, op="Log") 89 | 90 | 91 | @npxapi_function 92 | def negative( 93 | x: TensorType[ElemType.numerics, "T"], 94 | ) -> TensorType[ElemType.numerics, "T"]: 95 | "See :func:`numpy.negative`." 96 | return var(x, op="Neg") 97 | 98 | 99 | @npxapi_function 100 | def relu( 101 | x: TensorType[ElemType.numerics, "T"], 102 | ) -> TensorType[ElemType.numerics, "T"]: 103 | "See :func:`numpy.addition`." 104 | return var(var(absolute(x), x, op="Add"), var(cst(2), x, op="CastLike"), op="Div") 105 | 106 | 107 | @npxapi_function 108 | def topk( 109 | x: TensorType[ElemType.numerics, "T"], 110 | k: TensorType[ElemType.int64, "I", (1,)], 111 | axis: OptParType[int] = -1, 112 | largest: OptParType[int] = 1, 113 | sorted: OptParType[int] = 1, 114 | ) -> TupleType[TensorType[ElemType.numerics, "T"], TensorType[ElemType.int64, "I"]]: 115 | "See :func:`numpy.argsort`." 116 | return make_tuple(2, x, k, op="TopK", axis=axis, largest=largest, sorted=sorted) 117 | 118 | 119 | @npxapi_function 120 | def transpose( 121 | x: TensorType[ElemType.numerics, "T"], perm: ParType[Tuple[int]] = (1, 0) 122 | ) -> TensorType[ElemType.numerics, "T"]: 123 | "See :func:`numpy.transpose`." 124 | return var(x, op="Transpose", perm=list(perm)) 125 | -------------------------------------------------------------------------------- /onnx_array_api/ort/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/onnx_array_api/ort/__init__.py -------------------------------------------------------------------------------- /onnx_array_api/ort/ort_optimizers.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional 2 | from onnx import ModelProto, load 3 | from onnxruntime import InferenceSession, SessionOptions 4 | from onnxruntime.capi._pybind_state import GraphOptimizationLevel 5 | from ..cache import get_cache_file 6 | 7 | 8 | def ort_optimized_model( 9 | onx: Union[str, ModelProto], 10 | level: str = "ORT_ENABLE_ALL", 11 | output: Optional[str] = None, 12 | ) -> Union[str, ModelProto]: 13 | """ 14 | Returns the optimized model used by onnxruntime before 15 | running computing the inference. 16 | 17 | :param onx: ModelProto 18 | :param level: optimization level, `'ORT_ENABLE_BASIC'`, 19 | `'ORT_ENABLE_EXTENDED'`, `'ORT_ENABLE_ALL'` 20 | :param output: output file if the proposed cache is not wanted 21 | :return: optimized model 22 | """ 23 | glevel = getattr(GraphOptimizationLevel, level, None) 24 | if glevel is None: 25 | raise ValueError( 26 | f"Unrecognized level {level!r} among {dir(GraphOptimizationLevel)}." 27 | ) 28 | 29 | if output is not None: 30 | cache = output 31 | else: 32 | cache = get_cache_file("ort_optimized_model.onnx", remove=True) 33 | so = SessionOptions() 34 | so.graph_optimization_level = glevel 35 | so.optimized_model_filepath = str(cache) 36 | InferenceSession( 37 | onx if isinstance(onx, str) else onx.SerializeToString(), 38 | so, 39 | providers=["CPUExecutionProvider"], 40 | ) 41 | if output is None and not cache.exists(): 42 | raise RuntimeError(f"The optimized model {str(cache)!r} not found.") 43 | if output is not None: 44 | return output 45 | if isinstance(onx, str): 46 | return str(cache) 47 | opt_onx = load(str(cache)) 48 | cache.unlink() 49 | return opt_onx 50 | -------------------------------------------------------------------------------- /onnx_array_api/plotting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/onnx_array_api/plotting/__init__.py -------------------------------------------------------------------------------- /onnx_array_api/plotting/stat_plot.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | import pandas 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def plot_ort_profile( 7 | df: pandas.DataFrame, 8 | ax0: Optional[Any] = None, 9 | ax1: Optional[Any] = None, 10 | title: Optional[str] = None, 11 | ) -> Any: 12 | """ 13 | Plots time spend in computation based on dataframe 14 | produced by function :func:`ort_profile 15 | `. 16 | 17 | :param df: dataframe 18 | :param ax0: first axis to draw time 19 | :param ax1: second axis to draw occurences 20 | :param title: graph title 21 | :return: ax0 22 | 23 | See :ref:`l-example-ort-profiling` for an example. 24 | """ 25 | if ax0 is None: 26 | ax0 = plt.gca() # pragma: no cover 27 | 28 | gr_dur = ( 29 | df[["dur", "args_op_name"]].groupby("args_op_name").sum().sort_values("dur") 30 | ) 31 | gr_dur.plot.barh(ax=ax0) 32 | if title is not None: 33 | ax0.set_title(title) 34 | if ax1 is not None: 35 | gr_n = ( 36 | df[["dur", "args_op_name"]] 37 | .groupby("args_op_name") 38 | .count() 39 | .sort_values("dur") 40 | ) 41 | gr_n = gr_n.loc[gr_dur.index, :] 42 | gr_n.plot.barh(ax=ax1) 43 | ax1.set_title("n occurences") 44 | return ax0 45 | -------------------------------------------------------------------------------- /onnx_array_api/reference/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import numpy as np 3 | from onnx import TensorProto 4 | from onnx.numpy_helper import from_array as onnx_from_array 5 | from onnx.reference.ops.op_cast import ( 6 | bfloat16, 7 | float8e4m3fn, 8 | float8e4m3fnuz, 9 | float8e5m2, 10 | float8e5m2fnuz, 11 | ) 12 | from onnx.reference.op_run import to_array_extended 13 | from .evaluator import ExtendedReferenceEvaluator 14 | from .evaluator_yield import ( 15 | DistanceExecution, 16 | ResultExecution, 17 | ResultType, 18 | YieldEvaluator, 19 | compare_onnx_execution, 20 | ) 21 | 22 | 23 | def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorProto: 24 | """ 25 | Converts an array into a TensorProto. 26 | 27 | :param tensor: numpy array 28 | :param name: name 29 | :return: TensorProto 30 | """ 31 | dt = tensor.dtype 32 | if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn": 33 | to = TensorProto.FLOAT8E4M3FN 34 | dt_to = np.uint8 35 | elif dt == float8e4m3fnuz and dt.descr[0][0] == "e4m3fnuz": 36 | to = TensorProto.FLOAT8E4M3FNUZ 37 | dt_to = np.uint8 38 | elif dt == float8e5m2 and dt.descr[0][0] == "e5m2": 39 | to = TensorProto.FLOAT8E5M2 40 | dt_to = np.uint8 41 | elif dt == float8e5m2fnuz and dt.descr[0][0] == "e5m2fnuz": 42 | to = TensorProto.FLOAT8E5M2FNUZ 43 | dt_to = np.uint8 44 | elif dt == bfloat16 and dt.descr[0][0] == "bfloat16": 45 | to = TensorProto.BFLOAT16 46 | dt_to = np.uint16 47 | else: 48 | return onnx_from_array(tensor, name) 49 | 50 | t = onnx_from_array(tensor.astype(dt_to), name) 51 | t.data_type = to 52 | return t 53 | -------------------------------------------------------------------------------- /onnx_array_api/reference/evaluator.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from typing import Any, Dict, List, Optional, Union 3 | from onnx import FunctionProto, ModelProto 4 | from onnx.defs import get_schema 5 | from onnx.reference import ReferenceEvaluator 6 | from onnx.reference.op_run import OpRun 7 | from .ops.op_cast_like import CastLike_15, CastLike_19 8 | from .ops.op_concat import Concat 9 | from .ops.op_constant_of_shape import ConstantOfShape 10 | from .ops.op_fused_matmul import FusedMatMul 11 | from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost 12 | from .ops.op_quick_gelu import QuickGelu 13 | from .ops.op_scatter_elements import ScatterElements 14 | 15 | 16 | logger = getLogger("onnx-array-api-eval") 17 | 18 | 19 | class ExtendedReferenceEvaluator(ReferenceEvaluator): 20 | """ 21 | This class replaces the python implementation by custom implementation. 22 | The Array API extends many operator to all types not supported 23 | by the onnx specifications. The evaluator allows to test 24 | scenarios outside what an onnx backend bound to the official onnx 25 | operators definition could do. 26 | 27 | :: 28 | 29 | from onnx.reference import ReferenceEvaluator 30 | from onnx.reference.c_ops import Conv 31 | ref = ReferenceEvaluator(..., new_ops=[Conv]) 32 | """ 33 | 34 | default_ops = [ 35 | Concat, 36 | CastLike_15, 37 | CastLike_19, 38 | ConstantOfShape, 39 | FusedMatMul, 40 | MemcpyFromHost, 41 | MemcpyToHost, 42 | QuickGelu, 43 | ScatterElements, 44 | ] 45 | 46 | @staticmethod 47 | def filter_ops(proto, new_ops, opsets): 48 | if opsets is None and isinstance(proto, (ModelProto, FunctionProto)): 49 | opsets = {d.domain: d.version for d in proto.opset_import} 50 | best = {} 51 | renamed = {} 52 | for cl in new_ops: 53 | if "_" not in cl.__name__: 54 | continue 55 | vers = cl.__name__.split("_") 56 | try: 57 | v = int(vers[-1]) 58 | except ValueError: 59 | # not a version 60 | continue 61 | if opsets is not None and v > opsets.get(cl.op_domain, 1): 62 | continue 63 | renamed[cl.__name__] = cl 64 | key = cl.op_domain, "_".join(vers[:-1]) 65 | if key not in best or best[key][0] < v: 66 | best[key] = (v, cl) 67 | 68 | modified = [] 69 | for cl in new_ops: 70 | if cl.__name__ not in renamed: 71 | modified.append(cl) 72 | for k, v in best.items(): 73 | atts = {"domain": k[0]} 74 | bases = (v[1],) 75 | if not hasattr(v[1], "op_schema"): 76 | atts["op_schema"] = get_schema(k[1], v[0], domain=v[1].op_domain) 77 | new_cl = type(k[1], bases, atts) 78 | modified.append(new_cl) 79 | 80 | new_ops = modified 81 | return new_ops 82 | 83 | def __init__( 84 | self, 85 | proto: Any, 86 | opsets: Optional[Dict[str, int]] = None, 87 | functions: Optional[List[Union[ReferenceEvaluator, FunctionProto]]] = None, 88 | verbose: int = 0, 89 | new_ops: Optional[List[OpRun]] = None, 90 | **kwargs, 91 | ): 92 | if new_ops is None: 93 | new_ops = ExtendedReferenceEvaluator.default_ops 94 | else: 95 | new_ops = new_ops.copy() 96 | new_ops.extend(ExtendedReferenceEvaluator.default_ops) 97 | new_ops = ExtendedReferenceEvaluator.filter_ops(proto, new_ops, opsets) 98 | 99 | ReferenceEvaluator.__init__( 100 | self, 101 | proto, 102 | opsets=opsets, 103 | functions=functions, 104 | verbose=verbose, 105 | new_ops=new_ops, 106 | **kwargs, 107 | ) 108 | 109 | def _log(self, level: int, pattern: str, *args: List[Any]) -> None: 110 | if level < self.verbose: 111 | new_args = [self._log_arg(a) for a in args] 112 | print(pattern % tuple(new_args)) 113 | else: 114 | logger.debug(pattern, *args) 115 | 116 | def run(self, *args, **kwargs): 117 | """ 118 | See :meth:`onnx.reference.ReferenceEvaluator.run`. 119 | """ 120 | if len(args) == 1 and isinstance(args[0], list): 121 | feeds = dict(zip(self.input_names, args[0])) 122 | return self.run(None, feeds, **kwargs) 123 | return ReferenceEvaluator.run(self, *args, **kwargs) 124 | -------------------------------------------------------------------------------- /onnx_array_api/reference/ops/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /onnx_array_api/reference/ops/op_cast_like.py: -------------------------------------------------------------------------------- 1 | from onnx.helper import np_dtype_to_tensor_dtype 2 | from onnx.onnx_pb import TensorProto 3 | from onnx.reference.op_run import OpRun 4 | from onnx.reference.ops.op_cast import ( 5 | bfloat16, 6 | cast_to, 7 | float8e4m3fn, 8 | float8e4m3fnuz, 9 | float8e5m2, 10 | float8e5m2fnuz, 11 | ) 12 | 13 | 14 | def _cast_like(x, y, saturate): 15 | if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16": 16 | # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16 17 | to = TensorProto.BFLOAT16 18 | elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn": 19 | to = TensorProto.FLOAT8E4M3FN 20 | elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz": 21 | to = TensorProto.FLOAT8E4M3FNUZ 22 | elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2": 23 | to = TensorProto.FLOAT8E5M2 24 | elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz": 25 | to = TensorProto.FLOAT8E5M2FNUZ 26 | else: 27 | to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore 28 | return (cast_to(x, to, saturate),) 29 | 30 | 31 | class CastLike_15(OpRun): 32 | def _run(self, x, y): # type: ignore 33 | return _cast_like(x, y, True) 34 | 35 | 36 | class CastLike_19(OpRun): 37 | def _run(self, x, y, saturate=None): # type: ignore 38 | return _cast_like(x, y, saturate) 39 | -------------------------------------------------------------------------------- /onnx_array_api/reference/ops/op_concat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from onnx.reference.op_run import OpRun 4 | 5 | 6 | class Concat(OpRun): 7 | def _preprocess(self, a: np.ndarray, axis: int) -> np.ndarray: 8 | if axis >= len(a.shape): # type: ignore 9 | new_shape = a.shape + (1,) * (axis + 1 - len(a.shape)) # type: ignore 10 | return a.reshape(new_shape) 11 | return a 12 | 13 | def _run(self, *args, axis=None): # type: ignore 14 | targs = tuple(self._preprocess(a, axis) for a in args) 15 | return (np.concatenate(targs, axis),) # type: ignore 16 | -------------------------------------------------------------------------------- /onnx_array_api/reference/ops/op_constant_of_shape.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from onnx.reference.op_run import OpRun 3 | 4 | 5 | class ConstantOfShape(OpRun): 6 | @staticmethod 7 | def _process(value): 8 | cst = value[0] if isinstance(value, np.ndarray) and value.size > 0 else value 9 | if isinstance(value, np.ndarray): 10 | if not value.shape: 11 | cst = value 12 | elif value.size > 0: 13 | cst = value.ravel()[0] 14 | else: 15 | raise ValueError(f"Unexpected fill_value={value!r}") 16 | if isinstance(cst, bool): 17 | cst = np.bool_(cst) 18 | elif isinstance(cst, int): 19 | cst = np.int64(cst) 20 | elif isinstance(cst, float): 21 | cst = np.float64(cst) 22 | elif isinstance(cst, complex): 23 | cst = np.complex128(cst) 24 | elif cst is None: 25 | cst = np.float32(0) 26 | if not isinstance( 27 | cst, 28 | ( 29 | np.float16, 30 | np.float32, 31 | np.float64, 32 | np.complex64, 33 | np.complex128, 34 | np.int64, 35 | np.int32, 36 | np.int16, 37 | np.int8, 38 | np.uint64, 39 | np.uint32, 40 | np.uint16, 41 | np.uint8, 42 | np.bool_, 43 | ), 44 | ): 45 | raise TypeError(f"value must be a real not {type(cst)}") 46 | return cst 47 | 48 | def _run(self, data, value=None): 49 | cst = self._process(value) 50 | try: 51 | res = np.full(tuple(data), cst) 52 | except TypeError as e: 53 | raise RuntimeError( 54 | f"Unable to create a constant of shape " 55 | f"{data!r} with value {cst!r} " 56 | f"(raw value={value!r})." 57 | ) from e 58 | return (res,) 59 | -------------------------------------------------------------------------------- /onnx_array_api/reference/ops/op_fused_matmul.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from onnx.reference.op_run import OpRun 3 | 4 | 5 | class FusedMatMul(OpRun): 6 | op_domain = "com.microsoft" 7 | 8 | def _run( 9 | self, 10 | A, 11 | B, 12 | alpha: float = 1, 13 | transA: int = 0, 14 | transB: int = 0, 15 | transBatchA: int = 0, 16 | transBatchB: int = 0, 17 | ): 18 | assert ( 19 | transBatchA == 0 20 | ), f"Not implemented for transBatchA==1 and {A.shape}x{B.shape}" 21 | assert ( 22 | transBatchB == 0 23 | ), f"Not implemented for transBatchB==1 and {A.shape}x{B.shape}" 24 | if transA: 25 | perm = list(range(len(A.shape))) 26 | dim = len(perm) 27 | perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2] 28 | A = np.transpose(A, perm) 29 | if transB: 30 | perm = list(range(len(B.shape))) 31 | dim = len(perm) 32 | perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2] 33 | B = np.transpose(B, perm) 34 | a = np.array(alpha, dtype=A.dtype) 35 | return (np.matmul(A, B) * a,) 36 | -------------------------------------------------------------------------------- /onnx_array_api/reference/ops/op_memcpy_host.py: -------------------------------------------------------------------------------- 1 | from onnx.reference.op_run import OpRun 2 | 3 | 4 | class MemcpyFromHost(OpRun): 5 | def _run(self, x): 6 | return (x,) 7 | 8 | 9 | class MemcpyToHost(OpRun): 10 | def _run(self, x): 11 | return (x,) 12 | -------------------------------------------------------------------------------- /onnx_array_api/reference/ops/op_quick_gelu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from onnx.reference.op_run import OpRun 3 | 4 | 5 | def sigmoid(x): # type: ignore 6 | if x > 0: 7 | return 1 / (1 + np.exp(-x)) 8 | return np.exp(x) / (1 + np.exp(x)) 9 | 10 | 11 | class QuickGelu(OpRun): 12 | op_domain = "com.microsoft" 13 | 14 | def __init__(self, onnx_node, run_params): # type: ignore 15 | OpRun.__init__(self, onnx_node, run_params) 16 | self.vf = np.vectorize(sigmoid) 17 | 18 | def _run(self, X, alpha=1.0): 19 | if len(X.shape) == 0: 20 | return ((X * sigmoid(X * alpha)).astype(X.dtype),) 21 | if X.size == 0: 22 | return (X,) 23 | return ((X * self.vf(X * alpha)).astype(X.dtype),) 24 | -------------------------------------------------------------------------------- /onnx_array_api/reference/ops/op_scatter_elements.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from onnx.reference.op_run import OpRun 4 | 5 | 6 | def scatter_elements(data, indices, updates, axis=0, reduction=None): # type: ignore 7 | if reduction == "add": 8 | 9 | def f(x, y): 10 | return x + y 11 | 12 | elif reduction == "min": 13 | 14 | def f(x, y): 15 | return min(x, y) 16 | 17 | elif reduction == "max": 18 | 19 | def f(x, y): 20 | return max(x, y) 21 | 22 | else: 23 | 24 | def f(x, y): 25 | return y 26 | 27 | if axis < 0: 28 | axis = data.ndim + axis 29 | 30 | if len(data.shape) == 1 and axis == 0: 31 | scattered = np.copy(data) 32 | for pos, up in zip(indices, updates): 33 | scattered[pos] = f(scattered[pos], up) 34 | return scattered 35 | 36 | if len(indices.shape) == 2: 37 | scattered = np.copy(data) 38 | if axis == 0: 39 | for i in range(indices.shape[0]): 40 | for j in range(indices.shape[1]): 41 | scattered[indices[i, j], j] = f( 42 | scattered[indices[i, j], j], updates[i, j] 43 | ) 44 | else: 45 | for i in range(indices.shape[0]): 46 | for j in range(indices.shape[1]): 47 | scattered[i, indices[i, j]] = f( 48 | scattered[i, indices[i, j]], updates[i, j] 49 | ) 50 | return scattered 51 | 52 | if len(indices.shape) == 3: 53 | scattered = np.copy(data) 54 | if axis == 0: 55 | for i in range(indices.shape[0]): 56 | for j in range(indices.shape[1]): 57 | for k in range(indices.shape[2]): 58 | scattered[indices[i, j, k], j, k] = f( 59 | scattered[indices[i, j, k], j, k], updates[i, j, k] 60 | ) 61 | elif axis == 1: 62 | for i in range(indices.shape[0]): 63 | for j in range(indices.shape[1]): 64 | for k in range(indices.shape[2]): 65 | scattered[i, indices[i, j, k], k] = f( 66 | scattered[i, indices[i, j, k], k], updates[i, j, k] 67 | ) 68 | elif axis == 2: 69 | for i in range(indices.shape[0]): 70 | for j in range(indices.shape[1]): 71 | for k in range(indices.shape[2]): 72 | scattered[i, j, indices[i, j, k]] = f( 73 | scattered[i, j, indices[i, j, k]], updates[i, j, k] 74 | ) 75 | return scattered 76 | 77 | if len(indices.shape) == 4: 78 | scattered = np.copy(data) 79 | if axis == 3: 80 | for a in range(indices.shape[0]): 81 | for i in range(indices.shape[1]): 82 | for j in range(indices.shape[2]): 83 | for k in range(indices.shape[3]): 84 | scattered[a, i, j, indices[a, i, j, k]] = f( 85 | scattered[a, i, j, indices[a, i, j, k]], 86 | updates[a, i, j, k], 87 | ) 88 | return scattered 89 | 90 | raise RuntimeError( 91 | f"Not implemented for indices.shape={indices.shape} and axis={axis}" 92 | ) 93 | 94 | 95 | class ScatterElements(OpRun): 96 | def _run(self, data, indices, updates, axis=None, reduction=None): # type: ignore 97 | res = scatter_elements(data, indices, updates, axis=axis, reduction=reduction) 98 | return (res,) 99 | -------------------------------------------------------------------------------- /onnx_array_api/tools/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /onnx_array_api/translate_api/__init__.py: -------------------------------------------------------------------------------- 1 | from onnx import ModelProto 2 | from .translate import Translater 3 | from .inner_emitter import InnerEmitter, InnerEmitterShortInitializer 4 | from .builder_emitter import BuilderEmitter 5 | 6 | 7 | def translate(proto: ModelProto, single_line: bool = False, api: str = "light") -> str: 8 | """ 9 | Translates an ONNX proto into a code using :ref:`l-light-api` 10 | to describe the ONNX graph. 11 | 12 | :param proto: model to translate 13 | :param single_line: as a single line or not 14 | :param api: API to export into, 15 | default is `"light"` and this is handle by class 16 | :class:`onnx_array_api.translate_api.light_emitter.LightEmitter`, 17 | another value is `"onnx"` which is the inner API implemented 18 | in onnx package, `"builder"` follows the syntax for the 19 | class :class:`onnx_array_api.graph_api.GraphBuilder`, 20 | `"onnx-short"` replaces long initializer with random values 21 | :return: code 22 | 23 | .. runpython:: 24 | :showcode: 25 | 26 | from onnx_array_api.light_api import start 27 | from onnx_array_api.translate_api import translate 28 | 29 | onx = ( 30 | start() 31 | .vin("X") 32 | .reshape((-1, 1)) 33 | .Transpose(perm=[1, 0]) 34 | .rename("Y") 35 | .vout() 36 | .to_onnx() 37 | ) 38 | code = translate(onx) 39 | print(code) 40 | 41 | The inner API from onnx package is also available. 42 | 43 | .. runpython:: 44 | :showcode: 45 | 46 | from onnx_array_api.light_api import start 47 | from onnx_array_api.translate_api import translate 48 | 49 | onx = ( 50 | start() 51 | .vin("X") 52 | .reshape((-1, 1)) 53 | .Transpose(perm=[1, 0]) 54 | .rename("Y") 55 | .vout() 56 | .to_onnx() 57 | ) 58 | code = translate(onx, api="onnx") 59 | print(code) 60 | 61 | The :class:`GraphBuilder 62 | ` API returns this: 63 | 64 | .. runpython:: 65 | :showcode: 66 | 67 | from onnx_array_api.light_api import start 68 | from onnx_array_api.translate_api import translate 69 | 70 | onx = ( 71 | start() 72 | .vin("X") 73 | .reshape((-1, 1)) 74 | .Transpose(perm=[1, 0]) 75 | .rename("Y") 76 | .vout() 77 | .to_onnx() 78 | ) 79 | code = translate(onx, api="builder") 80 | print(code) 81 | """ 82 | if api == "light": 83 | tr = Translater(proto) 84 | return tr.export(single_line=single_line, as_str=True) 85 | if api == "onnx": 86 | tr = Translater(proto, emitter=InnerEmitter()) 87 | return tr.export(as_str=True) 88 | if api == "onnx-short": 89 | tr = Translater(proto, emitter=InnerEmitterShortInitializer()) 90 | return tr.export(as_str=True) 91 | if api == "builder": 92 | tr = Translater(proto, emitter=BuilderEmitter()) 93 | return tr.export(as_str=True) 94 | raise ValueError(f"Unexpected value {api!r} for api.") 95 | -------------------------------------------------------------------------------- /onnx_array_api/translate_api/light_emitter.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | from ..annotations import ELEMENT_TYPE_NAME 3 | from .base_emitter import BaseEmitter 4 | 5 | 6 | class LightEmitter(BaseEmitter): 7 | """ 8 | Converts event into proper code. 9 | """ 10 | 11 | def join(self, rows: List[str], single_line: bool = False) -> str: 12 | "Join the rows" 13 | if single_line: 14 | return ".".join(rows) 15 | return "".join(["(\n ", "\n .".join(rows), "\n)"]) 16 | 17 | def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: 18 | opsets = kwargs.get("opsets", {}) 19 | opset = opsets.get("", None) 20 | if opset is not None: 21 | del opsets[""] 22 | args = [] 23 | if opset: 24 | args.append(f"opset={opset}") 25 | if opsets: 26 | args.append(f"opsets={opsets}") 27 | return [f"start({', '.join(args)})"] 28 | 29 | def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: 30 | return ["to_onnx()"] 31 | 32 | def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: 33 | return [] 34 | 35 | def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: 36 | return [] 37 | 38 | def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: 39 | return [] 40 | 41 | def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: 42 | name = kwargs["name"] 43 | value = kwargs["value"] 44 | repl = {"bool": "bool_", "object": "object_", "str": "str_"} 45 | sdtype = repl.get(str(value.dtype), str(str(value.dtype))) 46 | return [ 47 | f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))", 48 | f"rename({name!r})", 49 | ] 50 | 51 | def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: 52 | name = kwargs["name"] 53 | elem_type = kwargs.get("elem_type", None) 54 | shape = kwargs.get("shape", None) 55 | if elem_type and shape: 56 | return [ 57 | f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, " 58 | f"shape={shape!r})" 59 | ] 60 | if elem_type: 61 | return [ 62 | f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})" 63 | ] 64 | return [f"vin({name!r})"] 65 | 66 | def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: 67 | inst = [] 68 | if "name" in kwargs: 69 | name = kwargs["name"] 70 | inst.append(f"bring({name!r})") 71 | elem_type = kwargs.get("elem_type", None) 72 | shape = kwargs.get("shape", None) 73 | if elem_type and shape: 74 | inst.append( 75 | f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, " 76 | f"shape={shape!r})" 77 | ) 78 | elif elem_type: 79 | inst.append(f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})") 80 | else: 81 | inst.append("vout()") 82 | return inst 83 | 84 | def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: 85 | op_type = kwargs["op_type"] 86 | inputs = kwargs["inputs"] 87 | outputs = kwargs["outputs"] 88 | if kwargs.get("domain", "") != "": 89 | domain = kwargs["domain"] 90 | op_type = f"{domain}.{op_type}" 91 | atts = kwargs.get("atts", {}) 92 | args = [] 93 | for k, v in atts.items(): 94 | before, vatt = self.render_attribute_value(v) 95 | if before: 96 | raise NotImplementedError("Graph attribute not supported yet.") 97 | args.append(f"{k}={vatt}") 98 | 99 | str_inputs = ", ".join([f"{i!r}" for i in inputs]) 100 | inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"] 101 | if len(outputs) == 1: 102 | inst.append(f"rename({outputs[0]!r})") 103 | else: 104 | str_outputs = ", ".join([f"{o!r}" for o in outputs]) 105 | inst.append(f"rename({str_outputs})") 106 | return inst 107 | -------------------------------------------------------------------------------- /onnx_array_api/translate_api/make_helper.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Sequence 2 | from onnx import AttributeProto, NodeProto 3 | from onnx.helper import make_attribute 4 | 5 | 6 | def make_ref_attribute( 7 | key: str, attr_type: int, ref_attr_name: Optional[str] = None 8 | ) -> AttributeProto: 9 | """ 10 | Creates an attribute. 11 | 12 | :param key: atttribute name 13 | :param attr_type: attribute type 14 | :param ref_attr_name: if not None, link this attribute 15 | to a function attribute 16 | :return: attribute 17 | """ 18 | att = AttributeProto() 19 | att.name = key 20 | att.type = attr_type 21 | att.ref_attr_name = ref_attr_name 22 | return att 23 | 24 | 25 | def make_node_extended( 26 | op_type: str, 27 | inputs: Sequence[str], 28 | outputs: Sequence[str], 29 | name: Optional[str] = None, 30 | doc_string: Optional[str] = None, 31 | domain: Optional[str] = None, 32 | **kwargs: Any, 33 | ) -> NodeProto: 34 | """ 35 | Constructs a NodeProto. 36 | 37 | :param op_type: The name of the operator to construct 38 | :param inputs: list of input names 39 | :param outputs: list of output names 40 | :param name: optional unique identifier for NodeProto 41 | :param doc_string: optional documentation string for NodeProto 42 | :param domain: optional domain for NodeProto. 43 | If it's None, we will just use default domain (which is empty) 44 | :param kwargs: the attributes of the node. 45 | :return: node proto 46 | """ 47 | node = NodeProto() 48 | node.op_type = op_type 49 | node.input.extend(inputs) 50 | node.output.extend(outputs) 51 | if name: 52 | node.name = name 53 | if doc_string: 54 | node.doc_string = doc_string 55 | if domain is not None: 56 | node.domain = domain 57 | if kwargs: 58 | for key, value in sorted(kwargs.items()): 59 | if value is None: 60 | continue 61 | if isinstance(value, AttributeProto): 62 | node.attribute.append(value) 63 | else: 64 | node.attribute.append(make_attribute(key, value)) 65 | return node 66 | -------------------------------------------------------------------------------- /onnx_array_api/validation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /onnx_array_api/validation/diff.py: -------------------------------------------------------------------------------- 1 | import os 2 | import difflib 3 | import textwrap 4 | from typing import Union 5 | from onnx import ModelProto 6 | 7 | 8 | def _get_diff_template(): 9 | import jinja2 10 | 11 | tpl = textwrap.dedent( 12 | """ 13 |
14 | 15 | 17 | 42 | """ 43 | ) 44 | path = os.path.abspath(os.path.dirname(__file__)) 45 | path = path.replace("\\", "/") 46 | path = f"file://{path}" 47 | tpl = tpl.replace("__PATH__", path) 48 | return jinja2.Template(tpl, autoescape=True) 49 | 50 | 51 | def text_diff(text1: Union[ModelProto, str], text2: Union[ModelProto, str]) -> str: 52 | """ 53 | Produces a string showing the differences between 54 | two strings. 55 | 56 | :param text1: first string 57 | :param text2: second string 58 | :return: differences 59 | """ 60 | if not isinstance(text1, str): 61 | from ..plotting.text_plot import onnx_simple_text_plot 62 | 63 | text1 = onnx_simple_text_plot(text1, indent=False) 64 | if not isinstance(text2, str): 65 | from ..plotting.text_plot import onnx_simple_text_plot 66 | 67 | text2 = onnx_simple_text_plot(text2, indent=False) 68 | differ = difflib.Differ() 69 | result = list( 70 | differ.compare(text1.splitlines(keepends=True), text2.splitlines(keepends=True)) 71 | ) 72 | raw = "".join(result) 73 | return raw 74 | 75 | 76 | def html_diff( 77 | text1: Union[ModelProto, str], 78 | text2: Union[ModelProto, str], 79 | title: str = "html_diff", 80 | div_name: str = "div_name", 81 | header: bool = True, 82 | ) -> str: 83 | """ 84 | Produces a HTML files showing the differences between 85 | two strings. 86 | 87 | :param text1: first string 88 | :param text2: second string 89 | :param title: title 90 | :param div: html format, section name 91 | :param header: if True, add header and html main tags 92 | :return: differences 93 | """ 94 | raw = text_diff(text1, text2) 95 | diff = _get_diff_template().render( 96 | title=title, 97 | version1=text1, 98 | version2=text2, 99 | div_name=f"div_{div_name}", 100 | diff_content=raw, 101 | ) 102 | return f"\n{diff}\n\n" 103 | -------------------------------------------------------------------------------- /onnx_array_api/validation/diff2html.min.css: -------------------------------------------------------------------------------- 1 | .d2h-wrapper{text-align:left}.d2h-file-header{background-color:#f7f7f7;border-bottom:1px solid #d8d8d8;font-family:Source Sans Pro,Helvetica Neue,Helvetica,Arial,sans-serif;height:35px;padding:5px 10px}.d2h-file-header,.d2h-file-stats{display:-webkit-box;display:-ms-flexbox;display:flex}.d2h-file-stats{font-size:14px;margin-left:auto}.d2h-lines-added{border:1px solid #b4e2b4;border-radius:5px 0 0 5px;color:#399839;padding:2px;text-align:right;vertical-align:middle}.d2h-lines-deleted{border:1px solid #e9aeae;border-radius:0 5px 5px 0;color:#c33;margin-left:1px;padding:2px;text-align:left;vertical-align:middle}.d2h-file-name-wrapper{-webkit-box-align:center;-ms-flex-align:center;align-items:center;display:-webkit-box;display:-ms-flexbox;display:flex;font-size:15px;width:100%}.d2h-file-name{overflow-x:hidden;text-overflow:ellipsis;white-space:nowrap}.d2h-file-wrapper{border:1px solid #ddd;border-radius:3px;margin-bottom:1em}.d2h-file-collapse{-webkit-box-pack:end;-ms-flex-pack:end;-webkit-box-align:center;-ms-flex-align:center;align-items:center;border:1px solid #ddd;border-radius:3px;cursor:pointer;display:none;font-size:12px;justify-content:flex-end;padding:4px 8px}.d2h-file-collapse.d2h-selected{background-color:#c8e1ff}.d2h-file-collapse-input{margin:0 4px 0 0}.d2h-diff-table{border-collapse:collapse;font-family:Menlo,Consolas,monospace;font-size:13px;width:100%}.d2h-files-diff{display:-webkit-box;display:-ms-flexbox;display:flex;width:100%}.d2h-file-diff{overflow-y:hidden}.d2h-file-diff.d2h-d-none,.d2h-files-diff.d2h-d-none{display:none}.d2h-file-side-diff{display:inline-block;margin-bottom:-8px;margin-right:-4px;overflow-x:scroll;overflow-y:hidden;width:50%}.d2h-code-line{padding:0 8em}.d2h-code-line,.d2h-code-side-line{display:inline-block;-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none;white-space:nowrap;width:100%}.d2h-code-side-line{padding:0 4.5em}.d2h-code-line-ctn{word-wrap:normal;background:none;display:inline-block;padding:0;-webkit-user-select:text;-moz-user-select:text;-ms-user-select:text;user-select:text;vertical-align:middle;white-space:pre;width:100%}.d2h-code-line del,.d2h-code-side-line del{background-color:#ffb6ba}.d2h-code-line del,.d2h-code-line ins,.d2h-code-side-line del,.d2h-code-side-line ins{border-radius:.2em;display:inline-block;margin-top:-1px;text-decoration:none;vertical-align:middle}.d2h-code-line ins,.d2h-code-side-line ins{background-color:#97f295;text-align:left}.d2h-code-line-prefix{word-wrap:normal;background:none;display:inline;padding:0;white-space:pre}.line-num1{float:left}.line-num1,.line-num2{-webkit-box-sizing:border-box;box-sizing:border-box;overflow:hidden;padding:0 .5em;text-overflow:ellipsis;width:3.5em}.line-num2{float:right} 2 | .d2h-code-linenumber{background-color:var(--pst-color-background);border:solid #eee;border-width:0 1px;-webkit-box-sizing:border-box;box-sizing:border-box;color:var(--pst-color-text-base);cursor:pointer;display:inline-block;position:absolute;text-align:right;width:7.5em} 3 | .d2h-code-linenumber:after{content:"\200b"} 4 | .d2h-code-side-linenumber{background-color:var(--pst-color-background);border:solid #eee;border-width:0 1px;-webkit-box-sizing:border-box;box-sizing:border-box;color:var(--pst-color-text-base);cursor:pointer;display:inline-block;overflow:hidden;padding:0 .5em;position:absolute;text-align:right;text-overflow:ellipsis;width:4em} 5 | .d2h-code-side-linenumber:after{content:"\200b"}.d2h-code-side-emptyplaceholder,.d2h-emptyplaceholder{background-color:#f1f1f1;border-color:#e1e1e1}.d2h-code-line-prefix, 6 | .d2h-code-linenumber,.d2h-code-side-linenumber,.d2h-emptyplaceholder{-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none} 7 | .d2h-code-linenumber,.d2h-code-side-linenumber{direction:rtl}.d2h-del{background-color:#fee8e9;border-color:#e9aeae}.d2h-ins{background-color:#dfd;border-color:#b4e2b4}.d2h-info{background-color:#f8fafd;border-color:#d5e4f2;color:rgba(0,0,0,.3)}.d2h-file-diff .d2h-del.d2h-change{background-color:#fdf2d0}.d2h-file-diff .d2h-ins.d2h-change{background-color:#ded}.d2h-file-list-wrapper{margin-bottom:10px}.d2h-file-list-wrapper a{color:#3572b0;text-decoration:none}.d2h-file-list-wrapper a:visited{color:#3572b0}.d2h-file-list-header{text-align:left}.d2h-file-list-title{font-weight:700}.d2h-file-list-line{display:-webkit-box;display:-ms-flexbox;display:flex;text-align:left}.d2h-file-list{display:block;list-style:none;margin:0;padding:0}.d2h-file-list>li{border-bottom:1px solid #ddd;margin:0;padding:5px 10px}.d2h-file-list>li:last-child{border-bottom:none}.d2h-file-switch{cursor:pointer;display:none;font-size:10px}.d2h-icon{fill:currentColor;margin-right:10px;vertical-align:middle}.d2h-deleted{color:#c33}.d2h-added{color:#399839}.d2h-changed{color:#d0b44c}.d2h-moved{color:#3572b0}.d2h-tag{background-color:#fff;display:-webkit-box;display:-ms-flexbox;display:flex;font-size:10px;margin-left:5px;padding:0 2px}.d2h-deleted-tag{border:1px solid #c33}.d2h-added-tag{border:1px solid #399839}.d2h-changed-tag{border:1px solid #d0b44c}.d2h-moved-tag{border:1px solid #3572b0} -------------------------------------------------------------------------------- /onnx_array_api/validation/docs.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import numpy as np 3 | import onnx 4 | import onnx.helper as oh 5 | 6 | 7 | def make_euclidean( 8 | input_names: Tuple[str] = ("X", "Y"), 9 | output_name: str = "Z", 10 | elem_type: int = onnx.TensorProto.FLOAT, 11 | opset: Optional[int] = None, 12 | ) -> onnx.ModelProto: 13 | """ 14 | Creates the onnx graph corresponding to the euclidean distance. 15 | 16 | :param input_names: names of the inputs 17 | :param output_name: name of the output 18 | :param elem_type: onnx is strongly types, which type is it? 19 | :param opset: opset version 20 | :return: onnx.ModelProto 21 | """ 22 | if opset is None: 23 | opset = onnx.defs.onnx_opset_version() 24 | 25 | X = oh.make_tensor_value_info(input_names[0], elem_type, None) 26 | Y = oh.make_tensor_value_info(input_names[1], elem_type, None) 27 | Z = oh.make_tensor_value_info(output_name, elem_type, None) 28 | two = oh.make_tensor("two", onnx.TensorProto.INT64, [1], [2]) 29 | n1 = oh.make_node("Sub", ["X", "Y"], ["dxy"]) 30 | n2 = oh.make_node("Pow", ["dxy", "two"], ["dxy2"]) 31 | n3 = oh.make_node("ReduceSum", ["dxy2"], [output_name]) 32 | graph = oh.make_graph([n1, n2, n3], "euclidian", [X, Y], [Z], [two]) 33 | model = oh.make_model( 34 | graph, opset_imports=[oh.make_opsetid("", opset)], ir_version=9 35 | ) 36 | return model 37 | 38 | 39 | def make_euclidean_skl2onnx( 40 | input_names: Tuple[str] = ("X", "Y"), 41 | output_name: str = "Z", 42 | elem_type: int = onnx.TensorProto.FLOAT, 43 | opset: Optional[int] = None, 44 | ) -> onnx.ModelProto: 45 | """ 46 | Creates the onnx graph corresponding to the euclidean distance 47 | with :epkg:`sklearn-onnx`. 48 | 49 | :param input_names: names of the inputs 50 | :param output_name: name of the output 51 | :param elem_type: onnx is strongly types, which type is it? 52 | :param opset: opset version 53 | :return: onnx.ModelProto 54 | """ 55 | if opset is None: 56 | opset = onnx.defs.onnx_opset_version() 57 | 58 | from skl2onnx.algebra.onnx_ops import OnnxSub, OnnxPow, OnnxReduceSum 59 | 60 | dxy = OnnxSub(input_names[0], input_names[1], op_version=opset) 61 | dxy2 = OnnxPow(dxy, np.array([2], dtype=np.int64), op_version=opset) 62 | final = OnnxReduceSum(dxy2, op_version=opset, output_names=[output_name]) 63 | 64 | np_type = oh.tensor_dtype_to_np_dtype(elem_type) 65 | dummy = np.empty([1], np_type) 66 | return final.to_onnx({"X": dummy, "Y": dummy}) 67 | -------------------------------------------------------------------------------- /onnx_array_api/validation/tools.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import numpy 3 | from onnx import ( 4 | AttributeProto, 5 | GraphProto, 6 | FunctionProto, 7 | ModelProto, 8 | NodeProto, 9 | TensorProto, 10 | ) 11 | from onnx.helper import ( 12 | make_attribute, 13 | make_function, 14 | make_graph, 15 | make_model, 16 | make_node, 17 | set_model_props, 18 | ) 19 | from ..reference import from_array_extended as from_array, to_array_extended as to_array 20 | 21 | 22 | def randomize_proto( 23 | onx: Union[ModelProto, GraphProto, FunctionProto, NodeProto, TensorProto], 24 | ) -> Union[ModelProto, GraphProto, FunctionProto, NodeProto, TensorProto]: 25 | """ 26 | Randomizes float initializers or constant nodes. 27 | 28 | :param onx: onnx model or proto 29 | :return: same object 30 | """ 31 | if isinstance(onx, TensorProto): 32 | t = to_array(onx) 33 | mini, maxi = t.min(), t.max() 34 | new_t = numpy.clip( 35 | numpy.random.random(t.shape) * (maxi - mini) + mini, mini, maxi 36 | ) 37 | return from_array(new_t.astype(t.dtype), name=onx.name) 38 | 39 | if isinstance(onx, ModelProto): 40 | new_graph = randomize_proto(onx.graph) 41 | new_functions = [randomize_proto(f) for f in onx.functions] 42 | 43 | onnx_model = make_model( 44 | new_graph, 45 | functions=new_functions, 46 | ir_version=onx.ir_version, 47 | producer_name=onx.producer_name, 48 | domain=onx.domain, 49 | doc_string=onx.doc_string, 50 | opset_imports=list(onx.opset_import), 51 | ) 52 | if onx.metadata_props: 53 | values = {p.key: p.value for p in onx.metadata_props} 54 | set_model_props(onnx_model, values) 55 | return onnx_model 56 | 57 | if isinstance(onx, (GraphProto, FunctionProto)): 58 | nodes = [] 59 | for node in onx.node: 60 | if node.op_type in "Constant": 61 | nodes.append(randomize_proto(node)) 62 | continue 63 | changed = False 64 | atts = [] 65 | for att in node.attribute: 66 | if att.type == AttributeProto.GRAPH: 67 | new_g = randomize_proto(att.g) 68 | att = make_attribute(att.name, new_g) 69 | changed = True 70 | atts.append(att) 71 | if changed: 72 | new_node = make_node( 73 | node.op_type, node.input, node.output, domain=node.domain 74 | ) 75 | new_node.attribute.extend(node.attribute) 76 | nodes.append(new_node) 77 | continue 78 | nodes.append(node) 79 | 80 | if isinstance(onx, FunctionProto): 81 | new_onx = make_function( 82 | onx.domain, 83 | onx.name, 84 | onx.input, 85 | onx.output, 86 | nodes, 87 | opset_imports=onx.opset_import, 88 | ) 89 | return new_onx 90 | 91 | inits = [randomize_proto(init) for init in onx.initializer] 92 | sp_inits = [randomize_proto(init) for init in onx.sparse_initializer] 93 | 94 | graph = make_graph( 95 | nodes, 96 | onx.name, 97 | onx.input, 98 | onx.output, 99 | initializer=inits, 100 | sparse_initializer=sp_inits, 101 | ) 102 | return graph 103 | 104 | raise TypeError(f"Unexpected type for onx {type(onx)}.") 105 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | 3 | # Exclude a variety of commonly ignored directories. 4 | exclude = [ 5 | ".eggs", 6 | ".git", 7 | "build", 8 | "dist", 9 | ] 10 | 11 | # Same as Black. 12 | line-length = 88 13 | 14 | [tool.ruff.lint] 15 | select = [ 16 | "B", # flake8-bugbear 17 | "C4", # flake8-comprehensions 18 | #"D", # pydocstyle 19 | "E", # pycodestyle 20 | "F", # Pyflakes 21 | "G", # flake8-logging-format 22 | #"I", # isort 23 | "ISC", # flake8-implicit-str-concat 24 | "LOG", # flake8-logging 25 | #"N", # pep8-naming 26 | #"NPY", # modern numpy 27 | #"PERF", # Perflint 28 | "PIE", # flake8-pie 29 | "PYI", # flake8-pyi 30 | "RUF", # Ruff-specific rules 31 | "SIM", # flake8-simplify 32 | "SLOT", # flake8-slot 33 | "T10", # flake8-debugger 34 | #"TID", # Disallow relative imports 35 | #"TRY", # flake8-try-except-raise 36 | "UP", # pyupgrade 37 | "W", # pycodestyle 38 | "YTT", # flake8-2020 39 | ] 40 | 41 | [tool.ruff.lint.per-file-ignores] 42 | "**" = ["B905", "C401", "C408", "C413", "PYI041", "RUF012", "RUF100", "RUF010", "SIM108", "SIM910", "SIM110", "SIM102", "SIM114", "SIM103", "UP015", "UP027", "UP031", "UP034", "UP032", "UP006", "UP035", "UP007", "UP038"] 43 | "**/plot*.py" = ["B018"] 44 | "_doc/examples/plot_first_example.py" = ["E402", "F811"] 45 | "_doc/examples/plot_onnxruntime.py" = ["E402", "F811"] 46 | "onnx_array_api/array_api/_onnx_common.py" = ["F821"] 47 | "onnx_array_api/graph_api/__init__.py" = ["F401"] 48 | "onnx_array_api/light_api/__init__.py" = ["F401"] 49 | "onnx_array_api/light_api/_op_var.py" = ["F821"] 50 | "onnx_array_api/light_api/_op_vars.py" = ["F821"] 51 | "onnx_array_api/annotations.py" = ["F821"] 52 | "onnx_array_api/light_api/model.py" = ["F821"] 53 | "onnx_array_api/translate_api/__init__.py" = ["F401"] 54 | "onnx_array_api/npx/__init__.py" = ["F401", "F403"] 55 | "onnx_array_api/npx/npx_functions.py" = ["F821"] 56 | "onnx_array_api/npx/npx_functions_test.py" = ["F821"] 57 | "onnx_array_api/npx/npx_tensors.py" = ["F821"] 58 | "onnx_array_api/npx/npx_var.py" = ["F821"] 59 | "onnx_array_api/profiling.py" = ["E731"] 60 | "onnx_array_api/reference/__init__.py" = ["F401"] 61 | "_unittests/ut_npx/test_npx.py" = ["F821"] 62 | "_unittests/ut_translate_api/test_translate_classic.py" = ["E501"] 63 | 64 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | array_api_compat 2 | array_api_strict 3 | autopep8 4 | black 5 | coverage 6 | flake8 7 | furo 8 | google-re2 9 | hypothesis 10 | isort 11 | joblib 12 | lightgbm 13 | matplotlib 14 | ml-dtypes 15 | git+https://github.com/onnx/onnxmltools.git 16 | onnxruntime>=1.17.0 17 | openpyxl 18 | packaging 19 | pandas 20 | Pillow 21 | psutil 22 | pytest 23 | pytest-cov 24 | sphinx-issues 25 | git+https://github.com/sdpython/sphinx-runpython.git 26 | ruff 27 | scikit-learn>=1.3.2 28 | git+https://github.com/onnx/sklearn-onnx.git 29 | sphinx 30 | sphinx-gallery 31 | tomli 32 | tqdm 33 | wheel 34 | xgboost 35 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | onnx>=1.15.0 3 | scipy 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [options] 2 | packages = find: 3 | 4 | [options.packages.find] 5 | include = onnx_array_api* 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import setup 4 | 5 | ###################### 6 | # beginning of setup 7 | ###################### 8 | 9 | 10 | here = os.path.dirname(__file__) 11 | if here == "": 12 | here = "." 13 | package_data = {"onnx_array_api.validation": ["*.css", "*.js"]} 14 | 15 | try: 16 | with open(os.path.join(here, "requirements.txt"), "r") as f: 17 | requirements = f.read().strip(" \n\r\t").split("\n") 18 | except FileNotFoundError: 19 | requirements = [] 20 | if not requirements or requirements == [""]: 21 | requirements = ["numpy", "scipy", "onnx"] 22 | 23 | try: 24 | with open(os.path.join(here, "README.rst"), "r", encoding="utf-8") as f: 25 | long_description = "onnx-array-api:" + f.read().split("onnx-array-api:")[1] 26 | except FileNotFoundError: 27 | long_description = "" 28 | 29 | version_str = "0.1.0" 30 | with open(os.path.join(here, "onnx_array_api/__init__.py"), "r") as f: 31 | line = [ 32 | _ 33 | for _ in [_.strip("\r\n ") for _ in f.readlines()] 34 | if _.startswith("__version__") 35 | ] 36 | if line: 37 | version_str = line[0].split("=")[1].strip('" ') 38 | 39 | 40 | setup( 41 | name="onnx-array-api", 42 | version=version_str, 43 | description="Array (and numpy) API for ONNX", 44 | long_description=long_description, 45 | author="Xavier Dupré", 46 | author_email="xavier.dupre@gmail.com", 47 | url="https://github.com/sdpython/onnx-array-api", 48 | package_data=package_data, 49 | setup_requires=["numpy", "scipy"], 50 | install_requires=requirements, 51 | classifiers=[ 52 | "Intended Audience :: Science/Research", 53 | "Intended Audience :: Developers", 54 | "License :: OSI Approved :: MIT License", 55 | "Programming Language :: C", 56 | "Programming Language :: Python", 57 | "Topic :: Software Development", 58 | "Topic :: Scientific/Engineering", 59 | "Development Status :: 5 - Production/Stable", 60 | "Operating System :: Microsoft :: Windows", 61 | "Operating System :: POSIX", 62 | "Operating System :: Unix", 63 | "Operating System :: MacOS", 64 | "Programming Language :: Python :: 3", 65 | "Programming Language :: Python :: 3.9", 66 | "Programming Language :: Python :: 3.10", 67 | "Programming Language :: Python :: 3.11", 68 | "Programming Language :: Python :: 3.12", 69 | "Programming Language :: Python :: 3.13", 70 | ], 71 | ) 72 | --------------------------------------------------------------------------------