├── .github ├── FUNDING.yml └── workflows │ └── ci.yml ├── .gitignore ├── .gitmodules ├── Cargo.toml ├── Dockerfile ├── LICENSE ├── README.md ├── examples ├── binary_classification │ ├── Cargo.toml │ └── src │ │ └── main.rs ├── multiclass_classification │ ├── Cargo.toml │ └── src │ │ └── main.rs └── regression │ ├── Cargo.toml │ └── src │ └── main.rs ├── lightgbm-sys ├── .cargo │ └── config ├── Cargo.toml ├── README.md ├── build.rs ├── src │ └── lib.rs └── wrapper.h ├── src ├── booster.rs ├── dataset.rs ├── error.rs └── lib.rs └── test └── test_from_file.input /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [vaaaaanquish] 2 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | name: Rust ${{ matrix.os }} ${{ matrix.rust }} 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | matrix: 11 | rust: 12 | - stable 13 | - beta 14 | - nightly 15 | os: [ubuntu-latest, macos-latest] 16 | steps: 17 | - uses: actions/checkout@v2 18 | with: 19 | submodules: recursive 20 | - name: Setup Rust 21 | uses: actions-rs/toolchain@v1 22 | with: 23 | toolchain: ${{ matrix.rust }} 24 | components: clippy 25 | - name: Build for OS X 26 | if: matrix.os == 'macos-latest' 27 | run: | 28 | brew install cmake 29 | brew install libomp 30 | cargo build --all-features 31 | - name: Build for ubuntu 32 | if: matrix.os == 'ubuntu-latest' 33 | run: | 34 | sudo apt-get update 35 | sudo apt-get install -y cmake libclang-dev libc++-dev gcc-multilib 36 | cargo build --all-features 37 | - name: Run tests 38 | run: cargo test --all-features 39 | continue-on-error: ${{ matrix.rust == 'nightly' }} 40 | - name: Run Clippy 41 | uses: actions-rs/clippy-check@v1 42 | with: 43 | token: ${{ secrets.GITHUB_TOKEN }} 44 | args: --all-features 45 | format_check: 46 | name: Run Rustfmt 47 | runs-on: ubuntu-latest 48 | steps: 49 | - uses: actions/checkout@v2 50 | - uses: actions-rs/toolchain@v1 51 | with: 52 | toolchain: stable 53 | components: rustfmt 54 | - run: cargo fmt -- --check 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 7 | Cargo.lock 8 | 9 | # These are backup files generated by rustfmt 10 | **/*.rs.bk 11 | 12 | # lightgbm-sys build target 13 | lightgbm-sys/target 14 | 15 | # example 16 | examples/binary_classification/target/ 17 | examples/multiclass_classification/target/ 18 | examples/regression/target/ 19 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "lightgbm-sys/lightgbm"] 2 | path = lightgbm-sys/lightgbm 3 | url = https://github.com/vaaaaanquish/LightGBM.git 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lightgbm" 3 | version = "0.2.3" 4 | authors = ["vaaaaanquish <6syun9@gmail.com>"] 5 | license = "MIT" 6 | repository = "https://github.com/vaaaaanquish/LightGBM" 7 | description = "Machine learning using LightGBM" 8 | readme = "README.md" 9 | exclude = [".gitignore", ".gitmodules", "examples", "lightgbm-sys"] 10 | 11 | [dependencies] 12 | lightgbm-sys = { path = "lightgbm-sys", version = "0.3.0" } 13 | libc = "0.2.81" 14 | derive_builder = "0.5.1" 15 | serde_json = "1.0.59" 16 | polars = {version = "0.16.0", optional = true} 17 | 18 | 19 | [features] 20 | default = [] 21 | dataframe = ["polars"] 22 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rust:1.49.0 2 | RUN apt update 3 | RUN apt install -y cmake libclang-dev libc++-dev gcc-multilib 4 | WORKDIR /app 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 vaaaaanquish 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # lightgbm-rs 2 | LightGBM Rust binding 3 | 4 | 5 | # Require 6 | 7 | You need an environment that can build LightGBM. 8 | 9 | ``` 10 | # linux 11 | apt install -y cmake libclang-dev libc++-dev gcc-multilib 12 | 13 | # OS X 14 | brew install cmake libomp 15 | ``` 16 | 17 | On Windows 18 | 19 | 1. Install CMake and VS Build Tools. 20 | 1. Install LLVM and set an environment variable `LIBCLANG_PATH` to PATH_TO_LLVM_BINARY (example: `C:\Program Files\LLVM\bin`) 21 | 22 | Please see below for details. 23 | 24 | - [LightGBM Installation-Guide](https://lightgbm.readthedocs.io/en/latest/Installation-Guide.html) 25 | 26 | # Usage 27 | 28 | Example LightGBM train. 29 | ``` 30 | extern crate serde_json; 31 | use lightgbm::{Dataset, Booster}; 32 | use serde_json::json; 33 | 34 | let data = vec![vec![1.0, 0.1, 0.2, 0.1], 35 | vec![0.7, 0.4, 0.5, 0.1], 36 | vec![0.9, 0.8, 0.5, 0.1], 37 | vec![0.2, 0.2, 0.8, 0.7], 38 | vec![0.1, 0.7, 1.0, 0.9]]; 39 | let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; 40 | let dataset = Dataset::from_mat(data, label).unwrap(); 41 | let params = json!{ 42 | { 43 | "num_iterations": 3, 44 | "objective": "binary", 45 | "metric": "auc" 46 | } 47 | }; 48 | let bst = Booster::train(dataset, ¶ms).unwrap(); 49 | ``` 50 | 51 | Please see the `./examples` for details. 52 | 53 | |example|link| 54 | |---|---| 55 | |binary classification|[link](https://github.com/vaaaaanquish/lightgbm-rs/blob/main/examples/binary_classification/src/main.rs)| 56 | |multiclass classification|[link](https://github.com/vaaaaanquish/lightgbm-rs/blob/main/examples/multiclass_classification/src/main.rs)| 57 | |regression|[link](https://github.com/vaaaaanquish/lightgbm-rs/blob/main/examples/regression/src/main.rs)| 58 | 59 | 60 | 61 | # Develop 62 | 63 | ``` 64 | git clone --recursive https://github.com/vaaaaanquish/lightgbm-rs 65 | ``` 66 | 67 | ``` 68 | docker build -t lgbmrs . 69 | docker run -it -v $PWD:/app lgbmrs bash 70 | 71 | # cargo build 72 | ``` 73 | 74 | 75 | # Thanks 76 | 77 | Much reference was made to implementation and documentation. Thanks. 78 | 79 | - [microsoft/LightGBM](https://github.com/microsoft/LightGBM) 80 | - [davechallis/rust-xgboost](https://github.com/davechallis/rust-xgboost) 81 | -------------------------------------------------------------------------------- /examples/binary_classification/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lightgbm-example-binary-classification" 3 | version = "0.1.0" 4 | authors = ["vaaaaanquish <6syun9@gmail.com>"] 5 | publish = false 6 | 7 | [dependencies] 8 | lightgbm = { path = "../../" } 9 | csv = "1.1.5" 10 | itertools = "0.9.0" 11 | serde_json = "1.0.59" 12 | -------------------------------------------------------------------------------- /examples/binary_classification/src/main.rs: -------------------------------------------------------------------------------- 1 | extern crate csv; 2 | extern crate itertools; 3 | extern crate lightgbm; 4 | extern crate serde_json; 5 | 6 | use itertools::zip; 7 | use lightgbm::{Booster, Dataset}; 8 | use serde_json::json; 9 | 10 | fn load_file(file_path: &str) -> (Vec>, Vec) { 11 | let rdr = csv::ReaderBuilder::new() 12 | .has_headers(false) 13 | .delimiter(b'\t') 14 | .from_path(file_path); 15 | let mut labels: Vec = Vec::new(); 16 | let mut features: Vec> = Vec::new(); 17 | for result in rdr.unwrap().records() { 18 | let record = result.unwrap(); 19 | let label = record[0].parse::().unwrap(); 20 | let feature: Vec = record 21 | .iter() 22 | .map(|x| x.parse::().unwrap()) 23 | .collect::>()[1..] 24 | .to_vec(); 25 | labels.push(label); 26 | features.push(feature); 27 | } 28 | (features, labels) 29 | } 30 | 31 | fn main() -> std::io::Result<()> { 32 | let (train_features, train_labels) = 33 | load_file("../../lightgbm-sys/lightgbm/examples/binary_classification/binary.train"); 34 | let (test_features, test_labels) = 35 | load_file("../../lightgbm-sys/lightgbm/examples/binary_classification/binary.test"); 36 | let train_dataset = Dataset::from_mat(train_features, train_labels).unwrap(); 37 | 38 | let params = json! { 39 | { 40 | "num_iterations": 100, 41 | "objective": "binary", 42 | "metric": "auc" 43 | } 44 | }; 45 | 46 | let booster = Booster::train(train_dataset, ¶ms).unwrap(); 47 | let result = booster.predict(test_features).unwrap(); 48 | 49 | let mut tp = 0; 50 | for (label, pred) in zip(&test_labels, &result[0]) { 51 | if (*label == 1_f32 && *pred > 0.5_f64) || (*label == 0_f32 && *pred <= 0.5_f64) { 52 | tp += 1; 53 | } 54 | println!("{}, {}", label, pred) 55 | } 56 | println!("feature importance"); 57 | let feature_name = booster.feature_name().unwrap(); 58 | let feature_importance = booster.feature_importance().unwrap(); 59 | for (feature, importance) in zip(&feature_name, &feature_importance) { 60 | println!("{}: {}", feature, importance); 61 | } 62 | println!("result: {} / {}", &tp, result[0].len()); 63 | Ok(()) 64 | } 65 | -------------------------------------------------------------------------------- /examples/multiclass_classification/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lightgbm-example-multiclass-classification" 3 | version = "0.1.0" 4 | authors = ["vaaaaanquish <6syun9@gmail.com>"] 5 | publish = false 6 | 7 | [dependencies] 8 | lightgbm = { path = "../../" } 9 | csv = "1.1.5" 10 | itertools = "0.9.0" 11 | serde_json = "1.0.59" 12 | -------------------------------------------------------------------------------- /examples/multiclass_classification/src/main.rs: -------------------------------------------------------------------------------- 1 | extern crate csv; 2 | extern crate itertools; 3 | extern crate lightgbm; 4 | extern crate serde_json; 5 | 6 | use itertools::zip; 7 | use lightgbm::{Booster, Dataset}; 8 | use serde_json::json; 9 | 10 | fn load_file(file_path: &str) -> (Vec>, Vec) { 11 | let rdr = csv::ReaderBuilder::new() 12 | .has_headers(false) 13 | .delimiter(b'\t') 14 | .from_path(file_path); 15 | let mut labels: Vec = Vec::new(); 16 | let mut features: Vec> = Vec::new(); 17 | for result in rdr.unwrap().records() { 18 | let record = result.unwrap(); 19 | let label = record[0].parse::().unwrap(); 20 | let feature: Vec = record 21 | .iter() 22 | .map(|x| x.parse::().unwrap()) 23 | .collect::>()[1..] 24 | .to_vec(); 25 | labels.push(label); 26 | features.push(feature); 27 | } 28 | (features, labels) 29 | } 30 | 31 | fn argmax(xs: &[T]) -> usize { 32 | if xs.len() == 1 { 33 | 0 34 | } else { 35 | let mut maxval = &xs[0]; 36 | let mut max_ixs: Vec = vec![0]; 37 | for (i, x) in xs.iter().enumerate().skip(1) { 38 | if x > maxval { 39 | maxval = x; 40 | max_ixs = vec![i]; 41 | } else if x == maxval { 42 | max_ixs.push(i); 43 | } 44 | } 45 | max_ixs[0] 46 | } 47 | } 48 | 49 | fn main() -> std::io::Result<()> { 50 | let (train_features, train_labels) = load_file( 51 | "../../lightgbm-sys/lightgbm/examples/multiclass_classification/multiclass.train", 52 | ); 53 | let (test_features, test_labels) = 54 | load_file("../../lightgbm-sys/lightgbm/examples/multiclass_classification/multiclass.test"); 55 | let train_dataset = Dataset::from_mat(train_features, train_labels).unwrap(); 56 | 57 | let params = json! { 58 | { 59 | "num_iterations": 100, 60 | "objective": "multiclass", 61 | "metric": "multi_logloss", 62 | "num_class": 5, 63 | } 64 | }; 65 | 66 | let booster = Booster::train(train_dataset, ¶ms).unwrap(); 67 | let result = booster.predict(test_features).unwrap(); 68 | 69 | let mut tp = 0; 70 | for (label, pred) in zip(&test_labels, &result) { 71 | let argmax_pred = argmax(&pred); 72 | if *label == argmax_pred as f32 { 73 | tp += 1; 74 | } 75 | println!("{}, {}, {:?}", label, argmax_pred, &pred); 76 | } 77 | println!("{} / {}", &tp, result.len()); 78 | Ok(()) 79 | } 80 | -------------------------------------------------------------------------------- /examples/regression/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lightgbm-example-regression" 3 | version = "0.1.0" 4 | authors = ["vaaaaanquish <6syun9@gmail.com>"] 5 | publish = false 6 | 7 | [dependencies] 8 | lightgbm = { path = "../../" } 9 | csv = "1.1.5" 10 | itertools = "0.9.0" 11 | serde_json = "1.0.59" 12 | -------------------------------------------------------------------------------- /examples/regression/src/main.rs: -------------------------------------------------------------------------------- 1 | extern crate csv; 2 | extern crate itertools; 3 | extern crate lightgbm; 4 | extern crate serde_json; 5 | 6 | use itertools::zip; 7 | use lightgbm::{Booster, Dataset}; 8 | use serde_json::json; 9 | 10 | fn load_file(file_path: &str) -> (Vec>, Vec) { 11 | let rdr = csv::ReaderBuilder::new() 12 | .has_headers(false) 13 | .delimiter(b'\t') 14 | .from_path(file_path); 15 | let mut labels: Vec = Vec::new(); 16 | let mut features: Vec> = Vec::new(); 17 | for result in rdr.unwrap().records() { 18 | let record = result.unwrap(); 19 | let label = record[0].parse::().unwrap(); 20 | let feature: Vec = record 21 | .iter() 22 | .map(|x| x.parse::().unwrap()) 23 | .collect::>()[1..] 24 | .to_vec(); 25 | labels.push(label); 26 | features.push(feature); 27 | } 28 | (features, labels) 29 | } 30 | 31 | fn main() -> std::io::Result<()> { 32 | let (train_features, train_labels) = 33 | load_file("../../lightgbm-sys/lightgbm/examples/regression/regression.train"); 34 | let (test_features, test_labels) = 35 | load_file("../../lightgbm-sys/lightgbm/examples/regression/regression.test"); 36 | let train_dataset = Dataset::from_mat(train_features, train_labels).unwrap(); 37 | 38 | let params = json! { 39 | { 40 | "num_iterations": 100, 41 | "objective": "regression", 42 | "metric": "l2" 43 | } 44 | }; 45 | 46 | let booster = Booster::train(train_dataset, ¶ms).unwrap(); 47 | let result = booster.predict(test_features).unwrap(); 48 | 49 | let mut tp = 0; 50 | for (label, pred) in zip(&test_labels, &result[0]) { 51 | if (*label == 1_f32 && *pred > 0.5_f64) || (*label == 0_f32 && *pred <= 0.5_f64) { 52 | tp += 1; 53 | } 54 | println!("{}, {}", label, pred) 55 | } 56 | println!("{} / {}", &tp, result[0].len()); 57 | Ok(()) 58 | } 59 | -------------------------------------------------------------------------------- /lightgbm-sys/.cargo/config: -------------------------------------------------------------------------------- 1 | [build] 2 | rustflags = ["-C", "link-args=-fopenmp"] 3 | -------------------------------------------------------------------------------- /lightgbm-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lightgbm-sys" 3 | version = "0.3.0" 4 | authors = ["vaaaaanquish <6syun9@gmail.com>"] 5 | build = "build.rs" 6 | license = "MIT" 7 | repository = "https://github.com/vaaaaanquish/LightGBM" 8 | description = "Native bindings to the LightGBM library" 9 | readme = "README.md" 10 | exclude = ["README.md", ".gitlab-ci.yml", ".hgeol", ".gitignore", ".appveyor.yml", ".coveralls.yml", ".travis.yml", ".github", ".gitmodules", ".nuget", "**/*.md", "lightgbm/compute/doc", "lightgbm/compute/example", "lightgbm/compute/index.html", "lightgbm/compute/perf", "lightgbm/compute/test", "lightgbm/eigen/debug", "lightgbm/eigen/demos", "lightgbm/eigen/doc", "lightgbm/eigen/failtest", "lightgbm/eigen/test", "lightgbm/examples", "lightgbm/external_libs/fast_double_parser/benchmarks", "lightgbm/external_libs/fmt/doc", "lightgbm/external_libs/fmt/test"] 11 | 12 | [dependencies] 13 | libc = "0.2.81" 14 | 15 | [build-dependencies] 16 | bindgen = "0.56.0" 17 | cmake = "0.1" 18 | -------------------------------------------------------------------------------- /lightgbm-sys/README.md: -------------------------------------------------------------------------------- 1 | # lightgbm-sys 2 | 3 | --- 4 | 5 | FFI bindings to [LightGBM](https://github.com/vaaaaanquish/LightGBM), generated at compile time with [bindgen](https://github.com/rust-lang/rust-bindgen). 6 | -------------------------------------------------------------------------------- /lightgbm-sys/build.rs: -------------------------------------------------------------------------------- 1 | extern crate bindgen; 2 | extern crate cmake; 3 | 4 | use cmake::Config; 5 | use std::env; 6 | use std::path::{Path, PathBuf}; 7 | use std::process::Command; 8 | 9 | fn main() { 10 | let target = env::var("TARGET").unwrap(); 11 | let out_dir = env::var("OUT_DIR").unwrap(); 12 | let lgbm_root = Path::new(&out_dir).join("lightgbm"); 13 | 14 | // copy source code 15 | if !lgbm_root.exists() { 16 | let status = if target.contains("windows") { 17 | Command::new("cmd") 18 | .args(&[ 19 | "/C", 20 | "echo D | xcopy /S /Y lightgbm", 21 | lgbm_root.to_str().unwrap(), 22 | ]) 23 | .status() 24 | } else { 25 | Command::new("cp") 26 | .args(&["-r", "lightgbm", lgbm_root.to_str().unwrap()]) 27 | .status() 28 | }; 29 | if let Some(err) = status.err() { 30 | panic!( 31 | "Failed to copy ./lightgbm to {}: {}", 32 | lgbm_root.display(), 33 | err 34 | ); 35 | } 36 | } 37 | 38 | // CMake 39 | let dst = Config::new(&lgbm_root) 40 | .profile("Release") 41 | .uses_cxx11() 42 | .define("BUILD_STATIC_LIB", "ON") 43 | .build(); 44 | 45 | // bindgen build 46 | let bindings = bindgen::Builder::default() 47 | .header("wrapper.h") 48 | .clang_args(&["-x", "c++", "-std=c++11"]) 49 | .clang_arg(format!("-I{}", lgbm_root.join("include").display())) 50 | .generate() 51 | .expect("Unable to generate bindings"); 52 | let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); 53 | bindings 54 | .write_to_file(out_path.join("bindings.rs")) 55 | .expect("Couldn't write bindings."); 56 | 57 | // link to appropriate C++ lib 58 | if target.contains("apple") { 59 | println!("cargo:rustc-link-lib=c++"); 60 | println!("cargo:rustc-link-lib=dylib=omp"); 61 | } else if target.contains("linux") { 62 | println!("cargo:rustc-link-lib=stdc++"); 63 | println!("cargo:rustc-link-lib=dylib=gomp"); 64 | } 65 | 66 | println!("cargo:rustc-link-search={}", out_path.join("lib").display()); 67 | println!("cargo:rustc-link-search=native={}", dst.display()); 68 | if target.contains("windows") { 69 | println!("cargo:rustc-link-lib=static=lib_lightgbm"); 70 | } else { 71 | println!("cargo:rustc-link-lib=static=_lightgbm"); 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /lightgbm-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_upper_case_globals)] 2 | #![allow(non_camel_case_types)] 3 | #![allow(non_snake_case)] 4 | #![allow(clippy::redundant_static_lifetimes)] 5 | #![allow(clippy::missing_safety_doc)] 6 | #![allow(clippy::upper_case_acronyms)] 7 | 8 | include!(concat!(env!("OUT_DIR"), "/bindings.rs")); 9 | -------------------------------------------------------------------------------- /lightgbm-sys/wrapper.h: -------------------------------------------------------------------------------- 1 | #include 2 | -------------------------------------------------------------------------------- /src/booster.rs: -------------------------------------------------------------------------------- 1 | use libc::{c_char, c_double, c_longlong, c_void}; 2 | use std; 3 | use std::ffi::CString; 4 | 5 | use serde_json::Value; 6 | 7 | use lightgbm_sys; 8 | 9 | use crate::{Dataset, Error, Result}; 10 | 11 | /// Core model in LightGBM, containing functions for training, evaluating and predicting. 12 | pub struct Booster { 13 | handle: lightgbm_sys::BoosterHandle, 14 | } 15 | 16 | impl Booster { 17 | fn new(handle: lightgbm_sys::BoosterHandle) -> Self { 18 | Booster { handle } 19 | } 20 | 21 | /// Init from model file. 22 | pub fn from_file(filename: &str) -> Result { 23 | let filename_str = CString::new(filename).unwrap(); 24 | let mut out_num_iterations = 0; 25 | let mut handle = std::ptr::null_mut(); 26 | lgbm_call!(lightgbm_sys::LGBM_BoosterCreateFromModelfile( 27 | filename_str.as_ptr() as *const c_char, 28 | &mut out_num_iterations, 29 | &mut handle 30 | ))?; 31 | 32 | Ok(Booster::new(handle)) 33 | } 34 | 35 | /// Create a new Booster model with given Dataset and parameters. 36 | /// 37 | /// Example 38 | /// ``` 39 | /// extern crate serde_json; 40 | /// use lightgbm::{Dataset, Booster}; 41 | /// use serde_json::json; 42 | /// 43 | /// let data = vec![vec![1.0, 0.1, 0.2, 0.1], 44 | /// vec![0.7, 0.4, 0.5, 0.1], 45 | /// vec![0.9, 0.8, 0.5, 0.1], 46 | /// vec![0.2, 0.2, 0.8, 0.7], 47 | /// vec![0.1, 0.7, 1.0, 0.9]]; 48 | /// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; 49 | /// let dataset = Dataset::from_mat(data, label).unwrap(); 50 | /// let params = json!{ 51 | /// { 52 | /// "num_iterations": 3, 53 | /// "objective": "binary", 54 | /// "metric": "auc" 55 | /// } 56 | /// }; 57 | /// let bst = Booster::train(dataset, ¶ms).unwrap(); 58 | /// ``` 59 | pub fn train(dataset: Dataset, parameter: &Value) -> Result { 60 | // get num_iterations 61 | let num_iterations: i64 = if parameter["num_iterations"].is_null() { 62 | 100 63 | } else { 64 | parameter["num_iterations"].as_i64().unwrap() 65 | }; 66 | 67 | // exchange params {"x": "y", "z": 1} => "x=y z=1" 68 | let params_string = parameter 69 | .as_object() 70 | .unwrap() 71 | .iter() 72 | .map(|(k, v)| format!("{}={}", k, v)) 73 | .collect::>() 74 | .join(" "); 75 | let params_cstring = CString::new(params_string).unwrap(); 76 | 77 | let mut handle = std::ptr::null_mut(); 78 | lgbm_call!(lightgbm_sys::LGBM_BoosterCreate( 79 | dataset.handle, 80 | params_cstring.as_ptr() as *const c_char, 81 | &mut handle 82 | ))?; 83 | 84 | let mut is_finished: i32 = 0; 85 | for _ in 1..num_iterations { 86 | lgbm_call!(lightgbm_sys::LGBM_BoosterUpdateOneIter( 87 | handle, 88 | &mut is_finished 89 | ))?; 90 | } 91 | Ok(Booster::new(handle)) 92 | } 93 | 94 | /// Predict results for given data. 95 | /// 96 | /// Input data example 97 | /// ``` 98 | /// let data = vec![vec![1.0, 0.1, 0.2], 99 | /// vec![0.7, 0.4, 0.5], 100 | /// vec![0.1, 0.7, 1.0]]; 101 | /// ``` 102 | /// 103 | /// Output data example 104 | /// ``` 105 | /// let output = vec![vec![1.0, 0.109, 0.433]]; 106 | /// ``` 107 | pub fn predict(&self, data: Vec>) -> Result>> { 108 | let data_length = data.len(); 109 | let feature_length = data[0].len(); 110 | let params = CString::new("").unwrap(); 111 | let mut out_length: c_longlong = 0; 112 | let flat_data = data.into_iter().flatten().collect::>(); 113 | 114 | // get num_class 115 | let mut num_class = 0; 116 | lgbm_call!(lightgbm_sys::LGBM_BoosterGetNumClasses( 117 | self.handle, 118 | &mut num_class 119 | ))?; 120 | 121 | let out_result: Vec = vec![Default::default(); data_length * num_class as usize]; 122 | 123 | lgbm_call!(lightgbm_sys::LGBM_BoosterPredictForMat( 124 | self.handle, 125 | flat_data.as_ptr() as *const c_void, 126 | lightgbm_sys::C_API_DTYPE_FLOAT64 as i32, 127 | data_length as i32, 128 | feature_length as i32, 129 | 1_i32, 130 | 0_i32, 131 | 0_i32, 132 | -1_i32, 133 | params.as_ptr() as *const c_char, 134 | &mut out_length, 135 | out_result.as_ptr() as *mut c_double 136 | ))?; 137 | 138 | // reshape for multiclass [1,2,3,4,5,6] -> [[1,2,3], [4,5,6]] # 3 class 139 | let reshaped_output = if num_class > 1 { 140 | out_result 141 | .chunks(num_class as usize) 142 | .map(|x| x.to_vec()) 143 | .collect() 144 | } else { 145 | vec![out_result] 146 | }; 147 | Ok(reshaped_output) 148 | } 149 | 150 | /// Get Feature Num. 151 | pub fn num_feature(&self) -> Result { 152 | let mut out_len = 0; 153 | lgbm_call!(lightgbm_sys::LGBM_BoosterGetNumFeature( 154 | self.handle, 155 | &mut out_len 156 | ))?; 157 | Ok(out_len) 158 | } 159 | 160 | /// Get Feature Names. 161 | pub fn feature_name(&self) -> Result> { 162 | let num_feature = self.num_feature()?; 163 | let feature_name_length = 32; 164 | let mut num_feature_names = 0; 165 | let mut out_buffer_len = 0; 166 | let out_strs = (0..num_feature) 167 | .map(|_| { 168 | CString::new(" ".repeat(feature_name_length)) 169 | .unwrap() 170 | .into_raw() as *mut c_char 171 | }) 172 | .collect::>(); 173 | lgbm_call!(lightgbm_sys::LGBM_BoosterGetFeatureNames( 174 | self.handle, 175 | feature_name_length as i32, 176 | &mut num_feature_names, 177 | num_feature as u64, 178 | &mut out_buffer_len, 179 | out_strs.as_ptr() as *mut *mut c_char 180 | ))?; 181 | let output: Vec = out_strs 182 | .into_iter() 183 | .map(|s| unsafe { CString::from_raw(s).into_string().unwrap() }) 184 | .collect(); 185 | Ok(output) 186 | } 187 | 188 | // Get Feature Importance 189 | pub fn feature_importance(&self) -> Result> { 190 | let num_feature = self.num_feature()?; 191 | let out_result: Vec = vec![Default::default(); num_feature as usize]; 192 | lgbm_call!(lightgbm_sys::LGBM_BoosterFeatureImportance( 193 | self.handle, 194 | 0_i32, 195 | 0_i32, 196 | out_result.as_ptr() as *mut c_double 197 | ))?; 198 | Ok(out_result) 199 | } 200 | 201 | /// Save model to file. 202 | pub fn save_file(&self, filename: &str) -> Result<()> { 203 | let filename_str = CString::new(filename).unwrap(); 204 | lgbm_call!(lightgbm_sys::LGBM_BoosterSaveModel( 205 | self.handle, 206 | 0_i32, 207 | -1_i32, 208 | 0_i32, 209 | filename_str.as_ptr() as *const c_char 210 | ))?; 211 | Ok(()) 212 | } 213 | } 214 | 215 | impl Drop for Booster { 216 | fn drop(&mut self) { 217 | lgbm_call!(lightgbm_sys::LGBM_BoosterFree(self.handle)).unwrap(); 218 | } 219 | } 220 | 221 | #[cfg(test)] 222 | mod tests { 223 | use super::*; 224 | use serde_json::json; 225 | use std::fs; 226 | use std::path::Path; 227 | 228 | fn _read_train_file() -> Result { 229 | Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train") 230 | } 231 | 232 | fn _train_booster(params: &Value) -> Booster { 233 | let dataset = _read_train_file().unwrap(); 234 | Booster::train(dataset, ¶ms).unwrap() 235 | } 236 | 237 | fn _default_params() -> Value { 238 | let params = json! { 239 | { 240 | "num_iterations": 1, 241 | "objective": "binary", 242 | "metric": "auc", 243 | "data_random_seed": 0 244 | } 245 | }; 246 | params 247 | } 248 | 249 | #[test] 250 | fn predict() { 251 | let params = json! { 252 | { 253 | "num_iterations": 10, 254 | "objective": "binary", 255 | "metric": "auc", 256 | "data_random_seed": 0 257 | } 258 | }; 259 | let bst = _train_booster(¶ms); 260 | let feature = vec![vec![0.5; 28], vec![0.0; 28], vec![0.9; 28]]; 261 | let result = bst.predict(feature).unwrap(); 262 | let mut normalized_result = Vec::new(); 263 | for r in &result[0] { 264 | normalized_result.push(if r > &0.5 { 1 } else { 0 }); 265 | } 266 | assert_eq!(normalized_result, vec![0, 0, 1]); 267 | } 268 | 269 | #[test] 270 | fn num_feature() { 271 | let params = _default_params(); 272 | let bst = _train_booster(¶ms); 273 | let num_feature = bst.num_feature().unwrap(); 274 | assert_eq!(num_feature, 28); 275 | } 276 | 277 | #[test] 278 | fn feature_importance() { 279 | let params = _default_params(); 280 | let bst = _train_booster(¶ms); 281 | let feature_importance = bst.feature_importance().unwrap(); 282 | assert_eq!(feature_importance, vec![0.0; 28]); 283 | } 284 | 285 | #[test] 286 | fn feature_name() { 287 | let params = _default_params(); 288 | let bst = _train_booster(¶ms); 289 | let feature_name = bst.feature_name().unwrap(); 290 | let target = (0..28).map(|i| format!("Column_{}", i)).collect::>(); 291 | assert_eq!(feature_name, target); 292 | } 293 | 294 | #[test] 295 | fn save_file() { 296 | let params = _default_params(); 297 | let bst = _train_booster(¶ms); 298 | assert_eq!(bst.save_file(&"./test/test_save_file.output"), Ok(())); 299 | assert!(Path::new("./test/test_save_file.output").exists()); 300 | let _ = fs::remove_file("./test/test_save_file.output"); 301 | } 302 | 303 | #[test] 304 | fn from_file() { 305 | let _ = Booster::from_file(&"./test/test_from_file.input"); 306 | } 307 | } 308 | -------------------------------------------------------------------------------- /src/dataset.rs: -------------------------------------------------------------------------------- 1 | use libc::{c_char, c_void}; 2 | use lightgbm_sys; 3 | use std; 4 | use std::ffi::CString; 5 | 6 | #[cfg(feature = "dataframe")] 7 | use polars::prelude::*; 8 | 9 | use crate::{Error, Result}; 10 | 11 | /// Dataset used throughout LightGBM for training. 12 | /// 13 | /// # Examples 14 | /// 15 | /// ## from mat 16 | /// 17 | /// ``` 18 | /// use lightgbm::Dataset; 19 | /// 20 | /// let data = vec![vec![1.0, 0.1, 0.2, 0.1], 21 | /// vec![0.7, 0.4, 0.5, 0.1], 22 | /// vec![0.9, 0.8, 0.5, 0.1], 23 | /// vec![0.2, 0.2, 0.8, 0.7], 24 | /// vec![0.1, 0.7, 1.0, 0.9]]; 25 | /// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; 26 | /// let dataset = Dataset::from_mat(data, label).unwrap(); 27 | /// ``` 28 | /// 29 | /// ## from file 30 | /// 31 | /// ``` 32 | /// use lightgbm::Dataset; 33 | /// 34 | /// let dataset = Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train").unwrap(); 35 | /// ``` 36 | pub struct Dataset { 37 | pub(crate) handle: lightgbm_sys::DatasetHandle, 38 | } 39 | 40 | #[link(name = "c")] 41 | impl Dataset { 42 | fn new(handle: lightgbm_sys::DatasetHandle) -> Self { 43 | Self { handle } 44 | } 45 | 46 | /// Create a new `Dataset` from dense array in row-major order. 47 | /// 48 | /// Example 49 | /// ``` 50 | /// use lightgbm::Dataset; 51 | /// 52 | /// let data = vec![vec![1.0, 0.1, 0.2, 0.1], 53 | /// vec![0.7, 0.4, 0.5, 0.1], 54 | /// vec![0.9, 0.8, 0.5, 0.1], 55 | /// vec![0.2, 0.2, 0.8, 0.7], 56 | /// vec![0.1, 0.7, 1.0, 0.9]]; 57 | /// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; 58 | /// let dataset = Dataset::from_mat(data, label).unwrap(); 59 | /// ``` 60 | pub fn from_mat(data: Vec>, label: Vec) -> Result { 61 | let data_length = data.len(); 62 | let feature_length = data[0].len(); 63 | let params = CString::new("").unwrap(); 64 | let label_str = CString::new("label").unwrap(); 65 | let reference = std::ptr::null_mut(); // not use 66 | let mut handle = std::ptr::null_mut(); 67 | let flat_data = data.into_iter().flatten().collect::>(); 68 | 69 | lgbm_call!(lightgbm_sys::LGBM_DatasetCreateFromMat( 70 | flat_data.as_ptr() as *const c_void, 71 | lightgbm_sys::C_API_DTYPE_FLOAT64 as i32, 72 | data_length as i32, 73 | feature_length as i32, 74 | 1_i32, 75 | params.as_ptr() as *const c_char, 76 | reference, 77 | &mut handle 78 | ))?; 79 | 80 | lgbm_call!(lightgbm_sys::LGBM_DatasetSetField( 81 | handle, 82 | label_str.as_ptr() as *const c_char, 83 | label.as_ptr() as *const c_void, 84 | data_length as i32, 85 | lightgbm_sys::C_API_DTYPE_FLOAT32 as i32 86 | ))?; 87 | 88 | Ok(Self::new(handle)) 89 | } 90 | 91 | /// Create a new `Dataset` from file. 92 | /// 93 | /// file is `tsv`. 94 | /// ```text 95 | ///