├── .github ├── FUNDING.yml └── workflows │ └── ci.yml ├── .gitignore ├── .gitmodules ├── Cargo.toml ├── Dockerfile ├── LICENSE ├── README.md ├── benches └── regression.rs ├── examples ├── binary_classification.rs ├── multiclass_classification.rs └── regression.rs ├── lightgbm3-sys ├── Cargo.toml ├── README.md ├── build.rs └── src │ └── lib.rs └── src ├── booster.rs ├── dataset.rs ├── error.rs └── lib.rs /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [Mottl] 2 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | name: Rust ${{ matrix.os }} ${{ matrix.rust }} 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | matrix: 11 | rust: [stable] 12 | os: [ubuntu-latest, macos-latest] #, windows-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | with: 16 | submodules: recursive 17 | - name: Setup Rust 18 | uses: actions-rs/toolchain@v1 19 | with: 20 | toolchain: ${{ matrix.rust }} 21 | components: clippy 22 | - name: Build for Mac 23 | if: matrix.os == 'macos-latest' 24 | run: | 25 | brew install cmake libomp 26 | cargo build --features=openmp --features=polars 27 | - name: Build for Ubuntu 28 | if: matrix.os == 'ubuntu-latest' 29 | run: | 30 | sudo apt-get update 31 | sudo apt-get install -y cmake libclang-dev libc++-dev gcc-multilib 32 | cargo build --features=openmp --features=polars 33 | # - name: Build for Windows 34 | # if: matrix.os == 'windows-latest' 35 | # run: | 36 | # cargo build --features=openmp --features=polars 37 | - name: Run tests 38 | run: cargo test --features=polars #--features=openmp 39 | continue-on-error: ${{ matrix.rust == 'nightly' }} 40 | - name: Run Clippy 41 | uses: actions-rs/clippy-check@v1 42 | with: 43 | token: ${{ secrets.GITHUB_TOKEN }} 44 | args: --features=polars --features=openmp -- --no-deps 45 | format_check: 46 | name: Run Rustfmt 47 | runs-on: ubuntu-latest 48 | steps: 49 | - uses: actions/checkout@v2 50 | - uses: actions-rs/toolchain@v1 51 | with: 52 | toolchain: stable 53 | components: rustfmt 54 | - run: cargo fmt -- --check 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 7 | Cargo.lock 8 | 9 | # lightgbm3-sys build target 10 | lightgbm3-sys/target 11 | 12 | # These are backup files generated by rustfmt 13 | *.rs.bk 14 | 15 | ## File system 16 | .DS_Store 17 | desktop.ini 18 | 19 | ## Editor 20 | *.swp 21 | *.swo 22 | Session.vim 23 | .cproject 24 | .idea 25 | *.iml 26 | .vscode 27 | .project 28 | .favorites.json 29 | .settings/ 30 | .vs/ 31 | 32 | # Tarpaulin coverage files 33 | cobertura.xml 34 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "lightgbm3-sys/lightgbm"] 2 | path = lightgbm3-sys/lightgbm 3 | url = https://github.com/microsoft/LightGBM.git 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lightgbm3" 3 | version = "1.0.8" 4 | edition = "2021" 5 | authors = [ 6 | "Dmitry Mottl ", 7 | "vaaaaanquish <6syun9@gmail.com>", 8 | "paq <89paku@gmail.com>", 9 | "Benjamin Ellis ", 10 | ] 11 | license = "MIT" 12 | repository = "https://github.com/Mottl/lightgbm3-rs" 13 | description = "Rust bindings for LightGBM library" 14 | documentation = "https://docs.rs/lightgbm3/" 15 | keywords = ["lightgbm", "machine-learning", "gradient-boosting"] 16 | categories = ["api-bindings", "science"] 17 | readme = "README.md" 18 | exclude = [".gitignore", ".github", ".gitmodules", "examples", "benches", "lightgbm3-sys"] 19 | 20 | [dependencies] 21 | lightgbm3-sys = { path = "lightgbm3-sys", version = "1" } 22 | serde_json = "1" 23 | polars = { version = "0.47", optional = true } 24 | 25 | [features] 26 | default = [] 27 | polars = ["dep:polars"] 28 | openmp = ["lightgbm3-sys/openmp"] 29 | gpu = ["lightgbm3-sys/gpu"] 30 | cuda = ["lightgbm3-sys/cuda"] 31 | 32 | [[bench]] 33 | name = "regression" 34 | path = "benches/regression.rs" 35 | harness = false 36 | 37 | [dev-dependencies] 38 | rand = "0.9" 39 | rand_distr = "0.5" 40 | csv = "1.3" 41 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rust 2 | RUN apt update 3 | RUN apt install -y cmake libclang-dev libc++-dev gcc-multilib 4 | WORKDIR /app 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Dmitry Mottl 4 | Copyright (c) 2021 vaaaaanquish 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # lightgbm3 — Rust bindings for LightGBM 2 | [![Crates.io](https://img.shields.io/crates/v/lightgbm3.svg)](https://crates.io/crates/lightgbm3) 3 | [![Docs.rs](https://docs.rs/lightgbm3/badge.svg)](https://docs.rs/lightgbm3/) 4 | [![build](https://github.com/Mottl/lightgbm3-rs/actions/workflows/ci.yml/badge.svg)](https://github.com/Mottl/lightgbm3-rs/actions) 5 | 6 | **`lightgbm3`** is based on [`lightgbm`](https://github.com/vaaaaanquish/lightgbm-rs) crate 7 | (which is unsupported by now), but it is not back-compatible with it. 8 | 9 | ## Installation 10 | ```shell 11 | cargo add lightgbm3 12 | ``` 13 | 14 | Since `lightgbm3` compiles `LightGBM` from source, you also need to install development libraries: 15 | 16 | #### for Linux: 17 | ``` 18 | apt install -y cmake clang libclang-dev libc++-dev gcc-multilib 19 | ``` 20 | 21 | #### for Mac: 22 | ``` 23 | brew install cmake 24 | brew install libomp # only required if you compile with "openmp" feature 25 | ``` 26 | 27 | #### for Windows 28 | 1. Install CMake and VS Build Tools. 29 | 2. Install LLVM and set `LIBCLANG_PATH` environment variable (i.e. `C:\Program Files\LLVM\bin`) 30 | 31 | Please see below for details. 32 | 33 | - [LightGBM Installation-Guide](https://lightgbm.readthedocs.io/en/latest/Installation-Guide.html) 34 | 35 | ## Usage 36 | 37 | ### Training: 38 | ```rust 39 | use lightgbm3::{Dataset, Booster}; 40 | use serde_json::json; 41 | 42 | let features = vec![vec![1.0, 0.1, 0.2], 43 | vec![0.7, 0.4, 0.5], 44 | vec![0.9, 0.8, 0.5], 45 | vec![0.2, 0.2, 0.8], 46 | vec![0.1, 0.7, 1.0]]; 47 | let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0]; 48 | let dataset = Dataset::from_vec_of_vec(features, labels, true).unwrap(); 49 | let params = json!{ 50 | { 51 | "num_iterations": 10, 52 | "objective": "binary", 53 | "metric": "auc", 54 | } 55 | }; 56 | let bst = Booster::train(dataset, ¶ms).unwrap(); 57 | bst.save_file("path/to/model.lgb").unwrap(); 58 | ``` 59 | 60 | ### Inference: 61 | ```rust 62 | use lightgbm3::{Dataset, Booster}; 63 | 64 | let bst = Booster::from_file("path/to/model.lgb").unwrap(); 65 | let features = vec![1.0, 2.0, -5.0]; 66 | let n_features = features.len(); 67 | let y_pred = bst.predict_with_params(&features, n_features as i32, true, "num_threads=1").unwrap()[0]; 68 | ``` 69 | 70 | Look in the [`./examples/`](https://github.com/Mottl/lightgbm3-rs/blob/main/examples/) folder for more details: 71 | - [binary classification](https://github.com/Mottl/lightgbm3-rs/blob/main/examples/binary_classification.rs) 72 | - [multiclass classification](https://github.com/Mottl/lightgbm3-rs/blob/main/examples/multiclass_classification.rs) 73 | - [regression](https://github.com/Mottl/lightgbm3-rs/blob/main/examples/regression.rs) 74 | 75 | ## Features 76 | **`lightgbm3`** supports the following features: 77 | - `polars` for [polars](https://github.com/pola-rs/polars) support 78 | - `openmp` for [multi-processing](https://lightgbm.readthedocs.io/en/latest/Installation-Guide.html#build-threadless-version-not-recommended) support 79 | - `gpu` for [GPU](https://lightgbm.readthedocs.io/en/latest/Installation-Guide.html#build-gpu-version) support 80 | - `cuda` for [CUDA](https://lightgbm.readthedocs.io/en/latest/Installation-Guide.html#build-cuda-version) support 81 | 82 | ## Benchmarks 83 | ``` 84 | cargo bench 85 | ``` 86 | 87 | Add `--features=openmp`, `--features=gpu` and `--features=cuda` appropriately. 88 | 89 | ## Development 90 | ``` 91 | git clone --recursive https://github.com/Mottl/lightgbm3-rs.git 92 | ``` 93 | 94 | ## Thanks 95 | Great respect to [vaaaaanquish](https://github.com/vaaaaanquish) for the LightGBM Rust package, which unfortunately 96 | no longer supported. 97 | 98 | Much reference was made to implementation and documentation. Thanks. 99 | 100 | - [microsoft/LightGBM](https://github.com/microsoft/LightGBM) 101 | - [davechallis/rust-xgboost](https://github.com/davechallis/rust-xgboost) 102 | -------------------------------------------------------------------------------- /benches/regression.rs: -------------------------------------------------------------------------------- 1 | use lightgbm3::{Booster, Dataset}; 2 | use rand_distr::Distribution; 3 | use serde_json::json; 4 | use std::hint::black_box; 5 | use std::time::Instant; 6 | 7 | fn generate_train_data() -> (Vec, Vec) { 8 | let mut rng = rand::rng(); 9 | let uniform = rand_distr::Uniform::::new(-5.0, 5.0).unwrap(); 10 | let normal = rand_distr::Normal::::new(0.0, 0.1).unwrap(); 11 | 12 | let mut x: Vec = vec![]; 13 | let mut y: Vec = Vec::with_capacity(100_000); 14 | 15 | for _ in 0..y.capacity() { 16 | let x1 = uniform.sample(&mut rng); 17 | let x2 = uniform.sample(&mut rng); 18 | let x3 = uniform.sample(&mut rng); 19 | let y_ = x1.sin() + x2.cos() + (x3 / 2.0).powi(2) + normal.sample(&mut rng); 20 | x.push(x1); 21 | x.push(x2); 22 | x.push(x3); 23 | y.push(y_ as f32); 24 | } 25 | (x, y) 26 | } 27 | 28 | fn main() -> std::io::Result<()> { 29 | const NUM_LEAVES: i32 = 5; 30 | 31 | let (x, y) = generate_train_data(); 32 | let train_dataset = Dataset::from_slice(&x, &y, 3, true).unwrap(); 33 | let params = json! { 34 | { 35 | "num_iterations": 1000, 36 | "learning_rate": 0.05, 37 | "num_leaves": NUM_LEAVES, 38 | "objective": "mse", 39 | } 40 | }; 41 | let start_time = Instant::now(); 42 | let mut booster = Booster::train(train_dataset, ¶ms).unwrap(); 43 | let train_time = start_time.elapsed().as_nanos() as f64 / 1000.0; 44 | let mut features: Vec = vec![]; 45 | #[cfg(feature = "openmp")] 46 | features.push("openmp".to_string()); 47 | #[cfg(feature = "gpu")] 48 | features.push("gpu".to_string()); 49 | #[cfg(feature = "cuda")] 50 | features.push("cuda".to_string()); 51 | if features.is_empty() { 52 | features.push("none".to_string()); 53 | } 54 | println!("Compiled features: {}", features.join(", ")); 55 | println!( 56 | "Booster train time: {:.3} us/iteration", 57 | train_time / 1000.0 58 | ); 59 | 60 | // let rmse = zip(y.iter(), y_preds.iter()).map(|(&y, &y_pred)| { 61 | // (y as f64 - y_pred).powi(2) 62 | // }).sum::() / y.len() as f64; 63 | // 64 | // println!("train rmse={}", rmse); 65 | // println!("num_iterations={}", booster.num_iterations()); 66 | 67 | // warm up CPU 68 | let _ = booster.predict(&x, 3, true).unwrap(); 69 | 70 | println!("Booster evaluation times:"); 71 | let x = [0.1, 0.5, -1.0]; 72 | for i in 1..=10 { 73 | booster.set_max_iterations(i * 50).unwrap(); 74 | let mut elapsed: u64 = 0; 75 | for _ in 0..100000 { 76 | let start_time = Instant::now(); 77 | let y_preds = booster.predict(&x, 3, true).unwrap(); 78 | let eval_time = start_time.elapsed().as_nanos() as u64; 79 | elapsed += eval_time; 80 | black_box(y_preds); 81 | } 82 | println!( 83 | "trees={:4.}, leaves={}, eval time={:.3} us/sample", 84 | i * 50, 85 | NUM_LEAVES, 86 | elapsed as f64 / 100000_f64 / 1000.0 87 | ) 88 | } 89 | 90 | Ok(()) 91 | } 92 | -------------------------------------------------------------------------------- /examples/binary_classification.rs: -------------------------------------------------------------------------------- 1 | //! Binary classification model training and evaluation example 2 | 3 | use lightgbm3::{Booster, Dataset, ImportanceType}; 4 | use serde_json::json; 5 | use std::iter::zip; 6 | 7 | /// Loads a .tsv file and returns a flattened vector of xs, a vector of labels 8 | /// and a number of features 9 | fn load_file(file_path: &str) -> (Vec, Vec, i32) { 10 | let rdr = csv::ReaderBuilder::new() 11 | .has_headers(false) 12 | .delimiter(b'\t') 13 | .from_path(file_path); 14 | let mut labels: Vec = Vec::new(); 15 | let mut features: Vec = Vec::new(); 16 | for result in rdr.unwrap().records() { 17 | let record = result.unwrap(); 18 | let mut record = record.into_iter(); 19 | let label = record.next().unwrap().parse::().unwrap(); 20 | labels.push(label); 21 | features.extend(record.map(|x| x.parse::().unwrap())); 22 | } 23 | let n_features = features.len() / labels.len(); 24 | (features, labels, n_features as i32) 25 | } 26 | 27 | fn main() -> std::io::Result<()> { 28 | let (train_features, train_labels, n_features) = 29 | load_file("lightgbm3-sys/lightgbm/examples/binary_classification/binary.train"); 30 | let (test_features, test_labels, n_features_test) = 31 | load_file("lightgbm3-sys/lightgbm/examples/binary_classification/binary.test"); 32 | assert_eq!(n_features, n_features_test); 33 | let train_dataset = 34 | Dataset::from_slice(&train_features, &train_labels, n_features, true).unwrap(); 35 | 36 | let params = json! { 37 | { 38 | "num_iterations": 100, 39 | "objective": "binary", 40 | "metric": "auc" 41 | } 42 | }; 43 | // Train a model 44 | let booster = Booster::train(train_dataset, ¶ms).unwrap(); 45 | // Predict probabilities 46 | let probas = booster.predict(&test_features, n_features, true).unwrap(); 47 | // Calculate accuracy 48 | let mut tp = 0; 49 | for (&label, &proba) in zip(&test_labels, &probas) { 50 | if (label == 1_f32 && proba > 0.5_f64) || (label == 0_f32 && proba <= 0.5_f64) { 51 | tp += 1; 52 | } 53 | println!("label={label}, proba={proba:.3}"); 54 | } 55 | println!("Accuracy: {} / {}\n", &tp, probas.len()); 56 | 57 | println!("Feature importance:"); 58 | let feature_name = booster.feature_name().unwrap(); 59 | let feature_importance = booster.feature_importance(ImportanceType::Gain).unwrap(); 60 | for (feature, importance) in zip(&feature_name, &feature_importance) { 61 | println!("{}: {}", feature, importance); 62 | } 63 | Ok(()) 64 | } 65 | -------------------------------------------------------------------------------- /examples/multiclass_classification.rs: -------------------------------------------------------------------------------- 1 | //! Multiclass classification model training and evaluation example 2 | 3 | use lightgbm3::{argmax, Booster, Dataset}; 4 | use serde_json::json; 5 | use std::iter::zip; 6 | 7 | /// Loads a .tsv file and returns a flattened vector of xs, a vector of labels 8 | /// and a number of features 9 | fn load_file(file_path: &str) -> (Vec, Vec, i32) { 10 | let rdr = csv::ReaderBuilder::new() 11 | .has_headers(false) 12 | .delimiter(b'\t') 13 | .from_path(file_path); 14 | let mut labels: Vec = Vec::new(); 15 | let mut features: Vec = Vec::new(); 16 | for result in rdr.unwrap().records() { 17 | let record = result.unwrap(); 18 | let mut record = record.into_iter(); 19 | let label = record.next().unwrap().parse::().unwrap(); 20 | labels.push(label); 21 | features.extend(record.map(|x| x.parse::().unwrap())); 22 | } 23 | let n_features = features.len() / labels.len(); 24 | (features, labels, n_features as i32) 25 | } 26 | 27 | fn main() -> std::io::Result<()> { 28 | let (train_features, train_labels, n_features) = 29 | load_file("lightgbm3-sys/lightgbm/examples/multiclass_classification/multiclass.train"); 30 | let (test_features, test_labels, n_features_test) = 31 | load_file("lightgbm3-sys/lightgbm/examples/multiclass_classification/multiclass.test"); 32 | assert_eq!(n_features, n_features_test); 33 | let train_dataset = 34 | Dataset::from_slice(&train_features, &train_labels, n_features, true).unwrap(); 35 | 36 | let params = json! { 37 | { 38 | "num_iterations": 100, 39 | "objective": "multiclass", 40 | "metric": "multi_logloss", 41 | "num_class": 5, 42 | } 43 | }; 44 | 45 | // Train a model 46 | let booster = Booster::train(train_dataset, ¶ms).unwrap(); 47 | // Predict probabilities for each class 48 | let probas = booster.predict(&test_features, n_features, true).unwrap(); 49 | // Calculate accuracy 50 | let mut tp = 0; 51 | for (&label, proba) in zip(&test_labels, probas.chunks(booster.num_classes() as usize)) { 52 | let argmax_pred = argmax(proba); 53 | if label == argmax_pred as f32 { 54 | tp += 1; 55 | } 56 | println!("true={label}, pred={argmax_pred}, probas={proba:.3?}"); 57 | } 58 | println!( 59 | "Accuracy: {} / {}", 60 | &tp, 61 | test_features.len() / n_features as usize 62 | ); 63 | Ok(()) 64 | } 65 | -------------------------------------------------------------------------------- /examples/regression.rs: -------------------------------------------------------------------------------- 1 | //! MSE regression model training and evaluation example 2 | 3 | use lightgbm3::{Booster, Dataset}; 4 | use serde_json::json; 5 | use std::iter::zip; 6 | 7 | /// Loads a .tsv file and returns a flattened vector of xs, a vector of ys 8 | /// and a number of features 9 | fn load_file(file_path: &str) -> (Vec, Vec, i32) { 10 | let rdr = csv::ReaderBuilder::new() 11 | .has_headers(false) 12 | .delimiter(b'\t') 13 | .from_path(file_path); 14 | let mut ys: Vec = Vec::new(); 15 | let mut xs: Vec = Vec::new(); 16 | for result in rdr.unwrap().records() { 17 | let record = result.unwrap(); 18 | let mut record = record.into_iter(); 19 | let y = record.next().unwrap().parse::().unwrap(); 20 | ys.push(y); 21 | xs.extend(record.map(|x| x.parse::().unwrap())); 22 | } 23 | let n_features = xs.len() / ys.len(); 24 | (xs, ys, n_features as i32) 25 | } 26 | 27 | fn main() -> std::io::Result<()> { 28 | let (train_xs, train_ys, n_features) = 29 | load_file("lightgbm3-sys/lightgbm/examples/regression/regression.train"); 30 | let (test_xs, test_ys, n_features_test) = 31 | load_file("lightgbm3-sys/lightgbm/examples/regression/regression.test"); 32 | assert_eq!(n_features, n_features_test); 33 | 34 | let train_dataset = Dataset::from_slice(&train_xs, &train_ys, n_features, true).unwrap(); 35 | 36 | let params = json! { 37 | { 38 | "num_iterations": 100, 39 | "objective": "regression", 40 | "metric": "l2" 41 | } 42 | }; 43 | 44 | // Train a model 45 | let booster = Booster::train(train_dataset, ¶ms).unwrap(); 46 | // Predicts floating point 47 | let y_pred = booster.predict(&test_xs, n_features, true).unwrap(); 48 | // Calculate regression metrics 49 | let mean = test_ys.iter().sum::() / test_ys.len() as f32; 50 | let var = test_ys.iter().map(|&y| (y - mean).powi(2)).sum::() / test_ys.len() as f32; 51 | let var_model = zip(&test_ys, &y_pred) 52 | .map(|(&y, &y_pred)| (y - y_pred as f32).powi(2)) 53 | .sum::() 54 | / test_ys.len() as f32; 55 | let r2 = 1.0f32 - var_model / var; 56 | println!("test mse = {var_model:.3}"); 57 | println!("test r^2 = {r2:.3}"); 58 | Ok(()) 59 | } 60 | -------------------------------------------------------------------------------- /lightgbm3-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lightgbm3-sys" 3 | version = "1.0.8" 4 | edition = "2021" 5 | authors = ["Dmitry Mottl ", "vaaaaanquish <6syun9@gmail.com>"] 6 | build = "build.rs" 7 | license = "MIT" 8 | repository = "https://github.com/Mottl/lightgbm3-rs" 9 | description = "Low-level Rust bindings for LightGBM library" 10 | categories = ["external-ffi-bindings"] 11 | readme = "README.md" 12 | exclude = ["README.md", ".gitlab-ci.yml", ".hgeol", ".gitignore", ".appveyor.yml", ".coveralls.yml", ".travis.yml", ".github", ".gitmodules", ".nuget", "**/*.md", "lightgbm/compute/doc", "lightgbm/compute/example", "lightgbm/compute/index.html", "lightgbm/compute/perf", "lightgbm/compute/test", "lightgbm/eigen/debug", "lightgbm/eigen/demos", "lightgbm/eigen/doc", "lightgbm/eigen/failtest", "lightgbm/eigen/test", "lightgbm/examples", "lightgbm/external_libs/fast_double_parser/benchmarks", "lightgbm/external_libs/fmt/doc", "lightgbm/external_libs/fmt/test"] 13 | 14 | [dependencies] 15 | libc = "0.2" 16 | 17 | [build-dependencies] 18 | cmake = "0.1" 19 | bindgen = "0.71" 20 | doxygen-rs = "0.4" 21 | 22 | [features] 23 | openmp = [] 24 | gpu = [] 25 | cuda = [] 26 | -------------------------------------------------------------------------------- /lightgbm3-sys/README.md: -------------------------------------------------------------------------------- 1 | # lightgbm3-sys 2 | 3 | FFI bindings to [LightGBM](https://github.com/microsoft/LightGBM), generated 4 | at compile time with [bindgen](https://github.com/rust-lang/rust-bindgen). 5 | -------------------------------------------------------------------------------- /lightgbm3-sys/build.rs: -------------------------------------------------------------------------------- 1 | use cmake::Config; 2 | use std::{ 3 | env, 4 | path::{Path, PathBuf}, 5 | process::Command, 6 | }; 7 | 8 | #[derive(Debug)] 9 | struct DoxygenCallback; 10 | 11 | impl bindgen::callbacks::ParseCallbacks for DoxygenCallback { 12 | fn process_comment(&self, comment: &str) -> Option { 13 | Some(doxygen_rs::transform(comment)) 14 | } 15 | } 16 | 17 | fn main() { 18 | let target = env::var("TARGET").unwrap(); 19 | let out_dir = env::var("OUT_DIR").unwrap(); 20 | let lgbm_root = Path::new(&out_dir).join("lightgbm"); 21 | 22 | // copy source code 23 | if !lgbm_root.exists() { 24 | let status = if target.contains("windows") { 25 | Command::new("cmd") 26 | .args(&[ 27 | "/C", 28 | "echo D | xcopy /S /Y lightgbm", 29 | lgbm_root.to_str().unwrap(), 30 | ]) 31 | .status() 32 | } else { 33 | Command::new("cp") 34 | .args(&["-r", "lightgbm", lgbm_root.to_str().unwrap()]) 35 | .status() 36 | }; 37 | if let Some(err) = status.err() { 38 | panic!( 39 | "Failed to copy ./lightgbm to {}: {}", 40 | lgbm_root.display(), 41 | err 42 | ); 43 | } 44 | } 45 | 46 | // CMake 47 | let mut cfg = Config::new(&lgbm_root); 48 | let cfg = cfg 49 | .profile("Release") 50 | .cxxflag("-std=c++14") 51 | .define("BUILD_STATIC_LIB", "ON"); 52 | #[cfg(not(feature = "openmp"))] 53 | let cfg = cfg.define("USE_OPENMP", "OFF"); 54 | #[cfg(feature = "gpu")] 55 | let cfg = cfg.define("USE_GPU", "1"); 56 | #[cfg(feature = "cuda")] 57 | let cfg = cfg.define("USE_CUDA", "1"); 58 | let dst = cfg.build(); 59 | 60 | // bindgen build 61 | let mut clang_args = vec!["-x", "c++", "-std=c++14"]; 62 | if target.contains("apple") { 63 | clang_args.push("-mmacosx-version-min=10.12"); 64 | } 65 | let bindings = bindgen::Builder::default() 66 | .header("lightgbm/include/LightGBM/c_api.h") 67 | .allowlist_file("lightgbm/include/LightGBM/c_api.h") 68 | .clang_args(&clang_args) 69 | .clang_arg(format!("-I{}", lgbm_root.join("include").display())) 70 | .parse_callbacks(Box::new(DoxygenCallback)) 71 | .generate() 72 | .expect("Unable to generate bindings"); 73 | let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); 74 | bindings 75 | .write_to_file(out_path.join("bindings.rs")) 76 | .unwrap_or_else(|err| panic!("Couldn't write bindings: {err}")); 77 | // link to appropriate C++ lib 78 | if target.contains("apple") { 79 | println!("cargo:rustc-link-lib=c++"); 80 | } else if target.contains("linux") { 81 | println!("cargo:rustc-link-lib=stdc++"); 82 | } 83 | #[cfg(feature = "openmp")] 84 | { 85 | println!("cargo:rustc-link-args=-fopenmp"); 86 | if target.contains("apple") { 87 | println!("cargo:rustc-link-lib=dylib=omp"); 88 | // Link to libomp 89 | // If it fails to compile in MacOS, try: 90 | // `brew install libomp` 91 | // `brew link --force libomp` 92 | #[cfg(all(target_arch = "x86_64", target_os = "macos"))] 93 | println!("cargo:rustc-link-search=/usr/local/opt/libomp/lib"); 94 | #[cfg(all(target_arch = "aarch64", target_os = "macos"))] 95 | println!("cargo:rustc-link-search=/opt/homebrew/opt/libomp/lib"); 96 | } else if target.contains("linux") { 97 | println!("cargo:rustc-link-lib=dylib=gomp"); 98 | } 99 | } 100 | println!("cargo:rustc-link-search={}", out_path.join("lib").display()); 101 | println!("cargo:rustc-link-search=native={}", dst.display()); 102 | if target.contains("windows") { 103 | println!("cargo:rustc-link-lib=static=lib_lightgbm"); 104 | } else { 105 | println!("cargo:rustc-link-lib=static=_lightgbm"); 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /lightgbm3-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | include!(concat!(env!("OUT_DIR"), "/bindings.rs")); 2 | -------------------------------------------------------------------------------- /src/booster.rs: -------------------------------------------------------------------------------- 1 | //! LightGBM booster 2 | 3 | use serde_json::Value; 4 | use std::os::raw::{c_char, c_longlong, c_void}; 5 | use std::{convert::TryInto, ffi::CString}; 6 | 7 | use crate::{dataset::DType, Dataset, Error, Result}; 8 | use lightgbm3_sys::BoosterHandle; 9 | 10 | /// Core model in LightGBM, containing functions for training, evaluating and predicting. 11 | pub struct Booster { 12 | handle: BoosterHandle, 13 | n_features: i32, 14 | n_iterations: i32, // number of trees in the booster 15 | max_iterations: i32, // maximum number of trees for prediction 16 | n_classes: i32, 17 | } 18 | 19 | /// Prediction type 20 | /// 21 | /// 22 | enum PredictType { 23 | Normal, 24 | RawScore, 25 | } 26 | 27 | /// Type of feature importance 28 | /// 29 | /// 30 | pub enum ImportanceType { 31 | /// Numbers of times the feature is used in a model 32 | Split, 33 | /// Total gains of splits which use the feature 34 | Gain, 35 | } 36 | 37 | impl Booster { 38 | fn new(handle: BoosterHandle) -> Result { 39 | let mut booster = Booster { 40 | handle, 41 | n_features: 0, 42 | n_iterations: 0, 43 | max_iterations: 0, 44 | n_classes: 0, 45 | }; 46 | booster.n_features = booster.inner_num_features()?; 47 | booster.n_iterations = booster.inner_num_iterations()?; 48 | booster.max_iterations = booster.n_iterations; 49 | booster.n_classes = booster.inner_num_classes()?; 50 | Ok(booster) 51 | } 52 | 53 | /// Load model from file. 54 | pub fn from_file(filename: &str) -> Result { 55 | let filename_str = CString::new(filename).unwrap(); 56 | let mut out_num_iterations = 0; 57 | let mut handle = std::ptr::null_mut(); 58 | lgbm_call!(lightgbm3_sys::LGBM_BoosterCreateFromModelfile( 59 | filename_str.as_ptr(), 60 | &mut out_num_iterations, 61 | &mut handle 62 | ))?; 63 | 64 | Booster::new(handle) 65 | } 66 | 67 | /// Load model from string. 68 | pub fn from_string(model_description: &str) -> Result { 69 | let cstring = CString::new(model_description).unwrap(); 70 | let mut out_num_iterations = 0; 71 | let mut handle = std::ptr::null_mut(); 72 | lgbm_call!(lightgbm3_sys::LGBM_BoosterLoadModelFromString( 73 | cstring.as_ptr(), 74 | &mut out_num_iterations, 75 | &mut handle 76 | ))?; 77 | 78 | Booster::new(handle) 79 | } 80 | 81 | /// Save model to file. 82 | pub fn save_file(&self, filename: &str) -> Result<()> { 83 | let filename_str = CString::new(filename).unwrap(); 84 | lgbm_call!(lightgbm3_sys::LGBM_BoosterSaveModel( 85 | self.handle, 86 | 0_i32, 87 | -1_i32, 88 | 0_i32, 89 | filename_str.as_ptr(), 90 | ))?; 91 | Ok(()) 92 | } 93 | 94 | /// Save model to string. This returns the same content that `save_file` writes into a file. 95 | pub fn save_string(&self) -> Result { 96 | // get nessesary buffer size 97 | 98 | let mut out_size = 0_i64; 99 | lgbm_call!(lightgbm3_sys::LGBM_BoosterSaveModelToString( 100 | self.handle, 101 | 0_i32, 102 | -1_i32, 103 | 0_i32, 104 | 0, 105 | &mut out_size, 106 | std::ptr::null_mut(), 107 | ))?; 108 | 109 | // write data to buffer and convert 110 | let mut buffer = vec![ 111 | 0u8; 112 | out_size 113 | .try_into() 114 | .map_err(|_| Error::new("size negative"))? 115 | ]; 116 | lgbm_call!(lightgbm3_sys::LGBM_BoosterSaveModelToString( 117 | self.handle, 118 | 0_i32, 119 | -1_i32, 120 | 0_i32, 121 | buffer.len() as c_longlong, 122 | &mut out_size, 123 | buffer.as_mut_ptr() as *mut c_char 124 | ))?; 125 | 126 | if buffer.pop() != Some(0) { 127 | // this should never happen, unless lightgbm has a bug 128 | panic!("write out of bounds happened in lightgbm call"); 129 | } 130 | 131 | let cstring = CString::new(buffer).map_err(|e| Error::new(e.to_string()))?; 132 | cstring 133 | .into_string() 134 | .map_err(|_| Error::new("can't convert model string to unicode")) 135 | } 136 | 137 | /// Get the number of classes. 138 | pub fn num_classes(&self) -> i32 { 139 | self.n_classes 140 | } 141 | 142 | /// Get the number of features. 143 | pub fn num_features(&self) -> i32 { 144 | self.n_features 145 | } 146 | 147 | /// Get the number of iterations in the booster. 148 | pub fn num_iterations(&self) -> i32 { 149 | self.n_iterations 150 | } 151 | 152 | /// Get the maximum number of iterations used for prediction. 153 | pub fn max_iterations(&self) -> i32 { 154 | self.max_iterations 155 | } 156 | 157 | /// Sets the the maximum number of iterations for prediction. 158 | pub fn set_max_iterations(&mut self, max_iterations: i32) -> Result<()> { 159 | if max_iterations > self.n_iterations { 160 | return Err(Error::new(format!( 161 | "max_iterations for prediction ({max_iterations})\ 162 | should not exceed the number of trees in the booster ({})", 163 | self.n_iterations 164 | ))); 165 | } 166 | self.max_iterations = max_iterations; 167 | Ok(()) 168 | } 169 | 170 | /// Trains a new model using `dataset` and `parameters`. 171 | /// 172 | /// Example 173 | /// ``` 174 | /// extern crate serde_json; 175 | /// use lightgbm3::{Dataset, Booster}; 176 | /// use serde_json::json; 177 | /// 178 | /// let xs = vec![vec![1.0, 0.1, 0.2, 0.1], 179 | /// vec![0.7, 0.4, 0.5, 0.1], 180 | /// vec![0.9, 0.8, 0.5, 0.1], 181 | /// vec![0.2, 0.2, 0.8, 0.7], 182 | /// vec![0.1, 0.7, 1.0, 0.9]]; 183 | /// let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0]; 184 | /// let dataset = Dataset::from_vec_of_vec(xs, labels, true).unwrap(); 185 | /// let params = json!{ 186 | /// { 187 | /// "num_iterations": 3, 188 | /// "objective": "binary", 189 | /// "metric": "auc" 190 | /// } 191 | /// }; 192 | /// let bst = Booster::train(dataset, ¶ms).unwrap(); 193 | /// ``` 194 | /// 195 | /// Full set of parameters can be found on the official LightGBM docs: 196 | /// 197 | pub fn train(dataset: Dataset, parameters: &Value) -> Result { 198 | let num_iterations: i64 = parameters["num_iterations"].as_i64().unwrap_or(100); 199 | 200 | // exchange params {"x": "y", "z": 1} => "x=y z=1" 201 | let params_string = parameters 202 | .as_object() 203 | .unwrap() 204 | .iter() 205 | .map(|(k, v)| format!("{}={}", k, v)) 206 | .collect::>() 207 | .join(" "); 208 | let params_cstring = CString::new(params_string).unwrap(); 209 | 210 | let mut handle = std::ptr::null_mut(); 211 | lgbm_call!(lightgbm3_sys::LGBM_BoosterCreate( 212 | dataset.handle, 213 | params_cstring.as_ptr(), 214 | &mut handle 215 | ))?; 216 | 217 | let mut is_finished: i32 = 0; 218 | for _ in 1..num_iterations { 219 | lgbm_call!(lightgbm3_sys::LGBM_BoosterUpdateOneIter( 220 | handle, 221 | &mut is_finished 222 | ))?; 223 | } 224 | Booster::new(handle) 225 | } 226 | 227 | fn real_predict( 228 | &self, 229 | flat_x: &[T], 230 | n_features: i32, 231 | is_row_major: bool, 232 | predict_type: PredictType, 233 | parameters: Option<&str>, 234 | ) -> Result> { 235 | if self.n_features <= 0 { 236 | return Err(Error::new("n_features should be greater than 0")); 237 | } 238 | if self.n_iterations <= 0 { 239 | return Err(Error::new("n_iterations should be greater than 0")); 240 | } 241 | if n_features != self.n_features { 242 | return Err(Error::new( 243 | format!("Number of features in data ({}) doesn't match the number of features in booster ({})", 244 | n_features, 245 | self.n_features) 246 | )); 247 | } 248 | if flat_x.len() % n_features as usize != 0 { 249 | return Err(Error::new(format!( 250 | "Invalid length of data: data.len()={}, n_features={}", 251 | flat_x.len(), 252 | n_features 253 | ))); 254 | } 255 | let n_rows = flat_x.len() / n_features as usize; 256 | let params_cstring = parameters 257 | .map(CString::new) 258 | .unwrap_or(CString::new("")) 259 | .unwrap(); 260 | let mut out_length: c_longlong = 0; 261 | let mut out_result: Vec = vec![Default::default(); n_rows * self.n_classes as usize]; 262 | lgbm_call!(lightgbm3_sys::LGBM_BoosterPredictForMat( 263 | self.handle, 264 | flat_x.as_ptr() as *const c_void, 265 | T::get_c_api_dtype(), 266 | n_rows as i32, 267 | n_features, 268 | if is_row_major { 1_i32 } else { 0_i32 }, // is_row_major 269 | predict_type.into(), // predict_type 270 | 0_i32, // start_iteration 271 | self.max_iterations, // num_iteration, <= 0 means no limit 272 | params_cstring.as_ptr(), 273 | &mut out_length, 274 | out_result.as_mut_ptr() 275 | ))?; 276 | 277 | Ok(out_result) 278 | } 279 | 280 | /// Get predictions given `&[f32]` or `&[f64]` slice of features. The resulting vector 281 | /// will have the size of `n_rows` by `n_classes`. 282 | pub fn predict( 283 | &self, 284 | flat_x: &[T], 285 | n_features: i32, 286 | is_row_major: bool, 287 | ) -> Result> { 288 | self.real_predict(flat_x, n_features, is_row_major, PredictType::Normal, None) 289 | } 290 | 291 | /// Get predictions given `&[f32]` or `&[f64]` slice of features. The resulting vector 292 | /// will have the size of `n_rows` by `n_classes`. 293 | /// 294 | /// Example: 295 | /// ```compile_fail 296 | /// use serde_json::json; 297 | /// let y_pred = bst.predict_with_params(&xs, 10, true, "num_threads=1").unwrap(); 298 | /// ``` 299 | pub fn predict_with_params( 300 | &self, 301 | flat_x: &[T], 302 | n_features: i32, 303 | is_row_major: bool, 304 | params: &str, 305 | ) -> Result> { 306 | self.real_predict( 307 | flat_x, 308 | n_features, 309 | is_row_major, 310 | PredictType::Normal, 311 | Some(params), 312 | ) 313 | } 314 | 315 | /// Get raw scores given `&[f32]` or `&[f64]` slice of features. The resulting vector 316 | /// will have the size of `n_rows` by `n_classes`. 317 | pub fn raw_scores( 318 | &self, 319 | flat_x: &[T], 320 | n_features: i32, 321 | is_row_major: bool, 322 | ) -> Result> { 323 | self.real_predict( 324 | flat_x, 325 | n_features, 326 | is_row_major, 327 | PredictType::RawScore, 328 | None, 329 | ) 330 | } 331 | 332 | /// Get raw scores given `&[f32]` or `&[f64]` slice of features. The resulting vector 333 | /// will have the size of `n_rows` by `n_classes`. 334 | /// 335 | /// Example: 336 | /// ```compile_fail 337 | /// use serde_json::json; 338 | /// let y_pred = bst.predict_with_params(&xs, 10, true, "num_threads=1").unwrap(); 339 | /// ``` 340 | pub fn raw_scores_with_params( 341 | &self, 342 | flat_x: &[T], 343 | n_features: i32, 344 | is_row_major: bool, 345 | parameters: &str, 346 | ) -> Result> { 347 | self.real_predict( 348 | flat_x, 349 | n_features, 350 | is_row_major, 351 | PredictType::RawScore, 352 | Some(parameters), 353 | ) 354 | } 355 | 356 | /// Predicts results for the given `x` and returns a vector or vectors (inner vectors will 357 | /// contain probabilities of classes per row). 358 | /// For regression the resulting inner vectors will have single element, so consider using 359 | /// predict method instead. 360 | /// 361 | /// Input data example 362 | /// ``` 363 | /// let data = vec![vec![1.0, 0.1], 364 | /// vec![0.7, 0.4], 365 | /// vec![0.1, 0.7], 366 | /// vec![0.2, 0.5]]; 367 | /// ``` 368 | /// 369 | /// Output data example for 3 classes: 370 | /// ``` 371 | /// let output = vec![vec![0.1, 0.8, 0.1], 372 | /// vec![0.7, 0.2, 0.1], 373 | /// vec![0.5, 0.4, 0.1], 374 | /// vec![0.2, 0.2, 0.6], 375 | /// ]; 376 | /// ``` 377 | pub fn predict_from_vec_of_vec( 378 | &self, 379 | x: Vec>, 380 | is_row_major: bool, 381 | ) -> Result>> { 382 | if x.is_empty() || x[0].is_empty() { 383 | return Err(Error::new("x is empty")); 384 | } 385 | let n_features = match is_row_major { 386 | true => x[0].len() as i32, 387 | false => x.len() as i32, 388 | }; 389 | let flat_x = x.into_iter().flatten().collect::>(); 390 | let pred_y = self.predict(&flat_x, n_features, is_row_major)?; 391 | 392 | Ok(pred_y 393 | .chunks(self.n_classes as usize) 394 | .map(|x| x.to_vec()) 395 | .collect()) 396 | } 397 | 398 | /// Get the number of classes. 399 | fn inner_num_classes(&self) -> Result { 400 | let mut num_classes = 0; 401 | lgbm_call!(lightgbm3_sys::LGBM_BoosterGetNumClasses( 402 | self.handle, 403 | &mut num_classes 404 | ))?; 405 | Ok(num_classes) 406 | } 407 | 408 | /// Get the number of features. 409 | fn inner_num_features(&self) -> Result { 410 | let mut num_features = 0; 411 | lgbm_call!(lightgbm3_sys::LGBM_BoosterGetNumFeature( 412 | self.handle, 413 | &mut num_features 414 | ))?; 415 | Ok(num_features) 416 | } 417 | 418 | /// Get index of the current boosting iteration. 419 | fn inner_num_iterations(&self) -> Result { 420 | let mut cur_iteration: i32 = 0; 421 | lgbm_call!(lightgbm3_sys::LGBM_BoosterGetCurrentIteration( 422 | self.handle, 423 | &mut cur_iteration 424 | ))?; 425 | Ok(cur_iteration + 1) 426 | } 427 | 428 | /// Gets features names. 429 | pub fn feature_name(&self) -> Result> { 430 | let num_feature = self.inner_num_features()?; 431 | let feature_name_length = 64; 432 | let mut num_feature_names = 0; 433 | let mut out_buffer_len = 0; 434 | let out_strs = (0..num_feature) 435 | .map(|_| { 436 | CString::new(" ".repeat(feature_name_length)) 437 | .unwrap() 438 | .into_raw() 439 | }) 440 | .collect::>(); 441 | lgbm_call!(lightgbm3_sys::LGBM_BoosterGetFeatureNames( 442 | self.handle, 443 | num_feature, 444 | &mut num_feature_names, 445 | feature_name_length, 446 | &mut out_buffer_len, 447 | out_strs.as_ptr() as *mut *mut c_char 448 | ))?; 449 | let output: Vec = out_strs 450 | .into_iter() 451 | .map(|s| unsafe { CString::from_raw(s).into_string().unwrap() }) 452 | .collect(); 453 | Ok(output) 454 | } 455 | 456 | /// Get feature importance. Refer to [`ImportanceType`] 457 | pub fn feature_importance(&self, importance_type: ImportanceType) -> Result> { 458 | let num_feature = self.inner_num_features()?; 459 | let mut out_result: Vec = vec![Default::default(); num_feature as usize]; 460 | lgbm_call!(lightgbm3_sys::LGBM_BoosterFeatureImportance( 461 | self.handle, 462 | 0_i32, 463 | importance_type.into(), 464 | out_result.as_mut_ptr() 465 | ))?; 466 | Ok(out_result) 467 | } 468 | } 469 | 470 | impl Drop for Booster { 471 | fn drop(&mut self) { 472 | lgbm_call!(lightgbm3_sys::LGBM_BoosterFree(self.handle)).unwrap(); 473 | } 474 | } 475 | 476 | impl From for i32 { 477 | fn from(value: ImportanceType) -> Self { 478 | match value { 479 | ImportanceType::Split => lightgbm3_sys::C_API_FEATURE_IMPORTANCE_SPLIT as i32, 480 | ImportanceType::Gain => lightgbm3_sys::C_API_FEATURE_IMPORTANCE_GAIN as i32, 481 | } 482 | } 483 | } 484 | 485 | impl From for i32 { 486 | fn from(value: PredictType) -> Self { 487 | match value { 488 | PredictType::Normal => lightgbm3_sys::C_API_PREDICT_NORMAL as i32, 489 | PredictType::RawScore => lightgbm3_sys::C_API_PREDICT_RAW_SCORE as i32, 490 | } 491 | } 492 | } 493 | 494 | #[cfg(test)] 495 | mod tests { 496 | use super::*; 497 | use serde_json::json; 498 | use std::{fs, path::Path}; 499 | const TMP_FOLDER: &str = "./target/tmp"; 500 | 501 | fn _read_train_file() -> Result { 502 | Dataset::from_file("lightgbm3-sys/lightgbm/examples/binary_classification/binary.train") 503 | } 504 | 505 | fn _train_booster(params: &Value) -> Booster { 506 | let dataset = _read_train_file().unwrap(); 507 | Booster::train(dataset, params).unwrap() 508 | } 509 | 510 | fn _default_params() -> Value { 511 | let params = json! { 512 | { 513 | "num_iterations": 1, 514 | "objective": "binary", 515 | "metric": "auc", 516 | "data_random_seed": 0 517 | } 518 | }; 519 | params 520 | } 521 | 522 | #[test] 523 | fn predict_from_vec_of_vec() { 524 | let params = json! { 525 | { 526 | "num_iterations": 10, 527 | "objective": "binary", 528 | "metric": "auc", 529 | "data_random_seed": 0 530 | } 531 | }; 532 | let bst = _train_booster(¶ms); 533 | let feature = vec![vec![0.5; 28], vec![0.0; 28], vec![0.9; 28]]; 534 | let result = bst.predict_from_vec_of_vec(feature, true).unwrap(); 535 | let mut normalized_result = Vec::new(); 536 | for r in &result { 537 | normalized_result.push(if r[0] > 0.5 { 1 } else { 0 }); 538 | } 539 | assert_eq!(normalized_result, vec![0, 0, 1]); 540 | } 541 | 542 | #[test] 543 | fn predict_with_params() { 544 | let params = json! { 545 | { 546 | "num_iterations": 10, 547 | "objective": "binary", 548 | "metric": "auc", 549 | "data_random_seed": 0 550 | } 551 | }; 552 | let bst = _train_booster(¶ms); 553 | // let feature = vec![vec![0.5; 28], vec![0.0; 28], vec![0.9; 28]]; 554 | let mut feature = [0.0; 28 * 3]; 555 | for i in 0..28 { 556 | feature[i] = 0.5; 557 | } 558 | for i in 56..feature.len() { 559 | feature[i] = 0.9; 560 | } 561 | 562 | let result = bst 563 | .predict_with_params(&feature, 28, true, "num_threads=1") 564 | .unwrap(); 565 | let mut normalized_result = Vec::new(); 566 | for r in &result { 567 | normalized_result.push(if *r > 0.5 { 1 } else { 0 }); 568 | } 569 | assert_eq!(normalized_result, vec![0, 0, 1]); 570 | } 571 | 572 | #[test] 573 | fn num_feature() { 574 | let params = _default_params(); 575 | let bst = _train_booster(¶ms); 576 | let num_feature = bst.inner_num_features().unwrap(); 577 | assert_eq!(num_feature, 28); 578 | } 579 | 580 | #[test] 581 | fn feature_importance() { 582 | let params = _default_params(); 583 | let bst = _train_booster(¶ms); 584 | let feature_importance = bst.feature_importance(ImportanceType::Gain).unwrap(); 585 | assert_eq!(feature_importance, vec![0.0; 28]); 586 | } 587 | 588 | #[test] 589 | fn feature_name() { 590 | let params = _default_params(); 591 | let bst = _train_booster(¶ms); 592 | let feature_name = bst.feature_name().unwrap(); 593 | let target = (0..28).map(|i| format!("Column_{}", i)).collect::>(); 594 | assert_eq!(feature_name, target); 595 | } 596 | 597 | #[test] 598 | fn save_file() { 599 | let params = _default_params(); 600 | let bst = _train_booster(¶ms); 601 | let _ = fs::create_dir(TMP_FOLDER); 602 | let filename = format!("{TMP_FOLDER}/model1.lgb"); 603 | assert!(bst.save_file(&filename).is_ok()); 604 | assert!(Path::new(&filename).exists()); 605 | assert!(Booster::from_file(&filename).is_ok()); 606 | assert!(fs::remove_file(&filename).is_ok()); 607 | } 608 | 609 | #[test] 610 | fn save_string() { 611 | let params = _default_params(); 612 | let bst = _train_booster(¶ms); 613 | let _ = fs::create_dir(TMP_FOLDER); 614 | let filename = format!("{TMP_FOLDER}/model2.lgb"); 615 | assert_eq!(bst.save_file(&filename), Ok(())); 616 | assert!(Path::new(&filename).exists()); 617 | let booster_file_content = fs::read_to_string(&filename).unwrap(); 618 | assert!(fs::remove_file(&filename).is_ok()); 619 | 620 | assert!(!booster_file_content.is_empty()); 621 | assert_eq!(Ok(booster_file_content.clone()), bst.save_string()); 622 | assert!(Booster::from_string(&booster_file_content).is_ok()); 623 | } 624 | } 625 | -------------------------------------------------------------------------------- /src/dataset.rs: -------------------------------------------------------------------------------- 1 | //! LightGBM Dataset used for training 2 | 3 | use lightgbm3_sys::{DatasetHandle, C_API_DTYPE_FLOAT32, C_API_DTYPE_FLOAT64}; 4 | use std::os::raw::c_void; 5 | use std::{self, ffi::CString}; 6 | 7 | #[cfg(feature = "polars")] 8 | use polars::{datatypes::DataType::Float32, prelude::*}; 9 | 10 | use crate::{Error, Result}; 11 | 12 | // a way of implementing sealed traits until they 13 | // come to rust lang. more details at: 14 | // https://internals.rust-lang.org/t/sealed-traits/16797 15 | mod private { 16 | pub trait Sealed {} 17 | 18 | impl Sealed for f32 {} 19 | impl Sealed for f64 {} 20 | } 21 | /// LightGBM dtype 22 | /// 23 | /// This trait is sealed as it is not intended 24 | /// to be implemented out of this crate 25 | pub trait DType: private::Sealed { 26 | fn get_c_api_dtype() -> i32; 27 | } 28 | 29 | impl DType for f32 { 30 | fn get_c_api_dtype() -> i32 { 31 | C_API_DTYPE_FLOAT32 as i32 32 | } 33 | } 34 | 35 | impl DType for f64 { 36 | fn get_c_api_dtype() -> i32 { 37 | C_API_DTYPE_FLOAT64 as i32 38 | } 39 | } 40 | 41 | /// LightGBM Dataset 42 | pub struct Dataset { 43 | pub(crate) handle: DatasetHandle, 44 | } 45 | 46 | impl Dataset { 47 | /// Creates a new Dataset object from the LightGBM's DatasetHandle. 48 | fn new(handle: DatasetHandle) -> Self { 49 | Self { handle } 50 | } 51 | 52 | /// Creates a new `Dataset` (x, labels) from flat `&[f64]` slice with a specified number 53 | /// of features (columns). 54 | /// 55 | /// `row_major` should be set to `true` for row-major order and `false` otherwise. 56 | /// 57 | /// # Example 58 | /// ``` 59 | /// use lightgbm3::Dataset; 60 | /// 61 | /// let x = vec![vec![1.0, 0.1, 0.2], 62 | /// vec![0.7, 0.4, 0.5], 63 | /// vec![0.9, 0.8, 0.5], 64 | /// vec![0.2, 0.2, 0.8], 65 | /// vec![0.1, 0.7, 1.0]]; 66 | /// let flat_x = x.into_iter().flatten().collect::>(); 67 | /// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; 68 | /// let n_features = 3; 69 | /// let dataset = Dataset::from_slice(&flat_x, &label, n_features, true).unwrap(); 70 | /// ``` 71 | pub fn from_slice( 72 | flat_x: &[T], 73 | label: &[f32], 74 | n_features: i32, 75 | is_row_major: bool, 76 | ) -> Result { 77 | if n_features <= 0 { 78 | return Err(Error::new("number of features should be greater than 0")); 79 | } 80 | if flat_x.len() % n_features as usize != 0 { 81 | return Err(Error::new( 82 | "number of features doesn't correspond to slice size", 83 | )); 84 | } 85 | let n_rows = flat_x.len() / n_features as usize; 86 | if n_rows == 0 { 87 | return Err(Error::new("slice is empty")); 88 | } else if n_rows > i32::MAX as usize { 89 | return Err(Error::new(format!( 90 | "number of rows should be less than {}. Got {}", 91 | i32::MAX, 92 | n_rows 93 | ))); 94 | } 95 | let params = CString::new("").unwrap(); 96 | let label_str = CString::new("label").unwrap(); 97 | let reference = std::ptr::null_mut(); // not used 98 | let mut dataset_handle = std::ptr::null_mut(); // will point to a new DatasetHandle 99 | 100 | lgbm_call!(lightgbm3_sys::LGBM_DatasetCreateFromMat( 101 | flat_x.as_ptr() as *const c_void, 102 | T::get_c_api_dtype(), 103 | n_rows as i32, 104 | n_features, 105 | if is_row_major { 1_i32 } else { 0_i32 }, // is_row_major – 1 for row-major, 0 for column-major 106 | params.as_ptr(), 107 | reference, 108 | &mut dataset_handle 109 | ))?; 110 | 111 | lgbm_call!(lightgbm3_sys::LGBM_DatasetSetField( 112 | dataset_handle, 113 | label_str.as_ptr(), 114 | label.as_ptr() as *const c_void, 115 | n_rows as i32, 116 | C_API_DTYPE_FLOAT32 as i32 // labels should be always float32 117 | ))?; 118 | 119 | Ok(Self::new(dataset_handle)) 120 | } 121 | 122 | /// Creates a new `Dataset` (x, labels) from `Vec>` in row-major order. 123 | /// 124 | /// # Example 125 | /// ``` 126 | /// use lightgbm3::Dataset; 127 | /// 128 | /// let data = vec![vec![1.0, 0.1, 0.2], 129 | /// vec![0.7, 0.4, 0.5], 130 | /// vec![0.9, 0.8, 0.5], 131 | /// vec![0.2, 0.2, 0.8], 132 | /// vec![0.1, 0.7, 1.0]]; 133 | /// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; // should be Vec 134 | /// let dataset = Dataset::from_vec_of_vec(data, label, true).unwrap(); 135 | /// ``` 136 | pub fn from_vec_of_vec( 137 | x: Vec>, 138 | label: Vec, 139 | is_row_major: bool, 140 | ) -> Result { 141 | if x.is_empty() || x[0].is_empty() { 142 | return Err(Error::new("x is empty")); 143 | } 144 | let n_features = match is_row_major { 145 | true => x[0].len() as i32, 146 | false => x.len() as i32, 147 | }; 148 | let x_flat = x.into_iter().flatten().collect::>(); 149 | Self::from_slice(&x_flat, &label, n_features, is_row_major) 150 | } 151 | 152 | /// Create a new `Dataset` from tab-separated-view file. 153 | /// 154 | /// file is `tsv`. 155 | /// ```text 156 | ///