├── .formatter.exs ├── .github └── workflows │ └── precompile.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── c └── exgboost │ ├── include │ ├── booster.h │ ├── config.h │ ├── dmatrix.h │ ├── exgboost.h │ └── utils.h │ └── src │ ├── booster.c │ ├── config.c │ ├── dmatrix.c │ ├── exgboost.c │ └── utils.c ├── lib ├── exgboost.ex └── exgboost │ ├── application.ex │ ├── array_interface.ex │ ├── booster.ex │ ├── dmatrix.ex │ ├── internal.ex │ ├── nif.ex │ ├── parameters.ex │ ├── plotting.ex │ ├── plotting │ ├── style.ex │ └── styles.ex │ ├── training.ex │ └── training │ ├── callback.ex │ └── state.ex ├── mix.exs ├── mix.lock ├── notebooks ├── compiled_benchmarks.livemd ├── iris_classification.livemd ├── plotting.livemd └── quantile_prediction_interval.livemd └── test ├── data ├── another.txt ├── model.json ├── test.conf ├── testfile.txt └── train.txt ├── exgboost_test.exs ├── model.json ├── nif_test.exs ├── parameter_test.exs └── test_helper.exs /.formatter.exs: -------------------------------------------------------------------------------- 1 | # Used by "mix format" 2 | [ 3 | inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] 4 | ] 5 | -------------------------------------------------------------------------------- /.github/workflows/precompile.yml: -------------------------------------------------------------------------------- 1 | name: precompile 2 | 3 | on: 4 | - push 5 | - workflow_dispatch 6 | 7 | jobs: 8 | linux: 9 | name: Linux Erlang/OTP ${{matrix.otp}} / Elixir ${{matrix.elixir}} 10 | runs-on: ubuntu-20.04 11 | env: 12 | MIX_ENV: "prod" 13 | strategy: 14 | matrix: 15 | # Elixir 1.14.5 is first version compatible with OTP 26 16 | # NIF versions change according to 17 | # https://github.com/erlang/otp/blob/dd57c853a324a9572a9e5ce227d8675ff004c6fe/erts/emulator/beam/erl_nif.h#L33 18 | otp: ["25.0", "26.0"] 19 | elixir: ["1.14.5"] 20 | steps: 21 | - uses: actions/checkout@v3 22 | - uses: erlef/setup-beam@v1 23 | with: 24 | otp-version: ${{matrix.otp}} 25 | elixir-version: ${{matrix.elixir}} 26 | - name: Install system dependecies 27 | run: | 28 | sudo apt-get update 29 | sudo apt-get install -y build-essential automake autoconf pkg-config bc m4 unzip zip \ 30 | gcc-aarch64-linux-gnu g++-aarch64-linux-gnu \ 31 | gcc-riscv64-linux-gnu g++-riscv64-linux-gnu 32 | - name: Mix Test 33 | run: | 34 | mix deps.get 35 | MIX_ENV=test mix test 36 | - name: Create precompiled library 37 | run: | 38 | export ELIXIR_MAKE_CACHE_DIR=$(pwd)/cache 39 | mkdir -p "${ELIXIR_MAKE_CACHE_DIR}" 40 | mix elixir_make.precompile 41 | - uses: softprops/action-gh-release@v1 42 | if: startsWith(github.ref, 'refs/tags/') 43 | with: 44 | files: | 45 | cache/*.tar.gz 46 | 47 | macos: 48 | runs-on: ${{matrix.runner}} 49 | # Homebrew supports versioned Erlang/OTP but not Elixir 50 | # It's a deliberate design decision from Homebrew to 51 | # only support versioned distrinutions for certin packages 52 | name: Mac (${{ matrix.runner == 'macos-13' && 'Intel' || 'ARM' }}) Erlang/OTP ${{matrix.otp}} / Elixir 53 | env: 54 | MIX_ENV: "prod" 55 | strategy: 56 | matrix: 57 | runner: ["macos-13", "macos-14"] 58 | otp: ["25.0", "26.0"] 59 | elixir: ["1.14.5"] 60 | steps: 61 | - uses: actions/checkout@v3 62 | - uses: asdf-vm/actions/install@v2 63 | with: 64 | tool_versions: | 65 | erlang ${{matrix.otp}} 66 | elixir ${{matrix.elixir}} 67 | - name: Install libomp 68 | run: | 69 | brew install libomp 70 | mix local.hex --force 71 | mix local.rebar --force 72 | - name: Mix Test 73 | run: | 74 | mix deps.get 75 | MIX_ENV=test mix test 76 | - name: Create precompiled library 77 | run: | 78 | export ELIXIR_MAKE_CACHE_DIR=$(pwd)/cache 79 | mkdir -p "${ELIXIR_MAKE_CACHE_DIR}" 80 | mix elixir_make.precompile 81 | - uses: softprops/action-gh-release@v1 82 | if: startsWith(github.ref, 'refs/tags/') 83 | with: 84 | files: | 85 | ${{ matrix.runner == 'macos-13' && 'cache/*x86_64*.tar.gz' || 'cache/*aarch64*.tar.gz' }} 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /cache 2 | /_build 3 | /cover 4 | /deps 5 | /doc 6 | /.fetch 7 | erl_crash.dump 8 | *.ez 9 | *.beam 10 | /config/*.secret.exs 11 | .elixir_ls/ 12 | .tool-versions 13 | .vscode/ 14 | checksum.exs 15 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Environment variables passed via elixir_make 2 | # ERTS_INCLUDE_DIR 3 | # MIX_APP_PATH 4 | 5 | TEMP ?= $(HOME)/.cache 6 | XGBOOST_CACHE ?= $(TEMP)/exgboost 7 | XGBOOST_GIT_REPO ?= https://github.com/dmlc/xgboost.git 8 | # 2.0.2 Release Commit 9 | XGBOOST_GIT_REV ?= 41ce8f28b269dbb7efc70e3a120af3c0bb85efe3 10 | XGBOOST_NS = xgboost-$(XGBOOST_GIT_REV) 11 | XGBOOST_DIR = $(XGBOOST_CACHE)/$(XGBOOST_NS) 12 | XGBOOST_LIB_DIR = $(XGBOOST_DIR)/build/xgboost 13 | XGBOOST_LIB_DIR_FLAG = $(XGBOOST_LIB_DIR)/exgboost.ok 14 | 15 | # Private configuration 16 | PRIV_DIR = $(MIX_APP_PATH)/priv 17 | EXGBOOST_DIR = $(realpath c/exgboost) 18 | EXGBOOST_CACHE_SO = cache/libexgboost.so 19 | EXGBOOST_CACHE_LIB_DIR = cache/lib 20 | EXGBOOST_SO = $(PRIV_DIR)/libexgboost.so 21 | EXGBOOST_LIB_DIR = $(PRIV_DIR)/lib 22 | 23 | # Build flags 24 | CFLAGS = -I$(EXGBOOST_DIR)/include -I$(XGBOOST_LIB_DIR)/include -I$(XGBOOST_DIR) -I$(ERTS_INCLUDE_DIR) -fPIC -O3 --verbose -shared -std=c11 25 | 26 | C_SRCS = $(wildcard $(EXGBOOST_DIR)/src/*.c) $(wildcard $(EXGBOOST_DIR)/include/*.h) 27 | 28 | LDFLAGS = -L$(EXGBOOST_CACHE_LIB_DIR) -lxgboost 29 | 30 | ifeq ($(shell uname -s), Darwin) 31 | POST_INSTALL = install_name_tool $(EXGBOOST_CACHE_SO) -change @rpath/libxgboost.dylib @loader_path/lib/libxgboost.dylib 32 | LDFLAGS += -flat_namespace -undefined suppress 33 | LIBXGBOOST = libxgboost.dylib 34 | ifeq ($(USE_LLVM_BREW), true) 35 | LLVM_PREFIX=$(shell brew --prefix llvm) 36 | CMAKE_FLAGS += -DCMAKE_CXX_COMPILER=$(LLVM_PREFIX)/bin/clang++ 37 | endif 38 | else 39 | LIBXGBOOST = libxgboost.so 40 | LDFLAGS += -Wl,-rpath,'$$ORIGIN/lib' 41 | LDFLAGS += -Wl,--allow-multiple-definition 42 | POST_INSTALL = $(NOOP) 43 | endif 44 | 45 | $(EXGBOOST_SO): $(EXGBOOST_CACHE_SO) 46 | @ mkdir -p $(PRIV_DIR) 47 | cp -a $(abspath $(EXGBOOST_CACHE_LIB_DIR)) $(EXGBOOST_LIB_DIR) ; \ 48 | cp -a $(abspath $(EXGBOOST_CACHE_SO)) $(EXGBOOST_SO) ; 49 | 50 | $(EXGBOOST_CACHE_SO): $(XGBOOST_LIB_DIR_FLAG) $(C_SRCS) 51 | @mkdir -p cache 52 | cp -a $(XGBOOST_LIB_DIR) $(EXGBOOST_CACHE_LIB_DIR) 53 | mv $(XGBOOST_LIB_DIR)/lib/$(LIBXGBOOST) $(EXGBOOST_CACHE_LIB_DIR) 54 | $(CC) $(CFLAGS) $(wildcard $(EXGBOOST_DIR)/src/*.c) $(LDFLAGS) -o $(EXGBOOST_CACHE_SO) 55 | $(POST_INSTALL) 56 | 57 | $(XGBOOST_LIB_DIR_FLAG): 58 | rm -rf $(XGBOOST_DIR) && \ 59 | mkdir -p $(XGBOOST_DIR) && \ 60 | cd $(XGBOOST_DIR) && \ 61 | git init && \ 62 | git remote add origin $(XGBOOST_GIT_REPO) && \ 63 | git fetch --depth 1 --recurse-submodules origin $(XGBOOST_GIT_REV) && \ 64 | git checkout FETCH_HEAD && \ 65 | git submodule update --init --recursive && \ 66 | sed 's|learner_parameters\["generic_param"\] = ToJson(ctx_);|&\nlearner_parameters\["default_metric"\] = String(obj_->DefaultEvalMetric());|' src/learner.cc > src/learner.cc.tmp && mv src/learner.cc.tmp src/learner.cc && \ 67 | cmake -DCMAKE_INSTALL_PREFIX=$(XGBOOST_LIB_DIR) -B build . $(CMAKE_FLAGS) && \ 68 | make -C build -j1 install 69 | touch $(XGBOOST_LIB_DIR_FLAG) 70 | 71 | clean: 72 | rm -rf $(EXGBOOST_CACHE_SO) 73 | rm -rf $(EXGBOOST_CACHE_LIB_DIR) 74 | rm -rf $(EXGBOOST_SO) 75 | rm -rf $(EXGBOOST_LIB_DIR) 76 | rm -rf $(XGBOOST_DIR) 77 | rm -rf $(XGBOOST_LIB_DIR_FLAG) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EXGBoost 2 | 3 | [![EXGBoost version](https://img.shields.io/hexpm/v/exgboost.svg)](https://hex.pm/packages/exgboost) 4 | [![Hex Docs](https://img.shields.io/badge/hex-docs-lightgreen.svg)](https://hexdocs.pm/exgboost/) 5 | [![Hex Downloads](https://img.shields.io/hexpm/dt/exgboost)](https://hex.pm/packages/exgboost) 6 | [![Twitter Follow](https://img.shields.io/twitter/follow/ac_alejos?style=social)](https://twitter.com/ac_alejos) 7 | 8 | Elixir bindings to the [XGBoost C API](https://xgboost.readthedocs.io/en/latest/c.html) using [Native Implemented Functions (NIFs)](https://www.erlang.org/doc/man/erl_nif.html). 9 | 10 | `EXGBoost` provides an implementation of XGBoost that works with 11 | [Nx](https://hexdocs.pm/nx/Nx.html) tensors. 12 | 13 | Xtreme Gradient Boosting (XGBoost) is an optimized distributed gradient 14 | boosting library designed to be highly efficient, flexible and portable. 15 | It implements machine learning algorithms under the [Gradient Boosting](https://en.wikipedia.org/wiki/Gradient_boosting) 16 | framework. XGBoost provides a parallel tree boosting (also known as GBDT, GBM) 17 | that solve many data science problems in a fast and accurate way. The same code 18 | runs on major distributed environment (Hadoop, SGE, MPI) and can solve problems beyond 19 | billions of examples. 20 | 21 | ## Installation 22 | 23 | ```elixir 24 | def deps do 25 | [ 26 | {:exgboost, "~> 0.5"} 27 | ] 28 | end 29 | ``` 30 | 31 | ## API Data Structures 32 | 33 | EXGBoost's top-level `EXGBoost` API works directly and only with `Nx` tensors. However, under the hood, 34 | it leverages the structs defined in the `EXGBoost.Booster` and `EXGBoost.DMatrix` modules. These structs 35 | are wrappers around the structs defined in the XGBoost library. 36 | The two main structs used are [DMatrix](https://xgboost.readthedocs.io/en/latest/c.html#dmatrix) 37 | to represent the data matrix that will be used 38 | to train the model, and [Booster](https://xgboost.readthedocs.io/en/latest/c.html#booster) 39 | which represents the model. 40 | 41 | The top-level `EXGBoost` API does not expose the structs directly. Instead, the 42 | structs are exposed through the `EXGBoost.Booster` and `EXGBoost.DMatrix` modules. Power users 43 | might wish to use these modules directly. For example, if you wish to use the `Booster` struct 44 | directly then you can use the `EXGBoost.Booster.booster/2` function to create a `Booster` struct 45 | from a `DMatrix` and a keyword list of options. See the `EXGBoost.Booster` and `EXGBoost.DMatrix` 46 | modules source for more implementation details. 47 | 48 | ## Basic Usage 49 | 50 | ```elixir 51 | key = Nx.Random.key(42) 52 | {x, key} = Nx.Random.normal(key, 0, 1, shape: {10, 5}) 53 | {y, key} = Nx.Random.normal(key, 0, 1, shape: {10}) 54 | model = EXGBoost.train(x, y) 55 | EXGBoost.predict(model, x) 56 | ``` 57 | 58 | ## Training 59 | 60 | EXGBoost is designed to feel familiar to the users of the Python XGBoost library. `EXGBoost.train/2` is the 61 | primary entry point for training a model. It accepts a Nx tensor for the features and a Nx tensor for the labels. 62 | `EXGBoost.train/2` returns a trained`Booster` struct that can be used for prediction. `EXGBoost.train/2` also 63 | accepts a keyword list of options that can be used to configure the training process. See the 64 | [XGBoost documentation](https://xgboost.readthedocs.io/en/latest/parameter.html) for the full list of options. 65 | 66 | `EXGBoost.train/2` uses the `EXGBoost.Training.train/1` function to perform the actual training. `EXGBoost.Training.train/1` 67 | and can be used directly if you wish to work directly with the `DMatrix` and `Booster` structs. 68 | 69 | One of the main features of `EXGBoost.train/2` is the ability for the end user to provide a custom training function 70 | that will be used to train the model. This is done by passing a function to the `:obj` option. The function must 71 | accept a `DMatrix` and a `Booster` and return a `Booster`. The function will be called at each iteration of the 72 | training process. This allows the user to implement custom training logic. For example, the user could implement 73 | a custom loss function or a custom metric function. See the [XGBoost documentation](https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html) 74 | for more information on custom loss functions and custom metric functions. 75 | 76 | Another feature of `EXGBoost.train/2` is the ability to provide a validation set for early stopping. This is done 77 | by passing a list of 3-tuples to the `:evals` option. Each 3-tuple should contain a Nx tensor for the features, a Nx tensor 78 | for the labels, and a string label for the validation set name. The validation set will be used to calculate the validation 79 | error at each iteration of the training process. If the validation error does not improve for `:early_stopping_rounds` iterations 80 | then the training process will stop. See the [XGBoost documentation](https://xgboost.readthedocs.io/en/latest/tutorials/param_tuning.html) 81 | for a more detailed explanation of early stopping. 82 | 83 | Early stopping is achieved through the use of callbacks. `EXGBoost.train/2` accepts a list of callbacks that will be called 84 | at each iteration of the training process. The callbacks can be used to implement custom logic. For example, the user could 85 | implement a callback that will print the validation error at each iteration of the training process or to provide a custom 86 | setup function for training. See the `EXGBoost.Training.Callback` module for more information on callbacks. 87 | 88 | Please notes that callbacks are called in the order that they are provided. If you provide multiple callbacks that modify 89 | the same parameter then the last callback will trump the previous callbacks. For example, if you provide a callback that 90 | sets the `:early_stopping_rounds` parameter to 10 and then provide a callback that sets the `:early_stopping_rounds` parameter 91 | to 20 then the `:early_stopping_rounds` parameter will be set to 20. 92 | 93 | You are also able to pass parameters to be applied to the Booster model using the `:params` option. These parameters will 94 | be applied to the Booster model before training begins. This allows you to set parameters that are not available as options 95 | to `EXGBoost.train/2`. See the [XGBoost documentation](https://xgboost.readthedocs.io/en/latest/parameter.html) for a full 96 | list of parameters. 97 | 98 | ```elixir 99 | EXGBoost.train(X, 100 | y, 101 | obj: &EXGBoost.Training.train/1, 102 | evals: [{X_test, y_test, "test"}], 103 | learning_rates: fn i -> i/10 end, 104 | num_boost_round: 10, 105 | early_stopping_rounds: 3, 106 | max_depth: 3, 107 | eval_metric: [:rmse,:logloss] 108 | ) 109 | ``` 110 | 111 | ## Prediction 112 | 113 | `EXGBoost.predict/2` is the primary entry point for making predictions with a trained model. 114 | It accepts a `Booster` struct (which is the output of `EXGBoost.train/2`). 115 | `EXGBoost.predict/2` returns a Nx tensor containing the predictions. 116 | `EXGBoost.predict/2` also accepts a keyword list of options that can be used to configure the prediction process. 117 | 118 | ```elixir 119 | preds = EXGBoost.train(X, y) |> EXGBoost.predict(X) 120 | ``` 121 | 122 | ## Serialization 123 | 124 | A Booster can be serialized to a file using `EXGBoost.write_*` and loaded from a file 125 | using `EXGBoost.read_*`. The file format can be specified using the `:format` option 126 | which can be either `:json` or `:ubj`. The default is `:json`. If the file already exists, it will NOT 127 | be overwritten by default. Boosters can either be serialized to a file or to a binary string. 128 | Boosters can be serialized in three different ways: configuration only, configuration and model, or 129 | model only. `dump` functions will serialize the Booster to a binary string. 130 | Functions named with `weights` will serialize the model's trained parameters only. This is best used when the model 131 | is already trained and only inferences/predictions are going to be performed. Functions named with `config` will 132 | serialize the configuration only. Functions that specify `model` will serialize both the model parameters 133 | and the configuration. 134 | 135 | ### Output Formats 136 | 137 | - `read`/`write` - File. 138 | - `load`/`dump` - Binary buffer. 139 | 140 | ### Output Contents 141 | 142 | - `config` - Save the configuration only. 143 | - `weights` - Save the model parameters only. Use this when you want to save the model to a format that can be ingested by other XGBoost APIs. 144 | - `model` - Save both the model parameters and the configuration. 145 | 146 | ## Plotting 147 | 148 | `EXGBoost.plot_tree/2` is the primary entry point for plotting a tree from a trained model. 149 | It accepts an `EXGBoost.Booster` struct (which is the output of `EXGBoost.train/2`). 150 | `EXGBoost.plot_tree/2` returns a VegaLite spec that can be rendered in a notebook or saved to a file. 151 | `EXGBoost.plot_tree/2` also accepts a keyword list of options that can be used to configure the plotting process. 152 | 153 | See `EXGBoost.Plotting` for more detail on plotting. 154 | 155 | You can see available styles by running `EXGBoost.Plotting.get_styles()` or refer to the `EXGBoost.Plotting.Styles` 156 | documentation for a gallery of the styles. 157 | 158 | ## Kino & Livebook Integration 159 | 160 | `EXGBoost` integrates with [Kino](https://hexdocs.pm/kino/Kino.html) and [Livebook](https://livebook.dev/) 161 | to provide a rich interactive experience for data scientists. 162 | 163 | EXGBoost implements the `Kino.Render` protocol for `EXGBoost.Booster` structs. This allows you to render 164 | a Booster in a Livebook notebook. Under the hood, `EXGBoost` uses [Vega-Lite](https://vega.github.io/vega-lite/) 165 | and [Kino Vega-Lite](https://hexdocs.pm/kino_vega_lite/Kino.VegaLite.html) to render the Booster. 166 | 167 | See the [`Plotting in EXGBoost`](notebooks/plotting.livemd) Notebook for an example of how to use `EXGBoost` with `Kino` and `Livebook`. 168 | 169 | ## Examples 170 | 171 | See the example Notebooks in the left sidebar (under the `Pages` tab) for more examples and tutorials 172 | on how to use EXGBoost. 173 | 174 | ## Requirements 175 | 176 | ### Precompiled Distribution 177 | 178 | We currenly offer the following precompiled packages for EXGBoost: 179 | 180 | ```elixir 181 | %{ 182 | "exgboost-nif-2.16-aarch64-apple-darwin-0.5.0.tar.gz" => "sha256:c659d086d07e9c209bdffbbf982951c6109b2097c4d3008ef9af59c3050663d2", 183 | "exgboost-nif-2.16-x86_64-apple-darwin-0.5.0.tar.gz" => "sha256:05256238700456c57e279558765b54b5b5ed4147878c6861cd4c937472abbe52", 184 | "exgboost-nif-2.16-x86_64-linux-gnu-0.5.0.tar.gz" => "sha256:ad3ba6aba8c3c2821dce4afc05b66a5e529764e0cea092c5a90e826446653d99", 185 | "exgboost-nif-2.17-aarch64-apple-darwin-0.5.0.tar.gz" => "sha256:745e7e970316b569a10d76ceb711b9189360b3bf9ab5ee6133747f4355f45483", 186 | "exgboost-nif-2.17-x86_64-apple-darwin-0.5.0.tar.gz" => "sha256:73948d6f2ef298e3ca3dceeca5d8a36a2d88d842827e1168c64589e4931af8d7", 187 | "exgboost-nif-2.17-x86_64-linux-gnu-0.5.0.tar.gz" => "sha256:a0b5ff0b074a9726c69d632b2dc0214fc7b66dccb4f5879e01255eeb7b9d4282", 188 | } 189 | ``` 190 | 191 | The correct package will be downloaded and installed (if supported) when you install 192 | the dependency through Mix (as shown above), otherwise you will need to compile 193 | manually. 194 | 195 | **NOTE** If MacOS, you still need to install `libomp` even to use the precompiled libraries: 196 | 197 | `brew install libomp` 198 | 199 | ### Dev Requirements 200 | 201 | If you are contributing to the library and need to compile locally or choose to not use the precompiled libraries, you will need the following: 202 | 203 | - Make 204 | - CMake 205 | - If MacOS: `brew install libomp` 206 | 207 | When you run `mix compile`, the `xgboost` shared library will be compiled, so the first time you compile your project will take longer than subsequent compilations. 208 | 209 | You also need to set `CC_PRECOMPILER_PRECOMPILE_ONLY_LOCAL=true` before the first local compilation, otherwise you will get an error related to a missing checksum file. 210 | 211 | ## Known Limitations 212 | 213 | - The XGBoost C API uses C function pointers to implement streaming data types. The Python ctypes library is able to pass function pointers to the C API which are then executed by XGBoost. Erlang/Elixir NIFs do not have this capability, and as such, streaming data types are not supported in EXGBoost. 214 | 215 | ## Roadmap 216 | 217 | - [ ] CUDA support 218 | - [ ] [Collective API](https://xgboost.readthedocs.io/en/latest/c.html#collective)? 219 | 220 | ## License 221 | 222 | Licensed under an [Apache-2](https://github.com/acalejos/exgboost/blob/main/LICENSE) license. 223 | -------------------------------------------------------------------------------- /c/exgboost/include/booster.h: -------------------------------------------------------------------------------- 1 | #ifndef EXGBOOST_BOOSTER_H 2 | #define EXGBOOST_BOOSTER_H 3 | 4 | #include "utils.h" 5 | 6 | ERL_NIF_TERM EXGBoosterCreate(ErlNifEnv *env, int argc, 7 | const ERL_NIF_TERM argv[]); 8 | ERL_NIF_TERM EXGBoosterBoostedRounds(ErlNifEnv *env, int argc, 9 | const ERL_NIF_TERM argv[]); 10 | ERL_NIF_TERM EXGBoosterSlice(ErlNifEnv *env, int argc, 11 | const ERL_NIF_TERM argv[]); 12 | ERL_NIF_TERM EXGBoosterSetParam(ErlNifEnv *env, int argc, 13 | const ERL_NIF_TERM argv[]); 14 | ERL_NIF_TERM EXGBoosterGetNumFeature(ErlNifEnv *env, int argc, 15 | const ERL_NIF_TERM argv[]); 16 | ERL_NIF_TERM EXGBoosterUpdateOneIter(ErlNifEnv *env, int argc, 17 | const ERL_NIF_TERM argv[]); 18 | ERL_NIF_TERM EXGBoosterBoostOneIter(ErlNifEnv *env, int argc, 19 | const ERL_NIF_TERM argv[]); 20 | ERL_NIF_TERM EXGBoosterEvalOneIter(ErlNifEnv *env, int argc, 21 | const ERL_NIF_TERM argv[]); 22 | ERL_NIF_TERM EXGBoosterGetAttrNames(ErlNifEnv *env, int argc, 23 | const ERL_NIF_TERM argv[]); 24 | ERL_NIF_TERM EXGBoosterGetAttr(ErlNifEnv *env, int argc, 25 | const ERL_NIF_TERM argv[]); 26 | ERL_NIF_TERM EXGBoosterSetAttr(ErlNifEnv *env, int argc, 27 | const ERL_NIF_TERM argv[]); 28 | ERL_NIF_TERM EXGBoosterSetStrFeatureInfo(ErlNifEnv *env, int argc, 29 | const ERL_NIF_TERM argv[]); 30 | ERL_NIF_TERM EXGBoosterGetStrFeatureInfo(ErlNifEnv *env, int argc, 31 | const ERL_NIF_TERM argv[]); 32 | ERL_NIF_TERM EXGBoosterFeatureScore(ErlNifEnv *env, int argc, 33 | const ERL_NIF_TERM argv[]); 34 | ERL_NIF_TERM EXGBoosterPredictFromDMatrix(ErlNifEnv *env, int argc, 35 | const ERL_NIF_TERM argv[]); 36 | ERL_NIF_TERM EXGBoosterPredictFromDense(ErlNifEnv *env, int argc, 37 | const ERL_NIF_TERM argv[]); 38 | ERL_NIF_TERM EXGBoosterPredictFromCSR(ErlNifEnv *env, int argc, 39 | const ERL_NIF_TERM argv[]); 40 | ERL_NIF_TERM EXGBoosterLoadModel(ErlNifEnv *env, int argc, 41 | const ERL_NIF_TERM argv[]); 42 | ERL_NIF_TERM EXGBoosterSaveModel(ErlNifEnv *env, int argc, 43 | const ERL_NIF_TERM argv[]); 44 | ERL_NIF_TERM EXGBoosterSerializeToBuffer(ErlNifEnv *env, int argc, 45 | const ERL_NIF_TERM argv[]); 46 | ERL_NIF_TERM EXGBoosterDeserializeFromBuffer(ErlNifEnv *env, int argc, 47 | const ERL_NIF_TERM argv[]); 48 | ERL_NIF_TERM EXGBoosterLoadModelFromBuffer(ErlNifEnv *env, int argc, 49 | const ERL_NIF_TERM argv[]); 50 | ERL_NIF_TERM EXGBoosterSaveModelToBuffer(ErlNifEnv *env, int argc, 51 | const ERL_NIF_TERM argv[]); 52 | ERL_NIF_TERM EXGBoosterSaveJsonConfig(ErlNifEnv *env, int argc, 53 | const ERL_NIF_TERM argv[]); 54 | ERL_NIF_TERM EXGBoosterLoadJsonConfig(ErlNifEnv *env, int argc, 55 | const ERL_NIF_TERM argv[]); 56 | ERL_NIF_TERM EXGBoosterDumpModelEx(ErlNifEnv *env, int argc, 57 | const ERL_NIF_TERM argv[]); 58 | #endif 59 | -------------------------------------------------------------------------------- /c/exgboost/include/config.h: -------------------------------------------------------------------------------- 1 | #ifndef EXGBOOST_CONFIG_H 2 | #define EXGBOOST_CONFIG_H 3 | 4 | #include "utils.h" 5 | 6 | /** 7 | * @brief Return the version of the XGBoost library being currently used. 8 | * 9 | * @param env 10 | * @param argc 11 | * @param argv 12 | * @return ERL_NIF_TERM as a 3-tuple of integers: {major, minor, patch} 13 | * 14 | */ 15 | ERL_NIF_TERM EXGBoostVersion(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); 16 | 17 | /** 18 | * @brief Return the build information of the XGBoost library being currently used. 19 | * 20 | * @param env 21 | * @param argc 22 | * @param argv 23 | * @return ERL_NIF_TERM String encoded JSON object containing build flags and dependency version. 24 | */ 25 | ERL_NIF_TERM EXGBuildInfo(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); 26 | 27 | /** 28 | * @brief Set global configuration (collection of parameters that apply globally). This function accepts the list of key-value pairs representing the global-scope parameters to be configured. The list of key-value pairs are passed in as a JSON string. 29 | * 30 | * @param env 31 | * @param argc 32 | * @param argv 33 | * @return ERL_NIF_TERM 0 on success, -1 on failure. 34 | */ 35 | ERL_NIF_TERM EXGBSetGlobalConfig(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); 36 | 37 | /** 38 | * @brief Get global configuration (collection of parameters that apply globally). This function returns the list of key-value pairs representing the global-scope parameters that are currently configured. The list of key-value pairs are returned as a JSON string. 39 | * 40 | * @param env 41 | * @param argc 42 | * @param argv 43 | * @return ERL_NIF_TERM string encoded JSON object containing the global-scope parameters that are currently configured. 44 | */ 45 | ERL_NIF_TERM EXGBGetGlobalConfig(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); 46 | 47 | #endif -------------------------------------------------------------------------------- /c/exgboost/include/dmatrix.h: -------------------------------------------------------------------------------- 1 | #ifndef EXGBOOST_DMATRIX_H 2 | #define EXGBOOST_DMATRIX_H 3 | 4 | #include "utils.h" 5 | 6 | ERL_NIF_TERM EXGDMatrixCreateFromFile(ErlNifEnv *env, int argc, 7 | const ERL_NIF_TERM argv[]); 8 | 9 | ERL_NIF_TERM EXGDMatrixCreateFromURI(ErlNifEnv *env, int argc, 10 | const ERL_NIF_TERM argv[]); 11 | 12 | ERL_NIF_TERM EXGDMatrixCreateFromMat(ErlNifEnv *env, int argc, 13 | const ERL_NIF_TERM argv[]); 14 | 15 | ERL_NIF_TERM EXGDMatrixCreateFromSparse(ErlNifEnv *env, int argc, 16 | const ERL_NIF_TERM argv[]); 17 | 18 | ERL_NIF_TERM EXGDMatrixCreateFromDense(ErlNifEnv *env, int argc, 19 | const ERL_NIF_TERM argv[]); 20 | 21 | ERL_NIF_TERM EXGDMatrixGetStrFeatureInfo(ErlNifEnv *env, int argc, 22 | const ERL_NIF_TERM argv[]); 23 | 24 | ERL_NIF_TERM EXGDMatrixSetStrFeatureInfo(ErlNifEnv *env, int argc, 25 | const ERL_NIF_TERM argv[]); 26 | 27 | ERL_NIF_TERM EXGDMatrixSetDenseInfo(ErlNifEnv *env, int argc, 28 | const ERL_NIF_TERM argv[]); 29 | 30 | ERL_NIF_TERM EXGDMatrixNumRow(ErlNifEnv *env, int argc, 31 | const ERL_NIF_TERM argv[]); 32 | 33 | ERL_NIF_TERM EXGDMatrixNumCol(ErlNifEnv *env, int argc, 34 | const ERL_NIF_TERM argv[]); 35 | 36 | ERL_NIF_TERM EXGDMatrixNumNonMissing(ErlNifEnv *env, int argc, 37 | const ERL_NIF_TERM argv[]); 38 | 39 | ERL_NIF_TERM EXGDMatrixSetInfoFromInterface(ErlNifEnv *env, int argc, 40 | const ERL_NIF_TERM argv[]); 41 | ERL_NIF_TERM EXGDMatrixSaveBinary(ErlNifEnv *env, int argc, 42 | const ERL_NIF_TERM argv[]); 43 | ERL_NIF_TERM EXGDMatrixGetFloatInfo(ErlNifEnv *env, int argc, 44 | const ERL_NIF_TERM argv[]); 45 | ERL_NIF_TERM EXGDMatrixGetUIntInfo(ErlNifEnv *env, int argc, 46 | const ERL_NIF_TERM argv[]); 47 | ERL_NIF_TERM EXGDMatrixGetDataAsCSR(ErlNifEnv *env, int argc, 48 | const ERL_NIF_TERM argv[]); 49 | ERL_NIF_TERM EXGDMatrixSliceDMatrix(ErlNifEnv *env, int argc, 50 | const ERL_NIF_TERM argv[]); 51 | ERL_NIF_TERM EXGProxyDMatrixCreate(ErlNifEnv *env, int argc, 52 | const ERL_NIF_TERM argv[]); 53 | ERL_NIF_TERM EXGDMatrixGetQuantileCut(ErlNifEnv *env, int argc, 54 | const ERL_NIF_TERM argv[]); 55 | #endif -------------------------------------------------------------------------------- /c/exgboost/include/exgboost.h: -------------------------------------------------------------------------------- 1 | #ifndef EXGBOOST_H 2 | #define EXGBOOST_H 3 | 4 | #include "config.h" 5 | #include "dmatrix.h" 6 | #include "booster.h" 7 | 8 | #endif -------------------------------------------------------------------------------- /c/exgboost/include/utils.h: -------------------------------------------------------------------------------- 1 | #ifndef EXGBOOST_UTILS_H 2 | #define EXGBOOST_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | ErlNifResourceType *DMatrix_RESOURCE_TYPE; 10 | ErlNifResourceType *Booster_RESOURCE_TYPE; 11 | typedef uint64_t bst_ulong; 12 | 13 | void DMatrix_RESOURCE_TYPE_cleanup(ErlNifEnv *env, void *arg); 14 | 15 | void Booster_RESOURCE_TYPE_cleanup(ErlNifEnv *env, void *arg); 16 | 17 | // Status helpers 18 | 19 | ERL_NIF_TERM exg_error(ErlNifEnv *env, const char *msg); 20 | 21 | ERL_NIF_TERM ok_atom(ErlNifEnv *env); 22 | 23 | ERL_NIF_TERM exg_ok(ErlNifEnv *env, ERL_NIF_TERM term); 24 | 25 | ERL_NIF_TERM exg_get_binary_address(ErlNifEnv *env, int argc, 26 | const ERL_NIF_TERM argv[]); 27 | 28 | ERL_NIF_TERM exg_get_binary_from_address(ErlNifEnv *env, int argc, 29 | const ERL_NIF_TERM argv[]); 30 | 31 | ERL_NIF_TERM exg_get_int_size(ErlNifEnv *env, int argc, 32 | const ERL_NIF_TERM argv[]); 33 | 34 | // Argument helpers 35 | 36 | int exg_get_string(ErlNifEnv *env, ERL_NIF_TERM term, char **var); 37 | 38 | int exg_get_list(ErlNifEnv *env, ERL_NIF_TERM term, double **out); 39 | 40 | int exg_get_string_list(ErlNifEnv *env, ERL_NIF_TERM term, char ***out, 41 | unsigned *len); 42 | int exg_get_dmatrix_list(ErlNifEnv *env, ERL_NIF_TERM term, 43 | DMatrixHandle **dmats, unsigned *len); 44 | 45 | #endif -------------------------------------------------------------------------------- /c/exgboost/src/config.c: -------------------------------------------------------------------------------- 1 | #include "booster.h" 2 | 3 | ERL_NIF_TERM EXGBoostVersion(ErlNifEnv *env, int argc, 4 | const ERL_NIF_TERM argv[]) { 5 | int major, minor, patch; 6 | XGBoostVersion(&major, &minor, &patch); 7 | return exg_ok(env, enif_make_tuple3(env, enif_make_int(env, major), 8 | enif_make_int(env, minor), 9 | enif_make_int(env, patch))); 10 | } 11 | 12 | ERL_NIF_TERM EXGBuildInfo(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { 13 | char const *out = NULL; 14 | int result = -1; 15 | ERL_NIF_TERM ret = 0; 16 | // Don't need to free this since it's a pointer to a static string defined in 17 | // the xgboost config struct 18 | // https://github.com/dmlc/xgboost/blob/21d95f3d8f23873a76f8afaad0fee5fa3e00eafe/src/c_api/c_api.cc#L107 19 | if (argc != 0) { 20 | ret = exg_error(env, "Wrong number of arguments"); 21 | goto END; 22 | } 23 | result = XGBuildInfo(&out); 24 | if (result == 0) { 25 | ret = exg_ok(env, enif_make_string(env, out, ERL_NIF_LATIN1)); 26 | } else { 27 | ret = exg_error(env, XGBGetLastError()); 28 | } 29 | END: 30 | return ret; 31 | } 32 | 33 | ERL_NIF_TERM EXGBSetGlobalConfig(ErlNifEnv *env, int argc, 34 | const ERL_NIF_TERM argv[]) { 35 | char *config = NULL; 36 | int result = -1; 37 | ERL_NIF_TERM ret = 0; 38 | if (argc != 1) { 39 | ret = exg_error(env, "Wrong number of arguments"); 40 | goto END; 41 | } 42 | if (!exg_get_string(env, argv[0], &config)) { 43 | ret = exg_error(env, "Config must be a string"); 44 | goto END; 45 | } 46 | result = XGBSetGlobalConfig((char const *)config); 47 | if (result == 0) { 48 | ret = ok_atom(env); 49 | } else { 50 | ret = exg_error(env, XGBGetLastError()); 51 | } 52 | END: 53 | if (config != NULL) { 54 | enif_free(config); 55 | config = NULL; 56 | } 57 | return ret; 58 | } 59 | 60 | ERL_NIF_TERM EXGBGetGlobalConfig(ErlNifEnv *env, int argc, 61 | const ERL_NIF_TERM argv[]) { 62 | char *out = NULL; 63 | int result = -1; 64 | ERL_NIF_TERM ret = 0; 65 | if (argc != 0) { 66 | ret = exg_error(env, "Wrong number of arguments"); 67 | goto END; 68 | } 69 | // No need to free out, it's a pointer to a static string defined in the 70 | // xgboost config struct 71 | result = XGBGetGlobalConfig((char const **)&out); 72 | if (result == 0) { 73 | ret = exg_ok(env, enif_make_string(env, out, ERL_NIF_LATIN1)); 74 | } else { 75 | ret = exg_error(env, XGBGetLastError()); 76 | } 77 | END: 78 | return ret; 79 | } -------------------------------------------------------------------------------- /c/exgboost/src/exgboost.c: -------------------------------------------------------------------------------- 1 | #include "exgboost.h" 2 | 3 | static int load(ErlNifEnv *env, void **priv_data, ERL_NIF_TERM load_info) { 4 | DMatrix_RESOURCE_TYPE = enif_open_resource_type( 5 | env, NULL, "DMatrix_RESOURCE_TYPE", DMatrix_RESOURCE_TYPE_cleanup, 6 | (ErlNifResourceFlags)(ERL_NIF_RT_CREATE | ERL_NIF_RT_TAKEOVER), NULL); 7 | Booster_RESOURCE_TYPE = enif_open_resource_type( 8 | env, NULL, "Booster_RESOURCE_TYPE", Booster_RESOURCE_TYPE_cleanup, 9 | (ErlNifResourceFlags)(ERL_NIF_RT_CREATE | ERL_NIF_RT_TAKEOVER), NULL); 10 | if (DMatrix_RESOURCE_TYPE == NULL || Booster_RESOURCE_TYPE == NULL) { 11 | return 1; 12 | } 13 | return 0; 14 | } 15 | 16 | static int upgrade(ErlNifEnv *env, void **priv_data, void **old_priv_data, 17 | ERL_NIF_TERM load_info) { 18 | DMatrix_RESOURCE_TYPE = enif_open_resource_type( 19 | env, NULL, "DMatrix_RESOURCE_TYPE", DMatrix_RESOURCE_TYPE_cleanup, 20 | ERL_NIF_RT_TAKEOVER, NULL); 21 | Booster_RESOURCE_TYPE = enif_open_resource_type( 22 | env, NULL, "Booster_RESOURCE_TYPE", Booster_RESOURCE_TYPE_cleanup, 23 | ERL_NIF_RT_TAKEOVER, NULL); 24 | if (DMatrix_RESOURCE_TYPE == NULL || Booster_RESOURCE_TYPE == NULL) { 25 | return 1; 26 | } 27 | return 0; 28 | } 29 | 30 | static ErlNifFunc nif_funcs[] = { 31 | {"get_int_size", 0, exg_get_int_size}, 32 | {"xgboost_version", 0, EXGBoostVersion}, 33 | {"xgboost_build_info", 0, EXGBuildInfo}, 34 | {"set_global_config", 1, EXGBSetGlobalConfig}, 35 | {"get_global_config", 0, EXGBGetGlobalConfig}, 36 | {"proxy_dmatrix_create", 0, EXGProxyDMatrixCreate}, 37 | {"dmatrix_create_from_file", 2, EXGDMatrixCreateFromFile, 38 | ERL_NIF_DIRTY_JOB_IO_BOUND}, 39 | {"dmatrix_create_from_uri", 1, EXGDMatrixCreateFromURI, 40 | ERL_NIF_DIRTY_JOB_IO_BOUND}, 41 | {"dmatrix_create_from_mat", 4, EXGDMatrixCreateFromMat}, 42 | {"dmatrix_create_from_sparse", 6, EXGDMatrixCreateFromSparse}, 43 | {"dmatrix_create_from_dense", 2, EXGDMatrixCreateFromDense}, 44 | {"dmatrix_set_str_feature_info", 3, EXGDMatrixSetStrFeatureInfo}, 45 | {"dmatrix_get_str_feature_info", 2, EXGDMatrixGetStrFeatureInfo}, 46 | {"dmatrix_num_row", 1, EXGDMatrixNumRow}, 47 | {"dmatrix_num_col", 1, EXGDMatrixNumCol}, 48 | {"dmatrix_num_non_missing", 1, EXGDMatrixNumNonMissing}, 49 | {"dmatrix_set_info_from_interface", 3, EXGDMatrixSetInfoFromInterface}, 50 | {"dmatrix_save_binary", 3, EXGDMatrixSaveBinary}, 51 | {"get_binary_address", 1, exg_get_binary_address}, 52 | {"get_binary_from_address", 2, exg_get_binary_from_address}, 53 | {"dmatrix_get_float_info", 2, EXGDMatrixGetFloatInfo}, 54 | {"dmatrix_get_uint_info", 2, EXGDMatrixGetUIntInfo}, 55 | {"dmatrix_get_data_as_csr", 2, EXGDMatrixGetDataAsCSR}, 56 | {"dmatrix_slice", 3, EXGDMatrixSliceDMatrix}, 57 | {"dmatrix_get_quantile_cut", 2, EXGDMatrixGetQuantileCut}, 58 | {"booster_create", 1, EXGBoosterCreate}, 59 | {"booster_boosted_rounds", 1, EXGBoosterBoostedRounds}, 60 | {"booster_set_param", 3, EXGBoosterSetParam}, 61 | {"booster_get_num_feature", 1, EXGBoosterGetNumFeature}, 62 | {"booster_update_one_iter", 3, EXGBoosterUpdateOneIter, 63 | ERL_NIF_DIRTY_JOB_CPU_BOUND}, 64 | {"booster_boost_one_iter", 4, EXGBoosterBoostOneIter, 65 | ERL_NIF_DIRTY_JOB_CPU_BOUND}, 66 | {"booster_eval_one_iter", 4, EXGBoosterEvalOneIter, 67 | ERL_NIF_DIRTY_JOB_CPU_BOUND}, 68 | {"booster_get_attr_names", 1, EXGBoosterGetAttrNames}, 69 | {"booster_get_attr", 2, EXGBoosterGetAttr}, 70 | {"booster_set_attr", 3, EXGBoosterSetAttr}, 71 | {"booster_set_str_feature_info", 3, EXGBoosterSetStrFeatureInfo}, 72 | {"booster_get_str_feature_info", 2, EXGBoosterGetStrFeatureInfo}, 73 | {"booster_feature_score", 2, EXGBoosterFeatureScore}, 74 | {"booster_slice", 4, EXGBoosterSlice}, 75 | {"booster_predict_from_dmatrix", 3, EXGBoosterPredictFromDMatrix, 76 | ERL_NIF_DIRTY_JOB_CPU_BOUND}, 77 | {"booster_predict_from_dense", 4, EXGBoosterPredictFromDense, 78 | ERL_NIF_DIRTY_JOB_CPU_BOUND}, 79 | {"booster_predict_from_csr", 7, EXGBoosterPredictFromCSR, 80 | ERL_NIF_DIRTY_JOB_CPU_BOUND}, 81 | {"booster_load_model", 1, EXGBoosterLoadModel, ERL_NIF_DIRTY_JOB_IO_BOUND}, 82 | {"booster_save_model", 2, EXGBoosterSaveModel, ERL_NIF_DIRTY_JOB_IO_BOUND}, 83 | // These all return binaries so they're CPU bound rather than IO bound 84 | {"booster_serialize_to_buffer", 1, EXGBoosterSerializeToBuffer, 85 | ERL_NIF_DIRTY_JOB_CPU_BOUND}, 86 | {"booster_deserialize_from_buffer", 1, EXGBoosterDeserializeFromBuffer, 87 | ERL_NIF_DIRTY_JOB_CPU_BOUND}, 88 | {"booster_save_model_to_buffer", 2, EXGBoosterSaveModelToBuffer, 89 | ERL_NIF_DIRTY_JOB_CPU_BOUND}, 90 | {"booster_load_model_from_buffer", 1, EXGBoosterLoadModelFromBuffer, 91 | ERL_NIF_DIRTY_JOB_CPU_BOUND}, 92 | {"booster_load_json_config", 2, EXGBoosterLoadJsonConfig, 93 | ERL_NIF_DIRTY_JOB_CPU_BOUND}, 94 | {"booster_dump_model", 4, EXGBoosterDumpModelEx, 95 | ERL_NIF_DIRTY_JOB_CPU_BOUND}, 96 | {"booster_save_json_config", 1, EXGBoosterSaveJsonConfig, 97 | ERL_NIF_DIRTY_JOB_CPU_BOUND}}; 98 | ERL_NIF_INIT(Elixir.EXGBoost.NIF, nif_funcs, load, NULL, upgrade, NULL) 99 | -------------------------------------------------------------------------------- /c/exgboost/src/utils.c: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | // Atoms 4 | ERL_NIF_TERM exg_error(ErlNifEnv *env, const char *msg) { 5 | ERL_NIF_TERM atom = enif_make_atom(env, "error"); 6 | ERL_NIF_TERM msg_term = enif_make_string(env, msg, ERL_NIF_LATIN1); 7 | return enif_make_tuple2(env, atom, msg_term); 8 | } 9 | 10 | ERL_NIF_TERM ok_atom(ErlNifEnv *env) { return enif_make_atom(env, "ok"); } 11 | 12 | ERL_NIF_TERM exg_ok(ErlNifEnv *env, ERL_NIF_TERM term) { 13 | return enif_make_tuple2(env, ok_atom(env), term); 14 | } 15 | 16 | // Resource type helpers 17 | void DMatrix_RESOURCE_TYPE_cleanup(ErlNifEnv *env, void *arg) { 18 | DMatrixHandle handle = *((DMatrixHandle *)arg); 19 | XGDMatrixFree(handle); 20 | } 21 | 22 | void Booster_RESOURCE_TYPE_cleanup(ErlNifEnv *env, void *arg) { 23 | BoosterHandle handle = *((BoosterHandle *)arg); 24 | XGBoosterFree(handle); 25 | } 26 | 27 | // Argument helpers 28 | int exg_get_string(ErlNifEnv *env, ERL_NIF_TERM term, char **var) { 29 | unsigned len; 30 | int ret = enif_get_list_length(env, term, &len); 31 | 32 | if (!ret) { 33 | ErlNifBinary bin; 34 | ret = enif_inspect_binary(env, term, &bin); 35 | if (!ret) { 36 | return 0; 37 | } 38 | *var = (char *)enif_alloc(bin.size + 1); 39 | strncpy(*var, (const char *)bin.data, bin.size); 40 | (*var)[bin.size] = '\0'; 41 | return ret; 42 | } 43 | 44 | *var = (char *)enif_alloc(len + 1); 45 | ret = enif_get_string(env, term, *var, len + 1, ERL_NIF_LATIN1); 46 | 47 | if (ret > 0) { 48 | (*var)[ret - 1] = '\0'; 49 | } else if (ret == 0) { 50 | (*var)[0] = '\0'; 51 | } 52 | 53 | return ret; 54 | } 55 | 56 | int exg_get_list(ErlNifEnv *env, ERL_NIF_TERM term, double **out) { 57 | ERL_NIF_TERM head, tail; 58 | unsigned len = 0; 59 | int i = 0; 60 | if (!enif_get_list_length(env, term, &len)) { 61 | return 0; 62 | } 63 | *out = (double *)enif_alloc(len * sizeof(double)); 64 | if (out == NULL) { 65 | return 0; 66 | } 67 | while (enif_get_list_cell(env, term, &head, &tail)) { 68 | int ret = enif_get_double(env, head, &((*out)[i])); 69 | if (!ret) { 70 | return 0; 71 | } 72 | term = tail; 73 | i++; 74 | } 75 | return 1; 76 | } 77 | 78 | int exg_get_string_list(ErlNifEnv *env, ERL_NIF_TERM term, char ***out, 79 | unsigned *len) { 80 | ERL_NIF_TERM head, tail; 81 | int i = 0; 82 | if (!enif_get_list_length(env, term, len)) { 83 | return 0; 84 | } 85 | *out = (char **)enif_alloc(*len * sizeof(char *)); 86 | if (*out == NULL) { 87 | return 0; 88 | } 89 | while (enif_get_list_cell(env, term, &head, &tail)) { 90 | int ret = exg_get_string(env, head, &((*out)[i])); 91 | if (!ret) { 92 | return 0; 93 | } 94 | term = tail; 95 | i++; 96 | } 97 | return 1; 98 | } 99 | 100 | int exg_get_dmatrix_list(ErlNifEnv *env, ERL_NIF_TERM term, 101 | DMatrixHandle **dmats, unsigned *len) { 102 | ERL_NIF_TERM head, tail; 103 | int i = 0; 104 | if (!enif_get_list_length(env, term, len)) { 105 | return 0; 106 | } 107 | *dmats = (DMatrixHandle *)enif_alloc(*len * sizeof(DMatrixHandle)); 108 | if (NULL == dmats) { 109 | return 0; 110 | } 111 | while (enif_get_list_cell(env, term, &head, &tail)) { 112 | DMatrixHandle **resource = NULL; 113 | if (!enif_get_resource(env, head, DMatrix_RESOURCE_TYPE, 114 | (void *)&(resource))) { 115 | return 0; 116 | } 117 | memcpy(&((*dmats)[i]), resource, sizeof(DMatrixHandle)); 118 | term = tail; 119 | i++; 120 | } 121 | return 1; 122 | } 123 | 124 | ERL_NIF_TERM exg_get_binary_address(ErlNifEnv *env, int argc, 125 | const ERL_NIF_TERM argv[]) { 126 | ErlNifBinary bin; 127 | ERL_NIF_TERM ret = 0; 128 | if (argc != 1) { 129 | ret = exg_error(env, "exg_get_binary_address: wrong number of arguments"); 130 | goto END; 131 | } 132 | if (!enif_inspect_binary(env, argv[0], &bin)) { 133 | ret = exg_error(env, "exg_get_binary_address: invalid binary"); 134 | goto END; 135 | } 136 | ret = exg_ok(env, enif_make_uint64(env, (uint64_t)bin.data)); 137 | END: 138 | return ret; 139 | } 140 | 141 | ERL_NIF_TERM exg_get_binary_from_address(ErlNifEnv *env, int argc, 142 | const ERL_NIF_TERM argv[]) { 143 | ErlNifBinary out_bin; 144 | ErlNifUInt64 address = 0; 145 | ErlNifUInt64 size = 0; 146 | ERL_NIF_TERM ret = -1; 147 | if (argc != 2) { 148 | ret = exg_error(env, "exg_get_binary_address: wrong number of arguments"); 149 | goto END; 150 | } 151 | if (!enif_get_uint64(env, argv[0], &address)) { 152 | ret = exg_error(env, "exg_get_binary_address: invalid address"); 153 | goto END; 154 | } 155 | if (!enif_get_uint64(env, argv[1], &size)) { 156 | ret = exg_error(env, "exg_get_binary_address: invalid size"); 157 | goto END; 158 | } 159 | if (!enif_alloc_binary(size, &out_bin)) { 160 | ret = exg_error(env, "Failed to allocate binary"); 161 | goto END; 162 | } 163 | memcpy(out_bin.data, address, size); 164 | ret = exg_ok(env, enif_make_binary(env, &out_bin)); 165 | END: 166 | return ret; 167 | } 168 | 169 | ERL_NIF_TERM exg_get_int_size(ErlNifEnv *env, int argc, 170 | const ERL_NIF_TERM argv[]) { 171 | ERL_NIF_TERM ret = 0; 172 | if (argc != 0) { 173 | ret = exg_error(env, "exg_get_int_size doesn't take any arguments"); 174 | goto END; 175 | } 176 | int size = sizeof(int); 177 | ret = exg_ok(env, enif_make_int(env, size)); 178 | END: 179 | return ret; 180 | } -------------------------------------------------------------------------------- /lib/exgboost.ex: -------------------------------------------------------------------------------- 1 | defmodule EXGBoost do 2 | @moduledoc """ 3 | #{File.cwd!() |> Path.join("README.md") |> File.read!() |> then(&Regex.run(~r/.*(?P.*).*/s, &1, capture: :all_but_first)) |> hd()} 4 | """ 5 | 6 | alias EXGBoost.ArrayInterface 7 | alias EXGBoost.Booster 8 | alias EXGBoost.Internal 9 | alias EXGBoost.DMatrix 10 | alias EXGBoost.ProxyDMatrix 11 | alias EXGBoost.Training 12 | alias EXGBoost.Plotting 13 | 14 | @doc """ 15 | Check the build information of the xgboost library. 16 | 17 | Returns a map containing information about the build. 18 | """ 19 | @spec xgboost_build_info() :: map() 20 | @doc type: :system 21 | def xgboost_build_info, 22 | do: EXGBoost.NIF.xgboost_build_info() |> Internal.unwrap!() |> Jason.decode!() 23 | 24 | @doc """ 25 | Check the version of the xgboost library. 26 | 27 | Returns a 3-tuple in the form of `{major, minor, patch}`. 28 | """ 29 | @spec xgboost_version() :: {integer(), integer(), integer()} | {:error, String.t()} 30 | @doc type: :system 31 | def xgboost_version, do: EXGBoost.NIF.xgboost_version() |> Internal.unwrap!() 32 | 33 | @doc """ 34 | Set global configuration. 35 | 36 | Global configuration consists of a collection of parameters that can be 37 | applied in the global scope. See `Global Parameters` in `EXGBoost.Parameters` 38 | for the full list of parameters supported in the global configuration. 39 | """ 40 | @spec set_config(map()) :: :ok | {:error, String.t()} 41 | @doc type: :system 42 | def set_config(%{} = config) do 43 | config = EXGBoost.Parameters.validate_global!(config) 44 | EXGBoost.NIF.set_global_config(Jason.encode!(config)) |> Internal.unwrap!() 45 | end 46 | 47 | @doc """ 48 | Get current values of the global configuration. 49 | 50 | Global configuration consists of a collection of parameters that can be 51 | applied in the global scope. See `Global Parameters` in `EXGBoost.Parameters` 52 | for the full list of parameters supported in the global configuration. 53 | """ 54 | @spec get_config() :: map() 55 | @doc type: :system 56 | def get_config do 57 | EXGBoost.NIF.get_global_config() |> Internal.unwrap!() |> Jason.decode!() 58 | end 59 | 60 | @doc """ 61 | Train a new booster model given a data tensor and a label tensor. 62 | 63 | ## Options 64 | 65 | * `:obj` - Specify the learning task and the corresponding learning objective. 66 | This function must accept two arguments: preds, dtrain. preds is an array of 67 | predicted real valued scores. dtrain is the training data set. This function 68 | returns gradient and second order gradient. 69 | 70 | * `:num_boost_rounds` - Number of boosting iterations. 71 | 72 | * `:evals` - A list of 3-Tuples `{x, y, label}` to use as a validation set for 73 | early-stopping. 74 | 75 | * `:early_stopping_rounds` - Activates early stopping. Target metric needs to 76 | increase/decrease (depending on metric) at least every `early_stopping_rounds` 77 | round(s) to continue training. Requires at least one item in `:evals`. If there's 78 | more than one, will use the last eval set. If there’s more than one metric in the 79 | `eval_metric` parameter given in the booster's params, the last metric will be 80 | used for early stopping. If early stopping occurs, the model will have two additional fields: 81 | 82 | 83 | - `bst.best_score` 84 | - `bst.best_iteration`. 85 | 86 | If these values are `nil` then no early stopping occurred. 87 | 88 | * `:verbose_eval` - Requires at least one item in `evals`. If `verbose_eval` is true then the evaluation metric on the validation set is printed at each boosting stage. If verbose_eval is an 89 | integer then the evaluation metric on the validation set is printed at every given `verbose_eval` boosting stage. The last boosting stage / the boosting stage found by using `early_stopping_rounds` 90 | is also printed. Example: with `verbose_eval=4` and at least one item in evals, an evaluation metric is printed every 4 boosting stages, instead of every boosting stage. 91 | 92 | * `:learning_rates` - Either an arity 1 function that accept an integer parameter epoch and returns the corresponding learning rate or a list with the same length as num_boost_rounds. 93 | 94 | * `:callbacks` - List of `EXGBoost.Training.Callback` that are called during a given event. It is possible to use predefined callbacks by using `EXGBoost.Training.Callback` module. 95 | Callbacks should be in the form of a keyword list where the only valid keys are `:before_training`, `:after_training`, `:before_iteration`, and `:after_iteration`. 96 | The value of each key should be a list of functions that accepts a booster and an iteration and returns a booster. The function will be called at the appropriate time with the booster and the iteration 97 | as the arguments. The function should return the booster. If the function returns a booster with a different memory address, the original booster will be replaced with the new booster. 98 | If the function returns the original booster, the original booster will be used. If the function returns a booster with the same memory address but different contents, the behavior is undefined. 99 | 100 | 101 | * `opts` - Refer to `EXGBoost.Parameters` for the full list of options. 102 | """ 103 | @spec train(Nx.Tensor.t(), Nx.Tensor.t(), Keyword.t()) :: EXGBoost.Booster.t() 104 | @doc type: :train_pred 105 | def train(x, y, opts \\ []) do 106 | x = Nx.concatenate(x) 107 | y = Nx.concatenate(y) 108 | dmat_opts = Keyword.take(opts, Internal.dmatrix_feature_opts()) 109 | dmat = DMatrix.from_tensor(x, y, Keyword.put_new(dmat_opts, :format, :dense)) 110 | Training.train(dmat, opts) 111 | end 112 | 113 | @doc """ 114 | Predict with a booster model against a tensor. 115 | 116 | The full model will be used unless `iteration_range` is specified, 117 | meaning user have to either slice the model or use the `best_iteration` 118 | attribute to get prediction from best model returned from early stopping. 119 | 120 | ## Options 121 | 122 | * `:output_margin` - Whether to output the raw untransformed margin value. 123 | 124 | * `:pred_leaf ` - When this option is on, the output will be an `Nx.Tensor` of 125 | shape {nsamples, ntrees}, where each row indicates the predicted leaf 126 | index of each sample in each tree. Note that the leaf index of a tree is 127 | unique per tree, but not globally, so you may find leaf 1 in both tree 1 and tree 0. 128 | 129 | * `:pred_contribs` - When this is `true` the output will be a matrix of size `{nsample, 130 | nfeats + 1}` with each record indicating the feature contributions 131 | (SHAP values) for that prediction. The sum of all feature 132 | contributions is equal to the raw untransformed margin value of the 133 | prediction. Note the final column is the bias term. 134 | 135 | * `:approx_contribs` - Approximate the contributions of each feature. Used when `pred_contribs` or 136 | `pred_interactions` is set to `true`. Changing the default of this parameter 137 | (false) is not recommended. 138 | 139 | * `:pred_interactions` - When this is `true` the output will be an `Nx.Tensor` of shape 140 | {nsamples, nfeats + 1} indicating the SHAP interaction values for 141 | each pair of features. The sum of each row (or column) of the 142 | interaction values equals the corresponding SHAP value (from 143 | pred_contribs), and the sum of the entire matrix equals the raw 144 | untransformed margin value of the prediction. Note the last row and 145 | column correspond to the bias term. 146 | 147 | * `:validate_features` - When this is `true`, validate that the Booster's and data's 148 | feature_names are identical. Otherwise, it is assumed that the 149 | feature_names are the same. 150 | 151 | * `:training` - Determines whether the prediction value is used for training. This 152 | can affect the `dart` booster, which performs dropouts during training iterations 153 | but uses all trees for inference. If you want to obtain result with dropouts, set 154 | this option to `true`. Also, the option is set to `true` when obtaining prediction for 155 | custom objective function. 156 | 157 | * `:iteration_range` - Specifies which layer of trees are used in prediction. For example, if a 158 | random forest is trained with 100 rounds. Specifying `iteration_range=(10, 159 | 20)`, then only the forests built during [10, 20) (half open set) rounds are 160 | used in this prediction. 161 | 162 | * `:strict_shape` - When set to `true`, output shape is invariant to whether classification is used. 163 | For both value and margin prediction, the output shape is (n_samples, 164 | n_groups), n_groups == 1 when multi-class is not used. Defaults to `false`, in 165 | which case the output shape can be (n_samples, ) if multi-class is not used. 166 | 167 | Returns an Nx.Tensor containing the predictions. 168 | """ 169 | @doc type: :train_pred 170 | def predict(%Booster{} = bst, x, opts \\ []) do 171 | x = Nx.concatenate(x) 172 | {dmat_opts, opts} = Keyword.split(opts, Internal.dmatrix_feature_opts()) 173 | dmat = DMatrix.from_tensor(x, Keyword.put_new(dmat_opts, :format, :dense)) 174 | Booster.predict(bst, dmat, opts) 175 | end 176 | 177 | @doc """ 178 | Run prediction in-place, Unlike `EXGBoost.predict/2`, in-place prediction does not cache the prediction result. 179 | 180 | ## Options 181 | 182 | * `:base_margin` - Base margin used for boosting from existing model. 183 | 184 | * `:missing` - Value used for missing values. If None, defaults to `Nx.Constants.nan()`. 185 | 186 | * `:predict_type` - One of: 187 | 188 | * `"value"` - Output model prediction values. 189 | 190 | * `"margin"` - Output the raw untransformed margin value. 191 | 192 | * `:output_margin` - Whether to output the raw untransformed margin value. 193 | 194 | * `:iteration_range` - See `EXGBoost.predict/2` for details. 195 | 196 | * `:strict_shape` - See `EXGBoost.predict/2` for details. 197 | 198 | Returns an Nx.Tensor containing the predictions. 199 | """ 200 | @doc type: :train_pred 201 | def inplace_predict(%Booster{} = boostr, data, opts \\ []) do 202 | opts = 203 | Keyword.validate!(opts, 204 | iteration_range: {0, 0}, 205 | predict_type: "value", 206 | missing: Nx.Constants.nan(), 207 | validate_features: true, 208 | base_margin: nil, 209 | strict_shape: false 210 | ) 211 | 212 | base_margin = Keyword.fetch!(opts, :base_margin) 213 | {iteration_range_left, iteration_range_right} = Keyword.fetch!(opts, :iteration_range) 214 | 215 | params = %{ 216 | type: if(Keyword.fetch!(opts, :predict_type) == "margin", do: 1, else: 0), 217 | training: false, 218 | iteration_begin: iteration_range_left, 219 | iteration_end: iteration_range_right, 220 | missing: Keyword.fetch!(opts, :missing), 221 | strict_shape: Keyword.fetch!(opts, :strict_shape), 222 | cache_id: 0 223 | } 224 | 225 | proxy = 226 | if not is_nil(base_margin) do 227 | prox = ProxyDMatrix.proxy_dmatrix() 228 | prox = DMatrix.set_params(prox, base_margin: base_margin) 229 | prox.ref 230 | else 231 | nil 232 | end 233 | 234 | case data do 235 | %Nx.Tensor{} = data -> 236 | data_interface = ArrayInterface.from_tensor(data) |> Jason.encode!() 237 | 238 | {shape, preds} = 239 | EXGBoost.NIF.booster_predict_from_dense( 240 | boostr.ref, 241 | data_interface, 242 | Jason.encode!(params), 243 | proxy 244 | ) 245 | |> Internal.unwrap!() 246 | 247 | Nx.tensor(preds) |> Nx.reshape(shape) 248 | 249 | {%Nx.Tensor{} = indptr, %Nx.Tensor{} = indices, %Nx.Tensor{} = values, ncol} -> 250 | indptr_interface = ArrayInterface.from_tensor(indptr) |> Jason.encode!() 251 | indices_interface = ArrayInterface.from_tensor(indices) |> Jason.encode!() 252 | values_interface = ArrayInterface.from_tensor(values) |> Jason.encode!() 253 | 254 | {shape, preds} = 255 | EXGBoost.NIF.booster_predict_from_csr( 256 | boostr.ref, 257 | indptr_interface, 258 | indices_interface, 259 | values_interface, 260 | ncol, 261 | Jason.encode!(params), 262 | proxy 263 | ) 264 | |> Internal.unwrap!() 265 | 266 | Nx.tensor(preds) |> Nx.reshape(shape) 267 | 268 | data -> 269 | data = Nx.concatenate(data) 270 | data_interface = ArrayInterface.from_tensor(data) |> Jason.encode!() 271 | 272 | {shape, preds} = 273 | EXGBoost.NIF.booster_predict_from_dense( 274 | boostr.ref, 275 | data_interface, 276 | Jason.encode!(params), 277 | proxy 278 | ) 279 | |> Internal.unwrap!() 280 | 281 | Nx.tensor(preds) |> Nx.reshape(shape) 282 | end 283 | end 284 | 285 | @format_opts [ 286 | format: [ 287 | type: {:in, [:json, :ubj]}, 288 | default: :json, 289 | doc: """ 290 | The format to serialize to. Can be either `:json` or `:ubj`. 291 | """ 292 | ] 293 | ] 294 | 295 | @overwrite_opts [ 296 | overwrite: [ 297 | type: :boolean, 298 | default: false, 299 | doc: """ 300 | Whether or not to overwrite the file if it already exists. 301 | """ 302 | ] 303 | ] 304 | 305 | @load_opts [ 306 | booster: [ 307 | type: {:struct, Booster}, 308 | doc: """ 309 | The Booster to load the model into. If a Booster is provided, the model will be loaded into 310 | that Booster. Otherwise, a new Booster will be created. If a Booster is provided, model parameters 311 | will be merged with the existing Booster's parameters using Map.merge/2, where the parameters 312 | of the provided Booster take precedence. 313 | """ 314 | ] 315 | ] 316 | 317 | @write_schema NimbleOptions.new!(@format_opts ++ @overwrite_opts) 318 | @dump_schema NimbleOptions.new!(@format_opts) 319 | @load_schema NimbleOptions.new!(@load_opts) 320 | 321 | @doc """ 322 | Write a model to a file. 323 | 324 | ## Options 325 | #{NimbleOptions.docs(@write_schema)} 326 | """ 327 | @doc type: :serialization 328 | @spec write_model(Booster.t(), String.t()) :: :ok | {:error, String.t()} 329 | def write_model(%Booster{} = booster, path, opts \\ []) do 330 | opts = NimbleOptions.validate!(opts, @write_schema) 331 | EXGBoost.Booster.save(booster, opts ++ [path: path, serialize: :model]) 332 | end 333 | 334 | @doc """ 335 | Read a model from a file and return the Booster. 336 | """ 337 | @doc type: :serialization 338 | @spec read_model(String.t()) :: EXGBoost.Booster.t() 339 | def read_model(path) do 340 | EXGBoost.Booster.load(path, deserialize: :model) 341 | end 342 | 343 | @doc """ 344 | Dump a model to a binary encoded in the desired format. 345 | 346 | ## Options 347 | #{NimbleOptions.docs(@dump_schema)} 348 | """ 349 | @spec dump_model(Booster.t()) :: binary() 350 | @doc type: :serialization 351 | def dump_model(%Booster{} = booster, opts \\ []) do 352 | opts = NimbleOptions.validate!(opts, @dump_schema) 353 | EXGBoost.Booster.save(booster, opts ++ [serialize: :model, to: :buffer]) 354 | end 355 | 356 | @doc """ 357 | Read a model from a buffer and return the Booster. 358 | """ 359 | @spec load_model(binary()) :: EXGBoost.Booster.t() 360 | @doc type: :serialization 361 | def load_model(buffer) do 362 | EXGBoost.Booster.load(buffer, deserialize: :model, from: :buffer) 363 | end 364 | 365 | @doc """ 366 | Write a model config to a file as a JSON - encoded string. 367 | 368 | ## Options 369 | #{NimbleOptions.docs(@write_schema)} 370 | """ 371 | @spec write_config(Booster.t(), String.t()) :: :ok | {:error, String.t()} 372 | @doc type: :serialization 373 | def write_config(%Booster{} = booster, path, opts \\ []) do 374 | opts = NimbleOptions.validate!(opts, @write_schema) 375 | EXGBoost.Booster.save(booster, opts ++ [path: path, serialize: :config]) 376 | end 377 | 378 | @doc """ 379 | Dump a model config to a buffer as a JSON - encoded string. 380 | 381 | ## Options 382 | #{NimbleOptions.docs(@dump_schema)} 383 | """ 384 | @spec dump_config(Booster.t()) :: binary() 385 | @doc type: :serialization 386 | def dump_config(%Booster{} = booster, opts \\ []) do 387 | opts = NimbleOptions.validate!(opts, @dump_schema) 388 | EXGBoost.Booster.save(booster, opts ++ [serialize: :config, to: :buffer]) 389 | end 390 | 391 | @doc """ 392 | Create a new Booster from a config file. The config file must be from the output of `write_config/2`. 393 | 394 | ## Options 395 | #{NimbleOptions.docs(@load_schema)} 396 | """ 397 | @spec read_config(String.t()) :: EXGBoost.Booster.t() 398 | @doc type: :serialization 399 | def read_config(path, opts \\ []) do 400 | opts = NimbleOptions.validate!(opts, @load_schema) 401 | EXGBoost.Booster.load(path, opts ++ [deserialize: :config]) 402 | end 403 | 404 | @doc """ 405 | Create a new Booster from a config buffer. The config buffer must be from the output of `dump_config/2`. 406 | 407 | ## Options 408 | #{NimbleOptions.docs(@load_schema)} 409 | """ 410 | @spec load_config(binary()) :: EXGBoost.Booster.t() 411 | @doc type: :serialization 412 | def load_config(buffer, opts \\ []) do 413 | opts = NimbleOptions.validate!(opts, @load_schema) 414 | EXGBoost.Booster.load(buffer, opts ++ [deserialize: :config, from: :buffer]) 415 | end 416 | 417 | @doc """ 418 | Write a model's trained parameters to a file. 419 | 420 | ## Options 421 | #{NimbleOptions.docs(@write_schema)} 422 | """ 423 | @spec write_weights(Booster.t(), String.t()) :: :ok | {:error, String.t()} 424 | @doc type: :serialization 425 | def write_weights(%Booster{} = booster, path, opts \\ []) do 426 | opts = NimbleOptions.validate!(opts, @write_schema) 427 | EXGBoost.Booster.save(booster, opts ++ [path: path, serialize: :weights]) 428 | end 429 | 430 | @doc """ 431 | Dump a model's trained parameters to a buffer as a JSON-encoded binary. 432 | 433 | ## Options 434 | #{NimbleOptions.docs(@dump_schema)} 435 | """ 436 | @spec dump_weights(Booster.t()) :: binary() 437 | @doc type: :serialization 438 | def dump_weights(%Booster{} = booster, opts \\ []) do 439 | opts = NimbleOptions.validate!(opts, @dump_schema) 440 | EXGBoost.Booster.save(booster, opts ++ [serialize: :weights, to: :buffer]) 441 | end 442 | 443 | @doc """ 444 | Read a model's trained parameters from a file and return the Booster. 445 | """ 446 | @spec read_weights(String.t()) :: EXGBoost.Booster.t() 447 | @doc type: :serialization 448 | def read_weights(path) do 449 | EXGBoost.Booster.load(path, deserialize: :weights) 450 | end 451 | 452 | @doc """ 453 | Read a model's trained parameters from a buffer and return the Booster. 454 | """ 455 | @spec load_weights(binary()) :: EXGBoost.Booster.t() 456 | @doc type: :serialization 457 | def load_weights(buffer) do 458 | EXGBoost.Booster.load(buffer, deserialize: :weights, from: :buffer) 459 | end 460 | 461 | @doc """ 462 | Plot a tree from a Booster model and save it to a file. 463 | 464 | ## Options 465 | * `:format` - the format to export the graphic as, must be either of: `:json`, `:html`, `:png`, `:svg`, `:pdf`. By default the format is inferred from the file extension. 466 | * `:local_npm_prefix` - a relative path pointing to a local npm project directory where the necessary npm packages are installed. For instance, in Phoenix projects you may want to pass local_npm_prefix: "assets". By default the npm packages are searched for in the current directory and globally. 467 | * `:path` - the path to save the graphic to. If not provided, the graphic is returned as a VegaLite spec. 468 | * `:opts` - additional options to pass to `EXGBoost.Plotting.plot/2`. See `EXGBoost.Plotting` for more information. 469 | """ 470 | @doc type: :plotting 471 | def plot_tree(booster, opts \\ []) do 472 | {path, opts} = Keyword.pop(opts, :path) 473 | {save_opts, opts} = Keyword.split(opts, [:format, :local_npm_prefix]) 474 | vega = Plotting.plot(booster, opts) 475 | 476 | if path != nil do 477 | VegaLite.Export.save!(vega, path, save_opts) 478 | else 479 | vega 480 | end 481 | end 482 | end 483 | -------------------------------------------------------------------------------- /lib/exgboost/application.ex: -------------------------------------------------------------------------------- 1 | defmodule EXGBoost.Application do 2 | @moduledoc false 3 | 4 | def start(_type, _args) do 5 | global_config = Application.get_all_env(:exgboost) |> Enum.into(%{}) 6 | :ok = EXGBoost.set_config(global_config) 7 | Supervisor.start_link([], strategy: :one_for_one) 8 | end 9 | end 10 | -------------------------------------------------------------------------------- /lib/exgboost/array_interface.ex: -------------------------------------------------------------------------------- 1 | defmodule EXGBoost.ArrayInterface do 2 | @moduledoc false 3 | alias EXGBoost.Internal 4 | 5 | @typedoc """ 6 | The XGBoost C API uses and is moving towards mainly supporting the use of 7 | JSON-Encoded NumPy ArrayyInterface format to pass data to and from the C API. This struct 8 | is used to represent the ArrayInterface format. 9 | 10 | If you wish to use the EXGBoost.NIF library directly, this will be the desired format 11 | to pass Nx.Tensors to the NIFs. Use of the EXGBoost.NIF library directly is not recommended 12 | unless you are familiar with the XGBoost C API and the EXGBoost.NIF library. 13 | 14 | See https://numpy.org/doc/stable/reference/arrays.interface.html for more information on 15 | the ArrayInterface protocol. 16 | """ 17 | @type t :: %__MODULE__{ 18 | typestr: String.t(), 19 | shape: tuple(), 20 | address: pos_integer(), 21 | readonly: boolean(), 22 | tensor: Nx.Tensor.t() 23 | } 24 | 25 | @enforce_keys [:typestr, :shape, :address, :readonly] 26 | defstruct [ 27 | :typestr, 28 | :shape, 29 | :address, 30 | :readonly, 31 | :tensor, 32 | version: 3 33 | ] 34 | 35 | defimpl Jason.Encoder do 36 | def encode( 37 | %{ 38 | typestr: typestr, 39 | shape: shape, 40 | address: address, 41 | readonly: readonly, 42 | version: version 43 | }, 44 | opts 45 | ) do 46 | Jason.Encode.map( 47 | %{ 48 | typestr: typestr, 49 | shape: Tuple.to_list(shape), 50 | data: [address, readonly], 51 | version: version 52 | }, 53 | opts 54 | ) 55 | end 56 | end 57 | 58 | defimpl Inspect do 59 | import Inspect.Algebra 60 | 61 | def inspect( 62 | %{ 63 | typestr: typestr, 64 | shape: shape, 65 | address: address, 66 | readonly: readonly, 67 | version: version 68 | }, 69 | opts 70 | ) do 71 | concat([ 72 | "#ArrayInterface<", 73 | line(), 74 | to_doc( 75 | %{ 76 | typestr: typestr, 77 | shape: Tuple.to_list(shape), 78 | data: [address, readonly], 79 | version: version 80 | }, 81 | opts 82 | ), 83 | line(), 84 | ">" 85 | ]) 86 | end 87 | end 88 | 89 | def from_map(%{} = interface) do 90 | interface 91 | |> Enum.reduce([], fn 92 | {"data", [address, readonly]}, acc -> 93 | [{:address, address} | [{:readonly, readonly} | acc]] 94 | 95 | {"shape", shape}, acc -> 96 | [{:shape, List.to_tuple(shape)} | acc] 97 | 98 | {key, value}, acc -> 99 | [{String.to_existing_atom(key), value} | acc] 100 | end) 101 | |> then(&struct(__MODULE__, &1)) 102 | end 103 | 104 | @doc """ 105 | This function is used to convert Nx.Tensors to the ArrayInterface format. 106 | 107 | Example: 108 | iex> EXGBoost.from_tensor(Nx.tensor([[1,2,3],[4,5,6]])) 109 | #ArrayInterface< 110 | %{data: [4418559984, true], shape: [2, 3], typestr: " 117 | " 121 | raise ArgumentError, 122 | "Invalid tensor type -- #{inspect(t_type)} not supported by EXGBoost" 123 | 124 | {tensor_type, type_width} -> 125 | "<#{Atom.to_string(tensor_type)}#{div(type_width, 8)}" 126 | end 127 | 128 | tensor_addr = 129 | EXGBoost.NIF.get_binary_address(Nx.to_binary(tensor)) |> EXGBoost.Internal.unwrap!() 130 | 131 | %__MODULE__{ 132 | typestr: type_char, 133 | shape: Nx.shape(tensor), 134 | address: tensor_addr, 135 | readonly: true, 136 | tensor: tensor 137 | } 138 | end 139 | 140 | @spec get_tensor(EXGBoost.ArrayInterface.t()) :: Nx.Tensor.t() 141 | def get_tensor(%__MODULE__{tensor: nil} = arr_int) do 142 | num_items = arr_int.shape |> Tuple.to_list() |> Enum.product() 143 | <<_endianess::utf8, char_code::binary-size(1), bytes::binary>> = arr_int.typestr 144 | 145 | nx_type = 146 | case char_code do 147 | "i" -> {:s, String.to_integer(bytes) * 8} 148 | other -> {String.to_existing_atom(other), String.to_integer(bytes) * 8} 149 | end 150 | 151 | tensor_bin = 152 | EXGBoost.NIF.get_binary_from_address(arr_int.address, String.to_integer(bytes) * num_items) 153 | |> Internal.unwrap!() 154 | 155 | Nx.from_binary( 156 | tensor_bin, 157 | nx_type 158 | ) 159 | |> Nx.reshape(arr_int.shape) 160 | end 161 | 162 | def get_tensor(%__MODULE__{tensor: %Nx.Tensor{} = tensor}) do 163 | tensor 164 | end 165 | end 166 | -------------------------------------------------------------------------------- /lib/exgboost/dmatrix.ex: -------------------------------------------------------------------------------- 1 | defmodule EXGBoost.DMatrix do 2 | @moduledoc false 3 | 4 | # Internal docs for development only 5 | _docstring = """ 6 | Parameters 7 | ---------- 8 | data : 9 | Data source of DMatrix. 10 | label : 11 | Label of the training data. 12 | weight : 13 | Weight for each instance. 14 | 15 | .. note:: 16 | 17 | For ranking task, weights are per-group. In ranking task, one weight 18 | is assigned to each group (not each data point). This is because we 19 | only care about the relative ordering of data points within each group, 20 | so it doesn't make sense to assign weights to individual data points. 21 | 22 | base_margin : 23 | Base margin used for boosting from existing model. 24 | missing : 25 | Value in the input data which needs to be present as a missing value. If 26 | None, defaults to np.nan. 27 | silent : 28 | Whether print messages during construction 29 | feature_names : 30 | Set names for features. 31 | feature_types : 32 | 33 | Set types for features. When `enable_categorical` is set to `True`, string 34 | "c" represents categorical data type while "q" represents numerical feature 35 | type. For categorical features, the input is assumed to be preprocessed and 36 | encoded by the users. The encoding can be done via 37 | :py:class:`sklearn.preprocessing.OrdinalEncoder` or pandas dataframe 38 | `.cat.codes` method. This is useful when users want to specify categorical 39 | features without having to construct a dataframe as input. 40 | 41 | nthread : 42 | Number of threads to use for loading data when parallelization is 43 | applicable. If -1, uses maximum threads available on the system. 44 | group : 45 | Group size for all ranking group. 46 | qid : 47 | Query ID for data samples, used for ranking. 48 | label_lower_bound : 49 | Lower bound for survival training. 50 | label_upper_bound : 51 | Upper bound for survival training. 52 | feature_weights : 53 | Set feature weights for column sampling. 54 | enable_categorical : 55 | 56 | .. versionadded:: 1.3.0 57 | 58 | .. note:: This parameter is experimental 59 | 60 | Experimental support of specializing for categorical features. Do not set 61 | to True unless you are interested in development. Also, JSON/UBJSON 62 | serialization format is required. 63 | 64 | """ 65 | 66 | alias EXGBoost.ArrayInterface 67 | alias EXGBoost.Internal 68 | 69 | @enforce_keys [ 70 | :ref, 71 | :format 72 | ] 73 | defstruct [ 74 | :ref, 75 | :format 76 | ] 77 | 78 | @type t :: %__MODULE__{ 79 | ref: reference(), 80 | format: atom() 81 | } 82 | 83 | def get_float_info(dmatrix, feature) 84 | when feature in [ 85 | "label", 86 | "weight", 87 | "base_margin", 88 | "label_lower_bound", 89 | "label_upper_bound", 90 | "feature_weights" 91 | ], 92 | do: EXGBoost.NIF.dmatrix_get_float_info(dmatrix.ref, feature) |> Internal.unwrap!() 93 | 94 | def get_group(dmatrix), do: get_uint_info(dmatrix, "group") 95 | 96 | def get_uint_info(dmatrix, "group"), 97 | do: EXGBoost.NIF.dmatrix_get_uint_info(dmatrix.ref, "group_ptr") |> Internal.unwrap!() 98 | 99 | def get_num_rows(dmatrix), 100 | do: EXGBoost.NIF.dmatrix_num_row(dmatrix.ref) |> Internal.unwrap!() 101 | 102 | def get_num_cols(dmatrix), 103 | do: EXGBoost.NIF.dmatrix_num_col(dmatrix.ref) |> Internal.unwrap!() 104 | 105 | def get_num_non_missing(dmatrix), 106 | do: EXGBoost.NIF.dmatrix_num_non_missing(dmatrix.ref) |> Internal.unwrap!() 107 | 108 | def get_data(dmatrix), 109 | do: 110 | EXGBoost.NIF.dmatrix_get_data_as_csr(dmatrix.ref, Jason.encode!(%{})) |> Internal.unwrap!() 111 | 112 | def get_feature_names(dmatrix), 113 | do: 114 | EXGBoost.NIF.dmatrix_get_str_feature_info(dmatrix.ref, "feature_name") |> Internal.unwrap!() 115 | 116 | def get_feature_types(dmatrix), 117 | do: 118 | EXGBoost.NIF.dmatrix_get_str_feature_info(dmatrix.ref, "feature_type") |> Internal.unwrap!() 119 | 120 | def set_params(dmat, opts) do 121 | options = Internal.dmatrix_str_feature_opts() ++ Internal.dmatrix_meta_feature_opts() 122 | opts = Keyword.validate!(opts, options) 123 | 124 | {meta_opts, opts} = Keyword.split(opts, Internal.dmatrix_meta_feature_opts()) 125 | {str_opts, _opts} = Keyword.split(opts, Internal.dmatrix_str_feature_opts()) 126 | 127 | args = Enum.into(Keyword.merge(meta_opts, str_opts), %{}) 128 | 129 | Enum.each(meta_opts, fn {key, value} -> 130 | data_interface = ArrayInterface.from_tensor(value) |> Jason.encode!() 131 | 132 | EXGBoost.NIF.dmatrix_set_info_from_interface( 133 | dmat.ref, 134 | Atom.to_string(key), 135 | data_interface 136 | ) 137 | end) 138 | 139 | Enum.each(str_opts, fn {key, value} -> 140 | EXGBoost.NIF.dmatrix_set_str_feature_info(dmat.ref, Atom.to_string(key), value) 141 | end) 142 | 143 | struct(dmat, args) 144 | end 145 | 146 | @doc """ 147 | Slice the DMatrix and return a new DMatrix that only contains rindex. 148 | """ 149 | def slice(dmat, %Nx.Tensor{shape: {_rows}} = r_index, opts \\ []) 150 | when is_list(opts) do 151 | opts = Keyword.validate!(opts, allow_groups: false) 152 | allow_groups = Keyword.fetch!(opts, :allow_groups) 153 | EXGBoost.NIF.dmatrix_slice(dmat.ref, Nx.to_binary(r_index), allow_groups) 154 | end 155 | 156 | @doc """ 157 | Export the quantile cuts used for training histogram-based models like `hist` and `approx`. 158 | Useful for model compression. 159 | 160 | Returns a tuple of {indptr, data} representing a CSC matrix of the cuts. 161 | """ 162 | def get_quantile_cut(%__MODULE__{} = dmat) do 163 | # https://xgboost.readthedocs.io/en/stable/c.html#_CPPv423XGDMatrixGetQuantileCutK13DMatrixHandlePKcPPKcPPKc 164 | # config – JSON configuration string. At the moment it should be an empty document, preserved for future use. 165 | config = %{} |> Jason.encode!() 166 | 167 | {indptr, data} = 168 | EXGBoost.NIF.dmatrix_get_quantile_cut(dmat.ref, config) 169 | |> Internal.unwrap!() 170 | 171 | indptr = 172 | Jason.decode!(indptr) 173 | |> ArrayInterface.from_map() 174 | |> ArrayInterface.get_tensor() 175 | 176 | data = 177 | Jason.decode!(data) 178 | |> ArrayInterface.from_map() 179 | |> ArrayInterface.get_tensor() 180 | 181 | {indptr, data} 182 | end 183 | 184 | defimpl Inspect do 185 | import Inspect.Algebra 186 | alias EXGBoost.DMatrix 187 | 188 | def inspect(dmatrix, _opts) do 189 | {indptr, indices, data} = DMatrix.get_data(dmatrix) 190 | 191 | concat([ 192 | "#DMatrix<", 193 | line(), 194 | " {#{DMatrix.get_num_rows(dmatrix)}x#{DMatrix.get_num_cols(dmatrix)}x#{DMatrix.get_num_non_missing(dmatrix)}}", 195 | line(), 196 | if(DMatrix.get_group(dmatrix) != nil, 197 | do: " group: #{inspect(DMatrix.get_group(dmatrix))}" 198 | ), 199 | line(), 200 | " indptr: #{inspect(Nx.tensor(indptr))}", 201 | line(), 202 | " indices: #{inspect(Nx.tensor(indices))}", 203 | line(), 204 | " data: #{inspect(Nx.tensor(data))}", 205 | line(), 206 | ">" 207 | ]) 208 | end 209 | end 210 | 211 | @doc """ 212 | Create a DMatrix from a file. 213 | 214 | Refer to https://xgboost.readthedocs.io/en/latest/tutorials/external_memory.html#text-file-inputs 215 | for proper formatting of the file and the options. 216 | 217 | This function will URI encode the filepath according to the URI scheme defined in 218 | XGBoost's documentation. 219 | """ 220 | def from_file(filepath, opts) when is_binary(filepath) and is_list(opts) do 221 | opts = 222 | Keyword.validate!(opts, 223 | label_column: nil, 224 | cacheprefix: nil, 225 | format: :dense, 226 | ext: :auto, 227 | silent: 1, 228 | data_split_mode: :row 229 | ) 230 | 231 | if not (File.exists?(filepath) and File.regular?(filepath)) do 232 | raise ArgumentError, "File must exist and be a regular file" 233 | end 234 | 235 | {file_format, opts} = Keyword.pop!(opts, :ext) 236 | {silent, opts} = Keyword.pop!(opts, :silent) 237 | {format, opts} = Keyword.pop!(opts, :format) 238 | {label_column, opts} = Keyword.pop!(opts, :label_column) 239 | {cacheprefix, opts} = Keyword.pop!(opts, :cacheprefix) 240 | {data_split_mode, opts} = Keyword.pop!(opts, :data_split_mode) 241 | 242 | unless data_split_mode in [:row, :column] do 243 | raise ArgumentError, "data_split_mode must be :row or :column" 244 | end 245 | 246 | ext = 247 | case file_format do 248 | :libsvm -> "libsvm" 249 | :csv -> "csv" 250 | :auto -> "auto" 251 | _ -> raise ArgumentError, "Invalid file format" 252 | end 253 | 254 | uri = "#{filepath}?format=#{ext}" 255 | 256 | if file_format != :csv and not is_nil(label_column) do 257 | if silent == 1 do 258 | IO.warn("label_column should only be specified for CSV files -- ignoring...") 259 | else 260 | raise ArgumentError, "label_column should only be specified for CSV files" 261 | end 262 | end 263 | 264 | if not is_nil(cacheprefix) and not File.exists?(cacheprefix) do 265 | if silent == 1 do 266 | IO.warn("cacheprefix file not found -- ignoring...") 267 | else 268 | raise ArgumentError, "cacheprefix file not found" 269 | end 270 | end 271 | 272 | uri = 273 | if not is_nil(label_column) and file_format == :csv do 274 | uri <> "&label_column=#{label_column}" 275 | else 276 | uri 277 | end 278 | 279 | uri = 280 | if not is_nil(cacheprefix) and File.exists?(cacheprefix) and File.regular?(cacheprefix) do 281 | uri <> "##{cacheprefix}" 282 | else 283 | uri 284 | end 285 | 286 | config = %{uri: uri, silent: silent, data_split_mode: data_split_mode} |> Jason.encode!() 287 | 288 | dmat = 289 | EXGBoost.NIF.dmatrix_create_from_uri(config) 290 | |> Internal.unwrap!() 291 | 292 | set_params(%__MODULE__{ref: dmat, format: format}, opts) 293 | end 294 | 295 | def from_tensor(_tensor, _opts \\ []) 296 | 297 | def from_tensor(%Nx.Tensor{} = tensor, opts) when is_list(opts) do 298 | opts = Keyword.validate!(opts, Internal.dmatrix_feature_opts()) 299 | 300 | {config_opts, opts} = Keyword.split(opts, Internal.dmatrix_config_feature_opts()) 301 | config_opts = Keyword.validate!(config_opts, missing: Nx.Constants.nan(), nthread: 0) 302 | {format_opts, opts} = Keyword.split(opts, Internal.dmatrix_format_feature_opts()) 303 | 304 | config = Enum.into(config_opts, %{}, fn {key, value} -> {Atom.to_string(key), value} end) 305 | format = Keyword.fetch!(format_opts, :format) 306 | 307 | dmat = 308 | EXGBoost.NIF.dmatrix_create_from_dense( 309 | Jason.encode!(ArrayInterface.from_tensor(tensor)), 310 | Jason.encode!(config) 311 | ) 312 | |> Internal.unwrap!() 313 | 314 | set_params(%__MODULE__{ref: dmat, format: format}, opts) 315 | end 316 | 317 | def from_tensor(%Nx.Tensor{} = x, %Nx.Tensor{} = y) do 318 | from_tensor(x, y, []) 319 | end 320 | 321 | def from_tensor(%Nx.Tensor{shape: x_shape}, %Nx.Tensor{shape: {y_shape}}, _opts) 322 | when is_tuple(x_shape) and elem(x_shape, 0) != elem(y_shape, 0) do 323 | raise ArgumentError, 324 | "x and y must have the same number of rows, got #{elem(x_shape, 0)} and #{elem(y_shape, 0)}" 325 | end 326 | 327 | def from_tensor(%Nx.Tensor{shape: x_shape} = x, %Nx.Tensor{shape: y_shape} = y, opts) 328 | when is_tuple(x_shape) and elem(x_shape, 0) == elem(y_shape, 0) do 329 | if Keyword.has_key?(opts, :label) do 330 | raise ArgumentError, "label must not be specified as an opt if y is provided" 331 | end 332 | 333 | opts = Keyword.put_new(opts, :label, y) 334 | from_tensor(x, opts) 335 | end 336 | 337 | def from_csr( 338 | %Nx.Tensor{} = indptr, 339 | %Nx.Tensor{} = indices, 340 | %Nx.Tensor{} = data, 341 | n, 342 | opts \\ [] 343 | ) 344 | when is_integer(n) and n > 0 do 345 | from_csr({indptr, indices, data, n}, opts) 346 | end 347 | 348 | def from_csr( 349 | {%Nx.Tensor{} = indptr, %Nx.Tensor{} = indices, %Nx.Tensor{} = data, n}, 350 | opts \\ [] 351 | ) 352 | when is_integer(n) and n > 0 do 353 | opts = Keyword.validate!(opts, Internal.dmatrix_feature_opts()) 354 | 355 | {config_opts, opts} = Keyword.split(opts, Internal.dmatrix_config_feature_opts()) 356 | config_opts = Keyword.validate!(config_opts, missing: Nx.Constants.nan(), nthread: 0) 357 | {format_opts, opts} = Keyword.split(opts, Internal.dmatrix_format_feature_opts()) 358 | 359 | config = Enum.into(config_opts, %{}, fn {key, value} -> {Atom.to_string(key), value} end) 360 | format = Keyword.fetch!(format_opts, :format) 361 | 362 | if format not in [:csr, :csc] do 363 | raise ArgumentError, "Sparse format must be :csr or :csc" 364 | end 365 | 366 | dmat = 367 | EXGBoost.NIF.dmatrix_create_from_sparse( 368 | Jason.encode!(ArrayInterface.from_tensor(indptr)), 369 | Jason.encode!(ArrayInterface.from_tensor(indices)), 370 | Jason.encode!(ArrayInterface.from_tensor(data)), 371 | n, 372 | Jason.encode!(config), 373 | Atom.to_string(format) 374 | ) 375 | |> Internal.unwrap!() 376 | 377 | set_params(%__MODULE__{ref: dmat, format: format}, opts) 378 | end 379 | end 380 | 381 | defmodule EXGBoost.ProxyDMatrix do 382 | @moduledoc false 383 | @enforce_keys [:ref] 384 | defstruct [:ref] 385 | 386 | def proxy_dmatrix() do 387 | p_ref = EXGBoost.NIF.proxy_dmatrix_create() 388 | %__MODULE__{ref: p_ref} 389 | end 390 | 391 | def set_params(%__MODULE__{} = dmat, opts) do 392 | EXGBoost.DMatrix.set_params(dmat, opts) 393 | end 394 | end 395 | -------------------------------------------------------------------------------- /lib/exgboost/internal.ex: -------------------------------------------------------------------------------- 1 | defmodule EXGBoost.Internal do 2 | @moduledoc false 3 | alias EXGBoost.Booster 4 | alias EXGBoost.DMatrix 5 | 6 | def dmatrix_feature_opts, 7 | do: 8 | dmatrix_str_feature_opts() ++ 9 | dmatrix_meta_feature_opts() ++ 10 | dmatrix_config_feature_opts() ++ dmatrix_format_feature_opts() 11 | 12 | def dmatrix_str_feature_opts, do: [:feature_name, :feature_type] 13 | 14 | def dmatrix_format_feature_opts(), do: [:format] 15 | 16 | def dmatrix_meta_feature_opts, 17 | do: [ 18 | :label, 19 | :weight, 20 | :base_margin, 21 | :group, 22 | :label_upper_bound, 23 | :label_lower_bound, 24 | :feature_weights 25 | ] 26 | 27 | def dmatrix_config_feature_opts, do: [:nthread, :missing] 28 | 29 | def validate_type!(%Nx.Tensor{} = tensor, type) do 30 | unless Nx.type(tensor) == type do 31 | raise ArgumentError, 32 | "invalid type #{inspect(Nx.type(tensor))}, vector type" <> 33 | " must be #{inspect(type)}" 34 | end 35 | end 36 | 37 | def validate_features!(%Booster{} = booster, %DMatrix{} = dmatrix) do 38 | unless DMatrix.get_num_rows(dmatrix) == 0 do 39 | booster_names = Booster.get_feature_names(booster) 40 | booster_types = Booster.get_feature_types(booster) 41 | dmatrix_names = DMatrix.get_feature_names(dmatrix) 42 | dmatrix_types = DMatrix.get_feature_types(dmatrix) 43 | 44 | if dmatrix_names == nil and booster_names != nil do 45 | raise ArgumentError, 46 | "training data did not have the following fields: #{inspect(booster_names)}" 47 | end 48 | 49 | if dmatrix_types == nil and booster_types != nil do 50 | raise ArgumentError, 51 | "training data did not have the following types: #{inspect(booster_types)}" 52 | end 53 | 54 | if booster_names != dmatrix_names do 55 | booster_name_set = MapSet.new(booster_names) 56 | dmatrix_name_set = MapSet.new(dmatrix_names) 57 | dmatrix_missing = MapSet.difference(booster_name_set, dmatrix_name_set) 58 | my_missing = MapSet.difference(dmatrix_name_set, booster_name_set) 59 | msg = "feature_names mismatch: #{inspect(booster_names)} #{inspect(dmatrix_names)}" 60 | 61 | msg = 62 | if MapSet.size(dmatrix_missing) != 0 do 63 | msg <> "\nexpected #{inspect(dmatrix_missing)} in input data" 64 | else 65 | msg 66 | end 67 | 68 | msg = 69 | if MapSet.size(my_missing) != 0 do 70 | msg <> "\ntraining data did not have the following fields: #{inspect(my_missing)}" 71 | else 72 | msg 73 | end 74 | 75 | raise ArgumentError, msg 76 | end 77 | end 78 | end 79 | 80 | def get_xgboost_data_type(%Nx.Tensor{} = tensor) do 81 | case Nx.type(tensor) do 82 | {:f, 32} -> 83 | {:ok, 1} 84 | 85 | {:f, 64} -> 86 | {:ok, 2} 87 | 88 | {:u, 32} -> 89 | {:ok, 3} 90 | 91 | {:u, 64} -> 92 | {:ok, 4} 93 | 94 | true -> 95 | {:error, 96 | "invalid type #{inspect(Nx.type(tensor))}\nxgboost DMatrix only supports data types of float32, float64, uint32, and uint64"} 97 | end 98 | end 99 | 100 | def set_params(_dmatrix, _opts \\ []) 101 | 102 | def set_params(%DMatrix{} = dmat, opts) do 103 | EXGBoost.DMatrix.set_params(dmat, opts) 104 | end 105 | 106 | def set_params(%Booster{} = booster, opts) do 107 | EXGBoost.Booster.set_params(booster, opts) 108 | end 109 | 110 | # Need to implement this because XGBoost expects NaN to be encoded as "NaN" without being 111 | # a string, so if we pass string NaN to XGBoost, it will fail. 112 | # This allows the user to use Nx.Constants.nan() and have it work as expected. 113 | defimpl Jason.Encoder, for: Nx.Tensor do 114 | @binary_nans [ 115 | <<0x7FC0::16-native>>, 116 | <<0x7E00::16-native>>, 117 | <<0x7FC00000::32-native>>, 118 | <<0x7FF8000000000000::64-native>> 119 | ] 120 | def encode(%Nx.Tensor{data: %Nx.BinaryBackend{state: state}}, _opts) 121 | when state in @binary_nans, 122 | do: "NaN" 123 | 124 | def encode(%Nx.Tensor{} = tensor, _opts) do 125 | case Nx.to_binary(tensor, limit: 1) do 126 | binary when binary in @binary_nans -> 127 | "NaN" 128 | 129 | _ -> 130 | raise ArgumentError, 131 | """ 132 | JSON Encoding only implemented for NaN Tensors (Nx.Constants.nan())! 133 | 134 | This normally is only used to map the `missing` parameter during EXGBoost 135 | training when `missing` is Nx.Constants.nan() 136 | """ 137 | end 138 | end 139 | end 140 | 141 | def unwrap!({:ok, val}), do: val 142 | def unwrap!({:error, reason}), do: raise(reason) 143 | def unwrap!(:ok), do: :ok 144 | end 145 | -------------------------------------------------------------------------------- /lib/exgboost/nif.ex: -------------------------------------------------------------------------------- 1 | defmodule EXGBoost.NIF do 2 | @moduledoc false 3 | 4 | @on_load :on_load 5 | 6 | @typedoc """ 7 | Indicator of data type. This is defined in xgboost::DataType enum class. 8 | float = 1 9 | double = 2 10 | uint32_t = 3 11 | uint64_t = 4 12 | """ 13 | @type xgboost_data_type :: 1..4 14 | @typedoc """ 15 | JSON-Encoded Array Interface as defined in the NumPy documentation. 16 | https://numpy.org/doc/stable/reference/arrays.interface.html 17 | """ 18 | @type array_interface :: String.t() 19 | @type dmatrix_reference :: reference() 20 | @type booster_reference :: reference() 21 | @type exgboost_return_type(return_type) :: {:ok, return_type} | {:error, String.t()} 22 | 23 | def on_load do 24 | path = :filename.join([:code.priv_dir(:exgboost), "libexgboost"]) 25 | :erlang.load_nif(path, 0) 26 | end 27 | 28 | @spec get_int_size :: integer() 29 | def get_int_size, do: :erlang.nif_error(:not_implemented) 30 | 31 | @spec xgboost_version :: exgboost_return_type(tuple) 32 | @doc """ 33 | Get the version of the XGBoost library. 34 | 35 | {major, minor, patch}. 36 | 37 | ## Examples 38 | 39 | iex> EXGBoost.NIF.xgboost_version() 40 | {:ok, {2, 0, 0}} 41 | """ 42 | def xgboost_version, do: :erlang.nif_error(:not_implemented) 43 | 44 | @spec xgboost_build_info :: exgboost_return_type(String.t()) 45 | @doc """ 46 | Get compile information of the XGBoost shared library. 47 | 48 | Returns a string encoded JSON object containing build flags and dependency version. 49 | 50 | ## Examples 51 | 52 | iex> EXGBoost.NIF.xgboost_build_info() 53 | {:ok,'{"BUILTIN_PREFETCH_PRESENT":true,"DEBUG":false,"GCC_VERSION":[9,3,0],"MM_PREFETCH_PRESENT":true,"USE_CUDA":false,"USE_FEDERATED":false,"USE_NCCL":false,"USE_OPENMP":true,"USE_RMM":false}'} 54 | """ 55 | def xgboost_build_info, do: :erlang.nif_error(:not_implemented) 56 | 57 | @spec set_global_config(String.t()) :: :ok | {:error, String.t()} 58 | @doc """ 59 | Set global config for XGBoost using a string encoded flat json. 60 | 61 | Returns `:ok` if the config is set successfully. 62 | 63 | ## Examples 64 | 65 | iex> EXGBoost.NIF.set_global_config('{"use_rmm":false,"verbosity":1}') 66 | :ok 67 | iex> EXGBoost.NIF.set_global_config('{"use_rmm":false,"verbosity": true}') 68 | {:error, 'Invalid Parameter format for verbosity expect int but value=\'true\''} 69 | """ 70 | def set_global_config(_config), do: :erlang.nif_error(:not_implemented) 71 | 72 | @spec get_global_config :: exgboost_return_type(String.t()) 73 | @doc """ 74 | Get global config for XGBoost as a string encoded flat json. 75 | 76 | Returns a string encoded flat json. 77 | 78 | ## Examples 79 | 80 | iex> EXGBoost.NIF.get_global_config() 81 | {:ok, '{"use_rmm":false,"verbosity":1}'} 82 | """ 83 | def get_global_config, do: :erlang.nif_error(:not_implemented) 84 | 85 | @deprecated "Since 0.4.0 -- Use `EXGBoost.NIF.dmatrix_create_from_uri/1` instead" 86 | @spec dmatrix_create_from_file(String.t(), Integer) :: 87 | exgboost_return_type(reference) 88 | @doc """ 89 | Create a DMatrix from a filename 90 | 91 | **WARNING** This function will break on an improper file type and parse and thus the user 92 | should take EXTREME caution when using this, and avoid calling directly. Instead 93 | use the `EXGBoost.dmatrix` function. 94 | 95 | This is set to be fixed with the 2.0.0 release of XGBoost. 96 | 97 | Refer to https://github.com/dmlc/xgboost/issues/9059 98 | 99 | """ 100 | def dmatrix_create_from_file(_file_uri, _silent), 101 | do: :erlang.nif_error(:not_implemented) 102 | 103 | def dmatrix_create_from_uri(_config), do: :erlang.nif_error(:not_implemented) 104 | 105 | @spec dmatrix_create_from_mat(binary, integer(), integer(), float()) :: 106 | exgboost_return_type(dmatrix_reference()) 107 | @doc """ 108 | Create a DMatrix from an Nx Tensor of type {:f, 32}. 109 | 110 | Returns a reference to the DMatrix. 111 | 112 | ## Examples 113 | 114 | iex> EXGBoost.NIF.dmatrix_create_from_mat(Nx.to_binary(Nx.tensor([1.0, 2.0, 3.0, 4.0])),1,4, -1.0) 115 | {:ok, #Reference<>} 116 | iex> EXGBoost.NIF.dmatrix_create_from_mat(Nx.to_binary(Nx.tensor([1, 2, 3, 4])),1,2, -1.0) 117 | {:error, 'Data size does not match nrow and ncol'} 118 | """ 119 | def dmatrix_create_from_mat(_data, _nrow, _ncol, _missing), 120 | do: :erlang.nif_error(:not_implemented) 121 | 122 | @spec dmatrix_create_from_sparse( 123 | array_interface(), 124 | array_interface(), 125 | array_interface(), 126 | integer(), 127 | String.t(), 128 | String.t() 129 | ) :: exgboost_return_type(dmatrix_reference()) 130 | @doc """ 131 | Create a DMatrix from a Sparse matrix (CSR / CSC) 132 | 133 | Returns a reference to the DMatrix. 134 | 135 | ## Examples 136 | 137 | iex> EXGBoost.NIF.dmatrix_create_from_csr([0, 2, 3], [0, 2, 2, 0], [1, 2, 3, 4], 2, 2, -1.0) 138 | {:ok, #Reference<>} 139 | 140 | iex> EXGBoost.NIF.dmatrix_create_from_csr([0, 2, 3], [0, 2, 2, 0], [1, 2, 3, 4], 2, 2, -1.0) 141 | {:error #Reference<>} 142 | """ 143 | def dmatrix_create_from_sparse( 144 | _indptr_interface, 145 | _indices_interface, 146 | _data_interface, 147 | _n, 148 | _config, 149 | _format 150 | ), 151 | do: :erlang.nif_error(:not_implemented) 152 | 153 | @spec dmatrix_create_from_dense(array_interface(), String.t()) :: 154 | exgboost_return_type(dmatrix_reference()) 155 | @doc """ 156 | Create a DMatrix from a JSON-Encoded Array-Interface 157 | https://numpy.org/doc/stable/reference/arrays.interface.html 158 | 159 | """ 160 | def dmatrix_create_from_dense(_array_interface, _config), 161 | do: :erlang.nif_error(:not_implemented) 162 | 163 | @spec dmatrix_get_str_feature_info(dmatrix_reference(), String.t()) :: 164 | exgboost_return_type([String.t()]) 165 | def dmatrix_get_str_feature_info(_dmatrix_resource, _field), 166 | do: :erlang.nif_error(:not_implemented) 167 | 168 | @spec dmatrix_set_str_feature_info(dmatrix_reference(), String.t(), [String.t()]) :: 169 | :ok | {:error, String.t()} 170 | def dmatrix_set_str_feature_info(_dmatrix_resource, _field, _features), 171 | do: :erlang.nif_error(:not_implemented) 172 | 173 | @spec dmatrix_num_row(dmatrix_reference()) :: exgboost_return_type(pos_integer()) 174 | def dmatrix_num_row(_handle), do: :erlang.nif_error(:not_implemented) 175 | 176 | @spec dmatrix_num_col(dmatrix_reference()) :: exgboost_return_type(pos_integer()) 177 | def dmatrix_num_col(_handle), do: :erlang.nif_error(:not_implemented) 178 | 179 | @spec dmatrix_num_non_missing(dmatrix_reference()) :: exgboost_return_type(pos_integer()) 180 | def dmatrix_num_non_missing(_handle), do: :erlang.nif_error(:not_implemented) 181 | 182 | @spec dmatrix_set_info_from_interface( 183 | dmatrix_reference(), 184 | String.t(), 185 | array_interface() 186 | ) :: :ok | {:error, String.t()} 187 | @doc """ 188 | Set the info from an array interface 189 | Valid fields are: 190 | Set meta info from dense matrix. Valid field names are: 191 | * label 192 | * weight 193 | * base_margin 194 | * group 195 | * label_lower_bound 196 | * label_upper_bound 197 | * feature_weights 198 | """ 199 | def dmatrix_set_info_from_interface(_handle, _field, _data_interface), 200 | do: :erlang.nif_error(:not_implemented) 201 | 202 | @spec dmatrix_save_binary(dmatrix_reference(), String.t(), integer()) :: 203 | exgboost_return_type(:ok) 204 | def dmatrix_save_binary(_handle, _fname, _silent), 205 | do: :erlang.nif_error(:not_implemented) 206 | 207 | @spec get_binary_address(dmatrix_reference()) :: exgboost_return_type(integer) 208 | def get_binary_address(_handle), 209 | do: :erlang.nif_error(:not_implemented) 210 | 211 | @spec get_binary_from_address(integer(), integer()) :: exgboost_return_type(binary()) 212 | def get_binary_from_address(_address, _size), do: :erlang.nif_error(:not_implemented) 213 | 214 | @doc """ 215 | Gets a field from the DMatrix. Valid fields are: 216 | * label 217 | * weight 218 | * base_margin 219 | * label_lower_bound 220 | * label_upper_bound 221 | * feature_weights 222 | """ 223 | @spec dmatrix_get_float_info(dmatrix_reference(), String.t()) :: exgboost_return_type([float]) 224 | def dmatrix_get_float_info(_handle, _field), 225 | do: :erlang.nif_error(:not_implemented) 226 | 227 | @doc """ 228 | Gets a field from the DMatrix. Valid fields are: 229 | * group_ptr 230 | """ 231 | @spec dmatrix_get_uint_info(dmatrix_reference(), String.t()) :: 232 | exgboost_return_type([pos_integer()]) 233 | def dmatrix_get_uint_info(_handle, _field), 234 | do: :erlang.nif_error(:not_implemented) 235 | 236 | @doc """ 237 | Get data field from DMatrix. 238 | 239 | * config: At the moment it should be an empty document, preserved for future use. 240 | 241 | Returns 3-tuple of {indptr, indices, data} 242 | """ 243 | @spec dmatrix_get_data_as_csr(dmatrix_reference(), String.t()) :: 244 | exgboost_return_type({[pos_integer()], [pos_integer()], [float]}) 245 | def dmatrix_get_data_as_csr(_handle, _config), 246 | do: :erlang.nif_error(:not_implemented) 247 | 248 | @doc """ 249 | Create a DMatrix from a slice of rows from an existing DMAtrix 250 | 251 | Expects a binary of `int`s so you should query for the size of the C int using 252 | NIF.get_int_size/0 and then use that to pack the binary, or use Nx.tensor(_intput, type: type) 253 | to get a binary of the correct size. 254 | """ 255 | @spec dmatrix_slice(dmatrix_reference(), binary(), 0 | 1) :: dmatrix_reference() 256 | def dmatrix_slice(_handle, _index_set, _allow_groups), do: :erlang.nif_error(:not_implemented) 257 | 258 | def dmatrix_get_quantile_cut(_handle, _config), do: :erlang.nif_error(:not_implemented) 259 | 260 | @spec booster_create([dmatrix_reference()]) :: exgboost_return_type(booster_reference()) 261 | def booster_create(_handles), do: :erlang.nif_error(:not_implemented) 262 | 263 | @spec booster_slice(booster_reference(), integer(), integer(), integer()) :: 264 | exgboost_return_type(booster_reference()) 265 | def booster_slice(_handle, _begin_layer, _end_layer, _step), 266 | do: :erlang.nif_error(:not_implemented) 267 | 268 | @spec booster_boosted_rounds(booster_reference()) :: exgboost_return_type(integer()) 269 | def booster_boosted_rounds(_handle), do: :erlang.nif_error(:not_implemented) 270 | 271 | @spec booster_set_param(booster_reference(), String.t(), String.t()) :: 272 | :ok | {:error, String.t()} 273 | def booster_set_param(_handle, _param, _value), do: :erlang.nif_error(:not_implemented) 274 | 275 | @spec booster_get_num_feature(booster_reference()) :: exgboost_return_type(pos_integer()) 276 | def booster_get_num_feature(_handle), do: :erlang.nif_error(:not_implemented) 277 | 278 | @spec booster_update_one_iter(booster_reference(), dmatrix_reference(), integer()) :: 279 | :ok | {:error, String.t()} 280 | def booster_update_one_iter(_booster_handle, _dmatrix_handle, _iteration), 281 | do: :erlang.nif_error(:not_implemented) 282 | 283 | @doc """ 284 | Update the model, by directly specify gradient and second order gradient, this can be used to replace UpdateOneIter, to support customized loss function 285 | 286 | Grad and hess must be binaries of Nx.Tensor float32 287 | """ 288 | @spec booster_boost_one_iter(booster_reference(), dmatrix_reference(), binary(), binary()) :: 289 | :ok | {:error, String.t()} 290 | def booster_boost_one_iter(_booster_handle, _dmatrix_handle, _grad, _hess), 291 | do: :erlang.nif_error(:not_implemented) 292 | 293 | @spec booster_eval_one_iter(booster_reference(), pos_integer(), [dmatrix_reference()], [ 294 | String.t() 295 | ]) :: exgboost_return_type(String.t()) 296 | def booster_eval_one_iter(_booster_handle, _iteration, _dmatrix_handles, _eval_names), 297 | do: :erlang.nif_error(:not_implemented) 298 | 299 | @spec booster_get_attr_names(booster_reference()) :: exgboost_return_type([String.t()]) 300 | def booster_get_attr_names(_booster_handle), do: :erlang.nif_error(:not_implemented) 301 | 302 | @spec booster_get_attr(booster_reference(), String.t()) :: exgboost_return_type(String.t()) 303 | def booster_get_attr(_booster_handle, _key), do: :erlang.nif_error(:not_implemented) 304 | 305 | @spec booster_set_attr(booster_reference(), String.t(), String.t()) :: 306 | :ok | {:error, String.t()} 307 | def booster_set_attr(_booster_handle, _key, _value), do: :erlang.nif_error(:not_implemented) 308 | 309 | @spec booster_get_str_feature_info(booster_reference(), String.t()) :: 310 | exgboost_return_type([String.t()]) 311 | def booster_get_str_feature_info(_booster_resource, _field), 312 | do: :erlang.nif_error(:not_implemented) 313 | 314 | @spec booster_set_str_feature_info(booster_reference(), String.t(), [String.t()]) :: 315 | :ok | {:error, String.t()} 316 | def booster_set_str_feature_info(_booster_resource, _field, _features), 317 | do: :erlang.nif_error(:not_implemented) 318 | 319 | @spec booster_feature_score(booster_reference(), String.t()) :: 320 | exgboost_return_type(tuple()) 321 | def booster_feature_score(_booster_resource, _config), 322 | do: :erlang.nif_error(:not_implemented) 323 | 324 | @spec booster_predict_from_dmatrix(booster_reference(), dmatrix_reference(), String.t()) :: 325 | tuple() 326 | def booster_predict_from_dmatrix(_boster, _dmatrix, _config), 327 | do: :erlang.nif_error(:not_implemented) 328 | 329 | @spec booster_predict_from_dense(booster_reference(), String.t(), String.t(), reference() | nil) :: 330 | tuple() 331 | def booster_predict_from_dense(_boster, _values, _config, _proxy), 332 | do: :erlang.nif_error(:not_implemented) 333 | 334 | @spec booster_predict_from_csr( 335 | booster_reference(), 336 | String.t(), 337 | String.t(), 338 | String.t(), 339 | integer(), 340 | String.t(), 341 | reference() | nil 342 | ) :: 343 | tuple() 344 | def booster_predict_from_csr(_boster, _indptr, _indices, _values, _ncols, _config, _proxy), 345 | do: :erlang.nif_error(:not_implemented) 346 | 347 | @spec proxy_dmatrix_create() :: dmatrix_reference() 348 | def proxy_dmatrix_create, do: :erlang.nif_error(:not_implemented) 349 | 350 | @spec booster_load_model(String.t()) :: 351 | exgboost_return_type(booster_reference()) 352 | def booster_load_model(_path), do: :erlang.nif_error(:not_implemented) 353 | 354 | @spec booster_save_model(booster_reference(), String.t()) :: 355 | :ok | {:error, String.t()} 356 | def booster_save_model(_handle, _path), do: :erlang.nif_error(:not_implemented) 357 | 358 | @spec booster_serialize_to_buffer(booster_reference()) :: binary() 359 | def booster_serialize_to_buffer(_handle), do: :erlang.nif_error(:not_implemented) 360 | 361 | @spec booster_deserialize_from_buffer(binary()) :: exgboost_return_type(booster_reference()) 362 | def booster_deserialize_from_buffer(_buffer), do: :erlang.nif_error(:not_implemented) 363 | 364 | @spec booster_save_model_to_buffer(booster_reference(), String.t()) :: binary() 365 | def booster_save_model_to_buffer(_handle, _config), do: :erlang.nif_error(:not_implemented) 366 | 367 | @spec booster_load_model_from_buffer(binary()) :: exgboost_return_type(booster_reference()) 368 | def booster_load_model_from_buffer(_buffer), do: :erlang.nif_error(:not_implemented) 369 | 370 | @spec booster_load_json_config(booster_reference(), String.t()) :: :ok | {:error, String.t()} 371 | def booster_load_json_config(_handle, _config), do: :erlang.nif_error(:not_implemented) 372 | 373 | @spec booster_save_json_config(booster_reference()) :: binary() 374 | def booster_save_json_config(_handle), do: :erlang.nif_error(:not_implemented) 375 | 376 | def booster_dump_model(_handle, _fmap, _with_stats, _format), 377 | do: :erlang.nif_error(:not_implemented) 378 | end 379 | -------------------------------------------------------------------------------- /lib/exgboost/plotting/style.ex: -------------------------------------------------------------------------------- 1 | defmodule EXGBoost.Plotting.Style do 2 | @moduledoc false 3 | defmacro __using__(_opts) do 4 | quote do 5 | @type style :: Keyword.t() 6 | import EXGBoost.Plotting.Style 7 | Module.register_attribute(__MODULE__, :styles, accumulate: true) 8 | end 9 | end 10 | 11 | def deep_merge_kw(a, b, ignore_set \\ []) do 12 | Keyword.merge(a, b, fn 13 | _key, val_a, val_b when is_list(val_a) and is_list(val_b) -> 14 | deep_merge_kw(val_a, val_b) 15 | 16 | key, val_a, val_b -> 17 | if Keyword.has_key?(b, key) do 18 | if Keyword.has_key?(ignore_set, key) and Keyword.get(ignore_set, key) == val_b do 19 | val_a 20 | else 21 | val_b 22 | end 23 | else 24 | val_a 25 | end 26 | end) 27 | end 28 | 29 | def deep_merge_maps(b, a) do 30 | Map.merge(a, b, fn 31 | _key, val_a, val_b when is_map(val_a) and is_map(val_b) -> 32 | deep_merge_maps(val_a, val_b) 33 | 34 | key, val_a, val_b -> 35 | if Map.has_key?(b, key) do 36 | val_b 37 | else 38 | val_a 39 | end 40 | end) 41 | end 42 | 43 | defmacro style(style_name, do: body) do 44 | quote do 45 | @spec unquote(style_name)() :: style 46 | def unquote(style_name)(), do: unquote(body) 47 | Module.put_attribute(__MODULE__, :styles, {unquote(style_name), unquote(body)}) 48 | end 49 | end 50 | end 51 | -------------------------------------------------------------------------------- /lib/exgboost/plotting/styles.ex: -------------------------------------------------------------------------------- 1 | Mix.env() == :docs && 2 | defmodule EXGBoost.Plotting.Styles do 3 | @bst File.cwd!() |> Path.join("test/data/model.json") |> EXGBoost.read_model() 4 | 5 | @moduledoc """ 6 |
7 | #{Enum.map(EXGBoost.Plotting.get_styles(), fn {name, _style} -> """ 8 |
9 |

