├── python
├── packager
│ ├── __init__.py
│ ├── sdist.py
│ ├── util.py
│ └── build_config.py
├── treelite
│ ├── py.typed
│ ├── VERSION
│ ├── gtil
│ │ └── __init__.py
│ ├── path_config.py
│ ├── __init__.py
│ ├── sklearn
│ │ ├── __init__.py
│ │ └── isolation_forest.py
│ ├── core.py
│ ├── libpath.py
│ └── util.py
├── README.rst
├── hatch_build.py
├── .pylintrc
└── pyproject.toml
├── tests
├── python
│ ├── __init__.py
│ └── metadata.py
├── examples
│ ├── mushroom
│ │ └── mushroom.model
│ ├── sparse_categorical
│ │ └── sparse_categorical.test.margin
│ ├── deep_lightgbm
│ │ └── model.txt
│ └── toy_categorical
│ │ └── toy_categorical.test.pred
├── ci_build
│ ├── build_via_cmake.sh
│ ├── Dockerfile.ubuntu20_amd64
│ ├── Dockerfile.ubuntu20_aarch64
│ ├── build_macos_python_wheels.sh
│ ├── entrypoint.sh
│ └── rename_whl.py
├── cpp
│ ├── test_main.cc
│ ├── CMakeLists.txt
│ ├── test_model_loader.cc
│ └── test_utils.cc
├── example_app
│ ├── CMakeLists.txt
│ ├── example.cc
│ └── example.c
└── serializer
│ ├── compatibility_tester.py
│ └── test_serializer.py
├── docs
├── .gitignore
├── _static
│ ├── deployment.png
│ └── custom.css
├── knobs
│ ├── index.rst
│ └── postprocessor.rst
├── _templates
│ └── layout.html
├── requirements.txt
├── tutorials
│ ├── index.rst
│ └── import.rst
├── treelite-doxygen.rst
├── Makefile
├── treelite-gtil-api.rst
├── make.bat
├── treelite-api.rst
├── serialization
│ └── index.rst
├── treelite-c-api.rst
├── index.rst
└── install.rst
├── .isort.cfg
├── ops
├── conda_env
│ ├── pre-commit.yml
│ └── dev.yml
├── test-linux-python-wheel.sh
├── test-macos-python-wheel.sh
├── test-cmake-import.sh
├── build-macos.sh
├── build-linux.sh
├── build-linux-aarch64.sh
├── test-sdist.sh
├── build-windows.bat
├── macos-python-coverage.sh
├── test-win-python-wheel.bat
├── build-cpack.sh
├── win-python-coverage.bat
├── cpp-python-coverage.sh
└── test-serializer-compatibility.sh
├── cmake
├── Version.cmake
├── TreeliteConfig.cmake.in
├── version.h.in
├── Doxygen.cmake
├── Utils.cmake
├── Sanitizer.cmake
└── ExternalLibs.cmake
├── .gitattributes
├── Makefile
├── include
└── treelite
│ ├── error.h
│ ├── pybuffer_frame.h
│ ├── thread_local.h
│ ├── enum
│ ├── tree_node_type.h
│ ├── task_type.h
│ ├── operator.h
│ └── typeinfo.h
│ ├── detail
│ ├── omp_exception.h
│ ├── file_utils.h
│ └── serializer_mixins.h
│ ├── c_api_error.h
│ ├── contiguous_array.h
│ └── gtil.h
├── ACKNOWLEDGMENTS.md
├── src
├── logging.cc
├── c_api
│ ├── logging.cc
│ ├── c_api_utils.h
│ ├── c_api_error.cc
│ ├── field_accessor.cc
│ ├── model.cc
│ ├── serializer.cc
│ ├── gtil.cc
│ └── model_loader.cc
├── model_loader
│ ├── detail
│ │ ├── string_utils.h
│ │ ├── xgboost.h
│ │ ├── xgboost_json
│ │ │ ├── sax_adapters.h
│ │ │ └── sax_adapters.cc
│ │ ├── lightgbm.h
│ │ └── xgboost.cc
│ └── xgboost_ubjson.cc
├── gtil
│ ├── postprocessor.h
│ ├── output_shape.cc
│ ├── config.cc
│ └── postprocessor.cc
├── enum
│ ├── typeinfo.cc
│ ├── tree_node_type.cc
│ ├── operator.cc
│ └── task_type.cc
├── model_query.cc
├── model_builder
│ └── metadata.cc
└── model_concat.cc
├── .readthedocs.yaml
├── dev
├── run_pylint.py
└── change_version.py
├── .flake8
├── .github
└── workflows
│ ├── pre-commit.yml
│ ├── windows-wheel-builder.yml
│ ├── cpack-builder.yml
│ ├── pre-commit-update.yml
│ ├── linux-wheel-builder.yml
│ ├── macos-wheel-builder.yml
│ ├── coverage-tests.yml
│ └── misc-tests.yml
├── README.md
├── .gitignore
├── .pre-commit-config.yaml
└── .clang-format
/python/packager/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/python/treelite/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/python/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/python/treelite/VERSION:
--------------------------------------------------------------------------------
1 | 4.7.0.dev0
2 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | _build
2 | _static
3 | doxyxml
4 | tmp
5 |
--------------------------------------------------------------------------------
/.isort.cfg:
--------------------------------------------------------------------------------
1 | [settings]
2 | profile=black
3 | known_first_party=treelite
4 |
--------------------------------------------------------------------------------
/docs/_static/deployment.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dmlc/treelite/HEAD/docs/_static/deployment.png
--------------------------------------------------------------------------------
/tests/examples/mushroom/mushroom.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dmlc/treelite/HEAD/tests/examples/mushroom/mushroom.model
--------------------------------------------------------------------------------
/ops/conda_env/pre-commit.yml:
--------------------------------------------------------------------------------
1 | name: precommit
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - python=3.11
6 | - pre-commit
7 |
--------------------------------------------------------------------------------
/tests/examples/sparse_categorical/sparse_categorical.test.margin:
--------------------------------------------------------------------------------
1 | -0.036716360109071977
2 | -0.088410677452890038
3 | 0.21990291345117738
4 |
--------------------------------------------------------------------------------
/docs/knobs/index.rst:
--------------------------------------------------------------------------------
1 | ====================
2 | Knobs and Parameters
3 | ====================
4 |
5 | .. toctree::
6 | :maxdepth: 1
7 |
8 | postprocessor
9 |
--------------------------------------------------------------------------------
/docs/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 |
3 | {%- block extrahead %}
4 |
5 | {% endblock %}
6 |
--------------------------------------------------------------------------------
/tests/ci_build/build_via_cmake.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -e
3 | set -x
4 |
5 | rm -rf build
6 | mkdir build
7 | cd build
8 | cmake .. -GNinja "$@"
9 | ninja -v
10 | cd ..
11 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx>=5.2.1
2 | sphinx_rtd_theme>=1.0.0
3 | breathe
4 | autodocsumm
5 | scikit-learn
6 | matplotlib>=2.1
7 | graphviz
8 | numpy
9 | scipy
10 | sphinx-gallery
11 | pandas
12 |
--------------------------------------------------------------------------------
/python/treelite/gtil/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | General Tree Inference Library (GTIL)
3 | """
4 |
5 | from .gtil import predict, predict_leaf, predict_per_tree
6 |
7 | __all__ = ["predict", "predict_leaf", "predict_per_tree"]
8 |
--------------------------------------------------------------------------------
/docs/tutorials/index.rst:
--------------------------------------------------------------------------------
1 | =========
2 | Tutorials
3 | =========
4 |
5 | This page lists tutorials about Treelite.
6 |
7 | .. toctree::
8 | :maxdepth: 1
9 | :caption: Contents:
10 |
11 | import
12 | builder
13 | edit
14 |
--------------------------------------------------------------------------------
/cmake/Version.cmake:
--------------------------------------------------------------------------------
1 | function (write_version)
2 | message(STATUS "Treelite VERSION: ${PROJECT_VERSION}")
3 | configure_file(
4 | ${PROJECT_SOURCE_DIR}/cmake/version.h.in
5 | include/treelite/version.h)
6 | endfunction (write_version)
7 |
--------------------------------------------------------------------------------
/docs/_static/custom.css:
--------------------------------------------------------------------------------
1 | @import url('theme.css');
2 |
3 | .red {
4 | color: red;
5 | }
6 |
7 | .wy-side-nav-search div.version {
8 | color: #FFFFFF;
9 | }
10 |
11 | .wy-table-responsive table td,.wy-table-responsive table th {
12 | white-space: normal;
13 | }
14 |
--------------------------------------------------------------------------------
/python/treelite/path_config.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom hook to customize path for libtreelite.so
3 | """
4 |
5 |
6 | def get_custom_libpath():
7 | """
8 | Get custom path for libtreelite.so.
9 | If valid, must return a directory containing libtreelite.so.
10 | """
11 | return None
12 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Set the default behavior, in case people don't have core.autocrlf set.
2 | * text=auto
3 |
4 | # Explicitly declare text files you want to always be normalized and converted
5 | # to native line endings on checkout.
6 | *.cc text
7 | *.h text
8 | *.proto text
9 | *.txt text
10 | *.md text
11 | Makefile text
12 | LICENSE text
13 |
--------------------------------------------------------------------------------
/ops/test-linux-python-wheel.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -euo pipefail
4 |
5 | echo "##[section]Installing Treelite into Python environment..."
6 | pip install --force-reinstall wheelhouse/*.whl
7 |
8 | echo "##[section]Running Python tests..."
9 | python -m pytest -v -rxXs --fulltrace --durations=0 tests/python/test_sklearn_integration.py
10 |
--------------------------------------------------------------------------------
/ops/test-macos-python-wheel.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -euo pipefail
4 |
5 | echo "##[section]Installing Treelite into Python environment..."
6 | pip install --force-reinstall wheelhouse/*.whl
7 |
8 | echo "##[section]Running Python tests..."
9 | python -m pytest -v -rxXs --fulltrace --durations=0 tests/python/test_sklearn_integration.py
10 |
--------------------------------------------------------------------------------
/docs/treelite-doxygen.rst:
--------------------------------------------------------------------------------
1 | ==================================
2 | Documentation for the C++ codebase
3 | ==================================
4 |
5 | The core logic of Treelite is written in C++. Visit the link below to find
6 | the documentation for all C++ functions and classes defined in Treelite.
7 |
8 | `Documentation for C++ functions and classes <./dev>`_
9 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | ifndef LINT_LANG
2 | LINT_LANG="all"
3 | endif
4 |
5 | ifndef NPROC
6 | NPROC=1
7 | endif
8 |
9 | doxygen:
10 | cd docs; doxygen
11 |
12 | cpp-coverage:
13 | rm -rf build; mkdir build; cd build; cmake .. -DTEST_COVERAGE=ON -DCMAKE_BUILD_TYPE=Debug && make -j$(NPROC)
14 |
15 | all:
16 | rm -rf build; mkdir build; cd build; cmake .. && make -j$(NPROC)
17 |
--------------------------------------------------------------------------------
/cmake/TreeliteConfig.cmake.in:
--------------------------------------------------------------------------------
1 | @PACKAGE_INIT@
2 |
3 | include(CMakeFindDependencyMacro)
4 |
5 | set(USE_OPENMP @USE_OPENMP@)
6 | if(USE_OPENMP)
7 | find_dependency(OpenMP)
8 | endif()
9 |
10 | if(NOT TARGET treelite::treelite)
11 | include(${CMAKE_CURRENT_LIST_DIR}/TreeliteTargets.cmake)
12 | endif()
13 |
14 | message(STATUS "Found Treelite (found version \"${Treelite_VERSION}\")")
15 |
--------------------------------------------------------------------------------
/python/README.rst:
--------------------------------------------------------------------------------
1 | =======================
2 | Treelite Python Package
3 | =======================
4 |
5 | |PyPI version|
6 |
7 | .. |PyPI version| image:: https://badge.fury.io/py/treelite.svg
8 | :target: http://badge.fury.io/py/treelite
9 |
10 | **Treelite** is a universal model exchange and serialization format for decision tree forests.
11 |
12 | See the documentation for more details.
13 |
--------------------------------------------------------------------------------
/tests/cpp/test_main.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2020-2023 by Contributors
3 | * \file test_main.cc
4 | * \author Hyunsu Cho
5 | * \brief Launcher for C++ unit tests, using Google Test framework
6 | */
7 | #include
8 |
9 | int main(int argc, char** argv) {
10 | testing::InitGoogleTest(&argc, argv);
11 | testing::FLAGS_gtest_death_test_style = "threadsafe";
12 | return RUN_ALL_TESTS();
13 | }
14 |
--------------------------------------------------------------------------------
/cmake/version.h.in:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | */
4 | #ifndef TREELITE_VERSION_H_
5 | #define TREELITE_VERSION_H_
6 |
7 | #define TREELITE_VER_MAJOR @treelite_VERSION_MAJOR@
8 | #define TREELITE_VER_MINOR @treelite_VERSION_MINOR@
9 | #define TREELITE_VER_PATCH @treelite_VERSION_PATCH@
10 | #define TREELITE_VERSION_STR "@treelite_VERSION_MAJOR@.@treelite_VERSION_MINOR@.@treelite_VERSION_PATCH@"
11 |
12 | #endif // TREELITE_VERSION_H_
13 |
--------------------------------------------------------------------------------
/ops/conda_env/dev.yml:
--------------------------------------------------------------------------------
1 | name: dev
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - python=3.11
6 | - numpy
7 | - scipy
8 | - pandas
9 | - pytest
10 | - pytest-cov
11 | - hypothesis
12 | - scikit-learn
13 | - coverage
14 | - codecov
15 | - ninja
16 | - lcov
17 | - cmake
18 | - llvm-openmp
19 | - cython
20 | - lightgbm
21 | - cpplint=1.6.0
22 | - pylint
23 | - awscli
24 | - python-build
25 | - pip
26 | - pip:
27 | - cibuildwheel
28 | - xgboost>=2.1.0
29 |
--------------------------------------------------------------------------------
/ops/test-cmake-import.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -euo pipefail
4 |
5 | # Install Treelite C++ library into the Conda env
6 | set -x
7 | rm -rf build/
8 | mkdir build
9 | cd build
10 | cmake .. -DCMAKE_INSTALL_PREFIX="$CONDA_PREFIX" -DCMAKE_INSTALL_LIBDIR="lib" -GNinja
11 | ninja install
12 |
13 | # Try compiling a sample application
14 | cd ../tests/example_app/
15 | rm -rf build/
16 | mkdir build
17 | cd build
18 | cmake .. -GNinja
19 | ninja
20 | ./cpp_example
21 | ./c_example
22 |
--------------------------------------------------------------------------------
/cmake/Doxygen.cmake:
--------------------------------------------------------------------------------
1 | function (run_doxygen)
2 | find_package(Doxygen REQUIRED dot)
3 |
4 | configure_file(
5 | ${treelite_SOURCE_DIR}/docs/Doxyfile.in
6 | ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile @ONLY)
7 | add_custom_target(doc_doxygen ALL
8 | COMMAND ${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile
9 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
10 | COMMENT "Generate documentation for C/C++ functions"
11 | VERBATIM)
12 | endfunction (run_doxygen)
13 |
--------------------------------------------------------------------------------
/ops/build-macos.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -euo pipefail
4 |
5 | if [[ $# -ne 2 ]]; then
6 | echo "Usage: $0 [platform_id] [commit ID]"
7 | exit 1
8 | fi
9 |
10 | platform_id=$1
11 | commit_id=$2
12 |
13 | echo "##[section]Building MacOS Python wheels..."
14 | tests/ci_build/build_macos_python_wheels.sh ${platform_id} ${commit_id}
15 |
16 | echo "##[section]Uploading MacOS Python wheels to S3..."
17 | python -m awscli s3 cp wheelhouse/treelite-*.whl s3://treelite-wheels/ --acl public-read --region us-west-2 || true
18 |
--------------------------------------------------------------------------------
/python/treelite/__init__.py:
--------------------------------------------------------------------------------
1 | """Treelite module"""
2 |
3 | import pathlib
4 |
5 | from . import frontend, gtil, model_builder, sklearn
6 | from .core import TreeliteError
7 | from .model import Model
8 |
9 | VERSION_FILE = pathlib.Path(__file__).parent / "VERSION"
10 | with open(VERSION_FILE, "r", encoding="UTF-8") as _f:
11 | __version__ = _f.read().strip()
12 |
13 | __all__ = [
14 | "Model",
15 | "frontend",
16 | "gtil",
17 | "sklearn",
18 | "model_builder",
19 | "TreeliteError",
20 | "__version__",
21 | ]
22 |
--------------------------------------------------------------------------------
/include/treelite/error.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2022 by Contributors
3 | * \file error.h
4 | * \brief Exception class used throughout the Treelite codebase
5 | * \author Hyunsu Cho
6 | */
7 | #ifndef TREELITE_ERROR_H_
8 | #define TREELITE_ERROR_H_
9 |
10 | #include
11 | #include
12 |
13 | namespace treelite {
14 |
15 | /*!
16 | * \brief Exception class that will be thrown by Treelite
17 | */
18 | struct Error : public std::runtime_error {
19 | explicit Error(std::string const& s) : std::runtime_error(s) {}
20 | };
21 |
22 | } // namespace treelite
23 |
24 | #endif // TREELITE_ERROR_H_
25 |
--------------------------------------------------------------------------------
/ops/build-linux.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -euo pipefail
4 |
5 | TAG=manylinux2014_x86_64
6 |
7 | export CIBW_BUILD=cp38-manylinux_x86_64
8 | export CIBW_ARCHS=x86_64
9 | export CIBW_BUILD_VERBOSITY=3
10 | export CIBW_MANYLINUX_X86_64_IMAGE=manylinux2014
11 |
12 | echo "##[section]Building Python wheel (amd64) for Treelite..."
13 | python -m cibuildwheel python --output-dir wheelhouse
14 | python tests/ci_build/rename_whl.py wheelhouse ${COMMIT_ID} ${TAG}
15 |
16 | echo "##[section]Uploading Python wheel (amd64)..."
17 | python -m awscli s3 cp wheelhouse/*.whl s3://treelite-wheels/ --acl public-read --region us-west-2 || true
18 |
--------------------------------------------------------------------------------
/ops/build-linux-aarch64.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -euo pipefail
4 |
5 | TAG=manylinux2014_aarch64
6 |
7 | export CIBW_BUILD=cp38-manylinux_aarch64
8 | export CIBW_ARCHS=aarch64
9 | export CIBW_BUILD_VERBOSITY=3
10 | export CIBW_MANYLINUX_AARCH64_IMAGE=manylinux2014
11 |
12 | echo "##[section]Building Python wheel (aarch64) for Treelite..."
13 | python -m cibuildwheel python --output-dir wheelhouse
14 | python tests/ci_build/rename_whl.py wheelhouse ${COMMIT_ID} ${TAG}
15 |
16 | echo "##[section]Uploading Python wheel (aarch64)..."
17 | python -m awscli s3 cp wheelhouse/*.whl s3://treelite-wheels/ --acl public-read --region us-west-2 || true
18 |
--------------------------------------------------------------------------------
/include/treelite/pybuffer_frame.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file pybuffer_frame.h
4 | * \brief Data structure to enable zero-copy exchange in Python
5 | * \author Hyunsu Cho
6 | */
7 |
8 | #ifndef TREELITE_PYBUFFER_FRAME_H_
9 | #define TREELITE_PYBUFFER_FRAME_H_
10 |
11 | #include
12 | #include
13 |
14 | #include
15 |
16 | namespace treelite {
17 |
18 | using PyBufferFrame = TreelitePyBufferFrame;
19 |
20 | static_assert(std::is_pod::value, "PyBufferFrame must be a POD type");
21 |
22 | } // namespace treelite
23 |
24 | #endif // TREELITE_PYBUFFER_FRAME_H_
25 |
--------------------------------------------------------------------------------
/ops/test-sdist.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -euo pipefail
4 |
5 | echo "##[section]Building a source distribution..."
6 | python -m build --sdist python/ --outdir .
7 |
8 | echo "##[section]Testing the source distribution..."
9 | python -m pip install --force-reinstall -v treelite-*.tar.gz
10 | python -m pytest -v -rxXs --fulltrace --durations=0 tests/python/test_sklearn_integration.py
11 |
12 | # Deploy source distribution to S3
13 | for file in ./treelite-*.tar.gz
14 | do
15 | mv "${file}" "${file%.tar.gz}+${COMMIT_ID}.tar.gz"
16 | done
17 | python -m awscli s3 cp treelite-*.tar.gz s3://treelite-wheels/ --acl public-read --region us-west-2 || true
18 |
--------------------------------------------------------------------------------
/ops/build-windows.bat:
--------------------------------------------------------------------------------
1 | echo ##[section]Generating Visual Studio solution...
2 | mkdir build
3 | cd build
4 | cmake .. -G"Visual Studio 17 2022" -A x64 -DBUILD_CPP_TEST=ON
5 | if %errorlevel% neq 0 exit /b %errorlevel%
6 |
7 | echo ##[section]Building Visual Studio solution...
8 | cmake --build . --config Release -- /m
9 | if %errorlevel% neq 0 exit /b %errorlevel%
10 | cd ..
11 |
12 | echo ##[section]Running C++ tests...
13 | .\build\treelite_cpp_test.exe
14 | if %errorlevel% neq 0 exit /b %errorlevel%
15 |
16 | echo ##[section]Packaging Python wheel for Treelite...
17 | cd python
18 | pip wheel --no-deps -v . --wheel-dir dist/
19 | if %errorlevel% neq 0 exit /b %errorlevel%
20 |
--------------------------------------------------------------------------------
/ACKNOWLEDGMENTS.md:
--------------------------------------------------------------------------------
1 | # Acknowledgments
2 |
3 | This page acknowledges those who provided great help in building the initial
4 | version of Treelite:
5 |
6 | * Treelite builds upon Hyunsu's earlier work at **Paul G. Allen School of
7 | Computer Science and Engineering, University of Washington at Seattle**, which was performed under the
8 | guidance of **Carlos Guestrin** and **Arvind Krishnamurthy**.
9 | * **Tianqi Chen** (member of DMLC) offered many great ideas for the conception
10 | of the project. He also provided feedback pertaining to API design.
11 | * **Mu Li** (member of DMLC) provided advice and resources to develop Treelite
12 | into a full-fledged open source project.
13 |
--------------------------------------------------------------------------------
/python/treelite/sklearn/__init__.py:
--------------------------------------------------------------------------------
1 | """Model loader to ingest scikit-learn models into Treelite"""
2 |
3 | from .exporter import export_model
4 | from .importer import import_model
5 |
6 |
7 | def import_model_with_model_builder(sklearn_model):
8 | """
9 | This function was removed in Treelite 4.0; please use :py:meth:`~treelite.sklearn.import_model`
10 | instead.
11 | """
12 | raise NotImplementedError(
13 | "treelite.sklearn.import_model_with_model_builder() was removed in Treelite 4.0. "
14 | "Please use treelite.sklearn.import_model() instead."
15 | )
16 |
17 |
18 | __all__ = ["import_model", "export_model", "import_model_with_model_builder"]
19 |
--------------------------------------------------------------------------------
/include/treelite/thread_local.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2021 by Contributors
3 | * \file thread_local.h
4 | * \brief Helper class for thread-local storage
5 | * \author Hyunsu Cho
6 | */
7 | #ifndef TREELITE_THREAD_LOCAL_H_
8 | #define TREELITE_THREAD_LOCAL_H_
9 |
10 | namespace treelite {
11 |
12 | /*!
13 | * \brief A thread-local storage
14 | * \tparam T the type we like to store
15 | */
16 | template
17 | class ThreadLocalStore {
18 | public:
19 | /*! \return get a thread local singleton */
20 | static T* Get() {
21 | static thread_local T inst;
22 | return &inst;
23 | }
24 | };
25 |
26 | } // namespace treelite
27 |
28 | #endif // TREELITE_THREAD_LOCAL_H_
29 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = python3 -msphinx
7 | SPHINXPROJ = treelite
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/python/packager/sdist.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for building sdist
3 | """
4 |
5 | import logging
6 | import pathlib
7 |
8 | from .util import copy_with_logging, copytree_with_logging
9 |
10 |
11 | def copy_cpp_src_tree(
12 | cpp_src_dir: pathlib.Path, target_dir: pathlib.Path, logger: logging.Logger
13 | ) -> None:
14 | """Copy C++ source tree into build directory"""
15 |
16 | for subdir in [
17 | "src",
18 | "include",
19 | "cmake",
20 | ]:
21 | copytree_with_logging(cpp_src_dir / subdir, target_dir / subdir, logger=logger)
22 |
23 | for filename in ["CMakeLists.txt", "LICENSE"]:
24 | copy_with_logging(cpp_src_dir.joinpath(filename), target_dir, logger=logger)
25 |
--------------------------------------------------------------------------------
/src/logging.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2017-2023 by Contributors
3 | * \file logging.cc
4 | * \author Hyunsu Cho
5 | * \brief logging facility for treelite
6 | */
7 |
8 | #include
9 |
10 | namespace treelite {
11 |
12 | void LogMessage::Log(std::string const& msg) {
13 | LogCallbackRegistry const* registry = LogCallbackRegistryStore::Get();
14 | auto callback = registry->GetCallbackLogInfo();
15 | callback(msg.c_str());
16 | }
17 |
18 | void LogMessageWarning::Log(std::string const& msg) {
19 | LogCallbackRegistry const* registry = LogCallbackRegistryStore::Get();
20 | auto callback = registry->GetCallbackLogWarning();
21 | callback(msg.c_str());
22 | }
23 |
24 | } // namespace treelite
25 |
--------------------------------------------------------------------------------
/docs/treelite-gtil-api.rst:
--------------------------------------------------------------------------------
1 | =====================================
2 | General Tree Inference Library (GTIL)
3 | =====================================
4 |
5 | GTIL is a **reference implementation** of a prediction runtime for all Treelite models. It has the following goals:
6 |
7 | * **Universal coverage**: GTIL shall support all tree ensemble models that can be represented as Treelite objects.
8 | * **Accessible code**: GTIL should be written in an easy-to-read style that can be understood to a first-time contributor. We prefer code legibility to performance optimization.
9 | * **Correct output**: As a reference implementation, GTIL should produce correct prediction outputs.
10 |
11 | .. automodule:: treelite.gtil
12 | :members:
13 |
--------------------------------------------------------------------------------
/src/c_api/logging.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file logging.cc
4 | * \author Hyunsu Cho
5 | * \brief C API for logging functions
6 | */
7 |
8 | #include
9 | #include
10 | #include
11 |
12 | int TreeliteRegisterLogCallback(void (*callback)(char const*)) {
13 | API_BEGIN();
14 | auto* registry = treelite::LogCallbackRegistryStore::Get();
15 | registry->RegisterCallBackLogInfo(callback);
16 | API_END();
17 | }
18 |
19 | int TreeliteRegisterWarningCallback(void (*callback)(char const*)) {
20 | API_BEGIN();
21 | auto* registry = treelite::LogCallbackRegistryStore::Get();
22 | registry->RegisterCallBackLogWarning(callback);
23 | API_END();
24 | }
25 |
--------------------------------------------------------------------------------
/tests/ci_build/Dockerfile.ubuntu20_amd64:
--------------------------------------------------------------------------------
1 | FROM ubuntu:20.04
2 | LABEL maintainer "DMLC"
3 |
4 | ENV GOSU_VERSION 1.13
5 | ENV DEBIAN_FRONTEND noninteractive
6 |
7 | RUN \
8 | apt-get update && \
9 | apt-get install -y build-essential git wget unzip tar cmake ninja-build
10 |
11 | # Install lightweight sudo (not bound to TTY)
12 | RUN set -ex; \
13 | wget -nv -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \
14 | chmod +x /usr/local/bin/gosu && \
15 | gosu nobody true
16 |
17 | # Default entry-point to use if running locally
18 | # It will preserve attributes of created files
19 | COPY entrypoint.sh /scripts/
20 |
21 | WORKDIR /workspace
22 | ENTRYPOINT ["/scripts/entrypoint.sh"]
23 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | submodules:
9 | include: all
10 |
11 | # Set the version of Python and other tools you might need
12 | build:
13 | os: ubuntu-22.04
14 | tools:
15 | python: "3.10"
16 | apt_packages:
17 | - graphviz
18 | - cmake
19 | - g++
20 | - doxygen
21 | - ninja-build
22 |
23 | # Build documentation in the docs/ directory with Sphinx
24 | sphinx:
25 | configuration: docs/conf.py
26 |
27 | # Optionally declare the Python requirements required to build your docs
28 | python:
29 | install:
30 | - requirements: docs/requirements.txt
31 |
--------------------------------------------------------------------------------
/tests/ci_build/Dockerfile.ubuntu20_aarch64:
--------------------------------------------------------------------------------
1 | FROM arm64v8/ubuntu:20.04
2 | LABEL maintainer "DMLC"
3 |
4 | ENV GOSU_VERSION 1.13
5 | ENV DEBIAN_FRONTEND noninteractive
6 |
7 | RUN \
8 | apt-get update && \
9 | apt-get install -y build-essential git wget unzip tar cmake ninja-build
10 |
11 | # Install lightweight sudo (not bound to TTY)
12 | RUN set -ex; \
13 | wget -nv -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-arm64" && \
14 | chmod +x /usr/local/bin/gosu && \
15 | gosu nobody true
16 |
17 | # Default entry-point to use if running locally
18 | # It will preserve attributes of created files
19 | COPY entrypoint.sh /scripts/
20 |
21 | WORKDIR /workspace
22 | ENTRYPOINT ["/scripts/entrypoint.sh"]
23 |
--------------------------------------------------------------------------------
/dev/run_pylint.py:
--------------------------------------------------------------------------------
1 | """Wrapper for Pylint"""
2 |
3 | import os
4 | import pathlib
5 | import subprocess
6 | import sys
7 |
8 | ROOT_PATH = pathlib.Path(__file__).parent.parent.expanduser().resolve()
9 | PYPKG_PATH = ROOT_PATH / "python"
10 | PYLINTRC_PATH = PYPKG_PATH / ".pylintrc"
11 |
12 |
13 | def main():
14 | """Wrapper for Pylint. Add treelite to PYTHONPATH so that pylint doesn't error out"""
15 | new_env = os.environ.copy()
16 | new_env["PYTHONPATH"] = str(PYPKG_PATH)
17 |
18 | # sys.argv[1:]: List of source files to check
19 | subprocess.run(
20 | ["pylint", "-rn", "-sn", "--rcfile", str(PYLINTRC_PATH)] + sys.argv[1:],
21 | check=True,
22 | env=new_env,
23 | )
24 |
25 |
26 | if __name__ == "__main__":
27 | main()
28 |
--------------------------------------------------------------------------------
/python/hatch_build.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom hook to customize the behavior of Hatchling.
3 | Here, we customize the tag of the generated wheels.
4 | """
5 |
6 | import sysconfig
7 | from typing import Any, Dict
8 |
9 | from hatchling.builders.hooks.plugin.interface import BuildHookInterface
10 |
11 |
12 | def get_tag() -> str:
13 | """Get appropriate wheel tag according to system"""
14 | tag_platform = sysconfig.get_platform().replace("-", "_").replace(".", "_")
15 | return f"py3-none-{tag_platform}"
16 |
17 |
18 | class CustomBuildHook(BuildHookInterface):
19 | """A custom build hook"""
20 |
21 | def initialize(self, version: str, build_data: Dict[str, Any]) -> None:
22 | """This step ccurs immediately before each build."""
23 | build_data["tag"] = get_tag()
24 |
--------------------------------------------------------------------------------
/python/packager/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for implementing PEP 517 backend
3 | """
4 |
5 | import logging
6 | import pathlib
7 | import shutil
8 |
9 |
10 | def copytree_with_logging(
11 | src: pathlib.Path, dest: pathlib.Path, logger: logging.Logger
12 | ) -> None:
13 | """Call shutil.copytree() with logging"""
14 | logger.info("Copying %s -> %s", str(src), str(dest))
15 | shutil.copytree(src, dest)
16 |
17 |
18 | def copy_with_logging(
19 | src: pathlib.Path, dest: pathlib.Path, logger: logging.Logger
20 | ) -> None:
21 | """Call shutil.copy() with logging"""
22 | if dest.is_dir():
23 | logger.info("Copying %s -> %s", str(src), str(dest / src.name))
24 | else:
25 | logger.info("Copying %s -> %s", str(src), str(dest))
26 | shutil.copy(src, dest)
27 |
--------------------------------------------------------------------------------
/ops/macos-python-coverage.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -euo pipefail
4 |
5 | conda --version
6 | python --version
7 |
8 | # Run coverage test
9 | echo "##[section]Building Treelite..."
10 | set -x
11 | rm -rf build/
12 | mkdir build
13 | cd build
14 | cmake .. -DTEST_COVERAGE=ON -DBUILD_CPP_TEST=ON -GNinja
15 | ninja -v
16 | cd ..
17 |
18 | ./build/treelite_cpp_test
19 | PYTHONPATH=./python python -m pytest --cov=treelite -v -rxXs \
20 | --fulltrace --durations=0 tests/python
21 | lcov --directory . --capture --output-file coverage.info
22 | lcov --remove coverage.info '*dmlccore*' --output-file coverage.info
23 | lcov --remove coverage.info '*fmtlib*' --output-file coverage.info
24 | lcov --remove coverage.info '*/usr/*' --output-file coverage.info
25 | lcov --remove coverage.info '*googletest*' --output-file coverage.info
26 | codecov
27 |
--------------------------------------------------------------------------------
/ops/test-win-python-wheel.bat:
--------------------------------------------------------------------------------
1 | echo ##[section]Installing Treelite into Python environment...
2 | setlocal enabledelayedexpansion
3 | python tests\ci_build\rename_whl.py python\dist %COMMIT_ID% win_amd64
4 | if %errorlevel% neq 0 exit /b %errorlevel%
5 | for /R %%i in (python\\dist\\*.whl) DO (
6 | python -m pip install --force-reinstall "%%i"
7 | if !errorlevel! neq 0 exit /b !errorlevel!
8 | )
9 |
10 | echo ##[section]Running Python tests...
11 | mkdir temp
12 | python -m pytest --basetemp="%WORKING_DIR%\temp" -v -rxXs --fulltrace --durations=0 tests\python\test_sklearn_integration.py
13 | if %errorlevel% neq 0 exit /b %errorlevel%
14 |
15 | echo ##[section]Uploading Python wheels...
16 | for /R %%i in (python\\dist\\*.whl) DO (
17 | python -m awscli s3 cp "%%i" s3://treelite-wheels/ --acl public-read --region us-west-2 || cd .
18 | )
19 |
--------------------------------------------------------------------------------
/include/treelite/enum/tree_node_type.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file tree_node_type.h
4 | * \brief Define enum type NodeType
5 | * \author Hyunsu Cho
6 | */
7 |
8 | #ifndef TREELITE_ENUM_TREE_NODE_TYPE_H_
9 | #define TREELITE_ENUM_TREE_NODE_TYPE_H_
10 |
11 | #include
12 | #include
13 |
14 | namespace treelite {
15 |
16 | /*! \brief Tree node type */
17 | enum class TreeNodeType : std::int8_t {
18 | kLeafNode = 0,
19 | kNumericalTestNode = 1,
20 | kCategoricalTestNode = 2
21 | };
22 |
23 | /*! \brief Get string representation of TreeNodeType */
24 | std::string TreeNodeTypeToString(TreeNodeType type);
25 |
26 | /*! \brief Get NodeType from string */
27 | TreeNodeType TreeNodeTypeFromString(std::string const& name);
28 |
29 | } // namespace treelite
30 |
31 | #endif // TREELITE_ENUM_TREE_NODE_TYPE_H_
32 |
--------------------------------------------------------------------------------
/include/treelite/enum/task_type.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file task_type.h
4 | * \brief Define enum type TaskType
5 | * \author Hyunsu Cho
6 | */
7 |
8 | #ifndef TREELITE_ENUM_TASK_TYPE_H_
9 | #define TREELITE_ENUM_TASK_TYPE_H_
10 |
11 | #include
12 | #include
13 |
14 | namespace treelite {
15 |
16 | /*!
17 | * \brief Enum type representing the task type.
18 | */
19 | enum class TaskType : std::uint8_t {
20 | kBinaryClf = 0,
21 | kRegressor = 1,
22 | kMultiClf = 2,
23 | kLearningToRank = 3,
24 | kIsolationForest = 4
25 | };
26 |
27 | /*! \brief Get string representation of TaskType */
28 | std::string TaskTypeToString(TaskType type);
29 |
30 | /*! \brief Get TaskType from string */
31 | TaskType TaskTypeFromString(std::string const& str);
32 |
33 | } // namespace treelite
34 |
35 | #endif // TREELITE_ENUM_TASK_TYPE_H_
36 |
--------------------------------------------------------------------------------
/tests/example_app/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.16)
2 | project(example_app LANGUAGES C CXX)
3 |
4 | if(DEFINED ENV{CONDA_PREFIX})
5 | set(CMAKE_PREFIX_PATH "$ENV{CONDA_PREFIX};${CMAKE_PREFIX_PATH}")
6 | message(STATUS "Detected Conda environment, CMAKE_PREFIX_PATH set to: ${CMAKE_PREFIX_PATH}")
7 | else()
8 | message(STATUS "No Conda environment detected")
9 | endif()
10 |
11 | find_package(Treelite REQUIRED)
12 |
13 | add_executable(cpp_example example.cc)
14 | target_link_libraries(cpp_example PRIVATE treelite::treelite)
15 |
16 | add_executable(c_example example.c)
17 | target_link_libraries(c_example PRIVATE treelite::treelite)
18 |
19 | set_target_properties(cpp_example PROPERTIES
20 | CXX_STANDARD 17
21 | CXX_STANDARD_REQUIRED YES
22 | )
23 |
24 | set_target_properties(c_example PROPERTIES
25 | C_STANDARD 99
26 | C_STANDARD_REQUIRED YES
27 | )
28 |
--------------------------------------------------------------------------------
/ops/build-cpack.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -euo pipefail
4 |
5 | if [[ -z "${COMMIT_ID:-}" ]]
6 | then
7 | echo "Make sure to set environment variable COMMIT_ID"
8 | exit 1
9 | fi
10 |
11 | if [[ "$#" -lt 1 ]]
12 | then
13 | echo "Usage: $0 {amd64,aarch64}"
14 | exit 2
15 | fi
16 |
17 | arch="$1"
18 |
19 | echo "##[section] Building Treelite for ${arch}..."
20 | tests/ci_build/ci_build.sh ubuntu20_${arch} tests/ci_build/build_via_cmake.sh
21 |
22 | echo "##[section] Packing CPack for ${arch}..."
23 | tests/ci_build/ci_build.sh ubuntu20_${arch} bash -c "cd build/ && cpack -G TGZ"
24 | for tgz in build/treelite-*-Linux.tar.gz
25 | do
26 | mv -v "${tgz}" "${tgz%-Linux.tar.gz}+${COMMIT_ID}-Linux-${arch}.tar.gz"
27 | done
28 |
29 | echo "##[section]Uploading CPack for ${arch}..."
30 | python -m awscli s3 cp build/*.tar.gz s3://treelite-cpack/ --acl public-read --region us-west-2 || true
31 |
--------------------------------------------------------------------------------
/ops/win-python-coverage.bat:
--------------------------------------------------------------------------------
1 | echo ##[section]Generating Visual Studio solution...
2 | mkdir build
3 | cd build
4 | cmake .. -G"Visual Studio 17 2022" -A x64
5 | if %errorlevel% neq 0 exit /b %errorlevel%
6 |
7 | echo ##[section]Building Visual Studio solution...
8 | cmake --build . --config Release -- /m
9 | if %errorlevel% neq 0 exit /b %errorlevel%
10 | cd ..
11 |
12 | echo ##[section]Running Python tests...
13 | mkdir temp
14 | set "PYTHONPATH=./python"
15 | set "PYTEST_TMPDIR=%USERPROFILE%\AppData\Local\Temp\pytest_temp"
16 | mkdir "%PYTEST_TMPDIR%"
17 | python -m pytest --basetemp="%USERPROFILE%\AppData\Local\Temp\pytest_temp" --cov=treelite --cov-report xml -v -rxXs --fulltrace --durations=0 tests\python
18 | if %errorlevel% neq 0 exit /b %errorlevel%
19 |
20 | echo ##[section]Submitting code coverage data to CodeCov...
21 | python -m codecov -f coverage.xml
22 | if %errorlevel% neq 0 exit /b %errorlevel%
23 |
--------------------------------------------------------------------------------
/src/model_loader/detail/string_utils.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file string_utils.h
4 | * \brief Helper functions for manipulating strings
5 | * \author Hyunsu Cho
6 | */
7 |
8 | #ifndef SRC_MODEL_LOADER_DETAIL_STRING_UTILS_H_
9 | #define SRC_MODEL_LOADER_DETAIL_STRING_UTILS_H_
10 |
11 | #include
12 | #include
13 |
14 | namespace treelite::model_loader::detail {
15 |
16 | inline bool StringStartsWith(std::string const& str, std::string const& prefix) {
17 | return str.rfind(prefix, 0) == 0;
18 | }
19 |
20 | inline void StringTrimFromEnd(std::string& s) {
21 | s.erase(std::find_if(
22 | s.rbegin(), s.rend(), [](char ch) { return ch != '\n' && ch != '\r' && ch != ' '; })
23 | .base(),
24 | s.end());
25 | }
26 |
27 | } // namespace treelite::model_loader::detail
28 |
29 | #endif // SRC_MODEL_LOADER_DETAIL_STRING_UTILS_H_
30 |
--------------------------------------------------------------------------------
/include/treelite/enum/operator.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file operator.h
4 | * \brief Define enum type Operator
5 | * \author Hyunsu Cho
6 | */
7 |
8 | #ifndef TREELITE_ENUM_OPERATOR_H_
9 | #define TREELITE_ENUM_OPERATOR_H_
10 |
11 | #include
12 | #include
13 |
14 | namespace treelite {
15 |
16 | /*! \brief Type of comparison operators used in numerical test nodes */
17 | enum class Operator : std::int8_t {
18 | kNone,
19 | kEQ, /*!< operator == */
20 | kLT, /*!< operator < */
21 | kLE, /*!< operator <= */
22 | kGT, /*!< operator > */
23 | kGE, /*!< operator >= */
24 | };
25 |
26 | /*! \brief Get string representation of Operator */
27 | std::string OperatorToString(Operator type);
28 |
29 | /*! \brief Get Operator from string */
30 | Operator OperatorFromString(std::string const& name);
31 |
32 | } // namespace treelite
33 |
34 | #endif // TREELITE_ENUM_OPERATOR_H_
35 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=treelite
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
20 | echo.installed, then set the SPHINXBUILD environment variable to point
21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
22 | echo.may add the Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/src/gtil/postprocessor.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2021-2023 by Contributors
3 | * \file postprocessor.h
4 | * \author Hyunsu Cho
5 | * \brief Functions to post-process prediction results
6 | */
7 |
8 | #ifndef SRC_GTIL_POSTPROCESSOR_H_
9 | #define SRC_GTIL_POSTPROCESSOR_H_
10 |
11 | #include
12 | #include
13 |
14 | namespace treelite {
15 |
16 | class Model;
17 |
18 | namespace gtil {
19 |
20 | template
21 | using PostProcessorFunc = void (*)(treelite::Model const&, std::int32_t, InputT*);
22 |
23 | template
24 | PostProcessorFunc GetPostProcessorFunc(std::string const& name);
25 |
26 | extern template PostProcessorFunc GetPostProcessorFunc(std::string const& name);
27 | extern template PostProcessorFunc GetPostProcessorFunc(std::string const& name);
28 |
29 | } // namespace gtil
30 | } // namespace treelite
31 |
32 | #endif // SRC_GTIL_POSTPROCESSOR_H_
33 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | filename = *.py, *.pyx, *.pxd
3 | exclude =
4 | *.egg,
5 | .git,
6 | __pycache__,
7 | build/,
8 | cpp,
9 | docs,
10 | tests/cython/
11 |
12 | # Cython Rules ignored:
13 | # E999: invalid syntax (works for Python, not Cython)
14 | # E225: Missing whitespace around operators (breaks cython casting syntax like )
15 | # E226: Missing whitespace around arithmetic operators (breaks cython pointer syntax like int*)
16 | # E227: Missing whitespace around bitwise or shift operator (Can also break casting syntax)
17 | # W503: line break before binary operator (breaks lines that start with a pointer)
18 | # W504: line break after binary operator (breaks lines that end with a pointer)
19 |
20 | extend-ignore =
21 | # handled by black
22 | E501, W503, E203
23 | # imported but unused
24 | F401
25 | # redefinition of unused
26 | F811
27 | # E203 whitespace before ':'
28 | # https://github.com/psf/black/issues/315
29 | E203
30 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yml:
--------------------------------------------------------------------------------
1 | name: pre-commit
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches:
7 | - mainline
8 | - 'release_*'
9 |
10 | permissions:
11 | contents: read # to fetch code (actions/checkout)
12 |
13 | defaults:
14 | run:
15 | shell: bash -l {0}
16 |
17 | concurrency:
18 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
19 | cancel-in-progress: true
20 |
21 | jobs:
22 | pre-commit:
23 | runs-on: ubuntu-latest
24 | steps:
25 | - uses: actions/checkout@v4
26 | - uses: conda-incubator/setup-miniconda@v3
27 | with:
28 | miniforge-variant: Miniforge3
29 | miniforge-version: latest
30 | conda-remove-defaults: "true"
31 | activate-environment: precommit
32 | environment-file: ops/conda_env/pre-commit.yml
33 | use-mamba: true
34 | - name: Run pre-commit checks
35 | run: |
36 | pre-commit install
37 | pre-commit run --all-files --color always
38 |
--------------------------------------------------------------------------------
/src/c_api/c_api_utils.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file c_api_utils.h
4 | * \author Hyunsu Cho
5 | * \brief C API of Treelite, used for interfacing with other languages
6 | */
7 | #ifndef SRC_C_API_C_API_UTILS_H_
8 | #define SRC_C_API_C_API_UTILS_H_
9 |
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 | namespace treelite::c_api {
18 |
19 | /*! \brief When returning a complex object from a C API function, we
20 | * store the object here and then return a pointer. The
21 | * storage is thread-local static storage. */
22 | struct ReturnValueEntry {
23 | std::string ret_str;
24 | std::vector ret_uint32_vec;
25 | std::vector ret_uint64_vec;
26 | std::vector ret_frames;
27 | };
28 | using ReturnValueStore = ThreadLocalStore;
29 |
30 | } // namespace treelite::c_api
31 |
32 | #endif // SRC_C_API_C_API_UTILS_H_
33 |
--------------------------------------------------------------------------------
/ops/cpp-python-coverage.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -euo pipefail
4 |
5 | echo "##[section]Installing lcov and Ninja..."
6 | sudo apt-get install lcov ninja-build
7 |
8 | echo "##[section]Building Treelite..."
9 | mkdir build/
10 | cd build/
11 | cmake .. -DTEST_COVERAGE=ON -DCMAKE_BUILD_TYPE=Debug -DBUILD_CPP_TEST=ON -GNinja
12 | ninja install -v
13 | cd ..
14 |
15 | echo "##[section]Running Google C++ tests..."
16 | ./build/treelite_cpp_test
17 |
18 | echo "##[section]Running Python integration tests..."
19 | export PYTHONPATH='./python'
20 | python -m pytest --cov=treelite -v -rxXs --fulltrace --durations=0 tests/python tests/serializer
21 |
22 | echo "##[section]Collecting coverage data..."
23 | lcov --directory . --capture --output-file coverage.info
24 | lcov --remove coverage.info '*/usr/*' --output-file coverage.info
25 | lcov --remove coverage.info '*/build/_deps/*' --output-file coverage.info
26 |
27 | echo "##[section]Submitting code coverage data to CodeCov..."
28 | bash <(curl -s https://codecov.io/bash) -X gcov || echo "Codecov did not collect coverage reports"
29 |
--------------------------------------------------------------------------------
/python/.pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 | ignore-paths=tests/cython,docs
3 | extension-pkg-whitelist=numpy
4 |
5 | load-plugins=pylint.extensions.no_self_use
6 |
7 | disable=unexpected-special-method-signature,too-many-nested-blocks,useless-object-inheritance,import-outside-toplevel,unsubscriptable-object,attribute-defined-outside-init,unbalanced-tuple-unpacking,too-many-lines,duplicate-code,too-many-arguments
8 |
9 | dummy-variables-rgx=(unused|)_.*
10 |
11 | [BASIC]
12 |
13 | # Enforce naming convention
14 | const-naming-style=UPPER_CASE
15 | class-naming-style=PascalCase
16 | function-naming-style=snake_case
17 | method-naming-style=snake_case
18 | attr-naming-style=snake_case
19 | argument-naming-style=snake_case
20 | variable-naming-style=snake_case
21 | class-attribute-naming-style=snake_case
22 |
23 | # Allow single-letter variables
24 | variable-rgx=[a-zA-Z_][a-z0-9_]{0,30}$
25 | argument-rgx=[a-zA-Z_][a-z0-9_]{0,30}$
26 |
27 | [TYPECHECK]
28 | generated-members=np.float32,np.uintc,np.uintp,np.uint32
29 |
30 | [MESSAGES CONTROL]
31 | # globally disable pylint checks (comma separated)
32 | disable=fixme
33 |
--------------------------------------------------------------------------------
/src/c_api/c_api_error.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2017-2023 by Contributors
3 | * \file c_api_error.cc
4 | * \author Hyunsu Cho
5 | * \brief C error handling
6 | */
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | namespace treelite::c_api {
15 |
16 | struct APIErrorEntry {
17 | std::string last_error;
18 | std::string version_str;
19 | };
20 |
21 | using APIErrorStore = ThreadLocalStore;
22 |
23 | } // namespace treelite::c_api
24 |
25 | char const* TreeliteGetLastError() {
26 | return treelite::c_api::APIErrorStore::Get()->last_error.c_str();
27 | }
28 |
29 | void TreeliteAPISetLastError(char const* msg) {
30 | treelite::c_api::APIErrorStore::Get()->last_error = msg;
31 | }
32 |
33 | char const* TreeliteQueryTreeliteVersion() {
34 | auto& version_str = treelite::c_api::APIErrorStore::Get()->version_str;
35 | version_str = TREELITE_VERSION_STR;
36 | return version_str.c_str();
37 | }
38 |
39 | char const* TREELITE_VERSION = "TREELITE_VERSION_" TREELITE_VERSION_STR;
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Treelite
2 |
3 | 
4 | [](http://treelite.readthedocs.io/en/latest/?badge=latest)
5 | [](https://codecov.io/gh/dmlc/treelite)
6 | [](./LICENSE)
7 | [](https://pypi.python.org/pypi/treelite/)
8 | [](https://anaconda.org/conda-forge/treelite)
9 |
10 | [Documentation](https://treelite.readthedocs.io/en/latest) |
11 | [Installation](http://treelite.readthedocs.io/en/latest/install.html) |
12 | [Release Notes](NEWS.md) |
13 | [Acknowledgements](ACKNOWLEDGMENTS.md) |
14 |
15 | **Treelite** is a universal model exchange and serialization format for
16 | decision tree forests. Treelite aims to be a small library that enables
17 | other C++ applications to exchange and store decision trees on the disk
18 | as well as the network.
19 |
--------------------------------------------------------------------------------
/tests/cpp/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | add_executable(treelite_cpp_test)
2 | set_target_properties(treelite_cpp_test
3 | PROPERTIES
4 | CXX_STANDARD 17
5 | CXX_STANDARD_REQUIRED ON)
6 | target_link_libraries(treelite_cpp_test
7 | PRIVATE objtreelite rapidjson
8 | GTest::gtest GTest::gmock fmt::fmt-header-only std::mdspan)
9 | set_output_directory(treelite_cpp_test ${PROJECT_BINARY_DIR})
10 |
11 | if(MSVC)
12 | target_compile_options(treelite_cpp_test PRIVATE
13 | /utf-8 -D_CRT_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_DEPRECATE)
14 | endif()
15 |
16 | if(TEST_COVERAGE)
17 | if(MSVC)
18 | message(FATAL_ERROR "Test coverage not available on Windows")
19 | endif()
20 | target_compile_options(treelite_cpp_test PUBLIC -g3 --coverage)
21 | target_link_options(treelite_cpp_test PUBLIC --coverage)
22 | endif()
23 |
24 | target_sources(treelite_cpp_test
25 | PRIVATE
26 | test_main.cc
27 | test_gtil.cc
28 | test_model_builder.cc
29 | test_model_concat.cc
30 | test_model_loader.cc
31 | test_serializer.cc
32 | test_utils.cc
33 | )
34 |
35 | target_include_directories(treelite_cpp_test
36 | PRIVATE ${PROJECT_SOURCE_DIR}/src/
37 | )
38 |
--------------------------------------------------------------------------------
/src/enum/typeinfo.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2017-2023 by Contributors
3 | * \file typeinfo.cc
4 | * \author Hyunsu Cho
5 | * \brief Utilities for TypeInfo enum
6 | */
7 |
8 | #include
9 |
10 | #include
11 | #include
12 |
13 | namespace treelite {
14 |
15 | std::string TypeInfoToString(treelite::TypeInfo info) {
16 | switch (info) {
17 | case treelite::TypeInfo::kInvalid:
18 | return "invalid";
19 | case treelite::TypeInfo::kUInt32:
20 | return "uint32";
21 | case treelite::TypeInfo::kFloat32:
22 | return "float32";
23 | case treelite::TypeInfo::kFloat64:
24 | return "float64";
25 | default:
26 | TREELITE_LOG(FATAL) << "Unrecognized type";
27 | return "";
28 | }
29 | }
30 |
31 | TypeInfo TypeInfoFromString(std::string const& str) {
32 | if (str == "uint32") {
33 | return TypeInfo::kUInt32;
34 | } else if (str == "float32") {
35 | return TypeInfo::kFloat32;
36 | } else if (str == "float64") {
37 | return TypeInfo::kFloat64;
38 | } else {
39 | TREELITE_LOG(FATAL) << "Unrecognized type: " << str;
40 | return TypeInfo::kInvalid;
41 | }
42 | }
43 |
44 | } // namespace treelite
45 |
--------------------------------------------------------------------------------
/src/enum/tree_node_type.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file tree_node_type.cc
4 | * \author Hyunsu Cho
5 | * \brief Utilities for NodeType enum
6 | */
7 |
8 | #include
9 |
10 | #include
11 | #include
12 |
13 | namespace treelite {
14 |
15 | std::string TreeNodeTypeToString(TreeNodeType type) {
16 | switch (type) {
17 | case TreeNodeType::kLeafNode:
18 | return "leaf_node";
19 | case TreeNodeType::kNumericalTestNode:
20 | return "numerical_test_node";
21 | case TreeNodeType::kCategoricalTestNode:
22 | return "categorical_test_node";
23 | default:
24 | return "";
25 | }
26 | }
27 |
28 | TreeNodeType TreeNodeTypeFromString(std::string const& name) {
29 | if (name == "leaf_node") {
30 | return TreeNodeType::kLeafNode;
31 | } else if (name == "numerical_test_node") {
32 | return TreeNodeType::kNumericalTestNode;
33 | } else if (name == "categorical_test_node") {
34 | return TreeNodeType::kCategoricalTestNode;
35 | } else {
36 | TREELITE_LOG(FATAL) << "Unknown split type: " << name;
37 | return TreeNodeType::kLeafNode;
38 | }
39 | }
40 |
41 | } // namespace treelite
42 |
--------------------------------------------------------------------------------
/python/treelite/sklearn/isolation_forest.py:
--------------------------------------------------------------------------------
1 | """Utility functions for loading IsolationForest models"""
2 |
3 | import numpy as np
4 |
5 |
6 | def harmonic(number):
7 | """Calculates the n-th harmonic number"""
8 | return np.log(number) + np.euler_gamma
9 |
10 |
11 | def expected_depth(n_remainder):
12 | """Calculates the expected isolation depth for a remainder of uniform points"""
13 | if n_remainder <= 1:
14 | return 0.0
15 | if n_remainder == 2:
16 | return 1.0
17 | return float(2 * harmonic(n_remainder - 1) - 2 * (n_remainder - 1) / n_remainder)
18 |
19 |
20 | def calculate_depths(isolation_depths, tree, curr_node, curr_depth):
21 | """Fill in an array of isolation depths for a scikit-learn isolation forest model"""
22 | if tree.children_left[curr_node] == -1:
23 | isolation_depths[curr_node] = curr_depth + expected_depth(
24 | tree.n_node_samples[curr_node]
25 | )
26 | else:
27 | calculate_depths(
28 | isolation_depths, tree, tree.children_left[curr_node], curr_depth + 1
29 | )
30 | calculate_depths(
31 | isolation_depths, tree, tree.children_right[curr_node], curr_depth + 1
32 | )
33 |
--------------------------------------------------------------------------------
/src/model_loader/detail/xgboost.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2020-2023 by Contributors
3 | * \file xgboost.h
4 | * \brief Helper functions for loading XGBoost models
5 | * \author William Hicks
6 | */
7 | #ifndef SRC_MODEL_LOADER_DETAIL_XGBOOST_H_
8 | #define SRC_MODEL_LOADER_DETAIL_XGBOOST_H_
9 |
10 | #include
11 | #include
12 | #include
13 |
14 | namespace treelite::model_loader::detail::xgboost {
15 |
16 | struct ProbToMargin {
17 | static double Sigmoid(double base_score) {
18 | return -std::log(1.0 / base_score - 1.0);
19 | }
20 | static double Exponential(double base_score) {
21 | return std::log(base_score);
22 | }
23 | };
24 |
25 | // Get correct prediction transform function, depending on objective function
26 | std::string GetPostProcessor(std::string const& objective_name);
27 |
28 | // Transform base score from probability into margin score
29 | double TransformBaseScoreToMargin(std::string const& postprocessor, double base_score);
30 |
31 | // Parse base score
32 | std::vector ParseBaseScore(std::string const& str);
33 |
34 | enum FeatureType { kNumerical = 0, kCategorical = 1 };
35 |
36 | } // namespace treelite::model_loader::detail::xgboost
37 |
38 | #endif // SRC_MODEL_LOADER_DETAIL_XGBOOST_H_
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Prerequisites
2 | *.d
3 |
4 | # Compiled Object files
5 | *.slo
6 | *.lo
7 | *.o
8 | *.obj
9 |
10 | # Precompiled Headers
11 | *.gch
12 | *.pch
13 |
14 | # Python compiled
15 | *.pyc
16 |
17 | # Compiled Java classes
18 | *.class
19 | /runtime/java/treelite4j/target/
20 |
21 | # Auxiliary Java files
22 | /runtime/java/treelite4j/.classpath
23 | /runtime/java/treelite4j/.project
24 | /runtime/java/treelite4j/.settings/
25 | /runtime/java/treelite4j/coverage.info
26 |
27 | # Compiled Dynamic libraries
28 | *.so
29 | *.dylib
30 | *.dll
31 |
32 | # Fortran module files
33 | *.mod
34 | *.smod
35 |
36 | # Compiled Static libraries
37 | *.lai
38 | *.la
39 | *.a
40 | *.lib
41 |
42 | # build folder
43 | /build/
44 | /python/dist/
45 | /python/build/
46 | /python/treelite.egg-info/
47 | /runtime/python/dist/
48 | /runtime/python/build/
49 | /runtime/python/treelite_runtime.egg-info/
50 |
51 | # Executables
52 | *.exe
53 | *.out
54 | *.app
55 | /treelite
56 |
57 | # Mac system file
58 | .DS_Store
59 |
60 | # Python wheel binaries
61 | /python/dist/
62 |
63 | # Python cache
64 | __pycache__
65 |
66 | # symbol directories
67 | *.dSYM/
68 |
69 | # Project files
70 | /runtime/python/.idea/
71 | /python/.idea/
72 | /tests/python/.idea/
73 | /.idea/
74 | /.hypothesis
75 | /lint.py
76 |
--------------------------------------------------------------------------------
/src/enum/operator.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file operator.cc
4 | * \author Hyunsu Cho
5 | * \brief Utilities for Operator enum
6 | */
7 |
8 | #include
9 |
10 | #include
11 | #include
12 |
13 | namespace treelite {
14 |
15 | /*! \brief Get string representation of Operator */
16 | std::string OperatorToString(Operator op) {
17 | switch (op) {
18 | case Operator::kEQ:
19 | return "==";
20 | case Operator::kLT:
21 | return "<";
22 | case Operator::kLE:
23 | return "<=";
24 | case Operator::kGT:
25 | return ">";
26 | case Operator::kGE:
27 | return ">=";
28 | default:
29 | return "";
30 | }
31 | }
32 |
33 | /*! \brief Get Operator from string */
34 | Operator OperatorFromString(std::string const& name) {
35 | if (name == "==") {
36 | return Operator::kEQ;
37 | } else if (name == "<") {
38 | return Operator::kLT;
39 | } else if (name == "<=") {
40 | return Operator::kLE;
41 | } else if (name == ">") {
42 | return Operator::kGT;
43 | } else if (name == ">=") {
44 | return Operator::kGE;
45 | } else {
46 | TREELITE_LOG(FATAL) << "Unknown operator: " << name;
47 | return Operator::kNone;
48 | }
49 | }
50 |
51 | } // namespace treelite
52 |
--------------------------------------------------------------------------------
/tests/cpp/test_model_loader.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file test_model_loader.cc
4 | * \author Hyunsu Cho
5 | * \brief C++ tests for model loader
6 | */
7 |
8 | #include
9 | #include
10 | #include
11 |
12 | #include
13 |
14 | #include "model_loader/detail/string_utils.h"
15 | #include "model_loader/detail/xgboost.h"
16 |
17 | TEST(ModelLoader, StringTrim) {
18 | std::string s{"foobar\r\n"};
19 | treelite::model_loader::detail::StringTrimFromEnd(s);
20 | EXPECT_EQ(s, "foobar");
21 | }
22 |
23 | TEST(ModelLoader, StringStartsWith) {
24 | std::string s{"foobar"};
25 | EXPECT_TRUE(treelite::model_loader::detail::StringStartsWith(s, "foo"));
26 | }
27 |
28 | TEST(ModelLoader, XGBoostBaseScore) {
29 | {
30 | std::string s{"[5.2008224E-1,4.665861E-1]"};
31 | std::vector parsed = treelite::model_loader::detail::xgboost::ParseBaseScore(s);
32 | std::vector expected{5.2008224E-1, 4.665861E-1};
33 | EXPECT_EQ(parsed, expected);
34 | }
35 |
36 | {
37 | std::string s{"4.9333417E-1"};
38 | std::vector parsed = treelite::model_loader::detail::xgboost::ParseBaseScore(s);
39 | std::vector expected{4.9333417E-1};
40 | EXPECT_EQ(parsed, expected);
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/cmake/Utils.cmake:
--------------------------------------------------------------------------------
1 | # Set output directory of target, ignoring debug or release
2 | function(set_output_directory target dir)
3 | set_target_properties(${target} PROPERTIES
4 | RUNTIME_OUTPUT_DIRECTORY ${dir} # for executable
5 | RUNTIME_OUTPUT_DIRECTORY_DEBUG ${dir}
6 | RUNTIME_OUTPUT_DIRECTORY_RELEASE ${dir}
7 | LIBRARY_OUTPUT_DIRECTORY ${dir} # for shared library
8 | LIBRARY_OUTPUT_DIRECTORY_DEBUG ${dir}
9 | LIBRARY_OUTPUT_DIRECTORY_RELEASE ${dir}
10 | ARCHIVE_OUTPUT_DIRECTORY ${dir} # for static library
11 | ARCHIVE_OUTPUT_DIRECTORY_DEBUG ${dir}
12 | ARCHIVE_OUTPUT_DIRECTORY_RELEASE ${dir}
13 | )
14 | endfunction(set_output_directory)
15 |
16 | # Set a default build type to release if none was specified
17 | function(set_default_configuration_release)
18 | if(CMAKE_CONFIGURATION_TYPES STREQUAL "Debug;Release;MinSizeRel;RelWithDebInfo") # multiconfig generator?
19 | set(CMAKE_CONFIGURATION_TYPES "Debug;Release" CACHE STRING "" FORCE)
20 | elseif(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
21 | message(STATUS "Setting build type to 'Release' as none was specified.")
22 | set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE )
23 | endif()
24 | endfunction(set_default_configuration_release)
25 |
--------------------------------------------------------------------------------
/.github/workflows/windows-wheel-builder.yml:
--------------------------------------------------------------------------------
1 | name: windows-wheel-builder
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches:
7 | - mainline
8 | - 'release_*'
9 | schedule:
10 | - cron: "0 7 * * *" # Run once daily
11 |
12 | permissions:
13 | contents: read # to fetch code (actions/checkout)
14 |
15 | defaults:
16 | run:
17 | shell: cmd /C CALL {0}
18 |
19 | concurrency:
20 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
21 | cancel-in-progress: true
22 |
23 | env:
24 | COMMIT_ID: ${{ github.sha }}
25 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_IAM_S3_UPLOADER }}
26 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }}
27 |
28 | jobs:
29 | windows-wheel-builder:
30 | name: Build and test Python wheels (Windows)
31 | runs-on: windows-latest
32 | steps:
33 | - uses: actions/checkout@v4
34 | - uses: conda-incubator/setup-miniconda@v3
35 | with:
36 | miniforge-variant: Miniforge3
37 | miniforge-version: latest
38 | conda-remove-defaults: "true"
39 | activate-environment: precommit
40 | environment-file: ops/conda_env/dev.yml
41 | use-mamba: true
42 | - name: Build wheel
43 | run: |
44 | call ops/build-windows.bat
45 | - name: Test wheel
46 | run: |
47 | call ops/test-win-python-wheel.bat
48 |
--------------------------------------------------------------------------------
/src/enum/task_type.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file task_type.cc
4 | * \author Hyunsu Cho
5 | * \brief Utilities for TaskType enum
6 | */
7 |
8 | #include
9 |
10 | #include
11 | #include
12 |
13 | namespace treelite {
14 |
15 | std::string TaskTypeToString(TaskType type) {
16 | switch (type) {
17 | case TaskType::kBinaryClf:
18 | return "kBinaryClf";
19 | case TaskType::kRegressor:
20 | return "kRegressor";
21 | case TaskType::kMultiClf:
22 | return "kMultiClf";
23 | case TaskType::kLearningToRank:
24 | return "kLearningToRank";
25 | case TaskType::kIsolationForest:
26 | return "kIsolationForest";
27 | default:
28 | return "";
29 | }
30 | }
31 |
32 | TaskType TaskTypeFromString(std::string const& str) {
33 | if (str == "kBinaryClf") {
34 | return TaskType::kBinaryClf;
35 | } else if (str == "kRegressor") {
36 | return TaskType::kRegressor;
37 | } else if (str == "kMultiClf") {
38 | return TaskType::kMultiClf;
39 | } else if (str == "kLearningToRank") {
40 | return TaskType::kLearningToRank;
41 | } else if (str == "kIsolationForest") {
42 | return TaskType::kIsolationForest;
43 | } else {
44 | TREELITE_LOG(FATAL) << "Unknown task type: " << str;
45 | return TaskType::kBinaryClf; // to avoid compiler warning
46 | }
47 | }
48 |
49 | } // namespace treelite
50 |
--------------------------------------------------------------------------------
/tests/example_app/example.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file example.cc
4 | * \brief Test using Treelite as a C++ library
5 | */
6 | #include
7 | #include
8 | #include
9 |
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 |
17 | int main(void) {
18 | std::cout << "TREELITE_VERSION = " << TREELITE_VERSION << std::endl;
19 | auto builder = treelite::model_builder::GetModelBuilder(treelite::TypeInfo::kFloat32,
20 | treelite::TypeInfo::kFloat32,
21 | treelite::model_builder::Metadata{2, treelite::TaskType::kRegressor, false, 1, {1}, {1, 1}},
22 | treelite::model_builder::TreeAnnotation{1, {0}, {0}},
23 | treelite::model_builder::PostProcessorFunc{"identity"}, std::vector{0.0});
24 | builder->StartTree();
25 | builder->StartNode(0);
26 | builder->NumericalTest(0, 0.0, true, treelite::Operator::kLT, 1, 2);
27 | builder->EndNode();
28 | builder->StartNode(1);
29 | builder->LeafScalar(-1.0);
30 | builder->EndNode();
31 | builder->StartNode(2);
32 | builder->LeafScalar(1.0);
33 | builder->EndNode();
34 | builder->EndTree();
35 |
36 | auto model = builder->CommitModel();
37 | std::cout << model->GetNumTree() << std::endl;
38 | return 0;
39 | }
40 |
--------------------------------------------------------------------------------
/src/c_api/field_accessor.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file field_accessor.cc
4 | * \author Hyunsu Cho
5 | * \brief C API for accessing fields in Treelite model
6 | */
7 |
8 | #include
9 | #include
10 | #include
11 |
12 | int TreeliteGetHeaderField(
13 | TreeliteModelHandle model, char const* name, TreelitePyBufferFrame* out_frame) {
14 | API_BEGIN();
15 | auto* model_ = static_cast(model);
16 | *out_frame = model_->GetHeaderField(name);
17 | API_END();
18 | }
19 |
20 | int TreeliteGetTreeField(TreeliteModelHandle model, uint64_t tree_id, char const* name,
21 | TreelitePyBufferFrame* out_frame) {
22 | API_BEGIN();
23 | auto* model_ = static_cast(model);
24 | *out_frame = model_->GetTreeField(tree_id, name);
25 | API_END();
26 | }
27 |
28 | int TreeliteSetHeaderField(
29 | TreeliteModelHandle model, char const* name, TreelitePyBufferFrame frame) {
30 | API_BEGIN();
31 | auto* model_ = static_cast(model);
32 | model_->SetHeaderField(name, frame);
33 | API_END();
34 | }
35 |
36 | int TreeliteSetTreeField(
37 | TreeliteModelHandle model, uint64_t tree_id, char const* name, TreelitePyBufferFrame frame) {
38 | API_BEGIN();
39 | auto* model_ = static_cast(model);
40 | model_->SetTreeField(tree_id, name, frame);
41 | API_END();
42 | }
43 |
--------------------------------------------------------------------------------
/.github/workflows/cpack-builder.yml:
--------------------------------------------------------------------------------
1 | name: cpack-builder
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches:
7 | - mainline
8 | - 'release_*'
9 | schedule:
10 | - cron: "0 7 * * *" # Run once daily
11 |
12 | permissions:
13 | contents: read
14 |
15 | defaults:
16 | run:
17 | shell: bash -l {0}
18 |
19 | concurrency:
20 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
21 | cancel-in-progress: true
22 |
23 | env:
24 | COMMIT_ID: ${{ github.sha }}
25 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_IAM_S3_UPLOADER }}
26 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }}
27 |
28 | jobs:
29 | cpack-builder:
30 | name: Build CPack
31 | runs-on: ${{ matrix.runner }}
32 | strategy:
33 | fail-fast: false
34 | matrix:
35 | include:
36 | - arch: aarch64
37 | runner: ubuntu-24.04-arm
38 | - arch: amd64
39 | runner: ubuntu-24.04
40 | steps:
41 | - uses: actions/checkout@v4
42 | - uses: conda-incubator/setup-miniconda@v3
43 | with:
44 | miniforge-variant: Miniforge3
45 | miniforge-version: latest
46 | conda-remove-defaults: "true"
47 | activate-environment: dev
48 | environment-file: ops/conda_env/dev.yml
49 | use-mamba: true
50 | - name: Build CPack
51 | run: |
52 | bash ops/build-cpack.sh ${{ matrix.arch }}
53 |
--------------------------------------------------------------------------------
/include/treelite/enum/typeinfo.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2017-2023 by Contributors
3 | * \file typeinfo.h
4 | * \brief Defines enum type TypeInfo
5 | * \author Hyunsu Cho
6 | */
7 |
8 | #ifndef TREELITE_ENUM_TYPEINFO_H_
9 | #define TREELITE_ENUM_TYPEINFO_H_
10 |
11 | #include
12 | #include
13 | #include
14 | #include
15 |
16 | #include
17 |
18 | namespace treelite {
19 |
20 | /*! \brief Types used by thresholds and leaf outputs */
21 | enum class TypeInfo : std::uint8_t { kInvalid = 0, kUInt32 = 1, kFloat32 = 2, kFloat64 = 3 };
22 |
23 | /*! \brief Get string representation of TypeInfo */
24 | std::string TypeInfoToString(treelite::TypeInfo info);
25 |
26 | /*! \brief Get TypeInfo from string */
27 | TypeInfo TypeInfoFromString(std::string const& str);
28 |
29 | /*!
30 | * \brief Convert a template type into a type info
31 | * \tparam template type to be converted
32 | * \return TypeInfo corresponding to the template type arg
33 | */
34 | template
35 | inline TypeInfo TypeInfoFromType() {
36 | if (std::is_same_v) {
37 | return TypeInfo::kUInt32;
38 | } else if (std::is_same_v) {
39 | return TypeInfo::kFloat32;
40 | } else if (std::is_same_v) {
41 | return TypeInfo::kFloat64;
42 | } else {
43 | return TypeInfo::kInvalid;
44 | }
45 | }
46 |
47 | } // namespace treelite
48 |
49 | #endif // TREELITE_ENUM_TYPEINFO_H_
50 |
--------------------------------------------------------------------------------
/src/gtil/output_shape.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file output_shape.cc
4 | * \author Hyunsu Cho
5 | * \brief Compute output shape for GTIL, so that callers can allocate sufficient space
6 | * to hold outputs.
7 | */
8 | #include
9 | #include
10 | #include
11 |
12 | #include
13 | #include
14 |
15 | namespace treelite::gtil {
16 |
17 | std::vector GetOutputShape(
18 | Model const& model, std::uint64_t num_row, Configuration const& config) {
19 | auto const num_tree = model.GetNumTree();
20 | auto const max_num_class = static_cast(
21 | *std::max_element(model.num_class.Data(), model.num_class.Data() + model.num_target));
22 | switch (config.pred_kind) {
23 | case PredictKind::kPredictDefault:
24 | case PredictKind::kPredictRaw:
25 | if (model.num_target > 1) {
26 | return {num_row, static_cast(model.num_target), max_num_class};
27 | } else {
28 | return {num_row, 1, max_num_class};
29 | }
30 | case PredictKind::kPredictLeafID:
31 | return {num_row, num_tree};
32 | case PredictKind::kPredictPerTree:
33 | return {num_row, num_tree,
34 | static_cast(model.leaf_vector_shape[0]) * model.leaf_vector_shape[1]};
35 | default:
36 | TREELITE_LOG(FATAL) << "Unsupported model type: " << static_cast(config.pred_kind);
37 | return {};
38 | }
39 | }
40 |
41 | } // namespace treelite::gtil
42 |
--------------------------------------------------------------------------------
/tests/examples/deep_lightgbm/model.txt:
--------------------------------------------------------------------------------
1 | tree
2 | version=v4
3 | num_class=1
4 | num_tree_per_iteration=1
5 | label_index=0
6 | max_feature_idx=0
7 | objective=regression
8 | feature_names=this
9 | feature_infos=[0:100]
10 | tree_sizes=1119
11 |
12 | Tree=0
13 | num_leaves=32
14 | num_cat=0
15 | split_feature=0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
16 | split_gain=0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
17 | threshold=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
18 | decision_type=2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
19 | left_child=1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 -31
20 | right_child=-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -32
21 | leaf_value=31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 0 1
22 | leaf_weight=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
23 | leaf_count=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
24 | internal_value=0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
25 | internal_weight=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
26 | internal_count=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
27 | is_linear=0
28 | shrinkage=1
29 |
30 |
31 | end of trees
32 |
33 | feature_importances:
34 | this=31
35 |
36 | pandas_categorical:null
37 |
--------------------------------------------------------------------------------
/src/model_query.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2024 by Contributors
3 | * \file model_query.cc
4 | * \author Hyunsu Cho
5 | * \brief Methods for querying various properties of tree models
6 | */
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | #include
14 |
15 | namespace {
16 |
17 | template
18 | std::uint32_t GetDepth(treelite::Tree const& tree) {
19 | // Visit all trees nodes in depth-first order
20 | std::stack st;
21 | st.push(0);
22 | std::uint32_t max_depth = 0;
23 | std::uint32_t depth = 1;
24 | while (!st.empty()) {
25 | int node_id = st.top();
26 | st.pop();
27 | if (tree.IsLeaf(node_id)) {
28 | --depth;
29 | } else {
30 | st.push(tree.LeftChild(node_id));
31 | st.push(tree.RightChild(node_id));
32 | ++depth;
33 | }
34 | max_depth = std::max(max_depth, depth);
35 | }
36 | return max_depth;
37 | }
38 |
39 | } // anonymous namespace
40 |
41 | namespace treelite {
42 |
43 | std::vector Model::GetTreeDepth() const {
44 | return std::visit(
45 | [](auto&& concrete_model) {
46 | std::vector depth;
47 | depth.reserve(concrete_model.trees.size());
48 | for (auto const& tree : concrete_model.trees) {
49 | depth.push_back(GetDepth(tree));
50 | }
51 | return depth;
52 | },
53 | variant_);
54 | }
55 |
56 | } // namespace treelite
57 |
--------------------------------------------------------------------------------
/tests/ci_build/build_macos_python_wheels.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | if [[ $# -ne 2 ]]; then
7 | echo "Usage: $0 [platform_id] [commit ID]"
8 | exit 1
9 | fi
10 |
11 | platform_id=$1
12 | shift
13 | commit_id=$1
14 | shift
15 |
16 | if [[ "$platform_id" == macosx_* ]]; then
17 | if [[ "$platform_id" == macosx_arm64 ]]; then
18 | # MacOS, Apple Silicon
19 | wheel_tag=macosx_12_0_arm64
20 | cpython_ver=310
21 | cibw_archs=arm64
22 | export MACOSX_DEPLOYMENT_TARGET=12.0
23 | elif [[ "$platform_id" == macosx_x86_64 ]]; then
24 | # MacOS, Intel
25 | wheel_tag=macosx_10_15_x86_64.macosx_11_0_x86_64.macosx_12_0_x86_64
26 | cpython_ver=310
27 | cibw_archs=x86_64
28 | export MACOSX_DEPLOYMENT_TARGET=10.15
29 | else
30 | echo "Platform not supported: $platform_id"
31 | exit 3
32 | fi
33 | # Set up environment variables to configure cibuildwheel
34 | export CIBW_BUILD=cp${cpython_ver}-${platform_id}
35 | export CIBW_ARCHS=${cibw_archs}
36 | export CIBW_TEST_SKIP='*-macosx_arm64'
37 | export CIBW_BUILD_VERBOSITY=3
38 | else
39 | echo "Platform not supported: $platform_id"
40 | exit 2
41 | fi
42 |
43 | # Tell delocate-wheel to not vendor libomp.dylib into the wheel
44 | export CIBW_REPAIR_WHEEL_COMMAND_MACOS="delocate-wheel --require-archs {delocate_archs} -w {dest_dir} -v {wheel} --exclude libomp.dylib"
45 |
46 | python -m cibuildwheel python --output-dir wheelhouse
47 | python tests/ci_build/rename_whl.py wheelhouse ${commit_id} ${wheel_tag}
48 |
--------------------------------------------------------------------------------
/docs/treelite-api.rst:
--------------------------------------------------------------------------------
1 | ============
2 | Treelite API
3 | ============
4 |
5 | API of Treelite Python package.
6 |
7 | .. contents::
8 | :local:
9 |
10 | Model loaders
11 | -------------
12 |
13 | .. automodule:: treelite.frontend
14 | :members:
15 | :member-order: bysource
16 |
17 | Scikit-learn importer
18 | ---------------------
19 |
20 | .. automodule:: treelite.sklearn
21 | :members:
22 | :member-order: bysource
23 |
24 | Model builder
25 | -------------
26 |
27 | .. automodule:: treelite.model_builder
28 | :members:
29 | :member-order: bysource
30 |
31 | Model class
32 | -----------
33 |
34 | .. autoclass:: treelite.Model
35 | :members:
36 | :member-order: bysource
37 | :exclude-members: load, from_xgboost, from_xgboost_json, from_lightgbm
38 |
39 |
40 | .. _field_accessors:
41 |
42 | Field accessors (Advanced)
43 | --------------------------
44 | Using field accessors, users can query and modify the value of fields in a :py:class:`~treelite.Model` object.
45 | See :doc:`/tutorials/edit` for more details.
46 |
47 | .. note:: Modifying a field is an unsafe operation
48 |
49 | Treelite does not prevent users from assigning an invalid value to a field. Setting an invalid value may
50 | cause undefined behavior. Always consult :doc:`the model spec ` to carefully examine
51 | model invariants and constraints on fields. For example, most tree fields must have an array of length ``num_nodes``.
52 |
53 | .. autoclass:: treelite.model.HeaderAccessor
54 | :members:
55 |
56 | .. autoclass:: treelite.model.TreeAccessor
57 | :members:
58 |
--------------------------------------------------------------------------------
/tests/ci_build/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # This script is a wrapper creating the same user inside container as the one
4 | # running the ci_build.sh outside the container. It also set the home directory
5 | # for the user inside container to match the same absolute path as the workspace
6 | # outside of container. Do not run this manually. It does not make sense. It is
7 | # intended to be called by ci_build.sh only.
8 |
9 | set -e
10 |
11 | COMMAND=("$@")
12 |
13 | if ! touch /this_is_writable_file_system; then
14 | echo "You can't write to your filesystem!"
15 | echo "If you are in Docker you should check you do not have too many images" \
16 | "with too many files in them. Docker has some issue with it."
17 | exit 1
18 | else
19 | rm /this_is_writable_file_system
20 | fi
21 |
22 | if [[ -n $CI_BUILD_UID ]] && [[ -n $CI_BUILD_GID ]]; then
23 | groupadd -o -g "${CI_BUILD_GID}" "${CI_BUILD_GROUP}"
24 | useradd -o -m -g "${CI_BUILD_GID}" -u "${CI_BUILD_UID}" \
25 | "${CI_BUILD_USER}"
26 | export HOME="/home/${CI_BUILD_USER}"
27 | shopt -s dotglob
28 | cp -r /root/* "$HOME/"
29 | chown -R "${CI_BUILD_UID}:${CI_BUILD_GID}" "$HOME"
30 |
31 | # Allows project-specific customization
32 | if [[ -e "/workspace/.pre_entry.sh" ]]; then
33 | gosu "${CI_BUILD_UID}:${CI_BUILD_GID}" /workspace/.pre_entry.sh
34 | fi
35 |
36 | # Enable passwordless sudo capabilities for the user
37 | chown root:"${CI_BUILD_GID}" "$(which gosu)"
38 | chmod +s "$(which gosu)"; sync
39 |
40 | exec gosu "${CI_BUILD_UID}:${CI_BUILD_GID}" "${COMMAND[@]}"
41 | else
42 | exec "${COMMAND[@]}"
43 | fi
44 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit-update.yml:
--------------------------------------------------------------------------------
1 | name: pre-commit-update
2 |
3 | on:
4 | workflow_dispatch:
5 | schedule:
6 | - cron: "0 20 * * 1" # Run once weekly
7 |
8 | permissions:
9 | pull-requests: write
10 | contents: write
11 |
12 | defaults:
13 | run:
14 | shell: bash -l {0}
15 |
16 | concurrency:
17 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
18 | cancel-in-progress: true
19 |
20 | jobs:
21 | pre-commit-update:
22 | name: Auto-update pre-commit hooks
23 | runs-on: ubuntu-latest
24 | steps:
25 | - uses: actions/create-github-app-token@v1
26 | id: generate-token
27 | with:
28 | app-id: ${{ secrets.TOKEN_GENERATOR_APP_ID }}
29 | private-key: ${{ secrets.TOKEN_GENERATOR_APP_PRIVATE_KEY }}
30 | - uses: actions/checkout@v4
31 | - uses: conda-incubator/setup-miniconda@v3
32 | with:
33 | miniforge-variant: Miniforge3
34 | miniforge-version: latest
35 | conda-remove-defaults: "true"
36 | activate-environment: precommit
37 | environment-file: ops/conda_env/pre-commit.yml
38 | use-mamba: true
39 | - name: Update pre-commit hooks
40 | run: |
41 | pre-commit autoupdate
42 | - name: Create Pull Request
43 | uses: peter-evans/create-pull-request@v7
44 | if: github.ref == 'refs/heads/mainline'
45 | with:
46 | branch: create-pull-request/pre-commit-update
47 | base: mainline
48 | title: "[CI] Update pre-commit hooks"
49 | commit-message: "[CI] Update pre-commit hooks"
50 | token: ${{ steps.generate-token.outputs.token }}
51 |
--------------------------------------------------------------------------------
/python/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "hatchling>=1.12.1"
4 | ]
5 | backend-path = ["."]
6 | build-backend = "packager.pep517"
7 |
8 | [project]
9 | name = "treelite"
10 | version = "4.7.0.dev0"
11 | authors = [
12 | {name = "Hyunsu Cho", email = "chohyu01@cs.washington.edu"}
13 | ]
14 | description = "Treelite: Universal model exchange format for decision tree forests"
15 | readme = {file = "README.rst", content-type = "text/x-rst"}
16 | requires-python = ">=3.8"
17 | license = {text = "Apache-2.0"}
18 | classifiers = [
19 | "License :: OSI Approved :: Apache Software License",
20 | "Development Status :: 5 - Production/Stable",
21 | "Operating System :: OS Independent",
22 | "Programming Language :: Python",
23 | "Programming Language :: Python :: 3",
24 | "Programming Language :: Python :: 3.8",
25 | "Programming Language :: Python :: 3.9",
26 | "Programming Language :: Python :: 3.10"
27 | ]
28 | dependencies = [
29 | "numpy",
30 | "scipy",
31 | "packaging"
32 | ]
33 |
34 | [project.urls]
35 | documentation = "https://treelite.readthedocs.io/en/latest/"
36 | repository = "https://github.com/dmlc/treelite"
37 |
38 | [project.optional-dependencies]
39 | scikit-learn = ["scikit-learn"]
40 | testing = ["scikit-learn", "pytest", "hypothesis", "pandas"]
41 |
42 | [tool.mypy]
43 | plugins = "numpy.typing.mypy_plugin"
44 |
45 | [tool.hatch.build.targets.wheel.hooks.custom]
46 |
47 | [tool.ruff]
48 | line-length = 120
49 |
50 | # this should be set to the oldest version of python treelite supports
51 | target-version = "py38"
52 |
53 | [tool.ruff.lint]
54 | select = [
55 | # numpy 2.0 deprecations/removals
56 | "NPY201",
57 | ]
58 |
--------------------------------------------------------------------------------
/src/gtil/config.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file config.cc
4 | * \author Hyunsu Cho
5 | * \brief Configuration handling logic for GTIL
6 | */
7 | #include
8 |
9 | #include
10 | #include
11 |
12 | #include
13 |
14 | namespace treelite::gtil {
15 |
16 | Configuration::Configuration(std::string const& config_json) {
17 | rapidjson::Document parsed_config;
18 | parsed_config.Parse(config_json);
19 |
20 | if (parsed_config.IsObject()) {
21 | auto itr = parsed_config.FindMember("predict_type");
22 | if (itr != parsed_config.MemberEnd() && itr->value.IsString()) {
23 | auto value = std::string(itr->value.GetString());
24 | if (value == "default") {
25 | this->pred_kind = PredictKind::kPredictDefault;
26 | } else if (value == "raw") {
27 | this->pred_kind = PredictKind::kPredictRaw;
28 | } else if (value == "leaf_id") {
29 | this->pred_kind = PredictKind::kPredictLeafID;
30 | } else if (value == "score_per_tree") {
31 | this->pred_kind = PredictKind::kPredictPerTree;
32 | } else {
33 | TREELITE_LOG(FATAL) << "Unknown prediction type: " << value;
34 | }
35 | } else {
36 | TREELITE_LOG(FATAL) << "The field \"predict_type\" must be specified";
37 | }
38 | itr = parsed_config.FindMember("nthread");
39 | if (itr != parsed_config.MemberEnd()) {
40 | TREELITE_CHECK(itr->value.IsInt()) << "nthread must be an integer";
41 | this->nthread = itr->value.GetInt();
42 | }
43 | } else {
44 | TREELITE_LOG(FATAL) << "The JSON string must be a valid JSON object";
45 | }
46 | }
47 |
48 | } // namespace treelite::gtil
49 |
--------------------------------------------------------------------------------
/include/treelite/detail/omp_exception.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2022-2023 by Contributors
3 | * \file omp_exception.h
4 | * \author Hyunsu Cho
5 | * \brief Utility to propagate exceptions throws inside an OpenMP block
6 | */
7 | #ifndef TREELITE_DETAIL_OMP_EXCEPTION_H_
8 | #define TREELITE_DETAIL_OMP_EXCEPTION_H_
9 |
10 | #include
11 | #include
12 |
13 | #include
14 |
15 | namespace treelite {
16 |
17 | /*!
18 | * \brief OMP Exception class catches, saves and rethrows exception from OMP blocks
19 | */
20 | class OMPException {
21 | private:
22 | // exception_ptr member to store the exception
23 | std::exception_ptr omp_exception_;
24 | // mutex to be acquired during catch to set the exception_ptr
25 | std::mutex mutex_;
26 |
27 | public:
28 | /*!
29 | * \brief Parallel OMP blocks should be placed within Run to save exception
30 | */
31 | template
32 | void Run(Function f, Parameters... params) {
33 | try {
34 | f(params...);
35 | } catch (treelite::Error& ex) {
36 | std::lock_guard lock(mutex_);
37 | if (!omp_exception_) {
38 | omp_exception_ = std::current_exception();
39 | }
40 | } catch (std::exception& ex) {
41 | std::lock_guard lock(mutex_);
42 | if (!omp_exception_) {
43 | omp_exception_ = std::current_exception();
44 | }
45 | }
46 | }
47 |
48 | /*!
49 | * \brief should be called from the main thread to rethrow the exception
50 | */
51 | void Rethrow() {
52 | if (this->omp_exception_) {
53 | std::rethrow_exception(this->omp_exception_);
54 | }
55 | }
56 | };
57 |
58 | } // namespace treelite
59 |
60 | #endif // TREELITE_DETAIL_OMP_EXCEPTION_H_
61 |
--------------------------------------------------------------------------------
/.github/workflows/linux-wheel-builder.yml:
--------------------------------------------------------------------------------
1 | name: linux-wheel-builder
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches:
7 | - mainline
8 | - 'release_*'
9 | schedule:
10 | - cron: "0 7 * * *" # Run once daily
11 |
12 | permissions:
13 | contents: read # to fetch code (actions/checkout)
14 |
15 | defaults:
16 | run:
17 | shell: bash -l {0}
18 |
19 | concurrency:
20 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
21 | cancel-in-progress: true
22 |
23 | env:
24 | COMMIT_ID: ${{ github.sha }}
25 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_IAM_S3_UPLOADER }}
26 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }}
27 |
28 | jobs:
29 | linux-wheel-builder:
30 | name: Build and test Python wheels (Linux)
31 | runs-on: ${{ matrix.os }}
32 | strategy:
33 | matrix:
34 | include:
35 | - os: ubuntu-latest
36 | build-script: ops/build-linux.sh
37 | - os: ubuntu-22.04-arm
38 | build-script: ops/build-linux-aarch64.sh
39 |
40 | steps:
41 | - uses: actions/checkout@v4
42 | - uses: conda-incubator/setup-miniconda@v3
43 | with:
44 | miniforge-variant: Miniforge3
45 | miniforge-version: latest
46 | conda-remove-defaults: "true"
47 | activate-environment: dev
48 | environment-file: ops/conda_env/dev.yml
49 | use-mamba: true
50 |
51 | - name: Display Conda env
52 | run: |
53 | conda info
54 | conda list
55 |
56 | - name: Build wheel
57 | run: |
58 | bash ${{ matrix.build-script }}
59 |
60 | - name: Test wheel
61 | run: |
62 | bash ops/test-linux-python-wheel.sh
63 |
--------------------------------------------------------------------------------
/ops/test-serializer-compatibility.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -euo pipefail
4 |
5 | echo "##[section]Building Treelite..."
6 | mkdir -p build
7 | cd build
8 | cmake .. -GNinja
9 | ninja
10 | cd ..
11 |
12 | CURRENT_VERSION=$(cat python/treelite/VERSION)
13 |
14 | echo "##[section]Testing serialization: 3.9 -> ${CURRENT_VERSION}"
15 | pip install --force-reinstall treelite==3.9.0 treelite_runtime==3.9.0
16 | python tests/serializer/compatibility_tester.py --task save --checkpoint-path checkpoint.bin \
17 | --model-pickle-path model.pkl --expected-treelite-version 3.9.0
18 | PYTHONPATH=./python/ python tests/serializer/compatibility_tester.py --task load \
19 | --checkpoint-path checkpoint.bin --model-pickle-path model.pkl \
20 | --expected-treelite-version ${CURRENT_VERSION}
21 |
22 | echo "##[section]Testing serialization: 4.3.0 -> ${CURRENT_VERSION}"
23 | pip install --force-reinstall treelite==4.3.0
24 | python tests/serializer/compatibility_tester.py --task save --checkpoint-path checkpoint.bin \
25 | --model-pickle-path model.pkl --expected-treelite-version 4.3.0
26 | PYTHONPATH=./python/ python tests/serializer/compatibility_tester.py --task load \
27 | --checkpoint-path checkpoint.bin --model-pickle-path model.pkl \
28 | --expected-treelite-version ${CURRENT_VERSION}
29 |
30 | echo "##[section]Testing serialization: ${CURRENT_VERSION} -> ${CURRENT_VERSION}"
31 | PYTHONPATH=./python/ python tests/serializer/compatibility_tester.py --task save \
32 | --checkpoint-path checkpoint.bin --model-pickle-path model.pkl \
33 | --expected-treelite-version ${CURRENT_VERSION}
34 | PYTHONPATH=./python/ python tests/serializer/compatibility_tester.py --task load \
35 | --checkpoint-path checkpoint.bin --model-pickle-path model.pkl \
36 | --expected-treelite-version ${CURRENT_VERSION}
37 |
--------------------------------------------------------------------------------
/.github/workflows/macos-wheel-builder.yml:
--------------------------------------------------------------------------------
1 | name: macos-wheel-builder
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches:
7 | - mainline
8 | - 'release_*'
9 | schedule:
10 | - cron: "0 7 * * *" # Run once daily
11 |
12 | permissions:
13 | contents: read # to fetch code (actions/checkout)
14 |
15 | defaults:
16 | run:
17 | shell: bash -l {0}
18 |
19 | concurrency:
20 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
21 | cancel-in-progress: true
22 |
23 | env:
24 | COMMIT_ID: ${{ github.sha }}
25 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_IAM_S3_UPLOADER }}
26 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }}
27 |
28 | jobs:
29 | macos-wheel-builder:
30 | name: Build and test Python wheels (MacOS)
31 | runs-on: ${{ matrix.os }}
32 | strategy:
33 | fail-fast: false
34 | matrix:
35 | include:
36 | - os: macos-15-intel
37 | platform_id: macosx_x86_64
38 | - os: macos-14
39 | platform_id: macosx_arm64
40 | env:
41 | CIBW_PLATFORM_ID: ${{ matrix.cibw_platform_id }}
42 | steps:
43 | - uses: actions/checkout@v4
44 | - uses: conda-incubator/setup-miniconda@v3
45 | with:
46 | miniforge-variant: Miniforge3
47 | miniforge-version: latest
48 | conda-remove-defaults: "true"
49 | activate-environment: dev
50 | environment-file: ops/conda_env/dev.yml
51 | use-mamba: true
52 | - name: Display Conda env
53 | run: |
54 | conda info
55 | conda list
56 | - name: Build wheel
57 | run: |
58 | bash ops/build-macos.sh ${{ matrix.platform_id }} ${{ github.sha }}
59 | - name: Test wheel
60 | run: |
61 | bash ops/test-macos-python-wheel.sh
62 |
--------------------------------------------------------------------------------
/include/treelite/c_api_error.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2017-2023 by Contributors
3 | * \file c_api_error.h
4 | * \author Hyunsu Cho
5 | * \brief Error handling for C API.
6 | */
7 | #ifndef TREELITE_C_API_ERROR_H_
8 | #define TREELITE_C_API_ERROR_H_
9 |
10 | #include
11 |
12 | /*! \brief Macro to guard beginning and end section of all functions */
13 | #define API_BEGIN() try {
14 | /*! \brief Every function starts with API_BEGIN();
15 | and finishes with API_END() or API_END_HANDLE_ERROR */
16 | #define API_END() \
17 | } \
18 | catch (std::exception & _except_) { \
19 | return TreeliteAPIHandleException(_except_); \
20 | } \
21 | return 0
22 | /*!
23 | * \brief Every function starts with API_BEGIN();
24 | * and finishes with API_END() or API_END_HANDLE_ERROR()
25 | * "Finalize" contains procedure to cleanup states when an error happens
26 | */
27 | #define API_END_HANDLE_ERROR(Finalize) \
28 | } \
29 | catch (std::exception & _except_) { \
30 | Finalize; \
31 | return TreeliteAPIHandleException(_except_); \
32 | } \
33 | return 0
34 |
35 | /*!
36 | * \brief Set the last error message needed by C API
37 | * \param msg Error message to set.
38 | */
39 | void TreeliteAPISetLastError(char const* msg);
40 | /*!
41 | * \brief handle Exception thrown out
42 | * \param e Exception object
43 | * \return The return value of API after exception is handled
44 | */
45 | inline int TreeliteAPIHandleException(std::exception const& e) {
46 | TreeliteAPISetLastError(e.what());
47 | return -1;
48 | }
49 | #endif // TREELITE_C_API_ERROR_H_
50 |
--------------------------------------------------------------------------------
/tests/cpp/test_utils.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file test_utils.cc
4 | * \author Hyunsu Cho
5 | * \brief C++ tests for utility functions
6 | */
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | #include
14 |
15 | #include
16 |
17 | TEST(FileUtils, StreamIO) {
18 | std::string s{"Hello world"};
19 | std::string s2;
20 | std::filesystem::path tmpdir = std::filesystem::temp_directory_path();
21 | std::filesystem::path filepath = tmpdir / std::filesystem::u8path("ななひら.txt");
22 |
23 | {
24 | std::ofstream ofs = treelite::detail::OpenFileForWriteAsStream(filepath);
25 | ofs.write(s.data(), s.length());
26 | }
27 | {
28 | std::ifstream ifs = treelite::detail::OpenFileForReadAsStream(filepath);
29 | s2.resize(s.length());
30 | ifs.read(s2.data(), s.length());
31 | ASSERT_EQ(s, s2);
32 | }
33 |
34 | std::filesystem::remove(filepath);
35 | }
36 |
37 | TEST(FileUtils, OpenFileForReadAsFilePtr) {
38 | std::string s{"Hello world"};
39 | std::string s2;
40 | std::filesystem::path tmpdir = std::filesystem::temp_directory_path();
41 | std::filesystem::path filepath = tmpdir / std::filesystem::u8path("ななひら.txt");
42 |
43 | {
44 | std::ofstream ofs(filepath, std::ios::out | std::ios::binary);
45 | ASSERT_TRUE(ofs);
46 | ofs.exceptions(std::ios::failbit | std::ios::badbit);
47 | ofs.write(s.data(), s.length());
48 | }
49 | {
50 | FILE* fp = treelite::detail::OpenFileForReadAsFilePtr(filepath);
51 | ASSERT_TRUE(fp);
52 | s2.resize(s.length());
53 | ASSERT_EQ(std::fread(s2.data(), sizeof(char), s.length(), fp), s.length());
54 | ASSERT_EQ(s, s2);
55 | ASSERT_EQ(std::fclose(fp), 0);
56 | }
57 |
58 | std::filesystem::remove(filepath);
59 | }
60 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v6.0.0
6 | hooks:
7 | - id: trailing-whitespace
8 | - id: end-of-file-fixer
9 | - id: check-yaml
10 | - id: check-added-large-files
11 | args: ["--maxkb=4000"]
12 | - repo: https://github.com/psf/black
13 | rev: 25.12.0
14 | hooks:
15 | - id: black
16 | - repo: https://github.com/pycqa/isort
17 | rev: 7.0.0
18 | hooks:
19 | - id: isort
20 | args: ["--profile", "black", "--filter-files"]
21 | - repo: https://github.com/MarcoGorelli/cython-lint
22 | rev: v0.18.1
23 | hooks:
24 | - id: cython-lint
25 | - id: double-quote-cython-strings
26 | - repo: https://github.com/PyCQA/flake8
27 | rev: 7.3.0
28 | hooks:
29 | - id: flake8
30 | args: [--config=.flake8]
31 | files: .*$
32 | types: [file]
33 | types_or: [python]
34 | additional_dependencies: [flake8-force]
35 | - repo: https://github.com/pocc/pre-commit-hooks
36 | rev: v1.3.5
37 | hooks:
38 | - id: clang-format
39 | args: ["-i", "--style=file:.clang-format"]
40 | language: python
41 | additional_dependencies: [clang-format>=15.0]
42 | types_or: [c++]
43 | - id: cpplint
44 | language: python
45 | args: [
46 | "--linelength=100", "--recursive",
47 | "--filter=-build/c++11,-build/include,-build/namespaces_literals,-runtime/references,-build/include_order,+build/include_what_you_use",
48 | "--root=include"]
49 | additional_dependencies: [cpplint==1.6.1]
50 | types_or: [c++]
51 | - repo: https://github.com/pre-commit/mirrors-mypy
52 | rev: v1.19.0
53 | hooks:
54 | - id: mypy
55 | additional_dependencies: [types-setuptools]
56 | - repo: https://github.com/astral-sh/ruff-pre-commit
57 | rev: v0.14.8
58 | hooks:
59 | - id: ruff
60 | args: ["--config", "python/pyproject.toml"]
61 |
--------------------------------------------------------------------------------
/cmake/Sanitizer.cmake:
--------------------------------------------------------------------------------
1 | # Set appropriate compiler and linker flags for sanitizers.
2 | #
3 | # Usage of this module:
4 | # enable_sanitizers("address;leak")
5 |
6 | # Add flags
7 | macro(enable_sanitizer sanitizer)
8 | if(${sanitizer} MATCHES "address")
9 | set(SAN_COMPILE_FLAGS "${SAN_COMPILE_FLAGS} -fsanitize=address")
10 |
11 | elseif(${sanitizer} MATCHES "thread")
12 | set(SAN_COMPILE_FLAGS "${SAN_COMPILE_FLAGS} -fsanitize=thread")
13 |
14 | elseif(${sanitizer} MATCHES "leak")
15 | set(SAN_COMPILE_FLAGS "${SAN_COMPILE_FLAGS} -fsanitize=leak")
16 |
17 | elseif(${sanitizer} MATCHES "undefined")
18 | set(SAN_COMPILE_FLAGS "${SAN_COMPILE_FLAGS} -fsanitize=undefined -fno-sanitize-recover=undefined")
19 |
20 | else()
21 | message(FATAL_ERROR "Santizer ${sanitizer} not supported.")
22 | endif()
23 | endmacro()
24 |
25 | macro(enable_sanitizers SANITIZERS)
26 | # Check sanitizers compatibility.
27 | foreach(_san ${SANITIZERS})
28 | string(TOLOWER ${_san} _san)
29 | if(_san MATCHES "thread")
30 | if(${_use_other_sanitizers})
31 | message(FATAL_ERROR
32 | "thread sanitizer is not compatible with ${_san} sanitizer.")
33 | endif()
34 | set(_use_thread_sanitizer 1)
35 | else()
36 | if(${_use_thread_sanitizer})
37 | message(FATAL_ERROR
38 | "${_san} sanitizer is not compatible with thread sanitizer.")
39 | endif()
40 | set(_use_other_sanitizers 1)
41 | endif()
42 | endforeach()
43 |
44 | message(STATUS "Sanitizers: ${SANITIZERS}")
45 |
46 | foreach(_san ${SANITIZERS})
47 | string(TOLOWER ${_san} _san)
48 | enable_sanitizer(${_san})
49 | endforeach()
50 | message(STATUS "Sanitizers compile flags: ${SAN_COMPILE_FLAGS}")
51 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SAN_COMPILE_FLAGS}")
52 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SAN_COMPILE_FLAGS}")
53 | endmacro()
54 |
--------------------------------------------------------------------------------
/python/packager/build_config.py:
--------------------------------------------------------------------------------
1 | """Build configuration"""
2 |
3 | import dataclasses
4 | from typing import Any, Dict, List, Optional
5 |
6 |
7 | @dataclasses.dataclass
8 | class BuildConfiguration: # pylint: disable=R0902
9 | """Configurations use when building libtreelite"""
10 |
11 | # Whether to enable OpenMP
12 | use_openmp: bool = True
13 | # Whether to hide C++ symbols
14 | hide_cxx_symbols: bool = True
15 | # Whether to use the Treelite library that's installed in the system prefix
16 | use_system_libtreelite: bool = False
17 | # Manually configure path for the Treelite library.
18 | # Only applicable when use_system_libtreelite=True
19 | system_libtreelite_dir: str = ""
20 |
21 | def _set_config_setting(self, config_settings: Dict[str, Any]) -> None:
22 | for field_name in config_settings:
23 | config_value = config_settings[field_name]
24 | if field_name == "system_libtreelite_dir":
25 | setattr(self, field_name, config_value)
26 | else:
27 | setattr(
28 | self,
29 | field_name,
30 | (config_value.lower() in ["true", "1", "on"]),
31 | )
32 |
33 | def update(self, config_settings: Optional[Dict[str, Any]]) -> None:
34 | """Parse config_settings from Pip (or other PEP 517 frontend)"""
35 | if config_settings is not None:
36 | self._set_config_setting(config_settings)
37 |
38 | def get_cmake_args(self) -> List[str]:
39 | """Convert build configuration to CMake args"""
40 | cmake_args = []
41 | for field_name in [x.name for x in dataclasses.fields(self)]:
42 | if field_name in ["use_system_libtreelite", "system_libtreelite_dir"]:
43 | continue
44 | cmake_option = field_name.upper()
45 | cmake_value = "ON" if getattr(self, field_name) is True else "OFF"
46 | cmake_args.append(f"-D{cmake_option}={cmake_value}")
47 | return cmake_args
48 |
--------------------------------------------------------------------------------
/tests/ci_build/rename_whl.py:
--------------------------------------------------------------------------------
1 | """Rename a Python wheel"""
2 |
3 | import argparse
4 | import glob
5 | import os
6 | from contextlib import contextmanager
7 |
8 |
9 | @contextmanager
10 | def cd(path): # pylint: disable=C0103
11 | """Temporarily change working directory"""
12 | path = os.path.normpath(path)
13 | cwd = os.getcwd()
14 | os.chdir(path)
15 | print(f"cd {path}")
16 | try:
17 | yield path
18 | finally:
19 | os.chdir(cwd)
20 |
21 |
22 | def main(args):
23 | """Main function"""
24 | if not os.path.isdir(args.wheel_dir):
25 | raise ValueError("wheel_dir argument must be a directory")
26 |
27 | with cd(args.wheel_dir):
28 | whl_list = list(glob.glob("*.whl"))
29 | for whl_path in whl_list:
30 | basename = os.path.basename(whl_path)
31 | tokens = basename.split("-")
32 | assert len(tokens) == 5
33 | keywords = {
34 | "pkg_name": tokens[0],
35 | "version": tokens[1],
36 | "commit_id": args.commit_id,
37 | "platform_tag": args.platform_tag,
38 | }
39 | new_name = (
40 | "{pkg_name}-{version}+{commit_id}-py3-none-{platform_tag}.whl".format(
41 | **keywords
42 | )
43 | )
44 | print(f"Renaming {basename} to {new_name}...")
45 | os.rename(basename, new_name)
46 |
47 |
48 | if __name__ == "__main__":
49 | DESCRIPTION = (
50 | "Script to rename wheel(s) using a commit ID and platform tag."
51 | "Note: This script will not recurse into subdirectories."
52 | )
53 | parser = argparse.ArgumentParser(description=DESCRIPTION)
54 | parser.add_argument("wheel_dir", type=str, help="Directory containing wheels")
55 | parser.add_argument("commit_id", type=str, help="Hash of current git commit")
56 | parser.add_argument(
57 | "platform_tag", type=str, help="Platform tag, PEP 425 compliant"
58 | )
59 | parsed_args = parser.parse_args()
60 | main(parsed_args)
61 |
--------------------------------------------------------------------------------
/docs/serialization/index.rst:
--------------------------------------------------------------------------------
1 | Notes on Serialization
2 | ======================
3 |
4 | Treelite model objects can be serialized into three ways:
5 |
6 | * Byte sequence. A tree model can be serialized to a byte sequence in memory. Later,
7 | we can recover the same tree model by deserializing it from the byte sequence.
8 | * Files. Tree models can be converted into Treelite checkpoint files that can be later
9 | read back.
10 | * `Python Buffer Protocol `_, to enable
11 | zero-copy serialization in the Python programming environment. When pickling a Python
12 | object containing a Treelite model object, we can convert the Treelite model into
13 | a byte sequence without physically making copies in memory.
14 |
15 | We make certain guarantees about compatiblity of serialization. It is possible to
16 | exchange serialized tree models between two different Treelite versions, as follows:
17 |
18 | .. |tick| unicode:: U+2714
19 | .. |cross| unicode:: U+2718
20 |
21 | +----------------------+--------------+--------------+--------------------+---------------+
22 | | | To: ``=3.9`` | To: ``=4.0`` | To: ``>=4.1,<5.0`` | To: ``>=5.0`` |
23 | +----------------------+--------------+--------------+--------------------+---------------+
24 | | From: ``=3.9`` | |tick| | |tick| | |tick| | |cross| |
25 | +----------------------+--------------+--------------+--------------------+---------------+
26 | | From: ``=4.0`` | |cross| | |tick| | |tick| | |tick| |
27 | +----------------------+--------------+--------------+--------------------+---------------+
28 | | From: ``>=4.1,<5.0`` | |cross| | |tick| | |tick| | |tick| |
29 | +----------------------+--------------+--------------+--------------------+---------------+
30 | | From: ``>=5.0`` | |cross| | |cross| | |cross| | |tick| |
31 | +----------------------+--------------+--------------+--------------------+---------------+
32 |
33 | .. toctree::
34 | :maxdepth: 1
35 |
36 | v4
37 | v3
38 |
--------------------------------------------------------------------------------
/docs/knobs/postprocessor.rst:
--------------------------------------------------------------------------------
1 | List of postprocessor functions
2 | ===============================
3 | When predicting with tree ensemble models, we sum the margin scores from individual trees and apply a postprocessor
4 | function to transform the sum into a final prediction. This function is also known as the link function.
5 |
6 | Currently, Treelite supports the following postprocessor functions.
7 |
8 | Element-wise postprocessor functions
9 | ------------------------------------
10 | * ``identity``: The identity function. Do not apply any transformation to the margin score vector.
11 | * ``signed_square``: Apply the function ``f(x) = sign(x) * (x**2)`` element-wise to the margin score vector.
12 | * ``hinge``: Apply the function ``f(x) = (1 if x > 0 else 0)`` element-wise to the margin score vector.
13 | * ``sigmoid``: Apply the sigmoid function ``f(x) = 1/(1+exp(-sigmoid_alpha * x))`` element-wise to the margin score
14 | vector, to transform margin scores into probability scores in the range ``[0, 1]``. The ``sigmoid_alpha`` parameter
15 | can be configured by the user.
16 | * ``exponential``: Apply the exponential function (``exp``) element-wise to the margin score vector.
17 | * ``exponential_standard_ratio``: Apply the function ``f(x) = exp2(-x / ratio_c)`` element-wise to the margin score
18 | vector. The ``ratio_c`` parameter can be configured by the user.
19 | * ``logarithm_one_plus_exp``: Apply the function ``f(x) = log(1 + exp(x))`` element-wise to the margin score vector.
20 |
21 | Row-wise postprocessor functions
22 | --------------------------------
23 | * ``identity_multiclass``: The identity function. Do not apply any transformation to the margin score vector.
24 | * ``softmax``: Use the softmax function ``f(x) = exp(x) / sum(exp(x))`` to the margin score vector, to transform the
25 | margin scores into probability scores in the range ``[0, 1]``. Adding up the transformed scores for all classes
26 | will yield 1.
27 | * ``multiclass_ova``: Apply the sigmoid function ``f(x) = 1/(1+exp(-sigmoid_alpha * x))`` element-wise to the margin
28 | scores. The ``sigmoid_alpha`` parameter can be configured by the user.
29 |
--------------------------------------------------------------------------------
/tests/serializer/compatibility_tester.py:
--------------------------------------------------------------------------------
1 | """Script to test backward compatibility of serializer"""
2 |
3 | import argparse
4 | import pickle
5 |
6 | import lightgbm as lgb
7 | import numpy as np
8 | from packaging.version import parse as parse_version
9 | from sklearn.datasets import load_iris
10 |
11 | import treelite
12 |
13 |
14 | def _fetch_data():
15 | X, y = load_iris(return_X_y=True)
16 | return X, y
17 |
18 |
19 | def _train_model(X, y):
20 | clf = lgb.LGBMClassifier(max_depth=3, random_state=0, n_estimators=3)
21 | clf.fit(X, y)
22 | return clf
23 |
24 |
25 | def save(args):
26 | """Save model"""
27 | X, y = _fetch_data()
28 | clf = _train_model(X, y)
29 | with open(args.model_pickle_path, "wb") as f:
30 | pickle.dump(clf, f)
31 | if parse_version(treelite.__version__) >= parse_version("4.0"):
32 | tl_model = treelite.frontend.from_lightgbm(clf.booster_)
33 | else:
34 | tl_model = treelite.Model.from_lightgbm(clf.booster_)
35 | tl_model.serialize(args.checkpoint_path)
36 |
37 |
38 | def load(args):
39 | """Load model"""
40 | X, _ = _fetch_data()
41 | with open(args.model_pickle_path, "rb") as f:
42 | clf = pickle.load(f)
43 | tl_model = treelite.Model.deserialize(args.checkpoint_path)
44 | expected_prob = clf.predict_proba(X).reshape((X.shape[0], 1, -1))
45 | out_prob = treelite.gtil.predict(tl_model, X)
46 | np.testing.assert_almost_equal(out_prob, expected_prob, decimal=5)
47 | print("Test passed!")
48 |
49 |
50 | def main(args):
51 | """Main function"""
52 | if treelite.__version__ != args.expected_treelite_version:
53 | raise ValueError(
54 | f"Expected Treelite {args.expected_treelite_version} "
55 | f"but running Treelite {treelite.__version__}"
56 | )
57 | if args.task == "save":
58 | save(args)
59 | elif args.task == "load":
60 | load(args)
61 |
62 |
63 | if __name__ == "__main__":
64 | parser = argparse.ArgumentParser()
65 | parser.add_argument("--task", type=str, choices=["save", "load"], required=True)
66 | parser.add_argument("--checkpoint-path", type=str, required=True)
67 | parser.add_argument("--model-pickle-path", type=str, required=True)
68 | parser.add_argument("--expected-treelite-version", type=str, required=True)
69 | parsed_args = parser.parse_args()
70 | main(parsed_args)
71 |
--------------------------------------------------------------------------------
/python/treelite/core.py:
--------------------------------------------------------------------------------
1 | """Interface with native lib"""
2 |
3 | import ctypes
4 | import os
5 | import sys
6 | import warnings
7 |
8 | from .libpath import TreeliteLibraryNotFound, find_lib_path
9 | from .util import py_str
10 |
11 |
12 | class TreeliteError(Exception):
13 | """Error thrown by Treelite"""
14 |
15 |
16 | @ctypes.CFUNCTYPE(None, ctypes.c_char_p)
17 | def _log_callback(msg: bytes) -> None:
18 | """Redirect logs from native library into Python console"""
19 | print(py_str(msg))
20 |
21 |
22 | @ctypes.CFUNCTYPE(None, ctypes.c_char_p)
23 | def _warn_callback(msg: bytes) -> None:
24 | """Redirect warnings from native library into Python console"""
25 | warnings.warn(py_str(msg))
26 |
27 |
28 | def _load_lib():
29 | """Load Treelite Library."""
30 | lib_path = [str(x) for x in find_lib_path()]
31 | if not lib_path:
32 | # Building docs
33 | return None # type: ignore
34 | if sys.version_info >= (3, 8) and sys.platform == "win32":
35 | # pylint: disable=no-member
36 | lib_bin_path = os.path.join(os.path.normpath(sys.base_prefix), "Library", "bin")
37 | if os.path.isdir(lib_bin_path):
38 | os.add_dll_directory(lib_bin_path)
39 |
40 | lib = ctypes.cdll.LoadLibrary(lib_path[0])
41 | lib.TreeliteGetLastError.restype = ctypes.c_char_p
42 | lib.log_callback = _log_callback
43 | lib.warn_callback = _warn_callback
44 | if lib.TreeliteRegisterLogCallback(lib.log_callback) != 0:
45 | raise TreeliteError(py_str(lib.TreeliteGetLastError()))
46 | if lib.TreeliteRegisterWarningCallback(lib.warn_callback) != 0:
47 | raise TreeliteError(py_str(lib.TreeliteGetLastError()))
48 | return lib
49 |
50 |
51 | # Load the Treelite library globally
52 | # (do not load if called by Sphinx)
53 | if "sphinx" in sys.modules:
54 | try:
55 | _LIB = _load_lib() # pylint: disable=invalid-name
56 | except TreeliteLibraryNotFound:
57 | _LIB = None # pylint: disable=invalid-name
58 | else:
59 | _LIB = _load_lib() # pylint: disable=invalid-name
60 |
61 |
62 | def _check_call(ret: int) -> None:
63 | """Check the return value of C API call
64 |
65 | This function will raise exception when error occurs.
66 | Wrap every API call with this function.
67 |
68 | Parameters
69 | ----------
70 | ret :
71 | return value from API calls
72 | """
73 | if ret != 0:
74 | raise TreeliteError(_LIB.TreeliteGetLastError().decode("utf-8"))
75 |
--------------------------------------------------------------------------------
/.clang-format:
--------------------------------------------------------------------------------
1 | ---
2 | Language: Cpp
3 | Standard: c++17
4 | BasedOnStyle: Google
5 | TabWidth: 2
6 | IndentWidth: 2
7 | ColumnLimit: 100
8 | UseTab: Never
9 |
10 | AccessModifierOffset: -1
11 | AlignAfterOpenBracket: DontAlign
12 | AlignConsecutiveAssignments: None
13 | AlignConsecutiveDeclarations: None
14 | AlignEscapedNewlines: Left
15 | AlignTrailingComments: false
16 | AllowAllArgumentsOnNextLine: true
17 | AllowAllConstructorInitializersOnNextLine: true
18 | AllowAllParametersOfDeclarationOnNextLine: true
19 | AllowShortBlocksOnASingleLine: Empty
20 | AllowShortCaseLabelsOnASingleLine: false
21 | AllowShortFunctionsOnASingleLine: Empty
22 | AllowShortIfStatementsOnASingleLine: Never
23 | AllowShortLambdasOnASingleLine: All
24 | AllowShortLoopsOnASingleLine: false
25 | AlwaysBreakTemplateDeclarations: Yes
26 | BinPackArguments: true
27 | BinPackParameters: true
28 | BreakBeforeBinaryOperators: All
29 | BreakBeforeBraces: Attach
30 | BreakBeforeTernaryOperators: true
31 | BreakConstructorInitializers: BeforeColon
32 | BreakInheritanceList: AfterColon
33 | BreakStringLiterals: true
34 | CompactNamespaces: false
35 | Cpp11BracedListStyle: true
36 | DerivePointerAlignment: false
37 | IndentWrappedFunctionNames: false
38 | InsertBraces: true
39 | IndentCaseLabels: false
40 | KeepEmptyLinesAtTheStartOfBlocks: false
41 | NamespaceIndentation: None
42 | PointerAlignment: Left
43 | ReflowComments: true
44 |
45 | SortIncludes: CaseInsensitive
46 | IncludeBlocks: Regroup
47 | IncludeCategories:
48 | # Headers in <> without extension.
49 | - Regex: '<([A-Za-z0-9\Q/-_\E])+>'
50 | Priority: 1
51 | # Headers in <> from Treelite
52 | - Regex: '<(treelite)\/'
53 | Priority: 2
54 | # Headers in <> from external libraries.
55 | - Regex: '<(rapidjson|nlohmann|gtest|fmt)\/'
56 | Priority: 3
57 | # Headers in "" with extension.
58 | - Regex: '"([A-Za-z0-9.\Q/-_\E])+"'
59 | Priority: 4
60 |
61 | SpaceAfterCStyleCast: false
62 | SpaceAfterTemplateKeyword: true
63 | SpaceBeforeAssignmentOperators: true
64 | SpaceBeforeCpp11BracedList: false
65 | SpaceBeforeCtorInitializerColon: true
66 | SpaceBeforeInheritanceColon: true
67 | SpaceBeforeParens: ControlStatements
68 | SpaceBeforeRangeBasedForLoopColon: true
69 | SpaceInEmptyParentheses: false
70 | SpacesInAngles: false
71 | SpacesInCStyleCastParentheses: false
72 | SpacesInContainerLiterals: false
73 | SpacesInParentheses: false
74 | SpacesInSquareBrackets: false
75 | QualifierAlignment: Right
76 |
--------------------------------------------------------------------------------
/include/treelite/contiguous_array.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file contiguous_array.h
4 | * \brief A simple array container, with owned or non-owned (externally allocated) buffer
5 | * \author Hyunsu Cho
6 | */
7 |
8 | #ifndef TREELITE_CONTIGUOUS_ARRAY_H_
9 | #define TREELITE_CONTIGUOUS_ARRAY_H_
10 |
11 | #include
12 | #include
13 |
14 | namespace treelite {
15 |
16 | template
17 | class ContiguousArray {
18 | public:
19 | ContiguousArray();
20 | ~ContiguousArray();
21 | // NOTE: use Clone to make deep copy; copy constructors disabled
22 | ContiguousArray(ContiguousArray const&) = delete;
23 | ContiguousArray& operator=(ContiguousArray const&) = delete;
24 | explicit ContiguousArray(std::vector const& other);
25 | ContiguousArray& operator=(std::vector const& other);
26 | ContiguousArray(ContiguousArray&& other) noexcept;
27 | ContiguousArray& operator=(ContiguousArray&& other) noexcept;
28 | inline ContiguousArray Clone() const;
29 | inline void UseForeignBuffer(void* prealloc_buf, std::size_t size);
30 | inline T* Data();
31 | inline T const* Data() const;
32 | inline T* End();
33 | inline T const* End() const;
34 | inline T& Back();
35 | inline T const& Back() const;
36 | inline std::size_t Size() const;
37 | inline bool Empty() const;
38 | inline void Reserve(std::size_t newsize);
39 | inline void Resize(std::size_t newsize);
40 | inline void Resize(std::size_t newsize, T t);
41 | inline void Clear();
42 | inline void PushBack(T t);
43 | inline void Extend(std::vector const& other);
44 | inline void Extend(ContiguousArray const& other);
45 | inline std::vector AsVector() const;
46 | inline bool operator==(ContiguousArray const& other);
47 | /* Unsafe access, no bounds checking */
48 | inline T& operator[](std::size_t idx);
49 | inline T const& operator[](std::size_t idx) const;
50 | /* Safe access, with bounds checking */
51 | inline T& at(std::size_t idx);
52 | inline T const& at(std::size_t idx) const;
53 | /* Safe access, with bounds checking + check against non-existent node (<0) */
54 | inline T& at(int idx);
55 | inline T const& at(int idx) const;
56 | static_assert(std::is_pod::value, "T must be POD");
57 |
58 | private:
59 | T* buffer_;
60 | std::size_t size_;
61 | std::size_t capacity_;
62 | bool owned_buffer_;
63 | };
64 |
65 | } // namespace treelite
66 |
67 | #include
68 |
69 | #endif // TREELITE_CONTIGUOUS_ARRAY_H_
70 |
--------------------------------------------------------------------------------
/src/model_loader/detail/xgboost_json/sax_adapters.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2024 by Contributors
3 | * \file sax_adapters.h
4 | * \brief Adapters to connect RapidJSON and nlohmann/json with the delegated handler
5 | * \author Hyunsu Cho
6 | */
7 |
8 | #ifndef SRC_MODEL_LOADER_DETAIL_XGBOOST_JSON_SAX_ADAPTERS_H_
9 | #define SRC_MODEL_LOADER_DETAIL_XGBOOST_JSON_SAX_ADAPTERS_H_
10 |
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include
18 |
19 | #include
20 |
21 | namespace treelite::model_loader::detail::xgboost {
22 |
23 | class DelegatedHandler;
24 |
25 | /*!
26 | * \brief Adapter for SAX parser from RapidJSON
27 | */
28 | class RapidJSONAdapter {
29 | public:
30 | explicit RapidJSONAdapter(std::shared_ptr handler)
31 | : handler_{std::move(handler)} {}
32 | bool Null();
33 | bool Bool(bool b);
34 | bool Int(int i);
35 | bool Uint(unsigned u);
36 | bool Int64(std::int64_t i);
37 | bool Uint64(std::uint64_t u);
38 | bool Double(double d);
39 | bool RawNumber(char const* str, std::size_t length, bool copy);
40 | bool String(char const* str, std::size_t length, bool copy);
41 | bool StartObject();
42 | bool Key(char const* str, std::size_t length, bool copy);
43 | bool EndObject(std::size_t);
44 | bool StartArray();
45 | bool EndArray(std::size_t);
46 |
47 | private:
48 | std::shared_ptr handler_;
49 | };
50 |
51 | /*!
52 | * \brief Adapter for SAX parser from nlohmann/json
53 | */
54 | class NlohmannJSONAdapter {
55 | public:
56 | explicit NlohmannJSONAdapter(std::shared_ptr handler)
57 | : handler_{std::move(handler)} {}
58 | bool null();
59 | bool boolean(bool val);
60 | bool number_integer(std::int64_t val);
61 | bool number_unsigned(std::uint64_t val);
62 | bool number_float(double val, std::string const&);
63 | bool string(std::string& val);
64 | bool binary(nlohmann::json::binary_t& val);
65 | bool start_object(std::size_t);
66 | bool end_object();
67 | bool start_array(std::size_t);
68 | bool end_array();
69 | bool key(std::string& val);
70 | bool parse_error(
71 | std::size_t position, std::string const& last_token, nlohmann::json::exception const& ex);
72 |
73 | private:
74 | std::shared_ptr handler_;
75 | };
76 |
77 | } // namespace treelite::model_loader::detail::xgboost
78 |
79 | #endif // SRC_MODEL_LOADER_DETAIL_XGBOOST_JSON_SAX_ADAPTERS_H_
80 |
--------------------------------------------------------------------------------
/src/model_loader/detail/lightgbm.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2020-2023 by Contributors
3 | * \file lightgbm.h
4 | * \brief Helper functions for loading LightGBM models
5 | * \author Hyunsu Cho
6 | */
7 | #ifndef SRC_MODEL_LOADER_DETAIL_LIGHTGBM_H_
8 | #define SRC_MODEL_LOADER_DETAIL_LIGHTGBM_H_
9 |
10 | #include
11 |
12 | #include
13 |
14 | #include "./string_utils.h"
15 |
16 | namespace treelite::model_loader::detail::lightgbm {
17 |
18 | /*!
19 | * \brief Canonicalize the name of an objective function.
20 | *
21 | * Some objective functions have many aliases. We use the canonical name to avoid confusion.
22 | *
23 | * @param obj_name Name of an objective function
24 | * @return Canonical name
25 | */
26 | std::string CanonicalObjective(std::string const& obj_name) {
27 | if (obj_name == "regression" || obj_name == "regression_l2" || obj_name == "l2"
28 | || obj_name == "mean_squared_error" || obj_name == "mse" || obj_name == "l2_root"
29 | || obj_name == "root_mean_squared_error" || obj_name == "rmse") {
30 | return "regression";
31 | } else if (obj_name == "regression_l1" || obj_name == "l1" || obj_name == "mean_absolute_error"
32 | || obj_name == "mae") {
33 | return "regression_l1";
34 | } else if (obj_name == "mape" || obj_name == "mean_absolute_percentage_error") {
35 | return "mape";
36 | } else if (obj_name == "multiclass" || obj_name == "softmax") {
37 | return "multiclass";
38 | } else if (obj_name == "multiclassova" || obj_name == "multiclass_ova" || obj_name == "ova"
39 | || obj_name == "ovr") {
40 | return "multiclassova";
41 | } else if (obj_name == "cross_entropy" || obj_name == "xentropy") {
42 | return "cross_entropy";
43 | } else if (obj_name == "cross_entropy_lambda" || obj_name == "xentlambda") {
44 | return "cross_entropy_lambda";
45 | } else if (obj_name == "rank_xendcg" || obj_name == "xendcg" || obj_name == "xe_ndcg"
46 | || obj_name == "xe_ndcg_mart" || obj_name == "xendcg_mart") {
47 | return "rank_xendcg";
48 | } else if (obj_name == "huber" || obj_name == "fair" || obj_name == "poisson"
49 | || obj_name == "quantile" || obj_name == "gamma" || obj_name == "tweedie"
50 | || obj_name == "binary" || obj_name == "lambdarank" || obj_name == "custom") {
51 | // These objectives have no aliases
52 | return obj_name;
53 | } else {
54 | TREELITE_LOG(FATAL) << "Unknown objective name: \"" << obj_name << "\"";
55 | return "";
56 | }
57 | }
58 |
59 | } // namespace treelite::model_loader::detail::lightgbm
60 |
61 | #endif // SRC_MODEL_LOADER_DETAIL_LIGHTGBM_H_
62 |
--------------------------------------------------------------------------------
/docs/treelite-c-api.rst:
--------------------------------------------------------------------------------
1 | ==============
2 | Treelite C API
3 | ==============
4 |
5 | Treelite exposes a set of C functions to enable interfacing with a variety of
6 | languages. This page will be most useful for:
7 |
8 | * those writing a new
9 | `language binding `_ (glue
10 | code).
11 | * those wanting to incorporate functions of Treelite into their own native
12 | libraries.
13 |
14 | **We recommend the Python API for everyday uses.**
15 |
16 | .. note:: Use of C and C++ in Treelite
17 |
18 | Core logic of Treelite are written in C++ to take advantage of higher
19 | abstractions. We provide C only interface here, as many more programming
20 | languages bind with C than with C++. See
21 | `this page `_ for
22 | more details.
23 |
24 | .. contents:: Contents
25 | :local:
26 |
27 | Model loader interface
28 | ----------------------
29 | Use the following functions to load decision tree ensemble models from a file.
30 | Treelite supports multiple model file formats.
31 |
32 | .. doxygengroup:: model_loader
33 | :project: treelite
34 | :content-only:
35 |
36 | Model loader interface for scikit-learn models
37 | ----------------------------------------------
38 | Use the following functions to load decision tree ensemble models from a scikit-learn
39 | model object.
40 |
41 | .. doxygengroup:: sklearn
42 | :project: treelite
43 | :content-only:
44 |
45 | Model builder interface
46 | -----------------------
47 | Use the following functions to incrementally build decisio n tree ensemble
48 | models.
49 |
50 | .. doxygengroup:: model_builder
51 | :project: treelite
52 | :content-only:
53 |
54 | Model manager interface
55 | -----------------------
56 |
57 | .. doxygengroup:: model_manager
58 | :project: treelite
59 | :content-only:
60 |
61 | Serializer
62 | ----------
63 |
64 | .. doxygengroup:: serializer
65 | :project: treelite
66 | :content-only:
67 |
68 | Getters and setters for the model object
69 | ----------------------------------------
70 |
71 | .. doxygengroup:: accessor
72 | :project: treelite
73 | :content-only:
74 |
75 | General Tree Inference Library (GTIL)
76 | -------------------------------------
77 |
78 | .. doxygengroup:: gtil
79 | :project: treelite
80 | :content-only:
81 |
82 | Handle types
83 | ------------
84 | Treelite uses C++ classes to define its internal data structures. In order to
85 | pass C++ objects to C functions, *opaque handles* are used. Opaque handles
86 | are ``void*`` pointers that store raw memory addresses.
87 |
88 | .. doxygengroup:: opaque_handles
89 | :project: treelite
90 | :content-only:
91 |
--------------------------------------------------------------------------------
/.github/workflows/coverage-tests.yml:
--------------------------------------------------------------------------------
1 | name: coverage-tests
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches:
7 | - mainline
8 | - 'release_*'
9 | schedule:
10 | - cron: "0 7 * * *" # Run once daily
11 |
12 | permissions:
13 | contents: read # to fetch code (actions/checkout)
14 |
15 | defaults:
16 | run:
17 | shell: bash -l {0}
18 |
19 | concurrency:
20 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
21 | cancel-in-progress: true
22 |
23 | env:
24 | CODECOV_TOKEN: afe9868c-2c27-4853-89fa-4bc5d3d2b255
25 |
26 | jobs:
27 | cpp-python-coverage:
28 | name: Run Python and C++ tests with test coverage (Linux)
29 | runs-on: ubuntu-latest
30 | steps:
31 | - uses: actions/checkout@v4
32 | - uses: conda-incubator/setup-miniconda@v3
33 | with:
34 | miniforge-variant: Miniforge3
35 | miniforge-version: latest
36 | conda-remove-defaults: "true"
37 | activate-environment: dev
38 | environment-file: ops/conda_env/dev.yml
39 | use-mamba: true
40 | - name: Display Conda env
41 | run: |
42 | conda info
43 | conda list
44 | - name: Run tests with test coverage computation
45 | run: |
46 | bash ops/cpp-python-coverage.sh
47 | win-python-coverage:
48 | name: Run Python and C++ tests with test coverage (Windows)
49 | runs-on: windows-latest
50 | steps:
51 | - uses: actions/checkout@v4
52 | - uses: conda-incubator/setup-miniconda@v3
53 | with:
54 | miniforge-variant: Miniforge3
55 | miniforge-version: latest
56 | conda-remove-defaults: "true"
57 | activate-environment: dev
58 | environment-file: ops/conda_env/dev.yml
59 | use-mamba: true
60 | - name: Display Conda env
61 | shell: cmd /C CALL {0}
62 | run: |
63 | conda info
64 | conda list
65 | - name: Run tests with test coverage computation
66 | shell: cmd /C CALL {0}
67 | run: |
68 | call ops/win-python-coverage.bat
69 | macos-python-coverage:
70 | name: Run Python and C++ tests with test coverage (MacOS)
71 | runs-on: macos-latest
72 | steps:
73 | - uses: actions/checkout@v4
74 | - uses: conda-incubator/setup-miniconda@v3
75 | with:
76 | miniforge-variant: Miniforge3
77 | miniforge-version: latest
78 | conda-remove-defaults: "true"
79 | activate-environment: dev
80 | environment-file: ops/conda_env/dev.yml
81 | use-mamba: true
82 | - name: Display Conda env
83 | run: |
84 | conda info
85 | conda list
86 | - name: Run tests with test coverage computation
87 | run: |
88 | bash ops/macos-python-coverage.sh
89 |
--------------------------------------------------------------------------------
/tests/serializer/test_serializer.py:
--------------------------------------------------------------------------------
1 | """Test for serialization, via buffer protocol"""
2 |
3 | import ctypes
4 | from typing import List
5 |
6 | import numpy as np
7 | import pytest
8 | from sklearn.datasets import load_iris
9 | from sklearn.ensemble import (
10 | ExtraTreesClassifier,
11 | GradientBoostingClassifier,
12 | RandomForestClassifier,
13 | )
14 |
15 | import treelite
16 | from treelite.core import _LIB, _check_call
17 | from treelite.model import _numpy2pybuffer, _pybuffer2numpy, _TreelitePyBufferFrame
18 | from treelite.util import c_array
19 |
20 |
21 | def treelite_deserialize(frames: List[np.ndarray]) -> treelite.Model:
22 | """Serialize model to PyBuffer frames"""
23 | buffers = [_numpy2pybuffer(frame) for frame in frames]
24 | handle = ctypes.c_void_p()
25 | _check_call(
26 | _LIB.TreeliteDeserializeModelFromPyBuffer(
27 | c_array(_TreelitePyBufferFrame, buffers),
28 | ctypes.c_size_t(len(buffers)),
29 | ctypes.byref(handle),
30 | )
31 | )
32 | return treelite.Model(handle=handle)
33 |
34 |
35 | def treelite_serialize(
36 | model: treelite.Model,
37 | ) -> List[np.ndarray]:
38 | """Deserialize model from PyBuffer frames"""
39 | frames = ctypes.POINTER(_TreelitePyBufferFrame)()
40 | n_frames = ctypes.c_size_t()
41 | _check_call(
42 | _LIB.TreeliteSerializeModelToPyBuffer(
43 | model.handle,
44 | ctypes.byref(frames),
45 | ctypes.byref(n_frames),
46 | )
47 | )
48 | return [_pybuffer2numpy(frames[i]) for i in range(n_frames.value)]
49 |
50 |
51 | @pytest.mark.parametrize(
52 | "clazz", [RandomForestClassifier, ExtraTreesClassifier, GradientBoostingClassifier]
53 | )
54 | def test_serialize_as_buffer(clazz):
55 | """Test whether Treelite objects can be serialized to a buffer"""
56 | X, y = load_iris(return_X_y=True)
57 | params = {"max_depth": 5, "random_state": 0, "n_estimators": 10}
58 | if clazz == GradientBoostingClassifier:
59 | params["init"] = "zero"
60 | clf = clazz(**params)
61 | clf.fit(X, y)
62 | expected_prob = clf.predict_proba(X).reshape((X.shape[0], 1, -1))
63 |
64 | # Prediction should be correct after a round-trip
65 | tl_model = treelite.sklearn.import_model(clf)
66 | frames = treelite_serialize(tl_model)
67 | tl_model2 = treelite_deserialize(frames)
68 | out_prob = treelite.gtil.predict(tl_model2, X)
69 | np.testing.assert_almost_equal(out_prob, expected_prob, decimal=5)
70 |
71 | # The model should serialize to the same byte sequence after a round-trip
72 | frames2 = treelite_serialize(tl_model2)
73 | assert len(frames) == len(frames2)
74 | for x, y in zip(frames, frames2):
75 | assert np.array_equal(x, y)
76 |
--------------------------------------------------------------------------------
/src/model_builder/metadata.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file metadata.cc
4 | * \brief C++ API for constructing Model metadata
5 | * \author Hyunsu Cho
6 | */
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 | #include
17 |
18 | namespace treelite::model_builder {
19 |
20 | TreeAnnotation::TreeAnnotation(std::int32_t num_tree, std::vector const& target_id,
21 | std::vector const& class_id)
22 | : num_tree{num_tree}, target_id{target_id}, class_id{class_id} {
23 | TREELITE_CHECK_EQ(target_id.size(), num_tree)
24 | << "target_id field must have length equal to num_tree (" << num_tree << ")";
25 | TREELITE_CHECK_EQ(class_id.size(), num_tree)
26 | << "class_id field must have length equal to num_tree (" << num_tree << ")";
27 | }
28 |
29 | PostProcessorFunc::PostProcessorFunc(std::string const& name) : PostProcessorFunc(name, {}) {}
30 |
31 | PostProcessorFunc::PostProcessorFunc(
32 | std::string const& name, std::map const& config)
33 | : name(name), config(config) {}
34 |
35 | Metadata::Metadata(std::int32_t num_feature, TaskType task_type, bool average_tree_output,
36 | std::int32_t num_target, std::vector const& num_class,
37 | std::array const& leaf_vector_shape)
38 | : num_feature(num_feature),
39 | task_type(task_type),
40 | average_tree_output(average_tree_output),
41 | num_target(num_target),
42 | num_class(num_class),
43 | leaf_vector_shape(leaf_vector_shape) {
44 | TREELITE_CHECK_GT(num_target, 0) << "num_target must be at least 1";
45 | TREELITE_CHECK_EQ(num_class.size(), num_target)
46 | << "num_class field must have length equal to num_target (" << num_target << ")";
47 | if (!std::all_of(num_class.begin(), num_class.end(), [](std::int32_t e) { return e >= 1; })) {
48 | TREELITE_LOG(FATAL) << "All elements in num_class field must be at least 1.";
49 | }
50 | TREELITE_CHECK(leaf_vector_shape[0] == 1 || leaf_vector_shape[0] == num_target)
51 | << "leaf_vector_shape[0] must be either 1 or num_target (" << num_target << "). "
52 | << "Currently given: leaf_vector_shape[1] = " << leaf_vector_shape[1];
53 | std::int32_t const max_num_class = *std::max_element(num_class.begin(), num_class.end());
54 | TREELITE_CHECK(leaf_vector_shape[1] == 1 || leaf_vector_shape[1] == max_num_class)
55 | << "leaf_vector_shape[1] must be either 1 or max_num_class (" << max_num_class << "). "
56 | << "Currently given: leaf_vector_shape[1] = " << leaf_vector_shape[1];
57 | }
58 |
59 | } // namespace treelite::model_builder
60 |
--------------------------------------------------------------------------------
/python/treelite/libpath.py:
--------------------------------------------------------------------------------
1 | """Find the path to Treelite dynamic library files."""
2 |
3 | import os
4 | import pathlib
5 | import sys
6 | from typing import List
7 |
8 | from .path_config import get_custom_libpath
9 |
10 |
11 | class TreeliteLibraryNotFound(Exception):
12 | """Error thrown by when Treelite is not found"""
13 |
14 |
15 | def find_lib_path() -> List[pathlib.Path]:
16 | """Find the path to Treelite dynamic library files.
17 |
18 | Returns
19 | -------
20 | lib_path
21 | List of all found library path to Treelite
22 | """
23 | curr_path = pathlib.Path(__file__).expanduser().absolute().parent
24 | dll_path = [
25 | # When installed, libtreelite will be installed in /lib
26 | curr_path / "lib",
27 | # Editable installation
28 | curr_path.parent.parent / "build",
29 | # Use libtreelite from a system prefix, if available. This should be the last option.
30 | pathlib.Path(sys.base_prefix).expanduser().resolve() / "lib",
31 | ]
32 | custom_libpath = get_custom_libpath() # pylint: disable=assignment-from-none
33 | if custom_libpath:
34 | dll_path.insert(0, pathlib.Path(custom_libpath).expanduser().resolve())
35 |
36 | if sys.platform == "win32":
37 | # On Windows, Conda may install libs in different paths
38 | sys_prefix = pathlib.Path(sys.base_prefix)
39 | dll_path.extend(
40 | [
41 | sys_prefix / "bin",
42 | sys_prefix / "Library",
43 | sys_prefix / "Library" / "bin",
44 | sys_prefix / "Library" / "lib",
45 | ]
46 | )
47 | dll_path = [p.joinpath("treelite.dll") for p in dll_path]
48 | elif sys.platform.startswith(("linux", "freebsd", "emscripten", "OS400")):
49 | dll_path = [p.joinpath("libtreelite.so") for p in dll_path]
50 | elif sys.platform == "darwin":
51 | dll_path = [p.joinpath("libtreelite.dylib") for p in dll_path]
52 | elif sys.platform == "cygwin":
53 | dll_path = [p.joinpath("cygtreelite.dll") for p in dll_path]
54 | else:
55 | raise RuntimeError(f"Unrecognized platform: {sys.platform}")
56 |
57 | lib_path = [p for p in dll_path if p.exists() and p.is_file()]
58 |
59 | # TREELITE_BUILD_DOC is defined by sphinx conf.
60 | if not lib_path and not os.environ.get("TREELITE_BUILD_DOC", False):
61 | link = "https://treelite.readthedocs.io/en/latest/install.html"
62 | msg = (
63 | "Cannot find Treelite Library in the candidate path. "
64 | + "List of candidates:\n- "
65 | + ("\n- ".join(str(x) for x in dll_path))
66 | + "\nTreelite Python package path: "
67 | + str(curr_path)
68 | + "\nsys.base_prefix: "
69 | + sys.base_prefix
70 | + "\nSee: "
71 | + link
72 | + " for installing Treelite."
73 | )
74 | raise TreeliteLibraryNotFound(msg)
75 | return lib_path
76 |
--------------------------------------------------------------------------------
/src/c_api/model.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file model.cc
4 | * \author Hyunsu Cho
5 | * \brief C API for functions to query and modify model objects
6 | */
7 |
8 | #include
9 | #include
10 | #include
11 |
12 | #include
13 | #include
14 | #include
15 |
16 | #include "./c_api_utils.h"
17 |
18 | int TreeliteDumpAsJSON(TreeliteModelHandle handle, int pretty_print, char const** out_json_str) {
19 | API_BEGIN();
20 | auto* model_ = static_cast(handle);
21 | std::string& ret_str = treelite::c_api::ReturnValueStore::Get()->ret_str;
22 | ret_str = model_->DumpAsJSON(pretty_print != 0);
23 | *out_json_str = ret_str.c_str();
24 | API_END();
25 | }
26 |
27 | int TreeliteGetInputType(TreeliteModelHandle model, char const** out_str) {
28 | API_BEGIN();
29 | auto const* model_ = static_cast(model);
30 | auto& type_str = treelite::c_api::ReturnValueStore::Get()->ret_str;
31 | type_str = treelite::TypeInfoToString(model_->GetThresholdType());
32 | *out_str = type_str.c_str();
33 | API_END();
34 | }
35 |
36 | int TreeliteGetOutputType(TreeliteModelHandle model, char const** out_str) {
37 | API_BEGIN();
38 | auto const* model_ = static_cast(model);
39 | auto& type_str = treelite::c_api::ReturnValueStore::Get()->ret_str;
40 | type_str = treelite::TypeInfoToString(model_->GetLeafOutputType());
41 | *out_str = type_str.c_str();
42 | API_END();
43 | }
44 |
45 | int TreeliteQueryNumTree(TreeliteModelHandle model, std::size_t* out) {
46 | API_BEGIN();
47 | auto const* model_ = static_cast(model);
48 | *out = model_->GetNumTree();
49 | API_END();
50 | }
51 |
52 | int TreeliteQueryNumFeature(TreeliteModelHandle model, int* out) {
53 | API_BEGIN();
54 | auto const* model_ = static_cast(model);
55 | *out = model_->num_feature;
56 | API_END();
57 | }
58 |
59 | int TreeliteConcatenateModelObjects(
60 | TreeliteModelHandle const* objs, std::size_t len, TreeliteModelHandle* out) {
61 | API_BEGIN();
62 | std::vector model_objs(len, nullptr);
63 | std::transform(objs, objs + len, model_objs.begin(),
64 | [](TreeliteModelHandle e) { return static_cast(e); });
65 | auto concatenated_model = ConcatenateModelObjects(model_objs);
66 | *out = static_cast(concatenated_model.release());
67 | API_END();
68 | }
69 |
70 | int TreeliteFreeModel(TreeliteModelHandle handle) {
71 | API_BEGIN();
72 | delete static_cast(handle);
73 | API_END();
74 | }
75 |
76 | int TreeliteGetTreeDepth(TreeliteModelHandle model, std::uint32_t** out, std::size_t* out_len) {
77 | API_BEGIN();
78 | auto* model_ = static_cast(model);
79 | auto& ret_depth = treelite::c_api::ReturnValueStore::Get()->ret_uint32_vec;
80 | ret_depth = model_->GetTreeDepth();
81 | *out = ret_depth.data();
82 | *out_len = ret_depth.size();
83 | API_END();
84 | }
85 |
--------------------------------------------------------------------------------
/src/c_api/serializer.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file serializer.cc
4 | * \author Hyunsu Cho
5 | * \brief C API for functions to serialize model objects
6 | */
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | #include
14 | #include
15 | #include
16 | #include
17 |
18 | #include "./c_api_utils.h"
19 |
20 | int TreeliteSerializeModelToFile(TreeliteModelHandle handle, char const* filename) {
21 | API_BEGIN();
22 | std::ofstream ofs = treelite::detail::OpenFileForWriteAsStream(filename);
23 | auto* model_ = static_cast(handle);
24 | model_->SerializeToStream(ofs);
25 | API_END();
26 | }
27 |
28 | int TreeliteDeserializeModelFromFile(char const* filename, TreeliteModelHandle* out) {
29 | API_BEGIN();
30 | std::ifstream ifs = treelite::detail::OpenFileForReadAsStream(filename);
31 | std::unique_ptr model = treelite::Model::DeserializeFromStream(ifs);
32 | *out = static_cast(model.release());
33 | API_END();
34 | }
35 |
36 | int TreeliteSerializeModelToBytes(
37 | TreeliteModelHandle handle, char const** out_bytes, std::size_t* out_bytes_len) {
38 | API_BEGIN();
39 | std::ostringstream oss;
40 | oss.exceptions(std::ios::failbit | std::ios::badbit); // Throw exception on failure
41 | auto* model_ = static_cast(handle);
42 | model_->SerializeToStream(oss);
43 |
44 | std::string& ret_str = treelite::c_api::ReturnValueStore::Get()->ret_str;
45 | ret_str = oss.str();
46 | *out_bytes = ret_str.data();
47 | *out_bytes_len = ret_str.length();
48 | API_END();
49 | }
50 |
51 | int TreeliteDeserializeModelFromBytes(
52 | char const* bytes, std::size_t bytes_len, TreeliteModelHandle* out) {
53 | API_BEGIN();
54 | std::istringstream iss(std::string(bytes, bytes_len));
55 | iss.exceptions(std::ios::failbit | std::ios::badbit); // Throw exception on failure
56 | std::unique_ptr model = treelite::Model::DeserializeFromStream(iss);
57 | *out = static_cast(model.release());
58 | API_END();
59 | }
60 |
61 | int TreeliteSerializeModelToPyBuffer(
62 | TreeliteModelHandle handle, TreelitePyBufferFrame** out_frames, size_t* out_num_frames) {
63 | API_BEGIN();
64 | auto* model_ = static_cast(handle);
65 | std::vector& ret_frames
66 | = treelite::c_api::ReturnValueStore::Get()->ret_frames;
67 | ret_frames = model_->SerializeToPyBuffer();
68 | if (ret_frames.empty()) {
69 | *out_frames = nullptr;
70 | *out_num_frames = 0;
71 | } else {
72 | *out_frames = &ret_frames[0];
73 | *out_num_frames = ret_frames.size();
74 | }
75 | API_END();
76 | }
77 |
78 | int TreeliteDeserializeModelFromPyBuffer(
79 | TreelitePyBufferFrame* frames, size_t num_frames, TreeliteModelHandle* out) {
80 | API_BEGIN();
81 | std::vector frames_(frames, frames + num_frames);
82 | auto model = treelite::Model::DeserializeFromPyBuffer(frames_);
83 | *out = static_cast(model.release());
84 | API_END();
85 | }
86 |
--------------------------------------------------------------------------------
/src/model_loader/xgboost_ubjson.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2024 by Contributors
3 | * \file xgboost_ubjson.cc
4 | * \brief Model loader for XGBoost model (UBJSON)
5 | * \author Hyunsu Cho
6 | */
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | #include
14 | #include
15 | #include
16 |
17 | #include
18 |
19 | #include "detail/xgboost_json/delegated_handler.h"
20 | #include "detail/xgboost_json/sax_adapters.h"
21 |
22 | namespace {
23 |
24 | template
25 | std::unique_ptr ParseStream(
26 | InputType&& input_stream, nlohmann::json const& parsed_config);
27 |
28 | } // anonymous namespace
29 |
30 | namespace treelite::model_loader {
31 |
32 | std::unique_ptr LoadXGBoostModelUBJSON(
33 | std::string const& filename, std::string const& config_json) {
34 | nlohmann::json parsed_config = nlohmann::json::parse(config_json);
35 | std::ifstream ifs = treelite::detail::OpenFileForReadAsStream(filename);
36 | return ParseStream(ifs, parsed_config);
37 | }
38 |
39 | std::unique_ptr LoadXGBoostModelFromUBJSONString(
40 | std::string_view ubjson_str, std::string const& config_json) {
41 | nlohmann::json parsed_config = nlohmann::json::parse(config_json);
42 | return ParseStream(ubjson_str, parsed_config);
43 | }
44 |
45 | } // namespace treelite::model_loader
46 |
47 | namespace {
48 |
49 | template
50 | std::unique_ptr ParseStream(
51 | InputType&& input_stream, nlohmann::json const& parsed_config) {
52 | treelite::model_loader::detail::xgboost::HandlerConfig handler_config;
53 | if (parsed_config.is_object()) {
54 | auto itr = parsed_config.find("allow_unknown_field");
55 | if (itr != parsed_config.end() && itr->is_boolean()) {
56 | handler_config.allow_unknown_field = itr->template get();
57 | }
58 | }
59 |
60 | std::shared_ptr handler
61 | = treelite::model_loader::detail::xgboost::DelegatedHandler::create(handler_config);
62 | auto adapter
63 | = std::make_unique(handler);
64 | TREELITE_CHECK(nlohmann::json::sax_parse(
65 | input_stream, adapter.get(), nlohmann::json::input_format_t::ubjson));
66 |
67 | treelite::model_loader::detail::xgboost::ParsedXGBoostModel parsed = handler->get_result();
68 | auto model = parsed.builder->CommitModel();
69 |
70 | // Apply Dart weights
71 | if (!parsed.weight_drop.empty()) {
72 | auto& trees = std::get>(model->variant_).trees;
73 | TREELITE_CHECK_EQ(trees.size(), parsed.weight_drop.size());
74 | for (std::size_t i = 0; i < trees.size(); ++i) {
75 | for (int nid = 0; nid < trees[i].num_nodes; ++nid) {
76 | if (trees[i].IsLeaf(nid)) {
77 | trees[i].SetLeaf(
78 | nid, static_cast(parsed.weight_drop[i] * trees[i].LeafValue(nid)));
79 | }
80 | }
81 | }
82 | }
83 | return model;
84 | }
85 |
86 | } // anonymous namespace
87 |
--------------------------------------------------------------------------------
/include/treelite/detail/file_utils.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file file_utils.h
4 | * \brief Helper functions for manipulating files
5 | * \author Hyunsu Cho
6 | */
7 |
8 | #ifndef TREELITE_DETAIL_FILE_UTILS_H_
9 | #define TREELITE_DETAIL_FILE_UTILS_H_
10 |
11 | #include
12 | #include
13 | #include
14 | #include
15 |
16 | #include
17 |
18 | namespace treelite::detail {
19 |
20 | inline std::ifstream OpenFileForReadAsStream(std::filesystem::path const& filepath) {
21 | auto path = std::filesystem::weakly_canonical(filepath);
22 | TREELITE_CHECK(std::filesystem::exists(path)) << "Path " << filepath << " does not exist";
23 | std::ifstream ifs(path, std::ios::in | std::ios::binary);
24 | TREELITE_CHECK(ifs) << "Could not open file " << filepath;
25 | ifs.exceptions(std::ios::badbit); // Throw exceptions on error
26 | // We don't throw on failbit, since we sometimes want to read
27 | // until the end of file in a loop of form `while
28 | // (std::getline(...))`, which will set failbit.
29 | return ifs;
30 | }
31 |
32 | inline std::ifstream OpenFileForReadAsStream(std::string const& filename) {
33 | return OpenFileForReadAsStream(std::filesystem::u8path(filename));
34 | }
35 |
36 | inline std::ifstream OpenFileForReadAsStream(char const* filename) {
37 | return OpenFileForReadAsStream(std::string(filename));
38 | }
39 |
40 | inline std::ofstream OpenFileForWriteAsStream(std::filesystem::path const& filepath) {
41 | auto path = std::filesystem::weakly_canonical(filepath);
42 | TREELITE_CHECK(path.has_filename()) << "Cannot write to a directory; please specify a file";
43 | TREELITE_CHECK(std::filesystem::exists(path.parent_path()))
44 | << "Path " << path.parent_path() << " does not exist";
45 | std::ofstream ofs(path, std::ios::out | std::ios::binary);
46 | TREELITE_CHECK(ofs) << "Could not open file " << filepath;
47 | ofs.exceptions(std::ios::failbit | std::ios::badbit); // Throw exceptions on error
48 | return ofs;
49 | }
50 |
51 | inline std::ofstream OpenFileForWriteAsStream(std::string const& filename) {
52 | return OpenFileForWriteAsStream(std::filesystem::u8path(filename));
53 | }
54 |
55 | inline std::ofstream OpenFileForWriteAsStream(char const* filename) {
56 | return OpenFileForWriteAsStream(std::string(filename));
57 | }
58 |
59 | inline FILE* OpenFileForReadAsFilePtr(std::filesystem::path const& filepath) {
60 | auto path = std::filesystem::weakly_canonical(filepath);
61 | TREELITE_CHECK(std::filesystem::exists(path)) << "Path " << filepath << " does not exist";
62 | FILE* fp;
63 | #ifdef _WIN32
64 | fp = _wfopen(path.wstring().c_str(), L"rb");
65 | #else
66 | fp = std::fopen(path.string().c_str(), "rb");
67 | #endif
68 | TREELITE_CHECK(fp) << "Could not open file " << filepath;
69 | return fp;
70 | }
71 |
72 | inline FILE* OpenFileForReadAsFilePtr(std::string const& filename) {
73 | return OpenFileForReadAsFilePtr(std::filesystem::u8path(filename));
74 | }
75 |
76 | inline FILE* OpenFileForReadAsFilePtr(char const* filename) {
77 | return OpenFileForReadAsFilePtr(std::string(filename));
78 | }
79 |
80 | } // namespace treelite::detail
81 |
82 | #endif // TREELITE_DETAIL_FILE_UTILS_H_
83 |
--------------------------------------------------------------------------------
/tests/python/metadata.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Metadata for datasets and models used for testing"""
3 |
4 | import collections
5 | import os
6 |
7 | _current_dir = os.path.dirname(__file__)
8 | _dpath = os.path.abspath(os.path.join(_current_dir, os.path.pardir, "examples"))
9 |
10 | Dataset = collections.namedtuple(
11 | "Dataset",
12 | "model format dtrain dtest libname expected_prob expected_margin is_multiclass dtype",
13 | )
14 |
15 | _dataset_db = {
16 | "mushroom": Dataset(
17 | model="mushroom.model",
18 | format="xgboost",
19 | dtrain="agaricus.train",
20 | dtest="agaricus.test",
21 | libname="agaricus",
22 | expected_prob="agaricus.test.prob",
23 | expected_margin="agaricus.test.margin",
24 | is_multiclass=False,
25 | dtype="float32",
26 | ),
27 | "dermatology": Dataset(
28 | model="dermatology.model",
29 | format="xgboost",
30 | dtrain="dermatology.train",
31 | dtest="dermatology.test",
32 | libname="dermatology",
33 | expected_prob="dermatology.test.prob",
34 | expected_margin="dermatology.test.margin",
35 | is_multiclass=True,
36 | dtype="float32",
37 | ),
38 | "letor": Dataset(
39 | model="mq2008.model",
40 | format="xgboost",
41 | dtrain="mq2008.train",
42 | dtest="mq2008.test",
43 | libname="letor",
44 | expected_prob=None,
45 | expected_margin="mq2008.test.pred",
46 | is_multiclass=False,
47 | dtype="float32",
48 | ),
49 | "toy_categorical": Dataset(
50 | model="toy_categorical_model.txt",
51 | format="lightgbm",
52 | dtrain=None,
53 | dtest="toy_categorical.test",
54 | libname="toycat",
55 | expected_prob=None,
56 | expected_margin="toy_categorical.test.pred",
57 | is_multiclass=False,
58 | dtype="float64",
59 | ),
60 | "sparse_categorical": Dataset(
61 | model="sparse_categorical_model.txt",
62 | format="lightgbm",
63 | dtrain=None,
64 | dtest="sparse_categorical.test",
65 | libname="sparsecat",
66 | expected_prob=None,
67 | expected_margin="sparse_categorical.test.margin",
68 | is_multiclass=False,
69 | dtype="float64",
70 | ),
71 | "xgb_toy_categorical": Dataset(
72 | model="xgb_toy_categorical_model.json",
73 | format="xgboost_json",
74 | dtrain=None,
75 | dtest="xgb_toy_categorical.test",
76 | libname="xgbtoycat",
77 | expected_prob=None,
78 | expected_margin="xgb_toy_categorical.test.pred",
79 | is_multiclass=False,
80 | dtype="float32",
81 | ),
82 | }
83 |
84 |
85 | def _qualify_path(prefix, path):
86 | if path is None:
87 | return None
88 | return os.path.join(_dpath, prefix, path)
89 |
90 |
91 | dataset_db = {
92 | k: v._replace(
93 | model=_qualify_path(k, v.model),
94 | dtrain=_qualify_path(k, v.dtrain),
95 | dtest=_qualify_path(k, v.dtest),
96 | expected_prob=_qualify_path(k, v.expected_prob),
97 | expected_margin=_qualify_path(k, v.expected_margin),
98 | )
99 | for k, v in _dataset_db.items()
100 | }
101 |
--------------------------------------------------------------------------------
/docs/tutorials/import.rst:
--------------------------------------------------------------------------------
1 | Importing tree ensemble models
2 | ==============================
3 |
4 | Since the scope of Treelite is limited to **prediction** only, one must use
5 | other machine learning packages to **train** decision tree ensemble models. In
6 | this document, we will show how to import an ensemble model that had been
7 | trained elsewhere.
8 |
9 | .. contents:: Contents
10 | :local:
11 |
12 | Importing XGBoost models
13 | ------------------------
14 |
15 | **XGBoost** (`dmlc/xgboost `_) is a fast,
16 | scalable package for gradient boosting. Both Treelite and XGBoost are hosted
17 | by the DMLC (Distributed Machine Learning Community) group.
18 |
19 | Treelite plays well with XGBoost --- if you used XGBoost to train your ensemble
20 | model, you need only one line of code to import it. Depending on where your
21 | model is located, use :py:meth:`~treelite.frontend.from_xgboost`,
22 | :py:meth:`~treelite.frontend.load_xgboost_model`, or
23 | :py:meth:`~treelite.frontend.load_xgboost_model_legacy_binary`:
24 |
25 | * Load XGBoost model from a :py:class:`xgboost.Booster` object
26 |
27 | .. code-block:: python
28 |
29 | # bst = an object of type xgboost.Booster
30 | model = treelite.frontend.from_xgboost(bst)
31 |
32 | * Load XGBoost model from a model file
33 |
34 | .. code-block:: python
35 |
36 | # JSON format
37 | model = treelite.frontend.load_xgboost_model("my_model.json")
38 | # Legacy binary format
39 | model = treelite.frontend.load_xgboost_model_legacy_binary("my_model.model")
40 |
41 | Importing LightGBM models
42 | -------------------------
43 |
44 | **LightGBM** (`Microsoft/LightGBM `_) is
45 | another well known machine learning package for gradient boosting. To import
46 | models generated by LightGBM, use the
47 | :py:meth:`~treelite.frontend.load_lightgbm_model` method:
48 |
49 | .. code-block:: python
50 |
51 | model = treelite.frontend.load_lightgbm_model("lightgbm_model.txt")
52 |
53 | Importing scikit-learn models
54 | -----------------------------
55 | **Scikit-learn** (`scikit-learn/scikit-learn
56 | `_) is a Python machine learning
57 | package known for its versatility and ease of use. It supports a wide variety
58 | of models and algorithms. The following kinds of models can be imported into
59 | Treelite.
60 |
61 | * :py:class:`sklearn.ensemble.RandomForestRegressor`
62 | * :py:class:`sklearn.ensemble.RandomForestClassifier`
63 | * :py:class:`sklearn.ensemble.ExtraTreesRegressor`
64 | * :py:class:`sklearn.ensemble.ExtraTreesClassifier`
65 | * :py:class:`sklearn.ensemble.GradientBoostingRegressor`
66 | * :py:class:`sklearn.ensemble.GradientBoostingClassifier`
67 | * :py:class:`sklearn.ensemble.HistGradientBoostingRegressor`
68 | * :py:class:`sklearn.ensemble.HistGradientBoostingClassifier`
69 | * :py:class:`sklearn.ensemble.IsolationForest`
70 |
71 | To import scikit-learn models, use
72 | :py:meth:`treelite.sklearn.import_model`:
73 |
74 | .. code-block:: python
75 |
76 | # clf is the model object generated by scikit-learn
77 | import treelite.sklearn
78 | model = treelite.sklearn.import_model(clf)
79 |
80 | How about other packages?
81 | -------------------------
82 | If you used other packages to train your ensemble model, you'd need to specify
83 | the model programmatically:
84 |
85 | * :doc:`/tutorials/builder`
86 |
--------------------------------------------------------------------------------
/src/c_api/gtil.cc:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2023 by Contributors
3 | * \file gtil.cc
4 | * \author Hyunsu Cho
5 | * \brief C API for functions for GTIL
6 | */
7 |
8 | #include
9 | #include
10 | #include
11 |
12 | #include
13 | #include
14 | #include
15 | #include
16 |
17 | #include "./c_api_utils.h"
18 |
19 | int TreeliteGTILParseConfig(char const* config_json, TreeliteGTILConfigHandle* out) {
20 | API_BEGIN();
21 | auto parsed_config = std::make_unique(config_json);
22 | *out = static_cast(parsed_config.release());
23 | API_END();
24 | }
25 |
26 | int TreeliteGTILDeleteConfig(TreeliteGTILConfigHandle handle) {
27 | API_BEGIN();
28 | delete static_cast(handle);
29 | API_END();
30 | }
31 |
32 | int TreeliteGTILGetOutputShape(TreeliteModelHandle model, std::uint64_t num_row,
33 | TreeliteGTILConfigHandle config, std::uint64_t const** out, std::uint64_t* out_ndim) {
34 | API_BEGIN();
35 | auto const* model_ = static_cast(model);
36 | auto const* config_ = static_cast(config);
37 | auto& shape = treelite::c_api::ReturnValueStore::Get()->ret_uint64_vec;
38 | shape = treelite::gtil::GetOutputShape(*model_, num_row, *config_);
39 | *out = shape.data();
40 | *out_ndim = shape.size();
41 | API_END();
42 | }
43 |
44 | int TreeliteGTILPredict(TreeliteModelHandle model, void const* input, char const* input_type,
45 | std::uint64_t num_row, void* output, TreeliteGTILConfigHandle config) {
46 | API_BEGIN();
47 | auto const* model_ = static_cast(model);
48 | auto const* config_ = static_cast(config);
49 | std::string input_type_str = std::string(input_type);
50 | if (input_type_str == "float32") {
51 | treelite::gtil::Predict(
52 | *model_, static_cast(input), num_row, static_cast(output), *config_);
53 | } else if (input_type_str == "float64") {
54 | treelite::gtil::Predict(*model_, static_cast(input), num_row,
55 | static_cast(output), *config_);
56 | } else {
57 | TREELITE_LOG(FATAL) << "Unexpected type spec: " << input_type_str;
58 | }
59 | API_END();
60 | }
61 |
62 | int TreeliteGTILPredictSparse(TreeliteModelHandle model, void const* data, char const* input_type,
63 | std::uint64_t const* col_ind, std::uint64_t const* row_ptr, std::uint64_t num_row, void* output,
64 | TreeliteGTILConfigHandle config) {
65 | API_BEGIN();
66 | auto const* model_ = static_cast(model);
67 | auto const* config_ = static_cast(config);
68 | std::string input_type_str = std::string(input_type);
69 | if (input_type_str == "float32") {
70 | treelite::gtil::PredictSparse(*model_, static_cast(data), col_ind, row_ptr,
71 | num_row, static_cast(output), *config_);
72 | } else if (input_type_str == "float64") {
73 | treelite::gtil::PredictSparse(*model_, static_cast