├── python ├── packager │ ├── __init__.py │ ├── sdist.py │ ├── util.py │ └── build_config.py ├── treelite │ ├── py.typed │ ├── VERSION │ ├── gtil │ │ └── __init__.py │ ├── path_config.py │ ├── __init__.py │ ├── sklearn │ │ ├── __init__.py │ │ └── isolation_forest.py │ ├── core.py │ ├── libpath.py │ └── util.py ├── README.rst ├── hatch_build.py ├── .pylintrc └── pyproject.toml ├── tests ├── python │ ├── __init__.py │ └── metadata.py ├── examples │ ├── mushroom │ │ └── mushroom.model │ ├── sparse_categorical │ │ └── sparse_categorical.test.margin │ ├── deep_lightgbm │ │ └── model.txt │ └── toy_categorical │ │ └── toy_categorical.test.pred ├── ci_build │ ├── build_via_cmake.sh │ ├── Dockerfile.ubuntu20_amd64 │ ├── Dockerfile.ubuntu20_aarch64 │ ├── build_macos_python_wheels.sh │ ├── entrypoint.sh │ └── rename_whl.py ├── cpp │ ├── test_main.cc │ ├── CMakeLists.txt │ ├── test_model_loader.cc │ └── test_utils.cc ├── example_app │ ├── CMakeLists.txt │ ├── example.cc │ └── example.c └── serializer │ ├── compatibility_tester.py │ └── test_serializer.py ├── docs ├── .gitignore ├── _static │ ├── deployment.png │ └── custom.css ├── knobs │ ├── index.rst │ └── postprocessor.rst ├── _templates │ └── layout.html ├── requirements.txt ├── tutorials │ ├── index.rst │ └── import.rst ├── treelite-doxygen.rst ├── Makefile ├── treelite-gtil-api.rst ├── make.bat ├── treelite-api.rst ├── serialization │ └── index.rst ├── treelite-c-api.rst ├── index.rst └── install.rst ├── .isort.cfg ├── ops ├── conda_env │ ├── pre-commit.yml │ └── dev.yml ├── test-linux-python-wheel.sh ├── test-macos-python-wheel.sh ├── test-cmake-import.sh ├── build-macos.sh ├── build-linux.sh ├── build-linux-aarch64.sh ├── test-sdist.sh ├── build-windows.bat ├── macos-python-coverage.sh ├── test-win-python-wheel.bat ├── build-cpack.sh ├── win-python-coverage.bat ├── cpp-python-coverage.sh └── test-serializer-compatibility.sh ├── cmake ├── Version.cmake ├── TreeliteConfig.cmake.in ├── version.h.in ├── Doxygen.cmake ├── Utils.cmake ├── Sanitizer.cmake └── ExternalLibs.cmake ├── .gitattributes ├── Makefile ├── include └── treelite │ ├── error.h │ ├── pybuffer_frame.h │ ├── thread_local.h │ ├── enum │ ├── tree_node_type.h │ ├── task_type.h │ ├── operator.h │ └── typeinfo.h │ ├── detail │ ├── omp_exception.h │ ├── file_utils.h │ └── serializer_mixins.h │ ├── c_api_error.h │ ├── contiguous_array.h │ └── gtil.h ├── ACKNOWLEDGMENTS.md ├── src ├── logging.cc ├── c_api │ ├── logging.cc │ ├── c_api_utils.h │ ├── c_api_error.cc │ ├── field_accessor.cc │ ├── model.cc │ ├── serializer.cc │ ├── gtil.cc │ └── model_loader.cc ├── model_loader │ ├── detail │ │ ├── string_utils.h │ │ ├── xgboost.h │ │ ├── xgboost_json │ │ │ ├── sax_adapters.h │ │ │ └── sax_adapters.cc │ │ ├── lightgbm.h │ │ └── xgboost.cc │ └── xgboost_ubjson.cc ├── gtil │ ├── postprocessor.h │ ├── output_shape.cc │ ├── config.cc │ └── postprocessor.cc ├── enum │ ├── typeinfo.cc │ ├── tree_node_type.cc │ ├── operator.cc │ └── task_type.cc ├── model_query.cc ├── model_builder │ └── metadata.cc └── model_concat.cc ├── .readthedocs.yaml ├── dev ├── run_pylint.py └── change_version.py ├── .flake8 ├── .github └── workflows │ ├── pre-commit.yml │ ├── windows-wheel-builder.yml │ ├── cpack-builder.yml │ ├── pre-commit-update.yml │ ├── linux-wheel-builder.yml │ ├── macos-wheel-builder.yml │ ├── coverage-tests.yml │ └── misc-tests.yml ├── README.md ├── .gitignore ├── .pre-commit-config.yaml └── .clang-format /python/packager/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/treelite/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/python/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/treelite/VERSION: -------------------------------------------------------------------------------- 1 | 4.7.0.dev0 2 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _build 2 | _static 3 | doxyxml 4 | tmp 5 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile=black 3 | known_first_party=treelite 4 | -------------------------------------------------------------------------------- /docs/_static/deployment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmlc/treelite/HEAD/docs/_static/deployment.png -------------------------------------------------------------------------------- /tests/examples/mushroom/mushroom.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmlc/treelite/HEAD/tests/examples/mushroom/mushroom.model -------------------------------------------------------------------------------- /ops/conda_env/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: precommit 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.11 6 | - pre-commit 7 | -------------------------------------------------------------------------------- /tests/examples/sparse_categorical/sparse_categorical.test.margin: -------------------------------------------------------------------------------- 1 | -0.036716360109071977 2 | -0.088410677452890038 3 | 0.21990291345117738 4 | -------------------------------------------------------------------------------- /docs/knobs/index.rst: -------------------------------------------------------------------------------- 1 | ==================== 2 | Knobs and Parameters 3 | ==================== 4 | 5 | .. toctree:: 6 | :maxdepth: 1 7 | 8 | postprocessor 9 | -------------------------------------------------------------------------------- /docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | 3 | {%- block extrahead %} 4 | 5 | {% endblock %} 6 | -------------------------------------------------------------------------------- /tests/ci_build/build_via_cmake.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | set -x 4 | 5 | rm -rf build 6 | mkdir build 7 | cd build 8 | cmake .. -GNinja "$@" 9 | ninja -v 10 | cd .. 11 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx>=5.2.1 2 | sphinx_rtd_theme>=1.0.0 3 | breathe 4 | autodocsumm 5 | scikit-learn 6 | matplotlib>=2.1 7 | graphviz 8 | numpy 9 | scipy 10 | sphinx-gallery 11 | pandas 12 | -------------------------------------------------------------------------------- /python/treelite/gtil/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | General Tree Inference Library (GTIL) 3 | """ 4 | 5 | from .gtil import predict, predict_leaf, predict_per_tree 6 | 7 | __all__ = ["predict", "predict_leaf", "predict_per_tree"] 8 | -------------------------------------------------------------------------------- /docs/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Tutorials 3 | ========= 4 | 5 | This page lists tutorials about Treelite. 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | :caption: Contents: 10 | 11 | import 12 | builder 13 | edit 14 | -------------------------------------------------------------------------------- /cmake/Version.cmake: -------------------------------------------------------------------------------- 1 | function (write_version) 2 | message(STATUS "Treelite VERSION: ${PROJECT_VERSION}") 3 | configure_file( 4 | ${PROJECT_SOURCE_DIR}/cmake/version.h.in 5 | include/treelite/version.h) 6 | endfunction (write_version) 7 | -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | @import url('theme.css'); 2 | 3 | .red { 4 | color: red; 5 | } 6 | 7 | .wy-side-nav-search div.version { 8 | color: #FFFFFF; 9 | } 10 | 11 | .wy-table-responsive table td,.wy-table-responsive table th { 12 | white-space: normal; 13 | } 14 | -------------------------------------------------------------------------------- /python/treelite/path_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom hook to customize path for libtreelite.so 3 | """ 4 | 5 | 6 | def get_custom_libpath(): 7 | """ 8 | Get custom path for libtreelite.so. 9 | If valid, must return a directory containing libtreelite.so. 10 | """ 11 | return None 12 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Set the default behavior, in case people don't have core.autocrlf set. 2 | * text=auto 3 | 4 | # Explicitly declare text files you want to always be normalized and converted 5 | # to native line endings on checkout. 6 | *.cc text 7 | *.h text 8 | *.proto text 9 | *.txt text 10 | *.md text 11 | Makefile text 12 | LICENSE text 13 | -------------------------------------------------------------------------------- /ops/test-linux-python-wheel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | echo "##[section]Installing Treelite into Python environment..." 6 | pip install --force-reinstall wheelhouse/*.whl 7 | 8 | echo "##[section]Running Python tests..." 9 | python -m pytest -v -rxXs --fulltrace --durations=0 tests/python/test_sklearn_integration.py 10 | -------------------------------------------------------------------------------- /ops/test-macos-python-wheel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | echo "##[section]Installing Treelite into Python environment..." 6 | pip install --force-reinstall wheelhouse/*.whl 7 | 8 | echo "##[section]Running Python tests..." 9 | python -m pytest -v -rxXs --fulltrace --durations=0 tests/python/test_sklearn_integration.py 10 | -------------------------------------------------------------------------------- /docs/treelite-doxygen.rst: -------------------------------------------------------------------------------- 1 | ================================== 2 | Documentation for the C++ codebase 3 | ================================== 4 | 5 | The core logic of Treelite is written in C++. Visit the link below to find 6 | the documentation for all C++ functions and classes defined in Treelite. 7 | 8 | `Documentation for C++ functions and classes <./dev>`_ 9 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ifndef LINT_LANG 2 | LINT_LANG="all" 3 | endif 4 | 5 | ifndef NPROC 6 | NPROC=1 7 | endif 8 | 9 | doxygen: 10 | cd docs; doxygen 11 | 12 | cpp-coverage: 13 | rm -rf build; mkdir build; cd build; cmake .. -DTEST_COVERAGE=ON -DCMAKE_BUILD_TYPE=Debug && make -j$(NPROC) 14 | 15 | all: 16 | rm -rf build; mkdir build; cd build; cmake .. && make -j$(NPROC) 17 | -------------------------------------------------------------------------------- /cmake/TreeliteConfig.cmake.in: -------------------------------------------------------------------------------- 1 | @PACKAGE_INIT@ 2 | 3 | include(CMakeFindDependencyMacro) 4 | 5 | set(USE_OPENMP @USE_OPENMP@) 6 | if(USE_OPENMP) 7 | find_dependency(OpenMP) 8 | endif() 9 | 10 | if(NOT TARGET treelite::treelite) 11 | include(${CMAKE_CURRENT_LIST_DIR}/TreeliteTargets.cmake) 12 | endif() 13 | 14 | message(STATUS "Found Treelite (found version \"${Treelite_VERSION}\")") 15 | -------------------------------------------------------------------------------- /python/README.rst: -------------------------------------------------------------------------------- 1 | ======================= 2 | Treelite Python Package 3 | ======================= 4 | 5 | |PyPI version| 6 | 7 | .. |PyPI version| image:: https://badge.fury.io/py/treelite.svg 8 | :target: http://badge.fury.io/py/treelite 9 | 10 | **Treelite** is a universal model exchange and serialization format for decision tree forests. 11 | 12 | See the documentation for more details. 13 | -------------------------------------------------------------------------------- /tests/cpp/test_main.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2020-2023 by Contributors 3 | * \file test_main.cc 4 | * \author Hyunsu Cho 5 | * \brief Launcher for C++ unit tests, using Google Test framework 6 | */ 7 | #include 8 | 9 | int main(int argc, char** argv) { 10 | testing::InitGoogleTest(&argc, argv); 11 | testing::FLAGS_gtest_death_test_style = "threadsafe"; 12 | return RUN_ALL_TESTS(); 13 | } 14 | -------------------------------------------------------------------------------- /cmake/version.h.in: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | */ 4 | #ifndef TREELITE_VERSION_H_ 5 | #define TREELITE_VERSION_H_ 6 | 7 | #define TREELITE_VER_MAJOR @treelite_VERSION_MAJOR@ 8 | #define TREELITE_VER_MINOR @treelite_VERSION_MINOR@ 9 | #define TREELITE_VER_PATCH @treelite_VERSION_PATCH@ 10 | #define TREELITE_VERSION_STR "@treelite_VERSION_MAJOR@.@treelite_VERSION_MINOR@.@treelite_VERSION_PATCH@" 11 | 12 | #endif // TREELITE_VERSION_H_ 13 | -------------------------------------------------------------------------------- /ops/conda_env/dev.yml: -------------------------------------------------------------------------------- 1 | name: dev 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.11 6 | - numpy 7 | - scipy 8 | - pandas 9 | - pytest 10 | - pytest-cov 11 | - hypothesis 12 | - scikit-learn 13 | - coverage 14 | - codecov 15 | - ninja 16 | - lcov 17 | - cmake 18 | - llvm-openmp 19 | - cython 20 | - lightgbm 21 | - cpplint=1.6.0 22 | - pylint 23 | - awscli 24 | - python-build 25 | - pip 26 | - pip: 27 | - cibuildwheel 28 | - xgboost>=2.1.0 29 | -------------------------------------------------------------------------------- /ops/test-cmake-import.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | # Install Treelite C++ library into the Conda env 6 | set -x 7 | rm -rf build/ 8 | mkdir build 9 | cd build 10 | cmake .. -DCMAKE_INSTALL_PREFIX="$CONDA_PREFIX" -DCMAKE_INSTALL_LIBDIR="lib" -GNinja 11 | ninja install 12 | 13 | # Try compiling a sample application 14 | cd ../tests/example_app/ 15 | rm -rf build/ 16 | mkdir build 17 | cd build 18 | cmake .. -GNinja 19 | ninja 20 | ./cpp_example 21 | ./c_example 22 | -------------------------------------------------------------------------------- /cmake/Doxygen.cmake: -------------------------------------------------------------------------------- 1 | function (run_doxygen) 2 | find_package(Doxygen REQUIRED dot) 3 | 4 | configure_file( 5 | ${treelite_SOURCE_DIR}/docs/Doxyfile.in 6 | ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile @ONLY) 7 | add_custom_target(doc_doxygen ALL 8 | COMMAND ${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile 9 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} 10 | COMMENT "Generate documentation for C/C++ functions" 11 | VERBATIM) 12 | endfunction (run_doxygen) 13 | -------------------------------------------------------------------------------- /ops/build-macos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | if [[ $# -ne 2 ]]; then 6 | echo "Usage: $0 [platform_id] [commit ID]" 7 | exit 1 8 | fi 9 | 10 | platform_id=$1 11 | commit_id=$2 12 | 13 | echo "##[section]Building MacOS Python wheels..." 14 | tests/ci_build/build_macos_python_wheels.sh ${platform_id} ${commit_id} 15 | 16 | echo "##[section]Uploading MacOS Python wheels to S3..." 17 | python -m awscli s3 cp wheelhouse/treelite-*.whl s3://treelite-wheels/ --acl public-read --region us-west-2 || true 18 | -------------------------------------------------------------------------------- /python/treelite/__init__.py: -------------------------------------------------------------------------------- 1 | """Treelite module""" 2 | 3 | import pathlib 4 | 5 | from . import frontend, gtil, model_builder, sklearn 6 | from .core import TreeliteError 7 | from .model import Model 8 | 9 | VERSION_FILE = pathlib.Path(__file__).parent / "VERSION" 10 | with open(VERSION_FILE, "r", encoding="UTF-8") as _f: 11 | __version__ = _f.read().strip() 12 | 13 | __all__ = [ 14 | "Model", 15 | "frontend", 16 | "gtil", 17 | "sklearn", 18 | "model_builder", 19 | "TreeliteError", 20 | "__version__", 21 | ] 22 | -------------------------------------------------------------------------------- /include/treelite/error.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2022 by Contributors 3 | * \file error.h 4 | * \brief Exception class used throughout the Treelite codebase 5 | * \author Hyunsu Cho 6 | */ 7 | #ifndef TREELITE_ERROR_H_ 8 | #define TREELITE_ERROR_H_ 9 | 10 | #include 11 | #include 12 | 13 | namespace treelite { 14 | 15 | /*! 16 | * \brief Exception class that will be thrown by Treelite 17 | */ 18 | struct Error : public std::runtime_error { 19 | explicit Error(std::string const& s) : std::runtime_error(s) {} 20 | }; 21 | 22 | } // namespace treelite 23 | 24 | #endif // TREELITE_ERROR_H_ 25 | -------------------------------------------------------------------------------- /ops/build-linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | TAG=manylinux2014_x86_64 6 | 7 | export CIBW_BUILD=cp38-manylinux_x86_64 8 | export CIBW_ARCHS=x86_64 9 | export CIBW_BUILD_VERBOSITY=3 10 | export CIBW_MANYLINUX_X86_64_IMAGE=manylinux2014 11 | 12 | echo "##[section]Building Python wheel (amd64) for Treelite..." 13 | python -m cibuildwheel python --output-dir wheelhouse 14 | python tests/ci_build/rename_whl.py wheelhouse ${COMMIT_ID} ${TAG} 15 | 16 | echo "##[section]Uploading Python wheel (amd64)..." 17 | python -m awscli s3 cp wheelhouse/*.whl s3://treelite-wheels/ --acl public-read --region us-west-2 || true 18 | -------------------------------------------------------------------------------- /ops/build-linux-aarch64.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | TAG=manylinux2014_aarch64 6 | 7 | export CIBW_BUILD=cp38-manylinux_aarch64 8 | export CIBW_ARCHS=aarch64 9 | export CIBW_BUILD_VERBOSITY=3 10 | export CIBW_MANYLINUX_AARCH64_IMAGE=manylinux2014 11 | 12 | echo "##[section]Building Python wheel (aarch64) for Treelite..." 13 | python -m cibuildwheel python --output-dir wheelhouse 14 | python tests/ci_build/rename_whl.py wheelhouse ${COMMIT_ID} ${TAG} 15 | 16 | echo "##[section]Uploading Python wheel (aarch64)..." 17 | python -m awscli s3 cp wheelhouse/*.whl s3://treelite-wheels/ --acl public-read --region us-west-2 || true 18 | -------------------------------------------------------------------------------- /include/treelite/pybuffer_frame.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file pybuffer_frame.h 4 | * \brief Data structure to enable zero-copy exchange in Python 5 | * \author Hyunsu Cho 6 | */ 7 | 8 | #ifndef TREELITE_PYBUFFER_FRAME_H_ 9 | #define TREELITE_PYBUFFER_FRAME_H_ 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | namespace treelite { 17 | 18 | using PyBufferFrame = TreelitePyBufferFrame; 19 | 20 | static_assert(std::is_pod::value, "PyBufferFrame must be a POD type"); 21 | 22 | } // namespace treelite 23 | 24 | #endif // TREELITE_PYBUFFER_FRAME_H_ 25 | -------------------------------------------------------------------------------- /ops/test-sdist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | echo "##[section]Building a source distribution..." 6 | python -m build --sdist python/ --outdir . 7 | 8 | echo "##[section]Testing the source distribution..." 9 | python -m pip install --force-reinstall -v treelite-*.tar.gz 10 | python -m pytest -v -rxXs --fulltrace --durations=0 tests/python/test_sklearn_integration.py 11 | 12 | # Deploy source distribution to S3 13 | for file in ./treelite-*.tar.gz 14 | do 15 | mv "${file}" "${file%.tar.gz}+${COMMIT_ID}.tar.gz" 16 | done 17 | python -m awscli s3 cp treelite-*.tar.gz s3://treelite-wheels/ --acl public-read --region us-west-2 || true 18 | -------------------------------------------------------------------------------- /ops/build-windows.bat: -------------------------------------------------------------------------------- 1 | echo ##[section]Generating Visual Studio solution... 2 | mkdir build 3 | cd build 4 | cmake .. -G"Visual Studio 17 2022" -A x64 -DBUILD_CPP_TEST=ON 5 | if %errorlevel% neq 0 exit /b %errorlevel% 6 | 7 | echo ##[section]Building Visual Studio solution... 8 | cmake --build . --config Release -- /m 9 | if %errorlevel% neq 0 exit /b %errorlevel% 10 | cd .. 11 | 12 | echo ##[section]Running C++ tests... 13 | .\build\treelite_cpp_test.exe 14 | if %errorlevel% neq 0 exit /b %errorlevel% 15 | 16 | echo ##[section]Packaging Python wheel for Treelite... 17 | cd python 18 | pip wheel --no-deps -v . --wheel-dir dist/ 19 | if %errorlevel% neq 0 exit /b %errorlevel% 20 | -------------------------------------------------------------------------------- /ACKNOWLEDGMENTS.md: -------------------------------------------------------------------------------- 1 | # Acknowledgments 2 | 3 | This page acknowledges those who provided great help in building the initial 4 | version of Treelite: 5 | 6 | * Treelite builds upon Hyunsu's earlier work at **Paul G. Allen School of 7 | Computer Science and Engineering, University of Washington at Seattle**, which was performed under the 8 | guidance of **Carlos Guestrin** and **Arvind Krishnamurthy**. 9 | * **Tianqi Chen** (member of DMLC) offered many great ideas for the conception 10 | of the project. He also provided feedback pertaining to API design. 11 | * **Mu Li** (member of DMLC) provided advice and resources to develop Treelite 12 | into a full-fledged open source project. 13 | -------------------------------------------------------------------------------- /python/treelite/sklearn/__init__.py: -------------------------------------------------------------------------------- 1 | """Model loader to ingest scikit-learn models into Treelite""" 2 | 3 | from .exporter import export_model 4 | from .importer import import_model 5 | 6 | 7 | def import_model_with_model_builder(sklearn_model): 8 | """ 9 | This function was removed in Treelite 4.0; please use :py:meth:`~treelite.sklearn.import_model` 10 | instead. 11 | """ 12 | raise NotImplementedError( 13 | "treelite.sklearn.import_model_with_model_builder() was removed in Treelite 4.0. " 14 | "Please use treelite.sklearn.import_model() instead." 15 | ) 16 | 17 | 18 | __all__ = ["import_model", "export_model", "import_model_with_model_builder"] 19 | -------------------------------------------------------------------------------- /include/treelite/thread_local.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2021 by Contributors 3 | * \file thread_local.h 4 | * \brief Helper class for thread-local storage 5 | * \author Hyunsu Cho 6 | */ 7 | #ifndef TREELITE_THREAD_LOCAL_H_ 8 | #define TREELITE_THREAD_LOCAL_H_ 9 | 10 | namespace treelite { 11 | 12 | /*! 13 | * \brief A thread-local storage 14 | * \tparam T the type we like to store 15 | */ 16 | template 17 | class ThreadLocalStore { 18 | public: 19 | /*! \return get a thread local singleton */ 20 | static T* Get() { 21 | static thread_local T inst; 22 | return &inst; 23 | } 24 | }; 25 | 26 | } // namespace treelite 27 | 28 | #endif // TREELITE_THREAD_LOCAL_H_ 29 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python3 -msphinx 7 | SPHINXPROJ = treelite 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /python/packager/sdist.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for building sdist 3 | """ 4 | 5 | import logging 6 | import pathlib 7 | 8 | from .util import copy_with_logging, copytree_with_logging 9 | 10 | 11 | def copy_cpp_src_tree( 12 | cpp_src_dir: pathlib.Path, target_dir: pathlib.Path, logger: logging.Logger 13 | ) -> None: 14 | """Copy C++ source tree into build directory""" 15 | 16 | for subdir in [ 17 | "src", 18 | "include", 19 | "cmake", 20 | ]: 21 | copytree_with_logging(cpp_src_dir / subdir, target_dir / subdir, logger=logger) 22 | 23 | for filename in ["CMakeLists.txt", "LICENSE"]: 24 | copy_with_logging(cpp_src_dir.joinpath(filename), target_dir, logger=logger) 25 | -------------------------------------------------------------------------------- /src/logging.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2017-2023 by Contributors 3 | * \file logging.cc 4 | * \author Hyunsu Cho 5 | * \brief logging facility for treelite 6 | */ 7 | 8 | #include 9 | 10 | namespace treelite { 11 | 12 | void LogMessage::Log(std::string const& msg) { 13 | LogCallbackRegistry const* registry = LogCallbackRegistryStore::Get(); 14 | auto callback = registry->GetCallbackLogInfo(); 15 | callback(msg.c_str()); 16 | } 17 | 18 | void LogMessageWarning::Log(std::string const& msg) { 19 | LogCallbackRegistry const* registry = LogCallbackRegistryStore::Get(); 20 | auto callback = registry->GetCallbackLogWarning(); 21 | callback(msg.c_str()); 22 | } 23 | 24 | } // namespace treelite 25 | -------------------------------------------------------------------------------- /docs/treelite-gtil-api.rst: -------------------------------------------------------------------------------- 1 | ===================================== 2 | General Tree Inference Library (GTIL) 3 | ===================================== 4 | 5 | GTIL is a **reference implementation** of a prediction runtime for all Treelite models. It has the following goals: 6 | 7 | * **Universal coverage**: GTIL shall support all tree ensemble models that can be represented as Treelite objects. 8 | * **Accessible code**: GTIL should be written in an easy-to-read style that can be understood to a first-time contributor. We prefer code legibility to performance optimization. 9 | * **Correct output**: As a reference implementation, GTIL should produce correct prediction outputs. 10 | 11 | .. automodule:: treelite.gtil 12 | :members: 13 | -------------------------------------------------------------------------------- /src/c_api/logging.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file logging.cc 4 | * \author Hyunsu Cho 5 | * \brief C API for logging functions 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | int TreeliteRegisterLogCallback(void (*callback)(char const*)) { 13 | API_BEGIN(); 14 | auto* registry = treelite::LogCallbackRegistryStore::Get(); 15 | registry->RegisterCallBackLogInfo(callback); 16 | API_END(); 17 | } 18 | 19 | int TreeliteRegisterWarningCallback(void (*callback)(char const*)) { 20 | API_BEGIN(); 21 | auto* registry = treelite::LogCallbackRegistryStore::Get(); 22 | registry->RegisterCallBackLogWarning(callback); 23 | API_END(); 24 | } 25 | -------------------------------------------------------------------------------- /tests/ci_build/Dockerfile.ubuntu20_amd64: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | LABEL maintainer "DMLC" 3 | 4 | ENV GOSU_VERSION 1.13 5 | ENV DEBIAN_FRONTEND noninteractive 6 | 7 | RUN \ 8 | apt-get update && \ 9 | apt-get install -y build-essential git wget unzip tar cmake ninja-build 10 | 11 | # Install lightweight sudo (not bound to TTY) 12 | RUN set -ex; \ 13 | wget -nv -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \ 14 | chmod +x /usr/local/bin/gosu && \ 15 | gosu nobody true 16 | 17 | # Default entry-point to use if running locally 18 | # It will preserve attributes of created files 19 | COPY entrypoint.sh /scripts/ 20 | 21 | WORKDIR /workspace 22 | ENTRYPOINT ["/scripts/entrypoint.sh"] 23 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | submodules: 9 | include: all 10 | 11 | # Set the version of Python and other tools you might need 12 | build: 13 | os: ubuntu-22.04 14 | tools: 15 | python: "3.10" 16 | apt_packages: 17 | - graphviz 18 | - cmake 19 | - g++ 20 | - doxygen 21 | - ninja-build 22 | 23 | # Build documentation in the docs/ directory with Sphinx 24 | sphinx: 25 | configuration: docs/conf.py 26 | 27 | # Optionally declare the Python requirements required to build your docs 28 | python: 29 | install: 30 | - requirements: docs/requirements.txt 31 | -------------------------------------------------------------------------------- /tests/ci_build/Dockerfile.ubuntu20_aarch64: -------------------------------------------------------------------------------- 1 | FROM arm64v8/ubuntu:20.04 2 | LABEL maintainer "DMLC" 3 | 4 | ENV GOSU_VERSION 1.13 5 | ENV DEBIAN_FRONTEND noninteractive 6 | 7 | RUN \ 8 | apt-get update && \ 9 | apt-get install -y build-essential git wget unzip tar cmake ninja-build 10 | 11 | # Install lightweight sudo (not bound to TTY) 12 | RUN set -ex; \ 13 | wget -nv -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-arm64" && \ 14 | chmod +x /usr/local/bin/gosu && \ 15 | gosu nobody true 16 | 17 | # Default entry-point to use if running locally 18 | # It will preserve attributes of created files 19 | COPY entrypoint.sh /scripts/ 20 | 21 | WORKDIR /workspace 22 | ENTRYPOINT ["/scripts/entrypoint.sh"] 23 | -------------------------------------------------------------------------------- /dev/run_pylint.py: -------------------------------------------------------------------------------- 1 | """Wrapper for Pylint""" 2 | 3 | import os 4 | import pathlib 5 | import subprocess 6 | import sys 7 | 8 | ROOT_PATH = pathlib.Path(__file__).parent.parent.expanduser().resolve() 9 | PYPKG_PATH = ROOT_PATH / "python" 10 | PYLINTRC_PATH = PYPKG_PATH / ".pylintrc" 11 | 12 | 13 | def main(): 14 | """Wrapper for Pylint. Add treelite to PYTHONPATH so that pylint doesn't error out""" 15 | new_env = os.environ.copy() 16 | new_env["PYTHONPATH"] = str(PYPKG_PATH) 17 | 18 | # sys.argv[1:]: List of source files to check 19 | subprocess.run( 20 | ["pylint", "-rn", "-sn", "--rcfile", str(PYLINTRC_PATH)] + sys.argv[1:], 21 | check=True, 22 | env=new_env, 23 | ) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /python/hatch_build.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom hook to customize the behavior of Hatchling. 3 | Here, we customize the tag of the generated wheels. 4 | """ 5 | 6 | import sysconfig 7 | from typing import Any, Dict 8 | 9 | from hatchling.builders.hooks.plugin.interface import BuildHookInterface 10 | 11 | 12 | def get_tag() -> str: 13 | """Get appropriate wheel tag according to system""" 14 | tag_platform = sysconfig.get_platform().replace("-", "_").replace(".", "_") 15 | return f"py3-none-{tag_platform}" 16 | 17 | 18 | class CustomBuildHook(BuildHookInterface): 19 | """A custom build hook""" 20 | 21 | def initialize(self, version: str, build_data: Dict[str, Any]) -> None: 22 | """This step ccurs immediately before each build.""" 23 | build_data["tag"] = get_tag() 24 | -------------------------------------------------------------------------------- /python/packager/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for implementing PEP 517 backend 3 | """ 4 | 5 | import logging 6 | import pathlib 7 | import shutil 8 | 9 | 10 | def copytree_with_logging( 11 | src: pathlib.Path, dest: pathlib.Path, logger: logging.Logger 12 | ) -> None: 13 | """Call shutil.copytree() with logging""" 14 | logger.info("Copying %s -> %s", str(src), str(dest)) 15 | shutil.copytree(src, dest) 16 | 17 | 18 | def copy_with_logging( 19 | src: pathlib.Path, dest: pathlib.Path, logger: logging.Logger 20 | ) -> None: 21 | """Call shutil.copy() with logging""" 22 | if dest.is_dir(): 23 | logger.info("Copying %s -> %s", str(src), str(dest / src.name)) 24 | else: 25 | logger.info("Copying %s -> %s", str(src), str(dest)) 26 | shutil.copy(src, dest) 27 | -------------------------------------------------------------------------------- /ops/macos-python-coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | conda --version 6 | python --version 7 | 8 | # Run coverage test 9 | echo "##[section]Building Treelite..." 10 | set -x 11 | rm -rf build/ 12 | mkdir build 13 | cd build 14 | cmake .. -DTEST_COVERAGE=ON -DBUILD_CPP_TEST=ON -GNinja 15 | ninja -v 16 | cd .. 17 | 18 | ./build/treelite_cpp_test 19 | PYTHONPATH=./python python -m pytest --cov=treelite -v -rxXs \ 20 | --fulltrace --durations=0 tests/python 21 | lcov --directory . --capture --output-file coverage.info 22 | lcov --remove coverage.info '*dmlccore*' --output-file coverage.info 23 | lcov --remove coverage.info '*fmtlib*' --output-file coverage.info 24 | lcov --remove coverage.info '*/usr/*' --output-file coverage.info 25 | lcov --remove coverage.info '*googletest*' --output-file coverage.info 26 | codecov 27 | -------------------------------------------------------------------------------- /ops/test-win-python-wheel.bat: -------------------------------------------------------------------------------- 1 | echo ##[section]Installing Treelite into Python environment... 2 | setlocal enabledelayedexpansion 3 | python tests\ci_build\rename_whl.py python\dist %COMMIT_ID% win_amd64 4 | if %errorlevel% neq 0 exit /b %errorlevel% 5 | for /R %%i in (python\\dist\\*.whl) DO ( 6 | python -m pip install --force-reinstall "%%i" 7 | if !errorlevel! neq 0 exit /b !errorlevel! 8 | ) 9 | 10 | echo ##[section]Running Python tests... 11 | mkdir temp 12 | python -m pytest --basetemp="%WORKING_DIR%\temp" -v -rxXs --fulltrace --durations=0 tests\python\test_sklearn_integration.py 13 | if %errorlevel% neq 0 exit /b %errorlevel% 14 | 15 | echo ##[section]Uploading Python wheels... 16 | for /R %%i in (python\\dist\\*.whl) DO ( 17 | python -m awscli s3 cp "%%i" s3://treelite-wheels/ --acl public-read --region us-west-2 || cd . 18 | ) 19 | -------------------------------------------------------------------------------- /include/treelite/enum/tree_node_type.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file tree_node_type.h 4 | * \brief Define enum type NodeType 5 | * \author Hyunsu Cho 6 | */ 7 | 8 | #ifndef TREELITE_ENUM_TREE_NODE_TYPE_H_ 9 | #define TREELITE_ENUM_TREE_NODE_TYPE_H_ 10 | 11 | #include 12 | #include 13 | 14 | namespace treelite { 15 | 16 | /*! \brief Tree node type */ 17 | enum class TreeNodeType : std::int8_t { 18 | kLeafNode = 0, 19 | kNumericalTestNode = 1, 20 | kCategoricalTestNode = 2 21 | }; 22 | 23 | /*! \brief Get string representation of TreeNodeType */ 24 | std::string TreeNodeTypeToString(TreeNodeType type); 25 | 26 | /*! \brief Get NodeType from string */ 27 | TreeNodeType TreeNodeTypeFromString(std::string const& name); 28 | 29 | } // namespace treelite 30 | 31 | #endif // TREELITE_ENUM_TREE_NODE_TYPE_H_ 32 | -------------------------------------------------------------------------------- /include/treelite/enum/task_type.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file task_type.h 4 | * \brief Define enum type TaskType 5 | * \author Hyunsu Cho 6 | */ 7 | 8 | #ifndef TREELITE_ENUM_TASK_TYPE_H_ 9 | #define TREELITE_ENUM_TASK_TYPE_H_ 10 | 11 | #include 12 | #include 13 | 14 | namespace treelite { 15 | 16 | /*! 17 | * \brief Enum type representing the task type. 18 | */ 19 | enum class TaskType : std::uint8_t { 20 | kBinaryClf = 0, 21 | kRegressor = 1, 22 | kMultiClf = 2, 23 | kLearningToRank = 3, 24 | kIsolationForest = 4 25 | }; 26 | 27 | /*! \brief Get string representation of TaskType */ 28 | std::string TaskTypeToString(TaskType type); 29 | 30 | /*! \brief Get TaskType from string */ 31 | TaskType TaskTypeFromString(std::string const& str); 32 | 33 | } // namespace treelite 34 | 35 | #endif // TREELITE_ENUM_TASK_TYPE_H_ 36 | -------------------------------------------------------------------------------- /tests/example_app/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.16) 2 | project(example_app LANGUAGES C CXX) 3 | 4 | if(DEFINED ENV{CONDA_PREFIX}) 5 | set(CMAKE_PREFIX_PATH "$ENV{CONDA_PREFIX};${CMAKE_PREFIX_PATH}") 6 | message(STATUS "Detected Conda environment, CMAKE_PREFIX_PATH set to: ${CMAKE_PREFIX_PATH}") 7 | else() 8 | message(STATUS "No Conda environment detected") 9 | endif() 10 | 11 | find_package(Treelite REQUIRED) 12 | 13 | add_executable(cpp_example example.cc) 14 | target_link_libraries(cpp_example PRIVATE treelite::treelite) 15 | 16 | add_executable(c_example example.c) 17 | target_link_libraries(c_example PRIVATE treelite::treelite) 18 | 19 | set_target_properties(cpp_example PROPERTIES 20 | CXX_STANDARD 17 21 | CXX_STANDARD_REQUIRED YES 22 | ) 23 | 24 | set_target_properties(c_example PROPERTIES 25 | C_STANDARD 99 26 | C_STANDARD_REQUIRED YES 27 | ) 28 | -------------------------------------------------------------------------------- /ops/build-cpack.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | if [[ -z "${COMMIT_ID:-}" ]] 6 | then 7 | echo "Make sure to set environment variable COMMIT_ID" 8 | exit 1 9 | fi 10 | 11 | if [[ "$#" -lt 1 ]] 12 | then 13 | echo "Usage: $0 {amd64,aarch64}" 14 | exit 2 15 | fi 16 | 17 | arch="$1" 18 | 19 | echo "##[section] Building Treelite for ${arch}..." 20 | tests/ci_build/ci_build.sh ubuntu20_${arch} tests/ci_build/build_via_cmake.sh 21 | 22 | echo "##[section] Packing CPack for ${arch}..." 23 | tests/ci_build/ci_build.sh ubuntu20_${arch} bash -c "cd build/ && cpack -G TGZ" 24 | for tgz in build/treelite-*-Linux.tar.gz 25 | do 26 | mv -v "${tgz}" "${tgz%-Linux.tar.gz}+${COMMIT_ID}-Linux-${arch}.tar.gz" 27 | done 28 | 29 | echo "##[section]Uploading CPack for ${arch}..." 30 | python -m awscli s3 cp build/*.tar.gz s3://treelite-cpack/ --acl public-read --region us-west-2 || true 31 | -------------------------------------------------------------------------------- /ops/win-python-coverage.bat: -------------------------------------------------------------------------------- 1 | echo ##[section]Generating Visual Studio solution... 2 | mkdir build 3 | cd build 4 | cmake .. -G"Visual Studio 17 2022" -A x64 5 | if %errorlevel% neq 0 exit /b %errorlevel% 6 | 7 | echo ##[section]Building Visual Studio solution... 8 | cmake --build . --config Release -- /m 9 | if %errorlevel% neq 0 exit /b %errorlevel% 10 | cd .. 11 | 12 | echo ##[section]Running Python tests... 13 | mkdir temp 14 | set "PYTHONPATH=./python" 15 | set "PYTEST_TMPDIR=%USERPROFILE%\AppData\Local\Temp\pytest_temp" 16 | mkdir "%PYTEST_TMPDIR%" 17 | python -m pytest --basetemp="%USERPROFILE%\AppData\Local\Temp\pytest_temp" --cov=treelite --cov-report xml -v -rxXs --fulltrace --durations=0 tests\python 18 | if %errorlevel% neq 0 exit /b %errorlevel% 19 | 20 | echo ##[section]Submitting code coverage data to CodeCov... 21 | python -m codecov -f coverage.xml 22 | if %errorlevel% neq 0 exit /b %errorlevel% 23 | -------------------------------------------------------------------------------- /src/model_loader/detail/string_utils.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file string_utils.h 4 | * \brief Helper functions for manipulating strings 5 | * \author Hyunsu Cho 6 | */ 7 | 8 | #ifndef SRC_MODEL_LOADER_DETAIL_STRING_UTILS_H_ 9 | #define SRC_MODEL_LOADER_DETAIL_STRING_UTILS_H_ 10 | 11 | #include 12 | #include 13 | 14 | namespace treelite::model_loader::detail { 15 | 16 | inline bool StringStartsWith(std::string const& str, std::string const& prefix) { 17 | return str.rfind(prefix, 0) == 0; 18 | } 19 | 20 | inline void StringTrimFromEnd(std::string& s) { 21 | s.erase(std::find_if( 22 | s.rbegin(), s.rend(), [](char ch) { return ch != '\n' && ch != '\r' && ch != ' '; }) 23 | .base(), 24 | s.end()); 25 | } 26 | 27 | } // namespace treelite::model_loader::detail 28 | 29 | #endif // SRC_MODEL_LOADER_DETAIL_STRING_UTILS_H_ 30 | -------------------------------------------------------------------------------- /include/treelite/enum/operator.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file operator.h 4 | * \brief Define enum type Operator 5 | * \author Hyunsu Cho 6 | */ 7 | 8 | #ifndef TREELITE_ENUM_OPERATOR_H_ 9 | #define TREELITE_ENUM_OPERATOR_H_ 10 | 11 | #include 12 | #include 13 | 14 | namespace treelite { 15 | 16 | /*! \brief Type of comparison operators used in numerical test nodes */ 17 | enum class Operator : std::int8_t { 18 | kNone, 19 | kEQ, /*!< operator == */ 20 | kLT, /*!< operator < */ 21 | kLE, /*!< operator <= */ 22 | kGT, /*!< operator > */ 23 | kGE, /*!< operator >= */ 24 | }; 25 | 26 | /*! \brief Get string representation of Operator */ 27 | std::string OperatorToString(Operator type); 28 | 29 | /*! \brief Get Operator from string */ 30 | Operator OperatorFromString(std::string const& name); 31 | 32 | } // namespace treelite 33 | 34 | #endif // TREELITE_ENUM_OPERATOR_H_ 35 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=treelite 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /src/gtil/postprocessor.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2021-2023 by Contributors 3 | * \file postprocessor.h 4 | * \author Hyunsu Cho 5 | * \brief Functions to post-process prediction results 6 | */ 7 | 8 | #ifndef SRC_GTIL_POSTPROCESSOR_H_ 9 | #define SRC_GTIL_POSTPROCESSOR_H_ 10 | 11 | #include 12 | #include 13 | 14 | namespace treelite { 15 | 16 | class Model; 17 | 18 | namespace gtil { 19 | 20 | template 21 | using PostProcessorFunc = void (*)(treelite::Model const&, std::int32_t, InputT*); 22 | 23 | template 24 | PostProcessorFunc GetPostProcessorFunc(std::string const& name); 25 | 26 | extern template PostProcessorFunc GetPostProcessorFunc(std::string const& name); 27 | extern template PostProcessorFunc GetPostProcessorFunc(std::string const& name); 28 | 29 | } // namespace gtil 30 | } // namespace treelite 31 | 32 | #endif // SRC_GTIL_POSTPROCESSOR_H_ 33 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | filename = *.py, *.pyx, *.pxd 3 | exclude = 4 | *.egg, 5 | .git, 6 | __pycache__, 7 | build/, 8 | cpp, 9 | docs, 10 | tests/cython/ 11 | 12 | # Cython Rules ignored: 13 | # E999: invalid syntax (works for Python, not Cython) 14 | # E225: Missing whitespace around operators (breaks cython casting syntax like ) 15 | # E226: Missing whitespace around arithmetic operators (breaks cython pointer syntax like int*) 16 | # E227: Missing whitespace around bitwise or shift operator (Can also break casting syntax) 17 | # W503: line break before binary operator (breaks lines that start with a pointer) 18 | # W504: line break after binary operator (breaks lines that end with a pointer) 19 | 20 | extend-ignore = 21 | # handled by black 22 | E501, W503, E203 23 | # imported but unused 24 | F401 25 | # redefinition of unused 26 | F811 27 | # E203 whitespace before ':' 28 | # https://github.com/psf/black/issues/315 29 | E203 30 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - mainline 8 | - 'release_*' 9 | 10 | permissions: 11 | contents: read # to fetch code (actions/checkout) 12 | 13 | defaults: 14 | run: 15 | shell: bash -l {0} 16 | 17 | concurrency: 18 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 19 | cancel-in-progress: true 20 | 21 | jobs: 22 | pre-commit: 23 | runs-on: ubuntu-latest 24 | steps: 25 | - uses: actions/checkout@v4 26 | - uses: conda-incubator/setup-miniconda@v3 27 | with: 28 | miniforge-variant: Miniforge3 29 | miniforge-version: latest 30 | conda-remove-defaults: "true" 31 | activate-environment: precommit 32 | environment-file: ops/conda_env/pre-commit.yml 33 | use-mamba: true 34 | - name: Run pre-commit checks 35 | run: | 36 | pre-commit install 37 | pre-commit run --all-files --color always 38 | -------------------------------------------------------------------------------- /src/c_api/c_api_utils.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file c_api_utils.h 4 | * \author Hyunsu Cho 5 | * \brief C API of Treelite, used for interfacing with other languages 6 | */ 7 | #ifndef SRC_C_API_C_API_UTILS_H_ 8 | #define SRC_C_API_C_API_UTILS_H_ 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | namespace treelite::c_api { 18 | 19 | /*! \brief When returning a complex object from a C API function, we 20 | * store the object here and then return a pointer. The 21 | * storage is thread-local static storage. */ 22 | struct ReturnValueEntry { 23 | std::string ret_str; 24 | std::vector ret_uint32_vec; 25 | std::vector ret_uint64_vec; 26 | std::vector ret_frames; 27 | }; 28 | using ReturnValueStore = ThreadLocalStore; 29 | 30 | } // namespace treelite::c_api 31 | 32 | #endif // SRC_C_API_C_API_UTILS_H_ 33 | -------------------------------------------------------------------------------- /ops/cpp-python-coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | echo "##[section]Installing lcov and Ninja..." 6 | sudo apt-get install lcov ninja-build 7 | 8 | echo "##[section]Building Treelite..." 9 | mkdir build/ 10 | cd build/ 11 | cmake .. -DTEST_COVERAGE=ON -DCMAKE_BUILD_TYPE=Debug -DBUILD_CPP_TEST=ON -GNinja 12 | ninja install -v 13 | cd .. 14 | 15 | echo "##[section]Running Google C++ tests..." 16 | ./build/treelite_cpp_test 17 | 18 | echo "##[section]Running Python integration tests..." 19 | export PYTHONPATH='./python' 20 | python -m pytest --cov=treelite -v -rxXs --fulltrace --durations=0 tests/python tests/serializer 21 | 22 | echo "##[section]Collecting coverage data..." 23 | lcov --directory . --capture --output-file coverage.info 24 | lcov --remove coverage.info '*/usr/*' --output-file coverage.info 25 | lcov --remove coverage.info '*/build/_deps/*' --output-file coverage.info 26 | 27 | echo "##[section]Submitting code coverage data to CodeCov..." 28 | bash <(curl -s https://codecov.io/bash) -X gcov || echo "Codecov did not collect coverage reports" 29 | -------------------------------------------------------------------------------- /python/.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | ignore-paths=tests/cython,docs 3 | extension-pkg-whitelist=numpy 4 | 5 | load-plugins=pylint.extensions.no_self_use 6 | 7 | disable=unexpected-special-method-signature,too-many-nested-blocks,useless-object-inheritance,import-outside-toplevel,unsubscriptable-object,attribute-defined-outside-init,unbalanced-tuple-unpacking,too-many-lines,duplicate-code,too-many-arguments 8 | 9 | dummy-variables-rgx=(unused|)_.* 10 | 11 | [BASIC] 12 | 13 | # Enforce naming convention 14 | const-naming-style=UPPER_CASE 15 | class-naming-style=PascalCase 16 | function-naming-style=snake_case 17 | method-naming-style=snake_case 18 | attr-naming-style=snake_case 19 | argument-naming-style=snake_case 20 | variable-naming-style=snake_case 21 | class-attribute-naming-style=snake_case 22 | 23 | # Allow single-letter variables 24 | variable-rgx=[a-zA-Z_][a-z0-9_]{0,30}$ 25 | argument-rgx=[a-zA-Z_][a-z0-9_]{0,30}$ 26 | 27 | [TYPECHECK] 28 | generated-members=np.float32,np.uintc,np.uintp,np.uint32 29 | 30 | [MESSAGES CONTROL] 31 | # globally disable pylint checks (comma separated) 32 | disable=fixme 33 | -------------------------------------------------------------------------------- /src/c_api/c_api_error.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2017-2023 by Contributors 3 | * \file c_api_error.cc 4 | * \author Hyunsu Cho 5 | * \brief C error handling 6 | */ 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace treelite::c_api { 15 | 16 | struct APIErrorEntry { 17 | std::string last_error; 18 | std::string version_str; 19 | }; 20 | 21 | using APIErrorStore = ThreadLocalStore; 22 | 23 | } // namespace treelite::c_api 24 | 25 | char const* TreeliteGetLastError() { 26 | return treelite::c_api::APIErrorStore::Get()->last_error.c_str(); 27 | } 28 | 29 | void TreeliteAPISetLastError(char const* msg) { 30 | treelite::c_api::APIErrorStore::Get()->last_error = msg; 31 | } 32 | 33 | char const* TreeliteQueryTreeliteVersion() { 34 | auto& version_str = treelite::c_api::APIErrorStore::Get()->version_str; 35 | version_str = TREELITE_VERSION_STR; 36 | return version_str.c_str(); 37 | } 38 | 39 | char const* TREELITE_VERSION = "TREELITE_VERSION_" TREELITE_VERSION_STR; 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Treelite 2 | 3 | ![Coverage tests](https://github.com/dmlc/treelite/actions/workflows/coverage-tests.yml/badge.svg) 4 | [![Documentation Status](https://readthedocs.org/projects/treelite/badge/?version=latest)](http://treelite.readthedocs.io/en/latest/?badge=latest) 5 | [![codecov](https://codecov.io/gh/dmlc/treelite/branch/mainline/graph/badge.svg)](https://codecov.io/gh/dmlc/treelite) 6 | [![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE) 7 | [![PyPI version](https://badge.fury.io/py/treelite.svg)](https://pypi.python.org/pypi/treelite/) 8 | [![Conda Version](https://img.shields.io/conda/vn/conda-forge/treelite.svg)](https://anaconda.org/conda-forge/treelite) 9 | 10 | [Documentation](https://treelite.readthedocs.io/en/latest) | 11 | [Installation](http://treelite.readthedocs.io/en/latest/install.html) | 12 | [Release Notes](NEWS.md) | 13 | [Acknowledgements](ACKNOWLEDGMENTS.md) | 14 | 15 | **Treelite** is a universal model exchange and serialization format for 16 | decision tree forests. Treelite aims to be a small library that enables 17 | other C++ applications to exchange and store decision trees on the disk 18 | as well as the network. 19 | -------------------------------------------------------------------------------- /tests/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(treelite_cpp_test) 2 | set_target_properties(treelite_cpp_test 3 | PROPERTIES 4 | CXX_STANDARD 17 5 | CXX_STANDARD_REQUIRED ON) 6 | target_link_libraries(treelite_cpp_test 7 | PRIVATE objtreelite rapidjson 8 | GTest::gtest GTest::gmock fmt::fmt-header-only std::mdspan) 9 | set_output_directory(treelite_cpp_test ${PROJECT_BINARY_DIR}) 10 | 11 | if(MSVC) 12 | target_compile_options(treelite_cpp_test PRIVATE 13 | /utf-8 -D_CRT_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_DEPRECATE) 14 | endif() 15 | 16 | if(TEST_COVERAGE) 17 | if(MSVC) 18 | message(FATAL_ERROR "Test coverage not available on Windows") 19 | endif() 20 | target_compile_options(treelite_cpp_test PUBLIC -g3 --coverage) 21 | target_link_options(treelite_cpp_test PUBLIC --coverage) 22 | endif() 23 | 24 | target_sources(treelite_cpp_test 25 | PRIVATE 26 | test_main.cc 27 | test_gtil.cc 28 | test_model_builder.cc 29 | test_model_concat.cc 30 | test_model_loader.cc 31 | test_serializer.cc 32 | test_utils.cc 33 | ) 34 | 35 | target_include_directories(treelite_cpp_test 36 | PRIVATE ${PROJECT_SOURCE_DIR}/src/ 37 | ) 38 | -------------------------------------------------------------------------------- /src/enum/typeinfo.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2017-2023 by Contributors 3 | * \file typeinfo.cc 4 | * \author Hyunsu Cho 5 | * \brief Utilities for TypeInfo enum 6 | */ 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | namespace treelite { 14 | 15 | std::string TypeInfoToString(treelite::TypeInfo info) { 16 | switch (info) { 17 | case treelite::TypeInfo::kInvalid: 18 | return "invalid"; 19 | case treelite::TypeInfo::kUInt32: 20 | return "uint32"; 21 | case treelite::TypeInfo::kFloat32: 22 | return "float32"; 23 | case treelite::TypeInfo::kFloat64: 24 | return "float64"; 25 | default: 26 | TREELITE_LOG(FATAL) << "Unrecognized type"; 27 | return ""; 28 | } 29 | } 30 | 31 | TypeInfo TypeInfoFromString(std::string const& str) { 32 | if (str == "uint32") { 33 | return TypeInfo::kUInt32; 34 | } else if (str == "float32") { 35 | return TypeInfo::kFloat32; 36 | } else if (str == "float64") { 37 | return TypeInfo::kFloat64; 38 | } else { 39 | TREELITE_LOG(FATAL) << "Unrecognized type: " << str; 40 | return TypeInfo::kInvalid; 41 | } 42 | } 43 | 44 | } // namespace treelite 45 | -------------------------------------------------------------------------------- /src/enum/tree_node_type.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file tree_node_type.cc 4 | * \author Hyunsu Cho 5 | * \brief Utilities for NodeType enum 6 | */ 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | namespace treelite { 14 | 15 | std::string TreeNodeTypeToString(TreeNodeType type) { 16 | switch (type) { 17 | case TreeNodeType::kLeafNode: 18 | return "leaf_node"; 19 | case TreeNodeType::kNumericalTestNode: 20 | return "numerical_test_node"; 21 | case TreeNodeType::kCategoricalTestNode: 22 | return "categorical_test_node"; 23 | default: 24 | return ""; 25 | } 26 | } 27 | 28 | TreeNodeType TreeNodeTypeFromString(std::string const& name) { 29 | if (name == "leaf_node") { 30 | return TreeNodeType::kLeafNode; 31 | } else if (name == "numerical_test_node") { 32 | return TreeNodeType::kNumericalTestNode; 33 | } else if (name == "categorical_test_node") { 34 | return TreeNodeType::kCategoricalTestNode; 35 | } else { 36 | TREELITE_LOG(FATAL) << "Unknown split type: " << name; 37 | return TreeNodeType::kLeafNode; 38 | } 39 | } 40 | 41 | } // namespace treelite 42 | -------------------------------------------------------------------------------- /python/treelite/sklearn/isolation_forest.py: -------------------------------------------------------------------------------- 1 | """Utility functions for loading IsolationForest models""" 2 | 3 | import numpy as np 4 | 5 | 6 | def harmonic(number): 7 | """Calculates the n-th harmonic number""" 8 | return np.log(number) + np.euler_gamma 9 | 10 | 11 | def expected_depth(n_remainder): 12 | """Calculates the expected isolation depth for a remainder of uniform points""" 13 | if n_remainder <= 1: 14 | return 0.0 15 | if n_remainder == 2: 16 | return 1.0 17 | return float(2 * harmonic(n_remainder - 1) - 2 * (n_remainder - 1) / n_remainder) 18 | 19 | 20 | def calculate_depths(isolation_depths, tree, curr_node, curr_depth): 21 | """Fill in an array of isolation depths for a scikit-learn isolation forest model""" 22 | if tree.children_left[curr_node] == -1: 23 | isolation_depths[curr_node] = curr_depth + expected_depth( 24 | tree.n_node_samples[curr_node] 25 | ) 26 | else: 27 | calculate_depths( 28 | isolation_depths, tree, tree.children_left[curr_node], curr_depth + 1 29 | ) 30 | calculate_depths( 31 | isolation_depths, tree, tree.children_right[curr_node], curr_depth + 1 32 | ) 33 | -------------------------------------------------------------------------------- /src/model_loader/detail/xgboost.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2020-2023 by Contributors 3 | * \file xgboost.h 4 | * \brief Helper functions for loading XGBoost models 5 | * \author William Hicks 6 | */ 7 | #ifndef SRC_MODEL_LOADER_DETAIL_XGBOOST_H_ 8 | #define SRC_MODEL_LOADER_DETAIL_XGBOOST_H_ 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | namespace treelite::model_loader::detail::xgboost { 15 | 16 | struct ProbToMargin { 17 | static double Sigmoid(double base_score) { 18 | return -std::log(1.0 / base_score - 1.0); 19 | } 20 | static double Exponential(double base_score) { 21 | return std::log(base_score); 22 | } 23 | }; 24 | 25 | // Get correct prediction transform function, depending on objective function 26 | std::string GetPostProcessor(std::string const& objective_name); 27 | 28 | // Transform base score from probability into margin score 29 | double TransformBaseScoreToMargin(std::string const& postprocessor, double base_score); 30 | 31 | // Parse base score 32 | std::vector ParseBaseScore(std::string const& str); 33 | 34 | enum FeatureType { kNumerical = 0, kCategorical = 1 }; 35 | 36 | } // namespace treelite::model_loader::detail::xgboost 37 | 38 | #endif // SRC_MODEL_LOADER_DETAIL_XGBOOST_H_ 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Python compiled 15 | *.pyc 16 | 17 | # Compiled Java classes 18 | *.class 19 | /runtime/java/treelite4j/target/ 20 | 21 | # Auxiliary Java files 22 | /runtime/java/treelite4j/.classpath 23 | /runtime/java/treelite4j/.project 24 | /runtime/java/treelite4j/.settings/ 25 | /runtime/java/treelite4j/coverage.info 26 | 27 | # Compiled Dynamic libraries 28 | *.so 29 | *.dylib 30 | *.dll 31 | 32 | # Fortran module files 33 | *.mod 34 | *.smod 35 | 36 | # Compiled Static libraries 37 | *.lai 38 | *.la 39 | *.a 40 | *.lib 41 | 42 | # build folder 43 | /build/ 44 | /python/dist/ 45 | /python/build/ 46 | /python/treelite.egg-info/ 47 | /runtime/python/dist/ 48 | /runtime/python/build/ 49 | /runtime/python/treelite_runtime.egg-info/ 50 | 51 | # Executables 52 | *.exe 53 | *.out 54 | *.app 55 | /treelite 56 | 57 | # Mac system file 58 | .DS_Store 59 | 60 | # Python wheel binaries 61 | /python/dist/ 62 | 63 | # Python cache 64 | __pycache__ 65 | 66 | # symbol directories 67 | *.dSYM/ 68 | 69 | # Project files 70 | /runtime/python/.idea/ 71 | /python/.idea/ 72 | /tests/python/.idea/ 73 | /.idea/ 74 | /.hypothesis 75 | /lint.py 76 | -------------------------------------------------------------------------------- /src/enum/operator.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file operator.cc 4 | * \author Hyunsu Cho 5 | * \brief Utilities for Operator enum 6 | */ 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | namespace treelite { 14 | 15 | /*! \brief Get string representation of Operator */ 16 | std::string OperatorToString(Operator op) { 17 | switch (op) { 18 | case Operator::kEQ: 19 | return "=="; 20 | case Operator::kLT: 21 | return "<"; 22 | case Operator::kLE: 23 | return "<="; 24 | case Operator::kGT: 25 | return ">"; 26 | case Operator::kGE: 27 | return ">="; 28 | default: 29 | return ""; 30 | } 31 | } 32 | 33 | /*! \brief Get Operator from string */ 34 | Operator OperatorFromString(std::string const& name) { 35 | if (name == "==") { 36 | return Operator::kEQ; 37 | } else if (name == "<") { 38 | return Operator::kLT; 39 | } else if (name == "<=") { 40 | return Operator::kLE; 41 | } else if (name == ">") { 42 | return Operator::kGT; 43 | } else if (name == ">=") { 44 | return Operator::kGE; 45 | } else { 46 | TREELITE_LOG(FATAL) << "Unknown operator: " << name; 47 | return Operator::kNone; 48 | } 49 | } 50 | 51 | } // namespace treelite 52 | -------------------------------------------------------------------------------- /tests/cpp/test_model_loader.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file test_model_loader.cc 4 | * \author Hyunsu Cho 5 | * \brief C++ tests for model loader 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | #include "model_loader/detail/string_utils.h" 15 | #include "model_loader/detail/xgboost.h" 16 | 17 | TEST(ModelLoader, StringTrim) { 18 | std::string s{"foobar\r\n"}; 19 | treelite::model_loader::detail::StringTrimFromEnd(s); 20 | EXPECT_EQ(s, "foobar"); 21 | } 22 | 23 | TEST(ModelLoader, StringStartsWith) { 24 | std::string s{"foobar"}; 25 | EXPECT_TRUE(treelite::model_loader::detail::StringStartsWith(s, "foo")); 26 | } 27 | 28 | TEST(ModelLoader, XGBoostBaseScore) { 29 | { 30 | std::string s{"[5.2008224E-1,4.665861E-1]"}; 31 | std::vector parsed = treelite::model_loader::detail::xgboost::ParseBaseScore(s); 32 | std::vector expected{5.2008224E-1, 4.665861E-1}; 33 | EXPECT_EQ(parsed, expected); 34 | } 35 | 36 | { 37 | std::string s{"4.9333417E-1"}; 38 | std::vector parsed = treelite::model_loader::detail::xgboost::ParseBaseScore(s); 39 | std::vector expected{4.9333417E-1}; 40 | EXPECT_EQ(parsed, expected); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /cmake/Utils.cmake: -------------------------------------------------------------------------------- 1 | # Set output directory of target, ignoring debug or release 2 | function(set_output_directory target dir) 3 | set_target_properties(${target} PROPERTIES 4 | RUNTIME_OUTPUT_DIRECTORY ${dir} # for executable 5 | RUNTIME_OUTPUT_DIRECTORY_DEBUG ${dir} 6 | RUNTIME_OUTPUT_DIRECTORY_RELEASE ${dir} 7 | LIBRARY_OUTPUT_DIRECTORY ${dir} # for shared library 8 | LIBRARY_OUTPUT_DIRECTORY_DEBUG ${dir} 9 | LIBRARY_OUTPUT_DIRECTORY_RELEASE ${dir} 10 | ARCHIVE_OUTPUT_DIRECTORY ${dir} # for static library 11 | ARCHIVE_OUTPUT_DIRECTORY_DEBUG ${dir} 12 | ARCHIVE_OUTPUT_DIRECTORY_RELEASE ${dir} 13 | ) 14 | endfunction(set_output_directory) 15 | 16 | # Set a default build type to release if none was specified 17 | function(set_default_configuration_release) 18 | if(CMAKE_CONFIGURATION_TYPES STREQUAL "Debug;Release;MinSizeRel;RelWithDebInfo") # multiconfig generator? 19 | set(CMAKE_CONFIGURATION_TYPES "Debug;Release" CACHE STRING "" FORCE) 20 | elseif(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) 21 | message(STATUS "Setting build type to 'Release' as none was specified.") 22 | set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE ) 23 | endif() 24 | endfunction(set_default_configuration_release) 25 | -------------------------------------------------------------------------------- /.github/workflows/windows-wheel-builder.yml: -------------------------------------------------------------------------------- 1 | name: windows-wheel-builder 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - mainline 8 | - 'release_*' 9 | schedule: 10 | - cron: "0 7 * * *" # Run once daily 11 | 12 | permissions: 13 | contents: read # to fetch code (actions/checkout) 14 | 15 | defaults: 16 | run: 17 | shell: cmd /C CALL {0} 18 | 19 | concurrency: 20 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 21 | cancel-in-progress: true 22 | 23 | env: 24 | COMMIT_ID: ${{ github.sha }} 25 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_IAM_S3_UPLOADER }} 26 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }} 27 | 28 | jobs: 29 | windows-wheel-builder: 30 | name: Build and test Python wheels (Windows) 31 | runs-on: windows-latest 32 | steps: 33 | - uses: actions/checkout@v4 34 | - uses: conda-incubator/setup-miniconda@v3 35 | with: 36 | miniforge-variant: Miniforge3 37 | miniforge-version: latest 38 | conda-remove-defaults: "true" 39 | activate-environment: precommit 40 | environment-file: ops/conda_env/dev.yml 41 | use-mamba: true 42 | - name: Build wheel 43 | run: | 44 | call ops/build-windows.bat 45 | - name: Test wheel 46 | run: | 47 | call ops/test-win-python-wheel.bat 48 | -------------------------------------------------------------------------------- /src/enum/task_type.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file task_type.cc 4 | * \author Hyunsu Cho 5 | * \brief Utilities for TaskType enum 6 | */ 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | namespace treelite { 14 | 15 | std::string TaskTypeToString(TaskType type) { 16 | switch (type) { 17 | case TaskType::kBinaryClf: 18 | return "kBinaryClf"; 19 | case TaskType::kRegressor: 20 | return "kRegressor"; 21 | case TaskType::kMultiClf: 22 | return "kMultiClf"; 23 | case TaskType::kLearningToRank: 24 | return "kLearningToRank"; 25 | case TaskType::kIsolationForest: 26 | return "kIsolationForest"; 27 | default: 28 | return ""; 29 | } 30 | } 31 | 32 | TaskType TaskTypeFromString(std::string const& str) { 33 | if (str == "kBinaryClf") { 34 | return TaskType::kBinaryClf; 35 | } else if (str == "kRegressor") { 36 | return TaskType::kRegressor; 37 | } else if (str == "kMultiClf") { 38 | return TaskType::kMultiClf; 39 | } else if (str == "kLearningToRank") { 40 | return TaskType::kLearningToRank; 41 | } else if (str == "kIsolationForest") { 42 | return TaskType::kIsolationForest; 43 | } else { 44 | TREELITE_LOG(FATAL) << "Unknown task type: " << str; 45 | return TaskType::kBinaryClf; // to avoid compiler warning 46 | } 47 | } 48 | 49 | } // namespace treelite 50 | -------------------------------------------------------------------------------- /tests/example_app/example.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file example.cc 4 | * \brief Test using Treelite as a C++ library 5 | */ 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | int main(void) { 18 | std::cout << "TREELITE_VERSION = " << TREELITE_VERSION << std::endl; 19 | auto builder = treelite::model_builder::GetModelBuilder(treelite::TypeInfo::kFloat32, 20 | treelite::TypeInfo::kFloat32, 21 | treelite::model_builder::Metadata{2, treelite::TaskType::kRegressor, false, 1, {1}, {1, 1}}, 22 | treelite::model_builder::TreeAnnotation{1, {0}, {0}}, 23 | treelite::model_builder::PostProcessorFunc{"identity"}, std::vector{0.0}); 24 | builder->StartTree(); 25 | builder->StartNode(0); 26 | builder->NumericalTest(0, 0.0, true, treelite::Operator::kLT, 1, 2); 27 | builder->EndNode(); 28 | builder->StartNode(1); 29 | builder->LeafScalar(-1.0); 30 | builder->EndNode(); 31 | builder->StartNode(2); 32 | builder->LeafScalar(1.0); 33 | builder->EndNode(); 34 | builder->EndTree(); 35 | 36 | auto model = builder->CommitModel(); 37 | std::cout << model->GetNumTree() << std::endl; 38 | return 0; 39 | } 40 | -------------------------------------------------------------------------------- /src/c_api/field_accessor.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file field_accessor.cc 4 | * \author Hyunsu Cho 5 | * \brief C API for accessing fields in Treelite model 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | int TreeliteGetHeaderField( 13 | TreeliteModelHandle model, char const* name, TreelitePyBufferFrame* out_frame) { 14 | API_BEGIN(); 15 | auto* model_ = static_cast(model); 16 | *out_frame = model_->GetHeaderField(name); 17 | API_END(); 18 | } 19 | 20 | int TreeliteGetTreeField(TreeliteModelHandle model, uint64_t tree_id, char const* name, 21 | TreelitePyBufferFrame* out_frame) { 22 | API_BEGIN(); 23 | auto* model_ = static_cast(model); 24 | *out_frame = model_->GetTreeField(tree_id, name); 25 | API_END(); 26 | } 27 | 28 | int TreeliteSetHeaderField( 29 | TreeliteModelHandle model, char const* name, TreelitePyBufferFrame frame) { 30 | API_BEGIN(); 31 | auto* model_ = static_cast(model); 32 | model_->SetHeaderField(name, frame); 33 | API_END(); 34 | } 35 | 36 | int TreeliteSetTreeField( 37 | TreeliteModelHandle model, uint64_t tree_id, char const* name, TreelitePyBufferFrame frame) { 38 | API_BEGIN(); 39 | auto* model_ = static_cast(model); 40 | model_->SetTreeField(tree_id, name, frame); 41 | API_END(); 42 | } 43 | -------------------------------------------------------------------------------- /.github/workflows/cpack-builder.yml: -------------------------------------------------------------------------------- 1 | name: cpack-builder 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - mainline 8 | - 'release_*' 9 | schedule: 10 | - cron: "0 7 * * *" # Run once daily 11 | 12 | permissions: 13 | contents: read 14 | 15 | defaults: 16 | run: 17 | shell: bash -l {0} 18 | 19 | concurrency: 20 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 21 | cancel-in-progress: true 22 | 23 | env: 24 | COMMIT_ID: ${{ github.sha }} 25 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_IAM_S3_UPLOADER }} 26 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }} 27 | 28 | jobs: 29 | cpack-builder: 30 | name: Build CPack 31 | runs-on: ${{ matrix.runner }} 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | include: 36 | - arch: aarch64 37 | runner: ubuntu-24.04-arm 38 | - arch: amd64 39 | runner: ubuntu-24.04 40 | steps: 41 | - uses: actions/checkout@v4 42 | - uses: conda-incubator/setup-miniconda@v3 43 | with: 44 | miniforge-variant: Miniforge3 45 | miniforge-version: latest 46 | conda-remove-defaults: "true" 47 | activate-environment: dev 48 | environment-file: ops/conda_env/dev.yml 49 | use-mamba: true 50 | - name: Build CPack 51 | run: | 52 | bash ops/build-cpack.sh ${{ matrix.arch }} 53 | -------------------------------------------------------------------------------- /include/treelite/enum/typeinfo.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2017-2023 by Contributors 3 | * \file typeinfo.h 4 | * \brief Defines enum type TypeInfo 5 | * \author Hyunsu Cho 6 | */ 7 | 8 | #ifndef TREELITE_ENUM_TYPEINFO_H_ 9 | #define TREELITE_ENUM_TYPEINFO_H_ 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | namespace treelite { 19 | 20 | /*! \brief Types used by thresholds and leaf outputs */ 21 | enum class TypeInfo : std::uint8_t { kInvalid = 0, kUInt32 = 1, kFloat32 = 2, kFloat64 = 3 }; 22 | 23 | /*! \brief Get string representation of TypeInfo */ 24 | std::string TypeInfoToString(treelite::TypeInfo info); 25 | 26 | /*! \brief Get TypeInfo from string */ 27 | TypeInfo TypeInfoFromString(std::string const& str); 28 | 29 | /*! 30 | * \brief Convert a template type into a type info 31 | * \tparam template type to be converted 32 | * \return TypeInfo corresponding to the template type arg 33 | */ 34 | template 35 | inline TypeInfo TypeInfoFromType() { 36 | if (std::is_same_v) { 37 | return TypeInfo::kUInt32; 38 | } else if (std::is_same_v) { 39 | return TypeInfo::kFloat32; 40 | } else if (std::is_same_v) { 41 | return TypeInfo::kFloat64; 42 | } else { 43 | return TypeInfo::kInvalid; 44 | } 45 | } 46 | 47 | } // namespace treelite 48 | 49 | #endif // TREELITE_ENUM_TYPEINFO_H_ 50 | -------------------------------------------------------------------------------- /src/gtil/output_shape.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file output_shape.cc 4 | * \author Hyunsu Cho 5 | * \brief Compute output shape for GTIL, so that callers can allocate sufficient space 6 | * to hold outputs. 7 | */ 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | 15 | namespace treelite::gtil { 16 | 17 | std::vector GetOutputShape( 18 | Model const& model, std::uint64_t num_row, Configuration const& config) { 19 | auto const num_tree = model.GetNumTree(); 20 | auto const max_num_class = static_cast( 21 | *std::max_element(model.num_class.Data(), model.num_class.Data() + model.num_target)); 22 | switch (config.pred_kind) { 23 | case PredictKind::kPredictDefault: 24 | case PredictKind::kPredictRaw: 25 | if (model.num_target > 1) { 26 | return {num_row, static_cast(model.num_target), max_num_class}; 27 | } else { 28 | return {num_row, 1, max_num_class}; 29 | } 30 | case PredictKind::kPredictLeafID: 31 | return {num_row, num_tree}; 32 | case PredictKind::kPredictPerTree: 33 | return {num_row, num_tree, 34 | static_cast(model.leaf_vector_shape[0]) * model.leaf_vector_shape[1]}; 35 | default: 36 | TREELITE_LOG(FATAL) << "Unsupported model type: " << static_cast(config.pred_kind); 37 | return {}; 38 | } 39 | } 40 | 41 | } // namespace treelite::gtil 42 | -------------------------------------------------------------------------------- /tests/examples/deep_lightgbm/model.txt: -------------------------------------------------------------------------------- 1 | tree 2 | version=v4 3 | num_class=1 4 | num_tree_per_iteration=1 5 | label_index=0 6 | max_feature_idx=0 7 | objective=regression 8 | feature_names=this 9 | feature_infos=[0:100] 10 | tree_sizes=1119 11 | 12 | Tree=0 13 | num_leaves=32 14 | num_cat=0 15 | split_feature=0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 16 | split_gain=0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 17 | threshold=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 18 | decision_type=2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 19 | left_child=1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 -31 20 | right_child=-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -32 21 | leaf_value=31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 0 1 22 | leaf_weight=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 23 | leaf_count=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 24 | internal_value=0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 25 | internal_weight=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 26 | internal_count=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 27 | is_linear=0 28 | shrinkage=1 29 | 30 | 31 | end of trees 32 | 33 | feature_importances: 34 | this=31 35 | 36 | pandas_categorical:null 37 | -------------------------------------------------------------------------------- /src/model_query.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2024 by Contributors 3 | * \file model_query.cc 4 | * \author Hyunsu Cho 5 | * \brief Methods for querying various properties of tree models 6 | */ 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | namespace { 16 | 17 | template 18 | std::uint32_t GetDepth(treelite::Tree const& tree) { 19 | // Visit all trees nodes in depth-first order 20 | std::stack st; 21 | st.push(0); 22 | std::uint32_t max_depth = 0; 23 | std::uint32_t depth = 1; 24 | while (!st.empty()) { 25 | int node_id = st.top(); 26 | st.pop(); 27 | if (tree.IsLeaf(node_id)) { 28 | --depth; 29 | } else { 30 | st.push(tree.LeftChild(node_id)); 31 | st.push(tree.RightChild(node_id)); 32 | ++depth; 33 | } 34 | max_depth = std::max(max_depth, depth); 35 | } 36 | return max_depth; 37 | } 38 | 39 | } // anonymous namespace 40 | 41 | namespace treelite { 42 | 43 | std::vector Model::GetTreeDepth() const { 44 | return std::visit( 45 | [](auto&& concrete_model) { 46 | std::vector depth; 47 | depth.reserve(concrete_model.trees.size()); 48 | for (auto const& tree : concrete_model.trees) { 49 | depth.push_back(GetDepth(tree)); 50 | } 51 | return depth; 52 | }, 53 | variant_); 54 | } 55 | 56 | } // namespace treelite 57 | -------------------------------------------------------------------------------- /tests/ci_build/build_macos_python_wheels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | if [[ $# -ne 2 ]]; then 7 | echo "Usage: $0 [platform_id] [commit ID]" 8 | exit 1 9 | fi 10 | 11 | platform_id=$1 12 | shift 13 | commit_id=$1 14 | shift 15 | 16 | if [[ "$platform_id" == macosx_* ]]; then 17 | if [[ "$platform_id" == macosx_arm64 ]]; then 18 | # MacOS, Apple Silicon 19 | wheel_tag=macosx_12_0_arm64 20 | cpython_ver=310 21 | cibw_archs=arm64 22 | export MACOSX_DEPLOYMENT_TARGET=12.0 23 | elif [[ "$platform_id" == macosx_x86_64 ]]; then 24 | # MacOS, Intel 25 | wheel_tag=macosx_10_15_x86_64.macosx_11_0_x86_64.macosx_12_0_x86_64 26 | cpython_ver=310 27 | cibw_archs=x86_64 28 | export MACOSX_DEPLOYMENT_TARGET=10.15 29 | else 30 | echo "Platform not supported: $platform_id" 31 | exit 3 32 | fi 33 | # Set up environment variables to configure cibuildwheel 34 | export CIBW_BUILD=cp${cpython_ver}-${platform_id} 35 | export CIBW_ARCHS=${cibw_archs} 36 | export CIBW_TEST_SKIP='*-macosx_arm64' 37 | export CIBW_BUILD_VERBOSITY=3 38 | else 39 | echo "Platform not supported: $platform_id" 40 | exit 2 41 | fi 42 | 43 | # Tell delocate-wheel to not vendor libomp.dylib into the wheel 44 | export CIBW_REPAIR_WHEEL_COMMAND_MACOS="delocate-wheel --require-archs {delocate_archs} -w {dest_dir} -v {wheel} --exclude libomp.dylib" 45 | 46 | python -m cibuildwheel python --output-dir wheelhouse 47 | python tests/ci_build/rename_whl.py wheelhouse ${commit_id} ${wheel_tag} 48 | -------------------------------------------------------------------------------- /docs/treelite-api.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Treelite API 3 | ============ 4 | 5 | API of Treelite Python package. 6 | 7 | .. contents:: 8 | :local: 9 | 10 | Model loaders 11 | ------------- 12 | 13 | .. automodule:: treelite.frontend 14 | :members: 15 | :member-order: bysource 16 | 17 | Scikit-learn importer 18 | --------------------- 19 | 20 | .. automodule:: treelite.sklearn 21 | :members: 22 | :member-order: bysource 23 | 24 | Model builder 25 | ------------- 26 | 27 | .. automodule:: treelite.model_builder 28 | :members: 29 | :member-order: bysource 30 | 31 | Model class 32 | ----------- 33 | 34 | .. autoclass:: treelite.Model 35 | :members: 36 | :member-order: bysource 37 | :exclude-members: load, from_xgboost, from_xgboost_json, from_lightgbm 38 | 39 | 40 | .. _field_accessors: 41 | 42 | Field accessors (Advanced) 43 | -------------------------- 44 | Using field accessors, users can query and modify the value of fields in a :py:class:`~treelite.Model` object. 45 | See :doc:`/tutorials/edit` for more details. 46 | 47 | .. note:: Modifying a field is an unsafe operation 48 | 49 | Treelite does not prevent users from assigning an invalid value to a field. Setting an invalid value may 50 | cause undefined behavior. Always consult :doc:`the model spec ` to carefully examine 51 | model invariants and constraints on fields. For example, most tree fields must have an array of length ``num_nodes``. 52 | 53 | .. autoclass:: treelite.model.HeaderAccessor 54 | :members: 55 | 56 | .. autoclass:: treelite.model.TreeAccessor 57 | :members: 58 | -------------------------------------------------------------------------------- /tests/ci_build/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script is a wrapper creating the same user inside container as the one 4 | # running the ci_build.sh outside the container. It also set the home directory 5 | # for the user inside container to match the same absolute path as the workspace 6 | # outside of container. Do not run this manually. It does not make sense. It is 7 | # intended to be called by ci_build.sh only. 8 | 9 | set -e 10 | 11 | COMMAND=("$@") 12 | 13 | if ! touch /this_is_writable_file_system; then 14 | echo "You can't write to your filesystem!" 15 | echo "If you are in Docker you should check you do not have too many images" \ 16 | "with too many files in them. Docker has some issue with it." 17 | exit 1 18 | else 19 | rm /this_is_writable_file_system 20 | fi 21 | 22 | if [[ -n $CI_BUILD_UID ]] && [[ -n $CI_BUILD_GID ]]; then 23 | groupadd -o -g "${CI_BUILD_GID}" "${CI_BUILD_GROUP}" 24 | useradd -o -m -g "${CI_BUILD_GID}" -u "${CI_BUILD_UID}" \ 25 | "${CI_BUILD_USER}" 26 | export HOME="/home/${CI_BUILD_USER}" 27 | shopt -s dotglob 28 | cp -r /root/* "$HOME/" 29 | chown -R "${CI_BUILD_UID}:${CI_BUILD_GID}" "$HOME" 30 | 31 | # Allows project-specific customization 32 | if [[ -e "/workspace/.pre_entry.sh" ]]; then 33 | gosu "${CI_BUILD_UID}:${CI_BUILD_GID}" /workspace/.pre_entry.sh 34 | fi 35 | 36 | # Enable passwordless sudo capabilities for the user 37 | chown root:"${CI_BUILD_GID}" "$(which gosu)" 38 | chmod +s "$(which gosu)"; sync 39 | 40 | exec gosu "${CI_BUILD_UID}:${CI_BUILD_GID}" "${COMMAND[@]}" 41 | else 42 | exec "${COMMAND[@]}" 43 | fi 44 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit-update.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit-update 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: "0 20 * * 1" # Run once weekly 7 | 8 | permissions: 9 | pull-requests: write 10 | contents: write 11 | 12 | defaults: 13 | run: 14 | shell: bash -l {0} 15 | 16 | concurrency: 17 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 18 | cancel-in-progress: true 19 | 20 | jobs: 21 | pre-commit-update: 22 | name: Auto-update pre-commit hooks 23 | runs-on: ubuntu-latest 24 | steps: 25 | - uses: actions/create-github-app-token@v1 26 | id: generate-token 27 | with: 28 | app-id: ${{ secrets.TOKEN_GENERATOR_APP_ID }} 29 | private-key: ${{ secrets.TOKEN_GENERATOR_APP_PRIVATE_KEY }} 30 | - uses: actions/checkout@v4 31 | - uses: conda-incubator/setup-miniconda@v3 32 | with: 33 | miniforge-variant: Miniforge3 34 | miniforge-version: latest 35 | conda-remove-defaults: "true" 36 | activate-environment: precommit 37 | environment-file: ops/conda_env/pre-commit.yml 38 | use-mamba: true 39 | - name: Update pre-commit hooks 40 | run: | 41 | pre-commit autoupdate 42 | - name: Create Pull Request 43 | uses: peter-evans/create-pull-request@v7 44 | if: github.ref == 'refs/heads/mainline' 45 | with: 46 | branch: create-pull-request/pre-commit-update 47 | base: mainline 48 | title: "[CI] Update pre-commit hooks" 49 | commit-message: "[CI] Update pre-commit hooks" 50 | token: ${{ steps.generate-token.outputs.token }} 51 | -------------------------------------------------------------------------------- /python/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "hatchling>=1.12.1" 4 | ] 5 | backend-path = ["."] 6 | build-backend = "packager.pep517" 7 | 8 | [project] 9 | name = "treelite" 10 | version = "4.7.0.dev0" 11 | authors = [ 12 | {name = "Hyunsu Cho", email = "chohyu01@cs.washington.edu"} 13 | ] 14 | description = "Treelite: Universal model exchange format for decision tree forests" 15 | readme = {file = "README.rst", content-type = "text/x-rst"} 16 | requires-python = ">=3.8" 17 | license = {text = "Apache-2.0"} 18 | classifiers = [ 19 | "License :: OSI Approved :: Apache Software License", 20 | "Development Status :: 5 - Production/Stable", 21 | "Operating System :: OS Independent", 22 | "Programming Language :: Python", 23 | "Programming Language :: Python :: 3", 24 | "Programming Language :: Python :: 3.8", 25 | "Programming Language :: Python :: 3.9", 26 | "Programming Language :: Python :: 3.10" 27 | ] 28 | dependencies = [ 29 | "numpy", 30 | "scipy", 31 | "packaging" 32 | ] 33 | 34 | [project.urls] 35 | documentation = "https://treelite.readthedocs.io/en/latest/" 36 | repository = "https://github.com/dmlc/treelite" 37 | 38 | [project.optional-dependencies] 39 | scikit-learn = ["scikit-learn"] 40 | testing = ["scikit-learn", "pytest", "hypothesis", "pandas"] 41 | 42 | [tool.mypy] 43 | plugins = "numpy.typing.mypy_plugin" 44 | 45 | [tool.hatch.build.targets.wheel.hooks.custom] 46 | 47 | [tool.ruff] 48 | line-length = 120 49 | 50 | # this should be set to the oldest version of python treelite supports 51 | target-version = "py38" 52 | 53 | [tool.ruff.lint] 54 | select = [ 55 | # numpy 2.0 deprecations/removals 56 | "NPY201", 57 | ] 58 | -------------------------------------------------------------------------------- /src/gtil/config.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file config.cc 4 | * \author Hyunsu Cho 5 | * \brief Configuration handling logic for GTIL 6 | */ 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | namespace treelite::gtil { 15 | 16 | Configuration::Configuration(std::string const& config_json) { 17 | rapidjson::Document parsed_config; 18 | parsed_config.Parse(config_json); 19 | 20 | if (parsed_config.IsObject()) { 21 | auto itr = parsed_config.FindMember("predict_type"); 22 | if (itr != parsed_config.MemberEnd() && itr->value.IsString()) { 23 | auto value = std::string(itr->value.GetString()); 24 | if (value == "default") { 25 | this->pred_kind = PredictKind::kPredictDefault; 26 | } else if (value == "raw") { 27 | this->pred_kind = PredictKind::kPredictRaw; 28 | } else if (value == "leaf_id") { 29 | this->pred_kind = PredictKind::kPredictLeafID; 30 | } else if (value == "score_per_tree") { 31 | this->pred_kind = PredictKind::kPredictPerTree; 32 | } else { 33 | TREELITE_LOG(FATAL) << "Unknown prediction type: " << value; 34 | } 35 | } else { 36 | TREELITE_LOG(FATAL) << "The field \"predict_type\" must be specified"; 37 | } 38 | itr = parsed_config.FindMember("nthread"); 39 | if (itr != parsed_config.MemberEnd()) { 40 | TREELITE_CHECK(itr->value.IsInt()) << "nthread must be an integer"; 41 | this->nthread = itr->value.GetInt(); 42 | } 43 | } else { 44 | TREELITE_LOG(FATAL) << "The JSON string must be a valid JSON object"; 45 | } 46 | } 47 | 48 | } // namespace treelite::gtil 49 | -------------------------------------------------------------------------------- /include/treelite/detail/omp_exception.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2022-2023 by Contributors 3 | * \file omp_exception.h 4 | * \author Hyunsu Cho 5 | * \brief Utility to propagate exceptions throws inside an OpenMP block 6 | */ 7 | #ifndef TREELITE_DETAIL_OMP_EXCEPTION_H_ 8 | #define TREELITE_DETAIL_OMP_EXCEPTION_H_ 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | namespace treelite { 16 | 17 | /*! 18 | * \brief OMP Exception class catches, saves and rethrows exception from OMP blocks 19 | */ 20 | class OMPException { 21 | private: 22 | // exception_ptr member to store the exception 23 | std::exception_ptr omp_exception_; 24 | // mutex to be acquired during catch to set the exception_ptr 25 | std::mutex mutex_; 26 | 27 | public: 28 | /*! 29 | * \brief Parallel OMP blocks should be placed within Run to save exception 30 | */ 31 | template 32 | void Run(Function f, Parameters... params) { 33 | try { 34 | f(params...); 35 | } catch (treelite::Error& ex) { 36 | std::lock_guard lock(mutex_); 37 | if (!omp_exception_) { 38 | omp_exception_ = std::current_exception(); 39 | } 40 | } catch (std::exception& ex) { 41 | std::lock_guard lock(mutex_); 42 | if (!omp_exception_) { 43 | omp_exception_ = std::current_exception(); 44 | } 45 | } 46 | } 47 | 48 | /*! 49 | * \brief should be called from the main thread to rethrow the exception 50 | */ 51 | void Rethrow() { 52 | if (this->omp_exception_) { 53 | std::rethrow_exception(this->omp_exception_); 54 | } 55 | } 56 | }; 57 | 58 | } // namespace treelite 59 | 60 | #endif // TREELITE_DETAIL_OMP_EXCEPTION_H_ 61 | -------------------------------------------------------------------------------- /.github/workflows/linux-wheel-builder.yml: -------------------------------------------------------------------------------- 1 | name: linux-wheel-builder 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - mainline 8 | - 'release_*' 9 | schedule: 10 | - cron: "0 7 * * *" # Run once daily 11 | 12 | permissions: 13 | contents: read # to fetch code (actions/checkout) 14 | 15 | defaults: 16 | run: 17 | shell: bash -l {0} 18 | 19 | concurrency: 20 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 21 | cancel-in-progress: true 22 | 23 | env: 24 | COMMIT_ID: ${{ github.sha }} 25 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_IAM_S3_UPLOADER }} 26 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }} 27 | 28 | jobs: 29 | linux-wheel-builder: 30 | name: Build and test Python wheels (Linux) 31 | runs-on: ${{ matrix.os }} 32 | strategy: 33 | matrix: 34 | include: 35 | - os: ubuntu-latest 36 | build-script: ops/build-linux.sh 37 | - os: ubuntu-22.04-arm 38 | build-script: ops/build-linux-aarch64.sh 39 | 40 | steps: 41 | - uses: actions/checkout@v4 42 | - uses: conda-incubator/setup-miniconda@v3 43 | with: 44 | miniforge-variant: Miniforge3 45 | miniforge-version: latest 46 | conda-remove-defaults: "true" 47 | activate-environment: dev 48 | environment-file: ops/conda_env/dev.yml 49 | use-mamba: true 50 | 51 | - name: Display Conda env 52 | run: | 53 | conda info 54 | conda list 55 | 56 | - name: Build wheel 57 | run: | 58 | bash ${{ matrix.build-script }} 59 | 60 | - name: Test wheel 61 | run: | 62 | bash ops/test-linux-python-wheel.sh 63 | -------------------------------------------------------------------------------- /ops/test-serializer-compatibility.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | echo "##[section]Building Treelite..." 6 | mkdir -p build 7 | cd build 8 | cmake .. -GNinja 9 | ninja 10 | cd .. 11 | 12 | CURRENT_VERSION=$(cat python/treelite/VERSION) 13 | 14 | echo "##[section]Testing serialization: 3.9 -> ${CURRENT_VERSION}" 15 | pip install --force-reinstall treelite==3.9.0 treelite_runtime==3.9.0 16 | python tests/serializer/compatibility_tester.py --task save --checkpoint-path checkpoint.bin \ 17 | --model-pickle-path model.pkl --expected-treelite-version 3.9.0 18 | PYTHONPATH=./python/ python tests/serializer/compatibility_tester.py --task load \ 19 | --checkpoint-path checkpoint.bin --model-pickle-path model.pkl \ 20 | --expected-treelite-version ${CURRENT_VERSION} 21 | 22 | echo "##[section]Testing serialization: 4.3.0 -> ${CURRENT_VERSION}" 23 | pip install --force-reinstall treelite==4.3.0 24 | python tests/serializer/compatibility_tester.py --task save --checkpoint-path checkpoint.bin \ 25 | --model-pickle-path model.pkl --expected-treelite-version 4.3.0 26 | PYTHONPATH=./python/ python tests/serializer/compatibility_tester.py --task load \ 27 | --checkpoint-path checkpoint.bin --model-pickle-path model.pkl \ 28 | --expected-treelite-version ${CURRENT_VERSION} 29 | 30 | echo "##[section]Testing serialization: ${CURRENT_VERSION} -> ${CURRENT_VERSION}" 31 | PYTHONPATH=./python/ python tests/serializer/compatibility_tester.py --task save \ 32 | --checkpoint-path checkpoint.bin --model-pickle-path model.pkl \ 33 | --expected-treelite-version ${CURRENT_VERSION} 34 | PYTHONPATH=./python/ python tests/serializer/compatibility_tester.py --task load \ 35 | --checkpoint-path checkpoint.bin --model-pickle-path model.pkl \ 36 | --expected-treelite-version ${CURRENT_VERSION} 37 | -------------------------------------------------------------------------------- /.github/workflows/macos-wheel-builder.yml: -------------------------------------------------------------------------------- 1 | name: macos-wheel-builder 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - mainline 8 | - 'release_*' 9 | schedule: 10 | - cron: "0 7 * * *" # Run once daily 11 | 12 | permissions: 13 | contents: read # to fetch code (actions/checkout) 14 | 15 | defaults: 16 | run: 17 | shell: bash -l {0} 18 | 19 | concurrency: 20 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 21 | cancel-in-progress: true 22 | 23 | env: 24 | COMMIT_ID: ${{ github.sha }} 25 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_IAM_S3_UPLOADER }} 26 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }} 27 | 28 | jobs: 29 | macos-wheel-builder: 30 | name: Build and test Python wheels (MacOS) 31 | runs-on: ${{ matrix.os }} 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | include: 36 | - os: macos-15-intel 37 | platform_id: macosx_x86_64 38 | - os: macos-14 39 | platform_id: macosx_arm64 40 | env: 41 | CIBW_PLATFORM_ID: ${{ matrix.cibw_platform_id }} 42 | steps: 43 | - uses: actions/checkout@v4 44 | - uses: conda-incubator/setup-miniconda@v3 45 | with: 46 | miniforge-variant: Miniforge3 47 | miniforge-version: latest 48 | conda-remove-defaults: "true" 49 | activate-environment: dev 50 | environment-file: ops/conda_env/dev.yml 51 | use-mamba: true 52 | - name: Display Conda env 53 | run: | 54 | conda info 55 | conda list 56 | - name: Build wheel 57 | run: | 58 | bash ops/build-macos.sh ${{ matrix.platform_id }} ${{ github.sha }} 59 | - name: Test wheel 60 | run: | 61 | bash ops/test-macos-python-wheel.sh 62 | -------------------------------------------------------------------------------- /include/treelite/c_api_error.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2017-2023 by Contributors 3 | * \file c_api_error.h 4 | * \author Hyunsu Cho 5 | * \brief Error handling for C API. 6 | */ 7 | #ifndef TREELITE_C_API_ERROR_H_ 8 | #define TREELITE_C_API_ERROR_H_ 9 | 10 | #include 11 | 12 | /*! \brief Macro to guard beginning and end section of all functions */ 13 | #define API_BEGIN() try { 14 | /*! \brief Every function starts with API_BEGIN(); 15 | and finishes with API_END() or API_END_HANDLE_ERROR */ 16 | #define API_END() \ 17 | } \ 18 | catch (std::exception & _except_) { \ 19 | return TreeliteAPIHandleException(_except_); \ 20 | } \ 21 | return 0 22 | /*! 23 | * \brief Every function starts with API_BEGIN(); 24 | * and finishes with API_END() or API_END_HANDLE_ERROR() 25 | * "Finalize" contains procedure to cleanup states when an error happens 26 | */ 27 | #define API_END_HANDLE_ERROR(Finalize) \ 28 | } \ 29 | catch (std::exception & _except_) { \ 30 | Finalize; \ 31 | return TreeliteAPIHandleException(_except_); \ 32 | } \ 33 | return 0 34 | 35 | /*! 36 | * \brief Set the last error message needed by C API 37 | * \param msg Error message to set. 38 | */ 39 | void TreeliteAPISetLastError(char const* msg); 40 | /*! 41 | * \brief handle Exception thrown out 42 | * \param e Exception object 43 | * \return The return value of API after exception is handled 44 | */ 45 | inline int TreeliteAPIHandleException(std::exception const& e) { 46 | TreeliteAPISetLastError(e.what()); 47 | return -1; 48 | } 49 | #endif // TREELITE_C_API_ERROR_H_ 50 | -------------------------------------------------------------------------------- /tests/cpp/test_utils.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file test_utils.cc 4 | * \author Hyunsu Cho 5 | * \brief C++ tests for utility functions 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | #include 16 | 17 | TEST(FileUtils, StreamIO) { 18 | std::string s{"Hello world"}; 19 | std::string s2; 20 | std::filesystem::path tmpdir = std::filesystem::temp_directory_path(); 21 | std::filesystem::path filepath = tmpdir / std::filesystem::u8path("ななひら.txt"); 22 | 23 | { 24 | std::ofstream ofs = treelite::detail::OpenFileForWriteAsStream(filepath); 25 | ofs.write(s.data(), s.length()); 26 | } 27 | { 28 | std::ifstream ifs = treelite::detail::OpenFileForReadAsStream(filepath); 29 | s2.resize(s.length()); 30 | ifs.read(s2.data(), s.length()); 31 | ASSERT_EQ(s, s2); 32 | } 33 | 34 | std::filesystem::remove(filepath); 35 | } 36 | 37 | TEST(FileUtils, OpenFileForReadAsFilePtr) { 38 | std::string s{"Hello world"}; 39 | std::string s2; 40 | std::filesystem::path tmpdir = std::filesystem::temp_directory_path(); 41 | std::filesystem::path filepath = tmpdir / std::filesystem::u8path("ななひら.txt"); 42 | 43 | { 44 | std::ofstream ofs(filepath, std::ios::out | std::ios::binary); 45 | ASSERT_TRUE(ofs); 46 | ofs.exceptions(std::ios::failbit | std::ios::badbit); 47 | ofs.write(s.data(), s.length()); 48 | } 49 | { 50 | FILE* fp = treelite::detail::OpenFileForReadAsFilePtr(filepath); 51 | ASSERT_TRUE(fp); 52 | s2.resize(s.length()); 53 | ASSERT_EQ(std::fread(s2.data(), sizeof(char), s.length(), fp), s.length()); 54 | ASSERT_EQ(s, s2); 55 | ASSERT_EQ(std::fclose(fp), 0); 56 | } 57 | 58 | std::filesystem::remove(filepath); 59 | } 60 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v6.0.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-added-large-files 11 | args: ["--maxkb=4000"] 12 | - repo: https://github.com/psf/black 13 | rev: 25.12.0 14 | hooks: 15 | - id: black 16 | - repo: https://github.com/pycqa/isort 17 | rev: 7.0.0 18 | hooks: 19 | - id: isort 20 | args: ["--profile", "black", "--filter-files"] 21 | - repo: https://github.com/MarcoGorelli/cython-lint 22 | rev: v0.18.1 23 | hooks: 24 | - id: cython-lint 25 | - id: double-quote-cython-strings 26 | - repo: https://github.com/PyCQA/flake8 27 | rev: 7.3.0 28 | hooks: 29 | - id: flake8 30 | args: [--config=.flake8] 31 | files: .*$ 32 | types: [file] 33 | types_or: [python] 34 | additional_dependencies: [flake8-force] 35 | - repo: https://github.com/pocc/pre-commit-hooks 36 | rev: v1.3.5 37 | hooks: 38 | - id: clang-format 39 | args: ["-i", "--style=file:.clang-format"] 40 | language: python 41 | additional_dependencies: [clang-format>=15.0] 42 | types_or: [c++] 43 | - id: cpplint 44 | language: python 45 | args: [ 46 | "--linelength=100", "--recursive", 47 | "--filter=-build/c++11,-build/include,-build/namespaces_literals,-runtime/references,-build/include_order,+build/include_what_you_use", 48 | "--root=include"] 49 | additional_dependencies: [cpplint==1.6.1] 50 | types_or: [c++] 51 | - repo: https://github.com/pre-commit/mirrors-mypy 52 | rev: v1.19.0 53 | hooks: 54 | - id: mypy 55 | additional_dependencies: [types-setuptools] 56 | - repo: https://github.com/astral-sh/ruff-pre-commit 57 | rev: v0.14.8 58 | hooks: 59 | - id: ruff 60 | args: ["--config", "python/pyproject.toml"] 61 | -------------------------------------------------------------------------------- /cmake/Sanitizer.cmake: -------------------------------------------------------------------------------- 1 | # Set appropriate compiler and linker flags for sanitizers. 2 | # 3 | # Usage of this module: 4 | # enable_sanitizers("address;leak") 5 | 6 | # Add flags 7 | macro(enable_sanitizer sanitizer) 8 | if(${sanitizer} MATCHES "address") 9 | set(SAN_COMPILE_FLAGS "${SAN_COMPILE_FLAGS} -fsanitize=address") 10 | 11 | elseif(${sanitizer} MATCHES "thread") 12 | set(SAN_COMPILE_FLAGS "${SAN_COMPILE_FLAGS} -fsanitize=thread") 13 | 14 | elseif(${sanitizer} MATCHES "leak") 15 | set(SAN_COMPILE_FLAGS "${SAN_COMPILE_FLAGS} -fsanitize=leak") 16 | 17 | elseif(${sanitizer} MATCHES "undefined") 18 | set(SAN_COMPILE_FLAGS "${SAN_COMPILE_FLAGS} -fsanitize=undefined -fno-sanitize-recover=undefined") 19 | 20 | else() 21 | message(FATAL_ERROR "Santizer ${sanitizer} not supported.") 22 | endif() 23 | endmacro() 24 | 25 | macro(enable_sanitizers SANITIZERS) 26 | # Check sanitizers compatibility. 27 | foreach(_san ${SANITIZERS}) 28 | string(TOLOWER ${_san} _san) 29 | if(_san MATCHES "thread") 30 | if(${_use_other_sanitizers}) 31 | message(FATAL_ERROR 32 | "thread sanitizer is not compatible with ${_san} sanitizer.") 33 | endif() 34 | set(_use_thread_sanitizer 1) 35 | else() 36 | if(${_use_thread_sanitizer}) 37 | message(FATAL_ERROR 38 | "${_san} sanitizer is not compatible with thread sanitizer.") 39 | endif() 40 | set(_use_other_sanitizers 1) 41 | endif() 42 | endforeach() 43 | 44 | message(STATUS "Sanitizers: ${SANITIZERS}") 45 | 46 | foreach(_san ${SANITIZERS}) 47 | string(TOLOWER ${_san} _san) 48 | enable_sanitizer(${_san}) 49 | endforeach() 50 | message(STATUS "Sanitizers compile flags: ${SAN_COMPILE_FLAGS}") 51 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SAN_COMPILE_FLAGS}") 52 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SAN_COMPILE_FLAGS}") 53 | endmacro() 54 | -------------------------------------------------------------------------------- /python/packager/build_config.py: -------------------------------------------------------------------------------- 1 | """Build configuration""" 2 | 3 | import dataclasses 4 | from typing import Any, Dict, List, Optional 5 | 6 | 7 | @dataclasses.dataclass 8 | class BuildConfiguration: # pylint: disable=R0902 9 | """Configurations use when building libtreelite""" 10 | 11 | # Whether to enable OpenMP 12 | use_openmp: bool = True 13 | # Whether to hide C++ symbols 14 | hide_cxx_symbols: bool = True 15 | # Whether to use the Treelite library that's installed in the system prefix 16 | use_system_libtreelite: bool = False 17 | # Manually configure path for the Treelite library. 18 | # Only applicable when use_system_libtreelite=True 19 | system_libtreelite_dir: str = "" 20 | 21 | def _set_config_setting(self, config_settings: Dict[str, Any]) -> None: 22 | for field_name in config_settings: 23 | config_value = config_settings[field_name] 24 | if field_name == "system_libtreelite_dir": 25 | setattr(self, field_name, config_value) 26 | else: 27 | setattr( 28 | self, 29 | field_name, 30 | (config_value.lower() in ["true", "1", "on"]), 31 | ) 32 | 33 | def update(self, config_settings: Optional[Dict[str, Any]]) -> None: 34 | """Parse config_settings from Pip (or other PEP 517 frontend)""" 35 | if config_settings is not None: 36 | self._set_config_setting(config_settings) 37 | 38 | def get_cmake_args(self) -> List[str]: 39 | """Convert build configuration to CMake args""" 40 | cmake_args = [] 41 | for field_name in [x.name for x in dataclasses.fields(self)]: 42 | if field_name in ["use_system_libtreelite", "system_libtreelite_dir"]: 43 | continue 44 | cmake_option = field_name.upper() 45 | cmake_value = "ON" if getattr(self, field_name) is True else "OFF" 46 | cmake_args.append(f"-D{cmake_option}={cmake_value}") 47 | return cmake_args 48 | -------------------------------------------------------------------------------- /tests/ci_build/rename_whl.py: -------------------------------------------------------------------------------- 1 | """Rename a Python wheel""" 2 | 3 | import argparse 4 | import glob 5 | import os 6 | from contextlib import contextmanager 7 | 8 | 9 | @contextmanager 10 | def cd(path): # pylint: disable=C0103 11 | """Temporarily change working directory""" 12 | path = os.path.normpath(path) 13 | cwd = os.getcwd() 14 | os.chdir(path) 15 | print(f"cd {path}") 16 | try: 17 | yield path 18 | finally: 19 | os.chdir(cwd) 20 | 21 | 22 | def main(args): 23 | """Main function""" 24 | if not os.path.isdir(args.wheel_dir): 25 | raise ValueError("wheel_dir argument must be a directory") 26 | 27 | with cd(args.wheel_dir): 28 | whl_list = list(glob.glob("*.whl")) 29 | for whl_path in whl_list: 30 | basename = os.path.basename(whl_path) 31 | tokens = basename.split("-") 32 | assert len(tokens) == 5 33 | keywords = { 34 | "pkg_name": tokens[0], 35 | "version": tokens[1], 36 | "commit_id": args.commit_id, 37 | "platform_tag": args.platform_tag, 38 | } 39 | new_name = ( 40 | "{pkg_name}-{version}+{commit_id}-py3-none-{platform_tag}.whl".format( 41 | **keywords 42 | ) 43 | ) 44 | print(f"Renaming {basename} to {new_name}...") 45 | os.rename(basename, new_name) 46 | 47 | 48 | if __name__ == "__main__": 49 | DESCRIPTION = ( 50 | "Script to rename wheel(s) using a commit ID and platform tag." 51 | "Note: This script will not recurse into subdirectories." 52 | ) 53 | parser = argparse.ArgumentParser(description=DESCRIPTION) 54 | parser.add_argument("wheel_dir", type=str, help="Directory containing wheels") 55 | parser.add_argument("commit_id", type=str, help="Hash of current git commit") 56 | parser.add_argument( 57 | "platform_tag", type=str, help="Platform tag, PEP 425 compliant" 58 | ) 59 | parsed_args = parser.parse_args() 60 | main(parsed_args) 61 | -------------------------------------------------------------------------------- /docs/serialization/index.rst: -------------------------------------------------------------------------------- 1 | Notes on Serialization 2 | ====================== 3 | 4 | Treelite model objects can be serialized into three ways: 5 | 6 | * Byte sequence. A tree model can be serialized to a byte sequence in memory. Later, 7 | we can recover the same tree model by deserializing it from the byte sequence. 8 | * Files. Tree models can be converted into Treelite checkpoint files that can be later 9 | read back. 10 | * `Python Buffer Protocol `_, to enable 11 | zero-copy serialization in the Python programming environment. When pickling a Python 12 | object containing a Treelite model object, we can convert the Treelite model into 13 | a byte sequence without physically making copies in memory. 14 | 15 | We make certain guarantees about compatiblity of serialization. It is possible to 16 | exchange serialized tree models between two different Treelite versions, as follows: 17 | 18 | .. |tick| unicode:: U+2714 19 | .. |cross| unicode:: U+2718 20 | 21 | +----------------------+--------------+--------------+--------------------+---------------+ 22 | | | To: ``=3.9`` | To: ``=4.0`` | To: ``>=4.1,<5.0`` | To: ``>=5.0`` | 23 | +----------------------+--------------+--------------+--------------------+---------------+ 24 | | From: ``=3.9`` | |tick| | |tick| | |tick| | |cross| | 25 | +----------------------+--------------+--------------+--------------------+---------------+ 26 | | From: ``=4.0`` | |cross| | |tick| | |tick| | |tick| | 27 | +----------------------+--------------+--------------+--------------------+---------------+ 28 | | From: ``>=4.1,<5.0`` | |cross| | |tick| | |tick| | |tick| | 29 | +----------------------+--------------+--------------+--------------------+---------------+ 30 | | From: ``>=5.0`` | |cross| | |cross| | |cross| | |tick| | 31 | +----------------------+--------------+--------------+--------------------+---------------+ 32 | 33 | .. toctree:: 34 | :maxdepth: 1 35 | 36 | v4 37 | v3 38 | -------------------------------------------------------------------------------- /docs/knobs/postprocessor.rst: -------------------------------------------------------------------------------- 1 | List of postprocessor functions 2 | =============================== 3 | When predicting with tree ensemble models, we sum the margin scores from individual trees and apply a postprocessor 4 | function to transform the sum into a final prediction. This function is also known as the link function. 5 | 6 | Currently, Treelite supports the following postprocessor functions. 7 | 8 | Element-wise postprocessor functions 9 | ------------------------------------ 10 | * ``identity``: The identity function. Do not apply any transformation to the margin score vector. 11 | * ``signed_square``: Apply the function ``f(x) = sign(x) * (x**2)`` element-wise to the margin score vector. 12 | * ``hinge``: Apply the function ``f(x) = (1 if x > 0 else 0)`` element-wise to the margin score vector. 13 | * ``sigmoid``: Apply the sigmoid function ``f(x) = 1/(1+exp(-sigmoid_alpha * x))`` element-wise to the margin score 14 | vector, to transform margin scores into probability scores in the range ``[0, 1]``. The ``sigmoid_alpha`` parameter 15 | can be configured by the user. 16 | * ``exponential``: Apply the exponential function (``exp``) element-wise to the margin score vector. 17 | * ``exponential_standard_ratio``: Apply the function ``f(x) = exp2(-x / ratio_c)`` element-wise to the margin score 18 | vector. The ``ratio_c`` parameter can be configured by the user. 19 | * ``logarithm_one_plus_exp``: Apply the function ``f(x) = log(1 + exp(x))`` element-wise to the margin score vector. 20 | 21 | Row-wise postprocessor functions 22 | -------------------------------- 23 | * ``identity_multiclass``: The identity function. Do not apply any transformation to the margin score vector. 24 | * ``softmax``: Use the softmax function ``f(x) = exp(x) / sum(exp(x))`` to the margin score vector, to transform the 25 | margin scores into probability scores in the range ``[0, 1]``. Adding up the transformed scores for all classes 26 | will yield 1. 27 | * ``multiclass_ova``: Apply the sigmoid function ``f(x) = 1/(1+exp(-sigmoid_alpha * x))`` element-wise to the margin 28 | scores. The ``sigmoid_alpha`` parameter can be configured by the user. 29 | -------------------------------------------------------------------------------- /tests/serializer/compatibility_tester.py: -------------------------------------------------------------------------------- 1 | """Script to test backward compatibility of serializer""" 2 | 3 | import argparse 4 | import pickle 5 | 6 | import lightgbm as lgb 7 | import numpy as np 8 | from packaging.version import parse as parse_version 9 | from sklearn.datasets import load_iris 10 | 11 | import treelite 12 | 13 | 14 | def _fetch_data(): 15 | X, y = load_iris(return_X_y=True) 16 | return X, y 17 | 18 | 19 | def _train_model(X, y): 20 | clf = lgb.LGBMClassifier(max_depth=3, random_state=0, n_estimators=3) 21 | clf.fit(X, y) 22 | return clf 23 | 24 | 25 | def save(args): 26 | """Save model""" 27 | X, y = _fetch_data() 28 | clf = _train_model(X, y) 29 | with open(args.model_pickle_path, "wb") as f: 30 | pickle.dump(clf, f) 31 | if parse_version(treelite.__version__) >= parse_version("4.0"): 32 | tl_model = treelite.frontend.from_lightgbm(clf.booster_) 33 | else: 34 | tl_model = treelite.Model.from_lightgbm(clf.booster_) 35 | tl_model.serialize(args.checkpoint_path) 36 | 37 | 38 | def load(args): 39 | """Load model""" 40 | X, _ = _fetch_data() 41 | with open(args.model_pickle_path, "rb") as f: 42 | clf = pickle.load(f) 43 | tl_model = treelite.Model.deserialize(args.checkpoint_path) 44 | expected_prob = clf.predict_proba(X).reshape((X.shape[0], 1, -1)) 45 | out_prob = treelite.gtil.predict(tl_model, X) 46 | np.testing.assert_almost_equal(out_prob, expected_prob, decimal=5) 47 | print("Test passed!") 48 | 49 | 50 | def main(args): 51 | """Main function""" 52 | if treelite.__version__ != args.expected_treelite_version: 53 | raise ValueError( 54 | f"Expected Treelite {args.expected_treelite_version} " 55 | f"but running Treelite {treelite.__version__}" 56 | ) 57 | if args.task == "save": 58 | save(args) 59 | elif args.task == "load": 60 | load(args) 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument("--task", type=str, choices=["save", "load"], required=True) 66 | parser.add_argument("--checkpoint-path", type=str, required=True) 67 | parser.add_argument("--model-pickle-path", type=str, required=True) 68 | parser.add_argument("--expected-treelite-version", type=str, required=True) 69 | parsed_args = parser.parse_args() 70 | main(parsed_args) 71 | -------------------------------------------------------------------------------- /python/treelite/core.py: -------------------------------------------------------------------------------- 1 | """Interface with native lib""" 2 | 3 | import ctypes 4 | import os 5 | import sys 6 | import warnings 7 | 8 | from .libpath import TreeliteLibraryNotFound, find_lib_path 9 | from .util import py_str 10 | 11 | 12 | class TreeliteError(Exception): 13 | """Error thrown by Treelite""" 14 | 15 | 16 | @ctypes.CFUNCTYPE(None, ctypes.c_char_p) 17 | def _log_callback(msg: bytes) -> None: 18 | """Redirect logs from native library into Python console""" 19 | print(py_str(msg)) 20 | 21 | 22 | @ctypes.CFUNCTYPE(None, ctypes.c_char_p) 23 | def _warn_callback(msg: bytes) -> None: 24 | """Redirect warnings from native library into Python console""" 25 | warnings.warn(py_str(msg)) 26 | 27 | 28 | def _load_lib(): 29 | """Load Treelite Library.""" 30 | lib_path = [str(x) for x in find_lib_path()] 31 | if not lib_path: 32 | # Building docs 33 | return None # type: ignore 34 | if sys.version_info >= (3, 8) and sys.platform == "win32": 35 | # pylint: disable=no-member 36 | lib_bin_path = os.path.join(os.path.normpath(sys.base_prefix), "Library", "bin") 37 | if os.path.isdir(lib_bin_path): 38 | os.add_dll_directory(lib_bin_path) 39 | 40 | lib = ctypes.cdll.LoadLibrary(lib_path[0]) 41 | lib.TreeliteGetLastError.restype = ctypes.c_char_p 42 | lib.log_callback = _log_callback 43 | lib.warn_callback = _warn_callback 44 | if lib.TreeliteRegisterLogCallback(lib.log_callback) != 0: 45 | raise TreeliteError(py_str(lib.TreeliteGetLastError())) 46 | if lib.TreeliteRegisterWarningCallback(lib.warn_callback) != 0: 47 | raise TreeliteError(py_str(lib.TreeliteGetLastError())) 48 | return lib 49 | 50 | 51 | # Load the Treelite library globally 52 | # (do not load if called by Sphinx) 53 | if "sphinx" in sys.modules: 54 | try: 55 | _LIB = _load_lib() # pylint: disable=invalid-name 56 | except TreeliteLibraryNotFound: 57 | _LIB = None # pylint: disable=invalid-name 58 | else: 59 | _LIB = _load_lib() # pylint: disable=invalid-name 60 | 61 | 62 | def _check_call(ret: int) -> None: 63 | """Check the return value of C API call 64 | 65 | This function will raise exception when error occurs. 66 | Wrap every API call with this function. 67 | 68 | Parameters 69 | ---------- 70 | ret : 71 | return value from API calls 72 | """ 73 | if ret != 0: 74 | raise TreeliteError(_LIB.TreeliteGetLastError().decode("utf-8")) 75 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | Standard: c++17 4 | BasedOnStyle: Google 5 | TabWidth: 2 6 | IndentWidth: 2 7 | ColumnLimit: 100 8 | UseTab: Never 9 | 10 | AccessModifierOffset: -1 11 | AlignAfterOpenBracket: DontAlign 12 | AlignConsecutiveAssignments: None 13 | AlignConsecutiveDeclarations: None 14 | AlignEscapedNewlines: Left 15 | AlignTrailingComments: false 16 | AllowAllArgumentsOnNextLine: true 17 | AllowAllConstructorInitializersOnNextLine: true 18 | AllowAllParametersOfDeclarationOnNextLine: true 19 | AllowShortBlocksOnASingleLine: Empty 20 | AllowShortCaseLabelsOnASingleLine: false 21 | AllowShortFunctionsOnASingleLine: Empty 22 | AllowShortIfStatementsOnASingleLine: Never 23 | AllowShortLambdasOnASingleLine: All 24 | AllowShortLoopsOnASingleLine: false 25 | AlwaysBreakTemplateDeclarations: Yes 26 | BinPackArguments: true 27 | BinPackParameters: true 28 | BreakBeforeBinaryOperators: All 29 | BreakBeforeBraces: Attach 30 | BreakBeforeTernaryOperators: true 31 | BreakConstructorInitializers: BeforeColon 32 | BreakInheritanceList: AfterColon 33 | BreakStringLiterals: true 34 | CompactNamespaces: false 35 | Cpp11BracedListStyle: true 36 | DerivePointerAlignment: false 37 | IndentWrappedFunctionNames: false 38 | InsertBraces: true 39 | IndentCaseLabels: false 40 | KeepEmptyLinesAtTheStartOfBlocks: false 41 | NamespaceIndentation: None 42 | PointerAlignment: Left 43 | ReflowComments: true 44 | 45 | SortIncludes: CaseInsensitive 46 | IncludeBlocks: Regroup 47 | IncludeCategories: 48 | # Headers in <> without extension. 49 | - Regex: '<([A-Za-z0-9\Q/-_\E])+>' 50 | Priority: 1 51 | # Headers in <> from Treelite 52 | - Regex: '<(treelite)\/' 53 | Priority: 2 54 | # Headers in <> from external libraries. 55 | - Regex: '<(rapidjson|nlohmann|gtest|fmt)\/' 56 | Priority: 3 57 | # Headers in "" with extension. 58 | - Regex: '"([A-Za-z0-9.\Q/-_\E])+"' 59 | Priority: 4 60 | 61 | SpaceAfterCStyleCast: false 62 | SpaceAfterTemplateKeyword: true 63 | SpaceBeforeAssignmentOperators: true 64 | SpaceBeforeCpp11BracedList: false 65 | SpaceBeforeCtorInitializerColon: true 66 | SpaceBeforeInheritanceColon: true 67 | SpaceBeforeParens: ControlStatements 68 | SpaceBeforeRangeBasedForLoopColon: true 69 | SpaceInEmptyParentheses: false 70 | SpacesInAngles: false 71 | SpacesInCStyleCastParentheses: false 72 | SpacesInContainerLiterals: false 73 | SpacesInParentheses: false 74 | SpacesInSquareBrackets: false 75 | QualifierAlignment: Right 76 | -------------------------------------------------------------------------------- /include/treelite/contiguous_array.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file contiguous_array.h 4 | * \brief A simple array container, with owned or non-owned (externally allocated) buffer 5 | * \author Hyunsu Cho 6 | */ 7 | 8 | #ifndef TREELITE_CONTIGUOUS_ARRAY_H_ 9 | #define TREELITE_CONTIGUOUS_ARRAY_H_ 10 | 11 | #include 12 | #include 13 | 14 | namespace treelite { 15 | 16 | template 17 | class ContiguousArray { 18 | public: 19 | ContiguousArray(); 20 | ~ContiguousArray(); 21 | // NOTE: use Clone to make deep copy; copy constructors disabled 22 | ContiguousArray(ContiguousArray const&) = delete; 23 | ContiguousArray& operator=(ContiguousArray const&) = delete; 24 | explicit ContiguousArray(std::vector const& other); 25 | ContiguousArray& operator=(std::vector const& other); 26 | ContiguousArray(ContiguousArray&& other) noexcept; 27 | ContiguousArray& operator=(ContiguousArray&& other) noexcept; 28 | inline ContiguousArray Clone() const; 29 | inline void UseForeignBuffer(void* prealloc_buf, std::size_t size); 30 | inline T* Data(); 31 | inline T const* Data() const; 32 | inline T* End(); 33 | inline T const* End() const; 34 | inline T& Back(); 35 | inline T const& Back() const; 36 | inline std::size_t Size() const; 37 | inline bool Empty() const; 38 | inline void Reserve(std::size_t newsize); 39 | inline void Resize(std::size_t newsize); 40 | inline void Resize(std::size_t newsize, T t); 41 | inline void Clear(); 42 | inline void PushBack(T t); 43 | inline void Extend(std::vector const& other); 44 | inline void Extend(ContiguousArray const& other); 45 | inline std::vector AsVector() const; 46 | inline bool operator==(ContiguousArray const& other); 47 | /* Unsafe access, no bounds checking */ 48 | inline T& operator[](std::size_t idx); 49 | inline T const& operator[](std::size_t idx) const; 50 | /* Safe access, with bounds checking */ 51 | inline T& at(std::size_t idx); 52 | inline T const& at(std::size_t idx) const; 53 | /* Safe access, with bounds checking + check against non-existent node (<0) */ 54 | inline T& at(int idx); 55 | inline T const& at(int idx) const; 56 | static_assert(std::is_pod::value, "T must be POD"); 57 | 58 | private: 59 | T* buffer_; 60 | std::size_t size_; 61 | std::size_t capacity_; 62 | bool owned_buffer_; 63 | }; 64 | 65 | } // namespace treelite 66 | 67 | #include 68 | 69 | #endif // TREELITE_CONTIGUOUS_ARRAY_H_ 70 | -------------------------------------------------------------------------------- /src/model_loader/detail/xgboost_json/sax_adapters.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2024 by Contributors 3 | * \file sax_adapters.h 4 | * \brief Adapters to connect RapidJSON and nlohmann/json with the delegated handler 5 | * \author Hyunsu Cho 6 | */ 7 | 8 | #ifndef SRC_MODEL_LOADER_DETAIL_XGBOOST_JSON_SAX_ADAPTERS_H_ 9 | #define SRC_MODEL_LOADER_DETAIL_XGBOOST_JSON_SAX_ADAPTERS_H_ 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include 20 | 21 | namespace treelite::model_loader::detail::xgboost { 22 | 23 | class DelegatedHandler; 24 | 25 | /*! 26 | * \brief Adapter for SAX parser from RapidJSON 27 | */ 28 | class RapidJSONAdapter { 29 | public: 30 | explicit RapidJSONAdapter(std::shared_ptr handler) 31 | : handler_{std::move(handler)} {} 32 | bool Null(); 33 | bool Bool(bool b); 34 | bool Int(int i); 35 | bool Uint(unsigned u); 36 | bool Int64(std::int64_t i); 37 | bool Uint64(std::uint64_t u); 38 | bool Double(double d); 39 | bool RawNumber(char const* str, std::size_t length, bool copy); 40 | bool String(char const* str, std::size_t length, bool copy); 41 | bool StartObject(); 42 | bool Key(char const* str, std::size_t length, bool copy); 43 | bool EndObject(std::size_t); 44 | bool StartArray(); 45 | bool EndArray(std::size_t); 46 | 47 | private: 48 | std::shared_ptr handler_; 49 | }; 50 | 51 | /*! 52 | * \brief Adapter for SAX parser from nlohmann/json 53 | */ 54 | class NlohmannJSONAdapter { 55 | public: 56 | explicit NlohmannJSONAdapter(std::shared_ptr handler) 57 | : handler_{std::move(handler)} {} 58 | bool null(); 59 | bool boolean(bool val); 60 | bool number_integer(std::int64_t val); 61 | bool number_unsigned(std::uint64_t val); 62 | bool number_float(double val, std::string const&); 63 | bool string(std::string& val); 64 | bool binary(nlohmann::json::binary_t& val); 65 | bool start_object(std::size_t); 66 | bool end_object(); 67 | bool start_array(std::size_t); 68 | bool end_array(); 69 | bool key(std::string& val); 70 | bool parse_error( 71 | std::size_t position, std::string const& last_token, nlohmann::json::exception const& ex); 72 | 73 | private: 74 | std::shared_ptr handler_; 75 | }; 76 | 77 | } // namespace treelite::model_loader::detail::xgboost 78 | 79 | #endif // SRC_MODEL_LOADER_DETAIL_XGBOOST_JSON_SAX_ADAPTERS_H_ 80 | -------------------------------------------------------------------------------- /src/model_loader/detail/lightgbm.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2020-2023 by Contributors 3 | * \file lightgbm.h 4 | * \brief Helper functions for loading LightGBM models 5 | * \author Hyunsu Cho 6 | */ 7 | #ifndef SRC_MODEL_LOADER_DETAIL_LIGHTGBM_H_ 8 | #define SRC_MODEL_LOADER_DETAIL_LIGHTGBM_H_ 9 | 10 | #include 11 | 12 | #include 13 | 14 | #include "./string_utils.h" 15 | 16 | namespace treelite::model_loader::detail::lightgbm { 17 | 18 | /*! 19 | * \brief Canonicalize the name of an objective function. 20 | * 21 | * Some objective functions have many aliases. We use the canonical name to avoid confusion. 22 | * 23 | * @param obj_name Name of an objective function 24 | * @return Canonical name 25 | */ 26 | std::string CanonicalObjective(std::string const& obj_name) { 27 | if (obj_name == "regression" || obj_name == "regression_l2" || obj_name == "l2" 28 | || obj_name == "mean_squared_error" || obj_name == "mse" || obj_name == "l2_root" 29 | || obj_name == "root_mean_squared_error" || obj_name == "rmse") { 30 | return "regression"; 31 | } else if (obj_name == "regression_l1" || obj_name == "l1" || obj_name == "mean_absolute_error" 32 | || obj_name == "mae") { 33 | return "regression_l1"; 34 | } else if (obj_name == "mape" || obj_name == "mean_absolute_percentage_error") { 35 | return "mape"; 36 | } else if (obj_name == "multiclass" || obj_name == "softmax") { 37 | return "multiclass"; 38 | } else if (obj_name == "multiclassova" || obj_name == "multiclass_ova" || obj_name == "ova" 39 | || obj_name == "ovr") { 40 | return "multiclassova"; 41 | } else if (obj_name == "cross_entropy" || obj_name == "xentropy") { 42 | return "cross_entropy"; 43 | } else if (obj_name == "cross_entropy_lambda" || obj_name == "xentlambda") { 44 | return "cross_entropy_lambda"; 45 | } else if (obj_name == "rank_xendcg" || obj_name == "xendcg" || obj_name == "xe_ndcg" 46 | || obj_name == "xe_ndcg_mart" || obj_name == "xendcg_mart") { 47 | return "rank_xendcg"; 48 | } else if (obj_name == "huber" || obj_name == "fair" || obj_name == "poisson" 49 | || obj_name == "quantile" || obj_name == "gamma" || obj_name == "tweedie" 50 | || obj_name == "binary" || obj_name == "lambdarank" || obj_name == "custom") { 51 | // These objectives have no aliases 52 | return obj_name; 53 | } else { 54 | TREELITE_LOG(FATAL) << "Unknown objective name: \"" << obj_name << "\""; 55 | return ""; 56 | } 57 | } 58 | 59 | } // namespace treelite::model_loader::detail::lightgbm 60 | 61 | #endif // SRC_MODEL_LOADER_DETAIL_LIGHTGBM_H_ 62 | -------------------------------------------------------------------------------- /docs/treelite-c-api.rst: -------------------------------------------------------------------------------- 1 | ============== 2 | Treelite C API 3 | ============== 4 | 5 | Treelite exposes a set of C functions to enable interfacing with a variety of 6 | languages. This page will be most useful for: 7 | 8 | * those writing a new 9 | `language binding `_ (glue 10 | code). 11 | * those wanting to incorporate functions of Treelite into their own native 12 | libraries. 13 | 14 | **We recommend the Python API for everyday uses.** 15 | 16 | .. note:: Use of C and C++ in Treelite 17 | 18 | Core logic of Treelite are written in C++ to take advantage of higher 19 | abstractions. We provide C only interface here, as many more programming 20 | languages bind with C than with C++. See 21 | `this page `_ for 22 | more details. 23 | 24 | .. contents:: Contents 25 | :local: 26 | 27 | Model loader interface 28 | ---------------------- 29 | Use the following functions to load decision tree ensemble models from a file. 30 | Treelite supports multiple model file formats. 31 | 32 | .. doxygengroup:: model_loader 33 | :project: treelite 34 | :content-only: 35 | 36 | Model loader interface for scikit-learn models 37 | ---------------------------------------------- 38 | Use the following functions to load decision tree ensemble models from a scikit-learn 39 | model object. 40 | 41 | .. doxygengroup:: sklearn 42 | :project: treelite 43 | :content-only: 44 | 45 | Model builder interface 46 | ----------------------- 47 | Use the following functions to incrementally build decisio n tree ensemble 48 | models. 49 | 50 | .. doxygengroup:: model_builder 51 | :project: treelite 52 | :content-only: 53 | 54 | Model manager interface 55 | ----------------------- 56 | 57 | .. doxygengroup:: model_manager 58 | :project: treelite 59 | :content-only: 60 | 61 | Serializer 62 | ---------- 63 | 64 | .. doxygengroup:: serializer 65 | :project: treelite 66 | :content-only: 67 | 68 | Getters and setters for the model object 69 | ---------------------------------------- 70 | 71 | .. doxygengroup:: accessor 72 | :project: treelite 73 | :content-only: 74 | 75 | General Tree Inference Library (GTIL) 76 | ------------------------------------- 77 | 78 | .. doxygengroup:: gtil 79 | :project: treelite 80 | :content-only: 81 | 82 | Handle types 83 | ------------ 84 | Treelite uses C++ classes to define its internal data structures. In order to 85 | pass C++ objects to C functions, *opaque handles* are used. Opaque handles 86 | are ``void*`` pointers that store raw memory addresses. 87 | 88 | .. doxygengroup:: opaque_handles 89 | :project: treelite 90 | :content-only: 91 | -------------------------------------------------------------------------------- /.github/workflows/coverage-tests.yml: -------------------------------------------------------------------------------- 1 | name: coverage-tests 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - mainline 8 | - 'release_*' 9 | schedule: 10 | - cron: "0 7 * * *" # Run once daily 11 | 12 | permissions: 13 | contents: read # to fetch code (actions/checkout) 14 | 15 | defaults: 16 | run: 17 | shell: bash -l {0} 18 | 19 | concurrency: 20 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 21 | cancel-in-progress: true 22 | 23 | env: 24 | CODECOV_TOKEN: afe9868c-2c27-4853-89fa-4bc5d3d2b255 25 | 26 | jobs: 27 | cpp-python-coverage: 28 | name: Run Python and C++ tests with test coverage (Linux) 29 | runs-on: ubuntu-latest 30 | steps: 31 | - uses: actions/checkout@v4 32 | - uses: conda-incubator/setup-miniconda@v3 33 | with: 34 | miniforge-variant: Miniforge3 35 | miniforge-version: latest 36 | conda-remove-defaults: "true" 37 | activate-environment: dev 38 | environment-file: ops/conda_env/dev.yml 39 | use-mamba: true 40 | - name: Display Conda env 41 | run: | 42 | conda info 43 | conda list 44 | - name: Run tests with test coverage computation 45 | run: | 46 | bash ops/cpp-python-coverage.sh 47 | win-python-coverage: 48 | name: Run Python and C++ tests with test coverage (Windows) 49 | runs-on: windows-latest 50 | steps: 51 | - uses: actions/checkout@v4 52 | - uses: conda-incubator/setup-miniconda@v3 53 | with: 54 | miniforge-variant: Miniforge3 55 | miniforge-version: latest 56 | conda-remove-defaults: "true" 57 | activate-environment: dev 58 | environment-file: ops/conda_env/dev.yml 59 | use-mamba: true 60 | - name: Display Conda env 61 | shell: cmd /C CALL {0} 62 | run: | 63 | conda info 64 | conda list 65 | - name: Run tests with test coverage computation 66 | shell: cmd /C CALL {0} 67 | run: | 68 | call ops/win-python-coverage.bat 69 | macos-python-coverage: 70 | name: Run Python and C++ tests with test coverage (MacOS) 71 | runs-on: macos-latest 72 | steps: 73 | - uses: actions/checkout@v4 74 | - uses: conda-incubator/setup-miniconda@v3 75 | with: 76 | miniforge-variant: Miniforge3 77 | miniforge-version: latest 78 | conda-remove-defaults: "true" 79 | activate-environment: dev 80 | environment-file: ops/conda_env/dev.yml 81 | use-mamba: true 82 | - name: Display Conda env 83 | run: | 84 | conda info 85 | conda list 86 | - name: Run tests with test coverage computation 87 | run: | 88 | bash ops/macos-python-coverage.sh 89 | -------------------------------------------------------------------------------- /tests/serializer/test_serializer.py: -------------------------------------------------------------------------------- 1 | """Test for serialization, via buffer protocol""" 2 | 3 | import ctypes 4 | from typing import List 5 | 6 | import numpy as np 7 | import pytest 8 | from sklearn.datasets import load_iris 9 | from sklearn.ensemble import ( 10 | ExtraTreesClassifier, 11 | GradientBoostingClassifier, 12 | RandomForestClassifier, 13 | ) 14 | 15 | import treelite 16 | from treelite.core import _LIB, _check_call 17 | from treelite.model import _numpy2pybuffer, _pybuffer2numpy, _TreelitePyBufferFrame 18 | from treelite.util import c_array 19 | 20 | 21 | def treelite_deserialize(frames: List[np.ndarray]) -> treelite.Model: 22 | """Serialize model to PyBuffer frames""" 23 | buffers = [_numpy2pybuffer(frame) for frame in frames] 24 | handle = ctypes.c_void_p() 25 | _check_call( 26 | _LIB.TreeliteDeserializeModelFromPyBuffer( 27 | c_array(_TreelitePyBufferFrame, buffers), 28 | ctypes.c_size_t(len(buffers)), 29 | ctypes.byref(handle), 30 | ) 31 | ) 32 | return treelite.Model(handle=handle) 33 | 34 | 35 | def treelite_serialize( 36 | model: treelite.Model, 37 | ) -> List[np.ndarray]: 38 | """Deserialize model from PyBuffer frames""" 39 | frames = ctypes.POINTER(_TreelitePyBufferFrame)() 40 | n_frames = ctypes.c_size_t() 41 | _check_call( 42 | _LIB.TreeliteSerializeModelToPyBuffer( 43 | model.handle, 44 | ctypes.byref(frames), 45 | ctypes.byref(n_frames), 46 | ) 47 | ) 48 | return [_pybuffer2numpy(frames[i]) for i in range(n_frames.value)] 49 | 50 | 51 | @pytest.mark.parametrize( 52 | "clazz", [RandomForestClassifier, ExtraTreesClassifier, GradientBoostingClassifier] 53 | ) 54 | def test_serialize_as_buffer(clazz): 55 | """Test whether Treelite objects can be serialized to a buffer""" 56 | X, y = load_iris(return_X_y=True) 57 | params = {"max_depth": 5, "random_state": 0, "n_estimators": 10} 58 | if clazz == GradientBoostingClassifier: 59 | params["init"] = "zero" 60 | clf = clazz(**params) 61 | clf.fit(X, y) 62 | expected_prob = clf.predict_proba(X).reshape((X.shape[0], 1, -1)) 63 | 64 | # Prediction should be correct after a round-trip 65 | tl_model = treelite.sklearn.import_model(clf) 66 | frames = treelite_serialize(tl_model) 67 | tl_model2 = treelite_deserialize(frames) 68 | out_prob = treelite.gtil.predict(tl_model2, X) 69 | np.testing.assert_almost_equal(out_prob, expected_prob, decimal=5) 70 | 71 | # The model should serialize to the same byte sequence after a round-trip 72 | frames2 = treelite_serialize(tl_model2) 73 | assert len(frames) == len(frames2) 74 | for x, y in zip(frames, frames2): 75 | assert np.array_equal(x, y) 76 | -------------------------------------------------------------------------------- /src/model_builder/metadata.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file metadata.cc 4 | * \brief C++ API for constructing Model metadata 5 | * \author Hyunsu Cho 6 | */ 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | 18 | namespace treelite::model_builder { 19 | 20 | TreeAnnotation::TreeAnnotation(std::int32_t num_tree, std::vector const& target_id, 21 | std::vector const& class_id) 22 | : num_tree{num_tree}, target_id{target_id}, class_id{class_id} { 23 | TREELITE_CHECK_EQ(target_id.size(), num_tree) 24 | << "target_id field must have length equal to num_tree (" << num_tree << ")"; 25 | TREELITE_CHECK_EQ(class_id.size(), num_tree) 26 | << "class_id field must have length equal to num_tree (" << num_tree << ")"; 27 | } 28 | 29 | PostProcessorFunc::PostProcessorFunc(std::string const& name) : PostProcessorFunc(name, {}) {} 30 | 31 | PostProcessorFunc::PostProcessorFunc( 32 | std::string const& name, std::map const& config) 33 | : name(name), config(config) {} 34 | 35 | Metadata::Metadata(std::int32_t num_feature, TaskType task_type, bool average_tree_output, 36 | std::int32_t num_target, std::vector const& num_class, 37 | std::array const& leaf_vector_shape) 38 | : num_feature(num_feature), 39 | task_type(task_type), 40 | average_tree_output(average_tree_output), 41 | num_target(num_target), 42 | num_class(num_class), 43 | leaf_vector_shape(leaf_vector_shape) { 44 | TREELITE_CHECK_GT(num_target, 0) << "num_target must be at least 1"; 45 | TREELITE_CHECK_EQ(num_class.size(), num_target) 46 | << "num_class field must have length equal to num_target (" << num_target << ")"; 47 | if (!std::all_of(num_class.begin(), num_class.end(), [](std::int32_t e) { return e >= 1; })) { 48 | TREELITE_LOG(FATAL) << "All elements in num_class field must be at least 1."; 49 | } 50 | TREELITE_CHECK(leaf_vector_shape[0] == 1 || leaf_vector_shape[0] == num_target) 51 | << "leaf_vector_shape[0] must be either 1 or num_target (" << num_target << "). " 52 | << "Currently given: leaf_vector_shape[1] = " << leaf_vector_shape[1]; 53 | std::int32_t const max_num_class = *std::max_element(num_class.begin(), num_class.end()); 54 | TREELITE_CHECK(leaf_vector_shape[1] == 1 || leaf_vector_shape[1] == max_num_class) 55 | << "leaf_vector_shape[1] must be either 1 or max_num_class (" << max_num_class << "). " 56 | << "Currently given: leaf_vector_shape[1] = " << leaf_vector_shape[1]; 57 | } 58 | 59 | } // namespace treelite::model_builder 60 | -------------------------------------------------------------------------------- /python/treelite/libpath.py: -------------------------------------------------------------------------------- 1 | """Find the path to Treelite dynamic library files.""" 2 | 3 | import os 4 | import pathlib 5 | import sys 6 | from typing import List 7 | 8 | from .path_config import get_custom_libpath 9 | 10 | 11 | class TreeliteLibraryNotFound(Exception): 12 | """Error thrown by when Treelite is not found""" 13 | 14 | 15 | def find_lib_path() -> List[pathlib.Path]: 16 | """Find the path to Treelite dynamic library files. 17 | 18 | Returns 19 | ------- 20 | lib_path 21 | List of all found library path to Treelite 22 | """ 23 | curr_path = pathlib.Path(__file__).expanduser().absolute().parent 24 | dll_path = [ 25 | # When installed, libtreelite will be installed in /lib 26 | curr_path / "lib", 27 | # Editable installation 28 | curr_path.parent.parent / "build", 29 | # Use libtreelite from a system prefix, if available. This should be the last option. 30 | pathlib.Path(sys.base_prefix).expanduser().resolve() / "lib", 31 | ] 32 | custom_libpath = get_custom_libpath() # pylint: disable=assignment-from-none 33 | if custom_libpath: 34 | dll_path.insert(0, pathlib.Path(custom_libpath).expanduser().resolve()) 35 | 36 | if sys.platform == "win32": 37 | # On Windows, Conda may install libs in different paths 38 | sys_prefix = pathlib.Path(sys.base_prefix) 39 | dll_path.extend( 40 | [ 41 | sys_prefix / "bin", 42 | sys_prefix / "Library", 43 | sys_prefix / "Library" / "bin", 44 | sys_prefix / "Library" / "lib", 45 | ] 46 | ) 47 | dll_path = [p.joinpath("treelite.dll") for p in dll_path] 48 | elif sys.platform.startswith(("linux", "freebsd", "emscripten", "OS400")): 49 | dll_path = [p.joinpath("libtreelite.so") for p in dll_path] 50 | elif sys.platform == "darwin": 51 | dll_path = [p.joinpath("libtreelite.dylib") for p in dll_path] 52 | elif sys.platform == "cygwin": 53 | dll_path = [p.joinpath("cygtreelite.dll") for p in dll_path] 54 | else: 55 | raise RuntimeError(f"Unrecognized platform: {sys.platform}") 56 | 57 | lib_path = [p for p in dll_path if p.exists() and p.is_file()] 58 | 59 | # TREELITE_BUILD_DOC is defined by sphinx conf. 60 | if not lib_path and not os.environ.get("TREELITE_BUILD_DOC", False): 61 | link = "https://treelite.readthedocs.io/en/latest/install.html" 62 | msg = ( 63 | "Cannot find Treelite Library in the candidate path. " 64 | + "List of candidates:\n- " 65 | + ("\n- ".join(str(x) for x in dll_path)) 66 | + "\nTreelite Python package path: " 67 | + str(curr_path) 68 | + "\nsys.base_prefix: " 69 | + sys.base_prefix 70 | + "\nSee: " 71 | + link 72 | + " for installing Treelite." 73 | ) 74 | raise TreeliteLibraryNotFound(msg) 75 | return lib_path 76 | -------------------------------------------------------------------------------- /src/c_api/model.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file model.cc 4 | * \author Hyunsu Cho 5 | * \brief C API for functions to query and modify model objects 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include "./c_api_utils.h" 17 | 18 | int TreeliteDumpAsJSON(TreeliteModelHandle handle, int pretty_print, char const** out_json_str) { 19 | API_BEGIN(); 20 | auto* model_ = static_cast(handle); 21 | std::string& ret_str = treelite::c_api::ReturnValueStore::Get()->ret_str; 22 | ret_str = model_->DumpAsJSON(pretty_print != 0); 23 | *out_json_str = ret_str.c_str(); 24 | API_END(); 25 | } 26 | 27 | int TreeliteGetInputType(TreeliteModelHandle model, char const** out_str) { 28 | API_BEGIN(); 29 | auto const* model_ = static_cast(model); 30 | auto& type_str = treelite::c_api::ReturnValueStore::Get()->ret_str; 31 | type_str = treelite::TypeInfoToString(model_->GetThresholdType()); 32 | *out_str = type_str.c_str(); 33 | API_END(); 34 | } 35 | 36 | int TreeliteGetOutputType(TreeliteModelHandle model, char const** out_str) { 37 | API_BEGIN(); 38 | auto const* model_ = static_cast(model); 39 | auto& type_str = treelite::c_api::ReturnValueStore::Get()->ret_str; 40 | type_str = treelite::TypeInfoToString(model_->GetLeafOutputType()); 41 | *out_str = type_str.c_str(); 42 | API_END(); 43 | } 44 | 45 | int TreeliteQueryNumTree(TreeliteModelHandle model, std::size_t* out) { 46 | API_BEGIN(); 47 | auto const* model_ = static_cast(model); 48 | *out = model_->GetNumTree(); 49 | API_END(); 50 | } 51 | 52 | int TreeliteQueryNumFeature(TreeliteModelHandle model, int* out) { 53 | API_BEGIN(); 54 | auto const* model_ = static_cast(model); 55 | *out = model_->num_feature; 56 | API_END(); 57 | } 58 | 59 | int TreeliteConcatenateModelObjects( 60 | TreeliteModelHandle const* objs, std::size_t len, TreeliteModelHandle* out) { 61 | API_BEGIN(); 62 | std::vector model_objs(len, nullptr); 63 | std::transform(objs, objs + len, model_objs.begin(), 64 | [](TreeliteModelHandle e) { return static_cast(e); }); 65 | auto concatenated_model = ConcatenateModelObjects(model_objs); 66 | *out = static_cast(concatenated_model.release()); 67 | API_END(); 68 | } 69 | 70 | int TreeliteFreeModel(TreeliteModelHandle handle) { 71 | API_BEGIN(); 72 | delete static_cast(handle); 73 | API_END(); 74 | } 75 | 76 | int TreeliteGetTreeDepth(TreeliteModelHandle model, std::uint32_t** out, std::size_t* out_len) { 77 | API_BEGIN(); 78 | auto* model_ = static_cast(model); 79 | auto& ret_depth = treelite::c_api::ReturnValueStore::Get()->ret_uint32_vec; 80 | ret_depth = model_->GetTreeDepth(); 81 | *out = ret_depth.data(); 82 | *out_len = ret_depth.size(); 83 | API_END(); 84 | } 85 | -------------------------------------------------------------------------------- /src/c_api/serializer.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file serializer.cc 4 | * \author Hyunsu Cho 5 | * \brief C API for functions to serialize model objects 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "./c_api_utils.h" 19 | 20 | int TreeliteSerializeModelToFile(TreeliteModelHandle handle, char const* filename) { 21 | API_BEGIN(); 22 | std::ofstream ofs = treelite::detail::OpenFileForWriteAsStream(filename); 23 | auto* model_ = static_cast(handle); 24 | model_->SerializeToStream(ofs); 25 | API_END(); 26 | } 27 | 28 | int TreeliteDeserializeModelFromFile(char const* filename, TreeliteModelHandle* out) { 29 | API_BEGIN(); 30 | std::ifstream ifs = treelite::detail::OpenFileForReadAsStream(filename); 31 | std::unique_ptr model = treelite::Model::DeserializeFromStream(ifs); 32 | *out = static_cast(model.release()); 33 | API_END(); 34 | } 35 | 36 | int TreeliteSerializeModelToBytes( 37 | TreeliteModelHandle handle, char const** out_bytes, std::size_t* out_bytes_len) { 38 | API_BEGIN(); 39 | std::ostringstream oss; 40 | oss.exceptions(std::ios::failbit | std::ios::badbit); // Throw exception on failure 41 | auto* model_ = static_cast(handle); 42 | model_->SerializeToStream(oss); 43 | 44 | std::string& ret_str = treelite::c_api::ReturnValueStore::Get()->ret_str; 45 | ret_str = oss.str(); 46 | *out_bytes = ret_str.data(); 47 | *out_bytes_len = ret_str.length(); 48 | API_END(); 49 | } 50 | 51 | int TreeliteDeserializeModelFromBytes( 52 | char const* bytes, std::size_t bytes_len, TreeliteModelHandle* out) { 53 | API_BEGIN(); 54 | std::istringstream iss(std::string(bytes, bytes_len)); 55 | iss.exceptions(std::ios::failbit | std::ios::badbit); // Throw exception on failure 56 | std::unique_ptr model = treelite::Model::DeserializeFromStream(iss); 57 | *out = static_cast(model.release()); 58 | API_END(); 59 | } 60 | 61 | int TreeliteSerializeModelToPyBuffer( 62 | TreeliteModelHandle handle, TreelitePyBufferFrame** out_frames, size_t* out_num_frames) { 63 | API_BEGIN(); 64 | auto* model_ = static_cast(handle); 65 | std::vector& ret_frames 66 | = treelite::c_api::ReturnValueStore::Get()->ret_frames; 67 | ret_frames = model_->SerializeToPyBuffer(); 68 | if (ret_frames.empty()) { 69 | *out_frames = nullptr; 70 | *out_num_frames = 0; 71 | } else { 72 | *out_frames = &ret_frames[0]; 73 | *out_num_frames = ret_frames.size(); 74 | } 75 | API_END(); 76 | } 77 | 78 | int TreeliteDeserializeModelFromPyBuffer( 79 | TreelitePyBufferFrame* frames, size_t num_frames, TreeliteModelHandle* out) { 80 | API_BEGIN(); 81 | std::vector frames_(frames, frames + num_frames); 82 | auto model = treelite::Model::DeserializeFromPyBuffer(frames_); 83 | *out = static_cast(model.release()); 84 | API_END(); 85 | } 86 | -------------------------------------------------------------------------------- /src/model_loader/xgboost_ubjson.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2024 by Contributors 3 | * \file xgboost_ubjson.cc 4 | * \brief Model loader for XGBoost model (UBJSON) 5 | * \author Hyunsu Cho 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | 19 | #include "detail/xgboost_json/delegated_handler.h" 20 | #include "detail/xgboost_json/sax_adapters.h" 21 | 22 | namespace { 23 | 24 | template 25 | std::unique_ptr ParseStream( 26 | InputType&& input_stream, nlohmann::json const& parsed_config); 27 | 28 | } // anonymous namespace 29 | 30 | namespace treelite::model_loader { 31 | 32 | std::unique_ptr LoadXGBoostModelUBJSON( 33 | std::string const& filename, std::string const& config_json) { 34 | nlohmann::json parsed_config = nlohmann::json::parse(config_json); 35 | std::ifstream ifs = treelite::detail::OpenFileForReadAsStream(filename); 36 | return ParseStream(ifs, parsed_config); 37 | } 38 | 39 | std::unique_ptr LoadXGBoostModelFromUBJSONString( 40 | std::string_view ubjson_str, std::string const& config_json) { 41 | nlohmann::json parsed_config = nlohmann::json::parse(config_json); 42 | return ParseStream(ubjson_str, parsed_config); 43 | } 44 | 45 | } // namespace treelite::model_loader 46 | 47 | namespace { 48 | 49 | template 50 | std::unique_ptr ParseStream( 51 | InputType&& input_stream, nlohmann::json const& parsed_config) { 52 | treelite::model_loader::detail::xgboost::HandlerConfig handler_config; 53 | if (parsed_config.is_object()) { 54 | auto itr = parsed_config.find("allow_unknown_field"); 55 | if (itr != parsed_config.end() && itr->is_boolean()) { 56 | handler_config.allow_unknown_field = itr->template get(); 57 | } 58 | } 59 | 60 | std::shared_ptr handler 61 | = treelite::model_loader::detail::xgboost::DelegatedHandler::create(handler_config); 62 | auto adapter 63 | = std::make_unique(handler); 64 | TREELITE_CHECK(nlohmann::json::sax_parse( 65 | input_stream, adapter.get(), nlohmann::json::input_format_t::ubjson)); 66 | 67 | treelite::model_loader::detail::xgboost::ParsedXGBoostModel parsed = handler->get_result(); 68 | auto model = parsed.builder->CommitModel(); 69 | 70 | // Apply Dart weights 71 | if (!parsed.weight_drop.empty()) { 72 | auto& trees = std::get>(model->variant_).trees; 73 | TREELITE_CHECK_EQ(trees.size(), parsed.weight_drop.size()); 74 | for (std::size_t i = 0; i < trees.size(); ++i) { 75 | for (int nid = 0; nid < trees[i].num_nodes; ++nid) { 76 | if (trees[i].IsLeaf(nid)) { 77 | trees[i].SetLeaf( 78 | nid, static_cast(parsed.weight_drop[i] * trees[i].LeafValue(nid))); 79 | } 80 | } 81 | } 82 | } 83 | return model; 84 | } 85 | 86 | } // anonymous namespace 87 | -------------------------------------------------------------------------------- /include/treelite/detail/file_utils.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file file_utils.h 4 | * \brief Helper functions for manipulating files 5 | * \author Hyunsu Cho 6 | */ 7 | 8 | #ifndef TREELITE_DETAIL_FILE_UTILS_H_ 9 | #define TREELITE_DETAIL_FILE_UTILS_H_ 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | namespace treelite::detail { 19 | 20 | inline std::ifstream OpenFileForReadAsStream(std::filesystem::path const& filepath) { 21 | auto path = std::filesystem::weakly_canonical(filepath); 22 | TREELITE_CHECK(std::filesystem::exists(path)) << "Path " << filepath << " does not exist"; 23 | std::ifstream ifs(path, std::ios::in | std::ios::binary); 24 | TREELITE_CHECK(ifs) << "Could not open file " << filepath; 25 | ifs.exceptions(std::ios::badbit); // Throw exceptions on error 26 | // We don't throw on failbit, since we sometimes want to read 27 | // until the end of file in a loop of form `while 28 | // (std::getline(...))`, which will set failbit. 29 | return ifs; 30 | } 31 | 32 | inline std::ifstream OpenFileForReadAsStream(std::string const& filename) { 33 | return OpenFileForReadAsStream(std::filesystem::u8path(filename)); 34 | } 35 | 36 | inline std::ifstream OpenFileForReadAsStream(char const* filename) { 37 | return OpenFileForReadAsStream(std::string(filename)); 38 | } 39 | 40 | inline std::ofstream OpenFileForWriteAsStream(std::filesystem::path const& filepath) { 41 | auto path = std::filesystem::weakly_canonical(filepath); 42 | TREELITE_CHECK(path.has_filename()) << "Cannot write to a directory; please specify a file"; 43 | TREELITE_CHECK(std::filesystem::exists(path.parent_path())) 44 | << "Path " << path.parent_path() << " does not exist"; 45 | std::ofstream ofs(path, std::ios::out | std::ios::binary); 46 | TREELITE_CHECK(ofs) << "Could not open file " << filepath; 47 | ofs.exceptions(std::ios::failbit | std::ios::badbit); // Throw exceptions on error 48 | return ofs; 49 | } 50 | 51 | inline std::ofstream OpenFileForWriteAsStream(std::string const& filename) { 52 | return OpenFileForWriteAsStream(std::filesystem::u8path(filename)); 53 | } 54 | 55 | inline std::ofstream OpenFileForWriteAsStream(char const* filename) { 56 | return OpenFileForWriteAsStream(std::string(filename)); 57 | } 58 | 59 | inline FILE* OpenFileForReadAsFilePtr(std::filesystem::path const& filepath) { 60 | auto path = std::filesystem::weakly_canonical(filepath); 61 | TREELITE_CHECK(std::filesystem::exists(path)) << "Path " << filepath << " does not exist"; 62 | FILE* fp; 63 | #ifdef _WIN32 64 | fp = _wfopen(path.wstring().c_str(), L"rb"); 65 | #else 66 | fp = std::fopen(path.string().c_str(), "rb"); 67 | #endif 68 | TREELITE_CHECK(fp) << "Could not open file " << filepath; 69 | return fp; 70 | } 71 | 72 | inline FILE* OpenFileForReadAsFilePtr(std::string const& filename) { 73 | return OpenFileForReadAsFilePtr(std::filesystem::u8path(filename)); 74 | } 75 | 76 | inline FILE* OpenFileForReadAsFilePtr(char const* filename) { 77 | return OpenFileForReadAsFilePtr(std::string(filename)); 78 | } 79 | 80 | } // namespace treelite::detail 81 | 82 | #endif // TREELITE_DETAIL_FILE_UTILS_H_ 83 | -------------------------------------------------------------------------------- /tests/python/metadata.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Metadata for datasets and models used for testing""" 3 | 4 | import collections 5 | import os 6 | 7 | _current_dir = os.path.dirname(__file__) 8 | _dpath = os.path.abspath(os.path.join(_current_dir, os.path.pardir, "examples")) 9 | 10 | Dataset = collections.namedtuple( 11 | "Dataset", 12 | "model format dtrain dtest libname expected_prob expected_margin is_multiclass dtype", 13 | ) 14 | 15 | _dataset_db = { 16 | "mushroom": Dataset( 17 | model="mushroom.model", 18 | format="xgboost", 19 | dtrain="agaricus.train", 20 | dtest="agaricus.test", 21 | libname="agaricus", 22 | expected_prob="agaricus.test.prob", 23 | expected_margin="agaricus.test.margin", 24 | is_multiclass=False, 25 | dtype="float32", 26 | ), 27 | "dermatology": Dataset( 28 | model="dermatology.model", 29 | format="xgboost", 30 | dtrain="dermatology.train", 31 | dtest="dermatology.test", 32 | libname="dermatology", 33 | expected_prob="dermatology.test.prob", 34 | expected_margin="dermatology.test.margin", 35 | is_multiclass=True, 36 | dtype="float32", 37 | ), 38 | "letor": Dataset( 39 | model="mq2008.model", 40 | format="xgboost", 41 | dtrain="mq2008.train", 42 | dtest="mq2008.test", 43 | libname="letor", 44 | expected_prob=None, 45 | expected_margin="mq2008.test.pred", 46 | is_multiclass=False, 47 | dtype="float32", 48 | ), 49 | "toy_categorical": Dataset( 50 | model="toy_categorical_model.txt", 51 | format="lightgbm", 52 | dtrain=None, 53 | dtest="toy_categorical.test", 54 | libname="toycat", 55 | expected_prob=None, 56 | expected_margin="toy_categorical.test.pred", 57 | is_multiclass=False, 58 | dtype="float64", 59 | ), 60 | "sparse_categorical": Dataset( 61 | model="sparse_categorical_model.txt", 62 | format="lightgbm", 63 | dtrain=None, 64 | dtest="sparse_categorical.test", 65 | libname="sparsecat", 66 | expected_prob=None, 67 | expected_margin="sparse_categorical.test.margin", 68 | is_multiclass=False, 69 | dtype="float64", 70 | ), 71 | "xgb_toy_categorical": Dataset( 72 | model="xgb_toy_categorical_model.json", 73 | format="xgboost_json", 74 | dtrain=None, 75 | dtest="xgb_toy_categorical.test", 76 | libname="xgbtoycat", 77 | expected_prob=None, 78 | expected_margin="xgb_toy_categorical.test.pred", 79 | is_multiclass=False, 80 | dtype="float32", 81 | ), 82 | } 83 | 84 | 85 | def _qualify_path(prefix, path): 86 | if path is None: 87 | return None 88 | return os.path.join(_dpath, prefix, path) 89 | 90 | 91 | dataset_db = { 92 | k: v._replace( 93 | model=_qualify_path(k, v.model), 94 | dtrain=_qualify_path(k, v.dtrain), 95 | dtest=_qualify_path(k, v.dtest), 96 | expected_prob=_qualify_path(k, v.expected_prob), 97 | expected_margin=_qualify_path(k, v.expected_margin), 98 | ) 99 | for k, v in _dataset_db.items() 100 | } 101 | -------------------------------------------------------------------------------- /docs/tutorials/import.rst: -------------------------------------------------------------------------------- 1 | Importing tree ensemble models 2 | ============================== 3 | 4 | Since the scope of Treelite is limited to **prediction** only, one must use 5 | other machine learning packages to **train** decision tree ensemble models. In 6 | this document, we will show how to import an ensemble model that had been 7 | trained elsewhere. 8 | 9 | .. contents:: Contents 10 | :local: 11 | 12 | Importing XGBoost models 13 | ------------------------ 14 | 15 | **XGBoost** (`dmlc/xgboost `_) is a fast, 16 | scalable package for gradient boosting. Both Treelite and XGBoost are hosted 17 | by the DMLC (Distributed Machine Learning Community) group. 18 | 19 | Treelite plays well with XGBoost --- if you used XGBoost to train your ensemble 20 | model, you need only one line of code to import it. Depending on where your 21 | model is located, use :py:meth:`~treelite.frontend.from_xgboost`, 22 | :py:meth:`~treelite.frontend.load_xgboost_model`, or 23 | :py:meth:`~treelite.frontend.load_xgboost_model_legacy_binary`: 24 | 25 | * Load XGBoost model from a :py:class:`xgboost.Booster` object 26 | 27 | .. code-block:: python 28 | 29 | # bst = an object of type xgboost.Booster 30 | model = treelite.frontend.from_xgboost(bst) 31 | 32 | * Load XGBoost model from a model file 33 | 34 | .. code-block:: python 35 | 36 | # JSON format 37 | model = treelite.frontend.load_xgboost_model("my_model.json") 38 | # Legacy binary format 39 | model = treelite.frontend.load_xgboost_model_legacy_binary("my_model.model") 40 | 41 | Importing LightGBM models 42 | ------------------------- 43 | 44 | **LightGBM** (`Microsoft/LightGBM `_) is 45 | another well known machine learning package for gradient boosting. To import 46 | models generated by LightGBM, use the 47 | :py:meth:`~treelite.frontend.load_lightgbm_model` method: 48 | 49 | .. code-block:: python 50 | 51 | model = treelite.frontend.load_lightgbm_model("lightgbm_model.txt") 52 | 53 | Importing scikit-learn models 54 | ----------------------------- 55 | **Scikit-learn** (`scikit-learn/scikit-learn 56 | `_) is a Python machine learning 57 | package known for its versatility and ease of use. It supports a wide variety 58 | of models and algorithms. The following kinds of models can be imported into 59 | Treelite. 60 | 61 | * :py:class:`sklearn.ensemble.RandomForestRegressor` 62 | * :py:class:`sklearn.ensemble.RandomForestClassifier` 63 | * :py:class:`sklearn.ensemble.ExtraTreesRegressor` 64 | * :py:class:`sklearn.ensemble.ExtraTreesClassifier` 65 | * :py:class:`sklearn.ensemble.GradientBoostingRegressor` 66 | * :py:class:`sklearn.ensemble.GradientBoostingClassifier` 67 | * :py:class:`sklearn.ensemble.HistGradientBoostingRegressor` 68 | * :py:class:`sklearn.ensemble.HistGradientBoostingClassifier` 69 | * :py:class:`sklearn.ensemble.IsolationForest` 70 | 71 | To import scikit-learn models, use 72 | :py:meth:`treelite.sklearn.import_model`: 73 | 74 | .. code-block:: python 75 | 76 | # clf is the model object generated by scikit-learn 77 | import treelite.sklearn 78 | model = treelite.sklearn.import_model(clf) 79 | 80 | How about other packages? 81 | ------------------------- 82 | If you used other packages to train your ensemble model, you'd need to specify 83 | the model programmatically: 84 | 85 | * :doc:`/tutorials/builder` 86 | -------------------------------------------------------------------------------- /src/c_api/gtil.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file gtil.cc 4 | * \author Hyunsu Cho 5 | * \brief C API for functions for GTIL 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "./c_api_utils.h" 18 | 19 | int TreeliteGTILParseConfig(char const* config_json, TreeliteGTILConfigHandle* out) { 20 | API_BEGIN(); 21 | auto parsed_config = std::make_unique(config_json); 22 | *out = static_cast(parsed_config.release()); 23 | API_END(); 24 | } 25 | 26 | int TreeliteGTILDeleteConfig(TreeliteGTILConfigHandle handle) { 27 | API_BEGIN(); 28 | delete static_cast(handle); 29 | API_END(); 30 | } 31 | 32 | int TreeliteGTILGetOutputShape(TreeliteModelHandle model, std::uint64_t num_row, 33 | TreeliteGTILConfigHandle config, std::uint64_t const** out, std::uint64_t* out_ndim) { 34 | API_BEGIN(); 35 | auto const* model_ = static_cast(model); 36 | auto const* config_ = static_cast(config); 37 | auto& shape = treelite::c_api::ReturnValueStore::Get()->ret_uint64_vec; 38 | shape = treelite::gtil::GetOutputShape(*model_, num_row, *config_); 39 | *out = shape.data(); 40 | *out_ndim = shape.size(); 41 | API_END(); 42 | } 43 | 44 | int TreeliteGTILPredict(TreeliteModelHandle model, void const* input, char const* input_type, 45 | std::uint64_t num_row, void* output, TreeliteGTILConfigHandle config) { 46 | API_BEGIN(); 47 | auto const* model_ = static_cast(model); 48 | auto const* config_ = static_cast(config); 49 | std::string input_type_str = std::string(input_type); 50 | if (input_type_str == "float32") { 51 | treelite::gtil::Predict( 52 | *model_, static_cast(input), num_row, static_cast(output), *config_); 53 | } else if (input_type_str == "float64") { 54 | treelite::gtil::Predict(*model_, static_cast(input), num_row, 55 | static_cast(output), *config_); 56 | } else { 57 | TREELITE_LOG(FATAL) << "Unexpected type spec: " << input_type_str; 58 | } 59 | API_END(); 60 | } 61 | 62 | int TreeliteGTILPredictSparse(TreeliteModelHandle model, void const* data, char const* input_type, 63 | std::uint64_t const* col_ind, std::uint64_t const* row_ptr, std::uint64_t num_row, void* output, 64 | TreeliteGTILConfigHandle config) { 65 | API_BEGIN(); 66 | auto const* model_ = static_cast(model); 67 | auto const* config_ = static_cast(config); 68 | std::string input_type_str = std::string(input_type); 69 | if (input_type_str == "float32") { 70 | treelite::gtil::PredictSparse(*model_, static_cast(data), col_ind, row_ptr, 71 | num_row, static_cast(output), *config_); 72 | } else if (input_type_str == "float64") { 73 | treelite::gtil::PredictSparse(*model_, static_cast(data), col_ind, row_ptr, 74 | num_row, static_cast(output), *config_); 75 | } else { 76 | TREELITE_LOG(FATAL) << "Unexpected type spec: " << input_type_str; 77 | } 78 | API_END(); 79 | } 80 | -------------------------------------------------------------------------------- /dev/change_version.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script changes the version field in different parts of the code base. 3 | """ 4 | 5 | import argparse 6 | import pathlib 7 | import re 8 | from typing import Optional, TypeVar 9 | 10 | R = TypeVar("R") 11 | ROOT = pathlib.Path(__file__).parent.parent.expanduser().resolve() 12 | PY_PACKAGE = ROOT / "python" 13 | 14 | 15 | def update_cmake(major: int, minor: int, patch: int) -> None: 16 | """Change version in CMakeLists.txt""" 17 | version = f"{major}.{minor}.{patch}" 18 | with open(ROOT / "CMakeLists.txt", "r", encoding="utf-8") as fd: 19 | cmakelist = fd.read() 20 | pattern = r"project\(treelite LANGUAGES .* VERSION ([0-9]+\.[0-9]+\.[0-9]+)\)" 21 | matched = re.search(pattern, cmakelist) 22 | assert matched, "Couldn't find the version string in CMakeLists.txt." 23 | cmakelist = cmakelist[: matched.start(1)] + version + cmakelist[matched.end(1) :] 24 | with open(ROOT / "CMakeLists.txt", "w", encoding="utf-8") as fd: 25 | fd.write(cmakelist) 26 | 27 | 28 | def update_pypkg( 29 | major: int, 30 | minor: int, 31 | patch: int, 32 | *, 33 | is_rc: bool, 34 | is_dev: bool, 35 | rc_ver: Optional[int] = None, 36 | ) -> None: 37 | """Change version in the Python package""" 38 | version = f"{major}.{minor}.{patch}" 39 | if is_rc: 40 | assert rc_ver 41 | version = version + f"rc{rc_ver}" 42 | if is_dev: 43 | version = version + ".dev0" 44 | 45 | pyver_path = PY_PACKAGE / "treelite" / "VERSION" 46 | with open(pyver_path, "w", encoding="utf-8") as fd: 47 | fd.write(version + "\n") 48 | 49 | pyprj_path = PY_PACKAGE / "pyproject.toml" 50 | with open(pyprj_path, "r", encoding="utf-8") as fd: 51 | pyprj = fd.read() 52 | matched = re.search('version = "' + r"([0-9]+\.[0-9]+\.[0-9]+.*)" + '"', pyprj) 53 | assert matched, "Couldn't find version string in pyproject.toml." 54 | pyprj = pyprj[: matched.start(1)] + version + pyprj[matched.end(1) :] 55 | with open(pyprj_path, "w", encoding="utf-8") as fd: 56 | fd.write(pyprj) 57 | 58 | 59 | def main(args: argparse.Namespace) -> None: 60 | """Perform version change in all relevant parts of the code base.""" 61 | if args.is_rc and args.is_dev: 62 | raise ValueError("A release version cannot be both RC and dev.") 63 | if args.is_rc: 64 | assert args.rc is not None, "rc field must be specified if is_rc is specified" 65 | assert args.rc >= 1, "RC version must start from 1." 66 | else: 67 | assert args.rc is None, "is_rc must be specified in order to specify rc field" 68 | update_cmake(args.major, args.minor, args.patch) 69 | update_pypkg( 70 | args.major, 71 | args.minor, 72 | args.patch, 73 | is_rc=args.is_rc, 74 | is_dev=args.is_dev, 75 | rc_ver=args.rc, 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument("--major", type=int, required=True) 82 | parser.add_argument("--minor", type=int, required=True) 83 | parser.add_argument("--patch", type=int, required=True) 84 | parser.add_argument("--rc", type=int) 85 | parser.add_argument("--is-rc", type=int, choices=[0, 1], default=0) 86 | parser.add_argument("--is-dev", type=int, choices=[0, 1], default=0) 87 | parsed_args = parser.parse_args() 88 | main(parsed_args) 89 | -------------------------------------------------------------------------------- /include/treelite/detail/serializer_mixins.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file serializer_mixins.h 4 | * \brief Mix-in classes for serializers 5 | * \author Hyunsu Cho 6 | */ 7 | 8 | #ifndef TREELITE_DETAIL_SERIALIZER_MIXINS_H_ 9 | #define TREELITE_DETAIL_SERIALIZER_MIXINS_H_ 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | #include 18 | #include 19 | 20 | namespace treelite::detail::serializer { 21 | 22 | class StreamSerializerMixIn { 23 | public: 24 | explicit StreamSerializerMixIn(std::ostream& os) : os_(os) {} 25 | 26 | template 27 | void SerializeScalar(T* field) { 28 | WriteScalarToStream(field, os_); 29 | } 30 | 31 | void SerializeString(std::string* field) { 32 | WriteStringToStream(field, os_); 33 | } 34 | 35 | template 36 | void SerializeArray(ContiguousArray* field) { 37 | WriteArrayToStream(field, os_); 38 | } 39 | 40 | private: 41 | std::ostream& os_; 42 | }; 43 | 44 | class StreamDeserializerMixIn { 45 | public: 46 | explicit StreamDeserializerMixIn(std::istream& is) : is_(is) {} 47 | 48 | template 49 | void DeserializeScalar(T* field) { 50 | ReadScalarFromStream(field, is_); 51 | } 52 | 53 | void DeserializeString(std::string* field) { 54 | ReadStringFromStream(field, is_); 55 | } 56 | 57 | template 58 | void DeserializeArray(ContiguousArray* field) { 59 | ReadArrayFromStream(field, is_); 60 | } 61 | 62 | void SkipOptionalField() { 63 | SkipOptionalFieldInStream(is_); 64 | } 65 | 66 | private: 67 | std::istream& is_; 68 | }; 69 | 70 | class PyBufferSerializerMixIn { 71 | public: 72 | PyBufferSerializerMixIn() = default; 73 | 74 | template 75 | void SerializeScalar(T* field) { 76 | frames_.push_back(GetPyBufferFromScalar(field)); 77 | } 78 | 79 | void SerializeString(std::string* field) { 80 | frames_.push_back(GetPyBufferFromString(field)); 81 | } 82 | 83 | template 84 | void SerializeArray(ContiguousArray* field) { 85 | frames_.push_back(GetPyBufferFromArray(field)); 86 | } 87 | 88 | std::vector GetFrames() { 89 | return frames_; 90 | } 91 | 92 | private: 93 | std::vector frames_; 94 | }; 95 | 96 | class PyBufferDeserializerMixIn { 97 | public: 98 | explicit PyBufferDeserializerMixIn(std::vector const& frames) 99 | : frames_(frames), cur_idx_(0) {} 100 | 101 | template 102 | void DeserializeScalar(T* field) { 103 | InitScalarFromPyBuffer(field, frames_[cur_idx_++]); 104 | } 105 | 106 | void DeserializeString(std::string* field) { 107 | InitStringFromPyBuffer(field, frames_[cur_idx_++]); 108 | } 109 | 110 | template 111 | void DeserializeArray(ContiguousArray* field) { 112 | InitArrayFromPyBuffer(field, frames_[cur_idx_++]); 113 | } 114 | 115 | void SkipOptionalField() { 116 | cur_idx_ += 2; // field name + content 117 | } 118 | 119 | private: 120 | std::vector const& frames_; 121 | std::size_t cur_idx_; 122 | }; 123 | 124 | } // namespace treelite::detail::serializer 125 | 126 | #endif // TREELITE_DETAIL_SERIALIZER_MIXINS_H_ 127 | -------------------------------------------------------------------------------- /src/model_concat.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2022-2023 by Contributors 3 | * \file model_concat.cc 4 | * \brief Implementation of model concatenation 5 | * \author Hyunsu Cho 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | namespace treelite { 18 | 19 | std::unique_ptr ConcatenateModelObjects(std::vector const& objs) { 20 | if (objs.empty()) { 21 | return {}; 22 | } 23 | TypeInfo const threshold_type = objs[0]->GetThresholdType(); 24 | TypeInfo const leaf_output_type = objs[0]->GetLeafOutputType(); 25 | std::unique_ptr concatenated_model = Model::Create(threshold_type, leaf_output_type); 26 | // Header 27 | concatenated_model->num_feature = objs[0]->num_feature; 28 | concatenated_model->task_type = objs[0]->task_type; 29 | concatenated_model->average_tree_output = objs[0]->average_tree_output; 30 | // Task parameters 31 | concatenated_model->num_target = objs[0]->num_target; 32 | concatenated_model->num_class = objs[0]->num_class.Clone(); 33 | concatenated_model->leaf_vector_shape = objs[0]->leaf_vector_shape.Clone(); 34 | // Model parameters 35 | concatenated_model->postprocessor = objs[0]->postprocessor; 36 | concatenated_model->sigmoid_alpha = objs[0]->sigmoid_alpha; 37 | concatenated_model->ratio_c = objs[0]->ratio_c; 38 | concatenated_model->base_scores = objs[0]->base_scores.Clone(); 39 | concatenated_model->attributes = objs[0]->attributes; 40 | 41 | std::visit( 42 | [&objs, &concatenated_model](auto&& first_model_obj) { 43 | using ModelType = std::remove_const_t>; 44 | TREELITE_CHECK(std::holds_alternative(concatenated_model->variant_)); 45 | auto& concatenated_model_concrete = std::get(concatenated_model->variant_); 46 | for (std::size_t i = 0; i < objs.size(); ++i) { 47 | TREELITE_CHECK(std::holds_alternative(objs[i]->variant_)) 48 | << "Model object at index " << i 49 | << " has a different type than the first model object (at index 0)"; 50 | TREELITE_CHECK_EQ(concatenated_model->num_target, objs[i]->num_target) 51 | << "Model object at index " << i 52 | << "has a different num_target than the first model object (at index 0)"; 53 | TREELITE_CHECK(concatenated_model->num_class == objs[i]->num_class) 54 | << "Model object at index " << i 55 | << "has a different num_class than the first model object (at index 0)"; 56 | TREELITE_CHECK(concatenated_model->leaf_vector_shape == objs[i]->leaf_vector_shape) 57 | << "Model object at index " << i 58 | << "has a different leaf_vector_shape than the first model object (at index 0)"; 59 | auto& casted = std::get(objs[i]->variant_); 60 | std::transform(casted.trees.begin(), casted.trees.end(), 61 | std::back_inserter(concatenated_model_concrete.trees), 62 | [](auto const& tree) { return tree.Clone(); }); 63 | concatenated_model->target_id.Extend(objs[i]->target_id); 64 | concatenated_model->class_id.Extend(objs[i]->class_id); 65 | } 66 | }, 67 | objs[0]->variant_); 68 | TREELITE_CHECK_EQ(concatenated_model->target_id.Size(), concatenated_model->GetNumTree()); 69 | TREELITE_CHECK_EQ(concatenated_model->class_id.Size(), concatenated_model->GetNumTree()); 70 | return concatenated_model; 71 | } 72 | 73 | } // namespace treelite 74 | -------------------------------------------------------------------------------- /.github/workflows/misc-tests.yml: -------------------------------------------------------------------------------- 1 | name: misc-tests 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - mainline 8 | - 'release_*' 9 | schedule: 10 | - cron: "0 7 * * *" # Run once daily 11 | 12 | permissions: 13 | contents: read # to fetch code (actions/checkout) 14 | 15 | defaults: 16 | run: 17 | shell: bash -l {0} 18 | 19 | concurrency: 20 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 21 | cancel-in-progress: true 22 | 23 | jobs: 24 | test-sdist: 25 | name: Test sdist 26 | runs-on: ubuntu-latest 27 | env: 28 | COMMIT_ID: ${{ github.sha }} 29 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_IAM_S3_UPLOADER }} 30 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }} 31 | steps: 32 | - uses: actions/checkout@v4 33 | - uses: conda-incubator/setup-miniconda@v3 34 | with: 35 | miniforge-variant: Miniforge3 36 | miniforge-version: latest 37 | conda-remove-defaults: "true" 38 | activate-environment: dev 39 | environment-file: ops/conda_env/dev.yml 40 | use-mamba: true 41 | - name: Display Conda env 42 | run: | 43 | conda info 44 | conda list 45 | - name: Test sdist 46 | run: | 47 | bash ops/test-sdist.sh 48 | test-cmake-import: 49 | name: Test using Treelite as a library, via CMake export (${{ matrix.os }}) 50 | runs-on: ${{ matrix.os }} 51 | strategy: 52 | fail-fast: false 53 | matrix: 54 | os: [ubuntu-latest, macos-latest] 55 | steps: 56 | - uses: actions/checkout@v4 57 | - uses: conda-incubator/setup-miniconda@v3 58 | with: 59 | miniforge-variant: Miniforge3 60 | miniforge-version: latest 61 | conda-remove-defaults: "true" 62 | activate-environment: dev 63 | environment-file: ops/conda_env/dev.yml 64 | use-mamba: true 65 | - name: Display Conda env 66 | run: | 67 | conda info 68 | conda list 69 | - name: Test using Treelite as a library 70 | run: | 71 | bash ops/test-cmake-import.sh 72 | test-serializer-compatibility: 73 | name: Test backward compatibility of serializers 74 | runs-on: ${{ matrix.os }} 75 | strategy: 76 | fail-fast: false 77 | matrix: 78 | os: [ubuntu-latest, macos-latest] 79 | steps: 80 | - uses: actions/checkout@v4 81 | - uses: conda-incubator/setup-miniconda@v3 82 | with: 83 | miniforge-variant: Miniforge3 84 | miniforge-version: latest 85 | activate-environment: dev 86 | environment-file: ops/conda_env/dev.yml 87 | use-mamba: true 88 | - name: Display Conda env 89 | run: | 90 | conda info 91 | conda list 92 | - name: Test compatibility 93 | run: | 94 | bash ops/test-serializer-compatibility.sh 95 | test-custom-libpath: 96 | name: Test Treelite with custom libpath 97 | runs-on: ubuntu-latest 98 | steps: 99 | - uses: actions/checkout@v4 100 | - uses: conda-incubator/setup-miniconda@v3 101 | with: 102 | miniforge-variant: Miniforge3 103 | miniforge-version: latest 104 | activate-environment: dev 105 | environment-file: ops/conda_env/dev.yml 106 | use-mamba: true 107 | - name: Display Conda env 108 | run: | 109 | conda info 110 | conda list 111 | - name: Test Treelite with custom libpath 112 | run: | 113 | mkdir build 114 | cd build 115 | cmake .. -GNinja -DCMAKE_INSTALL_PREFIX=/opt/treelite 116 | ninja install -v 117 | cd ../python 118 | pip install --force-reinstall -v . --config-settings use_system_libtreelite=True \ 119 | --config-settings system_libtreelite_dir=/opt/treelite/lib 120 | cd .. 121 | rm -rf build/ 122 | python -c "import treelite; print(treelite.core._LIB)" 123 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | Treelite 3 | ======== 4 | 5 | **Treelite** is a universal model exchange and serialization format for 6 | decision tree forests. Treelite aims to be a small library that enables 7 | other C++ applications to exchange and store decision trees on the disk 8 | as well as the network. 9 | 10 | .. raw:: html 11 | 12 | Star 15 | Watch 18 | 19 | .. warning:: Tree compiler was migrated to TL2cgen 20 | 21 | If you are looking for a compiler to translate tree models into C code, 22 | use :doc:`TL2cgen `. 23 | To migrate existing code using Treelite 3.x, consult the page 24 | :doc:`tl2cgen:treelite-migration`. 25 | 26 | 27 | Why Treelite? 28 | ============= 29 | 30 | Universal, lightweight specification for all tree models 31 | -------------------------------------------------------- 32 | Are you designing a C++ application that needs to read and write tree models, 33 | e.g. a prediction server? 34 | Do not be overwhelmed by the variety of tree models in the wild. Treelite 35 | lets you convert many kinds of tree models into a **common specification**. 36 | By using Treelite as a library, your application now only needs to deal 37 | with one model specification instead of many. Treelite currently 38 | supports: 39 | 40 | * `XGBoost `_ 41 | * `LightGBM `_ 42 | * `scikit-learn `_ 43 | * :doc:`flexible builder class ` for users of other 44 | tree libraries 45 | 46 | In addition, tree libraries can directly output trained trees using the 47 | Treelite specification. For example, the random forest algoritm in 48 | `RAPIDS cuML `_ stores the random forest 49 | object using Treelite. 50 | 51 | .. raw:: html 52 | 53 |

