├── .github ├── FUNDING.yml └── workflows │ └── ci.yml ├── lightgbm3-sys ├── src │ └── lib.rs ├── README.md ├── Cargo.toml └── build.rs ├── Dockerfile ├── .gitmodules ├── .gitignore ├── LICENSE ├── Cargo.toml ├── src ├── error.rs ├── lib.rs ├── dataset.rs └── booster.rs ├── examples ├── regression.rs ├── multiclass_classification.rs └── binary_classification.rs ├── benches └── regression.rs └── README.md /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [Mottl] 2 | -------------------------------------------------------------------------------- /lightgbm3-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | include!(concat!(env!("OUT_DIR"), "/bindings.rs")); 2 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rust 2 | RUN apt update 3 | RUN apt install -y cmake libclang-dev libc++-dev gcc-multilib 4 | WORKDIR /app 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "lightgbm3-sys/lightgbm"] 2 | path = lightgbm3-sys/lightgbm 3 | url = https://github.com/microsoft/LightGBM.git 4 | -------------------------------------------------------------------------------- /lightgbm3-sys/README.md: -------------------------------------------------------------------------------- 1 | # lightgbm3-sys 2 | 3 | FFI bindings to [LightGBM](https://github.com/microsoft/LightGBM), generated 4 | at compile time with [bindgen](https://github.com/rust-lang/rust-bindgen). 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 7 | Cargo.lock 8 | 9 | # lightgbm3-sys build target 10 | lightgbm3-sys/target 11 | 12 | # These are backup files generated by rustfmt 13 | *.rs.bk 14 | 15 | ## File system 16 | .DS_Store 17 | desktop.ini 18 | 19 | ## Editor 20 | *.swp 21 | *.swo 22 | Session.vim 23 | .cproject 24 | .idea 25 | *.iml 26 | .vscode 27 | .project 28 | .favorites.json 29 | .settings/ 30 | .vs/ 31 | 32 | # Tarpaulin coverage files 33 | cobertura.xml 34 | 35 | .specstory -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Dmitry Mottl 4 | Copyright (c) 2021 vaaaaanquish 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /lightgbm3-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lightgbm3-sys" 3 | version = "1.0.8" 4 | edition = "2021" 5 | authors = ["Dmitry Mottl ", "vaaaaanquish <6syun9@gmail.com>"] 6 | build = "build.rs" 7 | license = "MIT" 8 | repository = "https://github.com/Mottl/lightgbm3-rs" 9 | description = "Low-level Rust bindings for LightGBM library" 10 | categories = ["external-ffi-bindings"] 11 | readme = "README.md" 12 | exclude = ["README.md", ".gitlab-ci.yml", ".hgeol", ".gitignore", ".appveyor.yml", ".coveralls.yml", ".travis.yml", ".github", ".gitmodules", ".nuget", "**/*.md", "lightgbm/compute/doc", "lightgbm/compute/example", "lightgbm/compute/index.html", "lightgbm/compute/perf", "lightgbm/compute/test", "lightgbm/eigen/debug", "lightgbm/eigen/demos", "lightgbm/eigen/doc", "lightgbm/eigen/failtest", "lightgbm/eigen/test", "lightgbm/examples", "lightgbm/external_libs/fast_double_parser/benchmarks", "lightgbm/external_libs/fmt/doc", "lightgbm/external_libs/fmt/test"] 13 | 14 | [dependencies] 15 | libc = "0.2" 16 | 17 | [build-dependencies] 18 | cmake = "0.1" 19 | bindgen = "0.71" 20 | doxygen-rs = "0.4" 21 | 22 | [features] 23 | openmp = [] 24 | gpu = [] 25 | cuda = [] 26 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lightgbm3" 3 | version = "1.0.8" 4 | edition = "2021" 5 | authors = [ 6 | "Dmitry Mottl ", 7 | "vaaaaanquish <6syun9@gmail.com>", 8 | "paq <89paku@gmail.com>", 9 | "Benjamin Ellis ", 10 | ] 11 | license = "MIT" 12 | repository = "https://github.com/Mottl/lightgbm3-rs" 13 | description = "Rust bindings for LightGBM library" 14 | documentation = "https://docs.rs/lightgbm3/" 15 | keywords = ["lightgbm", "machine-learning", "gradient-boosting"] 16 | categories = ["api-bindings", "science"] 17 | readme = "README.md" 18 | exclude = [".gitignore", ".github", ".gitmodules", "examples", "benches", "lightgbm3-sys"] 19 | 20 | [dependencies] 21 | lightgbm3-sys = { path = "lightgbm3-sys", version = "1" } 22 | serde_json = "1" 23 | polars = { version = "0.47", optional = true } 24 | 25 | [features] 26 | default = [] 27 | polars = ["dep:polars"] 28 | openmp = ["lightgbm3-sys/openmp"] 29 | gpu = ["lightgbm3-sys/gpu"] 30 | cuda = ["lightgbm3-sys/cuda"] 31 | 32 | [[bench]] 33 | name = "regression" 34 | path = "benches/regression.rs" 35 | harness = false 36 | 37 | [dev-dependencies] 38 | rand = "0.9" 39 | rand_distr = "0.5" 40 | csv = "1.3" 41 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | name: Rust ${{ matrix.os }} ${{ matrix.rust }} 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | matrix: 11 | rust: [stable] 12 | os: [ubuntu-latest, macos-latest] #, windows-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | with: 16 | submodules: recursive 17 | - name: Setup Rust 18 | uses: actions-rs/toolchain@v1 19 | with: 20 | toolchain: ${{ matrix.rust }} 21 | components: clippy 22 | - name: Build for Mac 23 | if: matrix.os == 'macos-latest' 24 | run: | 25 | brew install cmake libomp 26 | cargo build --features=openmp --features=polars 27 | - name: Build for Ubuntu 28 | if: matrix.os == 'ubuntu-latest' 29 | run: | 30 | sudo apt-get update 31 | sudo apt-get install -y cmake libclang-dev libc++-dev gcc-multilib 32 | cargo build --features=openmp --features=polars 33 | # - name: Build for Windows 34 | # if: matrix.os == 'windows-latest' 35 | # run: | 36 | # cargo build --features=openmp --features=polars 37 | - name: Run tests 38 | run: cargo test --features=polars #--features=openmp 39 | continue-on-error: ${{ matrix.rust == 'nightly' }} 40 | - name: Run Clippy 41 | uses: actions-rs/clippy-check@v1 42 | with: 43 | token: ${{ secrets.GITHUB_TOKEN }} 44 | args: --features=polars --features=openmp -- --no-deps 45 | format_check: 46 | name: Run Rustfmt 47 | runs-on: ubuntu-latest 48 | steps: 49 | - uses: actions/checkout@v2 50 | - uses: actions-rs/toolchain@v1 51 | with: 52 | toolchain: stable 53 | components: rustfmt 54 | - run: cargo fmt -- --check 55 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | //! Functionality related to errors and error handling. 2 | 3 | use std::{ 4 | error, 5 | ffi::CStr, 6 | fmt::{self, Debug, Display}, 7 | }; 8 | 9 | #[cfg(feature = "polars")] 10 | use polars::prelude::*; 11 | 12 | /// Convenience return type for most operations which can return an `LightGBM`. 13 | pub type Result = std::result::Result; 14 | 15 | /// Wrap errors returned by the LightGBM library. 16 | #[derive(Debug, Eq, PartialEq)] 17 | pub struct Error { 18 | desc: String, 19 | } 20 | 21 | impl Error { 22 | pub(crate) fn new>(desc: S) -> Self { 23 | Self { desc: desc.into() } 24 | } 25 | 26 | /// Check the return value from an LightGBM FFI call, and return the last error message on error. 27 | /// 28 | /// Return values of 0 are treated as success, returns values of -1 are treated as errors. 29 | /// 30 | /// Meaning of any other return values are undefined, and will cause a panic. 31 | pub(crate) fn check_return_value(ret_val: i32) -> Result<()> { 32 | match ret_val { 33 | 0 => Ok(()), 34 | -1 => Err(Self::from_lightgbm()), 35 | _ => panic!("unexpected return value '{}', expected 0 or -1", ret_val), 36 | } 37 | } 38 | 39 | /// Get the last error message from LightGBM. 40 | fn from_lightgbm() -> Self { 41 | let c_str = unsafe { CStr::from_ptr(lightgbm3_sys::LGBM_GetLastError()) }; 42 | let str_slice = c_str.to_str().unwrap(); 43 | Self::new(str_slice) 44 | } 45 | } 46 | 47 | impl error::Error for Error {} 48 | 49 | impl Display for Error { 50 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 51 | write!(f, "LightGBM error: {}", &self.desc) 52 | } 53 | } 54 | 55 | #[cfg(feature = "polars")] 56 | impl From for Error { 57 | fn from(pe: PolarsError) -> Self { 58 | Self { 59 | desc: pe.to_string(), 60 | } 61 | } 62 | } 63 | 64 | #[cfg(test)] 65 | mod tests { 66 | use super::*; 67 | 68 | #[test] 69 | fn return_value_handling() { 70 | let result = Error::check_return_value(0); 71 | assert_eq!(result, Ok(())); 72 | 73 | let result = Error::check_return_value(-1); 74 | assert_eq!(result, Err(Error::new("Everything is fine"))); 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /examples/regression.rs: -------------------------------------------------------------------------------- 1 | //! MSE regression model training and evaluation example 2 | 3 | use lightgbm3::{Booster, Dataset}; 4 | use serde_json::json; 5 | use std::iter::zip; 6 | 7 | /// Loads a .tsv file and returns a flattened vector of xs, a vector of ys 8 | /// and a number of features 9 | fn load_file(file_path: &str) -> (Vec, Vec, i32) { 10 | let rdr = csv::ReaderBuilder::new() 11 | .has_headers(false) 12 | .delimiter(b'\t') 13 | .from_path(file_path); 14 | let mut ys: Vec = Vec::new(); 15 | let mut xs: Vec = Vec::new(); 16 | for result in rdr.unwrap().records() { 17 | let record = result.unwrap(); 18 | let mut record = record.into_iter(); 19 | let y = record.next().unwrap().parse::().unwrap(); 20 | ys.push(y); 21 | xs.extend(record.map(|x| x.parse::().unwrap())); 22 | } 23 | let n_features = xs.len() / ys.len(); 24 | (xs, ys, n_features as i32) 25 | } 26 | 27 | fn main() -> std::io::Result<()> { 28 | let (train_xs, train_ys, n_features) = 29 | load_file("lightgbm3-sys/lightgbm/examples/regression/regression.train"); 30 | let (test_xs, test_ys, n_features_test) = 31 | load_file("lightgbm3-sys/lightgbm/examples/regression/regression.test"); 32 | assert_eq!(n_features, n_features_test); 33 | 34 | let train_dataset = Dataset::from_slice(&train_xs, &train_ys, n_features, true).unwrap(); 35 | 36 | let params = json! { 37 | { 38 | "num_iterations": 100, 39 | "objective": "regression", 40 | "metric": "l2" 41 | } 42 | }; 43 | 44 | // Train a model 45 | let booster = Booster::train(train_dataset, ¶ms).unwrap(); 46 | // Predicts floating point 47 | let y_pred = booster.predict(&test_xs, n_features, true).unwrap(); 48 | // Calculate regression metrics 49 | let mean = test_ys.iter().sum::() / test_ys.len() as f32; 50 | let var = test_ys.iter().map(|&y| (y - mean).powi(2)).sum::() / test_ys.len() as f32; 51 | let var_model = zip(&test_ys, &y_pred) 52 | .map(|(&y, &y_pred)| (y - y_pred as f32).powi(2)) 53 | .sum::() 54 | / test_ys.len() as f32; 55 | let r2 = 1.0f32 - var_model / var; 56 | println!("test mse = {var_model:.3}"); 57 | println!("test r^2 = {r2:.3}"); 58 | Ok(()) 59 | } 60 | -------------------------------------------------------------------------------- /examples/multiclass_classification.rs: -------------------------------------------------------------------------------- 1 | //! Multiclass classification model training and evaluation example 2 | 3 | use lightgbm3::{argmax, Booster, Dataset}; 4 | use serde_json::json; 5 | use std::iter::zip; 6 | 7 | /// Loads a .tsv file and returns a flattened vector of xs, a vector of labels 8 | /// and a number of features 9 | fn load_file(file_path: &str) -> (Vec, Vec, i32) { 10 | let rdr = csv::ReaderBuilder::new() 11 | .has_headers(false) 12 | .delimiter(b'\t') 13 | .from_path(file_path); 14 | let mut labels: Vec = Vec::new(); 15 | let mut features: Vec = Vec::new(); 16 | for result in rdr.unwrap().records() { 17 | let record = result.unwrap(); 18 | let mut record = record.into_iter(); 19 | let label = record.next().unwrap().parse::().unwrap(); 20 | labels.push(label); 21 | features.extend(record.map(|x| x.parse::().unwrap())); 22 | } 23 | let n_features = features.len() / labels.len(); 24 | (features, labels, n_features as i32) 25 | } 26 | 27 | fn main() -> std::io::Result<()> { 28 | let (train_features, train_labels, n_features) = 29 | load_file("lightgbm3-sys/lightgbm/examples/multiclass_classification/multiclass.train"); 30 | let (test_features, test_labels, n_features_test) = 31 | load_file("lightgbm3-sys/lightgbm/examples/multiclass_classification/multiclass.test"); 32 | assert_eq!(n_features, n_features_test); 33 | let train_dataset = 34 | Dataset::from_slice(&train_features, &train_labels, n_features, true).unwrap(); 35 | 36 | let params = json! { 37 | { 38 | "num_iterations": 100, 39 | "objective": "multiclass", 40 | "metric": "multi_logloss", 41 | "num_class": 5, 42 | } 43 | }; 44 | 45 | // Train a model 46 | let booster = Booster::train(train_dataset, ¶ms).unwrap(); 47 | // Predict probabilities for each class 48 | let probas = booster.predict(&test_features, n_features, true).unwrap(); 49 | // Calculate accuracy 50 | let mut tp = 0; 51 | for (&label, proba) in zip(&test_labels, probas.chunks(booster.num_classes() as usize)) { 52 | let argmax_pred = argmax(proba); 53 | if label == argmax_pred as f32 { 54 | tp += 1; 55 | } 56 | println!("true={label}, pred={argmax_pred}, probas={proba:.3?}"); 57 | } 58 | println!( 59 | "Accuracy: {} / {}", 60 | &tp, 61 | test_features.len() / n_features as usize 62 | ); 63 | Ok(()) 64 | } 65 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! LightGBM Rust library 2 | //! 3 | //! **`lightgbm3`** supports the following features: 4 | //! - `polars` for [polars](https://github.com/pola-rs/polars) support 5 | //! - `openmp` for [multi-processing](https://lightgbm.readthedocs.io/en/latest/Installation-Guide.html#build-threadless-version-not-recommended) support 6 | //! - `gpu` for [GPU](https://lightgbm.readthedocs.io/en/latest/Installation-Guide.html#build-gpu-version) support 7 | //! - `cuda` for [CUDA](https://lightgbm.readthedocs.io/en/latest/Installation-Guide.html#build-cuda-version) support 8 | //! 9 | //! # Examples 10 | //! ### Training: 11 | //! ```no_run 12 | //! use lightgbm3::{Dataset, Booster}; 13 | //! use serde_json::json; 14 | //! 15 | //! let features = vec![vec![1.0, 0.1, 0.2], 16 | //! vec![0.7, 0.4, 0.5], 17 | //! vec![0.9, 0.8, 0.5], 18 | //! vec![0.2, 0.2, 0.8], 19 | //! vec![0.1, 0.7, 1.0]]; 20 | //! let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0]; 21 | //! let dataset = Dataset::from_vec_of_vec(features, labels, true).unwrap(); 22 | //! let params = json!{ 23 | //! { 24 | //! "num_iterations": 10, 25 | //! "objective": "binary", 26 | //! "metric": "auc", 27 | //! } 28 | //! }; 29 | //! let bst = Booster::train(dataset, ¶ms).unwrap(); 30 | //! bst.save_file("path/to/model.lgb").unwrap(); 31 | //! ``` 32 | //! 33 | //! ### Inference: 34 | //! ```no_run 35 | //! use lightgbm3::{Dataset, Booster}; 36 | //! 37 | //! let bst = Booster::from_file("path/to/model.lgb").unwrap(); 38 | //! let features = vec![1.0, 2.0, -5.0]; 39 | //! let n_features = features.len(); 40 | //! let y_pred = bst.predict_with_params(&features, n_features as i32, true, "num_threads=1").unwrap()[0]; 41 | //! ``` 42 | 43 | macro_rules! lgbm_call { 44 | ($x:expr) => { 45 | Error::check_return_value(unsafe { $x }) 46 | }; 47 | } 48 | 49 | mod booster; 50 | mod dataset; 51 | mod error; 52 | 53 | pub use booster::{Booster, ImportanceType}; 54 | pub use dataset::{DType, Dataset}; 55 | pub use error::{Error, Result}; 56 | 57 | /// Get index of the element in a slice with the maximum value 58 | pub fn argmax(xs: &[T]) -> usize { 59 | if xs.len() == 1 { 60 | 0 61 | } else { 62 | let mut maxval = &xs[0]; 63 | let mut max_ixs: Vec = vec![0]; 64 | for (i, x) in xs.iter().enumerate().skip(1) { 65 | if x > maxval { 66 | maxval = x; 67 | max_ixs = vec![i]; 68 | } else if x == maxval { 69 | max_ixs.push(i); 70 | } 71 | } 72 | max_ixs[0] 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /examples/binary_classification.rs: -------------------------------------------------------------------------------- 1 | //! Binary classification model training and evaluation example 2 | 3 | use lightgbm3::{Booster, Dataset, ImportanceType}; 4 | use serde_json::json; 5 | use std::iter::zip; 6 | 7 | /// Loads a .tsv file and returns a flattened vector of xs, a vector of labels 8 | /// and a number of features 9 | fn load_file(file_path: &str) -> (Vec, Vec, i32) { 10 | let rdr = csv::ReaderBuilder::new() 11 | .has_headers(false) 12 | .delimiter(b'\t') 13 | .from_path(file_path); 14 | let mut labels: Vec = Vec::new(); 15 | let mut features: Vec = Vec::new(); 16 | for result in rdr.unwrap().records() { 17 | let record = result.unwrap(); 18 | let mut record = record.into_iter(); 19 | let label = record.next().unwrap().parse::().unwrap(); 20 | labels.push(label); 21 | features.extend(record.map(|x| x.parse::().unwrap())); 22 | } 23 | let n_features = features.len() / labels.len(); 24 | (features, labels, n_features as i32) 25 | } 26 | 27 | fn main() -> std::io::Result<()> { 28 | let (train_features, train_labels, n_features) = 29 | load_file("lightgbm3-sys/lightgbm/examples/binary_classification/binary.train"); 30 | let (test_features, test_labels, n_features_test) = 31 | load_file("lightgbm3-sys/lightgbm/examples/binary_classification/binary.test"); 32 | assert_eq!(n_features, n_features_test); 33 | let train_dataset = 34 | Dataset::from_slice(&train_features, &train_labels, n_features, true).unwrap(); 35 | 36 | let params = json! { 37 | { 38 | "num_iterations": 100, 39 | "objective": "binary", 40 | "metric": "auc" 41 | } 42 | }; 43 | // Train a model 44 | let booster = Booster::train(train_dataset, ¶ms).unwrap(); 45 | // Predict probabilities 46 | let probas = booster.predict(&test_features, n_features, true).unwrap(); 47 | // Calculate accuracy 48 | let mut tp = 0; 49 | for (&label, &proba) in zip(&test_labels, &probas) { 50 | if (label == 1_f32 && proba > 0.5_f64) || (label == 0_f32 && proba <= 0.5_f64) { 51 | tp += 1; 52 | } 53 | println!("label={label}, proba={proba:.3}"); 54 | } 55 | println!("Accuracy: {} / {}\n", &tp, probas.len()); 56 | 57 | println!("Feature importance:"); 58 | let feature_name = booster.feature_name().unwrap(); 59 | let feature_importance = booster.feature_importance(ImportanceType::Gain).unwrap(); 60 | for (feature, importance) in zip(&feature_name, &feature_importance) { 61 | println!("{}: {}", feature, importance); 62 | } 63 | Ok(()) 64 | } 65 | -------------------------------------------------------------------------------- /benches/regression.rs: -------------------------------------------------------------------------------- 1 | use lightgbm3::{Booster, Dataset}; 2 | use rand_distr::Distribution; 3 | use serde_json::json; 4 | use std::hint::black_box; 5 | use std::time::Instant; 6 | 7 | fn generate_train_data() -> (Vec, Vec) { 8 | let mut rng = rand::rng(); 9 | let uniform = rand_distr::Uniform::::new(-5.0, 5.0).unwrap(); 10 | let normal = rand_distr::Normal::::new(0.0, 0.1).unwrap(); 11 | 12 | let mut x: Vec = vec![]; 13 | let mut y: Vec = Vec::with_capacity(100_000); 14 | 15 | for _ in 0..y.capacity() { 16 | let x1 = uniform.sample(&mut rng); 17 | let x2 = uniform.sample(&mut rng); 18 | let x3 = uniform.sample(&mut rng); 19 | let y_ = x1.sin() + x2.cos() + (x3 / 2.0).powi(2) + normal.sample(&mut rng); 20 | x.push(x1); 21 | x.push(x2); 22 | x.push(x3); 23 | y.push(y_ as f32); 24 | } 25 | (x, y) 26 | } 27 | 28 | fn main() -> std::io::Result<()> { 29 | const NUM_LEAVES: i32 = 5; 30 | 31 | let (x, y) = generate_train_data(); 32 | let train_dataset = Dataset::from_slice(&x, &y, 3, true).unwrap(); 33 | let params = json! { 34 | { 35 | "num_iterations": 1000, 36 | "learning_rate": 0.05, 37 | "num_leaves": NUM_LEAVES, 38 | "objective": "mse", 39 | } 40 | }; 41 | let start_time = Instant::now(); 42 | let mut booster = Booster::train(train_dataset, ¶ms).unwrap(); 43 | let train_time = start_time.elapsed().as_nanos() as f64 / 1000.0; 44 | let mut features: Vec = vec![]; 45 | #[cfg(feature = "openmp")] 46 | features.push("openmp".to_string()); 47 | #[cfg(feature = "gpu")] 48 | features.push("gpu".to_string()); 49 | #[cfg(feature = "cuda")] 50 | features.push("cuda".to_string()); 51 | if features.is_empty() { 52 | features.push("none".to_string()); 53 | } 54 | println!("Compiled features: {}", features.join(", ")); 55 | println!( 56 | "Booster train time: {:.3} us/iteration", 57 | train_time / 1000.0 58 | ); 59 | 60 | // let rmse = zip(y.iter(), y_preds.iter()).map(|(&y, &y_pred)| { 61 | // (y as f64 - y_pred).powi(2) 62 | // }).sum::() / y.len() as f64; 63 | // 64 | // println!("train rmse={}", rmse); 65 | // println!("num_iterations={}", booster.num_iterations()); 66 | 67 | // warm up CPU 68 | let _ = booster.predict(&x, 3, true).unwrap(); 69 | 70 | println!("Booster evaluation times:"); 71 | let x = [0.1, 0.5, -1.0]; 72 | for i in 1..=10 { 73 | booster.set_max_iterations(i * 50).unwrap(); 74 | let mut elapsed: u64 = 0; 75 | for _ in 0..100000 { 76 | let start_time = Instant::now(); 77 | let y_preds = booster.predict(&x, 3, true).unwrap(); 78 | let eval_time = start_time.elapsed().as_nanos() as u64; 79 | elapsed += eval_time; 80 | black_box(y_preds); 81 | } 82 | println!( 83 | "trees={:4.}, leaves={}, eval time={:.3} us/sample", 84 | i * 50, 85 | NUM_LEAVES, 86 | elapsed as f64 / 100000_f64 / 1000.0 87 | ) 88 | } 89 | 90 | Ok(()) 91 | } 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # lightgbm3 — Rust bindings for LightGBM 2 | [![Crates.io](https://img.shields.io/crates/v/lightgbm3.svg)](https://crates.io/crates/lightgbm3) 3 | [![Docs.rs](https://docs.rs/lightgbm3/badge.svg)](https://docs.rs/lightgbm3/) 4 | [![build](https://github.com/Mottl/lightgbm3-rs/actions/workflows/ci.yml/badge.svg)](https://github.com/Mottl/lightgbm3-rs/actions) 5 | 6 | **`lightgbm3`** is based on [`lightgbm`](https://github.com/vaaaaanquish/lightgbm-rs) crate 7 | (which is unsupported by now), but it is not back-compatible with it. 8 | 9 | ## Installation 10 | ```shell 11 | cargo add lightgbm3 12 | ``` 13 | 14 | Since `lightgbm3` compiles `LightGBM` from source, you also need to install development libraries: 15 | 16 | #### for Linux: 17 | ``` 18 | apt install -y cmake clang libclang-dev libc++-dev gcc-multilib 19 | ``` 20 | 21 | #### for Mac: 22 | ``` 23 | brew install cmake 24 | brew install libomp # only required if you compile with "openmp" feature 25 | ``` 26 | 27 | #### for Windows 28 | 1. Install CMake and VS Build Tools. 29 | 2. Install LLVM and set `LIBCLANG_PATH` environment variable (i.e. `C:\Program Files\LLVM\bin`) 30 | 31 | Please see below for details. 32 | 33 | - [LightGBM Installation-Guide](https://lightgbm.readthedocs.io/en/latest/Installation-Guide.html) 34 | 35 | ## Usage 36 | 37 | ### Training: 38 | ```rust 39 | use lightgbm3::{Dataset, Booster}; 40 | use serde_json::json; 41 | 42 | let features = vec![vec![1.0, 0.1, 0.2], 43 | vec![0.7, 0.4, 0.5], 44 | vec![0.9, 0.8, 0.5], 45 | vec![0.2, 0.2, 0.8], 46 | vec![0.1, 0.7, 1.0]]; 47 | let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0]; 48 | let dataset = Dataset::from_vec_of_vec(features, labels, true).unwrap(); 49 | let params = json!{ 50 | { 51 | "num_iterations": 10, 52 | "objective": "binary", 53 | "metric": "auc", 54 | } 55 | }; 56 | let bst = Booster::train(dataset, ¶ms).unwrap(); 57 | bst.save_file("path/to/model.lgb").unwrap(); 58 | ``` 59 | 60 | ### Inference: 61 | ```rust 62 | use lightgbm3::{Dataset, Booster}; 63 | 64 | let bst = Booster::from_file("path/to/model.lgb").unwrap(); 65 | let features = vec![1.0, 2.0, -5.0]; 66 | let n_features = features.len(); 67 | let y_pred = bst.predict_with_params(&features, n_features as i32, true, "num_threads=1").unwrap()[0]; 68 | ``` 69 | 70 | Look in the [`./examples/`](https://github.com/Mottl/lightgbm3-rs/blob/main/examples/) folder for more details: 71 | - [binary classification](https://github.com/Mottl/lightgbm3-rs/blob/main/examples/binary_classification.rs) 72 | - [multiclass classification](https://github.com/Mottl/lightgbm3-rs/blob/main/examples/multiclass_classification.rs) 73 | - [regression](https://github.com/Mottl/lightgbm3-rs/blob/main/examples/regression.rs) 74 | 75 | ## Features 76 | **`lightgbm3`** supports the following features: 77 | - `polars` for [polars](https://github.com/pola-rs/polars) support 78 | - `openmp` for [multi-processing](https://lightgbm.readthedocs.io/en/latest/Installation-Guide.html#build-threadless-version-not-recommended) support 79 | - `gpu` for [GPU](https://lightgbm.readthedocs.io/en/latest/Installation-Guide.html#build-gpu-version) support 80 | - `cuda` for [CUDA](https://lightgbm.readthedocs.io/en/latest/Installation-Guide.html#build-cuda-version) support 81 | 82 | ## Benchmarks 83 | ``` 84 | cargo bench 85 | ``` 86 | 87 | Add `--features=openmp`, `--features=gpu` and `--features=cuda` appropriately. 88 | 89 | ## Development 90 | ``` 91 | git clone --recursive https://github.com/Mottl/lightgbm3-rs.git 92 | ``` 93 | 94 | ## Thanks 95 | Great respect to [vaaaaanquish](https://github.com/vaaaaanquish) for the LightGBM Rust package, which unfortunately 96 | no longer supported. 97 | 98 | Much reference was made to implementation and documentation. Thanks. 99 | 100 | - [microsoft/LightGBM](https://github.com/microsoft/LightGBM) 101 | - [davechallis/rust-xgboost](https://github.com/davechallis/rust-xgboost) 102 | -------------------------------------------------------------------------------- /lightgbm3-sys/build.rs: -------------------------------------------------------------------------------- 1 | use cmake::Config; 2 | use std::{ 3 | env, 4 | path::{Path, PathBuf}, 5 | process::Command, 6 | }; 7 | 8 | #[derive(Debug)] 9 | struct DoxygenCallback; 10 | 11 | impl bindgen::callbacks::ParseCallbacks for DoxygenCallback { 12 | fn process_comment(&self, comment: &str) -> Option { 13 | Some(doxygen_rs::transform(comment)) 14 | } 15 | } 16 | 17 | fn main() { 18 | let target = env::var("TARGET").unwrap(); 19 | let out_dir = env::var("OUT_DIR").unwrap(); 20 | let lgbm_root = Path::new(&out_dir).join("lightgbm"); 21 | 22 | // copy source code 23 | if !lgbm_root.exists() { 24 | let status = if target.contains("windows") { 25 | Command::new("cmd") 26 | .args(&[ 27 | "/C", 28 | "echo D | xcopy /S /Y lightgbm", 29 | lgbm_root.to_str().unwrap(), 30 | ]) 31 | .status() 32 | } else { 33 | Command::new("cp") 34 | .args(&["-r", "lightgbm", lgbm_root.to_str().unwrap()]) 35 | .status() 36 | }; 37 | if let Some(err) = status.err() { 38 | panic!( 39 | "Failed to copy ./lightgbm to {}: {}", 40 | lgbm_root.display(), 41 | err 42 | ); 43 | } 44 | } 45 | 46 | // CMake 47 | let mut cfg = Config::new(&lgbm_root); 48 | let cfg = cfg 49 | .profile("Release") 50 | .cxxflag("-std=c++14") 51 | .define("BUILD_STATIC_LIB", "ON"); 52 | #[cfg(not(feature = "openmp"))] 53 | let cfg = cfg.define("USE_OPENMP", "OFF"); 54 | #[cfg(feature = "gpu")] 55 | let cfg = cfg.define("USE_GPU", "1"); 56 | #[cfg(feature = "cuda")] 57 | let cfg = cfg.define("USE_CUDA", "1"); 58 | let dst = cfg.build(); 59 | 60 | // bindgen build 61 | let mut clang_args = vec!["-x", "c++", "-std=c++14"]; 62 | if target.contains("apple") { 63 | clang_args.push("-mmacosx-version-min=10.12"); 64 | } 65 | let bindings = bindgen::Builder::default() 66 | .header("lightgbm/include/LightGBM/c_api.h") 67 | .allowlist_file("lightgbm/include/LightGBM/c_api.h") 68 | .clang_args(&clang_args) 69 | .clang_arg(format!("-I{}", lgbm_root.join("include").display())) 70 | .parse_callbacks(Box::new(DoxygenCallback)) 71 | .generate() 72 | .expect("Unable to generate bindings"); 73 | let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); 74 | bindings 75 | .write_to_file(out_path.join("bindings.rs")) 76 | .unwrap_or_else(|err| panic!("Couldn't write bindings: {err}")); 77 | // link to appropriate C++ lib 78 | if target.contains("apple") { 79 | println!("cargo:rustc-link-lib=c++"); 80 | } else if target.contains("linux") { 81 | println!("cargo:rustc-link-lib=stdc++"); 82 | } 83 | #[cfg(feature = "openmp")] 84 | { 85 | println!("cargo:rustc-link-args=-fopenmp"); 86 | if target.contains("apple") { 87 | println!("cargo:rustc-link-lib=dylib=omp"); 88 | // Link to libomp 89 | // If it fails to compile in MacOS, try: 90 | // `brew install libomp` 91 | // `brew link --force libomp` 92 | #[cfg(all(target_arch = "x86_64", target_os = "macos"))] 93 | println!("cargo:rustc-link-search=/usr/local/opt/libomp/lib"); 94 | #[cfg(all(target_arch = "aarch64", target_os = "macos"))] 95 | println!("cargo:rustc-link-search=/opt/homebrew/opt/libomp/lib"); 96 | } else if target.contains("linux") { 97 | println!("cargo:rustc-link-lib=dylib=gomp"); 98 | } 99 | } 100 | println!("cargo:rustc-link-search={}", out_path.join("lib").display()); 101 | println!("cargo:rustc-link-search=native={}", dst.display()); 102 | if target.contains("windows") { 103 | println!("cargo:rustc-link-lib=static=lib_lightgbm"); 104 | } else { 105 | println!("cargo:rustc-link-lib=static=_lightgbm"); 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/dataset.rs: -------------------------------------------------------------------------------- 1 | //! LightGBM Dataset used for training 2 | 3 | use lightgbm3_sys::{ 4 | DatasetHandle, LGBM_DatasetGetFeatureNames, LGBM_DatasetSetFeatureNames, C_API_DTYPE_FLOAT32, 5 | C_API_DTYPE_FLOAT64, 6 | }; 7 | use std::os::raw::c_void; 8 | use std::{self, ffi::CString}; 9 | 10 | #[cfg(feature = "polars")] 11 | use polars::{datatypes::DataType::Float32, prelude::*}; 12 | 13 | use crate::{Error, Result}; 14 | 15 | // a way of implementing sealed traits until they 16 | // come to rust lang. more details at: 17 | // https://internals.rust-lang.org/t/sealed-traits/16797 18 | mod private { 19 | pub trait Sealed {} 20 | 21 | impl Sealed for f32 {} 22 | impl Sealed for f64 {} 23 | } 24 | /// LightGBM dtype 25 | /// 26 | /// This trait is sealed as it is not intended 27 | /// to be implemented out of this crate 28 | pub trait DType: private::Sealed { 29 | fn get_c_api_dtype() -> i32; 30 | } 31 | 32 | impl DType for f32 { 33 | fn get_c_api_dtype() -> i32 { 34 | C_API_DTYPE_FLOAT32 as i32 35 | } 36 | } 37 | 38 | impl DType for f64 { 39 | fn get_c_api_dtype() -> i32 { 40 | C_API_DTYPE_FLOAT64 as i32 41 | } 42 | } 43 | 44 | /// LightGBM Dataset 45 | pub struct Dataset { 46 | pub(crate) handle: DatasetHandle, 47 | } 48 | 49 | impl Dataset { 50 | /// Creates a new Dataset object from the LightGBM's DatasetHandle. 51 | fn new(handle: DatasetHandle) -> Self { 52 | Self { handle } 53 | } 54 | 55 | /// Set feature names for the dataset. 56 | /// 57 | /// This allows the model to save and display correct feature names instead of generic "Column_X". 58 | pub fn set_feature_names(&mut self, feature_names: &[String]) -> Result<()> { 59 | // 1. Verify that the number of feature names matches the dataset 60 | let (_, n_features) = self.size()?; 61 | if feature_names.len() as i32 != n_features { 62 | return Err(Error::new(format!( 63 | "Input feature names count ({}) does not match dataset feature count ({})", 64 | feature_names.len(), 65 | n_features 66 | ))); 67 | } 68 | 69 | // 2. Convert Rust Strings to CStrings (handling null-termination and memory layout) 70 | let c_names: Vec = feature_names 71 | .iter() 72 | .map(|s| { 73 | CString::new(s.as_bytes()) 74 | .map_err(|e| Error::new(format!("Invalid feature name string: {}", e))) 75 | }) 76 | .collect::>>()?; 77 | 78 | // 3. Create an array of pointers to the CString internal buffers (char**) 79 | let c_ptrs: Vec<*const std::os::raw::c_char> = c_names.iter().map(|s| s.as_ptr()).collect(); 80 | 81 | // 4. Call LightGBM C API 82 | // int LGBM_DatasetSetFeatureNames(DatasetHandle handle, const char** feature_names, int num_features); 83 | lgbm_call!(LGBM_DatasetSetFeatureNames( 84 | self.handle, 85 | c_ptrs.clone().as_mut_ptr(), 86 | feature_names.len() as i32 87 | ))?; 88 | 89 | Ok(()) 90 | } 91 | 92 | /// Get feature names from the dataset. 93 | pub fn get_feature_names(&self) -> Result> { 94 | // 1. Get the number of features 95 | let (_, n_features) = self.size()?; 96 | let len = n_features as usize; 97 | 98 | // 2. Prepare buffers 99 | // LightGBM C API requires the caller to allocate memory. 100 | // We assume a maximum feature name length of 256 bytes, which is usually sufficient. 101 | const MAX_NAME_LEN: usize = 256; 102 | 103 | // Create 'len' buffers, each of size MAX_NAME_LEN, initialized to 0 104 | let mut name_buffers: Vec> = vec![vec![0u8; MAX_NAME_LEN]; len]; 105 | 106 | // Create an array of pointers to these buffers (char**) 107 | let mut name_ptrs: Vec<*mut std::os::raw::c_char> = name_buffers 108 | .iter_mut() 109 | .map(|buf| buf.as_mut_ptr() as *mut std::os::raw::c_char) 110 | .collect(); 111 | 112 | let mut num_features_out = 0; 113 | let mut required_len_out = 0; 114 | 115 | // 3. Call C API 116 | // int LGBM_DatasetGetFeatureNames(DatasetHandle handle, const int len, int* num_features, 117 | // const size_t max_feature_name_len, size_t* feature_name_len, char** feature_names); 118 | lgbm_call!(LGBM_DatasetGetFeatureNames( 119 | self.handle, 120 | len as i32, 121 | &mut num_features_out, 122 | MAX_NAME_LEN, 123 | &mut required_len_out, 124 | name_ptrs.as_mut_ptr() 125 | ))?; 126 | 127 | // 4. Convert C strings to Rust Strings 128 | let mut result = Vec::with_capacity(num_features_out as usize); 129 | for i in 0..num_features_out as usize { 130 | // Create CStr from pointer 131 | let c_str = unsafe { std::ffi::CStr::from_ptr(name_ptrs[i]) }; 132 | // Convert to Rust String (handle UTF-8) 133 | let str_slice = c_str.to_str().map_err(|e| { 134 | Error::new(format!( 135 | "Invalid UTF-8 in feature name at index {}: {}", 136 | i, e 137 | )) 138 | })?; 139 | result.push(str_slice.to_string()); 140 | } 141 | 142 | Ok(result) 143 | } 144 | 145 | /// Creates a new `Dataset` (x, labels) from flat `&[f64]` slice with a specified number 146 | /// of features (columns). 147 | /// 148 | /// `row_major` should be set to `true` for row-major order and `false` otherwise. 149 | /// 150 | /// # Example 151 | /// ``` 152 | /// use lightgbm3::Dataset; 153 | /// 154 | /// let x = vec![vec![1.0, 0.1, 0.2], 155 | /// vec![0.7, 0.4, 0.5], 156 | /// vec![0.9, 0.8, 0.5], 157 | /// vec![0.2, 0.2, 0.8], 158 | /// vec![0.1, 0.7, 1.0]]; 159 | /// let flat_x = x.into_iter().flatten().collect::>(); 160 | /// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; 161 | /// let n_features = 3; 162 | /// let dataset = Dataset::from_slice(&flat_x, &label, n_features, true).unwrap(); 163 | /// ``` 164 | pub fn from_slice( 165 | flat_x: &[T], 166 | label: &[f32], 167 | n_features: i32, 168 | is_row_major: bool, 169 | ) -> Result { 170 | if n_features <= 0 { 171 | return Err(Error::new("number of features should be greater than 0")); 172 | } 173 | if flat_x.len() % n_features as usize != 0 { 174 | return Err(Error::new( 175 | "number of features doesn't correspond to slice size", 176 | )); 177 | } 178 | let n_rows = flat_x.len() / n_features as usize; 179 | if n_rows == 0 { 180 | return Err(Error::new("slice is empty")); 181 | } else if n_rows > i32::MAX as usize { 182 | return Err(Error::new(format!( 183 | "number of rows should be less than {}. Got {}", 184 | i32::MAX, 185 | n_rows 186 | ))); 187 | } 188 | let params = CString::new("").unwrap(); 189 | let label_str = CString::new("label").unwrap(); 190 | let reference = std::ptr::null_mut(); // not used 191 | let mut dataset_handle = std::ptr::null_mut(); // will point to a new DatasetHandle 192 | 193 | lgbm_call!(lightgbm3_sys::LGBM_DatasetCreateFromMat( 194 | flat_x.as_ptr() as *const c_void, 195 | T::get_c_api_dtype(), 196 | n_rows as i32, 197 | n_features, 198 | if is_row_major { 1_i32 } else { 0_i32 }, // is_row_major – 1 for row-major, 0 for column-major 199 | params.as_ptr(), 200 | reference, 201 | &mut dataset_handle 202 | ))?; 203 | 204 | lgbm_call!(lightgbm3_sys::LGBM_DatasetSetField( 205 | dataset_handle, 206 | label_str.as_ptr(), 207 | label.as_ptr() as *const c_void, 208 | n_rows as i32, 209 | C_API_DTYPE_FLOAT32 as i32 // labels should be always float32 210 | ))?; 211 | 212 | Ok(Self::new(dataset_handle)) 213 | } 214 | 215 | /// Creates a new `Dataset` (x, labels) from `Vec>` in row-major order. 216 | /// 217 | /// # Example 218 | /// ``` 219 | /// use lightgbm3::Dataset; 220 | /// 221 | /// let data = vec![vec![1.0, 0.1, 0.2], 222 | /// vec![0.7, 0.4, 0.5], 223 | /// vec![0.9, 0.8, 0.5], 224 | /// vec![0.2, 0.2, 0.8], 225 | /// vec![0.1, 0.7, 1.0]]; 226 | /// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; // should be Vec 227 | /// let dataset = Dataset::from_vec_of_vec(data, label, true).unwrap(); 228 | /// ``` 229 | pub fn from_vec_of_vec( 230 | x: Vec>, 231 | label: Vec, 232 | is_row_major: bool, 233 | ) -> Result { 234 | if x.is_empty() || x[0].is_empty() { 235 | return Err(Error::new("x is empty")); 236 | } 237 | let n_features = match is_row_major { 238 | true => x[0].len() as i32, 239 | false => x.len() as i32, 240 | }; 241 | let x_flat = x.into_iter().flatten().collect::>(); 242 | Self::from_slice(&x_flat, &label, n_features, is_row_major) 243 | } 244 | 245 | /// Create a new `Dataset` from tab-separated-view file. 246 | /// 247 | /// file is `tsv`. 248 | /// ```text 249 | ///