├── .github
└── workflows
│ ├── black-ruff.yml
│ ├── check-urls.yml
│ ├── codeql.yml
│ ├── documentation.yml
│ └── wheels-any.yml
├── .gitignore
├── CHANGELOGS.rst
├── CODE_OF_CONDUCT.md
├── LICENSE.txt
├── MANIFEST.in
├── README.rst
├── _doc
├── _static
│ └── logo.png
├── api
│ ├── array_api.rst
│ ├── array_api_numpy.rst
│ ├── array_api_ort.rst
│ ├── docs.rst
│ ├── f8.rst
│ ├── graph_api.rst
│ ├── index.rst
│ ├── light_api.rst
│ ├── npx_array_api.rst
│ ├── npx_core_api.rst
│ ├── npx_functions.rst
│ ├── npx_jit_eager.rst
│ ├── npx_numpy.rst
│ ├── npx_tensors.rst
│ ├── npx_types.rst
│ ├── npx_var.rst
│ ├── onnx_tools.rst
│ ├── ort.rst
│ ├── plotting.rst
│ ├── profiling.rst
│ ├── reference.rst
│ ├── tools.rst
│ └── translate_api.rst
├── command_lines.rst
├── conf.py
├── examples
│ ├── README.txt
│ ├── data
│ │ └── small.onnx
│ ├── plot_benchmark_rf.py
│ ├── plot_f8.py
│ ├── plot_first_example.py
│ ├── plot_onnx_diff.py
│ ├── plot_onnxruntime.py
│ ├── plot_optimization.py
│ └── plot_profiling.py
├── index.rst
├── license.rst
├── long_outputs.rst
├── run_coverage.sh
├── tech
│ ├── aapi.rst
│ └── index.rst
└── tutorial
│ ├── benchmarks.rst
│ ├── graph_api.rst
│ ├── index.rst
│ ├── light_api.rst
│ ├── numpy_api.rst
│ ├── onnx_api.rst
│ └── tools.rst
├── _unittests
├── onnx-numpy-skips.txt
├── onnx-ort-skips.txt
├── test_array_api.sh
├── ut_array_api
│ ├── test_array_apis.py
│ ├── test_hypothesis_array_api.py
│ ├── test_onnx_numpy.py
│ └── test_onnx_ort.py
├── ut_graph_api
│ ├── data
│ │ └── debug_7951-CPUep.0.onnx
│ ├── test_graph_builder.py
│ └── test_graph_builder_optim.py
├── ut_light_api
│ ├── test_backend_export.py
│ └── test_light_api.py
├── ut_npx
│ ├── test_npx.py
│ └── test_sklearn_array_api.py
├── ut_ort
│ ├── data
│ │ ├── prof_base.xlsx
│ │ └── prof_opti.xlsx
│ ├── test_ort_optimizer.py
│ ├── test_ort_profile.py
│ ├── test_ort_tensor.py
│ └── test_sklearn_array_api_ort.py
├── ut_plotting
│ ├── data
│ │ ├── bug_Hardmax.onnx
│ │ ├── onnx_text_plot_tree_cls_2.onnx
│ │ ├── prof.csv
│ │ └── tree_torch.onnx
│ ├── test_dot_plot.py
│ ├── test_graphviz.py
│ ├── test_stat_plot.py
│ └── test_text_plot.py
├── ut_reference
│ ├── test_array_tensor.py
│ ├── test_backend_extended_reference_evaluator.py
│ ├── test_evaluator_yield.py
│ └── test_reference_ops.py
├── ut_tools
│ └── test_replace_constants.py
├── ut_translate_api
│ ├── _data
│ │ ├── custom_ops_type_inference_fails_0.onnx
│ │ └── stft_inlined_batch_1.onnx
│ ├── test_translate.py
│ ├── test_translate_builder.py
│ └── test_translate_classic.py
├── ut_validation
│ ├── data
│ │ └── small.onnx
│ ├── test_diff.py
│ ├── test_docs.py
│ ├── test_f8.py
│ └── test_tools.py
├── ut_xrun_doc
│ ├── test_command_lines1.py
│ ├── test_documentation_examples.py
│ └── test_profiling.py
└── win_test_array_api.bat
├── azure-pipelines.yml
├── onnx_array_api
├── __init__.py
├── __main__.py
├── _command_lines_parser.py
├── _helpers.py
├── annotations.py
├── array_api
│ ├── __init__.py
│ ├── _onnx_common.py
│ ├── onnx_numpy.py
│ └── onnx_ort.py
├── cache.py
├── ext_test_case.py
├── graph_api
│ ├── __init__.py
│ └── graph_builder.py
├── light_api
│ ├── __init__.py
│ ├── _op_var.py
│ ├── _op_vars.py
│ ├── model.py
│ └── var.py
├── npx
│ ├── __init__.py
│ ├── npx_array_api.py
│ ├── npx_constants.py
│ ├── npx_core_api.py
│ ├── npx_function_implementation.py
│ ├── npx_functions.py
│ ├── npx_functions_test.py
│ ├── npx_graph_builder.py
│ ├── npx_helper.py
│ ├── npx_jit_eager.py
│ ├── npx_numpy_tensors.py
│ ├── npx_tensors.py
│ ├── npx_types.py
│ └── npx_var.py
├── ort
│ ├── __init__.py
│ ├── ort_optimizers.py
│ ├── ort_profile.py
│ └── ort_tensors.py
├── plotting
│ ├── __init__.py
│ ├── _helper.py
│ ├── dot_plot.py
│ ├── graphviz_helper.py
│ ├── stat_plot.py
│ └── text_plot.py
├── profiling.py
├── reference
│ ├── __init__.py
│ ├── evaluator.py
│ ├── evaluator_yield.py
│ └── ops
│ │ ├── __init__.py
│ │ ├── op_cast_like.py
│ │ ├── op_concat.py
│ │ ├── op_constant_of_shape.py
│ │ ├── op_fused_matmul.py
│ │ ├── op_memcpy_host.py
│ │ ├── op_quick_gelu.py
│ │ └── op_scatter_elements.py
├── tools
│ ├── __init__.py
│ └── replace_constants.py
├── translate_api
│ ├── __init__.py
│ ├── base_emitter.py
│ ├── builder_emitter.py
│ ├── inner_emitter.py
│ ├── light_emitter.py
│ ├── make_helper.py
│ └── translate.py
└── validation
│ ├── __init__.py
│ ├── diff.py
│ ├── diff2html-ui-slim.min.js
│ ├── diff2html.min.css
│ ├── docs.py
│ ├── f8.py
│ └── tools.py
├── pyproject.toml
├── requirements-dev.txt
├── requirements.txt
├── setup.cfg
└── setup.py
/.github/workflows/black-ruff.yml:
--------------------------------------------------------------------------------
1 | name: Black + Ruff Format Checker
2 | on: [push, pull_request]
3 | jobs:
4 | black-format-check:
5 | runs-on: ubuntu-latest
6 | steps:
7 | - uses: actions/checkout@v2
8 | - uses: psf/black@stable
9 | with:
10 | options: "--diff --check"
11 | src: "."
12 | ruff-format-check:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v3
16 | - uses: chartboost/ruff-action@v1
17 |
--------------------------------------------------------------------------------
/.github/workflows/check-urls.yml:
--------------------------------------------------------------------------------
1 | name: Check URLs
2 |
3 | on:
4 | pull_request:
5 | branches: [main]
6 | schedule:
7 | # ┌───────────── minute (0 - 59)
8 | # │ ┌───────────── hour (0 - 23)
9 | # │ │ ┌───────────── day of the month (1 - 31)
10 | # │ │ │ ┌───────────── month (1 - 12 or JAN-DEC)
11 | # │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT)
12 | # │ │ │ │ │
13 | # │ │ │ │ │
14 | # │ │ │ │ │
15 | # * * * * *
16 | - cron: '30 1 * * 0'
17 |
18 | jobs:
19 | build:
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v3
24 |
25 | - name: urls-checker-code
26 | uses: urlstechie/urlchecker-action@master
27 | with:
28 | subfolder: onnx_array_api
29 | file_types: .md,.py,.rst,.ipynb
30 | print_all: false
31 | timeout: 2
32 | retry_count# : 2
33 | # exclude_urls: https://dumps.wikimedia.org/other/pageviews/%Y/%Y-%m/pageviews-%Y%m%d-%H0000.gz,https://dumps.wikimedia.org/frwiki/latest/latest-all-titles-in-ns0.gz
34 | exclude_patterns: https://dumps.wikimedia.org/
35 | # force_pass : true
36 |
37 | - name: urls-checker-docs
38 | uses: urlstechie/urlchecker-action@master
39 | with:
40 | subfolder: _doc
41 | file_types: .md,.py,.rst,.ipynb
42 | print_all: false
43 | timeout: 2
44 | retry_count# : 2
45 | exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document,https://github.com/onnx/tensorflow-onnx
46 | exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://github.com/onnx/tensorflow-onnx
47 | # force_pass : true
48 |
--------------------------------------------------------------------------------
/.github/workflows/codeql.yml:
--------------------------------------------------------------------------------
1 | name: "Code Scanning - Action"
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 | schedule:
9 | # ┌───────────── minute (0 - 59)
10 | # │ ┌───────────── hour (0 - 23)
11 | # │ │ ┌───────────── day of the month (1 - 31)
12 | # │ │ │ ┌───────────── month (1 - 12 or JAN-DEC)
13 | # │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT)
14 | # │ │ │ │ │
15 | # │ │ │ │ │
16 | # │ │ │ │ │
17 | # * * * * *
18 | - cron: '30 1 * * 0'
19 |
20 | jobs:
21 | CodeQL-Build:
22 | # CodeQL runs on ubuntu-latest, windows-latest, and macos-latest
23 | runs-on: ubuntu-latest
24 |
25 | permissions:
26 | # required for all workflows
27 | security-events: write
28 |
29 | # only required for workflows in private repositories
30 | actions: read
31 | contents: read
32 |
33 | steps:
34 | - name: Checkout repository
35 | uses: actions/checkout@v3
36 |
37 | # Initializes the CodeQL tools for scanning.
38 | - name: Initialize CodeQL
39 | uses: github/codeql-action/init@v2
40 | # Override language selection by uncommenting this and choosing your languages
41 | # with:
42 | # languages: go, javascript, csharp, python, cpp, java, ruby
43 |
44 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java).
45 | # If this step fails, then you should remove it and run the build manually (see below).
46 | - name: Autobuild
47 | uses: github/codeql-action/autobuild@v2
48 |
49 | # ℹ️ Command-line programs to run using the OS shell.
50 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
51 |
52 | # ✏️ If the Autobuild fails above, remove it and uncomment the following
53 | # three lines and modify them (or add more) to build your code if your
54 | # project uses a compiled language
55 |
56 | #- run: |
57 | # make bootstrap
58 | # make release
59 |
60 | - name: Perform CodeQL Analysis
61 | uses: github/codeql-action/analyze@v2
62 |
--------------------------------------------------------------------------------
/.github/workflows/documentation.yml:
--------------------------------------------------------------------------------
1 | name: Documentation and Code Coverage
2 |
3 | on:
4 | push:
5 | pull_request:
6 | types:
7 | - closed
8 | branches:
9 | - main
10 |
11 | jobs:
12 | run:
13 | name: Build documentation on ${{ matrix.os }}
14 | runs-on: ${{ matrix.os }}
15 | strategy:
16 | matrix:
17 | os: [ubuntu-latest]
18 |
19 | steps:
20 | - uses: actions/checkout@v3
21 |
22 | - uses: actions/setup-python@v4
23 | with:
24 | python-version: '3.12'
25 |
26 | - uses: tlylt/install-graphviz@v1
27 |
28 | - name: Install pandoc
29 | run: sudo apt-get install -y pandoc
30 |
31 | - name: Install requirements
32 | run: python -m pip install -r requirements.txt
33 |
34 | - name: Install requirements dev
35 | run: python -m pip install -r requirements-dev.txt
36 |
37 | - name: Cache pip
38 | uses: actions/cache@v4
39 | with:
40 | path: ~/.cache/pip
41 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }}
42 | restore-keys: |
43 | ${{ runner.os }}-pip-
44 | ${{ runner.os }}-
45 |
46 | - name: Generate coverage report
47 | run: |
48 | pip install pytest
49 | pip install pytest-cov
50 | export PYTHONPATH=.
51 | pytest --cov=./onnx_array_api/ --cov-report=xml --durations=10 --ignore-glob=**LONG*.py --ignore-glob=**notebook*.py
52 | export PYTHONPATH=
53 |
54 | - name: Upload coverage reports to Codecov
55 | uses: codecov/codecov-action@v3
56 | env:
57 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
58 |
59 | - name: Install
60 | run: python -m pip install -e . -v
61 |
62 | - name: Copy license, changelogs
63 | run: |
64 | cp LICENSE* ./_doc
65 | cp CHANGELOGS* ./_doc
66 |
67 | - name: Documentation
68 | run: python -m sphinx ./_doc ./dist/html -n -w doc.txt
69 |
70 | - name: Summary
71 | run: cat doc.txt
72 |
73 | - name: Check for errors and warnings
74 | run: |
75 | if [[ $(grep ERROR doc.txt) ]]; then
76 | echo "Documentation produces errors."
77 | grep ERROR doc.txt
78 | exit 1
79 | fi
80 | if [[ $(grep WARNING doc.txt) ]]; then
81 | echo "Documentation produces warnings."
82 | grep WARNING doc.txt
83 | exit 1
84 | fi
85 |
86 | - uses: actions/upload-artifact@v4
87 | with:
88 | path: ./dist/html/**
89 |
--------------------------------------------------------------------------------
/.github/workflows/wheels-any.yml:
--------------------------------------------------------------------------------
1 | name: Build Any Wheel
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - 'releases/**'
8 |
9 | jobs:
10 | build_wheels:
11 | name: Build wheels on ${{ matrix.os }}
12 | runs-on: ${{ matrix.os }}
13 | strategy:
14 | matrix:
15 | os: [ubuntu-latest]
16 |
17 | steps:
18 | - uses: actions/checkout@v3
19 |
20 | - uses: actions/setup-python@v4
21 | with:
22 | python-version: '3.12'
23 |
24 | - name: build wheel
25 | run: python -m pip wheel .
26 |
27 | - uses: actions/upload-artifact@v4
28 | with:
29 | path: ./onnx_array_api*.whl
30 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.pyd
3 | *.dylib
4 | *.so
5 | *.whl
6 | *.xlsx
7 | coverage.html/*
8 | _cache/*
9 | .coverage
10 | dist/*
11 | build/*
12 | .eggs/*
13 | .hypothesis/*
14 | *egg-info/*
15 | onnxruntime_profile*
16 | prof
17 | test*.png
18 | _doc/sg_execution_times.rst
19 | _doc/auto_examples/*
20 | _doc/examples/_cache/*
21 | _doc/examples/onnxruntime_profile*
22 | _doc/examples/plot_*.png
23 | _doc/examples/plot_*.xlsx
24 | _doc/examples/data/*.optimized.onnx
25 | _doc/examples/*.html
26 | _doc/_static/require.js
27 | _doc/_static/viz.js
28 | _doc/LICENSE.txt
29 | _doc/CHANGELOGS.rst
30 | _unittests/ut__main/*.png
31 | _unittests/ut__main/_cache/*
32 | _unittests/ut__main/*.html
33 | _unittests/.hypothesis/*
34 |
--------------------------------------------------------------------------------
/CHANGELOGS.rst:
--------------------------------------------------------------------------------
1 | Change Logs
2 | ===========
3 |
4 | 0.3.1
5 | +++++
6 |
7 | * :pr:`100`: updates requirements, add 3.12
8 | * :pr:`96`: supports local functions in translator
9 | * :pr:`95`: improves translation to GraphBuilder
10 |
11 | 0.3.0
12 | +++++
13 |
14 | * :pr:`93`: fixes evaluator type in ``compare_onnx_execution``
15 | * :pr:`92`: avoids recursion errors in profiling
16 | * :pr:`87`: adds command line to replace contant by ConstantOfShape
17 | * :pr:`79`: first draft to export to GraphBuilder
18 | * :pr:`77`: supports ConcatOfShape and Slice with the light API
19 |
20 | 0.2.0
21 | +++++
22 |
23 | * :pr:`76`, :pr:`79`: add a mode to compare models without execution
24 | * :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
25 | * :pr:`71`: adds tools to compare two onnx graphs
26 | * :pr:`61`: adds function to plot onnx model as graphs
27 | * :pr:`60`: supports translation of local functions
28 | * :pr:`59`: add methods to update nodes in GraphAPI
29 |
30 | 0.1.3
31 | +++++
32 |
33 | * :pr:`57`: implements GraphBuilder
34 | * :pr:`49`: adds command line to export a model into code
35 | * :pr:`48`: support for subgraph in light API
36 | * :pr:`47`: extends export onnx to code to support inner API
37 | * :pr:`46`: adds an export to convert an onnx graph into light API code
38 | * :pr:`45`: fixes light API for operators with two outputs
39 |
40 | 0.1.2
41 | +++++
42 |
43 | * :pr:`42`: first sketch for a very simple API to create onnx graph in one or two lines
44 | * :pr:`27`: add function from_array_extended to convert
45 | an array to a TensorProto, including bfloat16 and float 8 types
46 | * :pr:`24`: add ExtendedReferenceEvaluator to support scenario
47 | for the Array API onnx does not support
48 | * :pr:`22`: support OrtValue in function *ort_profile*
49 | * :pr:`17`: implements ArrayAPI
50 | * :pr:`3`: fixes Array API with onnxruntime and scikit-learn
51 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | We are a community based on openness, as well as friendly and didactic discussions.
4 |
5 | We aspire to treat everybody equally, and value their contributions.
6 |
7 | Decisions are made based on technical merit and consensus.
8 |
9 | Code is not the only way to help the project. Reviewing pull requests,
10 | answering questions to help others on mailing lists or issues, organizing and
11 | teaching tutorials, working on the website, improving the documentation, are
12 | all priceless contributions.
13 |
14 | We abide by the principles of openness, respect, and consideration of others of
15 | the Python Software Foundation: https://www.python.org/psf/codeofconduct/
16 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2023-2025, Xavier Dupré
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | prune _doc
2 | prune _unittests
3 | exclude *.bat
4 | exclude *.yml
5 | exclude *.git*
6 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 |
2 | .. image:: https://github.com/sdpython/onnx-array-api/raw/main/_doc/_static/logo.png
3 | :width: 120
4 |
5 | onnx-array-api: APIs to create ONNX Graphs
6 | ==========================================
7 |
8 | .. image:: https://dev.azure.com/xavierdupre3/onnx-array-api/_apis/build/status/sdpython.onnx-array-api
9 | :target: https://dev.azure.com/xavierdupre3/onnx-array-api/
10 |
11 | .. image:: https://badge.fury.io/py/onnx-array-api.svg
12 | :target: http://badge.fury.io/py/onnx-array-api
13 |
14 | .. image:: http://img.shields.io/github/issues/sdpython/onnx-array-api.png
15 | :alt: GitHub Issues
16 | :target: https://github.com/sdpython/onnx-array-api/issues
17 |
18 | .. image:: https://img.shields.io/badge/license-MIT-blue.svg
19 | :alt: MIT License
20 | :target: https://opensource.org/license/MIT/
21 |
22 | .. image:: https://img.shields.io/github/repo-size/sdpython/onnx-array-api
23 | :target: https://github.com/sdpython/onnx-array-api/
24 | :alt: size
25 |
26 | .. image:: https://img.shields.io/badge/code%20style-black-000000.svg
27 | :target: https://github.com/psf/black
28 |
29 | .. image:: https://codecov.io/gh/sdpython/onnx-array-api/branch/main/graph/badge.svg?token=Wb9ZGDta8J
30 | :target: https://codecov.io/gh/sdpython/onnx-array-api
31 |
32 | **onnx-array-api** implements APIs to create custom ONNX graphs.
33 | The objective is to speed up the implementation of converter libraries.
34 | The library is released on
35 | `pypi/onnx-array-api `_
36 | and its documentation is published at
37 | `APIs to create ONNX Graphs `_.
38 |
39 | Numpy API
40 | +++++++++
41 |
42 | The first one matches **numpy API**.
43 | It gives the user the ability to convert functions written
44 | following the numpy API to convert that function into ONNX as
45 | well as to execute it.
46 |
47 | .. code-block:: python
48 |
49 | import numpy as np
50 | from onnx_array_api.npx import absolute, jit_onnx
51 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
52 |
53 | def l1_loss(x, y):
54 | return absolute(x - y).sum()
55 |
56 |
57 | def l2_loss(x, y):
58 | return ((x - y) ** 2).sum()
59 |
60 |
61 | def myloss(x, y):
62 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
63 |
64 |
65 | jitted_myloss = jit_onnx(myloss)
66 |
67 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
68 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
69 |
70 | res = jitted_myloss(x, y)
71 | print(res)
72 |
73 | print(onnx_simple_text_plot(jitted_myloss.get_onnx()))
74 |
75 | ::
76 |
77 | [0.042]
78 | opset: domain='' version=18
79 | input: name='x0' type=dtype('float32') shape=['', '']
80 | input: name='x1' type=dtype('float32') shape=['', '']
81 | Sub(x0, x1) -> r__0
82 | Abs(r__0) -> r__1
83 | ReduceSum(r__1, keepdims=0) -> r__2
84 | output: name='r__2' type=dtype('float32') shape=None
85 |
86 | It supports eager mode as well:
87 |
88 | .. code-block:: python
89 |
90 | import numpy as np
91 | from onnx_array_api.npx import absolute, eager_onnx
92 |
93 |
94 | def l1_loss(x, y):
95 | err = absolute(x - y).sum()
96 | print(f"l1_loss={err.numpy()}")
97 | return err
98 |
99 |
100 | def l2_loss(x, y):
101 | err = ((x - y) ** 2).sum()
102 | print(f"l2_loss={err.numpy()}")
103 | return err
104 |
105 |
106 | def myloss(x, y):
107 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
108 |
109 |
110 | eager_myloss = eager_onnx(myloss)
111 |
112 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
113 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
114 |
115 | res = eager_myloss(x, y)
116 | print(res)
117 |
118 | ::
119 |
120 | l1_loss=[0.04]
121 | l2_loss=[0.002]
122 | [0.042]
123 |
124 | Light API
125 | +++++++++
126 |
127 | The second API or **Light API** tends to do every thing in one line.
128 | It is inspired from the `Reverse Polish Notation
129 | `_.
130 | The euclidean distance looks like the following:
131 |
132 | .. code-block:: python
133 |
134 | import numpy as np
135 | from onnx_array_api.light_api import start
136 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
137 |
138 | model = (
139 | start()
140 | .vin("X")
141 | .vin("Y")
142 | .bring("X", "Y")
143 | .Sub()
144 | .rename("dxy")
145 | .cst(np.array([2], dtype=np.int64), "two")
146 | .bring("dxy", "two")
147 | .Pow()
148 | .ReduceSum()
149 | .rename("Z")
150 | .vout()
151 | .to_onnx()
152 | )
153 |
154 | GraphBuilder API
155 | ++++++++++++++++
156 |
157 | Almost every converting library (converting a machine learned model to ONNX) is implementing
158 | its own graph builder and customizes it for its needs.
159 | It handles some frequent tasks such as giving names to intermediate
160 | results, loading, saving onnx models. It can be used as well to extend an existing graph.
161 |
162 | .. code-block:: python
163 |
164 | import numpy as np
165 | from onnx_array_api.graph_api import GraphBuilder
166 |
167 | g = GraphBuilder()
168 | g.make_tensor_input("X", np.float32, (None, None))
169 | g.make_tensor_input("Y", np.float32, (None, None))
170 | r1 = g.make_node("Sub", ["X", "Y"]) # the name given to the output is given by the class,
171 | # it ensures the name is unique
172 | init = g.make_initializer(np.array([2], dtype=np.int64)) # the class automatically
173 | # converts the array to a tensor
174 | r2 = g.make_node("Pow", [r1, init])
175 | g.make_node("ReduceSum", [r2], outputs=["Z"]) # the output name is given because
176 | # the user wants to choose the name
177 | g.make_tensor_output("Z", np.float32, (None, None))
178 |
179 | onx = g.to_onnx() # final conversion to onnx
180 |
--------------------------------------------------------------------------------
/_doc/_static/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_doc/_static/logo.png
--------------------------------------------------------------------------------
/_doc/api/array_api.rst:
--------------------------------------------------------------------------------
1 | onnx_array_api.array_api
2 | ========================
3 |
4 | .. toctree::
5 |
6 | array_api_numpy
7 | array_api_ort
8 | npx_array_api
9 |
--------------------------------------------------------------------------------
/_doc/api/array_api_numpy.rst:
--------------------------------------------------------------------------------
1 | onnx_array_api.array_api.onnx_numpy
2 | =============================================
3 |
4 | .. automodule:: onnx_array_api.array_api.onnx_numpy
5 | :members:
6 |
--------------------------------------------------------------------------------
/_doc/api/array_api_ort.rst:
--------------------------------------------------------------------------------
1 | onnx_array_api.array_api.onnx_ort
2 | =================================
3 |
4 | .. automodule:: onnx_array_api.array_api.onnx_ort
5 | :members:
6 |
--------------------------------------------------------------------------------
/_doc/api/docs.rst:
--------------------------------------------------------------------------------
1 | validation.docs
2 | ===============
3 |
4 | make_euclidean
5 | ++++++++++++++
6 |
7 | .. autofunction:: onnx_array_api.validation.docs.make_euclidean
8 |
--------------------------------------------------------------------------------
/_doc/api/f8.rst:
--------------------------------------------------------------------------------
1 | Float 8
2 | =======
3 |
4 | .. automodule:: onnx_array_api.validation.f8
5 | :members:
6 |
--------------------------------------------------------------------------------
/_doc/api/graph_api.rst:
--------------------------------------------------------------------------------
1 | ========================
2 | onnx_array_api.graph_api
3 | ========================
4 |
5 |
6 | GraphBuilder
7 | ============
8 |
9 | .. autoclass:: onnx_array_api.graph_api.GraphBuilder
10 | :members:
11 |
12 | NodePattern
13 | ===========
14 |
15 | .. autoclass:: onnx_array_api.graph_api.NodePattern
16 | :members:
17 |
18 | OptimizationOptions
19 | ===================
20 |
21 | .. autoclass:: onnx_array_api.graph_api.graph_builder.OptimizationOptions
22 | :members:
23 |
--------------------------------------------------------------------------------
/_doc/api/index.rst:
--------------------------------------------------------------------------------
1 |
2 | ===
3 | API
4 | ===
5 |
6 | .. toctree::
7 | :maxdepth: 1
8 |
9 | array_api
10 | graph_api
11 | light_api
12 | translate_api
13 | npx_core_api
14 | npx_functions
15 | npx_jit_eager
16 | npx_numpy
17 | npx_tensors
18 | npx_types
19 | npx_var
20 | onnx_tools
21 | ort
22 | plotting
23 | reference
24 | tools
25 | profiling
26 | f8
27 | docs
28 |
--------------------------------------------------------------------------------
/_doc/api/light_api.rst:
--------------------------------------------------------------------------------
1 | ========================
2 | onnx_array_api.light_api
3 | ========================
4 |
5 |
6 | Main API
7 | ========
8 |
9 | start
10 | +++++
11 |
12 | .. autofunction:: onnx_array_api.light_api.start
13 |
14 | g
15 | +
16 |
17 | .. autofunction:: onnx_array_api.light_api.g
18 |
19 | Classes for the Light API
20 | =========================
21 |
22 | domain
23 | ++++++
24 |
25 | ..autofunction:: onnx_array_api.light_api.domain
26 |
27 | BaseVar
28 | +++++++
29 |
30 | .. autoclass:: onnx_array_api.light_api.var.BaseVar
31 | :members:
32 |
33 | OnnxGraph
34 | +++++++++
35 |
36 | .. autoclass:: onnx_array_api.light_api.OnnxGraph
37 | :members:
38 |
39 | ProtoType
40 | +++++++++
41 |
42 | .. autoclass:: onnx_array_api.light_api.model.ProtoType
43 | :members:
44 |
45 | SubDomain
46 | +++++++++
47 |
48 | .. autoclass:: onnx_array_api.light_api.var.SubDomain
49 | :members:
50 |
51 | Var
52 | +++
53 |
54 | .. autoclass:: onnx_array_api.light_api.Var
55 | :members:
56 | :inherited-members:
57 |
58 | Vars
59 | ++++
60 |
61 | .. autoclass:: onnx_array_api.light_api.Vars
62 | :members:
63 | :inherited-members:
64 |
65 | Available operators
66 | ===================
67 |
68 | One input
69 | +++++++++
70 |
71 | .. autoclass:: onnx_array_api.light_api._op_var.OpsVar
72 | :members:
73 |
74 | Two inputs or more
75 | ++++++++++++++++++
76 |
77 | .. autoclass:: onnx_array_api.light_api._op_vars.OpsVars
78 | :members:
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/_doc/api/npx_array_api.rst:
--------------------------------------------------------------------------------
1 | onnx_array_api.npx.npx_array_api
2 | ================================
3 |
4 | .. automodule:: onnx_array_api.npx.npx_array_api.BaseArrayApi
5 | :members:
6 |
--------------------------------------------------------------------------------
/_doc/api/npx_core_api.rst:
--------------------------------------------------------------------------------
1 | ============
2 | npx_core_api
3 | ============
4 |
5 | cst
6 | ===
7 |
8 | .. autofunction:: onnx_array_api.npx.npx_core_api.cst
9 |
10 | make_tuple
11 | ==========
12 |
13 | .. autofunction:: onnx_array_api.npx.npx_core_api.make_tuple
14 |
15 | tuple_var
16 | =========
17 |
18 | .. autofunction:: onnx_array_api.npx.npx_core_api.tuple_var
19 |
20 | npxapi_inline
21 | =============
22 |
23 | .. autofunction:: onnx_array_api.npx.npx_core_api.npxapi_inline
24 |
25 | npxapi_function
26 | ===============
27 |
28 | .. autofunction:: onnx_array_api.npx.npx_core_api.npxapi_function
29 |
30 | var
31 | ===
32 |
33 | .. autofunction:: onnx_array_api.npx.npx_core_api.var
34 |
--------------------------------------------------------------------------------
/_doc/api/npx_functions.rst:
--------------------------------------------------------------------------------
1 | npx.npx_functions
2 | =================
3 |
4 | .. autofunction:: onnx_array_api.npx.npx_functions.abs
5 |
6 | .. autofunction:: onnx_array_api.npx.npx_functions.absolute
7 |
8 | .. autofunction:: onnx_array_api.npx.npx_functions.arccos
9 |
10 | .. autofunction:: onnx_array_api.npx.npx_functions.arccosh
11 |
12 | .. autofunction:: onnx_array_api.npx.npx_functions.amax
13 |
14 | .. autofunction:: onnx_array_api.npx.npx_functions.amin
15 |
16 | .. autofunction:: onnx_array_api.npx.npx_functions.arange
17 |
18 | .. autofunction:: onnx_array_api.npx.npx_functions.argmax
19 |
20 | .. autofunction:: onnx_array_api.npx.npx_functions.argmin
21 |
22 | .. autofunction:: onnx_array_api.npx.npx_functions.arcsin
23 |
24 | .. autofunction:: onnx_array_api.npx.npx_functions.arcsinh
25 |
26 | .. autofunction:: onnx_array_api.npx.npx_functions.arctan
27 |
28 | .. autofunction:: onnx_array_api.npx.npx_functions.arctanh
29 |
30 | .. autofunction:: onnx_array_api.npx.npx_functions.cdist
31 |
32 | .. autofunction:: onnx_array_api.npx.npx_functions.ceil
33 |
34 | .. autofunction:: onnx_array_api.npx.npx_functions.clip
35 |
36 | .. autofunction:: onnx_array_api.npx.npx_functions.compress
37 |
38 | .. autofunction:: onnx_array_api.npx.npx_functions.compute
39 |
40 | .. autofunction:: onnx_array_api.npx.npx_functions.concat
41 |
42 | .. autofunction:: onnx_array_api.npx.npx_functions.cos
43 |
44 | .. autofunction:: onnx_array_api.npx.npx_functions.cosh
45 |
46 | .. autofunction:: onnx_array_api.npx.npx_functions.cumsum
47 |
48 | .. autofunction:: onnx_array_api.npx.npx_functions.det
49 |
50 | .. autofunction:: onnx_array_api.npx.npx_functions.dot
51 |
52 | .. autofunction:: onnx_array_api.npx.npx_functions.einsum
53 |
54 | .. autofunction:: onnx_array_api.npx.npx_functions.erf
55 |
56 | .. autofunction:: onnx_array_api.npx.npx_functions.exp
57 |
58 | .. autofunction:: onnx_array_api.npx.npx_functions.expand_dims
59 |
60 | .. autofunction:: onnx_array_api.npx.npx_functions.expit
61 |
62 | .. autofunction:: onnx_array_api.npx.npx_functions.floor
63 |
64 | .. autofunction:: onnx_array_api.npx.npx_functions.hstack
65 |
66 | .. autofunction:: onnx_array_api.npx.npx_functions.copy
67 |
68 | .. autofunction:: onnx_array_api.npx.npx_functions.identity
69 |
70 | .. autofunction:: onnx_array_api.npx.npx_functions.isnan
71 |
72 | .. autofunction:: onnx_array_api.npx.npx_functions.log
73 |
74 | .. autofunction:: onnx_array_api.npx.npx_functions.log1p
75 |
76 | .. autofunction:: onnx_array_api.npx.npx_functions.matmul
77 |
78 | .. autofunction:: onnx_array_api.npx.npx_functions.pad
79 |
80 | .. autofunction:: onnx_array_api.npx.npx_functions.reciprocal
81 |
82 | .. autofunction:: onnx_array_api.npx.npx_functions.relu
83 |
84 | .. autofunction:: onnx_array_api.npx.npx_functions.round
85 |
86 | .. autofunction:: onnx_array_api.npx.npx_functions.sigmoid
87 |
88 | .. autofunction:: onnx_array_api.npx.npx_functions.sign
89 |
90 | .. autofunction:: onnx_array_api.npx.npx_functions.sin
91 |
92 | .. autofunction:: onnx_array_api.npx.npx_functions.sinh
93 |
94 | .. autofunction:: onnx_array_api.npx.npx_functions.squeeze
95 |
96 | .. autofunction:: onnx_array_api.npx.npx_functions.tan
97 |
98 | .. autofunction:: onnx_array_api.npx.npx_functions.tanh
99 |
100 | .. autofunction:: onnx_array_api.npx.npx_functions.topk
101 |
102 | .. autofunction:: onnx_array_api.npx.npx_functions.transpose
103 |
104 | .. autofunction:: onnx_array_api.npx.npx_functions.unsqueeze
105 |
106 | .. autofunction:: onnx_array_api.npx.npx_functions.vstack
107 |
108 | .. autofunction:: onnx_array_api.npx.npx_functions.where
109 |
--------------------------------------------------------------------------------
/_doc/api/npx_jit_eager.rst:
--------------------------------------------------------------------------------
1 | =============
2 | npx_jit_eager
3 | =============
4 |
5 | eager_onnx
6 | ==========
7 |
8 | .. autofunction:: onnx_array_api.npx.npx_jit_eager.eager_onnx
9 |
10 | EagerOnnx
11 | =========
12 |
13 | .. autoclass:: onnx_array_api.npx.npx_jit_eager.EagerOnnx
14 | :members:
15 |
16 | JitEager
17 | ========
18 |
19 | .. autoclass:: onnx_array_api.npx.npx_jit_eager.JitEager
20 | :members:
21 |
22 | jit_onnx
23 | ========
24 |
25 | .. autofunction:: onnx_array_api.npx.npx_jit_eager.jit_onnx
26 |
27 | JitOnnx
28 | =======
29 |
30 | .. autoclass:: onnx_array_api.npx.npx_jit_eager.JitOnnx
31 | :members:
32 |
--------------------------------------------------------------------------------
/_doc/api/npx_numpy.rst:
--------------------------------------------------------------------------------
1 | npx.npx_numpy_tensors
2 | =====================
3 |
4 | EagerNumpyTensor
5 | ++++++++++++++++
6 |
7 | .. autoclass:: onnx_array_api.npx.npx_numpy_tensors.EagerNumpyTensor
8 | :members:
9 |
10 | JitNumpyTensor
11 | ++++++++++++++
12 |
13 | .. autoclass:: onnx_array_api.npx.npx_numpy_tensors.JitNumpyTensor
14 | :members:
15 |
16 | NumpyTensor
17 | +++++++++++
18 |
19 | .. autoclass:: onnx_array_api.npx.npx_numpy_tensors.NumpyTensor
20 | :members:
--------------------------------------------------------------------------------
/_doc/api/npx_tensors.rst:
--------------------------------------------------------------------------------
1 | ===========
2 | npx_tensors
3 | ===========
4 |
5 |
6 | EagerTensor
7 | ===========
8 |
9 | .. autoclass:: onnx_array_api.npx.npx_tensors.EagerTensor
10 | :members:
11 |
--------------------------------------------------------------------------------
/_doc/api/npx_types.rst:
--------------------------------------------------------------------------------
1 | npx.npx_types
2 | =============
3 |
4 | DType
5 | +++++
6 |
7 | .. autoclass:: onnx_array_api.npx.npx_types.DType
8 | :members:
9 |
10 | ElemType
11 | ++++++++
12 |
13 | .. autoclass:: onnx_array_api.npx.npx_types.ElemType
14 | :members:
15 |
16 | OptParType
17 | ++++++++++
18 |
19 | .. autoclass:: onnx_array_api.npx.npx_types.OptParType
20 | :members:
21 |
22 | OptTensorType
23 | +++++++++++++
24 |
25 | .. autoclass:: onnx_array_api.npx.npx_types.OptTensorType
26 | :members:
27 |
28 | ParType
29 | +++++++
30 |
31 | .. autoclass:: onnx_array_api.npx.npx_types.ParType
32 | :members:
33 |
34 | Scalar
35 | ++++++
36 |
37 | .. autoclass:: onnx_array_api.npx.npx_types.Scalar
38 | :members:
39 |
40 | SequenceType
41 | ++++++++++++
42 |
43 | .. autoclass:: onnx_array_api.npx.npx_types.SequenceType
44 | :members:
45 |
46 | ShapeType
47 | +++++++++
48 |
49 | .. autoclass:: onnx_array_api.npx.npx_types.ShapeType
50 | :members:
51 |
52 | TensorType
53 | ++++++++++
54 |
55 | .. autoclass:: onnx_array_api.npx.npx_types.TensorType
56 | :members:
57 |
58 | TupleType
59 | +++++++++
60 |
61 | .. autoclass:: onnx_array_api.npx.npx_types.TupleType
62 | :members:
63 |
64 | Shortcuts
65 | =========
66 |
67 | .. autoclass:: onnx_array_api.npx.npx_types.Bool
68 |
69 | .. autoclass:: onnx_array_api.npx.npx_types.BFloat16
70 |
71 | .. autoclass:: onnx_array_api.npx.npx_types.Float16
72 |
73 | .. autoclass:: onnx_array_api.npx.npx_types.Float32
74 |
75 | .. autoclass:: onnx_array_api.npx.npx_types.Float64
76 |
77 | .. autoclass:: onnx_array_api.npx.npx_types.Int8
78 |
79 | .. autoclass:: onnx_array_api.npx.npx_types.Int16
80 |
81 | .. autoclass:: onnx_array_api.npx.npx_types.Int32
82 |
83 | .. autoclass:: onnx_array_api.npx.npx_types.Int64
84 |
85 | .. autoclass:: onnx_array_api.npx.npx_types.UInt8
86 |
87 | .. autoclass:: onnx_array_api.npx.npx_types.UInt16
88 |
89 | .. autoclass:: onnx_array_api.npx.npx_types.UInt32
90 |
91 | .. autoclass:: onnx_array_api.npx.npx_types.UInt64
92 |
--------------------------------------------------------------------------------
/_doc/api/npx_var.rst:
--------------------------------------------------------------------------------
1 | npx.npx_var
2 | ===========
3 |
4 | Var
5 | +++
6 |
7 | .. autoclass:: onnx_array_api.npx.npx_var.Var
8 | :members:
9 |
10 | Cst, Input
11 | ++++++++++
12 |
13 | .. autoclass:: onnx_array_api.npx.npx_var.Cst
14 | :members:
15 |
16 | .. autoclass:: onnx_array_api.npx.npx_var.Input
17 | :members:
18 |
19 | ManyIdentity
20 | ++++++++++++
21 |
22 | .. autoclass:: onnx_array_api.npx.npx_var.ManyIdentity
23 | :members:
24 |
25 | Par
26 | +++
27 |
28 | .. autoclass:: onnx_array_api.npx.npx_var.Par
29 | :members:
30 |
31 |
--------------------------------------------------------------------------------
/_doc/api/onnx_tools.rst:
--------------------------------------------------------------------------------
1 | onnx tools
2 | ==========
3 |
4 | Differences
5 | +++++++++++
6 |
7 | .. autofunction:: onnx_array_api.validation.diff.html_diff
8 |
9 | .. autofunction:: onnx_array_api.validation.diff.text_diff
10 |
11 | Protos
12 | ++++++
13 |
14 | .. autofunction:: onnx_array_api.validation.tools.randomize_proto
15 |
--------------------------------------------------------------------------------
/_doc/api/ort.rst:
--------------------------------------------------------------------------------
1 | ort
2 | ===
3 |
4 | ort_optimized_model
5 | +++++++++++++++++++
6 |
7 | .. autofunction:: onnx_array_api.ort.ort_optimizers.ort_optimized_model
8 |
9 | EagerOrtTensor
10 | ++++++++++++++
11 |
12 | .. autoclass:: onnx_array_api.ort.ort_tensors.EagerOrtTensor
13 | :members:
14 |
15 | JitOrtTensor
16 | ++++++++++++
17 |
18 | .. autoclass:: onnx_array_api.ort.ort_tensors.JitOrtTensor
19 | :members:
20 |
21 | OrtTensor
22 | +++++++++
23 |
24 | .. autoclass:: onnx_array_api.ort.ort_tensors.OrtTensor
25 | :members:
26 |
27 | merge_ort_profile
28 | +++++++++++++++++
29 |
30 | .. autofunction:: onnx_array_api.ort.ort_profile.merge_ort_profile
31 |
32 | ort_profile
33 | +++++++++++
34 |
35 | .. autofunction:: onnx_array_api.ort.ort_profile.ort_profile
36 |
--------------------------------------------------------------------------------
/_doc/api/plotting.rst:
--------------------------------------------------------------------------------
1 | plotting
2 | ========
3 |
4 | Dot
5 | +++
6 |
7 | .. autofunction:: onnx_array_api.plotting.dot_plot.to_dot
8 |
9 | .. autofunction:: onnx_array_api.plotting.graphviz_helper.plot_dot
10 |
11 | Statistics
12 | ++++++++++
13 |
14 | .. autofunction:: onnx_array_api.plotting.stat_plot.plot_ort_profile
15 |
16 | Text
17 | ++++
18 |
19 | .. autofunction:: onnx_array_api.plotting.text_plot.onnx_text_plot_tree
20 |
21 | .. autofunction:: onnx_array_api.plotting.text_plot.onnx_text_plot_io
22 |
23 | .. autofunction:: onnx_array_api.plotting.text_plot.onnx_simple_text_plot
24 |
--------------------------------------------------------------------------------
/_doc/api/profiling.rst:
--------------------------------------------------------------------------------
1 | profiling
2 | =========
3 |
4 | ProfileNode
5 | +++++++++++
6 |
7 | .. autoclass:: onnx_array_api.profiling.ProfileNode
8 |
9 | profile
10 | +++++++
11 |
12 | .. autofunction:: onnx_array_api.profiling.profile
13 |
14 | profile2graph
15 | +++++++++++++
16 |
17 | .. autofunction:: onnx_array_api.profiling.profile2graph
18 |
--------------------------------------------------------------------------------
/_doc/api/reference.rst:
--------------------------------------------------------------------------------
1 | reference
2 | =========
3 |
4 | ExtendedReferenceEvaluator
5 | ++++++++++++++++++++++++++
6 |
7 | .. autoclass:: onnx_array_api.reference.ExtendedReferenceEvaluator
8 | :members:
9 |
10 | ResultType
11 | ++++++++++
12 |
13 | .. autoclass:: onnx_array_api.reference.ResultType
14 | :members:
15 |
16 | ResultExecution
17 | +++++++++++++++
18 |
19 | .. autoclass:: onnx_array_api.reference.ResultExecution
20 | :members:
21 |
22 | YieldEvaluator
23 | ++++++++++++++
24 |
25 | .. autoclass:: onnx_array_api.reference.YieldEvaluator
26 | :members:
27 |
28 | DistanceExecution
29 | +++++++++++++++++
30 |
31 | .. autoclass:: onnx_array_api.reference.DistanceExecution
32 | :members:
33 |
34 | compare_onnx_execution
35 | ++++++++++++++++++++++
36 |
37 | .. autofunction:: onnx_array_api.reference.compare_onnx_execution
38 |
--------------------------------------------------------------------------------
/_doc/api/tools.rst:
--------------------------------------------------------------------------------
1 | tools
2 | =====
3 |
4 | Benchmark
5 | +++++++++
6 |
7 | .. autofunction:: onnx_array_api.ext_test_case.measure_time
8 |
9 | Manipulations
10 | +++++++++++++
11 |
12 | .. autofunction:: onnx_array_api.tools.replace_constants.replace_initializer_by_constant_of_shape
13 |
14 | Examples
15 | ++++++++
16 |
17 | .. autofunction:: onnx_array_api.ext_test_case.example_path
18 |
19 | Unit tests
20 | ++++++++++
21 |
22 | .. autofunction:: onnx_array_api.ext_test_case.ignore_warnings
23 |
24 | .. autoclass:: onnx_array_api.ext_test_case.ExtTestCase
25 | :members:
26 |
--------------------------------------------------------------------------------
/_doc/api/translate_api.rst:
--------------------------------------------------------------------------------
1 | ============================
2 | onnx_array_api.translate_api
3 | ============================
4 |
5 |
6 | Main API
7 | ========
8 |
9 | translate
10 | +++++++++
11 |
12 | .. autofunction:: onnx_array_api.translate_api.translate
13 |
14 | make_helper
15 | +++++++++++
16 |
17 | .. autofunction:: onnx_array_api.translate_api.make_helper.make_node_extended
18 |
19 | .. autofunction:: onnx_array_api.translate_api.make_helper.make_ref_attribute
20 |
21 | Classes for the Translater
22 | ==========================
23 |
24 | BaseEmitter
25 | +++++++++++
26 |
27 | .. autoclass:: onnx_array_api.translate_api.base_emitter.BaseEmitter
28 | :members:
29 |
30 | EventType
31 | +++++++++
32 |
33 | .. autoclass:: onnx_array_api.translate_api.base_emitter.EventType
34 | :members:
35 |
36 | InnerEmitter
37 | ++++++++++++
38 |
39 | .. autoclass:: onnx_array_api.translate_api.inner_emitter.InnerEmitter
40 | :members:
41 |
42 | InnerEmitterShortInitializer
43 | ++++++++++++++++++++++++++++
44 |
45 | .. autoclass:: onnx_array_api.translate_api.inner_emitter.InnerEmitterShortInitializer
46 | :members:
47 |
48 | LightEmitter
49 | ++++++++++++
50 |
51 | .. autoclass:: onnx_array_api.translate_api.light_emitter.LightEmitter
52 | :members:
53 |
54 | Translater
55 | ++++++++++
56 |
57 | .. autoclass:: onnx_array_api.translate_api.translate.Translater
58 | :members:
59 |
--------------------------------------------------------------------------------
/_doc/command_lines.rst:
--------------------------------------------------------------------------------
1 | =============
2 | command lines
3 | =============
4 |
5 | compare
6 | =======
7 |
8 | The function convers an onnx file into some code.
9 |
10 | ::
11 |
12 | python -m compare -m1 model1.onnx -m2 model2.onnx -v 1
13 |
14 | Output example::
15 |
16 | [compare_onnx_execution] got 2 inputs
17 | [compare_onnx_execution] execute first model
18 | [compare_onnx_execution] got 5 results
19 | [compare_onnx_execution] execute second model
20 | [compare_onnx_execution] got 5 results
21 | [compare_onnx_execution] compute edit distance
22 | [compare_onnx_execution] got 4 pairs
23 | [compare_onnx_execution] done
24 | = | INPUT float32 5x6 AAAA X | INPUT float32 5x6 AAAA X
25 | = | INPUT float32 5x6 AAAA Y | INPUT float32 5x6 AAAA Y
26 | = | RESULT float32 5x6 AABB Add res | RESULT float32 5x6 AABB Add res
27 | = | RESULT float32 5x6 AAAA Cos Z | RESULT float32 5x6 AAAA Cos Z
28 |
29 | .. runpython::
30 |
31 | from onnx_array_api._command_lines_parser import get_parser_compare
32 | get_parser_compare().print_help()
33 |
34 | See function :func:`onnx_array_api.reference.compare_onnx_execution`.
35 |
36 | translate
37 | =========
38 |
39 | The function convers an onnx file into some code.
40 |
41 | ::
42 |
43 | python -m translate ...
44 |
45 | Output example::
46 |
47 | not yet ready
48 |
49 | .. runpython::
50 |
51 | from onnx_array_api._command_lines_parser import get_parser_translate
52 | get_parser_translate().print_help()
53 |
--------------------------------------------------------------------------------
/_doc/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from sphinx_runpython.github_link import make_linkcode_resolve
4 | from sphinx_runpython.conf_helper import has_dvipng, has_dvisvgm
5 | from onnx_array_api import __version__
6 |
7 | extensions = [
8 | "sphinx.ext.autodoc",
9 | "sphinx.ext.coverage",
10 | "sphinx.ext.githubpages",
11 | "sphinx.ext.ifconfig",
12 | "sphinx.ext.intersphinx",
13 | "sphinx.ext.mathjax",
14 | "sphinx.ext.viewcode",
15 | "sphinx.ext.todo",
16 | "sphinx_gallery.gen_gallery",
17 | "sphinx_issues",
18 | "matplotlib.sphinxext.plot_directive",
19 | "sphinx_runpython.epkg",
20 | "sphinx_runpython.gdot",
21 | "sphinx_runpython.runpython",
22 | ]
23 |
24 | if has_dvisvgm():
25 | extensions.append("sphinx.ext.imgmath")
26 | imgmath_image_format = "svg"
27 | elif has_dvipng():
28 | extensions.append("sphinx.ext.pngmath")
29 | imgmath_image_format = "png"
30 | else:
31 | extensions.append("sphinx.ext.mathjax")
32 |
33 | templates_path = ["_templates"]
34 | html_logo = "_static/logo.png"
35 | source_suffix = ".rst"
36 | master_doc = "index"
37 | project = "onnx-array-api"
38 | copyright = "2023-2024, Xavier Dupré"
39 | author = "Xavier Dupré"
40 | version = __version__
41 | release = __version__
42 | language = "en"
43 | exclude_patterns = []
44 | pygments_style = "sphinx"
45 | todo_include_todos = True
46 |
47 | html_theme = "furo"
48 | html_theme_path = ["_static"]
49 | html_theme_options = {}
50 | html_static_path = ["_static"]
51 | html_sourcelink_suffix = ""
52 |
53 | issues_github_path = "sdpython/onnx-array-api"
54 |
55 | # The following is used by sphinx.ext.linkcode to provide links to github
56 | linkcode_resolve = make_linkcode_resolve(
57 | "onnx-array-api",
58 | (
59 | "https://github.com/sdpython/onnx-array-api/"
60 | "blob/{revision}/{package}/"
61 | "{path}#L{lineno}"
62 | ),
63 | )
64 |
65 | latex_elements = {
66 | "papersize": "a4",
67 | "pointsize": "10pt",
68 | "title": project,
69 | }
70 |
71 | intersphinx_mapping = {
72 | "matplotlib": ("https://matplotlib.org/", None),
73 | "numpy": ("https://numpy.org/doc/stable", None),
74 | "onnx": ("https://onnx.ai/onnx/", None),
75 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
76 | "python": (f"https://docs.python.org/{sys.version_info.major}", None),
77 | "scipy": ("https://docs.scipy.org/doc/scipy/reference", None),
78 | "sklearn": ("https://scikit-learn.org/stable/", None),
79 | "sklearn-onnx": ("https://onnx.ai/sklearn-onnx/", None),
80 | "torch": ("https://pytorch.org/docs/stable/", None),
81 | }
82 |
83 | # Check intersphinx reference targets exist
84 | nitpicky = True
85 | # See also scikit-learn/scikit-learn#26761
86 | nitpick_ignore = [
87 | ("py:class", "False"),
88 | ("py:class", "True"),
89 | ("py:class", "pipeline.Pipeline"),
90 | ("py:class", "default=sklearn.utils.metadata_routing.UNCHANGED"),
91 | ]
92 |
93 | nitpick_ignore_regex = [
94 | ("py:func", ".*numpy[.].*"),
95 | ("py:func", ".*scipy[.].*"),
96 | ("py:class", ".*onnxruntime[.].*"),
97 | ("py:class", ".*onnx_array_api.npx.npx_types.OptParTypeTupleType_.*"),
98 | ("py:class", ".*onnx_array_api.npx.npx_types.ParType[a-z].*"),
99 | ("py:class", ".*onnx_array_api.npx.npx_types.OptTensorType_.*"),
100 | ("py:class", ".*onnx_array_api.npx.npx_types.TensorType_.*"),
101 | ("py:class", ".*onnx_array_api.npx.npx_types.[ui].*"),
102 | ]
103 |
104 | sphinx_gallery_conf = {
105 | # path to your examples scripts
106 | "examples_dirs": os.path.join(os.path.dirname(__file__), "examples"),
107 | # path where to save gallery generated examples
108 | "gallery_dirs": "auto_examples",
109 | }
110 |
111 | epkg_dictionary = {
112 | "Array API": "https://data-apis.org/array-api/",
113 | "ArrayAPI": (
114 | "https://data-apis.org/array-api/",
115 | ("2022.12/API_specification/generated/array_api.{0}.html", 1),
116 | ),
117 | "ast": "https://docs.python.org/3/library/ast.html",
118 | "cProfile.Profile": "https://docs.python.org/3/library/profile.html#profile.Profile",
119 | "DOT": "https://graphviz.org/doc/info/lang.html",
120 | "Graphviz": "https://graphviz.org/",
121 | "inner API": "https://onnx.ai/onnx/intro/python.html",
122 | "JIT": "https://en.wikipedia.org/wiki/Just-in-time_compilation",
123 | "onnx": "https://onnx.ai/onnx/",
124 | "onnx-graphsurgeon": "https://github.com/NVIDIA/TensorRT/tree/main/tools/onnx-graphsurgeon",
125 | "onnx.helper": "https://onnx.ai/onnx/api/helper.html",
126 | "ONNX": "https://onnx.ai/",
127 | "ONNX Operators": "https://onnx.ai/onnx/operators/",
128 | "onnxruntime": "https://onnxruntime.ai/",
129 | "onnxruntime-training": "https://onnxruntime.ai/docs/get-started/training-on-device.html",
130 | "numpy": "https://numpy.org/",
131 | "numba": "https://numba.pydata.org/",
132 | "onnx-array-api": ("https://sdpython.github.io/doc/onnx-array-api/dev/"),
133 | "onnxscript": "https://github.com/microsoft/onnxscript",
134 | "pyinstrument": "https://github.com/joerick/pyinstrument",
135 | "python": "https://www.python.org/",
136 | "pytorch": "https://pytorch.org/",
137 | "reverse Polish notation": "https://en.wikipedia.org/wiki/Reverse_Polish_notation",
138 | "scikit-learn": "https://scikit-learn.org/stable/",
139 | "scipy": "https://scipy.org/",
140 | "sklearn-onnx": "https://onnx.ai/sklearn-onnx/",
141 | "spox": "https://github.com/Quantco/spox",
142 | "sphinx-gallery": "https://github.com/sphinx-gallery/sphinx-gallery",
143 | "tensorflow": "https://www.tensorflow.org/",
144 | "tensorflow-onnx": "https://github.com/onnx/tensorflow-onnx",
145 | "torch": "https://pytorch.org/docs/stable/torch.html",
146 | "torch.onnx": "https://pytorch.org/docs/stable/onnx.html",
147 | #
148 | "C_OrtValue": (
149 | "https://onnxruntime.ai/docs/api/csharp/api/Microsoft.ML.OnnxRuntime.OrtValue.html"
150 | ),
151 | "OrtValue": (
152 | "https://onnxruntime.ai/docs/api/python/api_summary.html#onnxruntime.OrtValue"
153 | ),
154 | }
155 |
--------------------------------------------------------------------------------
/_doc/examples/README.txt:
--------------------------------------------------------------------------------
1 | Example gallery
2 | ===============
3 |
4 | A couple of examples to illustrate different implementation
5 | of dot product (see also :epkg:`sphinx-gallery`).
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/_doc/examples/data/small.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_doc/examples/data/small.onnx
--------------------------------------------------------------------------------
/_doc/examples/plot_f8.py:
--------------------------------------------------------------------------------
1 | """
2 | .. _l-example-float8:
3 |
4 | About float 8
5 | =============
6 |
7 | Float 8 types were recently introduced to speed up the
8 | training of deep learning models.
9 |
10 | Possible values
11 | +++++++++++++++
12 |
13 | First E4M3FN.
14 | """
15 |
16 | import pprint
17 | from onnx_array_api.validation.f8 import CastFloat8
18 |
19 | pprint.pprint(CastFloat8.values_e4m3fn)
20 |
21 |
22 | ############################################
23 | # Then E5M2.
24 |
25 | pprint.pprint(CastFloat8.values_e5m2)
26 |
--------------------------------------------------------------------------------
/_doc/examples/plot_first_example.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | .. _l-onnx-array-first-api-example:
4 |
5 | First examples with onnx-array-api
6 | ==================================
7 |
8 | This demonstrates an easy case with :epkg:`onnx-array-api`.
9 | It shows how a function can be easily converted into
10 | ONNX.
11 |
12 | A loss function from numpy to ONNX
13 | ++++++++++++++++++++++++++++++++++
14 |
15 | The first example takes a loss function and converts it into ONNX.
16 | """
17 |
18 | import numpy as np
19 |
20 | from onnx_array_api.npx import absolute, jit_onnx
21 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
22 |
23 |
24 | ################################
25 | # The function looks like a numpy function.
26 | def l1_loss(x, y):
27 | return absolute(x - y).sum()
28 |
29 |
30 | ################################
31 | # The function needs to be converted into ONNX with function jit_onnx.
32 | # jitted_l1_loss is a wrapper. It intercepts all calls to l1_loss.
33 | # When it happens, it checks the input types and creates the
34 | # corresponding ONNX graph.
35 | jitted_l1_loss = jit_onnx(l1_loss)
36 |
37 | ################################
38 | # First execution and conversion to ONNX.
39 | # The wrapper caches the created onnx graph.
40 | # It reuses it if the input types and the number of dimension are the same.
41 | # It creates a new one otherwise and keep the old one.
42 |
43 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
44 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
45 |
46 | res = jitted_l1_loss(x, y)
47 | print(res)
48 |
49 | ####################################
50 | # The ONNX graph can be accessed the following way.
51 | print(onnx_simple_text_plot(jitted_l1_loss.get_onnx()))
52 |
53 | ################################
54 | # We can also define a more complex loss by computing L1 loss on
55 | # the first column and L2 loss on the seconde one.
56 |
57 |
58 | def l1_loss(x, y):
59 | return absolute(x - y).sum()
60 |
61 |
62 | def l2_loss(x, y):
63 | return ((x - y) ** 2).sum()
64 |
65 |
66 | def myloss(x, y):
67 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
68 |
69 |
70 | jitted_myloss = jit_onnx(myloss)
71 |
72 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
73 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
74 |
75 | res = jitted_myloss(x, y)
76 | print(res)
77 |
78 | print(onnx_simple_text_plot(jitted_myloss.get_onnx()))
79 |
80 | ############################
81 | # Eager mode
82 | # ++++++++++
83 |
84 | import numpy as np
85 |
86 | from onnx_array_api.npx import absolute, eager_onnx
87 |
88 |
89 | def l1_loss(x, y):
90 | """
91 | err is a type inheriting from
92 | :class:`EagerTensor `.
93 | It needs to be converted to numpy first before any display.
94 | """
95 | err = absolute(x - y).sum()
96 | print(f"l1_loss={err.numpy()}")
97 | return err
98 |
99 |
100 | def l2_loss(x, y):
101 | err = ((x - y) ** 2).sum()
102 | print(f"l2_loss={err.numpy()}")
103 | return err
104 |
105 |
106 | def myloss(x, y):
107 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
108 |
109 |
110 | #################################
111 | # Eager mode is enabled by function :func:`eager_onnx
112 | # `.
113 | # It intercepts all calls to `my_loss`. On the first call,
114 | # it replaces a numpy array by a tensor corresponding to the
115 | # selected runtime, here numpy as well through
116 | # :class:`EagerNumpyTensor
117 | # `.
118 | eager_myloss = eager_onnx(myloss)
119 |
120 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
121 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
122 |
123 | #################################
124 | # First execution and conversion to ONNX.
125 | # The wrapper caches many Onnx graphs corresponding to
126 | # simple opeator, (`+`, `-`, `/`, `*`, ...), reduce functions,
127 | # any other function from the API.
128 | # It reuses it if the input types and the number of dimension are the same.
129 | # It creates a new one otherwise and keep the old ones.
130 | res = eager_myloss(x, y)
131 | print(res)
132 |
133 | ################################
134 | # There is no ONNX graph to show. Every operation
135 | # is converted into small ONNX graphs.
136 |
--------------------------------------------------------------------------------
/_doc/examples/plot_onnx_diff.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | .. _l-onnx-diff-example:
4 |
5 | Compares the conversions of the same model with different options
6 | =================================================================
7 |
8 | The script compares two onnx models obtained with the same trained
9 | scikit-learn models but converted with different options.
10 |
11 | A model
12 | +++++++
13 | """
14 |
15 | from sklearn.mixture import GaussianMixture
16 | from sklearn.datasets import load_iris
17 | from sklearn.model_selection import train_test_split
18 | from skl2onnx import to_onnx
19 | from onnx_array_api.reference import compare_onnx_execution
20 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
21 |
22 |
23 | data = load_iris()
24 | X_train, X_test = train_test_split(data.data)
25 | model = GaussianMixture()
26 | model.fit(X_train)
27 |
28 | #################################
29 | # Conversion to onnx
30 | # ++++++++++++++++++
31 |
32 | onx = to_onnx(
33 | model, X_train[:1], options={id(model): {"score_samples": True}}, target_opset=12
34 | )
35 |
36 | print(onnx_simple_text_plot(onx))
37 |
38 | ##################################
39 | # Conversion to onnx without ReduceLogSumExp
40 | # ++++++++++++++++++++++++++++++++++++++++++
41 |
42 | onx2 = to_onnx(
43 | model,
44 | X_train[:1],
45 | options={id(model): {"score_samples": True}},
46 | black_op={"ReduceLogSumExp"},
47 | target_opset=12,
48 | )
49 |
50 | print(onnx_simple_text_plot(onx2))
51 |
52 |
53 | #############################################
54 | # Differences
55 | # +++++++++++
56 | #
57 | # Function :func:`onnx_array_api.reference.compare_onnx_execution`
58 | # compares the intermediate results of two onnx models. Then it finds
59 | # the best alignmet between the two models using an edit distance.
60 |
61 | res1, res2, align, dc = compare_onnx_execution(onx, onx2, verbose=1)
62 | print("------------")
63 | text = dc.to_str(res1, res2, align)
64 | print(text)
65 |
66 | ###############################
67 | # See :ref:`l-long-output-compare_onnx_execution` for a better view.
68 | # The display shows that ReduceSumSquare was replaced by Mul + ReduceSum,
69 | # and ReduceLogSumExp by ReduceMax + Sub + Exp + Log + Add.
70 |
--------------------------------------------------------------------------------
/_doc/examples/plot_onnxruntime.py:
--------------------------------------------------------------------------------
1 | """
2 | First examples with onnxruntime
3 | ===============================
4 |
5 | Example :ref:`l-onnx-array-first-api-example` defines a custom
6 | loss and then executes it with class
7 | :class:`onnx.reference.ReferenceEvaluator`.
8 | Next example replaces it with :epkg:`onnxruntime`.
9 |
10 | Example
11 | +++++++
12 | """
13 |
14 | import numpy as np
15 |
16 | from onnx_array_api.npx import absolute, jit_onnx
17 | from onnx_array_api.ort.ort_tensors import JitOrtTensor, OrtTensor
18 |
19 |
20 | def l1_loss(x, y):
21 | return absolute(x - y).sum()
22 |
23 |
24 | def l2_loss(x, y):
25 | return ((x - y) ** 2).sum()
26 |
27 |
28 | def myloss(x, y):
29 | l1 = l1_loss(x[:, 0], y[:, 0])
30 | l2 = l2_loss(x[:, 1], y[:, 1])
31 | return l1 + l2
32 |
33 |
34 | ort_myloss = jit_onnx(myloss, JitOrtTensor, target_opsets={"": 17}, ir_version=8)
35 |
36 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
37 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
38 |
39 | xort = OrtTensor.from_array(x)
40 | yort = OrtTensor.from_array(y)
41 |
42 | res = ort_myloss(xort, yort)
43 | print(res.numpy())
44 |
45 | ###############################
46 | # Profiling
47 | # +++++++++
48 | from onnx_array_api.profiling import profile, profile2graph
49 |
50 | x = np.random.randn(10000, 2).astype(np.float32)
51 | y = np.random.randn(10000, 2).astype(np.float32)
52 | xort = OrtTensor.from_array(x)
53 | yort = OrtTensor.from_array(y)
54 |
55 |
56 | def loop_ort(n):
57 | for _ in range(n):
58 | ort_myloss(xort, yort)
59 |
60 |
61 | def loop_numpy(n):
62 | for _ in range(n):
63 | myloss(x, y)
64 |
65 |
66 | def loop(n=1000):
67 | loop_numpy(n)
68 | loop_ort(n)
69 |
70 |
71 | ps = profile(loop)[0]
72 | root, nodes = profile2graph(ps, clean_text=lambda x: x.split("/")[-1])
73 | text = root.to_text()
74 | print(text)
75 |
76 | ##############################
77 | # Benchmark
78 | # +++++++++
79 |
80 | from pandas import DataFrame
81 | from tqdm import tqdm
82 |
83 | from onnx_array_api.ext_test_case import measure_time
84 |
85 | data = []
86 | for n in tqdm([1, 10, 100, 1000, 10000, 100000]):
87 | x = np.random.randn(n, 2).astype(np.float32)
88 | y = np.random.randn(n, 2).astype(np.float32)
89 |
90 | obs = measure_time(lambda x=x, y=y: myloss(x, y))
91 | obs["name"] = "numpy"
92 | obs["n"] = n
93 | data.append(obs)
94 |
95 | xort = OrtTensor.from_array(x)
96 | yort = OrtTensor.from_array(y)
97 | obs = measure_time(lambda xort=xort, yort=yort: ort_myloss(xort, yort))
98 | obs["name"] = "ort"
99 | obs["n"] = n
100 | data.append(obs)
101 |
102 | df = DataFrame(data)
103 | piv = df.pivot(index="n", columns="name", values="average")
104 | piv
105 |
106 | ############################
107 | # Plots
108 | # +++++
109 |
110 | import matplotlib.pyplot as plt
111 |
112 | fig, ax = plt.subplots(1, 2, figsize=(12, 4))
113 | piv.plot(
114 | title="Comparison between numpy and onnxruntime", logx=True, logy=True, ax=ax[0]
115 | )
116 | piv["ort/numpy"] = piv["ort"] / piv["numpy"]
117 | piv["ort/numpy"].plot(title="Ratio ort/numpy", logx=True, ax=ax[1])
118 | fig.savefig("plot_onnxruntime.png")
119 |
--------------------------------------------------------------------------------
/_doc/examples/plot_optimization.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | .. _l-onnx-array-onnxruntime-optimization:
4 |
5 | Optimization with onnxruntime
6 | =============================
7 |
8 | *onnxruntime* optimizes the onnx graph by default before running
9 | the inference. It modifies, fuses or add new operators.
10 | Some of them are standard onnx operators, some of them
11 | are implemented in onnxruntime (see `Supported Operators
12 | `_).
13 | This example looks into the differences of two models.
14 |
15 | Optimize a model with onnxruntime
16 | +++++++++++++++++++++++++++++++++
17 | """
18 |
19 | import os
20 | from pprint import pprint
21 | import numpy
22 | from pandas import DataFrame
23 | import matplotlib.pyplot as plt
24 | from onnx import load
25 | from onnx_array_api.ext_test_case import example_path
26 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
27 | from onnx_array_api.validation.diff import text_diff, html_diff
28 | from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
29 | from onnx_array_api.ext_test_case import measure_time
30 | from onnx_array_api.ort.ort_optimizers import ort_optimized_model
31 |
32 |
33 | filename = example_path("data/small.onnx")
34 | optimized = filename + ".optimized.onnx"
35 |
36 | if not os.path.exists(optimized):
37 | ort_optimized_model(filename, output=optimized)
38 | print(optimized)
39 |
40 | #############################
41 | # Output comparison
42 | # +++++++++++++++++
43 |
44 | so = SessionOptions()
45 | so.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
46 | img = numpy.random.random((1, 3, 112, 112)).astype(numpy.float32)
47 |
48 | sess = InferenceSession(filename, so, providers=["CPUExecutionProvider"])
49 | sess_opt = InferenceSession(optimized, so, providers=["CPUExecutionProvider"])
50 | input_name = sess.get_inputs()[0].name
51 | out = sess.run(None, {input_name: img})[0]
52 | out_opt = sess_opt.run(None, {input_name: img})[0]
53 | if out.shape != out_opt.shape:
54 | print("ERROR shape are different {out.shape} != {out_opt.shape}")
55 | diff = numpy.abs(out - out_opt).max()
56 | print(f"Differences: {diff}")
57 |
58 | ####################################
59 | # Difference
60 | # ++++++++++
61 | #
62 | # Unoptimized model.
63 |
64 | with open(filename, "rb") as f:
65 | model = load(f)
66 | print("first model to text...")
67 | text1 = onnx_simple_text_plot(model, indent=False)
68 | print(text1)
69 |
70 | #####################################
71 | # Optimized model.
72 |
73 |
74 | with open(optimized, "rb") as f:
75 | model = load(f)
76 | print("second model to text...")
77 | text2 = onnx_simple_text_plot(model, indent=False)
78 | print(text2)
79 |
80 | ########################################
81 | # Differences
82 |
83 | print("differences...")
84 | print(text_diff(text1, text2))
85 |
86 | #####################################
87 | # HTML version.
88 |
89 | print("html differences...")
90 | output = html_diff(text1, text2)
91 | with open("diff_html.html", "w", encoding="utf-8") as f:
92 | f.write(output)
93 | print("done.")
94 |
95 | #####################################
96 | # Benchmark
97 | # +++++++++
98 |
99 | img = numpy.random.random((1, 3, 112, 112)).astype(numpy.float32)
100 |
101 | t1 = measure_time(lambda: sess.run(None, {input_name: img}), repeat=25, number=25)
102 | t1["name"] = "original"
103 | print("Original model")
104 | pprint(t1)
105 |
106 | t2 = measure_time(lambda: sess_opt.run(None, {input_name: img}), repeat=25, number=25)
107 | t2["name"] = "optimized"
108 | print("Optimized")
109 | pprint(t2)
110 |
111 |
112 | ############################
113 | # Plots
114 | # +++++
115 |
116 |
117 | fig, ax = plt.subplots(1, 1, figsize=(12, 4))
118 |
119 | df = DataFrame([t1, t2]).set_index("name")
120 | df
121 |
122 | #######################################
123 | # And the graph is:
124 |
125 | ax.bar(df.index, df["average"].values, yerr=df["deviation"].values, capsize=6)
126 | ax.set_title("Measure performance of optimized model\nlower is better")
127 | plt.grid()
128 | fig.savefig("plot_optimization.png")
129 |
--------------------------------------------------------------------------------
/_doc/examples/plot_profiling.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | .. _l-onnx-array-onnxruntime-profiling:
4 |
5 | Profiling with onnxruntime
6 | ==========================
7 |
8 | *onnxruntime* optimizes the onnx graph by default before running
9 | the inference. It modifies, fuses or add new operators.
10 | Some of them are standard onnx operators, some of them
11 | are implemented in onnxruntime (see `Supported Operators
12 | `_).
13 | This example profiles the two models.
14 |
15 | Optimize a model with onnxruntime
16 | +++++++++++++++++++++++++++++++++
17 | """
18 |
19 | import os
20 | import numpy
21 | import matplotlib.pyplot as plt
22 | from onnxruntime import get_available_providers
23 | from onnx_array_api.ext_test_case import example_path
24 | from onnx_array_api.ort.ort_optimizers import ort_optimized_model
25 | from onnx_array_api.ort.ort_profile import ort_profile, merge_ort_profile
26 | from onnx_array_api.plotting.stat_plot import plot_ort_profile
27 |
28 |
29 | suffix = ""
30 | filename = example_path(f"data/small{suffix}.onnx")
31 | optimized = filename + ".optimized.onnx"
32 | print(f"model={filename!r}")
33 |
34 | if not os.path.exists(optimized):
35 | ort_optimized_model(filename, output=optimized)
36 | print(f"optimized={optimized!r}")
37 |
38 | #############################
39 | # .. _l-example-ort-profiling:
40 | #
41 | # Profiling
42 | # +++++++++
43 |
44 | feeds = {"input": numpy.random.random((1, 3, 112, 112)).astype(numpy.float32)}
45 | prof_base = ort_profile(
46 | filename,
47 | feeds,
48 | repeat=6,
49 | disable_optimization=True,
50 | providers=["CPUExecutionProvider"],
51 | )
52 | prof_base.to_excel(f"prof_base{suffix}.xlsx", index=False)
53 | prof_base
54 |
55 | #######################################
56 | # And the optimized model.
57 |
58 | prof_opti = ort_profile(
59 | optimized,
60 | feeds,
61 | repeat=6,
62 | disable_optimization=True,
63 | providers=["CPUExecutionProvider"],
64 | )
65 | prof_opti.to_excel(f"prof_opti{suffix}.xlsx", index=False)
66 | prof_opti
67 |
68 | #######################################
69 | # And the graph is:
70 |
71 | unique_op = set(prof_base["args_op_name"])
72 | fig, ax = plt.subplots(2, 2, figsize=(10, len(unique_op)), sharex="col")
73 | plot_ort_profile(prof_base, ax[0, 0], ax[0, 1], title="baseline")
74 | plot_ort_profile(prof_opti, ax[1, 0], ax[1, 1], title="optimized")
75 | fig.tight_layout()
76 | fig.savefig(f"plot_profiling{suffix}.png")
77 |
78 | ##################################################
79 | # Merging profiles
80 | # ++++++++++++++++
81 | #
82 | # Let's try to compare both profiles assuming every iteration
83 | # process the same image and the input and output size are the
84 | # same at every iteration.
85 |
86 | merge, gr = merge_ort_profile(prof_base, prof_opti)
87 | merge.to_excel(f"plot_profiling_merged{suffix}.xlsx", index=False)
88 | merge
89 |
90 | #####################################################
91 | # More detailed
92 |
93 | gr.to_excel(f"plot_profiling_merged_details{suffix}.xlsx", index=False)
94 | gr
95 |
96 | ################################
97 | # Final plot
98 | # ++++++++++
99 |
100 | # let's filter out unsignificant operator.
101 | grmax = gr["durbase"] + gr["duropti"]
102 | total = grmax.sum()
103 | grmax /= total
104 | gr = gr[grmax >= 0.01]
105 |
106 |
107 | fig, ax = plt.subplots(1, 2, figsize=(14, min(gr.shape[0], 500)), sharey=True)
108 | gr[["durbase", "duropti"]].plot.barh(ax=ax[0])
109 | ax[0].set_title("Side by side duration")
110 | gr = gr.copy()
111 | gr[["countbase", "countopti"]].plot.barh(ax=ax[1])
112 | ax[1].set_title("Side by side count")
113 | fig.tight_layout()
114 | fig.savefig(f"plot_profiling_side_by_side{suffix}.png")
115 |
116 |
117 | ########################################
118 | # On CUDA
119 | # +++++++
120 |
121 |
122 | if "CUDAExecutionProvider" in get_available_providers():
123 | print("Profiling on CUDA")
124 | prof_base = ort_profile(
125 | filename,
126 | feeds,
127 | repeat=6,
128 | disable_optimization=True,
129 | providers=["CUDAExecutionProvider"],
130 | )
131 | prof_base.to_excel(f"prof_cuda_base{suffix}.xlsx", index=False)
132 |
133 | prof_opti = ort_profile(
134 | optimized,
135 | feeds,
136 | repeat=6,
137 | disable_optimization=True,
138 | providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
139 | )
140 | prof_opti.to_excel(f"prof_cuda_opti{suffix}.xlsx", index=False)
141 |
142 | unique_op = set(prof_base["args_op_name"])
143 | fig, ax = plt.subplots(2, 2, figsize=(10, len(unique_op)), sharex="col")
144 | plot_ort_profile(prof_base, ax[0, 0], ax[0, 1], title="baseline")
145 | plot_ort_profile(prof_opti, ax[1, 0], ax[1, 1], title="optimized")
146 | fig.tight_layout()
147 | fig.savefig(f"plot_profiling_cuda{suffix}.png")
148 |
149 | merge, gr = merge_ort_profile(prof_base, prof_opti)
150 | merge.to_excel(f"plot_profiling_merged{suffix}.xlsx", index=False)
151 | gr.to_excel(f"plot_profiling_merged_details{suffix}.xlsx", index=False)
152 |
153 | grmax = gr["durbase"] + gr["duropti"]
154 | total = grmax.sum()
155 | grmax /= total
156 | gr = gr[grmax >= 0.01]
157 |
158 | fig, ax = plt.subplots(1, 2, figsize=(14, min(gr.shape[0], 500)), sharey=True)
159 | gr[["durbase", "duropti"]].plot.barh(ax=ax[0])
160 | ax[0].set_title("Side by side duration")
161 | gr = gr.copy()
162 | gr[["countbase", "countopti"]].plot.barh(ax=ax[1])
163 | ax[1].set_title("Side by side count")
164 | fig.tight_layout()
165 | fig.savefig(f"plot_profiling_side_by_side_cuda{suffix}.png")
166 |
167 | else:
168 | print(f"CUDA not available in {get_available_providers()}.")
169 | fig, ax = None, None
170 |
171 | ax
172 |
--------------------------------------------------------------------------------
/_doc/index.rst:
--------------------------------------------------------------------------------
1 |
2 | onnx-array-api: APIs to create ONNX Graphs
3 | ==========================================
4 |
5 | .. image:: https://dev.azure.com/xavierdupre3/onnx-array-api/_apis/build/status/sdpython.onnx-array-api
6 | :target: https://dev.azure.com/xavierdupre3/onnx-array-api/
7 |
8 | .. image:: https://badge.fury.io/py/onnx-array-api.svg
9 | :target: http://badge.fury.io/py/onnx-array-api
10 |
11 | .. image:: http://img.shields.io/github/issues/sdpython/onnx-array-api.png
12 | :alt: GitHub Issues
13 | :target: https://github.com/sdpython/onnx-array-api/issues
14 |
15 | .. image:: https://img.shields.io/badge/license-MIT-blue.svg
16 | :alt: MIT License
17 | :target: https://opensource.org/license/MIT/
18 |
19 | .. image:: https://img.shields.io/github/repo-size/sdpython/onnx-array-api
20 | :target: https://github.com/sdpython/onnx-array-api/
21 | :alt: size
22 |
23 | .. image:: https://img.shields.io/badge/code%20style-black-000000.svg
24 | :target: https://github.com/psf/black
25 |
26 | .. image:: https://codecov.io/gh/sdpython/onnx-array-api/branch/main/graph/badge.svg?token=Wb9ZGDta8J
27 | :target: https://codecov.io/gh/sdpython/onnx-array-api
28 |
29 | **onnx-array-api** implements APIs to create custom ONNX graphs.
30 | The objective is to speed up the implementation of converter libraries.
31 |
32 | .. toctree::
33 | :maxdepth: 1
34 | :caption: Contents
35 |
36 | tutorial/index
37 | api/index
38 | tech/index
39 | command_lines
40 | auto_examples/index
41 |
42 | .. toctree::
43 | :maxdepth: 1
44 | :caption: More
45 |
46 | CHANGELOGS
47 | license
48 | long_outputs
49 |
50 | Sources available on
51 | `github/onnx-array-api `_.
52 |
53 | GraphBuilder API
54 | ++++++++++++++++
55 |
56 | Almost every converting library (converting a machine learned model to ONNX) is implementing
57 | its own graph builder and customizes it for its needs.
58 | It handles some frequent tasks such as giving names to intermediate
59 | results, loading, saving onnx models. It can be used as well to extend an existing graph.
60 | See :ref:`l-graph-api`.
61 |
62 | .. runpython::
63 | :showcode:
64 |
65 | import numpy as np
66 | from onnx_array_api.graph_api import GraphBuilder
67 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
68 |
69 | g = GraphBuilder()
70 | g.make_tensor_input("X", np.float32, (None, None))
71 | g.make_tensor_input("Y", np.float32, (None, None))
72 | r1 = g.make_node("Sub", ["X", "Y"]) # the name given to the output is given by the class,
73 | # it ensures the name is unique
74 | init = g.make_initializer(np.array([2], dtype=np.int64)) # the class automatically
75 | # converts the array to a tensor
76 | r2 = g.make_node("Pow", [r1, init])
77 | g.make_node("ReduceSum", [r2], outputs=["Z"]) # the output name is given because
78 | # the user wants to choose the name
79 | g.make_tensor_output("Z", np.float32, (None, None))
80 |
81 | onx = g.to_onnx() # final conversion to onnx
82 |
83 | print(onnx_simple_text_plot(onx))
84 |
85 | Light API
86 | +++++++++
87 |
88 | The syntax is inspired from the
89 | `Reverse Polish Notation `_.
90 | This kind of API is easy to use to build new graphs,
91 | less easy to extend an existing graph. See :ref:`l-light-api`.
92 |
93 | .. runpython::
94 | :showcode:
95 |
96 | import numpy as np
97 | from onnx_array_api.light_api import start
98 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
99 |
100 | model = (
101 | start()
102 | .vin("X")
103 | .vin("Y")
104 | .bring("X", "Y")
105 | .Sub()
106 | .rename("dxy")
107 | .cst(np.array([2], dtype=np.int64), "two")
108 | .bring("dxy", "two")
109 | .Pow()
110 | .ReduceSum()
111 | .rename("Z")
112 | .vout()
113 | .to_onnx()
114 | )
115 |
116 | print(onnx_simple_text_plot(model))
117 |
118 | Numpy API
119 | +++++++++
120 |
121 | Writing ONNX graphs requires to know ONNX syntax unless
122 | it is possible to reuse an existing syntax such as :epkg:`numpy`.
123 | This is what this API is doing.
124 | This kind of API is easy to use to build new graphs,
125 | almost impossible to use to extend new graphs as it usually requires
126 | to know onnx for that. See :ref:`l-numpy-api-onnx`.
127 |
128 | .. runpython::
129 | :showcode:
130 | :warningout: DeprecationWarning, FutureWarning
131 | :process:
132 |
133 | import numpy as np # A
134 | from onnx_array_api.npx import absolute, jit_onnx
135 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
136 |
137 | def l1_loss(x, y):
138 | return absolute(x - y).sum()
139 |
140 |
141 | def l2_loss(x, y):
142 | return ((x - y) ** 2).sum()
143 |
144 |
145 | def myloss(x, y):
146 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
147 |
148 |
149 | jitted_myloss = jit_onnx(myloss)
150 |
151 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
152 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
153 | res = jitted_myloss(x, y)
154 | print(res)
155 |
156 | print(onnx_simple_text_plot(jitted_myloss.get_onnx()))
157 |
158 | .. gdot::
159 | :script: DOT-SECTION
160 | :process:
161 |
162 | # index
163 | import numpy as np
164 | from onnx_array_api.npx import absolute, jit_onnx
165 | from onnx_array_api.plotting.dot_plot import to_dot
166 |
167 |
168 | def l1_loss(x, y):
169 | return absolute(x - y).sum()
170 |
171 |
172 | def l2_loss(x, y):
173 | return ((x - y) ** 2).sum()
174 |
175 |
176 | def myloss(x, y):
177 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
178 |
179 |
180 | jitted_myloss = jit_onnx(myloss)
181 |
182 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
183 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
184 | res = jitted_myloss(x, y)
185 | print(to_dot(jitted_myloss.get_onnx()))
186 |
187 | Older versions
188 | ++++++++++++++
189 |
190 | * `0.3.0 <../v0.3.0/index.html>`_
191 | * `0.2.0 <../v0.2.0/index.html>`_
192 | * `0.1.3 <../v0.1.3/index.html>`_
193 | * `0.1.2 <../v0.1.2/index.html>`_
194 |
--------------------------------------------------------------------------------
/_doc/license.rst:
--------------------------------------------------------------------------------
1 | License
2 | =======
3 |
4 | .. literalinclude:: LICENSE.txt
5 | :language: none
6 |
--------------------------------------------------------------------------------
/_doc/long_outputs.rst:
--------------------------------------------------------------------------------
1 | :hide-toc:
2 |
3 | ==========================
4 | Long outputs uneasy to see
5 | ==========================
6 |
7 | onnx
8 | ====
9 |
10 | .. _l-long-output-compare_onnx_execution:
11 |
12 | onnx_array_api.reference.compare_onnx_execution
13 | +++++++++++++++++++++++++++++++++++++++++++++++
14 |
15 | From example :ref:`l-onnx-diff-example` for function
16 | :func:`onnx_array_api.reference.compare_onnx_execution`.
17 | See also `raw rendering `_.
18 |
19 | ::
20 |
21 | 1 = | INITIA float64 1 HAAA Ad_Addcst | INITIA float64 1 HAAA Ad_Addcst
22 | 2 = | INITIA float64 4x4 ADZF Ge_Gemmcst | INITIA float64 4x4 ADZF Ge_Gemmcst
23 | 3 = | INITIA float64 4 USEA Ge_Gemmcst1 | INITIA float64 4 USEA Ge_Gemmcst1
24 | 4 = | INITIA float64 1 AAAA Mu_Mulcst | INITIA float64 1 AAAA Mu_Mulcst
25 | 5 = | INITIA float64 1 DAAA Ad_Addcst1 | INITIA float64 1 DAAA Ad_Addcst1
26 | 6 = | INITIA float64 1 AAAA Ad_Addcst2 | INITIA float64 1 AAAA Ad_Addcst2
27 | 7 = | INPUT float64 1x4 AAAA X | INPUT float64 1x4 AAAA X
28 | 8 = | RESULT float64 1x4 UTFC Gemm Ge_Y0 | RESULT float64 1x4 UTFC Gemm Ge_Y0
29 | 9 + | | RESULT float64 1x4 TIEG Mul Mu_C01
30 | 10 ~ | RESULT float64 1x1 NAAA ReduceSumS Re_reduced0 | RESULT float64 1x1 NAAA ReduceSum Re_reduced0
31 | 11 = | RESULT float64 1x1 NAAA Concat Co_concat_re | RESULT float64 1x1 NAAA Concat Co_concat_re
32 | 12 = | RESULT float64 1x1 UAAA Add Ad_C02 | RESULT float64 1x1 UAAA Add Ad_C02
33 | 13 = | RESULT float64 1x1 DAAA Mul Mu_C0 | RESULT float64 1x1 DAAA Mul Mu_C0
34 | 14 = | RESULT float64 1x1 GAAA Add Ad_C01 | RESULT float64 1x1 GAAA Add Ad_C01
35 | 15 = | RESULT float64 1x1 GAAA Add Ad_C0 | RESULT float64 1x1 GAAA Add Ad_C0
36 | 16 = | RESULT int64 1x1 AAAA ArgMax label | RESULT int64 1x1 AAAA ArgMax label
37 | 17 + | | RESULT float64 1x1 GAAA ReduceMax Re_reduced03
38 | 18 + | | RESULT float64 1x1 AAAA Sub Su_C01
39 | 19 + | | RESULT float64 1x1 BAAA Exp Ex_output0
40 | 20 + | | RESULT float64 1x1 BAAA ReduceSum Re_reduced02
41 | 21 + | | RESULT float64 1x1 AAAA Log Lo_output0
42 | 22 ~ | RESULT float64 1x1 GAAA ReduceLogS score_sample | RESULT float64 1x1 GAAA Add score_sample
43 | 23 = | RESULT float64 1x1 AAAA Sub Su_C0 | RESULT float64 1x1 AAAA Sub Su_C0
44 | 24 = | RESULT float64 1x1 BAAA Exp probabilitie | RESULT float64 1x1 BAAA Exp probabilitie
45 | 25 = | OUTPUT int64 1x1 AAAA label | OUTPUT int64 1x1 AAAA label
46 | 26 = | OUTPUT float64 1x1 BAAA probabilitie | OUTPUT float64 1x1 BAAA probabilitie
47 | 27 = | OUTPUT float64 1x1 GAAA score_sample | OUTPUT float64 1x1 GAAA score_sample
48 |
--------------------------------------------------------------------------------
/_doc/run_coverage.sh:
--------------------------------------------------------------------------------
1 | python3 -m pytest --cov --cov-report html:_doc/_static/cov_html _unittests
2 |
--------------------------------------------------------------------------------
/_doc/tech/aapi.rst:
--------------------------------------------------------------------------------
1 | .. _l-array-api-painpoint:
2 |
3 | Difficulty to implement an Array API for ONNX
4 | =============================================
5 |
6 | Implementing the full array API is not always easy with :epkg:`onnx`.
7 | Python is not strongly typed and many different types can be used
8 | to represent a value. Argument *axis* can be an integer or a tuple
9 | (see `min from Array API
10 | `_
12 | for example). On the other side, `ReduceMin from ONNX
13 | `_
14 | is considered as a tensor.
15 |
16 | Performance
17 | +++++++++++
18 |
19 | The Array API must work in eager mode and for every operation,
20 | it generates an ONNX graph and executes it with a specific
21 | backend. It can be :epkg:`numpy`, :epkg:`onnxruntime` or any other
22 | backend. The generation of every graph takes a significant amount of time.
23 | It must be avoided. These graphs are cached. But a graph can be reused
24 | only if the inputs - by ONNX semantic - change. If a parameter change,
25 | a new graph must be cached. Method :meth:`JitEager.make_key
26 | `
27 | generates a unique key based on the input it receives,
28 | the signature of the function to call. If the key is the same,
29 | a cached onnx can be reused on the second call.
30 |
31 | However, eager mode - use a small single onnx graph for every operation -
32 | is not the most efficient one. At the same time, the design must allow
33 | to merge every needed operation into a bigger graph.
34 | Bigger graphs can be more easily optimized by the backend.
35 |
36 | Input vs parameter
37 | ++++++++++++++++++
38 |
39 | An input is a tensor or array, a parameter is any other type.
40 | Following onnx semantic, an input is variable, a parameter is frozen
41 | cannot be changed. It is a constant. A good design would be
42 | to considered any named input (`**kwargs`) a parameter and
43 | any input (`*args`) a tensor. But the Array API does not follow that
44 | design. Function `astype
45 | _`
47 | takes two inputs. Operator `Cast
48 | _`
49 | takes one input and a frozen parameter `to`.
50 | And python allows `astype(x, dtype)` as well as `astype(x, dtype=dtype)`
51 | unless the signature enforces one call over another type.
52 | There may be ambiguities from time to time.
53 | Beside, from onnx point of view, argument dtype should be named.
54 |
55 | Tensor type
56 | +++++++++++
57 |
58 | An :class:`EagerTensor `
59 | must be used to represent any tensor.
60 | This class defines the backend to use as well.
61 | :class:`EagerNumpyTensor
62 | `
63 | for :epkg:`numpy`, :class:`EagerOrtTensor
64 | `
65 | for :epkg:`onnxruntime`. Since the Array API is new,
66 | existing packages do not fully support the API if they support it
67 | (:epkg:`scikit-learn`). Some numpy array may still be used.
68 |
69 | Inplace
70 | +++++++
71 |
72 | ONNX has no notion of inplace computation. Therefore something
73 | like `coefs[:, 1] = 1` is not valid unless some code is written
74 | to create another tensor. The current design supports some of these
75 | by storing every call to `__setitem__`. The user sees `coefs`
76 | but the framework sees that `coefs` holds a reference to another
77 | tensor. That's the one the framework needs to use. However, since
78 | `__setitem__` is used for efficiency, it becomes less than efficient
79 | with this design and should be avoided. This assumption may be true
80 | when the backend is relying on CPU but not on GPU.
81 | A function such as `empty
82 | `_ should be avoided as it
84 | has to be followed by calls to `__setitem__`.
85 |
86 | Eager or compilation
87 | ++++++++++++++++++++
88 |
89 | Eager mode is what the Array API implies.
90 | Every function is converted into an ONNX graph based
91 | on its inputs without any knownledge of how these inputs
92 | were obtained. This graph is then executed before going
93 | to the next call of a function from the API.
94 | The conversion of a machine learned model
95 | into ONNX implies the gathering of all these operations
96 | into a graph. It means using a mode that records all the function
97 | calls to compile every tiny onnx graph into a unique graph.
98 |
99 | Iterators and Reduction
100 | +++++++++++++++++++++++
101 |
102 | An efficient implementation of function
103 | :func:`numpy.any` or :func:`numpy.all` returns
104 | as soon as the result is known. :func:`numpy.all` is
105 | false whenever the first false condition is met.
106 | Same goes for :func:`numpy.any` which is true
107 | whenever the first true condition is met.
108 | There is no such operator in ONNX (<= 20) because
109 | it is unlikely to appear in a machine learned model.
110 | However, it is highly used when two results are
111 | compared in unit tests. The ONNX implementation is
112 | not efficient due to that reason but it only impacts
113 | the unit tests.
114 |
115 | Types
116 | +++++
117 |
118 | :epkg:`onnx` supports more types than :epkg:`numpy` does.
119 | It is not always easy to deal with bfloat16 or float8 types.
120 |
--------------------------------------------------------------------------------
/_doc/tech/index.rst:
--------------------------------------------------------------------------------
1 | Technical Details
2 | =================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 |
7 | aapi
--------------------------------------------------------------------------------
/_doc/tutorial/benchmarks.rst:
--------------------------------------------------------------------------------
1 | ==========
2 | Benchmarks
3 | ==========
4 |
5 | A list of benchmark used to improve to the performance of
6 | ONNX components (onnx, onnxruntime, onnx-array-api, ...).
7 |
8 | .. toctree::
9 |
10 | ../auto_examples/plot_benchmark_rf
11 |
--------------------------------------------------------------------------------
/_doc/tutorial/graph_api.rst:
--------------------------------------------------------------------------------
1 | .. _l-graph-api:
2 |
3 | =================================
4 | GraphBuilder: common API for ONNX
5 | =================================
6 |
7 | This is a very common way to build ONNX graph. There are some
8 | annoying steps while building an ONNX graph. The first one is to
9 | give unique names to every intermediate result in the graph. The second
10 | is the conversion from numpy arrays to onnx tensors. A *graph builder*,
11 | here implemented by class
12 | :class:`GraphBuilder `
13 | usually makes these two frequent tasks easier.
14 |
15 | .. runpython::
16 | :showcode:
17 |
18 | import numpy as np
19 | from onnx_array_api.graph_api import GraphBuilder
20 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
21 |
22 | g = GraphBuilder()
23 | g.make_tensor_input("X", np.float32, (None, None))
24 | g.make_tensor_input("Y", np.float32, (None, None))
25 | r1 = g.make_node("Sub", ["X", "Y"]) # the name given to the output is given by the class,
26 | # it ensures the name is unique
27 | init = g.make_initializer(np.array([2], dtype=np.int64)) # the class automatically
28 | # converts the array to a tensor
29 | r2 = g.make_node("Pow", [r1, init])
30 | g.make_node("ReduceSum", [r2], outputs=["Z"]) # the output name is given because
31 | # the user wants to choose the name
32 | g.make_tensor_output("Z", np.float32, (None, None))
33 |
34 | onx = g.to_onnx() # final conversion to onnx
35 |
36 | print(onnx_simple_text_plot(onx))
37 |
38 | A more simple versions of the same code to produce the same graph.
39 |
40 | .. runpython::
41 | :showcode:
42 |
43 | import numpy as np
44 | from onnx_array_api.graph_api import GraphBuilder
45 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
46 |
47 | g = GraphBuilder()
48 | g.make_tensor_input("X", np.float32, (None, None))
49 | g.make_tensor_input("Y", np.float32, (None, None))
50 | r1 = g.op.Sub("X", "Y") # the method name indicates which operator to use,
51 | # this can be used when there is no ambiguity about the
52 | # number of outputs
53 | r2 = g.op.Pow(r1, np.array([2], dtype=np.int64))
54 | g.op.ReduceSum(r2, outputs=["Z"]) # the still wants the user to specify the name
55 | g.make_tensor_output("Z", np.float32, (None, None))
56 |
57 | onx = g.to_onnx()
58 |
59 | print(onnx_simple_text_plot(onx))
60 |
--------------------------------------------------------------------------------
/_doc/tutorial/index.rst:
--------------------------------------------------------------------------------
1 |
2 | ========
3 | Tutorial
4 | ========
5 |
6 | .. toctree::
7 | :maxdepth: 1
8 |
9 | onnx_api
10 | graph_api
11 | light_api
12 | numpy_api
13 | tools
14 | benchmarks
15 |
--------------------------------------------------------------------------------
/_doc/tutorial/light_api.rst:
--------------------------------------------------------------------------------
1 | .. _l-light-api:
2 |
3 | ==========================================
4 | Light API for ONNX: everything in one line
5 | ==========================================
6 |
7 | It is inspired from the :epkg:`reverse Polish notation`.
8 | Following example implements the euclidean distance.
9 | This API tries to keep it simple and intuitive to short functions.
10 |
11 | .. runpython::
12 | :showcode:
13 |
14 | import numpy as np
15 | from onnx_array_api.light_api import start
16 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
17 |
18 | model = (
19 | start()
20 | .vin("X")
21 | .vin("Y")
22 | .bring("X", "Y")
23 | .Sub()
24 | .rename("dxy")
25 | .cst(np.array([2], dtype=np.int64), "two")
26 | .bring("dxy", "two")
27 | .Pow()
28 | .ReduceSum()
29 | .rename("Z")
30 | .vout()
31 | .to_onnx()
32 | )
33 |
34 | print(onnx_simple_text_plot(model))
35 |
36 | There are two kinds of methods, the graph methods, playing with the graph structure,
37 | and the methods for operators starting with an upper letter.
38 |
39 | Graph methods
40 | =============
41 |
42 | Any graph must start with function :func:`start `.
43 | It is usually following by `vin` to add an input.
44 |
45 | * bring (:meth:`Var.bring `,
46 | :meth:`Vars.bring `):
47 | assembles multiple results into a set before calling an operator taking mulitple inputs,
48 | * cst (:meth:`Var.cst `,
49 | :meth:`Vars.cst `):
50 | adds a constant tensor to the graph,
51 | * rename (:meth:`Var.rename `,
52 | :meth:`Vars.rename `):
53 | renames or give a name to a variable in order to call it later.
54 | * vin (:meth:`Var.vin `,
55 | :meth:`Vars.vin `):
56 | adds an input to the graph,
57 | * vout (:meth:`Var.vout `,
58 | :meth:`Vars.vout `):
59 | declares an existing result as an output.
60 |
61 | These methods are implemented in class :class:`onnx_array_api.light_api.var.BaseVar`
62 |
63 | Operator methods
64 | ================
65 |
66 | They are described in :epkg:`ONNX Operators` and redefined in a stable API
67 | so that the definition should not change depending on this opset.
68 | :class:`onnx_array_api.light_api.Var` defines all operators taking only one input.
69 | :class:`onnx_array_api.light_api.Vars` defines all other operators.
70 |
71 | Numpy methods
72 | =============
73 |
74 | Numpy users expect methods such as `reshape`, property `shape` or
75 | operator `+` to be available as well and that the case. They are
76 | defined in class :class:`Var ` or
77 | :class:`Vars ` depending on the number of
78 | inputs they require. Their name starts with a lower letter.
79 |
80 | Other domains
81 | =============
82 |
83 | The following example uses operator *Normalizer* from domain
84 | *ai.onnx.ml*. The operator name is called with the syntax
85 | `.`. The domain may have dots in its name
86 | but it must follow the python definition of a variable.
87 | The operator *Normalizer* becomes `ai.onnx.ml.Normalizer`.
88 |
89 | .. runpython::
90 | :showcode:
91 |
92 | import numpy as np
93 | from onnx_array_api.light_api import start
94 | from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
95 |
96 | model = (
97 | start(opset=19, opsets={"ai.onnx.ml": 3})
98 | .vin("X")
99 | .reshape((-1, 1))
100 | .rename("USE")
101 | .ai.onnx.ml.Normalizer(norm="MAX")
102 | .rename("Y")
103 | .vout()
104 | .to_onnx()
105 | )
106 |
107 | print(onnx_simple_text_plot(model))
108 |
--------------------------------------------------------------------------------
/_doc/tutorial/numpy_api.rst:
--------------------------------------------------------------------------------
1 | .. _l-numpy-api-onnx:
2 |
3 | ==================
4 | Numpy API for ONNX
5 | ==================
6 |
7 | Many users have difficulties to write onnx graphs.
8 | Many packages tries to symplify it either by implementing
9 | their own api very close to onnx operators
10 | (`sklearn-onnx `_,
11 | `tf2onnx `_,
12 | `spox `_,
13 | `onnx-script `_).
14 | This contribution tries a different approach by implementing
15 | a numpy API for ONNX. It does not cover everything numpy
16 | or ONNX can do but it can easily be used to define
17 | loss functions for example without knowing too much about ONNX.
18 |
19 | .. note:: control flow
20 |
21 | The first version (onnx==1.15) does not support control flow yet (test and loops).
22 | There is no easy syntax for that yet and the main challenge is to deal with local context.
23 |
24 | You read :ref:`l-array-api-painpoint` as well.
25 |
26 | Overview
27 | ========
28 |
29 | .. toctree::
30 |
31 | ../auto_examples/plot_first_example
32 | ../auto_examples/plot_onnxruntime
33 |
--------------------------------------------------------------------------------
/_doc/tutorial/tools.rst:
--------------------------------------------------------------------------------
1 | =====
2 | Tools
3 | =====
4 |
5 | Some of useful tools.
6 |
7 | Text representation
8 | ===================
9 |
10 | Plotting a graph is great but difficult to read when
11 | the graph is big and it is slow.
12 | :func:`onnx_array_api.plotting.text_plot.onnx_simple_text_plot`
13 | prints out a text representation.
14 |
15 | Differences between two models
16 | ==============================
17 |
18 | How to understand the differences between two models
19 | assuming they are producing the same outputs?
20 | Example :ref:`l-onnx-diff-example` shows how to do it.
21 |
--------------------------------------------------------------------------------
/_unittests/onnx-numpy-skips.txt:
--------------------------------------------------------------------------------
1 | # API failures
2 | # see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt
3 | # uses __setitem__
4 | array_api_tests/test_creation_functions.py::test_arange
5 | array_api_tests/test_creation_functions.py::test_asarray_arrays
6 | array_api_tests/test_creation_functions.py::test_empty
7 | array_api_tests/test_creation_functions.py::test_empty_like
8 | array_api_tests/test_creation_functions.py::test_eye
9 | array_api_tests/test_creation_functions.py::test_full
10 | array_api_tests/test_creation_functions.py::test_full_like
11 | array_api_tests/test_creation_functions.py::test_ones
12 | array_api_tests/test_creation_functions.py::test_ones_like
13 | array_api_tests/test_creation_functions.py::test_zeros
14 | array_api_tests/test_creation_functions.py::test_zeros_like
15 | # fails to precision issue
16 | array_api_tests/test_creation_functions.py::test_linspace
17 | array_api_tests/test_creation_functions.py::test_meshgrid
18 |
--------------------------------------------------------------------------------
/_unittests/onnx-ort-skips.txt:
--------------------------------------------------------------------------------
1 | # Not implementated by onnxruntime
2 | array_api_tests/test_creation_functions.py::test_arange
3 | array_api_tests/test_creation_functions.py::test_asarray_scalars
4 | array_api_tests/test_creation_functions.py::test_asarray_arrays
5 | array_api_tests/test_creation_functions.py::test_empty
6 | array_api_tests/test_creation_functions.py::test_empty_like
7 | array_api_tests/test_creation_functions.py::test_eye
8 | array_api_tests/test_creation_functions.py::test_full
9 | array_api_tests/test_creation_functions.py::test_full_like
10 | array_api_tests/test_creation_functions.py::test_linspace
11 | array_api_tests/test_creation_functions.py::test_meshgrid
12 | array_api_tests/test_creation_functions.py::test_ones
13 | array_api_tests/test_creation_functions.py::test_ones_like
14 | array_api_tests/test_creation_functions.py::test_zeros
15 | array_api_tests/test_creation_functions.py::test_zeros_like
16 |
--------------------------------------------------------------------------------
/_unittests/test_array_api.sh:
--------------------------------------------------------------------------------
1 | export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2 | pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_full_like || exit 1
3 | # pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
4 | pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1
5 |
--------------------------------------------------------------------------------
/_unittests/ut_array_api/test_array_apis.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from inspect import isfunction, ismethod
3 | import numpy as np
4 | from onnx_array_api.ext_test_case import ExtTestCase
5 | from onnx_array_api.array_api import onnx_numpy as xpn
6 | from onnx_array_api.array_api import onnx_ort as xpo
7 |
8 | # from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
9 | # from onnx_array_api.ort.ort_tensors import EagerOrtTensor
10 |
11 |
12 | class TestArraysApis(ExtTestCase):
13 | def test_zeros_numpy_1(self):
14 | c = xpn.zeros(1)
15 | d = c.numpy()
16 | self.assertEqualArray(np.array([0], dtype=np.float64), d)
17 |
18 | def test_zeros_ort_1(self):
19 | c = xpo.zeros(1)
20 | d = c.numpy()
21 | self.assertEqualArray(np.array([0], dtype=np.float64), d)
22 |
23 | def test_ffinfo(self):
24 | dt = np.float32
25 | fi1 = np.finfo(dt)
26 | fi2 = xpn.finfo(dt)
27 | fi3 = xpo.finfo(dt)
28 | dt1 = fi1.dtype
29 | dt2 = fi2.dtype
30 | dt3 = fi3.dtype
31 | self.assertEqual(dt2, dt3)
32 | self.assertNotEqual(dt1.__class__, dt2.__class__)
33 | mi1 = fi1.min
34 | mi2 = fi2.min
35 | self.assertEqual(mi1, mi2)
36 | mi1 = fi1.smallest_normal
37 | mi2 = fi2.smallest_normal
38 | self.assertEqual(mi1, mi2)
39 | for n in dir(fi1):
40 | if n.startswith("__"):
41 | continue
42 | if n in {"machar"}:
43 | continue
44 | v1 = getattr(fi1, n)
45 | with self.subTest(att=n):
46 | v2 = getattr(fi2, n)
47 | v3 = getattr(fi3, n)
48 | if isfunction(v1) or ismethod(v1):
49 | try:
50 | v1 = v1()
51 | except TypeError:
52 | continue
53 | v2 = v2()
54 | v3 = v3()
55 | if v1 != v2:
56 | raise AssertionError(
57 | f"12: info disagree on name {n!r}: {v1} != {v2}, "
58 | f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
59 | f"ismethod={ismethod(v1)}."
60 | )
61 | if v2 != v3:
62 | raise AssertionError(
63 | f"23: info disagree on name {n!r}: {v2} != {v3}, "
64 | f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
65 | f"ismethod={ismethod(v1)}."
66 | )
67 |
68 | def test_iiinfo(self):
69 | dt = np.int64
70 | fi1 = np.iinfo(dt)
71 | fi2 = xpn.iinfo(dt)
72 | fi3 = xpo.iinfo(dt)
73 | dt1 = fi1.dtype
74 | dt2 = fi2.dtype
75 | dt3 = fi3.dtype
76 | self.assertEqual(dt2, dt3)
77 | self.assertNotEqual(dt1.__class__, dt2.__class__)
78 | mi1 = fi1.min
79 | mi2 = fi2.min
80 | self.assertEqual(mi1, mi2)
81 | for n in dir(fi1):
82 | if n.startswith("__"):
83 | continue
84 | if n in {"machar"}:
85 | continue
86 | v1 = getattr(fi1, n)
87 | with self.subTest(att=n):
88 | v2 = getattr(fi2, n)
89 | v3 = getattr(fi3, n)
90 | if isfunction(v1) or ismethod(v1):
91 | try:
92 | v1 = v1()
93 | except TypeError:
94 | continue
95 | v2 = v2()
96 | v3 = v3()
97 | if v1 != v2:
98 | raise AssertionError(
99 | f"12: info disagree on name {n!r}: {v1} != {v2}, "
100 | f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
101 | f"ismethod={ismethod(v1)}."
102 | )
103 | if v2 != v3:
104 | raise AssertionError(
105 | f"23: info disagree on name {n!r}: {v2} != {v3}, "
106 | f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
107 | f"ismethod={ismethod(v1)}."
108 | )
109 |
110 |
111 | if __name__ == "__main__":
112 | unittest.main(verbosity=2)
113 |
--------------------------------------------------------------------------------
/_unittests/ut_array_api/test_onnx_ort.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | from onnx_array_api.ext_test_case import ExtTestCase
4 | from onnx_array_api.array_api import onnx_ort as xp
5 | from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
6 | from onnx_array_api.ort.ort_tensors import EagerOrtTensor as EagerTensor
7 |
8 |
9 | class TestOnnxOrt(ExtTestCase):
10 | def test_abs(self):
11 | c = EagerTensor(np.array([4, 5], dtype=np.int64))
12 | mat = xp.zeros(c, dtype=xp.int64)
13 | matnp = mat.numpy()
14 | self.assertEqual(matnp.shape, (4, 5))
15 | self.assertNotEmpty(matnp[0, 0])
16 | a = xp.absolute(mat)
17 | self.assertEqualArray(np.absolute(mat.numpy()), a.numpy())
18 |
19 | def test_matmul(self):
20 | for cls in [EagerTensor, EagerNumpyTensor]:
21 | for dtype in (np.float32, np.float64):
22 | X = cls(
23 | np.array(
24 | [[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]],
25 | dtype=dtype,
26 | )
27 | )
28 | coef = cls(np.array([[1e-13, 8]], dtype=dtype).T)
29 | self.assertEqualArray(
30 | np.array([[1e-13, 8]], dtype=dtype), coef.numpy().T
31 | )
32 | expected = X.numpy() @ coef.numpy()
33 | got = X @ coef
34 | try:
35 | self.assertEqualArray(expected, got.numpy())
36 | except AssertionError as e:
37 | raise AssertionError(
38 | f"Discrepancies (1) with cls={cls.__name__}, dtype={dtype}"
39 | ) from e
40 |
41 | coef = np.array([[1e-13, 8]], dtype=dtype).T
42 | expected = X.numpy() @ coef
43 | got = X @ coef
44 | try:
45 | self.assertEqualArray(expected, got.numpy())
46 | except AssertionError as e:
47 | raise AssertionError(
48 | f"Discrepancies (2) with cls={cls.__name__}, dtype={dtype}"
49 | ) from e
50 |
51 |
52 | if __name__ == "__main__":
53 | # import logging
54 |
55 | # logging.basicConfig(level=logging.DEBUG)
56 | # TestOnnxOrt().test_matmul()
57 | unittest.main(verbosity=2)
58 |
--------------------------------------------------------------------------------
/_unittests/ut_graph_api/data/debug_7951-CPUep.0.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_graph_api/data/debug_7951-CPUep.0.onnx
--------------------------------------------------------------------------------
/_unittests/ut_graph_api/test_graph_builder_optim.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 | import onnx
4 | from onnx.inliner import inline_local_functions
5 | from onnx_array_api.ext_test_case import ExtTestCase
6 | from onnx_array_api.graph_api.graph_builder import GraphBuilder
7 |
8 |
9 | class TestGraphBuilderOptim(ExtTestCase):
10 | def test_wcheck_afiles(self):
11 | import onnxruntime
12 |
13 | data = os.path.join(os.path.dirname(__file__), "data")
14 | filename = [f for f in os.listdir(data) if f.endswith(".onnx")]
15 | for f in filename:
16 | with self.subTest(f=f):
17 | onx = onnx.load(os.path.join(data, f))
18 | sess = onnxruntime.InferenceSession(
19 | os.path.join(data, f), providers=["CPUExecutionProvider"]
20 | )
21 | assert sess
22 | onxi = inline_local_functions(onx)
23 | sess = onnxruntime.InferenceSession(
24 | onxi.SerializeToString(), providers=["CPUExecutionProvider"]
25 | )
26 | assert sess
27 | g = GraphBuilder(onxi)
28 | g.optimize(check_order=True)
29 | g.check_order()
30 | onx2 = g.to_onnx()
31 | sess2 = onnxruntime.InferenceSession(
32 | onx2.SerializeToString(), providers=["CPUExecutionProvider"]
33 | )
34 | assert sess2
35 |
36 |
37 | if __name__ == "__main__":
38 | unittest.main(verbosity=2)
39 |
--------------------------------------------------------------------------------
/_unittests/ut_npx/test_sklearn_array_api.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | from packaging.version import Version
4 | from onnx.defs import onnx_opset_version
5 | from sklearn import config_context, __version__ as sklearn_version
6 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
7 | from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings
8 | from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
9 |
10 |
11 | DEFAULT_OPSET = onnx_opset_version()
12 |
13 |
14 | class TestSklearnArrayAPI(ExtTestCase):
15 | @unittest.skipIf(
16 | Version(sklearn_version) <= Version("1.2.2"),
17 | reason="reshape ArrayAPI not followed",
18 | )
19 | @ignore_warnings(DeprecationWarning)
20 | @unittest.skip("not maintained")
21 | def test_sklearn_array_api_linear_discriminant(self):
22 | X = np.array(
23 | [[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float64
24 | )
25 | y = np.array([1, 1, 1, 2, 2, 2], dtype=np.int64)
26 | ana = LinearDiscriminantAnalysis()
27 | ana.fit(X, y)
28 | expected = ana.predict(X)
29 |
30 | new_x = EagerNumpyTensor(X)
31 | self.assertStartsWith("EagerNumpyTensor(array([[", repr(new_x))
32 | with config_context(array_api_dispatch=True):
33 | # It fails if scikit-learn <= 1.2.2 because the ArrayAPI
34 | # is not strictly applied.
35 | got = ana.predict(new_x)
36 | self.assertEqualArray(expected, got.numpy())
37 |
38 | @unittest.skipIf(
39 | Version(sklearn_version) <= Version("1.2.2"),
40 | reason="reshape ArrayAPI not followed",
41 | )
42 | @ignore_warnings(DeprecationWarning)
43 | @unittest.skip("not maintained")
44 | def test_sklearn_array_api_linear_discriminant_float32(self):
45 | X = np.array(
46 | [[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float32
47 | )
48 | y = np.array([1, 1, 1, 2, 2, 2], dtype=np.int64)
49 | ana = LinearDiscriminantAnalysis()
50 | ana.fit(X, y)
51 | expected = ana.predict(X)
52 |
53 | new_x = EagerNumpyTensor(X)
54 | self.assertStartsWith("EagerNumpyTensor(array([[", repr(new_x))
55 | with config_context(array_api_dispatch=True):
56 | # It fails if scikit-learn <= 1.2.2 because the ArrayAPI
57 | # is not strictly applied.
58 | got = ana.predict(new_x)
59 | self.assertEqualArray(expected, got.numpy())
60 |
61 |
62 | if __name__ == "__main__":
63 | # import logging
64 |
65 | # logging.basicConfig(level=logging.DEBUG)
66 | unittest.main(verbosity=2)
67 |
--------------------------------------------------------------------------------
/_unittests/ut_ort/data/prof_base.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_ort/data/prof_base.xlsx
--------------------------------------------------------------------------------
/_unittests/ut_ort/data/prof_opti.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_ort/data/prof_opti.xlsx
--------------------------------------------------------------------------------
/_unittests/ut_ort/test_ort_optimizer.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 | from onnx.defs import onnx_opset_version
5 | from onnx_array_api.npx import absolute, jit_onnx
6 | from onnx_array_api.ext_test_case import ExtTestCase
7 | from onnx_array_api.ort.ort_optimizers import ort_optimized_model
8 |
9 |
10 | DEFAULT_OPSET = onnx_opset_version()
11 |
12 |
13 | class TestOrtOptimizer(ExtTestCase):
14 | def test_ort_optimizers(self):
15 | def l1_loss(x, y):
16 | return absolute(x - y).sum()
17 |
18 | def l2_loss(x, y):
19 | return ((x - y) ** 2).sum()
20 |
21 | def myloss(x, y):
22 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
23 |
24 | jitted_myloss = jit_onnx(myloss)
25 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
26 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
27 | jitted_myloss(x, y)
28 | onx = jitted_myloss.get_onnx()
29 | self.assertRaise(lambda: ort_optimized_model(onx, "NO"), ValueError)
30 | optimized = ort_optimized_model(onx)
31 | self.assertIn('op_type: "Squeeze"', str(optimized))
32 | self.assertIn("initializer {", str(optimized))
33 |
34 |
35 | if __name__ == "__main__":
36 | unittest.main(verbosity=2)
37 |
--------------------------------------------------------------------------------
/_unittests/ut_ort/test_ort_profile.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import os
3 | import numpy as np
4 | from pandas import DataFrame, read_excel
5 | from onnx_array_api.npx import absolute, jit_onnx
6 | from onnx_array_api.ext_test_case import ExtTestCase
7 | from onnx_array_api.ort.ort_optimizers import ort_optimized_model
8 | from onnx_array_api.ort.ort_profile import ort_profile, merge_ort_profile
9 | from onnxruntime.capi._pybind_state import (
10 | OrtValue as C_OrtValue,
11 | OrtDevice as C_OrtDevice,
12 | )
13 |
14 |
15 | class TestOrtProfile(ExtTestCase):
16 | def test_ort_profile(self):
17 | def l1_loss(x, y):
18 | return absolute(x - y).sum()
19 |
20 | def l2_loss(x, y):
21 | return ((x - y) ** 2).sum()
22 |
23 | def myloss(x, y):
24 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
25 |
26 | jitted_myloss = jit_onnx(myloss)
27 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
28 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
29 | jitted_myloss(x, y)
30 | onx = jitted_myloss.get_onnx()
31 | feeds = {"x0": x, "x1": y}
32 | self.assertRaise(lambda: ort_optimized_model(onx, "NO"), ValueError)
33 | optimized = ort_optimized_model(onx)
34 | prof = ort_profile(optimized, feeds)
35 | self.assertIsInstance(prof, DataFrame)
36 | prof = ort_profile(optimized, feeds, as_df=False)
37 | self.assertIsInstance(prof, list)
38 |
39 | def test_ort_profile_first_it_out(self):
40 | def l1_loss(x, y):
41 | return absolute(x - y).sum()
42 |
43 | def l2_loss(x, y):
44 | return ((x - y) ** 2).sum()
45 |
46 | def myloss(x, y):
47 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
48 |
49 | jitted_myloss = jit_onnx(myloss)
50 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
51 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
52 | jitted_myloss(x, y)
53 | onx = jitted_myloss.get_onnx()
54 | feeds = {"x0": x, "x1": y}
55 | self.assertRaise(lambda: ort_optimized_model(onx, "NO"), ValueError)
56 | optimized = ort_optimized_model(onx)
57 | prof = ort_profile(optimized, feeds)
58 | events = {
59 | "kernel_time",
60 | "SequentialExecutor::Execute",
61 | "model_run",
62 | "model_loading_array",
63 | "session_initialization",
64 | }
65 | self.assertEqual(set(prof["event_name"]), events)
66 | agg = ort_profile(optimized, feeds, first_it_out=True, agg=True)
67 | self.assertIsInstance(agg, DataFrame)
68 | self.assertLess(agg.shape[0], prof.shape[0])
69 | self.assertEqual(set(agg.reset_index(drop=False)["event_name"]), events)
70 | agg = ort_profile(
71 | optimized, feeds, first_it_out=True, agg=True, agg_op_name=False
72 | )
73 | self.assertIsInstance(agg, DataFrame)
74 | self.assertLess(agg.shape[0], prof.shape[0])
75 | self.assertEqual(set(agg.reset_index(drop=False)["event_name"]), events)
76 |
77 | def test_ort_profile_ort_value(self):
78 | def to_ort_value(m):
79 | device = C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0)
80 | ort_value = C_OrtValue.ortvalue_from_numpy(m, device)
81 | return ort_value
82 |
83 | def l1_loss(x, y):
84 | return absolute(x - y).sum()
85 |
86 | def l2_loss(x, y):
87 | return ((x - y) ** 2).sum()
88 |
89 | def myloss(x, y):
90 | return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
91 |
92 | jitted_myloss = jit_onnx(myloss)
93 | x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
94 | y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
95 | jitted_myloss(x, y)
96 | onx = jitted_myloss.get_onnx()
97 | np_feeds = {"x0": x, "x1": y}
98 | feeds = {k: to_ort_value(v) for k, v in np_feeds.items()}
99 |
100 | self.assertRaise(lambda: ort_optimized_model(onx, "NO"), ValueError)
101 | optimized = ort_optimized_model(onx)
102 | prof = ort_profile(optimized, feeds)
103 | self.assertIsInstance(prof, DataFrame)
104 | prof = ort_profile(optimized, feeds, as_df=False)
105 | self.assertIsInstance(prof, list)
106 |
107 | def test_merge_ort_profile(self):
108 | data = os.path.join(os.path.dirname(__file__), "data")
109 | df1 = read_excel(os.path.join(data, "prof_base.xlsx"))
110 | df2 = read_excel(os.path.join(data, "prof_opti.xlsx"))
111 | merged, gr = merge_ort_profile(df1, df2)
112 | self.assertEqual(merged.shape, (23, 9))
113 | self.assertEqual(
114 | list(merged.columns),
115 | [
116 | "args_op_name",
117 | "args_output_type_shape",
118 | "args_input_type_shape",
119 | "args_provider",
120 | "idx",
121 | "durbase",
122 | "countbase",
123 | "duropti",
124 | "countopti",
125 | ],
126 | )
127 | self.assertEqual(gr.shape, (19, 4))
128 | self.assertEqual(
129 | list(gr.columns), ["durbase", "duropti", "countbase", "countopti"]
130 | )
131 |
132 |
133 | if __name__ == "__main__":
134 | unittest.main(verbosity=2)
135 |
--------------------------------------------------------------------------------
/_unittests/ut_ort/test_sklearn_array_api_ort.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | from packaging.version import Version
4 | from onnx.defs import onnx_opset_version
5 | from sklearn import config_context, __version__ as sklearn_version
6 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
7 | from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows
8 | from onnx_array_api.ort.ort_tensors import EagerOrtTensor, OrtTensor
9 |
10 |
11 | DEFAULT_OPSET = onnx_opset_version()
12 |
13 |
14 | class TestSklearnArrayAPIOrt(ExtTestCase):
15 | @unittest.skipIf(
16 | Version(sklearn_version) <= Version("1.2.2"),
17 | reason="reshape ArrayAPI not followed",
18 | )
19 | @skipif_ci_windows("Unstable on Windows.")
20 | @unittest.skip("discontinued")
21 | def test_sklearn_array_api_linear_discriminant_ort(self):
22 | X = np.array(
23 | [[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float64
24 | )
25 | y = np.array([1, 1, 1, 2, 2, 2], dtype=np.int64)
26 | ana = LinearDiscriminantAnalysis()
27 | ana.fit(X, y)
28 | expected = ana.predict(X)
29 |
30 | new_x = EagerOrtTensor(OrtTensor.from_array(X))
31 | self.assertEqual(new_x.device_name, "Cpu")
32 | self.assertStartsWith(
33 | "EagerOrtTensor(OrtTensor.from_array(array([[", repr(new_x)
34 | )
35 | with config_context(array_api_dispatch=True):
36 | got = ana.predict(new_x)
37 | self.assertEqualArray(expected, got.numpy())
38 |
39 | @unittest.skipIf(
40 | Version(sklearn_version) <= Version("1.2.2"),
41 | reason="reshape ArrayAPI not followed",
42 | )
43 | @skipif_ci_windows("Unstable on Windows.")
44 | @unittest.skip("discontinued")
45 | def test_sklearn_array_api_linear_discriminant_ort_float32(self):
46 | X = np.array(
47 | [[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float32
48 | )
49 | y = np.array([1, 1, 1, 2, 2, 2], dtype=np.int64)
50 | ana = LinearDiscriminantAnalysis()
51 | ana.fit(X, y)
52 | expected = ana.predict(X)
53 |
54 | new_x = EagerOrtTensor(OrtTensor.from_array(X))
55 | self.assertEqual(new_x.device_name, "Cpu")
56 | self.assertStartsWith(
57 | "EagerOrtTensor(OrtTensor.from_array(array([[", repr(new_x)
58 | )
59 | with config_context(array_api_dispatch=True):
60 | got = ana.predict(new_x)
61 | self.assertEqualArray(expected, got.numpy())
62 |
63 |
64 | if __name__ == "__main__":
65 | # import logging
66 |
67 | # logging.basicConfig(level=logging.DEBUG)
68 | unittest.main(verbosity=2)
69 |
--------------------------------------------------------------------------------
/_unittests/ut_plotting/data/bug_Hardmax.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_plotting/data/bug_Hardmax.onnx
--------------------------------------------------------------------------------
/_unittests/ut_plotting/data/onnx_text_plot_tree_cls_2.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_plotting/data/onnx_text_plot_tree_cls_2.onnx
--------------------------------------------------------------------------------
/_unittests/ut_plotting/data/tree_torch.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_plotting/data/tree_torch.onnx
--------------------------------------------------------------------------------
/_unittests/ut_plotting/test_graphviz.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 | import onnx.parser
4 | from onnx_array_api.ext_test_case import (
5 | ExtTestCase,
6 | skipif_ci_windows,
7 | skipif_ci_apple,
8 | )
9 | from onnx_array_api.plotting.dot_plot import to_dot
10 | from onnx_array_api.plotting.graphviz_helper import draw_graph_graphviz, plot_dot
11 |
12 |
13 | class TestGraphviz(ExtTestCase):
14 | @classmethod
15 | def _get_graph(cls):
16 | return onnx.parser.parse_model(
17 | """
18 |
19 | agraph (float[N] x) => (float[N] z) {
20 | two = Constant ()
21 | four = Add(two, two)
22 | z = Mul(x, x)
23 | }"""
24 | )
25 |
26 | @skipif_ci_windows("graphviz not installed")
27 | @skipif_ci_apple("graphviz not installed")
28 | def test_draw_graph_graphviz(self):
29 | fout = "test_draw_graph_graphviz.png"
30 | dot = to_dot(self._get_graph())
31 | draw_graph_graphviz(dot, image=fout)
32 | self.assertExists(os.path.exists(fout))
33 |
34 | @skipif_ci_windows("graphviz not installed")
35 | @skipif_ci_apple("graphviz not installed")
36 | def test_draw_graph_graphviz_proto(self):
37 | fout = "test_draw_graph_graphviz_proto.png"
38 | dot = self._get_graph()
39 | draw_graph_graphviz(dot, image=fout)
40 | self.assertExists(os.path.exists(fout))
41 |
42 | @skipif_ci_windows("graphviz not installed")
43 | @skipif_ci_apple("graphviz not installed")
44 | def test_plot_dot(self):
45 | dot = to_dot(self._get_graph())
46 | ax = plot_dot(dot)
47 | ax.get_figure().savefig("test_plot_dot.png")
48 |
49 |
50 | if __name__ == "__main__":
51 | unittest.main(verbosity=2)
52 |
--------------------------------------------------------------------------------
/_unittests/ut_plotting/test_stat_plot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 | import pandas
4 | import matplotlib.pyplot as plt
5 | from onnx_array_api.ext_test_case import ExtTestCase, matplotlib_test
6 | from onnx_array_api.plotting.stat_plot import plot_ort_profile
7 |
8 |
9 | class TestStatPlot(ExtTestCase):
10 | @matplotlib_test()
11 | def test_plot_ort_profile(self):
12 | data = os.path.join(os.path.dirname(__file__), "data", "prof.csv")
13 | df = pandas.read_csv(data)
14 | _, ax = plt.subplots(2, 1)
15 | plot_ort_profile(df, ax0=ax[0], ax1=ax[1])
16 |
17 |
18 | if __name__ == "__main__":
19 | unittest.main()
20 |
--------------------------------------------------------------------------------
/_unittests/ut_reference/test_array_tensor.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | from onnx import TensorProto
4 | from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
5 | from onnx_array_api.ext_test_case import ExtTestCase
6 | from onnx_array_api.reference import (
7 | to_array_extended,
8 | from_array_extended,
9 | ExtendedReferenceEvaluator,
10 | )
11 |
12 |
13 | class TestArrayTensor(ExtTestCase):
14 | def test_from_array(self):
15 | for dt in (np.float32, np.float16, np.uint16, np.uint8):
16 | with self.subTest(dtype=dt):
17 | a = np.array([0, 1, 2], dtype=dt)
18 | t = from_array_extended(a, "a")
19 | b = to_array_extended(t)
20 | self.assertEqualArray(a, b)
21 | t2 = from_array_extended(b, "a")
22 | self.assertEqual(t.SerializeToString(), t2.SerializeToString())
23 |
24 | def test_from_array_f8(self):
25 | def make_model_f8(fr, to):
26 | model = make_model(
27 | make_graph(
28 | [make_node("Cast", ["X"], ["Y"], to=to)],
29 | "cast",
30 | [make_tensor_value_info("X", fr, None)],
31 | [make_tensor_value_info("Y", to, None)],
32 | )
33 | )
34 | return model
35 |
36 | for dt in (np.float32, np.float16, np.uint16, np.uint8):
37 | with self.subTest(dtype=dt):
38 | a = np.array([0, 1, 2], dtype=dt)
39 | b = from_array_extended(a, "a")
40 | for to in [
41 | TensorProto.FLOAT8E4M3FN,
42 | TensorProto.FLOAT8E4M3FNUZ,
43 | TensorProto.FLOAT8E5M2,
44 | TensorProto.FLOAT8E5M2FNUZ,
45 | TensorProto.BFLOAT16,
46 | ]:
47 | with self.subTest(fr=b.data_type, to=to):
48 | model = make_model_f8(b.data_type, to)
49 | ref = ExtendedReferenceEvaluator(model)
50 | got = ref.run(None, {"X": a})[0]
51 | back = from_array_extended(got, "a")
52 | self.assertEqual(to, back.data_type)
53 |
54 |
55 | if __name__ == "__main__":
56 | unittest.main(verbosity=2)
57 |
--------------------------------------------------------------------------------
/_unittests/ut_reference/test_reference_ops.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | from onnx import TensorProto
4 | from onnx.helper import (
5 | make_graph,
6 | make_model,
7 | make_node,
8 | make_tensor_value_info,
9 | make_opsetid,
10 | )
11 | from onnx_array_api.ext_test_case import ExtTestCase
12 | from onnx_array_api.reference import ExtendedReferenceEvaluator
13 |
14 |
15 | class TestReferenceOps(ExtTestCase):
16 |
17 | def test_fused_matmul(self):
18 | model = make_model(
19 | make_graph(
20 | [make_node("FusedMatMul", ["X", "Y"], ["Z"], domain="com.microsoft")],
21 | "name",
22 | [
23 | make_tensor_value_info("X", TensorProto.FLOAT, None),
24 | make_tensor_value_info("Y", TensorProto.FLOAT, None),
25 | ],
26 | [make_tensor_value_info("Z", TensorProto.FLOAT, None)],
27 | ),
28 | opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
29 | )
30 | ref = ExtendedReferenceEvaluator(model)
31 | a = np.arange(4).reshape(-1, 2)
32 | got = ref.run(None, {"X": a, "Y": a})
33 | self.assertEqualArray(a @ a, got[0])
34 |
35 | def test_fused_matmul11(self):
36 | model = make_model(
37 | make_graph(
38 | [
39 | make_node(
40 | "FusedMatMul",
41 | ["X", "Y"],
42 | ["Z"],
43 | transA=1,
44 | transB=1,
45 | domain="com.microsoft",
46 | )
47 | ],
48 | "name",
49 | [
50 | make_tensor_value_info("X", TensorProto.FLOAT, None),
51 | make_tensor_value_info("Y", TensorProto.FLOAT, None),
52 | ],
53 | [make_tensor_value_info("Z", TensorProto.FLOAT, None)],
54 | ),
55 | opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
56 | )
57 | ref = ExtendedReferenceEvaluator(model)
58 | a = np.arange(4).reshape(-1, 2)
59 | got = ref.run(None, {"X": a, "Y": a})
60 | self.assertEqualArray(a.T @ a.T, got[0])
61 |
62 | def test_memcpy(self):
63 | model = make_model(
64 | make_graph(
65 | [
66 | make_node("MemcpyToHost", ["X"], ["Z"]),
67 | make_node("MemcpyFromHost", ["X"], ["Z"]),
68 | ],
69 | "name",
70 | [make_tensor_value_info("X", TensorProto.FLOAT, None)],
71 | [make_tensor_value_info("Z", TensorProto.FLOAT, None)],
72 | ),
73 | opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
74 | ir_version=9,
75 | )
76 | a = np.arange(4).reshape(-1, 2).astype(np.float32)
77 | ref = ExtendedReferenceEvaluator(model)
78 | got = ref.run(None, {"X": a})
79 | self.assertEqualArray(a, got[0])
80 |
81 | def test_quick_gelu(self):
82 | from onnxruntime import InferenceSession
83 |
84 | for alpha in [0.0, 2.0]:
85 | model = make_model(
86 | make_graph(
87 | [
88 | make_node(
89 | "QuickGelu",
90 | ["X"],
91 | ["Z"],
92 | domain="com.microsoft",
93 | alpha=alpha,
94 | )
95 | ],
96 | "name",
97 | [make_tensor_value_info("X", TensorProto.FLOAT, None)],
98 | [make_tensor_value_info("Z", TensorProto.FLOAT, None)],
99 | ),
100 | opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
101 | ir_version=9,
102 | )
103 | sess = InferenceSession(
104 | model.SerializeToString(), providers=["CPUExecutionProvider"]
105 | )
106 | a = np.arange(4).reshape(-1, 2).astype(np.float32)
107 | expected = sess.run(None, {"X": a})
108 | ref = ExtendedReferenceEvaluator(model)
109 | got = ref.run(None, {"X": a})
110 | self.assertEqualArray(expected[0], got[0])
111 |
112 | def test_scatter_elements(self):
113 | model = make_model(
114 | make_graph(
115 | [
116 | make_node(
117 | "ScatterElements",
118 | ["data", "indices", "updates"],
119 | ["Z"],
120 | axis=3,
121 | reduction="add",
122 | )
123 | ],
124 | "name",
125 | [
126 | make_tensor_value_info("data", TensorProto.FLOAT, None),
127 | make_tensor_value_info("indices", TensorProto.INT64, None),
128 | make_tensor_value_info("updates", TensorProto.FLOAT, None),
129 | ],
130 | [make_tensor_value_info("Z", TensorProto.FLOAT, None)],
131 | ),
132 | opset_imports=[make_opsetid("", 18)],
133 | )
134 | data = np.zeros(2**4, dtype=np.float32).reshape((2, 2, 2, 2))
135 | indices = np.array([[[[0]]]], dtype=np.int64)
136 | updates = np.array([[[[1]]]], dtype=np.float32)
137 | y = np.array(
138 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32
139 | ).reshape((2, 2, 2, 2))
140 | ref = ExtendedReferenceEvaluator(model)
141 | got = ref.run(None, {"data": data, "indices": indices, "updates": updates})
142 | self.assertEqualArray(y, got[0])
143 |
144 |
145 | if __name__ == "__main__":
146 | unittest.main(verbosity=2)
147 |
--------------------------------------------------------------------------------
/_unittests/ut_translate_api/_data/custom_ops_type_inference_fails_0.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_translate_api/_data/custom_ops_type_inference_fails_0.onnx
--------------------------------------------------------------------------------
/_unittests/ut_translate_api/_data/stft_inlined_batch_1.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_translate_api/_data/stft_inlined_batch_1.onnx
--------------------------------------------------------------------------------
/_unittests/ut_validation/data/small.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/_unittests/ut_validation/data/small.onnx
--------------------------------------------------------------------------------
/_unittests/ut_validation/test_diff.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from onnx import load
3 | from onnx.checker import check_model
4 | from onnx_array_api.ext_test_case import ExtTestCase
5 | from onnx_array_api.ort.ort_optimizers import ort_optimized_model
6 | from onnx_array_api.validation.diff import text_diff, html_diff
7 |
8 |
9 | class TestDiff(ExtTestCase):
10 | def test_diff_optimized(self):
11 | data = self.relative_path(__file__, "data", "small.onnx")
12 | with open(data, "rb") as f:
13 | model = load(f)
14 | optimized = ort_optimized_model(model)
15 | check_model(optimized)
16 | diff = text_diff(model, optimized)
17 | self.assertIn("^^^^^^^^^^^^^^^^", diff)
18 | ht = html_diff(model, optimized)
19 | self.assertIn("", ht)
20 |
21 |
22 | if __name__ == "__main__":
23 | unittest.main(verbosity=2)
24 |
--------------------------------------------------------------------------------
/_unittests/ut_validation/test_docs.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | from onnx.reference import ReferenceEvaluator
4 | from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows
5 | from onnx_array_api.validation.docs import make_euclidean, make_euclidean_skl2onnx
6 |
7 |
8 | class TestDocs(ExtTestCase):
9 | def test_make_euclidean(self):
10 | model = make_euclidean()
11 |
12 | ref = ReferenceEvaluator(model)
13 | X = np.random.rand(3, 4).astype(np.float32)
14 | Y = np.random.rand(3, 4).astype(np.float32)
15 | expected = ((X - Y) ** 2).sum(keepdims=1)
16 | got = ref.run(None, {"X": X, "Y": Y})[0]
17 | self.assertEqualArray(expected, got)
18 |
19 | def test_make_euclidean_skl2onnx(self):
20 | model = make_euclidean_skl2onnx()
21 |
22 | ref = ReferenceEvaluator(model)
23 | X = np.random.rand(3, 4).astype(np.float32)
24 | Y = np.random.rand(3, 4).astype(np.float32)
25 | expected = ((X - Y) ** 2).sum(keepdims=1)
26 | got = ref.run(None, {"X": X, "Y": Y})[0]
27 | self.assertEqualArray(expected, got)
28 |
29 | @skipif_ci_windows("Unstable on Windows.")
30 | def test_make_euclidean_np(self):
31 | from onnx_array_api.npx import jit_onnx
32 |
33 | def l2_loss(x, y):
34 | return ((x - y) ** 2).sum(keepdims=1)
35 |
36 | jitted_myloss = jit_onnx(l2_loss)
37 | dummy1 = np.array([0], dtype=np.float32)
38 | dummy2 = np.array([1], dtype=np.float32)
39 | # unstable on windows?
40 | jitted_myloss(dummy1, dummy2)
41 | model = jitted_myloss.get_onnx()
42 |
43 | ref = ReferenceEvaluator(model)
44 | X = np.random.rand(3, 4).astype(np.float32)
45 | Y = np.random.rand(3, 4).astype(np.float32)
46 | expected = ((X - Y) ** 2).sum(keepdims=1)
47 | got = ref.run(None, {"x0": X, "x1": Y})[0]
48 | self.assertEqualArray(expected, got)
49 |
50 | def test_make_euclidean_light(self):
51 | from onnx_array_api.light_api import start
52 |
53 | model = (
54 | start()
55 | .vin("X")
56 | .vin("Y")
57 | .bring("X", "Y")
58 | .Sub()
59 | .rename("dxy")
60 | .cst(np.array([2], dtype=np.int64), "two")
61 | .bring("dxy", "two")
62 | .Pow()
63 | .ReduceSum()
64 | .rename("Z")
65 | .vout()
66 | .to_onnx()
67 | )
68 |
69 | ref = ReferenceEvaluator(model)
70 | X = np.random.rand(3, 4).astype(np.float32)
71 | Y = np.random.rand(3, 4).astype(np.float32)
72 | expected = ((X - Y) ** 2).sum(keepdims=1)
73 | got = ref.run(None, {"X": X, "Y": Y})[0]
74 | self.assertEqualArray(expected, got)
75 |
76 | def test_ort_make_euclidean(self):
77 | from onnxruntime import InferenceSession
78 |
79 | model = make_euclidean(opset=18)
80 |
81 | ref = InferenceSession(
82 | model.SerializeToString(), providers=["CPUExecutionProvider"]
83 | )
84 | X = np.random.rand(3, 4).astype(np.float32)
85 | Y = np.random.rand(3, 4).astype(np.float32)
86 | expected = ((X - Y) ** 2).sum(keepdims=1)
87 | got = ref.run(None, {"X": X, "Y": Y})[0]
88 | self.assertEqualArray(expected, got, atol=1e-6)
89 |
90 |
91 | if __name__ == "__main__":
92 | unittest.main(verbosity=2)
93 |
--------------------------------------------------------------------------------
/_unittests/ut_validation/test_tools.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from onnx import load
3 | from onnx.checker import check_model
4 | from onnx_array_api.ext_test_case import ExtTestCase
5 | from onnx_array_api.validation.tools import randomize_proto
6 |
7 |
8 | class TestTools(ExtTestCase):
9 | def test_randomize_proto(self):
10 | data = self.relative_path(__file__, "data", "small.onnx")
11 | with open(data, "rb") as f:
12 | model = load(f)
13 | check_model(model)
14 | rnd = randomize_proto(model)
15 | self.assertEqual(len(model.SerializeToString()), len(rnd.SerializeToString()))
16 | check_model(rnd)
17 |
18 |
19 | if __name__ == "__main__":
20 | unittest.main(verbosity=2)
21 |
--------------------------------------------------------------------------------
/_unittests/ut_xrun_doc/test_command_lines1.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | import unittest
4 | from contextlib import redirect_stdout
5 | from io import StringIO
6 | from onnx import TensorProto
7 | from onnx.helper import (
8 | make_graph,
9 | make_model,
10 | make_node,
11 | make_opsetid,
12 | make_tensor_value_info,
13 | )
14 | from onnx_array_api.ext_test_case import ExtTestCase
15 | from onnx_array_api._command_lines_parser import (
16 | get_main_parser,
17 | get_parser_compare,
18 | get_parser_translate,
19 | get_parser_replace,
20 | main,
21 | )
22 |
23 |
24 | class TestCommandLines1(ExtTestCase):
25 | def test_main_parser(self):
26 | st = StringIO()
27 | with redirect_stdout(st):
28 | get_main_parser().print_help()
29 | text = st.getvalue()
30 | self.assertIn("translate", text)
31 |
32 | def test_parser_translate(self):
33 | st = StringIO()
34 | with redirect_stdout(st):
35 | get_parser_translate().print_help()
36 | text = st.getvalue()
37 | self.assertIn("model", text)
38 |
39 | def test_parser_replace(self):
40 | st = StringIO()
41 | with redirect_stdout(st):
42 | get_parser_replace().print_help()
43 | text = st.getvalue()
44 | self.assertIn("model", text)
45 |
46 | def test_command_translate(self):
47 | X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
48 | Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6])
49 | Z = make_tensor_value_info("Z", TensorProto.FLOAT, [None, None])
50 | graph = make_graph(
51 | [
52 | make_node("Add", ["X", "Y"], ["res"]),
53 | make_node("Cos", ["res"], ["Z"]),
54 | ],
55 | "g",
56 | [X, Y],
57 | [Z],
58 | )
59 | onnx_model = make_model(graph, opset_imports=[make_opsetid("", 18)])
60 |
61 | with tempfile.TemporaryDirectory() as root:
62 | model_file = os.path.join(root, "model.onnx")
63 | with open(model_file, "wb") as f:
64 | f.write(onnx_model.SerializeToString())
65 |
66 | args = ["translate", "-m", model_file]
67 | st = StringIO()
68 | with redirect_stdout(st):
69 | main(args)
70 |
71 | code = st.getvalue()
72 | self.assertIn("model = make_model(", code)
73 |
74 | args = ["translate", "-m", model_file, "-a", "light"]
75 | st = StringIO()
76 | with redirect_stdout(st):
77 | main(args)
78 |
79 | code = st.getvalue()
80 | self.assertIn("start(opset=", code)
81 |
82 | def test_parser_compare(self):
83 | st = StringIO()
84 | with redirect_stdout(st):
85 | get_parser_compare().print_help()
86 | text = st.getvalue()
87 | self.assertIn("model1", text)
88 |
89 | def test_command_compare(self):
90 | X = make_tensor_value_info("X", TensorProto.FLOAT, [5, 6])
91 | Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6])
92 | Z = make_tensor_value_info("Z", TensorProto.FLOAT, [5, 6])
93 | graph = make_graph(
94 | [
95 | make_node("Add", ["X", "Y"], ["res"]),
96 | make_node("Cos", ["res"], ["Z"]),
97 | ],
98 | "g",
99 | [X, Y],
100 | [Z],
101 | )
102 | onnx_model = make_model(graph, opset_imports=[make_opsetid("", 18)])
103 |
104 | with tempfile.TemporaryDirectory() as root:
105 | model_file = os.path.join(root, "model.onnx")
106 | with open(model_file, "wb") as f:
107 | f.write(onnx_model.SerializeToString())
108 |
109 | args = ["compare", "-m1", model_file, "-m2", model_file, "-v", "1"]
110 | st = StringIO()
111 | with redirect_stdout(st):
112 | main(args)
113 |
114 | code = st.getvalue()
115 | self.assertIn("[compare_onnx_execution]", code)
116 | self.assertIn("ADFF", code)
117 |
118 |
119 | if __name__ == "__main__":
120 | unittest.main(verbosity=2)
121 |
--------------------------------------------------------------------------------
/_unittests/ut_xrun_doc/test_documentation_examples.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import os
3 | import sys
4 | import importlib
5 | import subprocess
6 | import time
7 | from onnx_array_api import __file__ as onnx_array_api_file
8 | from onnx_array_api.ext_test_case import ExtTestCase, is_windows
9 |
10 | VERBOSE = 0
11 | ROOT = os.path.realpath(os.path.abspath(os.path.join(onnx_array_api_file, "..", "..")))
12 |
13 |
14 | def import_source(module_file_path, module_name):
15 | if not os.path.exists(module_file_path):
16 | raise FileNotFoundError(module_file_path)
17 | module_spec = importlib.util.spec_from_file_location(module_name, module_file_path)
18 | if module_spec is None:
19 | raise FileNotFoundError(
20 | "Unable to find '{}' in '{}'.".format(module_name, module_file_path)
21 | )
22 | module = importlib.util.module_from_spec(module_spec)
23 | return module_spec.loader.exec_module(module)
24 |
25 |
26 | class TestDocumentationExamples(ExtTestCase):
27 | def run_test(self, fold: str, name: str, verbose=0) -> int:
28 | ppath = os.environ.get("PYTHONPATH", "")
29 | if not ppath:
30 | os.environ["PYTHONPATH"] = ROOT
31 | elif ROOT not in ppath:
32 | sep = ";" if is_windows() else ":"
33 | os.environ["PYTHONPATH"] = ppath + sep + ROOT
34 | perf = time.perf_counter()
35 | try:
36 | mod = import_source(fold, os.path.splitext(name)[0])
37 | assert mod is not None
38 | except FileNotFoundError:
39 | # try another way
40 | cmds = [sys.executable, "-u", os.path.join(fold, name)]
41 | p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
42 | res = p.communicate()
43 | out, err = res
44 | st = err.decode("ascii", errors="ignore")
45 | if st and "Traceback" in st:
46 | if '"dot" not found in path.' in st:
47 | # dot not installed, this part
48 | # is tested in onnx framework
49 | if verbose:
50 | print(f"failed: {name!r} due to missing dot.")
51 | return 0
52 | raise AssertionError( # noqa: B904
53 | "Example '{}' (cmd: {} - exec_prefix='{}') "
54 | "failed due to\n{}"
55 | "".format(name, cmds, sys.exec_prefix, st)
56 | )
57 | dt = time.perf_counter() - perf
58 | if verbose:
59 | print(f"{dt:.3f}: run {name!r}")
60 | return 1
61 |
62 | @classmethod
63 | def add_test_methods(cls):
64 | this = os.path.abspath(os.path.dirname(__file__))
65 | fold = os.path.normpath(os.path.join(this, "..", "..", "_doc", "examples"))
66 | found = os.listdir(fold)
67 | for name in found:
68 | if not name.startswith("plot_") or not name.endswith(".py"):
69 | continue
70 | short_name = os.path.split(os.path.splitext(name)[0])[-1]
71 |
72 | def _test_(self, name=name):
73 | res = self.run_test(fold, name, verbose=VERBOSE)
74 | self.assertTrue(res)
75 |
76 | setattr(cls, f"test_{short_name}", _test_)
77 |
78 |
79 | TestDocumentationExamples.add_test_methods()
80 |
81 | if __name__ == "__main__":
82 | unittest.main(verbosity=2)
83 |
--------------------------------------------------------------------------------
/_unittests/ut_xrun_doc/test_profiling.py:
--------------------------------------------------------------------------------
1 | """
2 | @brief test tree node (time=5s)
3 | """
4 |
5 | import os
6 | import sys
7 | import time
8 | import unittest
9 | from io import StringIO
10 | from pstats import SortKey
11 |
12 | import pandas
13 |
14 | from onnx_array_api import __file__ as rootfile
15 | from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings
16 | from onnx_array_api.profiling import ProfileNode, profile, profile2df, profile2graph
17 |
18 |
19 | class TestProfiling(ExtTestCase):
20 | def test_profile(self):
21 | def simple():
22 | df = pandas.DataFrame(
23 | [{"A": "x", "AA": "xx", "AAA": "xxx"}, {"AA": "xxxxxxx", "AAA": "xxx"}]
24 | )
25 | return df.to_csv(StringIO())
26 |
27 | rootrem = os.path.normpath(
28 | os.path.abspath(os.path.join(os.path.dirname(rootfile), ".."))
29 | )
30 | ps, res = profile(simple, rootrem=rootrem)
31 | res = res.replace("\\", "/")
32 | self.assertIn("function calls", res)
33 | self.assertNotEmpty(ps)
34 |
35 | ps, res = profile(simple)
36 | res = res.replace("\\", "/")
37 | self.assertIn("function calls", res)
38 | self.assertNotEmpty(ps)
39 |
40 | @ignore_warnings(FutureWarning)
41 | def test_profile_df(self):
42 | def simple():
43 | def simple2():
44 | df = pandas.DataFrame(
45 | [
46 | {"A": "x", "AA": "xx", "AAA": "xxx"},
47 | {"AA": "xxxxxxx", "AAA": "xxx"},
48 | ]
49 | )
50 | return df.to_csv(StringIO())
51 |
52 | return simple2()
53 |
54 | rootrem = os.path.normpath(
55 | os.path.abspath(os.path.join(os.path.dirname(rootfile), ".."))
56 | )
57 | ps, df = profile(simple, rootrem=rootrem, as_df=True)
58 | self.assertIsInstance(df, pandas.DataFrame)
59 | self.assertEqual(df.loc[0, "namefct"].split("-")[-1], "simple")
60 | self.assertNotEmpty(ps)
61 | df = profile2df(ps, False)
62 | self.assertIsInstance(df, list)
63 | self.assertIsInstance(df[0], dict)
64 | df = profile2df(ps, True)
65 | self.assertIsInstance(df, pandas.DataFrame)
66 |
67 | def test_profile_df_verbose(self):
68 | calls = [0]
69 |
70 | def f0(t):
71 | calls[0] += 1
72 | time.sleep(t)
73 |
74 | def f1(t):
75 | calls[0] += 1
76 | time.sleep(t)
77 |
78 | def f2():
79 | calls[0] += 1
80 | f1(0.1)
81 | f1(0.01)
82 |
83 | def f3():
84 | calls[0] += 1
85 | f0(0.2)
86 | f1(0.5)
87 |
88 | def f4():
89 | calls[0] += 1
90 | f2()
91 | f3()
92 |
93 | ps = profile(f4)[0]
94 | df = self.capture(lambda: profile2df(ps, verbose=True, fLOG=print))[0]
95 | dfi = df.set_index("fct")
96 | self.assertEqual(dfi.loc["f4", "ncalls1"], 1)
97 | self.assertEqual(dfi.loc["f4", "ncalls2"], 1)
98 |
99 | @unittest.skipIf(sys.version_info[:2] < (3, 7), reason="not supported")
100 | def test_profile_graph(self):
101 | calls = [0]
102 |
103 | def f0(t):
104 | calls[0] += 1
105 | time.sleep(t)
106 |
107 | def f1(t):
108 | calls[0] += 1
109 | time.sleep(t)
110 |
111 | def f2():
112 | calls[0] += 1
113 | f1(0.1)
114 | f1(0.01)
115 |
116 | def f3():
117 | calls[0] += 1
118 | f0(0.2)
119 | f1(0.5)
120 |
121 | def f4():
122 | calls[0] += 1
123 | f2()
124 | f3()
125 |
126 | ps = profile(f4)[0]
127 | profile2df(ps, verbose=False, clean_text=lambda x: x.split("/")[-1])
128 | root, nodes = profile2graph(ps, clean_text=lambda x: x.split("/")[-1])
129 | self.assertEqual(len(nodes), 6)
130 | self.assertIsInstance(nodes, dict)
131 | self.assertIsInstance(root, ProfileNode)
132 | self.assertIn("(", str(root))
133 | dicts = root.as_dict()
134 | self.assertEqual(10, len(dicts))
135 | text = root.to_text()
136 | self.assertIn("1 1", text)
137 | self.assertIn(" f1", text)
138 | text = root.to_text(fct_width=20)
139 | self.assertIn("...", text)
140 | root.to_text(sort_key=SortKey.CUMULATIVE)
141 | root.to_text(sort_key=SortKey.TIME)
142 | self.assertRaise(
143 | lambda: root.to_text(sort_key=SortKey.NAME), NotImplementedError
144 | )
145 | js = root.to_json(indent=2)
146 | self.assertIn('"details"', js)
147 | js = root.to_json(as_str=False)
148 | self.assertIsInstance(js, dict)
149 |
150 | def test_profile_graph_recursive2(self):
151 | def f0(t):
152 | if t < 0.2:
153 | time.sleep(t)
154 | else:
155 | f1(t - 0.1)
156 |
157 | def f1(t):
158 | if t < 0.1:
159 | time.sleep(t)
160 | else:
161 | f0(t)
162 |
163 | def f4():
164 | f1(0.3)
165 |
166 | ps = profile(f4)[0]
167 | profile2df(ps, verbose=False, clean_text=lambda x: x.split("/")[-1])
168 | root, nodes = profile2graph(ps, clean_text=lambda x: x.split("/")[-1])
169 | self.assertEqual(len(nodes), 4)
170 | text = root.to_text()
171 | self.assertIn(" f1", text)
172 | js = root.to_json(indent=2)
173 | self.assertIn('"details"', js)
174 |
175 | def test_profile_graph_recursive1(self):
176 | def f0(t):
177 | if t < 0.1:
178 | time.sleep(t)
179 | else:
180 | f0(t - 0.1)
181 |
182 | def f4():
183 | f0(0.15)
184 |
185 | ps = profile(f4)[0]
186 | profile2df(ps, verbose=False, clean_text=lambda x: x.split("/")[-1])
187 | root, nodes = profile2graph(ps, clean_text=lambda x: x.split("/")[-1])
188 | self.assertEqual(len(nodes), 3)
189 | text = root.to_text()
190 | self.assertIn(" f0", text)
191 | js = root.to_json(indent=2)
192 | self.assertIn('"details"', js)
193 |
194 |
195 | if __name__ == "__main__":
196 | unittest.main()
197 |
--------------------------------------------------------------------------------
/_unittests/win_test_array_api.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | set ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
3 | python -m pytest ../../array-api-tests/array_api_tests/test_creation_functions.py::test_arange || exit 1
4 | python -m pytest ../../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1
5 |
--------------------------------------------------------------------------------
/onnx_array_api/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | APIs to create ONNX Graphs.
3 | """
4 |
5 | __version__ = "0.3.1"
6 | __author__ = "Xavier Dupré"
7 |
--------------------------------------------------------------------------------
/onnx_array_api/__main__.py:
--------------------------------------------------------------------------------
1 | from ._command_lines_parser import main
2 |
3 | if __name__ == "__main__":
4 | main()
5 |
--------------------------------------------------------------------------------
/onnx_array_api/_helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import Any
3 | from onnx import helper, TensorProto
4 |
5 |
6 | def np_dtype_to_tensor_dtype(dtype: Any):
7 | """
8 | Improves :func:`onnx.helper.np_dtype_to_tensor_dtype`.
9 | """
10 | try:
11 | dt = helper.np_dtype_to_tensor_dtype(dtype)
12 | except (KeyError, ValueError):
13 | if dtype == np.float32:
14 | dt = TensorProto.FLOAT
15 | elif dtype == np.float64:
16 | dt = TensorProto.DOUBLE
17 | elif dtype == np.int64:
18 | dt = TensorProto.INT64
19 | elif dtype == np.int32:
20 | dt = TensorProto.INT32
21 | elif dtype == np.int16:
22 | dt = TensorProto.INT16
23 | elif dtype == np.int8:
24 | dt = TensorProto.INT8
25 | elif dtype == np.uint64:
26 | dt = TensorProto.UINT64
27 | elif dtype == np.uint32:
28 | dt = TensorProto.UINT32
29 | elif dtype == np.uint16:
30 | dt = TensorProto.UINT16
31 | elif dtype == np.uint8:
32 | dt = TensorProto.UINT8
33 | elif dtype == np.float16:
34 | dt = TensorProto.FLOAT16
35 | elif dtype in (bool, np.bool_):
36 | dt = TensorProto.BOOL
37 | elif dtype in (str, np.str_):
38 | dt = TensorProto.STRING
39 | elif dtype is int:
40 | dt = TensorProto.INT64
41 | elif dtype is float:
42 | dt = TensorProto.DOUBLE
43 | elif dtype == np.complex64:
44 | dt = TensorProto.COMPLEX64
45 | elif dtype == np.complex128:
46 | dt = TensorProto.COMPLEX128
47 | else:
48 | raise KeyError(f"Unable to guess type for dtype={dtype}.") # noqa: B904
49 | return dt
50 |
--------------------------------------------------------------------------------
/onnx_array_api/annotations.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2 | import numpy as np
3 | from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, TensorShapeProto
4 | from onnx.helper import np_dtype_to_tensor_dtype
5 |
6 | NP_DTYPE = np.dtype
7 | ELEMENT_TYPE = Union[int, NP_DTYPE]
8 | SHAPE_TYPE = Tuple[int, ...]
9 | VAR_CONSTANT_TYPE = Union["Var", TensorProto, np.ndarray]
10 | GRAPH_PROTO = Union[FunctionProto, GraphProto, ModelProto]
11 |
12 | AI_ONNX_ML = "ai.onnx.ml"
13 |
14 | ELEMENT_TYPE_NAME = {
15 | getattr(TensorProto, k): k
16 | for k in dir(TensorProto)
17 | if isinstance(getattr(TensorProto, k), int) and "_" not in k
18 | }
19 |
20 |
21 | class SubDomain:
22 | pass
23 |
24 |
25 | def domain(domain: str, op_type: Optional[str] = None) -> Callable:
26 | """
27 | Registers one operator into a sub domain. It should be used as a
28 | decorator. One example:
29 |
30 | .. code-block:: python
31 |
32 | @domain("ai.onnx.ml")
33 | def Normalizer(self, norm: str = "MAX"):
34 | return self.make_node("Normalizer", self, norm=norm, domain="ai.onnx.ml")
35 | """
36 | names = [op_type]
37 |
38 | def decorate(op_method: Callable) -> Callable:
39 | if names[0] is None:
40 | names[0] = op_method.__name__
41 |
42 | def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
43 | return op_method(self.parent, *args, **kwargs)
44 |
45 | wrapper.__qual__name__ = f"[{domain}]{names[0]}"
46 | wrapper.__name__ = f"[{domain}]{names[0]}"
47 | wrapper.__domain__ = domain
48 | return wrapper
49 |
50 | return decorate
51 |
52 |
53 | _type_numpy = {
54 | np.float32: TensorProto.FLOAT,
55 | np.float64: TensorProto.DOUBLE,
56 | np.float16: TensorProto.FLOAT16,
57 | np.int8: TensorProto.INT8,
58 | np.int16: TensorProto.INT16,
59 | np.int32: TensorProto.INT32,
60 | np.int64: TensorProto.INT64,
61 | np.uint8: TensorProto.UINT8,
62 | np.uint16: TensorProto.UINT16,
63 | np.uint32: TensorProto.UINT32,
64 | np.uint64: TensorProto.UINT64,
65 | np.bool_: TensorProto.BOOL,
66 | np.str_: TensorProto.STRING,
67 | np.complex64: TensorProto.COMPLEX64,
68 | np.complex128: TensorProto.COMPLEX128,
69 | }
70 |
71 |
72 | def elem_type_int(elem_type: ELEMENT_TYPE) -> int:
73 | """
74 | Converts an element type into an onnx element type (int).
75 |
76 | :param elem_type: integer or numpy type
77 | :return: int
78 | """
79 | if isinstance(elem_type, int):
80 | return elem_type
81 | if elem_type in _type_numpy:
82 | return _type_numpy[elem_type]
83 | return np_dtype_to_tensor_dtype(elem_type)
84 |
85 |
86 | def _pick_dim(d, empty_dim):
87 | if d.dim_value:
88 | return d.dim_value
89 | if d.dim_param:
90 | return d.dim_param
91 | return empty_dim
92 |
93 |
94 | def make_shape(shape: TensorShapeProto, empty_dim: Optional[Any] = None) -> SHAPE_TYPE:
95 | "Extracts a shape from a tensor type."
96 | if hasattr(shape, "dim"):
97 | res = [_pick_dim(d, empty_dim=empty_dim) for i, d in enumerate(shape.dim)]
98 | return tuple(res)
99 | return None
100 |
--------------------------------------------------------------------------------
/onnx_array_api/array_api/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, List, Dict
2 | import warnings
3 | import numpy as np
4 | from onnx import TensorProto
5 | from .._helpers import np_dtype_to_tensor_dtype
6 | from ..npx.npx_types import DType
7 | from ..npx import npx_functions
8 |
9 |
10 | supported_functions = [
11 | "abs",
12 | "absolute",
13 | "all",
14 | "any",
15 | "arange",
16 | "asarray",
17 | "astype",
18 | "empty",
19 | "equal",
20 | "eye",
21 | "full",
22 | "full_like",
23 | "isdtype",
24 | "isfinite",
25 | "isinf",
26 | "isnan",
27 | "linspace",
28 | "ones",
29 | "ones_like",
30 | "reshape",
31 | "sum",
32 | "take",
33 | "zeros",
34 | "zeros_like",
35 | ]
36 |
37 |
38 | def _finfo(dtype):
39 | """
40 | Similar to :class:`numpy.finfo`.
41 | """
42 | dt = dtype.np_dtype if isinstance(dtype, DType) else dtype
43 | res = np.finfo(dt)
44 | d = {}
45 | for k, v in res.__dict__.items():
46 | if k.startswith("__"):
47 | continue
48 | if isinstance(v, (np.float32, np.float64, np.float16)):
49 | d[k] = float(v)
50 | elif isinstance(v, (np.complex128, np.complex64)):
51 | d[k] = complex(v)
52 | else:
53 | d[k] = v
54 | d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
55 | nres = type("finfo", (res.__class__,), d)
56 | setattr(nres, "smallest_normal", float(res.smallest_normal)) # noqa: B010
57 | setattr(nres, "tiny", float(res.tiny)) # noqa: B010
58 | return nres
59 |
60 |
61 | def _iinfo(dtype):
62 | """
63 | Similar to :class:`numpy.finfo`.
64 | """
65 | dt = dtype.np_dtype if isinstance(dtype, DType) else dtype
66 | res = np.iinfo(dt)
67 | d = {}
68 | for k, v in res.__dict__.items():
69 | if k.startswith("__"):
70 | continue
71 | if isinstance(
72 | v,
73 | (
74 | np.int16,
75 | np.int32,
76 | np.int64,
77 | np.uint16,
78 | np.uint32,
79 | np.uint64,
80 | np.int8,
81 | np.uint8,
82 | ),
83 | ):
84 | d[k] = int(v)
85 | else:
86 | d[k] = v
87 | d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
88 | nres = type("iinfo", (res.__class__,), d)
89 | setattr(nres, "min", int(res.min)) # noqa: B010
90 | setattr(nres, "max", int(res.max)) # noqa: B010
91 | return nres
92 |
93 |
94 | def array_api_wrap_function(f: Callable, TEagerTensor: type) -> Callable:
95 | """
96 | Converts an eager function takeing EagerTensor into a function
97 | available through an Array API.
98 |
99 | :param callable: function
100 | :param TEagerTensor: EagerTensor class
101 | :return: new function
102 | """
103 |
104 | def wrap(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
105 | new_args = []
106 | for a in args:
107 | if isinstance(a, np.ndarray):
108 | b = TEagerTensor(a)
109 | else:
110 | b = a
111 | new_args.append(b)
112 | res = f(TEagerTensor, *new_args, **kwargs)
113 | return res
114 |
115 | wrap.__doc__ = f.__doc__
116 | return wrap
117 |
118 |
119 | def _finalize_array_api(module, function_names, TEagerTensor):
120 | """
121 | Adds common attributes to Array API defined in this modules
122 | such as types.
123 | """
124 | from . import _onnx_common
125 |
126 | module.float16 = DType(TensorProto.FLOAT16)
127 | module.float32 = DType(TensorProto.FLOAT)
128 | module.float64 = DType(TensorProto.DOUBLE)
129 | module.complex64 = DType(TensorProto.COMPLEX64)
130 | module.complex128 = DType(TensorProto.COMPLEX128)
131 | module.int8 = DType(TensorProto.INT8)
132 | module.int16 = DType(TensorProto.INT16)
133 | module.int32 = DType(TensorProto.INT32)
134 | module.int64 = DType(TensorProto.INT64)
135 | module.uint8 = DType(TensorProto.UINT8)
136 | module.uint16 = DType(TensorProto.UINT16)
137 | module.uint32 = DType(TensorProto.UINT32)
138 | module.uint64 = DType(TensorProto.UINT64)
139 | module.bfloat16 = DType(TensorProto.BFLOAT16)
140 | setattr(module, "bool", DType(TensorProto.BOOL)) # noqa: B010
141 | setattr(module, "str", DType(TensorProto.STRING)) # noqa: B010
142 | setattr(module, "finfo", _finfo) # noqa: B010
143 | setattr(module, "iinfo", _iinfo) # noqa: B010
144 |
145 | if function_names is None:
146 | function_names = supported_functions
147 |
148 | for name in function_names:
149 | f = getattr(_onnx_common, name, None)
150 | if f is None:
151 | f2 = getattr(npx_functions, name, None)
152 | if f2 is None:
153 | warnings.warn(
154 | f"Function {name!r} is not available in {module!r}.",
155 | stacklevel=0,
156 | )
157 | continue
158 | f = lambda TEagerTensor, *args, _f=f2, **kwargs: _f( # noqa: E731
159 | *args, **kwargs
160 | )
161 | setattr(module, name, array_api_wrap_function(f, TEagerTensor))
162 |
--------------------------------------------------------------------------------
/onnx_array_api/array_api/onnx_numpy.py:
--------------------------------------------------------------------------------
1 | from ..npx.npx_numpy_tensors import EagerNumpyTensor
2 | from . import _finalize_array_api
3 |
4 |
5 | def _finalize():
6 | """
7 | Adds common attributes to Array API defined in this modules
8 | such as types.
9 | """
10 | from . import onnx_numpy
11 |
12 | _finalize_array_api(onnx_numpy, None, EagerNumpyTensor)
13 |
14 |
15 | _finalize()
16 |
--------------------------------------------------------------------------------
/onnx_array_api/array_api/onnx_ort.py:
--------------------------------------------------------------------------------
1 | from ..ort.ort_tensors import EagerOrtTensor
2 | from . import _finalize_array_api
3 |
4 |
5 | def _finalize():
6 | """
7 | Adds common attributes to Array API defined in this modules
8 | such as types.
9 | """
10 | from . import onnx_ort
11 |
12 | _finalize_array_api(onnx_ort, None, EagerOrtTensor)
13 |
14 |
15 | _finalize()
16 |
--------------------------------------------------------------------------------
/onnx_array_api/cache.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 |
4 | def get_cache_file(filename: str, remove: bool = False):
5 | """
6 | Returns a name in the cache folder `~/.onnx-array-api`.
7 |
8 | :param filename: filename
9 | :param remove: remove if exists
10 | :return: full filename
11 | """
12 | home = Path.home()
13 | folder = home / ".onnx-array-api"
14 | if not folder.exists():
15 | folder.mkdir()
16 | name = folder / filename
17 | if name.exists():
18 | name.unlink()
19 | return name
20 |
--------------------------------------------------------------------------------
/onnx_array_api/graph_api/__init__.py:
--------------------------------------------------------------------------------
1 | from .graph_builder import GraphBuilder, NodePattern
2 |
--------------------------------------------------------------------------------
/onnx_array_api/light_api/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 | from onnx import ModelProto
3 | from ..annotations import domain
4 | from .model import OnnxGraph, ProtoType
5 | from .var import Var, Vars
6 |
7 |
8 | def start(
9 | opset: Optional[int] = None,
10 | opsets: Optional[Dict[str, int]] = None,
11 | ir_version: Optional[int] = None,
12 | ) -> OnnxGraph:
13 | """
14 | Starts an onnx model.
15 |
16 | :param opset: main opset version
17 | :param opsets: others opsets as a dictionary
18 | :param ir_version: specify the ir_version as well
19 | :return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`
20 |
21 | A very simple model:
22 |
23 | .. runpython::
24 | :showcode:
25 |
26 | from onnx_array_api.light_api import start
27 |
28 | onx = start().vin("X").Neg().rename("Y").vout().to_onnx()
29 | print(onx)
30 |
31 | Another with operator Add:
32 |
33 | .. runpython::
34 | :showcode:
35 |
36 | from onnx_array_api.light_api import start
37 |
38 | onx = (
39 | start()
40 | .vin("X")
41 | .vin("Y")
42 | .bring("X", "Y")
43 | .Add()
44 | .rename("Z")
45 | .vout()
46 | .to_onnx()
47 | )
48 | print(onx)
49 | """
50 | return OnnxGraph(opset=opset, opsets=opsets, ir_version=ir_version)
51 |
52 |
53 | def g() -> OnnxGraph:
54 | """
55 | Starts a subgraph.
56 | :return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`
57 | """
58 | return OnnxGraph(proto_type=ProtoType.GRAPH)
59 |
--------------------------------------------------------------------------------
/onnx_array_api/npx/__init__.py:
--------------------------------------------------------------------------------
1 | from .npx_core_api import cst, make_tuple, npxapi_function, npxapi_inline, var
2 | from .npx_functions import *
3 | from .npx_jit_eager import eager_onnx, jit_onnx
4 | from .npx_types import ElemType, OptParType, ParType, SequenceType, TensorType
5 |
--------------------------------------------------------------------------------
/onnx_array_api/npx/npx_constants.py:
--------------------------------------------------------------------------------
1 | DEFAULT_OPSETS = {"": 18, "ai.onnx.ml": 3}
2 | FUNCTION_DOMAIN = "FUNCTION-DOMAIN"
3 | ONNX_DOMAIN = "ONNX-DOMAIN"
4 |
5 | _OPSET_TO_IR_VERSION = {
6 | 14: 7,
7 | 15: 8,
8 | 16: 8,
9 | 17: 8,
10 | 18: 8,
11 | 19: 9,
12 | }
13 |
14 | DEFAULT_IR_VERSION = _OPSET_TO_IR_VERSION[DEFAULT_OPSETS[""]]
15 |
--------------------------------------------------------------------------------
/onnx_array_api/npx/npx_function_implementation.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Tuple
2 |
3 | from onnx import FunctionProto, ValueInfoProto
4 | from onnx.helper import make_function, make_graph, make_node, make_opsetid
5 |
6 | from .npx_constants import FUNCTION_DOMAIN
7 |
8 |
9 | def get_function_implementation(
10 | domop: Tuple[str, str],
11 | node_inputs: List[str],
12 | node_outputs: List[str],
13 | opsets: Dict[str, int],
14 | **kwargs: Any,
15 | ) -> FunctionProto:
16 | """
17 | Returns a :class:`onnx.FunctionProto` for a specific proto.
18 |
19 | :param domop: domain, function
20 | :param node_inputs: list of input names
21 | :param node_outputs: list of output names
22 | :param opsets: available opsets
23 | :kwargs: any other parameters
24 | :return: FunctionProto
25 | """
26 | if domop[0] != FUNCTION_DOMAIN:
27 | raise ValueError(
28 | f"This function only considers function for domain "
29 | f"{FUNCTION_DOMAIN!r} not {domop[0]!r}."
30 | )
31 | if domop[1] == "CDist":
32 | return _get_cdist_implementation(node_inputs, node_outputs, opsets, **kwargs)
33 | raise ValueError(f"Unable to return an implementation of function {domop!r}.")
34 |
35 |
36 | def _get_cdist_implementation(
37 | node_inputs: List[str],
38 | node_outputs: List[str],
39 | opsets: Dict[str, int],
40 | **kwargs: Any,
41 | ) -> FunctionProto:
42 | """
43 | Returns the CDist implementation as a function.
44 | """
45 | if len(node_inputs) != 2:
46 | raise ValueError(f"cdist has two inputs not {len(node_inputs)}.")
47 | if len(node_outputs) != 1:
48 | raise ValueError(f"cdist has one outputs not {len(node_outputs)}.")
49 | if opsets is None:
50 | raise ValueError("opsets cannot be None.")
51 | if "" not in opsets:
52 | raise ValueError(
53 | "Opsets for domain '' must be specified but opsets={opsets!r}."
54 | )
55 | if set(kwargs) != {"metric"}:
56 | raise ValueError(f"kwargs={kwargs} must contain metric and only metric.")
57 | metric = kwargs["metric"]
58 | if opsets is not None and "com.microsoft" in opsets:
59 | node = make_node(
60 | "CDist", ["xa", "xb"], ["z"], domain="com.microsoft", metric=metric
61 | )
62 | return make_function(
63 | "npx",
64 | f"CDist_{metric}",
65 | ["xa", "xb"],
66 | ["z"],
67 | [node],
68 | [make_opsetid("com.microsoft", 1)],
69 | )
70 |
71 | if metric in ("euclidean", "sqeuclidean"):
72 | # subgraph
73 | nodes = [
74 | make_node("Sub", ["next", "next_in"], ["diff"]),
75 | make_node("Constant", [], ["axis"], value_ints=[1]),
76 | make_node("ReduceSumSquare", ["diff", "axis"], ["scan_out"], keepdims=0),
77 | make_node("Identity", ["next_in"], ["next_out"]),
78 | ]
79 |
80 | def make_value(name):
81 | value = ValueInfoProto()
82 | value.name = name
83 | return value
84 |
85 | graph = make_graph(
86 | nodes,
87 | "loop",
88 | [make_value("next_in"), make_value("next")],
89 | [make_value("next_out"), make_value("scan_out")],
90 | )
91 |
92 | scan = make_node(
93 | "Scan", ["xb", "xa"], ["next_out", "zout"], num_scan_inputs=1, body=graph
94 | )
95 | if metric == "euclidean":
96 | final = make_node("Sqrt", ["zout"], ["z"])
97 | else:
98 | final = make_node("Identity", ["zout"], ["z"])
99 | return make_function(
100 | "npx",
101 | f"CDist_{metric}",
102 | ["xa", "xb"],
103 | ["z"],
104 | [scan, final],
105 | [make_opsetid("", opsets[""])],
106 | )
107 |
108 | raise RuntimeError(
109 | f"There is no implementation for cdist and metric={metric!r} yet."
110 | )
111 |
--------------------------------------------------------------------------------
/onnx_array_api/npx/npx_functions_test.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 |
5 | from .npx_core_api import (
6 | cst,
7 | make_tuple,
8 | npxapi_function,
9 | npxapi_inline,
10 | tuple_var,
11 | var,
12 | )
13 | from .npx_types import (
14 | ElemType,
15 | OptParType,
16 | ParType,
17 | SequenceType,
18 | TensorType,
19 | TupleType,
20 | )
21 |
22 |
23 | @npxapi_function
24 | def _min_max(
25 | x: TensorType[ElemType.numerics, "T"],
26 | ) -> TupleType[TensorType[ElemType.numerics, "T"], TensorType[ElemType.numerics, "T"]]:
27 | return tuple_var(var(x, op="ReduceMin"), var(x, op="ReduceMax"))
28 |
29 |
30 | @npxapi_inline
31 | def _min_max_inline(
32 | x: TensorType[ElemType.numerics, "T"],
33 | ) -> TupleType[TensorType[ElemType.numerics, "T"], TensorType[ElemType.numerics, "T"]]:
34 | return tuple_var(var(x, op="ReduceMin"), var(x, op="ReduceMax"))
35 |
36 |
37 | @npxapi_function
38 | def absolute(
39 | x: TensorType[ElemType.numerics, "T"],
40 | ) -> TensorType[ElemType.numerics, "T"]:
41 | "See :func:`numpy.absolute`."
42 | return var(x, op="Abs")
43 |
44 |
45 | @npxapi_function
46 | def addition(
47 | x: TensorType[ElemType.numerics, "T"], y: TensorType[ElemType.numerics, "T"]
48 | ) -> TensorType[ElemType.numerics, "T"]:
49 | "See :func:`numpy.addition`."
50 | return var(x, y, op="Add")
51 |
52 |
53 | @npxapi_function
54 | def argmin(
55 | x: TensorType[ElemType.numerics, "T"],
56 | axis: OptParType[int] = 0,
57 | keepdims: OptParType[int] = 0,
58 | ) -> TensorType[ElemType.numerics, "T"]:
59 | """
60 | See :func:`numpy.argmin`.
61 | """
62 | return var(x, op="ArgMin", axis=axis, keepdims=keepdims)
63 |
64 |
65 | @npxapi_function
66 | def concat(
67 | *x: SequenceType[TensorType[ElemType.numerics, "T"]], axis: ParType[int] = 0
68 | ) -> TensorType[ElemType.numerics, "T"]:
69 | """
70 | Operator concat, handle :func:`numpy.vstack` and
71 | :func:`numpy.hstack`.
72 | """
73 | if len(x) <= 1:
74 | raise RuntimeError(f"N={len(x)}<=1 elements to concatenate.")
75 | return var(*x, op="Concat", axis=axis)
76 |
77 |
78 | @npxapi_function
79 | def copy(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]:
80 | "Makes a copy."
81 | return var(x, op="Identity")
82 |
83 |
84 | @npxapi_function
85 | def log1p(x: TensorType[ElemType.floats, "T"]) -> TensorType[ElemType.floats, "T"]:
86 | "See :func:`numpy.log1p`."
87 | x1 = var(x, var(cst(np.array([1], dtype=np.int64)), x, op="CastLike"), op="Add")
88 | return var(x1, op="Log")
89 |
90 |
91 | @npxapi_function
92 | def negative(
93 | x: TensorType[ElemType.numerics, "T"],
94 | ) -> TensorType[ElemType.numerics, "T"]:
95 | "See :func:`numpy.negative`."
96 | return var(x, op="Neg")
97 |
98 |
99 | @npxapi_function
100 | def relu(
101 | x: TensorType[ElemType.numerics, "T"],
102 | ) -> TensorType[ElemType.numerics, "T"]:
103 | "See :func:`numpy.addition`."
104 | return var(var(absolute(x), x, op="Add"), var(cst(2), x, op="CastLike"), op="Div")
105 |
106 |
107 | @npxapi_function
108 | def topk(
109 | x: TensorType[ElemType.numerics, "T"],
110 | k: TensorType[ElemType.int64, "I", (1,)],
111 | axis: OptParType[int] = -1,
112 | largest: OptParType[int] = 1,
113 | sorted: OptParType[int] = 1,
114 | ) -> TupleType[TensorType[ElemType.numerics, "T"], TensorType[ElemType.int64, "I"]]:
115 | "See :func:`numpy.argsort`."
116 | return make_tuple(2, x, k, op="TopK", axis=axis, largest=largest, sorted=sorted)
117 |
118 |
119 | @npxapi_function
120 | def transpose(
121 | x: TensorType[ElemType.numerics, "T"], perm: ParType[Tuple[int]] = (1, 0)
122 | ) -> TensorType[ElemType.numerics, "T"]:
123 | "See :func:`numpy.transpose`."
124 | return var(x, op="Transpose", perm=list(perm))
125 |
--------------------------------------------------------------------------------
/onnx_array_api/ort/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/onnx_array_api/ort/__init__.py
--------------------------------------------------------------------------------
/onnx_array_api/ort/ort_optimizers.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Optional
2 | from onnx import ModelProto, load
3 | from onnxruntime import InferenceSession, SessionOptions
4 | from onnxruntime.capi._pybind_state import GraphOptimizationLevel
5 | from ..cache import get_cache_file
6 |
7 |
8 | def ort_optimized_model(
9 | onx: Union[str, ModelProto],
10 | level: str = "ORT_ENABLE_ALL",
11 | output: Optional[str] = None,
12 | ) -> Union[str, ModelProto]:
13 | """
14 | Returns the optimized model used by onnxruntime before
15 | running computing the inference.
16 |
17 | :param onx: ModelProto
18 | :param level: optimization level, `'ORT_ENABLE_BASIC'`,
19 | `'ORT_ENABLE_EXTENDED'`, `'ORT_ENABLE_ALL'`
20 | :param output: output file if the proposed cache is not wanted
21 | :return: optimized model
22 | """
23 | glevel = getattr(GraphOptimizationLevel, level, None)
24 | if glevel is None:
25 | raise ValueError(
26 | f"Unrecognized level {level!r} among {dir(GraphOptimizationLevel)}."
27 | )
28 |
29 | if output is not None:
30 | cache = output
31 | else:
32 | cache = get_cache_file("ort_optimized_model.onnx", remove=True)
33 | so = SessionOptions()
34 | so.graph_optimization_level = glevel
35 | so.optimized_model_filepath = str(cache)
36 | InferenceSession(
37 | onx if isinstance(onx, str) else onx.SerializeToString(),
38 | so,
39 | providers=["CPUExecutionProvider"],
40 | )
41 | if output is None and not cache.exists():
42 | raise RuntimeError(f"The optimized model {str(cache)!r} not found.")
43 | if output is not None:
44 | return output
45 | if isinstance(onx, str):
46 | return str(cache)
47 | opt_onx = load(str(cache))
48 | cache.unlink()
49 | return opt_onx
50 |
--------------------------------------------------------------------------------
/onnx_array_api/plotting/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdpython/onnx-array-api/96eb50e002a6529c0c10e62414f960cecd62f0c3/onnx_array_api/plotting/__init__.py
--------------------------------------------------------------------------------
/onnx_array_api/plotting/stat_plot.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional
2 | import pandas
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | def plot_ort_profile(
7 | df: pandas.DataFrame,
8 | ax0: Optional[Any] = None,
9 | ax1: Optional[Any] = None,
10 | title: Optional[str] = None,
11 | ) -> Any:
12 | """
13 | Plots time spend in computation based on dataframe
14 | produced by function :func:`ort_profile
15 | `.
16 |
17 | :param df: dataframe
18 | :param ax0: first axis to draw time
19 | :param ax1: second axis to draw occurences
20 | :param title: graph title
21 | :return: ax0
22 |
23 | See :ref:`l-example-ort-profiling` for an example.
24 | """
25 | if ax0 is None:
26 | ax0 = plt.gca() # pragma: no cover
27 |
28 | gr_dur = (
29 | df[["dur", "args_op_name"]].groupby("args_op_name").sum().sort_values("dur")
30 | )
31 | gr_dur.plot.barh(ax=ax0)
32 | if title is not None:
33 | ax0.set_title(title)
34 | if ax1 is not None:
35 | gr_n = (
36 | df[["dur", "args_op_name"]]
37 | .groupby("args_op_name")
38 | .count()
39 | .sort_values("dur")
40 | )
41 | gr_n = gr_n.loc[gr_dur.index, :]
42 | gr_n.plot.barh(ax=ax1)
43 | ax1.set_title("n occurences")
44 | return ax0
45 |
--------------------------------------------------------------------------------
/onnx_array_api/reference/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import numpy as np
3 | from onnx import TensorProto
4 | from onnx.numpy_helper import from_array as onnx_from_array
5 | from onnx.reference.ops.op_cast import (
6 | bfloat16,
7 | float8e4m3fn,
8 | float8e4m3fnuz,
9 | float8e5m2,
10 | float8e5m2fnuz,
11 | )
12 | from onnx.reference.op_run import to_array_extended
13 | from .evaluator import ExtendedReferenceEvaluator
14 | from .evaluator_yield import (
15 | DistanceExecution,
16 | ResultExecution,
17 | ResultType,
18 | YieldEvaluator,
19 | compare_onnx_execution,
20 | )
21 |
22 |
23 | def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorProto:
24 | """
25 | Converts an array into a TensorProto.
26 |
27 | :param tensor: numpy array
28 | :param name: name
29 | :return: TensorProto
30 | """
31 | dt = tensor.dtype
32 | if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn":
33 | to = TensorProto.FLOAT8E4M3FN
34 | dt_to = np.uint8
35 | elif dt == float8e4m3fnuz and dt.descr[0][0] == "e4m3fnuz":
36 | to = TensorProto.FLOAT8E4M3FNUZ
37 | dt_to = np.uint8
38 | elif dt == float8e5m2 and dt.descr[0][0] == "e5m2":
39 | to = TensorProto.FLOAT8E5M2
40 | dt_to = np.uint8
41 | elif dt == float8e5m2fnuz and dt.descr[0][0] == "e5m2fnuz":
42 | to = TensorProto.FLOAT8E5M2FNUZ
43 | dt_to = np.uint8
44 | elif dt == bfloat16 and dt.descr[0][0] == "bfloat16":
45 | to = TensorProto.BFLOAT16
46 | dt_to = np.uint16
47 | else:
48 | return onnx_from_array(tensor, name)
49 |
50 | t = onnx_from_array(tensor.astype(dt_to), name)
51 | t.data_type = to
52 | return t
53 |
--------------------------------------------------------------------------------
/onnx_array_api/reference/evaluator.py:
--------------------------------------------------------------------------------
1 | from logging import getLogger
2 | from typing import Any, Dict, List, Optional, Union
3 | from onnx import FunctionProto, ModelProto
4 | from onnx.defs import get_schema
5 | from onnx.reference import ReferenceEvaluator
6 | from onnx.reference.op_run import OpRun
7 | from .ops.op_cast_like import CastLike_15, CastLike_19
8 | from .ops.op_concat import Concat
9 | from .ops.op_constant_of_shape import ConstantOfShape
10 | from .ops.op_fused_matmul import FusedMatMul
11 | from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
12 | from .ops.op_quick_gelu import QuickGelu
13 | from .ops.op_scatter_elements import ScatterElements
14 |
15 |
16 | logger = getLogger("onnx-array-api-eval")
17 |
18 |
19 | class ExtendedReferenceEvaluator(ReferenceEvaluator):
20 | """
21 | This class replaces the python implementation by custom implementation.
22 | The Array API extends many operator to all types not supported
23 | by the onnx specifications. The evaluator allows to test
24 | scenarios outside what an onnx backend bound to the official onnx
25 | operators definition could do.
26 |
27 | ::
28 |
29 | from onnx.reference import ReferenceEvaluator
30 | from onnx.reference.c_ops import Conv
31 | ref = ReferenceEvaluator(..., new_ops=[Conv])
32 | """
33 |
34 | default_ops = [
35 | Concat,
36 | CastLike_15,
37 | CastLike_19,
38 | ConstantOfShape,
39 | FusedMatMul,
40 | MemcpyFromHost,
41 | MemcpyToHost,
42 | QuickGelu,
43 | ScatterElements,
44 | ]
45 |
46 | @staticmethod
47 | def filter_ops(proto, new_ops, opsets):
48 | if opsets is None and isinstance(proto, (ModelProto, FunctionProto)):
49 | opsets = {d.domain: d.version for d in proto.opset_import}
50 | best = {}
51 | renamed = {}
52 | for cl in new_ops:
53 | if "_" not in cl.__name__:
54 | continue
55 | vers = cl.__name__.split("_")
56 | try:
57 | v = int(vers[-1])
58 | except ValueError:
59 | # not a version
60 | continue
61 | if opsets is not None and v > opsets.get(cl.op_domain, 1):
62 | continue
63 | renamed[cl.__name__] = cl
64 | key = cl.op_domain, "_".join(vers[:-1])
65 | if key not in best or best[key][0] < v:
66 | best[key] = (v, cl)
67 |
68 | modified = []
69 | for cl in new_ops:
70 | if cl.__name__ not in renamed:
71 | modified.append(cl)
72 | for k, v in best.items():
73 | atts = {"domain": k[0]}
74 | bases = (v[1],)
75 | if not hasattr(v[1], "op_schema"):
76 | atts["op_schema"] = get_schema(k[1], v[0], domain=v[1].op_domain)
77 | new_cl = type(k[1], bases, atts)
78 | modified.append(new_cl)
79 |
80 | new_ops = modified
81 | return new_ops
82 |
83 | def __init__(
84 | self,
85 | proto: Any,
86 | opsets: Optional[Dict[str, int]] = None,
87 | functions: Optional[List[Union[ReferenceEvaluator, FunctionProto]]] = None,
88 | verbose: int = 0,
89 | new_ops: Optional[List[OpRun]] = None,
90 | **kwargs,
91 | ):
92 | if new_ops is None:
93 | new_ops = ExtendedReferenceEvaluator.default_ops
94 | else:
95 | new_ops = new_ops.copy()
96 | new_ops.extend(ExtendedReferenceEvaluator.default_ops)
97 | new_ops = ExtendedReferenceEvaluator.filter_ops(proto, new_ops, opsets)
98 |
99 | ReferenceEvaluator.__init__(
100 | self,
101 | proto,
102 | opsets=opsets,
103 | functions=functions,
104 | verbose=verbose,
105 | new_ops=new_ops,
106 | **kwargs,
107 | )
108 |
109 | def _log(self, level: int, pattern: str, *args: List[Any]) -> None:
110 | if level < self.verbose:
111 | new_args = [self._log_arg(a) for a in args]
112 | print(pattern % tuple(new_args))
113 | else:
114 | logger.debug(pattern, *args)
115 |
116 | def run(self, *args, **kwargs):
117 | """
118 | See :meth:`onnx.reference.ReferenceEvaluator.run`.
119 | """
120 | if len(args) == 1 and isinstance(args[0], list):
121 | feeds = dict(zip(self.input_names, args[0]))
122 | return self.run(None, feeds, **kwargs)
123 | return ReferenceEvaluator.run(self, *args, **kwargs)
124 |
--------------------------------------------------------------------------------
/onnx_array_api/reference/ops/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/onnx_array_api/reference/ops/op_cast_like.py:
--------------------------------------------------------------------------------
1 | from onnx.helper import np_dtype_to_tensor_dtype
2 | from onnx.onnx_pb import TensorProto
3 | from onnx.reference.op_run import OpRun
4 | from onnx.reference.ops.op_cast import (
5 | bfloat16,
6 | cast_to,
7 | float8e4m3fn,
8 | float8e4m3fnuz,
9 | float8e5m2,
10 | float8e5m2fnuz,
11 | )
12 |
13 |
14 | def _cast_like(x, y, saturate):
15 | if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
16 | # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
17 | to = TensorProto.BFLOAT16
18 | elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn":
19 | to = TensorProto.FLOAT8E4M3FN
20 | elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz":
21 | to = TensorProto.FLOAT8E4M3FNUZ
22 | elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2":
23 | to = TensorProto.FLOAT8E5M2
24 | elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz":
25 | to = TensorProto.FLOAT8E5M2FNUZ
26 | else:
27 | to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore
28 | return (cast_to(x, to, saturate),)
29 |
30 |
31 | class CastLike_15(OpRun):
32 | def _run(self, x, y): # type: ignore
33 | return _cast_like(x, y, True)
34 |
35 |
36 | class CastLike_19(OpRun):
37 | def _run(self, x, y, saturate=None): # type: ignore
38 | return _cast_like(x, y, saturate)
39 |
--------------------------------------------------------------------------------
/onnx_array_api/reference/ops/op_concat.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from onnx.reference.op_run import OpRun
4 |
5 |
6 | class Concat(OpRun):
7 | def _preprocess(self, a: np.ndarray, axis: int) -> np.ndarray:
8 | if axis >= len(a.shape): # type: ignore
9 | new_shape = a.shape + (1,) * (axis + 1 - len(a.shape)) # type: ignore
10 | return a.reshape(new_shape)
11 | return a
12 |
13 | def _run(self, *args, axis=None): # type: ignore
14 | targs = tuple(self._preprocess(a, axis) for a in args)
15 | return (np.concatenate(targs, axis),) # type: ignore
16 |
--------------------------------------------------------------------------------
/onnx_array_api/reference/ops/op_constant_of_shape.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from onnx.reference.op_run import OpRun
3 |
4 |
5 | class ConstantOfShape(OpRun):
6 | @staticmethod
7 | def _process(value):
8 | cst = value[0] if isinstance(value, np.ndarray) and value.size > 0 else value
9 | if isinstance(value, np.ndarray):
10 | if not value.shape:
11 | cst = value
12 | elif value.size > 0:
13 | cst = value.ravel()[0]
14 | else:
15 | raise ValueError(f"Unexpected fill_value={value!r}")
16 | if isinstance(cst, bool):
17 | cst = np.bool_(cst)
18 | elif isinstance(cst, int):
19 | cst = np.int64(cst)
20 | elif isinstance(cst, float):
21 | cst = np.float64(cst)
22 | elif isinstance(cst, complex):
23 | cst = np.complex128(cst)
24 | elif cst is None:
25 | cst = np.float32(0)
26 | if not isinstance(
27 | cst,
28 | (
29 | np.float16,
30 | np.float32,
31 | np.float64,
32 | np.complex64,
33 | np.complex128,
34 | np.int64,
35 | np.int32,
36 | np.int16,
37 | np.int8,
38 | np.uint64,
39 | np.uint32,
40 | np.uint16,
41 | np.uint8,
42 | np.bool_,
43 | ),
44 | ):
45 | raise TypeError(f"value must be a real not {type(cst)}")
46 | return cst
47 |
48 | def _run(self, data, value=None):
49 | cst = self._process(value)
50 | try:
51 | res = np.full(tuple(data), cst)
52 | except TypeError as e:
53 | raise RuntimeError(
54 | f"Unable to create a constant of shape "
55 | f"{data!r} with value {cst!r} "
56 | f"(raw value={value!r})."
57 | ) from e
58 | return (res,)
59 |
--------------------------------------------------------------------------------
/onnx_array_api/reference/ops/op_fused_matmul.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from onnx.reference.op_run import OpRun
3 |
4 |
5 | class FusedMatMul(OpRun):
6 | op_domain = "com.microsoft"
7 |
8 | def _run(
9 | self,
10 | A,
11 | B,
12 | alpha: float = 1,
13 | transA: int = 0,
14 | transB: int = 0,
15 | transBatchA: int = 0,
16 | transBatchB: int = 0,
17 | ):
18 | assert (
19 | transBatchA == 0
20 | ), f"Not implemented for transBatchA==1 and {A.shape}x{B.shape}"
21 | assert (
22 | transBatchB == 0
23 | ), f"Not implemented for transBatchB==1 and {A.shape}x{B.shape}"
24 | if transA:
25 | perm = list(range(len(A.shape)))
26 | dim = len(perm)
27 | perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2]
28 | A = np.transpose(A, perm)
29 | if transB:
30 | perm = list(range(len(B.shape)))
31 | dim = len(perm)
32 | perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2]
33 | B = np.transpose(B, perm)
34 | a = np.array(alpha, dtype=A.dtype)
35 | return (np.matmul(A, B) * a,)
36 |
--------------------------------------------------------------------------------
/onnx_array_api/reference/ops/op_memcpy_host.py:
--------------------------------------------------------------------------------
1 | from onnx.reference.op_run import OpRun
2 |
3 |
4 | class MemcpyFromHost(OpRun):
5 | def _run(self, x):
6 | return (x,)
7 |
8 |
9 | class MemcpyToHost(OpRun):
10 | def _run(self, x):
11 | return (x,)
12 |
--------------------------------------------------------------------------------
/onnx_array_api/reference/ops/op_quick_gelu.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from onnx.reference.op_run import OpRun
3 |
4 |
5 | def sigmoid(x): # type: ignore
6 | if x > 0:
7 | return 1 / (1 + np.exp(-x))
8 | return np.exp(x) / (1 + np.exp(x))
9 |
10 |
11 | class QuickGelu(OpRun):
12 | op_domain = "com.microsoft"
13 |
14 | def __init__(self, onnx_node, run_params): # type: ignore
15 | OpRun.__init__(self, onnx_node, run_params)
16 | self.vf = np.vectorize(sigmoid)
17 |
18 | def _run(self, X, alpha=1.0):
19 | if len(X.shape) == 0:
20 | return ((X * sigmoid(X * alpha)).astype(X.dtype),)
21 | if X.size == 0:
22 | return (X,)
23 | return ((X * self.vf(X * alpha)).astype(X.dtype),)
24 |
--------------------------------------------------------------------------------
/onnx_array_api/reference/ops/op_scatter_elements.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from onnx.reference.op_run import OpRun
4 |
5 |
6 | def scatter_elements(data, indices, updates, axis=0, reduction=None): # type: ignore
7 | if reduction == "add":
8 |
9 | def f(x, y):
10 | return x + y
11 |
12 | elif reduction == "min":
13 |
14 | def f(x, y):
15 | return min(x, y)
16 |
17 | elif reduction == "max":
18 |
19 | def f(x, y):
20 | return max(x, y)
21 |
22 | else:
23 |
24 | def f(x, y):
25 | return y
26 |
27 | if axis < 0:
28 | axis = data.ndim + axis
29 |
30 | if len(data.shape) == 1 and axis == 0:
31 | scattered = np.copy(data)
32 | for pos, up in zip(indices, updates):
33 | scattered[pos] = f(scattered[pos], up)
34 | return scattered
35 |
36 | if len(indices.shape) == 2:
37 | scattered = np.copy(data)
38 | if axis == 0:
39 | for i in range(indices.shape[0]):
40 | for j in range(indices.shape[1]):
41 | scattered[indices[i, j], j] = f(
42 | scattered[indices[i, j], j], updates[i, j]
43 | )
44 | else:
45 | for i in range(indices.shape[0]):
46 | for j in range(indices.shape[1]):
47 | scattered[i, indices[i, j]] = f(
48 | scattered[i, indices[i, j]], updates[i, j]
49 | )
50 | return scattered
51 |
52 | if len(indices.shape) == 3:
53 | scattered = np.copy(data)
54 | if axis == 0:
55 | for i in range(indices.shape[0]):
56 | for j in range(indices.shape[1]):
57 | for k in range(indices.shape[2]):
58 | scattered[indices[i, j, k], j, k] = f(
59 | scattered[indices[i, j, k], j, k], updates[i, j, k]
60 | )
61 | elif axis == 1:
62 | for i in range(indices.shape[0]):
63 | for j in range(indices.shape[1]):
64 | for k in range(indices.shape[2]):
65 | scattered[i, indices[i, j, k], k] = f(
66 | scattered[i, indices[i, j, k], k], updates[i, j, k]
67 | )
68 | elif axis == 2:
69 | for i in range(indices.shape[0]):
70 | for j in range(indices.shape[1]):
71 | for k in range(indices.shape[2]):
72 | scattered[i, j, indices[i, j, k]] = f(
73 | scattered[i, j, indices[i, j, k]], updates[i, j, k]
74 | )
75 | return scattered
76 |
77 | if len(indices.shape) == 4:
78 | scattered = np.copy(data)
79 | if axis == 3:
80 | for a in range(indices.shape[0]):
81 | for i in range(indices.shape[1]):
82 | for j in range(indices.shape[2]):
83 | for k in range(indices.shape[3]):
84 | scattered[a, i, j, indices[a, i, j, k]] = f(
85 | scattered[a, i, j, indices[a, i, j, k]],
86 | updates[a, i, j, k],
87 | )
88 | return scattered
89 |
90 | raise RuntimeError(
91 | f"Not implemented for indices.shape={indices.shape} and axis={axis}"
92 | )
93 |
94 |
95 | class ScatterElements(OpRun):
96 | def _run(self, data, indices, updates, axis=None, reduction=None): # type: ignore
97 | res = scatter_elements(data, indices, updates, axis=axis, reduction=reduction)
98 | return (res,)
99 |
--------------------------------------------------------------------------------
/onnx_array_api/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/onnx_array_api/translate_api/__init__.py:
--------------------------------------------------------------------------------
1 | from onnx import ModelProto
2 | from .translate import Translater
3 | from .inner_emitter import InnerEmitter, InnerEmitterShortInitializer
4 | from .builder_emitter import BuilderEmitter
5 |
6 |
7 | def translate(proto: ModelProto, single_line: bool = False, api: str = "light") -> str:
8 | """
9 | Translates an ONNX proto into a code using :ref:`l-light-api`
10 | to describe the ONNX graph.
11 |
12 | :param proto: model to translate
13 | :param single_line: as a single line or not
14 | :param api: API to export into,
15 | default is `"light"` and this is handle by class
16 | :class:`onnx_array_api.translate_api.light_emitter.LightEmitter`,
17 | another value is `"onnx"` which is the inner API implemented
18 | in onnx package, `"builder"` follows the syntax for the
19 | class :class:`onnx_array_api.graph_api.GraphBuilder`,
20 | `"onnx-short"` replaces long initializer with random values
21 | :return: code
22 |
23 | .. runpython::
24 | :showcode:
25 |
26 | from onnx_array_api.light_api import start
27 | from onnx_array_api.translate_api import translate
28 |
29 | onx = (
30 | start()
31 | .vin("X")
32 | .reshape((-1, 1))
33 | .Transpose(perm=[1, 0])
34 | .rename("Y")
35 | .vout()
36 | .to_onnx()
37 | )
38 | code = translate(onx)
39 | print(code)
40 |
41 | The inner API from onnx package is also available.
42 |
43 | .. runpython::
44 | :showcode:
45 |
46 | from onnx_array_api.light_api import start
47 | from onnx_array_api.translate_api import translate
48 |
49 | onx = (
50 | start()
51 | .vin("X")
52 | .reshape((-1, 1))
53 | .Transpose(perm=[1, 0])
54 | .rename("Y")
55 | .vout()
56 | .to_onnx()
57 | )
58 | code = translate(onx, api="onnx")
59 | print(code)
60 |
61 | The :class:`GraphBuilder
62 | ` API returns this:
63 |
64 | .. runpython::
65 | :showcode:
66 |
67 | from onnx_array_api.light_api import start
68 | from onnx_array_api.translate_api import translate
69 |
70 | onx = (
71 | start()
72 | .vin("X")
73 | .reshape((-1, 1))
74 | .Transpose(perm=[1, 0])
75 | .rename("Y")
76 | .vout()
77 | .to_onnx()
78 | )
79 | code = translate(onx, api="builder")
80 | print(code)
81 | """
82 | if api == "light":
83 | tr = Translater(proto)
84 | return tr.export(single_line=single_line, as_str=True)
85 | if api == "onnx":
86 | tr = Translater(proto, emitter=InnerEmitter())
87 | return tr.export(as_str=True)
88 | if api == "onnx-short":
89 | tr = Translater(proto, emitter=InnerEmitterShortInitializer())
90 | return tr.export(as_str=True)
91 | if api == "builder":
92 | tr = Translater(proto, emitter=BuilderEmitter())
93 | return tr.export(as_str=True)
94 | raise ValueError(f"Unexpected value {api!r} for api.")
95 |
--------------------------------------------------------------------------------
/onnx_array_api/translate_api/light_emitter.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 | from ..annotations import ELEMENT_TYPE_NAME
3 | from .base_emitter import BaseEmitter
4 |
5 |
6 | class LightEmitter(BaseEmitter):
7 | """
8 | Converts event into proper code.
9 | """
10 |
11 | def join(self, rows: List[str], single_line: bool = False) -> str:
12 | "Join the rows"
13 | if single_line:
14 | return ".".join(rows)
15 | return "".join(["(\n ", "\n .".join(rows), "\n)"])
16 |
17 | def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
18 | opsets = kwargs.get("opsets", {})
19 | opset = opsets.get("", None)
20 | if opset is not None:
21 | del opsets[""]
22 | args = []
23 | if opset:
24 | args.append(f"opset={opset}")
25 | if opsets:
26 | args.append(f"opsets={opsets}")
27 | return [f"start({', '.join(args)})"]
28 |
29 | def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
30 | return ["to_onnx()"]
31 |
32 | def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]:
33 | return []
34 |
35 | def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
36 | return []
37 |
38 | def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
39 | return []
40 |
41 | def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
42 | name = kwargs["name"]
43 | value = kwargs["value"]
44 | repl = {"bool": "bool_", "object": "object_", "str": "str_"}
45 | sdtype = repl.get(str(value.dtype), str(str(value.dtype)))
46 | return [
47 | f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))",
48 | f"rename({name!r})",
49 | ]
50 |
51 | def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
52 | name = kwargs["name"]
53 | elem_type = kwargs.get("elem_type", None)
54 | shape = kwargs.get("shape", None)
55 | if elem_type and shape:
56 | return [
57 | f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, "
58 | f"shape={shape!r})"
59 | ]
60 | if elem_type:
61 | return [
62 | f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})"
63 | ]
64 | return [f"vin({name!r})"]
65 |
66 | def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
67 | inst = []
68 | if "name" in kwargs:
69 | name = kwargs["name"]
70 | inst.append(f"bring({name!r})")
71 | elem_type = kwargs.get("elem_type", None)
72 | shape = kwargs.get("shape", None)
73 | if elem_type and shape:
74 | inst.append(
75 | f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, "
76 | f"shape={shape!r})"
77 | )
78 | elif elem_type:
79 | inst.append(f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})")
80 | else:
81 | inst.append("vout()")
82 | return inst
83 |
84 | def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
85 | op_type = kwargs["op_type"]
86 | inputs = kwargs["inputs"]
87 | outputs = kwargs["outputs"]
88 | if kwargs.get("domain", "") != "":
89 | domain = kwargs["domain"]
90 | op_type = f"{domain}.{op_type}"
91 | atts = kwargs.get("atts", {})
92 | args = []
93 | for k, v in atts.items():
94 | before, vatt = self.render_attribute_value(v)
95 | if before:
96 | raise NotImplementedError("Graph attribute not supported yet.")
97 | args.append(f"{k}={vatt}")
98 |
99 | str_inputs = ", ".join([f"{i!r}" for i in inputs])
100 | inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"]
101 | if len(outputs) == 1:
102 | inst.append(f"rename({outputs[0]!r})")
103 | else:
104 | str_outputs = ", ".join([f"{o!r}" for o in outputs])
105 | inst.append(f"rename({str_outputs})")
106 | return inst
107 |
--------------------------------------------------------------------------------
/onnx_array_api/translate_api/make_helper.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional, Sequence
2 | from onnx import AttributeProto, NodeProto
3 | from onnx.helper import make_attribute
4 |
5 |
6 | def make_ref_attribute(
7 | key: str, attr_type: int, ref_attr_name: Optional[str] = None
8 | ) -> AttributeProto:
9 | """
10 | Creates an attribute.
11 |
12 | :param key: atttribute name
13 | :param attr_type: attribute type
14 | :param ref_attr_name: if not None, link this attribute
15 | to a function attribute
16 | :return: attribute
17 | """
18 | att = AttributeProto()
19 | att.name = key
20 | att.type = attr_type
21 | att.ref_attr_name = ref_attr_name
22 | return att
23 |
24 |
25 | def make_node_extended(
26 | op_type: str,
27 | inputs: Sequence[str],
28 | outputs: Sequence[str],
29 | name: Optional[str] = None,
30 | doc_string: Optional[str] = None,
31 | domain: Optional[str] = None,
32 | **kwargs: Any,
33 | ) -> NodeProto:
34 | """
35 | Constructs a NodeProto.
36 |
37 | :param op_type: The name of the operator to construct
38 | :param inputs: list of input names
39 | :param outputs: list of output names
40 | :param name: optional unique identifier for NodeProto
41 | :param doc_string: optional documentation string for NodeProto
42 | :param domain: optional domain for NodeProto.
43 | If it's None, we will just use default domain (which is empty)
44 | :param kwargs: the attributes of the node.
45 | :return: node proto
46 | """
47 | node = NodeProto()
48 | node.op_type = op_type
49 | node.input.extend(inputs)
50 | node.output.extend(outputs)
51 | if name:
52 | node.name = name
53 | if doc_string:
54 | node.doc_string = doc_string
55 | if domain is not None:
56 | node.domain = domain
57 | if kwargs:
58 | for key, value in sorted(kwargs.items()):
59 | if value is None:
60 | continue
61 | if isinstance(value, AttributeProto):
62 | node.attribute.append(value)
63 | else:
64 | node.attribute.append(make_attribute(key, value))
65 | return node
66 |
--------------------------------------------------------------------------------
/onnx_array_api/validation/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/onnx_array_api/validation/diff.py:
--------------------------------------------------------------------------------
1 | import os
2 | import difflib
3 | import textwrap
4 | from typing import Union
5 | from onnx import ModelProto
6 |
7 |
8 | def _get_diff_template():
9 | import jinja2
10 |
11 | tpl = textwrap.dedent(
12 | """
13 |
14 |
15 |
17 |
42 | """
43 | )
44 | path = os.path.abspath(os.path.dirname(__file__))
45 | path = path.replace("\\", "/")
46 | path = f"file://{path}"
47 | tpl = tpl.replace("__PATH__", path)
48 | return jinja2.Template(tpl, autoescape=True)
49 |
50 |
51 | def text_diff(text1: Union[ModelProto, str], text2: Union[ModelProto, str]) -> str:
52 | """
53 | Produces a string showing the differences between
54 | two strings.
55 |
56 | :param text1: first string
57 | :param text2: second string
58 | :return: differences
59 | """
60 | if not isinstance(text1, str):
61 | from ..plotting.text_plot import onnx_simple_text_plot
62 |
63 | text1 = onnx_simple_text_plot(text1, indent=False)
64 | if not isinstance(text2, str):
65 | from ..plotting.text_plot import onnx_simple_text_plot
66 |
67 | text2 = onnx_simple_text_plot(text2, indent=False)
68 | differ = difflib.Differ()
69 | result = list(
70 | differ.compare(text1.splitlines(keepends=True), text2.splitlines(keepends=True))
71 | )
72 | raw = "".join(result)
73 | return raw
74 |
75 |
76 | def html_diff(
77 | text1: Union[ModelProto, str],
78 | text2: Union[ModelProto, str],
79 | title: str = "html_diff",
80 | div_name: str = "div_name",
81 | header: bool = True,
82 | ) -> str:
83 | """
84 | Produces a HTML files showing the differences between
85 | two strings.
86 |
87 | :param text1: first string
88 | :param text2: second string
89 | :param title: title
90 | :param div: html format, section name
91 | :param header: if True, add header and html main tags
92 | :return: differences
93 | """
94 | raw = text_diff(text1, text2)
95 | diff = _get_diff_template().render(
96 | title=title,
97 | version1=text1,
98 | version2=text2,
99 | div_name=f"div_{div_name}",
100 | diff_content=raw,
101 | )
102 | return f"\n{diff}\n\n"
103 |
--------------------------------------------------------------------------------
/onnx_array_api/validation/diff2html.min.css:
--------------------------------------------------------------------------------
1 | .d2h-wrapper{text-align:left}.d2h-file-header{background-color:#f7f7f7;border-bottom:1px solid #d8d8d8;font-family:Source Sans Pro,Helvetica Neue,Helvetica,Arial,sans-serif;height:35px;padding:5px 10px}.d2h-file-header,.d2h-file-stats{display:-webkit-box;display:-ms-flexbox;display:flex}.d2h-file-stats{font-size:14px;margin-left:auto}.d2h-lines-added{border:1px solid #b4e2b4;border-radius:5px 0 0 5px;color:#399839;padding:2px;text-align:right;vertical-align:middle}.d2h-lines-deleted{border:1px solid #e9aeae;border-radius:0 5px 5px 0;color:#c33;margin-left:1px;padding:2px;text-align:left;vertical-align:middle}.d2h-file-name-wrapper{-webkit-box-align:center;-ms-flex-align:center;align-items:center;display:-webkit-box;display:-ms-flexbox;display:flex;font-size:15px;width:100%}.d2h-file-name{overflow-x:hidden;text-overflow:ellipsis;white-space:nowrap}.d2h-file-wrapper{border:1px solid #ddd;border-radius:3px;margin-bottom:1em}.d2h-file-collapse{-webkit-box-pack:end;-ms-flex-pack:end;-webkit-box-align:center;-ms-flex-align:center;align-items:center;border:1px solid #ddd;border-radius:3px;cursor:pointer;display:none;font-size:12px;justify-content:flex-end;padding:4px 8px}.d2h-file-collapse.d2h-selected{background-color:#c8e1ff}.d2h-file-collapse-input{margin:0 4px 0 0}.d2h-diff-table{border-collapse:collapse;font-family:Menlo,Consolas,monospace;font-size:13px;width:100%}.d2h-files-diff{display:-webkit-box;display:-ms-flexbox;display:flex;width:100%}.d2h-file-diff{overflow-y:hidden}.d2h-file-diff.d2h-d-none,.d2h-files-diff.d2h-d-none{display:none}.d2h-file-side-diff{display:inline-block;margin-bottom:-8px;margin-right:-4px;overflow-x:scroll;overflow-y:hidden;width:50%}.d2h-code-line{padding:0 8em}.d2h-code-line,.d2h-code-side-line{display:inline-block;-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none;white-space:nowrap;width:100%}.d2h-code-side-line{padding:0 4.5em}.d2h-code-line-ctn{word-wrap:normal;background:none;display:inline-block;padding:0;-webkit-user-select:text;-moz-user-select:text;-ms-user-select:text;user-select:text;vertical-align:middle;white-space:pre;width:100%}.d2h-code-line del,.d2h-code-side-line del{background-color:#ffb6ba}.d2h-code-line del,.d2h-code-line ins,.d2h-code-side-line del,.d2h-code-side-line ins{border-radius:.2em;display:inline-block;margin-top:-1px;text-decoration:none;vertical-align:middle}.d2h-code-line ins,.d2h-code-side-line ins{background-color:#97f295;text-align:left}.d2h-code-line-prefix{word-wrap:normal;background:none;display:inline;padding:0;white-space:pre}.line-num1{float:left}.line-num1,.line-num2{-webkit-box-sizing:border-box;box-sizing:border-box;overflow:hidden;padding:0 .5em;text-overflow:ellipsis;width:3.5em}.line-num2{float:right}
2 | .d2h-code-linenumber{background-color:var(--pst-color-background);border:solid #eee;border-width:0 1px;-webkit-box-sizing:border-box;box-sizing:border-box;color:var(--pst-color-text-base);cursor:pointer;display:inline-block;position:absolute;text-align:right;width:7.5em}
3 | .d2h-code-linenumber:after{content:"\200b"}
4 | .d2h-code-side-linenumber{background-color:var(--pst-color-background);border:solid #eee;border-width:0 1px;-webkit-box-sizing:border-box;box-sizing:border-box;color:var(--pst-color-text-base);cursor:pointer;display:inline-block;overflow:hidden;padding:0 .5em;position:absolute;text-align:right;text-overflow:ellipsis;width:4em}
5 | .d2h-code-side-linenumber:after{content:"\200b"}.d2h-code-side-emptyplaceholder,.d2h-emptyplaceholder{background-color:#f1f1f1;border-color:#e1e1e1}.d2h-code-line-prefix,
6 | .d2h-code-linenumber,.d2h-code-side-linenumber,.d2h-emptyplaceholder{-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}
7 | .d2h-code-linenumber,.d2h-code-side-linenumber{direction:rtl}.d2h-del{background-color:#fee8e9;border-color:#e9aeae}.d2h-ins{background-color:#dfd;border-color:#b4e2b4}.d2h-info{background-color:#f8fafd;border-color:#d5e4f2;color:rgba(0,0,0,.3)}.d2h-file-diff .d2h-del.d2h-change{background-color:#fdf2d0}.d2h-file-diff .d2h-ins.d2h-change{background-color:#ded}.d2h-file-list-wrapper{margin-bottom:10px}.d2h-file-list-wrapper a{color:#3572b0;text-decoration:none}.d2h-file-list-wrapper a:visited{color:#3572b0}.d2h-file-list-header{text-align:left}.d2h-file-list-title{font-weight:700}.d2h-file-list-line{display:-webkit-box;display:-ms-flexbox;display:flex;text-align:left}.d2h-file-list{display:block;list-style:none;margin:0;padding:0}.d2h-file-list>li{border-bottom:1px solid #ddd;margin:0;padding:5px 10px}.d2h-file-list>li:last-child{border-bottom:none}.d2h-file-switch{cursor:pointer;display:none;font-size:10px}.d2h-icon{fill:currentColor;margin-right:10px;vertical-align:middle}.d2h-deleted{color:#c33}.d2h-added{color:#399839}.d2h-changed{color:#d0b44c}.d2h-moved{color:#3572b0}.d2h-tag{background-color:#fff;display:-webkit-box;display:-ms-flexbox;display:flex;font-size:10px;margin-left:5px;padding:0 2px}.d2h-deleted-tag{border:1px solid #c33}.d2h-added-tag{border:1px solid #399839}.d2h-changed-tag{border:1px solid #d0b44c}.d2h-moved-tag{border:1px solid #3572b0}
--------------------------------------------------------------------------------
/onnx_array_api/validation/docs.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import numpy as np
3 | import onnx
4 | import onnx.helper as oh
5 |
6 |
7 | def make_euclidean(
8 | input_names: Tuple[str] = ("X", "Y"),
9 | output_name: str = "Z",
10 | elem_type: int = onnx.TensorProto.FLOAT,
11 | opset: Optional[int] = None,
12 | ) -> onnx.ModelProto:
13 | """
14 | Creates the onnx graph corresponding to the euclidean distance.
15 |
16 | :param input_names: names of the inputs
17 | :param output_name: name of the output
18 | :param elem_type: onnx is strongly types, which type is it?
19 | :param opset: opset version
20 | :return: onnx.ModelProto
21 | """
22 | if opset is None:
23 | opset = onnx.defs.onnx_opset_version()
24 |
25 | X = oh.make_tensor_value_info(input_names[0], elem_type, None)
26 | Y = oh.make_tensor_value_info(input_names[1], elem_type, None)
27 | Z = oh.make_tensor_value_info(output_name, elem_type, None)
28 | two = oh.make_tensor("two", onnx.TensorProto.INT64, [1], [2])
29 | n1 = oh.make_node("Sub", ["X", "Y"], ["dxy"])
30 | n2 = oh.make_node("Pow", ["dxy", "two"], ["dxy2"])
31 | n3 = oh.make_node("ReduceSum", ["dxy2"], [output_name])
32 | graph = oh.make_graph([n1, n2, n3], "euclidian", [X, Y], [Z], [two])
33 | model = oh.make_model(
34 | graph, opset_imports=[oh.make_opsetid("", opset)], ir_version=9
35 | )
36 | return model
37 |
38 |
39 | def make_euclidean_skl2onnx(
40 | input_names: Tuple[str] = ("X", "Y"),
41 | output_name: str = "Z",
42 | elem_type: int = onnx.TensorProto.FLOAT,
43 | opset: Optional[int] = None,
44 | ) -> onnx.ModelProto:
45 | """
46 | Creates the onnx graph corresponding to the euclidean distance
47 | with :epkg:`sklearn-onnx`.
48 |
49 | :param input_names: names of the inputs
50 | :param output_name: name of the output
51 | :param elem_type: onnx is strongly types, which type is it?
52 | :param opset: opset version
53 | :return: onnx.ModelProto
54 | """
55 | if opset is None:
56 | opset = onnx.defs.onnx_opset_version()
57 |
58 | from skl2onnx.algebra.onnx_ops import OnnxSub, OnnxPow, OnnxReduceSum
59 |
60 | dxy = OnnxSub(input_names[0], input_names[1], op_version=opset)
61 | dxy2 = OnnxPow(dxy, np.array([2], dtype=np.int64), op_version=opset)
62 | final = OnnxReduceSum(dxy2, op_version=opset, output_names=[output_name])
63 |
64 | np_type = oh.tensor_dtype_to_np_dtype(elem_type)
65 | dummy = np.empty([1], np_type)
66 | return final.to_onnx({"X": dummy, "Y": dummy})
67 |
--------------------------------------------------------------------------------
/onnx_array_api/validation/tools.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import numpy
3 | from onnx import (
4 | AttributeProto,
5 | GraphProto,
6 | FunctionProto,
7 | ModelProto,
8 | NodeProto,
9 | TensorProto,
10 | )
11 | from onnx.helper import (
12 | make_attribute,
13 | make_function,
14 | make_graph,
15 | make_model,
16 | make_node,
17 | set_model_props,
18 | )
19 | from ..reference import from_array_extended as from_array, to_array_extended as to_array
20 |
21 |
22 | def randomize_proto(
23 | onx: Union[ModelProto, GraphProto, FunctionProto, NodeProto, TensorProto],
24 | ) -> Union[ModelProto, GraphProto, FunctionProto, NodeProto, TensorProto]:
25 | """
26 | Randomizes float initializers or constant nodes.
27 |
28 | :param onx: onnx model or proto
29 | :return: same object
30 | """
31 | if isinstance(onx, TensorProto):
32 | t = to_array(onx)
33 | mini, maxi = t.min(), t.max()
34 | new_t = numpy.clip(
35 | numpy.random.random(t.shape) * (maxi - mini) + mini, mini, maxi
36 | )
37 | return from_array(new_t.astype(t.dtype), name=onx.name)
38 |
39 | if isinstance(onx, ModelProto):
40 | new_graph = randomize_proto(onx.graph)
41 | new_functions = [randomize_proto(f) for f in onx.functions]
42 |
43 | onnx_model = make_model(
44 | new_graph,
45 | functions=new_functions,
46 | ir_version=onx.ir_version,
47 | producer_name=onx.producer_name,
48 | domain=onx.domain,
49 | doc_string=onx.doc_string,
50 | opset_imports=list(onx.opset_import),
51 | )
52 | if onx.metadata_props:
53 | values = {p.key: p.value for p in onx.metadata_props}
54 | set_model_props(onnx_model, values)
55 | return onnx_model
56 |
57 | if isinstance(onx, (GraphProto, FunctionProto)):
58 | nodes = []
59 | for node in onx.node:
60 | if node.op_type in "Constant":
61 | nodes.append(randomize_proto(node))
62 | continue
63 | changed = False
64 | atts = []
65 | for att in node.attribute:
66 | if att.type == AttributeProto.GRAPH:
67 | new_g = randomize_proto(att.g)
68 | att = make_attribute(att.name, new_g)
69 | changed = True
70 | atts.append(att)
71 | if changed:
72 | new_node = make_node(
73 | node.op_type, node.input, node.output, domain=node.domain
74 | )
75 | new_node.attribute.extend(node.attribute)
76 | nodes.append(new_node)
77 | continue
78 | nodes.append(node)
79 |
80 | if isinstance(onx, FunctionProto):
81 | new_onx = make_function(
82 | onx.domain,
83 | onx.name,
84 | onx.input,
85 | onx.output,
86 | nodes,
87 | opset_imports=onx.opset_import,
88 | )
89 | return new_onx
90 |
91 | inits = [randomize_proto(init) for init in onx.initializer]
92 | sp_inits = [randomize_proto(init) for init in onx.sparse_initializer]
93 |
94 | graph = make_graph(
95 | nodes,
96 | onx.name,
97 | onx.input,
98 | onx.output,
99 | initializer=inits,
100 | sparse_initializer=sp_inits,
101 | )
102 | return graph
103 |
104 | raise TypeError(f"Unexpected type for onx {type(onx)}.")
105 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.ruff]
2 |
3 | # Exclude a variety of commonly ignored directories.
4 | exclude = [
5 | ".eggs",
6 | ".git",
7 | "build",
8 | "dist",
9 | ]
10 |
11 | # Same as Black.
12 | line-length = 88
13 |
14 | [tool.ruff.lint]
15 | select = [
16 | "B", # flake8-bugbear
17 | "C4", # flake8-comprehensions
18 | #"D", # pydocstyle
19 | "E", # pycodestyle
20 | "F", # Pyflakes
21 | "G", # flake8-logging-format
22 | #"I", # isort
23 | "ISC", # flake8-implicit-str-concat
24 | "LOG", # flake8-logging
25 | #"N", # pep8-naming
26 | #"NPY", # modern numpy
27 | #"PERF", # Perflint
28 | "PIE", # flake8-pie
29 | "PYI", # flake8-pyi
30 | "RUF", # Ruff-specific rules
31 | "SIM", # flake8-simplify
32 | "SLOT", # flake8-slot
33 | "T10", # flake8-debugger
34 | #"TID", # Disallow relative imports
35 | #"TRY", # flake8-try-except-raise
36 | "UP", # pyupgrade
37 | "W", # pycodestyle
38 | "YTT", # flake8-2020
39 | ]
40 |
41 | [tool.ruff.lint.per-file-ignores]
42 | "**" = ["B905", "C401", "C408", "C413", "PYI041", "RUF012", "RUF100", "RUF010", "SIM108", "SIM910", "SIM110", "SIM102", "SIM114", "SIM103", "UP015", "UP027", "UP031", "UP034", "UP032", "UP006", "UP035", "UP007", "UP038"]
43 | "**/plot*.py" = ["B018"]
44 | "_doc/examples/plot_first_example.py" = ["E402", "F811"]
45 | "_doc/examples/plot_onnxruntime.py" = ["E402", "F811"]
46 | "onnx_array_api/array_api/_onnx_common.py" = ["F821"]
47 | "onnx_array_api/graph_api/__init__.py" = ["F401"]
48 | "onnx_array_api/light_api/__init__.py" = ["F401"]
49 | "onnx_array_api/light_api/_op_var.py" = ["F821"]
50 | "onnx_array_api/light_api/_op_vars.py" = ["F821"]
51 | "onnx_array_api/annotations.py" = ["F821"]
52 | "onnx_array_api/light_api/model.py" = ["F821"]
53 | "onnx_array_api/translate_api/__init__.py" = ["F401"]
54 | "onnx_array_api/npx/__init__.py" = ["F401", "F403"]
55 | "onnx_array_api/npx/npx_functions.py" = ["F821"]
56 | "onnx_array_api/npx/npx_functions_test.py" = ["F821"]
57 | "onnx_array_api/npx/npx_tensors.py" = ["F821"]
58 | "onnx_array_api/npx/npx_var.py" = ["F821"]
59 | "onnx_array_api/profiling.py" = ["E731"]
60 | "onnx_array_api/reference/__init__.py" = ["F401"]
61 | "_unittests/ut_npx/test_npx.py" = ["F821"]
62 | "_unittests/ut_translate_api/test_translate_classic.py" = ["E501"]
63 |
64 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | array_api_compat
2 | array_api_strict
3 | autopep8
4 | black
5 | coverage
6 | flake8
7 | furo
8 | google-re2
9 | hypothesis
10 | isort
11 | joblib
12 | lightgbm
13 | matplotlib
14 | ml-dtypes
15 | git+https://github.com/onnx/onnxmltools.git
16 | onnxruntime>=1.17.0
17 | openpyxl
18 | packaging
19 | pandas
20 | Pillow
21 | psutil
22 | pytest
23 | pytest-cov
24 | sphinx-issues
25 | git+https://github.com/sdpython/sphinx-runpython.git
26 | ruff
27 | scikit-learn>=1.3.2
28 | git+https://github.com/onnx/sklearn-onnx.git
29 | sphinx
30 | sphinx-gallery
31 | tomli
32 | tqdm
33 | wheel
34 | xgboost
35 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | onnx>=1.15.0
3 | scipy
4 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [options]
2 | packages = find:
3 |
4 | [options.packages.find]
5 | include = onnx_array_api*
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from setuptools import setup
4 |
5 | ######################
6 | # beginning of setup
7 | ######################
8 |
9 |
10 | here = os.path.dirname(__file__)
11 | if here == "":
12 | here = "."
13 | package_data = {"onnx_array_api.validation": ["*.css", "*.js"]}
14 |
15 | try:
16 | with open(os.path.join(here, "requirements.txt"), "r") as f:
17 | requirements = f.read().strip(" \n\r\t").split("\n")
18 | except FileNotFoundError:
19 | requirements = []
20 | if not requirements or requirements == [""]:
21 | requirements = ["numpy", "scipy", "onnx"]
22 |
23 | try:
24 | with open(os.path.join(here, "README.rst"), "r", encoding="utf-8") as f:
25 | long_description = "onnx-array-api:" + f.read().split("onnx-array-api:")[1]
26 | except FileNotFoundError:
27 | long_description = ""
28 |
29 | version_str = "0.1.0"
30 | with open(os.path.join(here, "onnx_array_api/__init__.py"), "r") as f:
31 | line = [
32 | _
33 | for _ in [_.strip("\r\n ") for _ in f.readlines()]
34 | if _.startswith("__version__")
35 | ]
36 | if line:
37 | version_str = line[0].split("=")[1].strip('" ')
38 |
39 |
40 | setup(
41 | name="onnx-array-api",
42 | version=version_str,
43 | description="Array (and numpy) API for ONNX",
44 | long_description=long_description,
45 | author="Xavier Dupré",
46 | author_email="xavier.dupre@gmail.com",
47 | url="https://github.com/sdpython/onnx-array-api",
48 | package_data=package_data,
49 | setup_requires=["numpy", "scipy"],
50 | install_requires=requirements,
51 | classifiers=[
52 | "Intended Audience :: Science/Research",
53 | "Intended Audience :: Developers",
54 | "License :: OSI Approved :: MIT License",
55 | "Programming Language :: C",
56 | "Programming Language :: Python",
57 | "Topic :: Software Development",
58 | "Topic :: Scientific/Engineering",
59 | "Development Status :: 5 - Production/Stable",
60 | "Operating System :: Microsoft :: Windows",
61 | "Operating System :: POSIX",
62 | "Operating System :: Unix",
63 | "Operating System :: MacOS",
64 | "Programming Language :: Python :: 3",
65 | "Programming Language :: Python :: 3.9",
66 | "Programming Language :: Python :: 3.10",
67 | "Programming Language :: Python :: 3.11",
68 | "Programming Language :: Python :: 3.12",
69 | "Programming Language :: Python :: 3.13",
70 | ],
71 | )
72 |
--------------------------------------------------------------------------------