54 | 55 |
58 | (Click to enlarge) 59 |
60 |

61 | 62 | A small library that's easy to embed in another C++ application 63 | --------------------------------------------------------------- 64 | Treelite has an up-to-date CMake build script. If your C++ 65 | application uses CMake, it is easy to embed Treelite. 66 | Treelite is currently used by the following applications: 67 | 68 | * :doc:`tl2cgen:index` 69 | * Forest Inference Library (FIL) in `RAPIDS cuML `_ 70 | * `Triton Inference Server FIL Backend `_, 71 | an optimized prediction runtime for CPUs and GPUs. 72 | 73 | Quick start 74 | =========== 75 | Install Treelite: 76 | 77 | .. code-block:: console 78 | 79 | # From PyPI 80 | pip install treelite 81 | # From Conda 82 | conda install -c conda-forge treelite 83 | 84 | Import your tree ensemble model into Treelite: 85 | 86 | .. code-block:: python 87 | 88 | import treelite 89 | model = treelite.frontend.load_xgboost_model("my_model.json") 90 | 91 | Compute predictions using :doc:`treelite-gtil-api`: 92 | 93 | .. code-block:: python 94 | 95 | X = ... # numpy array 96 | treelite.gtil.predict(model, data=X) 97 | 98 | ******** 99 | Contents 100 | ******** 101 | 102 | .. toctree:: 103 | :maxdepth: 2 104 | :titlesonly: 105 | 106 | install 107 | tutorials/index 108 | treelite-api 109 | treelite-gtil-api 110 | treelite-c-api 111 | knobs/index 112 | serialization/index 113 | treelite-doxygen 114 | 115 | 116 | ******* 117 | Indices 118 | ******* 119 | * :ref:`genindex` 120 | * :ref:`modindex` 121 | -------------------------------------------------------------------------------- /src/gtil/postprocessor.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2021-2023 by Contributors 3 | * \file postprocessor.cc 4 | * \author Hyunsu Cho 5 | * \brief Functions to post-process prediction results 6 | */ 7 | #include "./postprocessor.h" 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | namespace treelite::gtil { 16 | 17 | namespace detail::postprocessor { 18 | 19 | template 20 | void identity(treelite::Model const&, std::int32_t, InputT*) {} 21 | 22 | template 23 | void signed_square(treelite::Model const&, std::int32_t, InputT* elem) { 24 | InputT const margin = *elem; 25 | *elem = std::copysign(margin * margin, margin); 26 | } 27 | 28 | template 29 | void hinge(treelite::Model const&, std::int32_t, InputT* elem) { 30 | *elem = (*elem > 0 ? InputT(1) : InputT(0)); 31 | } 32 | 33 | template 34 | void sigmoid(treelite::Model const& model, std::int32_t, InputT* elem) { 35 | InputT const val = *elem; 36 | *elem = InputT(1) / (InputT(1) + std::exp(-model.sigmoid_alpha * val)); 37 | } 38 | 39 | template 40 | void exponential(treelite::Model const&, std::int32_t, InputT* elem) { 41 | *elem = std::exp(*elem); 42 | } 43 | 44 | template 45 | void exponential_standard_ratio(treelite::Model const& model, std::int32_t, InputT* elem) { 46 | *elem = std::exp2(-*elem / model.ratio_c); 47 | } 48 | 49 | template 50 | void logarithm_one_plus_exp(treelite::Model const&, std::int32_t, InputT* elem) { 51 | *elem = std::log1p(std::exp(*elem)); 52 | } 53 | 54 | template 55 | void identity_multiclass(treelite::Model const&, std::int32_t, InputT*) {} 56 | 57 | template 58 | void softmax(treelite::Model const&, std::int32_t num_class, InputT* row) { 59 | float max_margin = row[0]; 60 | double norm_const = 0.0; 61 | float t; 62 | for (std::int32_t i = 1; i < num_class; ++i) { 63 | if (row[i] > max_margin) { 64 | max_margin = row[i]; 65 | } 66 | } 67 | for (std::int32_t i = 0; i < num_class; ++i) { 68 | t = std::exp(row[i] - max_margin); 69 | norm_const += t; 70 | row[i] = t; 71 | } 72 | for (std::int32_t i = 0; i < num_class; ++i) { 73 | row[i] /= static_cast(norm_const); 74 | } 75 | } 76 | 77 | template 78 | void multiclass_ova(treelite::Model const& model, std::int32_t num_class, InputT* row) { 79 | for (std::int32_t i = 0; i < num_class; ++i) { 80 | row[i] = InputT(1) / (InputT(1) + std::exp(-model.sigmoid_alpha * row[i])); 81 | } 82 | } 83 | 84 | } // namespace detail::postprocessor 85 | 86 | template 87 | PostProcessorFunc GetPostProcessorFunc(std::string const& name) { 88 | if (name == "identity") { 89 | return detail::postprocessor::identity; 90 | } else if (name == "signed_square") { 91 | return detail::postprocessor::signed_square; 92 | } else if (name == "hinge") { 93 | return detail::postprocessor::hinge; 94 | } else if (name == "sigmoid") { 95 | return detail::postprocessor::sigmoid; 96 | } else if (name == "exponential") { 97 | return detail::postprocessor::exponential; 98 | } else if (name == "exponential_standard_ratio") { 99 | return detail::postprocessor::exponential_standard_ratio; 100 | } else if (name == "logarithm_one_plus_exp") { 101 | return detail::postprocessor::logarithm_one_plus_exp; 102 | } else if (name == "identity_multiclass") { 103 | return detail::postprocessor::identity_multiclass; 104 | } else if (name == "softmax") { 105 | return detail::postprocessor::softmax; 106 | } else if (name == "multiclass_ova") { 107 | return detail::postprocessor::multiclass_ova; 108 | } else { 109 | TREELITE_LOG(FATAL) << "Post-processor named '" << name << "' not found"; 110 | } 111 | return nullptr; 112 | } 113 | 114 | template PostProcessorFunc GetPostProcessorFunc(std::string const&); 115 | template PostProcessorFunc GetPostProcessorFunc(std::string const&); 116 | 117 | } // namespace treelite::gtil 118 | -------------------------------------------------------------------------------- /tests/example_app/example.c: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file example.c 4 | * \brief Test using Treelite as a C++ library 5 | */ 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #define safe_treelite(call) \ 13 | { \ 14 | int err = (call); \ 15 | if (err == -1) { \ 16 | fprintf( \ 17 | stderr, "%s:%d: error in %s: %s\n", __FILE__, __LINE__, #call, TreeliteGetLastError()); \ 18 | exit(1); \ 19 | } \ 20 | } 21 | 22 | TreeliteModelHandle BuildModel() { 23 | TreeliteModelHandle model; 24 | TreeliteModelBuilderHandle builder; 25 | char const* model_metadata 26 | = "{" 27 | " \"threshold_type\": \"float32\"," 28 | " \"leaf_output_type\": \"float32\"," 29 | " \"metadata\": {" 30 | " \"num_feature\": 2," 31 | " \"task_type\": \"kRegressor\"," 32 | " \"average_tree_output\": false," 33 | " \"num_target\": 1," 34 | " \"num_class\": [1]," 35 | " \"leaf_vector_shape\": [1, 1]" 36 | " }," 37 | " \"tree_annotation\": {" 38 | " \"num_tree\": 1," 39 | " \"target_id\": [0]," 40 | " \"class_id\": [0]" 41 | " }," 42 | " \"postprocessor\": {" 43 | " \"name\": \"identity\"" 44 | " }," 45 | " \"base_scores\": [0.0]" 46 | "}"; 47 | safe_treelite(TreeliteGetModelBuilder(model_metadata, &builder)); 48 | safe_treelite(TreeliteModelBuilderStartTree(builder)); 49 | safe_treelite(TreeliteModelBuilderStartNode(builder, 0)); 50 | safe_treelite(TreeliteModelBuilderNumericalTest(builder, 0, 0.0, 0, "<", 1, 2)); 51 | safe_treelite(TreeliteModelBuilderEndNode(builder)); 52 | safe_treelite(TreeliteModelBuilderStartNode(builder, 1)); 53 | safe_treelite(TreeliteModelBuilderLeafScalar(builder, -1.0)); 54 | safe_treelite(TreeliteModelBuilderEndNode(builder)); 55 | safe_treelite(TreeliteModelBuilderStartNode(builder, 2)); 56 | safe_treelite(TreeliteModelBuilderLeafScalar(builder, 1.0)); 57 | safe_treelite(TreeliteModelBuilderEndNode(builder)); 58 | safe_treelite(TreeliteModelBuilderEndTree(builder)); 59 | 60 | safe_treelite(TreeliteModelBuilderCommitModel(builder, &model)); 61 | 62 | // Clean up 63 | safe_treelite(TreeliteDeleteModelBuilder(builder)); 64 | 65 | return model; 66 | } 67 | 68 | int main() { 69 | TreeliteModelHandle model = BuildModel(); 70 | size_t num_row = 5; 71 | size_t num_col = 2; 72 | float input[10] = {-2.0f, 0.0f, -1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 2.0f, 0.0f}; 73 | float* output; 74 | size_t out_result_size; 75 | 76 | char const* gtil_config_str 77 | = "{" 78 | " \"predict_type\": \"default\"," 79 | " \"nthread\": 2" 80 | "}"; 81 | TreeliteGTILConfigHandle gtil_config; 82 | safe_treelite(TreeliteGTILParseConfig(gtil_config_str, >il_config)); 83 | 84 | uint64_t const* output_shape; 85 | uint64_t output_ndim; 86 | safe_treelite( 87 | TreeliteGTILGetOutputShape(model, num_row, gtil_config, &output_shape, &output_ndim)); 88 | output = (float*)malloc(output_shape[0] * output_shape[1] * sizeof(float)); 89 | safe_treelite(TreeliteGTILPredict(model, input, "float32", num_row, output, gtil_config)); 90 | 91 | printf("TREELITE_VERSION = %s\n", TREELITE_VERSION); 92 | 93 | for (size_t i = 0; i < num_row; ++i) { 94 | printf("Input %d: [%f", (int)i, input[i * num_col]); 95 | for (size_t j = 1; j < num_col; ++j) { 96 | printf(", %f", input[i * num_col + j]); 97 | } 98 | printf("], output: %f\n", output[i]); 99 | } 100 | 101 | free(output); 102 | 103 | return 0; 104 | } 105 | -------------------------------------------------------------------------------- /src/model_loader/detail/xgboost_json/sax_adapters.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2024 by Contributors 3 | * \file sax_adapters.cc 4 | * \brief Adapters to connect RapidJSON and nlohmann/json with the delegated handler 5 | * \author Hyunsu Cho 6 | */ 7 | 8 | #include "./sax_adapters.h" 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | #include "./delegated_handler.h" 16 | 17 | namespace treelite::model_loader::detail::xgboost { 18 | 19 | /****************************************************************************** 20 | * RapidJSONAdapter 21 | * ***************************************************************************/ 22 | 23 | bool RapidJSONAdapter::Null() { 24 | return handler_->Null(); 25 | } 26 | 27 | bool RapidJSONAdapter::Bool(bool b) { 28 | return handler_->Bool(b); 29 | } 30 | 31 | bool RapidJSONAdapter::Int(int i) { 32 | return handler_->Int64(static_cast(i)); 33 | } 34 | 35 | bool RapidJSONAdapter::Uint(unsigned int u) { 36 | return handler_->Uint64(static_cast(u)); 37 | } 38 | 39 | bool RapidJSONAdapter::Int64(std::int64_t i) { 40 | return handler_->Int64(i); 41 | } 42 | 43 | bool RapidJSONAdapter::Uint64(std::uint64_t u) { 44 | return handler_->Uint64(u); 45 | } 46 | 47 | bool RapidJSONAdapter::Double(double d) { 48 | return handler_->Double(d); 49 | } 50 | 51 | bool RapidJSONAdapter::RawNumber(char const* str, std::size_t length, bool copy) { 52 | TREELITE_LOG(FATAL) << "RawNumber() not implemented"; 53 | return false; 54 | } 55 | 56 | bool RapidJSONAdapter::String(char const* str, std::size_t length, bool) { 57 | return handler_->String(std::string{str, length}); 58 | } 59 | 60 | bool RapidJSONAdapter::StartObject() { 61 | return handler_->StartObject(); 62 | } 63 | 64 | bool RapidJSONAdapter::Key(char const* str, std::size_t length, bool copy) { 65 | return handler_->Key(std::string{str, length}); 66 | } 67 | 68 | bool RapidJSONAdapter::EndObject(std::size_t) { 69 | return handler_->EndObject(); 70 | } 71 | 72 | bool RapidJSONAdapter::StartArray() { 73 | return handler_->StartArray(); 74 | } 75 | 76 | bool RapidJSONAdapter::EndArray(std::size_t) { 77 | return handler_->EndArray(); 78 | } 79 | 80 | /****************************************************************************** 81 | * NlohmannJSONAdapter 82 | * ***************************************************************************/ 83 | 84 | bool NlohmannJSONAdapter::null() { 85 | return handler_->Null(); 86 | } 87 | 88 | bool NlohmannJSONAdapter::boolean(bool val) { 89 | return handler_->Bool(val); 90 | } 91 | 92 | bool NlohmannJSONAdapter::number_integer(std::int64_t val) { 93 | return handler_->Int64(val); 94 | } 95 | 96 | bool NlohmannJSONAdapter::number_unsigned(std::uint64_t val) { 97 | return handler_->Uint64(val); 98 | } 99 | 100 | bool NlohmannJSONAdapter::number_float(double val, std::string const&) { 101 | return handler_->Double(val); 102 | } 103 | 104 | bool NlohmannJSONAdapter::string(std::string& val) { 105 | return handler_->String(val); 106 | } 107 | 108 | bool NlohmannJSONAdapter::binary(nlohmann::json::binary_t& val) { 109 | static_assert(sizeof(char) == sizeof(std::uint8_t), "char must be 1 byte"); 110 | std::string s; 111 | s.resize(val.size()); 112 | std::transform(std::begin(val), std::end(val), std::begin(s), 113 | [](std::uint8_t e) -> char { return static_cast(e); }); 114 | return handler_->String(s); 115 | } 116 | 117 | bool NlohmannJSONAdapter::start_object(std::size_t) { 118 | return handler_->StartObject(); 119 | } 120 | 121 | bool NlohmannJSONAdapter::end_object() { 122 | return handler_->EndObject(); 123 | } 124 | 125 | bool NlohmannJSONAdapter::start_array(std::size_t) { 126 | return handler_->StartArray(); 127 | } 128 | 129 | bool NlohmannJSONAdapter::end_array() { 130 | return handler_->EndArray(); 131 | } 132 | 133 | bool NlohmannJSONAdapter::key(std::string& val) { 134 | return handler_->Key(val); 135 | } 136 | 137 | bool NlohmannJSONAdapter::parse_error( 138 | std::size_t position, std::string const& last_token, nlohmann::json::exception const& ex) { 139 | TREELITE_LOG(ERROR) << "Parsing error at token " << position << ": " << ex.what(); 140 | return false; 141 | } 142 | 143 | } // namespace treelite::model_loader::detail::xgboost 144 | -------------------------------------------------------------------------------- /src/model_loader/detail/xgboost.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file xgboost.cc 4 | * \brief Utility functions for XGBoost frontend 5 | * \author Hyunsu Cho 6 | */ 7 | #include "./xgboost.h" 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | #include 20 | 21 | #include "./string_utils.h" 22 | 23 | namespace treelite::model_loader { 24 | 25 | namespace detail::xgboost { 26 | 27 | // Get correct postprocessor for prediction, depending on objective function 28 | std::string GetPostProcessor(std::string const& objective_name) { 29 | std::vector const exponential_objectives{ 30 | "count:poisson", "reg:gamma", "reg:tweedie", "survival:cox", "survival:aft"}; 31 | if (objective_name == "multi:softmax" || objective_name == "multi:softprob") { 32 | return "softmax"; 33 | } else if (objective_name == "reg:logistic" || objective_name == "binary:logistic") { 34 | return "sigmoid"; 35 | } else if (std::find( 36 | exponential_objectives.cbegin(), exponential_objectives.cend(), objective_name) 37 | != exponential_objectives.cend()) { 38 | return "exponential"; 39 | } else if (objective_name == "binary:hinge") { 40 | return "hinge"; 41 | } else if (objective_name == "reg:squarederror" || objective_name == "reg:linear" 42 | || objective_name == "reg:squaredlogerror" || objective_name == "reg:pseudohubererror" 43 | || objective_name == "binary:logitraw" || objective_name == "rank:pairwise" 44 | || objective_name == "rank:ndcg" || objective_name == "rank:map") { 45 | return "identity"; 46 | } else { 47 | TREELITE_LOG(FATAL) << "Unrecognized XGBoost objective: " << objective_name; 48 | return ""; 49 | } 50 | } 51 | 52 | double TransformBaseScoreToMargin(std::string const& postprocessor, double base_score) { 53 | if (postprocessor == "sigmoid") { 54 | return ProbToMargin::Sigmoid(base_score); 55 | } else if (postprocessor == "exponential") { 56 | return ProbToMargin::Exponential(base_score); 57 | } else { 58 | return base_score; 59 | } 60 | } 61 | 62 | std::vector ParseBaseScore(std::string const& str) { 63 | std::vector parsed_base_score; 64 | if (StringStartsWith(str, "[")) { 65 | // Vector base_score (from XGBoost 3.1+) 66 | rapidjson::Document doc; 67 | doc.Parse(str); 68 | TREELITE_CHECK(doc.IsArray()) << "Expected an array for base_score"; 69 | parsed_base_score.clear(); 70 | for (auto const& e : doc.GetArray()) { 71 | TREELITE_CHECK(e.IsFloat()) << "Expected a float array for base_score"; 72 | parsed_base_score.push_back(e.GetFloat()); 73 | } 74 | } else { 75 | // Scalar base_score (from XGBoost <3.1) 76 | parsed_base_score = std::vector{std::stof(str)}; 77 | } 78 | return parsed_base_score; 79 | } 80 | 81 | } // namespace detail::xgboost 82 | 83 | std::string DetectXGBoostFormat(std::string const& filename) { 84 | constexpr std::size_t nbytes = 2; 85 | char buf[nbytes] = {0}; 86 | 87 | std::ifstream ifs = treelite::detail::OpenFileForReadAsStream(filename); 88 | ifs.read(buf, nbytes); 89 | 90 | auto is_space = [](char c) -> bool { return c == ' ' || c == '\n' || c == '\r' || c == '\t'; }; 91 | 92 | // First look at the first character 93 | if (buf[0] == 'N') { 94 | // The no-op code is only used in UBJSON 95 | return "ubjson"; 96 | } else if (is_space(buf[0])) { 97 | // White-spaces are only present in JSON 98 | return "json"; 99 | } else if (buf[0] != '{') { 100 | // Otherwise, should have '{' if the file is JSON or UBJSON. 101 | return "unknown"; 102 | } 103 | 104 | // First character is '{'. Now look at the second character. 105 | if (is_space(buf[1]) || buf[1] == '"') { 106 | // White-spaces and double quotation marks are only present in JSON 107 | return "json"; 108 | } else if (buf[1] == 'N' || buf[1] == '$' || buf[1] == '#' || buf[1] == 'i' || buf[1] == 'U' 109 | || buf[1] == 'I' || buf[1] == 'l' || buf[1] == 'L') { 110 | // The no-op code and type markers are only present in UBJSON 111 | return "ubjson"; 112 | } 113 | 114 | return "unknown"; 115 | } 116 | 117 | } // namespace treelite::model_loader 118 | -------------------------------------------------------------------------------- /cmake/ExternalLibs.cmake: -------------------------------------------------------------------------------- 1 | include(FetchContent) 2 | 3 | # RapidJSON (header-only library) 4 | add_library(rapidjson INTERFACE) 5 | target_compile_definitions(rapidjson INTERFACE -DRAPIDJSON_HAS_STDSTRING=1) 6 | find_package(RapidJSON) 7 | if(RapidJSON_FOUND) 8 | if(DEFINED RAPIDJSON_INCLUDE_DIRS) 9 | # Compatibility with 1.1.0 stable (circa 2016) 10 | set(RapidJSON_include_dir "${RAPIDJSON_INCLUDE_DIRS}") 11 | else() 12 | # Latest RapidJSON (1.1.0.post*) 13 | set(RapidJSON_include_dir "${RapidJSON_INCLUDE_DIRS}") 14 | endif() 15 | target_include_directories(rapidjson INTERFACE ${RapidJSON_include_dir}) 16 | message(STATUS "Found RapidJSON: ${RapidJSON_include_dir}") 17 | else() 18 | message(STATUS "Did not find RapidJSON in the system root. Fetching RapidJSON now...") 19 | FetchContent_Declare( 20 | RapidJSON 21 | GIT_REPOSITORY https://github.com/Tencent/rapidjson 22 | GIT_TAG ab1842a2dae061284c0a62dca1cc6d5e7e37e346 23 | ) 24 | FetchContent_Populate(RapidJSON) 25 | message(STATUS "RapidJSON was downloaded at ${rapidjson_SOURCE_DIR}.") 26 | target_include_directories(rapidjson INTERFACE $) 27 | endif() 28 | add_library(RapidJSON::rapidjson ALIAS rapidjson) 29 | 30 | # nlohmann/json (header-only library), to parse UBJSON 31 | find_package(nlohmann_json 3.11.3) 32 | if(NOT nlohmann_json_FOUND) 33 | message(STATUS "Did not find nlohmann/json in the system root. Fetching nlohmann/json now...") 34 | FetchContent_Declare( 35 | nlohmann_json 36 | URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz 37 | URL_HASH SHA256=d6c65aca6b1ed68e7a182f4757257b107ae403032760ed6ef121c9d55e81757d 38 | ) 39 | FetchContent_MakeAvailable(nlohmann_json) 40 | message(STATUS "nlohmann/json was downloaded at ${nlohmann_json_SOURCE_DIR}.") 41 | endif() 42 | 43 | # mdspan (header-only library) 44 | message(STATUS "Fetching mdspan...") 45 | set(MDSPAN_CXX_STANDARD 17 CACHE STRING "") 46 | FetchContent_Declare( 47 | mdspan 48 | GIT_REPOSITORY https://github.com/kokkos/mdspan.git 49 | GIT_TAG mdspan-0.6.0 50 | ) 51 | FetchContent_GetProperties(mdspan) 52 | if(NOT mdspan_POPULATED) 53 | FetchContent_Populate(mdspan) 54 | add_subdirectory(${mdspan_SOURCE_DIR} ${mdspan_BINARY_DIR} EXCLUDE_FROM_ALL) 55 | message(STATUS "mdspan was downloaded at ${mdspan_SOURCE_DIR}.") 56 | endif() 57 | if(MSVC) # workaround for MSVC 19.x: https://github.com/kokkos/mdspan/issues/276 58 | target_compile_options(mdspan INTERFACE "/permissive-") 59 | endif() 60 | 61 | # Google C++ tests 62 | if(BUILD_CPP_TEST) 63 | find_package(GTest 1.14.0) 64 | if(NOT GTest_FOUND) 65 | message(STATUS "Did not find Google Test in the system root. Fetching Google Test now...") 66 | FetchContent_Declare( 67 | googletest 68 | URL https://github.com/google/googletest/archive/refs/tags/v1.14.0.tar.gz 69 | ) 70 | set(gtest_force_shared_crt ${DMLC_FORCE_SHARED_CRT} CACHE BOOL "" FORCE) 71 | FetchContent_MakeAvailable(googletest) 72 | 73 | add_library(GTest::gtest ALIAS gtest) 74 | add_library(GTest::gmock ALIAS gmock) 75 | target_compile_definitions(gtest PRIVATE ${ENABLE_GNU_EXTENSION_FLAGS}) 76 | target_compile_definitions(gmock PRIVATE ${ENABLE_GNU_EXTENSION_FLAGS}) 77 | foreach(target gtest gmock) 78 | target_compile_features(${target} PUBLIC cxx_std_14) 79 | if(MSVC) 80 | set_target_properties(${target} PROPERTIES 81 | MSVC_RUNTIME_LIBRARY "${Treelite_MSVC_RUNTIME_LIBRARY}") 82 | endif() 83 | endforeach() 84 | if(IS_DIRECTORY "${googletest_SOURCE_DIR}") 85 | # Do not install gtest 86 | set_property(DIRECTORY ${googletest_SOURCE_DIR} PROPERTY EXCLUDE_FROM_ALL YES) 87 | endif() 88 | endif() 89 | endif() 90 | 91 | # fmtlib 92 | if(BUILD_CPP_TEST) 93 | find_package(fmt 10.1) 94 | if(fmt_FOUND) 95 | get_target_property(fmt_loc fmt::fmt INTERFACE_INCLUDE_DIRECTORIES) 96 | message(STATUS "Found fmtlib at ${fmt_loc}") 97 | set(FMTLIB_FROM_SYSTEM_ROOT TRUE) 98 | else() 99 | message(STATUS "Did not find fmtlib in the system root. Fetching fmtlib now...") 100 | set(FMT_INSTALL OFF CACHE BOOL "" FORCE) 101 | FetchContent_Declare( 102 | fmtlib 103 | GIT_REPOSITORY https://github.com/fmtlib/fmt.git 104 | GIT_TAG 10.1.1 105 | ) 106 | FetchContent_MakeAvailable(fmtlib) 107 | set_target_properties(fmt PROPERTIES EXCLUDE_FROM_ALL TRUE) 108 | set(FMTLIB_FROM_SYSTEM_ROOT FALSE) 109 | endif() 110 | endif() 111 | -------------------------------------------------------------------------------- /include/treelite/gtil.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2021-2023 by Contributors 3 | * \file gtil.h 4 | * \author Hyunsu Cho 5 | * \brief General Tree Inference Library (GTIL), providing a reference implementation for 6 | * predicting with decision trees. 7 | */ 8 | 9 | #ifndef TREELITE_GTIL_H_ 10 | #define TREELITE_GTIL_H_ 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | namespace treelite { 20 | 21 | class Model; 22 | 23 | namespace gtil { 24 | 25 | /*! \brief Prediction type */ 26 | enum class PredictKind : std::int8_t { 27 | /*! 28 | * \brief Usual prediction method: sum over trees and apply post-processing. 29 | * Expected output dimensions: (num_row, num_target, max_num_class) 30 | */ 31 | kPredictDefault = 0, 32 | /*! 33 | * \brief Sum over trees, but don't apply post-processing; get raw margin scores instead. 34 | * Expected output dimensions: (num_row, num_target, max_num_class) 35 | */ 36 | kPredictRaw = 1, 37 | /*! 38 | * \brief Output one (integer) leaf ID per tree. 39 | * Expected output dimensions: (num_row, num_tree) 40 | */ 41 | kPredictLeafID = 2, 42 | /*! 43 | * \brief Output one or more margin scores per tree. 44 | * Expected output dimensions: (num_row, num_tree, leaf_vector_shape[0] * leaf_vector_shape[1]) 45 | */ 46 | kPredictPerTree = 3 47 | }; 48 | 49 | /*! \brief Configuration class */ 50 | struct Configuration { 51 | int nthread{0}; // use all threads by default 52 | PredictKind pred_kind{PredictKind::kPredictDefault}; 53 | Configuration() = default; 54 | explicit Configuration(std::string const& config_json); 55 | }; 56 | 57 | /*! 58 | * \brief Predict with dense data 59 | * \param model Treelite Model object 60 | * \param input The 2D data array, laid out in row-major layout 61 | * \param num_row Number of rows in the data matrix. 62 | * \param output Pointer to buffer to store the output. Call \ref GetOutputShape to get 63 | * the amount of buffer you should allocate for this parameter. 64 | * \param config Configuration of GTIL predictor 65 | */ 66 | template 67 | void Predict(Model const& model, InputT const* input, std::uint64_t num_row, InputT* output, 68 | Configuration const& config); 69 | 70 | /*! 71 | * \brief Predict with sparse data with CSR (compressed sparse row) layout. 72 | * 73 | * In the CSR layout, data[row_ptr[i]:row_ptr[i+1]] store the nonzero entries of row i, and 74 | * col_ind[row_ptr[i]:row_ptr[i+1]] stores the corresponding column indices. 75 | * 76 | * \param model Treelite Model object 77 | * \param data Nonzero elements in the data matrix 78 | * \param col_ind Feature indices. col_ind[i] indicates the feature index associated with data[i]. 79 | * \param row_ptr Pointer to row headers. Length is [num_row] + 1. 80 | * \param num_row Number of rows in the data matrix. 81 | * \param output Pointer to buffer to store the output. Call \ref GetOutputShape to get 82 | * the amount of buffer you should allocate for this parameter. 83 | * \param config Configuration of GTIL predictor 84 | */ 85 | template 86 | void PredictSparse(Model const& model, InputT const* data, std::uint64_t const* col_ind, 87 | std::uint64_t const* row_ptr, std::uint64_t num_row, InputT* output, 88 | Configuration const& config); 89 | 90 | /*! 91 | * \brief Given a data matrix, query the necessary shape of array to hold predictions for all 92 | * data points. 93 | * \param model Treelite Model object 94 | * \param num_row Number of rows in the input 95 | * \param config Configuration of GTIL predictor. Set this by calling \ref TreeliteGTILParseConfig. 96 | * \return Array shape 97 | */ 98 | std::vector GetOutputShape( 99 | Model const& model, std::uint64_t num_row, Configuration const& config); 100 | 101 | extern template void Predict( 102 | Model const&, float const*, std::uint64_t, float*, Configuration const&); 103 | extern template void Predict( 104 | Model const&, double const*, std::uint64_t, double*, Configuration const&); 105 | extern template void PredictSparse(Model const&, float const*, std::uint64_t const*, 106 | std::uint64_t const*, std::uint64_t, float*, Configuration const&); 107 | extern template void PredictSparse(Model const&, double const*, std::uint64_t const*, 108 | std::uint64_t const*, std::uint64_t, double*, Configuration const&); 109 | 110 | } // namespace gtil 111 | } // namespace treelite 112 | 113 | #endif // TREELITE_GTIL_H_ 114 | -------------------------------------------------------------------------------- /docs/install.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Installation 3 | ============ 4 | 5 | You may choose one of two methods to install Treelite on your system: 6 | 7 | .. contents:: 8 | :local: 9 | :depth: 1 10 | 11 | Download binary releases from PyPI (Recommended) 12 | ================================================ 13 | This is probably the most convenient method. Simply type 14 | 15 | .. code-block:: console 16 | 17 | pip install treelite 18 | 19 | to install the Treelite package. The command will locate the binary release that is compatible with 20 | your current platform. Check the installation by running 21 | 22 | .. code-block:: python 23 | 24 | import treelite 25 | 26 | in an interactive Python session. This method is available for only Windows, MacOS, and Linux. 27 | For other operating systems, see the next section. 28 | 29 | .. note:: Windows users need to install Visual C++ Redistributable 30 | 31 | Treelite requires DLLs from `Visual C++ Redistributable 32 | `_ 33 | in order to function, so make sure to install it. Exception: If 34 | you have Visual Studio installed, you already have access to 35 | necessary libraries and thus don't need to install Visual C++ 36 | Redistributable. 37 | 38 | .. note:: Installing OpenMP runtime on MacOS 39 | 40 | Treelite requires the presence of OpenMP runtime. To install OpenMP runtime on a MacOS system, 41 | run the following command: 42 | 43 | .. code-block:: bash 44 | 45 | brew install libomp 46 | 47 | 48 | Download binary releases from Conda 49 | =================================== 50 | Treelite is also available on Conda. 51 | 52 | .. code-block:: console 53 | 54 | conda install -c conda-forge treelite 55 | 56 | to install the Treelite package. See https://anaconda.org/conda-forge/treelite to check the 57 | available platforms. 58 | 59 | .. _install-source: 60 | 61 | Compile Treelite from the source 62 | ================================ 63 | Installation consists of two steps: 64 | 65 | 1. Build the shared libraries from C++ code (See the note below for the list.) 66 | 2. Install the Python package. 67 | 68 | .. note:: List of libraries created 69 | 70 | ================== ===================== 71 | Operating System Main library 72 | ================== ===================== 73 | Windows ``treelite.dll`` 74 | MacOS ``libtreelite.dylib`` 75 | Linux / other UNIX ``libtreelite.so`` 76 | ================== ===================== 77 | 78 | To get started, clone Treelite repo from GitHub. 79 | 80 | .. code-block:: bash 81 | 82 | git clone https://github.com/dmlc/treelite.git 83 | cd treelite 84 | 85 | The next step is to build the shared libraries. 86 | 87 | 1-1. Compiling shared libraries on Linux and MacOS 88 | -------------------------------------------------- 89 | Here, we use CMake to generate a Makefile: 90 | 91 | .. code-block:: bash 92 | 93 | mkdir build 94 | cd build 95 | cmake .. 96 | 97 | Once CMake finished running, simply invoke GNU Make to obtain the shared 98 | libraries. 99 | 100 | .. code-block:: bash 101 | 102 | make 103 | 104 | The compiled libraries will be under the ``build/`` directory. 105 | 106 | .. note:: Compiling Treelite with multithreading on MacOS 107 | 108 | Treelite requires the presence of OpenMP runtime. To install OpenMP runtime on a Mac OSX system, 109 | run the following command: 110 | 111 | .. code-block:: bash 112 | 113 | brew install libomp 114 | 115 | 1-2. Compiling shared libraries on Windows 116 | ------------------------------------------ 117 | We can use CMake to generate a Visual Studio project. The following snippet assumes that Visual 118 | Studio 2022 is installed. Adjust the version depending on the copy that's installed on your system. 119 | 120 | .. code-block:: dosbatch 121 | 122 | mkdir build 123 | cd build 124 | cmake .. -G"Visual Studio 17 2022" -A x64 125 | 126 | .. note:: Visual Studio 2019 or newer is required 127 | 128 | Treelite uses the C++17 standard. Ensure that you have Visual Studio version 2019 or newer. 129 | 130 | Once CMake finished running, open the generated solution file (``treelite.sln``) in Visual Studio. 131 | From the top menu, select **Build > Build Solution**. 132 | 133 | 2. Installing Python package 134 | ---------------------------- 135 | The Python package is located at the ``python`` subdirectory. Run Pip to install the Python 136 | package. The Python package will re-use the native library built in Step 1. 137 | 138 | .. code-block:: bash 139 | 140 | cd python 141 | pip install . # will re-use libtreelite.so 142 | -------------------------------------------------------------------------------- /src/c_api/model_loader.cc: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2023 by Contributors 3 | * \file model_loader.cc 4 | * \author Hyunsu Cho 5 | * \brief C API for frontend functions 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "./c_api_utils.h" 19 | 20 | int TreeliteLoadXGBoostModelLegacyBinary( 21 | char const* filename, [[maybe_unused]] char const* config_json, TreeliteModelHandle* out) { 22 | // config_json is unused for now 23 | API_BEGIN(); 24 | std::unique_ptr model 25 | = treelite::model_loader::LoadXGBoostModelLegacyBinary(filename); 26 | *out = static_cast(model.release()); 27 | API_END(); 28 | } 29 | 30 | int TreeliteLoadXGBoostModelLegacyBinaryFromMemoryBuffer(void const* buf, std::uint64_t len, 31 | [[maybe_unused]] char const* config_json, TreeliteModelHandle* out) { 32 | // config_json is unused for now 33 | API_BEGIN(); 34 | std::unique_ptr model 35 | = treelite::model_loader::LoadXGBoostModelLegacyBinary(buf, len); 36 | *out = static_cast(model.release()); 37 | API_END(); 38 | } 39 | 40 | int TreeliteLoadXGBoostModel( 41 | char const* filename, char const* config_json, TreeliteModelHandle* out) { 42 | TREELITE_LOG(WARNING) << "TreeliteLoadXGBoostModel() is deprecated. Please use " 43 | << "TreeliteLoadXGBoostModelJSON() instead."; 44 | return TreeliteLoadXGBoostModelJSON(filename, config_json, out); 45 | } 46 | 47 | int TreeliteLoadXGBoostModelFromString( 48 | char const* json_str, std::size_t length, char const* config_json, TreeliteModelHandle* out) { 49 | TREELITE_LOG(WARNING) << "TreeliteLoadXGBoostModelFromString() is deprecated. Please use " 50 | << "TreeliteLoadXGBoostModelFromJSONString() instead."; 51 | return TreeliteLoadXGBoostModelFromJSONString(json_str, length, config_json, out); 52 | } 53 | 54 | int TreeliteLoadXGBoostModelJSON( 55 | char const* filename, char const* config_json, TreeliteModelHandle* out) { 56 | API_BEGIN(); 57 | std::unique_ptr model 58 | = treelite::model_loader::LoadXGBoostModelJSON(filename, config_json); 59 | *out = static_cast(model.release()); 60 | API_END(); 61 | } 62 | 63 | int TreeliteLoadXGBoostModelFromJSONString( 64 | char const* json_str, std::size_t length, char const* config_json, TreeliteModelHandle* out) { 65 | API_BEGIN(); 66 | std::unique_ptr model = treelite::model_loader::LoadXGBoostModelFromJSONString( 67 | std::string_view{json_str, length}, config_json); 68 | *out = static_cast(model.release()); 69 | API_END(); 70 | } 71 | 72 | int TreeliteLoadXGBoostModelUBJSON( 73 | char const* filename, char const* config_json, TreeliteModelHandle* out) { 74 | API_BEGIN(); 75 | std::unique_ptr model 76 | = treelite::model_loader::LoadXGBoostModelUBJSON(filename, config_json); 77 | *out = static_cast(model.release()); 78 | API_END(); 79 | } 80 | 81 | int TreeliteLoadXGBoostModelFromUBJSONString( 82 | char const* ubjson_str, std::size_t length, char const* config_json, TreeliteModelHandle* out) { 83 | API_BEGIN(); 84 | std::unique_ptr model = treelite::model_loader::LoadXGBoostModelFromUBJSONString( 85 | std::string_view{ubjson_str, length}, config_json); 86 | *out = static_cast(model.release()); 87 | API_END(); 88 | } 89 | 90 | int TreeliteDetectXGBoostFormat(char const* filename, char const** out_str) { 91 | API_BEGIN(); 92 | std::string& ret_str = treelite::c_api::ReturnValueStore::Get()->ret_str; 93 | ret_str = treelite::model_loader::DetectXGBoostFormat(filename); 94 | *out_str = ret_str.c_str(); 95 | API_END(); 96 | } 97 | 98 | int TreeliteLoadLightGBMModel( 99 | char const* filename, [[maybe_unused]] char const* config_json, TreeliteModelHandle* out) { 100 | // config_json is unused for now 101 | API_BEGIN(); 102 | std::unique_ptr model = treelite::model_loader::LoadLightGBMModel(filename); 103 | *out = static_cast(model.release()); 104 | API_END(); 105 | } 106 | 107 | TREELITE_DLL int TreeliteLoadLightGBMModelFromString( 108 | char const* model_str, [[maybe_unused]] char const* config_json, TreeliteModelHandle* out) { 109 | // config_json is unused for now 110 | API_BEGIN(); 111 | std::unique_ptr model 112 | = treelite::model_loader::LoadLightGBMModelFromString(model_str); 113 | *out = static_cast(model.release()); 114 | API_END(); 115 | } 116 | -------------------------------------------------------------------------------- /python/treelite/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellaneous utilities 3 | """ 4 | 5 | import ctypes 6 | import warnings 7 | from functools import wraps 8 | from inspect import Parameter, signature 9 | from typing import Any, Callable, TypeVar 10 | 11 | import numpy as np 12 | 13 | _CTYPES_TYPE_TABLE = { 14 | "uint32": ctypes.c_uint32, 15 | "float32": ctypes.c_float, 16 | "float64": ctypes.c_double, 17 | } 18 | 19 | 20 | _NUMPY_TYPE_TABLE = {"uint32": np.uint32, "float32": np.float32, "float64": np.float64} 21 | 22 | 23 | _PyBytes_FromStringAndSize = ctypes.pythonapi.PyBytes_FromStringAndSize 24 | _PyBytes_FromStringAndSize.argtypes = (ctypes.c_char_p, ctypes.c_ssize_t) 25 | _PyBytes_FromStringAndSize.restype = ctypes.py_object 26 | 27 | 28 | def typestr_to_ctypes_type(type_info): 29 | """Obtain ctypes type corresponding to a given Type str""" 30 | return _CTYPES_TYPE_TABLE[type_info] 31 | 32 | 33 | def typestr_to_numpy_type(type_info): 34 | """Obtain ctypes type corresponding to a given Type str""" 35 | return _NUMPY_TYPE_TABLE[type_info] 36 | 37 | 38 | def c_str(string): 39 | """Convert a Python string to C string""" 40 | return ctypes.c_char_p(string.encode("utf-8")) 41 | 42 | 43 | def bytes_from_string_and_size(ptr, size): 44 | """Copy `size` bytes from `ptr` to create a new python `bytes` object""" 45 | # Theoretically `ctypes.string_at` does this, but the `size` argument 46 | # there only takes an `int`, while python bytes object can support up to a 47 | # `ssize_t` in size. 48 | return _PyBytes_FromStringAndSize(ptr, size) 49 | 50 | 51 | def py_str(string): 52 | """Convert C string back to Python string""" 53 | return string.decode("utf-8") 54 | 55 | 56 | def c_array(ctype, values): 57 | """ 58 | Convert a Python byte array to C array 59 | 60 | WARNING 61 | ------- 62 | DO NOT USE THIS FUNCTION if performance is critical. Instead, use np.array(*) 63 | with dtype option to explicitly convert type and then use 64 | ndarray.ctypes.data_as(*) to expose underlying buffer as C pointer. 65 | """ 66 | return (ctype * len(values))(*values) 67 | 68 | 69 | _T = TypeVar("_T") 70 | 71 | 72 | # Notice for `require_keyword_args` 73 | # Authors: Olivier Grisel 74 | # Gael Varoquaux 75 | # Andreas Mueller 76 | # Lars Buitinck 77 | # Alexandre Gramfort 78 | # Nicolas Tresegnie 79 | # Sylvain Marie 80 | # License: BSD 3 clause 81 | def _require_keyword_args( 82 | error: bool, 83 | ) -> Callable[[Callable[..., _T]], Callable[..., _T]]: 84 | """Decorator for methods that issues warnings for positional arguments 85 | 86 | Using the keyword-only argument syntax in pep 3102, arguments after the 87 | * will issue a warning or error when passed as a positional argument. 88 | 89 | Modified from sklearn utils.validation. 90 | 91 | Parameters 92 | ---------- 93 | error : 94 | Whether to throw an error or raise a warning. 95 | """ 96 | 97 | def throw_if(func: Callable[..., _T]) -> Callable[..., _T]: 98 | """Throw an error/warning if there are positional arguments after the asterisk. 99 | 100 | Parameters 101 | ---------- 102 | func : 103 | function to check arguments on. 104 | 105 | """ 106 | sig = signature(func) 107 | kwonly_args = [] 108 | all_args = [] 109 | 110 | for name, param in sig.parameters.items(): 111 | if param.kind == Parameter.POSITIONAL_OR_KEYWORD: 112 | all_args.append(name) 113 | elif param.kind == Parameter.KEYWORD_ONLY: 114 | kwonly_args.append(name) 115 | 116 | @wraps(func) 117 | def inner_f(*args: Any, **kwargs: Any) -> _T: 118 | extra_args = len(args) - len(all_args) 119 | if not all_args and extra_args > 0: # keyword argument only 120 | raise TypeError("Keyword argument is required.") 121 | 122 | if extra_args > 0: 123 | # ignore first 'self' argument for instance methods 124 | args_msg = [ 125 | f"{name}" 126 | for name, _ in zip(kwonly_args[:extra_args], args[-extra_args:]) 127 | ] 128 | # pylint: disable=consider-using-f-string 129 | msg = "Pass `{}` as keyword args.".format(", ".join(args_msg)) 130 | if error: 131 | raise TypeError(msg) 132 | warnings.warn(msg, FutureWarning) 133 | for k, arg in zip(sig.parameters, args): 134 | kwargs[k] = arg 135 | return func(**kwargs) 136 | 137 | return inner_f 138 | 139 | return throw_if 140 | 141 | 142 | deprecate_positional_args = _require_keyword_args(False) 143 | -------------------------------------------------------------------------------- /tests/examples/toy_categorical/toy_categorical.test.pred: -------------------------------------------------------------------------------- 1 | -20.738507080549652528 2 | 1.319734514567587835 3 | -10.803203734157792226 4 | -10.001096027455963267 5 | -10.001096027455963267 6 | 0.666932720576836280 7 | -10.803203734157792226 8 | 0.666932720576836280 9 | -10.001096027455963267 10 | -21.301163574763471331 11 | -10.001096027455963267 12 | -21.301163574763471331 13 | 1.319734514567587835 14 | -21.301163574763471331 15 | -10.803203734157792226 16 | -9.234014767952478664 17 | -9.234014767952478664 18 | 1.319734514567587835 19 | -21.301163574763471331 20 | 1.319734514567587835 21 | -10.001096027455963267 22 | -20.738507080549652528 23 | -9.234014767952478664 24 | -10.001096027455963267 25 | 0.666932720576836280 26 | -20.738507080549652528 27 | 0.666932720576836280 28 | -10.803203734157792226 29 | -20.738507080549652528 30 | 1.319734514567587835 31 | -10.001096027455963267 32 | -9.234014767952478664 33 | -9.234014767952478664 34 | -10.001096027455963267 35 | 1.319734514567587835 36 | -20.738507080549652528 37 | -9.234014767952478664 38 | -21.301163574763471331 39 | -10.001096027455963267 40 | -10.001096027455963267 41 | 1.319734514567587835 42 | 0.666932720576836280 43 | -10.001096027455963267 44 | -10.001096027455963267 45 | -10.001096027455963267 46 | -21.301163574763471331 47 | 1.319734514567587835 48 | 1.319734514567587835 49 | 1.319734514567587835 50 | -10.001096027455963267 51 | -20.738507080549652528 52 | -9.234014767952478664 53 | -10.001096027455963267 54 | -10.803203734157792226 55 | -10.001096027455963267 56 | -21.301163574763471331 57 | 1.319734514567587835 58 | -10.001096027455963267 59 | -10.001096027455963267 60 | -21.301163574763471331 61 | 1.319734514567587835 62 | -20.738507080549652528 63 | 1.319734514567587835 64 | -10.803203734157792226 65 | -10.001096027455963267 66 | -21.301163574763471331 67 | 1.319734514567587835 68 | 1.319734514567587835 69 | -21.301163574763471331 70 | -21.301163574763471331 71 | 1.319734514567587835 72 | 1.319734514567587835 73 | -10.803203734157792226 74 | -9.234014767952478664 75 | 1.319734514567587835 76 | -20.738507080549652528 77 | 0.666932720576836280 78 | -20.738507080549652528 79 | -21.301163574763471331 80 | 1.319734514567587835 81 | -10.001096027455963267 82 | 1.319734514567587835 83 | -20.738507080549652528 84 | -21.301163574763471331 85 | -10.001096027455963267 86 | -21.301163574763471331 87 | -10.001096027455963267 88 | -10.001096027455963267 89 | 1.319734514567587835 90 | -10.001096027455963267 91 | -20.738507080549652528 92 | -20.738507080549652528 93 | -10.001096027455963267 94 | -10.001096027455963267 95 | 0.666932720576836280 96 | -21.301163574763471331 97 | 1.319734514567587835 98 | -21.301163574763471331 99 | 1.319734514567587835 100 | 0.666932720576836280 101 | -21.301163574763471331 102 | -20.738507080549652528 103 | -20.738507080549652528 104 | -20.738507080549652528 105 | 0.666932720576836280 106 | -21.301163574763471331 107 | -21.301163574763471331 108 | -10.803203734157792226 109 | -10.001096027455963267 110 | -21.301163574763471331 111 | 1.319734514567587835 112 | 1.319734514567587835 113 | 1.319734514567587835 114 | 1.319734514567587835 115 | -10.001096027455963267 116 | -10.001096027455963267 117 | 1.319734514567587835 118 | -21.301163574763471331 119 | 1.319734514567587835 120 | 1.319734514567587835 121 | -10.803203734157792226 122 | 0.666932720576836280 123 | 1.319734514567587835 124 | -21.301163574763471331 125 | -21.301163574763471331 126 | -10.001096027455963267 127 | -21.301163574763471331 128 | 1.319734514567587835 129 | -10.001096027455963267 130 | -21.301163574763471331 131 | -9.234014767952478664 132 | 1.319734514567587835 133 | -20.738507080549652528 134 | 1.319734514567587835 135 | 0.666932720576836280 136 | -10.001096027455963267 137 | -21.301163574763471331 138 | 1.319734514567587835 139 | -20.738507080549652528 140 | -10.001096027455963267 141 | -10.001096027455963267 142 | -20.738507080549652528 143 | -21.301163574763471331 144 | 1.319734514567587835 145 | -21.301163574763471331 146 | -20.738507080549652528 147 | -10.001096027455963267 148 | -21.301163574763471331 149 | 0.666932720576836280 150 | 1.319734514567587835 151 | -10.803203734157792226 152 | 0.666932720576836280 153 | 0.666932720576836280 154 | -20.738507080549652528 155 | -20.738507080549652528 156 | -21.301163574763471331 157 | -10.803203734157792226 158 | -21.301163574763471331 159 | -21.301163574763471331 160 | -21.301163574763471331 161 | 0.666932720576836280 162 | -20.738507080549652528 163 | -21.301163574763471331 164 | -20.738507080549652528 165 | 1.319734514567587835 166 | 1.319734514567587835 167 | -10.803203734157792226 168 | -10.001096027455963267 169 | 0.666932720576836280 170 | -21.301163574763471331 171 | 0.666932720576836280 172 | -10.001096027455963267 173 | 0.666932720576836280 174 | -21.301163574763471331 175 | 1.319734514567587835 176 | -21.301163574763471331 177 | -10.803203734157792226 178 | -9.234014767952478664 179 | -9.234014767952478664 180 | 0.666932720576836280 181 | --------------------------------------------------------------------------------