├── .github
└── workflows
│ ├── black-ruff.yml
│ ├── codeql.yml
│ ├── linux-ci.yml
│ ├── release.yml
│ ├── wheels-any.yml
│ └── windows-macos-ci.yml
├── .gitignore
├── CHANGELOGS.md
├── LICENSE
├── MANIFEST.in
├── NOTICE
├── README.md
├── benchmarks
├── README.txt
├── bench_plot_onnxruntime_decision_tree.py
├── bench_plot_onnxruntime_hgb.py
├── bench_plot_onnxruntime_linreg.py
├── bench_plot_onnxruntime_logreg.py
├── bench_plot_onnxruntime_random_forest.py
├── bench_plot_onnxruntime_random_forest_reg.py
├── bench_plot_onnxruntime_svm_reg.py
└── post_graph.py
├── docs
├── api_summary.rst
├── conf.py
├── examples
│ ├── Au-Salon-de-l-agriculture-la-campagne-recrute.jpg
│ ├── README.txt
│ ├── daisy_wikipedia.jpg
│ ├── plot_backend.py
│ ├── plot_benchmark_cdist.py
│ ├── plot_benchmark_pipeline.py
│ ├── plot_black_op.py
│ ├── plot_cast_transformer.py
│ ├── plot_complex_pipeline.py
│ ├── plot_convert_decision_function.py
│ ├── plot_convert_model.py
│ ├── plot_convert_syntax.py
│ ├── plot_convert_zipmap.py
│ ├── plot_custom_model.py
│ ├── plot_custom_parser.py
│ ├── plot_custom_parser_alternative.py
│ ├── plot_errors_onnxruntime.py
│ ├── plot_gpr.py
│ ├── plot_intermediate_outputs.py
│ ├── plot_investigate_pipeline.py
│ ├── plot_logging.py
│ ├── plot_metadata.py
│ ├── plot_nmf.py
│ ├── plot_onnx_operators.py
│ ├── plot_output_onnx_single_probability.py
│ ├── plot_pipeline.py
│ ├── plot_pipeline_lightgbm.py
│ ├── plot_pipeline_xgboost.py
│ └── plot_tfidfvectorizer.py
├── exts
│ ├── github_link.py
│ └── sphinx_skl2onnx_extension.py
├── images
│ └── woe.png
├── index.rst
├── index_tutorial.rst
├── introduction.rst
├── logo_main.png
├── parameterized.rst
├── pipeline.png
├── pipeline.rst
├── requirements.txt
├── supported.rst
├── tests
│ ├── test_documentation_examples.py
│ ├── test_utils_benchmark.py
│ └── test_utils_classes.py
├── tutorial
│ ├── README.txt
│ ├── plot_abegin_convert_pipeline.py
│ ├── plot_bbegin_measure_time.py
│ ├── plot_catwoe_transformer.py
│ ├── plot_cbegin_opset.py
│ ├── plot_dbegin_options.py
│ ├── plot_dbegin_options_list.py
│ ├── plot_dbegin_options_zipmap.py
│ ├── plot_ebegin_float_double.py
│ ├── plot_fbegin_investigate.py
│ ├── plot_gbegin_cst.py
│ ├── plot_gbegin_dataframe.py
│ ├── plot_gconverting.py
│ ├── plot_gexternal_catboost.py
│ ├── plot_gexternal_lightgbm.py
│ ├── plot_gexternal_lightgbm_reg.py
│ ├── plot_gexternal_xgboost.py
│ ├── plot_icustom_converter.py
│ ├── plot_jcustom_syntax.py
│ ├── plot_jfunction_transformer.py
│ ├── plot_kcustom_converter_wrapper.py
│ ├── plot_lcustom_options.py
│ ├── plot_mcustom_parser.py
│ ├── plot_ngrams.py
│ ├── plot_transformer_discrepancy.py
│ ├── plot_usparse_xgboost.py
│ ├── plot_weird_pandas_and_hash.py
│ ├── plot_wext_pyod_forest.py
│ └── plot_woe_transformer.py
├── tutorial_1-5_external.rst
├── tutorial_1_simple.rst
├── tutorial_2-5_extlib.rst
├── tutorial_2_new_converter.rst
└── tutorial_4_advanced.rst
├── pyproject.toml
├── requirements-dev.txt
├── requirements.txt
├── skl2onnx
├── __init__.py
├── __main__.py
├── _parse.py
├── _supported_operators.py
├── algebra
│ ├── __init__.py
│ ├── automation.py
│ ├── complex_functions.py
│ ├── custom_ops.py
│ ├── graph_state.py
│ ├── onnx_operator.py
│ ├── onnx_operator_mixin.py
│ ├── onnx_ops.py
│ ├── onnx_subgraph_operator_mixin.py
│ ├── sklearn_ops.py
│ └── type_helper.py
├── common
│ ├── __init__.py
│ ├── _apply_operation.py
│ ├── _container.py
│ ├── _onnx_optimisation_common.py
│ ├── _registration.py
│ ├── _topology.py
│ ├── data_types.py
│ ├── exceptions.py
│ ├── graph_builder_opset.py
│ ├── onnx_optimisation_identity.py
│ ├── shape_calculator.py
│ ├── tree_ensemble.py
│ ├── utils.py
│ ├── utils_checking.py
│ ├── utils_classifier.py
│ └── utils_sklearn.py
├── convert.py
├── helpers
│ ├── __init__.py
│ ├── integration.py
│ ├── investigate.py
│ ├── onnx_helper.py
│ └── onnx_rare_helper.py
├── operator_converters
│ ├── __init__.py
│ ├── _gp_kernels.py
│ ├── ada_boost.py
│ ├── array_feature_extractor.py
│ ├── bagging.py
│ ├── binariser.py
│ ├── calibrated_classifier_cv.py
│ ├── cast_op.py
│ ├── class_labels.py
│ ├── common.py
│ ├── concat_op.py
│ ├── cross_decomposition.py
│ ├── decision_tree.py
│ ├── decomposition.py
│ ├── dict_vectoriser.py
│ ├── feature_hasher.py
│ ├── feature_selection.py
│ ├── flatten_op.py
│ ├── function_transformer.py
│ ├── gamma_regressor.py
│ ├── gaussian_mixture.py
│ ├── gaussian_process.py
│ ├── gradient_boosting.py
│ ├── grid_search_cv.py
│ ├── id_op.py
│ ├── imputer_op.py
│ ├── isolation_forest.py
│ ├── k_bins_discretiser.py
│ ├── k_means.py
│ ├── kernel_pca.py
│ ├── label_binariser.py
│ ├── label_encoder.py
│ ├── linear_classifier.py
│ ├── linear_regressor.py
│ ├── local_outlier_factor.py
│ ├── multilayer_perceptron.py
│ ├── multioutput.py
│ ├── multiply_op.py
│ ├── naive_bayes.py
│ ├── nearest_neighbours.py
│ ├── normaliser.py
│ ├── one_hot_encoder.py
│ ├── one_vs_one_classifier.py
│ ├── one_vs_rest_classifier.py
│ ├── ordinal_encoder.py
│ ├── ovr_decision_function.py
│ ├── pipelines.py
│ ├── polynomial_features.py
│ ├── power_transformer.py
│ ├── quadratic_discriminant_analysis.py
│ ├── quantile_transformer.py
│ ├── random_forest.py
│ ├── random_projection.py
│ ├── random_trees_embedding.py
│ ├── ransac_regressor.py
│ ├── replace_op.py
│ ├── scaler_op.py
│ ├── sequence.py
│ ├── sgd_classifier.py
│ ├── sgd_oneclass_svm.py
│ ├── stacking.py
│ ├── support_vector_machines.py
│ ├── target_encoder.py
│ ├── text_vectoriser.py
│ ├── tfidf_transformer.py
│ ├── tfidf_vectoriser.py
│ ├── tuned_threshold_classifier.py
│ ├── voting_classifier.py
│ ├── voting_regressor.py
│ └── zip_map.py
├── proto
│ └── __init__.py
├── shape_calculators
│ ├── __init__.py
│ ├── array_feature_extractor.py
│ ├── cast_op.py
│ ├── class_labels.py
│ ├── concat.py
│ ├── cross_decomposition.py
│ ├── dict_vectorizer.py
│ ├── ensemble_shapes.py
│ ├── feature_hasher.py
│ ├── feature_selection.py
│ ├── flatten.py
│ ├── function_transformer.py
│ ├── gaussian_process.py
│ ├── grid_search_cv.py
│ ├── identity.py
│ ├── imputer.py
│ ├── isolation_forest.py
│ ├── k_bins_discretiser.py
│ ├── k_means.py
│ ├── kernel_pca.py
│ ├── label_binariser.py
│ ├── label_encoder.py
│ ├── linear_classifier.py
│ ├── linear_regressor.py
│ ├── local_outlier_factor.py
│ ├── mixture.py
│ ├── multioutput.py
│ ├── multiply.py
│ ├── nearest_neighbours.py
│ ├── one_hot_encoder.py
│ ├── one_vs_one_classifier.py
│ ├── one_vs_rest_classifier.py
│ ├── ordinal_encoder.py
│ ├── ovr_decision_function.py
│ ├── pipelines.py
│ ├── polynomial_features.py
│ ├── power_transformer.py
│ ├── quadratic_discriminant_analysis.py
│ ├── quantile_transformer.py
│ ├── random_projection.py
│ ├── random_trees_embedding.py
│ ├── replace_op.py
│ ├── scaler.py
│ ├── sequence.py
│ ├── sgd_oneclass_svm.py
│ ├── support_vector_machines.py
│ ├── svd.py
│ ├── target_encoder.py
│ ├── text_vectorizer.py
│ ├── tfidf_transformer.py
│ ├── tuned_threshold_classifier.py
│ ├── voting_classifier.py
│ ├── voting_regressor.py
│ └── zip_map.py
├── sklapi
│ ├── __init__.py
│ ├── cast_regressor.py
│ ├── cast_transformer.py
│ ├── register.py
│ ├── replace_transformer.py
│ ├── sklearn_text.py
│ ├── sklearn_text_onnx.py
│ ├── woe_transformer.py
│ └── woe_transformer_onnx.py
└── tutorial
│ ├── __init__.py
│ ├── benchmark.py
│ └── imagenet_classes.py
├── tests
├── benchmark.py
├── datasets
│ ├── small_titanic.csv
│ ├── treecl.onnx
│ ├── treecl2.onnx
│ └── treecl3.onnx
├── test_algebra_cascade.py
├── test_algebra_complex.py
├── test_algebra_converters.py
├── test_algebra_custom_model.py
├── test_algebra_custom_model_sub_estimator.py
├── test_algebra_deprecation.py
├── test_algebra_double.py
├── test_algebra_onnx_doc.py
├── test_algebra_onnx_operator_mixin_syntax.py
├── test_algebra_onnx_operators.py
├── test_algebra_onnx_operators_if.py
├── test_algebra_onnx_operators_opset.py
├── test_algebra_onnx_operators_scan.py
├── test_algebra_onnx_operators_sparse.py
├── test_algebra_onnx_operators_sub_estimator.py
├── test_algebra_onnx_operators_wrapped.py
├── test_algebra_symbolic.py
├── test_algebra_test_helper.py
├── test_algebra_to_onnx.py
├── test_convert.py
├── test_convert_options.py
├── test_custom_transformer_ordwoe.py
├── test_custom_transformer_tsne.py
├── test_investigate.py
├── test_issues_2024.py
├── test_issues_2025.py
├── test_onnx_helper.py
├── test_onnx_rare_helper.py
├── test_onnxruntime.py
├── test_op10.py
├── test_opset13.py
├── test_optimisation.py
├── test_options.py
├── test_other_converter_library_pipelines.py
├── test_parsing_options.py
├── test_raw_name.py
├── test_scikit_pandas.py
├── test_shapes.py
├── test_sklearn_adaboost_converter.py
├── test_sklearn_array_feature_extractor.py
├── test_sklearn_bagging_converter.py
├── test_sklearn_binarizer_converter.py
├── test_sklearn_calibrated_classifier_cv_converter.py
├── test_sklearn_cast_regressor.py
├── test_sklearn_cast_transformer.py
├── test_sklearn_classifiers_extreme.py
├── test_sklearn_concat.py
├── test_sklearn_constant_predictor.py
├── test_sklearn_count_vectorizer_converter.py
├── test_sklearn_count_vectorizer_converter_bug.py
├── test_sklearn_custom_nmf.py
├── test_sklearn_decision_tree_converters.py
├── test_sklearn_dict_vectorizer_converter.py
├── test_sklearn_documentation.py
├── test_sklearn_double_tensor_type_cls.py
├── test_sklearn_double_tensor_type_reg.py
├── test_sklearn_double_tensor_type_tr.py
├── test_sklearn_feature_hasher.py
├── test_sklearn_feature_selection_converters.py
├── test_sklearn_feature_union.py
├── test_sklearn_function_transformer_converter.py
├── test_sklearn_gamma_regressor.py
├── test_sklearn_gaussian_mixture_converter.py
├── test_sklearn_gaussian_process_classifier.py
├── test_sklearn_gaussian_process_regressor.py
├── test_sklearn_glm_classifier_converter.py
├── test_sklearn_glm_regressor_converter.py
├── test_sklearn_gradient_boosting_converters.py
├── test_sklearn_grid_search_cv_converter.py
├── test_sklearn_imputer_converter.py
├── test_sklearn_isolation_forest.py
├── test_sklearn_k_bins_discretiser_converter.py
├── test_sklearn_k_means_converter.py
├── test_sklearn_kernel_pca_converter.py
├── test_sklearn_label_binariser_converter.py
├── test_sklearn_label_encoder_converter.py
├── test_sklearn_local_outlier_factor.py
├── test_sklearn_mlp_converter.py
├── test_sklearn_multi_output.py
├── test_sklearn_naive_bayes_converter.py
├── test_sklearn_nearest_neighbour_converter.py
├── test_sklearn_normalizer_converter.py
├── test_sklearn_one_hot_encoder_converter.py
├── test_sklearn_one_vs_one_classifier_converter.py
├── test_sklearn_one_vs_rest_classifier_converter.py
├── test_sklearn_ordinal_encoder.py
├── test_sklearn_passive_aggressive_classifier_converter.py
├── test_sklearn_pca_converter.py
├── test_sklearn_perceptron_converter.py
├── test_sklearn_pipeline.py
├── test_sklearn_pipeline_concat_tfidf.py
├── test_sklearn_pipeline_within_pipeline.py
├── test_sklearn_pls_regression.py
├── test_sklearn_polynomial_features_converter.py
├── test_sklearn_power_transformer.py
├── test_sklearn_quadratic_discriminant_analysis_converter.py
├── test_sklearn_quantile_transformer.py
├── test_sklearn_random_forest_converters.py
├── test_sklearn_random_projection.py
├── test_sklearn_random_trees_embedding.py
├── test_sklearn_replace_transformer.py
├── test_sklearn_scaler_converter.py
├── test_sklearn_sgd_classifier_converter.py
├── test_sklearn_sgd_oneclass_svm_converter.py
├── test_sklearn_stacking.py
├── test_sklearn_svm_converters.py
├── test_sklearn_target_encoder_converter.py
├── test_sklearn_text.py
├── test_sklearn_tfidf_transformer_converter.py
├── test_sklearn_tfidf_transformer_converter_sparse.py
├── test_sklearn_tfidf_vectorizer_converter.py
├── test_sklearn_tfidf_vectorizer_converter_char.py
├── test_sklearn_tfidf_vectorizer_converter_dataset.py
├── test_sklearn_tfidf_vectorizer_converter_pipeline.py
├── test_sklearn_tfidf_vectorizer_converter_regex.py
├── test_sklearn_truncated_svd.py
├── test_sklearn_tuned_threshold_classifier.py
├── test_sklearn_voting_classifier_converter.py
├── test_sklearn_voting_regressor_converter.py
├── test_sklearn_woe_transformer.py
├── test_supported_converters.py
├── test_topology_prune.py
├── test_utils
│ ├── __init__.py
│ ├── main.py
│ ├── reference_implementation_afe.py
│ ├── reference_implementation_helper.py
│ ├── reference_implementation_ml.py
│ ├── reference_implementation_svm.py
│ ├── reference_implementation_text.py
│ ├── reference_implementation_tree.py
│ ├── reference_implementation_zipmap.py
│ ├── tests_helper.py
│ ├── utils_backend.py
│ ├── utils_backend_onnx.py
│ └── utils_backend_onnxruntime.py
├── test_utils_sklearn.py
└── test_variable_names.py
└── tests_onnxmltools
├── test_columns.py
├── test_issues_onnxmltools_2024.py
├── test_lightgbm.py
└── test_xgboost_converters.py
/.github/workflows/black-ruff.yml:
--------------------------------------------------------------------------------
1 | name: Black Format Checker
2 | on: [push, pull_request]
3 | jobs:
4 | black-format-check:
5 | runs-on: ubuntu-latest
6 | steps:
7 | - uses: actions/checkout@v2
8 | - uses: psf/black@stable
9 | with:
10 | options: "--diff --check"
11 | src: "."
12 | ruff-format-check:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v3
16 | - uses: chartboost/ruff-action@v1
17 |
--------------------------------------------------------------------------------
/.github/workflows/codeql.yml:
--------------------------------------------------------------------------------
1 | name: "CodeQL"
2 |
3 | on:
4 | push:
5 | branches: [ 'main' ]
6 | pull_request:
7 | # The branches below must be a subset of the branches above
8 | branches: [ 'main' ]
9 | schedule:
10 | - cron: '14 5 * * 6'
11 |
12 | jobs:
13 | analyze:
14 | name: Analyze
15 | runs-on: ubuntu-latest
16 | permissions:
17 | actions: read
18 | contents: read
19 | security-events: write
20 |
21 | strategy:
22 | fail-fast: false
23 | matrix:
24 | language: [ 'python' ]
25 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
26 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
27 |
28 | steps:
29 | - name: Checkout repository
30 | uses: actions/checkout@v3
31 |
32 | # Initializes the CodeQL tools for scanning.
33 | - name: Initialize CodeQL
34 | uses: github/codeql-action/init@v2
35 | with:
36 | languages: ${{ matrix.language }}
37 | # If you wish to specify custom queries, you can do so here or in a config file.
38 | # By default, queries listed here will override any specified in a config file.
39 | # Prefix the list here with "+" to use these queries and those in the config file.
40 |
41 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
42 | queries: +security-and-quality
43 |
44 |
45 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java).
46 | # If this step fails, then you should remove it and run the build manually (see below)
47 | - name: Autobuild
48 | uses: github/codeql-action/autobuild@v2
49 |
50 | # ℹ️ Command-line programs to run using the OS shell.
51 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
52 |
53 | # If the Autobuild fails above, remove it and uncomment the following three lines.
54 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.
55 |
56 | # - run: |
57 | # echo "Run, Build Application using script"
58 | # ./location_of_script_within_repo/buildscript.sh
59 |
60 | - name: Perform CodeQL Analysis
61 | uses: github/codeql-action/analyze@v2
62 | with:
63 | category: "/language:${{matrix.language}}"
64 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 |
21 | runs-on: ubuntu-latest
22 | permissions:
23 | id-token: write
24 | environment: release
25 | steps:
26 | - uses: actions/checkout@v4
27 | - name: Set up Python
28 | uses: actions/setup-python@v5
29 | with:
30 | python-version: '3.x'
31 |
32 | - name: Install dependencies
33 | run: |
34 | python -m pip install --upgrade pip
35 | pip install build
36 |
37 | - name: Build package
38 | run: python -m build
39 |
40 | - name: Publish package
41 | uses: pypa/gh-action-pypi-publish@release/v1
42 | with:
43 | attestations: true
44 |
--------------------------------------------------------------------------------
/.github/workflows/wheels-any.yml:
--------------------------------------------------------------------------------
1 | name: Build Any Wheel
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - 'releases/**'
8 |
9 | jobs:
10 | build_wheels:
11 | name: Build wheels on ${{ matrix.os }}
12 | runs-on: ${{ matrix.os }}
13 | strategy:
14 | matrix:
15 | os: [ubuntu-latest]
16 |
17 | steps:
18 | - uses: actions/checkout@v3
19 |
20 | - uses: actions/setup-python@v4
21 | with:
22 | python-version: '3.12'
23 |
24 | - name: build wheel
25 | run: python -m pip wheel . -v
26 |
27 | - name: install twine
28 | run: python -m pip install twine
29 |
30 | - name: check wheel
31 | run: python -m twine check ./skl2onnx*.whl
32 |
33 | - uses: actions/upload-artifact@v4
34 | with:
35 | path: ./skl2onnx*.whl
36 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Visual Studio Code files
2 | .vscode
3 |
4 | # IPython notebook checkpoints
5 | .ipynb_checkpoints
6 |
7 | # Compiled python
8 | *.pyc
9 |
10 | # setup.py intermediates
11 | .eggs
12 | *.egg-info/
13 | dist/
14 | build/
15 |
16 | # PyCharm files
17 | .idea
18 |
19 | # OSX dir files
20 | .DS_Store
21 |
22 | # Windows
23 | *.bat
24 |
25 | # test generated files
26 | *.onnx
27 | *.dot*
28 | *.whl
29 | .pytest_cache
30 | .cache
31 | htmlcov
32 | coverage.xml
33 | .coverage
34 | __dump_data/
35 | junit/
36 | tests_dump/
37 | skl2onnx/algebra/_cache/*.rst
38 | docs/auto_examples
39 | docs/examples/graph*.*
40 | docs/examples/*.onnx
41 | docs/examples/pipeline*.dot*
42 | docs/sg_execution_times.rst
43 | tests/TESTDUMP
44 | tests/tests_dump
45 | tests/graph.dot*
46 | docs/examples/tiny_yolov2*
47 | docs/examples/imagenet_class_index.json
48 | TESTDUMP/*
49 | htmlcov/*
50 | tests/temp_onnx_helper_load_save.onnx
51 | tests/*.new
52 | benchmarks/*.csv
53 | benchmarks/*.png
54 | tests/Operators*.md
55 | docs/examples/*.pkl
56 | tests/debug_gp.onnx
57 | tests/test*.onnx
58 | tests_onnxmltools/*.pkl
59 | tests_onnxmltools/tests/*
60 | tests_onnxmltools/tests_dump/*
61 | docs/tests/*.dot*
62 | tests/*.dot*
63 | tests/*.onnx
64 | docs/tests/*.onnx
65 | docs/examples/validator_classifier.dot.png
66 | docs/examples/validator_classifier.dot
67 | docs/examples/mixture*.*
68 | docs/examples/cast1*
69 | docs/examples/cast2*
70 | docs/auto_tutorial
71 | docs/tutorial/*.onnx
72 | docs/tutorial/*.jpg
73 | docs/tutorial/*.png
74 | docs/tutorial/*.dot
75 | docs/tutorial/catboost_info
76 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | # include
2 | include *.rst
3 | recursive-include docs *
4 | recursive-include tests *
5 | include LICENSE
6 | include README.md
7 | include MANIFEST.in
8 | include requirements.txt
9 | include skl2onnx/algebra/_cache/*.rst
10 |
11 | # exclude from sdist
12 | recursive-exclude benchmarks *
13 | recursive-exclude .azure-pipelines *
14 | recursive-exclude tests/tests_dump *
15 | recursive-exclude tests_onnxmltools/tests_dump *
16 | recursive-exclude tests/test_utils/__pycache__ *
17 | recursive-exclude docs/notebooks *
18 | exclude *.onnx
19 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | sklearn-onnx
2 | Copyright (c) 2018-2023 Microsoft Corporation
3 |
4 | This product includes software developed at
5 | The LF AI & Data Foundation (https://lfaidata.foundation/).
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |

4 |
5 | [](https://pypi.org/project/skl2onnx)
6 | [](https://github.com/onnx/sklearn-onnx/actions/workflows/linux-ci.yml)
7 | [](https://github.com/onnx/sklearn-onnx/actions/workflows/windows-macos-ci.yml)
8 | [](https://github.com/psf/black)
9 |
10 | ## Introduction
11 | *sklearn-onnx* converts [scikit-learn](https://scikit-learn.org/stable/) models to [ONNX](https://github.com/onnx/onnx).
12 | Once in the ONNX format, you can use tools like [ONNX Runtime](https://github.com/Microsoft/onnxruntime) for high performance scoring.
13 | All converters are tested with [onnxruntime](https://onnxruntime.ai/).
14 | Any external converter can be registered to convert scikit-learn pipeline
15 | including models or transformers coming from external libraries.
16 |
17 | ## Documentation
18 | Full documentation including tutorials is available at [https://onnx.ai/sklearn-onnx/](https://onnx.ai/sklearn-onnx/).
19 | [Supported scikit-learn Models](https://onnx.ai/sklearn-onnx/supported.html)
20 | Last supported opset is 21.
21 |
22 | You may also find answers in [existing issues](https://github.com/onnx/sklearn-onnx/issues?utf8=%E2%9C%93&q=is%3Aissue)
23 | or submit a new one.
24 |
25 | ## Installation
26 | You can install from [PyPi](https://pypi.org/project/skl2onnx/):
27 | ```
28 | pip install skl2onnx
29 | ```
30 | Or you can install from the source with the latest changes.
31 | ```
32 | pip install git+https://github.com/onnx/sklearn-onnx.git
33 | ```
34 |
35 | ## Getting started
36 |
37 | ```python
38 | # Train a model.
39 | import numpy as np
40 | from sklearn.datasets import load_iris
41 | from sklearn.model_selection import train_test_split
42 | from sklearn.ensemble import RandomForestClassifier
43 |
44 | iris = load_iris()
45 | X, y = iris.data, iris.target
46 | X = X.astype(np.float32)
47 | X_train, X_test, y_train, y_test = train_test_split(X, y)
48 | clr = RandomForestClassifier()
49 | clr.fit(X_train, y_train)
50 |
51 | # Convert into ONNX format.
52 | from skl2onnx import to_onnx
53 |
54 | onx = to_onnx(clr, X[:1])
55 | with open("rf_iris.onnx", "wb") as f:
56 | f.write(onx.SerializeToString())
57 |
58 | # Compute the prediction with onnxruntime.
59 | import onnxruntime as rt
60 |
61 | sess = rt.InferenceSession("rf_iris.onnx", providers=["CPUExecutionProvider"])
62 | input_name = sess.get_inputs()[0].name
63 | label_name = sess.get_outputs()[0].name
64 | pred_onx = sess.run([label_name], {input_name: X_test.astype(np.float32)})[0]
65 | ```
66 |
67 | ## Contribute
68 | We welcome contributions in the form of feedback, ideas, or code.
69 |
70 | ## License
71 | [Apache License v2.0](LICENSE)
72 |
--------------------------------------------------------------------------------
/benchmarks/README.txt:
--------------------------------------------------------------------------------
1 | To run the benchmark:
2 |
3 | All benchmarks produces csv files written in subfolder *results*.
4 | Benchmark can be run the following way:
5 |
6 | ::
7 |
8 | python bench_plot_onnxruntime_linreg.py
9 | python bench_plot_onnxruntime_logreg.py
10 | python bench_plot_onnxruntime_random_forest_reg.py
11 | python bench_plot_onnxruntime_svm_reg.py
12 |
13 | In subfolder *results*, script post_graph produces
14 | graph for each of them.
15 |
16 | ::
17 |
18 | python results/post_graph.py
19 |
20 |
--------------------------------------------------------------------------------
/docs/examples/Au-Salon-de-l-agriculture-la-campagne-recrute.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/sklearn-onnx/eaac0e13333962a2391a33c9d5192e382b7a985d/docs/examples/Au-Salon-de-l-agriculture-la-campagne-recrute.jpg
--------------------------------------------------------------------------------
/docs/examples/README.txt:
--------------------------------------------------------------------------------
1 | Gallery of examples
2 | ===================
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
--------------------------------------------------------------------------------
/docs/examples/daisy_wikipedia.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/sklearn-onnx/eaac0e13333962a2391a33c9d5192e382b7a985d/docs/examples/daisy_wikipedia.jpg
--------------------------------------------------------------------------------
/docs/examples/plot_backend.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 |
6 | .. _l-example-backend-api:
7 |
8 | ONNX Runtime Backend for ONNX
9 | =============================
10 |
11 | .. index:: backend
12 |
13 | *ONNX Runtime* extends the
14 | `onnx backend API `_
16 | to run predictions using this runtime.
17 | Let's use the API to compute the prediction
18 | of a simple logistic regression model.
19 | """
20 |
21 | import skl2onnx
22 | import onnxruntime
23 | import onnx
24 | import sklearn
25 | from sklearn.datasets import load_iris
26 | from sklearn.linear_model import LogisticRegression
27 | import numpy
28 | from onnxruntime import get_device
29 | import numpy as np
30 | import onnxruntime.backend as backend
31 |
32 |
33 | #######################################
34 | # Let's create an ONNX graph first.
35 |
36 | data = load_iris()
37 | X, Y = data.data, data.target
38 | logreg = LogisticRegression(C=1e5).fit(X, Y)
39 | model = skl2onnx.to_onnx(logreg, X.astype(np.float32))
40 | name = "logreg_iris.onnx"
41 | with open(name, "wb") as f:
42 | f.write(model.SerializeToString())
43 |
44 | #######################################
45 | # Let's use ONNX backend API to test it.
46 |
47 | model = onnx.load(name)
48 | rep = backend.prepare(model)
49 | x = np.array(
50 | [[-1.0, -2.0, 5.0, 6.0], [-1.0, -2.0, -3.0, -4.0], [-1.0, -2.0, 7.0, 8.0]],
51 | dtype=np.float32,
52 | )
53 | label, proba = rep.run(x)
54 | print("label={}".format(label))
55 | print("probabilities={}".format(proba))
56 |
57 | ########################################
58 | # The device depends on how the package was compiled,
59 | # GPU or CPU.
60 | print(get_device())
61 |
62 | ########################################
63 | # The backend can also directly load the model
64 | # without using *onnx*.
65 |
66 | rep = backend.prepare(name)
67 | x = np.array(
68 | [[-1.0, -2.0, -3.0, -4.0], [-1.0, -2.0, -3.0, -4.0], [-1.0, -2.0, -3.0, -4.0]],
69 | dtype=np.float32,
70 | )
71 | label, proba = rep.run(x)
72 | print("label={}".format(label))
73 | print("probabilities={}".format(proba))
74 |
75 | #######################################
76 | # The backend API is implemented by other frameworks
77 | # and makes it easier to switch between multiple runtimes
78 | # with the same API.
79 |
80 | #################################
81 | # **Versions used for this example**
82 |
83 | print("numpy:", numpy.__version__)
84 | print("scikit-learn:", sklearn.__version__)
85 | print("onnx: ", onnx.__version__)
86 | print("onnxruntime: ", onnxruntime.__version__)
87 | print("skl2onnx: ", skl2onnx.__version__)
88 |
--------------------------------------------------------------------------------
/docs/examples/plot_convert_decision_function.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 | .. _l-rf-example-decision-function:
6 |
7 | Probabilities or raw scores
8 | ===========================
9 |
10 | A classifier usually returns a matrix of probabilities.
11 | By default, *sklearn-onnx* creates an ONNX graph
12 | which returns probabilities but it may skip that
13 | step and return raw scores if the model implements
14 | the method *decision_function*. Option ``'raw_scores'``
15 | is used to change the default behaviour. Let's see
16 | that on a simple example.
17 |
18 | Train a model and convert it
19 | ++++++++++++++++++++++++++++
20 |
21 | """
22 |
23 | import numpy
24 | import sklearn
25 | from sklearn.datasets import load_iris
26 | from sklearn.model_selection import train_test_split
27 | import onnxruntime as rt
28 | import onnx
29 | import skl2onnx
30 | from skl2onnx.common.data_types import FloatTensorType
31 | from skl2onnx import convert_sklearn
32 | from sklearn.linear_model import LogisticRegression
33 |
34 | iris = load_iris()
35 | X, y = iris.data, iris.target
36 | X_train, X_test, y_train, y_test = train_test_split(X, y)
37 | clr = LogisticRegression(max_iter=500)
38 | clr.fit(X_train, y_train)
39 | print(clr)
40 |
41 | initial_type = [("float_input", FloatTensorType([None, 4]))]
42 | onx = convert_sklearn(clr, initial_types=initial_type, target_opset=12)
43 |
44 | ############################
45 | # Output type
46 | # +++++++++++
47 | #
48 | # Let's confirm the output type of the probabilities
49 | # is a list of dictionaries with onnxruntime.
50 |
51 | sess = rt.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
52 | res = sess.run(None, {"float_input": X_test.astype(numpy.float32)})
53 | print("skl", clr.predict_proba(X_test[:1]))
54 | print("onnx", res[1][:2])
55 |
56 | ###################################
57 | # Raw scores and decision_function
58 | # ++++++++++++++++++++++++++++++++
59 | #
60 |
61 | initial_type = [("float_input", FloatTensorType([None, 4]))]
62 | options = {id(clr): {"raw_scores": True}}
63 | onx2 = convert_sklearn(
64 | clr, initial_types=initial_type, options=options, target_opset=12
65 | )
66 |
67 | sess2 = rt.InferenceSession(
68 | onx2.SerializeToString(), providers=["CPUExecutionProvider"]
69 | )
70 | res2 = sess2.run(None, {"float_input": X_test.astype(numpy.float32)})
71 | print("skl", clr.decision_function(X_test[:1]))
72 | print("onnx", res2[1][:2])
73 |
74 | #################################
75 | # **Versions used for this example**
76 |
77 | print("numpy:", numpy.__version__)
78 | print("scikit-learn:", sklearn.__version__)
79 | print("onnx: ", onnx.__version__)
80 | print("onnxruntime: ", rt.__version__)
81 | print("skl2onnx: ", skl2onnx.__version__)
82 |
--------------------------------------------------------------------------------
/docs/examples/plot_convert_model.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 | .. _l-rf-iris-example:
6 |
7 | Train, convert and predict a model
8 | ==================================
9 |
10 | Train and deploy a model usually involves the
11 | three following steps:
12 |
13 | * train a pipeline with *scikit-learn*,
14 | * convert it into *ONNX* with *sklearn-onnx*,
15 | * predict with *onnxruntime*.
16 |
17 | Train a model
18 | +++++++++++++
19 |
20 | A very basic example using random forest and
21 | the iris dataset.
22 | """
23 |
24 | import skl2onnx
25 | import onnx
26 | import sklearn
27 | from sklearn.linear_model import LogisticRegression
28 | import numpy
29 | import onnxruntime as rt
30 | from skl2onnx.common.data_types import FloatTensorType
31 | from skl2onnx import convert_sklearn
32 | from sklearn.datasets import load_iris
33 | from sklearn.model_selection import train_test_split
34 | from sklearn.ensemble import RandomForestClassifier
35 |
36 | iris = load_iris()
37 | X, y = iris.data, iris.target
38 | X_train, X_test, y_train, y_test = train_test_split(X, y)
39 | clr = RandomForestClassifier()
40 | clr.fit(X_train, y_train)
41 | print(clr)
42 |
43 | ###########################
44 | # Convert a model into ONNX
45 | # +++++++++++++++++++++++++
46 |
47 | initial_type = [("float_input", FloatTensorType([None, 4]))]
48 | onx = convert_sklearn(clr, initial_types=initial_type, target_opset=12)
49 |
50 | with open("rf_iris.onnx", "wb") as f:
51 | f.write(onx.SerializeToString())
52 |
53 | ###################################
54 | # Compute the prediction with ONNX Runtime
55 | # ++++++++++++++++++++++++++++++++++++++++
56 | sess = rt.InferenceSession("rf_iris.onnx", providers=["CPUExecutionProvider"])
57 | input_name = sess.get_inputs()[0].name
58 | label_name = sess.get_outputs()[0].name
59 | pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
60 | print(pred_onx)
61 |
62 | #######################################
63 | # Full example with a logistic regression
64 |
65 | clr = LogisticRegression()
66 | clr.fit(X_train, y_train)
67 | initial_type = [("float_input", FloatTensorType([None, X_train.shape[1]]))]
68 | onx = convert_sklearn(clr, initial_types=initial_type, target_opset=12)
69 | with open("logreg_iris.onnx", "wb") as f:
70 | f.write(onx.SerializeToString())
71 |
72 | sess = rt.InferenceSession("logreg_iris.onnx", providers=["CPUExecutionProvider"])
73 | input_name = sess.get_inputs()[0].name
74 | label_name = sess.get_outputs()[0].name
75 | pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
76 | print(pred_onx)
77 |
78 |
79 | #################################
80 | # **Versions used for this example**
81 |
82 | print("numpy:", numpy.__version__)
83 | print("scikit-learn:", sklearn.__version__)
84 | print("onnx: ", onnx.__version__)
85 | print("onnxruntime: ", rt.__version__)
86 | print("skl2onnx: ", skl2onnx.__version__)
87 |
--------------------------------------------------------------------------------
/docs/examples/plot_logging.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 | .. _l-example-logging:
6 |
7 | Logging, verbose
8 | ================
9 |
10 | The conversion of a pipeline fails if it contains an object without any
11 | associated converter. It may also fails if one of the object is mapped
12 | by a custom converter. If the error message is not explicit enough,
13 | it is possible to enable logging.
14 |
15 | Train a model
16 | +++++++++++++
17 |
18 | A very basic example using random forest and
19 | the iris dataset.
20 | """
21 |
22 | import logging
23 | import numpy
24 | import onnx
25 | import onnxruntime as rt
26 | import sklearn
27 | from sklearn.datasets import load_iris
28 | from sklearn.model_selection import train_test_split
29 | from sklearn.tree import DecisionTreeClassifier
30 | from skl2onnx.common.data_types import FloatTensorType
31 | from skl2onnx import convert_sklearn
32 | import skl2onnx
33 |
34 | iris = load_iris()
35 | X, y = iris.data, iris.target
36 | X_train, X_test, y_train, y_test = train_test_split(X, y)
37 | clr = DecisionTreeClassifier()
38 | clr.fit(X_train, y_train)
39 | print(clr)
40 |
41 | ###########################
42 | # Convert a model into ONNX
43 | # +++++++++++++++++++++++++
44 |
45 | initial_type = [("float_input", FloatTensorType([None, 4]))]
46 | onx = convert_sklearn(clr, initial_types=initial_type, target_opset=12)
47 |
48 |
49 | sess = rt.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
50 | input_name = sess.get_inputs()[0].name
51 | label_name = sess.get_outputs()[0].name
52 | pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
53 | print(pred_onx)
54 |
55 | ########################################
56 | # Conversion with parameter verbose
57 | # +++++++++++++++++++++++++++++++++
58 | #
59 | # verbose is a parameter which prints messages on the standard output.
60 | # It tells which converter is called. `verbose=1` usually means what *skl2onnx*
61 | # is doing to convert a pipeline. `verbose=2+`
62 | # is reserved for information within converters.
63 |
64 | convert_sklearn(clr, initial_types=initial_type, target_opset=12, verbose=1)
65 |
66 | ########################################
67 | # Conversion with logging
68 | # +++++++++++++++++++++++
69 | #
70 | # This is very detailed logging. It which operators or variables
71 | # (output of converters) is processed, which node is created...
72 | # This information may be useful when a custom converter is being
73 | # implemented.
74 |
75 | logger = logging.getLogger("skl2onnx")
76 | logger.setLevel(logging.DEBUG)
77 |
78 | convert_sklearn(clr, initial_types=initial_type, target_opset=12)
79 |
80 | ###########################
81 | # And to disable it.
82 |
83 | logger.setLevel(logging.INFO)
84 |
85 | convert_sklearn(clr, initial_types=initial_type, target_opset=12)
86 |
87 | logger.setLevel(logging.WARNING)
88 |
89 | #################################
90 | # **Versions used for this example**
91 |
92 | print("numpy:", numpy.__version__)
93 | print("scikit-learn:", sklearn.__version__)
94 | print("onnx: ", onnx.__version__)
95 | print("onnxruntime: ", rt.__version__)
96 | print("skl2onnx: ", skl2onnx.__version__)
97 |
--------------------------------------------------------------------------------
/docs/examples/plot_metadata.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 | Metadata
6 | ========
7 |
8 | .. index:: metadata
9 |
10 | ONNX format contains metadata related to how the
11 | model was produced. It is useful when the model
12 | is deployed to production to keep track of which
13 | instance was used at a specific time.
14 | Let's see how to do that with a simple
15 | logistic regression model trained with
16 | *scikit-learn*.
17 | """
18 |
19 | import skl2onnx
20 | import onnxruntime
21 | import sklearn
22 | import numpy
23 | from onnxruntime import InferenceSession
24 | import onnx
25 | from onnxruntime.datasets import get_example
26 |
27 | example = get_example("logreg_iris.onnx")
28 |
29 | model = onnx.load(example)
30 |
31 | print("doc_string={}".format(model.doc_string))
32 | print("domain={}".format(model.domain))
33 | print("ir_version={}".format(model.ir_version))
34 | print("metadata_props={}".format(model.metadata_props))
35 | print("model_version={}".format(model.model_version))
36 | print("producer_name={}".format(model.producer_name))
37 | print("producer_version={}".format(model.producer_version))
38 |
39 | #############################
40 | # With *ONNX Runtime*:
41 |
42 | sess = InferenceSession(example, providers=["CPUExecutionProvider"])
43 | meta = sess.get_modelmeta()
44 |
45 | print("custom_metadata_map={}".format(meta.custom_metadata_map))
46 | print("description={}".format(meta.description))
47 | print("domain={}".format(meta.domain))
48 | print("graph_name={}".format(meta.graph_name))
49 | print("producer_name={}".format(meta.producer_name))
50 | print("version={}".format(meta.version))
51 |
52 | #################################
53 | # **Versions used for this example**
54 |
55 | print("numpy:", numpy.__version__)
56 | print("scikit-learn:", sklearn.__version__)
57 | print("onnx: ", onnx.__version__)
58 | print("onnxruntime: ", onnxruntime.__version__)
59 | print("skl2onnx: ", skl2onnx.__version__)
60 |
--------------------------------------------------------------------------------
/docs/examples/plot_pipeline.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 | Draw a pipeline
6 | ===============
7 |
8 | There is no other way to look into one model stored
9 | in ONNX format than looking into its node with
10 | *onnx*. This example demonstrates
11 | how to draw a model and to retrieve it in *json*
12 | format.
13 |
14 | Retrieve a model in JSON format
15 | +++++++++++++++++++++++++++++++
16 |
17 | That's the most simple way.
18 | """
19 |
20 | import skl2onnx
21 | import onnxruntime
22 | import sklearn
23 | import numpy
24 | import matplotlib.pyplot as plt
25 | import os
26 | from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
27 | from onnx import ModelProto
28 | import onnx
29 | from skl2onnx.algebra.onnx_ops import OnnxAdd, OnnxMul
30 |
31 | onnx_fct = OnnxAdd(
32 | OnnxMul("X", numpy.array([2], dtype=numpy.float32), op_version=12),
33 | numpy.array([[1, 0], [0, 1]], dtype=numpy.float32),
34 | output_names=["Y"],
35 | op_version=12,
36 | )
37 |
38 | X = numpy.array([[4, 5], [-2, 3]], dtype=numpy.float32)
39 | model = onnx_fct.to_onnx({"X": X}, target_opset=12)
40 | print(model)
41 |
42 | filename = "example1.onnx"
43 | with open(filename, "wb") as f:
44 | f.write(model.SerializeToString())
45 |
46 |
47 | #################################
48 | # Draw a model with ONNX
49 | # ++++++++++++++++++++++
50 | # We use `net_drawer.py
51 | # `_
52 | # included in *onnx* package.
53 | # We use *onnx* to load the model
54 | # in a different way than before.
55 |
56 |
57 | model = ModelProto()
58 | with open(filename, "rb") as fid:
59 | content = fid.read()
60 | model.ParseFromString(content)
61 |
62 | ###################################
63 | # We convert it into a graph.
64 | pydot_graph = GetPydotGraph(
65 | model.graph,
66 | name=model.graph.name,
67 | rankdir="TB",
68 | node_producer=GetOpNodeProducer("docstring"),
69 | )
70 | pydot_graph.write_dot("graph.dot")
71 |
72 | #######################################
73 | # Then into an image
74 | os.system("dot -O -Tpng graph.dot")
75 |
76 | ################################
77 | # Which we display...
78 | image = plt.imread("graph.dot.png")
79 | plt.imshow(image)
80 | plt.axis("off")
81 |
82 | #################################
83 | # **Versions used for this example**
84 |
85 | print("numpy:", numpy.__version__)
86 | print("scikit-learn:", sklearn.__version__)
87 | print("onnx: ", onnx.__version__)
88 | print("onnxruntime: ", onnxruntime.__version__)
89 | print("skl2onnx: ", skl2onnx.__version__)
90 |
--------------------------------------------------------------------------------
/docs/exts/github_link.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # Source: https://github.com/scikit-learn/scikit-learn/blob/
4 | # main/doc/sphinxext/github_link.py
5 | from operator import attrgetter
6 | import inspect
7 | import subprocess
8 | import os
9 | import sys
10 | from functools import partial
11 |
12 | REVISION_CMD = "git rev-parse --short HEAD"
13 |
14 |
15 | def _get_git_revision():
16 | try:
17 | revision = subprocess.check_output(REVISION_CMD.split()).strip()
18 | except (subprocess.CalledProcessError, OSError):
19 | print("Failed to execute git to get revision")
20 | return None
21 | return revision.decode("utf-8")
22 |
23 |
24 | def _linkcode_resolve(domain, info, package, url_fmt, revision):
25 | """Determine a link to online source for a class/method/function
26 | This is called by sphinx.ext.linkcode
27 | An example with a long-untouched module that everyone has
28 | >>> _linkcode_resolve('py', {'module': 'tty',
29 | ... 'fullname': 'setraw'},
30 | ... package='tty',
31 | ... url_fmt='http://hg.python.org/cpython/file/'
32 | ... '{revision}/Lib/{package}/{path}#L{lineno}',
33 | ... revision='xxxx')
34 | 'http://hg.python.org/cpython/file/xxxx/Lib/tty/tty.py#L18'
35 | """
36 |
37 | if revision is None:
38 | return
39 | if domain not in ("py", "pyx"):
40 | return
41 | if not info.get("module") or not info.get("fullname"):
42 | return
43 |
44 | class_name = info["fullname"].split(".")[0]
45 | module = __import__(info["module"], fromlist=[class_name])
46 | obj = attrgetter(info["fullname"])(module)
47 |
48 | # Unwrap the object to get the correct source
49 | # file in case that is wrapped by a decorator
50 | obj = inspect.unwrap(obj)
51 |
52 | try:
53 | fn = inspect.getsourcefile(obj)
54 | except Exception:
55 | fn = None
56 | if not fn:
57 | try:
58 | fn = inspect.getsourcefile(sys.modules[obj.__module__])
59 | except Exception:
60 | fn = None
61 | if not fn:
62 | return
63 |
64 | fn = os.path.relpath(fn, start=os.path.dirname(__import__(package).__file__))
65 | try:
66 | lineno = inspect.getsourcelines(obj)[1]
67 | except Exception:
68 | lineno = ""
69 | return url_fmt.format(revision=revision, package=package, path=fn, lineno=lineno)
70 |
71 |
72 | def make_linkcode_resolve(package, url_fmt):
73 | """Returns a linkcode_resolve function for the given URL format
74 | revision is a git commit reference (hash or name)
75 | package is the name of the root module of the package
76 | url_fmt is along the lines of ('https://github.com/USER/PROJECT/'
77 | 'blob/{revision}/{package}/'
78 | '{path}#L{lineno}')
79 | """
80 | revision = _get_git_revision()
81 | return partial(
82 | _linkcode_resolve, revision=revision, package=package, url_fmt=url_fmt
83 | )
84 |
--------------------------------------------------------------------------------
/docs/images/woe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/sklearn-onnx/eaac0e13333962a2391a33c9d5192e382b7a985d/docs/images/woe.png
--------------------------------------------------------------------------------
/docs/index_tutorial.rst:
--------------------------------------------------------------------------------
1 | .. SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | Tutorial
5 | ========
6 |
7 | .. index:: tutorial
8 |
9 | The tutorial goes from a simple example which
10 | converts a pipeline to a more complex example
11 | involving operator not actually implemented in
12 | :epkg:`ONNX operators` or :epkg:`ONNX ML operators`.
13 |
14 | .. toctree::
15 | :maxdepth: 2
16 |
17 | tutorial_1_simple
18 | tutorial_1-5_external
19 | tutorial_2_new_converter
20 | tutorial_4_advanced
21 | tutorial_2-5_extlib
22 |
23 | The tutorial was tested with following version:
24 |
25 | .. runpython::
26 | :showcode:
27 |
28 | try:
29 | import catboost
30 | except Exception as e:
31 | print("Unable to import catboost due to", e)
32 | catboost = None
33 | import numpy
34 | import scipy
35 | import sklearn
36 | import lightgbm
37 | import onnx
38 | import onnxmltools
39 | import onnxruntime
40 | import xgboost
41 | import skl2onnx
42 |
43 | mods = [numpy, scipy, sklearn, lightgbm, xgboost, catboost,
44 | onnx, onnxmltools, onnxruntime,
45 | skl2onnx]
46 | mods = [(m.__name__, m.__version__) for m in mods if m is not None]
47 | mx = max(len(_[0]) for _ in mods) + 1
48 | for name, vers in sorted(mods):
49 | print("%s%s%s" % (name, " " * (mx - len(name)), vers))
50 |
--------------------------------------------------------------------------------
/docs/logo_main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/sklearn-onnx/eaac0e13333962a2391a33c9d5192e382b7a985d/docs/logo_main.png
--------------------------------------------------------------------------------
/docs/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/sklearn-onnx/eaac0e13333962a2391a33c9d5192e382b7a985d/docs/pipeline.png
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | autopep8
2 | catboost
3 | category_encoders
4 | coverage
5 | flake8
6 | furo
7 | joblib
8 | lightgbm; sys_platform != 'darwin'
9 | loky
10 | matplotlib
11 | mlinsights>=0.3.631
12 | nbsphinx
13 | onnx
14 | onnx-array-api
15 | onnxmltools
16 | onnxruntime
17 | pillow
18 | py-spy
19 | pandas
20 | pydot
21 | pyinstrument
22 | pyod
23 | pytest
24 | pytest-cov
25 | skl2onnx
26 | sphinx
27 | sphinxcontrib-blockdiag
28 | sphinx-gallery
29 | sphinx-runpython
30 | tabulate
31 | tqdm
32 | wheel
33 | xgboost
34 |
--------------------------------------------------------------------------------
/docs/supported.rst:
--------------------------------------------------------------------------------
1 | .. SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | =============================
5 | Supported scikit-learn Models
6 | =============================
7 |
8 | *skl2onnx* currently can convert the following list
9 | of models for *skl2onnx* :skl2onnxversion:`v`. They
10 | were tested using *onnxruntime* :skl2onnxversion:`rt`.
11 | All the following classes overloads the following methods
12 | such as :class:`OnnxSklearnPipeline` does. They wrap existing
13 | *scikit-learn* classes by dynamically creating a new one
14 | which inherits from :class:`OnnxOperatorMixin` which
15 | implements *to_onnx* methods.
16 |
17 | .. _l-converter-list:
18 |
19 | Covered Converters
20 | ==================
21 |
22 | .. covered-sklearn-ops::
23 |
24 | Converters Documentation
25 | ========================
26 |
27 | .. supported-sklearn-ops::
28 |
29 | Pipeline
30 | ========
31 |
32 | .. autoclass:: skl2onnx.algebra.sklearn_ops.OnnxSklearnPipeline
33 | :members: to_onnx, to_onnx_operator, onnx_parser, onnx_shape_calculator, onnx_converter
34 |
35 | .. autoclass:: skl2onnx.algebra.sklearn_ops.OnnxSklearnColumnTransformer
36 | :members: to_onnx, to_onnx_operator, onnx_parser, onnx_shape_calculator, onnx_converter
37 |
38 | .. autoclass:: skl2onnx.algebra.sklearn_ops.OnnxSklearnFeatureUnion
39 | :members: to_onnx, to_onnx_operator, onnx_parser, onnx_shape_calculator, onnx_converter
40 |
41 | Available ONNX operators
42 | ========================
43 |
44 | *skl2onnx* maps every ONNX operators into a class
45 | easy to insert into a graph. These operators get
46 | dynamically added and the list depends on the installed
47 | *ONNX* package. The documentation for these operators
48 | can be found on github: `ONNX Operators.md
49 | `_
50 | and `ONNX-ML Operators
51 | `_.
52 | Associated to `onnxruntime `_,
53 | the mapping makes it easier to easily check the output
54 | of the *ONNX* operators on any data as shown
55 | in example :ref:`l-onnx-operators`.
56 |
57 | .. supported-onnx-ops::
58 |
--------------------------------------------------------------------------------
/docs/tests/test_utils_benchmark.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | @brief test log(time=3s)
5 | """
6 |
7 | import unittest
8 | import numpy
9 | from skl2onnx.tutorial import measure_time
10 |
11 |
12 | class TestMeasureTime(unittest.TestCase):
13 | def test_vector_count(self):
14 | def fct():
15 | X = numpy.ones((1000, 5))
16 | return X
17 |
18 | res = measure_time("fct", context={"fct": fct}, div_by_number=False, number=100)
19 | self.assertIn("average", res)
20 | res = measure_time("fct", context={"fct": fct}, div_by_number=True, number=100)
21 | self.assertIn("average", res)
22 | res = measure_time("fct", context={"fct": fct}, div_by_number=True, number=1000)
23 | self.assertIn("average", res)
24 |
25 |
26 | if __name__ == "__main__":
27 | unittest.main()
28 |
--------------------------------------------------------------------------------
/docs/tests/test_utils_classes.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | @brief test log(time=3s)
5 | """
6 |
7 | import unittest
8 | from skl2onnx.tutorial.imagenet_classes import class_names
9 |
10 |
11 | class TestUtilsClasses(unittest.TestCase):
12 | def test_classes(self):
13 | cl = class_names
14 | self.assertIsInstance(cl, dict)
15 | self.assertEqual(len(cl), 1000)
16 |
17 |
18 | if __name__ == "__main__":
19 | unittest.main()
20 |
--------------------------------------------------------------------------------
/docs/tutorial/README.txt:
--------------------------------------------------------------------------------
1 | Examples
2 | ========
3 |
--------------------------------------------------------------------------------
/docs/tutorial/plot_ngrams.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | .. _example-ngrams:
5 |
6 | Tricky issue when converting CountVectorizer or TfidfVectorizer
7 | ===============================================================
8 |
9 | This issue is described at `scikit-learn/issues/13733
10 | `_.
11 | If a CountVectorizer or a TfidfVectorizer produces a token with a space,
12 | skl2onnx cannot know if it a bi-grams or a unigram with a space.
13 |
14 | A simple example impossible to convert
15 | ++++++++++++++++++++++++++++++++++++++
16 | """
17 |
18 | import pprint
19 | import numpy
20 | from numpy.testing import assert_almost_equal
21 | from onnxruntime import InferenceSession
22 | from sklearn.feature_extraction.text import TfidfVectorizer
23 | from skl2onnx import to_onnx
24 | from skl2onnx.sklapi import TraceableTfidfVectorizer
25 | import skl2onnx.sklapi.register # noqa: F401
26 |
27 | corpus = numpy.array(
28 | [
29 | "This is the first document.",
30 | "This document is the second document.",
31 | "Is this the first document?",
32 | "",
33 | ]
34 | ).reshape((4,))
35 |
36 | pattern = r"\b[a-z ]{1,10}\b"
37 | mod1 = TfidfVectorizer(ngram_range=(1, 2), token_pattern=pattern)
38 | mod1.fit(corpus)
39 |
40 |
41 | ######################################
42 | # Unigrams and bi-grams are placed into the following container
43 | # which maps it to its column index.
44 |
45 | pprint.pprint(mod1.vocabulary_)
46 |
47 |
48 | ####################################
49 | # Conversion.
50 |
51 | try:
52 | to_onnx(mod1, corpus)
53 | except RuntimeError as e:
54 | print(e)
55 |
56 |
57 | #######################################
58 | # TraceableTfidfVectorizer
59 | # ++++++++++++++++++++++++
60 | #
61 | # Class :class:`TraceableTfidfVectorizer` is equivalent to
62 | # :class:`sklearn.feature_extraction.text.TfidfVectorizer`
63 | # but stores the unigrams and bi-grams of the vocabulary with tuple
64 | # instead of concatenating every piece into a string.
65 |
66 |
67 | mod2 = TraceableTfidfVectorizer(ngram_range=(1, 2), token_pattern=pattern)
68 | mod2.fit(corpus)
69 |
70 | pprint.pprint(mod2.vocabulary_)
71 |
72 | #######################################
73 | # Let's check it produces the same results.
74 |
75 | assert_almost_equal(mod1.transform(corpus).todense(), mod2.transform(corpus).todense())
76 |
77 | ####################################
78 | # Conversion. Line `import skl2onnx.sklapi.register`
79 | # was added to register the converters associated to these
80 | # new class. By default, only converters for scikit-learn are
81 | # declared.
82 |
83 | onx = to_onnx(mod2, corpus)
84 | sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
85 | got = sess.run(None, {"X": corpus})
86 |
87 | ###################################
88 | # Let's check if there are discrepancies...
89 |
90 | assert_almost_equal(mod2.transform(corpus).todense(), got[0])
91 |
--------------------------------------------------------------------------------
/docs/tutorial_1-5_external.rst:
--------------------------------------------------------------------------------
1 | .. SPDX-License-Identifier: Apache-2.0
2 |
3 | Using converters from other libraries
4 | =====================================
5 |
6 | Before starting writing our own converter,
7 | we can use some available in other libraries
8 | than :epkg:`sklearn-onnx`. :epkg:`onnxmltools` implements
9 | converters for :epkg:`xgboost` and :epkg:`LightGBM`.
10 | Following examples show how to use the conveter when the
11 | model are part of a pipeline.
12 |
13 | .. toctree::
14 | :maxdepth: 1
15 |
16 | auto_tutorial/plot_gexternal_lightgbm
17 | auto_tutorial/plot_gexternal_lightgbm_reg
18 | auto_tutorial/plot_gexternal_xgboost
19 | auto_tutorial/plot_gexternal_catboost
20 |
--------------------------------------------------------------------------------
/docs/tutorial_1_simple.rst:
--------------------------------------------------------------------------------
1 | .. SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | The easy case
5 | =============
6 |
7 | The easy case is when the machine learned model
8 | can be converter into ONNX with a converting library
9 | without writing any specific code. That means that a converter
10 | exists for the model or each piece of the model,
11 | the converter produces an ONNX graph where every node
12 | is part of the existing ONNX specifications, and the runtime
13 | used to compute the predictions implements every node
14 | used in the ONNX graph.
15 |
16 | .. toctree::
17 | :maxdepth: 1
18 |
19 | auto_tutorial/plot_abegin_convert_pipeline
20 | auto_tutorial/plot_bbegin_measure_time
21 | auto_tutorial/plot_cbegin_opset
22 | auto_tutorial/plot_dbegin_options
23 | auto_tutorial/plot_dbegin_options_zipmap
24 | auto_tutorial/plot_dbegin_options_list
25 | auto_tutorial/plot_ebegin_float_double
26 | auto_tutorial/plot_fbegin_investigate
27 | auto_tutorial/plot_gbegin_cst
28 | auto_tutorial/plot_gbegin_dataframe
29 | auto_tutorial/plot_gconverting
30 |
--------------------------------------------------------------------------------
/docs/tutorial_2-5_extlib.rst:
--------------------------------------------------------------------------------
1 | .. SPDX-License-Identifier: Apache-2.0
2 |
3 | Write converters for other libraries
4 | ====================================
5 |
6 | *sklearn-onnx* only converts models from *scikit-learn*. It
7 | implements a mechanism to register converters from other libraries.
8 | Converters for models from other libraries will not be added to
9 | *sklearn-onnx*. Every library has its own maintenance cycle and
10 | it would become difficult to maintain a package having too many
11 | dependencies. Following examples were added to show how to
12 | develop converters for new libraries.
13 |
14 | .. toctree::
15 | :maxdepth: 1
16 |
17 | auto_tutorial/plot_wext_pyod_forest
18 |
--------------------------------------------------------------------------------
/docs/tutorial_2_new_converter.rst:
--------------------------------------------------------------------------------
1 | .. SPDX-License-Identifier: Apache-2.0
2 |
3 | A custom converter for a custom model
4 | =====================================
5 |
6 | When :epkg:`sklearn-onnx` converts a :epkg:`scikit-learn`
7 | pipeline, it looks into every transformer and predictor
8 | and fetches the associated converter. The resulting
9 | ONNX graph combines the outcome of every converter
10 | in a single graph. If a model does not have its converter,
11 | it displays an error message telling it misses a converter.
12 |
13 | .. runpython::
14 | :showcode:
15 |
16 | import numpy
17 | from sklearn.linear_model import LogisticRegression
18 | from skl2onnx import to_onnx
19 |
20 |
21 | class MyLogisticRegression(LogisticRegression):
22 | pass
23 |
24 |
25 | X = numpy.array([[0, 0.1]])
26 | try:
27 | to_onnx(MyLogisticRegression(), X)
28 | except Exception as e:
29 | print(e)
30 |
31 | Following sections show how to create a custom converter.
32 | It assumes this new converter is not meant to be added to
33 | this package but only to be registered and used when converting
34 | a pipeline. To to contribute and add a converter
35 | for a :epkg:`scikit-learn` model, the logic is still the same,
36 | only the converter registration changes. `PR 737
37 | `_ can be used as
38 | an example.
39 |
40 | .. toctree::
41 | :maxdepth: 1
42 |
43 | auto_tutorial/plot_icustom_converter
44 | auto_tutorial/plot_jcustom_syntax
45 | auto_tutorial/plot_jfunction_transformer
46 | auto_tutorial/plot_kcustom_converter_wrapper
47 | auto_tutorial/plot_lcustom_options
48 | auto_tutorial/plot_mcustom_parser
49 |
--------------------------------------------------------------------------------
/docs/tutorial_4_advanced.rst:
--------------------------------------------------------------------------------
1 | .. SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | Advanced scenarios
5 | ==================
6 |
7 | Unexpected discrepencies may appear. This a list of examples
8 | with issues and resolved issues.
9 |
10 | .. toctree::
11 | :maxdepth: 1
12 |
13 | auto_tutorial/plot_ngrams
14 | auto_tutorial/plot_usparse_xgboost
15 | auto_tutorial/plot_woe_transformer
16 | auto_tutorial/plot_output_onnx_single_probability
17 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | # tests
2 | black
3 | jinja2
4 | onnxruntime-extensions
5 | onnxscript
6 | pandas
7 | py-cpuinfo
8 | pybind11
9 | pytest
10 | pytest-cov
11 | ruff
12 | wheel
13 |
14 | # docs/examples
15 | xgboost
16 | lightgbm; sys_platform != 'darwin'
17 | matplotlib
18 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | onnx>=1.2.1
2 | scikit-learn>=1.1
3 |
--------------------------------------------------------------------------------
/skl2onnx/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | Main entry point to the converter from the *scikit-learn* to *onnx*.
5 | """
6 |
7 | __version__ = "1.19.1"
8 | __author__ = "Microsoft"
9 | __producer__ = "skl2onnx"
10 | __producer_version__ = __version__
11 | __domain__ = "ai.onnx"
12 | __model_version__ = 0
13 | __max_supported_opset__ = 21 # Converters are tested up to this version.
14 |
15 |
16 | from .convert import convert_sklearn, to_onnx, wrap_as_onnx_mixin
17 | from ._supported_operators import update_registered_converter, get_model_alias
18 | from ._parse import update_registered_parser
19 | from .proto import get_latest_tested_opset_version
20 |
21 |
22 | def supported_converters(from_sklearn=False):
23 | """
24 | Returns the list of supported converters.
25 | To find the converter associated to a specific model,
26 | the library gets the name of the model class,
27 | adds ``'Sklearn'`` as a prefix and retrieves
28 | the associated converter if available.
29 |
30 | :param from_sklearn: every supported model is mapped to converter
31 | by a name prefixed with ``'Sklearn'``, the prefix is removed
32 | if this parameter is False but the function only returns converters
33 | whose name is prefixed by ``'Sklearn'``
34 | :return: list of supported models as string
35 | """
36 | from .common._registration import _converter_pool
37 |
38 | # The two following lines populates the list of supported converters.
39 | from . import shape_calculators
40 | from . import operator_converters
41 |
42 | names = sorted(_converter_pool.keys())
43 | if from_sklearn:
44 | return [_[7:] for _ in names if _.startswith("Sklearn")]
45 | return list(names)
46 |
--------------------------------------------------------------------------------
/skl2onnx/__main__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import sys
3 | from textwrap import dedent
4 |
5 |
6 | def _help():
7 | print(
8 | dedent(
9 | """
10 | python -m skl2onnx [command]
11 |
12 | command is:
13 |
14 | setup generate rst documentation for every ONNX operator
15 | before building the package"""
16 | )
17 | )
18 |
19 |
20 | def _setup():
21 | from skl2onnx.algebra.onnx_ops import dynamic_class_creation
22 |
23 | dynamic_class_creation(True)
24 |
25 |
26 | def main(argv):
27 | if len(argv) <= 1 or "--help" in argv:
28 | _help()
29 | return
30 |
31 | if "setup" in argv:
32 | print("generate rst documentation for every ONNX operator")
33 | _setup()
34 | return
35 |
36 | _help()
37 |
38 |
39 | if __name__ == "__main__":
40 | main(sys.argv)
41 |
--------------------------------------------------------------------------------
/skl2onnx/algebra/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from .onnx_operator import OnnxOperator
5 |
--------------------------------------------------------------------------------
/skl2onnx/algebra/custom_ops.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from .onnx_operator import OnnxOperator
4 |
5 |
6 | class OnnxCDist(OnnxOperator):
7 | """
8 | Defines a custom operator not defined by ONNX
9 | specifications but in onnxruntime.
10 | """
11 |
12 | since_version = 1
13 | expected_inputs = [("X", "T"), ("Y", "T")]
14 | expected_outputs = [("dist", "T")]
15 | input_range = [2, 2]
16 | output_range = [1, 1]
17 | is_deprecated = False
18 | domain = "com.microsoft"
19 | operator_name = "CDist"
20 | past_version = {}
21 |
22 | def __init__(self, X, Y, metric="sqeuclidean", op_version=None, **kwargs):
23 | """
24 | :param X: array or OnnxOperatorMixin
25 | :param Y: array or OnnxOperatorMixin
26 | :param metric: distance type
27 | :param dtype: *np.float32* or *np.float64*
28 | :param op_version: opset version
29 | :param kwargs: addition parameter
30 | """
31 | OnnxOperator.__init__(
32 | self, X, Y, metric=metric, op_version=op_version, **kwargs
33 | )
34 |
35 |
36 | class OnnxSolve(OnnxOperator):
37 | """
38 | Defines a custom operator not defined by ONNX
39 | specifications but in onnxruntime.
40 | """
41 |
42 | since_version = 1
43 | expected_inputs = [("A", "T"), ("Y", "T")]
44 | expected_outputs = [("X", "T")]
45 | input_range = [2, 2]
46 | output_range = [1, 1]
47 | is_deprecated = False
48 | domain = "com.microsoft"
49 | operator_name = "Solve"
50 | past_version = {}
51 |
52 | def __init__(self, A, Y, lower=False, transposed=False, op_version=None, **kwargs):
53 | """
54 | :param A: array or OnnxOperatorMixin
55 | :param Y: array or OnnxOperatorMixin
56 | :param lower: see :epkg:`solve`
57 | :param transposed: see :epkg:`solve`
58 | :param dtype: *np.float32* or *np.float64*
59 | :param op_version: opset version
60 | :param kwargs: additional parameters
61 | """
62 | OnnxOperator.__init__(
63 | self,
64 | A,
65 | Y,
66 | lower=lower,
67 | transposed=transposed,
68 | op_version=op_version,
69 | **kwargs,
70 | )
71 |
--------------------------------------------------------------------------------
/skl2onnx/algebra/onnx_subgraph_operator_mixin.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from .onnx_operator_mixin import OnnxOperatorMixin
4 |
5 |
6 | class OnnxSubGraphOperatorMixin(OnnxOperatorMixin):
7 | """
8 | :class:`OnnxOperatorMixin` for converters.
9 | """
10 |
--------------------------------------------------------------------------------
/skl2onnx/common/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from .exceptions import MissingShapeCalculator, MissingConverter
4 |
--------------------------------------------------------------------------------
/skl2onnx/common/exceptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | Common errors.
5 | """
6 |
7 | _missing_converter = """
8 | It usually means the pipeline being converted contains a
9 | transformer or a predictor with no corresponding converter
10 | implemented in sklearn-onnx. If the converted is implemented
11 | in another library, you need to register
12 | the converted so that it can be used by sklearn-onnx (function
13 | update_registered_converter). If the model is not yet covered
14 | by sklearn-onnx, you may raise an issue to
15 | https://github.com/onnx/sklearn-onnx/issues
16 | to get the converter implemented or even contribute to the
17 | project. If the model is a custom model, a new converter must
18 | be implemented. Examples can be found in the gallery.
19 | """
20 |
21 |
22 | class MissingShapeCalculator(RuntimeError):
23 | """
24 | Raised when there is no registered shape calculator
25 | for a machine learning operator.
26 | """
27 |
28 | def __init__(self, msg):
29 | super().__init__(msg + _missing_converter)
30 |
31 |
32 | class MissingConverter(RuntimeError):
33 | """
34 | Raised when there is no registered converter
35 | for a machine learning operator. If the model is
36 | part of scikit-learn, you may raise an issue at
37 | https://github.com/onnx/sklearn-onnx/issues.
38 | """
39 |
40 | def __init__(self, msg):
41 | super().__init__(msg + _missing_converter)
42 |
--------------------------------------------------------------------------------
/skl2onnx/common/utils_checking.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from inspect import signature
5 | from collections import OrderedDict
6 |
7 |
8 | def check_signature(fct, reference, skip=None):
9 | """
10 | Checks that two functions have the same signature
11 | (same parameter names).
12 | Raises an exception otherwise.
13 | """
14 |
15 | def select_parameters(pars):
16 | new_pars = OrderedDict()
17 | for i, (name, p) in enumerate(pars.items()):
18 | if (
19 | i >= 3
20 | and name in ("op_type", "op_domain", "op_version")
21 | and p.default is not None
22 | ):
23 | # Parameters op_type and op_domain are skipped.
24 | continue
25 | new_pars[name] = p
26 | return new_pars
27 |
28 | sig = signature(fct)
29 | sig_ref = signature(reference)
30 | fct_pars = select_parameters(sig.parameters)
31 | ref_pars = select_parameters(sig_ref.parameters)
32 | if len(fct_pars) != len(ref_pars):
33 | raise TypeError(
34 | "Function '{}' must have {} parameters but has {}."
35 | "".format(fct.__name__, len(ref_pars), len(fct_pars))
36 | )
37 | for i, (a, b) in enumerate(zip(fct_pars, ref_pars)):
38 | if a != b and skip is not None and b not in skip and a not in skip:
39 | raise NameError(
40 | "Parameter name mismatch at position {}."
41 | "Function '{}' has '{}' but '{}' is expected."
42 | "".format(i + 1, fct.__name__, a, b)
43 | )
44 |
--------------------------------------------------------------------------------
/skl2onnx/helpers/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from .investigate import collect_intermediate_steps, compare_objects
5 | from .investigate import enumerate_pipeline_models
6 | from .integration import add_onnx_graph
7 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/array_feature_extractor.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..proto import onnx_proto
5 | from ..common._registration import register_converter
6 | from ..common._topology import Scope, Operator
7 | from ..common._container import ModelComponentContainer
8 |
9 |
10 | def convert_sklearn_array_feature_extractor(
11 | scope: Scope, operator: Operator, container: ModelComponentContainer
12 | ):
13 | """
14 | Extracts a subset of columns. This is used by *ColumnTransformer*.
15 | """
16 | column_indices_name = scope.get_unique_variable_name("column_indices")
17 |
18 | for i, ind in enumerate(operator.column_indices):
19 | assert isinstance(ind, int), (
20 | "Column {0}:'{1}' indices must be specified "
21 | "as integers. This error may happen when "
22 | "column names are used to define a "
23 | "ColumnTransformer. Column name in input data "
24 | "do not necessarily match input variables "
25 | "defined for the ONNX model."
26 | ).format(i, ind)
27 | container.add_initializer(
28 | column_indices_name,
29 | onnx_proto.TensorProto.INT64,
30 | [len(operator.column_indices)],
31 | operator.column_indices,
32 | )
33 |
34 | container.add_node(
35 | "ArrayFeatureExtractor",
36 | [operator.inputs[0].full_name, column_indices_name],
37 | operator.outputs[0].full_name,
38 | name=scope.get_unique_operator_name("ArrayFeatureExtractor"),
39 | op_domain="ai.onnx.ml",
40 | )
41 |
42 |
43 | register_converter(
44 | "SklearnArrayFeatureExtractor", convert_sklearn_array_feature_extractor
45 | )
46 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/binariser.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..proto import onnx_proto
5 | from ..common.data_types import DoubleTensorType
6 | from ..common._registration import register_converter
7 | from ..common._topology import Scope, Operator
8 | from ..common._container import ModelComponentContainer
9 | from .common import concatenate_variables
10 |
11 |
12 | def convert_sklearn_binarizer(
13 | scope: Scope, operator: Operator, container: ModelComponentContainer
14 | ):
15 | feature_name = concatenate_variables(scope, operator.inputs, container)
16 |
17 | if isinstance(operator.inputs[0].type, DoubleTensorType):
18 | name0 = scope.get_unique_variable_name("cst0")
19 | name1 = scope.get_unique_variable_name("cst1")
20 | thres = scope.get_unique_variable_name("th")
21 | container.add_initializer(name0, onnx_proto.TensorProto.DOUBLE, [], [0.0])
22 | container.add_initializer(name1, onnx_proto.TensorProto.DOUBLE, [], [1.0])
23 | container.add_initializer(
24 | thres,
25 | onnx_proto.TensorProto.DOUBLE,
26 | [],
27 | [float(operator.raw_operator.threshold)],
28 | )
29 | binbool = scope.get_unique_variable_name("binbool")
30 | container.add_node(
31 | "Less",
32 | [feature_name, thres],
33 | binbool,
34 | name=scope.get_unique_operator_name("Less"),
35 | )
36 | container.add_node(
37 | "Where", [binbool, name0, name1], operator.output_full_names, name="Where"
38 | )
39 | return
40 |
41 | op_type = "Binarizer"
42 | attrs = {
43 | "name": scope.get_unique_operator_name(op_type),
44 | "threshold": float(operator.raw_operator.threshold),
45 | }
46 | container.add_node(
47 | op_type,
48 | feature_name,
49 | operator.output_full_names,
50 | op_domain="ai.onnx.ml",
51 | **attrs,
52 | )
53 |
54 |
55 | register_converter("SklearnBinarizer", convert_sklearn_binarizer)
56 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/cast_op.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._apply_operation import apply_cast
5 | from ..common._registration import register_converter
6 | from ..common._topology import Scope, Operator
7 | from ..common._container import ModelComponentContainer
8 | from .._supported_operators import sklearn_operator_name_map
9 |
10 |
11 | def convert_sklearn_cast(
12 | scope: Scope, operator: Operator, container: ModelComponentContainer
13 | ):
14 | inp = operator.inputs[0]
15 | exptype = operator.outputs[0]
16 | res = exptype.type.to_onnx_type()
17 | et = res.tensor_type.elem_type
18 | apply_cast(scope, inp.full_name, exptype.full_name, container, to=et)
19 |
20 |
21 | def convert_sklearn_cast_regressor(
22 | scope: Scope, operator: Operator, container: ModelComponentContainer
23 | ):
24 | op = operator.raw_operator
25 | estimator = op.estimator
26 |
27 | op_type = sklearn_operator_name_map[type(estimator)]
28 | this_operator = scope.declare_local_operator(op_type, estimator)
29 | this_operator.inputs = operator.inputs
30 |
31 | cls = operator.inputs[0].type.__class__
32 | var_name = scope.declare_local_variable("cast_est", cls())
33 | this_operator.outputs.append(var_name)
34 | var_name = var_name.onnx_name
35 |
36 | exptype = operator.outputs[0]
37 | res = exptype.type.to_onnx_type()
38 | et = res.tensor_type.elem_type
39 | apply_cast(scope, var_name, exptype.full_name, container, to=et)
40 |
41 |
42 | register_converter("SklearnCastTransformer", convert_sklearn_cast)
43 | register_converter("SklearnCastRegressor", convert_sklearn_cast_regressor)
44 | register_converter("SklearnCast", convert_sklearn_cast)
45 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/common.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._apply_operation import apply_cast
5 | from ..common.data_types import (
6 | Int64TensorType,
7 | FloatTensorType,
8 | DoubleTensorType,
9 | StringTensorType,
10 | guess_proto_type,
11 | )
12 |
13 |
14 | def concatenate_variables(scope, variables, container, main_type=None):
15 | """
16 | This function allocate operators to from a float tensor by concatenating
17 | all input variables. Notice that if all integer inputs would be converted
18 | to floats before concatenation.
19 | """
20 | if main_type is None:
21 | main_type = variables[0].type.__class__
22 |
23 | # Check if it's possible to concatenate those inputs.
24 | type_set = {type(variable.type) for variable in variables}
25 | number_type_set = {
26 | FloatTensorType,
27 | Int64TensorType,
28 | DoubleTensorType,
29 | StringTensorType,
30 | }
31 | if any(itype not in number_type_set for itype in type_set):
32 | raise RuntimeError(
33 | "Numerical tensor(s) and string tensor(s) cannot be concatenated."
34 | )
35 | # input variables' names we want to concatenate
36 | input_names = []
37 | # dimensions of the variables that is going to be concatenated
38 | input_dims = []
39 |
40 | # Collect input variable names and do cast if needed
41 | for variable in variables:
42 | if not isinstance(variable.type, main_type):
43 | proto_type = guess_proto_type(main_type())
44 | new_name = scope.get_unique_variable_name("cast")
45 | apply_cast(scope, variable.full_name, new_name, container, to=proto_type)
46 | input_names.append(new_name)
47 | else:
48 | input_names.append(variable.full_name)
49 | # We assume input variables' shape are [1, C_1], ..., [1, C_n],
50 | # if there are n inputs.
51 | input_dims.append(variable.type.shape[1])
52 |
53 | if len(input_names) == 1:
54 | # No need to concatenate tensors if there is only one input
55 | return input_names[0]
56 |
57 | # To combine all inputs, we need a FeatureVectorizer
58 | op_type = "FeatureVectorizer"
59 | attrs = {
60 | "name": scope.get_unique_operator_name(op_type),
61 | "inputdimensions": input_dims,
62 | }
63 | # Create a variable name to capture feature vectorizer's output
64 | # Set up our FeatureVectorizer
65 | concatenated_name = scope.get_unique_variable_name("concatenated")
66 | container.add_node(
67 | op_type, input_names, concatenated_name, op_domain="ai.onnx.ml", **attrs
68 | )
69 | if main_type == FloatTensorType:
70 | return concatenated_name
71 | # Cast output as FeatureVectorizer always produces float32.
72 | concatenated_name_cast = scope.get_unique_variable_name("concatenated_cast")
73 | container.add_node(
74 | "CastLike", [concatenated_name, input_names[0]], concatenated_name_cast
75 | )
76 |
77 | return concatenated_name_cast
78 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/concat_op.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._apply_operation import apply_concat, apply_cast
5 | from ..common._registration import register_converter
6 | from ..common._topology import Scope, Operator
7 | from ..common._container import ModelComponentContainer
8 |
9 |
10 | def convert_sklearn_concat(
11 | scope: Scope, operator: Operator, container: ModelComponentContainer
12 | ):
13 | exptype = operator.outputs[0].type
14 | new_inputs = []
15 | for inp in operator.inputs:
16 | if inp.type.__class__ is exptype.__class__:
17 | new_inputs.append(inp.full_name)
18 | continue
19 | name = scope.get_unique_variable_name("{}_cast".format(inp.full_name))
20 | res = exptype.to_onnx_type()
21 | et = res.tensor_type.elem_type
22 | apply_cast(scope, inp.full_name, name, container, to=et)
23 | new_inputs.append(name)
24 |
25 | apply_concat(scope, new_inputs, operator.outputs[0].full_name, container, axis=1)
26 |
27 |
28 | register_converter("SklearnConcat", convert_sklearn_concat)
29 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/cross_decomposition.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import numpy as np
4 | from ..proto import onnx_proto
5 | from ..common._registration import register_converter
6 | from ..common._topology import Scope, Operator
7 | from ..common._container import ModelComponentContainer
8 | from ..common.data_types import Int64TensorType, guess_numpy_type, guess_proto_type
9 | from ..algebra.onnx_ops import OnnxAdd, OnnxCast, OnnxDiv, OnnxMatMul, OnnxSub
10 |
11 |
12 | def _skl150() -> bool:
13 | import sklearn
14 | import packaging.version as pv
15 |
16 | return pv.Version(sklearn.__version__) >= pv.Version("1.5.0")
17 |
18 |
19 | def convert_pls_regression(
20 | scope: Scope, operator: Operator, container: ModelComponentContainer
21 | ):
22 | X = operator.inputs[0]
23 | op = operator.raw_operator
24 | opv = container.target_opset
25 | dtype = guess_numpy_type(X.type)
26 | if dtype != np.float64:
27 | dtype = np.float32
28 | proto_dtype = guess_proto_type(operator.inputs[0].type)
29 | if proto_dtype != onnx_proto.TensorProto.DOUBLE:
30 | proto_dtype = onnx_proto.TensorProto.FLOAT
31 |
32 | if isinstance(X.type, Int64TensorType):
33 | X = OnnxCast(X, to=proto_dtype, op_version=opv)
34 |
35 | coefs = op.x_mean_ if hasattr(op, "x_mean_") else op._x_mean
36 | std = op.x_std_ if hasattr(op, "x_std_") else op._x_std
37 | if hasattr(op, "intercept_") and _skl150():
38 | # scikit-learn==1.5.0
39 | # https://github.com/scikit-learn/scikit-learn/pull/28612
40 | ym = op.intercept_
41 | centered_x = OnnxSub(X, coefs.astype(dtype), op_version=opv)
42 | coefs = op.coef_.T.astype(dtype)
43 | dot = OnnxMatMul(centered_x, coefs, op_version=opv)
44 | else:
45 | ym = op.y_mean_ if hasattr(op, "x_mean_") else op._y_mean
46 |
47 | norm_x = OnnxDiv(
48 | OnnxSub(X, coefs.astype(dtype), op_version=opv),
49 | std.astype(dtype),
50 | op_version=opv,
51 | )
52 | if hasattr(op, "set_predict_request"):
53 | # new in 1.3
54 | coefs = op.coef_.T.astype(dtype)
55 | else:
56 | coefs = op.coef_.astype(dtype)
57 | dot = OnnxMatMul(norm_x, coefs, op_version=opv)
58 |
59 | pred = OnnxAdd(dot, ym.astype(dtype), op_version=opv, output_names=operator.outputs)
60 | pred.add_to(scope, container)
61 |
62 |
63 | register_converter("SklearnPLSRegression", convert_pls_regression)
64 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/dict_vectoriser.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | import numbers
5 | from ..common._registration import register_converter
6 | from ..common._topology import Scope, Operator
7 | from ..common._container import ModelComponentContainer
8 |
9 |
10 | def convert_sklearn_dict_vectorizer(
11 | scope: Scope, operator: Operator, container: ModelComponentContainer
12 | ):
13 | """
14 | When a *DictVectorizer* converts numbers into strings,
15 | scikit-learn adds a separator to disambiguate strings
16 | and still outputs floats. The method *predict*
17 | contains the following lines:
18 |
19 | ::
20 |
21 | if isinstance(v, str):
22 | f = "%s%s%s" % (f, self.separator, v)
23 | v = 1
24 |
25 | This cannot be implemented in ONNX. The converter
26 | raises an exception in that case.
27 | """
28 | op_type = "DictVectorizer"
29 | op = operator.raw_operator
30 | attrs = {"name": scope.get_unique_operator_name(op_type)}
31 | if all(isinstance(feature_name, str) for feature_name in op.feature_names_):
32 | # all strings, scikit-learn does the following:
33 | new_cats = []
34 | unique_cats = set()
35 | nbsep = 0
36 | for i in op.feature_names_:
37 | if op.separator in i:
38 | nbsep += 1
39 | if i in unique_cats:
40 | raise RuntimeError("Duplicated category '{}'.".format(i))
41 | unique_cats.add(i)
42 | new_cats.append(i)
43 | if nbsep >= len(new_cats):
44 | raise RuntimeError(
45 | "All categories contain a separator '{}'. "
46 | "This case is not supported by the converter. "
47 | "The mapping must map to numbers not string.".format(op.separator)
48 | )
49 | attrs["string_vocabulary"] = new_cats
50 | elif all(
51 | isinstance(feature_name, numbers.Integral) for feature_name in op.feature_names_
52 | ):
53 | attrs["int64_vocabulary"] = list( # noqa: C400
54 | int(i) for i in op.feature_names_
55 | )
56 | else:
57 | raise ValueError("Keys must be all integers or all strings.")
58 |
59 | container.add_node(
60 | op_type,
61 | operator.input_full_names,
62 | operator.output_full_names,
63 | op_domain="ai.onnx.ml",
64 | **attrs,
65 | )
66 |
67 |
68 | register_converter("SklearnDictVectorizer", convert_sklearn_dict_vectorizer)
69 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/feature_selection.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..proto import onnx_proto
5 | from ..common._registration import register_converter
6 | from ..common._topology import Scope, Operator
7 | from ..common._container import ModelComponentContainer
8 |
9 |
10 | def convert_sklearn_feature_selection(
11 | scope: Scope, operator: Operator, container: ModelComponentContainer
12 | ):
13 | op = operator.raw_operator
14 | # Get indices of the features selected
15 | index = op.get_support(indices=True)
16 | if len(index) == 0:
17 | raise RuntimeError(
18 | "Model '{}' did not select any feature. "
19 | "This model cannot be converted into ONNX."
20 | "".format(op.__class__.__name__)
21 | )
22 | output_name = operator.outputs[0].full_name
23 | if index.any():
24 | column_indices_name = scope.get_unique_variable_name("column_indices")
25 |
26 | container.add_initializer(
27 | column_indices_name, onnx_proto.TensorProto.INT64, [len(index)], index
28 | )
29 |
30 | container.add_node(
31 | "ArrayFeatureExtractor",
32 | [operator.inputs[0].full_name, column_indices_name],
33 | output_name,
34 | op_domain="ai.onnx.ml",
35 | name=scope.get_unique_operator_name("ArrayFeatureExtractor"),
36 | )
37 | else:
38 | container.add_node(
39 | "ConstantOfShape", operator.inputs[0].full_name, output_name, op_version=9
40 | )
41 |
42 |
43 | register_converter("SklearnGenericUnivariateSelect", convert_sklearn_feature_selection)
44 | register_converter("SklearnRFE", convert_sklearn_feature_selection)
45 | register_converter("SklearnRFECV", convert_sklearn_feature_selection)
46 | register_converter("SklearnSelectFdr", convert_sklearn_feature_selection)
47 | register_converter("SklearnSelectFpr", convert_sklearn_feature_selection)
48 | register_converter("SklearnSelectFromModel", convert_sklearn_feature_selection)
49 | register_converter("SklearnSelectFwe", convert_sklearn_feature_selection)
50 | register_converter("SklearnSelectKBest", convert_sklearn_feature_selection)
51 | register_converter("SklearnSelectPercentile", convert_sklearn_feature_selection)
52 | register_converter("SklearnVarianceThreshold", convert_sklearn_feature_selection)
53 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/flatten_op.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from ..common._registration import register_converter
4 | from ..common._topology import Scope, Operator
5 | from ..common._container import ModelComponentContainer
6 |
7 |
8 | def convert_sklearn_flatten(
9 | scope: Scope, operator: Operator, container: ModelComponentContainer
10 | ):
11 | name = scope.get_unique_operator_name("Flatten")
12 | container.add_node(
13 | "Flatten",
14 | operator.inputs[0].full_name,
15 | operator.outputs[0].full_name,
16 | name=name,
17 | axis=1,
18 | )
19 |
20 |
21 | register_converter("SklearnFlatten", convert_sklearn_flatten)
22 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/function_transformer.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_converter
5 | from ..common._apply_operation import apply_concat, apply_identity
6 | from ..common._topology import Scope, Operator
7 | from ..common._container import ModelComponentContainer
8 |
9 |
10 | def convert_sklearn_function_transformer(
11 | scope: Scope, operator: Operator, container: ModelComponentContainer
12 | ):
13 | op = operator.raw_operator
14 | if op.func is not None:
15 | raise RuntimeError(
16 | "FunctionTransformer is not supported unless the "
17 | "transform function is None (= identity). "
18 | "You may raise an issue at "
19 | "https://github.com/onnx/sklearn-onnx/issues."
20 | )
21 | if len(operator.inputs) == 1:
22 | apply_identity(
23 | scope,
24 | operator.inputs[0].full_name,
25 | operator.outputs[0].full_name,
26 | container,
27 | )
28 | else:
29 | apply_concat(
30 | scope,
31 | [i.full_name for i in operator.inputs],
32 | operator.outputs[0].full_name,
33 | container,
34 | axis=1,
35 | )
36 |
37 |
38 | register_converter("SklearnFunctionTransformer", convert_sklearn_function_transformer)
39 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/grid_search_cv.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from ..common._apply_operation import apply_identity
4 | from ..common._registration import register_converter
5 | from ..common._topology import Scope, Operator
6 | from ..common._container import ModelComponentContainer
7 | from .._supported_operators import sklearn_operator_name_map
8 |
9 |
10 | def convert_sklearn_grid_search_cv(
11 | scope: Scope, operator: Operator, container: ModelComponentContainer
12 | ):
13 | """
14 | Converter for scikit-learn's GridSearchCV.
15 | """
16 | opts = scope.get_options(operator.raw_operator)
17 | grid_search_op = operator.raw_operator
18 | best_estimator = grid_search_op.best_estimator_
19 | op_type = sklearn_operator_name_map[type(best_estimator)]
20 | grid_search_operator = scope.declare_local_operator(op_type, best_estimator)
21 | container.add_options(id(best_estimator), opts)
22 | scope.add_options(id(best_estimator), opts)
23 | grid_search_operator.inputs = operator.inputs
24 |
25 | for _i, o in enumerate(operator.outputs):
26 | v = scope.declare_local_variable(o.onnx_name, type=o.type)
27 | grid_search_operator.outputs.append(v)
28 | apply_identity(scope, v.full_name, o.full_name, container)
29 |
30 |
31 | register_converter(
32 | "SklearnGridSearchCV", convert_sklearn_grid_search_cv, options="passthrough"
33 | )
34 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/id_op.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._apply_operation import apply_identity
5 | from ..common._registration import register_converter
6 | from ..common._topology import Scope, Operator
7 | from ..common._container import ModelComponentContainer
8 |
9 |
10 | def convert_sklearn_identity(
11 | scope: Scope, operator: Operator, container: ModelComponentContainer
12 | ):
13 | apply_identity(
14 | scope,
15 | operator.inputs[0].full_name,
16 | operator.outputs[0].full_name,
17 | container,
18 | operator_name=scope.get_unique_operator_name("CIdentity"),
19 | )
20 |
21 |
22 | register_converter("SklearnIdentity", convert_sklearn_identity)
23 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/label_encoder.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | import numpy as np
5 | from ..common._registration import register_converter
6 | from ..common._topology import Scope, Operator
7 | from ..common._container import ModelComponentContainer
8 |
9 |
10 | def convert_sklearn_label_encoder(
11 | scope: Scope, operator: Operator, container: ModelComponentContainer
12 | ):
13 | op = operator.raw_operator
14 | op_type = "LabelEncoder"
15 | attrs = {"name": scope.get_unique_operator_name(op_type)}
16 | classes = op.classes_
17 | if np.issubdtype(classes.dtype, np.floating):
18 | attrs["keys_floats"] = classes
19 | elif np.issubdtype(classes.dtype, np.signedinteger) or classes.dtype == np.bool_:
20 | attrs["keys_int64s"] = [int(i) for i in classes]
21 | else:
22 | attrs["keys_strings"] = np.array([s.encode("utf-8") for s in classes])
23 | attrs["values_int64s"] = np.arange(len(classes))
24 |
25 | cop = container.target_opset_any_domain("ai.onnx.ml")
26 | if cop is not None and cop < 2:
27 | raise RuntimeError(
28 | "LabelEncoder requires at least opset 2 for domain 'ai.onnx.ml' "
29 | "not {}".format(cop)
30 | )
31 |
32 | container.add_node(
33 | op_type,
34 | operator.input_full_names,
35 | operator.output_full_names,
36 | op_domain="ai.onnx.ml",
37 | op_version=2,
38 | **attrs,
39 | )
40 |
41 |
42 | register_converter("SklearnLabelEncoder", convert_sklearn_label_encoder)
43 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/multiply_op.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._apply_operation import apply_mul
5 | from ..common._registration import register_converter
6 | from ..common._topology import Scope, Operator
7 | from ..common._container import ModelComponentContainer
8 | from ..common.data_types import guess_proto_type
9 |
10 |
11 | def convert_sklearn_multiply(
12 | scope: Scope, operator: Operator, container: ModelComponentContainer
13 | ):
14 | for input, output in zip(operator.inputs, operator.outputs):
15 | operand_name = scope.get_unique_variable_name("operand")
16 |
17 | container.add_initializer(
18 | operand_name, guess_proto_type(input.type), [], [operator.operand]
19 | )
20 |
21 | apply_mul(
22 | scope,
23 | [input.full_name, operand_name],
24 | output.full_name,
25 | container,
26 | )
27 |
28 |
29 | register_converter("SklearnMultiply", convert_sklearn_multiply)
30 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/normaliser.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_converter
5 | from ..common._topology import Scope, Operator
6 | from ..common._container import ModelComponentContainer
7 | from ..common._apply_operation import apply_normalizer
8 | from ..common.data_types import DoubleTensorType
9 | from .common import concatenate_variables
10 |
11 |
12 | def convert_sklearn_normalizer(
13 | scope: Scope, operator: Operator, container: ModelComponentContainer
14 | ):
15 | if len(operator.inputs) > 1:
16 | # If there are multiple input tensors,
17 | # we combine them using a FeatureVectorizer
18 | feature_name = concatenate_variables(scope, operator.inputs, container)
19 | else:
20 | # No concatenation is needed, we just use the first variable's name
21 | feature_name = operator.inputs[0].full_name
22 | op = operator.raw_operator
23 | norm_map = {"max": "MAX", "l1": "L1", "l2": "L2"}
24 | if op.norm in norm_map:
25 | norm = norm_map[op.norm]
26 | else:
27 | raise RuntimeError(
28 | "Invalid norm '%s'. You may raise an issue"
29 | "at https://github.com/onnx/sklearn-onnx/"
30 | "issues." % op.norm
31 | )
32 | use_float = type(operator.inputs[0].type) not in (DoubleTensorType,)
33 | apply_normalizer(
34 | scope,
35 | feature_name,
36 | operator.outputs[0].full_name,
37 | container,
38 | norm=norm,
39 | use_float=use_float,
40 | )
41 |
42 |
43 | register_converter("SklearnNormalizer", convert_sklearn_normalizer)
44 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/pipelines.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from sklearn.base import is_classifier
4 | from sklearn.pipeline import Pipeline
5 | from ..common._registration import register_converter
6 | from ..common._topology import Scope, Operator
7 | from ..common._container import ModelComponentContainer
8 | from ..common._apply_operation import apply_cast
9 | from ..common.data_types import guess_proto_type
10 | from .._parse import _parse_sklearn
11 |
12 |
13 | def convert_pipeline(
14 | scope: Scope, operator: Operator, container: ModelComponentContainer
15 | ):
16 | model = operator.raw_operator
17 | inputs = operator.inputs
18 | for step in model.steps:
19 | step_model = step[1]
20 | if is_classifier(step_model) or isinstance(step_model, Pipeline):
21 | scope.add_options(id(step_model), options={"zipmap": False})
22 | container.add_options(id(step_model), options={"zipmap": False})
23 | outputs = _parse_sklearn(scope, step_model, inputs, custom_parsers=None)
24 | inputs = outputs
25 | if len(outputs) != len(operator.outputs):
26 | raise RuntimeError(
27 | "Mismatch between pipeline output %d and "
28 | "last step outputs %d." % (len(outputs), len(operator.outputs))
29 | )
30 | for fr, to in zip(outputs, operator.outputs):
31 | if isinstance(to.type, type(fr.type)):
32 | container.add_node(
33 | "Identity",
34 | fr.full_name,
35 | to.full_name,
36 | name=scope.get_unique_operator_name("Id" + operator.onnx_name),
37 | )
38 | else:
39 | # If Pipeline output types are different with last stage output type
40 | apply_cast(
41 | scope,
42 | fr.full_name,
43 | to.full_name,
44 | container,
45 | operator_name=scope.get_unique_operator_name(
46 | "Cast" + operator.onnx_name
47 | ),
48 | to=guess_proto_type(to.type),
49 | )
50 |
51 |
52 | def convert_feature_union(
53 | scope: Scope, operator: Operator, container: ModelComponentContainer
54 | ):
55 | raise NotImplementedError(
56 | "This converter not needed so far. It is usually handled during parsing."
57 | )
58 |
59 |
60 | def convert_column_transformer(
61 | scope: Scope, operator: Operator, container: ModelComponentContainer
62 | ):
63 | raise NotImplementedError(
64 | "This converter not needed so far. It is usually handled during parsing."
65 | )
66 |
67 |
68 | register_converter(
69 | "SklearnPipeline",
70 | convert_pipeline,
71 | options={
72 | "zipmap": [True, False, "columns"],
73 | "nocl": [True, False],
74 | "output_class_labels": [False, True],
75 | "raw_scores": [True, False],
76 | },
77 | )
78 | register_converter("SklearnFeatureUnion", convert_feature_union)
79 | register_converter("SklearnColumnTransformer", convert_column_transformer)
80 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/random_projection.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import numpy as np
4 | from ..common._registration import register_converter
5 | from ..common.data_types import guess_numpy_type
6 | from ..common._topology import Scope, Operator
7 | from ..common._container import ModelComponentContainer
8 | from ..algebra.onnx_ops import OnnxMatMul
9 |
10 |
11 | def convert_random_projection(
12 | scope: Scope, operator: Operator, container: ModelComponentContainer
13 | ):
14 | """Converter for PowerTransformer"""
15 | op_in = operator.inputs[0]
16 | op_out = operator.outputs[0].full_name
17 | op = operator.raw_operator
18 | opv = container.target_opset
19 | dtype = guess_numpy_type(operator.inputs[0].type)
20 | if dtype != np.float64:
21 | dtype = np.float32
22 |
23 | y = OnnxMatMul(
24 | op_in, op.components_.T.astype(dtype), op_version=opv, output_names=[op_out]
25 | )
26 | y.add_to(scope, container)
27 |
28 |
29 | register_converter("SklearnGaussianRandomProjection", convert_random_projection)
30 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/random_trees_embedding.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import numpy as np
4 | from ..common._registration import register_converter
5 | from ..common._topology import Scope, Operator
6 | from ..common._container import ModelComponentContainer
7 | from ..algebra.onnx_operator import OnnxSubEstimator
8 | from ..algebra.onnx_ops import OnnxIdentity, OnnxConcat, OnnxReshape
9 |
10 |
11 | def convert_sklearn_random_tree_embedding(
12 | scope: Scope, operator: Operator, container: ModelComponentContainer
13 | ):
14 | X = operator.inputs[0]
15 | out = operator.outputs
16 | op = operator.raw_operator
17 | opv = container.target_opset
18 |
19 | if op.sparse_output:
20 | raise RuntimeError(
21 | "The converter cannot convert the model with sparse outputs."
22 | )
23 |
24 | outputs = []
25 | for est in op.estimators_:
26 | leave = OnnxSubEstimator(
27 | est, X, op_version=opv, options={"decision_leaf": True}
28 | )
29 | outputs.append(
30 | OnnxReshape(leave[1], np.array([-1, 1], dtype=np.int64), op_version=opv)
31 | )
32 | merged = OnnxConcat(*outputs, axis=1, op_version=opv)
33 | ohe = OnnxSubEstimator(op.one_hot_encoder_, merged, op_version=opv)
34 | y = OnnxIdentity(ohe, op_version=opv, output_names=out)
35 | y.add_to(scope, container)
36 |
37 |
38 | register_converter("SklearnRandomTreesEmbedding", convert_sklearn_random_tree_embedding)
39 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/ransac_regressor.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from .._supported_operators import sklearn_operator_name_map
5 | from ..common._apply_operation import apply_identity
6 | from ..common._registration import register_converter
7 | from ..common._topology import Scope, Operator
8 | from ..common._container import ModelComponentContainer
9 |
10 |
11 | def convert_sklearn_ransac_regressor(
12 | scope: Scope, operator: Operator, container: ModelComponentContainer
13 | ):
14 | """
15 | Converter for RANSACRegressor.
16 | """
17 | ransac_op = operator.raw_operator
18 | op_type = sklearn_operator_name_map[type(ransac_op.estimator_)]
19 | this_operator = scope.declare_local_operator(op_type, ransac_op.estimator_)
20 | this_operator.inputs = operator.inputs
21 | label_name = scope.declare_local_variable(
22 | "label", operator.inputs[0].type.__class__()
23 | )
24 | this_operator.outputs.append(label_name)
25 | apply_identity(
26 | scope, label_name.full_name, operator.outputs[0].full_name, container
27 | )
28 |
29 |
30 | register_converter("SklearnRANSACRegressor", convert_sklearn_ransac_regressor)
31 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/replace_op.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_converter
5 | from ..common._topology import Scope, Operator
6 | from ..common._container import ModelComponentContainer
7 | from ..common.data_types import guess_proto_type
8 |
9 |
10 | def convert_sklearn_replace_transformer(
11 | scope: Scope, operator: Operator, container: ModelComponentContainer
12 | ):
13 | op = operator.raw_operator
14 | input_name = operator.inputs[0].full_name
15 | output_name = operator.outputs[0].full_name
16 |
17 | proto_dtype = guess_proto_type(operator.inputs[0].type)
18 |
19 | cst_nan_name = scope.get_unique_variable_name("nan_name")
20 | container.add_initializer(cst_nan_name, proto_dtype, [1], [op.to_value])
21 | cst_zero_name = scope.get_unique_variable_name("zero_name")
22 | container.add_initializer(cst_zero_name, proto_dtype, [1], [op.from_value])
23 |
24 | mask_name = scope.get_unique_variable_name("mask_name")
25 | container.add_node(
26 | "Equal",
27 | [input_name, cst_zero_name],
28 | mask_name,
29 | name=scope.get_unique_operator_name("Equal"),
30 | )
31 |
32 | container.add_node(
33 | "Where",
34 | [mask_name, cst_nan_name, input_name],
35 | output_name,
36 | name=scope.get_unique_operator_name("Where"),
37 | )
38 |
39 |
40 | register_converter("SklearnReplaceTransformer", convert_sklearn_replace_transformer)
41 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/sequence.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from ..proto import onnx_proto
4 | from ..common._registration import register_converter
5 | from ..common._topology import Scope, Operator
6 | from ..common._container import ModelComponentContainer
7 |
8 |
9 | def convert_sklearn_sequence_at(
10 | scope: Scope, operator: Operator, container: ModelComponentContainer
11 | ):
12 | i_index = operator.index
13 | index_name = scope.get_unique_variable_name("seq_at%d" % i_index)
14 | container.add_initializer(index_name, onnx_proto.TensorProto.INT64, [], [i_index])
15 | container.add_node(
16 | "SequenceAt",
17 | [operator.inputs[0].full_name, index_name],
18 | operator.outputs[0].full_name,
19 | name=scope.get_unique_operator_name("SequenceAt%d" % i_index),
20 | )
21 |
22 |
23 | def convert_sklearn_sequence_construct(
24 | scope: Scope, operator: Operator, container: ModelComponentContainer
25 | ):
26 | container.add_node(
27 | "SequenceConstruct",
28 | [i.full_name for i in operator.inputs],
29 | operator.outputs[0].full_name,
30 | name=scope.get_unique_operator_name("SequenceConstruct"),
31 | )
32 |
33 |
34 | register_converter("SklearnSequenceAt", convert_sklearn_sequence_at)
35 | register_converter("SklearnSequenceConstruct", convert_sklearn_sequence_construct)
36 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/sgd_oneclass_svm.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._apply_operation import apply_cast, apply_sub
5 | from ..common.data_types import BooleanTensorType, Int64TensorType, guess_proto_type
6 | from ..common._registration import register_converter
7 | from ..common._topology import Scope, Operator
8 | from ..common._container import ModelComponentContainer
9 | from ..proto import onnx_proto
10 |
11 |
12 | def convert_sklearn_sgd_oneclass_svm(
13 | scope: Scope, operator: Operator, container: ModelComponentContainer
14 | ):
15 | input_name = operator.inputs[0].full_name
16 | output_names = operator.output_full_names
17 | model = operator.raw_operator
18 | coef = model.coef_.T
19 | offset = model.offset_
20 |
21 | proto_dtype = guess_proto_type(operator.inputs[0].type)
22 | if proto_dtype != onnx_proto.TensorProto.DOUBLE:
23 | proto_dtype = onnx_proto.TensorProto.FLOAT
24 |
25 | if isinstance(operator.inputs[0].type, (BooleanTensorType, Int64TensorType)):
26 | cast_input_name = scope.get_unique_variable_name("cast_input")
27 | apply_cast(
28 | scope, operator.input_full_names, cast_input_name, container, to=proto_dtype
29 | )
30 | input_name = cast_input_name
31 |
32 | coef_name = scope.get_unique_variable_name("coef")
33 | container.add_initializer(coef_name, proto_dtype, coef.shape, coef.ravel())
34 |
35 | offset_name = scope.get_unique_variable_name("offset")
36 | container.add_initializer(offset_name, proto_dtype, offset.shape, offset)
37 |
38 | matmul_result_name = scope.get_unique_variable_name("matmul_result")
39 | container.add_node(
40 | "MatMul",
41 | [input_name, coef_name],
42 | matmul_result_name,
43 | name=scope.get_unique_operator_name("MatMul"),
44 | )
45 |
46 | apply_sub(
47 | scope,
48 | [matmul_result_name, offset_name],
49 | output_names[1],
50 | container,
51 | broadcast=0,
52 | )
53 |
54 | pred = scope.get_unique_variable_name("class_prediction")
55 | container.add_node("Sign", output_names[1], pred, op_version=9)
56 | apply_cast(scope, pred, output_names[0], container, to=onnx_proto.TensorProto.INT64)
57 |
58 |
59 | register_converter("SklearnSGDOneClassSVM", convert_sklearn_sgd_oneclass_svm)
60 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/tfidf_vectoriser.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from onnx import onnx_pb as onnx_proto
4 | from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
5 | from ..common._apply_operation import apply_identity
6 | from ..common.data_types import FloatTensorType, DoubleTensorType, guess_proto_type
7 | from ..common._registration import register_converter
8 | from ..common._topology import Scope, Operator
9 | from ..common._container import ModelComponentContainer
10 | from .._supported_operators import sklearn_operator_name_map
11 |
12 |
13 | def convert_sklearn_tfidf_vectoriser(
14 | scope: Scope, operator: Operator, container: ModelComponentContainer
15 | ):
16 | """
17 | Converter for scikit-learn's TfidfVectoriser.
18 | """
19 | tfidf_op = operator.raw_operator
20 | op_type = sklearn_operator_name_map[CountVectorizer]
21 | cv_operator = scope.declare_local_operator(op_type, tfidf_op)
22 | cv_operator.inputs = operator.inputs
23 | columns = max(operator.raw_operator.vocabulary_.values()) + 1
24 | proto_dtype = guess_proto_type(operator.inputs[0].type)
25 | if proto_dtype != onnx_proto.TensorProto.DOUBLE:
26 | proto_dtype = onnx_proto.TensorProto.FLOAT
27 | if proto_dtype == onnx_proto.TensorProto.FLOAT:
28 | clr = FloatTensorType
29 | elif proto_dtype == onnx_proto.TensorProto.DOUBLE:
30 | clr = DoubleTensorType
31 | else:
32 | raise RuntimeError(
33 | "Unexpected dtype '{}'. Float or double expected.".format(proto_dtype)
34 | )
35 | cv_output_name = scope.declare_local_variable(
36 | "count_vec_output", clr([None, columns])
37 | )
38 | cv_operator.outputs.append(cv_output_name)
39 |
40 | op_type = sklearn_operator_name_map[TfidfTransformer]
41 | tfidf_operator = scope.declare_local_operator(op_type, tfidf_op)
42 | tfidf_operator.inputs.append(cv_output_name)
43 | tfidf_output_name = scope.declare_local_variable("tfidf_output", clr())
44 | tfidf_operator.outputs.append(tfidf_output_name)
45 |
46 | apply_identity(
47 | scope, tfidf_output_name.full_name, operator.outputs[0].full_name, container
48 | )
49 |
50 |
51 | register_converter(
52 | "SklearnTfidfVectorizer",
53 | convert_sklearn_tfidf_vectoriser,
54 | options={
55 | "tokenexp": None,
56 | "separators": None,
57 | "nan": [True, False],
58 | "keep_empty_string": [True, False],
59 | "locale": None,
60 | },
61 | )
62 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/tuned_threshold_classifier.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from ..common._registration import register_converter
4 | from ..common._topology import Scope, Operator
5 | from ..common._container import ModelComponentContainer
6 | from ..common.data_types import Int64TensorType
7 | from .._supported_operators import sklearn_operator_name_map
8 |
9 |
10 | def convert_sklearn_tuned_threshold_classifier(
11 | scope: Scope, operator: Operator, container: ModelComponentContainer
12 | ):
13 | estimator = operator.raw_operator.estimator_
14 | op_type = sklearn_operator_name_map[type(estimator)]
15 |
16 | this_operator = scope.declare_local_operator(op_type, estimator)
17 | this_operator.inputs = operator.inputs
18 |
19 | label_name = scope.declare_local_variable("label_tuned", Int64TensorType())
20 | prob_name = scope.declare_local_variable(
21 | "proba_tuned", operator.outputs[1].type.__class__()
22 | )
23 | this_operator.outputs.append(label_name)
24 | this_operator.outputs.append(prob_name)
25 |
26 | container.add_node(
27 | "Identity", [label_name.onnx_name], [operator.outputs[0].full_name]
28 | )
29 | container.add_node(
30 | "Identity", [prob_name.onnx_name], [operator.outputs[1].full_name]
31 | )
32 |
33 |
34 | register_converter(
35 | "SklearnTunedThresholdClassifierCV",
36 | convert_sklearn_tuned_threshold_classifier,
37 | options={
38 | "zipmap": [True, False, "columns"],
39 | "output_class_labels": [False, True],
40 | "nocl": [True, False],
41 | },
42 | )
43 |
--------------------------------------------------------------------------------
/skl2onnx/operator_converters/voting_regressor.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_converter
5 | from ..common._topology import Scope, Operator
6 | from ..common._container import ModelComponentContainer
7 | from ..common._apply_operation import apply_mul
8 | from ..common.data_types import guess_proto_type
9 | from .._supported_operators import sklearn_operator_name_map
10 |
11 |
12 | def convert_voting_regressor(
13 | scope: Scope, operator: Operator, container: ModelComponentContainer
14 | ):
15 | """
16 | Converts a *VotingRegressor* into *ONNX* format.
17 | """
18 | op = operator.raw_operator
19 | proto_dtype = guess_proto_type(operator.outputs[0].type)
20 |
21 | vars_names = []
22 | for i, estimator in enumerate(op.estimators_):
23 | if estimator is None:
24 | continue
25 |
26 | op_type = sklearn_operator_name_map[type(estimator)]
27 |
28 | this_operator = scope.declare_local_operator(op_type, estimator)
29 | this_operator.inputs = operator.inputs
30 |
31 | var_name = scope.declare_local_variable(
32 | "var_%d" % i, operator.outputs[0].type.__class__()
33 | )
34 | this_operator.outputs.append(var_name)
35 | var_name = var_name.onnx_name
36 |
37 | if op.weights is not None:
38 | val = op.weights[i] / op.weights.sum()
39 | else:
40 | val = 1.0 / len(op.estimators_)
41 |
42 | weights_name = scope.get_unique_variable_name("w%d" % i)
43 | container.add_initializer(weights_name, proto_dtype, [1], [val])
44 | wvar_name = scope.get_unique_variable_name("wvar_%d" % i)
45 | apply_mul(scope, [var_name, weights_name], wvar_name, container, broadcast=1)
46 |
47 | flat_name = scope.get_unique_variable_name("fvar_%d" % i)
48 | container.add_node("Flatten", wvar_name, flat_name)
49 | vars_names.append(flat_name)
50 |
51 | container.add_node(
52 | "Sum",
53 | vars_names,
54 | operator.outputs[0].full_name,
55 | name=scope.get_unique_operator_name("Sum"),
56 | )
57 |
58 |
59 | register_converter("SklearnVotingRegressor", convert_voting_regressor)
60 |
--------------------------------------------------------------------------------
/skl2onnx/proto/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | # Rather than using ONNX protobuf definition throughout our codebase,
5 | # we import ONNX protobuf definition here so that we can conduct quick
6 | # fixes by overwriting ONNX functions without changing any lines
7 | # elsewhere.
8 | from onnx import onnx_pb as onnx_proto
9 | from onnx import defs
10 |
11 | # Overwrite the make_tensor defined in onnx.helper because of a bug
12 | # (string tensor get assigned twice)
13 | from onnx.onnx_pb import TensorProto, ValueInfoProto
14 |
15 | try: # noqa: SIM105
16 | from onnx.onnx_pb import SparseTensorProto
17 | except ImportError:
18 | # onnx is too old.
19 | pass
20 |
21 |
22 | def get_opset_number_from_onnx():
23 | """
24 | Returns the latest opset version supported
25 | by the *onnx* package.
26 | """
27 | return defs.onnx_opset_version()
28 |
29 |
30 | def get_latest_tested_opset_version():
31 | """
32 | This module relies on *onnxruntime* to test every
33 | converter. The function returns the most recent
34 | target opset tested with *onnxruntime* or the opset
35 | version specified by *onnx* package if this one is lower
36 | (return by `onnx.defs.onnx_opset_version()`).
37 | """
38 | from .. import __max_supported_opset__
39 |
40 | return min(__max_supported_opset__, get_opset_number_from_onnx())
41 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/array_feature_extractor.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.utils import check_input_and_output_numbers
6 |
7 |
8 | def calculate_sklearn_array_feature_extractor(operator):
9 | check_input_and_output_numbers(operator, output_count_range=1)
10 | i = operator.inputs[0]
11 | N = i.get_first_dimension()
12 | C = len(operator.column_indices)
13 | operator.outputs[0].type = i.type.__class__([N, C])
14 |
15 |
16 | register_shape_calculator(
17 | "SklearnArrayFeatureExtractor", calculate_sklearn_array_feature_extractor
18 | )
19 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/cast_op.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.utils import check_input_and_output_numbers
6 | from ..common.data_types import _guess_numpy_type
7 | from ..common.shape_calculator import calculate_linear_regressor_output_shapes
8 |
9 |
10 | def calculate_sklearn_cast(operator):
11 | check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
12 |
13 |
14 | def calculate_sklearn_cast_transformer(operator):
15 | check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
16 | op = operator.raw_operator
17 | otype = _guess_numpy_type(op.dtype, operator.inputs[0].type.shape)
18 | operator.outputs[0].type = otype
19 |
20 |
21 | register_shape_calculator("SklearnCast", calculate_sklearn_cast)
22 | register_shape_calculator("SklearnCastTransformer", calculate_sklearn_cast_transformer)
23 | register_shape_calculator(
24 | "SklearnCastRegressor", calculate_linear_regressor_output_shapes
25 | )
26 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/class_labels.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.utils import check_input_and_output_numbers
6 |
7 |
8 | def calculate_sklearn_class_labels(operator):
9 | check_input_and_output_numbers(operator, output_count_range=1)
10 |
11 |
12 | register_shape_calculator("SklearnClassLabels", calculate_sklearn_class_labels)
13 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/cross_decomposition.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.data_types import FloatTensorType, Int64TensorType, DoubleTensorType
6 | from ..common.utils import check_input_and_output_numbers, check_input_and_output_types
7 |
8 |
9 | def calculate_pls_regression_output_shapes(operator):
10 | check_input_and_output_numbers(operator, input_count_range=1)
11 | check_input_and_output_types(
12 | operator, good_input_types=[FloatTensorType, Int64TensorType, DoubleTensorType]
13 | )
14 |
15 | if len(operator.inputs[0].type.shape) != 2:
16 | raise RuntimeError("Input must be a [N, C]-tensor")
17 |
18 | op = operator.raw_operator
19 | cls_type = operator.inputs[0].type.__class__
20 | if cls_type != DoubleTensorType:
21 | cls_type = FloatTensorType
22 | N = operator.inputs[0].get_first_dimension()
23 | operator.outputs[0].type = cls_type([N, op.coef_.shape[1]])
24 |
25 |
26 | register_shape_calculator(
27 | "SklearnPLSRegression", calculate_pls_regression_output_shapes
28 | )
29 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/dict_vectorizer.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.utils import check_input_and_output_numbers
6 |
7 |
8 | def calculate_sklearn_dict_vectorizer_output_shapes(operator):
9 | """
10 | Allowed input/output patterns are
11 | 1. Map ---> [1, C]
12 |
13 | C is the total number of allowed keys in the input dictionary.
14 | """
15 | check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
16 | C = len(operator.raw_operator.feature_names_)
17 | operator.outputs[0].type.shape = [None, C]
18 |
19 |
20 | register_shape_calculator(
21 | "SklearnDictVectorizer", calculate_sklearn_dict_vectorizer_output_shapes
22 | )
23 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/feature_hasher.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import numpy as np
4 | from ..common.data_types import (
5 | StringTensorType,
6 | Int64TensorType,
7 | FloatTensorType,
8 | DoubleTensorType,
9 | )
10 | from ..common._registration import register_shape_calculator
11 | from ..common.utils import check_input_and_output_numbers
12 | from ..common.utils import check_input_and_output_types
13 |
14 |
15 | def calculate_sklearn_feature_hasher(operator):
16 | check_input_and_output_numbers(operator, output_count_range=1)
17 | check_input_and_output_types(
18 | operator, good_input_types=[StringTensorType, Int64TensorType]
19 | )
20 |
21 | N = operator.inputs[0].get_first_dimension()
22 | model = operator.raw_operator
23 | shape = [N, model.n_features]
24 | if model.dtype == np.float32:
25 | operator.outputs[0].type = FloatTensorType(shape=shape)
26 | elif model.dtype == np.float64:
27 | operator.outputs[0].type = DoubleTensorType(shape=shape)
28 | elif model.dtype in (np.int32, np.uint32, np.int64):
29 | operator.outputs[0].type = Int64TensorType(shape=shape)
30 | else:
31 | raise RuntimeError(
32 | f"Converter is not implemented for FeatureHasher.dtype={model.dtype}."
33 | )
34 |
35 |
36 | register_shape_calculator("SklearnFeatureHasher", calculate_sklearn_feature_hasher)
37 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/feature_selection.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.utils import check_input_and_output_numbers
6 |
7 |
8 | def calculate_sklearn_select(operator):
9 | check_input_and_output_numbers(operator, output_count_range=1)
10 | i = operator.inputs[0]
11 | N = i.get_first_dimension()
12 | C = operator.raw_operator.get_support().sum()
13 | operator.outputs[0].type = i.type.__class__([N, C])
14 |
15 |
16 | register_shape_calculator("SklearnGenericUnivariateSelect", calculate_sklearn_select)
17 | register_shape_calculator("SklearnRFE", calculate_sklearn_select)
18 | register_shape_calculator("SklearnRFECV", calculate_sklearn_select)
19 | register_shape_calculator("SklearnSelectFdr", calculate_sklearn_select)
20 | register_shape_calculator("SklearnSelectFpr", calculate_sklearn_select)
21 | register_shape_calculator("SklearnSelectFromModel", calculate_sklearn_select)
22 | register_shape_calculator("SklearnSelectFwe", calculate_sklearn_select)
23 | register_shape_calculator("SklearnSelectKBest", calculate_sklearn_select)
24 | register_shape_calculator("SklearnSelectPercentile", calculate_sklearn_select)
25 | register_shape_calculator("SklearnVarianceThreshold", calculate_sklearn_select)
26 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/flatten.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.data_types import FloatType, Int64Type, StringType, TensorType
6 | from ..common.utils import check_input_and_output_numbers
7 |
8 |
9 | def calculate_sklearn_flatten(operator):
10 | check_input_and_output_numbers(operator, output_count_range=1, input_count_range=1)
11 | i = operator.inputs[0]
12 | N = i.get_first_dimension()
13 | if isinstance(i.type, TensorType):
14 | if i.type.shape[1] is None:
15 | C = None
16 | else:
17 | C = i.type.shape[1]
18 | elif isinstance(i.type, (Int64Type, FloatType, StringType)):
19 | C = 1
20 | else:
21 | C = None
22 | if C is None:
23 | operator.outputs[0].type.shape = [N, C]
24 | else:
25 | operator.outputs[0].type.shape = [N * C]
26 |
27 |
28 | register_shape_calculator("SklearnFlatten", calculate_sklearn_flatten)
29 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/function_transformer.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | import copy
5 | from ..common._registration import register_shape_calculator
6 |
7 |
8 | def calculate_sklearn_function_transformer_output_shapes(operator):
9 | """
10 | This operator is used only to merge columns in a pipeline.
11 | Only identity function is supported.
12 | """
13 | if operator.raw_operator.func is not None:
14 | raise RuntimeError(
15 | "FunctionTransformer is not supported unless the "
16 | "transform function is None (= identity). "
17 | "You may raise an issue at "
18 | "https://github.com/onnx/sklearn-onnx/issues."
19 | )
20 | N = operator.inputs[0].get_first_dimension()
21 | C = 0
22 | for variable in operator.inputs:
23 | if variable.type.shape[1] is not None:
24 | C += variable.type.shape[1]
25 | else:
26 | C = None
27 | break
28 |
29 | operator.outputs[0].type = copy.deepcopy(operator.inputs[0].type)
30 | operator.outputs[0].type.shape = [N, C]
31 |
32 |
33 | register_shape_calculator(
34 | "SklearnFunctionTransformer", calculate_sklearn_function_transformer_output_shapes
35 | )
36 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/gaussian_process.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.shape_calculator import calculate_linear_classifier_output_shapes
6 | from ..common.data_types import FloatTensorType, DoubleTensorType
7 | from ..common.utils import check_input_and_output_types
8 |
9 |
10 | def calculate_sklearn_gaussian_process_regressor_shape(operator):
11 | check_input_and_output_types(
12 | operator,
13 | good_input_types=[FloatTensorType, DoubleTensorType],
14 | good_output_types=[FloatTensorType, DoubleTensorType],
15 | )
16 | if len(operator.inputs) != 1:
17 | raise RuntimeError(
18 | "Only one input vector is allowed for GaussianProcessRegressor."
19 | )
20 | if len(operator.outputs) not in (1, 2):
21 | raise RuntimeError("One output is expected for GaussianProcessRegressor.")
22 |
23 | variable = operator.inputs[0]
24 |
25 | N = variable.get_first_dimension()
26 | op = operator.raw_operator
27 |
28 | # Output 1 is mean
29 | # Output 2 is cov or std
30 | if hasattr(op, "y_train_") and op.y_train_ is not None:
31 | dim = 1 if len(op.y_train_.shape) == 1 else op.y_train_.shape[1]
32 | else:
33 | dim = 1
34 | operator.outputs[0].type.shape = [N, dim]
35 |
36 |
37 | register_shape_calculator(
38 | "SklearnGaussianProcessRegressor",
39 | calculate_sklearn_gaussian_process_regressor_shape,
40 | )
41 | register_shape_calculator(
42 | "SklearnGaussianProcessClassifier", calculate_linear_classifier_output_shapes
43 | )
44 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/grid_search_cv.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import logging
4 | from ..common._registration import register_shape_calculator, get_shape_calculator
5 | from .._supported_operators import sklearn_operator_name_map
6 |
7 |
8 | def convert_sklearn_grid_search_cv(operator):
9 | grid_search_op = operator.raw_operator
10 | best_estimator = grid_search_op.best_estimator_
11 | name = sklearn_operator_name_map.get(type(best_estimator), None)
12 | if name is None:
13 | logger = logging.getLogger("skl2onnx")
14 | logger.warning(
15 | "[convert_sklearn_grid_search_cv] failed to find alias "
16 | "to model type %r.",
17 | type(best_estimator),
18 | )
19 | return
20 | op = operator.new_raw_operator(best_estimator, name)
21 | shape_calc = get_shape_calculator(name)
22 | shape_calc(op)
23 | operator.outputs = op.outputs
24 |
25 |
26 | register_shape_calculator("SklearnGridSearchCV", convert_sklearn_grid_search_cv)
27 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/identity.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.utils import check_input_and_output_numbers
6 |
7 |
8 | def calculate_sklearn_identity(operator):
9 | check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
10 | operator.outputs[0].type = operator.inputs[0].type
11 |
12 |
13 | register_shape_calculator("SklearnIdentity", calculate_sklearn_identity)
14 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/imputer.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.data_types import (
6 | FloatTensorType,
7 | Int64TensorType,
8 | DoubleTensorType,
9 | StringTensorType,
10 | )
11 | from ..common.utils import check_input_and_output_numbers
12 | from ..common.utils import check_input_and_output_types
13 |
14 |
15 | def calculate_sklearn_imputer_output_shapes(operator):
16 | """
17 | Allowed input/output patterns are
18 | 1. [N, C_1], ..., [N, C_n] ---> [N, C_1 + ... + C_n]
19 |
20 | It's possible to receive multiple inputs so we need to concatenate
21 | them along C-axis. The produced tensor's shape is used as the
22 | output shape.
23 | """
24 | check_input_and_output_numbers(
25 | operator, input_count_range=[1, None], output_count_range=1
26 | )
27 | check_input_and_output_types(
28 | operator,
29 | good_input_types=[
30 | FloatTensorType,
31 | Int64TensorType,
32 | DoubleTensorType,
33 | StringTensorType,
34 | ],
35 | )
36 | output = operator.outputs[0]
37 | for variable in operator.inputs:
38 | if not isinstance(variable.type, type(output.type)):
39 | raise RuntimeError(
40 | "Inputs and outputs should have the same type "
41 | "%r != %r." % (type(variable.type), type(output.type))
42 | )
43 |
44 | N = operator.inputs[0].get_first_dimension()
45 | C = 0
46 | for variable in operator.inputs:
47 | if variable.type.shape[1] is not None:
48 | C += variable.type.shape[1]
49 | else:
50 | C = None
51 | break
52 |
53 | output.type.shape = [N, C]
54 |
55 |
56 | register_shape_calculator("SklearnImputer", calculate_sklearn_imputer_output_shapes)
57 | register_shape_calculator(
58 | "SklearnSimpleImputer", calculate_sklearn_imputer_output_shapes
59 | )
60 | register_shape_calculator("SklearnBinarizer", calculate_sklearn_imputer_output_shapes)
61 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/isolation_forest.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from ..common._registration import register_shape_calculator
4 | from ..common.data_types import Int64TensorType
5 |
6 |
7 | def calculate_isolation_forest_output_shapes(operator):
8 | N = operator.inputs[0].get_first_dimension()
9 | operator.outputs[0].type = Int64TensorType([N, 1])
10 | operator.outputs[1].type.shape = [N, 1]
11 |
12 |
13 | register_shape_calculator(
14 | "SklearnIsolationForest", calculate_isolation_forest_output_shapes
15 | )
16 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/k_bins_discretiser.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common.data_types import FloatTensorType, Int64TensorType, DoubleTensorType
5 | from ..common._registration import register_shape_calculator
6 | from ..common.utils import check_input_and_output_numbers
7 | from ..common.utils import check_input_and_output_types
8 |
9 |
10 | def calculate_sklearn_k_bins_discretiser(operator):
11 | check_input_and_output_numbers(operator, output_count_range=1)
12 | check_input_and_output_types(
13 | operator, good_input_types=[FloatTensorType, Int64TensorType, DoubleTensorType]
14 | )
15 |
16 | M = operator.inputs[0].get_first_dimension()
17 | model = operator.raw_operator
18 | N = len(model.n_bins_) if model.encode == "ordinal" else sum(model.n_bins_)
19 | operator.outputs[0].type.shape = [M, N]
20 |
21 |
22 | register_shape_calculator(
23 | "SklearnKBinsDiscretizer", calculate_sklearn_k_bins_discretiser
24 | )
25 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/k_means.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.data_types import FloatTensorType, Int64TensorType, DoubleTensorType
6 | from ..common.utils import check_input_and_output_types
7 |
8 |
9 | def calculate_sklearn_kmeans_output_shapes(operator):
10 | check_input_and_output_types(
11 | operator,
12 | good_input_types=[Int64TensorType, FloatTensorType, DoubleTensorType],
13 | good_output_types=[Int64TensorType, FloatTensorType, DoubleTensorType],
14 | )
15 | if len(operator.inputs) != 1:
16 | raise RuntimeError("Only one input vector is allowed for KMeans.")
17 | if len(operator.outputs) != 2:
18 | raise RuntimeError("Two outputs are expected for KMeans.")
19 |
20 | variable = operator.inputs[0]
21 | N = variable.get_first_dimension()
22 | op = operator.raw_operator
23 | operator.outputs[0].type.shape = [N]
24 | operator.outputs[1].type.shape = [N, op.n_clusters]
25 |
26 |
27 | register_shape_calculator("SklearnKMeans", calculate_sklearn_kmeans_output_shapes)
28 | register_shape_calculator(
29 | "SklearnMiniBatchKMeans", calculate_sklearn_kmeans_output_shapes
30 | )
31 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/kernel_pca.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.data_types import FloatTensorType, DoubleTensorType
6 | from ..common.utils import check_input_and_output_numbers, check_input_and_output_types
7 |
8 |
9 | def calculate_sklearn_kernel_pca_output_shapes(operator):
10 | check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
11 | check_input_and_output_types(
12 | operator,
13 | good_input_types=[FloatTensorType, DoubleTensorType],
14 | good_output_types=[FloatTensorType, DoubleTensorType],
15 | )
16 | N = operator.inputs[0].get_first_dimension()
17 | op = operator.raw_operator
18 | lbd = op.eigenvalues_ if hasattr(op, "eigenvalues_") else op.lambdas_
19 | C = lbd.shape[0]
20 | operator.outputs[0].type.shape = [N, C]
21 |
22 |
23 | def calculate_sklearn_kernel_centerer_output_shapes(operator):
24 | check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
25 | check_input_and_output_types(
26 | operator,
27 | good_input_types=[FloatTensorType, DoubleTensorType],
28 | good_output_types=[FloatTensorType, DoubleTensorType],
29 | )
30 | N = operator.inputs[0].get_first_dimension()
31 | C = operator.raw_operator.K_fit_rows_.shape[0]
32 | operator.outputs[0].type.shape = [N, C]
33 |
34 |
35 | register_shape_calculator(
36 | "SklearnKernelCenterer", calculate_sklearn_kernel_centerer_output_shapes
37 | )
38 | register_shape_calculator(
39 | "SklearnKernelPCA", calculate_sklearn_kernel_pca_output_shapes
40 | )
41 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/label_binariser.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.data_types import Int64TensorType, StringTensorType
6 | from ..common.utils import check_input_and_output_numbers
7 | from ..common.utils import check_input_and_output_types
8 |
9 |
10 | def calculate_sklearn_label_binariser_output_shapes(operator):
11 | check_input_and_output_numbers(operator, output_count_range=1)
12 | check_input_and_output_types(
13 | operator, good_input_types=[Int64TensorType, StringTensorType]
14 | )
15 |
16 | N = operator.inputs[0].get_first_dimension()
17 | operator.outputs[0].type = Int64TensorType([N, len(operator.raw_operator.classes_)])
18 |
19 |
20 | register_shape_calculator(
21 | "SklearnLabelBinarizer", calculate_sklearn_label_binariser_output_shapes
22 | )
23 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/label_encoder.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | import copy
5 | from ..common._registration import register_shape_calculator
6 | from ..common.data_types import FloatTensorType
7 | from ..common.data_types import Int64TensorType, StringTensorType
8 | from ..common.utils import check_input_and_output_numbers
9 | from ..common.utils import check_input_and_output_types
10 |
11 |
12 | def calculate_sklearn_label_encoder_output_shapes(operator):
13 | """
14 | This function just copy the input shape to the output because label
15 | encoder only alters input features' values, not their shape.
16 | """
17 | check_input_and_output_numbers(operator, output_count_range=1)
18 | check_input_and_output_types(
19 | operator, good_input_types=[FloatTensorType, Int64TensorType, StringTensorType]
20 | )
21 |
22 | input_shape = copy.deepcopy(operator.inputs[0].type.shape)
23 | operator.outputs[0].type = Int64TensorType(copy.deepcopy(input_shape))
24 |
25 |
26 | register_shape_calculator(
27 | "SklearnLabelEncoder", calculate_sklearn_label_encoder_output_shapes
28 | )
29 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/linear_classifier.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.shape_calculator import calculate_linear_classifier_output_shapes
6 |
7 |
8 | register_shape_calculator(
9 | "SklearnLinearClassifier", calculate_linear_classifier_output_shapes
10 | )
11 | register_shape_calculator("SklearnLinearSVC", calculate_linear_classifier_output_shapes)
12 | register_shape_calculator(
13 | "SklearnAdaBoostClassifier", calculate_linear_classifier_output_shapes
14 | )
15 | register_shape_calculator(
16 | "SklearnBaggingClassifier", calculate_linear_classifier_output_shapes
17 | )
18 | register_shape_calculator(
19 | "SklearnBernoulliNB", calculate_linear_classifier_output_shapes
20 | )
21 | register_shape_calculator(
22 | "SklearnCategoricalNB", calculate_linear_classifier_output_shapes
23 | )
24 | register_shape_calculator(
25 | "SklearnComplementNB", calculate_linear_classifier_output_shapes
26 | )
27 | register_shape_calculator(
28 | "SklearnGaussianNB", calculate_linear_classifier_output_shapes
29 | )
30 | register_shape_calculator(
31 | "SklearnMultinomialNB", calculate_linear_classifier_output_shapes
32 | )
33 | register_shape_calculator(
34 | "SklearnCalibratedClassifierCV", calculate_linear_classifier_output_shapes
35 | )
36 | register_shape_calculator(
37 | "SklearnMLPClassifier", calculate_linear_classifier_output_shapes
38 | )
39 | register_shape_calculator(
40 | "SklearnSGDClassifier", calculate_linear_classifier_output_shapes
41 | )
42 | register_shape_calculator(
43 | "SklearnStackingClassifier", calculate_linear_classifier_output_shapes
44 | )
45 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/linear_regressor.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.utils import check_input_and_output_numbers, check_input_and_output_types
6 | from ..common.shape_calculator import calculate_linear_regressor_output_shapes
7 | from ..common.data_types import (
8 | BooleanTensorType,
9 | DoubleTensorType,
10 | FloatTensorType,
11 | Int64TensorType,
12 | )
13 |
14 |
15 | def calculate_bayesian_ridge_output_shapes(operator):
16 | """
17 | Allowed input/output patterns are
18 | 1. [N, C] ---> [N, 1]
19 |
20 | This operator produces a scalar prediction for every example in a
21 | batch. If the input batch size is N, the output shape may be
22 | [N, 1].
23 | """
24 | check_input_and_output_numbers(
25 | operator, input_count_range=1, output_count_range=[1, 2]
26 | )
27 | check_input_and_output_types(
28 | operator,
29 | good_input_types=[
30 | BooleanTensorType,
31 | DoubleTensorType,
32 | FloatTensorType,
33 | Int64TensorType,
34 | ],
35 | )
36 |
37 | inp0 = operator.inputs[0].type
38 | if isinstance(inp0, (FloatTensorType, DoubleTensorType)):
39 | cls_type = inp0.__class__
40 | else:
41 | cls_type = FloatTensorType
42 |
43 | N = operator.inputs[0].get_first_dimension()
44 | if (
45 | hasattr(operator.raw_operator, "coef_")
46 | and len(operator.raw_operator.coef_.shape) > 1
47 | ):
48 | operator.outputs[0].type = cls_type([N, operator.raw_operator.coef_.shape[1]])
49 | else:
50 | operator.outputs[0].type = cls_type([N, 1])
51 |
52 | if len(operator.inputs) == 2:
53 | # option return_std is True
54 | operator.outputs[1].type = cls_type([N, 1])
55 |
56 |
57 | register_shape_calculator(
58 | "SklearnAdaBoostRegressor", calculate_linear_regressor_output_shapes
59 | )
60 | register_shape_calculator(
61 | "SklearnBaggingRegressor", calculate_linear_regressor_output_shapes
62 | )
63 | register_shape_calculator(
64 | "SklearnBayesianRidge", calculate_bayesian_ridge_output_shapes
65 | )
66 | register_shape_calculator(
67 | "SklearnLinearRegressor", calculate_linear_regressor_output_shapes
68 | )
69 | register_shape_calculator("SklearnLinearSVR", calculate_linear_regressor_output_shapes)
70 | register_shape_calculator(
71 | "SklearnMLPRegressor", calculate_linear_regressor_output_shapes
72 | )
73 | register_shape_calculator(
74 | "SklearnPoissonRegressor", calculate_linear_regressor_output_shapes
75 | )
76 | register_shape_calculator(
77 | "SklearnRANSACRegressor", calculate_linear_regressor_output_shapes
78 | )
79 | register_shape_calculator(
80 | "SklearnStackingRegressor", calculate_linear_regressor_output_shapes
81 | )
82 | register_shape_calculator(
83 | "SklearnTweedieRegressor", calculate_linear_regressor_output_shapes
84 | )
85 | register_shape_calculator(
86 | "SklearnGammaRegressor", calculate_linear_regressor_output_shapes
87 | )
88 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/local_outlier_factor.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from ..common._registration import register_shape_calculator
4 | from ..common.data_types import Int64TensorType
5 |
6 |
7 | def calculate_local_outlier_factor_output_shapes(operator):
8 | N = operator.inputs[0].get_first_dimension()
9 | operator.outputs[0].type = Int64TensorType([N, 1])
10 | operator.outputs[1].type.shape = [N, 1]
11 |
12 |
13 | register_shape_calculator(
14 | "SklearnLocalOutlierFactor", calculate_local_outlier_factor_output_shapes
15 | )
16 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/mixture.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.data_types import FloatTensorType, Int64TensorType, DoubleTensorType
6 | from ..common.utils import check_input_and_output_numbers, check_input_and_output_types
7 |
8 |
9 | def calculate_gaussian_mixture_output_shapes(operator):
10 | check_input_and_output_numbers(
11 | operator, input_count_range=1, output_count_range=[2, 3]
12 | )
13 | check_input_and_output_types(
14 | operator, good_input_types=[FloatTensorType, Int64TensorType, DoubleTensorType]
15 | )
16 |
17 | if len(operator.inputs[0].type.shape) != 2:
18 | raise RuntimeError("Input must be a [N, C]-tensor")
19 |
20 | op = operator.raw_operator
21 | N = operator.inputs[0].get_first_dimension()
22 | operator.outputs[0].type = Int64TensorType([N, 1])
23 | operator.outputs[1].type.shape = [N, op.n_components]
24 | if len(operator.outputs) > 2:
25 | operator.outputs[2].type.shape = [N, 1]
26 |
27 |
28 | register_shape_calculator(
29 | "SklearnGaussianMixture", calculate_gaussian_mixture_output_shapes
30 | )
31 | register_shape_calculator(
32 | "SklearnBayesianGaussianMixture", calculate_gaussian_mixture_output_shapes
33 | )
34 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/multioutput.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.utils import check_input_and_output_numbers
6 | from ..common.data_types import SequenceType
7 |
8 | _stack = []
9 |
10 |
11 | def multioutput_regressor_shape_calculator(operator):
12 | """Shape calculator for MultiOutputRegressor"""
13 | check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
14 | i = operator.inputs[0]
15 | o = operator.outputs[0]
16 | N = i.get_first_dimension()
17 | C = len(operator.raw_operator.estimators_)
18 | o.type = o.type.__class__([N, C])
19 |
20 |
21 | def multioutput_classifier_shape_calculator(operator):
22 | """Shape calculator for MultiOutputClassifier"""
23 | check_input_and_output_numbers(operator, input_count_range=1, output_count_range=2)
24 | if not isinstance(operator.outputs[1].type, SequenceType):
25 | raise RuntimeError(
26 | "Probabilites should be a sequence not %r." % operator.outputs[1].type
27 | )
28 | i = operator.inputs[0]
29 | outputs = operator.outputs
30 | N = i.get_first_dimension()
31 | C = len(operator.raw_operator.estimators_)
32 | outputs[0].type.shape = [N, C]
33 |
34 |
35 | register_shape_calculator(
36 | "SklearnMultiOutputRegressor", multioutput_regressor_shape_calculator
37 | )
38 | register_shape_calculator(
39 | "SklearnMultiOutputClassifier", multioutput_classifier_shape_calculator
40 | )
41 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/multiply.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import copy
3 |
4 | from ..common._registration import register_shape_calculator
5 |
6 |
7 | def calculate_sklearn_multiply(operator):
8 | for variable, output in zip(operator.inputs, operator.outputs):
9 | output.type = copy.copy(variable.type)
10 |
11 |
12 | register_shape_calculator("SklearnMultiply", calculate_sklearn_multiply)
13 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/one_hot_encoder.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | import numpy as np
5 | from ..common._registration import register_shape_calculator
6 | from ..common.data_types import FloatTensorType, Int64TensorType
7 |
8 |
9 | def calculate_sklearn_one_hot_encoder_output_shapes(operator):
10 | op = operator.raw_operator
11 | categories_len = 0
12 | for index, categories in enumerate(op.categories_):
13 | if hasattr(op, "drop_idx_") and op.drop_idx_ is not None:
14 | categories = categories[np.arange(len(categories)) != op.drop_idx_[index]]
15 | categories_len += len(categories)
16 | instances = operator.inputs[0].get_first_dimension()
17 | if np.issubdtype(op.dtype, np.signedinteger):
18 | operator.outputs[0].type = Int64TensorType([instances, categories_len])
19 | else:
20 | operator.outputs[0].type = FloatTensorType([instances, categories_len])
21 |
22 |
23 | register_shape_calculator(
24 | "SklearnOneHotEncoder", calculate_sklearn_one_hot_encoder_output_shapes
25 | )
26 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/one_vs_one_classifier.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from ..common._registration import register_shape_calculator
4 | from ..common.shape_calculator import calculate_linear_classifier_output_shapes
5 |
6 |
7 | register_shape_calculator(
8 | "SklearnOneVsOneClassifier", calculate_linear_classifier_output_shapes
9 | )
10 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/one_vs_rest_classifier.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from ..common._registration import register_shape_calculator
4 | from ..common.data_types import Int64TensorType
5 | from ..common.shape_calculator import calculate_linear_classifier_output_shapes
6 |
7 |
8 | def calculate_constant_predictor_output_shapes(operator):
9 | N = operator.inputs[0].get_first_dimension()
10 | operator.outputs[0].type = Int64TensorType([N])
11 | operator.outputs[1].type.shape = [N, 2]
12 |
13 |
14 | register_shape_calculator(
15 | "Sklearn_ConstantPredictor", calculate_constant_predictor_output_shapes
16 | )
17 |
18 | register_shape_calculator(
19 | "SklearnOneVsRestClassifier", calculate_linear_classifier_output_shapes
20 | )
21 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/ordinal_encoder.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | import numpy as np
5 | from ..common._registration import register_shape_calculator
6 | from ..common.data_types import Int64TensorType, FloatTensorType
7 |
8 |
9 | def calculate_sklearn_ordinal_encoder_output_shapes(operator):
10 | ordinal_op = operator.raw_operator
11 | op_features = sum(list(map(lambda x: x.type.shape[1], operator.inputs)))
12 | if np.issubdtype(ordinal_op.dtype, np.floating):
13 | operator.outputs[0].type = FloatTensorType(
14 | [operator.inputs[0].get_first_dimension(), op_features]
15 | )
16 | else:
17 | operator.outputs[0].type = Int64TensorType(
18 | [operator.inputs[0].get_first_dimension(), op_features]
19 | )
20 |
21 |
22 | register_shape_calculator(
23 | "SklearnOrdinalEncoder", calculate_sklearn_ordinal_encoder_output_shapes
24 | )
25 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/ovr_decision_function.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 |
6 |
7 | def calculate_sklearn_ovr_decision_function(operator):
8 | N = operator.inputs[0].get_first_dimension()
9 | operator.outputs[0].type = operator.inputs[0].type.__class__(
10 | [N, len(operator.raw_operator.classes_)]
11 | )
12 |
13 |
14 | register_shape_calculator(
15 | "SklearnOVRDecisionFunction", calculate_sklearn_ovr_decision_function
16 | )
17 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/pipelines.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from ..common._registration import register_shape_calculator
4 |
5 |
6 | def pipeline_shape_calculator(operator):
7 | pass
8 |
9 |
10 | def feature_union_shape_calculator(operator):
11 | pass
12 |
13 |
14 | def column_transformer_shape_calculator(operator):
15 | pass
16 |
17 |
18 | register_shape_calculator("SklearnPipeline", pipeline_shape_calculator)
19 | register_shape_calculator("SklearnFeatureUnion", feature_union_shape_calculator)
20 | register_shape_calculator(
21 | "SklearnColumnTransformer", column_transformer_shape_calculator
22 | )
23 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/polynomial_features.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | import copy
5 | from ..common.data_types import FloatTensorType, Int64TensorType, DoubleTensorType
6 | from ..common._registration import register_shape_calculator
7 | from ..common.utils import check_input_and_output_numbers
8 | from ..common.utils import check_input_and_output_types
9 |
10 |
11 | def calculate_sklearn_polynomial_features(operator):
12 | check_input_and_output_numbers(operator, output_count_range=1)
13 | check_input_and_output_types(
14 | operator, good_input_types=[FloatTensorType, Int64TensorType, DoubleTensorType]
15 | )
16 |
17 | N = operator.inputs[0].get_first_dimension()
18 | model = operator.raw_operator
19 | operator.outputs[0].type = copy.deepcopy(operator.inputs[0].type)
20 | operator.outputs[0].type.shape = [N, model.n_output_features_]
21 |
22 |
23 | register_shape_calculator(
24 | "SklearnPolynomialFeatures", calculate_sklearn_polynomial_features
25 | )
26 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/power_transformer.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.data_types import FloatTensorType
6 |
7 |
8 | def powertransformer_shape_calculator(operator):
9 | """Shape calculator for PowerTransformer"""
10 | inputs = operator.inputs[0]
11 | output = operator.outputs[0]
12 | n, c = inputs.type.shape
13 | output.type = FloatTensorType([n, c])
14 |
15 |
16 | register_shape_calculator("SklearnPowerTransformer", powertransformer_shape_calculator)
17 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/quadratic_discriminant_analysis.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from ..common._registration import register_shape_calculator
4 | from ..common.data_types import Int64TensorType, StringTensorType
5 |
6 |
7 | def calculate_quadratic_discriminant_analysis_shapes(operator):
8 | classes = operator.raw_operator.classes_
9 | if all((isinstance(s, str)) for s in classes):
10 | label_tensor_type = StringTensorType
11 | else:
12 | label_tensor_type = Int64TensorType
13 |
14 | n_clasess = len(classes)
15 | operator.outputs[0].type = label_tensor_type([1, None])
16 | operator.outputs[1].type.shape = [None, n_clasess]
17 |
18 |
19 | register_shape_calculator(
20 | "SklearnQuadraticDiscriminantAnalysis",
21 | calculate_quadratic_discriminant_analysis_shapes,
22 | )
23 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/quantile_transformer.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import copy
4 | from ..common._registration import register_shape_calculator
5 | from ..common.utils import check_input_and_output_numbers, check_input_and_output_types
6 | from ..common.data_types import FloatTensorType, Int64TensorType, DoubleTensorType
7 |
8 |
9 | def quantile_transformer_shape_calculator(operator):
10 | """Shape calculator for QuantileTransformer"""
11 | check_input_and_output_numbers(operator, output_count_range=1)
12 | check_input_and_output_types(
13 | operator, good_input_types=[FloatTensorType, Int64TensorType, DoubleTensorType]
14 | )
15 |
16 | N = operator.inputs[0].get_first_dimension()
17 | model = operator.raw_operator
18 | operator.outputs[0].type = copy.deepcopy(operator.inputs[0].type)
19 | operator.outputs[0].type.shape = [N, model.quantiles_.shape[1]]
20 |
21 |
22 | register_shape_calculator(
23 | "SklearnQuantileTransformer", quantile_transformer_shape_calculator
24 | )
25 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/random_projection.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 |
6 |
7 | def random_projection_shape_calculator(operator):
8 | """Shape calculator for PowerTransformer"""
9 | inputs = operator.inputs[0]
10 | op = operator.raw_operator
11 | n = inputs.get_first_dimension()
12 | c = op.components_.shape[0]
13 | operator.outputs[0].type.shape = [n, c]
14 |
15 |
16 | register_shape_calculator(
17 | "SklearnGaussianRandomProjection", random_projection_shape_calculator
18 | )
19 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/random_trees_embedding.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | import numpy as np
5 | from ..common._registration import register_shape_calculator
6 | from ..common.data_types import FloatTensorType, Int64TensorType
7 |
8 |
9 | def calculate_sklearn_random_trees_embedding_output_shapes(operator):
10 | op = operator.raw_operator.one_hot_encoder_
11 | categories_len = 0
12 | for index, categories in enumerate(op.categories_):
13 | if hasattr(op, "drop_idx_") and op.drop_idx_ is not None:
14 | categories = categories[np.arange(len(categories)) != op.drop_idx_[index]]
15 | categories_len += len(categories)
16 | instances = operator.inputs[0].get_first_dimension()
17 | if np.issubdtype(op.dtype, np.signedinteger):
18 | operator.outputs[0].type = Int64TensorType([instances, categories_len])
19 | else:
20 | operator.outputs[0].type = FloatTensorType([instances, categories_len])
21 |
22 |
23 | register_shape_calculator(
24 | "SklearnRandomTreesEmbedding",
25 | calculate_sklearn_random_trees_embedding_output_shapes,
26 | )
27 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/replace_op.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.utils import check_input_and_output_numbers
6 |
7 |
8 | def calculate_sklearn_replace_transformer(operator):
9 | check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
10 | operator.outputs[0].type = operator.inputs[0].type
11 |
12 |
13 | register_shape_calculator(
14 | "SklearnReplaceTransformer", calculate_sklearn_replace_transformer
15 | )
16 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/scaler.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | import numbers
5 | from ..common._registration import register_shape_calculator
6 | from ..common.data_types import FloatTensorType, Int64TensorType, DoubleTensorType
7 | from ..common.utils import check_input_and_output_numbers
8 | from ..common.utils import check_input_and_output_types
9 |
10 |
11 | def calculate_sklearn_scaler_output_shapes(operator):
12 | """
13 | Allowed input/output patterns are
14 | 1. [N, C_1], ..., [N, C_n] ---> [N, C_1 + ... + C_n]
15 |
16 | Similar to imputer, this operator can take multiple input feature
17 | tensors and concatenate them along C-axis.
18 | """
19 | check_input_and_output_numbers(
20 | operator, input_count_range=[1, None], output_count_range=1
21 | )
22 | check_input_and_output_types(
23 | operator,
24 | good_input_types=[FloatTensorType, Int64TensorType, DoubleTensorType],
25 | good_output_types=[FloatTensorType, DoubleTensorType],
26 | )
27 | # Inputs: multiple float- and integer-tensors
28 | # Output: one float tensor
29 | for variable in operator.inputs:
30 | if len({variable.get_first_dimension() for variable in operator.inputs}) > 1:
31 | raise RuntimeError("Batch size must be identical across inputs.")
32 |
33 | N = operator.inputs[0].get_first_dimension()
34 | C = 0
35 | for variable in operator.inputs:
36 | c = variable.get_second_dimension()
37 | if isinstance(c, numbers.Integral):
38 | C += c
39 | else:
40 | C = None
41 | break
42 |
43 | operator.outputs[0].type.shape = [N, C]
44 |
45 |
46 | register_shape_calculator("SklearnRobustScaler", calculate_sklearn_scaler_output_shapes)
47 | register_shape_calculator("SklearnScaler", calculate_sklearn_scaler_output_shapes)
48 | register_shape_calculator("SklearnNormalizer", calculate_sklearn_scaler_output_shapes)
49 | register_shape_calculator("SklearnMinMaxScaler", calculate_sklearn_scaler_output_shapes)
50 | register_shape_calculator("SklearnMaxAbsScaler", calculate_sklearn_scaler_output_shapes)
51 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/sequence.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from ..common._registration import register_shape_calculator
4 |
5 |
6 | def calculate_sklearn_sequence_at(operator):
7 | pass
8 |
9 |
10 | def calculate_sklearn_sequence_construct(operator):
11 | pass
12 |
13 |
14 | register_shape_calculator("SklearnSequenceAt", calculate_sklearn_sequence_at)
15 | register_shape_calculator(
16 | "SklearnSequenceConstruct", calculate_sklearn_sequence_construct
17 | )
18 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/sgd_oneclass_svm.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from ..common._registration import register_shape_calculator
4 | from ..common.data_types import Int64TensorType
5 |
6 |
7 | def calculate_sgd_oneclass_svm_output_shapes(operator):
8 | N = operator.inputs[0].get_first_dimension()
9 | operator.outputs[0].type = Int64TensorType(
10 | [
11 | N,
12 | ]
13 | )
14 | operator.outputs[1].type.shape = [
15 | N,
16 | ]
17 |
18 |
19 | register_shape_calculator(
20 | "SklearnSGDOneClassSVM", calculate_sgd_oneclass_svm_output_shapes
21 | )
22 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/svd.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.data_types import FloatTensorType, Int64TensorType, DoubleTensorType
6 | from ..common.utils import check_input_and_output_numbers
7 | from ..common.utils import check_input_and_output_types
8 |
9 |
10 | def calculate_sklearn_truncated_svd_output_shapes(operator):
11 | """
12 | Allowed input/output patterns are
13 | 1. [N, C] ---> [N, K]
14 |
15 | Transform feature dimension from C to K
16 | """
17 | check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
18 | check_input_and_output_types(
19 | operator,
20 | good_input_types=[FloatTensorType, Int64TensorType, DoubleTensorType],
21 | good_output_types=[FloatTensorType, DoubleTensorType],
22 | )
23 |
24 | if len(operator.inputs[0].type.shape) != 2:
25 | raise RuntimeError("Only 2-D tensor(s) can be input(s).")
26 |
27 | cls_type = operator.inputs[0].type.__class__
28 | if cls_type != DoubleTensorType:
29 | cls_type = FloatTensorType
30 | N = operator.inputs[0].get_first_dimension()
31 | K = (
32 | operator.raw_operator.n_components
33 | if operator.type == "SklearnTruncatedSVD"
34 | else operator.raw_operator.n_components_
35 | )
36 |
37 | operator.outputs[0].type = cls_type([N, K])
38 |
39 |
40 | register_shape_calculator(
41 | "SklearnIncrementalPCA", calculate_sklearn_truncated_svd_output_shapes
42 | )
43 | register_shape_calculator("SklearnPCA", calculate_sklearn_truncated_svd_output_shapes)
44 | register_shape_calculator(
45 | "SklearnTruncatedSVD", calculate_sklearn_truncated_svd_output_shapes
46 | )
47 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/target_encoder.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.data_types import FloatTensorType
6 | from ..common.data_types import Int64TensorType, StringTensorType
7 | from ..common.utils import check_input_and_output_numbers
8 | from ..common.utils import check_input_and_output_types
9 |
10 |
11 | def calculate_sklearn_target_encoder_output_shapes(operator):
12 | """
13 | This function just copy the input shape to the output because target
14 | encoder only alters input features' values, not their shape.
15 | """
16 | check_input_and_output_numbers(operator, output_count_range=1)
17 | check_input_and_output_types(
18 | operator, good_input_types=[FloatTensorType, Int64TensorType, StringTensorType]
19 | )
20 |
21 | N = operator.inputs[0].get_first_dimension()
22 | shape = [N, len(operator.raw_operator.categories_)]
23 |
24 | operator.outputs[0].type = FloatTensorType(shape=shape)
25 |
26 |
27 | register_shape_calculator(
28 | "SklearnTargetEncoder", calculate_sklearn_target_encoder_output_shapes
29 | )
30 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/text_vectorizer.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.utils import check_input_and_output_numbers
6 |
7 |
8 | def calculate_sklearn_text_vectorizer_output_shapes(operator):
9 | """
10 | Allowed input/output patterns are
11 | 1. Map ---> [1, C]
12 |
13 | C is the total number of allowed keys in the input dictionary.
14 | """
15 | check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
16 |
17 | C = max(operator.raw_operator.vocabulary_.values()) + 1
18 | operator.outputs[0].type.shape = [None, C]
19 |
20 |
21 | register_shape_calculator(
22 | "SklearnCountVectorizer", calculate_sklearn_text_vectorizer_output_shapes
23 | )
24 | register_shape_calculator(
25 | "SklearnTfidfVectorizer", calculate_sklearn_text_vectorizer_output_shapes
26 | )
27 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/tfidf_transformer.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.utils import check_input_and_output_numbers
6 |
7 |
8 | def calculate_sklearn_tfidf_transformer_output_shapes(operator):
9 | check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
10 | C = operator.inputs[0].type.shape[1]
11 | operator.outputs[0].type.shape = [1, C]
12 |
13 |
14 | register_shape_calculator(
15 | "SklearnTfidfTransformer", calculate_sklearn_tfidf_transformer_output_shapes
16 | )
17 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/tuned_threshold_classifier.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from ..common._registration import register_shape_calculator
4 | from ..common.utils import check_input_and_output_numbers
5 | from ..common.shape_calculator import _infer_linear_classifier_output_types
6 |
7 |
8 | def tuned_threshold_classifier_shape_calculator(operator):
9 | check_input_and_output_numbers(operator, output_count_range=2)
10 |
11 | _infer_linear_classifier_output_types(operator)
12 |
13 |
14 | register_shape_calculator(
15 | "SklearnTunedThresholdClassifierCV", tuned_threshold_classifier_shape_calculator
16 | )
17 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/voting_classifier.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from ..common._registration import register_shape_calculator
4 | from ..common.utils import check_input_and_output_numbers
5 | from ..common.shape_calculator import _infer_linear_classifier_output_types
6 |
7 |
8 | def voting_classifier_shape_calculator(operator):
9 | check_input_and_output_numbers(operator, output_count_range=2)
10 |
11 | _infer_linear_classifier_output_types(operator)
12 |
13 |
14 | register_shape_calculator("SklearnVotingClassifier", voting_classifier_shape_calculator)
15 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/voting_regressor.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 | from ..common.utils import check_input_and_output_numbers
6 | from ..common.shape_calculator import _infer_linear_regressor_output_types
7 |
8 |
9 | def voting_regressor_shape_calculator(operator):
10 | check_input_and_output_numbers(operator, output_count_range=1)
11 | return _infer_linear_regressor_output_types(operator)
12 |
13 |
14 | register_shape_calculator("SklearnVotingRegressor", voting_regressor_shape_calculator)
15 |
--------------------------------------------------------------------------------
/skl2onnx/shape_calculators/zip_map.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from ..common._registration import register_shape_calculator
5 |
6 |
7 | def calculate_sklearn_zipmap(operator):
8 | if len(operator.inputs) != len(operator.outputs) or len(operator.inputs) not in (
9 | 1,
10 | 2,
11 | ):
12 | raise RuntimeError(
13 | "SklearnZipMap expects the same number of inputs and outputs."
14 | )
15 | if len(operator.inputs) == 2:
16 | operator.outputs[0].type = operator.inputs[0].type.__class__(
17 | operator.inputs[0].type.shape
18 | )
19 | if operator.outputs[1].type is not None:
20 | operator.outputs[1].type.element_type.value_type = operator.inputs[
21 | 1
22 | ].type.__class__([])
23 |
24 |
25 | def calculate_sklearn_zipmap_columns(operator):
26 | N = operator.inputs[0].get_first_dimension()
27 | operator.outputs[0].type = operator.inputs[0].type.__class__(
28 | operator.inputs[0].type.shape
29 | )
30 | for i in range(1, len(operator.outputs)):
31 | operator.outputs[i].type.shape = [N]
32 |
33 |
34 | register_shape_calculator("SklearnZipMap", calculate_sklearn_zipmap)
35 | register_shape_calculator("SklearnZipMapColumns", calculate_sklearn_zipmap_columns)
36 |
--------------------------------------------------------------------------------
/skl2onnx/sklapi/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from .cast_transformer import CastTransformer
5 | from .cast_regressor import CastRegressor
6 | from .replace_transformer import ReplaceTransformer
7 | from .sklearn_text import TraceableCountVectorizer, TraceableTfidfVectorizer
8 | from .woe_transformer import WOETransformer
9 |
--------------------------------------------------------------------------------
/skl2onnx/sklapi/cast_regressor.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import numpy as np
4 | from sklearn.base import RegressorMixin, BaseEstimator
5 |
6 | try:
7 | from sklearn.utils.validation import _deprecate_positional_args
8 | except ImportError:
9 |
10 | def _deprecate_positional_args(x):
11 | return x
12 |
13 |
14 | class CastRegressor(RegressorMixin, BaseEstimator):
15 | """
16 | Cast predictions into a specific types.
17 | This should be used to minimize the conversion
18 | of a pipeline using float32 instead of double
19 | when onnx do not support double.
20 |
21 | Parameters
22 | ----------
23 | estimator : regressor
24 | wrapped regressor
25 | dtype : numpy type,
26 | output are cast into that type
27 | """
28 |
29 | @_deprecate_positional_args
30 | def __init__(self, estimator, *, dtype=np.float32):
31 | self.dtype = dtype
32 | self.estimator = estimator
33 |
34 | def _cast(self, a, name):
35 | try:
36 | a2 = a.astype(self.dtype)
37 | except ValueError as e:
38 | raise ValueError(
39 | "Unable to cast {} from {} into {}.".format(name, a.dtype, self.dtype)
40 | ) from e
41 | return a2
42 |
43 | def fit(self, X, y=None, sample_weight=None):
44 | """
45 | Does nothing except checking *dtype* may be applied.
46 | """
47 | self.estimator.fit(X, y=y, sample_weight=sample_weight)
48 | return self
49 |
50 | def predict(self, X, y=None):
51 | """
52 | Predicts and casts the prediction.
53 | """
54 | return self._cast(self.estimator.predict(X), "predict(X)")
55 |
56 | def decision_function(self, X, y=None):
57 | """
58 | Calls *decision_function* and casts the outputs.
59 | """
60 | if not hasattr(self.estimator, "decision_function"):
61 | raise AttributeError(
62 | "%r object has no attribute 'decision_function'."
63 | % self.estimator.__class__.__name__
64 | )
65 | return self._cast(self.estimator.decision_function(X), "decision_function(X)")
66 |
--------------------------------------------------------------------------------
/skl2onnx/sklapi/cast_transformer.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import numpy as np
4 | from sklearn.base import TransformerMixin, BaseEstimator
5 |
6 | try:
7 | from sklearn.utils.validation import _deprecate_positional_args
8 | except ImportError:
9 |
10 | def _deprecate_positional_args(x):
11 | return x
12 |
13 |
14 | class CastTransformer(TransformerMixin, BaseEstimator):
15 | """
16 | Cast features into a specific types.
17 | This should be used to minimize the conversion
18 | of a pipeline using float32 instead of double.
19 |
20 | Parameters
21 | ----------
22 | dtype : numpy type,
23 | output are cast into that type
24 | """
25 |
26 | @_deprecate_positional_args
27 | def __init__(self, *, dtype=np.float32):
28 | self.dtype = dtype
29 |
30 | def _cast(self, a, name):
31 | if not isinstance(a, np.ndarray):
32 | if hasattr(a, "values") and hasattr(a, "iloc"):
33 | # dataframe
34 | a = a.values
35 | elif not hasattr(a, "astype"):
36 | raise TypeError("{} must be a numpy array or a dataframe.".format(name))
37 | try:
38 | a2 = a.astype(self.dtype)
39 | except ValueError as e:
40 | raise ValueError(
41 | "Unable to cast {} from {} into {}.".format(name, a.dtype, self.dtype)
42 | ) from e
43 | return a2
44 |
45 | def fit(self, X, y=None, sample_weight=None):
46 | """
47 | Does nothing except checking *dtype* may be applied.
48 | """
49 | self._cast(X, "X")
50 | return self
51 |
52 | def transform(self, X, y=None):
53 | """
54 | Casts array X.
55 | """
56 | return self._cast(X, "X")
57 |
--------------------------------------------------------------------------------
/skl2onnx/sklapi/register.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from .sklearn_text_onnx import register as register_text
4 | from .woe_transformer_onnx import register as register_woe
5 |
6 |
7 | register_text()
8 | register_woe()
9 |
--------------------------------------------------------------------------------
/skl2onnx/sklapi/replace_transformer.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import numpy as np
4 | from sklearn.base import TransformerMixin, BaseEstimator
5 |
6 | try:
7 | from sklearn.utils.validation import _deprecate_positional_args
8 | except ImportError:
9 |
10 | def _deprecate_positional_args(x):
11 | return x
12 |
13 |
14 | class ReplaceTransformer(TransformerMixin, BaseEstimator):
15 | """
16 | Replaces a value by another one.
17 | It can be used to replace 0 by nan.
18 |
19 | Parameters
20 | ----------
21 | from_value : value to replace
22 | to_value : new value
23 | dtype: dtype of replaced values
24 | """
25 |
26 | @_deprecate_positional_args
27 | def __init__(self, *, from_value=0, to_value=np.nan, dtype=np.float32):
28 | BaseEstimator.__init__(self)
29 | self.dtype = dtype
30 | self.from_value = from_value
31 | self.to_value = to_value
32 |
33 | def _replace(self, a):
34 | if hasattr(a, "todense"):
35 | if np.isnan(self.to_value) and self.from_value == 0:
36 | # implicit
37 | return a
38 | raise RuntimeError(
39 | "Unable to replace 0 by nan one value by another in sparse matrix."
40 | )
41 | return np.where(a == self.from_value, self.to_value, a)
42 |
43 | def fit(self, X, y=None, sample_weight=None):
44 | """
45 | Does nothing except checking *dtype* may be applied.
46 | """
47 | self._replace(X)
48 | return self
49 |
50 | def transform(self, X, y=None):
51 | """
52 | Casts array X.
53 | """
54 | return self._replace(X)
55 |
--------------------------------------------------------------------------------
/skl2onnx/sklapi/sklearn_text_onnx.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | from .. import update_registered_converter
4 | from ..shape_calculators.text_vectorizer import (
5 | calculate_sklearn_text_vectorizer_output_shapes,
6 | )
7 | from ..operator_converters.text_vectoriser import convert_sklearn_text_vectorizer
8 | from ..operator_converters.tfidf_vectoriser import convert_sklearn_tfidf_vectoriser
9 | from .sklearn_text import TraceableCountVectorizer, TraceableTfidfVectorizer
10 |
11 |
12 | def register():
13 | """Register converter for TraceableCountVectorizer,
14 | TraceableTfidfVectorizer."""
15 | update_registered_converter(
16 | TraceableCountVectorizer,
17 | "Skl2onnxTraceableCountVectorizer",
18 | calculate_sklearn_text_vectorizer_output_shapes,
19 | convert_sklearn_text_vectorizer,
20 | options={
21 | "tokenexp": None,
22 | "separators": None,
23 | "nan": [True, False],
24 | "keep_empty_string": [True, False],
25 | "locale": None,
26 | },
27 | )
28 |
29 | update_registered_converter(
30 | TraceableTfidfVectorizer,
31 | "Skl2onnxTraceableTfidfVectorizer",
32 | calculate_sklearn_text_vectorizer_output_shapes,
33 | convert_sklearn_tfidf_vectoriser,
34 | options={
35 | "tokenexp": None,
36 | "separators": None,
37 | "nan": [True, False],
38 | "keep_empty_string": [True, False],
39 | "locale": None,
40 | },
41 | )
42 |
--------------------------------------------------------------------------------
/skl2onnx/tutorial/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | Shortcuts to *tutorial*.
5 | """
6 |
7 | from .benchmark import measure_time
8 |
--------------------------------------------------------------------------------
/skl2onnx/tutorial/benchmark.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | Tools to help benchmarking.
5 | """
6 |
7 | from timeit import Timer
8 | import numpy
9 |
10 |
11 | def measure_time(stmt, context, repeat=10, number=50, div_by_number=False):
12 | """
13 | Measures a statement and returns the results as a dictionary.
14 |
15 | :param stmt: string
16 | :param context: variable to know in a dictionary
17 | :param repeat: average over *repeat* experiment
18 | :param number: number of executions in one row
19 | :param div_by_number: divide by the number of executions
20 | :return: dictionary
21 |
22 | .. runpython::
23 | :showcode:
24 |
25 | from skl2onnx.tutorial import measure_time
26 | from math import cos
27 |
28 | res = measure_time("cos(x)", context=dict(cos=cos, x=5.))
29 | print(res)
30 |
31 | See `Timer.repeat `_
33 | for a better understanding of parameter *repeat* and *number*.
34 | The function returns a duration corresponding to
35 | *number* times the execution of the main statement.
36 | """
37 | tim = Timer(stmt, globals=context)
38 | res = numpy.array(tim.repeat(repeat=repeat, number=number))
39 | if div_by_number:
40 | res /= number
41 | mean = numpy.mean(res)
42 | dev = numpy.mean(res**2)
43 | dev = (dev - mean**2) ** 0.5
44 | mes = dict(
45 | average=mean,
46 | deviation=dev,
47 | min_exec=numpy.min(res),
48 | max_exec=numpy.max(res),
49 | repeat=repeat,
50 | number=number,
51 | )
52 | return mes
53 |
--------------------------------------------------------------------------------
/tests/benchmark.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | You can run this file with to get a report on every tested model conversion.
5 |
6 | ::
7 |
8 | python -u tests/benchmark.py
9 |
10 | Folder contains the model to compare implemented as unit tests.
11 | """
12 |
13 | import os
14 | import sys
15 | import unittest
16 | import warnings
17 |
18 |
19 | def run_all_tests(folder=None, verbose=True):
20 | """
21 | Runs all unit tests or unit tests specific to one library.
22 | The tests produce a series of files dumped into ``folder``
23 | which can be later used to tests a backend (or a runtime).
24 |
25 | :param folder: where to put the dumped files
26 | :param verbose: verbose
27 | """
28 | if folder is None:
29 | folder = "TESTDUMP"
30 | os.environ["ONNXTESTDUMP"] = folder
31 | os.environ["ONNXTESTDUMPERROR"] = "1"
32 | os.environ["ONNXTESTBENCHMARK"] = "1"
33 |
34 | if verbose:
35 | print("[benchmark] look into '{0}'".format(folder))
36 |
37 | try:
38 | import onnxmltools # noqa: F401
39 | except ImportError:
40 | warnings.warn("Cannot import onnxmltools. Some tests won't work.")
41 |
42 | this = os.path.abspath(os.path.dirname(__file__))
43 | subs = [this]
44 | loader = unittest.TestLoader()
45 | suites = []
46 |
47 | for sub in subs:
48 | fold = os.path.join(this, sub)
49 | if not os.path.exists(fold):
50 | raise FileNotFoundError("Unable to find '{0}'".format(fold))
51 |
52 | # ts = loader.discover(fold)
53 | sys.path.append(fold)
54 | names = [_ for _ in os.listdir(fold) if _.startswith("test")]
55 | for name in names:
56 | name = os.path.splitext(name)[0]
57 | ts = loader.loadTestsFromName(name)
58 | suites.append(ts)
59 | index = sys.path.index(fold)
60 | del sys.path[index]
61 |
62 | with warnings.catch_warnings():
63 | warnings.filterwarnings(category=DeprecationWarning, action="ignore")
64 | warnings.filterwarnings(category=FutureWarning, action="ignore")
65 | runner = unittest.TextTestRunner()
66 | for tsi, ts in enumerate(suites):
67 | for k in ts:
68 | try:
69 | for t in k:
70 | print(t.__class__.__name__)
71 | break
72 | except TypeError as e:
73 | raise RuntimeError("Unable to run test '{}'.".format(ts)) from e
74 | runner.run(ts)
75 |
76 | from test_utils.tests_helper import make_report_backend
77 |
78 | df = make_report_backend(folder, as_df=True)
79 |
80 | from pandas import set_option
81 |
82 | set_option("display.max_columns", None)
83 | set_option("display.max_rows", None)
84 | exfile = os.path.join(folder, "report_backend.xlsx")
85 | df.to_excel(exfile)
86 | if verbose:
87 | print("[benchmark] wrote report in '{0}'".format(exfile))
88 | return df
89 |
90 |
91 | if __name__ == "__main__":
92 | folder = None if len(sys.argv) < 2 else sys.argv[1]
93 | run_all_tests(folder=folder)
94 |
--------------------------------------------------------------------------------
/tests/datasets/treecl.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/sklearn-onnx/eaac0e13333962a2391a33c9d5192e382b7a985d/tests/datasets/treecl.onnx
--------------------------------------------------------------------------------
/tests/datasets/treecl2.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/sklearn-onnx/eaac0e13333962a2391a33c9d5192e382b7a985d/tests/datasets/treecl2.onnx
--------------------------------------------------------------------------------
/tests/datasets/treecl3.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/sklearn-onnx/eaac0e13333962a2391a33c9d5192e382b7a985d/tests/datasets/treecl3.onnx
--------------------------------------------------------------------------------
/tests/test_algebra_complex.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | from numpy.testing import assert_almost_equal
4 | from onnxruntime import InferenceSession
5 |
6 | try:
7 | from onnxruntime.capi.onnxruntime_pybind11_state import (
8 | InvalidGraph,
9 | Fail,
10 | InvalidArgument,
11 | )
12 | except ImportError:
13 | InvalidGraph = RuntimeError
14 | InvalidArgument = RuntimeError
15 | Fail = RuntimeError
16 | from skl2onnx.common.data_types import Complex64TensorType, Complex128TensorType
17 | from skl2onnx.algebra.onnx_ops import OnnxAdd
18 | from test_utils import TARGET_OPSET
19 |
20 |
21 | class TestAlgebraComplex(unittest.TestCase):
22 | @unittest.skipIf(Complex64TensorType is None, reason="not available")
23 | @unittest.skipIf(TARGET_OPSET < 13, reason="not implemented")
24 | def test_complex(self):
25 | for dt, var, pr in (
26 | (np.complex64, Complex64TensorType, 14),
27 | (np.complex128, Complex128TensorType, 15),
28 | ):
29 | X = np.array([[1 - 2j, -12j], [-1 - 2j, 1 + 2j]]).astype(dt)
30 |
31 | for opv in range(10, 20):
32 | if opv > TARGET_OPSET:
33 | continue
34 | with self.subTest(dt=dt, opset=opv):
35 | out = OnnxAdd(
36 | "X",
37 | np.array([1 + 2j], dtype=dt),
38 | output_names=["Y"],
39 | op_version=opv,
40 | )
41 | onx = out.to_onnx(
42 | [("X", var((None, 2)))],
43 | outputs=[("Y", var())],
44 | target_opset=opv,
45 | )
46 | self.assertIn("elem_type: %d" % pr, str(onx))
47 |
48 | try:
49 | ort = InferenceSession(
50 | onx.SerializeToString(), providers=["CPUExecutionProvider"]
51 | )
52 | except InvalidGraph as e:
53 | if "Type Error: Type 'tensor(complex" in str(e):
54 | continue
55 | raise e
56 | assert ort is not None
57 | got = ort.run(None, {"X": X})[0]
58 | assert_almost_equal(X + np.array([1 + 2j]), got)
59 |
60 |
61 | if __name__ == "__main__":
62 | unittest.main()
63 |
--------------------------------------------------------------------------------
/tests/test_algebra_double.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import unittest
4 | import packaging.version as pv
5 | import numpy
6 | from numpy.testing import assert_almost_equal
7 | from skl2onnx.algebra.onnx_ops import OnnxMatMul, OnnxSub
8 | import onnxruntime
9 | from onnxruntime import InferenceSession
10 | from test_utils import TARGET_OPSET
11 |
12 |
13 | class TestAlgebraDouble(unittest.TestCase):
14 | @unittest.skipIf(TARGET_OPSET < 10, reason="not available")
15 | @unittest.skipIf(
16 | pv.Version(onnxruntime.__version__) <= pv.Version("0.4.0"),
17 | reason="Sub(7) not available",
18 | )
19 | def test_algebra_converter(self):
20 | coef = numpy.array([[1, 2], [3, 4]], dtype=numpy.float64)
21 | intercept = 1
22 | X_test = numpy.array([[1, -2], [3, -4]], dtype=numpy.float64)
23 |
24 | onnx_fct = OnnxSub(
25 | OnnxMatMul("X", coef, op_version=TARGET_OPSET),
26 | numpy.array([intercept], dtype=numpy.float64),
27 | output_names=["Y"],
28 | op_version=TARGET_OPSET,
29 | )
30 | onnx_model = onnx_fct.to_onnx({"X": X_test}, target_opset=TARGET_OPSET)
31 |
32 | sess = InferenceSession(
33 | onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
34 | )
35 | ort_pred = sess.run(None, {"X": X_test})[0]
36 | assert_almost_equal(ort_pred, numpy.array([[-6.0, -7.0], [-10.0, -11.0]]))
37 |
38 |
39 | if __name__ == "__main__":
40 | unittest.main()
41 |
--------------------------------------------------------------------------------
/tests/test_algebra_onnx_doc.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import unittest
4 | import sys
5 | import numpy as np
6 | from numpy.testing import assert_almost_equal
7 | import onnx
8 | from skl2onnx.algebra.onnx_ops import dynamic_class_creation
9 | from skl2onnx.algebra.automation import get_rst_doc_sklearn
10 | from test_utils import TARGET_OPSET
11 |
12 |
13 | class TestAlgebraOnnxDoc(unittest.TestCase):
14 | def setUp(self):
15 | self._algebra = dynamic_class_creation()
16 |
17 | def predict_with_onnxruntime(self, model_def, *inputs):
18 | import onnxruntime as ort
19 |
20 | sess = ort.InferenceSession(
21 | model_def.SerializeToString(), providers=["CPUExecutionProvider"]
22 | )
23 | names = [i.name for i in sess.get_inputs()]
24 | input = {name: input for name, input in zip(names, inputs)}
25 | res = sess.run(None, input)
26 | names = [o.name for o in sess.get_outputs()]
27 | return {name: output for name, output in zip(names, res)}
28 |
29 | @unittest.skipIf(TARGET_OPSET < 10, reason="not available")
30 | def test_transpose2(self):
31 | from skl2onnx.algebra.onnx_ops import OnnxTranspose
32 |
33 | node = OnnxTranspose(
34 | OnnxTranspose("X", perm=[1, 0, 2], op_version=TARGET_OPSET),
35 | perm=[1, 0, 2],
36 | output_names=["Y"],
37 | op_version=TARGET_OPSET,
38 | )
39 | X = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32)
40 |
41 | model_def = node.to_onnx({"X": X})
42 | onnx.checker.check_model(model_def)
43 | res = self.predict_with_onnxruntime(model_def, X)
44 | assert_almost_equal(res["Y"], X)
45 |
46 | @unittest.skipIf(
47 | sys.platform.startswith("win"), reason="onnx schema are incorrect on Windows"
48 | )
49 | @unittest.skipIf(TARGET_OPSET <= 20, reason="not available")
50 | def test_doc_sklearn(self):
51 | rst = get_rst_doc_sklearn()
52 | assert (
53 | ".. _l-sklops-OnnxSklearnBernoulliNB:" in rst
54 | ), f"Unable to find a substring in {rst}"
55 |
56 |
57 | if __name__ == "__main__":
58 | unittest.main(verbosity=2)
59 |
--------------------------------------------------------------------------------
/tests/test_algebra_onnx_operators_opset.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import unittest
4 | import numpy as np
5 | from numpy.testing import assert_almost_equal
6 | import onnx
7 | import onnxruntime as ort
8 | from skl2onnx.algebra.onnx_ops import OnnxPad
9 |
10 |
11 | class TestOnnxOperatorsOpset(unittest.TestCase):
12 | @unittest.skipIf(onnx.defs.onnx_opset_version() < 10, "irrelevant")
13 | def test_pad_opset_10(self):
14 | pad = OnnxPad(
15 | "X",
16 | output_names=["Y"],
17 | mode="constant",
18 | value=1.5,
19 | pads=[0, 1, 0, 1],
20 | op_version=2,
21 | )
22 |
23 | X = np.array([[0, 1]], dtype=np.float32)
24 | model_def = pad.to_onnx({"X": X}, target_opset=10)
25 | onnx.checker.check_model(model_def)
26 |
27 | def predict_with_onnxruntime(model_def, *inputs):
28 | sess = ort.InferenceSession(
29 | model_def.SerializeToString(), providers=["CPUExecutionProvider"]
30 | )
31 | names = [i.name for i in sess.get_inputs()]
32 | dinputs = {name: input for name, input in zip(names, inputs)}
33 | res = sess.run(None, dinputs)
34 | names = [o.name for o in sess.get_outputs()]
35 | return {name: output for name, output in zip(names, res)}
36 |
37 | Y = predict_with_onnxruntime(model_def, X)
38 | assert_almost_equal(np.array([[1.5, 0.0, 1.0, 1.5]], dtype=np.float32), Y["Y"])
39 |
40 |
41 | if __name__ == "__main__":
42 | unittest.main()
43 |
--------------------------------------------------------------------------------
/tests/test_issues_2025.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import unittest
3 | from sklearn.utils._testing import ignore_warnings
4 | from sklearn.exceptions import ConvergenceWarning
5 |
6 |
7 | class TestInvestigate2025(unittest.TestCase):
8 | @ignore_warnings(category=(ConvergenceWarning, FutureWarning))
9 | def test_issue_1161_gaussian(self):
10 | # https://github.com/onnx/sklearn-onnx/issues/1161
11 | import numpy as np
12 | from sklearn.gaussian_process import GaussianProcessRegressor
13 | from sklearn.gaussian_process.kernels import WhiteKernel
14 | from skl2onnx import convert_sklearn
15 | from skl2onnx.common.data_types import FloatTensorType
16 |
17 | # Generate sample data
18 | X = np.array([[1], [3], [5], [6], [7], [8], [10], [12], [14], [15]])
19 | y = np.array([3, 2, 7, 8, 7, 6, 9, 11, 10, 12])
20 |
21 | # Define the kernel
22 | kernel = WhiteKernel()
23 |
24 | # Create and train the Gaussian Process Regressor
25 | gpr = GaussianProcessRegressor(
26 | kernel=kernel, n_restarts_optimizer=10, alpha=1e-2
27 | )
28 | gpr.fit(X, y)
29 |
30 | # Convert the trained model to ONNX format
31 | initial_type = [("float_input", FloatTensorType([None, 1]))]
32 | onnx_model = convert_sklearn(
33 | gpr,
34 | initial_types=initial_type,
35 | options={GaussianProcessRegressor: {"return_std": True}},
36 | )
37 | self.assertTrue(onnx_model is not None)
38 |
39 |
40 | if __name__ == "__main__":
41 | unittest.main(verbosity=2)
42 |
--------------------------------------------------------------------------------
/tests/test_onnx_rare_helper.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | Tests on functions in *onnx_helper*.
5 | """
6 |
7 | import unittest
8 | from sklearn.datasets import load_iris
9 | from sklearn.cluster import KMeans
10 | from sklearn.neighbors import NearestNeighbors
11 | from onnx.defs import onnx_opset_version
12 | from skl2onnx import convert_sklearn
13 | from skl2onnx.common.data_types import FloatTensorType
14 | from skl2onnx.helpers.onnx_rare_helper import upgrade_opset_number
15 | from test_utils import TARGET_OPSET
16 |
17 |
18 | class TestOnnxRareHelper(unittest.TestCase):
19 | def test_kmeans_upgrade(self):
20 | data = load_iris()
21 | X = data.data
22 | model = KMeans(n_clusters=3)
23 | model.fit(X)
24 | model_onnx = convert_sklearn(
25 | model, "kmeans", [("input", FloatTensorType([None, 4]))], target_opset=7
26 | )
27 | model8 = upgrade_opset_number(model_onnx, 8)
28 | assert "version: 8" in str(model8)
29 |
30 | @unittest.skipIf(onnx_opset_version() < 11, reason="Needs opset >= 11")
31 | def test_knn_upgrade(self):
32 | iris = load_iris()
33 | X, _ = iris.data, iris.target
34 |
35 | clr = NearestNeighbors(n_neighbors=3, radius=None)
36 | clr.fit(X)
37 |
38 | model_onnx = convert_sklearn(
39 | clr, "up", [("input", FloatTensorType([None, 4]))], target_opset=9
40 | )
41 | try:
42 | upgrade_opset_number(model_onnx, 8)
43 | raise AssertionError()
44 | except RuntimeError:
45 | pass
46 | try:
47 | upgrade_opset_number(model_onnx, TARGET_OPSET)
48 | except RuntimeError as e:
49 | assert "was updated" in str(e)
50 |
51 |
52 | if __name__ == "__main__":
53 | unittest.main()
54 |
--------------------------------------------------------------------------------
/tests/test_raw_name.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy
4 | import onnxruntime as rt
5 | from numpy.testing import assert_almost_equal
6 | from skl2onnx import convert_sklearn
7 | from skl2onnx.common.data_types import FloatTensorType
8 | from sklearn.datasets import load_iris
9 | from sklearn.linear_model import LogisticRegression
10 | from test_utils import TARGET_OPSET
11 |
12 |
13 | class RawNameTest(unittest.TestCase):
14 | _raw_names = (
15 | "float_input",
16 | "float_input--",
17 | "float_input(",
18 | "float_input)",
19 | )
20 |
21 | @staticmethod
22 | def _load_data():
23 | iris = load_iris()
24 | return iris.data[:, :2], iris.target
25 |
26 | @staticmethod
27 | def _train_model(X, y):
28 | return LogisticRegression().fit(X, y)
29 |
30 | @staticmethod
31 | def _get_initial_types(X, raw_name):
32 | return [(raw_name, FloatTensorType([None, X.shape[1]]))]
33 |
34 | @staticmethod
35 | def _predict(clr_onnx, X):
36 | sess = rt.InferenceSession(
37 | clr_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
38 | )
39 | input_name = sess.get_inputs()[0].name
40 | label_name = sess.get_outputs()[0].name
41 | return sess.run([label_name], {input_name: X.astype(numpy.float32)})[0]
42 |
43 | def test_raw_name(self):
44 | """
45 | Assert that input raw names do not break the compilation
46 | of the graph and that the ONNX model still produces
47 | correct predictions.
48 | """
49 | X, y = self._load_data()
50 | clr = self._train_model(X, y)
51 | pred = clr.predict(X)
52 | for raw_name in self._raw_names:
53 | with self.subTest(raw_name=raw_name):
54 | clr_onnx = convert_sklearn(
55 | clr,
56 | initial_types=self._get_initial_types(X, raw_name),
57 | target_opset=TARGET_OPSET,
58 | )
59 | pred_onnx = self._predict(clr_onnx, X)
60 | assert_almost_equal(pred, pred_onnx)
61 |
62 |
63 | if __name__ == "__main__":
64 | unittest.main()
65 |
--------------------------------------------------------------------------------
/tests/test_scikit_pandas.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | Tests scikit-learn's binarizer converter.
5 | """
6 |
7 | import unittest
8 | import pandas
9 | from sklearn.preprocessing import StandardScaler, MinMaxScaler
10 |
11 | from skl2onnx.common.data_types import FloatTensorType
12 | from skl2onnx import convert_sklearn
13 |
14 |
15 | def has_scikit_pandas():
16 | try:
17 | import sklearn_pandas # noqa: F401
18 |
19 | return True
20 | except ImportError:
21 | return False
22 |
23 |
24 | def dataframe_mapper_shape_calculator(operator):
25 | if len(operator.inputs) == 1:
26 | raise RuntimeError("DataFrameMapper has no associated parser.")
27 |
28 |
29 | class TestOtherLibrariesInPipelineScikitPandas(unittest.TestCase):
30 | @unittest.skipIf(not has_scikit_pandas(), reason="scikit-pandas not installed")
31 | def test_scikit_pandas(self):
32 | from sklearn_pandas import DataFrameMapper
33 |
34 | df = pandas.DataFrame(
35 | {
36 | "feat1": [1, 2, 3, 4, 5, 6],
37 | "feat2": [1.0, 2.0, 3.0, 2.0, 3.0, 4.0],
38 | }
39 | )
40 |
41 | mapper = DataFrameMapper(
42 | [
43 | (["feat1", "feat2"], StandardScaler()),
44 | (["feat1", "feat2"], MinMaxScaler()),
45 | ]
46 | )
47 |
48 | try:
49 | convert_sklearn(
50 | mapper,
51 | "predictable_tsne",
52 | [("input", FloatTensorType([None, df.shape[1]]))],
53 | custom_shape_calculators={
54 | DataFrameMapper: dataframe_mapper_shape_calculator
55 | },
56 | )
57 | except RuntimeError as e:
58 | assert "DataFrameMapper has no associated parser." in str(e)
59 |
60 |
61 | if __name__ == "__main__":
62 | unittest.main()
63 |
--------------------------------------------------------------------------------
/tests/test_sklearn_binarizer_converter.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | Tests scikit-learn's binarizer converter.
5 | """
6 |
7 | import unittest
8 | import numpy as np
9 | from sklearn.preprocessing import Binarizer
10 | from skl2onnx import convert_sklearn
11 | from skl2onnx.common.data_types import FloatTensorType
12 | from test_utils import dump_data_and_model, TARGET_OPSET
13 |
14 |
15 | class TestSklearnBinarizer(unittest.TestCase):
16 | def test_model_binarizer(self):
17 | data = np.array(
18 | [[1.0, -1.0, 2.0], [2.0, 0.0, 0.0], [0.0, 1.0, -1.0]], dtype=np.float32
19 | )
20 | model = Binarizer(threshold=0.5)
21 | model.fit(data)
22 | model_onnx = convert_sklearn(
23 | model,
24 | "scikit-learn binarizer",
25 | [("input", FloatTensorType(data.shape))],
26 | target_opset=TARGET_OPSET,
27 | )
28 | self.assertTrue(model_onnx is not None)
29 | dump_data_and_model(
30 | data, model, model_onnx, basename="SklearnBinarizer-SkipDim1"
31 | )
32 |
33 |
34 | if __name__ == "__main__":
35 | unittest.main()
36 |
--------------------------------------------------------------------------------
/tests/test_sklearn_classifiers_extreme.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import unittest
4 | import numpy as np
5 |
6 | try:
7 | from onnx.reference import ReferenceEvaluator
8 | except ImportError:
9 | ReferenceEvaluator = None
10 | from sklearn.tree import DecisionTreeClassifier
11 | from onnxruntime import InferenceSession
12 | from skl2onnx import to_onnx
13 | from test_utils import TARGET_OPSET
14 |
15 |
16 | class TestSklearnClassifiersExtreme(unittest.TestCase):
17 | def test_one_training_class(self):
18 | x = np.eye(4, dtype=np.float32)
19 | y = np.array([5, 5, 5, 5], dtype=np.int64)
20 |
21 | cl = DecisionTreeClassifier()
22 | cl = cl.fit(x, y)
23 |
24 | expected = [cl.predict(x), cl.predict_proba(x)]
25 | onx = to_onnx(cl, x, target_opset=TARGET_OPSET, options={"zipmap": False})
26 |
27 | for cls in [
28 | (
29 | (lambda onx: ReferenceEvaluator(onx, verbose=0))
30 | if ReferenceEvaluator is not None
31 | else None
32 | ),
33 | lambda onx: InferenceSession(
34 | onx.SerializeToString(), providers=["CPUExecutionProvider"]
35 | ),
36 | ]:
37 | if cls is None:
38 | continue
39 | sess = cls(onx)
40 | res = sess.run(None, {"X": x})
41 | self.assertEqual(len(res), len(expected))
42 | for e, g in zip(expected, res):
43 | self.assertEqual(e.tolist(), g.tolist())
44 |
45 |
46 | if __name__ == "__main__":
47 | unittest.main(verbosity=2)
48 |
--------------------------------------------------------------------------------
/tests/test_sklearn_constant_predictor.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """Tests scikit-learn's SGDClassifier converter."""
4 |
5 | import unittest
6 | import numpy as np
7 | from sklearn.multiclass import _ConstantPredictor
8 | from onnxruntime import __version__ as ort_version
9 | from skl2onnx import to_onnx
10 |
11 | from skl2onnx.common.data_types import FloatTensorType, DoubleTensorType
12 |
13 | from test_utils import dump_data_and_model, TARGET_OPSET
14 |
15 | ort_version = ".".join(ort_version.split(".")[:2])
16 |
17 |
18 | class TestConstantPredictorConverter(unittest.TestCase):
19 | def test_constant_predictor_float(self):
20 | model = _ConstantPredictor()
21 | X = np.array([[1, 2]])
22 | y = np.array([0])
23 | model.fit(X, y)
24 | test_x = np.array([[1, 0], [2, 8]])
25 |
26 | model_onnx = to_onnx(
27 | model,
28 | "scikit-learn ConstantPredictor",
29 | initial_types=[("input", FloatTensorType([None, X.shape[1]]))],
30 | target_opset=TARGET_OPSET,
31 | options={"zipmap": False},
32 | )
33 |
34 | self.assertIsNotNone(model_onnx is not None)
35 | dump_data_and_model(
36 | test_x.astype(np.float32),
37 | model,
38 | model_onnx,
39 | basename="SklearnConstantPredictorFloat",
40 | )
41 |
42 | def test_constant_predictor_double(self):
43 | model = _ConstantPredictor()
44 | X = np.array([[1, 2]])
45 | y = np.array([0])
46 | model.fit(X, y)
47 | test_x = np.array([[1, 0], [2, 8]])
48 |
49 | model_onnx = to_onnx(
50 | model,
51 | "scikit-learn ConstantPredictor",
52 | initial_types=[("input", DoubleTensorType([None, X.shape[1]]))],
53 | target_opset=TARGET_OPSET,
54 | options={"zipmap": False},
55 | )
56 |
57 | self.assertIsNotNone(model_onnx is not None)
58 | dump_data_and_model(
59 | test_x.astype(np.float64),
60 | model,
61 | model_onnx,
62 | basename="SklearnConstantPredictorDouble",
63 | )
64 |
65 |
66 | if __name__ == "__main__":
67 | unittest.main(verbosity=3)
68 |
--------------------------------------------------------------------------------
/tests/test_sklearn_quantile_transformer.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | Tests scikit-learn's polynomial features converter.
5 | """
6 | import unittest
7 | import numpy as np
8 | from sklearn.preprocessing import QuantileTransformer
9 | from skl2onnx import convert_sklearn
10 | from skl2onnx.common.data_types import FloatTensorType
11 | from test_utils import dump_data_and_model, TARGET_OPSET
12 |
13 |
14 | class TestSklearnQuantileTransformer(unittest.TestCase):
15 | def test_quantile_transformer_simple(self):
16 | X = np.empty((100, 2), dtype=np.float32)
17 | X[:, 0] = np.arange(X.shape[0])
18 | X[:, 1] = np.arange(X.shape[0]) * 2
19 | model = QuantileTransformer(n_quantiles=6).fit(X)
20 | model_onnx = convert_sklearn(
21 | model,
22 | "test",
23 | [("input", FloatTensorType([None, X.shape[1]]))],
24 | target_opset=TARGET_OPSET,
25 | )
26 | dump_data_and_model(
27 | X.astype(np.float32),
28 | model,
29 | model_onnx,
30 | basename="SklearnQuantileTransformerSimple",
31 | )
32 |
33 | def test_quantile_transformer_int(self):
34 | X = np.random.randint(0, 5, (100, 20))
35 | model = QuantileTransformer(n_quantiles=6).fit(X)
36 | model_onnx = convert_sklearn(
37 | model,
38 | "test",
39 | [("input", FloatTensorType([None, X.shape[1]]))],
40 | target_opset=TARGET_OPSET,
41 | )
42 | dump_data_and_model(
43 | X.astype(np.float32),
44 | model,
45 | model_onnx,
46 | basename="SklearnQuantileTransformerInt",
47 | )
48 |
49 | def test_quantile_transformer_nan(self):
50 | X = np.random.randint(0, 5, (100, 20))
51 | X = X.astype(np.float32)
52 | X[0][0] = np.nan
53 | X[1][1] = np.nan
54 | model = QuantileTransformer(n_quantiles=6).fit(X)
55 | model_onnx = convert_sklearn(
56 | model,
57 | "test",
58 | [("input", FloatTensorType([None, X.shape[1]]))],
59 | target_opset=TARGET_OPSET,
60 | )
61 | dump_data_and_model(
62 | X.astype(np.float32),
63 | model,
64 | model_onnx,
65 | basename="SklearnQuantileTransformerNan",
66 | )
67 |
68 |
69 | if __name__ == "__main__":
70 | unittest.main()
71 |
--------------------------------------------------------------------------------
/tests/test_sklearn_random_projection.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | import unittest
5 | import packaging.version as pv
6 | import numpy as np
7 | import onnxruntime
8 | from sklearn.random_projection import GaussianRandomProjection
9 | from skl2onnx import convert_sklearn, to_onnx
10 | from skl2onnx.common.data_types import FloatTensorType
11 | from test_utils import dump_data_and_model, TARGET_OPSET
12 |
13 | nort = pv.Version(onnxruntime.__version__) < pv.Version("0.5.0")
14 |
15 |
16 | class TestSklearnRandomProjection(unittest.TestCase):
17 | @unittest.skipIf(TARGET_OPSET < 9 or nort, reason="MatMul not available")
18 | def test_gaussian_random_projection_float32(self):
19 | rng = np.random.RandomState(42)
20 | pt = GaussianRandomProjection(n_components=4)
21 | X = rng.rand(10, 5)
22 | model = pt.fit(X)
23 | assert model.transform(X).shape[1] == 4
24 | model_onnx = convert_sklearn(
25 | model,
26 | "scikit-learn GaussianRandomProjection",
27 | [("inputs", FloatTensorType([None, X.shape[1]]))],
28 | target_opset=TARGET_OPSET,
29 | )
30 | self.assertIsNotNone(model_onnx)
31 | dump_data_and_model(
32 | X.astype(np.float32), model, model_onnx, basename="GaussianRandomProjection"
33 | )
34 |
35 | @unittest.skipIf(TARGET_OPSET < 9 or nort, reason="MatMul not available")
36 | def test_gaussian_random_projection_float64(self):
37 | rng = np.random.RandomState(42)
38 | pt = GaussianRandomProjection(n_components=4)
39 | X = rng.rand(10, 5).astype(np.float64)
40 | model = pt.fit(X)
41 | model_onnx = to_onnx(model, X[:1], target_opset=TARGET_OPSET)
42 | self.assertIsNotNone(model_onnx)
43 | dump_data_and_model(X, model, model_onnx, basename="GaussianRandomProjection64")
44 |
45 |
46 | if __name__ == "__main__":
47 | unittest.main()
48 |
--------------------------------------------------------------------------------
/tests/test_sklearn_random_trees_embedding.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import unittest
4 | import numpy
5 | from onnxruntime import InferenceSession
6 |
7 | try:
8 | # scikit-learn >= 0.22
9 | from sklearn.utils._testing import ignore_warnings
10 | except ImportError:
11 | # scikit-learn < 0.22
12 | from sklearn.utils.testing import ignore_warnings
13 | from sklearn.exceptions import ConvergenceWarning
14 | from sklearn.datasets import make_regression
15 | from sklearn.ensemble import RandomTreesEmbedding
16 | from skl2onnx import to_onnx
17 | from test_utils import TARGET_OPSET, dump_data_and_model
18 |
19 |
20 | class TestSklearnRandomTreeEmbeddings(unittest.TestCase):
21 | def check_model(self, model, X, name="X"):
22 | try:
23 | sess = InferenceSession(
24 | model.SerializeToString(), providers=["CPUExecutionProvider"]
25 | )
26 | except Exception as e:
27 | raise AssertionError("Unable to load model\n%s" % str(model)) from e
28 | try:
29 | return sess.run(None, {name: X[:7]})
30 | except Exception as e:
31 | raise AssertionError(
32 | "Unable to run model X.shape=%r X.dtype=%r\n%s"
33 | % (X[:7].shape, X.dtype, str(model))
34 | ) from e
35 |
36 | @ignore_warnings(category=(FutureWarning, ConvergenceWarning, DeprecationWarning))
37 | def test_random_trees_embedding(self):
38 | X, _ = make_regression(
39 | n_features=5, n_samples=100, n_targets=1, random_state=42, n_informative=3
40 | )
41 | X = X.astype(numpy.float32)
42 |
43 | model = RandomTreesEmbedding(
44 | n_estimators=3, max_depth=2, sparse_output=False
45 | ).fit(X)
46 | model.transform(X)
47 | model_onnx = to_onnx(model, X[:1], target_opset=TARGET_OPSET)
48 | with open("model.onnx", "wb") as f:
49 | f.write(model_onnx.SerializeToString())
50 | self.check_model(model_onnx, X)
51 | dump_data_and_model(
52 | X.astype(numpy.float32),
53 | model,
54 | model_onnx,
55 | basename="SklearnRandomTreesEmbedding",
56 | )
57 |
58 |
59 | if __name__ == "__main__":
60 | unittest.main()
61 |
--------------------------------------------------------------------------------
/tests/test_sklearn_replace_transformer.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | Tests scikit-learn's cast transformer converter.
5 | """
6 |
7 | import unittest
8 | import numpy
9 | from sklearn.pipeline import Pipeline
10 |
11 | try:
12 | from sklearn.compose import ColumnTransformer
13 | except ImportError:
14 | ColumnTransformer = None
15 | from skl2onnx.sklapi import ReplaceTransformer
16 | from skl2onnx import convert_sklearn
17 | from skl2onnx.common.data_types import FloatTensorType
18 | from test_utils import dump_data_and_model, TARGET_OPSET
19 |
20 |
21 | class TestSklearnCastTransformerConverter(unittest.TestCase):
22 | def common_test_replace_transformer(self, dtype, input_type):
23 | model = Pipeline(
24 | [
25 | ("replace", ReplaceTransformer(dtype=numpy.float32)),
26 | ]
27 | )
28 | data = numpy.array(
29 | [[0.1, 0.2, 3.1], [1, 1, 0], [0, 2, 1], [1, 0, 2]], dtype=numpy.float32
30 | )
31 | model.fit(data)
32 | pred = model.steps[0][1].transform(data)
33 | assert pred.dtype == dtype
34 | model_onnx = convert_sklearn(
35 | model,
36 | "cast",
37 | [("input", FloatTensorType([None, 3]))],
38 | target_opset=TARGET_OPSET,
39 | )
40 | self.assertTrue(model_onnx is not None)
41 | dump_data_and_model(
42 | data,
43 | model,
44 | model_onnx,
45 | basename="SklearnCastTransformer{}".format(input_type.__class__.__name__),
46 | )
47 |
48 | @unittest.skipIf(TARGET_OPSET < 11, reason="not supported")
49 | def test_replace_transformer(self):
50 | self.common_test_replace_transformer(numpy.float32, FloatTensorType)
51 |
52 |
53 | if __name__ == "__main__":
54 | unittest.main()
55 |
--------------------------------------------------------------------------------
/tests/test_sklearn_sgd_oneclass_svm_converter.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """Tests scikit-learn's SGDClassifier converter."""
4 |
5 | import unittest
6 | import numpy as np
7 |
8 | try:
9 | from sklearn.linear_model import SGDOneClassSVM
10 | except ImportError:
11 | SGDOneClassSVM = None
12 | from onnxruntime import __version__ as ort_version
13 | from skl2onnx import convert_sklearn
14 |
15 | from skl2onnx.common.data_types import (
16 | FloatTensorType,
17 | )
18 |
19 | from test_utils import dump_data_and_model, TARGET_OPSET
20 |
21 | ort_version = ".".join(ort_version.split(".")[:2])
22 |
23 |
24 | class TestSGDOneClassSVMConverter(unittest.TestCase):
25 | @unittest.skipIf(SGDOneClassSVM is None, reason="scikit-learn<1.0")
26 | def test_model_sgd_oneclass_svm(self):
27 | X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]])
28 | model = SGDOneClassSVM(random_state=42)
29 | model.fit(X)
30 | test_x = np.array([[0, 0], [-1, -1], [1, 1]]).astype(np.float32)
31 | model.predict(test_x)
32 |
33 | model_onnx = convert_sklearn(
34 | model,
35 | "scikit-learn SGD OneClass SVM",
36 | [("input", FloatTensorType([None, X.shape[1]]))],
37 | target_opset=TARGET_OPSET,
38 | )
39 |
40 | self.assertIsNotNone(model_onnx)
41 | dump_data_and_model(
42 | test_x.astype(np.float32),
43 | model,
44 | model_onnx,
45 | basename="SklearnSGDOneClassSVMBinaryHinge",
46 | )
47 |
48 |
49 | if __name__ == "__main__":
50 | unittest.main(verbosity=3)
51 |
--------------------------------------------------------------------------------
/tests/test_sklearn_tfidf_transformer_converter.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # coding: utf-8
4 | """
5 | Tests scikit-learn's TfidfTransformer converter.
6 | """
7 |
8 | import unittest
9 | import numpy
10 | from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
11 | from skl2onnx import convert_sklearn
12 | from skl2onnx.common.data_types import FloatTensorType
13 | from test_utils import dump_data_and_model, TARGET_OPSET
14 |
15 |
16 | class TestSklearnTfidfTransformerConverter(unittest.TestCase):
17 | def test_model_tfidf_transform(self):
18 | corpus = numpy.array(
19 | [
20 | "This is the first document.",
21 | "This document is the second document.",
22 | "And this is the third one.",
23 | "Is this the first document?",
24 | "Troisième document en français",
25 | ]
26 | ).reshape((5, 1))
27 | data = (
28 | CountVectorizer(ngram_range=(1, 1)).fit_transform(corpus.ravel()).todense()
29 | )
30 | data = numpy.array(data.astype(numpy.float32))
31 |
32 | for sublinear_tf in (False, True):
33 | if sublinear_tf:
34 | # scikit-learn applies a log on a matrix
35 | # but only on strictly positive coefficients
36 | break
37 | for norm in (None, "l1", "l2"):
38 | for smooth_idf in (False, True):
39 | for use_idf in (False, True):
40 | model = TfidfTransformer(
41 | norm=norm,
42 | use_idf=use_idf,
43 | smooth_idf=smooth_idf,
44 | sublinear_tf=sublinear_tf,
45 | )
46 | model.fit(data)
47 | model_onnx = convert_sklearn(
48 | model,
49 | "TfidfTransformer",
50 | [("input", FloatTensorType([None, data.shape[1]]))],
51 | target_opset=TARGET_OPSET,
52 | )
53 | self.assertTrue(model_onnx is not None)
54 | suffix = norm.upper() if norm else ""
55 | suffix += "Sub" if sublinear_tf else ""
56 | suffix += "Idf" if use_idf else ""
57 | suffix += "Smooth" if smooth_idf else ""
58 | dump_data_and_model(
59 | data,
60 | model,
61 | model_onnx,
62 | basename="SklearnTfidfTransform" + suffix,
63 | )
64 |
65 |
66 | if __name__ == "__main__":
67 | unittest.main()
68 |
--------------------------------------------------------------------------------
/tests/test_sklearn_tfidf_transformer_converter_sparse.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | """
3 | Tests examples from scikit-learn's documentation.
4 | """
5 |
6 | import packaging.version as pv
7 | import unittest
8 | import urllib.error
9 | import sys
10 | import onnx
11 | from sklearn.datasets import fetch_20newsgroups
12 | from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
13 | from sklearn.pipeline import Pipeline
14 | import onnxruntime as ort
15 | from skl2onnx.common.data_types import StringTensorType
16 | from skl2onnx import convert_sklearn
17 | from test_utils import dump_data_and_model, TARGET_OPSET
18 |
19 | BACKEND = (
20 | "onnxruntime"
21 | if pv.Version(onnx.__version__) < pv.Version("1.16.0")
22 | else "onnx;onnxruntime"
23 | )
24 |
25 |
26 | class TestSklearnTfidfVectorizerSparse(unittest.TestCase):
27 | @unittest.skipIf(
28 | TARGET_OPSET < 9,
29 | # issue with encoding
30 | reason="https://github.com/onnx/onnx/pull/1734",
31 | )
32 | @unittest.skipIf(TARGET_OPSET < 18, reason="too long")
33 | @unittest.skipIf(
34 | pv.Version(ort.__version__) <= pv.Version("0.2.1"),
35 | reason="sparse not supported",
36 | )
37 | @unittest.skipIf(sys.platform != "linux", reason="too long")
38 | def test_model_tfidf_transform_bug(self):
39 | categories = [
40 | "alt.atheism",
41 | "soc.religion.christian",
42 | "comp.graphics",
43 | "sci.med",
44 | ]
45 | try:
46 | twenty_train = fetch_20newsgroups(
47 | subset="train", categories=categories, shuffle=True, random_state=0
48 | )
49 | except urllib.error.HTTPError as e:
50 | raise unittest.SkipTest(f"HTTP fails due to {e}")
51 | text_clf = Pipeline(
52 | [("vect", CountVectorizer()), ("tfidf", TfidfTransformer())]
53 | )
54 | twenty_train.data[0] = "bruît " + twenty_train.data[0]
55 | text_clf.fit(twenty_train.data, twenty_train.target)
56 | model_onnx = convert_sklearn(
57 | text_clf,
58 | name="DocClassifierCV-Tfidf",
59 | initial_types=[("input", StringTensorType([5]))],
60 | target_opset=TARGET_OPSET,
61 | )
62 | dump_data_and_model(
63 | twenty_train.data[5:10],
64 | text_clf,
65 | model_onnx,
66 | basename="SklearnPipelineTfidfTransformer",
67 | backend=BACKEND,
68 | )
69 |
70 |
71 | if __name__ == "__main__":
72 | unittest.main()
73 |
--------------------------------------------------------------------------------
/tests/test_sklearn_tfidf_vectorizer_converter_dataset.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | Tests scikit-learn's tfidf converter using downloaded data.
5 | """
6 |
7 | import unittest
8 | import urllib.error
9 | import packaging.version as pv
10 | import numpy as np
11 | import onnx
12 | from sklearn.model_selection import train_test_split
13 | from sklearn.feature_extraction.text import TfidfVectorizer
14 | from sklearn.datasets import fetch_20newsgroups
15 | from skl2onnx import convert_sklearn
16 | from skl2onnx.common.data_types import StringTensorType
17 | from test_utils import dump_data_and_model, TARGET_OPSET
18 |
19 | BACKEND = (
20 | "onnxruntime"
21 | if pv.Version(onnx.__version__) < pv.Version("1.16.0")
22 | else "onnx;onnxruntime"
23 | )
24 |
25 |
26 | class TestSklearnTfidfVectorizerDataSet(unittest.TestCase):
27 | @unittest.skipIf(TARGET_OPSET < 9, reason="not available")
28 | @unittest.skipIf(TARGET_OPSET < 18, reason="too long")
29 | def test_tfidf_20newsgroups(self):
30 | try:
31 | data = fetch_20newsgroups()
32 | except urllib.error.HTTPError as e:
33 | raise unittest.SkipTest(f"HTTP fails due to {e}")
34 | X, y = np.array(data.data)[:100], np.array(data.target)[:100]
35 | X_train, X_test, y_train, y_test = train_test_split(
36 | X, y, test_size=0.5, random_state=42
37 | )
38 |
39 | model = TfidfVectorizer().fit(X_train)
40 | onnx_model = convert_sklearn(
41 | model,
42 | "cv",
43 | [("input", StringTensorType(X_test.shape))],
44 | target_opset=TARGET_OPSET,
45 | )
46 | dump_data_and_model(
47 | X_test,
48 | model,
49 | onnx_model,
50 | basename="SklearnTfidfVectorizer20newsgroups",
51 | backend=BACKEND,
52 | )
53 |
54 | @unittest.skipIf(TARGET_OPSET < 9, reason="not available")
55 | @unittest.skipIf(TARGET_OPSET < 18, reason="too long")
56 | def test_tfidf_20newsgroups_nolowercase(self):
57 | try:
58 | data = fetch_20newsgroups()
59 | except urllib.error.HTTPError as e:
60 | raise unittest.SkipTest(f"HTTP fails due to {e}")
61 | X, y = np.array(data.data)[:100], np.array(data.target)[:100]
62 | X_train, X_test, y_train, y_test = train_test_split(
63 | X, y, test_size=0.5, random_state=42
64 | )
65 |
66 | model = TfidfVectorizer(lowercase=False).fit(X_train)
67 | onnx_model = convert_sklearn(
68 | model,
69 | "cv",
70 | [("input", StringTensorType(X_test.shape))],
71 | target_opset=TARGET_OPSET,
72 | )
73 | dump_data_and_model(
74 | X_test,
75 | model,
76 | onnx_model,
77 | basename="SklearnTfidfVectorizer20newsgroupsNOLower",
78 | backend=BACKEND,
79 | )
80 |
81 |
82 | if __name__ == "__main__":
83 | unittest.main()
84 |
--------------------------------------------------------------------------------
/tests/test_sklearn_truncated_svd.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | import unittest
5 |
6 | import numpy as np
7 | from sklearn.decomposition import TruncatedSVD
8 |
9 | from skl2onnx.common.data_types import FloatTensorType, Int64TensorType
10 | from skl2onnx import convert_sklearn
11 | from test_utils import create_tensor
12 | from test_utils import dump_data_and_model, TARGET_OPSET
13 |
14 |
15 | class TestTruncatedSVD(unittest.TestCase):
16 | def setUp(self):
17 | np.random.seed(0)
18 |
19 | def test_truncated_svd(self):
20 | N, C, K = 2, 3, 2
21 | x = create_tensor(N, C)
22 |
23 | svd = TruncatedSVD(n_components=K)
24 | svd.fit(x)
25 | model_onnx = convert_sklearn(
26 | svd,
27 | initial_types=[("input", FloatTensorType(shape=[None, C]))],
28 | target_opset=TARGET_OPSET,
29 | )
30 | self.assertTrue(model_onnx is not None)
31 | dump_data_and_model(x, svd, model_onnx, basename="SklearnTruncatedSVD")
32 |
33 | def test_truncated_svd_arpack(self):
34 | X = create_tensor(10, 10)
35 | svd = TruncatedSVD(
36 | n_components=5, algorithm="arpack", n_iter=10, tol=0.1, random_state=42
37 | ).fit(X)
38 | model_onnx = convert_sklearn(
39 | svd,
40 | initial_types=[("input", FloatTensorType(shape=X.shape))],
41 | target_opset=TARGET_OPSET,
42 | )
43 | self.assertTrue(model_onnx is not None)
44 | dump_data_and_model(X, svd, model_onnx, basename="SklearnTruncatedSVDArpack")
45 |
46 | def test_truncated_svd_int(self):
47 | X = create_tensor(5, 5).astype(np.int64)
48 | svd = TruncatedSVD(n_iter=20, random_state=42).fit(X)
49 | model_onnx = convert_sklearn(
50 | svd,
51 | initial_types=[("input", Int64TensorType([None, X.shape[1]]))],
52 | target_opset=TARGET_OPSET,
53 | )
54 | self.assertTrue(model_onnx is not None)
55 | dump_data_and_model(X, svd, model_onnx, basename="SklearnTruncatedSVDInt")
56 |
57 |
58 | if __name__ == "__main__":
59 | unittest.main()
60 |
--------------------------------------------------------------------------------
/tests/test_sklearn_tuned_threshold_classifier.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import unittest
4 | import numpy as np
5 | from sklearn.datasets import make_classification
6 | from sklearn.ensemble import RandomForestClassifier
7 | from sklearn.model_selection import train_test_split
8 | from sklearn.utils._testing import ignore_warnings
9 | from skl2onnx import to_onnx
10 | from skl2onnx.common.data_types import FloatTensorType
11 | from test_utils import dump_data_and_model, TARGET_OPSET
12 |
13 |
14 | def has_tuned_theshold_classifier():
15 | try:
16 | from sklearn.model_selection import TunedThresholdClassifierCV # noqa: F401
17 | except ImportError:
18 | return False
19 | return True
20 |
21 |
22 | class TestSklearnTunedThresholdClassifierConverter(unittest.TestCase):
23 | @unittest.skipIf(
24 | not has_tuned_theshold_classifier(),
25 | reason="TunedThresholdClassifierCV not available",
26 | )
27 | @ignore_warnings(category=FutureWarning)
28 | def test_tuned_threshold_classifier(self):
29 | from sklearn.model_selection import TunedThresholdClassifierCV
30 |
31 | X, y = make_classification(
32 | n_samples=1_000, weights=[0.9, 0.1], class_sep=0.8, random_state=42
33 | )
34 | X_train, X_test, y_train, y_test = train_test_split(
35 | X, y, stratify=y, random_state=42
36 | )
37 | classifier = RandomForestClassifier(random_state=0)
38 |
39 | classifier_tuned = TunedThresholdClassifierCV(
40 | classifier, scoring="balanced_accuracy"
41 | ).fit(X_train, y_train)
42 |
43 | model_onnx = to_onnx(
44 | classifier_tuned,
45 | initial_types=[("X", FloatTensorType([None, X_train.shape[1]]))],
46 | target_opset=TARGET_OPSET - 1,
47 | options={"zipmap": False},
48 | )
49 | self.assertTrue(model_onnx is not None)
50 | dump_data_and_model(
51 | X_test[:10].astype(np.float32),
52 | classifier_tuned,
53 | model_onnx,
54 | basename="SklearnTunedThresholdClassifier",
55 | )
56 |
57 |
58 | if __name__ == "__main__":
59 | unittest.main(verbosity=2)
60 |
--------------------------------------------------------------------------------
/tests/test_sklearn_voting_regressor_converter.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """Tests VotingRegressor converter."""
4 |
5 | import unittest
6 | import numpy
7 | from sklearn.linear_model import LinearRegression
8 |
9 | try:
10 | from sklearn.ensemble import VotingRegressor
11 | except ImportError:
12 | # New in 0.21
13 | VotingRegressor = None
14 | from sklearn.tree import DecisionTreeRegressor
15 | from skl2onnx import convert_sklearn
16 | from skl2onnx.common.data_types import (
17 | BooleanTensorType,
18 | FloatTensorType,
19 | Int64TensorType,
20 | )
21 | from test_utils import dump_data_and_model, fit_regression_model, TARGET_OPSET
22 |
23 |
24 | def model_to_test():
25 | return VotingRegressor(
26 | [
27 | ("lr", LinearRegression()),
28 | ("dt", DecisionTreeRegressor()),
29 | ]
30 | )
31 |
32 |
33 | class TestVotingRegressorConverter(unittest.TestCase):
34 | @unittest.skipIf(VotingRegressor is None, reason="new in 0.21")
35 | def test_model_voting_regression(self):
36 | model, X = fit_regression_model(model_to_test())
37 | model_onnx = convert_sklearn(
38 | model,
39 | "voting regression",
40 | [("input", FloatTensorType([None, X.shape[1]]))],
41 | target_opset=TARGET_OPSET,
42 | )
43 | self.assertIsNotNone(model_onnx)
44 | dump_data_and_model(
45 | X.astype(numpy.float32),
46 | model,
47 | model_onnx,
48 | basename="SklearnVotingRegressor-Dec4",
49 | comparable_outputs=[0],
50 | )
51 |
52 | @unittest.skipIf(VotingRegressor is None, reason="new in 0.21")
53 | def test_model_voting_regression_int(self):
54 | model, X = fit_regression_model(model_to_test(), is_int=True)
55 | model_onnx = convert_sklearn(
56 | model,
57 | "voting regression",
58 | [("input", Int64TensorType([None, X.shape[1]]))],
59 | target_opset=TARGET_OPSET,
60 | )
61 | self.assertIsNotNone(model_onnx)
62 | dump_data_and_model(
63 | X,
64 | model,
65 | model_onnx,
66 | basename="SklearnVotingRegressorInt-Dec4",
67 | comparable_outputs=[0],
68 | )
69 |
70 | @unittest.skipIf(VotingRegressor is None, reason="new in 0.21")
71 | def test_model_voting_regression_bool(self):
72 | model, X = fit_regression_model(model_to_test(), is_bool=True)
73 | model_onnx = convert_sklearn(
74 | model,
75 | "voting regression",
76 | [("input", BooleanTensorType([None, X.shape[1]]))],
77 | target_opset=TARGET_OPSET,
78 | )
79 | self.assertIsNotNone(model_onnx)
80 | dump_data_and_model(
81 | X,
82 | model,
83 | model_onnx,
84 | basename="SklearnVotingRegressorBool",
85 | comparable_outputs=[0],
86 | )
87 |
88 |
89 | if __name__ == "__main__":
90 | unittest.main()
91 |
--------------------------------------------------------------------------------
/tests/test_utils/main.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | from skl2onnx.proto import onnx_proto
5 | from skl2onnx.common import utils as convert_utils
6 |
7 |
8 | def set_model_domain(model, domain):
9 | """
10 | Sets the domain on the ONNX model.
11 |
12 | :param model: instance of an ONNX model
13 | :param domain: string containing the domain name of the model
14 |
15 | Example:
16 |
17 | ::
18 | from test_utils import set_model_domain
19 | onnx_model = load_model("SqueezeNet.onnx")
20 | set_model_domain(onnx_model, "com.acme")
21 | """
22 | if model is None or not isinstance(model, onnx_proto.ModelProto):
23 | raise ValueError("Parameter model is not an onnx model.")
24 | if not convert_utils.is_string_type(domain):
25 | raise ValueError("Parameter domain must be a string type.")
26 | model.domain = domain
27 |
28 |
29 | def set_model_version(model, version):
30 | """
31 | Sets the version of the ONNX model.
32 |
33 | :param model: instance of an ONNX model
34 | :param version: integer containing the version of the model
35 |
36 | Example:
37 |
38 | ::
39 | from test_utils import set_model_version
40 | onnx_model = load_model("SqueezeNet.onnx")
41 | set_model_version(onnx_model, 1)
42 | """
43 | if model is None or not isinstance(model, onnx_proto.ModelProto):
44 | raise ValueError("Parameter model is not an onnx model.")
45 | if not convert_utils.is_numeric_type(version):
46 | raise ValueError("Parameter version must be a numeric type.")
47 | model.model_version = version
48 |
49 |
50 | def set_model_doc_string(model, doc, override=False):
51 | """
52 | Sets the doc string of the ONNX model.
53 |
54 | :param model: instance of an ONNX model
55 | :param doc: string containing the doc string that describes the model.
56 | :param override: bool if true will always override the doc
57 | string with the new value
58 |
59 | Example:
60 |
61 | ::
62 | from test_utils import set_model_doc_string
63 | onnx_model = load_model("SqueezeNet.onnx")
64 | set_model_doc_string(onnx_model, "Sample doc string")
65 | """
66 | if model is None or not isinstance(model, onnx_proto.ModelProto):
67 | raise ValueError("Parameter model is not an onnx model.")
68 | if not convert_utils.is_string_type(doc):
69 | raise ValueError("Parameter doc must be a string type.")
70 | if model.doc_string and not doc and override is False:
71 | raise ValueError(
72 | "Failed to overwrite the doc string with a blank string,"
73 | " set override to True if intentional."
74 | )
75 | model.doc_string = doc
76 |
--------------------------------------------------------------------------------
/tests/test_utils/reference_implementation_afe.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | """
3 | Helpers to test runtimes.
4 | """
5 |
6 | from onnx.defs import onnx_opset_version
7 |
8 |
9 | def _array_feature_extrator(data, indices):
10 | """
11 | Implementation of operator *ArrayFeatureExtractor*
12 | with :epkg:`numpy`.
13 | """
14 | if len(indices.shape) == 2 and indices.shape[0] == 1:
15 | index = indices.ravel().tolist()
16 | add = len(index)
17 | elif len(indices.shape) == 1:
18 | index = indices.tolist()
19 | add = len(index)
20 | else:
21 | add = 1
22 | for s in indices.shape:
23 | add *= s
24 | index = indices.ravel().tolist()
25 | new_shape = (1, add) if len(data.shape) == 1 else list(data.shape[:-1]) + [add]
26 | try:
27 | tem = data[..., index]
28 | except IndexError as e:
29 | raise RuntimeError(f"data.shape={data.shape}, indices={indices}") from e
30 | res = tem.reshape(new_shape)
31 | return res
32 |
33 |
34 | if onnx_opset_version() >= 18:
35 | from onnx.reference.op_run import OpRun
36 |
37 | class ArrayFeatureExtractor(OpRun):
38 | op_domain = "ai.onnx.ml"
39 |
40 | def _run(self, data, indices):
41 | """
42 | Runtime for operator *ArrayFeatureExtractor*.
43 |
44 | .. warning::
45 | ONNX specifications may be imprecise in some cases.
46 | When the input data is a vector (one dimension),
47 | the output has still two like a matrix with one row.
48 | The implementation follows what :epkg:`onnxruntime` does in
49 | `array_feature_extractor.cc
50 | `_.
52 | """
53 | res = _array_feature_extrator(data, indices)
54 | return (res,)
55 |
--------------------------------------------------------------------------------