├── .github ├── FUNDING.yml ├── codecov.yml └── workflows │ ├── benching.yml │ ├── checking.yml │ ├── codequality.yml │ ├── docs.yml │ └── testing.yml ├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTE.md ├── Cargo.toml ├── LICENSE-APACHE2 ├── LICENSE-MIT ├── README.md ├── algorithms ├── linfa-bayes │ ├── Cargo.toml │ ├── README.md │ ├── examples │ │ ├── winequality_bayes.rs │ │ ├── winequality_bernouilli.rs │ │ └── winequality_multinomial.rs │ └── src │ │ ├── base_nb.rs │ │ ├── bernoulli_nb.rs │ │ ├── error.rs │ │ ├── gaussian_nb.rs │ │ ├── hyperparams.rs │ │ ├── lib.rs │ │ └── multinomial_nb.rs ├── linfa-clustering │ ├── Cargo.toml │ ├── README.md │ ├── benches │ │ ├── dbscan.rs │ │ ├── gaussian_mixture.rs │ │ └── k_means.rs │ ├── examples │ │ ├── dbscan.rs │ │ ├── kmeans.rs │ │ ├── optics.rs │ │ └── optics_plot.py │ └── src │ │ ├── appx_dbscan │ │ ├── algorithm.rs │ │ ├── cells_grid │ │ │ ├── cell.rs │ │ │ ├── mod.rs │ │ │ └── tests.rs │ │ ├── clustering │ │ │ ├── mod.rs │ │ │ └── tests.rs │ │ ├── counting_tree │ │ │ ├── mod.rs │ │ │ └── tests.rs │ │ ├── hyperparams.rs │ │ ├── mod.rs │ │ └── tests.rs │ │ ├── dbscan │ │ ├── algorithm.rs │ │ ├── hyperparams.rs │ │ └── mod.rs │ │ ├── gaussian_mixture │ │ ├── algorithm.rs │ │ ├── errors.rs │ │ ├── hyperparams.rs │ │ └── mod.rs │ │ ├── k_means │ │ ├── algorithm.rs │ │ ├── errors.rs │ │ ├── hyperparams.rs │ │ ├── init.rs │ │ └── mod.rs │ │ ├── lib.rs │ │ └── optics │ │ ├── algorithm.rs │ │ ├── errors.rs │ │ ├── hyperparams.rs │ │ └── mod.rs ├── linfa-elasticnet │ ├── Cargo.toml │ ├── README.md │ ├── examples │ │ ├── elasticnet.rs │ │ ├── elasticnet_cv.rs │ │ └── multitask_elasticnet.rs │ └── src │ │ ├── algorithm.rs │ │ ├── error.rs │ │ ├── hyperparams.rs │ │ └── lib.rs ├── linfa-ensemble │ ├── Cargo.toml │ ├── README.md │ ├── examples │ │ └── bagging_iris.rs │ └── src │ │ ├── algorithm.rs │ │ ├── hyperparams.rs │ │ └── lib.rs ├── linfa-ftrl │ ├── Cargo.toml │ ├── README.md │ ├── benches │ │ └── ftrl.rs │ ├── examples │ │ └── winequality_ftrl.rs │ └── src │ │ ├── algorithm.rs │ │ ├── error.rs │ │ ├── hyperparams.rs │ │ └── lib.rs ├── linfa-hierarchical │ ├── Cargo.toml │ ├── README.md │ ├── examples │ │ └── irisflower.rs │ └── src │ │ ├── error.rs │ │ └── lib.rs ├── linfa-ica │ ├── Cargo.toml │ ├── README.md │ ├── benches │ │ └── fast_ica.rs │ ├── examples │ │ ├── README.md │ │ ├── fast_ica.rs │ │ └── images │ │ │ └── fast_ica.png │ └── src │ │ ├── error.rs │ │ ├── fast_ica.rs │ │ ├── hyperparams.rs │ │ └── lib.rs ├── linfa-kernel │ ├── Cargo.toml │ ├── README.md │ └── src │ │ ├── inner.rs │ │ ├── lib.rs │ │ └── sparse.rs ├── linfa-linear │ ├── Cargo.toml │ ├── README.md │ ├── benches │ │ └── ols_bench.rs │ ├── examples │ │ ├── diabetes.rs │ │ └── glm.rs │ └── src │ │ ├── error.rs │ │ ├── float.rs │ │ ├── glm │ │ ├── distribution.rs │ │ ├── hyperparams.rs │ │ ├── link.rs │ │ └── mod.rs │ │ ├── isotonic.rs │ │ ├── lib.rs │ │ └── ols.rs ├── linfa-logistic │ ├── Cargo.toml │ ├── README.md │ ├── examples │ │ ├── logistic_cv.rs │ │ ├── winequality_logistic.rs │ │ └── winequality_multi_logistic.rs │ └── src │ │ ├── argmin_param.rs │ │ ├── error.rs │ │ ├── float.rs │ │ ├── hyperparams.rs │ │ └── lib.rs ├── linfa-nn │ ├── Cargo.toml │ ├── README.md │ ├── benches │ │ └── nn.rs │ ├── src │ │ ├── balltree.rs │ │ ├── distance.rs │ │ ├── heap_elem.rs │ │ ├── kdtree.rs │ │ ├── lib.rs │ │ └── linear.rs │ └── tests │ │ └── nn.rs ├── linfa-pls │ ├── Cargo.toml │ ├── README.md │ ├── benches │ │ └── pls.rs │ ├── examples │ │ └── pls_regression.rs │ └── src │ │ ├── errors.rs │ │ ├── hyperparams.rs │ │ ├── lib.rs │ │ ├── pls_generic.rs │ │ ├── pls_svd.rs │ │ └── utils.rs ├── linfa-preprocessing │ ├── Cargo.toml │ ├── README.md │ ├── benches │ │ ├── linear_scaler_bench.rs │ │ ├── norm_scaler_bench.rs │ │ ├── vectorizer_bench.rs │ │ └── whitening_bench.rs │ ├── examples │ │ ├── count_vectorization.rs │ │ ├── scaling.rs │ │ ├── tfidf_vectorization.rs │ │ └── whitening.rs │ └── src │ │ ├── countgrams │ │ ├── hyperparams.rs │ │ └── mod.rs │ │ ├── error.rs │ │ ├── helpers.rs │ │ ├── lib.rs │ │ ├── linear_scaling.rs │ │ ├── norm_scaling.rs │ │ ├── tf_idf_vectorization.rs │ │ └── whitening.rs ├── linfa-reduction │ ├── Cargo.toml │ ├── README.md │ ├── examples │ │ ├── diffusion_map.rs │ │ ├── gaussian_projection.rs │ │ ├── pca.rs │ │ └── sparse_projection.rs │ └── src │ │ ├── diffusion_map │ │ ├── algorithms.rs │ │ ├── hyperparams.rs │ │ └── mod.rs │ │ ├── error.rs │ │ ├── lib.rs │ │ ├── pca.rs │ │ ├── random_projection │ │ ├── algorithms.rs │ │ ├── common.rs │ │ ├── hyperparams.rs │ │ ├── methods.rs │ │ └── mod.rs │ │ └── utils.rs ├── linfa-svm │ ├── Cargo.toml │ ├── README.md │ ├── examples │ │ ├── noisy_sin_svr.rs │ │ ├── winequality_multi_svm.rs │ │ └── winequality_svm.rs │ └── src │ │ ├── classification.rs │ │ ├── error.rs │ │ ├── hyperparams.rs │ │ ├── lib.rs │ │ ├── permutable_kernel.rs │ │ ├── regression.rs │ │ └── solver_smo.rs ├── linfa-trees │ ├── Cargo.toml │ ├── README.md │ ├── benches │ │ └── decision_tree.rs │ ├── examples │ │ └── decision_tree.rs │ ├── iris-decisiontree.svg │ └── src │ │ ├── decision_trees │ │ ├── algorithm.rs │ │ ├── hyperparams.rs │ │ ├── iter.rs │ │ ├── mod.rs │ │ └── tikz.rs │ │ └── lib.rs └── linfa-tsne │ ├── Cargo.toml │ ├── README.md │ ├── examples │ ├── iris.dat │ ├── iris_plot.plt │ ├── mnist.rs │ ├── mnist_plot.plt │ └── tsne.rs │ └── src │ ├── error.rs │ ├── hyperparams.rs │ └── lib.rs ├── build.rs ├── datasets ├── Cargo.toml ├── README.md ├── data │ ├── diabetes_data.csv.gz │ ├── diabetes_target.csv.gz │ ├── iris.csv.gz │ ├── linnerud_exercise.csv.gz │ ├── linnerud_physiological.csv.gz │ └── winequality-red.csv.gz └── src │ ├── dataset.rs │ ├── generate.rs │ └── lib.rs ├── docs └── website │ ├── config.toml │ ├── content │ ├── about.md │ ├── blog │ │ ├── _index.md │ │ └── first.md │ ├── community.md │ ├── docs.md │ ├── news │ │ ├── _index.md │ │ ├── first_release.md │ │ ├── new_website.md │ │ ├── release_021.md │ │ ├── release_030.md │ │ ├── release_031.md │ │ ├── release_040 │ │ │ ├── index.md │ │ │ └── tsne.png │ │ ├── release_050.md │ │ ├── release_051.md │ │ ├── release_060.md │ │ ├── release_061.md │ │ ├── release_070.md │ │ └── release_071.md │ └── snippets │ │ ├── _index.md │ │ ├── cross-validation.md │ │ ├── decision-trees.md │ │ ├── diffusion-maps.md │ │ ├── elasticnet.md │ │ ├── gaussian-naive-bayes.md │ │ ├── k-folding.md │ │ ├── multi-class.md │ │ ├── multi-targets.md │ │ ├── partial-least-squares.md │ │ ├── random-projection.md │ │ ├── sv-machines.md │ │ └── tsne.md │ ├── sass │ ├── _base.scss │ ├── desktop │ │ ├── _base.scss │ │ ├── _home.scss │ │ └── _news.scss │ ├── mobile │ │ ├── _base.scss │ │ ├── _home.scss │ │ └── _news.scss │ └── style.scss │ ├── static │ ├── code-examples.js │ └── mascot.svg │ └── templates │ ├── about.html │ ├── base.html │ ├── blog-page.html │ ├── blog.html │ ├── index.html │ ├── news-page.html │ ├── news.html │ └── page.html ├── mascot.svg └── src ├── benchmarks └── mod.rs ├── composing ├── mod.rs ├── multi_class_model.rs ├── multi_target_model.rs └── platt_scaling.rs ├── correlation.rs ├── dataset ├── impl_dataset.rs ├── impl_records.rs ├── impl_targets.rs ├── iter.rs ├── lapack_bounds.rs └── mod.rs ├── error.rs ├── lib.rs ├── metrics_classification.rs ├── metrics_clustering.rs ├── metrics_regression.rs ├── param_guard.rs ├── prelude.rs └── traits.rs /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | liberapay: linfa 2 | -------------------------------------------------------------------------------- /.github/codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | threshold: 1% 6 | patch: 7 | default: 8 | informational: true -------------------------------------------------------------------------------- /.github/workflows/benching.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | 3 | name: Run iai Benches 4 | 5 | jobs: 6 | testing: 7 | name: benching 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - name: Checkout sources 12 | uses: actions/checkout@master 13 | 14 | - name: Install toolchain 15 | uses: dtolnay/rust-toolchain@stable 16 | 17 | - name: Run cargo bench iai 18 | run: cargo bench iai --all 19 | -------------------------------------------------------------------------------- /.github/workflows/checking.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | 3 | name: Check For Build Errors 4 | 5 | jobs: 6 | check: 7 | name: check-${{ matrix.toolchain }}-${{ matrix.os }} 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | fail-fast: false 11 | matrix: 12 | toolchain: 13 | - 1.82.0 14 | - stable 15 | - nightly 16 | os: 17 | - ubuntu-latest 18 | - windows-latest 19 | 20 | continue-on-error: ${{ matrix.toolchain == 'nightly' }} 21 | 22 | steps: 23 | - name: Checkout sources 24 | uses: actions/checkout@master 25 | 26 | - name: Install toolchain 27 | uses: dtolnay/rust-toolchain@master 28 | with: 29 | toolchain: ${{ matrix.toolchain }} 30 | 31 | - name: Log active toolchain 32 | run: rustup show 33 | 34 | # Check if linfa compiles by itself without uniting dependency features with other crates 35 | - name: Run cargo check on linfa 36 | run: cargo check 37 | 38 | - name: Run cargo check (no features) 39 | run: cargo check --workspace --all-targets 40 | 41 | - name: Run cargo check (with serde) 42 | run: cargo check --workspace --all-targets --features "linfa-clustering/serde linfa-ica/serde linfa-kernel/serde linfa-reduction/serde linfa-svm/serde linfa-elasticnet/serde linfa-pls/serde linfa-trees/serde linfa-nn/serde linfa-linear/serde linfa-preprocessing/serde linfa-bayes/serde linfa-logistic/serde linfa-ftrl/serde" 43 | -------------------------------------------------------------------------------- /.github/workflows/codequality.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | 3 | 4 | name: Codequality Lints 5 | 6 | jobs: 7 | codequality: 8 | name: codequality 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | toolchain: 13 | - stable 14 | 15 | steps: 16 | - name: Checkout sources 17 | uses: actions/checkout@master 18 | 19 | - name: Install toolchain 20 | uses: dtolnay/rust-toolchain@master 21 | with: 22 | toolchain: ${{ matrix.toolchain }} 23 | components: rustfmt, clippy 24 | 25 | - name: Run cargo fmt 26 | run: cargo fmt --all -- --check 27 | 28 | - name: Run cargo clippy 29 | run: cargo clippy --workspace --all-targets -- -D warnings 30 | 31 | coverage: 32 | needs: codequality 33 | name: coverage 34 | runs-on: ubuntu-latest 35 | if: github.event_name == 'pull_request' || github.ref == 'refs/heads/master' 36 | 37 | steps: 38 | - name: Checkout sources 39 | uses: actions/checkout@master 40 | 41 | - name: Install toolchain 42 | uses: dtolnay/rust-toolchain@stable 43 | 44 | - name: Get rustc version 45 | id: rustc-version 46 | run: echo "::set-output name=version::$(cargo --version | cut -d ' ' -f 2)" 47 | shell: bash 48 | 49 | - uses: actions/cache@v4 50 | id: tarpaulin-cache 51 | with: 52 | path: | 53 | ~/.cargo/bin/cargo-tarpaulin 54 | key: ${{ runner.os }}-cargo-${{ steps.rustc-version.outputs.version }} 55 | 56 | - name: Install tarpaulin 57 | if: steps.tarpaulin-cache.outputs.cache-hit != 'true' 58 | run: cargo install cargo-tarpaulin 59 | 60 | - name: Generate code coverage 61 | run: | 62 | cargo tarpaulin --verbose --timeout 120 --out Xml --all --release 63 | - name: Upload to codecov.io 64 | uses: codecov/codecov-action@v4 65 | with: 66 | token: ${{ secrets.CODECOV_TOKEN }} 67 | fail_ci_if_error: true 68 | 69 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | # build and deploy on master push, otherwise just try to build the page 2 | on: 3 | push: 4 | branches: 5 | - master 6 | pull_request: 7 | 8 | name: Build website with Zola, build rust docs and publish to GH pages 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | if: github.ref != 'refs/heads/master' && github.repository == 'rust-ml/linfa' 14 | steps: 15 | - name: 'Checkout' 16 | uses: actions/checkout@master 17 | 18 | - name: 'Build only' 19 | uses: shalzz/zola-deploy-action@master 20 | env: 21 | BUILD_DIR: docs/website/ 22 | TOKEN: ${{ secrets.TOKEN }} 23 | BUILD_ONLY: true 24 | 25 | - name: Build Documentation 26 | run: cargo doc --workspace --no-deps 27 | env: 28 | RUSTDOCFLAGS: -D warnings 29 | 30 | build_and_deploy: 31 | runs-on: ubuntu-latest 32 | if: github.ref == 'refs/heads/master' || github.repository != 'rust-ml/linfa' 33 | steps: 34 | - name: 'Checkout' 35 | uses: actions/checkout@master 36 | 37 | - name: Install Rust toolchain 38 | uses: dtolnay/rust-toolchain@stable 39 | with: 40 | components: rustfmt, rust-src 41 | 42 | - name: Build Documentation 43 | run: cargo doc --workspace --no-deps 44 | env: 45 | RUSTDOCFLAGS: -D warnings 46 | 47 | - name: Copy Rust Documentation to Zola 48 | run: cp -R "target/doc/" "docs/website/static/rustdocs/" 49 | 50 | - name: 'Build and deploy' 51 | uses: shalzz/zola-deploy-action@master 52 | env: 53 | PAGES_BRANCH: gh-pages 54 | BUILD_DIR: docs/website/ 55 | TOKEN: ${{ secrets.TOKEN }} 56 | -------------------------------------------------------------------------------- /.github/workflows/testing.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | 3 | name: Run Tests 4 | 5 | jobs: 6 | testing: 7 | name: testing-${{ matrix.toolchain }}-${{ matrix.os }} 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | fail-fast: false 11 | matrix: 12 | toolchain: 13 | - 1.82.0 14 | - stable 15 | os: 16 | - ubuntu-latest 17 | - windows-latest 18 | 19 | steps: 20 | - name: Checkout sources 21 | uses: actions/checkout@master 22 | 23 | - name: Install toolchain 24 | uses: dtolnay/rust-toolchain@master 25 | with: 26 | toolchain: ${{ matrix.toolchain }} 27 | 28 | - name: Run cargo test 29 | run: cargo test --release --workspace 30 | 31 | testing-blas: 32 | name: testing-with-BLAS-${{ matrix.toolchain }}-${{ matrix.os }} 33 | runs-on: ${{ matrix.os }} 34 | strategy: 35 | fail-fast: false 36 | matrix: 37 | toolchain: 38 | - 1.82.0 39 | - stable 40 | os: 41 | - ubuntu-latest 42 | - windows-latest 43 | 44 | steps: 45 | - name: Checkout sources 46 | uses: actions/checkout@master 47 | 48 | - name: Install toolchain 49 | uses: dtolnay/rust-toolchain@master 50 | with: 51 | toolchain: ${{ matrix.toolchain }} 52 | 53 | - name: Run cargo test with BLAS enabled 54 | run: cargo test --release --workspace --features intel-mkl-static,linfa-ica/blas,linfa-reduction/blas,linfa-linear/blas,linfa-preprocessing/blas,linfa-pls/blas,linfa-elasticnet/blas 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here http://doc.crates.io/guide.html#cargotoml-vs-cargolock 7 | Cargo.lock 8 | 9 | # These are backup files generated by rustfmt 10 | **/*.rs.bk 11 | # vscode 12 | .vscode 13 | 14 | # ctags 15 | tags 16 | *.npy 17 | .idea 18 | 19 | # Python Bindings 20 | build/ 21 | __pycache__/ 22 | .pytest_cache/ 23 | dist/ 24 | *.so 25 | *.out 26 | *.egg-info 27 | *.eggs/ 28 | .venv/ 29 | .python-version 30 | poetry.lock 31 | .ipynb_checkpoints/ 32 | *.ipynb 33 | 34 | *.json 35 | 36 | # Generated artifacts of website (with Zola) 37 | docs/website/public/* 38 | docs/website/static/rustdocs/ 39 | 40 | # Downloaded data for the linfa-preprocessing benches 41 | 20news/ -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa" 3 | version = "0.7.1" 4 | authors = [ 5 | "Luca Palmieri ", 6 | "Lorenz Schmidt ", 7 | "Paul Körbitz ", 8 | "Yuhan Lin ", 9 | ] 10 | description = "A Machine Learning framework for Rust" 11 | edition = "2018" 12 | license = "MIT OR Apache-2.0" 13 | 14 | repository = "https://github.com/rust-ml/linfa" 15 | readme = "README.md" 16 | 17 | keywords = ["machine-learning", "linfa", "ai", "ml"] 18 | categories = ["algorithms", "mathematics", "science"] 19 | 20 | exclude = [".github/"] 21 | 22 | [features] 23 | default = [] 24 | benchmarks = ["criterion", "pprof"] 25 | netlib-static = ["blas", "ndarray-linalg/netlib-static"] 26 | netlib-system = ["blas", "ndarray-linalg/netlib-system"] 27 | 28 | openblas-static = ["blas", "ndarray-linalg/openblas-static"] 29 | openblas-system = ["blas", "ndarray-linalg/openblas-system"] 30 | 31 | intel-mkl-static = ["blas", "ndarray-linalg/intel-mkl-static"] 32 | intel-mkl-system = ["blas", "ndarray-linalg/intel-mkl-system"] 33 | 34 | blas = ["ndarray/blas"] 35 | 36 | serde = ["serde_crate", "ndarray/serde"] 37 | 38 | [dependencies] 39 | num-traits = "0.2" 40 | rand = { version = "0.8", features = ["small_rng"] } 41 | approx = "0.4" 42 | 43 | ndarray = { version = "0.15", features = ["approx"] } 44 | ndarray-linalg = { version = "0.16", optional = true } 45 | sprs = { version = "=0.11.1", default-features = false } 46 | 47 | thiserror = "1.0" 48 | 49 | criterion = { version = "0.4.0", optional = true } 50 | 51 | [dependencies.serde_crate] 52 | package = "serde" 53 | optional = true 54 | version = "1.0" 55 | default-features = false 56 | features = ["std", "derive"] 57 | 58 | [dev-dependencies] 59 | ndarray-rand = "0.14" 60 | linfa-datasets = { path = "datasets", features = [ 61 | "winequality", 62 | "iris", 63 | "diabetes", 64 | "generate", 65 | ] } 66 | statrs = "0.16.0" 67 | 68 | [target.'cfg(not(windows))'.dependencies] 69 | pprof = { version = "0.11.0", features = [ 70 | "flamegraph", 71 | "criterion", 72 | ], optional = true } 73 | 74 | [workspace] 75 | members = ["algorithms/*", "datasets"] 76 | 77 | [profile.release] 78 | opt-level = 3 79 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 2 | 3 | Permission is hereby granted, free of charge, to any 4 | person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the 6 | Software without restriction, including without 7 | limitation the rights to use, copy, modify, merge, 8 | publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software 10 | is furnished to do so, subject to the following 11 | conditions: 12 | 13 | The above copyright notice and this permission notice 14 | shall be included in all copies or substantial portions 15 | of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 18 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 19 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 20 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 21 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 22 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 23 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 24 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 25 | DEALINGS IN THE SOFTWARE. 26 | 27 | -------------------------------------------------------------------------------- /algorithms/linfa-bayes/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-bayes" 3 | version = "0.7.1" 4 | authors = ["VasanthakumarV "] 5 | description = "Collection of Naive Bayes Algorithms" 6 | edition = "2018" 7 | license = "MIT OR Apache-2.0" 8 | repository = "https://github.com/rust-ml/linfa" 9 | readme = "README.md" 10 | keywords = ["factorization", "machine-learning", "linfa", "unsupervised"] 11 | categories = ["algorithms", "mathematics", "science"] 12 | 13 | [features] 14 | serde = ["serde_crate", "ndarray/serde"] 15 | 16 | [dependencies.serde_crate] 17 | package = "serde" 18 | optional = true 19 | version = "1.0" 20 | default-features = false 21 | features = ["std", "derive"] 22 | 23 | [dependencies] 24 | ndarray = { version = "0.15" , features = ["approx"]} 25 | ndarray-stats = "0.5" 26 | thiserror = "1.0" 27 | 28 | linfa = { version = "0.7.1", path = "../.." } 29 | 30 | [dev-dependencies] 31 | approx = "0.4" 32 | linfa-datasets = { version = "0.7.1", path = "../../datasets", features = ["winequality"] } 33 | -------------------------------------------------------------------------------- /algorithms/linfa-bayes/examples/winequality_bayes.rs: -------------------------------------------------------------------------------- 1 | use linfa::metrics::ToConfusionMatrix; 2 | use linfa::traits::{Fit, Predict}; 3 | use linfa_bayes::{GaussianNb, Result}; 4 | 5 | fn main() -> Result<()> { 6 | // Read in the dataset and convert targets to binary data 7 | let (train, valid) = linfa_datasets::winequality() 8 | .map_targets(|x| if *x > 6 { "good" } else { "bad" }) 9 | .split_with_ratio(0.9); 10 | 11 | // Train the model 12 | let model = GaussianNb::params().fit(&train)?; 13 | 14 | // Predict the validation dataset 15 | let pred = model.predict(&valid); 16 | 17 | // Construct confusion matrix 18 | let cm = pred.confusion_matrix(&valid)?; 19 | 20 | // classes | bad | good 21 | // bad | 130 | 12 22 | // good | 7 | 10 23 | // 24 | // accuracy 0.8805031, MCC 0.45080978 25 | println!("{:?}", cm); 26 | println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc()); 27 | 28 | Ok(()) 29 | } 30 | -------------------------------------------------------------------------------- /algorithms/linfa-bayes/examples/winequality_bernouilli.rs: -------------------------------------------------------------------------------- 1 | use linfa::metrics::ToConfusionMatrix; 2 | use linfa::traits::{Fit, Predict}; 3 | use linfa_bayes::{BernoulliNb, Result}; 4 | 5 | fn main() -> Result<()> { 6 | // Read in the dataset and convert targets to binary data 7 | let (train, valid) = linfa_datasets::winequality() 8 | .map_targets(|x| if *x > 6 { "good" } else { "bad" }) 9 | .split_with_ratio(0.9); 10 | 11 | // Train the model 12 | let model = BernoulliNb::params().fit(&train)?; 13 | 14 | // Predict the validation dataset 15 | let pred = model.predict(&valid); 16 | 17 | // Construct confusion matrix 18 | let cm = pred.confusion_matrix(&valid)?; 19 | // classes | bad | good 20 | // bad | 142 | 0 21 | // good | 17 | 0 22 | 23 | // accuracy 0.8930818, MCC 24 | println!("{:?}", cm); 25 | println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc()); 26 | 27 | Ok(()) 28 | } 29 | -------------------------------------------------------------------------------- /algorithms/linfa-bayes/examples/winequality_multinomial.rs: -------------------------------------------------------------------------------- 1 | use linfa::metrics::ToConfusionMatrix; 2 | use linfa::traits::{Fit, Predict}; 3 | use linfa_bayes::{MultinomialNb, Result}; 4 | 5 | fn main() -> Result<()> { 6 | // Read in the dataset and convert targets to binary data 7 | let (train, valid) = linfa_datasets::winequality() 8 | .map_targets(|x| if *x > 6 { "good" } else { "bad" }) 9 | .split_with_ratio(0.9); 10 | 11 | // Train the model 12 | let model = MultinomialNb::params().fit(&train)?; 13 | 14 | // Predict the validation dataset 15 | let pred = model.predict(&valid); 16 | 17 | // Construct confusion matrix 18 | let cm = pred.confusion_matrix(&valid)?; 19 | // classes | bad | good 20 | // bad | 88 | 54 21 | // good | 10 | 7 22 | 23 | // accuracy 0.5974843, MCC 0.02000631 24 | println!("{:?}", cm); 25 | println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc()); 26 | 27 | Ok(()) 28 | } 29 | -------------------------------------------------------------------------------- /algorithms/linfa-bayes/src/error.rs: -------------------------------------------------------------------------------- 1 | use ndarray_stats::errors::MinMaxError; 2 | use thiserror::Error; 3 | 4 | /// Simplified `Result` using [`NaiveBayesError`](crate::NaiveBayesError) as error type 5 | pub type Result = std::result::Result; 6 | 7 | /// Error variants from hyper-parameter construction or model estimation 8 | #[derive(Error, Debug)] 9 | pub enum NaiveBayesError { 10 | /// Error when performing Max operation on data 11 | #[error("invalid statistical operation {0}")] 12 | Stats(#[from] MinMaxError), 13 | /// Invalid smoothing parameter 14 | #[error("invalid smoothing parameter {0}")] 15 | InvalidSmoothing(f64), 16 | #[error(transparent)] 17 | BaseCrate(#[from] linfa::Error), 18 | } 19 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-clustering" 3 | version = "0.7.1" 4 | edition = "2018" 5 | authors = [ 6 | "Luca Palmieri ", 7 | "xd009642 ", 8 | "Rémi Lafage ", 9 | ] 10 | description = "A collection of clustering algorithms" 11 | license = "MIT OR Apache-2.0" 12 | 13 | repository = "https://github.com/rust-ml/linfa/" 14 | readme = "README.md" 15 | 16 | keywords = [ 17 | "clustering", 18 | "machine-learning", 19 | "linfa", 20 | "k-means", 21 | "unsupervised", 22 | ] 23 | categories = ["algorithms", "mathematics", "science"] 24 | 25 | [features] 26 | default = [] 27 | blas = [] 28 | serde = ["serde_crate", "ndarray/serde", "linfa-nn/serde"] 29 | 30 | [dependencies.serde_crate] 31 | package = "serde" 32 | optional = true 33 | version = "1.0" 34 | default-features = false 35 | features = ["std", "derive"] 36 | 37 | [dependencies] 38 | ndarray = { version = "0.15", features = ["rayon", "approx"] } 39 | linfa-linalg = { version = "0.1", default-features = false } 40 | ndarray-linalg = { version = "0.16", optional = true } 41 | ndarray-rand = "0.14" 42 | ndarray-stats = "0.5" 43 | num-traits = "0.2" 44 | rand_xoshiro = "0.6" 45 | space = "0.12" 46 | thiserror = "1.0" 47 | #partitions = "0.2.4" This one will break in a future version of Rust and has no replacement 48 | linfa = { version = "0.7.1", path = "../.." } 49 | linfa-nn = { version = "0.7.1", path = "../linfa-nn" } 50 | noisy_float = "0.2.0" 51 | 52 | [dev-dependencies] 53 | ndarray-npy = { version = "0.8", default-features = false } 54 | linfa-datasets = { version = "0.7.1", path = "../../datasets", features = [ 55 | "generate", 56 | ] } 57 | criterion = "0.4.0" 58 | serde_json = "1" 59 | approx = "0.4" 60 | lax = "0.15.0" 61 | linfa = { version = "0.7.1", path = "../..", features = ["benchmarks"] } 62 | 63 | [[bench]] 64 | name = "k_means" 65 | harness = false 66 | 67 | [[bench]] 68 | name = "dbscan" 69 | harness = false 70 | 71 | [[bench]] 72 | name = "gaussian_mixture" 73 | harness = false 74 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/README.md: -------------------------------------------------------------------------------- 1 | # Clustering 2 | 3 | `linfa-clustering` aims to provide pure Rust implementations of popular clustering algorithms. 4 | 5 | ## The big picture 6 | 7 | `linfa-clustering` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 8 | 9 | You can find a roadmap (and a selection of good first issues) 10 | [here](https://github.com/rust-ml/linfa/issues) - contributors are more than welcome! 11 | 12 | ## Current state 13 | 14 | `linfa-clustering` currently provides implementation of the following clustering algorithms, in addition to a couple of helper functions: 15 | - K-Means 16 | - DBSCAN 17 | - Approximated DBSCAN (Currently an alias for DBSCAN, due to its superior performance) 18 | - Gaussian Mixture Model 19 | 20 | 21 | Implementation choices, algorithmic details and a tutorial can be found 22 | [here](https://docs.rs/linfa-clustering). 23 | 24 | ## BLAS/Lapack backend 25 | We found that the pure Rust implementation maintained similar performance to the BLAS/LAPACK version and have removed it with this [PR](https://github.com/rust-ml/linfa/pull/257). Thus, to reduce code complexity BLAS support has been removed for this module. 26 | 27 | ## License 28 | Dual-licensed to be compatible with the Rust project. 29 | 30 | Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms. 31 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/benches/dbscan.rs: -------------------------------------------------------------------------------- 1 | use criterion::{ 2 | black_box, criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, 3 | PlotConfiguration, 4 | }; 5 | use linfa::benchmarks::config; 6 | use linfa::prelude::{ParamGuard, Transformer}; 7 | use linfa_clustering::Dbscan; 8 | use linfa_datasets::generate; 9 | use ndarray::Array2; 10 | use ndarray_rand::rand::SeedableRng; 11 | use ndarray_rand::rand_distr::Uniform; 12 | use ndarray_rand::RandomExt; 13 | use rand_xoshiro::Xoshiro256Plus; 14 | 15 | fn dbscan_bench(c: &mut Criterion) { 16 | let mut rng = Xoshiro256Plus::seed_from_u64(40); 17 | let cluster_sizes = vec![10, 100, 1000, 10000]; 18 | 19 | let mut benchmark = c.benchmark_group("dbscan"); 20 | config::set_default_benchmark_configs(&mut benchmark); 21 | benchmark.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); 22 | 23 | for cluster_size in cluster_sizes { 24 | let rng = &mut rng; 25 | benchmark.bench_with_input( 26 | BenchmarkId::new("dbscan", cluster_size), 27 | &cluster_size, 28 | move |bencher, &cluster_size| { 29 | let min_points = 4; 30 | let n_features = 3; 31 | let tolerance = 0.3; 32 | let centroids = 33 | Array2::random_using((min_points, n_features), Uniform::new(-30., 30.), rng); 34 | let dataset = generate::blobs(cluster_size, ¢roids, rng); 35 | 36 | bencher.iter(|| { 37 | black_box( 38 | Dbscan::params(min_points) 39 | .tolerance(tolerance) 40 | .check_unwrap() 41 | .transform(&dataset), 42 | ) 43 | }); 44 | }, 45 | ); 46 | } 47 | benchmark.finish() 48 | } 49 | 50 | #[cfg(not(target_os = "windows"))] 51 | criterion_group! { 52 | name = benches; 53 | config = config::get_default_profiling_configs(); 54 | targets = dbscan_bench 55 | } 56 | #[cfg(target_os = "windows")] 57 | criterion_group!(benches, dbscan_bench); 58 | 59 | criterion_main!(benches); 60 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/benches/gaussian_mixture.rs: -------------------------------------------------------------------------------- 1 | use criterion::{ 2 | black_box, criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, 3 | PlotConfiguration, 4 | }; 5 | use linfa::benchmarks::config; 6 | use linfa::traits::Fit; 7 | use linfa::DatasetBase; 8 | use linfa_clustering::GaussianMixtureModel; 9 | use linfa_datasets::generate; 10 | use ndarray::Array2; 11 | use ndarray_rand::rand::SeedableRng; 12 | use ndarray_rand::rand_distr::Uniform; 13 | use ndarray_rand::RandomExt; 14 | use rand_xoshiro::Xoshiro256Plus; 15 | 16 | fn gaussian_mixture_bench(c: &mut Criterion) { 17 | let mut rng = Xoshiro256Plus::seed_from_u64(40); 18 | let cluster_sizes = vec![10, 100, 1000, 10000]; 19 | 20 | let mut benchmark = c.benchmark_group("gaussian_mixture"); 21 | config::set_default_benchmark_configs(&mut benchmark); 22 | benchmark.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); 23 | 24 | for cluster_size in cluster_sizes { 25 | let rng = &mut rng; 26 | benchmark.bench_with_input( 27 | BenchmarkId::new("gaussian_mixture", cluster_size), 28 | &cluster_size, 29 | move |bencher, &cluster_size| { 30 | let n_clusters = 4; 31 | let n_features = 3; 32 | let centroids = 33 | Array2::random_using((n_clusters, n_features), Uniform::new(-30., 30.), rng); 34 | let dataset: DatasetBase<_, _> = 35 | (generate::blobs(cluster_size, ¢roids, rng)).into(); 36 | bencher.iter(|| { 37 | black_box( 38 | GaussianMixtureModel::params(n_clusters) 39 | .with_rng(rng.clone()) 40 | .tolerance(1e-3) 41 | .max_n_iterations(1000) 42 | .fit(&dataset) 43 | .expect("GMM fitting fail"), 44 | ) 45 | }); 46 | }, 47 | ); 48 | } 49 | benchmark.finish(); 50 | } 51 | 52 | #[cfg(not(target_os = "windows"))] 53 | criterion_group! { 54 | name = benches; 55 | config = config::get_default_profiling_configs(); 56 | targets = gaussian_mixture_bench 57 | } 58 | #[cfg(target_os = "windows")] 59 | criterion_group!(benches, gaussian_mixture_bench); 60 | 61 | criterion_main!(benches); 62 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/examples/dbscan.rs: -------------------------------------------------------------------------------- 1 | use linfa::dataset::{DatasetBase, Labels, Records}; 2 | use linfa::metrics::SilhouetteScore; 3 | use linfa::traits::Transformer; 4 | use linfa_clustering::Dbscan; 5 | use linfa_datasets::generate; 6 | use ndarray::array; 7 | use ndarray_npy::write_npy; 8 | use ndarray_rand::rand::SeedableRng; 9 | use rand_xoshiro::Xoshiro256Plus; 10 | 11 | // A routine DBScan task: build a synthetic dataset, predict clusters for it 12 | // and save both training data and predictions to disk. 13 | fn main() { 14 | // Our random number generator, seeded for reproducibility 15 | let mut rng = Xoshiro256Plus::seed_from_u64(42); 16 | 17 | // For each our expected centroids, generate `n` data points around it (a "blob") 18 | let expected_centroids = array![[10., 10.], [1., 12.], [20., 30.], [-20., 30.],]; 19 | let n = 100; 20 | let dataset: DatasetBase<_, _> = generate::blobs(n, &expected_centroids, &mut rng).into(); 21 | 22 | // Configure our training algorithm 23 | let min_points = 3; 24 | 25 | println!( 26 | "Clustering #{} data points grouped in 4 clusters of {} points each", 27 | dataset.nsamples(), 28 | n 29 | ); 30 | 31 | // Infer an optimal set of centroids based on the training data distribution 32 | let cluster_memberships = Dbscan::params(min_points) 33 | .tolerance(1.) 34 | .transform(dataset) 35 | .unwrap(); 36 | 37 | // sigle target dataset 38 | let label_count = cluster_memberships.label_count().remove(0); 39 | 40 | println!(); 41 | println!("Result: "); 42 | for (label, count) in label_count { 43 | match label { 44 | None => println!(" - {} noise points", count), 45 | Some(i) => println!(" - {} points in cluster {}", count, i), 46 | } 47 | } 48 | println!(); 49 | 50 | let silhouette_score = cluster_memberships.silhouette_score().unwrap(); 51 | 52 | println!("Silhouette score: {}", silhouette_score); 53 | 54 | let (records, cluster_memberships) = (cluster_memberships.records, cluster_memberships.targets); 55 | 56 | // Save to disk our dataset (and the cluster label assigned to each observation) 57 | // We use the `npy` format for compatibility with NumPy 58 | write_npy("clustered_dataset.npy", &records).expect("Failed to write .npy file"); 59 | write_npy( 60 | "clustered_memberships.npy", 61 | &cluster_memberships.map(|&x| x.map(|c| c as i64).unwrap_or(-1)), 62 | ) 63 | .expect("Failed to write .npy file"); 64 | } 65 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/examples/kmeans.rs: -------------------------------------------------------------------------------- 1 | use linfa::traits::Fit; 2 | use linfa::traits::Predict; 3 | use linfa::DatasetBase; 4 | use linfa_clustering::KMeans; 5 | use linfa_datasets::generate; 6 | use ndarray::{array, Axis}; 7 | use ndarray_npy::write_npy; 8 | use ndarray_rand::rand::SeedableRng; 9 | use rand_xoshiro::Xoshiro256Plus; 10 | 11 | use linfa_nn::distance::LInfDist; 12 | 13 | // A routine K-means task: build a synthetic dataset, fit the algorithm on it 14 | // and save both training data and predictions to disk. 15 | fn main() { 16 | // Our random number generator, seeded for reproducibility 17 | let mut rng = Xoshiro256Plus::seed_from_u64(42); 18 | 19 | // For each our expected centroids, generate `n` data points around it (a "blob") 20 | let expected_centroids = array![[10., 10.], [1., 12.], [20., 30.], [-20., 30.],]; 21 | let n = 10000; 22 | let dataset = DatasetBase::from(generate::blobs(n, &expected_centroids, &mut rng)); 23 | 24 | // Configure our training algorithm 25 | let n_clusters = expected_centroids.len_of(Axis(0)); 26 | let model = KMeans::params_with(n_clusters, rng, LInfDist) 27 | .max_n_iterations(200) 28 | .tolerance(1e-5) 29 | .fit(&dataset) 30 | .expect("KMeans fitted"); 31 | 32 | // Assign each point to a cluster using the set of centroids found using `fit` 33 | let dataset = model.predict(dataset); 34 | let DatasetBase { 35 | records, targets, .. 36 | } = dataset; 37 | 38 | // Save to disk our dataset (and the cluster label assigned to each observation) 39 | // We use the `npy` format for compatibility with NumPy 40 | write_npy("clustered_dataset.npy", &records).expect("Failed to write .npy file"); 41 | write_npy("clustered_memberships.npy", &targets.map(|&x| x as u64)) 42 | .expect("Failed to write .npy file"); 43 | } 44 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/examples/optics.rs: -------------------------------------------------------------------------------- 1 | use linfa::dataset::Records; 2 | use linfa::traits::Transformer; 3 | use linfa_clustering::Optics; 4 | use linfa_datasets::generate; 5 | use ndarray::{array, Array, Array2}; 6 | use ndarray_npy::write_npy; 7 | use ndarray_rand::rand::SeedableRng; 8 | use rand_xoshiro::Xoshiro256Plus; 9 | 10 | // A routine DBScan task: build a synthetic dataset, predict clusters for it 11 | // and save both training data and predictions to disk. 12 | fn main() { 13 | // Our random number generator, seeded for reproducibility 14 | let mut rng = Xoshiro256Plus::seed_from_u64(42); 15 | 16 | let expected_centroids = array![[10., 10.], [5., 5.], [20., 30.], [-20., 30.],]; 17 | let n = 100; 18 | let dataset: Array2 = generate::blobs(n, &expected_centroids, &mut rng); 19 | 20 | // Configure our training algorithm 21 | let min_points = 3; 22 | 23 | println!( 24 | "Performing Optics Analysis with #{} data points grouped in {} blobs", 25 | dataset.nsamples(), 26 | n 27 | ); 28 | 29 | // Perform OPTICS analysis with minimum points for a cluster neighborhood set to 3 30 | let analysis = Optics::params(min_points) 31 | .tolerance(3.0) 32 | .transform(dataset.view()) 33 | .unwrap(); 34 | 35 | println!(); 36 | println!("Result: "); 37 | for sample in analysis.iter() { 38 | println!("{:?}", sample); 39 | } 40 | println!(); 41 | 42 | // Save to disk our dataset (and the cluster label assigned to each observation) 43 | // We use the `npy` format for compatibility with NumPy 44 | write_npy("dataset.npy", &dataset).expect("Failed to write .npy file"); 45 | write_npy( 46 | "reachability.npy", 47 | &analysis 48 | .iter() 49 | .map(|x| x.reachability_distance().unwrap_or(f64::INFINITY)) 50 | .collect::>(), 51 | ) 52 | .expect("Failed to write .npy file"); 53 | write_npy( 54 | "indexes.npy", 55 | &analysis 56 | .iter() 57 | .map(|x| x.index() as u32) 58 | .collect::>(), 59 | ) 60 | .expect("Failed to write .npy file"); 61 | } 62 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/examples/optics_plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | # This script assumes you're running from the linfa-clustering root after 5 | # running the example so the dataset and reachability npy files are in the 6 | # linfa-clustering root as well. 7 | 8 | dataset = np.load("../dataset.npy") 9 | reachability = np.load("../reachability.npy") 10 | 11 | plot1 = plt.figure(1) 12 | plt.scatter(dataset[:, 0], dataset[:, 1]) 13 | 14 | plot2 = plt.figure(2) 15 | plt.plot(reachability) 16 | 17 | plt.show() 18 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/src/appx_dbscan/clustering/tests.rs: -------------------------------------------------------------------------------- 1 | use crate::AppxDbscan; 2 | 3 | use linfa::{traits::Transformer, ParamGuard}; 4 | use ndarray::Array2; 5 | 6 | #[test] 7 | fn clustering_test() { 8 | let params = AppxDbscan::params(2) 9 | .tolerance(2.0) 10 | .slack(0.1) 11 | .check() 12 | .unwrap(); 13 | let l = params.tolerance / 2_f64.sqrt(); 14 | let all_points = vec![ 15 | 2.0 * l, 16 | 2.0 * l, 17 | 2.0 * l, 18 | 2.0 * l, 19 | 2.0 * l, 20 | 2.0 * l, 21 | -5.0 * l, 22 | -5.0 * l, 23 | ]; 24 | let points = Array2::from_shape_vec((4, 2), all_points).unwrap(); 25 | let labels = params.transform(&points); 26 | assert_eq!( 27 | labels 28 | .iter() 29 | .filter(|x| x.is_some()) 30 | .map(|x| x.unwrap() as i64) 31 | .max() 32 | .unwrap_or(-1) 33 | + 1, 34 | 1 35 | ); 36 | assert_eq!(labels.iter().filter(|x| x.is_none()).count(), 1); 37 | assert_eq!( 38 | labels 39 | .iter() 40 | .filter(|x| x.is_some() && x.unwrap() == 0) 41 | .count(), 42 | 3 43 | ); 44 | } 45 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/src/appx_dbscan/mod.rs: -------------------------------------------------------------------------------- 1 | mod algorithm; 2 | mod cells_grid; 3 | mod clustering; 4 | mod counting_tree; 5 | mod hyperparams; 6 | 7 | pub use algorithm::*; 8 | pub use hyperparams::*; 9 | 10 | #[cfg(test)] 11 | mod tests; 12 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/src/dbscan/mod.rs: -------------------------------------------------------------------------------- 1 | mod algorithm; 2 | mod hyperparams; 3 | 4 | pub use algorithm::*; 5 | pub use hyperparams::*; 6 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/src/gaussian_mixture/errors.rs: -------------------------------------------------------------------------------- 1 | use crate::k_means::KMeansError; 2 | #[cfg(not(feature = "blas"))] 3 | use linfa_linalg::LinalgError; 4 | #[cfg(feature = "blas")] 5 | use ndarray_linalg::error::LinalgError; 6 | use thiserror::Error; 7 | 8 | /// An error when modeling a GMM algorithm 9 | #[derive(Error, Debug)] 10 | pub enum GmmError { 11 | /// When any of the hyperparameters are set the wrong value 12 | #[error("Invalid value encountered: {0}")] 13 | InvalidValue(String), 14 | /// Errors encountered during linear algebra operations 15 | #[error( 16 | "Linalg Error: \ 17 | Fitting the mixture model failed because some components have \ 18 | ill-defined empirical covariance (for instance caused by singleton \ 19 | or collapsed samples). Try to decrease the number of components, \ 20 | or increase reg_covar. Error: {0}" 21 | )] 22 | LinalgError(#[from] LinalgError), 23 | /// When a cluster has no more data point while fitting GMM 24 | #[error("Fitting failed: {0}")] 25 | EmptyCluster(String), 26 | /// When lower bound computation fails 27 | #[error("Fitting failed: {0}")] 28 | LowerBoundError(String), 29 | /// When fitting EM algorithm does not converge 30 | #[error("Fitting failed: {0}")] 31 | NotConverged(String), 32 | /// When initial KMeans fails 33 | #[error("Initial KMeans failed: {0}")] 34 | KMeansError(#[from] KMeansError), 35 | #[error(transparent)] 36 | LinfaError(#[from] linfa::error::Error), 37 | #[error(transparent)] 38 | MinMaxError(#[from] ndarray_stats::errors::MinMaxError), 39 | } 40 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/src/gaussian_mixture/mod.rs: -------------------------------------------------------------------------------- 1 | mod algorithm; 2 | mod errors; 3 | mod hyperparams; 4 | 5 | pub use algorithm::*; 6 | pub use errors::*; 7 | pub use hyperparams::*; 8 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/src/k_means/errors.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | /// An error when fitting with an invalid hyperparameter 4 | #[derive(Error, Debug)] 5 | pub enum KMeansParamsError { 6 | #[error("n_clusters cannot be 0")] 7 | NClusters, 8 | #[error("n_runs cannot be 0")] 9 | NRuns, 10 | #[error("tolerance must be greater than 0")] 11 | Tolerance, 12 | #[error("max_n_iterations cannot be 0")] 13 | MaxIterations, 14 | } 15 | 16 | /// An error when modeling a KMeans algorithm 17 | #[derive(Error, Debug)] 18 | pub enum KMeansError { 19 | /// When any of the hyperparameters are set the wrong value 20 | #[error("Invalid hyperparameter: {0}")] 21 | InvalidParams(#[from] KMeansParamsError), 22 | /// When inertia computation fails 23 | #[error("Fitting failed: No inertia improvement (-inf)")] 24 | InertiaError, 25 | #[error(transparent)] 26 | LinfaError(#[from] linfa::error::Error), 27 | } 28 | 29 | #[derive(Error, Debug)] 30 | pub enum IncrKMeansError { 31 | /// When any of the hyperparameters are set the wrong value 32 | #[error("Invalid hyperparameter: {0}")] 33 | InvalidParams(#[from] KMeansParamsError), 34 | /// When the distance between the old and new centroids exceeds the tolerance parameter. Not an 35 | /// actual error, just there to signal that the algorithm should keep running. 36 | #[error("Algorithm has not yet converged, Keep on running the algorithm.")] 37 | NotConverged(M), 38 | #[error(transparent)] 39 | LinfaError(#[from] linfa::error::Error), 40 | } 41 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/src/k_means/mod.rs: -------------------------------------------------------------------------------- 1 | mod algorithm; 2 | mod errors; 3 | mod hyperparams; 4 | mod init; 5 | 6 | pub use algorithm::*; 7 | pub use errors::*; 8 | pub use hyperparams::*; 9 | pub use init::*; 10 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! `linfa-clustering` aims to provide pure Rust implementations 2 | //! of popular clustering algorithms. 3 | //! 4 | //! ## The big picture 5 | //! 6 | //! `linfa-clustering` is a crate in the `linfa` ecosystem, a wider effort to 7 | //! bootstrap a toolkit for classical Machine Learning implemented in pure Rust, 8 | //! kin in spirit to Python's `scikit-learn`. 9 | //! 10 | //! You can find a roadmap (and a selection of good first issues) 11 | //! [here](https://github.com/LukeMathWalker/linfa/issues) - contributors are more than welcome! 12 | //! 13 | //! ## Current state 14 | //! 15 | //! Right now `linfa-clustering` provides the following clustering algorithms: 16 | //! * [K-Means](KMeans) 17 | //! * [DBSCAN](Dbscan) 18 | //! * [Approximated DBSCAN](AppxDbscan) (Currently an alias for DBSCAN, due to its superior 19 | //! performance) 20 | //! * [Gaussian-Mixture-Model](GaussianMixtureModel) 21 | //! * [OPTICS](OpticsAnalysis) 22 | //! 23 | //! Implementation choices, algorithmic details and tutorials can be found in the page dedicated to the specific algorithms. 24 | mod dbscan; 25 | mod gaussian_mixture; 26 | #[allow(clippy::new_ret_no_self)] 27 | mod k_means; 28 | mod optics; 29 | 30 | pub use dbscan::*; 31 | pub use gaussian_mixture::*; 32 | pub use k_means::*; 33 | pub use optics::*; 34 | 35 | // Approx DBSCAN is currently an alias for DBSCAN, due to the old Approx DBSCAN implementation's 36 | // lower performance and outdated dependencies 37 | 38 | use linfa_nn::distance::L2Dist; 39 | pub type AppxDbscanValidParams = DbscanValidParams; 40 | pub type AppxDbscanParams = DbscanParams; 41 | pub type AppxDbscanParamsError = DbscanParamsError; 42 | pub type AppxDbscan = Dbscan; 43 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/src/optics/errors.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | /// An error when performing OPTICS Analysis 4 | #[derive(Error, Debug)] 5 | pub enum OpticsError { 6 | /// When any of the hyperparameters are set the wrong value 7 | #[error("Invalid value encountered: {0}")] 8 | InvalidValue(String), 9 | } 10 | -------------------------------------------------------------------------------- /algorithms/linfa-clustering/src/optics/mod.rs: -------------------------------------------------------------------------------- 1 | mod algorithm; 2 | mod errors; 3 | mod hyperparams; 4 | 5 | pub use algorithm::*; 6 | pub use errors::*; 7 | pub use hyperparams::*; 8 | -------------------------------------------------------------------------------- /algorithms/linfa-elasticnet/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-elasticnet" 3 | version = "0.7.1" 4 | authors = [ 5 | "Paul Körbitz / Google ", 6 | "Lorenz Schmidt ", 7 | ] 8 | 9 | description = "A Machine Learning framework for Rust" 10 | edition = "2018" 11 | license = "MIT OR Apache-2.0" 12 | 13 | repository = "https://github.com/rust-ml/linfa" 14 | readme = "README.md" 15 | 16 | keywords = ["machine-learning", "linfa", "ai", "ml", "linear"] 17 | categories = ["algorithms", "mathematics", "science"] 18 | 19 | [features] 20 | default = [] 21 | serde = ["serde_crate", "ndarray/serde", "linfa/serde"] 22 | blas = ["ndarray-linalg", "linfa/ndarray-linalg"] 23 | 24 | [dependencies.serde_crate] 25 | package = "serde" 26 | optional = true 27 | version = "1.0" 28 | default-features = false 29 | features = ["std", "derive"] 30 | 31 | [dependencies] 32 | ndarray = { version = "0.15", features = ["approx"] } 33 | linfa-linalg = { version = "0.1", default-features = false } 34 | ndarray-linalg = { version = "0.16", optional = true } 35 | 36 | num-traits = "0.2" 37 | approx = "0.4" 38 | thiserror = "1.0" 39 | 40 | linfa = { version = "0.7.1", path = "../.." } 41 | 42 | [dev-dependencies] 43 | linfa-datasets = { version = "0.7.1", path = "../../datasets", features = [ 44 | "diabetes", 45 | "linnerud", 46 | ] } 47 | ndarray-rand = "0.14" 48 | rand_xoshiro = "0.6" 49 | -------------------------------------------------------------------------------- /algorithms/linfa-elasticnet/README.md: -------------------------------------------------------------------------------- 1 | # Elastic Net 2 | 3 | `linfa-elasticnet` provides a pure Rust implementations of elastic net linear regression. 4 | 5 | ## The Big Picture 6 | 7 | `linfa-elasticnet` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 8 | 9 | ## Current state 10 | 11 | The `linfa-elasticnet` crate provides linear regression with ridge and LASSO constraints. The solver uses coordinate descent to find an optimal solution. 12 | 13 | This library contains an elastic net implementation for linear regression models. It combines l1 and l2 penalties of the LASSO and ridge methods and offers therefore a greater flexibility for feature selection. With increasing penalization certain parameters become zero, their corresponding variables are dropped from the model. 14 | 15 | See also: 16 | * [Wikipedia on Elastic net](https://en.wikipedia.org/wiki/Elastic_net_regularization) 17 | 18 | ## BLAS/Lapack backend 19 | 20 | See [this section](../../README.md#blaslapack-backend) to enable an external BLAS/LAPACK backend. 21 | 22 | ## Examples 23 | 24 | There is an usage example in the `examples/` directory. To run, use: 25 | 26 | ```bash 27 | $ cargo run --example elasticnet 28 | ``` 29 | 30 |
31 | 32 | Show source code 33 | 34 | 35 | ```rust 36 | use linfa::prelude::*; 37 | use linfa_elasticnet::{ElasticNet, Result}; 38 | 39 | // load Diabetes dataset 40 | let (train, valid) = linfa_datasets::diabetes().split_with_ratio(0.90); 41 | 42 | // train pure LASSO model with 0.1 penalty 43 | let model = ElasticNet::params() 44 | .penalty(0.3) 45 | .l1_ratio(1.0) 46 | .fit(&train)?; 47 | 48 | println!("intercept: {}", model.intercept()); 49 | println!("params: {}", model.hyperplane()); 50 | 51 | println!("z score: {:?}", model.z_score()); 52 | 53 | // validate 54 | let y_est = model.predict(&valid); 55 | println!("predicted variance: {}", valid.r2(&y_est)?); 56 | # Result::Ok(()) 57 | ``` 58 |
59 | -------------------------------------------------------------------------------- /algorithms/linfa-elasticnet/examples/elasticnet.rs: -------------------------------------------------------------------------------- 1 | use linfa::prelude::*; 2 | use linfa_elasticnet::{ElasticNet, Result}; 3 | 4 | fn main() -> Result<()> { 5 | // load Diabetes dataset 6 | let (train, valid) = linfa_datasets::diabetes().split_with_ratio(0.90); 7 | 8 | // train pure LASSO model with 0.3 penalty 9 | let model = ElasticNet::params() 10 | .penalty(0.3) 11 | .l1_ratio(1.0) 12 | .fit(&train)?; 13 | 14 | println!("intercept: {}", model.intercept()); 15 | println!("params: {}", model.hyperplane()); 16 | 17 | println!("z score: {:?}", model.z_score()); 18 | 19 | // validate 20 | let y_est = model.predict(&valid); 21 | println!("predicted variance: {}", valid.r2(&y_est)?); 22 | 23 | Ok(()) 24 | } 25 | -------------------------------------------------------------------------------- /algorithms/linfa-elasticnet/examples/elasticnet_cv.rs: -------------------------------------------------------------------------------- 1 | use linfa::prelude::*; 2 | use linfa_elasticnet::{ElasticNet, Result}; 3 | 4 | fn main() -> Result<()> { 5 | // load Diabetes dataset (mutable to allow fast k-folding) 6 | let mut dataset = linfa_datasets::diabetes(); 7 | 8 | // parameters to compare 9 | let ratios = &[0.1, 0.2, 0.5, 0.7, 1.0]; 10 | 11 | // create a model for each parameter 12 | let models = ratios 13 | .iter() 14 | .map(|ratio| ElasticNet::params().penalty(0.3).l1_ratio(*ratio)) 15 | .collect::>(); 16 | 17 | // get the mean r2 validation score across all folds for each model 18 | let r2_values = 19 | dataset.cross_validate_single(5, &models, |prediction, truth| prediction.r2(&truth))?; 20 | 21 | for (ratio, r2) in ratios.iter().zip(r2_values.iter()) { 22 | println!("L1 ratio: {}, r2 score: {}", ratio, r2); 23 | } 24 | 25 | Ok(()) 26 | } 27 | -------------------------------------------------------------------------------- /algorithms/linfa-elasticnet/examples/multitask_elasticnet.rs: -------------------------------------------------------------------------------- 1 | use linfa::prelude::*; 2 | use linfa_elasticnet::{MultiTaskElasticNet, Result}; 3 | 4 | fn main() -> Result<()> { 5 | // load Diabetes dataset 6 | let (train, valid) = linfa_datasets::linnerud().split_with_ratio(0.80); 7 | 8 | // train pure LASSO model with 0.1 penalty 9 | let model = MultiTaskElasticNet::params() 10 | .penalty(0.1) 11 | .l1_ratio(1.0) 12 | .fit(&train)?; 13 | 14 | println!("intercept: {}", model.intercept()); 15 | println!("params: {}", model.hyperplane()); 16 | 17 | println!("z score: {:?}", model.z_score()); 18 | 19 | // validate 20 | let y_est = model.predict(&valid); 21 | println!("predicted variance: {}", y_est.r2(&valid)?); 22 | 23 | Ok(()) 24 | } 25 | -------------------------------------------------------------------------------- /algorithms/linfa-elasticnet/src/error.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "serde")] 2 | use serde_crate::{Deserialize, Serialize}; 3 | use thiserror::Error; 4 | 5 | /// Simplified `Result` using [`ElasticNetError`](crate::ElasticNetError) as error type 6 | pub type Result = std::result::Result; 7 | 8 | #[cfg_attr( 9 | feature = "serde", 10 | derive(Serialize, Deserialize), 11 | serde(crate = "serde_crate") 12 | )] 13 | /// Error variants from hyperparameter construction or model estimation 14 | #[derive(Debug, Clone, Error)] 15 | pub enum ElasticNetError { 16 | /// The input has not enough samples 17 | #[error("not enough samples as they have to be larger than number of features")] 18 | NotEnoughSamples, 19 | /// The input is singular 20 | #[error("the data is ill-conditioned")] 21 | IllConditioned, 22 | #[error("l1 ratio should be in range [0, 1], but is {0}")] 23 | InvalidL1Ratio(f32), 24 | #[error("invalid penalty {0}")] 25 | InvalidPenalty(f32), 26 | #[error("invalid tolerance {0}")] 27 | InvalidTolerance(f32), 28 | #[error("the target can either be a vector (ndim=1) or a matrix (ndim=2)")] 29 | IncorrectTargetShape, 30 | #[error(transparent)] 31 | BaseCrate(#[from] linfa::Error), 32 | } 33 | -------------------------------------------------------------------------------- /algorithms/linfa-ensemble/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-ensemble" 3 | version = "0.7.0" 4 | edition = "2018" 5 | authors = [ 6 | "James Knight ", 7 | "James Kay ", 8 | ] 9 | description = "A general method for creating ensemble classifiers" 10 | license = "MIT/Apache-2.0" 11 | 12 | repository = "https://github.com/rust-ml/linfa" 13 | readme = "README.md" 14 | 15 | keywords = ["machine-learning", "linfa", "ensemble"] 16 | categories = ["algorithms", "mathematics", "science"] 17 | 18 | [features] 19 | default = [] 20 | serde = ["serde_crate", "ndarray/serde"] 21 | 22 | [dependencies.serde_crate] 23 | package = "serde" 24 | optional = true 25 | version = "1.0" 26 | default-features = false 27 | features = ["std", "derive"] 28 | 29 | [dependencies] 30 | ndarray = { version = "0.15", features = ["rayon", "approx"] } 31 | ndarray-rand = "0.14" 32 | rand = "0.8.5" 33 | 34 | linfa = { version = "0.7.1", path = "../.." } 35 | linfa-trees = { version = "0.7.1", path = "../linfa-trees" } 36 | 37 | [dev-dependencies] 38 | linfa-datasets = { version = "0.7.1", path = "../../datasets/", features = [ 39 | "iris", 40 | ] } 41 | -------------------------------------------------------------------------------- /algorithms/linfa-ensemble/README.md: -------------------------------------------------------------------------------- 1 | # Ensemble Learning 2 | 3 | `linfa-ensemble` provides pure Rust implementations of Ensemble Learning algorithms for the Linfa toolkit. 4 | 5 | ## The Big Picture 6 | 7 | `linfa-ensemble` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 8 | 9 | ## Current state 10 | 11 | `linfa-ensemble` currently provides an implementation of bootstrap aggregation (bagging) for other classifiers provided in linfa. 12 | 13 | ## Examples 14 | 15 | You can find examples in the `examples/` directory. To run an bootstrap aggregation for ensemble of decision trees (a Random Forest) use: 16 | 17 | ```bash 18 | $ cargo run --example randomforest_iris --release 19 | ``` 20 | 21 | 22 | -------------------------------------------------------------------------------- /algorithms/linfa-ensemble/examples/bagging_iris.rs: -------------------------------------------------------------------------------- 1 | use linfa::prelude::{Fit, Predict, ToConfusionMatrix}; 2 | use linfa_ensemble::EnsembleLearnerParams; 3 | use linfa_trees::DecisionTree; 4 | use ndarray_rand::rand::SeedableRng; 5 | use rand::rngs::SmallRng; 6 | 7 | fn main() { 8 | // Number of models in the ensemble 9 | let ensemble_size = 100; 10 | // Proportion of training data given to each model 11 | let bootstrap_proportion = 0.7; 12 | 13 | // Load dataset 14 | let mut rng = SmallRng::seed_from_u64(42); 15 | let (train, test) = linfa_datasets::iris() 16 | .shuffle(&mut rng) 17 | .split_with_ratio(0.8); 18 | 19 | // Train ensemble learner model 20 | let model = EnsembleLearnerParams::new(DecisionTree::params()) 21 | .ensemble_size(ensemble_size) 22 | .bootstrap_proportion(bootstrap_proportion) 23 | .fit(&train) 24 | .unwrap(); 25 | 26 | // Return highest ranking predictions 27 | let final_predictions_ensemble = model.predict(&test); 28 | println!("Final Predictions: \n{:?}", final_predictions_ensemble); 29 | 30 | let cm = final_predictions_ensemble.confusion_matrix(&test).unwrap(); 31 | 32 | println!("{:?}", cm); 33 | println!("Test accuracy: {} \n with default Decision Tree params, \n Ensemble Size: {},\n Bootstrap Proportion: {}", 34 | 100.0 * cm.accuracy(), ensemble_size, bootstrap_proportion); 35 | } 36 | -------------------------------------------------------------------------------- /algorithms/linfa-ensemble/src/hyperparams.rs: -------------------------------------------------------------------------------- 1 | use linfa::{ 2 | error::{Error, Result}, 3 | ParamGuard, 4 | }; 5 | use rand::rngs::ThreadRng; 6 | use rand::Rng; 7 | 8 | #[derive(Clone, Copy, Debug, PartialEq)] 9 | pub struct EnsembleLearnerValidParams { 10 | /// The number of models in the ensemble 11 | pub ensemble_size: usize, 12 | /// The proportion of the total number of training samples that should be given to each model for training 13 | pub bootstrap_proportion: f64, 14 | /// The model parameters for the base model 15 | pub model_params: P, 16 | pub rng: R, 17 | } 18 | 19 | #[derive(Clone, Copy, Debug, PartialEq)] 20 | pub struct EnsembleLearnerParams(EnsembleLearnerValidParams); 21 | 22 | impl

EnsembleLearnerParams { 23 | pub fn new(model_params: P) -> EnsembleLearnerParams { 24 | Self::new_fixed_rng(model_params, rand::thread_rng()) 25 | } 26 | } 27 | 28 | impl EnsembleLearnerParams { 29 | pub fn new_fixed_rng(model_params: P, rng: R) -> EnsembleLearnerParams { 30 | Self(EnsembleLearnerValidParams { 31 | ensemble_size: 1, 32 | bootstrap_proportion: 1.0, 33 | model_params, 34 | rng, 35 | }) 36 | } 37 | 38 | pub fn ensemble_size(mut self, size: usize) -> Self { 39 | self.0.ensemble_size = size; 40 | self 41 | } 42 | 43 | pub fn bootstrap_proportion(mut self, proportion: f64) -> Self { 44 | self.0.bootstrap_proportion = proportion; 45 | self 46 | } 47 | } 48 | 49 | impl ParamGuard for EnsembleLearnerParams { 50 | type Checked = EnsembleLearnerValidParams; 51 | type Error = Error; 52 | 53 | fn check_ref(&self) -> Result<&Self::Checked> { 54 | if self.0.bootstrap_proportion > 1.0 || self.0.bootstrap_proportion <= 0.0 { 55 | Err(Error::Parameters(format!( 56 | "Bootstrap proportion should be greater than zero and less than or equal to one, but was {}", 57 | self.0.bootstrap_proportion 58 | ))) 59 | } else if self.0.ensemble_size < 1 { 60 | Err(Error::Parameters(format!( 61 | "Ensemble size should be less than one, but was {}", 62 | self.0.ensemble_size 63 | ))) 64 | } else { 65 | Ok(&self.0) 66 | } 67 | } 68 | 69 | fn check(self) -> Result { 70 | self.check_ref()?; 71 | Ok(self.0) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /algorithms/linfa-ensemble/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # Ensemble Learning Algorithms 2 | //! 3 | //! Ensemble methods combine the predictions of several base estimators built with a given 4 | //! learning algorithm in order to improve generalizability / robustness over a single estimator. 5 | //! 6 | //! ## Bootstrap Aggregation (aka Bagging) 7 | //! 8 | //! A typical example of ensemble method is Bootstrapo AGgregation, which combines the predictions of 9 | //! several decision trees (see `linfa-trees`) trained on different samples subset of the training dataset. 10 | //! 11 | //! ## Reference 12 | //! 13 | //! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html) 14 | //! 15 | //! ## Example 16 | //! 17 | //! This example shows how to train a bagging model using 100 decision trees, 18 | //! each trained on 70% of the training data (bootstrap sampling). 19 | //! 20 | //! ```no_run 21 | //! use linfa::prelude::{Fit, Predict}; 22 | //! use linfa_ensemble::EnsembleLearnerParams; 23 | //! use linfa_trees::DecisionTree; 24 | //! use ndarray_rand::rand::SeedableRng; 25 | //! use rand::rngs::SmallRng; 26 | //! 27 | //! // Load Iris dataset 28 | //! let mut rng = SmallRng::seed_from_u64(42); 29 | //! let (train, test) = linfa_datasets::iris() 30 | //! .shuffle(&mut rng) 31 | //! .split_with_ratio(0.8); 32 | //! 33 | //! // Train the model on the iris dataset 34 | //! let bagging_model = EnsembleLearnerParams::new(DecisionTree::params()) 35 | //! .ensemble_size(100) 36 | //! .bootstrap_proportion(0.7) 37 | //! .fit(&train) 38 | //! .unwrap(); 39 | //! 40 | //! // Make predictions on the test set 41 | //! let predictions = bagging_model.predict(&test); 42 | //! ``` 43 | //! 44 | mod algorithm; 45 | mod hyperparams; 46 | 47 | pub use algorithm::*; 48 | pub use hyperparams::*; 49 | 50 | #[cfg(test)] 51 | mod tests { 52 | use super::*; 53 | use linfa::prelude::{Fit, Predict, ToConfusionMatrix}; 54 | use linfa_trees::DecisionTree; 55 | use ndarray_rand::rand::SeedableRng; 56 | use rand::rngs::SmallRng; 57 | 58 | #[test] 59 | fn test_ensemble_learner_accuracy_on_iris_dataset() { 60 | let mut rng = SmallRng::seed_from_u64(42); 61 | let (train, test) = linfa_datasets::iris() 62 | .shuffle(&mut rng) 63 | .split_with_ratio(0.8); 64 | 65 | let model = EnsembleLearnerParams::new(DecisionTree::params()) 66 | .ensemble_size(100) 67 | .bootstrap_proportion(0.7) 68 | .fit(&train) 69 | .unwrap(); 70 | 71 | let predictions = model.predict(&test); 72 | 73 | let cm = predictions.confusion_matrix(&test).unwrap(); 74 | let acc = cm.accuracy(); 75 | assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc); 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /algorithms/linfa-ftrl/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-ftrl" 3 | version = "0.7.1" 4 | authors = ["Liudmyla Kyrashchuk "] 5 | 6 | description = "A Machine Learning framework for Rust" 7 | edition = "2018" 8 | license = "MIT OR Apache-2.0" 9 | 10 | repository = "https://github.com/rust-ml/linfa" 11 | readme = "README.md" 12 | 13 | keywords = ["machine-learning", "linfa", "ai", "ml", "ftrl"] 14 | categories = ["algorithms", "mathematics", "science"] 15 | 16 | [features] 17 | serde = ["serde_crate", "linfa/serde", "ndarray/serde", "argmin/serde1"] 18 | wasm-bindgen = ["argmin/wasm-bindgen"] 19 | 20 | [dependencies.serde_crate] 21 | package = "serde" 22 | optional = true 23 | version = "1.0" 24 | features = ["derive"] 25 | 26 | [dependencies] 27 | ndarray = { version = "0.15", features = ["serde"] } 28 | ndarray-rand = "0.14.0" 29 | argmin = { version = "0.9.0", default-features = false } 30 | argmin-math = { version = "0.3", features = ["ndarray_v0_15-nolinalg"] } 31 | thiserror = "1.0" 32 | rand = "0.8.5" 33 | rand_xoshiro = "0.6.0" 34 | 35 | linfa = { version = "0.7.1", path = "../.." } 36 | 37 | [dev-dependencies] 38 | criterion = "0.4.0" 39 | approx = "0.4" 40 | linfa-datasets = { version = "0.7.1", path = "../../datasets", features = [ 41 | "winequality", 42 | ] } 43 | linfa = { version = "0.7.1", path = "../..", features = ["benchmarks"] } 44 | 45 | [[bench]] 46 | name = "ftrl" 47 | harness = false 48 | -------------------------------------------------------------------------------- /algorithms/linfa-ftrl/README.md: -------------------------------------------------------------------------------- 1 | # Follow the regularized leader 2 | 3 | `linfa-ftrl` provides a pure Rust implementations of follow the regularized leader, proximal, model. 4 | 5 | ## The Big Picture 6 | 7 | `linfa-ftrl` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 8 | 9 | ## Current state 10 | 11 | The `linfa-ftrl` crate provides Follow The Regularized Leader - Proximal model with L1 and L2 regularization from Logistic Regression, and primarily used for CTR prediction. It actively stores z and n values, needed to calculate weights. 12 | Without L1 and L2 regularization, it is identical to online gradient descent. 13 | 14 | 15 | See also: 16 | * [Paper about Ftrl](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf) 17 | 18 | ## Examples 19 | 20 | There is a usage example in the `examples/` directory. To run, use: 21 | 22 | ```bash 23 | $ cargo run --example winequality 24 | ``` 25 | 26 |

27 | 28 | Show source code 29 | 30 | 31 | ```rust 32 | use linfa::prelude::*; 33 | use linfa::dataset::{AsSingleTargets, Records}; 34 | use linfa_ftrl::{Ftrl, Result}; 35 | use rand::{rngs::SmallRng, SeedableRng}; 36 | 37 | // load Winequality dataset 38 | let (train, valid) = linfa_datasets::winequality() 39 | .map_targets(|v| if *v > 6 { true } else { false }) 40 | .split_with_ratio(0.9); 41 | 42 | let params = Ftrl::params() 43 | .alpha(0.005) 44 | .beta(1.0) 45 | .l1_ratio(0.005) 46 | .l2_ratio(1.0); 47 | 48 | let valid_params = params.clone().check_unwrap(); 49 | let mut model = Ftrl::new(valid_params, train.nfeatures()); 50 | 51 | // Bootstrap each row from the train dataset to imitate online nature of the data flow 52 | let mut rng = SmallRng::seed_from_u64(42); 53 | let mut row_iter = train.bootstrap_samples(1, &mut rng); 54 | for _ in 0..train.nsamples() { 55 | let b_dataset = row_iter.next().unwrap(); 56 | model = params.fit_with(Some(model), &b_dataset)?; 57 | } 58 | let val_predictions = model.predict(&valid); 59 | println!("valid log loss {:?}", val_predictions.log_loss(&valid.as_single_targets().to_vec())?); 60 | # Result::Ok(()) 61 | ``` 62 |
63 | -------------------------------------------------------------------------------- /algorithms/linfa-ftrl/examples/winequality_ftrl.rs: -------------------------------------------------------------------------------- 1 | use linfa::dataset::{AsSingleTargets, Records}; 2 | use linfa::prelude::*; 3 | use linfa_ftrl::{Ftrl, Result}; 4 | use rand::{rngs::SmallRng, SeedableRng}; 5 | 6 | fn main() -> Result<()> { 7 | // Read the data 8 | let (train, valid) = linfa_datasets::winequality() 9 | .map_targets(|v| *v > 6) 10 | .split_with_ratio(0.9); 11 | 12 | let params = Ftrl::params() 13 | .alpha(0.005) 14 | .beta(1.0) 15 | .l1_ratio(0.005) 16 | .l2_ratio(1.0); 17 | 18 | let valid_params = params.clone().check_unwrap(); 19 | let mut model = Ftrl::new(valid_params, train.nfeatures()); 20 | 21 | // Bootstrap each row from the train dataset to imitate online nature of the data flow 22 | let mut rng = SmallRng::seed_from_u64(42); 23 | let mut row_iter = train.bootstrap_samples(1, &mut rng); 24 | for _ in 0..train.nsamples() { 25 | let b_dataset = row_iter.next().unwrap(); 26 | model = params.fit_with(Some(model), &b_dataset)?; 27 | } 28 | let val_predictions = model.predict(&valid); 29 | println!( 30 | "valid log loss {:?}", 31 | val_predictions.log_loss(&valid.as_single_targets().to_vec())? 32 | ); 33 | Ok(()) 34 | } 35 | -------------------------------------------------------------------------------- /algorithms/linfa-ftrl/src/error.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "serde")] 2 | use serde_crate::{Deserialize, Serialize}; 3 | use thiserror::Error; 4 | 5 | #[cfg_attr( 6 | feature = "serde", 7 | derive(Serialize, Deserialize), 8 | serde(crate = "serde_crate") 9 | )] 10 | #[derive(Error, Debug)] 11 | pub enum FtrlError { 12 | #[error("l1 ratio should be in range [0, 1], but is {0}")] 13 | InvalidL1Ratio(f32), 14 | #[error("l2 ratio should be in range [0, 1], but is {0}")] 15 | InvalidL2Ratio(f32), 16 | #[error("alpha should be positive and finite, but is {0}")] 17 | InvalidAlpha(f32), 18 | #[error("beta should be positive and finite, but is {0}")] 19 | InvalidBeta(f32), 20 | #[error("number of features must be bigger than 0, but is {0}")] 21 | InvalidNFeatures(usize), 22 | #[error(transparent)] 23 | LinfaError(#[from] linfa::error::Error), 24 | } 25 | -------------------------------------------------------------------------------- /algorithms/linfa-hierarchical/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-hierarchical" 3 | version = "0.7.1" 4 | authors = ["Lorenz Schmidt "] 5 | edition = "2018" 6 | 7 | description = "Agglomerative Hierarchical clustering" 8 | license = "MIT OR Apache-2.0" 9 | 10 | repository = "https://github.com/rust-ml/linfa" 11 | readme = "README.md" 12 | 13 | keywords = ["hierachical", "agglomerative", "clustering", "machine-learning", "linfa"] 14 | categories = ["algorithms", "mathematics", "science"] 15 | 16 | [dependencies] 17 | ndarray = { version = "0.15" } 18 | kodama = "0.2" 19 | thiserror = "1.0.25" 20 | 21 | linfa = { version = "0.7.1", path = "../.." } 22 | linfa-kernel = { version = "0.7.1", path = "../linfa-kernel" } 23 | 24 | [dev-dependencies] 25 | rand = "0.8" 26 | ndarray-rand = "0.14" 27 | linfa-datasets = { version = "0.7.1", path = "../../datasets", features = ["iris"] } 28 | -------------------------------------------------------------------------------- /algorithms/linfa-hierarchical/README.md: -------------------------------------------------------------------------------- 1 | # Clustering 2 | 3 | `linfa-hierarchical` provides an implementation of agglomerative hierarchical clustering. 4 | In this clustering algorithm, each point is first considered as a separate cluster. During each 5 | step, two points are merged into new clusters, until a stopping criterion is reached. The distance 6 | between the points is computed as the negative-log transform of the similarity kernel. 7 | 8 | _Documentation_: [latest](https://docs.rs/linfa). 9 | 10 | ## The big picture 11 | 12 | `linfa-hierarchical` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, a wider effort to bootstrap a toolkit for classical Machine Learning implemented in pure Rust, akin in spirit to Python's `scikit-learn`. 13 | 14 | ## Current state 15 | 16 | `linfa-hierarchical` implements agglomerative hierarchical clustering with support of the [kodama](https://docs.rs/kodama/0.2.3/kodama/) crate. 17 | 18 | ## License 19 | Dual-licensed to be compatible with the Rust project. 20 | 21 | Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms. 22 | -------------------------------------------------------------------------------- /algorithms/linfa-hierarchical/examples/irisflower.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | 3 | use linfa::traits::Transformer; 4 | use linfa_hierarchical::HierarchicalCluster; 5 | use linfa_kernel::{Kernel, KernelMethod}; 6 | 7 | fn main() -> Result<(), Box> { 8 | // load Iris plant dataset 9 | let dataset = linfa_datasets::iris(); 10 | 11 | let kernel = Kernel::params() 12 | .method(KernelMethod::Gaussian(1.0)) 13 | .transform(dataset.records().view()); 14 | 15 | let kernel = HierarchicalCluster::default() 16 | .num_clusters(3) 17 | .transform(kernel)?; 18 | 19 | for (id, target) in kernel.targets().iter().zip(dataset.targets().into_iter()) { 20 | let name = match *target { 21 | 0 => "setosa", 22 | 1 => "versicolor", 23 | 2 => "virginica", 24 | _ => unreachable!(), 25 | }; 26 | 27 | print!("({} {}) ", id, name); 28 | } 29 | println!(); 30 | 31 | Ok(()) 32 | } 33 | -------------------------------------------------------------------------------- /algorithms/linfa-hierarchical/src/error.rs: -------------------------------------------------------------------------------- 1 | //! Error definitions 2 | //! 3 | 4 | use crate::{Criterion, Float}; 5 | use thiserror::Error; 6 | 7 | /// Simplified `Result` using [`HierarchicalError`](crate::HierarchicalError) as error type 8 | pub type Result = std::result::Result>; 9 | 10 | /// Error variants from parameter construction 11 | #[derive(Error, Debug)] 12 | pub enum HierarchicalError { 13 | /// Invalid stopping condition 14 | #[error("The stopping condition {0:?} is not valid")] 15 | InvalidStoppingCondition(Criterion), 16 | #[error(transparent)] 17 | BaseCrate(#[from] linfa::Error), 18 | } 19 | -------------------------------------------------------------------------------- /algorithms/linfa-ica/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-ica" 3 | version = "0.7.1" 4 | authors = ["VasanthakumarV "] 5 | description = "A collection of Independent Component Analysis (ICA) algorithms" 6 | edition = "2018" 7 | license = "MIT OR Apache-2.0" 8 | 9 | repository = "https://github.com/rust-ml/linfa" 10 | readme = "README.md" 11 | 12 | keywords = ["factorization", "machine-learning", "linfa", "unsupervised"] 13 | categories = ["algorithms", "mathematics", "science"] 14 | 15 | [features] 16 | default = [] 17 | blas = ["ndarray-linalg", "linfa/ndarray-linalg"] 18 | serde = ["serde_crate", "ndarray/serde"] 19 | 20 | [dependencies.serde_crate] 21 | package = "serde" 22 | optional = true 23 | version = "1.0" 24 | default-features = false 25 | features = ["std", "derive"] 26 | 27 | [dependencies] 28 | ndarray = { version = "0.15" } 29 | linfa-linalg = { version = "0.1", default-features = false } 30 | ndarray-linalg = { version = "0.16", optional = true } 31 | ndarray-rand = "0.14" 32 | ndarray-stats = "0.5" 33 | num-traits = "0.2" 34 | rand_xoshiro = "0.6" 35 | thiserror = "1.0" 36 | 37 | linfa = { version = "0.7.1", path = "../.." } 38 | 39 | [dev-dependencies] 40 | ndarray-npy = { version = "0.8", default-features = false } 41 | paste = "1.0" 42 | criterion = "0.4.0" 43 | linfa = { version = "0.7.1", path = "../..", features = ["benchmarks"] } 44 | 45 | [[bench]] 46 | name = "fast_ica" 47 | harness = false 48 | -------------------------------------------------------------------------------- /algorithms/linfa-ica/README.md: -------------------------------------------------------------------------------- 1 | # Independent Component Analysis (ICA) 2 | 3 | `linfa-ica` aims to provide pure Rust implementations of ICA algorithms. 4 | 5 | ## The Big Picture 6 | 7 | `linfa-ica` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 8 | 9 | ## Current state 10 | 11 | `linfa-ica` currently provides an implementation of the following factorization methods: 12 | 13 | - Fast Independent Component Analysis (FastICA) 14 | 15 | ## Examples 16 | 17 | There is an usage example in the `examples/` directory. To run, use: 18 | 19 | ```bash 20 | $ cargo run --release --example fast_ica 21 | ``` 22 | 23 | ## BLAS/Lapack backend 24 | 25 | See [this section](../../README.md#blaslapack-backend) to enable an external BLAS/LAPACK backend. 26 | 27 | ## License 28 | Dual-licensed to be compatible with the Rust project. 29 | 30 | Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be copied, modified, or distributed except according to those terms. 31 | -------------------------------------------------------------------------------- /algorithms/linfa-ica/benches/fast_ica.rs: -------------------------------------------------------------------------------- 1 | use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; 2 | use linfa::benchmarks::config; 3 | use linfa::{dataset::DatasetBase, traits::Fit}; 4 | use linfa_ica::fast_ica::{FastIca, GFunc}; 5 | use ndarray::{array, concatenate}; 6 | use ndarray::{Array, Array2, Axis}; 7 | use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt}; 8 | use rand_xoshiro::Xoshiro256Plus; 9 | 10 | fn perform_ica(size: usize, gfunc: GFunc) { 11 | let sources_mixed = create_data(size); 12 | 13 | let ica = FastIca::params().gfunc(gfunc).random_state(10); 14 | 15 | ica.fit(&DatasetBase::from(sources_mixed.view())).unwrap(); 16 | } 17 | 18 | fn create_data(nsamples: usize) -> Array2 { 19 | // Creating a sine wave signal 20 | let source1 = Array::linspace(0., 8., nsamples).mapv(|x| (2f64 * x).sin()); 21 | 22 | // Creating a sawtooth signal 23 | let source2 = Array::linspace(0., 8., nsamples).mapv(|x| { 24 | let tmp = (4f64 * x).sin(); 25 | if tmp > 0. { 26 | return 1.; 27 | } 28 | -1. 29 | }); 30 | 31 | // Column concatenating both the signals 32 | let mut sources_original = concatenate![ 33 | Axis(1), 34 | source1.insert_axis(Axis(1)), 35 | source2.insert_axis(Axis(1)) 36 | ]; 37 | 38 | // Adding noise to the signals 39 | let mut rng = Xoshiro256Plus::seed_from_u64(42); 40 | sources_original += 41 | &Array::random_using((nsamples, 2), Uniform::new(0.0, 1.0), &mut rng).mapv(|x| x * 0.2); 42 | 43 | // Mixing the two signals 44 | let mixing = array![[1., 1.], [0.5, 2.]]; 45 | let sources_mixed = sources_original.dot(&mixing.t()); 46 | 47 | sources_mixed 48 | } 49 | 50 | fn bench(c: &mut Criterion) { 51 | for (gfunc, name) in [ 52 | (GFunc::Cube, "GFunc_Cube"), 53 | (GFunc::Logcosh(1.0), "GFunc_Logcosh"), 54 | (GFunc::Exp, "Exp"), 55 | ] { 56 | let mut group = c.benchmark_group("Fast ICA"); 57 | config::set_default_benchmark_configs(&mut group); 58 | 59 | let sizes: [usize; 3] = [1_000, 10_000, 100_000]; 60 | for size in sizes { 61 | let input = (size, gfunc); 62 | group.bench_with_input(BenchmarkId::new(name, size), &input, |b, (size, gfunc)| { 63 | b.iter(|| perform_ica(*size, *gfunc)); 64 | }); 65 | } 66 | group.finish(); 67 | } 68 | } 69 | 70 | #[cfg(not(target_os = "windows"))] 71 | criterion_group! { 72 | name = benches; 73 | config = config::get_default_profiling_configs(); 74 | targets = bench 75 | } 76 | #[cfg(target_os = "windows")] 77 | criterion_group!(benches, bench); 78 | 79 | criterion_main!(benches); 80 | -------------------------------------------------------------------------------- /algorithms/linfa-ica/examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | ## FastICA 4 | 5 | ``` 6 | cargo run --example fast_ica 7 | ``` 8 | 9 | This example creates three .npy files, we plot them using python's [matplotlib](https://matplotlib.org/) separately. 10 | 11 | ICA algorithms do not retain the ordering or the sign of the input, hence they can differ in the output. 12 | 13 | ![fast_ica_example_plot](images/fast_ica.png) 14 | -------------------------------------------------------------------------------- /algorithms/linfa-ica/examples/fast_ica.rs: -------------------------------------------------------------------------------- 1 | use linfa::{ 2 | dataset::DatasetBase, 3 | traits::{Fit, Predict}, 4 | }; 5 | use linfa_ica::fast_ica::{FastIca, GFunc}; 6 | use ndarray::{array, concatenate}; 7 | use ndarray::{Array, Array2, Axis}; 8 | use ndarray_npy::write_npy; 9 | use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt}; 10 | use rand_xoshiro::Xoshiro256Plus; 11 | use std::error::Error; 12 | 13 | fn main() -> Result<(), Box> { 14 | // Create sample dataset for the model 15 | // `sources_original` has the unmixed sources (we merely have it to save to disk) 16 | // `sources_mixed` is the mixed source that will be unmixed using ICA 17 | // Shape of the data will be (2000 x 2) 18 | let (sources_original, sources_mixed) = create_data(); 19 | 20 | // Fitting the model 21 | // We set the G function used in the approximation of neg-entropy as logcosh 22 | // with its alpha value as 1 23 | // `ncomponents` is not set, it will be automatically be assigned 2 from 24 | // the input 25 | let ica = FastIca::params().gfunc(GFunc::Logcosh(1.0)); 26 | let ica = ica.fit(&DatasetBase::from(sources_mixed.view()))?; 27 | 28 | // Here we unmix the data to recover back the original signals 29 | let sources_ica = ica.predict(&sources_mixed); 30 | 31 | // Saving to disk 32 | write_npy("sources_original.npy", &sources_original).expect("Failed to write .npy file"); 33 | write_npy("sources_mixed.npy", &sources_mixed).expect("Failed to write .npy file"); 34 | write_npy("sources_ica.npy", &sources_ica).expect("Failed to write .npy file"); 35 | 36 | Ok(()) 37 | } 38 | 39 | // Helper function to create two signals (sources) and mix them together 40 | // as input for the ICA model 41 | fn create_data() -> (Array2, Array2) { 42 | let nsamples = 2000; 43 | 44 | // Creating a sine wave signal 45 | let source1 = Array::linspace(0., 8., nsamples).mapv(|x| (2f64 * x).sin()); 46 | 47 | // Creating a sawtooth signal 48 | let source2 = Array::linspace(0., 8., nsamples).mapv(|x| { 49 | let tmp = (4f64 * x).sin(); 50 | if tmp > 0. { 51 | return 1.; 52 | } 53 | -1. 54 | }); 55 | 56 | // Column concatenating both the signals 57 | let mut sources_original = concatenate![ 58 | Axis(1), 59 | source1.insert_axis(Axis(1)), 60 | source2.insert_axis(Axis(1)) 61 | ]; 62 | 63 | // Adding noise to the signals 64 | let mut rng = Xoshiro256Plus::seed_from_u64(42); 65 | sources_original += 66 | &Array::random_using((2000, 2), Uniform::new(0.0, 1.0), &mut rng).mapv(|x| x * 0.2); 67 | 68 | // Mixing the two signals 69 | let mixing = array![[1., 1.], [0.5, 2.]]; 70 | let sources_mixed = sources_original.dot(&mixing.t()); 71 | 72 | (sources_original, sources_mixed) 73 | } 74 | -------------------------------------------------------------------------------- /algorithms/linfa-ica/examples/images/fast_ica.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rust-ml/linfa/20b1dd2d0879ca114aa4ea24db5cfdcdd1aae186/algorithms/linfa-ica/examples/images/fast_ica.png -------------------------------------------------------------------------------- /algorithms/linfa-ica/src/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | pub type Result = std::result::Result; 4 | 5 | /// An error when modeling FastICA algorithm 6 | #[derive(Error, Debug)] 7 | #[non_exhaustive] 8 | pub enum FastIcaError { 9 | /// When there are no samples in the provided dataset 10 | #[error("Dataset must contain at least one sample")] 11 | NotEnoughSamples, 12 | /// When any of the hyperparameters are set the wrong value 13 | #[error("Invalid value encountered: {0}")] 14 | InvalidValue(String), 15 | /// If we fail to compute any components of the SVD decomposition 16 | /// due to an Ill-Conditioned matrix 17 | #[error("SVD Decomposition failed, X could be an Ill-Conditioned matrix")] 18 | SvdDecomposition, 19 | #[error("tolerance should be positive but is {0}")] 20 | InvalidTolerance(f32), 21 | #[cfg(feature = "blas")] 22 | #[error("Linalg BLAS error: {0}")] 23 | LinalgBlasError(#[from] ndarray_linalg::error::LinalgError), 24 | #[error("Linalg error: {0}")] 25 | /// Errors encountered during linear algebra operations 26 | LinalgError(#[from] linfa_linalg::LinalgError), 27 | #[error(transparent)] 28 | LinfaError(#[from] linfa::error::Error), 29 | } 30 | -------------------------------------------------------------------------------- /algorithms/linfa-kernel/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-kernel" 3 | version = "0.7.1" 4 | authors = ["Lorenz Schmidt "] 5 | description = "Kernel methods for non-linear algorithms" 6 | edition = "2018" 7 | license = "MIT OR Apache-2.0" 8 | 9 | repository = "https://github.com/rust-ml/linfa" 10 | readme = "README.md" 11 | 12 | keywords = ["kernel", "machine-learning", "linfa"] 13 | categories = ["algorithms", "mathematics", "science"] 14 | 15 | [features] 16 | default = [] 17 | serde = ["serde_crate", "ndarray/serde", "sprs/serde"] 18 | 19 | [dependencies.serde_crate] 20 | package = "serde" 21 | optional = true 22 | version = "1.0" 23 | default-features = false 24 | features = ["std", "derive"] 25 | 26 | [dependencies] 27 | ndarray = "0.15" 28 | num-traits = "0.2" 29 | sprs = { version = "=0.11.1", default-features = false } 30 | 31 | linfa = { version = "0.7.1", path = "../.." } 32 | linfa-nn = { version = "0.7.1", path = "../linfa-nn" } 33 | -------------------------------------------------------------------------------- /algorithms/linfa-kernel/README.md: -------------------------------------------------------------------------------- 1 | # Kernel methods 2 | 3 | `linfa-kernel` provides methods for dimensionality expansion. 4 | 5 | ## The Big Picture 6 | 7 | `linfa-kernel` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 8 | 9 | In machine learning, kernel methods are a class of algorithms for pattern analysis, whose best known member is the [support vector machine](https://en.wikipedia.org/wiki/Support_vector_machine). They owe their name to the kernel functions, which maps the features to some higher-dimensional target space. Common examples for kernel functions are the radial basis function (euclidean distance) or polynomial kernels. 10 | 11 | ## Current State 12 | 13 | linfa-kernel currently provides an implementation of kernel methods for RBF and polynomial kernels, with sparse or dense representation. Further a k-neighbour approximation allows to reduce the kernel matrix size. 14 | 15 | Low-rank kernel approximation are currently missing, but are on the roadmap. Examples for these are the [Nyström approximation](https://www.jmlr.org/papers/volume6/drineas05a/drineas05a.pdf) or [Quasi Random Fourier Features](http://www-personal.umich.edu/~aniketde/processed_md/Stats608_Aniketde.pdf). 16 | 17 | ## License 18 | Dual-licensed to be compatible with the Rust project. 19 | 20 | Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be copied, modified, or distributed except according to those terms. 21 | -------------------------------------------------------------------------------- /algorithms/linfa-linear/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-linear" 3 | version = "0.7.1" 4 | authors = [ 5 | "Paul Körbitz / Google ", 6 | "VasanthakumarV ", 7 | ] 8 | 9 | description = "A Machine Learning framework for Rust" 10 | edition = "2018" 11 | license = "MIT OR Apache-2.0" 12 | 13 | repository = "https://github.com/rust-ml/linfa" 14 | readme = "README.md" 15 | 16 | keywords = ["machine-learning", "linfa", "ai", "ml", "linear"] 17 | categories = ["algorithms", "mathematics", "science"] 18 | 19 | [features] 20 | blas = ["ndarray-linalg", "linfa/ndarray-linalg"] 21 | serde = ["serde_crate", "linfa/serde", "ndarray/serde", "argmin/serde1"] 22 | wasm-bindgen = ["argmin/wasm-bindgen"] 23 | 24 | [dependencies.serde_crate] 25 | package = "serde" 26 | optional = true 27 | version = "1.0" 28 | default-features = false 29 | features = ["std", "derive"] 30 | 31 | [dependencies] 32 | ndarray = { version = "0.15", features = ["approx"] } 33 | linfa-linalg = { version = "0.1", default-features = false } 34 | ndarray-linalg = { version = "0.16", optional = true } 35 | num-traits = "0.2" 36 | argmin = { version = "0.9.0", default-features = false } 37 | argmin-math = { version = "0.3", features = ["ndarray_v0_15-nolinalg"] } 38 | thiserror = "1.0" 39 | 40 | linfa = { version = "0.7.1", path = "../.." } 41 | 42 | [dev-dependencies] 43 | linfa-datasets = { version = "0.7.1", path = "../../datasets", features = [ 44 | "diabetes", 45 | ] } 46 | approx = "0.4" 47 | criterion = "0.4.0" 48 | statrs = "0.16.0" 49 | linfa = { version = "0.7.1", path = "../..", features = ["benchmarks"] } 50 | 51 | [[bench]] 52 | name = "ols_bench" 53 | harness = false 54 | -------------------------------------------------------------------------------- /algorithms/linfa-linear/README.md: -------------------------------------------------------------------------------- 1 | # Linear Models 2 | 3 | `linfa-linear` aims to provide pure Rust implementations of popular linear regression algorithms. 4 | 5 | ## The Big Picture 6 | 7 | `linfa-linear` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 8 | 9 | ## Current state 10 | 11 | `linfa-linear` currently provides an implementation of the following regression algorithms: 12 | - Ordinary Least Squares 13 | - Generalized Linear Models (GLM) 14 | 15 | ## Examples 16 | 17 | There is an usage example in the `examples/` directory. To run, use: 18 | 19 | ```bash 20 | $ cargo run --example diabetes 21 | $ cargo run --example glm 22 | ``` 23 | 24 | ## BLAS/Lapack backend 25 | 26 | See [this section](../../README.md#blaslapack-backend) to enable an external BLAS/LAPACK backend. 27 | 28 | ## License 29 | Dual-licensed to be compatible with the Rust project. 30 | 31 | Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be copied, modified, or distributed except according to those terms. 32 | -------------------------------------------------------------------------------- /algorithms/linfa-linear/benches/ols_bench.rs: -------------------------------------------------------------------------------- 1 | use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; 2 | use linfa::benchmarks::config; 3 | use linfa::traits::Fit; 4 | use linfa::Dataset; 5 | use linfa_datasets::generate::make_dataset; 6 | use linfa_linear::{LinearRegression, TweedieRegressor}; 7 | use ndarray::Ix1; 8 | use statrs::distribution::{DiscreteUniform, Laplace}; 9 | 10 | #[allow(unused_must_use)] 11 | fn perform_ols(dataset: &Dataset) { 12 | let model = LinearRegression::new(); 13 | model.fit(dataset); 14 | } 15 | 16 | #[allow(unused_must_use)] 17 | fn perform_glm(dataset: &Dataset) { 18 | let model = TweedieRegressor::params().power(0.).alpha(0.); 19 | model.fit(dataset); 20 | } 21 | 22 | fn bench(c: &mut Criterion) { 23 | let mut group = c.benchmark_group("Linfa_linear"); 24 | config::set_default_benchmark_configs(&mut group); 25 | 26 | let params: [(usize, usize); 4] = [(1_000, 5), (10_000, 5), (100_000, 5), (100_000, 10)]; 27 | 28 | let feat_distr = Laplace::new(0.5, 5.).unwrap(); 29 | let target_distr = DiscreteUniform::new(0, 5).unwrap(); 30 | 31 | let ols_id = "OLS-".to_string(); 32 | let glm_id = "GLM-".to_string(); 33 | 34 | for (size, num_feat) in params { 35 | let suffix = format!("{}Feats", num_feat); 36 | let mut func_name = ols_id.clone(); 37 | func_name.push_str(&suffix); 38 | 39 | let dataset = make_dataset(size, num_feat, 1, feat_distr, target_distr); 40 | let dataset = dataset.into_single_target(); 41 | 42 | group.bench_with_input( 43 | BenchmarkId::new(&func_name, size), 44 | &dataset, 45 | |b, dataset| { 46 | b.iter(|| perform_ols(dataset)); 47 | }, 48 | ); 49 | 50 | let mut func_name = glm_id.clone(); 51 | func_name.push_str(&suffix); 52 | group.bench_with_input( 53 | BenchmarkId::new(&func_name, size), 54 | &dataset, 55 | |b, dataset| { 56 | b.iter(|| perform_glm(dataset)); 57 | }, 58 | ); 59 | } 60 | group.finish(); 61 | } 62 | 63 | #[cfg(not(target_os = "windows"))] 64 | criterion_group! { 65 | name = benches; 66 | config = config::get_default_profiling_configs(); 67 | targets = bench 68 | } 69 | #[cfg(target_os = "windows")] 70 | criterion_group!(benches, bench); 71 | 72 | criterion_main!(benches); 73 | -------------------------------------------------------------------------------- /algorithms/linfa-linear/examples/diabetes.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | 3 | use linfa::traits::Fit; 4 | use linfa_linear::LinearRegression; 5 | 6 | fn main() -> Result<(), Box> { 7 | // load Diabetes dataset 8 | let dataset = linfa_datasets::diabetes(); 9 | 10 | let lin_reg = LinearRegression::new(); 11 | let model = lin_reg.fit(&dataset)?; 12 | 13 | println!("intercept: {}", model.intercept()); 14 | println!("parameters: {}", model.params()); 15 | 16 | Ok(()) 17 | } 18 | -------------------------------------------------------------------------------- /algorithms/linfa-linear/examples/glm.rs: -------------------------------------------------------------------------------- 1 | use linfa::prelude::*; 2 | use linfa_linear::{Result, TweedieRegressor}; 3 | use ndarray::Axis; 4 | 5 | fn main() -> Result<(), f64> { 6 | // load the Diabetes dataset 7 | let dataset = linfa_datasets::diabetes(); 8 | 9 | // Here the power and alpha is set to 0 10 | // Setting the power to 0 makes it a Normal Regressioon 11 | // Setting the alpha to 0 removes any regularization 12 | // In total this is the regular old Linear Regression 13 | let lin_reg = TweedieRegressor::params().power(0.).alpha(0.); 14 | let model = lin_reg.fit(&dataset)?; 15 | 16 | // We print the learnt parameters 17 | // 18 | // intercept: 152.13349207485706 19 | // parameters: [-10.01009490755511, -239.81838728651834, 519.8493593356682, 324.3878222341785, -792.2097759223642, 476.75394339962384, 101.07307112047873, 177.0853514839987, 751.2889123356807, 67.61902228894756] 20 | println!("intercept: {}", model.intercept); 21 | println!("parameters: {}", model.coef); 22 | 23 | // We print the Mean Absolute Error (MAE) on the training data 24 | // 25 | // Some(43.27739632065444) 26 | let ypred = model.predict(&dataset); 27 | let loss = (dataset.targets() - &ypred.insert_axis(Axis(1))) 28 | .mapv(|x| x.abs()) 29 | .mean(); 30 | 31 | println!("{:?}", loss); 32 | 33 | Ok(()) 34 | } 35 | -------------------------------------------------------------------------------- /algorithms/linfa-linear/src/error.rs: -------------------------------------------------------------------------------- 1 | //! An error when modeling a Linear algorithm 2 | use linfa::Float; 3 | use thiserror::Error; 4 | 5 | pub type Result = std::result::Result>; 6 | 7 | /// An error when modeling a Linear algorithm 8 | #[derive(Error, Debug)] 9 | #[non_exhaustive] 10 | pub enum LinearError { 11 | /// Errors encountered when using argmin's solver 12 | #[error("argmin {0}")] 13 | Argmin(#[from] argmin::core::Error), 14 | #[error(transparent)] 15 | BaseCrate(#[from] linfa::Error), 16 | #[error("At least one sample needed")] 17 | NotEnoughSamples, 18 | #[error("At least one target needed")] 19 | NotEnoughTargets, 20 | #[error("penalty should be positive, but is {0}")] 21 | InvalidPenalty(F), 22 | #[error("tweedie distribution power should not be in (0, 1), but is {0}")] 23 | InvalidTweediePower(F), 24 | #[error("some value(s) of y are out of the valid range for power value {0}")] 25 | InvalidTargetRange(F), 26 | #[error(transparent)] 27 | #[cfg(feature = "blas")] 28 | LinalgBlasError(#[from] ndarray_linalg::error::LinalgError), 29 | #[error(transparent)] 30 | LinalgError(#[from] linfa_linalg::LinalgError), 31 | } 32 | -------------------------------------------------------------------------------- /algorithms/linfa-linear/src/float.rs: -------------------------------------------------------------------------------- 1 | use argmin::core::ArgminFloat; 2 | use ndarray::NdFloat; 3 | use num_traits::float::FloatConst; 4 | use num_traits::FromPrimitive; 5 | 6 | // A Float trait that captures the requirements we need for the various places 7 | // we need floats. There requirements are imposed y ndarray and argmin 8 | pub trait Float: 9 | ArgminFloat + FloatConst + NdFloat + Default + Clone + FromPrimitive + linfa::Float 10 | { 11 | const POSITIVE_LABEL: Self; 12 | const NEGATIVE_LABEL: Self; 13 | } 14 | 15 | impl Float for f32 { 16 | const POSITIVE_LABEL: Self = 1.0; 17 | const NEGATIVE_LABEL: Self = -1.0; 18 | } 19 | 20 | impl Float for f64 { 21 | const POSITIVE_LABEL: Self = 1.0; 22 | const NEGATIVE_LABEL: Self = -1.0; 23 | } 24 | -------------------------------------------------------------------------------- /algorithms/linfa-linear/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! `linfa-linear` aims to provide pure Rust implementations of popular linear regression algorithms. 3 | //! 4 | //! ## The Big Picture 5 | //! 6 | //! `linfa-linear` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning 7 | //! implemented in pure Rust, akin to Python's `scikit-learn`. 8 | //! 9 | //! ## Current state 10 | //! 11 | //! `linfa-linear` currently provides an implementation of the following regression algorithms: 12 | //! - Ordinary Least Squares 13 | //! - Generalized Linear Models (GLM) 14 | //! - Isotonic 15 | //! 16 | //! ## Examples 17 | //! 18 | //! There is an usage example in the `examples/` directory. To run, use: 19 | //! 20 | //! ```bash 21 | //! $ cargo run --features openblas --example diabetes 22 | //! $ cargo run --example glm 23 | //! ``` 24 | 25 | mod error; 26 | mod float; 27 | mod glm; 28 | mod isotonic; 29 | mod ols; 30 | 31 | pub use error::*; 32 | pub use glm::*; 33 | pub use isotonic::*; 34 | pub use ols::*; 35 | -------------------------------------------------------------------------------- /algorithms/linfa-logistic/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-logistic" 3 | version = "0.7.1" 4 | authors = ["Paul Körbitz / Google "] 5 | 6 | description = "A Machine Learning framework for Rust" 7 | edition = "2018" 8 | license = "MIT OR Apache-2.0" 9 | 10 | repository = "https://github.com/rust-ml/linfa" 11 | readme = "README.md" 12 | 13 | keywords = ["machine-learning", "linfa", "ai", "ml", "linear"] 14 | categories = ["algorithms", "mathematics", "science"] 15 | 16 | [features] 17 | serde = ["serde_crate", "linfa/serde", "ndarray/serde", "argmin/serde1"] 18 | wasm-bindgen = ["argmin/wasm-bindgen"] 19 | 20 | [dependencies.serde_crate] 21 | package = "serde" 22 | optional = true 23 | version = "1.0" 24 | 25 | [dependencies] 26 | ndarray = { version = "0.15", features = ["approx"] } 27 | ndarray-stats = "0.5.0" 28 | num-traits = "0.2" 29 | argmin = { version = "0.9.0", default-features = false } 30 | argmin-math = { version = "0.3", features = ["ndarray_v0_15-nolinalg"] } 31 | thiserror = "1.0" 32 | 33 | linfa = { version = "0.7.1", path = "../.." } 34 | 35 | [dev-dependencies] 36 | approx = "0.4" 37 | linfa-datasets = { version = "0.7.1", path = "../../datasets", features = [ 38 | "winequality", 39 | ] } 40 | rmp-serde = "1" 41 | -------------------------------------------------------------------------------- /algorithms/linfa-logistic/README.md: -------------------------------------------------------------------------------- 1 | # Logistic Regression 2 | 3 | ## The Big Picture 4 | 5 | `linfa-logistic` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 6 | 7 | ## Current state 8 | `linfa-logistic` provides pure Rust implementations of two-class and multinomial logistic regression models. 9 | 10 | ## Examples 11 | There are usage examples in the `examples/` directory. 12 | 13 | To run the two-class example, use: 14 | ```bash 15 | $ cargo run --example winequality_logistic 16 | ``` 17 | 18 | To run the multinomial example, use: 19 | ```bash 20 | $ cargo run --example winequality_multi_logistic 21 | ``` 22 | 23 | ## License 24 | Dual-licensed to be compatible with the Rust project. 25 | 26 | Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be copied, modified, or distributed except according to those terms. 27 | -------------------------------------------------------------------------------- /algorithms/linfa-logistic/examples/logistic_cv.rs: -------------------------------------------------------------------------------- 1 | use linfa::prelude::*; 2 | use linfa_logistic::error::Result; 3 | use linfa_logistic::LogisticRegression; 4 | 5 | fn main() -> Result<()> { 6 | // Load dataset. Mutability is needed for fast cross validation 7 | let mut dataset = 8 | linfa_datasets::winequality().map_targets(|x| if *x > 6 { "good" } else { "bad" }); 9 | 10 | // define a sequence of models to compare. In this case the 11 | // models will differ by the amount of l2 regularization 12 | let alphas = &[0.1, 1., 10.]; 13 | let models: Vec<_> = alphas 14 | .iter() 15 | .map(|alpha| { 16 | LogisticRegression::default() 17 | .alpha(*alpha) 18 | .max_iterations(150) 19 | }) 20 | .collect(); 21 | 22 | // use cross validation to compute the validation accuracy of each model. The 23 | // accuracy of each model will be averaged across the folds, 5 in this case 24 | let accuracies = dataset.cross_validate_single(5, &models, |prediction, truth| { 25 | Ok(prediction.confusion_matrix(truth)?.accuracy()) 26 | })?; 27 | 28 | // display the accuracy of the models along with their regularization coefficient 29 | for (alpha, accuracy) in alphas.iter().zip(accuracies.iter()) { 30 | println!("Alpha: {}, accuracy: {} ", alpha, accuracy); 31 | } 32 | 33 | Ok(()) 34 | } 35 | -------------------------------------------------------------------------------- /algorithms/linfa-logistic/examples/winequality_logistic.rs: -------------------------------------------------------------------------------- 1 | use linfa::prelude::*; 2 | use linfa_logistic::LogisticRegression; 3 | 4 | use std::error::Error; 5 | 6 | fn main() -> Result<(), Box> { 7 | // everything above 6.5 is considered a good wine 8 | let (train, valid) = linfa_datasets::winequality() 9 | .map_targets(|x| if *x > 6 { "good" } else { "bad" }) 10 | .split_with_ratio(0.9); 11 | 12 | println!( 13 | "Fit Logistic Regression classifier with #{} training points", 14 | train.nsamples() 15 | ); 16 | 17 | // fit a Logistic regression model with 150 max iterations 18 | let model = LogisticRegression::default() 19 | .max_iterations(150) 20 | .fit(&train) 21 | .unwrap(); 22 | 23 | // predict and map targets 24 | let pred = model.predict(&valid); 25 | 26 | // create a confusion matrix 27 | let cm = pred.confusion_matrix(&valid).unwrap(); 28 | 29 | // Print the confusion matrix, this will print a table with four entries. On the diagonal are 30 | // the number of true-positive and true-negative predictions, off the diagonal are 31 | // false-positive and false-negative 32 | println!("{:?}", cm); 33 | 34 | // Calculate the accuracy and Matthew Correlation Coefficient (cross-correlation between 35 | // predicted and targets) 36 | println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc()); 37 | 38 | Ok(()) 39 | } 40 | -------------------------------------------------------------------------------- /algorithms/linfa-logistic/examples/winequality_multi_logistic.rs: -------------------------------------------------------------------------------- 1 | use linfa::prelude::*; 2 | use linfa_logistic::MultiLogisticRegression; 3 | 4 | use std::error::Error; 5 | 6 | fn main() -> Result<(), Box> { 7 | let (train, valid) = linfa_datasets::winequality().split_with_ratio(0.9); 8 | 9 | println!( 10 | "Fit Multinomial Logistic Regression classifier with #{} training points", 11 | train.nsamples() 12 | ); 13 | 14 | // fit a Logistic regression model with 150 max iterations 15 | let model = MultiLogisticRegression::default() 16 | .max_iterations(50) 17 | .fit(&train) 18 | .unwrap(); 19 | 20 | // predict and map targets 21 | let pred = model.predict(&valid); 22 | 23 | // create a confusion matrix 24 | let cm = pred.confusion_matrix(&valid).unwrap(); 25 | 26 | // Print the confusion matrix, this will print a table with four entries. On the diagonal are 27 | // the number of true-positive and true-negative predictions, off the diagonal are 28 | // false-positive and false-negative 29 | println!("{:?}", cm); 30 | 31 | // Calculate the accuracy and Matthew Correlation Coefficient (cross-correlation between 32 | // predicted and targets) 33 | println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc()); 34 | 35 | Ok(()) 36 | } 37 | -------------------------------------------------------------------------------- /algorithms/linfa-logistic/src/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | pub type Result = std::result::Result; 3 | 4 | #[derive(Error, Debug)] 5 | pub enum Error { 6 | #[error(transparent)] 7 | LinfaError(#[from] linfa::Error), 8 | #[error("More than two classes for logistic regression")] 9 | TooManyClasses, 10 | #[error("Fewer than two classes for logistic regression")] 11 | TooFewClasses, 12 | #[error(transparent)] 13 | ArgMinError(#[from] argmin::core::Error), 14 | #[error("Expected `x` and `y` to have same number of rows, got {0} != {1}")] 15 | MismatchedShapes(usize, usize), 16 | #[error("Values must be finite and not `Inf`, `-Inf` or `NaN`")] 17 | InvalidValues, 18 | #[error("Rows of initial parameter ({rows}) must be the same as the number of features ({n_features})")] 19 | InitialParameterFeaturesMismatch { rows: usize, n_features: usize }, 20 | #[error("Columns of initial parameter ({cols}) must be the same as the number of classes ({n_classes})")] 21 | InitialParameterClassesMismatch { cols: usize, n_classes: usize }, 22 | 23 | #[error("gradient_tolerance must be a positive, finite number")] 24 | InvalidGradientTolerance, 25 | #[error("alpha must be a positive, finite number")] 26 | InvalidAlpha, 27 | #[error("Initial parameters must be finite")] 28 | InvalidInitialParameters, 29 | } 30 | -------------------------------------------------------------------------------- /algorithms/linfa-logistic/src/float.rs: -------------------------------------------------------------------------------- 1 | use crate::argmin_param::ArgminParam; 2 | use argmin::core::ArgminFloat; 3 | use argmin_math::ArgminMul; 4 | use ndarray::{Dimension, Ix1, Ix2, NdFloat}; 5 | use num_traits::FromPrimitive; 6 | 7 | /// A Float trait that captures the requirements we need for the various 8 | /// places we use floats. These are basically imposed by NdArray and Argmin. 9 | pub trait Float: 10 | ArgminFloat 11 | + NdFloat 12 | + Default 13 | + Clone 14 | + FromPrimitive 15 | + ArgminMul, ArgminParam> 16 | + ArgminMul, ArgminParam> 17 | + linfa::Float 18 | { 19 | const POSITIVE_LABEL: Self; 20 | const NEGATIVE_LABEL: Self; 21 | } 22 | 23 | impl ArgminMul, ArgminParam> for f64 { 24 | fn mul(&self, other: &ArgminParam) -> ArgminParam { 25 | ArgminParam(&other.0 * *self) 26 | } 27 | } 28 | 29 | impl ArgminMul, ArgminParam> for f32 { 30 | fn mul(&self, other: &ArgminParam) -> ArgminParam { 31 | ArgminParam(&other.0 * *self) 32 | } 33 | } 34 | 35 | impl Float for f32 { 36 | const POSITIVE_LABEL: Self = 1.0; 37 | const NEGATIVE_LABEL: Self = -1.0; 38 | } 39 | 40 | impl Float for f64 { 41 | const POSITIVE_LABEL: Self = 1.0; 42 | const NEGATIVE_LABEL: Self = -1.0; 43 | } 44 | -------------------------------------------------------------------------------- /algorithms/linfa-nn/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-nn" 3 | version = "0.7.1" 4 | authors = ["YuhanLiin "] 5 | edition = "2018" 6 | description = "A collection of nearest neighbour algorithms" 7 | license = "MIT OR Apache-2.0" 8 | 9 | repository = "https://github.com/rust-ml/linfa/" 10 | readme = "README.md" 11 | 12 | keywords = ["nearest-neighbour", "machine-learning", "linfa"] 13 | categories = ["algorithms", "mathematics", "science"] 14 | 15 | [features] 16 | default = [] 17 | serde = ["serde_crate", "ndarray/serde"] 18 | 19 | [dependencies.serde_crate] 20 | package = "serde" 21 | optional = true 22 | version = "1.0" 23 | default-features = false 24 | features = ["std", "derive"] 25 | 26 | [dependencies] 27 | ndarray = { version = "0.15", features = ["approx"]} 28 | ndarray-stats = "0.5" 29 | num-traits = "0.2.0" 30 | noisy_float = "0.2.0" 31 | order-stat = "0.1.3" 32 | thiserror = "1.0" 33 | 34 | kdtree = "0.7.0" 35 | 36 | linfa = { version = "0.7.1", path = "../.." } 37 | 38 | [dev-dependencies] 39 | approx = "0.4" 40 | criterion = "0.4.0" 41 | rand_xoshiro = "0.6" 42 | ndarray-rand = "0.14" 43 | linfa = { version = "0.7.1", path = "../..", features = ["benchmarks"] } 44 | 45 | [[bench]] 46 | name = "nn" 47 | harness = false 48 | -------------------------------------------------------------------------------- /algorithms/linfa-nn/README.md: -------------------------------------------------------------------------------- 1 | # Nearest Neighbor 2 | 3 | `linfa-nn` provides a pure Rust implementation of nearest neighbor algorithms. 4 | 5 | ## The Big Picture 6 | 7 | `linfa-nn` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 8 | 9 | Nearest neighbor search (NNS), as a form of proximity search, is the optimization problem of finding the point in a given set that is closest (or most similar) to a given point. Closeness is typically expressed in terms of a dissimilarity function: the less similar the objects, the larger the function values. 10 | 11 | ## Current State 12 | 13 | linfa-nn currently provides the following implementations: 14 | - linear 15 | - balltree 16 | - KDTree 17 | 18 | 19 | ## License 20 | Dual-licensed to be compatible with the Rust project. 21 | 22 | Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be copied, modified, or distributed except according to those terms. 23 | -------------------------------------------------------------------------------- /algorithms/linfa-nn/src/heap_elem.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::{Ordering, Reverse}; 2 | 3 | use linfa::Float; 4 | use noisy_float::{checkers::FiniteChecker, NoisyFloat}; 5 | 6 | #[derive(Debug, Clone)] 7 | pub(crate) struct HeapElem { 8 | pub(crate) dist: D, 9 | pub(crate) elem: T, 10 | } 11 | 12 | impl PartialEq for HeapElem { 13 | fn eq(&self, other: &Self) -> bool { 14 | self.dist.eq(&other.dist) 15 | } 16 | } 17 | impl Eq for HeapElem {} 18 | 19 | #[allow(clippy::non_canonical_partial_ord_impl)] 20 | impl PartialOrd for HeapElem { 21 | fn partial_cmp(&self, other: &Self) -> Option { 22 | self.dist.partial_cmp(&other.dist) 23 | } 24 | } 25 | 26 | impl Ord for HeapElem { 27 | fn cmp(&self, other: &Self) -> Ordering { 28 | self.dist.cmp(&other.dist) 29 | } 30 | } 31 | 32 | pub(crate) type MinHeapElem = HeapElem>, T>; 33 | 34 | impl MinHeapElem { 35 | pub(crate) fn new(dist: F, elem: T) -> Self { 36 | Self { 37 | dist: Reverse(NoisyFloat::new(dist)), 38 | elem, 39 | } 40 | } 41 | } 42 | 43 | pub(crate) type MaxHeapElem = HeapElem, T>; 44 | 45 | impl MaxHeapElem { 46 | pub(crate) fn new(dist: F, elem: T) -> Self { 47 | Self { 48 | dist: NoisyFloat::new(dist), 49 | elem, 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /algorithms/linfa-pls/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-pls" 3 | version = "0.7.1" 4 | edition = "2018" 5 | authors = ["relf "] 6 | description = "Partial Least Squares family methods" 7 | license = "MIT OR Apache-2.0" 8 | 9 | repository = "https://github.com/rust-ml/linfa" 10 | readme = "README.md" 11 | 12 | keywords = ["pls", "machine-learning", "linfa", "supervised"] 13 | categories = ["algorithms", "mathematics", "science"] 14 | 15 | [features] 16 | default = [] 17 | blas = ["ndarray-linalg", "linfa/ndarray-linalg"] 18 | serde = ["serde_crate", "ndarray/serde"] 19 | 20 | [dependencies.serde_crate] 21 | package = "serde" 22 | optional = true 23 | version = "1.0" 24 | default-features = false 25 | features = ["std", "derive"] 26 | 27 | [dependencies] 28 | ndarray = { version = "0.15" } 29 | linfa-linalg = { version = "0.1", default-features = false } 30 | ndarray-linalg = { version = "0.16", optional = true } 31 | ndarray-stats = "0.5" 32 | ndarray-rand = "0.14" 33 | num-traits = "0.2" 34 | paste = "1.0" 35 | thiserror = "1.0" 36 | linfa = { version = "0.7.1", path = "../.." } 37 | 38 | [dev-dependencies] 39 | linfa = { version = "0.7.1", path = "../..", features = ["benchmarks"] } 40 | linfa-datasets = { version = "0.7.1", path = "../../datasets", features = [ 41 | "linnerud", 42 | ] } 43 | approx = "0.4" 44 | rand_xoshiro = "0.6" 45 | criterion = "0.4.0" 46 | statrs = "0.16.0" 47 | 48 | [[bench]] 49 | name = "pls" 50 | harness = false 51 | -------------------------------------------------------------------------------- /algorithms/linfa-pls/README.md: -------------------------------------------------------------------------------- 1 | # Partial Least Squares 2 | 3 | `linfa-pls` provides a pure Rust implementation of the partial least squares algorithm family. 4 | 5 | ## The Big Picture 6 | 7 | `linfa-pls` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 8 | 9 | ## Current state 10 | 11 | `linfa-pls` currently provides an implementation of the following methods: 12 | 13 | - Partial Least Squares 14 | 15 | ## Examples 16 | 17 | There is an usage example in the `examples/` directory. To run it, use: 18 | 19 | ```bash 20 | $ cargo run --example pls_regression 21 | ``` 22 | 23 | ## BLAS/Lapack backend 24 | 25 | See [this section](../../README.md#blaslapack-backend) to enable an external BLAS/LAPACK backend. 26 | 27 | ## License 28 | Dual-licensed to be compatible with the Rust project. 29 | 30 | Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be copied, modified, or distributed except according to those terms. 31 | 32 | -------------------------------------------------------------------------------- /algorithms/linfa-pls/examples/pls_regression.rs: -------------------------------------------------------------------------------- 1 | use linfa::prelude::*; 2 | use linfa_pls::{PlsRegression, Result}; 3 | use ndarray::{Array, Array1, Array2}; 4 | use ndarray_rand::rand::SeedableRng; 5 | use ndarray_rand::rand_distr::StandardNormal; 6 | use ndarray_rand::RandomExt; 7 | use rand_xoshiro::Xoshiro256Plus; 8 | 9 | #[allow(clippy::many_single_char_names)] 10 | fn main() -> Result<()> { 11 | let n = 1000; 12 | let q = 3; 13 | let p = 10; 14 | let mut rng = Xoshiro256Plus::seed_from_u64(42); 15 | 16 | // X shape (n, p) random 17 | let x: Array2 = Array::random_using((n, p), StandardNormal, &mut rng); 18 | 19 | // B shape (p, q) such that B[0, ..] = 1, B[1, ..] = 2; otherwise zero 20 | let mut b: Array2 = Array2::zeros((p, q)); 21 | b.row_mut(0).assign(&Array1::ones(q)); 22 | b.row_mut(1).assign(&Array1::from_elem(q, 2.)); 23 | 24 | // Y shape (n, q) such that yj = 1*x1 + 2*x2 + noise(Normal(5, 1)) 25 | let y = x.dot(&b) + Array::random_using((n, q), StandardNormal, &mut rng).mapv(|v: f64| v + 5.); 26 | 27 | let ds = Dataset::new(x, y); 28 | let pls = PlsRegression::params(3) 29 | .scale(true) 30 | .max_iterations(200) 31 | .fit(&ds)?; 32 | 33 | println!("True B (such that: Y = XB + noise)"); 34 | println!("{:?}", b); 35 | 36 | // PLS regression coefficients is an estimation of B 37 | println!("Estimated B"); 38 | println!("{:1.1}", pls.coefficients()); 39 | Ok(()) 40 | } 41 | -------------------------------------------------------------------------------- /algorithms/linfa-pls/src/errors.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "blas"))] 2 | use linfa_linalg::LinalgError; 3 | #[cfg(feature = "blas")] 4 | use ndarray_linalg::error::LinalgError; 5 | use thiserror::Error; 6 | pub type Result = std::result::Result; 7 | 8 | #[derive(Error, Debug)] 9 | pub enum PlsError { 10 | #[error("Number of samples should be greater than 1, got {0}")] 11 | NotEnoughSamplesError(usize), 12 | #[error("Number of components should be in [1, {upperbound}], got {actual}")] 13 | BadComponentNumberError { upperbound: usize, actual: usize }, 14 | #[error("The tolerance is should not be negative, NaN or inf but is {0}")] 15 | InvalidTolerance(f32), 16 | #[error("The maximal number of iterations should be positive")] 17 | ZeroMaxIter, 18 | #[error("Singular vector computation power method: max iterations ({0}) reached")] 19 | PowerMethodNotConvergedError(usize), 20 | #[error("Constant residual detected in power method")] 21 | PowerMethodConstantResidualError(), 22 | #[error(transparent)] 23 | LinalgError(#[from] LinalgError), 24 | #[error(transparent)] 25 | LinfaError(#[from] linfa::error::Error), 26 | #[error(transparent)] 27 | MinMaxError(#[from] ndarray_stats::errors::MinMaxError), 28 | } 29 | -------------------------------------------------------------------------------- /algorithms/linfa-preprocessing/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-preprocessing" 3 | version = "0.7.1" 4 | authors = ["Sauro98 "] 5 | 6 | description = "A Machine Learning framework for Rust" 7 | edition = "2018" 8 | license = "MIT OR Apache-2.0" 9 | 10 | repository = "https://github.com/rust-ml/linfa" 11 | readme = "README.md" 12 | 13 | keywords = ["machine-learning", "linfa", "ai", "ml", "preprocessing"] 14 | categories = ["algorithms", "mathematics", "science"] 15 | 16 | [features] 17 | blas = ["ndarray-linalg", "linfa/ndarray-linalg"] 18 | serde = ["serde_crate", "ndarray/serde", "serde_regex"] 19 | 20 | [dependencies] 21 | linfa = { version = "0.7.1", path = "../.." } 22 | ndarray = { version = "0.15", features = ["approx"] } 23 | ndarray-linalg = { version = "0.16", optional = true } 24 | linfa-linalg = { version = "0.1", default-features = false } 25 | ndarray-stats = "0.5" 26 | thiserror = "1.0" 27 | approx = { version = "0.4" } 28 | ndarray-rand = { version = "0.14" } 29 | unicode-normalization = "0.1.8" 30 | regex = "1.4.5" 31 | encoding = "0.2" 32 | sprs = { version = "=0.11.1", default-features = false } 33 | 34 | serde_regex = { version = "1.1", optional = true } 35 | itertools = "0.14.0" 36 | 37 | [dependencies.serde_crate] 38 | package = "serde" 39 | optional = true 40 | version = "1.0" 41 | default-features = false 42 | features = ["std", "derive"] 43 | 44 | [dev-dependencies] 45 | linfa-datasets = { version = "0.7.1", path = "../../datasets", features = [ 46 | "diabetes", 47 | "winequality", 48 | "generate" 49 | ] } 50 | linfa-bayes = { version = "0.7.1", path = "../linfa-bayes" } 51 | iai = "0.1" 52 | curl = "0.4.35" 53 | flate2 = "1.0.20" 54 | tar = "0.4.33" 55 | linfa = { version = "0.7.1", path = "../..", features = ["benchmarks"] } 56 | criterion = "0.4.0" 57 | statrs = "0.16.0" 58 | 59 | [[bench]] 60 | name = "vectorizer_bench" 61 | harness = false 62 | 63 | [[bench]] 64 | name = "linear_scaler_bench" 65 | harness = false 66 | 67 | [[bench]] 68 | name = "whitening_bench" 69 | harness = false 70 | 71 | [[bench]] 72 | name = "norm_scaler_bench" 73 | harness = false 74 | -------------------------------------------------------------------------------- /algorithms/linfa-preprocessing/README.md: -------------------------------------------------------------------------------- 1 | # Preprocessing 2 | ## The Big Picture 3 | 4 | `linfa-preprocessing` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 5 | 6 | ## Current state 7 | `linfa-preprocessing` provides a pure Rust implementation of: 8 | * Standard scaling 9 | * Min-max scaling 10 | * Max Abs Scaling 11 | * Normalization 12 | * Count vectorization 13 | * TfIdf vectorization 14 | * Whitening 15 | 16 | ## Examples 17 | 18 | There are various usage examples in the `examples/` directory. To run, use: 19 | 20 | ```bash 21 | $ cargo run --release --example count_vectorization 22 | ``` 23 | ```bash 24 | $ cargo run --release --example tfidf_vectorization 25 | ``` 26 | ```bash 27 | $ cargo run --release --example scaling 28 | ``` 29 | ```bash 30 | $ cargo run --release --example whitening 31 | ``` 32 | 33 | ## BLAS/Lapack backend 34 | 35 | See [this section](../../README.md#blaslapack-backend) to enable an external BLAS/LAPACK backend. 36 | 37 | ## License 38 | Dual-licensed to be compatible with the Rust project. 39 | 40 | Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be copied, modified, or distributed except according to those terms. 41 | -------------------------------------------------------------------------------- /algorithms/linfa-preprocessing/benches/linear_scaler_bench.rs: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; 2 | use linfa::benchmarks::config; 3 | use linfa::traits::{Fit, Transformer}; 4 | use linfa_datasets::generate::make_dataset; 5 | use linfa_preprocessing::linear_scaling::LinearScaler; 6 | use statrs::distribution::{DiscreteUniform, Laplace}; 7 | 8 | fn bench(c: &mut Criterion) { 9 | let mut benchmark = c.benchmark_group("liner scaler"); 10 | config::set_default_benchmark_configs(&mut benchmark); 11 | let size = 10000; 12 | let feat_distr = Laplace::new(0.5, 5.).unwrap(); 13 | let target_distr = DiscreteUniform::new(0, 5).unwrap(); 14 | 15 | for (liner_scaler, fn_name) in [ 16 | (LinearScaler::standard(), "standard scaler"), 17 | (LinearScaler::min_max(), "min max scaler"), 18 | (LinearScaler::max_abs(), "max abs scaler"), 19 | ] { 20 | for nfeatures in (10..100).step_by(10) { 21 | let dataset = make_dataset(size, nfeatures, 1, feat_distr, target_distr); 22 | benchmark.bench_function( 23 | BenchmarkId::new(fn_name, format!("{}x{}", nfeatures, size)), 24 | |bencher| { 25 | bencher.iter(|| { 26 | liner_scaler 27 | .fit(black_box(&dataset)) 28 | .unwrap() 29 | .transform(black_box(dataset.view())); 30 | }); 31 | }, 32 | ); 33 | } 34 | } 35 | } 36 | 37 | #[cfg(not(target_os = "windows"))] 38 | criterion_group! { 39 | name = benches; 40 | config = config::get_default_profiling_configs(); 41 | targets = bench 42 | } 43 | #[cfg(target_os = "windows")] 44 | criterion_group!(benches, bench); 45 | 46 | criterion_main!(benches); 47 | -------------------------------------------------------------------------------- /algorithms/linfa-preprocessing/benches/norm_scaler_bench.rs: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; 2 | use linfa::benchmarks::config; 3 | use linfa::traits::Transformer; 4 | use linfa_datasets::generate::make_dataset; 5 | use linfa_preprocessing::norm_scaling::NormScaler; 6 | use statrs::distribution::{DiscreteUniform, Laplace}; 7 | 8 | fn bench(c: &mut Criterion) { 9 | let mut benchmark = c.benchmark_group("norm scaler"); 10 | config::set_default_benchmark_configs(&mut benchmark); 11 | let size = 10000; 12 | let feat_distr = Laplace::new(0.5, 5.).unwrap(); 13 | let target_distr = DiscreteUniform::new(0, 5).unwrap(); 14 | 15 | for (scaler, fn_name) in [ 16 | (NormScaler::l2(), "l2 scaler"), 17 | (NormScaler::l1(), "l1 scaler"), 18 | (NormScaler::max(), "max scaler"), 19 | ] { 20 | for nfeatures in (10..100).step_by(10) { 21 | let dataset = make_dataset(size, nfeatures, 1, feat_distr, target_distr); 22 | benchmark.bench_function( 23 | BenchmarkId::new(fn_name, format!("{}x{}", nfeatures, size)), 24 | |bencher| { 25 | bencher.iter(|| { 26 | scaler.transform(black_box(dataset.view())); 27 | }); 28 | }, 29 | ); 30 | } 31 | } 32 | } 33 | 34 | #[cfg(not(target_os = "windows"))] 35 | criterion_group! { 36 | name = benches; 37 | config = config::get_default_profiling_configs(); 38 | targets = bench 39 | } 40 | #[cfg(target_os = "windows")] 41 | criterion_group!(benches, bench); 42 | 43 | criterion_main!(benches); 44 | -------------------------------------------------------------------------------- /algorithms/linfa-preprocessing/benches/whitening_bench.rs: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; 2 | use linfa::benchmarks::config; 3 | use linfa::traits::Fit; 4 | use linfa::traits::Transformer; 5 | use linfa_datasets::generate::make_dataset; 6 | use linfa_preprocessing::whitening::Whitener; 7 | use statrs::distribution::{DiscreteUniform, Laplace}; 8 | 9 | fn bench(c: &mut Criterion) { 10 | let mut benchmark = c.benchmark_group("whitening"); 11 | config::set_default_benchmark_configs(&mut benchmark); 12 | let size = 10000; 13 | let feat_distr = Laplace::new(0.5, 5.).unwrap(); 14 | let target_distr = DiscreteUniform::new(0, 5).unwrap(); 15 | 16 | for (whitener, fn_name) in [ 17 | (Whitener::cholesky(), "cholesky"), 18 | (Whitener::zca(), "zca"), 19 | (Whitener::pca(), "pca"), 20 | ] { 21 | for nfeatures in (10..100).step_by(10) { 22 | let dataset = make_dataset(size, nfeatures, 1, feat_distr, target_distr); 23 | benchmark.bench_function( 24 | BenchmarkId::new(fn_name, format!("{}x{}", nfeatures, size)), 25 | |bencher| { 26 | bencher.iter(|| { 27 | whitener 28 | .fit(black_box(&dataset)) 29 | .unwrap() 30 | .transform(black_box(dataset.view())); 31 | }); 32 | }, 33 | ); 34 | } 35 | } 36 | } 37 | 38 | #[cfg(not(target_os = "windows"))] 39 | criterion_group! { 40 | name = benches; 41 | config = config::get_default_profiling_configs(); 42 | targets = bench 43 | } 44 | #[cfg(target_os = "windows")] 45 | criterion_group!(benches, bench); 46 | 47 | criterion_main!(benches); 48 | -------------------------------------------------------------------------------- /algorithms/linfa-preprocessing/examples/scaling.rs: -------------------------------------------------------------------------------- 1 | use linfa::metrics::ToConfusionMatrix; 2 | use linfa::traits::{Fit, Predict, Transformer}; 3 | use linfa_bayes::GaussianNb; 4 | use linfa_preprocessing::linear_scaling::LinearScaler; 5 | 6 | fn main() { 7 | // Read in the dataset and convert continuous target into categorical 8 | let (train, valid) = linfa_datasets::winequality() 9 | .map_targets(|x| if *x > 6 { 1 } else { 0 }) 10 | .split_with_ratio(0.7); 11 | 12 | // Fit a standard scaler to the training set 13 | let scaler = LinearScaler::standard().fit(&train).unwrap(); 14 | 15 | // Scale training and validation sets according to the fitted scaler 16 | let train = scaler.transform(train); 17 | let valid = scaler.transform(valid); 18 | 19 | // Learn a naive bayes model from the training set 20 | let model = GaussianNb::params().fit(&train).unwrap(); 21 | 22 | // compute accuracies 23 | let train_acc = model 24 | .predict(&train) 25 | .confusion_matrix(&train) 26 | .unwrap() 27 | .accuracy(); 28 | let cm = model.predict(&valid).confusion_matrix(&valid).unwrap(); 29 | let valid_acc = cm.accuracy(); 30 | println!( 31 | "Scaled model training and validation accuracies: {} - {}", 32 | train_acc, valid_acc 33 | ); 34 | println!("{:?}", cm); 35 | } 36 | -------------------------------------------------------------------------------- /algorithms/linfa-preprocessing/examples/whitening.rs: -------------------------------------------------------------------------------- 1 | use linfa::metrics::ToConfusionMatrix; 2 | use linfa::traits::{Fit, Predict, Transformer}; 3 | use linfa_bayes::GaussianNb; 4 | use linfa_preprocessing::whitening::Whitener; 5 | 6 | fn main() { 7 | // Read in the dataset and convert continuous target into categorical 8 | let (train, valid) = linfa_datasets::winequality() 9 | .map_targets(|x| if *x > 6 { 1 } else { 0 }) 10 | .split_with_ratio(0.7); 11 | 12 | // Fit a standard scaler to the training set 13 | let scaler = Whitener::pca().fit(&train).unwrap(); 14 | 15 | // Scale training and validation sets according to the fitted scaler 16 | let train = scaler.transform(train); 17 | let valid = scaler.transform(valid); 18 | 19 | // Learn a naive bayes model from the training set 20 | let model = GaussianNb::params().fit(&train).unwrap(); 21 | 22 | // compute accuracies 23 | let train_acc = model 24 | .predict(&train) 25 | .confusion_matrix(&train) 26 | .unwrap() 27 | .accuracy(); 28 | let cm = model.predict(&valid).confusion_matrix(&valid).unwrap(); 29 | let valid_acc = cm.accuracy(); 30 | println!( 31 | "Whitened model training and validation accuracies: {} - {}", 32 | train_acc, valid_acc 33 | ); 34 | println!("{:?}", cm); 35 | } 36 | -------------------------------------------------------------------------------- /algorithms/linfa-preprocessing/src/error.rs: -------------------------------------------------------------------------------- 1 | //! Error definitions for preprocessing 2 | use thiserror::Error; 3 | pub type Result = std::result::Result; 4 | 5 | #[derive(Error, Debug)] 6 | #[non_exhaustive] 7 | pub enum PreprocessingError { 8 | #[error("wrong measure ({0}) for scaler: {1}")] 9 | WrongMeasureForScaler(String, String), 10 | #[error("subsamples greater than total samples: {0} > {1}")] 11 | TooManySubsamples(usize, usize), 12 | #[error("not enough samples")] 13 | NotEnoughSamples, 14 | #[error("not a valid float")] 15 | InvalidFloat, 16 | #[error("minimum value for MinMax scaler cannot be greater than the maximum")] 17 | TokenizerNotSet, 18 | #[error("Tokenizer must be defined after deserializing CountVectorizer by calling force_tokenizer_redefinition")] 19 | FlippedMinMaxRange, 20 | #[error("n_gram boundaries cannot be zero (min = {0}, max = {1})")] 21 | InvalidNGramBoundaries(usize, usize), 22 | #[error("n_gram min boundary cannot be greater than max boundary (min = {0}, max = {1})")] 23 | FlippedNGramBoundaries(usize, usize), 24 | #[error("document frequencies have to be between 0 and 1 (min = {0}, max = {1})")] 25 | InvalidDocumentFrequencies(f32, f32), 26 | #[error("min document frequency cannot be greater than max document frequency (min = {0}, max = {1})")] 27 | FlippedDocumentFrequencies(f32, f32), 28 | #[error(transparent)] 29 | RegexError(#[from] regex::Error), 30 | #[error(transparent)] 31 | IoError(#[from] std::io::Error), 32 | #[error("Encoding error {0}")] 33 | EncodingError(std::borrow::Cow<'static, str>), 34 | #[cfg(feature = "blas")] 35 | #[error(transparent)] 36 | LinalgBlasError(#[from] ndarray_linalg::error::LinalgError), 37 | #[error(transparent)] 38 | LinalgError(#[from] linfa_linalg::LinalgError), 39 | #[error(transparent)] 40 | NdarrayStatsEmptyError(#[from] ndarray_stats::errors::EmptyInput), 41 | #[error(transparent)] 42 | LinfaError(#[from] linfa::error::Error), 43 | } 44 | -------------------------------------------------------------------------------- /algorithms/linfa-preprocessing/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # Preprocessing 2 | //! ## The Big Picture 3 | //! 4 | //! `linfa-preprocessing` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 5 | //! 6 | //! ## Current state 7 | //! `linfa-preprocessing` provides a pure Rust implementation of: 8 | //! * Standard scaling 9 | //! * Min-max scaling 10 | //! * Max Abs Scaling 11 | //! * Normalization (l1, l2 and max norm) 12 | //! * Count vectorization 13 | //! * Term frequency - inverse document frequency count vectorization 14 | //! * Whitening 15 | 16 | mod countgrams; 17 | pub mod error; 18 | mod helpers; 19 | pub mod linear_scaling; 20 | pub mod norm_scaling; 21 | pub mod tf_idf_vectorization; 22 | pub mod whitening; 23 | 24 | pub use countgrams::{ 25 | CountVectorizer, CountVectorizerParams, CountVectorizerValidParams, Tokenizer, 26 | }; 27 | pub use error::{PreprocessingError, Result}; 28 | -------------------------------------------------------------------------------- /algorithms/linfa-reduction/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-reduction" 3 | version = "0.7.1" 4 | authors = [ 5 | "Lorenz Schmidt ", 6 | "Gabriel Bathie ", 7 | ] 8 | description = "A collection of dimensionality reduction techniques" 9 | edition = "2018" 10 | license = "MIT OR Apache-2.0" 11 | 12 | repository = "https://github.com/rust-ml/linfa" 13 | readme = "README.md" 14 | 15 | keywords = [ 16 | "reduction", 17 | "machine-learning", 18 | "linfa", 19 | "spectral", 20 | "unsupervised", 21 | ] 22 | categories = ["algorithms", "mathematics", "science"] 23 | 24 | [features] 25 | default = [] 26 | blas = ["ndarray-linalg", "linfa/ndarray-linalg"] 27 | serde = ["serde_crate", "ndarray/serde"] 28 | 29 | [dependencies.serde_crate] 30 | package = "serde" 31 | optional = true 32 | version = "1.0" 33 | default-features = false 34 | features = ["std", "derive"] 35 | 36 | [dependencies] 37 | ndarray = { version = "0.15", features = ["approx"] } 38 | linfa-linalg = { version = "0.1" } 39 | ndarray-linalg = { version = "0.16", optional = true } 40 | ndarray-rand = "0.14" 41 | num-traits = "0.2" 42 | thiserror = "1.0" 43 | rand = { version = "0.8", features = ["small_rng"] } 44 | 45 | linfa = { version = "0.7.1", path = "../.." } 46 | linfa-kernel = { version = "0.7.1", path = "../linfa-kernel" } 47 | sprs = "=0.11.1" 48 | rand_xoshiro = "0.6.0" 49 | 50 | [dev-dependencies] 51 | ndarray-npy = { version = "0.8", default-features = false } 52 | linfa-datasets = { version = "0.7.1", path = "../../datasets", features = [ 53 | "iris", 54 | "generate", 55 | ] } 56 | approx = { version = "0.4" } 57 | mnist = { version = "0.6.0", features = ["download"] } 58 | linfa-trees = { version = "0.7.1", path = "../linfa-trees" } 59 | -------------------------------------------------------------------------------- /algorithms/linfa-reduction/README.md: -------------------------------------------------------------------------------- 1 | # Dimensional Reduction 2 | 3 | `linfa-reduction` aims to provide pure Rust implementations of dimensional reduction algorithms. 4 | 5 | ## The Big Picture 6 | 7 | `linfa-reduction` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 8 | 9 | ## Current state 10 | 11 | `linfa-reduction` currently provides an implementation of the following dimensional reduction methods: 12 | - Diffusion Mapping 13 | - Principal Component Analysis (PCA) 14 | - Gaussian random projections 15 | - Sparse random projections 16 | 17 | ## Examples 18 | 19 | There is an usage example in the `examples/` directory. To run, use: 20 | 21 | ```bash 22 | $ cargo run --release --example diffusion_map 23 | $ cargo run --release --example pca 24 | $ cargo run --release --example gaussian_projection 25 | $ cargo run --release --example sparse_projection 26 | ``` 27 | 28 | ## BLAS/LAPACK backend 29 | 30 | See [this section](../../README.md#blaslapack-backend) to enable an external BLAS/LAPACK backend. 31 | 32 | ## License 33 | Dual-licensed to be compatible with the Rust project. 34 | 35 | Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be copied, modified, or distributed except according to those terms. 36 | -------------------------------------------------------------------------------- /algorithms/linfa-reduction/examples/diffusion_map.rs: -------------------------------------------------------------------------------- 1 | use linfa::traits::Transformer; 2 | use linfa_kernel::{Kernel, KernelMethod, KernelType}; 3 | use linfa_reduction::utils::generate_convoluted_rings2d; 4 | use linfa_reduction::{DiffusionMap, Result}; 5 | 6 | use ndarray_npy::write_npy; 7 | use rand::{rngs::SmallRng, SeedableRng}; 8 | 9 | fn main() -> Result<()> { 10 | // Our random number generator, seeded for reproducibility 11 | let mut rng = SmallRng::seed_from_u64(42); 12 | 13 | // For each our expected centroids, generate `n` data points around it (a "blob") 14 | let n = 500; 15 | 16 | // generate three convoluted rings 17 | let dataset = 18 | generate_convoluted_rings2d(&[(0.0, 3.0), (10.0, 13.0), (20.0, 23.0)], n, &mut rng); 19 | 20 | // generate sparse polynomial kernel with k = 14, c = 5 and d = 2 21 | let kernel = Kernel::params() 22 | //.method(KernelMethod::Polynomial(5.0, 2.0)) 23 | .kind(KernelType::Sparse(15)) 24 | .method(KernelMethod::Gaussian(2.0)) 25 | //.kind(KernelType::Dense) 26 | .transform(dataset.view()); 27 | 28 | let embedding = DiffusionMap::::params(2).steps(1).transform(&kernel)?; 29 | 30 | // get embedding 31 | let embedding = embedding.embedding(); 32 | 33 | // Save to disk our dataset (and the cluster label assigned to each observation) 34 | // We use the `npy` format for compatibility with NumPy 35 | write_npy("diffusion_map_dataset.npy", &dataset).expect("Failed to write .npy file"); 36 | write_npy("diffusion_map_embedding.npy", embedding).expect("Failed to write .npy file"); 37 | 38 | Ok(()) 39 | } 40 | -------------------------------------------------------------------------------- /algorithms/linfa-reduction/examples/gaussian_projection.rs: -------------------------------------------------------------------------------- 1 | use std::{error::Error, time::Instant}; 2 | 3 | use linfa::prelude::*; 4 | use linfa_reduction::random_projection::GaussianRandomProjection; 5 | use linfa_trees::{DecisionTree, SplitQuality}; 6 | 7 | use mnist::{MnistBuilder, NormalizedMnist}; 8 | use ndarray::{Array1, Array2}; 9 | use rand::SeedableRng; 10 | use rand_xoshiro::Xoshiro256Plus; 11 | 12 | /// Train a Decision tree on the MNIST data set, with and without dimensionality reduction. 13 | fn main() -> Result<(), Box> { 14 | // Parameters 15 | let train_sz = 10_000usize; 16 | let test_sz = 1_000usize; 17 | let reduced_dim = 100; 18 | let rng = Xoshiro256Plus::seed_from_u64(42); 19 | 20 | let NormalizedMnist { 21 | trn_img, 22 | trn_lbl, 23 | tst_img, 24 | tst_lbl, 25 | .. 26 | } = MnistBuilder::new() 27 | .label_format_digit() 28 | .training_set_length(train_sz as u32) 29 | .test_set_length(test_sz as u32) 30 | .download_and_extract() 31 | .finalize() 32 | .normalize(); 33 | 34 | let train_data = Array2::from_shape_vec((train_sz, 28 * 28), trn_img)?; 35 | let train_labels: Array1 = 36 | Array1::from_shape_vec(train_sz, trn_lbl)?.map(|x| *x as usize); 37 | let train_dataset = Dataset::new(train_data, train_labels); 38 | 39 | let test_data = Array2::from_shape_vec((test_sz, 28 * 28), tst_img)?; 40 | let test_labels: Array1 = Array1::from_shape_vec(test_sz, tst_lbl)?.map(|x| *x as usize); 41 | 42 | let params = DecisionTree::params() 43 | .split_quality(SplitQuality::Gini) 44 | .max_depth(Some(10)); 45 | 46 | println!("Training non-reduced model..."); 47 | let start = Instant::now(); 48 | let model: DecisionTree = params.fit(&train_dataset)?; 49 | 50 | let end = start.elapsed(); 51 | let pred_y = model.predict(&test_data); 52 | let cm = pred_y.confusion_matrix(&test_labels)?; 53 | println!("Non-reduced model precision: {}%", 100.0 * cm.accuracy()); 54 | println!("Training time: {:.2}s\n", end.as_secs_f32()); 55 | 56 | println!("Training reduced model..."); 57 | let start = Instant::now(); 58 | // Compute the random projection and train the model on the reduced dataset. 59 | let proj = GaussianRandomProjection::::params_with_rng(rng) 60 | .target_dim(reduced_dim) 61 | .fit(&train_dataset)?; 62 | let reduced_train_ds = proj.transform(&train_dataset); 63 | let reduced_test_data = proj.transform(&test_data); 64 | let model_reduced: DecisionTree = params.fit(&reduced_train_ds)?; 65 | 66 | let end = start.elapsed(); 67 | let pred_reduced = model_reduced.predict(&reduced_test_data); 68 | let cm_reduced = pred_reduced.confusion_matrix(&test_labels)?; 69 | println!( 70 | "Reduced model precision: {}%", 71 | 100.0 * cm_reduced.accuracy() 72 | ); 73 | println!("Reduction + training time: {:.2}s", end.as_secs_f32()); 74 | 75 | Ok(()) 76 | } 77 | -------------------------------------------------------------------------------- /algorithms/linfa-reduction/examples/pca.rs: -------------------------------------------------------------------------------- 1 | use linfa::prelude::*; 2 | use linfa_datasets::generate; 3 | use linfa_reduction::Pca; 4 | 5 | use ndarray::array; 6 | use ndarray_npy::write_npy; 7 | use rand::{rngs::SmallRng, SeedableRng}; 8 | 9 | // A routine K-means task: build a synthetic dataset, fit the algorithm on it 10 | // and save both training data and predictions to disk. 11 | fn main() { 12 | // Our random number generator, seeded for reproducibility 13 | let mut rng = SmallRng::seed_from_u64(42); 14 | 15 | // For each our expected centroids, generate `n` data points around it (a "blob") 16 | let expected_centroids = array![[10., 10.], [1., 12.], [20., 30.], [-20., 30.],]; 17 | let n = 10; 18 | let dataset = Dataset::from(generate::blobs(n, &expected_centroids, &mut rng)); 19 | 20 | let embedding: Pca = Pca::params(1).fit(&dataset).unwrap(); 21 | let embedding = embedding.predict(&dataset); 22 | 23 | dbg!(&embedding); 24 | 25 | // Save to disk our dataset (and the cluster label assigned to each observation) 26 | // We use the `npy` format for compatibility with NumPy 27 | write_npy("pca_dataset.npy", &dataset.records().view()).expect("Failed to write .npy file"); 28 | write_npy("pca_embedding.npy", &embedding).expect("Failed to write .npy file"); 29 | } 30 | -------------------------------------------------------------------------------- /algorithms/linfa-reduction/examples/sparse_projection.rs: -------------------------------------------------------------------------------- 1 | use std::{error::Error, time::Instant}; 2 | 3 | use linfa::prelude::*; 4 | use linfa_reduction::random_projection::SparseRandomProjection; 5 | use linfa_trees::{DecisionTree, SplitQuality}; 6 | 7 | use mnist::{MnistBuilder, NormalizedMnist}; 8 | use ndarray::{Array1, Array2}; 9 | use rand::SeedableRng; 10 | use rand_xoshiro::Xoshiro256Plus; 11 | 12 | /// Train a Decision tree on the MNIST data set, with and without dimensionality reduction. 13 | fn main() -> Result<(), Box> { 14 | // Parameters 15 | let train_sz = 10_000usize; 16 | let test_sz = 1_000usize; 17 | let reduced_dim = 100; 18 | let rng = Xoshiro256Plus::seed_from_u64(42); 19 | 20 | let NormalizedMnist { 21 | trn_img, 22 | trn_lbl, 23 | tst_img, 24 | tst_lbl, 25 | .. 26 | } = MnistBuilder::new() 27 | .label_format_digit() 28 | .training_set_length(train_sz as u32) 29 | .test_set_length(test_sz as u32) 30 | .download_and_extract() 31 | .finalize() 32 | .normalize(); 33 | 34 | let train_data = Array2::from_shape_vec((train_sz, 28 * 28), trn_img)?; 35 | let train_labels: Array1 = 36 | Array1::from_shape_vec(train_sz, trn_lbl)?.map(|x| *x as usize); 37 | let train_dataset = Dataset::new(train_data, train_labels); 38 | 39 | let test_data = Array2::from_shape_vec((test_sz, 28 * 28), tst_img)?; 40 | let test_labels: Array1 = Array1::from_shape_vec(test_sz, tst_lbl)?.map(|x| *x as usize); 41 | 42 | let params = DecisionTree::params() 43 | .split_quality(SplitQuality::Gini) 44 | .max_depth(Some(10)); 45 | 46 | println!("Training non-reduced model..."); 47 | let start = Instant::now(); 48 | let model: DecisionTree = params.fit(&train_dataset)?; 49 | 50 | let end = start.elapsed(); 51 | let pred_y = model.predict(&test_data); 52 | let cm = pred_y.confusion_matrix(&test_labels)?; 53 | println!("Non-reduced model precision: {}%", 100.0 * cm.accuracy()); 54 | println!("Training time: {:.2}s\n", end.as_secs_f32()); 55 | 56 | println!("Training reduced model..."); 57 | let start = Instant::now(); 58 | // Compute the random projection and train the model on the reduced dataset. 59 | let proj = SparseRandomProjection::::params_with_rng(rng) 60 | .target_dim(reduced_dim) 61 | .fit(&train_dataset)?; 62 | let reduced_train_ds = proj.transform(&train_dataset); 63 | let reduced_test_data = proj.transform(&test_data); 64 | let model_reduced: DecisionTree = params.fit(&reduced_train_ds)?; 65 | 66 | let end = start.elapsed(); 67 | let pred_reduced = model_reduced.predict(&reduced_test_data); 68 | let cm_reduced = pred_reduced.confusion_matrix(&test_labels)?; 69 | println!( 70 | "Reduced model precision: {}%", 71 | 100.0 * cm_reduced.accuracy() 72 | ); 73 | println!("Reduction + training time: {:.2}s", end.as_secs_f32()); 74 | 75 | Ok(()) 76 | } 77 | -------------------------------------------------------------------------------- /algorithms/linfa-reduction/src/diffusion_map/mod.rs: -------------------------------------------------------------------------------- 1 | //! Diffusion Map 2 | //! 3 | //! The diffusion map computes an embedding of the data by applying PCA on the diffusion operator 4 | //! of the data. It transforms the data along the direction of the largest diffusion flow and is therefore 5 | //! a non-linear dimensionality reduction technique. A normalized kernel describes the high dimensional 6 | //! diffusion graph with the (i, j) entry the probability that a diffusion happens from point i to 7 | //! j. 8 | //! 9 | mod algorithms; 10 | mod hyperparams; 11 | 12 | pub use algorithms::*; 13 | pub use hyperparams::*; 14 | -------------------------------------------------------------------------------- /algorithms/linfa-reduction/src/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | pub type Result = std::result::Result; 4 | 5 | #[derive(Error, Debug)] 6 | #[non_exhaustive] 7 | pub enum ReductionError { 8 | #[error("At least 1 sample needed")] 9 | NotEnoughSamples, 10 | #[error("embedding dimension smaller {0} than feature dimension")] 11 | EmbeddingTooSmall(usize), 12 | #[error("Number of steps zero in diffusion map operator")] 13 | StepsZero, 14 | #[cfg(feature = "blas")] 15 | #[error(transparent)] 16 | LinalgBlasError(#[from] ndarray_linalg::error::LinalgError), 17 | #[error(transparent)] 18 | LinalgError(#[from] linfa_linalg::LinalgError), 19 | #[error(transparent)] 20 | LinfaError(#[from] linfa::error::Error), 21 | #[error(transparent)] 22 | NdarrayRandError(#[from] ndarray_rand::rand_distr::NormalError), 23 | #[error("Precision parameter must be in the interval (0; 1)")] 24 | InvalidPrecision, 25 | #[error("Target dimension of the projection must be positive")] 26 | NonPositiveEmbeddingSize, 27 | #[error("Target dimension {0} is larger than the number of features {1}.")] 28 | DimensionIncrease(usize, usize), 29 | } 30 | -------------------------------------------------------------------------------- /algorithms/linfa-reduction/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![doc = include_str!("../README.md")] 2 | 3 | #[macro_use] 4 | extern crate ndarray; 5 | 6 | mod diffusion_map; 7 | mod error; 8 | mod pca; 9 | pub mod random_projection; 10 | pub mod utils; 11 | 12 | pub use diffusion_map::{DiffusionMap, DiffusionMapParams, DiffusionMapValidParams}; 13 | pub use error::{ReductionError, Result}; 14 | pub use pca::{Pca, PcaParams}; 15 | -------------------------------------------------------------------------------- /algorithms/linfa-reduction/src/random_projection/common.rs: -------------------------------------------------------------------------------- 1 | /// Compute a safe dimension for a projection with precision `eps`, 2 | /// using the Johnson-Lindestrauss Lemma. 3 | /// 4 | /// References: 5 | /// - [D. Achlioptas, JCSS](https://www.sciencedirect.com/science/article/pii/S0022000003000254) 6 | /// - [Li et al., SIGKDD'06](https://hastie.su.domains/Papers/Ping/KDD06_rp.pdf) 7 | pub(crate) fn johnson_lindenstrauss_min_dim(n_samples: usize, eps: f64) -> usize { 8 | let log_samples = (n_samples as f64).ln(); 9 | let value = 4. * log_samples / (eps.powi(2) / 2. - eps.powi(3) / 3.); 10 | value as usize 11 | } 12 | 13 | #[cfg(test)] 14 | mod tests { 15 | use super::*; 16 | 17 | #[test] 18 | /// Test against values computed by the scikit-learn implementation 19 | /// of `johnson_lindenstrauss_min_dim`. 20 | fn test_johnson_lindenstrauss() { 21 | assert_eq!(johnson_lindenstrauss_min_dim(100, 0.05), 15244); 22 | assert_eq!(johnson_lindenstrauss_min_dim(100, 0.1), 3947); 23 | assert_eq!(johnson_lindenstrauss_min_dim(100, 0.2), 1062); 24 | assert_eq!(johnson_lindenstrauss_min_dim(100, 0.5), 221); 25 | assert_eq!(johnson_lindenstrauss_min_dim(1000, 0.05), 22867); 26 | assert_eq!(johnson_lindenstrauss_min_dim(1000, 0.1), 5920); 27 | assert_eq!(johnson_lindenstrauss_min_dim(1000, 0.2), 1594); 28 | assert_eq!(johnson_lindenstrauss_min_dim(1000, 0.5), 331); 29 | assert_eq!(johnson_lindenstrauss_min_dim(5000, 0.05), 28194); 30 | assert_eq!(johnson_lindenstrauss_min_dim(5000, 0.1), 7300); 31 | assert_eq!(johnson_lindenstrauss_min_dim(5000, 0.2), 1965); 32 | assert_eq!(johnson_lindenstrauss_min_dim(5000, 0.5), 408); 33 | assert_eq!(johnson_lindenstrauss_min_dim(10000, 0.05), 30489); 34 | assert_eq!(johnson_lindenstrauss_min_dim(10000, 0.1), 7894); 35 | assert_eq!(johnson_lindenstrauss_min_dim(10000, 0.2), 2125); 36 | assert_eq!(johnson_lindenstrauss_min_dim(10000, 0.5), 442); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /algorithms/linfa-reduction/src/utils.rs: -------------------------------------------------------------------------------- 1 | use ndarray::Array2; 2 | use ndarray_rand::rand::Rng; 3 | use num_traits::float::FloatConst; 4 | 5 | /// Generates a three dimension swiss roll, centered at the origin with height `height` and 6 | /// outwards speed `speed` 7 | pub fn generate_swissroll( 8 | height: f64, 9 | speed: f64, 10 | n_points: usize, 11 | rng: &mut impl Rng, 12 | ) -> Array2 { 13 | let mut roll: Array2 = Array2::zeros((n_points, 3)); 14 | 15 | for i in 0..n_points { 16 | let z = rng.gen_range(0.0..height); 17 | let phi: f64 = rng.gen_range(0.0..10.0); 18 | //let offset: f64 = rng.gen_range(-0.5..0.5); 19 | let offset = 0.0; 20 | 21 | let x = speed * phi * phi.cos() + offset; 22 | let y = speed * phi * phi.sin() + offset; 23 | 24 | roll[(i, 0)] = x; 25 | roll[(i, 1)] = y; 26 | roll[(i, 2)] = z; 27 | } 28 | roll 29 | } 30 | 31 | pub fn generate_convoluted_rings( 32 | rings: &[(f64, f64)], 33 | n_points: usize, 34 | rng: &mut impl Rng, 35 | ) -> Array2 { 36 | let n_points = (n_points as f32 / rings.len() as f32).ceil() as usize; 37 | let mut array = Array2::zeros((n_points * rings.len(), 3)); 38 | 39 | for (n, (start, end)) in rings.iter().enumerate() { 40 | // inner circle 41 | for i in 0..n_points { 42 | let r: f64 = rng.gen_range(*start..*end); 43 | let phi: f64 = rng.gen_range(0.0..(f64::PI() * 2.0)); 44 | let theta: f64 = rng.gen_range(0.0..(f64::PI() * 2.0)); 45 | 46 | let x = theta.sin() * phi.cos() * r; 47 | let y = theta.sin() * phi.sin() * r; 48 | let z = theta.cos() * r; 49 | 50 | array[(n * n_points + i, 0)] = x; 51 | array[(n * n_points + i, 1)] = y; 52 | array[(n * n_points + i, 2)] = z; 53 | } 54 | } 55 | 56 | array 57 | } 58 | 59 | pub fn generate_convoluted_rings2d( 60 | rings: &[(f64, f64)], 61 | n_points: usize, 62 | rng: &mut impl Rng, 63 | ) -> Array2 { 64 | let n_points = (n_points as f32 / rings.len() as f32).ceil() as usize; 65 | let mut array = Array2::zeros((n_points * rings.len(), 2)); 66 | 67 | for (n, (start, end)) in rings.iter().enumerate() { 68 | // inner circle 69 | for i in 0..n_points { 70 | let r: f64 = rng.gen_range(*start..*end); 71 | let phi: f64 = rng.gen_range(0.0..(f64::PI() * 2.0)); 72 | 73 | let x = phi.cos() * r; 74 | let y = phi.sin() * r; 75 | 76 | array[(n * n_points + i, 0)] = x; 77 | array[(n * n_points + i, 1)] = y; 78 | } 79 | } 80 | 81 | array 82 | } 83 | -------------------------------------------------------------------------------- /algorithms/linfa-svm/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-svm" 3 | version = "0.7.2" 4 | edition = "2018" 5 | authors = ["Lorenz Schmidt "] 6 | description = "Support Vector Machines" 7 | license = "MIT OR Apache-2.0" 8 | 9 | repository = "https://github.com/rust-ml/linfa" 10 | readme = "README.md" 11 | 12 | keywords = ["svm", "machine-learning", "linfa", "supervised"] 13 | categories = ["algorithms", "mathematics", "science"] 14 | 15 | [features] 16 | default = [] 17 | serde = ["serde_crate", "ndarray/serde", "linfa-kernel/serde"] 18 | 19 | [dependencies.serde_crate] 20 | package = "serde" 21 | optional = true 22 | version = "1.0" 23 | default-features = false 24 | features = ["std", "derive"] 25 | 26 | [dependencies] 27 | ndarray = { version = "0.15" } 28 | ndarray-rand = "0.14" 29 | num-traits = "0.2" 30 | thiserror = "1.0" 31 | 32 | linfa = { version = "0.7.1", path = "../.." } 33 | linfa-kernel = { version = "0.7.1", path = "../linfa-kernel" } 34 | 35 | [dev-dependencies] 36 | linfa-datasets = { version = "0.7.1", path = "../../datasets", features = [ 37 | "winequality", 38 | "diabetes", 39 | ] } 40 | rand_xoshiro = "0.6" 41 | approx = "0.4" 42 | -------------------------------------------------------------------------------- /algorithms/linfa-svm/README.md: -------------------------------------------------------------------------------- 1 | # Support Vector Machines 2 | 3 | `linfa-svm` provides a pure Rust implementation for support vector machines. 4 | 5 | ## The Big Picture 6 | 7 | `linfa-svm` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 8 | 9 | Support Vector Machines are one major branch of machine learning models and offer classification or regression analysis of labeled datasets. They seek a discriminant, which seperates the data in an optimal way, e.g. have the fewest numbers of miss-classifications and maximizes the margin between positive and negative classes. A support vector contributes to the discriminant and is therefore important for the classification/regression task. The balance between the number of support vectors and model performance can be controlled with hyperparameters. 10 | 11 | ## Current State 12 | 13 | linfa-svm currently provides an implementation of SVM with Sequential Minimal Optimization: 14 | - Support Vector Classification with C/Nu/one-class 15 | - Support Vector Regression with Epsilon/Nu 16 | 17 | 18 | ## Examples 19 | 20 | There is an usage example in the `examples/` directory. To run, use: 21 | 22 | ```bash 23 | $ cargo run --release --example winequality 24 | ``` 25 | 26 | ## License 27 | Dual-licensed to be compatible with the Rust project. 28 | 29 | Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be copied, modified, or distributed except according to those terms. -------------------------------------------------------------------------------- /algorithms/linfa-svm/examples/noisy_sin_svr.rs: -------------------------------------------------------------------------------- 1 | use linfa::prelude::*; 2 | use linfa_svm::{error::Result, Svm}; 3 | use ndarray::Array1; 4 | use ndarray_rand::{ 5 | rand::{Rng, SeedableRng}, 6 | rand_distr::Uniform, 7 | }; 8 | use rand_xoshiro::Xoshiro256Plus; 9 | 10 | /// Example inspired by https://scikit-learn.org/stable/auto_examples/svm/plot_svm_regression.html 11 | fn main() -> Result<()> { 12 | let mut rng = Xoshiro256Plus::seed_from_u64(42); 13 | let range = Uniform::new(0., 5.); 14 | let mut x: Vec = (0..40).map(|_| rng.sample(range)).collect(); 15 | x.sort_by(|a, b| a.partial_cmp(b).unwrap()); 16 | let x = Array1::from_vec(x); 17 | 18 | let mut y = x.mapv(|v| v.sin()); 19 | 20 | // add some noise 21 | y.iter_mut() 22 | .enumerate() 23 | .filter(|(i, _)| i % 5 == 0) 24 | .for_each(|(_, y)| *y = 3. * (0.5 - rng.gen::())); 25 | 26 | let x = x.into_shape((40, 1)).unwrap(); 27 | let dataset = DatasetBase::new(x, y); 28 | let model = Svm::params() 29 | .c_svr(100., Some(0.1)) 30 | .gaussian_kernel(10.) 31 | .fit(&dataset)?; 32 | 33 | println!("{}", model); 34 | 35 | let predicted = model.predict(&dataset); 36 | let err = predicted.mean_squared_error(&dataset).unwrap(); 37 | println!("err={}", err); 38 | 39 | Ok(()) 40 | } 41 | -------------------------------------------------------------------------------- /algorithms/linfa-svm/examples/winequality_multi_svm.rs: -------------------------------------------------------------------------------- 1 | use linfa::composing::MultiClassModel; 2 | use linfa::prelude::*; 3 | use linfa_svm::{error::Result, Svm}; 4 | 5 | fn main() -> Result<()> { 6 | let (train, valid) = linfa_datasets::winequality().split_with_ratio(0.9); 7 | 8 | println!( 9 | "Fit SVM classifier with #{} training points", 10 | train.nsamples() 11 | ); 12 | 13 | let params = Svm::<_, Pr>::params() 14 | //.pos_neg_weights(5000., 500.) 15 | .gaussian_kernel(30.0); 16 | 17 | let model = train 18 | .one_vs_all()? 19 | .into_iter() 20 | .map(|(l, x)| (l, params.fit(&x).unwrap())) 21 | .collect::>(); 22 | 23 | let pred = model.predict(&valid); 24 | 25 | // create a confusion matrix 26 | let cm = pred.confusion_matrix(&valid)?; 27 | 28 | // Print the confusion matrix 29 | println!("{:?}", cm); 30 | 31 | // Calculate the accuracy and Matthew Correlation Coefficient (cross-correlation between 32 | // predicted and targets) 33 | println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc()); 34 | 35 | Ok(()) 36 | } 37 | -------------------------------------------------------------------------------- /algorithms/linfa-svm/examples/winequality_svm.rs: -------------------------------------------------------------------------------- 1 | use linfa::prelude::*; 2 | use linfa_svm::{error::Result, Svm}; 3 | 4 | fn main() -> Result<()> { 5 | // everything above 6.5 is considered a good wine 6 | let (train, valid) = linfa_datasets::winequality() 7 | .map_targets(|x| *x > 6) 8 | .split_with_ratio(0.9); 9 | 10 | println!( 11 | "Fit SVM classifier with #{} training points", 12 | train.nsamples() 13 | ); 14 | 15 | // fit a SVM with C value 7 and 0.6 for positive and negative classes 16 | let model = Svm::<_, bool>::params() 17 | .pos_neg_weights(50000., 5000.) 18 | .gaussian_kernel(80.0) 19 | .fit(&train)?; 20 | 21 | println!("{}", model); 22 | // A positive prediction indicates a good wine, a negative, a bad one 23 | fn tag_classes(x: &bool) -> String { 24 | if *x { 25 | "good".into() 26 | } else { 27 | "bad".into() 28 | } 29 | } 30 | 31 | // map targets for validation dataset 32 | let valid = valid.map_targets(tag_classes); 33 | 34 | // predict and map targets 35 | let pred = model.predict(&valid).map(tag_classes); 36 | 37 | // create a confusion matrix 38 | let cm = pred.confusion_matrix(&valid)?; 39 | 40 | // Print the confusion matrix, this will print a table with four entries. On the diagonal are 41 | // the number of true-positive and true-negative predictions, off the diagonal are 42 | // false-positive and false-negative 43 | println!("{:?}", cm); 44 | 45 | // Calculate the accuracy and Matthew Correlation Coefficient (cross-correlation between 46 | // predicted and targets) 47 | println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc()); 48 | 49 | Ok(()) 50 | } 51 | -------------------------------------------------------------------------------- /algorithms/linfa-svm/src/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | pub type Result = std::result::Result; 4 | 5 | #[derive(Error, Debug)] 6 | pub enum SvmError { 7 | #[error("Invalid epsilon {0}")] 8 | InvalidEps(f32), 9 | #[error("Negative C value {0:?} (positive, negative samples")] 10 | InvalidC((f32, f32)), 11 | #[error("Nu should be in unit range, is {0}")] 12 | InvalidNu(f32), 13 | #[error("platt scaling failed")] 14 | Platt(#[from] linfa::composing::PlattError), 15 | #[error(transparent)] 16 | BaseCrate(#[from] linfa::Error), 17 | } 18 | -------------------------------------------------------------------------------- /algorithms/linfa-trees/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-trees" 3 | version = "0.7.1" 4 | edition = "2018" 5 | authors = ["Moss Ebeling "] 6 | description = "A collection of tree-based algorithms" 7 | license = "MIT OR Apache-2.0" 8 | 9 | repository = "https://github.com/rust-ml/linfa" 10 | readme = "README.md" 11 | 12 | keywords = ["machine-learning", "linfa", "trees", "supervised"] 13 | categories = ["algorithms", "mathematics", "science"] 14 | 15 | [features] 16 | default = [] 17 | serde = ["serde_crate", "ndarray/serde"] 18 | 19 | [dependencies.serde_crate] 20 | package = "serde" 21 | optional = true 22 | version = "1.0" 23 | default-features = false 24 | features = ["std", "derive"] 25 | 26 | [dependencies] 27 | ndarray = { version = "0.15" , features = ["rayon", "approx"]} 28 | ndarray-rand = "0.14" 29 | 30 | linfa = { version = "0.7.1", path = "../.." } 31 | 32 | [dev-dependencies] 33 | rand = { version = "0.8", features = ["small_rng"] } 34 | criterion = "0.4.0" 35 | approx = "0.4" 36 | linfa-datasets = { version = "0.7.1", path = "../../datasets/", features = ["iris"] } 37 | linfa = { version = "0.7.1", path = "../..", features = ["benchmarks"] } 38 | 39 | [[bench]] 40 | name = "decision_tree" 41 | harness = false 42 | -------------------------------------------------------------------------------- /algorithms/linfa-trees/README.md: -------------------------------------------------------------------------------- 1 | # Decision tree learning 2 | 3 | `linfa-trees` provides methods for decision tree learning algorithms. 4 | 5 | ## The Big Picture 6 | 7 | `linfa-trees` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 8 | 9 | Decision Trees (DTs) are a non-parametric supervised learning method used for classification and regression. The goal is to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data features. 10 | 11 | ## Current state 12 | 13 | `linfa-trees` currently provides an implementation of single tree fitting 14 | 15 | ## Examples 16 | 17 | There is an example in the `examples/` directory showing how to use decision trees. To run, use: 18 | 19 | ```bash 20 | $ cargo run --release --example decision_tree 21 | ``` 22 | 23 | This generates the following tree: 24 | 25 |

26 | 27 |

28 | 29 | ## License 30 | Dual-licensed to be compatible with the Rust project. 31 | 32 | Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be copied, modified, or distributed except according to those terms. 33 | -------------------------------------------------------------------------------- /algorithms/linfa-trees/benches/decision_tree.rs: -------------------------------------------------------------------------------- 1 | use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; 2 | use linfa::benchmarks::config; 3 | use linfa::prelude::*; 4 | use linfa_trees::DecisionTree; 5 | use ndarray::{concatenate, Array, Array1, Array2, Axis}; 6 | use ndarray_rand::rand::SeedableRng; 7 | use ndarray_rand::rand_distr::{StandardNormal, Uniform}; 8 | use ndarray_rand::RandomExt; 9 | use rand::rngs::SmallRng; 10 | 11 | fn generate_blobs(means: &Array2, samples: usize, mut rng: &mut SmallRng) -> Array2 { 12 | let out = means 13 | .axis_iter(Axis(0)) 14 | .map(|mean| Array::random_using((samples, 4), StandardNormal, &mut rng) + mean) 15 | .collect::>(); 16 | let out2 = out.iter().map(|x| x.view()).collect::>(); 17 | 18 | concatenate(Axis(0), &out2).unwrap() 19 | } 20 | 21 | fn decision_tree_bench(c: &mut Criterion) { 22 | let mut rng = SmallRng::seed_from_u64(42); 23 | 24 | // Controls how many samples for each class are generated 25 | let training_set_sizes = &[100, 1000, 10000, 100000]; 26 | 27 | let n_classes = 4; 28 | let n_features = 4; 29 | 30 | // Use the default configuration 31 | let hyperparams = DecisionTree::params(); 32 | 33 | // Benchmark training time 10 times for each training sample size 34 | let mut group = c.benchmark_group("decision_tree"); 35 | config::set_default_benchmark_configs(&mut group); 36 | 37 | for n in training_set_sizes.iter() { 38 | let centroids = 39 | Array2::random_using((n_classes, n_features), Uniform::new(-30., 30.), &mut rng); 40 | 41 | let train_x = generate_blobs(¢roids, *n, &mut rng); 42 | #[allow(clippy::manual_repeat_n)] 43 | let train_y: Array1 = (0..n_classes) 44 | .flat_map(|x| std::iter::repeat(x).take(*n).collect::>()) 45 | .collect::>(); 46 | let dataset = DatasetBase::new(train_x, train_y); 47 | 48 | group.bench_with_input(BenchmarkId::from_parameter(n), &dataset, |b, d| { 49 | b.iter(|| hyperparams.fit(d)) 50 | }); 51 | } 52 | 53 | group.finish(); 54 | } 55 | 56 | #[cfg(not(target_os = "windows"))] 57 | criterion_group! { 58 | name = benches; 59 | config = config::get_default_profiling_configs(); 60 | targets = decision_tree_bench 61 | } 62 | #[cfg(target_os = "windows")] 63 | criterion_group!(benches, decision_tree_bench); 64 | 65 | criterion_main!(benches); 66 | -------------------------------------------------------------------------------- /algorithms/linfa-trees/examples/decision_tree.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::io::Write; 3 | 4 | use ndarray_rand::rand::SeedableRng; 5 | use rand::rngs::SmallRng; 6 | 7 | use linfa::prelude::*; 8 | use linfa_trees::{DecisionTree, Result, SplitQuality}; 9 | 10 | fn main() -> Result<()> { 11 | // load Iris dataset 12 | let mut rng = SmallRng::seed_from_u64(42); 13 | 14 | let (train, test) = linfa_datasets::iris() 15 | .shuffle(&mut rng) 16 | .split_with_ratio(0.8); 17 | 18 | println!("Training model with Gini criterion ..."); 19 | let gini_model = DecisionTree::params() 20 | .split_quality(SplitQuality::Gini) 21 | .max_depth(Some(100)) 22 | .min_weight_split(1.0) 23 | .min_weight_leaf(1.0) 24 | .fit(&train)?; 25 | 26 | let gini_pred_y = gini_model.predict(&test); 27 | let cm = gini_pred_y.confusion_matrix(&test)?; 28 | 29 | println!("{:?}", cm); 30 | 31 | println!( 32 | "Test accuracy with Gini criterion: {:.2}%", 33 | 100.0 * cm.accuracy() 34 | ); 35 | 36 | let feats = gini_model.features(); 37 | println!("Features trained in this tree {:?}", feats); 38 | 39 | println!("Training model with entropy criterion ..."); 40 | let entropy_model = DecisionTree::params() 41 | .split_quality(SplitQuality::Entropy) 42 | .max_depth(Some(100)) 43 | .min_weight_split(10.0) 44 | .min_weight_leaf(10.0) 45 | .fit(&train)?; 46 | 47 | let entropy_pred_y = entropy_model.predict(&test); 48 | let cm = entropy_pred_y.confusion_matrix(&test)?; 49 | 50 | println!("{:?}", cm); 51 | 52 | println!( 53 | "Test accuracy with Entropy criterion: {:.2}%", 54 | 100.0 * cm.accuracy() 55 | ); 56 | 57 | let feats = entropy_model.features(); 58 | println!("Features trained in this tree {:?}", feats); 59 | 60 | let mut tikz = File::create("decision_tree_example.tex").unwrap(); 61 | tikz.write_all( 62 | gini_model 63 | .export_to_tikz() 64 | .with_legend() 65 | .to_string() 66 | .as_bytes(), 67 | ) 68 | .unwrap(); 69 | println!(" => generate Gini tree description with `latex decision_tree_example.tex`!"); 70 | 71 | Ok(()) 72 | } 73 | -------------------------------------------------------------------------------- /algorithms/linfa-trees/src/decision_trees/iter.rs: -------------------------------------------------------------------------------- 1 | use std::collections::VecDeque; 2 | use std::fmt::Debug; 3 | use std::iter::Iterator; 4 | 5 | use super::TreeNode; 6 | use linfa::{Float, Label}; 7 | 8 | /// Level-order (BFT) iterator of nodes in a decision tree 9 | #[derive(Debug, Clone, PartialEq)] 10 | pub struct NodeIter<'a, F, L> { 11 | queue: VecDeque<&'a TreeNode>, 12 | } 13 | 14 | impl<'a, F, L> NodeIter<'a, F, L> { 15 | pub fn new(queue: VecDeque<&'a TreeNode>) -> Self { 16 | NodeIter { queue } 17 | } 18 | } 19 | 20 | impl<'a, F: Float, L: Debug + Label> Iterator for NodeIter<'a, F, L> { 21 | type Item = &'a TreeNode; 22 | 23 | fn next(&mut self) -> Option { 24 | #[allow(clippy::manual_inspect)] 25 | self.queue.pop_front().map(|node| { 26 | node.children() 27 | .into_iter() 28 | .filter_map(|x| x.as_ref()) 29 | .for_each(|child| self.queue.push_back(child)); 30 | node 31 | }) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /algorithms/linfa-trees/src/decision_trees/mod.rs: -------------------------------------------------------------------------------- 1 | mod algorithm; 2 | mod hyperparams; 3 | mod iter; 4 | mod tikz; 5 | 6 | pub use algorithm::*; 7 | pub use hyperparams::*; 8 | pub use iter::*; 9 | pub use tikz::*; 10 | -------------------------------------------------------------------------------- /algorithms/linfa-trees/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! # Decision tree learning 3 | //! `linfa-trees` aims to provide pure rust implementations 4 | //! of decison trees learning algorithms. 5 | //! 6 | //! # The big picture 7 | //! 8 | //! `linfa-trees` is a crate in the [linfa](https://github.com/rust-ml/linfa) ecosystem, 9 | //! an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's scikit-learn. 10 | //! 11 | //! Decision Trees (DTs) are a non-parametric supervised learning method used for classification and regression. 12 | //! The goal is to create a model that predicts the value of a target variable by learning simple decision rules 13 | //! inferred from the data features. 14 | //! 15 | //! # Current state 16 | //! 17 | //! `linfa-trees` currently provides an [implementation](DecisionTree) of single-tree fitting for classification. 18 | //! 19 | 20 | mod decision_trees; 21 | 22 | pub use decision_trees::*; 23 | pub use linfa::error::Result; 24 | -------------------------------------------------------------------------------- /algorithms/linfa-tsne/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-tsne" 3 | version = "0.7.1" 4 | authors = ["Lorenz Schmidt "] 5 | edition = "2018" 6 | 7 | description = "Barnes-Hut t-distributed stochastic neighbor embedding" 8 | license = "MIT OR Apache-2.0" 9 | 10 | repository = "https://github.com/rust-ml/linfa" 11 | readme = "README.md" 12 | 13 | keywords = ["tsne", "visualization", "clustering", "machine-learning", "linfa"] 14 | categories = ["algorithms", "mathematics", "science"] 15 | 16 | [dependencies] 17 | thiserror = "1.0" 18 | ndarray = { version = "0.15" } 19 | ndarray-rand = "0.14" 20 | bhtsne = "0.4.0" 21 | pdqselect = "=0.1.0" 22 | 23 | linfa = { version = "0.7.1", path = "../.." } 24 | 25 | [dev-dependencies] 26 | rand = "0.8" 27 | approx = "0.4" 28 | 29 | linfa-datasets = { version = "0.7.1", path = "../../datasets", features = ["iris"] } 30 | linfa-reduction = { version = "0.7.1", path = "../linfa-reduction" } 31 | 32 | [target.'cfg(not(target_family = "windows"))'.dev-dependencies] 33 | mnist = { version = "0.5", features = ["download"] } 34 | -------------------------------------------------------------------------------- /algorithms/linfa-tsne/README.md: -------------------------------------------------------------------------------- 1 | # t-SNE 2 | 3 | `linfa-tsne` provides a pure Rust implementation of exact and Barnes-Hut t-SNE. 4 | 5 | ## The Big Picture 6 | 7 | `linfa-tsne` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. 8 | 9 | ## Current state 10 | 11 | `linfa-tsne` currently provides an implementation of the following methods: 12 | 13 | - exact solution t-SNE 14 | - Barnes-Hut t-SNE 15 | 16 | It wraps the [bhtsne](https://github.com/frjnn/bhtsne) crate, all kudos to them. 17 | 18 | ## Examples 19 | 20 | There is an usage example in the `examples/` directory. To run it, do: 21 | 22 | ```bash 23 | $ cargo run --example tsne 24 | ``` 25 | 26 | You have to install the `gnuplot` library for plotting. Also take a look at the [README](https://github.com/rust-ml/linfa#blaslapack-backend) to see what BLAS/LAPACK backends are possible. 27 | 28 | ## License 29 | Dual-licensed to be compatible with the Rust project. 30 | 31 | Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be copied, modified, or distributed except according to those terms. 32 | 33 | -------------------------------------------------------------------------------- /algorithms/linfa-tsne/examples/iris_plot.plt: -------------------------------------------------------------------------------- 1 | set style increment user 2 | set style line 1 lc rgb 'red' 3 | set style line 2 lc rgb 'blue' 4 | set style line 3 lc rgb 'green' 5 | 6 | set style data points 7 | plot 'examples/iris.dat' using 1:2:3 linecolor variable pt 7 ps 2 t '' 8 | -------------------------------------------------------------------------------- /algorithms/linfa-tsne/examples/mnist.rs: -------------------------------------------------------------------------------- 1 | // This example is disabled for windows till mnist > 0.5 is released 2 | // See https://github.com/davidMcneil/mnist/issues/10 3 | 4 | #[cfg(not(target_family = "windows"))] 5 | use linfa_tsne::Result; 6 | 7 | #[cfg(not(target_family = "windows"))] 8 | fn main() -> Result<()> { 9 | use linfa::traits::{Fit, Transformer}; 10 | use linfa::Dataset; 11 | use linfa_reduction::Pca; 12 | use linfa_tsne::TSneParams; 13 | 14 | #[cfg(not(target_family = "windows"))] 15 | use mnist::{Mnist, MnistBuilder}; 16 | 17 | use ndarray::Array; 18 | use std::{io::Write, process::Command}; 19 | 20 | // use 50k samples from the MNIST dataset 21 | let (trn_size, rows, cols) = (50_000usize, 28, 28); 22 | 23 | // download and extract it into a dataset 24 | let Mnist { 25 | trn_img, trn_lbl, .. 26 | } = MnistBuilder::new() 27 | .label_format_digit() 28 | .training_set_length(trn_size as u32) 29 | .download_and_extract() 30 | .finalize(); 31 | 32 | // create a dataset from it 33 | let ds = Dataset::new( 34 | Array::from_shape_vec((trn_size, rows * cols), trn_img)?.mapv(|x| (x as f64) / 255.), 35 | Array::from_shape_vec((trn_size, 1), trn_lbl)?, 36 | ); 37 | 38 | // reduce to 50 dimension without whitening 39 | let ds = Pca::params(50) 40 | .whiten(false) 41 | .fit(&ds) 42 | .unwrap() 43 | .transform(ds); 44 | 45 | // calculate a two-dimensional embedding with Barnes-Hut t-SNE 46 | let ds = TSneParams::embedding_size(2) 47 | .perplexity(50.0) 48 | .approx_threshold(0.5) 49 | .max_iter(1000) 50 | .transform(ds)?; 51 | 52 | // write out 53 | let mut f = std::fs::File::create("examples/mnist.dat").unwrap(); 54 | 55 | for (x, y) in ds.sample_iter() { 56 | f.write_all(format!("{} {} {}\n", x[0], x[1], y[0]).as_bytes()) 57 | .unwrap(); 58 | } 59 | 60 | // and plot with gnuplot 61 | #[allow(clippy::zombie_processes)] 62 | Command::new("gnuplot") 63 | .arg("-p") 64 | .arg("examples/mnist_plot.plt") 65 | .spawn() 66 | .expect( 67 | "Failed to launch gnuplot. Pleasure ensure that gnuplot is installed and on the $PATH.", 68 | ); 69 | Ok(()) 70 | } 71 | 72 | #[cfg(target_family = "windows")] 73 | fn main() {} 74 | -------------------------------------------------------------------------------- /algorithms/linfa-tsne/examples/mnist_plot.plt: -------------------------------------------------------------------------------- 1 | set style increment user 2 | set style line 1 lc rgb 'red' 3 | set style line 2 lc rgb 'blue' 4 | set style line 3 lc rgb 'green' 5 | 6 | set style data points 7 | plot 'examples/mnist.dat' using 1:2:3 linecolor variable pt 7 ps 2 t '' 8 | -------------------------------------------------------------------------------- /algorithms/linfa-tsne/examples/tsne.rs: -------------------------------------------------------------------------------- 1 | use linfa::traits::{Fit, Transformer}; 2 | use linfa_reduction::Pca; 3 | use linfa_tsne::{Result, TSneParams}; 4 | use std::{io::Write, process::Command}; 5 | 6 | fn main() -> Result<()> { 7 | let ds = linfa_datasets::iris(); 8 | let ds = Pca::params(3).whiten(true).fit(&ds).unwrap().transform(ds); 9 | 10 | let ds = TSneParams::embedding_size(2) 11 | .perplexity(10.0) 12 | .approx_threshold(0.1) 13 | .transform(ds)?; 14 | 15 | let mut f = std::fs::File::create("examples/iris.dat").unwrap(); 16 | 17 | for (x, y) in ds.sample_iter() { 18 | f.write_all(format!("{} {} {}\n", x[0], x[1], y).as_bytes()) 19 | .unwrap(); 20 | } 21 | 22 | Command::new("gnuplot") 23 | .arg("-p") 24 | .arg("examples/iris_plot.plt") 25 | .spawn() 26 | .expect( 27 | "Failed to launch gnuplot. Pleasure ensure that gnuplot is installed and on the $PATH.", 28 | ) 29 | .wait() 30 | .expect("Failed to wait on gnuplot sub-process"); 31 | 32 | Ok(()) 33 | } 34 | -------------------------------------------------------------------------------- /algorithms/linfa-tsne/src/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | /// Simplified `Result` using [`TSneError`](crate::TSneError) as error type 4 | pub type Result = std::result::Result; 5 | 6 | /// Error variants from hyper-parameter construction or model estimation 7 | #[derive(Error, Debug)] 8 | pub enum TSneError { 9 | #[error("negative perplexity")] 10 | NegativePerplexity, 11 | #[error("perplexity too large for number of samples")] 12 | PerplexityTooLarge, 13 | #[error("negative approximation threshold")] 14 | NegativeApproximationThreshold, 15 | #[error("embedding size larger than original dimensionality")] 16 | EmbeddingSizeTooLarge, 17 | #[error("number of preliminary iterations larger than total iterations")] 18 | PreliminaryIterationsTooLarge, 19 | #[error("invalid shaped array {0}")] 20 | InvalidShape(#[from] ndarray::ShapeError), 21 | #[error(transparent)] 22 | BaseCrate(#[from] linfa::Error), 23 | } 24 | -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(feature = "openblas-system", feature = "netlib-system"))] 2 | fn main() { 3 | println!("cargo:rustc-link-lib=lapacke"); 4 | println!("cargo:rustc-link-lib=lapack"); 5 | println!("cargo:rustc-link-lib=cblas"); 6 | } 7 | 8 | #[cfg(not(any(feature = "openblas-system", feature = "netlib-system")))] 9 | fn main() {} 10 | -------------------------------------------------------------------------------- /datasets/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linfa-datasets" 3 | version = "0.7.1" 4 | authors = ["Lorenz Schmidt "] 5 | description = "Collection of small datasets for Linfa" 6 | edition = "2018" 7 | license = "MIT OR Apache-2.0" 8 | repository = "https://github.com/rust-ml/linfa" 9 | 10 | [dependencies] 11 | linfa = { version = "0.7.1", path = ".." } 12 | ndarray = { version = "0.15" } 13 | ndarray-csv = "=0.5.1" 14 | csv = "1.1" 15 | flate2 = "1.0" 16 | ndarray-rand = { version = "0.14", optional = true } 17 | 18 | [dev-dependencies] 19 | approx = "0.4" 20 | statrs = "0.16.0" 21 | 22 | [features] 23 | default = [] 24 | diabetes = [] 25 | iris = [] 26 | winequality = [] 27 | linnerud = [] 28 | generate = ["ndarray-rand"] 29 | -------------------------------------------------------------------------------- /datasets/data/diabetes_data.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rust-ml/linfa/20b1dd2d0879ca114aa4ea24db5cfdcdd1aae186/datasets/data/diabetes_data.csv.gz -------------------------------------------------------------------------------- /datasets/data/diabetes_target.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rust-ml/linfa/20b1dd2d0879ca114aa4ea24db5cfdcdd1aae186/datasets/data/diabetes_target.csv.gz -------------------------------------------------------------------------------- /datasets/data/iris.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rust-ml/linfa/20b1dd2d0879ca114aa4ea24db5cfdcdd1aae186/datasets/data/iris.csv.gz -------------------------------------------------------------------------------- /datasets/data/linnerud_exercise.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rust-ml/linfa/20b1dd2d0879ca114aa4ea24db5cfdcdd1aae186/datasets/data/linnerud_exercise.csv.gz -------------------------------------------------------------------------------- /datasets/data/linnerud_physiological.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rust-ml/linfa/20b1dd2d0879ca114aa4ea24db5cfdcdd1aae186/datasets/data/linnerud_physiological.csv.gz -------------------------------------------------------------------------------- /datasets/data/winequality-red.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rust-ml/linfa/20b1dd2d0879ca114aa4ea24db5cfdcdd1aae186/datasets/data/winequality-red.csv.gz -------------------------------------------------------------------------------- /datasets/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![doc = include_str!("../README.md")] 2 | 3 | pub mod dataset; 4 | #[cfg(feature = "generate")] 5 | pub mod generate; 6 | 7 | pub use dataset::*; 8 | -------------------------------------------------------------------------------- /docs/website/config.toml: -------------------------------------------------------------------------------- 1 | # The URL the site will be built for 2 | base_url = "https://rust-ml.github.io/linfa/" 3 | 4 | # Whether to automatically compile all Sass files in the sass directory 5 | compile_sass = true 6 | 7 | # Whether to build a search index to be used later on by a JavaScript library 8 | build_search_index = true 9 | 10 | [markdown] 11 | # Whether to do syntax highlighting 12 | # Theme can be customised by setting the `highlight_theme` variable to a theme supported by Zola 13 | highlight_code = true 14 | highlight_theme = "inspired-github" 15 | 16 | [extra] 17 | # Put all your custom variables here 18 | -------------------------------------------------------------------------------- /docs/website/content/blog/_index.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "List of blog posts" 3 | sort_by = "date" 4 | template = "blog.html" 5 | page_template = "blog-page.html" 6 | +++ 7 | -------------------------------------------------------------------------------- /docs/website/content/blog/first.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "My first post" 3 | date = 2019-11-27 4 | +++ 5 | 6 | This is my first blog post. 7 | -------------------------------------------------------------------------------- /docs/website/content/community.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Community" 3 | date = 2021-02-26 4 | +++ 5 | 6 | Linfa provides a framework for statistical learning to the Rust community. As such we provide reference implementations, help others implementing their algorithms, ensuring a quality standard in documentation and providing guidelines with best-practices in context of the Rust language. 7 | 8 | A welcoming, friendly and safe environment is very important to use and we are striving to make the community as inclusive and open as possible. The implementation and communication is happening in the open with the possibility for anyone to participate. 9 | 10 | ## Providing knowledge 11 | 12 | In the past we could help people in two different ways. If you have knowledge of an algorithm and would like to use Rust as a language, we can help you in the language details. On the other hand, we can also provide you with a paper as a starting point and help you on questions which may come up. 13 | 14 | If you happen to belong to the first group, then please reach out on [Zulip](https://rust-ml.zulipchat.com) or open a pull request if you wish for feedback. If you are looking for a topic you can either reach out on [Zulip](https://rust-ml.zulipchat.com) as well, or look into the [roadmap](https://github.com/rust-ml/linfa/issues/7) and find one which fits your interests. 15 | 16 | ## Platforms 17 | 18 | There are two ways you can reach us: 19 | * say hello on [Zulip](https://rust-ml.zulipchat.com) 20 | * or participate in our [Github issues](https://github.com/rust-ml/linfa/issues) 21 | -------------------------------------------------------------------------------- /docs/website/content/docs.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Documentation" 3 | +++ 4 | 5 | # API documentation 6 | 7 | You can find the latest API documentation [here](../rustdocs/linfa/). 8 | 9 | # Examples 10 | 11 | A good way to start is by looking at code examples. You can find them in the *examples/* folder of each sub-crate. For instance, Support Vector Machines have their examples [here](https://github.com/rust-ml/linfa/tree/master/algorithms/linfa-svm/examples). 12 | 13 | # How do contribute 14 | 15 | If you want to know how to integrate something into Linfa, take a look at the [Contribute.md](https://github.com/rust-ml/linfa/blob/master/CONTRIBUTE.md). It covers topics on the type system implementation, how to use datasets, add parameters to your algorithm, etc. 16 | 17 | # Understanding the theory 18 | 19 | A good book for someone with knowledge of undergraduate information/probability theory and linear algebra is the *Elements of Statistical Learning* by Friedman, Tibshirani and Hastie. You can find the free version [here](https://web.stanford.edu/~hastie/ElemStatLearn/). 20 | -------------------------------------------------------------------------------- /docs/website/content/news/_index.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "News" 3 | sort_by = "date" 4 | template = "news.html" 5 | page_template = "news-page.html" 6 | +++ 7 | -------------------------------------------------------------------------------- /docs/website/content/news/first_release.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Release 0.2.0" 3 | date = "2020-11-26" 4 | +++ 5 | This release of Linfa introduced 9 new implementations and a couple of changes to the APIs. Travis support for FOSS projects was dropped, so we were forced to switch to Github Actions and we introduced a couple of traits to represent different classes of algorithms in a better way. 6 | 7 | 8 | 9 | New algorithms 10 | ----------- 11 | 12 | - Ordinary Linear Regression has been added to `linfa-linear` by [@Nimpruda] and [@paulkoerbitz] 13 | - Generalized Linear Models has been added to `linfa-linear` by [@VasanthakumarV] 14 | - Linear decision trees were added to `linfa-trees` by [@mossbanay] 15 | - Fast independent component analysis (ICA) has been added to `linfa-ica` by [@VasanthakumarV] 16 | - Principal Component Analysis and Diffusion Maps have been added to `linfa-reduction` by [@bytesnake] 17 | - Support Vector Machines has been added to `linfa-svm` by [@bytesnake] 18 | - Logistic regression has been added to `linfa-logistic` by [@paulkoerbitz] 19 | - Hierarchical agglomerative clustering has been added to `linfa-hierarchical` by [@bytesnake] 20 | - Gaussian Mixture Models has been added to `linfa-clustering` by [@relf] 21 | 22 | Changes 23 | ---------- 24 | 25 | - Common metrics for classification and regression have been added 26 | - A new dataset interface simplifies the work with targets and labels 27 | - New traits for `Transformer`, `Fit` and `IncrementalFit` standardizes the interface 28 | - Switched to Github Actions for better integration 29 | 30 | -------------------------------------------------------------------------------- /docs/website/content/news/new_website.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "New website" 3 | date = "2021-02-27" 4 | +++ 5 | 6 | I'm happy to announce that Linfa finally gets its own website. We are currently trying to improve the documentation in various places and a website should serve as a starting point for anyone interested in Linfa. 7 | 8 | 9 | 10 | The [Zola](https://github.com/getzola/zola) project offers an excellent static page generator. It was easy to setup, customize and powerful enough for our usecase. A big kudos to the developer over there! We also integrated it into our CI system for a continuous publication of the website and its announcements. 11 | 12 | In the future we may publish a markdown book on Statistical/Machine learning algorithms, but this may be a bit far fetched right now. If you are interested in what we are doing, say hello at [Zulip](https://rust-ml.zulipchat.com). 13 | -------------------------------------------------------------------------------- /docs/website/content/news/release_021.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Release 0.2.1" 3 | date = "2020-11-29" 4 | +++ 5 | 6 | Linfa 0.2.1 amends changes to the feature system, which made it impossible to release 0.2.0 on `crates.io`. 7 | 8 | 9 | 10 | Changes 11 | ---------- 12 | 13 | * Use openblas-system backend for now 14 | * Fill up missing information in algorithm descriptions 15 | -------------------------------------------------------------------------------- /docs/website/content/news/release_031.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Release 0.3.1" 3 | date = "2021-03-11" 4 | +++ 5 | 6 | In this release of Linfa the documentation is extended, new examples are added and the functionality of datasets improved. No new algorithms were added. 7 | 8 | 9 | 10 | The meta-issue [#82](https://github.com/rust-ml/linfa/issues/82) gives a good overview of the necessary documentation improvements and testing/documentation/examples were considerably extended in this release. 11 | 12 | Further new functionality was added to datasets and multi-target datasets are introduced. Bootstrapping is now possible for features and samples and you can cross-validate your model with k-folding. We polished various bits in the kernel machines and simplified the interface there. 13 | 14 | The trait structure of regression metrics are simplified and the silhouette score introduced for easier testing of K-Means and other algorithms. 15 | 16 | 17 | # Changes 18 | 19 | * improve documentation in all algorithms, various commits 20 | * add a website to the infrastructure (c8acc785b) 21 | * add k-folding with and without copying (b0af80546f8) 22 | * add feature naming and pearson's cross correlation (71989627f) 23 | * improve ergonomics when handling kernels (1a7982b973) 24 | * improve TikZ generator in `linfa-trees` (9d71f603bbe) 25 | * introduce multi-target datasets (b231118629) 26 | * simplify regression metrics and add cluster metrics (d0363a1fa8ef) 27 | 28 | # Example 29 | 30 | You can now perform cross-validation with k-folding. @Sauro98 actually implemented two versions, one which copies the dataset into k folds and one which avoid excessive memory operations by copying only the validation dataset around. For example to test a model with 8-folding: 31 | 32 | ```rust 33 | // perform cross-validation with the F1 score 34 | let f1_runs = dataset 35 | .iter_fold(8, |v| params.fit(&v).unwrap()) 36 | .map(|(model, valid)| { 37 | let cm = model 38 | .predict(&valid) 39 | .mapv(|x| x > Pr::even()) 40 | .confusion_matrix(&valid).unwrap(); 41 | 42 | cm.f1_score() 43 | }) 44 | .collect::>(); 45 | 46 | // calculate mean and standard deviation 47 | println!("F1 score: {}±{}", 48 | f1_runs.mean().unwrap(), 49 | f1_runs.std_axis(Axis(0), 0.0), 50 | ); 51 | ``` 52 | -------------------------------------------------------------------------------- /docs/website/content/news/release_040/tsne.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rust-ml/linfa/20b1dd2d0879ca114aa4ea24db5cfdcdd1aae186/docs/website/content/news/release_040/tsne.png -------------------------------------------------------------------------------- /docs/website/content/news/release_051.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Release 0.5.1" 3 | date = "2022-02-28" 4 | +++ 5 | 6 | Linfa's 0.5.1 release fixes errors and bugs in the previous release, as well as removing useless trait bounds on the `Dataset` type. Note that the commits for this release are located in the `0-5-1` branch of the GitHub repo. 7 | 8 | 9 | 10 | ## Improvements 11 | 12 | * remove `Float` trait bound from many `Dataset` impls, making non-float datasets usable 13 | * fix build errors in 0.5.0 caused by breaking minor releases from dependencies 14 | * fix bug in k-means where the termination condition of the algorithm was calculated incorrectly 15 | * fix build failure when building `linfa` alone, caused by incorrect feature selection for `ndarray` 16 | -------------------------------------------------------------------------------- /docs/website/content/news/release_061.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Release 0.6.1" 3 | date = "2022-12-03" 4 | +++ 5 | 6 | Linfa's 0.6.1 release mainly consists of fixes to existing algorithms and the overall crate. The Isotonic Regression algorithm has also been added to `linfa-linear`. 7 | 8 | 9 | 10 | ## Improvements and fixes 11 | 12 | * Add constructor for `LpDist` in `linfa-nn`. 13 | * Add `Send + Sync` to trait objects returned by `linfa-nn`, which are now aliased as `NearestNeighbourBox`. 14 | * Remove `anyhow <= 1.0.48` version restriction from `linfa`. 15 | * Bump `ndarray` dependency to 0.15 16 | * Fix `serde` support for `LogisticRegression` in `linfa-logistic`. 17 | 18 | ## New algorithms 19 | 20 | Isotonic regression fits a free-form line to the training data. Unlike linear regression, which fits a straight line, isotonic regression can result in a much closer fit to the data. The algorithm has been added to `linfa-linear`. 21 | 22 | Mean absolution percentage error (MAPE) is a method of measuring the difference between two datasets, and has been added to the main `linfa` crate. 23 | -------------------------------------------------------------------------------- /docs/website/content/news/release_070.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Release 0.7.0" 3 | date = "2023-10-15" 4 | +++ 5 | 6 | Linfa's 0.7.0 release mainly consists of improvements to Serde support. It also removes Approximate DBSCAN from `linfa-clustering` due to subpar performance and outdated dependencies. 7 | 8 | 9 | 10 | ## Improvements and fixes 11 | 12 | * Add `array_from_gz_csv` and `array_from_csv` in `linfa-datasets`. 13 | * Make Serde support in `linfa-linear`, `linfa-logistic`, and `linfa-ftrl` optional. 14 | * Add Serde support to `linfa-preprocessing` and `linfa-bayes`. 15 | * Bump `argmin` to 0.8.1. 16 | * Make licenses follow SPDX 2.1 license expression standard. 17 | 18 | ## Removals 19 | 20 | Approximate DBSCAN is an alternative implementation of the DBSCAN algorithm that trades precision for speed. However, the implementation in `linfa-clustering` is actually slower than the regular DBSCAN implementation. It also depends on the `partitions` crate, which is incompatible with current versions of Rust. Thus, we have decided to remove Approximate DBSCAN from Linfa. The Approximate DBSCAN types and APIs are now aliases to regular DBSCAN. 21 | -------------------------------------------------------------------------------- /docs/website/content/news/release_071.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Release 0.7.1" 3 | date = "2025-01-14" 4 | +++ 5 | 6 | Linfa's 0.7.1 release mainly consists of fixes to existing algorithms and the overall crate. The Random Projection algorithm has also been added to `linfa-reduction`. 7 | 8 | 9 | 10 | ## Improvements and fixes 11 | 12 | * add `serde` support to `linfa-clustering` 13 | * add accessors for classes in `linfa-logistics` 14 | * add accessors for `Pca` attributes in `linfa-reduction` 15 | * add `wasm-bindgen`feature to use linfa in the browser 16 | * fix covariance update for `GaussianMixtureModel` in `linfa-clustering` 17 | * bump `ndarray-linalg` to 0.16 and `argmin` to 0.9.0 18 | * bump MSRV to 1.71.1 19 | 20 | ## New algorithms 21 | 22 | Random projections are a simple and computationally efficient way to reduce the dimensionality of the data by trading a controlled amount of accuracy (as additional variance) for faster processing times and smaller model sizes. 23 | 24 | The dimensions and distribution of random projections matrices are controlled so as to preserve the pairwise distances between any two samples of the dataset. 25 | 26 | See also [sklearn.random_projection](https://scikit-learn.org/stable/api/sklearn.random_projection.html) 27 | -------------------------------------------------------------------------------- /docs/website/content/snippets/_index.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Snippets" 3 | +++ 4 | -------------------------------------------------------------------------------- /docs/website/content/snippets/cross-validation.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Cross Validation" 3 | +++ 4 | ```rust 5 | // parameters to compare 6 | let ratios = vec![0.1, 0.2, 0.5, 0.7, 1.0]; 7 | 8 | // create a model for each parameter 9 | let models = ratios 10 | .iter() 11 | .map(|ratio| ElasticNet::params().penalty(0.3).l1_ratio(*ratio)) 12 | .collect::>(); 13 | 14 | // get the mean r2 validation score across 5 folds for each model 15 | let r2_values = 16 | dataset.cross_validate(5, &models, |prediction, truth| prediction.r2(&truth))?; 17 | 18 | // show the mean r2 score for each parameter choice 19 | for (ratio, r2) in ratios.iter().zip(r2_values.iter()) { 20 | println!("L1 ratio: {}, r2 score: {}", ratio, r2); 21 | } 22 | ``` -------------------------------------------------------------------------------- /docs/website/content/snippets/decision-trees.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Linear Decision Trees" 3 | +++ 4 | ```rust 5 | let (train, valid) = linfa_datasets::iris() 6 | .split_with_ratio(0.8); 7 | 8 | // Train model with Gini criterion 9 | let gini_model = DecisionTree::params() 10 | .split_quality(SplitQuality::Gini) 11 | .max_depth(Some(100)) 12 | .min_weight_split(1.0) 13 | .fit(&train)?; 14 | 15 | let cm = gini_model.predict(&valid) 16 | .confusion_matrix(&valid); 17 | 18 | println!("{:?}", cm); 19 | println!("Accuracy {}%", cm.accuracy() * 100.0); 20 | ``` 21 | -------------------------------------------------------------------------------- /docs/website/content/snippets/diffusion-maps.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Diffusion Maps" 3 | +++ 4 | ```rust 5 | // generate RBF kernel with sparsity constraints 6 | let kernel = Kernel::params() 7 | .kind(KernelType::Sparse(15)) 8 | .method(KernelMethod::Gaussian(2.0)) 9 | .transform(dataset.view()); 10 | 11 | let embedding = DiffusionMap::::params(2) 12 | .steps(1) 13 | .transform(&kernel)?; 14 | 15 | // get embedding 16 | let embedding = embedding.embedding(); 17 | ``` 18 | -------------------------------------------------------------------------------- /docs/website/content/snippets/elasticnet.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Elastic Net" 3 | +++ 4 | ```rust 5 | let (train, valid) = linfa_datasets::diabetes() 6 | .split_with_ratio(0.9); 7 | 8 | // train pure LASSO model with 0.1 penalty 9 | let model = ElasticNet::params() 10 | .penalty(0.1) 11 | .l1_ratio(1.0) 12 | .fit(&train)?; 13 | 14 | println!("z score: {:?}", model.z_score()); 15 | 16 | // validate 17 | let y_est = model.predict(&valid); 18 | println!("predicted variance: {}", y_est.r2(&valid)?); 19 | ``` 20 | -------------------------------------------------------------------------------- /docs/website/content/snippets/gaussian-naive-bayes.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Gaussian Naive Bayes" 3 | +++ 4 | ```rust 5 | let (train, valid) = linfa_datasets::iris() 6 | .split_with_ratio(0.8); 7 | 8 | // train the model 9 | let model = GaussianNbParams::params() 10 | .fit(&train)?; 11 | 12 | // Predict the validation dataset 13 | let pred = model.predict(&valid); 14 | 15 | // construct confusion matrix 16 | let cm = pred.confusion_matrix(&valid)?; 17 | 18 | // print confusion matrix, accuracy and precision 19 | println!("{:?}", cm); 20 | println!("accuracy {}, precision {}", 21 | cm.accuracy(), cm.precision()); 22 | ``` 23 | -------------------------------------------------------------------------------- /docs/website/content/snippets/k-folding.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "K folding" 3 | +++ 4 | ```rust 5 | // perform cross-validation with the F1 score 6 | let f1_runs = dataset 7 | .iter_fold(8, |v| params.fit(&v).unwrap()) 8 | .map(|(model, valid)| { 9 | let cm = model 10 | .predict(&valid) 11 | .mapv(|x| x > Pr::even()) 12 | .confusion_matrix(&valid).unwrap(); 13 | 14 | cm.f1_score() 15 | }) 16 | .collect::>(); 17 | 18 | // calculate mean and standard deviation 19 | println!("F1 score: {}±{}", 20 | f1_runs.mean().unwrap(), 21 | f1_runs.std_axis(Axis(0), 0.0), 22 | ); 23 | ``` 24 | -------------------------------------------------------------------------------- /docs/website/content/snippets/multi-class.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Multi Class" 3 | +++ 4 | ```rust 5 | let params = Svm::<_, Pr>::params() 6 | .gaussian_kernel(30.0); 7 | 8 | // assume we have a binary decision model (here SVM) 9 | // predicting probability. We can merge them into a 10 | // multi-class model by collecting several of them 11 | // into a `MultiClassModel` 12 | let model = train 13 | .one_vs_all()? 14 | .into_iter() 15 | .map(|(l, x)| (l, params.fit(&x).unwrap())) 16 | .collect::>(); 17 | 18 | // predict multi-class label 19 | let pred = model.predict(&valid); 20 | ``` 21 | -------------------------------------------------------------------------------- /docs/website/content/snippets/multi-targets.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Multi Targets" 3 | +++ 4 | ```rust 5 | // assume we have a dataset with multiple, 6 | // uncorrelated targets and we want to train 7 | // a single model for each target variable 8 | let model = train.target_iter() 9 | .map(|x| params.fit(&x).unwrap()) 10 | .collect::>()?; 11 | 12 | // composing `model` returns multiple targets 13 | let valid_est = model.predict(valid); 14 | println!("{}", valid_est.ntargets()); 15 | ``` 16 | -------------------------------------------------------------------------------- /docs/website/content/snippets/partial-least-squares.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Partial least squares regression" 3 | +++ 4 | ```rust 5 | // Load linnerud dataset with 20 samples, 6 | // 3 input features, 3 output features 7 | let ds = linfa_datasets::linnerud(); 8 | 9 | // Fit PLS2 method using 2 principal components 10 | // (latent variables) 11 | let pls = PlsRegression::params(2).fit(&ds)?; 12 | 13 | // We can either apply the dimension reduction to the dataset 14 | let reduced_ds = pls.transform(ds); 15 | 16 | // ... or predict outputs given a new input sample. 17 | let exercices = array![[14., 146., 61.], [6., 80., 60.]]; 18 | let physio_measures = pls.predict(exercices); 19 | ``` 20 | -------------------------------------------------------------------------------- /docs/website/content/snippets/random-projection.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Gaussian Random Projection" 3 | +++ 4 | ```rust 5 | // Assume we get some training data like MNIST: 60000 samples of 28*28 images (ie dim 784) 6 | let dataset = Dataset::from(Array::::random((60000, 28 * 28), Standard)); 7 | 8 | // We can work in a reduced dimension using a Gaussian Random Projection 9 | let reduced_dim = 100; 10 | let proj = GaussianRandomProjection::::params() 11 | .target_dim(reduced_dim) 12 | .fit(&dataset)?; 13 | let reduced_ds = proj.transform(&dataset); 14 | 15 | println!("New dataset shape: {:?}", reduced_ds.records().shape()); 16 | // -> New dataset shape: [60000, 100] 17 | ``` -------------------------------------------------------------------------------- /docs/website/content/snippets/sv-machines.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Support Vector Machines" 3 | +++ 4 | ```rust 5 | // everything above 6.5 is considered a good wine 6 | let (train, valid) = linfa_datasets::winequality() 7 | .map_targets(|x| *x > 6) 8 | .split_with_ratio(0.9); 9 | 10 | // train SVM with nu=0.01 and RBF with eps=80.0 11 | let model = Svm::params() 12 | .nu_weight(0.01) 13 | .gaussian_kernel(80.0) 14 | .fit(&train)?; 15 | 16 | // print model performance and number of SVs 17 | println!("{}", model); 18 | ``` 19 | -------------------------------------------------------------------------------- /docs/website/content/snippets/tsne.md: -------------------------------------------------------------------------------- 1 | +++ 2 | title = "Barnes-Hut t-SNE" 3 | +++ 4 | ```rust 5 | // normalize the iris dataset 6 | let ds = linfa_datasets::iris(); 7 | let ds = Pca::params(3).whiten(true).fit(&ds).transform(ds); 8 | 9 | // transform to two-dimensional embeddings 10 | let ds = TSne::embedding_size(2) 11 | .perplexity(10.0) 12 | .approx_threshold(0.1) 13 | .transform(ds)?; 14 | 15 | // write embedding to file 16 | let mut f = File::create("iris.dat")?; 17 | for (x, y) in ds.sample_iter() { 18 | f.write(format!("{} {} {}\n", x[0], x[1], y[0]).as_bytes())?; 19 | } 20 | ``` 21 | -------------------------------------------------------------------------------- /docs/website/sass/_base.scss: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0; 3 | } 4 | 5 | html { 6 | line-height: 1.5; 7 | font-family: sans-serif; 8 | } 9 | 10 | h1, h2 { 11 | margin-top: 0; 12 | font-size: 1.5rem; 13 | } 14 | 15 | .top { 16 | border-bottom: 1px solid $shadow-color; 17 | box-shadow: 0px 1px 2px 0px $shadow-color; 18 | padding: 0 10px; 19 | } 20 | 21 | .bottom { 22 | border-top: 1px solid $shadow-color; 23 | display: flex; 24 | 25 | margin: 10px 0; 26 | } 27 | 28 | .footer { 29 | margin-top: 10px; 30 | font-size: 13px; 31 | 32 | justify-content: center; 33 | 34 | a { 35 | text-decoration: none; 36 | } 37 | 38 | a:hover { 39 | text-decoration: underline; 40 | } 41 | } 42 | 43 | .grid { 44 | max-width: 950px; 45 | margin: 0 auto; 46 | } 47 | 48 | td, th { 49 | border: 1px solid #ddd; 50 | padding: 8px; 51 | } 52 | 53 | tr:nth-child(even){background-color: #f2f2f2;} 54 | 55 | tr:hover {background-color: #ddd;} 56 | 57 | th { 58 | padding-top: 12px; 59 | padding-bottom: 12px; 60 | text-align: left; 61 | background-color: $secondary-color; 62 | color: white; 63 | } 64 | -------------------------------------------------------------------------------- /docs/website/sass/desktop/_base.scss: -------------------------------------------------------------------------------- 1 | .container > .grid { 2 | padding: 0 10px; 3 | margin-top: 20px; 4 | } 5 | 6 | .header { 7 | display: flex; 8 | flex-direction: flow; 9 | align-items: center; 10 | 11 | .header-main { 12 | flex-grow: 1; 13 | } 14 | } 15 | 16 | .header-main ul { 17 | display: flex; 18 | justify-content: right; 19 | 20 | li { 21 | margin-left: 15px; 22 | list-style: none; 23 | 24 | a { 25 | color: black; 26 | text-decoration: none; 27 | font-size: 17px; 28 | } 29 | 30 | a:hover { 31 | color: $primary-color; 32 | } 33 | } 34 | } 35 | 36 | .logo { 37 | margin: 5px 0; 38 | 39 | display: flex; 40 | align-items: center; 41 | 42 | img { 43 | margin-right: 5px; 44 | } 45 | 46 | p { 47 | span { 48 | display: none; 49 | } 50 | 51 | font-size: 30px; 52 | line-height: 0px; 53 | } 54 | } 55 | 56 | .example { 57 | display: none; 58 | 59 | border: 1px solid #ddd; 60 | border-radius: 5px; 61 | 62 | margin-right: 10px; 63 | min-height: 200px; 64 | padding: 5px; 65 | 66 | } 67 | 68 | .visible { 69 | display: block; 70 | } 71 | 72 | .news_grid { 73 | grid-template-columns: 3fr 1fr; 74 | display: grid; 75 | grid-gap: 15px; 76 | } 77 | 78 | .news_grid .sitelinks a { 79 | display: block; 80 | text-decoration: none; 81 | color: $primary-color; 82 | } 83 | 84 | .news_grid .sitelinks a:hover { 85 | text-decoration: underline; 86 | } 87 | 88 | -------------------------------------------------------------------------------- /docs/website/sass/desktop/_home.scss: -------------------------------------------------------------------------------- 1 | .expose .grid { 2 | grid-template-columns: 3fr 2fr; 3 | display: grid; 4 | grid-gap: 15px; 5 | 6 | padding: 20px 10px; 7 | 8 | .introduction { 9 | p { 10 | margin-bottom: 5px; 11 | } 12 | } 13 | } 14 | 15 | .colored { 16 | background-color: #36b0ec; 17 | color: white; 18 | 19 | a { 20 | text-decoration: none; 21 | color: black; 22 | } 23 | 24 | a:hover { 25 | color: white; 26 | } 27 | 28 | p { 29 | margin-bottom: 0; 30 | } 31 | } 32 | 33 | .not-colored { 34 | background-color: white; 35 | color: black; 36 | 37 | a { 38 | text-decoration: none; 39 | color: gray; 40 | } 41 | 42 | a:hover { 43 | text-decoration: bold; 44 | } 45 | 46 | p { 47 | margin-bottom: 0; 48 | } 49 | } 50 | 51 | .highlights .grid { 52 | grid-template-columns: 2fr 2fr 2fr; 53 | display: grid; 54 | grid-gap: 15px; 55 | 56 | padding: 20px 10px; 57 | } 58 | -------------------------------------------------------------------------------- /docs/website/sass/desktop/_news.scss: -------------------------------------------------------------------------------- 1 | .news li { 2 | list-style-type: none; 3 | } 4 | 5 | .news a { 6 | color: #ec7235; 7 | text-decoration: none; 8 | font-size: 25px; 9 | } 10 | 11 | .news a:hover { 12 | text-decoration: underline; 13 | } 14 | 15 | .news span { 16 | color: gray; 17 | display: block; 18 | font-size: 14px; 19 | } 20 | 21 | .news-page span { 22 | color: gray; 23 | font-size: 14px; 24 | } 25 | 26 | /* 27 | .news-page code { 28 | counter-reset: line; 29 | } 30 | 31 | .news-page code span { 32 | counter-increment: line; 33 | } 34 | 35 | .news-page code span:before { 36 | content: counter(line); 37 | }*/ 38 | 39 | .news-page h1 { 40 | margin-bottom: 0; 41 | } 42 | -------------------------------------------------------------------------------- /docs/website/sass/mobile/_base.scss: -------------------------------------------------------------------------------- 1 | .section { 2 | padding: 0 10px; 3 | } 4 | 5 | .outer-table { 6 | overflow-x: scroll; 7 | } 8 | 9 | .header-main ul { 10 | padding-left: 0px; 11 | display: none; 12 | 13 | li { 14 | list-style: none; 15 | margin-bottom: 15px; 16 | 17 | a { 18 | color: black; 19 | text-decoration: none; 20 | font-size: 17px; 21 | } 22 | 23 | a:hover { 24 | color: $primary-color; 25 | } 26 | } 27 | } 28 | 29 | .logo { 30 | margin: 0px 0; 31 | 32 | display: flex; 33 | align-items: center; 34 | 35 | img { 36 | margin-right: 5px; 37 | } 38 | 39 | p { 40 | font-size: 30px; 41 | flex-grow: 1; 42 | display: flex; 43 | line-height: 0.5em; 44 | 45 | span { 46 | text-align: right; 47 | color: $primary-color; 48 | flex-grow: 1; 49 | } 50 | 51 | span:hover { 52 | cursor: pointer; 53 | } 54 | } 55 | } 56 | 57 | .sitelinks { 58 | display: none; 59 | } 60 | -------------------------------------------------------------------------------- /docs/website/sass/mobile/_home.scss: -------------------------------------------------------------------------------- 1 | .grid section { 2 | padding-top: 20px; 3 | } 4 | 5 | .example { 6 | display: none; 7 | 8 | border: 1px solid #ddd; 9 | 10 | padding: 5px; 11 | 12 | pre { 13 | margin: 0; 14 | } 15 | } 16 | 17 | .visible { 18 | display: block; 19 | overflow: scroll; 20 | } 21 | 22 | .expose { 23 | margin-top: 10px; 24 | } 25 | -------------------------------------------------------------------------------- /docs/website/sass/mobile/_news.scss: -------------------------------------------------------------------------------- 1 | .news ul { 2 | margin: 0; 3 | padding: 0; 4 | } 5 | 6 | .news li { 7 | list-style-type: none; 8 | } 9 | 10 | .news a { 11 | color: #ec7235; 12 | text-decoration: none; 13 | font-size: 25px; 14 | } 15 | 16 | .news a:hover { 17 | text-decoration: underline; 18 | } 19 | 20 | .news span { 21 | color: gray; 22 | display: block; 23 | font-size: 14px; 24 | } 25 | 26 | .news-page span { 27 | color: gray; 28 | font-size: 14px; 29 | } 30 | 31 | .news-page h1 { 32 | margin-bottom: 0; 33 | } 34 | -------------------------------------------------------------------------------- /docs/website/sass/style.scss: -------------------------------------------------------------------------------- 1 | $primary-color: #ec7235; 2 | $secondary-color: #36b0ec; 3 | $tetradic-color: #B0EC36; 4 | $shadow-color: #ddd; 5 | 6 | @import 'base'; 7 | 8 | @media only screen and (min-width: 700px) { 9 | @import 'desktop/base', 'desktop/news', 'desktop/home' 10 | } 11 | @media only screen and (max-width: 700px) { 12 | @import 'mobile/base', 'mobile/news', 'mobile/home' 13 | } 14 | -------------------------------------------------------------------------------- /docs/website/static/code-examples.js: -------------------------------------------------------------------------------- 1 | function toggle_selected(t) { 2 | let e = t.value; 3 | document.querySelector(".example.visible").classList.remove("visible") 4 | document.querySelector(`.example[data-example="${e}"]`).classList.add("visible") 5 | } 6 | 7 | document.addEventListener("DOMContentLoaded", function() { 8 | let a = document.querySelector(".logo span") 9 | a.addEventListener("click", function(t) { 10 | let obj = document.querySelector(".header-main ul"); 11 | 12 | if(obj.style.display === "" || obj.style.display === "none"){ 13 | obj.style.display = "block"; 14 | } else { 15 | obj.style.display = "none"; 16 | } 17 | }); 18 | 19 | let n = document.querySelector(".code-examples select") 20 | n.addEventListener("change", function(t) { 21 | toggle_selected(t.target) 22 | }) 23 | 24 | let ops = n.querySelectorAll("option") 25 | ops[Math.floor(Math.random() * ops.length)].selected = !0 26 | 27 | toggle_selected(n) 28 | }); 29 | -------------------------------------------------------------------------------- /docs/website/templates/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | Linfa Toolkit 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 |
18 |
19 | 27 |
28 | 35 |
36 |
37 |
38 | 39 | 40 |
41 |
42 | {% block content %} {% endblock %} 43 |
44 |
45 |
46 | 50 |
51 | 52 | 53 | -------------------------------------------------------------------------------- /docs/website/templates/blog-page.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 |

5 | {{ page.title }} 6 |

7 |

{{ page.date }}

8 | {{ page.content | safe }} 9 | {% endblock content %} 10 | -------------------------------------------------------------------------------- /docs/website/templates/blog.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 |

5 | {{ section.title }} 6 |

7 |
    8 | {% for page in section.pages %} 9 |
  • {{ page.title }}
  • 10 | {% endfor %} 11 |
12 | {% endblock content %} 13 | -------------------------------------------------------------------------------- /docs/website/templates/news-page.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 |
5 |

6 | {{ page.title }} 7 |

8 | Published on {{ page.date | date(format="%B %dth, %Y") }} 9 | {{ page.content | safe }} 10 |
11 | {% endblock content %} 12 | -------------------------------------------------------------------------------- /docs/website/templates/news.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 |
5 |

6 | {{ section.title }} 7 |

8 |
    9 | {% for page in section.pages %} 10 |
  • 11 | {{ page.title }} 12 | Published on {{ page.date | date(format="%B %dth, %Y") }} 13 |

    {{ page.summary | safe }}

    14 |
  • 15 | {% endfor %} 16 |
17 |
18 | {% endblock content %} 19 | -------------------------------------------------------------------------------- /docs/website/templates/page.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 |
5 |
6 | 7 |

8 | {{ page.title }} 9 |

10 | {{ page.content | safe }} 11 |
12 | 19 |
20 |
21 | {% endblock content %} 22 | -------------------------------------------------------------------------------- /src/benchmarks/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "benchmarks")] 2 | pub mod config { 3 | #[cfg(not(target_os = "windows"))] 4 | use criterion::Criterion; 5 | use criterion::{measurement::WallTime, BenchmarkGroup}; 6 | #[cfg(not(target_os = "windows"))] 7 | use pprof::criterion::{Output, PProfProfiler}; 8 | use std::time::Duration; 9 | 10 | #[cfg(not(target_os = "windows"))] 11 | pub fn get_default_profiling_configs() -> Criterion { 12 | Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))) 13 | } 14 | 15 | pub fn set_default_benchmark_configs(benchmark: &mut BenchmarkGroup) { 16 | let sample_size: usize = 200; 17 | let measurement_time: Duration = Duration::new(10, 0); 18 | let confidence_level: f64 = 0.97; 19 | let warm_up_time: Duration = Duration::new(10, 0); 20 | let noise_threshold: f64 = 0.05; 21 | 22 | benchmark 23 | .sample_size(sample_size) 24 | .measurement_time(measurement_time) 25 | .confidence_level(confidence_level) 26 | .warm_up_time(warm_up_time) 27 | .noise_threshold(noise_threshold); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/composing/mod.rs: -------------------------------------------------------------------------------- 1 | //! Composition models 2 | //! 3 | //! This module contains three composition models: 4 | //! * `MultiClassModel`: combine multiple binary decision models to a single multi-class model 5 | //! * `MultiTargetModel`: combine multiple univariate models to a single multi-target model 6 | //! * `Platt`: calibrate a classifier (i.e. SVC) to predicted posterior probabilities 7 | mod multi_class_model; 8 | mod multi_target_model; 9 | pub mod platt_scaling; 10 | 11 | pub use multi_class_model::MultiClassModel; 12 | pub use multi_target_model::MultiTargetModel; 13 | pub use platt_scaling::{Platt, PlattError, PlattParams}; 14 | -------------------------------------------------------------------------------- /src/dataset/impl_records.rs: -------------------------------------------------------------------------------- 1 | use super::{DatasetBase, Records}; 2 | use ndarray::{ArrayBase, Axis, Data, Dimension}; 3 | 4 | /// Implement records for NdArrays 5 | impl, I: Dimension> Records for ArrayBase { 6 | type Elem = F; 7 | 8 | fn nsamples(&self) -> usize { 9 | self.len_of(Axis(0)) 10 | } 11 | 12 | fn nfeatures(&self) -> usize { 13 | self.len_of(Axis(1)) 14 | } 15 | } 16 | 17 | /// Implement records for a DatasetBase 18 | impl, T> Records for DatasetBase { 19 | type Elem = F; 20 | 21 | fn nsamples(&self) -> usize { 22 | self.records.nsamples() 23 | } 24 | 25 | fn nfeatures(&self) -> usize { 26 | self.records.nfeatures() 27 | } 28 | } 29 | 30 | /// Implement records for an empty dataset 31 | impl Records for () { 32 | type Elem = (); 33 | 34 | fn nsamples(&self) -> usize { 35 | 0 36 | } 37 | 38 | fn nfeatures(&self) -> usize { 39 | 0 40 | } 41 | } 42 | 43 | /// Implement records for references 44 | impl Records for &R { 45 | type Elem = R::Elem; 46 | 47 | fn nsamples(&self) -> usize { 48 | (*self).nsamples() 49 | } 50 | 51 | fn nfeatures(&self) -> usize { 52 | (*self).nfeatures() 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | //! Error types in Linfa 2 | //! 3 | 4 | use thiserror::Error; 5 | 6 | use ndarray::ShapeError; 7 | #[cfg(feature = "serde")] 8 | use serde_crate::{Deserialize, Serialize}; 9 | 10 | pub type Result = std::result::Result; 11 | 12 | #[cfg_attr( 13 | feature = "serde", 14 | derive(Serialize, Deserialize), 15 | serde(crate = "serde_crate") 16 | )] 17 | #[derive(Error, Debug, Clone)] 18 | pub enum Error { 19 | #[error("invalid parameter {0}")] 20 | Parameters(String), 21 | #[error("invalid prior {0}")] 22 | Priors(String), 23 | #[error("algorithm not converged {0}")] 24 | NotConverged(String), 25 | // ShapeError doesn't implement serde traits, and deriving them remotely on a complex error 26 | // type isn't really feasible, so we skip this variant. 27 | #[cfg_attr(feature = "serde", serde(skip))] 28 | #[error("invalid ndarray shape {0}")] 29 | NdShape(#[from] ShapeError), 30 | #[error("not enough samples")] 31 | NotEnoughSamples, 32 | #[error("The number of samples do not match: {0} - {1}")] 33 | MismatchedShapes(usize, usize), 34 | } 35 | -------------------------------------------------------------------------------- /src/prelude.rs: -------------------------------------------------------------------------------- 1 | //! Linfa prelude. 2 | //! 3 | //! This module contains the most used types, type aliases, traits and 4 | //! functions that you can import easily as a group. 5 | //! 6 | 7 | #[doc(no_inline)] 8 | pub use crate::error::Error; 9 | 10 | #[doc(no_inline)] 11 | pub use crate::traits::*; 12 | 13 | #[doc(no_inline)] 14 | pub use crate::dataset::{AsTargets, Dataset, DatasetBase, DatasetView, Float, Pr, Records}; 15 | 16 | #[doc(no_inline)] 17 | pub use crate::metrics_classification::{BinaryClassification, ConfusionMatrix, ToConfusionMatrix}; 18 | 19 | #[doc(no_inline)] 20 | pub use crate::metrics_regression::{MultiTargetRegression, SingleTargetRegression}; 21 | 22 | #[doc(no_inline)] 23 | pub use crate::metrics_clustering::SilhouetteScore; 24 | 25 | #[doc(no_inline)] 26 | pub use crate::correlation::PearsonCorrelation; 27 | 28 | #[doc(no_inline)] 29 | pub use crate::param_guard::ParamGuard; 30 | -------------------------------------------------------------------------------- /src/traits.rs: -------------------------------------------------------------------------------- 1 | //! Provide traits for different classes of algorithms 2 | //! 3 | 4 | use crate::dataset::{DatasetBase, Records}; 5 | use std::convert::From; 6 | 7 | /// Transformation algorithms 8 | /// 9 | /// A transformer takes a dataset and transforms it into a different one. It has no concept of 10 | /// state and provides therefore no method to predict new data. A typical example are kernel 11 | /// methods. 12 | /// 13 | /// It should be implemented for all algorithms, also for those which can be fitted. 14 | /// 15 | pub trait Transformer { 16 | fn transform(&self, x: R) -> T; 17 | } 18 | 19 | /// Fittable algorithms 20 | /// 21 | /// A fittable algorithm takes a dataset and creates a concept of some kind about it. For example 22 | /// in *KMeans* this would be the mean values for each class, or in *SVM* the separating 23 | /// hyperplane. It returns a model, which can be used to predict targets for new data. 24 | pub trait Fit> { 25 | type Object; 26 | 27 | fn fit(&self, dataset: &DatasetBase) -> Result; 28 | } 29 | 30 | /// Incremental algorithms 31 | /// 32 | /// An incremental algorithm takes a former model and dataset and returns a new model with updated 33 | /// parameters. If the former model is `None`, then the function acts like `Fit::fit` and 34 | /// initializes the model first. 35 | pub trait FitWith<'a, R: Records, T, E: std::error::Error + From> { 36 | type ObjectIn: 'a; 37 | type ObjectOut: 'a; 38 | 39 | fn fit_with( 40 | &self, 41 | model: Self::ObjectIn, 42 | dataset: &'a DatasetBase, 43 | ) -> Result; 44 | } 45 | 46 | /// Predict with model 47 | /// 48 | /// This trait assumes the `PredictInplace` implementation and provides additional input/output 49 | /// combinations. 50 | /// 51 | /// # Provided implementation 52 | /// 53 | /// ```rust, ignore 54 | /// use linfa::traits::Predict; 55 | /// 56 | /// // predict targets with reference to dataset (&Dataset -> Array) 57 | /// let pred_targets = model.predict(&dataset); 58 | /// // predict targets inside dataset (Dataset -> Dataset) 59 | /// let pred_dataset = model.predict(dataset); 60 | /// // or use a record datastruct directly (Array -> Dataset) 61 | /// let pred_targets = model.predict(x); 62 | /// ``` 63 | pub trait Predict { 64 | fn predict(&self, x: R) -> T; 65 | } 66 | 67 | /// Predict with model into a mutable reference of targets. 68 | pub trait PredictInplace { 69 | /// Predict something in place 70 | fn predict_inplace<'a>(&'a self, x: &'a R, y: &mut T); 71 | 72 | /// Create targets that `predict_inplace` works with. 73 | fn default_target(&self, x: &R) -> T; 74 | } 75 | --------------------------------------------------------------------------------