#{name}

10 |
11 |           
12 |           #{EXGBoost.plot_tree(@bst, style: name, height: 200, width: 300).spec |> Jason.encode!()}
13 |           
14 |         
15 |
16 | """ end) |> Enum.join("\n\n")} 17 |
18 | """ 19 | end 20 | -------------------------------------------------------------------------------- /lib/exgboost/training.ex: -------------------------------------------------------------------------------- 1 | defmodule EXGBoost.Training do 2 | @moduledoc false 3 | alias EXGBoost.Booster 4 | alias EXGBoost.DMatrix 5 | alias EXGBoost.Training.{State, Callback} 6 | 7 | @spec train(DMatrix.t(), Keyword.t()) :: Booster.t() 8 | def train(%DMatrix{} = dmat, opts \\ []) do 9 | dmat_opts = Keyword.take(opts, EXGBoost.Internal.dmatrix_feature_opts()) 10 | 11 | valid_opts = [ 12 | callbacks: [], 13 | early_stopping_rounds: nil, 14 | evals: [], 15 | learning_rates: nil, 16 | num_boost_rounds: 10, 17 | obj: nil, 18 | verbose_eval: true, 19 | disable_default_eval_metric: false 20 | ] 21 | 22 | {opts, booster_params} = Keyword.split(opts, Keyword.keys(valid_opts)) 23 | 24 | [ 25 | callbacks: callbacks, 26 | disable_default_eval_metric: disable_default_eval_metric, 27 | early_stopping_rounds: early_stopping_rounds, 28 | evals: evals, 29 | learning_rates: learning_rates, 30 | num_boost_rounds: num_boost_rounds, 31 | obj: objective, 32 | verbose_eval: verbose_eval 33 | ] = opts |> Keyword.validate!(valid_opts) |> Enum.sort() 34 | 35 | unless is_nil(learning_rates) or is_function(learning_rates, 1) or is_list(learning_rates) do 36 | raise ArgumentError, "learning_rates must be a function/1 or a list" 37 | end 38 | 39 | if early_stopping_rounds && evals == [] do 40 | raise ArgumentError, "early_stopping_rounds requires at least one evaluation set" 41 | end 42 | 43 | verbose_eval = 44 | case verbose_eval do 45 | true -> 1 46 | false -> 0 47 | value -> value 48 | end 49 | 50 | evals_dmats = 51 | Enum.map(evals, fn {x, y, name} -> 52 | {DMatrix.from_tensor(x, y, Keyword.put_new(dmat_opts, :format, :dense)), name} 53 | end) 54 | 55 | bst = 56 | Booster.booster( 57 | [dmat | Enum.map(evals_dmats, fn {dmat, _name} -> dmat end)], 58 | booster_params 59 | ) 60 | 61 | defaults = 62 | default_callbacks( 63 | bst, 64 | learning_rates, 65 | verbose_eval, 66 | evals_dmats, 67 | early_stopping_rounds, 68 | disable_default_eval_metric 69 | ) 70 | 71 | callbacks = 72 | Enum.map(callbacks ++ defaults, fn %Callback{fun: fun} = callback -> 73 | %{callback | fun: fn state -> state |> fun.() |> State.validate!() end} 74 | end) 75 | 76 | # Validate callbacks and ensure all names are unique. 77 | Enum.each(callbacks, &Callback.validate!/1) 78 | name_counts = Enum.frequencies_by(callbacks, & &1.name) 79 | 80 | if Enum.any?(name_counts, &(elem(&1, 1) > 1)) do 81 | str = name_counts |> Enum.sort() |> Enum.map_join("\n\n", &" * #{inspect(&1)}") 82 | raise ArgumentError, "Found duplicate callback names.\n\nName counts:\n\n#{str}\n" 83 | end 84 | 85 | state = %State{ 86 | booster: bst, 87 | iteration: 0, 88 | max_iteration: num_boost_rounds, 89 | meta_vars: Map.new(callbacks, &{&1.name, &1.init_state}) 90 | } 91 | 92 | callbacks = Enum.group_by(callbacks, & &1.event, & &1.fun) 93 | 94 | state = 95 | state 96 | |> run_callbacks(callbacks, :before_training) 97 | |> run_training(callbacks, dmat, objective) 98 | |> run_callbacks(callbacks, :after_training) 99 | 100 | state.booster 101 | end 102 | 103 | defp run_callbacks(%{status: :halt} = state, _callbacks, _event), do: state 104 | 105 | defp run_callbacks(%{status: :cont} = state, callbacks, event) do 106 | Enum.reduce_while(callbacks[event] || [], state, fn callback, state -> 107 | state = callback.(state) 108 | {state.status, state} 109 | end) 110 | end 111 | 112 | defp run_training(%{status: :halt} = state, _callbacks, _dmat, _objective), do: state 113 | 114 | defp run_training(%{status: :cont} = state, callbacks, dmat, objective) do 115 | Enum.reduce_while(1..state.max_iteration, state, fn iter, state -> 116 | state = 117 | state 118 | |> run_callbacks(callbacks, :before_iteration) 119 | |> run_iteration(dmat, iter, objective) 120 | |> run_callbacks(callbacks, :after_iteration) 121 | 122 | {state.status, state} 123 | end) 124 | end 125 | 126 | defp run_iteration(%{status: :halt} = state, _dmat, _iter, _objective), do: state 127 | 128 | defp run_iteration(%{status: :cont} = state, dmat, iter, objective) do 129 | :ok = Booster.update(state.booster, dmat, iter, objective) 130 | %{state | iteration: iter} 131 | end 132 | 133 | defp default_callbacks( 134 | bst, 135 | learning_rates, 136 | verbose_eval, 137 | evals_dmats, 138 | early_stopping_rounds, 139 | disable_default_eval_metric 140 | ) do 141 | default_callbacks = [] 142 | 143 | default_callbacks = 144 | if learning_rates do 145 | lr_scheduler = %Callback{ 146 | event: :before_iteration, 147 | fun: &Callback.lr_scheduler/1, 148 | name: :lr_scheduler, 149 | init_state: %{learning_rates: learning_rates} 150 | } 151 | 152 | [lr_scheduler | default_callbacks] 153 | else 154 | default_callbacks 155 | end 156 | 157 | default_callbacks = 158 | if verbose_eval != 0 and evals_dmats != [] do 159 | monitor_metrics = %Callback{ 160 | event: :after_iteration, 161 | fun: &Callback.monitor_metrics/1, 162 | name: :monitor_metrics, 163 | init_state: %{period: verbose_eval, filter: fn {_, _} -> true end} 164 | } 165 | 166 | [monitor_metrics | default_callbacks] 167 | else 168 | default_callbacks 169 | end 170 | 171 | default_callbacks = 172 | if early_stopping_rounds && evals_dmats != [] do 173 | [{_dmat, target_eval} | _tail] = Enum.reverse(evals_dmats) 174 | 175 | # This is still somewhat hacky and relies on a modification made to 176 | # XGBoost in the Makefile to dump the config to JSON. 177 | # 178 | %{"learner" => %{"metrics" => metrics, "default_metric" => default_metric}} = 179 | EXGBoost.dump_config(bst) |> Jason.decode!() 180 | 181 | metric_name = 182 | cond do 183 | Enum.empty?(metrics) && disable_default_eval_metric -> 184 | raise ArgumentError, 185 | "`:early_stopping_rounds` requires at least one evaluation set. This means you have likely set `disable_default_eval_metric: true` and have not set any explicit evalutation metrics. Please supply at least one metric in the `:eval_metric` option or set `disable_default_eval_metric: false` (default option)" 186 | 187 | Enum.empty?(metrics) -> 188 | default_metric 189 | 190 | true -> 191 | metrics |> Enum.reverse() |> hd() |> Map.fetch!("name") 192 | end 193 | 194 | early_stop = %Callback{ 195 | event: :after_iteration, 196 | fun: &Callback.early_stop/1, 197 | name: :early_stop, 198 | init_state: %{ 199 | patience: early_stopping_rounds, 200 | best: nil, 201 | since_last_improvement: 0, 202 | mode: :min, 203 | target_eval: target_eval, 204 | target_metric: metric_name 205 | } 206 | } 207 | 208 | [early_stop | default_callbacks] 209 | else 210 | default_callbacks 211 | end 212 | 213 | default_callbacks = 214 | if evals_dmats != [] do 215 | eval_metrics = %Callback{ 216 | event: :after_iteration, 217 | fun: &Callback.eval_metrics/1, 218 | name: :eval_metrics, 219 | init_state: %{evals: evals_dmats, filter: fn {_, _} -> true end} 220 | } 221 | 222 | [eval_metrics | default_callbacks] 223 | else 224 | default_callbacks 225 | end 226 | 227 | default_callbacks 228 | end 229 | end 230 | -------------------------------------------------------------------------------- /lib/exgboost/training/callback.ex: -------------------------------------------------------------------------------- 1 | defmodule EXGBoost.Training.Callback do 2 | @moduledoc """ 3 | Callbacks are a mechanism to hook into the training process and perform custom actions. 4 | 5 | Callbacks are structs with the following fields: 6 | * `event` - the event that triggers the callback 7 | * `fun` - the function to call when the callback is triggered 8 | * `name` - the name of the callback 9 | * `init_state` - the initial state of the callback 10 | 11 | The following events are supported: 12 | * `:before_training` - called before the training starts 13 | * `:after_training` - called after the training ends 14 | * `:before_iteration` - called before each iteration 15 | * `:after_iteration` - called after each iteration 16 | 17 | The callback function is called with the following arguments: 18 | * `state` - the current training state 19 | 20 | The callback function should return one of the following: 21 | * `{:cont, state}` - continue training with the given state 22 | * `{:halt, state}` - stop training with the given state 23 | 24 | The following callbacks are provided in the `EXGBoost.Training.Callback` module: 25 | * `lr_scheduler` - sets the learning rate for each iteration 26 | * `early_stop` - performs early stopping 27 | * `eval_metrics` - evaluates metrics on the training and evaluation sets 28 | * `eval_monitor` - prints evaluation metrics 29 | 30 | Callbacks can be added to the training process by passing them to `EXGBoost.Training.train/2`. 31 | 32 | ## Example 33 | 34 | ```elixir 35 | # Callback to perform setup before training 36 | setup_fn = fn state -> 37 | updated_state = put_in(state, [:meta_vars,:early_stop], %{best: 1, since_last_improvement: 0, mode: :max, patience: 5}) 38 | {:cont, updated_state} 39 | end 40 | 41 | setup_callback = Callback.new(:before_training, setup_fn) 42 | ``` 43 | 44 | """ 45 | alias EXGBoost.Training.State 46 | @enforce_keys [:event, :fun] 47 | defstruct [:event, :fun, :name, :init_state] 48 | 49 | @type event :: :before_training | :after_training | :before_iteration | :after_iteration 50 | @type fun :: (State.t() -> State.t()) 51 | 52 | @valid_events [:before_training, :after_training, :before_iteration, :after_iteration] 53 | 54 | @doc """ 55 | Factory for a new callback with an initial state. 56 | """ 57 | @spec new(event :: event(), fun :: fun(), name :: atom(), init_state :: any()) :: Callback.t() 58 | def new(event, fun, name, init_state \\ %{}) 59 | when event in @valid_events and is_function(fun, 1) and is_atom(name) and not is_nil(name) do 60 | %__MODULE__{event: event, fun: fun, name: name, init_state: init_state} 61 | |> validate!() 62 | end 63 | 64 | def validate!(%__MODULE__{} = callback) do 65 | unless is_atom(callback.name) and not is_nil(callback.name) do 66 | raise "A callback must have a non-`nil` atom for a name. Found: #{callback.name}." 67 | end 68 | 69 | unless callback.event in @valid_events do 70 | raise "Callback #{callback.name} must have an event in #{@valid_events}. Found: #{callback.event}." 71 | end 72 | 73 | unless is_function(callback.fun, 1) do 74 | raise "Callback #{callback.name} must have a 1-arity function. Found: #{callback.event}." 75 | end 76 | 77 | callback 78 | end 79 | 80 | @doc """ 81 | A callback that sets the learning rate for each iteration. 82 | 83 | Requires that `learning_rates` either be a list of learning rates or a function that takes the 84 | iteration number and returns a learning rate. `learning_rates` must exist in the `state` that 85 | is passed to the callback. 86 | """ 87 | def lr_scheduler( 88 | %State{ 89 | booster: bst, 90 | meta_vars: %{lr_scheduler: %{learning_rates: learning_rates}}, 91 | iteration: i, 92 | status: :cont 93 | } = state 94 | ) do 95 | lr = if is_list(learning_rates), do: Enum.at(learning_rates, i), else: learning_rates.(i) 96 | boostr = EXGBoost.Booster.set_params(bst, learning_rate: lr) 97 | %{state | booster: boostr} 98 | end 99 | 100 | # TODO: Ideally this would be generalized like it is in Axon to allow generic monitoring of metrics, 101 | # but for now we'll just do early stopping 102 | 103 | @doc """ 104 | A callback function that performs early stopping. 105 | 106 | Requires that the following exist in the `state` that is passed to the callback: 107 | 108 | * `target` is the metric to monitor for early stopping. It must exist in the `metrics` that the 109 | state contains. 110 | * `mode` is either `:min` or `:max` and indicates whether the metric should be 111 | minimized or maximized. 112 | * `patience` is the number of iterations to wait for the metric to improve before stopping. 113 | * `since_last_improvement` is the number of iterations since the metric last improved. 114 | * `best` is the best value of the metric seen so far. 115 | """ 116 | def early_stop( 117 | %State{ 118 | booster: bst, 119 | meta_vars: %{early_stop: early_stop} = meta_vars, 120 | metrics: metrics, 121 | status: :cont 122 | } = state 123 | ) do 124 | %{ 125 | best: best_score, 126 | patience: patience, 127 | target_metric: target_metric, 128 | target_eval: target_eval, 129 | mode: mode, 130 | since_last_improvement: since_last_improvement 131 | } = early_stop 132 | 133 | unless Map.has_key?(metrics, target_eval) do 134 | raise ArgumentError, 135 | "target eval_set #{inspect(target_eval)} not found in metrics #{inspect(metrics)}" 136 | end 137 | 138 | unless Map.has_key?(metrics[target_eval], target_metric) do 139 | raise ArgumentError, 140 | "target metric #{inspect(target_metric)} not found in metrics #{inspect(metrics)}" 141 | end 142 | 143 | score = metrics[target_eval][target_metric] 144 | 145 | improved? = 146 | cond do 147 | best_score == nil -> true 148 | mode == :min -> score < best_score 149 | mode == :max -> score > best_score 150 | end 151 | 152 | cond do 153 | improved? -> 154 | early_stop = %{early_stop | best: score, since_last_improvement: 0} 155 | 156 | bst = 157 | bst 158 | |> struct(best_iteration: state.iteration, best_score: score) 159 | |> EXGBoost.Booster.set_attr(best_iteration: state.iteration, best_score: score) 160 | 161 | %{state | booster: bst, meta_vars: %{meta_vars | early_stop: early_stop}} 162 | 163 | since_last_improvement < patience -> 164 | early_stop = Map.update!(early_stop, :since_last_improvement, &(&1 + 1)) 165 | %{state | meta_vars: %{meta_vars | early_stop: early_stop}} 166 | 167 | true -> 168 | early_stop = Map.update!(early_stop, :since_last_improvement, &(&1 + 1)) 169 | # TODO: Should this actually update the best iteration and score? 170 | # This iteration is not the best, but it is the last one, so do we want 171 | # another way to track last iteration? 172 | bst = struct(bst, best_iteration: state.iteration, best_score: score) 173 | %{state | booster: bst, meta_vars: %{meta_vars | early_stop: early_stop}, status: :halt} 174 | end 175 | end 176 | 177 | @doc """ 178 | A callback function that evaluates metrics on the training and evaluation sets. 179 | 180 | Requires that the following exist in the `state.meta_vars` that is passed to the callback: 181 | * eval_metrics: 182 | * evals: a list of evaluation sets to evaluate metrics on 183 | * filter: a function that takes a metric name and value and returns 184 | true if the metric should be included in the results 185 | """ 186 | def eval_metrics( 187 | %State{ 188 | booster: bst, 189 | iteration: iter, 190 | meta_vars: %{eval_metrics: %{evals: evals, filter: filter}}, 191 | status: :cont 192 | } = state 193 | ) do 194 | metrics = 195 | EXGBoost.Booster.eval_set(bst, evals, iter) 196 | |> Enum.reduce(%{}, fn {evname, mname, value}, acc -> 197 | Map.update(acc, evname, %{mname => value}, fn existing -> 198 | Map.put(existing, mname, value) 199 | end) 200 | end) 201 | |> Map.filter(filter) 202 | 203 | %{state | metrics: metrics} 204 | end 205 | 206 | @doc """ 207 | A callback function that prints evaluation metrics according to a period. 208 | 209 | Requires that the following exist in the `state.meta_vars` that is passed to the callback: 210 | * monitor_metrics: 211 | * period: print metrics every `period` iterations 212 | * filter: a function that takes a metric name and value and returns 213 | true if the metric should be included in the results 214 | """ 215 | def monitor_metrics( 216 | %State{ 217 | iteration: iteration, 218 | metrics: metrics, 219 | meta_vars: %{ 220 | monitor_metrics: %{period: period, filter: filter} 221 | }, 222 | status: :cont 223 | } = state 224 | ) do 225 | if period != 0 and rem(iteration, period) == 0 do 226 | IO.puts("Iteration #{iteration}: #{inspect(Map.filter(metrics, filter))}") 227 | end 228 | 229 | state 230 | end 231 | end 232 | -------------------------------------------------------------------------------- /lib/exgboost/training/state.ex: -------------------------------------------------------------------------------- 1 | defmodule EXGBoost.Training.State do 2 | @moduledoc false 3 | @enforce_keys [:booster] 4 | defstruct [ 5 | :booster, 6 | iteration: 0, 7 | max_iteration: -1, 8 | meta_vars: %{}, 9 | metrics: %{}, 10 | status: :cont 11 | ] 12 | 13 | def validate!(%__MODULE__{} = state) do 14 | unless state.status in [:cont, :halt] do 15 | raise ArgumentError, 16 | "`status` must be `:cont` or `:halt`, found: `#{inspect(state.status)}`." 17 | end 18 | 19 | state 20 | end 21 | end 22 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule EXGBoost.MixProject do 2 | use Mix.Project 3 | @version "0.5.1" 4 | 5 | def project do 6 | [ 7 | app: :exgboost, 8 | version: @version, 9 | make_precompiler: {:nif, CCPrecompiler}, 10 | make_precompiler_url: 11 | "https://github.com/acalejos/exgboost/releases/download/v#{@version}/@{artefact_filename}", 12 | make_precompiler_priv_paths: ["libexgboost.*", "lib"], 13 | # NIF Versions correspond to OTP Releases 14 | # https://github.com/erlang/otp/blob/d3aa6c044c3927f011fb76ac087d5ce0e814954c/erts/emulator/beam/erl_nif.h#L57 15 | make_precompiler_nif_versions: [ 16 | versions: ["2.15", "2.16", "2.17"] 17 | ], 18 | elixir: "~> 1.14", 19 | start_permanent: Mix.env() == :prod, 20 | compilers: [:elixir_make] ++ Mix.compilers(), 21 | deps: deps(), 22 | name: "EXGBoost", 23 | source_url: "https://github.com/acalejos/exgboost", 24 | homepage_url: "https://github.com/acalejos/exgboost", 25 | docs: docs(), 26 | package: package(), 27 | preferred_cli_env: [ 28 | docs: :docs, 29 | "hex.publish": :docs 30 | ], 31 | before_closing_body_tag: &before_closing_body_tag/1, 32 | name: "EXGBoost", 33 | description: 34 | "Elixir bindings for the XGBoost library. `EXGBoost` provides an implementation of XGBoost that works with 35 | [Nx](https://hexdocs.pm/nx/Nx.html) tensors." 36 | ] 37 | end 38 | 39 | def application do 40 | [ 41 | extra_applications: [:logger], 42 | mod: {EXGBoost.Application, []} 43 | ] 44 | end 45 | 46 | defp deps do 47 | [ 48 | {:elixir_make, "~> 0.4", runtime: false}, 49 | {:nimble_options, "~> 1.0"}, 50 | {:nx, "~> 0.7"}, 51 | {:jason, "~> 1.3"}, 52 | {:ex_doc, "~> 0.31.0", only: :docs}, 53 | {:cc_precompiler, "~> 0.1.0", runtime: false}, 54 | {:exterval, "0.1.0"}, 55 | {:ex_json_schema, "~> 0.10.2"}, 56 | {:httpoison, "~> 2.0", runtime: false}, 57 | {:vega_lite, "~> 0.1"}, 58 | {:kino, "~> 0.11"}, 59 | {:scidata, "~> 0.1", only: :dev}, 60 | {:kino_vega_lite, "~> 0.1.9", only: :dev} 61 | ] 62 | end 63 | 64 | defp package do 65 | [ 66 | maintainers: ["Andres Alejos"], 67 | licenses: ["Apache-2.0"], 68 | links: %{"GitHub" => "https://github.com/acalejos/exgboost"}, 69 | files: [ 70 | "lib", 71 | "mix.exs", 72 | "c", 73 | "Makefile", 74 | "README.md", 75 | "LICENSE", 76 | ".formatter.exs", 77 | "checksum.exs" 78 | ] 79 | ] 80 | end 81 | 82 | defp docs do 83 | [ 84 | main: "EXGBoost", 85 | extras: [ 86 | "notebooks/compiled_benchmarks.livemd", 87 | "notebooks/iris_classification.livemd", 88 | "notebooks/quantile_prediction_interval.livemd", 89 | "notebooks/plotting.livemd" 90 | ], 91 | groups_for_extras: [ 92 | Notebooks: Path.wildcard("notebooks/*.livemd") 93 | ], 94 | groups_for_functions: [ 95 | "System / Native Config": &(&1[:type] == :system), 96 | "Training & Prediction": &(&1[:type] == :train_pred), 97 | Serialization: &(&1[:type] == :serialization), 98 | Plotting: &(&1[:type] == :plotting) 99 | ], 100 | groups_for_modules: [ 101 | Plotting: [EXGBoost.Plotting, EXGBoost.Plotting.Styles], 102 | Training: [ 103 | EXGBoost.Training, 104 | EXGBoost.Training.Callback, 105 | EXGBoost.Booster, 106 | EXGBoost.Parameters 107 | ] 108 | ], 109 | before_closing_body_tag: &before_closing_body_tag/1 110 | ] 111 | end 112 | 113 | defp before_closing_body_tag(:html) do 114 | """ 115 | 116 | 117 | 118 | 119 | 129 | 130 | 131 | 132 | 150 | 151 | 152 | 153 | 154 | 155 | 166 | 182 | """ 183 | end 184 | 185 | defp before_closing_body_tag(_), do: "" 186 | end 187 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "castore": {:hex, :castore, "0.1.22", "4127549e411bedd012ca3a308dede574f43819fe9394254ca55ab4895abfa1a2", [:mix], [], "hexpm", "c17576df47eb5aa1ee40cc4134316a99f5cad3e215d5c77b8dd3cfef12a22cac"}, 3 | "cc_precompiler": {:hex, :cc_precompiler, "0.1.9", "e8d3364f310da6ce6463c3dd20cf90ae7bbecbf6c5203b98bf9b48035592649b", [:mix], [{:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "9dcab3d0f3038621f1601f13539e7a9ee99843862e66ad62827b0c42b2f58a54"}, 4 | "certifi": {:hex, :certifi, "2.12.0", "2d1cca2ec95f59643862af91f001478c9863c2ac9cb6e2f89780bfd8de987329", [:rebar3], [], "hexpm", "ee68d85df22e554040cdb4be100f33873ac6051387baf6a8f6ce82272340ff1c"}, 5 | "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, 6 | "decimal": {:hex, :decimal, "2.1.1", "5611dca5d4b2c3dd497dec8f68751f1f1a54755e8ed2a966c2633cf885973ad6", [:mix], [], "hexpm", "53cfe5f497ed0e7771ae1a475575603d77425099ba5faef9394932b35020ffcc"}, 7 | "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, 8 | "elixir_make": {:hex, :elixir_make, "0.7.8", "505026f266552ee5aabca0b9f9c229cbb496c689537c9f922f3eb5431157efc7", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "7a71945b913d37ea89b06966e1342c85cfe549b15e6d6d081e8081c493062c07"}, 9 | "ex_doc": {:hex, :ex_doc, "0.31.1", "8a2355ac42b1cc7b2379da9e40243f2670143721dd50748bf6c3b1184dae2089", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3178c3a407c557d8343479e1ff117a96fd31bafe52a039079593fb0524ef61b0"}, 10 | "ex_json_schema": {:hex, :ex_json_schema, "0.10.2", "7c4b8c1481fdeb1741e2ce66223976edfb9bccebc8014f6aec35d4efe964fb71", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "37f43be60f8407659d4d0155a7e45e7f406dab1f827051d3d35858a709baf6a6"}, 11 | "exterval": {:hex, :exterval, "0.1.0", "d35ec43b0f260239f859665137fac0974f1c6a8d50bf8d52b5999c87c67c63e5", [:mix], [], "hexpm", "8ed444558a501deec6563230e3124cdf242c413a95eb9ca9f39de024ad779d7f"}, 12 | "fss": {:hex, :fss, "0.1.1", "9db2344dbbb5d555ce442ac7c2f82dd975b605b50d169314a20f08ed21e08642", [:mix], [], "hexpm", "78ad5955c7919c3764065b21144913df7515d52e228c09427a004afe9c1a16b0"}, 13 | "hackney": {:hex, :hackney, "1.20.1", "8d97aec62ddddd757d128bfd1df6c5861093419f8f7a4223823537bad5d064e2", [:rebar3], [{:certifi, "~> 2.12.0", [hex: :certifi, repo: "hexpm", optional: false]}, {:idna, "~> 6.1.0", [hex: :idna, repo: "hexpm", optional: false]}, {:metrics, "~> 1.0.0", [hex: :metrics, repo: "hexpm", optional: false]}, {:mimerl, "~> 1.1", [hex: :mimerl, repo: "hexpm", optional: false]}, {:parse_trans, "3.4.1", [hex: :parse_trans, repo: "hexpm", optional: false]}, {:ssl_verify_fun, "~> 1.1.0", [hex: :ssl_verify_fun, repo: "hexpm", optional: false]}, {:unicode_util_compat, "~> 0.7.0", [hex: :unicode_util_compat, repo: "hexpm", optional: false]}], "hexpm", "fe9094e5f1a2a2c0a7d10918fee36bfec0ec2a979994cff8cfe8058cd9af38e3"}, 14 | "httpoison": {:hex, :httpoison, "2.2.1", "87b7ed6d95db0389f7df02779644171d7319d319178f6680438167d7b69b1f3d", [:mix], [{:hackney, "~> 1.17", [hex: :hackney, repo: "hexpm", optional: false]}], "hexpm", "51364e6d2f429d80e14fe4b5f8e39719cacd03eb3f9a9286e61e216feac2d2df"}, 15 | "idna": {:hex, :idna, "6.1.1", "8a63070e9f7d0c62eb9d9fcb360a7de382448200fbbd1b106cc96d3d8099df8d", [:rebar3], [{:unicode_util_compat, "~> 0.7.0", [hex: :unicode_util_compat, repo: "hexpm", optional: false]}], "hexpm", "92376eb7894412ed19ac475e4a86f7b413c1b9fbb5bd16dccd57934157944cea"}, 16 | "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, 17 | "kino": {:hex, :kino, "0.12.3", "a5f48a243c60a7ac18ba23869f697b1c775fc7794e8cd55dd248ba33c6fe9445", [:mix], [{:fss, "~> 0.1.0", [hex: :fss, repo: "hexpm", optional: false]}, {:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "a6dfa3d54ba0edec9ca6e5940154916b381901001f171c85a2d8c67869dbc2d8"}, 18 | "kino_vega_lite": {:hex, :kino_vega_lite, "0.1.11", "d3c2a00b3685b95f91833920d06cc9b1fd7fb293a2663d89affe9aaec16a5b77", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.8", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "5ccd9148ce7cfcc95a137e12596cd8b95b371e9ea107e745bc262c39c5d8d48e"}, 19 | "makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"}, 20 | "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, 21 | "makeup_erlang": {:hex, :makeup_erlang, "0.1.3", "d684f4bac8690e70b06eb52dad65d26de2eefa44cd19d64a8095e1417df7c8fd", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "b78dc853d2e670ff6390b605d807263bf606da3c82be37f9d7f68635bd886fc9"}, 22 | "metrics": {:hex, :metrics, "1.0.1", "25f094dea2cda98213cecc3aeff09e940299d950904393b2a29d191c346a8486", [:rebar3], [], "hexpm", "69b09adddc4f74a40716ae54d140f93beb0fb8978d8636eaded0c31b6f099f16"}, 23 | "mimerl": {:hex, :mimerl, "1.2.0", "67e2d3f571088d5cfd3e550c383094b47159f3eee8ffa08e64106cdf5e981be3", [:rebar3], [], "hexpm", "f278585650aa581986264638ebf698f8bb19df297f66ad91b18910dfc6e19323"}, 24 | "nimble_csv": {:hex, :nimble_csv, "1.2.0", "4e26385d260c61eba9d4412c71cea34421f296d5353f914afe3f2e71cce97722", [:mix], [], "hexpm", "d0628117fcc2148178b034044c55359b26966c6eaa8e2ce15777be3bbc91b12a"}, 25 | "nimble_options": {:hex, :nimble_options, "1.1.0", "3b31a57ede9cb1502071fade751ab0c7b8dbe75a9a4c2b5bbb0943a690b63172", [:mix], [], "hexpm", "8bbbb3941af3ca9acc7835f5655ea062111c9c27bcac53e004460dfd19008a99"}, 26 | "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, 27 | "nx": {:hex, :nx, "0.7.2", "7f6f6584585e49ffbf81769e7ccc2d01c5639074e399c1f94adc2b509869673e", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e2c0680066eec5af8b8ef00c99e9bf40a0d08d8b2bbba77f59f801ec54a3f90e"}, 28 | "parse_trans": {:hex, :parse_trans, "3.4.1", "6e6aa8167cb44cc8f39441d05193be6e6f4e7c2946cb2759f015f8c56b76e5ff", [:rebar3], [], "hexpm", "620a406ce75dada827b82e453c19cf06776be266f5a67cff34e1ef2cbb60e49a"}, 29 | "scidata": {:hex, :scidata, "0.1.11", "fe3358bac7d740374b4f2a7eff6a1cb02e5ee7f87f7cdb1e8648ad93c533165f", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.1", [hex: :nimble_csv, repo: "hexpm", optional: false]}, {:stb_image, "~> 0.4", [hex: :stb_image, repo: "hexpm", optional: true]}], "hexpm", "90873337a9d5fe880d640517efa93d3c07e46c8ba436de44117f581800549f93"}, 30 | "ssl_verify_fun": {:hex, :ssl_verify_fun, "1.1.7", "354c321cf377240c7b8716899e182ce4890c5938111a1296add3ec74cf1715df", [:make, :mix, :rebar3], [], "hexpm", "fe4c190e8f37401d30167c8c405eda19469f34577987c76dde613e838bbc67f8"}, 31 | "table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"}, 32 | "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, 33 | "unicode_util_compat": {:hex, :unicode_util_compat, "0.7.0", "bc84380c9ab48177092f43ac89e4dfa2c6d62b40b8bd132b1059ecc7232f9a78", [:rebar3], [], "hexpm", "25eee6d67df61960cf6a794239566599b09e17e668d3700247bc498638152521"}, 34 | "vega_lite": {:hex, :vega_lite, "0.1.8", "7f6119126ecaf4bc2c1854084370d7091424f5cce4795fbac044eee9963f0752", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "6c8a9271f850612dd8a90de8d1ebd433590ed07ffef76fc2397c240dc04d3fdc"}, 35 | } 36 | -------------------------------------------------------------------------------- /notebooks/compiled_benchmarks.livemd: -------------------------------------------------------------------------------- 1 | # Compiled Decision Trees Benchmark 2 | 3 | ```elixir 4 | Mix.install([ 5 | {:scidata, "~> 0.1"}, 6 | {:exgboost, "~> 0.4"}, 7 | {:mockingjay, github: "acalejos/mockingjay"}, 8 | {:nx, "~> 0.5", override: true}, 9 | {:exla, "~> 0.5"}, 10 | {:scholar, "~> 0.2"}, 11 | {:benchee, "~> 1.0"} 12 | ]) 13 | ``` 14 | 15 | ## Setup Dataset 16 | 17 | ```elixir 18 | {x, y} = Scidata.Iris.download() 19 | data = Enum.zip(x, y) |> Enum.shuffle() 20 | {train, test} = Enum.split(data, ceil(length(data) * 0.8)) 21 | {x_train, y_train} = Enum.unzip(train) 22 | {x_test, y_test} = Enum.unzip(test) 23 | 24 | x_train = Nx.tensor(x_train) 25 | y_train = Nx.tensor(y_train) 26 | 27 | x_test = Nx.tensor(x_test) 28 | y_test = Nx.tensor(y_test) 29 | ``` 30 | 31 | ## Gather Model / Prediction Functions 32 | 33 | `EXGBoost.compile/1` will convert your trained `Booster` model into a set of tensor operations which can then be run on any of the `Nx` backends. 34 | 35 | ```elixir 36 | # Get Baseline Model (XGBoost C API) 37 | model = EXGBoost.train(x_train, y_train, num_class: 3, objective: :multi_softprob) 38 | # Get Compiled Models w/ Binary Backend 39 | Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) 40 | Nx.default_backend(Nx.BinaryBackend) 41 | gemm_predict = EXGBoost.compile(model, strategy: :gemm) 42 | gemm_jit_exla = EXLA.jit(gemm_predict) 43 | tree_trav_predict = EXGBoost.compile(model, strategy: :tree_traversal) 44 | tree_trav_jit_exla = EXLA.jit(tree_trav_predict) 45 | ptt_predict = EXGBoost.compile(model, strategy: :perfect_tree_traversal) 46 | ptt_jit_exla = EXLA.jit(ptt_predict) 47 | # Get Compiled Models w/ EXLA Backend 48 | Nx.Defn.default_options(compiler: EXLA) 49 | Nx.default_backend(EXLA.Backend) 50 | gemm_exla = EXGBoost.compile(model, strategy: :gemm) 51 | tree_trav_exla = EXGBoost.compile(model, strategy: :tree_traversal) 52 | ptt_exla = EXGBoost.compile(model, strategy: :perfect_tree_traversal) 53 | 54 | funcs = %{ 55 | "Base" => fn x -> EXGBoost.predict(model, x) end, 56 | "Compiled -- GEMM Strategy -- Binary Backend" => fn x -> gemm_predict.(x) end, 57 | "Compiled -- Tree Traversal Strategy -- Binary Backend" => fn x -> tree_trav_predict.(x) end, 58 | "Compiled -- Perfect Tree Traversal Strategy -- Binary Backend" => fn x -> ptt_predict.(x) end, 59 | "Compiled -- GEMM Strategy -- EXLA Backend" => fn x -> gemm_exla.(x) end, 60 | "Compiled -- Tree Traversal Strategy -- EXLA Backend" => fn x -> tree_trav_exla.(x) end, 61 | "Compiled -- Perfect Tree Traversal Strategy -- EXLA Backend" => fn x -> ptt_exla.(x) end, 62 | "Compiled -- GEMM Strategy -- EXLA Backend (JIT)" => fn x -> gemm_jit_exla.(x) end, 63 | "Compiled -- Tree Traversal Strategy -- EXLA Backend (JIT)" => fn x -> 64 | tree_trav_jit_exla.(x) 65 | end, 66 | "Compiled -- Perfect Tree Traversal Strategy -- EXLA Backend (JIT)" => fn x -> 67 | ptt_jit_exla.(x) 68 | end 69 | } 70 | ``` 71 | 72 | ## Run Time Benchmarks 73 | 74 | ```elixir 75 | benches = Map.new(funcs, fn {k, v} -> {k, v.(x_train)} end) 76 | 77 | Benchee.run(benches, 78 | time: 10, 79 | memory_time: 2, 80 | warmup: 5 81 | ) 82 | ``` 83 | 84 | ## Compare Accuracies 85 | 86 | ```elixir 87 | Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) 88 | Nx.default_backend(Nx.BinaryBackend) 89 | 90 | accuracies = 91 | Enum.reduce(funcs, %{}, fn {name, pred_fn}, acc -> 92 | accuracy = 93 | pred_fn.(x_test) 94 | |> Nx.argmax(axis: -1) 95 | |> then(&Scholar.Metrics.Classification.accuracy(y_test, &1)) 96 | |> Nx.to_number() 97 | 98 | Map.put(acc, name, accuracy) 99 | end) 100 | ``` 101 | -------------------------------------------------------------------------------- /notebooks/iris_classification.livemd: -------------------------------------------------------------------------------- 1 | # Iris Classification with Gradient Boosting 2 | 3 | ```elixir 4 | Mix.install([ 5 | {:exgboost, "~> 0.5"}, 6 | {:nx, "~> 0.5"}, 7 | {:scidata, "~> 0.1"}, 8 | {:scholar, "~> 0.1"} 9 | ]) 10 | ``` 11 | 12 | ## Data 13 | 14 | We'll be working with the Iris flower dataset. The Iris dataset consists of features corresponding to measurements of 3 different species of the Iris flower. Overall we have 150 examples, each with 4 featurse and a numeric label mapping to 1 of the 3 species. We can download this dataset using [Scidata](https://github.com/elixir-nx/scidata): 15 | 16 | ```elixir 17 | {x, y} = Scidata.Iris.download() 18 | :ok 19 | ``` 20 | 21 | Scidata doesn't provide train-test splits for Iris. Instead, we'll need to shuffle the original dataset and split manually. We'll save 20% of the dataset for testing: 22 | 23 | ```elixir 24 | data = Enum.zip(x, y) |> Enum.shuffle() 25 | {train, test} = Enum.split(data, ceil(length(data) * 0.8)) 26 | :ok 27 | ``` 28 | 29 | EXGBoost requires inputs to be [Nx](https://github.com/elixir-nx/nx) tensors. The conversion for this example is rather easy as we can just wrap both features and labels in a call to `Nx.tensor/1`: 30 | 31 | ```elixir 32 | {x_train, y_train} = Enum.unzip(train) 33 | {x_test, y_test} = Enum.unzip(test) 34 | 35 | x_train = Nx.tensor(x_train) 36 | y_train = Nx.tensor(y_train) 37 | 38 | x_test = Nx.tensor(x_test) 39 | y_test = Nx.tensor(y_test) 40 | 41 | x_train 42 | ``` 43 | 44 | ```elixir 45 | y_train 46 | ``` 47 | 48 | We now have both train and test sets consisting of features and labels. Time to train a booster! 49 | 50 | ## Training 51 | 52 | The simplest way to train a booster is using the top-level `EXGBoost.train/2` function. This function expects input features and labels, as well as some optional training configuration parameters. 53 | 54 | This example is a multi-class classification problem with 3 output classes. We need to configure EXGBoost to train this booster as a multi-class classifier by specifying a different training objective. We also need to specify the number of output classes: 55 | 56 | ```elixir 57 | booster = 58 | EXGBoost.train(x_train, y_train, 59 | num_class: 3, 60 | objective: :multi_softprob, 61 | num_boost_rounds: 10000, 62 | evals: [{x_train, y_train, "training"}] 63 | ) 64 | ``` 65 | 66 | And that's it! Now we can test our booster. 67 | 68 | ## Testing 69 | 70 | To get predictions from a trained booster, we can just call `EXGBoost.predict/2`. You'll notice for this problem that the booster outputs a tensor of shape `{30, 3}` where the 2nd dimension represents output probabilities for each class. We can obtain a discrete prediction for use in our accuracy measurement by computing the `argmax` along the last dimension: 71 | 72 | ```elixir 73 | preds = EXGBoost.predict(booster, x_test) |> Nx.argmax(axis: -1) 74 | Scholar.Metrics.Classification.accuracy(y_test, preds) 75 | ``` 76 | 77 | And that's it! We've successfully trained a booster on the Iris dataset with `EXGBoost`. 78 | -------------------------------------------------------------------------------- /notebooks/plotting.livemd: -------------------------------------------------------------------------------- 1 | # Plotting in EXGBoost 2 | 3 | ```elixir 4 | Mix.install([ 5 | {:exgboost, "~> 0.5"}, 6 | {:scidata, "~> 0.1"}, 7 | {:kino_vega_lite, "~> 0.1"} 8 | ]) 9 | 10 | # This assumed you launch this livebook from its location in the exgboost/notebooks folder 11 | ``` 12 | 13 | ## Introduction 14 | 15 | Much of the utility from decision trees come from their intuitiveness and ability to inform dcisions outside of the confines of a black-box model. A decision tree can be easily translated to a series of actions that can be taken on behalf of the stakeholder to achieve the desired outcome. This makes them especially useful in business decisions, where people might still want to have the final say but be as informed as possible. Additionally, tabular data is still quite popular in the business domain, which conforms to the required input for decision trees. 16 | 17 | Decision trees can be used for both regression and classification tasks, but classification tends to be what is most associated with decision trees. 18 | 19 | 20 | 21 | This notebook will go over some of the details of the `EXGBoost.Plotting` module, including using preconfiged styles, custom styling, as well as customizing the entire vidualization. 22 | 23 | ## Plotting APIs 24 | 25 | There are 2 main APIs exposed to control plotting in `EXGBoost`: 26 | 27 | * Top-level API (`EXGBoost.plot_tree/2`) 28 | 29 | * Using predefined styles 30 | * Defining custom styles 31 | * Mix of the first 2 32 | 33 | * `EXBoost.Plotting` module API 34 | 35 | * Use the Vega `data` spec defined in `EXGBoost.get_data_spec/2` 36 | * Define your own Vega spec using the data from either `EXGBoost.Plotting.to_tabular/1` or some other means 37 | 38 | We will walk through each of these in detail. 39 | 40 | Regardless of which API you choose to use, it is helpful to understand how the plotting module works (althought the higher-level API you choose to work with the less important it becomes). 41 | 42 | ## Implementation Details 43 | 44 | The plotting functionality provided in `EXGBoost` is powered by the [Vega](https://vega.github.io/vega/) JavaScript library and the Elixir [`VegaLite`](https://hexdocs.pm/vega_lite/VegaLite.html) library which provides the piping to interop with the JavaScript libraries. **We do not actually much use the Elixir API provided by the Elixir VegaLite library. It is mainly used for the purposes of rendering.** 45 | 46 | Vega is a plotting library built on top of the very powerful [D3](https://d3js.org/) JavaScript library. Vega visualizations are defined according to the respective JSON Schema specification. Vega-Lite offers a [reduced schema](https://vega.github.io/schema/vega-lite/v5.json) compared to the [full Vega spec](https://vega.github.io/schema/vega/v5.json). `EXGBoost.Plotting` leverages several transforms which are not available in the reduced Vega-Lite schema, which is the reason for targeting the lower-level API. 47 | 48 | For these reasons, unfortunately we could not just implement plotting for `EXGBoost` as a composable Vega-Lite pipeline. This makes working synamically with the spec a bit more unwieldly, but much care was taken to still make the high-level plotting API extensible, and if needed you can go straight to defining your own JSON spec. 49 | 50 | ## Setup Data 51 | 52 | We will still be using the Iris dataset for this notebook, but if you want more details about the process of training and evaluating a model please check out the `Iris Classification with Gradient Boosting` notebook. 53 | 54 | So let's proceed by setting up the Iris dataset. 55 | 56 | ```elixir 57 | {x, y} = Scidata.Iris.download() 58 | data = Enum.zip(x, y) |> Enum.shuffle() 59 | {train, test} = Enum.split(data, ceil(length(data) * 0.8)) 60 | {x_train, y_train} = Enum.unzip(train) 61 | {x_test, y_test} = Enum.unzip(test) 62 | 63 | x_train = Nx.tensor(x_train) 64 | y_train = Nx.tensor(y_train) 65 | 66 | x_test = Nx.tensor(x_test) 67 | y_test = Nx.tensor(y_test) 68 | ``` 69 | 70 | ## Train Your Booster 71 | 72 | Now go ahead and train your booster. We will use `early_stopping_rounds: 1` because we're not interested in the accuracy of the booster for this demonstration (*Note that we need to set `evals` to use early stopping*). 73 | 74 | You will notice that `EXGBoost` also provides an implementation for `Kino.Render` so that `EXGBoost.Booster`s are rendered as a plot by default. 75 | 76 | ```elixir 77 | booster = 78 | EXGBoost.train( 79 | x_train, 80 | y_train, 81 | num_class: 3, 82 | objective: :multi_softprob, 83 | num_boost_rounds: 10, 84 | evals: [{x_train, y_train, "training"}], 85 | verbose_eval: false, 86 | early_stopping_rounds: 1 87 | ) 88 | ``` 89 | 90 | You'll notice that the plot doesn't display any labels to the features in the splits, and instead only shows features labelled as "f2" etc. If you provide feature labels during training, your plot will show the splits using the feature labels. 91 | 92 | ```elixir 93 | booster = 94 | EXGBoost.train(x_train, y_train, 95 | num_class: 3, 96 | objective: :multi_softprob, 97 | num_boost_rounds: 10, 98 | evals: [{x_train, y_train, "training"}], 99 | verbose_eval: false, 100 | feature_name: ["sepal length", "sepal width", "petal length", "petal width"], 101 | early_stopping_rounds: 1 102 | ) 103 | ``` 104 | 105 | ## Top-Level API 106 | 107 | `EXGBoost.plot_tree/2` is the quickest way to customize the output of the plot. 108 | 109 | This API uses [Vega `Mark`s](https://vega.github.io/vega/docs/marks/) to describe the plot. Each of the following `Mark` options accepts any of the valid keys from their respective `Mark` type as described in the Vega documentation. 110 | 111 | **Please note that these are passed as a `Keyword`, and as such the keys must be atoms rather than strings as the Vega docs show. Valid options for this API are `camel_cased` atoms as opposed to the `pascalCased` strings the Vega docs describe, so if you wish to pass `"fontSize"` as the Vega docs show, you would instead pass it as `font_size:` in this API.** 112 | 113 | The plot is composed of the following parts: 114 | 115 | * Top-level keys: Options controlling parts of the plot outside of direct control of a `Mark`, such as `:padding`, `:autosize`, etc. Accepts any Vega top-level [top-level key](https://vega.github.io/vega/docs/specification/) in addition to several specific to this API (scuh as `:style` and `:depth`). 116 | * `:leaves`: `Mark` specifying the leaf nodes of the tree 117 | * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) 118 | * `:rect`: [Rect Mark](https://vega.github.io/vega/docs/marks/rect/) 119 | * `:splits` `Mark` specifying the split (or inner / decision) nodes of the tree 120 | * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) 121 | * `:rect`: [Rect Mark](https://vega.github.io/vega/docs/marks/rect/) 122 | * `:children`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) for the child count 123 | * `:yes` 124 | * `:path`: [Path Mark](https://vega.github.io/vega/docs/marks/path/) 125 | * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) 126 | * `:no` 127 | * `:path`: [Path Mark](https://vega.github.io/vega/docs/marks/path/) 128 | * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) 129 | 130 | `EXGBoost.plot_tree/2` defaults to outputting a `VegaLite` struct. If you pass the `:path` option it will save to a file instead. 131 | 132 | If you want to add any marks to the underlying plot you will have to use the lower-level `EXGBoost.Plotting` API, as the top-level API is only capable of customizing these marks. 133 | 134 | 135 | 136 | ### Top-Level Keys 137 | 138 | 139 | 140 | `EXGBoost` supports changing the direction of the plots through the `:rankdir` option. Avaiable directions are `[:tb, :bt, :lr, :rl]`, with top-to-bottom (`:tb`) being the default. 141 | 142 | ```elixir 143 | EXGBoost.plot_tree(booster, rankdir: :bt) 144 | ``` 145 | 146 | By default, plotting only shows one (the first) tree, but seeing as a `Booster` is really an ensemble of trees you can choose which tree to plot through the `:index` option, or set to `nil` to have a dropdown box to select the tree. 147 | 148 | ```elixir 149 | EXGBoost.plot_tree(booster, rankdir: :lr, index: nil) 150 | ``` 151 | 152 | You'll also notice that the plot is interactive, with support for scrolling, zooming, and collapsing sections of the tree. If you click on a split node you will toggle the visibility of its descendents, and the rest of the tree will fill the canvas. 153 | 154 | You can also use the `:depth` option to programatically set the max depth to display in the tree: 155 | 156 | ```elixir 157 | EXGBoost.plot_tree(booster, rankdir: :lr, index: 4, depth: 3) 158 | ``` 159 | 160 | One way to affect the canvas size is by controlling the padding. 161 | 162 | You can add padding to all side by specifying an integer for the `:padding` option 163 | 164 | ```elixir 165 | EXGBoost.plot_tree(booster, rankdir: :rl, index: 4, depth: 3, padding: 50) 166 | ``` 167 | 168 | Or specify padding for each side: 169 | 170 | ```elixir 171 | EXGBoost.plot_tree(booster, 172 | rankdir: :lr, 173 | index: 4, 174 | depth: 3, 175 | padding: [top: 5, bottom: 25, left: 50, right: 10] 176 | ) 177 | ``` 178 | 179 | You can also specify the canvas size using the `:width` and `:height` options: 180 | 181 | ```elixir 182 | EXGBoost.plot_tree(booster, 183 | rankdir: :lr, 184 | index: 4, 185 | depth: 3, 186 | width: 500, 187 | height: 500 188 | ) 189 | ``` 190 | 191 | But do note that changing the padding of a canvas does change the size, even if you specify the size using `:height` and `:width` 192 | 193 | ```elixir 194 | EXGBoost.plot_tree(booster, 195 | rankdir: :lr, 196 | index: 4, 197 | depth: 3, 198 | width: 500, 199 | height: 500, 200 | padding: 10 201 | ) 202 | ``` 203 | 204 | You can change the dimensions of all nodes through the `:node_height` and `:node_width` options: 205 | 206 | ```elixir 207 | EXGBoost.plot_tree(booster, rankdir: :lr, index: 4, depth: 3, node_width: 60, node_height: 60) 208 | ``` 209 | 210 | Or change the space between nodes using the `:space_between` option. 211 | 212 | **Note that the size of the accompanying nodes and/or text will change to accomodate the new `:space_between` option while trying to maintain the canvas size.** 213 | 214 | ```elixir 215 | EXGBoost.plot_tree( 216 | booster, 217 | rankdir: :lr, 218 | index: 4, 219 | depth: 3, 220 | space_between: [nodes: 200] 221 | ) 222 | ``` 223 | 224 | So if you want to add the space between while not changing the size of the nodes you might need to manually adjust the canvas size: 225 | 226 | ```elixir 227 | EXGBoost.plot_tree( 228 | booster, 229 | rankdir: :lr, 230 | index: 4, 231 | depth: 3, 232 | space_between: [nodes: 200], 233 | height: 800 234 | ) 235 | ``` 236 | 237 | ```elixir 238 | EXGBoost.plot_tree( 239 | booster, 240 | rankdir: :lr, 241 | index: 4, 242 | depth: 3, 243 | space_between: [levels: 200] 244 | ) 245 | ``` 246 | 247 | ### Mark Options 248 | 249 | The options controlling the appearance of individual marks all conform to a similar API. You can refer to the options and pre-defined defaults for a subset of the allowed options, but you can also pass other options so long as they are allowed by the Vega Mark spec (as defined [here](#cell-y5oxrrri4daa6xt5)) 250 | 251 | ```elixir 252 | EXGBoost.plot_tree( 253 | booster, 254 | rankdir: :bt, 255 | index: 4, 256 | depth: 3, 257 | space_between: [levels: 200], 258 | yes: [ 259 | text: [font_size: 18, fill: :teal] 260 | ], 261 | no: [ 262 | text: [font_size: 20] 263 | ], 264 | node_width: 100 265 | ) 266 | ``` 267 | 268 | Most marks accept an `:opacity` option that you can use to effectively hide the mark: 269 | 270 | ```elixir 271 | EXGBoost.plot_tree( 272 | booster, 273 | rankdir: :lr, 274 | index: 4, 275 | depth: 3, 276 | splits: [ 277 | text: [opacity: 0], 278 | rect: [opacity: 0], 279 | children: [opacity: 1] 280 | ] 281 | ) 282 | ``` 283 | 284 | And `text` marks accept normal text options such as `:fill`, `:font_size`, and `:font`: 285 | 286 | ```elixir 287 | EXGBoost.plot_tree( 288 | booster, 289 | node_width: 250, 290 | splits: [ 291 | text: [font: "Helvetica Neue", font_size: 20, fill: "orange"] 292 | ], 293 | space_between: [levels: 20] 294 | ) 295 | ``` 296 | 297 | ### Styles 298 | 299 | There are a set of provided pre-configured settings for the top-level API that you may optionally use. You can refer to the `EXGBoost.Plottings.Styles` docs to see a gallery of each style in action. You can specify a style with the `:style` option in `EXGBoost.plot_tree/2`. 300 | 301 | You can still specify custom settings along with using a style. Most styles only specify a subset of the total possible settings, but you are free to specify any other allowed keys and they will be merged with the style. Any options passed explicitly to the option **does** take precedence over the style options. 302 | 303 | For example, let's look at the `:solarized_dark` style: 304 | 305 | ```elixir 306 | EXGBoost.Plotting.solarized_dark() |> Keyword.take([:background, :height]) |> IO.inspect() 307 | EXGBoost.plot_tree(booster, style: :solarized_dark) 308 | ``` 309 | 310 | You can see that it defines a background color of `#002b36` but does not restrict what the height must be. 311 | 312 | ```elixir 313 | EXGBoost.plot_tree(booster, style: :solarized_dark, background: "white", height: 200) 314 | ``` 315 | 316 | We specified both `:background` and `:height` here, and the background specified in the option supercedes the one from the style. 317 | 318 | You can also always get the style specification as a `Keyword` which can be passed to `EXGBoost.plot_tree/2` manually, making any needed changes yourself, like so: 319 | 320 | ```elixir 321 | custom_style = EXGBoost.Plotting.solarized_dark() |> Keyword.put(:background, "white") 322 | EXGBoost.plot_tree(booster, style: custom_style) 323 | ``` 324 | 325 | You can also programatically check which styles are available: 326 | 327 | ```elixir 328 | EXGBoost.Plotting.get_styles() 329 | ``` 330 | 331 | ### Configuration 332 | 333 | You can also set defaults for the top-level API using an `Application` configuration for `EXGBoost` under the `:plotting` key. Since the defaults are collected from your configuration file at compile-time, anything you set during runtime, even if you set it to the Application environment, will not be registered as defaults. 334 | 335 | For example, if you just want to change the default pre-configured style you can do: 336 | 337 | 338 | 339 | ```elixir 340 | Mix.install([ 341 | {:exgboost, path: Path.join(__DIR__, ".."), env: :dev}, 342 | ], 343 | config: 344 | [ 345 | exgboost: [ 346 | plotting: [ 347 | style: :solarized_dark, 348 | ]] 349 | ], 350 | lockfile: :exgboost) 351 | ``` 352 | 353 | You can also make one-off changes to any of the settings with this method. In effect, this turns into a default custom style. **Just make sure to set `style: nil` to ensure that the `style` option doesn't supercede any of your settings.** Here's an example of that: 354 | 355 | 356 | 357 | ```elixir 358 | default_style = 359 | [ 360 | style: nil, 361 | background: "#3f3f3f", 362 | leaves: [ 363 | # Foreground 364 | text: [fill: "#dcdccc", font_size: 12, font_style: "normal", font_weight: "normal"], 365 | # Comment 366 | rect: [fill: "#7f9f7f", stroke: "#7f9f7f"] 367 | ], 368 | splits: [ 369 | # Foreground 370 | text: [fill: "#dcdccc", font_size: 12, font_style: "normal", font_weight: "bold"], 371 | # Comment 372 | rect: [fill: "#7f9f7f", stroke: "#7f9f7f"], 373 | # Selection 374 | children: [fill: "#2b2b2b", stroke: "#2b2b2b"] 375 | ], 376 | yes: [ 377 | # Green 378 | text: [fill: "#7f9f7f"], 379 | # Selection 380 | path: [stroke: "#2b2b2b"] 381 | ], 382 | no: [ 383 | # Red 384 | text: [fill: "#cc9393"], 385 | # Selection 386 | path: [stroke: "#2b2b2b"] 387 | ] 388 | ] 389 | 390 | Mix.install([ 391 | {:exgboost, path: Path.join(__DIR__, ".."), env: :dev}, 392 | ], 393 | config: 394 | [ 395 | exgboost: [ 396 | plotting: default_style, 397 | ] 398 | ] 399 | ) 400 | ``` 401 | 402 | **NOTE: When you specify a parameter in the configuration, it is merged with the defaults which is different from runtime behavior.** 403 | 404 | At any point, you can check what your default settings are by using `EXGBoost.Plotting.get_defaults/0` 405 | 406 | ```elixir 407 | EXGBoost.Plotting.get_defaults() 408 | ``` 409 | 410 | ## Low-Level API 411 | 412 | If you find yourself needing more granular control over your plots, you can reach towards the `EXGBoost.Plotting` module. This module houses the `EXGBoost.Plotting.plot/2` function, which is what is used under the hood from the `EXGBoost.plot_tree/2` top-level API. This module also has the `get_data_spec/2` function, as well as the `to_tabular/1` function, both of which can be used to specify your own Vega specification. Lastly, the module also houses all of the pre-configured styles, which are 0-arity functions which output the `Keyword`s containing their respective style's options that can be passed to the plotting APIs. 413 | 414 | Let's briefly go over the `to_tabular/1` and `get_data_spec/2` functions: 415 | 416 | 417 | 418 | The `to_tabular/1` function is used to convert a `Booster`, which is formatted as a tree structure, to a tabular format which can be ingested specifically by the [Vega Stratify transform](https://vega.github.io/vega/docs/transforms/stratify/). It returns a list of "nodes", which are just `Map`s with info about each node in the tree. 419 | 420 | ```elixir 421 | EXGBoost.Plotting.to_tabular(booster) |> hd 422 | ``` 423 | 424 | You can use this function if you want to have complete control over the visualization, and just want a bit of a head start with respect to data transformation for converting the `Booster` into a more digestible format. 425 | 426 | 427 | 428 | The `get_data_source/2` function is used if you want to use the provided [Vega data specification](https://vega.github.io/vega/docs/data/). This is for those who want to only focus on implementing your own [Vega Marks](https://vega.github.io/vega/docs/marks/), and want to leverage the data transformation pipeline that powers the top-level API. 429 | 430 | The data transformation used is the following pipeline: 431 | 432 | `to_tabular/1` -> [Filter](https://vega.github.io/vega/docs/transforms/filter/) (by tree index) -> [Stratify](https://vega.github.io/vega/docs/transforms/stratify/) -> [Tree](https://vega.github.io/vega/docs/transforms/tree/) 433 | 434 | ```elixir 435 | EXGBoost.Plotting.get_data_spec(booster, rankdir: :bt) 436 | ``` 437 | 438 | The Vega fields which are not included with `get_data_spec/2` and are included in `plot/2` are: 439 | 440 | * [Marks](https://vega.github.io/vega/docs/marks/) 441 | * [Scales](https://vega.github.io/vega/docs/scales/) 442 | * [Signals](https://vega.github.io/vega/docs/signals/) 443 | 444 | You can make a completely valid plot using only the Data from `get_data_specs/2` and adding the marks you need. 445 | -------------------------------------------------------------------------------- /test/data/another.txt: -------------------------------------------------------------------------------- 1 | another test file for the thing -------------------------------------------------------------------------------- /test/data/model.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acalejos/exgboost/782d00f9293def811dd4f79af5c215a03b484f46/test/data/model.json -------------------------------------------------------------------------------- /test/data/test.conf: -------------------------------------------------------------------------------- 1 | # General Parameters, see comment for each definition 2 | # can be gbtree or gblinear 3 | booster = gbtree 4 | # choose logistic regression loss function for binary classification 5 | objective = binary:logistic 6 | 7 | # Tree Booster Parameters 8 | # step size shrinkage 9 | eta = 1.0 10 | # minimum loss reduction required to make a further partition 11 | gamma = 1.0 12 | # minimum sum of instance weight(hessian) needed in a child 13 | min_child_weight = 1 14 | # maximum depth of a tree 15 | max_depth = 3 16 | 17 | # Task Parameters 18 | # the number of round to do boosting 19 | num_round = 2 20 | # 0 means do not save any model except the final round model 21 | save_period = 0 22 | # The path of training data 23 | data = "/home/semafore/exgboost/test/data/testfile.txt" 24 | # The path of validation data, used to monitor training process, here [test] sets name of the validation set 25 | #eval[test] = "agaricus.txt.test" 26 | # The path of test data 27 | #test:data = "agaricus.txt.test" -------------------------------------------------------------------------------- /test/data/testfile.txt: -------------------------------------------------------------------------------- 1 | 0.5,0.2,0.1,0.2,0 2 | 0.1,0.3,0.5,0.1,1 3 | 0.3,0.5,0.1,0.1,0 4 | 0.2,0.2,0.2,0.4,1 -------------------------------------------------------------------------------- /test/data/train.txt: -------------------------------------------------------------------------------- 1 | 1 101:1.2 102:0.03 2 | 0 1:2.1 10001:300 10002:400 3 | 0 0:1.3 1:0.3 4 | 1 0:0.01 1:0.3 5 | 0 0:0.2 1:0.3 -------------------------------------------------------------------------------- /test/exgboost_test.exs: -------------------------------------------------------------------------------- 1 | defmodule EXGBoostTest do 2 | alias EXGBoost.DMatrix 3 | alias EXGBoost.Booster 4 | use ExUnit.Case, async: true 5 | doctest EXGBoost 6 | 7 | setup do 8 | %{key: Nx.Random.key(42)} 9 | end 10 | 11 | test "dmatrix_from_tensor", context do 12 | nrows = :rand.uniform(10) 13 | ncols = :rand.uniform(10) 14 | {tensor, _new_key} = Nx.Random.normal(context.key, 0, 1, shape: {nrows, ncols}) 15 | dmatrix = EXGBoost.DMatrix.from_tensor(tensor, format: :dense) 16 | assert DMatrix.get_num_rows(dmatrix) == nrows 17 | assert DMatrix.get_num_cols(dmatrix) == ncols 18 | assert DMatrix.get_num_non_missing(dmatrix) == nrows * ncols 19 | assert DMatrix.get_feature_names(dmatrix) == [] 20 | assert DMatrix.get_feature_types(dmatrix) == [] 21 | assert DMatrix.get_group(dmatrix) == [] 22 | 23 | {_indptr, _indices, data} = DMatrix.get_data(dmatrix) 24 | assert length(data) == nrows * ncols 25 | end 26 | 27 | test "train_booster", context do 28 | nrows = :rand.uniform(10) 29 | ncols = :rand.uniform(10) 30 | {x, new_key} = Nx.Random.normal(context.key, 0, 1, shape: {nrows, ncols}) 31 | {y, _new_key} = Nx.Random.normal(new_key, 0, 1, shape: {nrows}) 32 | num_boost_round = 10 33 | booster = EXGBoost.train(x, y, num_boost_rounds: num_boost_round, tree_method: :hist) 34 | assert Booster.get_boosted_rounds(booster) == num_boost_round 35 | end 36 | 37 | test "quantile cut", context do 38 | nrows = :rand.uniform(10) 39 | ncols = :rand.uniform(10) 40 | {x, new_key} = Nx.Random.normal(context.key, 0, 1, shape: {nrows, ncols}) 41 | {y, _new_key} = Nx.Random.normal(new_key, 0, 1, shape: {nrows}) 42 | num_boost_round = 10 43 | dmat = DMatrix.from_tensor(x, y, format: :dense) 44 | 45 | _booster = 46 | EXGBoost.Training.train(dmat, num_boost_rounds: num_boost_round, tree_method: :hist) 47 | 48 | {indptr, data} = DMatrix.get_quantile_cut(dmat) 49 | end 50 | 51 | test "booster params" do 52 | x = Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 53 | y = Nx.tensor([0, 1, 2]) 54 | num_boost_round = 10 55 | 56 | booster = 57 | EXGBoost.train(x, y, 58 | num_boost_rounds: num_boost_round, 59 | tree_method: :hist, 60 | obj: :multi_softprob, 61 | num_class: 3 62 | ) 63 | 64 | assert Booster.get_boosted_rounds(booster) == num_boost_round 65 | end 66 | 67 | test "train with container" do 68 | x = {Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])} 69 | y = {Nx.tensor([0, 1, 2])} 70 | num_boost_round = 10 71 | 72 | booster = 73 | EXGBoost.train(x, y, 74 | num_boost_rounds: num_boost_round, 75 | tree_method: :hist, 76 | objective: :multi_softprob, 77 | num_class: 3 78 | ) 79 | 80 | assert Booster.get_boosted_rounds(booster) == num_boost_round 81 | end 82 | 83 | test "predict", context do 84 | nrows = :rand.uniform(10) 85 | ncols = :rand.uniform(10) 86 | {x, new_key} = Nx.Random.normal(context.key, 0, 1, shape: {nrows, ncols}) 87 | {y, _new_key} = Nx.Random.normal(new_key, 0, 1, shape: {nrows}) 88 | num_boost_round = 10 89 | booster = EXGBoost.train(x, y, num_boost_rounds: num_boost_round, tree_method: :hist) 90 | dmat_preds = EXGBoost.predict(booster, x) 91 | inplace_preds_no_proxy = EXGBoost.inplace_predict(booster, x) 92 | # TODO: Test inplace_predict with proxy 93 | # inplace_preds_with_proxy = EXGBoost.inplace_predict(booster, x, base_margin: true) 94 | assert dmat_preds.shape == y.shape 95 | assert inplace_preds_no_proxy.shape == y.shape 96 | end 97 | 98 | test "predict with container", context do 99 | nrows = :rand.uniform(10) 100 | ncols = :rand.uniform(10) 101 | {x, _new_key} = Nx.Random.normal(context[:key], 0, 1, shape: {nrows, ncols}) 102 | {y, _new_key} = Nx.Random.normal(context[:key], 0, 1, shape: {nrows}) 103 | num_boost_round = 10 104 | booster = EXGBoost.train({x}, {y}, num_boost_rounds: num_boost_round, tree_method: :hist) 105 | dmat_preds = EXGBoost.predict(booster, {x}) 106 | inplace_preds_no_proxy = EXGBoost.inplace_predict(booster, {x}) 107 | # TODO: Test inplace_predict with proxy 108 | # inplace_preds_with_proxy = EXGBoost.inplace_predict(booster, x, base_margin: true) 109 | assert dmat_preds.shape == y.shape 110 | assert inplace_preds_no_proxy.shape == y.shape 111 | end 112 | 113 | test "train with learning rates", context do 114 | nrows = :rand.uniform(10) 115 | ncols = :rand.uniform(10) 116 | {x, new_key} = Nx.Random.normal(context.key, 0, 1, shape: {nrows, ncols}) 117 | {y, _new_key} = Nx.Random.normal(new_key, 0, 1, shape: {nrows}) 118 | num_boost_round = 10 119 | lrs = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1] 120 | lrs_fun = fn i -> i / 10 end 121 | 122 | EXGBoost.train(x, y, 123 | num_boost_rounds: num_boost_round, 124 | tree_method: :hist, 125 | learning_rates: lrs 126 | ) 127 | 128 | EXGBoost.train(x, y, 129 | num_boost_rounds: num_boost_round, 130 | tree_method: :hist, 131 | learning_rates: lrs_fun 132 | ) 133 | end 134 | 135 | test "train with early stopping", context do 136 | nrows = :rand.uniform(10) 137 | ncols = :rand.uniform(10) 138 | {x, new_key} = Nx.Random.normal(context.key, 0, 1, shape: {nrows, ncols}) 139 | {y, _new_key} = Nx.Random.normal(new_key, 0, 1, shape: {nrows}) 140 | 141 | {booster, _} = 142 | ExUnit.CaptureIO.with_io(fn -> 143 | EXGBoost.train(x, y, 144 | num_boost_rounds: 10, 145 | early_stopping_rounds: 1, 146 | evals: [{x, y, "validation"}], 147 | tree_method: :hist, 148 | eval_metric: [:rmse, :logloss] 149 | ) 150 | end) 151 | 152 | refute is_nil(booster.best_iteration) 153 | refute is_nil(booster.best_score) 154 | 155 | # If no eval metric is provided, the default metric is used. If the default 156 | # metric is disabled, an error is raised. 157 | assert_raise ArgumentError, 158 | fn -> 159 | ExUnit.CaptureIO.with_io(fn -> 160 | EXGBoost.train(x, y, 161 | disable_default_eval_metric: true, 162 | num_boost_rounds: 10, 163 | early_stopping_rounds: 1, 164 | evals: [{x, y, "validation"}], 165 | tree_method: :hist 166 | ) 167 | end) 168 | end 169 | 170 | refute is_nil(booster.best_iteration) 171 | refute is_nil(booster.best_score) 172 | end 173 | 174 | test "eval with multiple metrics", context do 175 | nrows = :rand.uniform(10) 176 | ncols = :rand.uniform(10) 177 | {x, new_key} = Nx.Random.normal(context.key, 0, 1, shape: {nrows, ncols}) 178 | {y, _new_key} = Nx.Random.normal(new_key, 0, 1, shape: {nrows}) 179 | num_boost_round = 10 180 | 181 | booster = 182 | EXGBoost.train(x, y, 183 | num_boost_rounds: num_boost_round, 184 | tree_method: :hist, 185 | eval_metric: :rmse 186 | ) 187 | 188 | dmat = DMatrix.from_tensor(x, y, format: :dense) 189 | [{_ev_name, metric_name, _metric_value}] = Booster.eval(booster, dmat) 190 | 191 | assert metric_name == "rmse" 192 | 193 | Booster.set_params(booster, eval_metric: :logloss) 194 | 195 | metric_results = Booster.eval(booster, dmat) 196 | 197 | assert length(metric_results) == 2 198 | end 199 | 200 | test "save and load model to and from file", context do 201 | nrows = :rand.uniform(10) 202 | ncols = :rand.uniform(10) 203 | {x, _new_key} = Nx.Random.normal(context[:key], 0, 1, shape: {nrows, ncols}) 204 | {y, _new_key} = Nx.Random.normal(context[:key], 0, 1, shape: {nrows}) 205 | num_boost_round = 10 206 | 207 | booster = 208 | EXGBoost.train(x, y, 209 | num_boost_rounds: num_boost_round, 210 | tree_method: :hist, 211 | eval_metric: :rmse 212 | ) 213 | 214 | EXGBoost.write_model(booster, "test") 215 | assert File.exists?("test.json") 216 | bst = EXGBoost.read_model("test.json") 217 | assert is_struct(bst, EXGBoost.Booster) 218 | File.rm!("test.json") 219 | end 220 | 221 | test "save and load weights to and from file", context do 222 | nrows = :rand.uniform(10) 223 | ncols = :rand.uniform(10) 224 | {x, _new_key} = Nx.Random.normal(context[:key], 0, 1, shape: {nrows, ncols}) 225 | {y, _new_key} = Nx.Random.normal(context[:key], 0, 1, shape: {nrows}) 226 | num_boost_round = 10 227 | 228 | booster = 229 | EXGBoost.train(x, y, 230 | num_boost_rounds: num_boost_round, 231 | tree_method: :hist, 232 | eval_metric: :rmse 233 | ) 234 | 235 | EXGBoost.write_weights(booster, "test") 236 | assert File.exists?("test.json") 237 | bst = EXGBoost.read_weights("test.json") 238 | assert is_struct(bst, EXGBoost.Booster) 239 | File.rm!("test.json") 240 | end 241 | 242 | test "save and load config to and from file", context do 243 | nrows = :rand.uniform(10) 244 | ncols = :rand.uniform(10) 245 | {x, _new_key} = Nx.Random.normal(context[:key], 0, 1, shape: {nrows, ncols}) 246 | {y, _new_key} = Nx.Random.normal(context[:key], 0, 1, shape: {nrows}) 247 | num_boost_round = 10 248 | 249 | booster = 250 | EXGBoost.train(x, y, 251 | num_boost_rounds: num_boost_round, 252 | tree_method: :hist, 253 | eval_metric: :rmse 254 | ) 255 | 256 | EXGBoost.write_config(booster, "test") 257 | assert File.exists?("test.json") 258 | bst = EXGBoost.read_config("test.json") 259 | assert is_struct(bst, EXGBoost.Booster) 260 | File.rm!("test.json") 261 | end 262 | 263 | test "serialize and deserialize model to and from buffer", context do 264 | nrows = :rand.uniform(10) 265 | ncols = :rand.uniform(10) 266 | {x, _new_key} = Nx.Random.normal(context[:key], 0, 1, shape: {nrows, ncols}) 267 | {y, _new_key} = Nx.Random.normal(context[:key], 0, 1, shape: {nrows}) 268 | num_boost_round = 10 269 | 270 | booster = 271 | EXGBoost.train(x, y, 272 | num_boost_rounds: num_boost_round, 273 | tree_method: :hist, 274 | eval_metric: :rmse 275 | ) 276 | 277 | buffer = EXGBoost.dump_model(booster) 278 | assert is_binary(buffer) 279 | bst = EXGBoost.load_model(buffer) 280 | assert is_struct(bst, EXGBoost.Booster) 281 | end 282 | 283 | test "serialize and deserialize weights to and from buffer", context do 284 | nrows = :rand.uniform(10) 285 | ncols = :rand.uniform(10) 286 | {x, _new_key} = Nx.Random.normal(context[:key], 0, 1, shape: {nrows, ncols}) 287 | {y, _new_key} = Nx.Random.normal(context[:key], 0, 1, shape: {nrows}) 288 | num_boost_round = 10 289 | 290 | booster = 291 | EXGBoost.train(x, y, 292 | num_boost_rounds: num_boost_round, 293 | tree_method: :hist, 294 | eval_metric: :rmse 295 | ) 296 | 297 | buffer = EXGBoost.dump_weights(booster) 298 | assert is_binary(buffer) 299 | bst = EXGBoost.load_weights(buffer) 300 | assert is_struct(bst, EXGBoost.Booster) 301 | end 302 | 303 | test "serialize and deserialize config to and from buffer", context do 304 | nrows = :rand.uniform(10) 305 | ncols = :rand.uniform(10) 306 | {x, _new_key} = Nx.Random.normal(context[:key], 0, 1, shape: {nrows, ncols}) 307 | {y, _new_key} = Nx.Random.normal(context[:key], 0, 1, shape: {nrows}) 308 | num_boost_round = 10 309 | 310 | booster = 311 | EXGBoost.train(x, y, 312 | num_boost_rounds: num_boost_round, 313 | tree_method: :hist, 314 | eval_metric: :rmse 315 | ) 316 | 317 | buffer = EXGBoost.dump_config(booster) 318 | assert is_binary(buffer) 319 | config = EXGBoost.load_config(buffer) 320 | assert is_map(config) 321 | end 322 | 323 | test "array interface get tensor" do 324 | tensor = Nx.tensor([[1, 2, 3], [4, 5, 6]]) 325 | array_interface = EXGBoost.ArrayInterface.from_tensor(tensor) 326 | # Set this to nil so we can test the get_tensor reconstruction 327 | array_interface = struct(array_interface, tensor: nil) 328 | 329 | assert EXGBoost.ArrayInterface.get_tensor(array_interface) == tensor 330 | end 331 | 332 | describe "errors" do 333 | setup %{key: key0} do 334 | {nrows, ncols} = {10, 10} 335 | {x, key1} = Nx.Random.normal(key0, 0, 1, shape: {nrows, ncols}) 336 | {y, _key2} = Nx.Random.normal(key1, 0, 1, shape: {nrows}) 337 | %{x: x, y: y} 338 | end 339 | 340 | test "duplicate callback names result in an error", %{x: x, y: y} do 341 | # This callback's name is the same as one of the default callbacks. 342 | custom_callback = EXGBoost.Training.Callback.new(:before_training, & &1, :monitor_metrics) 343 | 344 | assert_raise ArgumentError, 345 | """ 346 | Found duplicate callback names. 347 | 348 | Name counts: 349 | 350 | * {:eval_metrics, 1} 351 | 352 | * {:monitor_metrics, 2} 353 | """, 354 | fn -> 355 | EXGBoost.train(x, y, 356 | callbacks: [custom_callback], 357 | eval_metric: [:rmse, :logloss], 358 | evals: [{x, y, "validation"}] 359 | ) 360 | end 361 | end 362 | 363 | test "callback with bad function results in helpful error", %{x: x, y: y} do 364 | bad_fun = fn state -> %{state | status: :bad_status} end 365 | bad_callback = EXGBoost.Training.Callback.new(:before_training, bad_fun, :bad_callback) 366 | 367 | assert_raise ArgumentError, 368 | "`status` must be `:cont` or `:halt`, found: `:bad_status`.", 369 | fn -> EXGBoost.train(x, y, callbacks: [bad_callback]) end 370 | end 371 | end 372 | end 373 | -------------------------------------------------------------------------------- /test/model.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acalejos/exgboost/782d00f9293def811dd4f79af5c215a03b484f46/test/model.json -------------------------------------------------------------------------------- /test/nif_test.exs: -------------------------------------------------------------------------------- 1 | defmodule NifTest do 2 | use ExUnit.Case, async: true 3 | import EXGBoost.Internal 4 | import EXGBoost.ArrayInterface, only: [from_tensor: 1] 5 | 6 | test "exgboost_version" do 7 | assert EXGBoost.NIF.xgboost_version() |> unwrap!() != :error 8 | end 9 | 10 | test "build_info" do 11 | assert EXGBoost.NIF.xgboost_build_info() |> unwrap!() != :error 12 | end 13 | 14 | test "set_global_config" do 15 | assert EXGBoost.NIF.set_global_config('{"use_rmm":false,"verbosity":1}') == :ok 16 | 17 | assert EXGBoost.NIF.set_global_config('{"use_rmm":false,"verbosity": true}') == 18 | {:error, 'Invalid Parameter format for verbosity expect int but value=\'true\''} 19 | end 20 | 21 | test "get_global_config" do 22 | assert EXGBoost.NIF.get_global_config() |> unwrap!() != :error 23 | end 24 | 25 | test "dmatrix_create_from_uri" do 26 | config = Jason.encode!(%{uri: "test/data/train.txt?format=libsvm"}) 27 | assert EXGBoost.NIF.dmatrix_create_from_uri(config) |> unwrap!() != :error 28 | end 29 | 30 | test "dmatrix_create_from_sparse" do 31 | config = Jason.encode!(%{"missing" => 0.0}) 32 | indptr = Nx.tensor([0, 22]) 33 | ncols = 127 34 | 35 | indices = 36 | Nx.tensor([ 37 | 1, 38 | 9, 39 | 19, 40 | 21, 41 | 24, 42 | 34, 43 | 36, 44 | 39, 45 | 42, 46 | 53, 47 | 56, 48 | 65, 49 | 69, 50 | 77, 51 | 86, 52 | 88, 53 | 92, 54 | 95, 55 | 102, 56 | 106, 57 | 117, 58 | 122 59 | ]) 60 | 61 | data = 62 | Nx.tensor([ 63 | 1.0, 64 | 1.0, 65 | 1.0, 66 | 1.0, 67 | 1.0, 68 | 1.0, 69 | 1.0, 70 | 1.0, 71 | 1.0, 72 | 1.0, 73 | 1.0, 74 | 1.0, 75 | 1.0, 76 | 1.0, 77 | 1.0, 78 | 1.0, 79 | 1.0, 80 | 1.0, 81 | 1.0, 82 | 1.0, 83 | 1.0, 84 | 1.0 85 | ]) 86 | 87 | assert EXGBoost.NIF.dmatrix_create_from_sparse( 88 | from_tensor(indptr) |> Jason.encode!(), 89 | from_tensor(indices) |> Jason.encode!(), 90 | from_tensor(data) |> Jason.encode!(), 91 | ncols, 92 | config, 93 | "csr" 94 | ) 95 | |> unwrap!() != 96 | :error 97 | 98 | assert EXGBoost.NIF.dmatrix_create_from_sparse( 99 | from_tensor(indptr) |> Jason.encode!(), 100 | from_tensor(indices) |> Jason.encode!(), 101 | from_tensor(data) |> Jason.encode!(), 102 | ncols, 103 | config, 104 | "csc" 105 | ) 106 | |> unwrap!() != 107 | :error 108 | 109 | {status, _} = 110 | EXGBoost.NIF.dmatrix_create_from_sparse( 111 | from_tensor(indptr) |> Jason.encode!(), 112 | from_tensor(indices) |> Jason.encode!(), 113 | from_tensor(data) |> Jason.encode!(), 114 | ncols, 115 | config, 116 | "csa" 117 | ) 118 | 119 | assert status == :error 120 | end 121 | 122 | test "test_dmatrix_create_from_dense" do 123 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 124 | array_interface = from_tensor(mat) |> Jason.encode!() 125 | 126 | config = Jason.encode!(%{"missing" => -1.0}) 127 | 128 | assert EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 129 | |> unwrap!() != 130 | :error 131 | end 132 | 133 | test "test_dmatrix_set_str_feature_info" do 134 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 135 | array_interface = from_tensor(mat) |> Jason.encode!() 136 | 137 | config = Jason.encode!(%{"missing" => -1.0}) 138 | 139 | dmat = 140 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 141 | |> unwrap!() 142 | 143 | assert EXGBoost.NIF.dmatrix_set_str_feature_info(dmat, 'feature_name', [ 144 | 'name', 145 | 'color', 146 | 'length' 147 | ]) == :ok 148 | end 149 | 150 | test "test_dmatrix_get_str_feature_info" do 151 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 152 | array_interface = from_tensor(mat) |> Jason.encode!() 153 | 154 | config = Jason.encode!(%{"missing" => -1.0}) 155 | 156 | dmat = 157 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 158 | |> unwrap!() 159 | 160 | EXGBoost.NIF.dmatrix_set_str_feature_info(dmat, 'feature_name', ['name', 'color', 'length']) 161 | 162 | assert EXGBoost.NIF.dmatrix_get_str_feature_info(dmat, 'feature_name') |> unwrap!() 163 | end 164 | 165 | test "dmatrix_num_row" do 166 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 167 | array_interface = from_tensor(mat) |> Jason.encode!() 168 | 169 | config = Jason.encode!(%{"missing" => -1.0}) 170 | 171 | dmat = 172 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 173 | |> unwrap!() 174 | 175 | assert EXGBoost.NIF.dmatrix_num_row(dmat) |> unwrap! == 2 176 | end 177 | 178 | test "dmatrix_num_col" do 179 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 180 | array_interface = from_tensor(mat) |> Jason.encode!() 181 | 182 | config = Jason.encode!(%{"missing" => -1.0}) 183 | 184 | dmat = 185 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 186 | |> unwrap!() 187 | 188 | assert EXGBoost.NIF.dmatrix_num_col(dmat) |> unwrap! == 3 189 | end 190 | 191 | test "dmatrix_num_non_missing" do 192 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 193 | array_interface = from_tensor(mat) |> Jason.encode!() 194 | 195 | config = Jason.encode!(%{"missing" => -1.0}) 196 | 197 | dmat = 198 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 199 | |> unwrap!() 200 | 201 | assert EXGBoost.NIF.dmatrix_num_non_missing(dmat) |> unwrap! == 6 202 | end 203 | 204 | test "dmatrix_set_info_from_interface" do 205 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 206 | array_interface = from_tensor(mat) |> Jason.encode!() 207 | labels = Nx.tensor([1.0, 0.0]) 208 | 209 | config = Jason.encode!(%{"missing" => -1.0}) 210 | 211 | dmat = 212 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 213 | |> unwrap!() 214 | 215 | label_interface = from_tensor(labels) |> Jason.encode!() 216 | 217 | assert EXGBoost.NIF.dmatrix_set_info_from_interface( 218 | dmat, 219 | 'label', 220 | label_interface 221 | ) == 222 | :ok 223 | 224 | assert EXGBoost.NIF.dmatrix_set_info_from_interface( 225 | dmat, 226 | 'unsupported', 227 | label_interface 228 | ) != :ok 229 | end 230 | 231 | test "dmatrix_save_binary" do 232 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 233 | array_interface = from_tensor(mat) |> Jason.encode!() 234 | labels = Nx.tensor([1.0, 0.0]) 235 | 236 | config = Jason.encode!(%{"missing" => -1.0}) 237 | 238 | dmat = 239 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 240 | |> unwrap!() 241 | 242 | interface = from_tensor(labels) |> Jason.encode!() 243 | 244 | EXGBoost.NIF.dmatrix_set_info_from_interface(dmat, 'label', interface) 245 | 246 | path = Path.join(System.tmp_dir!(), "test.buffer") |> String.to_charlist() 247 | assert EXGBoost.NIF.dmatrix_save_binary(dmat, path, 1) == :ok 248 | end 249 | 250 | test "dmatrix_get_float_info" do 251 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 252 | array_interface = from_tensor(mat) |> Jason.encode!() 253 | weights = Nx.tensor([1.0, 0.0]) 254 | 255 | config = Jason.encode!(%{"missing" => -1.0}) 256 | 257 | dmat = 258 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 259 | |> unwrap!() 260 | 261 | interface = from_tensor(weights) |> Jason.encode!() 262 | EXGBoost.NIF.dmatrix_set_info_from_interface(dmat, 'feature_weights', interface) 263 | 264 | assert EXGBoost.NIF.dmatrix_get_float_info(dmat, 'feature_weights') |> unwrap!() == 265 | Nx.to_list(weights) 266 | end 267 | 268 | test "dmatrix_get_data_as_csr" do 269 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 270 | array_interface = from_tensor(mat) |> Jason.encode!() 271 | 272 | config = Jason.encode!(%{"missing" => -1.0}) 273 | 274 | dmat = 275 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 276 | |> unwrap!() 277 | 278 | assert EXGBoost.NIF.dmatrix_get_data_as_csr(dmat, Jason.encode!(%{})) |> unwrap!() != :error 279 | end 280 | 281 | test "dmatrix_slice" do 282 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) 283 | array_interface = from_tensor(mat) |> Jason.encode!() 284 | 285 | config = Jason.encode!(%{"missing" => -1.0}) 286 | 287 | dmat = 288 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 289 | |> unwrap!() 290 | 291 | # We do this because the C API uses non fixed-width types so we need to know the size they're expecting from int 292 | c_int_size = EXGBoost.NIF.get_int_size() |> unwrap!() 293 | tensor_size = c_int_size * 8 294 | 295 | dmatrix = 296 | EXGBoost.NIF.dmatrix_slice( 297 | dmat, 298 | Nx.to_binary(Nx.tensor([0, 1], type: {:s, tensor_size})), 299 | 1 300 | ) 301 | |> unwrap!() 302 | 303 | assert EXGBoost.NIF.dmatrix_num_row(dmatrix) |> unwrap!() == 2 304 | 305 | {status, _e} = 306 | EXGBoost.NIF.dmatrix_slice( 307 | dmat, 308 | Nx.to_binary(Nx.tensor([0, 1], type: {:s, tensor_size})), 309 | 2 310 | ) 311 | 312 | assert status == :error 313 | 314 | {status, _e} = EXGBoost.NIF.dmatrix_slice(dmat, Nx.to_binary(Nx.tensor([1.5, 1.6])), 2) 315 | 316 | assert status == :error 317 | end 318 | 319 | test "booster_create" do 320 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 321 | mat2 = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 322 | array_interface = from_tensor(mat) |> Jason.encode!() 323 | array_interface2 = from_tensor(mat2) |> Jason.encode!() 324 | 325 | config = Jason.encode!(%{"missing" => -1.0}) 326 | 327 | dmat = 328 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 329 | |> unwrap!() 330 | 331 | dmat2 = 332 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface2, config) 333 | |> unwrap!() 334 | 335 | assert EXGBoost.NIF.booster_create([dmat]) |> unwrap!() != :error 336 | assert EXGBoost.NIF.booster_create([]) |> unwrap!() != :error 337 | assert EXGBoost.NIF.booster_create([dmat, dmat2]) |> unwrap!() != :error 338 | end 339 | 340 | test "booster_get_num_feature" do 341 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 342 | array_interface = from_tensor(mat) |> Jason.encode!() 343 | 344 | config = Jason.encode!(%{"missing" => -1.0}) 345 | 346 | dmat = 347 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 348 | |> unwrap!() 349 | 350 | booster = EXGBoost.NIF.booster_create([dmat]) |> unwrap!() 351 | assert EXGBoost.NIF.booster_get_num_feature(booster) |> unwrap!() == 3 352 | end 353 | 354 | test "test_booster_set_str_feature_info" do 355 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 356 | array_interface = from_tensor(mat) |> Jason.encode!() 357 | 358 | config = Jason.encode!(%{"missing" => -1.0}) 359 | 360 | dmat = 361 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 362 | |> unwrap!() 363 | 364 | booster = EXGBoost.NIF.booster_create([dmat]) |> unwrap!() 365 | 366 | assert EXGBoost.NIF.booster_set_str_feature_info(booster, 'feature_name', [ 367 | 'name', 368 | 'color', 369 | 'length' 370 | ]) == :ok 371 | end 372 | 373 | test "test_booster_get_str_feature_info" do 374 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 375 | array_interface = from_tensor(mat) |> Jason.encode!() 376 | 377 | config = Jason.encode!(%{"missing" => -1.0}) 378 | 379 | dmat = 380 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 381 | |> unwrap!() 382 | 383 | booster = EXGBoost.NIF.booster_create([dmat]) |> unwrap!() 384 | 385 | EXGBoost.NIF.booster_set_str_feature_info(booster, 'feature_name', ['name', 'color', 'length']) 386 | 387 | assert EXGBoost.NIF.booster_get_str_feature_info(booster, 'feature_name') |> unwrap!() 388 | end 389 | 390 | test "test_boster_feature_score" do 391 | # TODO: Make more robust test. This will just return an empty list 392 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 393 | array_interface = from_tensor(mat) |> Jason.encode!() 394 | 395 | config = Jason.encode!(%{"missing" => -1.0}) 396 | 397 | dmat = 398 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 399 | |> unwrap!() 400 | 401 | config = Jason.encode!(%{"importance_type" => "weight"}) 402 | booster = EXGBoost.NIF.booster_create([dmat]) |> unwrap!() 403 | 404 | assert EXGBoost.NIF.booster_feature_score(booster, config) |> unwrap!() != :error 405 | end 406 | 407 | test "save model" do 408 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 409 | array_interface = from_tensor(mat) |> Jason.encode!() 410 | 411 | config = Jason.encode!(%{"missing" => -1.0}) 412 | 413 | dmat = 414 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 415 | |> unwrap!() 416 | 417 | json_file = Path.join(System.tmp_dir!(), "model.json") |> String.to_charlist() 418 | ubj_file = Path.join(System.tmp_dir!(), "model.ubj") |> String.to_charlist() 419 | booster = EXGBoost.NIF.booster_create([dmat]) |> unwrap!() 420 | assert EXGBoost.NIF.booster_save_model(booster, json_file) |> unwrap!() == :ok 421 | assert EXGBoost.NIF.booster_save_model(booster, ubj_file) |> unwrap!() == :ok 422 | assert File.exists?(json_file) and File.regular?(json_file) 423 | assert File.exists?(ubj_file) and File.regular?(ubj_file) 424 | assert File.rm(json_file) == :ok 425 | assert File.rm(ubj_file) == :ok 426 | end 427 | 428 | test "load model" do 429 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 430 | array_interface = from_tensor(mat) |> Jason.encode!() 431 | 432 | config = Jason.encode!(%{"missing" => -1.0}) 433 | 434 | dmat = 435 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 436 | |> unwrap!() 437 | 438 | json_file = Path.join(System.tmp_dir!(), "model.json") |> String.to_charlist() 439 | ubj_file = Path.join(System.tmp_dir!(), "model.ubj") |> String.to_charlist() 440 | booster = EXGBoost.NIF.booster_create([dmat]) |> unwrap!() 441 | assert EXGBoost.NIF.booster_save_model(booster, json_file) |> unwrap!() == :ok 442 | assert EXGBoost.NIF.booster_save_model(booster, ubj_file) |> unwrap!() == :ok 443 | assert File.exists?(json_file) and File.regular?(json_file) 444 | assert File.exists?(ubj_file) and File.regular?(ubj_file) 445 | assert EXGBoost.NIF.booster_load_model(json_file) |> unwrap!() != :error 446 | assert EXGBoost.NIF.booster_load_model(ubj_file) |> unwrap!() != :error 447 | end 448 | 449 | test "booster serialize" do 450 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 451 | array_interface = from_tensor(mat) |> Jason.encode!() 452 | 453 | config = Jason.encode!(%{"missing" => -1.0}) 454 | 455 | dmat = 456 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 457 | |> unwrap!() 458 | 459 | booster = EXGBoost.NIF.booster_create([dmat]) |> unwrap!() 460 | assert EXGBoost.NIF.booster_serialize_to_buffer(booster) |> unwrap!() != :error 461 | end 462 | 463 | test "booster deserialize" do 464 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 465 | array_interface = from_tensor(mat) |> Jason.encode!() 466 | 467 | config = Jason.encode!(%{"missing" => -1.0}) 468 | 469 | dmat = 470 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 471 | |> unwrap!() 472 | 473 | booster = EXGBoost.NIF.booster_create([dmat]) |> unwrap!() 474 | buffer = EXGBoost.NIF.booster_serialize_to_buffer(booster) |> unwrap!() 475 | EXGBoost.NIF.booster_deserialize_from_buffer(buffer) 476 | assert EXGBoost.NIF.booster_deserialize_from_buffer(buffer) |> unwrap!() != :error 477 | end 478 | 479 | test "save booster config" do 480 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 481 | array_interface = from_tensor(mat) |> Jason.encode!() 482 | 483 | config = Jason.encode!(%{"missing" => -1.0}) 484 | 485 | dmat = 486 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 487 | |> unwrap!() 488 | 489 | booster = EXGBoost.NIF.booster_create([dmat]) |> unwrap!() 490 | assert EXGBoost.NIF.booster_save_json_config(booster) |> unwrap!() != :error 491 | end 492 | 493 | test "load booster config" do 494 | mat = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 495 | array_interface = from_tensor(mat) |> Jason.encode!() 496 | 497 | config = Jason.encode!(%{"missing" => -1.0}) 498 | 499 | dmat = 500 | EXGBoost.NIF.dmatrix_create_from_dense(array_interface, config) 501 | |> unwrap!() 502 | 503 | booster = EXGBoost.NIF.booster_create([dmat]) |> unwrap!() 504 | buf = EXGBoost.NIF.booster_save_json_config(booster) |> unwrap!() 505 | assert EXGBoost.NIF.booster_load_json_config(booster, buf) |> unwrap!() != :error 506 | end 507 | end 508 | -------------------------------------------------------------------------------- /test/parameter_test.exs: -------------------------------------------------------------------------------- 1 | defmodule ParameterTest do 2 | use ExUnit.Case, async: true 3 | alias EXGBoost.Booster 4 | 5 | setup do 6 | %{key: Nx.Random.key(42)} 7 | end 8 | 9 | test "tree booster", context do 10 | num_class = 10 11 | nrows = :rand.uniform(10) 12 | ncols = :rand.uniform(10) 13 | {x, key} = Nx.Random.normal(context.key, 0, 1, shape: {nrows, ncols}) 14 | {y, _key} = Nx.Random.randint(key, 0, num_class, shape: {nrows}) 15 | 16 | num_boost_round = 10 17 | 18 | params = [ 19 | device: :cpu, 20 | num_boost_rounds: num_boost_round, 21 | tree_method: :hist, 22 | obj: :multi_softprob, 23 | num_class: num_class, 24 | eval_metric: [ 25 | :rmse, 26 | :rmsle, 27 | :mae, 28 | :mape, 29 | :logloss, 30 | :error, 31 | :auc, 32 | :merror, 33 | :mlogloss, 34 | :gamma_nloglik, 35 | :inv_map, 36 | {:tweedie_nloglik, 1.5}, 37 | {:error, 0.2}, 38 | {:ndcg, 3}, 39 | {:map, 2}, 40 | {:inv_ndcg, 3} 41 | ], 42 | max_depth: 3, 43 | eta: 0.3, 44 | gamma: 0.1, 45 | min_child_weight: 1, 46 | subsample: 0.8, 47 | colsample_by: [tree: 0.8, node: 0.8, level: 0.8], 48 | lambda: 1, 49 | alpha: 0, 50 | grow_policy: :lossguide, 51 | max_leaves: 0, 52 | max_bin: 128, 53 | predictor: :cpu_predictor, 54 | num_parallel_tree: 1, 55 | monotone_constraints: [], 56 | interaction_constraints: [] 57 | ] 58 | 59 | booster = EXGBoost.train(x, y, params) 60 | assert Booster.get_boosted_rounds(booster) == num_boost_round 61 | 62 | assert_raise NimbleOptions.ValidationError, fn -> EXGBoost.train(x, y, eta: 2) end 63 | 64 | assert_raise NimbleOptions.ValidationError, fn -> 65 | EXGBoost.train(x, y, updater: [:grow_colmakerm, :grow_histmakerm]) 66 | end 67 | 68 | assert_raise NimbleOptions.ValidationError, fn -> 69 | EXGBoost.train(x, y, colsample_by: [nottree: 1]) 70 | end 71 | end 72 | 73 | test "linear booster", context do 74 | num_class = 10 75 | nrows = :rand.uniform(10) 76 | ncols = :rand.uniform(10) 77 | {x, key} = Nx.Random.normal(context.key, 0, 1, shape: {nrows, ncols}) 78 | {y, _key} = Nx.Random.randint(key, 0, num_class, shape: {nrows}) 79 | 80 | num_boost_round = 10 81 | 82 | params = [ 83 | booster: :gblinear, 84 | num_boost_rounds: num_boost_round, 85 | lambda: 0.1, 86 | alpha: 0.1, 87 | updater: :coord_descent, 88 | feature_selector: :greedy, 89 | top_k: 1, 90 | device: {:gpu, 0} 91 | ] 92 | 93 | booster = EXGBoost.train(x, y, params) 94 | assert Booster.get_boosted_rounds(booster) == num_boost_round 95 | 96 | params = [ 97 | booster: :gblinear, 98 | num_boost_rounds: num_boost_round, 99 | lambda: 0.1, 100 | alpha: 0.1, 101 | updater: :shotgun, 102 | feature_selector: :shuffle, 103 | top_k: 1 104 | ] 105 | 106 | booster = EXGBoost.train(x, y, params) 107 | assert Booster.get_boosted_rounds(booster) == num_boost_round 108 | 109 | # TODO Right now this is an ArgumentError, but it should be a NimbleOptions.ValidationError 110 | assert_raise ArgumentError, fn -> 111 | EXGBoost.train(x, y, booster: :gblinear, updater: :shotgun, feature_selector: :greedy) 112 | end 113 | end 114 | 115 | test "dart booster", context do 116 | num_class = 10 117 | nrows = :rand.uniform(10) 118 | ncols = :rand.uniform(10) 119 | {x, key} = Nx.Random.normal(context.key, 0, 1, shape: {nrows, ncols}) 120 | {y, _key} = Nx.Random.randint(key, 0, num_class, shape: {nrows}) 121 | 122 | num_boost_round = 10 123 | 124 | params = [ 125 | booster: :dart, 126 | num_boost_rounds: num_boost_round, 127 | tree_method: :hist, 128 | obj: :multi_softprob, 129 | num_class: num_class, 130 | eval_metric: [ 131 | :rmse, 132 | :rmsle, 133 | :mae, 134 | :mape, 135 | :logloss, 136 | :error, 137 | :auc, 138 | :merror, 139 | :mlogloss, 140 | :gamma_nloglik, 141 | :inv_map, 142 | {:tweedie_nloglik, 1.5}, 143 | {:error, 0.2}, 144 | {:ndcg, 3}, 145 | {:map, 2}, 146 | {:inv_ndcg, 3} 147 | ], 148 | max_depth: 3, 149 | eta: 0.3, 150 | gamma: 0.1, 151 | min_child_weight: 1, 152 | subsample: 0.8, 153 | colsample_by: [tree: 0.8, node: 0.8, level: 0.8], 154 | lambda: 1, 155 | alpha: 0, 156 | grow_policy: :lossguide, 157 | max_leaves: 0, 158 | max_bin: 128, 159 | predictor: :cpu_predictor, 160 | num_parallel_tree: 1, 161 | monotone_constraints: [], 162 | interaction_constraints: [], 163 | rate_drop: 0.2, 164 | one_drop: 1 165 | ] 166 | 167 | booster = EXGBoost.train(x, y, params) 168 | assert Booster.get_boosted_rounds(booster) == num_boost_round 169 | 170 | assert_raise NimbleOptions.ValidationError, fn -> 171 | EXGBoost.train(x, y, booster: :dart, eta: 1.5) 172 | end 173 | 174 | assert_raise NimbleOptions.ValidationError, fn -> 175 | EXGBoost.train(x, y, booster: :dart, max_depth: -1) 176 | end 177 | 178 | assert_raise NimbleOptions.ValidationError, fn -> 179 | EXGBoost.train(x, y, booster: :dart, rate_drop: 2) 180 | end 181 | end 182 | end 183 | -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | ExUnit.start() 2 | --------------------------------------------------------------------------------