├── .formatter.exs ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── benchmarks └── knn.exs ├── config └── config.exs ├── images ├── scholar.png └── scholar_simplified.png ├── lib └── scholar │ ├── cluster │ ├── affinity_propagation.ex │ ├── dbscan.ex │ ├── gmm.ex │ ├── hierarchical.ex │ ├── k_means.ex │ └── optics.ex │ ├── covariance │ ├── ledoit_wolf.ex │ ├── shrunk_covariance.ex │ └── utils.ex │ ├── cross_decomposition │ └── pls_svd.ex │ ├── decomposition │ ├── pca.ex │ ├── truncated_svd.ex │ └── utils.ex │ ├── feature_extraction │ └── count_vectorizer.ex │ ├── impute │ ├── knn_imputter.ex │ └── simple_imputer.ex │ ├── integrate │ └── integrate.ex │ ├── interpolation │ ├── bezier_spline.ex │ ├── cubic_spline.ex │ └── linear.ex │ ├── linear │ ├── bayesian_ridge_regression.ex │ ├── isotonic_regression.ex │ ├── linear_helpers.ex │ ├── linear_regression.ex │ ├── logistic_regression.ex │ ├── polynomial_regression.ex │ ├── ridge_regression.ex │ └── svm.ex │ ├── manifold │ ├── mds.ex │ ├── trimap.ex │ └── tsne.ex │ ├── metrics │ ├── classification.ex │ ├── clustering.ex │ ├── distance.ex │ ├── neighbors.ex │ ├── ranking.ex │ ├── regression.ex │ └── similarity.ex │ ├── model_selection.ex │ ├── naive_bayes │ ├── bernoulli.ex │ ├── complement.ex │ ├── gaussian.ex │ └── multinomial.ex │ ├── neighbors │ ├── brute_knn.ex │ ├── kd_tree.ex │ ├── knn_classifier.ex │ ├── knn_regressor.ex │ ├── large_vis.ex │ ├── nn_descent.ex │ ├── random_projection_forest.ex │ ├── rnn_classifier.ex │ ├── rnn_regressor.ex │ └── utils.ex │ ├── options.ex │ ├── preprocessing.ex │ ├── preprocessing │ ├── binarizer.ex │ ├── max_abs_scaler.ex │ ├── min_max_scaler.ex │ ├── normalizer.ex │ ├── one_hot_encoder.ex │ ├── ordinal_encoder.ex │ ├── robust_scaler.ex │ └── standard_scaler.ex │ ├── shared.ex │ └── stats.ex ├── mix.exs ├── mix.lock ├── notebooks ├── cv_gradient_boosting_tree.livemd ├── files │ ├── knn_gyms.png │ ├── mammoth.png │ ├── rbf_and_kdtree.png │ └── rpKNN.png ├── hierarchical_clustering.livemd ├── k_means.livemd ├── k_nearest_neighbors.livemd ├── linear_regression.livemd ├── manifold_learning.livemd ├── mds.livemd └── nearest_neighbors.livemd └── test ├── data └── pima.csv ├── scholar ├── cluster │ ├── affinity_propagation_test.exs │ ├── dbscan_test.exs │ ├── gmm_test.exs │ ├── hierarchical_test.exs │ └── k_means_test.exs ├── covariance │ ├── ledoit_wolf_test.exs │ └── shrunk_covariance_test.exs ├── cross_decomposition │ └── pls_svd_test.exs ├── decomposition │ ├── pca_test.exs │ └── truncated_svd_test.exs ├── feature_extraction │ └── count_vectorizer_test.exs ├── impute │ ├── knn_imputter_test.exs │ └── simple_imputer_test.exs ├── integrate │ └── integrate_test.exs ├── interpolation │ ├── bezier_spline_test.exs │ ├── cubic_spline_test.exs │ └── linear_test.exs ├── linear │ ├── bayesian_ridge_regression_test.exs │ ├── isotonic_regression_test.exs │ ├── linear_regression_test.exs │ ├── logistic_regression_test.exs │ ├── polynomial_regression_test.exs │ ├── ridge_regression_test.exs │ └── svm_test.exs ├── manifold │ ├── mds_test.exs │ ├── trimap_test.exs │ └── tsne_test.exs ├── metrics │ ├── classification_test.exs │ ├── clustering_test.exs │ ├── distance_test.exs │ ├── neighbors_test.exs │ ├── ranking_test.exs │ ├── regression_test.exs │ └── similarity_test.exs ├── model_selection_test.exs ├── naive_bayes │ ├── bernoulli_test.exs │ ├── complement_test.exs │ ├── gaussian_test.exs │ └── multinomial_test.exs ├── neighbors │ ├── brute_knn_test.exs │ ├── kd_tree_test.exs │ ├── knn_classifier_test.exs │ ├── knn_regressor_test.exs │ ├── large_vis_test.exs │ ├── nn_descent_test.exs │ ├── random_projection_forest_test.exs │ ├── rnn_classifier_test.exs │ └── rnn_regressor_test.exs ├── preprocessing │ ├── binarizer_test.exs │ ├── max_abs_scaler_test.exs │ ├── min_max_scaler_test.exs │ ├── normalizer_test.exs │ ├── one_hot_encoder_test.exs │ ├── ordinal_encoder_test.exs │ ├── robust_scaler_test.exs │ └── standard_scaler_test.exs ├── preprocessing_test.exs └── stats_test.exs ├── support ├── diabetes_data_raw.csv ├── diabetes_target.csv └── scholar_case.ex └── test_helper.exs /.formatter.exs: -------------------------------------------------------------------------------- 1 | # Used by "mix format" 2 | [ 3 | import_deps: [:nx], 4 | inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] 5 | ] 6 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | pull_request: 4 | push: 5 | branches: 6 | - main 7 | jobs: 8 | main: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | fail-fast: false 12 | matrix: 13 | include: 14 | - elixir: "1.18.1" 15 | otp: "27.1" 16 | lint: true 17 | - elixir: "1.15.8" 18 | otp: "26.1" 19 | steps: 20 | - uses: actions/checkout@v2 21 | 22 | - uses: erlef/setup-beam@v1 23 | with: 24 | otp-version: ${{matrix.otp}} 25 | elixir-version: ${{matrix.elixir}} 26 | 27 | - uses: actions/cache@v2 28 | with: 29 | path: | 30 | deps 31 | _build 32 | key: ${{ runner.os }}-mix-${{matrix.elixir}}-${{matrix.otp}}-${{ hashFiles('**/mix.lock') }} 33 | restore-keys: | 34 | ${{ runner.os }}-mix- 35 | 36 | - run: mix deps.get 37 | 38 | - run: mix format --check-formatted 39 | if: ${{ matrix.lint }} 40 | 41 | - run: mix deps.unlock --check-unused 42 | if: ${{ matrix.lint }} 43 | 44 | - run: mix deps.compile 45 | 46 | - run: mix compile --warnings-as-errors 47 | if: ${{ matrix.lint }} 48 | 49 | - run: mix test 50 | 51 | - run: USE_EXLA_AT_COMPILE_TIME=1 mix compile --force 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build/ 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover/ 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps/ 9 | 10 | # Where third-party dependencies like ExDoc output generated docs. 11 | /doc/ 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | # Ignore package tarball (built via "mix hex.build"). 23 | scholar-*.tar 24 | 25 | # Temporary files, for example, from tests. 26 | /tmp/ 27 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## v0.4.0 (2025-01-15) 4 | 5 | * Require Nx `~> 0.9` 6 | * Add batching to regression metrics 7 | * Add `Scholar.Cluster.OPTICS` 8 | * Add `Scholar.Covariance.LedoitWolf` 9 | * Add `Scholar.Covariance.ShrunkCovariance` 10 | * Add `Scholar.CrossDecomposition.PLSSVD` 11 | * Add `Scholar.Decomposition.TruncatedSVD` 12 | * Add `Scholar.Impute.KNNImputter` 13 | * Add `Scholar.NaiveBayes.Bernoulli` 14 | * Add `Scholar.Preprocessing.Binarizer` 15 | * Add `Scholar.Preprocessing.RobustScaler` 16 | * Add `partial_fit/2` and `incremental_fit/2` to PCA 17 | * Split `RNN` into `Scholar.Neighbors.RadiusNNClassifier` and `Scholar.Neighbors.RadiusNNRegressor` 18 | * Unify shape checks across all APIs 19 | 20 | ## v0.3.1 (2024-06-18) 21 | 22 | ### Enhancements 23 | 24 | * Add a notebook about manifold learning 25 | * Make knn algorithm configurable on Trimap 26 | * Add `d2_pinball_score` and `d2_absolute_error_score` 27 | 28 | ## v0.3.0 (2024-05-29) 29 | 30 | ### Enhancements 31 | 32 | * Add LargeVis for visualization of large-scale and high-dimensional data in a low-dimensional (typically 2D or 3D) space 33 | * Add `Scholar.Neighbors.KDTree` and `Scholar.Neighbors.RandomProjectionForest` 34 | * Add `Scholar.Metrics.Neighbors` 35 | * Add `Scholar.Linear.BayesianRidgeRegression` 36 | * Add `Scholar.Cluster.Hierarchical` 37 | * Add `Scholar.Manifold.Trimap` 38 | * Add Mean Pinball Loss function 39 | * Add Matthews Correlation Coefficient function 40 | * Add D2 Tweedie Score function 41 | * Add Mean Tweedie Deviance function 42 | * Add Discounted Cumulative Gain function 43 | * Add Precision Recall f-score function 44 | * Add f-beta score function 45 | * Add convergence check to AffinityPropagation 46 | * Default Affinity Propagation preference to `reduce_min` and make it customizable 47 | * Move preprocessing functionality to their own modules with `fit` and `fit_transform` callbacks 48 | 49 | ### Breaking changes 50 | 51 | * Split `KNearestNeighbors` into `KNNClassifier` and `KNNRegressor` with custom algorithm support 52 | 53 | ## v0.2.1 (2023-08-30) 54 | 55 | ### Enhancements 56 | 57 | * Remove `VegaLite.Data` in favour of future use of `Tucan` 58 | * Do not use EXLA at compile time in `Metrics` 59 | 60 | ## v0.2.0 (2023-08-29) 61 | 62 | This version requires Elixir v1.14+. 63 | 64 | ### Enhancements 65 | 66 | * Update notebooks 67 | * Add support for `:f16` and `:bf16` types in `SVD` 68 | * Add `Affinity Propagation` 69 | * Add `t-SNE` 70 | * Add `Polynomial Regression` 71 | * Replace seeds with `Random.key` 72 | * Add 'unrolling loops' option 73 | * Add support for custom optimizers in `Logistic Regression` 74 | * Add `Trapezoidal Integration` 75 | * Add `AUC-ROC`, `AUC`, and `ROC Curve` 76 | * Add `Simpson rule integration` 77 | * Unify tests 78 | * Add `Radius Nearest Neighbors` 79 | * Add `DBSCAN` 80 | * Add classification metrics: `Average Precision Score`, `Balanced Accuracy Score`, 81 | `Cohen Kappa Score`, `Brier Score Loss`, `Zero-One Loss`, `Top-k Accuracy Score` 82 | * Add regression metrics: `R2 Score`, `MSLE`, `MAPE`, `Maximum Residual Error` 83 | * Add support for axes in `Confusion Matrix` 84 | * Add support for broadcasting in `Metrics.Distances` 85 | * Update CI 86 | * Add `Gaussian Mixtures` 87 | * Add Model selection functionalities: `K-fold`, `K-fold Cross Validation`, `Grid Search` 88 | * Change structure of metrics in `Scholar` 89 | * Add a guide with `Cross-Validation` and `Grid Search` 90 | 91 | ## v0.1.0 (2023-03-29) 92 | 93 | First release. 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Scholar 3 |
4 | 5 | 6 |

7 | 8 |
9 | 10 | Traditional machine learning tools built on top of Nx. Scholar implements 11 | several algorithms for classification, regression, clustering, dimensionality 12 | reduction, metrics, and preprocessing. 13 | 14 | For deep learning, see [Axon](https://github.com/elixir-nx/axon). 15 | For decision trees/forests, see [EXGBoost](https://github.com/acalejos/exgboost). 16 | 17 | ## Installation 18 | 19 | ### Mix projects 20 | 21 | Add to your `mix.exs`: 22 | 23 | ```elixir 24 | def deps do 25 | [ 26 | {:scholar, "~> 0.3.0"} 27 | ] 28 | end 29 | ``` 30 | 31 | Besides Scholar, you will most likely want to use an existing Nx compiler/backend, 32 | such as EXLA: 33 | 34 | ```elixir 35 | def deps do 36 | [ 37 | {:scholar, "~> 0.3.0"}, 38 | {:exla, ">= 0.0.0"} 39 | ] 40 | end 41 | ``` 42 | 43 | And then in your `config/config.exs` file: 44 | 45 | ```elixir 46 | import Config 47 | config :nx, :default_backend, EXLA.Backend 48 | # Client can also be set to :cuda / :rocm 49 | config :nx, :default_defn_options, [compiler: EXLA, client: :host] 50 | ``` 51 | 52 | > #### JIT required! {: .warning} 53 | > 54 | > It is important you set the `default_defn_options` as shown in the snippet above, 55 | > as many algorithms in Scholar use loops which are much more memory efficient when 56 | > JIT compiled. 57 | > 58 | > If for some reason you cannot set a default `defn` compiler, you can explicitly 59 | > JIT any function, for example: `EXLA.jit(&Scholar.Cluster.AffinityPropagation.fit/1)`. 60 | 61 | ### Notebooks 62 | 63 | To use Scholar inside code notebooks, run: 64 | 65 | ```elixir 66 | Mix.install([ 67 | {:scholar, "~> 0.3.0"}, 68 | {:exla, ">= 0.0.0"} 69 | ]) 70 | 71 | Nx.global_default_backend(EXLA.Backend) 72 | # Client can also be set to :cuda / :rocm 73 | Nx.Defn.global_default_options(compiler: EXLA, client: :host) 74 | ``` 75 | 76 | > #### JIT required! {: .warning} 77 | > 78 | > It is important you set the `Nx.Defn.global_default_options/1` as shown in the snippet 79 | > above, as many algorithms in Scholar use loops which are much more memory efficient 80 | > when JIT compiled. 81 | > 82 | > If for some reason you cannot set a default `defn` compiler, you can explicitly 83 | > JIT any function, for example: `EXLA.jit(&Scholar.Cluster.AffinityPropagation.fit/1)`. 84 | 85 | ## License 86 | 87 | Copyright (c) 2022 The Machine Learning Working Group of the Erlang Ecosystem Foundation 88 | 89 | Licensed under the Apache License, Version 2.0 (the "License"); 90 | you may not use this file except in compliance with the License. 91 | You may obtain a copy of the License at [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0) 92 | 93 | Unless required by applicable law or agreed to in writing, software 94 | distributed under the License is distributed on an "AS IS" BASIS, 95 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 96 | See the License for the specific language governing permissions and 97 | limitations under the License. 98 | -------------------------------------------------------------------------------- /benchmarks/knn.exs: -------------------------------------------------------------------------------- 1 | # mix run benchmarks/knn.exs 2 | Nx.global_default_backend(EXLA.Backend) 3 | Nx.Defn.global_default_options(compiler: EXLA) 4 | 5 | key = Nx.Random.key(System.os_time()) 6 | 7 | inputs_knn = %{ 8 | "100x10" => elem(Nx.Random.uniform(key, 0, 100, shape: {100, 10}), 0), 9 | "1000x10" => elem(Nx.Random.uniform(key, 0, 1000, shape: {1000, 10}), 0), 10 | "10000x10" => elem(Nx.Random.uniform(key, 0, 10000, shape: {10000, 10}), 0) 11 | } 12 | 13 | Benchee.run( 14 | %{ 15 | "kdtree" => fn x -> 16 | kdtree = Scholar.Neighbors.KDTree.fit(x) 17 | Scholar.Neighbors.KDTree.predict(kdtree, x, k: 4) 18 | end, 19 | "brute force knn" => fn x -> 20 | model = 21 | Scholar.Neighbors.KNearestNeighbors.fit(x, Nx.broadcast(1, {Nx.axis_size(x, 0)}), 22 | num_classes: 2, 23 | num_neighbors: 4 24 | ) 25 | 26 | Scholar.Neighbors.KNearestNeighbors.k_neighbors(model, x) 27 | end 28 | }, 29 | time: 10, 30 | memory_time: 2, 31 | inputs: inputs_knn 32 | ) 33 | -------------------------------------------------------------------------------- /config/config.exs: -------------------------------------------------------------------------------- 1 | import Config 2 | 3 | config :exla, :add_backend_on_inspect, config_env() != :test 4 | 5 | if System.get_env("USE_EXLA_AT_COMPILE_TIME") do 6 | config :nx, :default_backend, EXLA.Backend 7 | end 8 | -------------------------------------------------------------------------------- /images/scholar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/scholar/48aa36e95dd14fa51e3ad5e23fc72c3a3eed3877/images/scholar.png -------------------------------------------------------------------------------- /images/scholar_simplified.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/scholar/48aa36e95dd14fa51e3ad5e23fc72c3a3eed3877/images/scholar_simplified.png -------------------------------------------------------------------------------- /lib/scholar/covariance/ledoit_wolf.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Covariance.LedoitWolf do 2 | @moduledoc """ 3 | Ledoit-Wolf is a particular form of shrinkage covariance estimator, 4 | where the shrinkage coefficient is computed using O. Ledoit and M. Wolf’s formula. 5 | 6 | Ledoit and M. Wolf's formula as 7 | described in "A Well-Conditioned Estimator for Large-Dimensional 8 | Covariance Matrices", Ledoit and Wolf, Journal of Multivariate 9 | Analysis, Volume 88, Issue 2, February 2004, pages 365-411. 10 | """ 11 | import Nx.Defn 12 | 13 | @derive {Nx.Container, containers: [:covariance, :shrinkage, :location]} 14 | defstruct [:covariance, :shrinkage, :location] 15 | 16 | opts_schema = [ 17 | assume_centered?: [ 18 | default: false, 19 | type: :boolean, 20 | doc: """ 21 | If `true`, data will not be centered before computation. 22 | Useful when working with data whose mean is almost, but not exactly 23 | zero. 24 | If `false`, data will be centered before computation. 25 | """ 26 | ] 27 | ] 28 | 29 | @opts_schema NimbleOptions.new!(opts_schema) 30 | @doc """ 31 | Estimate the shrunk Ledoit-Wolf covariance matrix. 32 | 33 | ## Options 34 | 35 | #{NimbleOptions.docs(@opts_schema)} 36 | 37 | ## Return Values 38 | 39 | The function returns a struct with the following parameters: 40 | 41 | * `:covariance` - Tensor of shape `{num_features, num_features}`. Estimated covariance matrix. 42 | 43 | * `:shrinkage` - Coefficient in the convex combination used for the computation of the shrunken estimate. Range is `[0, 1]`. 44 | 45 | * `:location` - Tensor of shape `{num_features,}`. 46 | Estimated location, i.e. the estimated mean. 47 | 48 | ## Examples 49 | 50 | iex> key = Nx.Random.key(0) 51 | iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0]), Nx.tensor([[0.4, 0.2], [0.2, 0.8]]), shape: {50}, type: :f32) 52 | iex> model = Scholar.Covariance.LedoitWolf.fit(x) 53 | iex> model.covariance 54 | #Nx.Tensor< 55 | f32[2][2] 56 | [ 57 | [0.3557686507701874, 0.17340737581253052], 58 | [0.17340737581253052, 1.0300586223602295] 59 | ] 60 | > 61 | iex> model.shrinkage 62 | #Nx.Tensor< 63 | f32 64 | 0.15034137666225433 65 | > 66 | iex> model.location 67 | #Nx.Tensor< 68 | f32[2] 69 | [0.17184630036354065, 0.3276958167552948] 70 | > 71 | 72 | iex> key = Nx.Random.key(0) 73 | iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0, 0.0]), Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), shape: {10}, type: :f32) 74 | iex> model = Scholar.Covariance.LedoitWolf.fit(x) 75 | iex> model.covariance 76 | #Nx.Tensor< 77 | f32[3][3] 78 | [ 79 | [2.5945029258728027, 1.5078359842300415, 1.1623677015304565], 80 | [1.5078359842300415, 2.106797456741333, 1.1812156438827515], 81 | [1.1623677015304565, 1.1812156438827515, 1.4606266021728516] 82 | ] 83 | > 84 | iex> model.shrinkage 85 | #Nx.Tensor< 86 | f32 87 | 0.1908363401889801 88 | > 89 | iex> model.location 90 | #Nx.Tensor< 91 | f32[3] 92 | [1.1228725910186768, 0.5419300198554993, 0.8678852319717407] 93 | > 94 | 95 | iex> key = Nx.Random.key(0) 96 | iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0, 0.0]), Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), shape: {10}, type: :f32) 97 | iex> cov = Scholar.Covariance.LedoitWolf.fit(x, assume_centered?: true) 98 | iex> cov.covariance 99 | #Nx.Tensor< 100 | f32[3][3] 101 | [ 102 | [3.8574986457824707, 2.2048025131225586, 2.1504499912261963], 103 | [2.2048025131225586, 2.4572863578796387, 1.7215262651443481], 104 | [2.1504499912261963, 1.7215262651443481, 2.154898166656494] 105 | ] 106 | > 107 | """ 108 | 109 | deftransform fit(x, opts \\ []) do 110 | fit_n(x, NimbleOptions.validate!(opts, @opts_schema)) 111 | end 112 | 113 | defnp fit_n(x, opts) do 114 | {x, location} = Scholar.Covariance.Utils.center(x, opts[:assume_centered?]) 115 | 116 | {covariance, shrinkage} = 117 | ledoit_wolf(x) 118 | 119 | %__MODULE__{ 120 | covariance: covariance, 121 | shrinkage: shrinkage, 122 | location: location 123 | } 124 | end 125 | 126 | defnp ledoit_wolf(x) do 127 | case Nx.shape(x) do 128 | {_n, 1} -> 129 | {Nx.mean(x ** 2) |> Nx.reshape({1, 1}), 0.0} 130 | 131 | _ -> 132 | ledoit_wolf_shrinkage(x) 133 | end 134 | end 135 | 136 | defnp ledoit_wolf_shrinkage(x) do 137 | case Nx.shape(x) do 138 | {_, 1} -> 139 | 0 140 | 141 | {n} -> 142 | Nx.reshape(x, {1, n}) 143 | |> ledoit_wolf_shrinkage_complex() 144 | 145 | _ -> 146 | ledoit_wolf_shrinkage_complex(x) 147 | end 148 | end 149 | 150 | defnp ledoit_wolf_shrinkage_complex(x) do 151 | {num_samples, num_features} = Nx.shape(x) 152 | emp_cov = Scholar.Covariance.Utils.empirical_covariance(x) 153 | 154 | emp_cov_trace = Scholar.Covariance.Utils.trace(emp_cov) 155 | mu = Nx.sum(emp_cov_trace) / num_features 156 | 157 | flatten_delta = Nx.flatten(emp_cov) 158 | 159 | indices = 160 | Nx.shape(flatten_delta) 161 | |> Nx.iota() 162 | 163 | subtract = Nx.select(Nx.remainder(indices, num_features + 1) == 0, mu, 0) 164 | 165 | delta = 166 | (flatten_delta - subtract) 167 | |> Nx.pow(2) 168 | |> Nx.sum() 169 | 170 | delta = delta / num_features 171 | 172 | x2 = Nx.pow(x, 2) 173 | 174 | beta = 175 | (Nx.dot(x2, [0], x2, [0]) / num_samples - emp_cov ** 2) 176 | |> Nx.sum() 177 | |> Nx.divide(num_features * num_samples) 178 | 179 | beta = Nx.min(beta, delta) 180 | shrinkage = beta / delta 181 | 182 | shrunk_cov = (1.0 - shrinkage) * emp_cov 183 | mask = Nx.iota(Nx.shape(shrunk_cov)) 184 | selector = Nx.remainder(mask, num_features + 1) == 0 185 | shrunk_cov = Nx.select(selector, shrunk_cov + shrinkage * mu, shrunk_cov) 186 | 187 | {shrunk_cov, shrinkage} 188 | end 189 | end 190 | -------------------------------------------------------------------------------- /lib/scholar/covariance/shrunk_covariance.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Covariance.ShrunkCovariance do 2 | @moduledoc """ 3 | Covariance estimator with shrinkage. 4 | """ 5 | import Nx.Defn 6 | 7 | @derive {Nx.Container, containers: [:covariance, :location]} 8 | defstruct [:covariance, :location] 9 | 10 | opts_schema = [ 11 | assume_centered?: [ 12 | default: false, 13 | type: :boolean, 14 | doc: """ 15 | If `true`, data will not be centered before computation. 16 | Useful when working with data whose mean is almost, but not exactly 17 | zero. 18 | If `false`, data will be centered before computation. 19 | """ 20 | ], 21 | shrinkage: [ 22 | default: 0.1, 23 | type: :float, 24 | doc: "Coefficient in the convex combination used for the computation 25 | of the shrunk estimate. Range is [0, 1]." 26 | ] 27 | ] 28 | 29 | @opts_schema NimbleOptions.new!(opts_schema) 30 | @doc """ 31 | Fit the shrunk covariance model to `x`. 32 | 33 | ## Options 34 | 35 | #{NimbleOptions.docs(@opts_schema)} 36 | 37 | ## Return Values 38 | 39 | The function returns a struct with the following parameters: 40 | 41 | * `:covariance` - Tensor of shape `{num_features, num_features}`. Estimated covariance matrix. 42 | * `:location` - Tensor of shape `{num_features,}`. 43 | Estimated location, i.e. the estimated mean. 44 | 45 | ## Examples 46 | 47 | iex> key = Nx.Random.key(0) 48 | iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0]), Nx.tensor([[0.8, 0.3], [0.2, 0.4]]), shape: {10}, type: :f32) 49 | iex> model = Scholar.Covariance.ShrunkCovariance.fit(x) 50 | iex> model.covariance 51 | #Nx.Tensor< 52 | f32[2][2] 53 | [ 54 | [0.7721845507621765, 0.19141492247581482], 55 | [0.19141492247581482, 0.33952537178993225] 56 | ] 57 | > 58 | iex> model.location 59 | #Nx.Tensor< 60 | f32[2] 61 | [0.18202415108680725, -0.09216632694005966] 62 | > 63 | iex> key = Nx.Random.key(0) 64 | iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0]), Nx.tensor([[0.8, 0.3], [0.2, 0.4]]), shape: {10}, type: :f32) 65 | iex> model = Scholar.Covariance.ShrunkCovariance.fit(x, shrinkage: 0.4) 66 | iex> model.covariance 67 | #Nx.Tensor< 68 | f32[2][2] 69 | [ 70 | [0.7000747323036194, 0.1276099532842636], 71 | [0.1276099532842636, 0.41163527965545654] 72 | ] 73 | > 74 | iex> model.location 75 | #Nx.Tensor< 76 | f32[2] 77 | [0.18202415108680725, -0.09216632694005966] 78 | > 79 | """ 80 | 81 | deftransform fit(x, opts \\ []) do 82 | fit_n(x, NimbleOptions.validate!(opts, @opts_schema)) 83 | end 84 | 85 | defnp fit_n(x, opts) do 86 | shrinkage = opts[:shrinkage] 87 | 88 | if shrinkage < 0 or shrinkage > 1 do 89 | raise ArgumentError, 90 | """ 91 | expected :shrinkage option to be in [0, 1] range, \ 92 | got shrinkage: #{inspect(Nx.shape(x))}\ 93 | """ 94 | end 95 | 96 | {x, location} = Scholar.Covariance.Utils.center(x, opts[:assume_centered?]) 97 | 98 | covariance = 99 | Scholar.Covariance.Utils.empirical_covariance(x) 100 | |> shrunk_covariance(shrinkage) 101 | 102 | %__MODULE__{ 103 | covariance: covariance, 104 | location: location 105 | } 106 | end 107 | 108 | defnp shrunk_covariance(emp_cov, shrinkage) do 109 | num_features = Nx.axis_size(emp_cov, 1) 110 | shrunk_cov = (1.0 - shrinkage) * emp_cov 111 | emp_cov_trace = Scholar.Covariance.Utils.trace(emp_cov) 112 | mu = Nx.sum(emp_cov_trace) / num_features 113 | 114 | mask = Nx.iota(Nx.shape(shrunk_cov)) 115 | selector = Nx.remainder(mask, num_features + 1) == 0 116 | 117 | shrunk_cov + shrinkage * mu * selector 118 | end 119 | end 120 | -------------------------------------------------------------------------------- /lib/scholar/covariance/utils.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Covariance.Utils do 2 | @moduledoc false 3 | import Nx.Defn 4 | require Nx 5 | 6 | defn center(x, assume_centered? \\ false) do 7 | x = 8 | case Nx.shape(x) do 9 | {_} -> Nx.new_axis(x, 1) 10 | _ -> x 11 | end 12 | 13 | location = 14 | if assume_centered? do 15 | 0 16 | else 17 | Nx.mean(x, axes: [0]) 18 | end 19 | 20 | {x - location, location} 21 | end 22 | 23 | defn empirical_covariance(x) do 24 | n = Nx.axis_size(x, 0) 25 | 26 | covariance = Nx.dot(x, [0], x, [0]) / n 27 | 28 | case Nx.shape(covariance) do 29 | {} -> Nx.reshape(covariance, {1, 1}) 30 | _ -> covariance 31 | end 32 | end 33 | 34 | defn trace(x) do 35 | x 36 | |> Nx.take_diagonal() 37 | |> Nx.sum() 38 | end 39 | end 40 | -------------------------------------------------------------------------------- /lib/scholar/decomposition/utils.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Decomposition.Utils do 2 | @moduledoc false 3 | import Nx.Defn 4 | require Nx 5 | 6 | defn flip_svd(u, v, u_based \\ true) do 7 | base = 8 | if u_based do 9 | u 10 | else 11 | Nx.transpose(v) 12 | end 13 | 14 | max_abs_cols_idx = base |> Nx.abs() |> Nx.argmax(axis: 0, keep_axis: true) 15 | signs = base |> Nx.take_along_axis(max_abs_cols_idx, axis: 0) |> Nx.sign() |> Nx.squeeze() 16 | u = u * signs 17 | v = v * Nx.new_axis(signs, -1) 18 | {u, v} 19 | end 20 | end 21 | -------------------------------------------------------------------------------- /lib/scholar/feature_extraction/count_vectorizer.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.FeatureExtraction.CountVectorizer do 2 | @moduledoc """ 3 | A `CountVectorizer` converts already indexed collection of text documents to a matrix of token counts. 4 | """ 5 | import Nx.Defn 6 | 7 | opts_schema = [ 8 | max_token_id: [ 9 | type: :pos_integer, 10 | required: true, 11 | doc: ~S""" 12 | Maximum token id in the input tensor. 13 | """ 14 | ] 15 | ] 16 | 17 | @opts_schema NimbleOptions.new!(opts_schema) 18 | 19 | @doc """ 20 | Generates a count matrix where each row corresponds to a document in the input corpus, 21 | and each column corresponds to a unique token in the vocabulary of the corpus. 22 | 23 | The input must be a 2D tensor where: 24 | 25 | * Each row represents a document. 26 | * Each document has integer values representing tokens. 27 | 28 | The same number represents the same token in the vocabulary. Tokens should start from 0 29 | and be consecutive. Negative values are ignored, making them suitable for padding. 30 | 31 | ## Options 32 | 33 | #{NimbleOptions.docs(@opts_schema)} 34 | 35 | ## Examples 36 | 37 | iex> t = Nx.tensor([[0, 1, 2], [1, 3, 4]]) 38 | iex> Scholar.FeatureExtraction.CountVectorizer.fit_transform(t, max_token_id: Scholar.FeatureExtraction.CountVectorizer.max_token_id(t)) 39 | Nx.tensor([ 40 | [1, 1, 1, 0, 0], 41 | [0, 1, 0, 1, 1] 42 | ]) 43 | 44 | With padding: 45 | 46 | iex> t = Nx.tensor([[0, 1, -1], [1, 3, 4]]) 47 | iex> Scholar.FeatureExtraction.CountVectorizer.fit_transform(t, max_token_id: Scholar.FeatureExtraction.CountVectorizer.max_token_id(t)) 48 | Nx.tensor([ 49 | [1, 1, 0, 0, 0], 50 | [0, 1, 0, 1, 1] 51 | ]) 52 | """ 53 | deftransform fit_transform(tensor, opts \\ []) do 54 | fit_transform_n(tensor, NimbleOptions.validate!(opts, @opts_schema)) 55 | end 56 | 57 | @doc """ 58 | Computes the max_token_id option from given tensor. 59 | 60 | This function cannot be called inside `defn` (and it will raise 61 | if you try to do so). 62 | 63 | ## Examples 64 | 65 | iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 66 | iex> Scholar.FeatureExtraction.CountVectorizer.max_token_id(t) 67 | 2 68 | """ 69 | def max_token_id(tensor) do 70 | tensor |> Nx.reduce_max() |> Nx.to_number() 71 | end 72 | 73 | defnp fit_transform_n(tensor, opts) do 74 | check_for_rank(tensor) 75 | counts = Nx.broadcast(0, {Nx.axis_size(tensor, 0), opts[:max_token_id] + 1}) 76 | 77 | {_, counts} = 78 | while {{i = 0, tensor}, counts}, Nx.less(i, Nx.axis_size(tensor, 0)) do 79 | {_, counts} = 80 | while {{j = 0, i, tensor}, counts}, Nx.less(j, Nx.axis_size(tensor, 1)) do 81 | index = tensor[i][j] 82 | 83 | counts = 84 | if Nx.any(Nx.less(index, 0)), 85 | do: counts, 86 | else: Nx.indexed_add(counts, Nx.stack([i, index]), 1) 87 | 88 | {{j + 1, i, tensor}, counts} 89 | end 90 | 91 | {{i + 1, tensor}, counts} 92 | end 93 | 94 | counts 95 | end 96 | 97 | defnp check_for_rank(tensor) do 98 | if Nx.rank(tensor) != 2 do 99 | raise ArgumentError, 100 | """ 101 | expected tensor to have shape {num_documents, num_tokens}, \ 102 | got tensor with shape: #{inspect(Nx.shape(tensor))}\ 103 | """ 104 | end 105 | end 106 | end 107 | -------------------------------------------------------------------------------- /lib/scholar/linear/linear_helpers.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Linear.LinearHelpers do 2 | require Nx 3 | import Nx.Defn 4 | import Scholar.Shared 5 | 6 | @moduledoc false 7 | 8 | defp valid_column_vector?(y, n_samples) do 9 | Nx.shape(y) == {n_samples, 1} and Nx.rank(y) == 2 10 | end 11 | 12 | @doc false 13 | def flatten_column_vector(y, n_samples) do 14 | is_column_vector? = valid_column_vector?(y, n_samples) 15 | 16 | if is_column_vector? do 17 | y |> Nx.flatten() 18 | else 19 | y 20 | end 21 | end 22 | 23 | @doc false 24 | def validate_y_shape(y, n_samples, module_name) do 25 | y = flatten_column_vector(y, n_samples) 26 | is_valid_target? = Nx.rank(y) == 1 27 | 28 | if not is_valid_target? do 29 | message = 30 | "#{inspect(module_name)} expected y to have shape {n_samples}, got tensor with shape: #{inspect(Nx.shape(y))}" 31 | 32 | raise ArgumentError, message 33 | else 34 | y 35 | end 36 | end 37 | 38 | @doc false 39 | def build_sample_weights(x, opts) do 40 | x_type = to_float_type(x) 41 | {num_samples, _} = Nx.shape(x) 42 | default_sample_weights = Nx.broadcast(Nx.as_type(1.0, x_type), {num_samples}) 43 | {sample_weights, _} = Keyword.pop(opts, :sample_weights, default_sample_weights) 44 | 45 | # this is required for ridge regression 46 | sample_weights = 47 | if Nx.is_tensor(sample_weights), 48 | do: Nx.as_type(sample_weights, x_type), 49 | else: Nx.tensor(sample_weights, type: x_type) 50 | 51 | sample_weights 52 | end 53 | 54 | @doc false 55 | defn preprocess_data(x, y, sample_weights, opts) do 56 | if opts[:sample_weights_flag], 57 | do: 58 | {Nx.weighted_mean(x, sample_weights, axes: [0]), 59 | Nx.weighted_mean(y, sample_weights, axes: [0])}, 60 | else: {Nx.mean(x, axes: [0]), Nx.mean(y, axes: [0])} 61 | end 62 | 63 | @doc false 64 | defn set_intercept(coeff, x_offset, y_offset, fit_intercept?) do 65 | if fit_intercept? do 66 | y_offset - Nx.dot(coeff, [-1], x_offset, [-1]) 67 | else 68 | Nx.tensor(0.0, type: Nx.type(coeff)) 69 | end 70 | end 71 | 72 | # Implements sample weighting by rescaling inputs and 73 | # targets by sqrt(sample_weight). 74 | @doc false 75 | defn rescale(x, y, sample_weights) do 76 | factor = Nx.sqrt(sample_weights) 77 | 78 | x_scaled = 79 | case Nx.shape(factor) do 80 | {} -> factor * x 81 | _ -> x * Nx.new_axis(factor, -1) 82 | end 83 | 84 | y_scaled = 85 | case Nx.rank(y) do 86 | 1 -> factor * y 87 | _ -> y * Nx.new_axis(factor, -1) 88 | end 89 | 90 | {x_scaled, y_scaled} 91 | end 92 | end 93 | -------------------------------------------------------------------------------- /lib/scholar/linear/linear_regression.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Linear.LinearRegression do 2 | @moduledoc """ 3 | Ordinary least squares linear regression. 4 | 5 | Time complexity of linear regression is $O((K^2) * (K+N))$ where $N$ is the number of samples 6 | and $K$ is the number of features. 7 | """ 8 | require Nx 9 | import Nx.Defn 10 | import Scholar.Shared 11 | alias Scholar.Linear.LinearHelpers 12 | 13 | @derive {Nx.Container, containers: [:coefficients, :intercept]} 14 | defstruct [:coefficients, :intercept] 15 | 16 | opts = [ 17 | sample_weights: [ 18 | type: {:custom, Scholar.Options, :weights, []}, 19 | doc: """ 20 | The weights for each observation. If not provided, 21 | all observations are assigned equal weight. 22 | """ 23 | ], 24 | fit_intercept?: [ 25 | type: :boolean, 26 | default: true, 27 | doc: """ 28 | If set to `true`, a model will fit the intercept. Otherwise, 29 | the intercept is set to `0.0`. The intercept is an independent term 30 | in a linear model. Specifically, it is the expected mean value 31 | of targets for a zero-vector on input. 32 | """ 33 | ] 34 | ] 35 | 36 | @opts_schema NimbleOptions.new!(opts) 37 | 38 | @doc """ 39 | Fits a linear regression model for sample inputs `x` and 40 | sample targets `y`. 41 | 42 | ## Options 43 | 44 | #{NimbleOptions.docs(@opts_schema)} 45 | 46 | ## Return Values 47 | 48 | The function returns a struct with the following parameters: 49 | 50 | * `:coefficients` - Estimated coefficients for the linear regression problem. 51 | 52 | * `:intercept` - Independent term in the linear model. 53 | 54 | ## Examples 55 | 56 | iex> x = Nx.tensor([[1.0, 2.0], [3.0, 2.0], [4.0, 7.0]]) 57 | iex> y = Nx.tensor([4.0, 3.0, -1.0]) 58 | iex> model = Scholar.Linear.LinearRegression.fit(x, y) 59 | iex> model.coefficients 60 | #Nx.Tensor< 61 | f32[2] 62 | [-0.49724647402763367, -0.7010394930839539] 63 | > 64 | iex> model.intercept 65 | #Nx.Tensor< 66 | f32 67 | 5.8964691162109375 68 | > 69 | """ 70 | deftransform fit(x, y, opts \\ []) do 71 | {n_samples, _} = Nx.shape(x) 72 | y = LinearHelpers.flatten_column_vector(y, n_samples) 73 | opts = NimbleOptions.validate!(opts, @opts_schema) 74 | 75 | opts = 76 | [ 77 | sample_weights_flag: opts[:sample_weights] != nil 78 | ] ++ 79 | opts 80 | 81 | sample_weights = LinearHelpers.build_sample_weights(x, opts) 82 | {n_samples, _} = Nx.shape(x) 83 | y = LinearHelpers.flatten_column_vector(y, n_samples) 84 | 85 | fit_n(x, y, sample_weights, opts) 86 | end 87 | 88 | defnp fit_n(a, b, sample_weights, opts) do 89 | a = to_float(a) 90 | b = to_float(b) 91 | 92 | {a_offset, b_offset} = 93 | if opts[:fit_intercept?] do 94 | LinearHelpers.preprocess_data(a, b, sample_weights, opts) 95 | else 96 | a_offset_shape = Nx.axis_size(a, 1) 97 | b_reshaped = if Nx.rank(b) > 1, do: b, else: Nx.reshape(b, {:auto, 1}) 98 | b_offset_shape = Nx.axis_size(b_reshaped, 1) 99 | 100 | {Nx.broadcast(Nx.tensor(0.0, type: Nx.type(a)), {a_offset_shape}), 101 | Nx.broadcast(Nx.tensor(0.0, type: Nx.type(b)), {b_offset_shape})} 102 | end 103 | 104 | {a, b} = {a - a_offset, b - b_offset} 105 | 106 | {a, b} = 107 | if opts[:sample_weights_flag] do 108 | LinearHelpers.rescale(a, b, sample_weights) 109 | else 110 | {a, b} 111 | end 112 | 113 | {coeff, intercept} = lstsq(a, b, a_offset, b_offset, opts[:fit_intercept?]) 114 | %__MODULE__{coefficients: coeff, intercept: intercept} 115 | end 116 | 117 | @doc """ 118 | Makes predictions with the given `model` on input `x`. 119 | 120 | Output predictions have shape `{n_samples}` when train target is shaped either `{n_samples}` or `{n_samples, 1}`. 121 | Otherwise, predictions match train target shape. 122 | 123 | ## Examples 124 | 125 | iex> x = Nx.tensor([[1.0, 2.0], [3.0, 2.0], [4.0, 7.0]]) 126 | iex> y = Nx.tensor([4.0, 3.0, -1.0]) 127 | iex> model = Scholar.Linear.LinearRegression.fit(x, y) 128 | iex> Scholar.Linear.LinearRegression.predict(model, Nx.tensor([[2.0, 1.0]])) 129 | Nx.tensor( 130 | [4.200936794281006] 131 | ) 132 | """ 133 | defn predict(%__MODULE__{coefficients: coeff, intercept: intercept} = _model, x) do 134 | Nx.dot(x, [-1], coeff, [-1]) + intercept 135 | end 136 | 137 | # Implements ordinary least-squares by estimating the 138 | # solution A to the equation A.X = b. 139 | defnp lstsq(a, b, a_offset, b_offset, fit_intercept?) do 140 | pinv = Nx.LinAlg.pinv(a) 141 | coeff = Nx.dot(b, [0], pinv, [1]) 142 | intercept = LinearHelpers.set_intercept(coeff, a_offset, b_offset, fit_intercept?) 143 | {coeff, intercept} 144 | end 145 | end 146 | -------------------------------------------------------------------------------- /lib/scholar/metrics/clustering.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Metrics.Clustering do 2 | @moduledoc """ 3 | Metrics related to clustering algorithms. 4 | """ 5 | 6 | import Nx.Defn 7 | import Scholar.Shared 8 | 9 | opts = [ 10 | num_clusters: [ 11 | required: true, 12 | type: :pos_integer, 13 | doc: "Number of clusters in clustering." 14 | ] 15 | ] 16 | 17 | @opts_schema NimbleOptions.new!(opts) 18 | 19 | @doc """ 20 | Compute the Silhouette Coefficient for each sample. 21 | 22 | The silhouette value is a measure of how similar an object is to its own cluster (cohesion) 23 | compared to other clusters (separation). The silhouette ranges from −1 to +1, where a high 24 | value indicates that the object is well matched to its own cluster and poorly 25 | matched to neighboring clusters. If most objects have a high value, then the 26 | clustering configuration is appropriate. If many points have a low or negative 27 | value, then the clustering configuration may have too many or too few clusters. 28 | 29 | Time complexity of silhouette score is $O(N^2)$ where $N$ is the number of samples. 30 | 31 | ## Options 32 | 33 | #{NimbleOptions.docs(@opts_schema)} 34 | 35 | ## Examples 36 | 37 | iex> x = Nx.tensor([[0, 0], [1, 0], [1, 1], [3, 3], [4, 4.5]]) 38 | iex> labels = Nx.tensor([0, 0, 0, 1, 1]) 39 | iex> Scholar.Metrics.Clustering.silhouette_samples(x, labels, num_clusters: 2) 40 | #Nx.Tensor< 41 | f32[5] 42 | [0.7647753357887268, 0.7781199216842651, 0.6754303574562073, 0.49344196915626526, 0.6627992987632751] 43 | > 44 | 45 | iex> x = Nx.tensor([[0.1, 0], [0, 1], [22, 65], [42, 3], [4.2, 51]]) 46 | iex> labels = Nx.tensor([0, 1, 2, 1, 1]) 47 | iex> Scholar.Metrics.Clustering.silhouette_samples(x, labels, num_clusters: 3) 48 | #Nx.Tensor< 49 | f32[5] 50 | [0.0, -0.9782054424285889, 0.0, -0.18546827137470245, -0.5929659008979797] 51 | > 52 | """ 53 | deftransform silhouette_samples(x, labels, opts \\ []) do 54 | silhouette_samples_n(x, labels, NimbleOptions.validate!(opts, @opts_schema)) 55 | end 56 | 57 | defnp silhouette_samples_n(x, labels, opts) do 58 | verify_num_clusters(x, opts) 59 | {inner, alone?, outer} = inner_and_outer_dist(x, labels, opts) 60 | result = (outer - inner) / Nx.max(outer, inner) 61 | Nx.select(alone?, 0, result) 62 | end 63 | 64 | @doc """ 65 | Compute the mean Silhouette Coefficient of all samples. 66 | 67 | ## Options 68 | 69 | #{NimbleOptions.docs(@opts_schema)} 70 | 71 | ## Examples 72 | 73 | iex> x = Nx.tensor([[0, 0], [1, 0], [1, 1], [3, 3], [4, 4.5]]) 74 | iex> labels = Nx.tensor([0, 0, 0, 1, 1]) 75 | iex> Scholar.Metrics.Clustering.silhouette_score(x, labels, num_clusters: 2) 76 | #Nx.Tensor< 77 | f32 78 | 0.6749133467674255 79 | > 80 | 81 | iex> x = Nx.tensor([[0.1, 0], [0, 1], [22, 65], [42, 3], [4.2, 51]]) 82 | iex> labels = Nx.tensor([0, 1, 2, 1, 1]) 83 | iex> Scholar.Metrics.Clustering.silhouette_score(x, labels, num_clusters: 3) 84 | #Nx.Tensor< 85 | f32 86 | -0.35132792592048645 87 | > 88 | """ 89 | deftransform silhouette_score(x, labels, opts \\ []) do 90 | silhouette_score_n(x, labels, NimbleOptions.validate!(opts, @opts_schema)) 91 | end 92 | 93 | defnp silhouette_score_n(x, labels, opts) do 94 | Nx.mean(silhouette_samples(x, labels, opts)) 95 | end 96 | 97 | defnp inner_and_outer_dist(x, labels, opts) do 98 | num_clusters = opts[:num_clusters] 99 | num_samples = Nx.axis_size(x, 0) 100 | inf = Nx.Constants.infinity(to_float_type(x)) 101 | pairwise_dist = Scholar.Metrics.Distance.pairwise_euclidean(x) 102 | membership_mask = Nx.reshape(labels, {num_samples, 1}) == Nx.iota({1, num_clusters}) 103 | cluster_size = membership_mask |> Nx.sum(axes: [0]) |> Nx.reshape({1, num_clusters}) 104 | dist_in_cluster = Nx.dot(pairwise_dist, membership_mask) 105 | mean_dist_in_cluster = dist_in_cluster / cluster_size 106 | 107 | alone? = (cluster_size == 1) |> Nx.squeeze() |> Nx.take(labels) 108 | 109 | inner_dist = 110 | (dist_in_cluster / Nx.max(cluster_size - 1, 1)) 111 | |> Nx.take_along_axis(Nx.reshape(labels, {num_samples, 1}), axis: 1) 112 | |> Nx.squeeze(axes: [1]) 113 | 114 | outer_dist = 115 | membership_mask 116 | |> Nx.select(inf, mean_dist_in_cluster) 117 | |> Nx.reduce_min(axes: [1]) 118 | 119 | {inner_dist, alone?, outer_dist} 120 | end 121 | 122 | deftransformp verify_num_clusters(x, opts) do 123 | {num_samples, _} = Nx.shape(x) 124 | 125 | unless opts[:num_clusters] <= num_samples do 126 | raise ArgumentError, 127 | "expected :num_clusters to to be a positive integer in range 1 to #{inspect(num_samples)}, got: #{inspect(opts[:num_clusters])}" 128 | end 129 | end 130 | end 131 | -------------------------------------------------------------------------------- /lib/scholar/metrics/neighbors.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Metrics.Neighbors do 2 | @moduledoc """ 3 | Metrics for evaluating the results of approximate k-nearest neighbor search algorithms. 4 | """ 5 | 6 | import Nx.Defn 7 | 8 | @doc """ 9 | Computes the recall of predicted k-nearest neighbors given the true k-nearest neighbors. 10 | Recall is defined as the average fraction of nearest neighbors the algorithm predicted correctly. 11 | 12 | ## Examples 13 | 14 | iex> neighbors_true = Nx.tensor([[0, 1], [1, 2], [2, 1]]) 15 | iex> Scholar.Metrics.Neighbors.recall(neighbors_true, neighbors_true) 16 | #Nx.Tensor< 17 | f32 18 | 1.0 19 | > 20 | 21 | iex> neighbors_true = Nx.tensor([[0, 1], [1, 2], [2, 1]]) 22 | iex> neighbors_pred = Nx.tensor([[0, 1], [1, 0], [2, 0]]) 23 | iex> Scholar.Metrics.Neighbors.recall(neighbors_true, neighbors_pred) 24 | #Nx.Tensor< 25 | f32 26 | 0.6666666865348816 27 | > 28 | """ 29 | defn recall(neighbors_true, neighbors_pred) do 30 | if Nx.rank(neighbors_true) != 2 do 31 | raise ArgumentError, 32 | """ 33 | expected true neighbors to have shape {num_samples, num_neighbors}, \ 34 | got tensor with shape: #{inspect(Nx.shape(neighbors_true))}\ 35 | """ 36 | end 37 | 38 | if Nx.rank(neighbors_pred) != 2 do 39 | raise ArgumentError, 40 | """ 41 | expected predicted neighbors to have shape {num_samples, num_neighbors}, \ 42 | got tensor with shape: #{inspect(Nx.shape(neighbors_pred))}\ 43 | """ 44 | end 45 | 46 | if Nx.axis_size(neighbors_true, 0) != Nx.axis_size(neighbors_pred, 0) do 47 | raise ArgumentError, 48 | """ 49 | expected true and predicted neighbors to have the same axis 0 size, \ 50 | got #{inspect(Nx.axis_size(neighbors_true, 0))} and #{inspect(Nx.axis_size(neighbors_pred, 0))}\ 51 | """ 52 | end 53 | 54 | if Nx.axis_size(neighbors_true, 1) != Nx.axis_size(neighbors_pred, 1) do 55 | raise ArgumentError, 56 | """ 57 | expected true and predicted neighbors to have the same axis 1 size, \ 58 | got #{inspect(Nx.axis_size(neighbors_true, 1))} and #{inspect(Nx.axis_size(neighbors_pred, 1))}\ 59 | """ 60 | end 61 | 62 | {n, k} = Nx.shape(neighbors_true) 63 | concatenated = Nx.concatenate([neighbors_true, neighbors_pred], axis: 1) |> Nx.sort(axis: 1) 64 | duplicate_mask = concatenated[[.., 0..(2 * k - 2)]] == concatenated[[.., 1..(2 * k - 1)]] 65 | duplicate_mask |> Nx.sum() |> Nx.divide(n * k) 66 | end 67 | end 68 | -------------------------------------------------------------------------------- /lib/scholar/metrics/ranking.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Metrics.Ranking do 2 | @moduledoc """ 3 | Provides metrics and calculations related to ranking quality. 4 | 5 | Ranking metrics evaluate the quality of ordered lists of items, 6 | often used in information retrieval and recommendation systems. 7 | 8 | This module currently supports the following ranking metrics: 9 | * Discounted Cumulative Gain (DCG) 10 | """ 11 | 12 | import Nx.Defn 13 | import Scholar.Shared 14 | require Nx 15 | 16 | @dcg_opts [ 17 | k: [ 18 | type: {:custom, Scholar.Options, :positive_number, []}, 19 | doc: "Truncation parameter to consider only the top-k elements." 20 | ] 21 | ] 22 | 23 | @dcg_opts_schema NimbleOptions.new!(@dcg_opts) 24 | 25 | deftransform dcg(y_true, y_score, opts \\ []) do 26 | dcg_n(y_true, y_score, NimbleOptions.validate!(opts, @dcg_opts_schema)) 27 | end 28 | 29 | @doc """ 30 | ## Options 31 | #{NimbleOptions.docs(@dcg_opts_schema)} 32 | 33 | Computes the DCG based on true relevance scores (`y_true`) and their respective predicted scores (`y_score`). 34 | """ 35 | defn dcg_n(y_true, y_score, opts) do 36 | y_true_shape = Nx.shape(y_true) 37 | y_score_shape = Nx.shape(y_score) 38 | 39 | check_shape(y_true_shape, y_score_shape) 40 | 41 | {adjusted_y_true, adjusted_y_score} = handle_ties(y_true, y_score) 42 | 43 | sorted_indices = Nx.argsort(adjusted_y_score, axis: 0, direction: :desc) 44 | sorted_y_true = Nx.take(adjusted_y_true, sorted_indices) 45 | 46 | truncated_y_true = truncate_at_k(sorted_y_true, opts) 47 | dcg_value(truncated_y_true) 48 | end 49 | 50 | defnp check_shape(y_true, y_pred) do 51 | assert_same_shape!(y_true, y_pred) 52 | end 53 | 54 | defnp handle_ties(y_true, y_score) do 55 | sorted_indices = Nx.argsort(y_score, axis: 0, direction: :desc) 56 | 57 | sorted_y_true = Nx.take(y_true, sorted_indices) 58 | sorted_y_score = Nx.take(y_score, sorted_indices) 59 | 60 | tie_sorted_indices = Nx.argsort(sorted_y_true, axis: 0, direction: :desc) 61 | adjusted_y_true = Nx.take(sorted_y_true, tie_sorted_indices) 62 | adjusted_y_score = Nx.take(sorted_y_score, tie_sorted_indices) 63 | 64 | {adjusted_y_true, adjusted_y_score} 65 | end 66 | 67 | defnp dcg_value(y_true) do 68 | float_y_true = Nx.as_type(y_true, :f32) 69 | 70 | log_tensor = 71 | y_true 72 | |> Nx.shape() 73 | |> Nx.iota() 74 | |> Nx.as_type(:f32) 75 | |> Nx.add(2.0) 76 | |> Nx.log2() 77 | 78 | div_result = Nx.divide(float_y_true, log_tensor) 79 | 80 | Nx.sum(div_result) 81 | end 82 | 83 | defnp truncate_at_k(tensor, opts) do 84 | case opts[:k] do 85 | nil -> 86 | tensor 87 | 88 | _ -> 89 | if opts[:k] > Nx.axis_size(tensor, 0) do 90 | tensor 91 | else 92 | {top_k, _rest} = Nx.split(tensor, opts[:k], axis: 0) 93 | top_k 94 | end 95 | end 96 | end 97 | 98 | @doc """ 99 | Computes the normalized discounted cumulative gain (NDCG) based on true relevance scores `y_true` and their respective predicted scores `y_score`. 100 | 101 | ## Options 102 | 103 | #{NimbleOptions.docs(@dcg_opts_schema)} 104 | 105 | ## Examples 106 | 107 | iex> true_relevance = Nx.tensor([10, 0, 0, 1, 5]) 108 | iex> scores = Nx.tensor([0.1, 0.2, 0.3, 4, 70]) 109 | iex> Scholar.Metrics.Ranking.ndcg_n(true_relevance, scores) 110 | #Nx.Tensor< 111 | f32 112 | 0.6956940293312073 113 | > 114 | iex> scores = Nx.tensor([0.05, 1.1, 1.0, 0.5, 0.0]) 115 | iex> Scholar.Metrics.Ranking.ndcg_n(true_relevance, scores) 116 | #Nx.Tensor< 117 | f32 118 | 0.4936802089214325 119 | > 120 | iex> scores = Nx.tensor([0.05, 1.1, 1.0, 0.5, 0.0]) 121 | iex> Scholar.Metrics.Ranking.ndcg_n(true_relevance, scores, k: 4) 122 | #Nx.Tensor< 123 | f32 124 | 0.352024108171463 125 | > 126 | iex> Scholar.Metrics.Ranking.ndcg_n(true_relevance, true_relevance, k: 4) 127 | #Nx.Tensor< 128 | f32 129 | 1.0 130 | > 131 | """ 132 | defn ndcg_n(y_true, y_score, opts \\ []) do 133 | dcg_n(y_true, y_score, opts) / dcg_n(y_true, y_true, opts) 134 | end 135 | end 136 | -------------------------------------------------------------------------------- /lib/scholar/neighbors/knn_regressor.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Neighbors.KNNRegressor do 2 | @moduledoc """ 3 | K-Nearest Neighbors Regressor. 4 | 5 | Performs regression by computing the (weighted) mean of k-nearest neighbor labels. 6 | """ 7 | import Nx.Defn 8 | require Nx 9 | 10 | @derive {Nx.Container, keep: [:weights], containers: [:algorithm, :labels]} 11 | defstruct [:algorithm, :weights, :labels] 12 | 13 | opts = [ 14 | algorithm: [ 15 | type: :atom, 16 | default: :brute, 17 | doc: """ 18 | Algorithm used to compute the k-nearest neighbors. Possible values: 19 | 20 | * `:brute` - Brute-force search. See `Scholar.Neighbors.BruteKNN` for more details. 21 | 22 | * `:kd_tree` - k-d tree. See `Scholar.Neighbors.KDTree` for more details. 23 | 24 | * `:random_projection_forest` - Random projection forest. See `Scholar.Neighbors.RandomProjectionForest` for more details. 25 | 26 | * Module implementing `fit(data, opts)` and `predict(model, query)`. predict/2 must return a tuple containing indices 27 | of k-nearest neighbors of query points as well as distances between query points and their k-nearest neighbors. 28 | """ 29 | ], 30 | weights: [ 31 | type: {:in, [:uniform, :distance]}, 32 | default: :uniform, 33 | doc: """ 34 | Weight function used in prediction. Possible values: 35 | 36 | * `:uniform` - uniform weights. All points in each neighborhood are weighted equally. 37 | 38 | * `:distance` - weight points by the inverse of their distance. in this case, closer neighbors of 39 | a query point will have a greater influence than neighbors which are further away. 40 | """ 41 | ] 42 | ] 43 | 44 | @opts_schema NimbleOptions.new!(opts) 45 | 46 | @doc """ 47 | Fits a k-NN regressor model. 48 | 49 | ## Options 50 | 51 | #{NimbleOptions.docs(@opts_schema)} 52 | 53 | Algorithm-specific options (e.g. `:num_neighbors`, `:metric`) should be provided together with the regressor options. 54 | 55 | ## Examples 56 | 57 | iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) 58 | iex> y = Nx.tensor([[1], [2], [3], [4], [5]]) 59 | iex> model = Scholar.Neighbors.KNNRegressor.fit(x, y, num_neighbors: 3) 60 | iex> model.algorithm 61 | Scholar.Neighbors.BruteKNN.fit(x, num_neighbors: 3) 62 | iex> model.labels 63 | Nx.tensor([[1], [2], [3], [4], [5]]) 64 | 65 | iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) 66 | iex> y = Nx.tensor([[1], [2], [3], [4], [5]]) 67 | iex> model = Scholar.Neighbors.KNNRegressor.fit(x, y, algorithm: :kd_tree, num_neighbors: 3, metric: {:minkowski, 1}) 68 | iex> model.algorithm 69 | Scholar.Neighbors.KDTree.fit(x, num_neighbors: 3, metric: {:minkowski, 1}) 70 | iex> model.labels 71 | Nx.tensor([[1], [2], [3], [4], [5]]) 72 | 73 | iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) 74 | iex> y = Nx.tensor([[1], [2], [3], [4], [5]]) 75 | iex> key = Nx.Random.key(12) 76 | iex> model = Scholar.Neighbors.KNNRegressor.fit(x, y, algorithm: :random_projection_forest, num_neighbors: 2, num_trees: 4, key: key) 77 | iex> model.algorithm 78 | Scholar.Neighbors.RandomProjectionForest.fit(x, num_neighbors: 2, num_trees: 4, key: key) 79 | iex> model.labels 80 | Nx.tensor([[1], [2], [3], [4], [5]]) 81 | """ 82 | deftransform fit(x, y, opts) do 83 | if Nx.rank(x) != 2 do 84 | raise ArgumentError, 85 | """ 86 | expected x to have shape {num_samples, num_features_in}, \ 87 | got tensor with shape: #{inspect(Nx.shape(x))} 88 | """ 89 | end 90 | 91 | if Nx.rank(y) != 2 do 92 | raise ArgumentError, 93 | """ 94 | expected y to have shape {num_samples, num_features_out}, \ 95 | got tensor with shape: #{inspect(Nx.shape(y))} 96 | """ 97 | end 98 | 99 | if Nx.axis_size(x, 0) != Nx.axis_size(y, 0) do 100 | raise ArgumentError, 101 | """ 102 | expected x and y to have the same first dimension, \ 103 | got #{Nx.axis_size(x, 0)} and #{Nx.axis_size(y, 0)} 104 | """ 105 | end 106 | 107 | {opts, algorithm_opts} = Keyword.split(opts, [:algorithm, :weights]) 108 | opts = NimbleOptions.validate!(opts, @opts_schema) 109 | 110 | algorithm_module = 111 | case opts[:algorithm] do 112 | :brute -> 113 | Scholar.Neighbors.BruteKNN 114 | 115 | :kd_tree -> 116 | Scholar.Neighbors.KDTree 117 | 118 | :random_projection_forest -> 119 | Scholar.Neighbors.RandomProjectionForest 120 | 121 | module when is_atom(module) -> 122 | module 123 | end 124 | 125 | algorithm = algorithm_module.fit(x, algorithm_opts) 126 | 127 | %__MODULE__{ 128 | algorithm: algorithm, 129 | labels: y, 130 | weights: opts[:weights] 131 | } 132 | end 133 | 134 | @doc """ 135 | Predicts labels using a k-NN regressor model. 136 | 137 | ## Examples 138 | 139 | iex> x_train = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) 140 | iex> y_train = Nx.tensor([[1], [2], [3], [4], [5]]) 141 | iex> model = Scholar.Neighbors.KNNRegressor.fit(x_train, y_train, num_neighbors: 3) 142 | iex> x = Nx.tensor([[1, 3], [4, 2], [3, 6]]) 143 | iex> Scholar.Neighbors.KNNRegressor.predict(model, x) 144 | Nx.tensor([[2.0], [2.0], [4.0]]) 145 | """ 146 | defn predict(model, x) do 147 | {neighbors, distances} = compute_knn(model.algorithm, x) 148 | neighbor_labels = Nx.take(model.labels, neighbors) 149 | 150 | case model.weights do 151 | :uniform -> 152 | Nx.mean(neighbor_labels, axes: [1]) 153 | 154 | :distance -> 155 | weights = 156 | Scholar.Neighbors.Utils.check_weights(distances) 157 | |> Nx.new_axis(2) 158 | |> Nx.broadcast(neighbor_labels) 159 | 160 | Nx.weighted_mean(neighbor_labels, weights, axes: [1]) 161 | end 162 | end 163 | 164 | deftransformp compute_knn(algorithm, x) do 165 | algorithm.__struct__.predict(algorithm, x) 166 | end 167 | end 168 | -------------------------------------------------------------------------------- /lib/scholar/neighbors/large_vis.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Neighbors.LargeVis do 2 | @moduledoc """ 3 | LargeVis algorithm for approximate k-nearest neighbor (k-NN) graph construction. 4 | 5 | The algorithms works in the following way. First, the approximate k-NN graph is constructed 6 | using a random projection forest. Then, the graph is refined by looking at the neighbors of 7 | neighbors of every point for a fixed number of iterations. This step is called NN-expansion. 8 | 9 | ## References 10 | 11 | * [Visualizing Large-scale and High-dimensional Data](https://arxiv.org/abs/1602.00370). 12 | """ 13 | 14 | import Nx.Defn 15 | import Scholar.Shared 16 | require Nx 17 | alias Scholar.Neighbors.RandomProjectionForest 18 | alias Scholar.Neighbors.Utils 19 | 20 | opts = [ 21 | num_neighbors: [ 22 | required: true, 23 | type: :pos_integer, 24 | doc: "The number of neighbors in the graph." 25 | ], 26 | metric: [ 27 | type: {:in, [:squared_euclidean, :euclidean]}, 28 | default: :euclidean, 29 | doc: "The function that measures distance between two points." 30 | ], 31 | min_leaf_size: [ 32 | type: :pos_integer, 33 | doc: """ 34 | The minimum number of points in every leaf. 35 | Must be at least num_neighbors. 36 | If not provided, it is set based on the number of neighbors. 37 | """ 38 | ], 39 | num_trees: [ 40 | type: :pos_integer, 41 | doc: """ 42 | The number of trees in random projection forest. 43 | If not provided, it is set based on the dataset size. 44 | """ 45 | ], 46 | num_iters: [ 47 | type: :non_neg_integer, 48 | default: 3, 49 | doc: "The number of times to perform neighborhood expansion." 50 | ], 51 | key: [ 52 | type: {:custom, Scholar.Options, :key, []}, 53 | doc: """ 54 | Used for random number generation in parameter initialization. 55 | If the key is not provided, it is set to `Nx.Random.key(System.system_time())`. 56 | """ 57 | ] 58 | ] 59 | 60 | @opts_schema NimbleOptions.new!(opts) 61 | 62 | @doc """ 63 | Constructs the approximate k-NN graph with LargeVis. 64 | 65 | Returns neighbor indices and distances. 66 | 67 | ## Examples 68 | 69 | iex> key = Nx.Random.key(12) 70 | iex> tensor = Nx.iota({5, 2}) 71 | iex> {graph, distances} = Scholar.Neighbors.LargeVis.fit(tensor, num_neighbors: 2, metric: :squared_euclidean, min_leaf_size: 2, num_trees: 3, key: key) 72 | iex> graph 73 | #Nx.Tensor< 74 | u32[5][2] 75 | [ 76 | [0, 1], 77 | [1, 0], 78 | [2, 1], 79 | [3, 2], 80 | [4, 3] 81 | ] 82 | > 83 | iex> distances 84 | #Nx.Tensor< 85 | f32[5][2] 86 | [ 87 | [0.0, 8.0], 88 | [0.0, 8.0], 89 | [0.0, 8.0], 90 | [0.0, 8.0], 91 | [0.0, 8.0] 92 | ] 93 | > 94 | """ 95 | deftransform fit(tensor, opts) do 96 | if Nx.rank(tensor) != 2 do 97 | raise ArgumentError, 98 | """ 99 | expected input tensor to have shape {num_samples, num_features}, \ 100 | got tensor with shape: #{inspect(Nx.shape(tensor))}\ 101 | """ 102 | end 103 | 104 | opts = NimbleOptions.validate!(opts, @opts_schema) 105 | k = opts[:num_neighbors] 106 | 107 | metric = 108 | case opts[:metric] do 109 | :euclidean -> &Scholar.Metrics.Distance.euclidean/2 110 | :squared_euclidean -> &Scholar.Metrics.Distance.squared_euclidean/2 111 | end 112 | 113 | min_leaf_size = opts[:min_leaf_size] || max(10, 2 * k) 114 | 115 | size = Nx.axis_size(tensor, 0) 116 | num_trees = opts[:num_trees] || 5 + round(:math.pow(size, 0.25)) 117 | key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end) 118 | 119 | fit_n( 120 | tensor, 121 | key, 122 | num_neighbors: k, 123 | metric: metric, 124 | min_leaf_size: min_leaf_size, 125 | num_trees: num_trees, 126 | num_iters: opts[:num_iters] 127 | ) 128 | end 129 | 130 | defnp fit_n(tensor, key, opts) do 131 | forest = 132 | RandomProjectionForest.fit(tensor, 133 | num_neighbors: opts[:num_neighbors], 134 | min_leaf_size: opts[:min_leaf_size], 135 | num_trees: opts[:num_trees], 136 | key: key 137 | ) 138 | 139 | {graph, _} = RandomProjectionForest.predict(forest, tensor) 140 | expand(graph, tensor, metric: opts[:metric], num_iters: opts[:num_iters]) 141 | end 142 | 143 | defn expand(graph, tensor, opts) do 144 | num_iters = opts[:num_iters] 145 | {n, k} = Nx.shape(graph) 146 | 147 | {result, _} = 148 | while { 149 | { 150 | graph, 151 | _distances = Nx.broadcast(Nx.tensor(:nan, type: to_float_type(tensor)), {n, k}) 152 | }, 153 | {tensor, iter = 0} 154 | }, 155 | iter < num_iters do 156 | {expansion_iter(graph, tensor, metric: opts[:metric]), {tensor, iter + 1}} 157 | end 158 | 159 | result 160 | end 161 | 162 | defnp expansion_iter(graph, tensor, opts) do 163 | {size, k} = Nx.shape(graph) 164 | candidate_indices = Nx.take(graph, graph) |> Nx.reshape({size, k * k}) 165 | candidate_indices = Nx.concatenate([graph, candidate_indices], axis: 1) 166 | 167 | Utils.brute_force_search_with_candidates(tensor, tensor, candidate_indices, 168 | num_neighbors: k, 169 | metric: opts[:metric] 170 | ) 171 | end 172 | end 173 | -------------------------------------------------------------------------------- /lib/scholar/neighbors/utils.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Neighbors.Utils do 2 | @moduledoc false 3 | import Nx.Defn 4 | require Nx 5 | 6 | def metric(:cosine), do: {:ok, &Scholar.Metrics.Distance.cosine/2} 7 | 8 | def metric({:minkowski, p}) when p == :infinity or (is_number(p) and p > 0) do 9 | {:ok, &Scholar.Metrics.Distance.minkowski(&1, &2, p: p)} 10 | end 11 | 12 | def metric(metric) when is_function(metric, 2), do: {:ok, metric} 13 | 14 | def metric(metric) do 15 | {:error, 16 | "expected metric to be a 2-arity function, :cosine, tuple {:minkowski, p} where p is a positive number or :infinity, got: #{inspect(metric)}"} 17 | end 18 | 19 | def pairwise_metric(:cosine), do: {:ok, &Scholar.Metrics.Distance.pairwise_cosine/2} 20 | 21 | def pairwise_metric({:minkowski, p}) when p == :infinity or (is_number(p) and p > 0) do 22 | {:ok, &Scholar.Metrics.Distance.pairwise_minkowski(&1, &2, p: p)} 23 | end 24 | 25 | def pairwise_metric(:euclidean), do: {:ok, &Scholar.Metrics.Distance.pairwise_euclidean/2} 26 | 27 | def pairwise_metric(:squared_euclidean), 28 | do: {:ok, &Scholar.Metrics.Distance.pairwise_squared_euclidean/2} 29 | 30 | def pairwise_metric(:manhattan), 31 | do: {:ok, &Scholar.Metrics.Distance.pairwise_minkowski(&1, &2, p: 1)} 32 | 33 | def pairwise_metric(metric) when is_function(metric, 2), do: {:ok, metric} 34 | 35 | def pairwise_metric(metric) do 36 | {:error, 37 | "expected metric to be a 2-arity function, :cosine or tuple {:minkowski, p} where p is a positive number or :infinity, got: #{inspect(metric)}"} 38 | end 39 | 40 | defn brute_force_search_with_candidates(data, query, candidate_indices, opts) do 41 | k = opts[:num_neighbors] 42 | metric = opts[:metric] 43 | dim = Nx.axis_size(data, 1) 44 | {size, length} = Nx.shape(candidate_indices) 45 | 46 | x = 47 | query 48 | |> Nx.new_axis(1) 49 | |> Nx.broadcast({size, length, dim}) 50 | |> Nx.vectorize([:query, :candidates]) 51 | 52 | y = Nx.take(data, candidate_indices) |> Nx.vectorize([:query, :candidates]) 53 | distances = metric.(x, y) |> Nx.devectorize() |> Nx.rename(nil) 54 | 55 | distances = 56 | if length > 1 do 57 | sorted_indices = Nx.argsort(candidate_indices, axis: 1, stable: true) 58 | inverse = inverse_permutation(sorted_indices) 59 | sorted = Nx.take_along_axis(candidate_indices, sorted_indices, axis: 1) 60 | 61 | duplicate_mask = 62 | Nx.concatenate( 63 | [ 64 | Nx.broadcast(0, {size, 1}), 65 | Nx.equal(sorted[[.., 0..-2//1]], sorted[[.., 1..-1//1]]) 66 | ], 67 | axis: 1 68 | ) 69 | |> Nx.take_along_axis(inverse, axis: 1) 70 | 71 | Nx.select(duplicate_mask, :infinity, distances) 72 | else 73 | distances 74 | end 75 | 76 | indices = Nx.argsort(distances, axis: 1) |> Nx.slice_along_axis(0, k, axis: 1) 77 | 78 | neighbor_indices = 79 | Nx.take( 80 | Nx.vectorize(candidate_indices, :samples), 81 | Nx.vectorize(indices, :samples) 82 | ) 83 | |> Nx.devectorize() 84 | |> Nx.rename(nil) 85 | 86 | neighbor_distances = Nx.take_along_axis(distances, indices, axis: 1) 87 | 88 | {neighbor_indices, neighbor_distances} 89 | end 90 | 91 | defnp inverse_permutation(indices) do 92 | {size, length} = Nx.shape(indices) 93 | target = Nx.broadcast(Nx.u32(0), {size, length}) 94 | samples = Nx.iota({size, length, 1}, axis: 0) 95 | 96 | target_indices = 97 | Nx.concatenate([samples, Nx.new_axis(indices, 2)], axis: 2) 98 | |> Nx.reshape({size * length, 2}) 99 | 100 | updates = Nx.iota({size, length}, axis: 1) |> Nx.reshape({size * length}) 101 | Nx.indexed_add(target, target_indices, updates) 102 | end 103 | 104 | defn check_weights(weights) do 105 | zero_mask = weights == 0 106 | zero_rows = zero_mask |> Nx.any(axes: [1], keep_axes: true) |> Nx.broadcast(weights) 107 | weights = Nx.select(zero_mask, 1, weights) 108 | weights_inv = 1 / weights 109 | Nx.select(zero_rows, zero_mask, weights_inv) 110 | end 111 | end 112 | -------------------------------------------------------------------------------- /lib/scholar/options.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Options do 2 | # Useful NimbleOptions validations. 3 | @moduledoc false 4 | 5 | require Nx 6 | 7 | def optimizer(value) do 8 | error = 9 | {:error, 10 | "expected :optimizer to be either a valid 0-arity function in Polaris.Optimizers or a valid {init_fn, update_fn} tuple"} 11 | 12 | case value do 13 | {init_fn, update_fn} when is_function(init_fn, 1) and is_function(update_fn, 3) -> 14 | {:ok, value} 15 | 16 | atom when is_atom(atom) -> 17 | mod = Polaris.Optimizers 18 | 19 | if Code.ensure_loaded(mod) == {:module, mod} and function_exported?(mod, atom, 0) do 20 | {:ok, atom} 21 | else 22 | error 23 | end 24 | 25 | _ -> 26 | error 27 | end 28 | end 29 | 30 | def axes(axes) do 31 | # Axes are further validated by Nx, including against the tensor. 32 | if axes == nil or is_list(axes) do 33 | {:ok, axes} 34 | else 35 | {:error, "expected :axes to be a list positive integers as axis"} 36 | end 37 | end 38 | 39 | def axis(axis) do 40 | # Axis is further validated by Nx, including against the tensor. 41 | if axis == nil or is_integer(axis) or is_atom(axis) do 42 | {:ok, axis} 43 | else 44 | {:error, "expected :axis to be an integers, atom or nil"} 45 | end 46 | end 47 | 48 | def type(type) do 49 | {:ok, Nx.Type.normalize!(type)} 50 | end 51 | 52 | def positive_number(num) do 53 | if is_number(num) and num > 0 do 54 | {:ok, num} 55 | else 56 | {:error, "expected positive number, got: #{inspect(num)}"} 57 | end 58 | end 59 | 60 | def non_negative_number(num) do 61 | if is_number(num) and num >= 0 do 62 | {:ok, num} 63 | else 64 | {:error, "expected a non-negative number, got: #{inspect(num)}"} 65 | end 66 | end 67 | 68 | def non_negative_integer(num) do 69 | if is_integer(num) and num >= 0 do 70 | {:ok, num} 71 | else 72 | {:error, "expected a non-negative integer, got: #{inspect(num)}"} 73 | end 74 | end 75 | 76 | def weights(weights) do 77 | if is_nil(weights) or 78 | (Nx.is_tensor(weights) and Nx.rank(weights) in 0..1) or 79 | (is_list(weights) and Enum.all?(weights, &is_number/1)) do 80 | {:ok, weights} 81 | else 82 | {:error, "expected weights to be a flat tensor or a flat list, got: #{inspect(weights)}"} 83 | end 84 | end 85 | 86 | def multi_weights(weights) do 87 | if is_nil(weights) or 88 | (Nx.is_tensor(weights) and Nx.rank(weights) > 1) do 89 | {:ok, weights} 90 | else 91 | {:error, 92 | "expected weights to be a tensor with rank greater than 1, got: #{inspect(weights)}"} 93 | end 94 | end 95 | 96 | def key(key) do 97 | if Nx.is_tensor(key) and Nx.type(key) == {:u, 32} and Nx.shape(key) == {2} do 98 | {:ok, key} 99 | else 100 | {:error, "expected key to be a key (use Nx.Random.key/1), got: #{inspect(key)}"} 101 | end 102 | end 103 | 104 | def beta(beta) do 105 | if (is_number(beta) and beta >= 0) or (Nx.is_tensor(beta) and Nx.rank(beta) == 0) do 106 | {:ok, beta} 107 | else 108 | {:error, "expected 'beta' to be in the range [0, inf]"} 109 | end 110 | end 111 | 112 | def quantile_range(value) do 113 | case value do 114 | {q_min, q_max} 115 | when is_number(q_min) and is_number(q_max) and 0.0 < q_min and q_min < q_max and 116 | q_max < 100.0 -> 117 | {:ok, {q_min, q_max}} 118 | 119 | _ -> 120 | {:error, 121 | "expected :quantile_range to be a tuple {q_min, q_max} such that 0.0 < q_min < q_max < 100.0, got: #{inspect(value)}"} 122 | end 123 | end 124 | end 125 | -------------------------------------------------------------------------------- /lib/scholar/preprocessing.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing do 2 | @moduledoc """ 3 | Set of functions for preprocessing data. 4 | """ 5 | 6 | import Nx.Defn 7 | 8 | binarize_schema = [ 9 | type: [ 10 | type: {:custom, Scholar.Options, :type, []}, 11 | default: :f32, 12 | doc: """ 13 | Type of the resultant tensor. 14 | """ 15 | ], 16 | threshold: [ 17 | type: {:or, [:integer, :float]}, 18 | default: 0, 19 | doc: """ 20 | Feature values below or equal to this are replaced by 0, above it by 1. 21 | """ 22 | ] 23 | ] 24 | 25 | @binarize_schema NimbleOptions.new!(binarize_schema) 26 | 27 | @doc """ 28 | Standardizes the tensor by removing the mean and scaling to unit variance. 29 | 30 | It is a shortcut for `Scholar.Preprocessing.StandardScale.fit_transform/3`. 31 | See `Scholar.Preprocessing.StandardScale` for more information. 32 | 33 | ## Examples 34 | 35 | iex> Scholar.Preprocessing.standard_scale(Nx.tensor([1,2,3])) 36 | #Nx.Tensor< 37 | f32[3] 38 | [-1.2247447967529297, 0.0, 1.2247447967529297] 39 | > 40 | 41 | """ 42 | defn standard_scale(tensor, opts \\ []) do 43 | Scholar.Preprocessing.StandardScaler.fit_transform(tensor, opts) 44 | end 45 | 46 | @doc """ 47 | Scales a tensor by dividing each sample in batch by maximum absolute value in the batch. 48 | 49 | It is a shortcut for `Scholar.Preprocessing.MaxAbsScaler.fit_transform/2`. 50 | See `Scholar.Preprocessing.MaxAbsScaler` for more information. 51 | 52 | ## Examples 53 | 54 | iex> Scholar.Preprocessing.max_abs_scale(Nx.tensor([1, 2, 3])) 55 | #Nx.Tensor< 56 | f32[3] 57 | [0.3333333432674408, 0.6666666865348816, 1.0] 58 | > 59 | 60 | iex> Scholar.Preprocessing.max_abs_scale(Nx.tensor([[1, -1, 2], [3, 0, 0], [0, 1, -1], [2, 3, 1]]), axes: [0]) 61 | #Nx.Tensor< 62 | f32[4][3] 63 | [ 64 | [0.3333333432674408, -0.3333333432674408, 1.0], 65 | [1.0, 0.0, 0.0], 66 | [0.0, 0.3333333432674408, -0.5], 67 | [0.6666666865348816, 1.0, 0.5] 68 | ] 69 | > 70 | 71 | iex> Scholar.Preprocessing.max_abs_scale(42) 72 | #Nx.Tensor< 73 | f32 74 | 1.0 75 | > 76 | """ 77 | defn max_abs_scale(tensor, opts \\ []) do 78 | Scholar.Preprocessing.MaxAbsScaler.fit_transform(tensor, opts) 79 | end 80 | 81 | @doc """ 82 | Scales a tensor by a given range. 83 | 84 | It is a shortcut for `Scholar.Preprocessing.MinMaxScaler.fit_transform/2`. 85 | See `Scholar.Preprocessing.MinMaxScaler` for more information. 86 | 87 | ## Examples 88 | 89 | iex> Scholar.Preprocessing.min_max_scale(Nx.tensor([1, 2, 3])) 90 | #Nx.Tensor< 91 | f32[3] 92 | [0.0, 0.5, 1.0] 93 | > 94 | 95 | iex> Scholar.Preprocessing.min_max_scale(42) 96 | #Nx.Tensor< 97 | f32 98 | 0.0 99 | > 100 | """ 101 | defn min_max_scale(tensor, opts \\ []) do 102 | Scholar.Preprocessing.MinMaxScaler.fit_transform(tensor, opts) 103 | end 104 | 105 | @doc """ 106 | Converts a tensor into binary values based on the given threshold. 107 | 108 | ## Options 109 | 110 | #{NimbleOptions.docs(@binarize_schema)} 111 | 112 | ## Examples 113 | 114 | iex> Scholar.Preprocessing.binarize(Nx.tensor([[1.0, -1.0, 2.0], [2.0, 0.0, 0.0], [0.0, 1.0, -1.0]])) 115 | #Nx.Tensor< 116 | f32[3][3] 117 | [ 118 | [1.0, 0.0, 1.0], 119 | [1.0, 0.0, 0.0], 120 | [0.0, 1.0, 0.0] 121 | ] 122 | > 123 | 124 | iex> Scholar.Preprocessing.binarize(Nx.tensor([[1.0, -1.0, 2.0], [2.0, 0.0, 0.0], [0.0, 1.0, -1.0]]), threshold: 1.3, type: {:u, 8}) 125 | #Nx.Tensor< 126 | u8[3][3] 127 | [ 128 | [0, 0, 1], 129 | [1, 0, 0], 130 | [0, 0, 0] 131 | ] 132 | > 133 | """ 134 | deftransform binarize(tensor, opts \\ []) do 135 | binarize_n(tensor, NimbleOptions.validate!(opts, @binarize_schema)) 136 | end 137 | 138 | defnp binarize_n(tensor, opts) do 139 | (tensor > opts[:threshold]) |> Nx.as_type(opts[:type]) 140 | end 141 | 142 | @doc """ 143 | It is a shortcut for `Scholar.Preprocessing.OrdinalEncoder.fit_transform/1`. 144 | 145 | See `Scholar.Preprocessing.OrdinalEncoder` for more information. 146 | 147 | ## Examples 148 | 149 | iex> Scholar.Preprocessing.ordinal_encode(Nx.tensor([3, 2, 4, 56, 2, 4, 2])) 150 | #Nx.Tensor< 151 | u64[7] 152 | [1, 0, 2, 3, 0, 2, 0] 153 | > 154 | """ 155 | defn ordinal_encode(tensor) do 156 | Scholar.Preprocessing.OrdinalEncoder.fit_transform(tensor) 157 | end 158 | 159 | @doc """ 160 | It is a shortcut for `Scholar.Preprocessing.OneHotEncoder.fit_transform/2`. 161 | See `Scholar.Preprocessing.OneHotEncoder` for more information. 162 | 163 | ## Examples 164 | 165 | iex> Scholar.Preprocessing.one_hot_encode(Nx.tensor([2, 0, 3, 2, 1, 1, 0]), num_categories: 4) 166 | #Nx.Tensor< 167 | u8[7][4] 168 | [ 169 | [0, 0, 1, 0], 170 | [1, 0, 0, 0], 171 | [0, 0, 0, 1], 172 | [0, 0, 1, 0], 173 | [0, 1, 0, 0], 174 | [0, 1, 0, 0], 175 | [1, 0, 0, 0] 176 | ] 177 | > 178 | """ 179 | defn one_hot_encode(tensor, opts) do 180 | Scholar.Preprocessing.OneHotEncoder.fit_transform(tensor, opts) 181 | end 182 | 183 | @doc """ 184 | Normalize samples individually to unit norm. 185 | 186 | The zero-tensors cannot be normalized and they stay the same 187 | after normalization. 188 | 189 | It is a shortcut for `Scholar.Preprocessing.Normalizer.fit_transform/2`. 190 | See `Scholar.Preprocessing.Normalizer` for more information. 191 | 192 | ## Examples 193 | 194 | iex> Scholar.Preprocessing.normalize(Nx.tensor([[0, 0, 0], [3, 4, 5], [-2, 4, 3]]), axes: [1]) 195 | #Nx.Tensor< 196 | f32[3][3] 197 | [ 198 | [0.0, 0.0, 0.0], 199 | [0.4242640733718872, 0.5656854510307312, 0.7071067690849304], 200 | [-0.3713906705379486, 0.7427813410758972, 0.5570860505104065] 201 | ] 202 | > 203 | 204 | iex> Scholar.Preprocessing.normalize(Nx.tensor([[0, 0, 0], [3, 4, 5], [-2, 4, 3]])) 205 | #Nx.Tensor< 206 | f32[3][3] 207 | [ 208 | [0.0, 0.0, 0.0], 209 | [0.3375263810157776, 0.4500351846218109, 0.5625439882278442], 210 | [-0.22501759231090546, 0.4500351846218109, 0.3375263810157776] 211 | ] 212 | > 213 | """ 214 | deftransform normalize(tensor, opts \\ []) do 215 | Scholar.Preprocessing.Normalizer.fit_transform(tensor, opts) 216 | end 217 | end 218 | -------------------------------------------------------------------------------- /lib/scholar/preprocessing/binarizer.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.Binarizer do 2 | @moduledoc """ 3 | Binarize data according to a threshold. 4 | """ 5 | import Nx.Defn 6 | 7 | binarize_schema = [ 8 | threshold: [ 9 | type: :float, 10 | default: 0.0, 11 | doc: """ 12 | Feature values below or equal to this are replaced by 0, above it by 1. 13 | Threshold may not be less than 0 for operations on sparse matrices. 14 | """ 15 | ] 16 | ] 17 | 18 | @binarize_schema NimbleOptions.new!(binarize_schema) 19 | 20 | @doc """ 21 | Values greater than the threshold map to 1, while values less than 22 | or equal to the threshold map to 0. With the default threshold of 0, 23 | only positive values map to 1. 24 | ## Options 25 | #{NimbleOptions.docs(@binarize_schema)} 26 | ## Examples 27 | iex> t = Nx.tensor([[0, 0, 0], [3, 4, 5], [-2, 4, 3]]) 28 | iex> Scholar.Preprocessing.Binarizer.fit_transform(t, threshold: 3.0) 29 | #Nx.Tensor< 30 | u8[3][3] 31 | [ 32 | [0, 0, 0], 33 | [0, 1, 1], 34 | [0, 1, 0] 35 | ] 36 | > 37 | iex> t = Nx.tensor([[0, 0, 0], [3, 4, 5], [-2, 4, 3]]) 38 | iex> Scholar.Preprocessing.Binarizer.fit_transform(t,threshold: 0.4) 39 | #Nx.Tensor< 40 | u8[3][3] 41 | [ 42 | [0, 0, 0], 43 | [1, 1, 1], 44 | [0, 1, 1] 45 | ] 46 | > 47 | """ 48 | deftransform fit_transform(tensor, opts \\ []) do 49 | binarize_n(tensor, NimbleOptions.validate!(opts, @binarize_schema)) 50 | end 51 | 52 | defnp binarize_n(tensor, opts) do 53 | threshold = opts[:threshold] 54 | tensor > threshold 55 | end 56 | end 57 | -------------------------------------------------------------------------------- /lib/scholar/preprocessing/max_abs_scaler.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.MaxAbsScaler do 2 | @moduledoc """ 3 | Scales a tensor by dividing each sample in batch by the maximum absolute value in the batch. 4 | 5 | Centering and scaling happen independently on each feature by computing the relevant 6 | statistics on the samples in the training set. The maximum absolute value is then 7 | stored to be used on new samples. 8 | """ 9 | 10 | import Nx.Defn 11 | 12 | @derive {Nx.Container, containers: [:max_abs]} 13 | defstruct [:max_abs] 14 | 15 | opts_schema = [ 16 | axes: [ 17 | type: {:custom, Scholar.Options, :axes, []}, 18 | doc: """ 19 | Axes to calculate the max absolute value over. By default the absolute values 20 | are calculated between the whole tensors. 21 | """ 22 | ] 23 | ] 24 | 25 | @opts_schema NimbleOptions.new!(opts_schema) 26 | 27 | @doc """ 28 | Compute the maximum absolute value of samples to be used for later scaling. 29 | 30 | ## Options 31 | 32 | #{NimbleOptions.docs(@opts_schema)} 33 | 34 | ## Return values 35 | 36 | Returns a struct with the following parameters: 37 | 38 | * `max_abs`: the calculated maximum absolute value of samples. 39 | 40 | ## Examples 41 | 42 | iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 43 | iex> Scholar.Preprocessing.MaxAbsScaler.fit(t) 44 | %Scholar.Preprocessing.MaxAbsScaler{ 45 | max_abs: Nx.tensor( 46 | [ 47 | [2] 48 | ] 49 | ) 50 | } 51 | """ 52 | deftransform fit(tensor, opts \\ []) do 53 | fit_n(tensor, NimbleOptions.validate!(opts, @opts_schema)) 54 | end 55 | 56 | defnp fit_n(tensor, opts) do 57 | max_abs = 58 | Nx.abs(tensor) 59 | |> Nx.reduce_max(axes: opts[:axes], keep_axes: true) 60 | 61 | max_abs = Nx.select(max_abs == 0, 1, max_abs) 62 | 63 | %__MODULE__{max_abs: max_abs} 64 | end 65 | 66 | @doc """ 67 | Performs the standardization of the tensor using a fitted scaler. 68 | 69 | ## Examples 70 | 71 | iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 72 | iex> scaler = Scholar.Preprocessing.MaxAbsScaler.fit(t) 73 | iex> Scholar.Preprocessing.MaxAbsScaler.transform(scaler, t) 74 | #Nx.Tensor< 75 | f32[3][3] 76 | [ 77 | [0.5, -0.5, 1.0], 78 | [1.0, 0.0, 0.0], 79 | [0.0, 0.5, -0.5] 80 | ] 81 | > 82 | iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 83 | iex> scaler = Scholar.Preprocessing.MaxAbsScaler.fit(t) 84 | iex> new_tensor = Nx.tensor([[0.5, 1, -1], [0.3, 0.8, -1.6]]) 85 | iex> Scholar.Preprocessing.MaxAbsScaler.transform(scaler, new_tensor) 86 | #Nx.Tensor< 87 | f32[2][3] 88 | [ 89 | [0.25, 0.5, -0.5], 90 | [0.15000000596046448, 0.4000000059604645, -0.800000011920929] 91 | ] 92 | > 93 | """ 94 | defn transform(%__MODULE__{max_abs: max_abs}, tensor) do 95 | tensor / max_abs 96 | end 97 | 98 | @doc """ 99 | Standardizes the tensor by removing the mean and scaling to unit variance. 100 | 101 | ## Examples 102 | 103 | iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 104 | iex> Scholar.Preprocessing.MaxAbsScaler.fit_transform(t) 105 | #Nx.Tensor< 106 | f32[3][3] 107 | [ 108 | [0.5, -0.5, 1.0], 109 | [1.0, 0.0, 0.0], 110 | [0.0, 0.5, -0.5] 111 | ] 112 | > 113 | """ 114 | defn fit_transform(tensor, opts \\ []) do 115 | tensor 116 | |> fit(opts) 117 | |> transform(tensor) 118 | end 119 | end 120 | -------------------------------------------------------------------------------- /lib/scholar/preprocessing/min_max_scaler.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.MinMaxScaler do 2 | @moduledoc """ 3 | Scales a tensor by dividing each sample in batch by maximum absolute value in the batch 4 | 5 | Centering and scaling happen independently on each feature by computing the relevant 6 | statistics on the samples in the training set. Maximum absolute value then is 7 | stored to be used on new samples. 8 | """ 9 | 10 | import Nx.Defn 11 | 12 | @derive {Nx.Container, containers: [:min_data, :max_data, :min_bound, :max_bound]} 13 | defstruct [:min_data, :max_data, :min_bound, :max_bound] 14 | 15 | opts_schema = [ 16 | axes: [ 17 | type: {:custom, Scholar.Options, :axes, []}, 18 | doc: """ 19 | Axes to calculate the max absolute value over. By default the absolute values 20 | are calculated between the whole tensors. 21 | """ 22 | ], 23 | min_bound: [ 24 | type: {:or, [:integer, :float]}, 25 | default: 0, 26 | doc: """ 27 | The lower boundary of the desired range of transformed data. 28 | """ 29 | ], 30 | max_bound: [ 31 | type: {:or, [:integer, :float]}, 32 | default: 1, 33 | doc: """ 34 | The upper boundary of the desired range of transformed data. 35 | """ 36 | ] 37 | ] 38 | 39 | @opts_schema NimbleOptions.new!(opts_schema) 40 | 41 | @doc """ 42 | Compute the maximum absolute value of samples to be used for later scaling. 43 | 44 | ## Options 45 | 46 | #{NimbleOptions.docs(@opts_schema)} 47 | 48 | ## Return values 49 | 50 | Returns a struct with the following parameters: 51 | 52 | * `min_data`: the calculated minimum value of samples. 53 | 54 | * `max_data`: the calculated maximum value of samples. 55 | 56 | * `min_bound`: The lower boundary of the desired range of transformed data. 57 | 58 | * `max_bound`: The upper boundary of the desired range of transformed data. 59 | 60 | ## Examples 61 | 62 | iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 63 | iex> Scholar.Preprocessing.MinMaxScaler.fit(t) 64 | %Scholar.Preprocessing.MinMaxScaler{ 65 | min_data: Nx.tensor( 66 | [ 67 | [-1] 68 | ] 69 | ), 70 | max_data: Nx.tensor( 71 | [ 72 | [2] 73 | ] 74 | ), 75 | min_bound: Nx.tensor( 76 | 0 77 | ), 78 | max_bound: Nx.tensor( 79 | 1 80 | ) 81 | } 82 | """ 83 | deftransform fit(tensor, opts \\ []) do 84 | fit_n(tensor, NimbleOptions.validate!(opts, @opts_schema)) 85 | end 86 | 87 | defnp fit_n(tensor, opts) do 88 | if opts[:max_bound] <= opts[:min_bound] do 89 | raise ArgumentError, 90 | "expected :max to be greater than :min" 91 | else 92 | reduced_max = Nx.reduce_max(tensor, axes: opts[:axes], keep_axes: true) 93 | reduced_min = Nx.reduce_min(tensor, axes: opts[:axes], keep_axes: true) 94 | 95 | %__MODULE__{ 96 | min_data: reduced_min, 97 | max_data: reduced_max, 98 | min_bound: opts[:min_bound], 99 | max_bound: opts[:max_bound] 100 | } 101 | end 102 | end 103 | 104 | @doc """ 105 | Performs the standardization of the tensor using a fitted scaler. 106 | 107 | ## Examples 108 | 109 | iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 110 | iex> scaler = Scholar.Preprocessing.MinMaxScaler.fit(t) 111 | iex> Scholar.Preprocessing.MinMaxScaler.transform(scaler, t) 112 | #Nx.Tensor< 113 | f32[3][3] 114 | [ 115 | [0.6666666865348816, 0.0, 1.0], 116 | [1.0, 0.3333333432674408, 0.3333333432674408], 117 | [0.3333333432674408, 0.6666666865348816, 0.0] 118 | ] 119 | > 120 | 121 | iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 122 | iex> scaler = Scholar.Preprocessing.MinMaxScaler.fit(t) 123 | iex> new_tensor = Nx.tensor([[0.5, 1, -1], [0.3, 0.8, -1.6]]) 124 | iex> Scholar.Preprocessing.MinMaxScaler.transform(scaler, new_tensor) 125 | #Nx.Tensor< 126 | f32[2][3] 127 | [ 128 | [0.5, 0.6666666865348816, 0.0], 129 | [0.43333330750465393, 0.5999999642372131, -0.20000000298023224] 130 | ] 131 | > 132 | """ 133 | defn transform( 134 | %__MODULE__{ 135 | min_data: min_data, 136 | max_data: max_data, 137 | min_bound: min_bound, 138 | max_bound: max_bound 139 | }, 140 | tensor 141 | ) do 142 | denominator = max_data - min_data 143 | denominator = Nx.select(denominator == 0, 1, denominator) 144 | x_std = (tensor - min_data) / denominator 145 | x_std * (max_bound - min_bound) + min_bound 146 | end 147 | 148 | @doc """ 149 | Standardizes the tensor by removing the mean and scaling to unit variance. 150 | 151 | ## Examples 152 | 153 | iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 154 | iex> Scholar.Preprocessing.MinMaxScaler.fit_transform(t) 155 | #Nx.Tensor< 156 | f32[3][3] 157 | [ 158 | [0.6666666865348816, 0.0, 1.0], 159 | [1.0, 0.3333333432674408, 0.3333333432674408], 160 | [0.3333333432674408, 0.6666666865348816, 0.0] 161 | ] 162 | > 163 | """ 164 | defn fit_transform(tensor, opts \\ []) do 165 | tensor 166 | |> fit(opts) 167 | |> transform(tensor) 168 | end 169 | end 170 | -------------------------------------------------------------------------------- /lib/scholar/preprocessing/normalizer.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.Normalizer do 2 | @moduledoc """ 3 | Implements functionality for rescaling tensor to unit norm. It enables to apply normalization along any combination of axes. 4 | """ 5 | import Nx.Defn 6 | import Scholar.Shared 7 | 8 | normalize_schema = [ 9 | axes: [ 10 | type: {:custom, Scholar.Options, :axes, []}, 11 | doc: """ 12 | Axes to calculate the distance over. By default the distance 13 | is calculated between the whole tensors. 14 | """ 15 | ], 16 | norm: [ 17 | type: {:in, [:euclidean, :chebyshev, :manhattan]}, 18 | default: :euclidean, 19 | doc: """ 20 | The norm to use to normalize each non zero sample. 21 | Possible options are `:euclidean`, `:manhattan`, and `:chebyshev` 22 | """ 23 | ] 24 | ] 25 | 26 | @normalize_schema NimbleOptions.new!(normalize_schema) 27 | 28 | @doc """ 29 | Normalize samples individually to unit norm. 30 | 31 | The zero-tensors cannot be normalized and they stay the same 32 | after normalization. 33 | 34 | ## Options 35 | 36 | #{NimbleOptions.docs(@normalize_schema)} 37 | 38 | ## Examples 39 | 40 | iex> t = Nx.tensor([[0, 0, 0], [3, 4, 5], [-2, 4, 3]]) 41 | iex> Scholar.Preprocessing.Normalizer.fit_transform(t, axes: [1]) 42 | #Nx.Tensor< 43 | f32[3][3] 44 | [ 45 | [0.0, 0.0, 0.0], 46 | [0.4242640733718872, 0.5656854510307312, 0.7071067690849304], 47 | [-0.3713906705379486, 0.7427813410758972, 0.5570860505104065] 48 | ] 49 | > 50 | 51 | iex> t = Nx.tensor([[0, 0, 0], [3, 4, 5], [-2, 4, 3]]) 52 | iex> Scholar.Preprocessing.Normalizer.fit_transform(t) 53 | #Nx.Tensor< 54 | f32[3][3] 55 | [ 56 | [0.0, 0.0, 0.0], 57 | [0.3375263810157776, 0.4500351846218109, 0.5625439882278442], 58 | [-0.22501759231090546, 0.4500351846218109, 0.3375263810157776] 59 | ] 60 | > 61 | """ 62 | deftransform fit_transform(tensor, opts \\ []) do 63 | normalize_n(tensor, NimbleOptions.validate!(opts, @normalize_schema)) 64 | end 65 | 66 | defnp normalize_n(tensor, opts) do 67 | shape = Nx.shape(tensor) 68 | type = to_float_type(tensor) 69 | zeros = Nx.broadcast(Nx.tensor(0.0, type: type), shape) 70 | 71 | norm = 72 | case opts[:norm] do 73 | :euclidean -> 74 | Scholar.Metrics.Distance.euclidean(tensor, zeros, axes: opts[:axes]) 75 | 76 | :manhattan -> 77 | Scholar.Metrics.Distance.manhattan(tensor, zeros, axes: opts[:axes]) 78 | 79 | :chebyshev -> 80 | Scholar.Metrics.Distance.chebyshev(tensor, zeros, axes: opts[:axes]) 81 | 82 | other -> 83 | raise ArgumentError, 84 | "expected :norm to be one of: :euclidean, :manhattan, and :chebyshev, got: #{inspect(other)}" 85 | end 86 | 87 | shape_to_broadcast = unsqueezed_reduced_shape(shape, opts[:axes]) 88 | 89 | norm = 90 | Nx.select(norm == 0.0, Nx.tensor(1.0, type: type), norm) |> Nx.reshape(shape_to_broadcast) 91 | 92 | tensor / norm 93 | end 94 | 95 | deftransformp unsqueezed_reduced_shape(shape, axes) do 96 | if axes != nil do 97 | Enum.reduce(axes, shape, &put_elem(&2, &1, 1)) 98 | else 99 | Tuple.duplicate(1, Nx.rank(shape)) 100 | end 101 | end 102 | end 103 | -------------------------------------------------------------------------------- /lib/scholar/preprocessing/one_hot_encoder.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.OneHotEncoder do 2 | @moduledoc """ 3 | Implements encoder that converts integer value (substitute of categorical data in tensors) into 0-1 vector. 4 | The index of 1 in the vector is aranged in sorted manner. This means that for x < y => one_index(x) < one_index(y). 5 | 6 | Currently the module supports only 1D tensors. 7 | """ 8 | import Nx.Defn 9 | 10 | @derive {Nx.Container, containers: [:ordinal_encoder]} 11 | defstruct [:ordinal_encoder] 12 | 13 | encode_schema = [ 14 | num_categories: [ 15 | required: true, 16 | type: :pos_integer, 17 | doc: """ 18 | The number of categories to be encoded. 19 | """ 20 | ] 21 | ] 22 | 23 | @encode_schema NimbleOptions.new!(encode_schema) 24 | 25 | @doc """ 26 | Creates mapping from values into one-hot vectors. 27 | 28 | ## Options 29 | 30 | #{NimbleOptions.docs(@encode_schema)} 31 | 32 | ## Examples 33 | 34 | iex> tensor = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) 35 | iex> Scholar.Preprocessing.OneHotEncoder.fit(tensor, num_categories: 4) 36 | %Scholar.Preprocessing.OneHotEncoder{ 37 | ordinal_encoder: %Scholar.Preprocessing.OrdinalEncoder{ 38 | categories: Nx.tensor([2, 3, 4, 56] 39 | ) 40 | } 41 | } 42 | """ 43 | deftransform fit(tensor, opts) do 44 | if Nx.rank(tensor) != 1 do 45 | raise ArgumentError, 46 | """ 47 | expected input tensor to have shape {num_samples}, \ 48 | got tensor with shape: #{inspect(Nx.shape(tensor))} 49 | """ 50 | end 51 | 52 | opts = NimbleOptions.validate!(opts, @encode_schema) 53 | 54 | fit_n(tensor, opts) 55 | end 56 | 57 | defnp fit_n(tensor, opts) do 58 | ordinal_encoder = Scholar.Preprocessing.OrdinalEncoder.fit(tensor, opts) 59 | %__MODULE__{ordinal_encoder: ordinal_encoder} 60 | end 61 | 62 | @doc """ 63 | Encode labels as a one-hot numeric tensor. All values provided to `transform/2` must be seen 64 | in `fit/2` function, otherwise an error occurs. 65 | 66 | ## Examples 67 | 68 | iex> tensor = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) 69 | iex> encoder = Scholar.Preprocessing.OneHotEncoder.fit(tensor, num_categories: 4) 70 | iex> Scholar.Preprocessing.OneHotEncoder.transform(encoder, tensor) 71 | #Nx.Tensor< 72 | u8[7][4] 73 | [ 74 | [0, 1, 0, 0], 75 | [1, 0, 0, 0], 76 | [0, 0, 1, 0], 77 | [0, 0, 0, 1], 78 | [1, 0, 0, 0], 79 | [0, 0, 1, 0], 80 | [1, 0, 0, 0] 81 | ] 82 | > 83 | 84 | iex> tensor = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) 85 | iex> encoder = Scholar.Preprocessing.OneHotEncoder.fit(tensor, num_categories: 4) 86 | iex> new_tensor = Nx.tensor([2, 3, 4, 3, 4, 56, 2]) 87 | iex> Scholar.Preprocessing.OneHotEncoder.transform(encoder, new_tensor) 88 | #Nx.Tensor< 89 | u8[7][4] 90 | [ 91 | [1, 0, 0, 0], 92 | [0, 1, 0, 0], 93 | [0, 0, 1, 0], 94 | [0, 1, 0, 0], 95 | [0, 0, 1, 0], 96 | [0, 0, 0, 1], 97 | [1, 0, 0, 0] 98 | ] 99 | > 100 | """ 101 | defn transform(%__MODULE__{ordinal_encoder: ordinal_encoder}, tensor) do 102 | num_categories = Nx.size(ordinal_encoder.categories) 103 | num_samples = Nx.size(tensor) 104 | 105 | encoded = 106 | ordinal_encoder 107 | |> Scholar.Preprocessing.OrdinalEncoder.transform(tensor) 108 | |> Nx.new_axis(1) 109 | |> Nx.broadcast({num_samples, num_categories}) 110 | 111 | encoded == Nx.iota({num_samples, num_categories}, axis: 1) 112 | end 113 | 114 | @doc """ 115 | Appl 116 | encoding on the provided tensor directly. It's equivalent to `fit/2` and then `transform/2` on the same data. 117 | 118 | ## Examples 119 | 120 | iex> tensor = Nx.tensor([3, 2, 4, 56, 2, 4, 2]) 121 | iex> Scholar.Preprocessing.OneHotEncoder.fit_transform(tensor, num_categories: 4) 122 | #Nx.Tensor< 123 | u8[7][4] 124 | [ 125 | [0, 1, 0, 0], 126 | [1, 0, 0, 0], 127 | [0, 0, 1, 0], 128 | [0, 0, 0, 1], 129 | [1, 0, 0, 0], 130 | [0, 0, 1, 0], 131 | [1, 0, 0, 0] 132 | ] 133 | > 134 | """ 135 | deftransform fit_transform(tensor, opts) do 136 | if Nx.rank(tensor) != 1 do 137 | raise ArgumentError, 138 | """ 139 | expected input tensor to have shape {num_samples}, \ 140 | got tensor with shape: #{inspect(Nx.shape(tensor))} 141 | """ 142 | end 143 | 144 | opts = NimbleOptions.validate!(opts, @encode_schema) 145 | fit_transform_n(tensor, opts) 146 | end 147 | 148 | defnp fit_transform_n(tensor, opts) do 149 | num_samples = Nx.size(tensor) 150 | num_categories = opts[:num_categories] 151 | 152 | encoded = 153 | tensor 154 | |> Scholar.Preprocessing.OrdinalEncoder.fit_transform() 155 | |> Nx.new_axis(1) 156 | |> Nx.broadcast({num_samples, num_categories}) 157 | 158 | encoded == Nx.iota({num_samples, num_categories}, axis: 1) 159 | end 160 | end 161 | -------------------------------------------------------------------------------- /lib/scholar/preprocessing/robust_scaler.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.RobustScaler do 2 | @moduledoc ~S""" 3 | Scale features using statistics that are robust to outliers. 4 | 5 | This Scaler removes the median and scales the data according to 6 | the quantile range (defaults to IQR: Interquartile Range). 7 | The IQR is the range between the 1st quartile (25th quantile) 8 | and the 3rd quartile (75th quantile). 9 | """ 10 | 11 | import Nx.Defn 12 | 13 | @derive {Nx.Container, containers: [:medians, :iqr]} 14 | defstruct [:medians, :iqr] 15 | 16 | opts_schema = [ 17 | quantile_range: [ 18 | type: {:custom, Scholar.Options, :quantile_range, []}, 19 | default: {25.0, 75.0}, 20 | doc: """ 21 | Quantile range as a tuple {q_min, q_max} defining the range of quantiles 22 | to include. Must satisfy 0.0 < q_min < q_max < 100.0. 23 | """ 24 | ] 25 | ] 26 | 27 | @opts_schema NimbleOptions.new!(opts_schema) 28 | 29 | @doc """ 30 | Compute the median and quantiles to be used for scaling. 31 | 32 | ## Options 33 | 34 | #{NimbleOptions.docs(@opts_schema)} 35 | 36 | ## Return values 37 | 38 | Returns a struct with the following parameters: 39 | 40 | * `:iqr` - the calculated interquartile range. 41 | 42 | * `:medians` - the calculated medians of each feature across samples. 43 | 44 | ## Examples 45 | 46 | iex> Scholar.Preprocessing.RobustScaler.fit(Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])) 47 | %Scholar.Preprocessing.RobustScaler{ 48 | medians: Nx.tensor([1, 0, 0]), 49 | iqr: Nx.tensor([1.0, 1.0, 1.5]) 50 | } 51 | """ 52 | deftransform fit(tensor, opts \\ []) do 53 | fit_n(tensor, NimbleOptions.validate!(opts, @opts_schema)) 54 | end 55 | 56 | defnp fit_n(tensor, opts) do 57 | check_for_rank(tensor) 58 | 59 | {q_min, q_max} = opts[:quantile_range] 60 | 61 | medians = Nx.median(tensor, axis: 0) 62 | 63 | sorted_tensor = Nx.sort(tensor, axis: 0) 64 | 65 | q_min = percentile(sorted_tensor, q_min) 66 | q_max = percentile(sorted_tensor, q_max) 67 | 68 | iqr = q_max - q_min 69 | 70 | %__MODULE__{medians: medians, iqr: iqr} 71 | end 72 | 73 | @doc """ 74 | Performs centering and scaling of the tensor using a fitted scaler. 75 | 76 | ## Examples 77 | 78 | iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 79 | iex> scaler = Scholar.Preprocessing.RobustScaler.fit(t) 80 | %Scholar.Preprocessing.RobustScaler{ 81 | medians: Nx.tensor([1, 0, 0]), 82 | iqr: Nx.tensor([1.0, 1.0, 1.5]) 83 | } 84 | iex> Scholar.Preprocessing.RobustScaler.transform(scaler, t) 85 | #Nx.Tensor< 86 | f32[3][3] 87 | [ 88 | [0.0, -1.0, 1.3333333730697632], 89 | [1.0, 0.0, 0.0], 90 | [-1.0, 1.0, -0.6666666865348816] 91 | ] 92 | > 93 | """ 94 | defn transform(%__MODULE__{medians: medians, iqr: iqr}, tensor) do 95 | check_for_rank(tensor) 96 | scale(tensor, medians, iqr) 97 | end 98 | 99 | @doc """ 100 | Computes the scaling parameters and applies them to transform the tensor. 101 | 102 | ## Examples 103 | 104 | iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 105 | iex> Scholar.Preprocessing.RobustScaler.fit_transform(t) 106 | #Nx.Tensor< 107 | f32[3][3] 108 | [ 109 | [0.0, -1.0, 1.3333333730697632], 110 | [1.0, 0.0, 0.0], 111 | [-1.0, 1.0, -0.6666666865348816] 112 | ] 113 | > 114 | """ 115 | defn fit_transform(tensor, opts \\ []) do 116 | tensor 117 | |> fit(opts) 118 | |> transform(tensor) 119 | end 120 | 121 | defnp scale(tensor, medians, iqr) do 122 | (tensor - medians) / Nx.select(iqr == 0, 1.0, iqr) 123 | end 124 | 125 | defnp percentile(sorted_tensor, p) do 126 | num_rows = Nx.axis_size(sorted_tensor, 0) 127 | idx = p / 100 * (num_rows - 1) 128 | 129 | lower_idx = Nx.floor(idx) |> Nx.as_type(:s64) 130 | upper_idx = Nx.ceil(idx) |> Nx.as_type(:s64) 131 | 132 | lower_values = Nx.take(sorted_tensor, lower_idx, axis: 0) 133 | upper_values = Nx.take(sorted_tensor, upper_idx, axis: 0) 134 | 135 | weight_upper = idx - Nx.floor(idx) 136 | weight_lower = 1.0 - weight_upper 137 | lower_values * weight_lower + upper_values * weight_upper 138 | end 139 | 140 | defnp check_for_rank(tensor) do 141 | if Nx.rank(tensor) != 2 do 142 | raise ArgumentError, 143 | """ 144 | expected tensor to have shape {num_samples, num_features}, \ 145 | got tensor with shape: #{inspect(Nx.shape(tensor))}\ 146 | """ 147 | end 148 | end 149 | end 150 | -------------------------------------------------------------------------------- /lib/scholar/preprocessing/standard_scaler.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.StandardScaler do 2 | @moduledoc ~S""" 3 | Standardizes the tensor by removing the mean and scaling to unit variance. 4 | 5 | Formula for input tensor $x$: 6 | 7 | $$ 8 | z = \frac{x - \mu}{\sigma} 9 | $$ 10 | 11 | Where $\mu$ is the mean of the samples, and $\sigma$ is the standard deviation. 12 | Standardization can be helpful in cases where the data follows 13 | a Gaussian distribution (or Normal distribution) without outliers. 14 | 15 | Centering and scaling happen independently on each feature by computing the relevant 16 | statistics on the samples in the training set. Mean and standard deviation are then 17 | stored to be used on new samples. 18 | """ 19 | 20 | import Nx.Defn 21 | 22 | @derive {Nx.Container, containers: [:standard_deviation, :mean]} 23 | defstruct [:standard_deviation, :mean] 24 | 25 | opts_schema = [ 26 | axes: [ 27 | type: {:custom, Scholar.Options, :axes, []}, 28 | doc: """ 29 | Axes to calculate the distance over. By default the distance 30 | is calculated between the whole tensors. 31 | """ 32 | ] 33 | ] 34 | 35 | @opts_schema NimbleOptions.new!(opts_schema) 36 | 37 | @doc """ 38 | Compute the standard deviation and mean of samples to be used for later scaling. 39 | 40 | ## Options 41 | 42 | #{NimbleOptions.docs(@opts_schema)} 43 | 44 | ## Return values 45 | 46 | Returns a struct with the following parameters: 47 | 48 | * `standard_deviation`: the calculated standard deviation of samples. 49 | 50 | * `mean`: the calculated mean of samples. 51 | 52 | ## Examples 53 | 54 | iex> Scholar.Preprocessing.StandardScaler.fit(Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])) 55 | %Scholar.Preprocessing.StandardScaler{ 56 | standard_deviation: Nx.tensor( 57 | [ 58 | [1.0657403469085693] 59 | ] 60 | ), 61 | mean: Nx.tensor( 62 | [ 63 | [0.4444444477558136] 64 | ] 65 | ) 66 | } 67 | """ 68 | deftransform fit(tensor, opts \\ []) do 69 | NimbleOptions.validate!(opts, @opts_schema) 70 | fit_n(tensor, opts) 71 | end 72 | 73 | defnp fit_n(tensor, opts) do 74 | std = Nx.standard_deviation(tensor, axes: opts[:axes], keep_axes: true) 75 | mean_reduced = Nx.mean(tensor, axes: opts[:axes], keep_axes: true) 76 | mean_reduced = Nx.select(std == 0, 0.0, mean_reduced) 77 | %__MODULE__{standard_deviation: std, mean: mean_reduced} 78 | end 79 | 80 | @doc """ 81 | Performs the standardization of the tensor using a fitted scaler. 82 | 83 | ## Examples 84 | 85 | iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 86 | iex> scaler = Scholar.Preprocessing.StandardScaler.fit(t) 87 | %Scholar.Preprocessing.StandardScaler{ 88 | standard_deviation: Nx.tensor( 89 | [ 90 | [1.0657403469085693] 91 | ] 92 | ), 93 | mean: Nx.tensor( 94 | [ 95 | [0.4444444477558136] 96 | ] 97 | ) 98 | } 99 | iex> Scholar.Preprocessing.StandardScaler.transform(scaler, t) 100 | #Nx.Tensor< 101 | f32[3][3] 102 | [ 103 | [0.5212860703468323, -1.3553436994552612, 1.4596009254455566], 104 | [1.4596009254455566, -0.4170288145542145, -0.4170288145542145], 105 | [-0.4170288145542145, 0.5212860703468323, -1.3553436994552612] 106 | ] 107 | > 108 | """ 109 | defn transform(%__MODULE__{standard_deviation: std, mean: mean}, tensor) do 110 | scale(tensor, std, mean) 111 | end 112 | 113 | @doc """ 114 | Standardizes the tensor by removing the mean and scaling to unit variance. 115 | 116 | ## Examples 117 | 118 | iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 119 | iex> Scholar.Preprocessing.StandardScaler.fit_transform(t) 120 | #Nx.Tensor< 121 | f32[3][3] 122 | [ 123 | [0.5212860703468323, -1.3553436994552612, 1.4596009254455566], 124 | [1.4596009254455566, -0.4170288145542145, -0.4170288145542145], 125 | [-0.4170288145542145, 0.5212860703468323, -1.3553436994552612] 126 | ] 127 | > 128 | """ 129 | defn fit_transform(tensor, opts \\ []) do 130 | tensor 131 | |> fit(opts) 132 | |> transform(tensor) 133 | end 134 | 135 | defnp scale(tensor, std, mean) do 136 | (tensor - mean) / Nx.select(std == 0, 1.0, std) 137 | end 138 | end 139 | -------------------------------------------------------------------------------- /lib/scholar/shared.ex: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Shared do 2 | @moduledoc false 3 | 4 | # Collection of private helper functions and 5 | # macros for enforcing shape/type constraints, 6 | # and doing shape calculations. 7 | 8 | import Nx.Defn 9 | require Nx 10 | 11 | @doc """ 12 | Asserts `left` has same shape as `right`. 13 | """ 14 | deftransform assert_same_shape!(left, right) do 15 | left_shape = Nx.shape(left) 16 | right_shape = Nx.shape(right) 17 | 18 | unless left_shape == right_shape do 19 | raise ArgumentError, 20 | "expected tensor to have shape #{inspect(left_shape)}, got tensor with shape #{inspect(right_shape)}" 21 | end 22 | end 23 | 24 | @doc """ 25 | Asserts `tensor` has rank `target_rank`. 26 | """ 27 | deftransform assert_rank!(tensor, target_rank) do 28 | rank = Nx.rank(tensor) 29 | 30 | unless rank == target_rank do 31 | raise ArgumentError, 32 | "expected tensor to have rank #{target_rank}, got tensor with rank #{rank}" 33 | end 34 | end 35 | 36 | @doc """ 37 | Returns the floating type of `tensor`. 38 | """ 39 | deftransform to_float_type(tensor) do 40 | tensor |> Nx.type() |> Nx.Type.to_floating() 41 | end 42 | 43 | @doc """ 44 | Converts `tensor` to the floating type. 45 | """ 46 | defn to_float(tensor) do 47 | type = to_float_type(tensor) 48 | Nx.as_type(tensor, type) 49 | end 50 | 51 | deftransform validate_weights(weights, num_samples, opts \\ []) do 52 | type = opts[:type] 53 | 54 | cond do 55 | is_nil(weights) -> 56 | Nx.tensor(1.0, type: type) 57 | 58 | Nx.is_tensor(weights) and Nx.shape(weights) in [{}, {num_samples}] -> 59 | weights |> Nx.broadcast({num_samples}) |> Nx.as_type(type) 60 | 61 | is_list(weights) and length(weights) == num_samples -> 62 | Nx.tensor(weights, type: type) 63 | 64 | true -> 65 | raise ArgumentError, 66 | "invalid value for :weights option: expected list or tensor of positive numbers of size #{num_samples}, got: #{inspect(weights)}" 67 | end 68 | end 69 | 70 | deftransform valid_broadcast!(n_dims, shape1, shape2) do 71 | if tuple_size(shape1) != tuple_size(shape2) do 72 | raise ArgumentError, 73 | "expected shapes to have same rank, got #{inspect(tuple_size(shape1))} and #{inspect(tuple_size(shape2))}" 74 | end 75 | 76 | valid_broadcast(n_dims, n_dims, shape1, shape2) 77 | end 78 | 79 | deftransformp valid_broadcast(0, _n_dims, _shape1, _shape2), do: true 80 | 81 | deftransformp valid_broadcast(to_parse, n_dims, shape1, shape2) do 82 | dim1 = elem(shape1, n_dims - to_parse) 83 | dim2 = elem(shape2, n_dims - to_parse) 84 | 85 | if not (dim1 == 1 or dim2 == 1 or dim2 == dim1) do 86 | raise ArgumentError, 87 | "tensors must be broadcast compatible, got tensors with shapes #{inspect(shape1)} and #{inspect(shape2)}" 88 | end 89 | 90 | valid_broadcast(to_parse - 1, n_dims, shape1, shape2) 91 | end 92 | end 93 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.MixProject do 2 | use Mix.Project 3 | 4 | @source_url "https://github.com/elixir-nx/scholar" 5 | @version "0.4.0" 6 | 7 | def project do 8 | [ 9 | app: :scholar, 10 | name: "Scholar", 11 | version: @version, 12 | elixir: "~> 1.14", 13 | elixirc_paths: elixirc_paths(Mix.env()), 14 | deps: deps(), 15 | docs: &docs/0, 16 | package: package() 17 | ] 18 | end 19 | 20 | defp elixirc_paths(:test), do: ["lib", "test/support"] 21 | defp elixirc_paths(_), do: ["lib"] 22 | 23 | def application do 24 | [ 25 | extra_applications: [:logger] 26 | ] 27 | end 28 | 29 | defp deps do 30 | [ 31 | {:ex_doc, "~> 0.34", only: :docs}, 32 | {:nx, "~> 0.9"}, 33 | {:nimble_options, "~> 0.5.2 or ~> 1.0"}, 34 | {:exla, ">= 0.0.0", only: :test}, 35 | {:polaris, "~> 0.1"}, 36 | {:benchee, "~> 1.0", only: :dev}, 37 | {:scidata, "~> 0.1.11", only: :test} 38 | ] 39 | end 40 | 41 | defp package do 42 | [ 43 | maintainers: ["Mateusz Słuszniak", "Krsto Proroković"], 44 | description: "Traditional machine learning on top of Nx", 45 | licenses: ["Apache-2.0"], 46 | links: %{"GitHub" => @source_url} 47 | ] 48 | end 49 | 50 | defp docs do 51 | [ 52 | main: "readme", 53 | source_url: @source_url, 54 | assets: %{"notebooks/files" => "files"}, 55 | logo: "images/scholar_simplified.png", 56 | extra_section: "Guides", 57 | extras: [ 58 | "README.md", 59 | "notebooks/cv_gradient_boosting_tree.livemd", 60 | # "notebooks/hierarchical_clustering.livemd", 61 | "notebooks/k_means.livemd", 62 | "notebooks/k_nearest_neighbors.livemd", 63 | "notebooks/linear_regression.livemd", 64 | "notebooks/manifold_learning.livemd", 65 | "notebooks/mds.livemd", 66 | "notebooks/nearest_neighbors.livemd" 67 | ], 68 | groups_for_modules: [ 69 | Models: [ 70 | Scholar.Cluster.AffinityPropagation, 71 | Scholar.Cluster.DBSCAN, 72 | Scholar.Cluster.GaussianMixture, 73 | Scholar.Cluster.Hierarchical, 74 | Scholar.Cluster.KMeans, 75 | Scholar.Decomposition.PCA, 76 | Scholar.Integrate, 77 | Scholar.Interpolation.BezierSpline, 78 | Scholar.Interpolation.CubicSpline, 79 | Scholar.Interpolation.Linear, 80 | Scholar.Linear.BayesianRidgeRegression, 81 | Scholar.Linear.IsotonicRegression, 82 | Scholar.Linear.LinearRegression, 83 | Scholar.Linear.LogisticRegression, 84 | Scholar.Linear.PolynomialRegression, 85 | Scholar.Linear.RidgeRegression, 86 | Scholar.Linear.SVM, 87 | Scholar.Manifold.MDS, 88 | Scholar.Manifold.Trimap, 89 | Scholar.Manifold.TSNE, 90 | Scholar.NaiveBayes.Complement, 91 | Scholar.NaiveBayes.Gaussian, 92 | Scholar.NaiveBayes.Multinomial, 93 | Scholar.Neighbors.BruteKNN, 94 | Scholar.Neighbors.KDTree, 95 | Scholar.Neighbors.KNNClassifier, 96 | Scholar.Neighbors.KNNRegressor, 97 | Scholar.Neighbors.LargeVis, 98 | Scholar.Neighbors.NNDescent, 99 | Scholar.Neighbors.RadiusNNClassifier, 100 | Scholar.Neighbors.RadiusNNRegressor, 101 | Scholar.Neighbors.RandomProjectionForest 102 | ], 103 | Utilities: [ 104 | Scholar.Impute.SimpleImputer, 105 | Scholar.Metrics.Classification, 106 | Scholar.Metrics.Clustering, 107 | Scholar.Metrics.Distance, 108 | Scholar.Metrics.Neighbors, 109 | Scholar.Metrics.Ranking, 110 | Scholar.Metrics.Regression, 111 | Scholar.Metrics.Similarity, 112 | Scholar.ModelSelection, 113 | Scholar.Preprocessing, 114 | Scholar.Preprocessing.MaxAbsScaler, 115 | Scholar.Preprocessing.MinMaxScaler, 116 | Scholar.Preprocessing.Normalizer, 117 | Scholar.Preprocessing.OneHotEncoder, 118 | Scholar.Preprocessing.OrdinalEncoder, 119 | Scholar.Preprocessing.StandardScaler, 120 | Scholar.Stats 121 | ] 122 | ], 123 | before_closing_body_tag: &before_closing_body_tag/1 124 | ] 125 | end 126 | 127 | defp before_closing_body_tag(:html) do 128 | """ 129 | 130 | 131 | 132 | 142 | 143 | 144 | 145 | 161 | """ 162 | end 163 | 164 | defp before_closing_body_tag(_), do: "" 165 | end 166 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "benchee": {:hex, :benchee, "1.3.1", "c786e6a76321121a44229dde3988fc772bca73ea75170a73fd5f4ddf1af95ccf", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: true]}], "hexpm", "76224c58ea1d0391c8309a8ecbfe27d71062878f59bd41a390266bf4ac1cc56d"}, 3 | "castore": {:hex, :castore, "0.1.22", "4127549e411bedd012ca3a308dede574f43819fe9394254ca55ab4895abfa1a2", [:mix], [], "hexpm", "c17576df47eb5aa1ee40cc4134316a99f5cad3e215d5c77b8dd3cfef12a22cac"}, 4 | "complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"}, 5 | "deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"}, 6 | "earmark_parser": {:hex, :earmark_parser, "1.4.42", "f23d856f41919f17cd06a493923a722d87a2d684f143a1e663c04a2b93100682", [:mix], [], "hexpm", "6915b6ca369b5f7346636a2f41c6a6d78b5af419d61a611079189233358b8b8b"}, 7 | "elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"}, 8 | "ex_doc": {:hex, :ex_doc, "0.36.1", "4197d034f93e0b89ec79fac56e226107824adcce8d2dd0a26f5ed3a95efc36b1", [:mix], [{:earmark_parser, "~> 1.4.42", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "d7d26a7cf965dacadcd48f9fa7b5953d7d0cfa3b44fa7a65514427da44eafd89"}, 9 | "exla": {:hex, :exla, "0.9.2", "2b5cb7334f79fedc301502a793ffd10bc1ec8de2c61eebabcabf213fc98ae7e6", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.9.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.8.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "e51085e196b466d235e93d9f5ea2cbf7d90315d216aa02e996f99bcaaa19c593"}, 10 | "jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"}, 11 | "makeup": {:hex, :makeup, "1.2.1", "e90ac1c65589ef354378def3ba19d401e739ee7ee06fb47f94c687016e3713d1", [:mix], [{:nimble_parsec, "~> 1.4", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "d36484867b0bae0fea568d10131197a4c2e47056a6fbe84922bf6ba71c8d17ce"}, 12 | "makeup_elixir": {:hex, :makeup_elixir, "1.0.1", "e928a4f984e795e41e3abd27bfc09f51db16ab8ba1aebdba2b3a575437efafc2", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "7284900d412a3e5cfd97fdaed4f5ed389b8f2b4cb49efc0eb3bd10e2febf9507"}, 13 | "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"}, 14 | "nimble_csv": {:hex, :nimble_csv, "1.2.0", "4e26385d260c61eba9d4412c71cea34421f296d5353f914afe3f2e71cce97722", [:mix], [], "hexpm", "d0628117fcc2148178b034044c55359b26966c6eaa8e2ce15777be3bbc91b12a"}, 15 | "nimble_options": {:hex, :nimble_options, "1.1.1", "e3a492d54d85fc3fd7c5baf411d9d2852922f66e69476317787a7b2bb000a61b", [:mix], [], "hexpm", "821b2470ca9442c4b6984882fe9bb0389371b8ddec4d45a9504f00a66f650b44"}, 16 | "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, 17 | "nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"}, 18 | "nx": {:hex, :nx, "0.9.2", "17563029c01bf749aad3c31234326d7665abd0acc33ee2acbe531a4759f29a8a", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "914d74741617d8103de8ab1f8c880353e555263e1c397b8a1109f79a3716557f"}, 19 | "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, 20 | "scidata": {:hex, :scidata, "0.1.11", "fe3358bac7d740374b4f2a7eff6a1cb02e5ee7f87f7cdb1e8648ad93c533165f", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.1", [hex: :nimble_csv, repo: "hexpm", optional: false]}, {:stb_image, "~> 0.4", [hex: :stb_image, repo: "hexpm", optional: true]}], "hexpm", "90873337a9d5fe880d640517efa93d3c07e46c8ba436de44117f581800549f93"}, 21 | "statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"}, 22 | "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, 23 | "xla": {:hex, :xla, "0.8.0", "fef314d085dd3ee16a0816c095239938f80769150e15db16dfaa435553d7cb16", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "739c61c8d93b97e12ba0369d10e76130224c208f1a76ad293e3581f056833e57"}, 24 | } 25 | -------------------------------------------------------------------------------- /notebooks/files/knn_gyms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/scholar/48aa36e95dd14fa51e3ad5e23fc72c3a3eed3877/notebooks/files/knn_gyms.png -------------------------------------------------------------------------------- /notebooks/files/mammoth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/scholar/48aa36e95dd14fa51e3ad5e23fc72c3a3eed3877/notebooks/files/mammoth.png -------------------------------------------------------------------------------- /notebooks/files/rbf_and_kdtree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/scholar/48aa36e95dd14fa51e3ad5e23fc72c3a3eed3877/notebooks/files/rbf_and_kdtree.png -------------------------------------------------------------------------------- /notebooks/files/rpKNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/scholar/48aa36e95dd14fa51e3ad5e23fc72c3a3eed3877/notebooks/files/rpKNN.png -------------------------------------------------------------------------------- /test/scholar/cluster/affinity_propagation_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Cluster.AffinityPropagationTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Cluster.AffinityPropagation 4 | doctest AffinityPropagation 5 | 6 | defp key do 7 | Nx.Random.key(42) 8 | end 9 | 10 | defp x do 11 | Nx.tensor([ 12 | [16, 2, 17], 13 | [17, 3, 9], 14 | [9, 16, 15], 15 | [13, 8, 8], 16 | [3, 5, 15], 17 | [19, 11, 9], 18 | [8, 15, 2], 19 | [16, 2, 2], 20 | [10, 10, 0], 21 | [8, 7, 5], 22 | [4, 11, 8], 23 | [11, 17, 7], 24 | [16, 4, 2], 25 | [13, 9, 7], 26 | [18, 16, 12], 27 | [6, 8, 6], 28 | [18, 13, 1], 29 | [2, 2, 2], 30 | [0, 1, 18], 31 | [12, 16, 18], 32 | [3, 14, 5], 33 | [2, 16, 13], 34 | [6, 6, 13], 35 | [16, 3, 5], 36 | [0, 16, 5], 37 | [4, 18, 5], 38 | [5, 8, 0], 39 | [1, 5, 15], 40 | [10, 0, 14], 41 | [13, 8, 14], 42 | [19, 2, 9], 43 | [17, 17, 0], 44 | [19, 14, 19], 45 | [9, 19, 10], 46 | [11, 4, 12], 47 | [3, 16, 19], 48 | [17, 3, 6], 49 | [9, 16, 10], 50 | [5, 17, 3], 51 | [3, 15, 17] 52 | ]) 53 | end 54 | 55 | defp x_test do 56 | Nx.tensor([ 57 | [12, 11, 15], 58 | [6, 3, 3], 59 | [8, 16, 16], 60 | [12, 2, 17], 61 | [11, 3, 17], 62 | [15, 1, 14], 63 | [0, 6, 7], 64 | [7, 9, 3], 65 | [13, 3, 16], 66 | [11, 2, 2] 67 | ]) 68 | end 69 | 70 | test "fit and compute_values" do 71 | model = AffinityPropagation.fit(x(), key: key(), preference: :median) 72 | 73 | model = AffinityPropagation.prune(model) 74 | 75 | assert model.labels == 76 | Nx.tensor( 77 | [5, 6, 0, 2, 1, 3, 7, 6, 2, 2, 2, 7, 6, 2, 3, 2, 4, 2, 1] ++ 78 | [0, 7, 0, 1, 6, 7, 7, 2, 1, 5, 5, 6, 4, 3, 0, 5, 0, 6, 0, 7, 0] 79 | ) 80 | 81 | assert model.cluster_centers == 82 | Nx.tensor([ 83 | [9.0, 16.0, 15.0], 84 | [3.0, 5.0, 15.0], 85 | [8.0, 7.0, 5.0], 86 | [18.0, 16.0, 12.0], 87 | [18.0, 13.0, 1.0], 88 | [11.0, 4.0, 12.0], 89 | [17.0, 3.0, 6.0], 90 | [5.0, 17.0, 3.0] 91 | ]) 92 | 93 | assert model.cluster_centers_indices == Nx.tensor([2, 4, 9, 14, 16, 34, 36, 38]) 94 | end 95 | 96 | test "predict with pruning" do 97 | model = AffinityPropagation.fit(x(), key: key(), preference: :median) 98 | model = AffinityPropagation.prune(model) 99 | preds = AffinityPropagation.predict(model, x_test()) 100 | assert preds == Nx.tensor([0, 2, 0, 5, 5, 5, 2, 2, 5, 2]) 101 | end 102 | 103 | test "predict without pruning" do 104 | model = AffinityPropagation.fit(x(), key: key(), preference: :median) 105 | preds = AffinityPropagation.predict(model, x_test()) 106 | assert preds == Nx.tensor([2, 9, 2, 34, 34, 34, 9, 9, 34, 9]) 107 | end 108 | end 109 | -------------------------------------------------------------------------------- /test/scholar/cluster/gmm_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Cluster.GaussianMixtureTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Cluster.GaussianMixture 4 | doctest GaussianMixture 5 | 6 | describe "invalid arguments" do 7 | test "when :num_gaussians is not provided" do 8 | x = Nx.tensor([[1, 2], [3, 4], [5, 6]]) 9 | 10 | assert_raise NimbleOptions.ValidationError, 11 | "required :num_gaussians option not found, received options: []", 12 | fn -> 13 | GaussianMixture.fit(x) 14 | end 15 | end 16 | 17 | test "when :num_gaussians is invalid" do 18 | x = Nx.tensor([[1, 2], [3, 4], [5, 6]]) 19 | 20 | assert_raise ArgumentError, 21 | "invalid value for :num_gaussians option: expected positive integer between 1 and 3, got: 4", 22 | fn -> 23 | GaussianMixture.fit(x, num_gaussians: 4) 24 | end 25 | 26 | assert_raise NimbleOptions.ValidationError, 27 | "invalid value for :num_gaussians option: expected positive integer, got: 2.0", 28 | fn -> 29 | GaussianMixture.fit(x, num_gaussians: 2.0) 30 | end 31 | 32 | assert_raise NimbleOptions.ValidationError, 33 | "invalid value for :num_gaussians option: expected positive integer, got: -1", 34 | fn -> 35 | GaussianMixture.fit(x, num_gaussians: -1) 36 | end 37 | end 38 | 39 | test "when training vector size is invalid" do 40 | x = Nx.tensor([5, 6]) 41 | 42 | assert_raise ArgumentError, 43 | "expected input tensor to have shape {n_samples, n_features}, got tensor with shape: {2}", 44 | fn -> GaussianMixture.fit(x, num_gaussians: 2) end 45 | end 46 | 47 | test "when :num_runs is not a positive integer" do 48 | x = Nx.tensor([[1, 2], [3, 4]]) 49 | 50 | assert_raise NimbleOptions.ValidationError, 51 | "invalid value for :num_runs option: expected positive integer, got: 0", 52 | fn -> 53 | GaussianMixture.fit(x, num_gaussians: 2, num_runs: 0) 54 | end 55 | 56 | assert_raise NimbleOptions.ValidationError, 57 | "invalid value for :num_runs option: expected positive integer, got: 10.0", 58 | fn -> 59 | GaussianMixture.fit(x, num_gaussians: 2, num_runs: 10.0) 60 | end 61 | end 62 | 63 | test "when :max_iter is not a positive integer" do 64 | x = Nx.tensor([[1, 2], [3, 4]]) 65 | 66 | assert_raise NimbleOptions.ValidationError, 67 | "invalid value for :max_iter option: expected positive integer, got: 0", 68 | fn -> 69 | GaussianMixture.fit(x, num_gaussians: 2, max_iter: 0) 70 | end 71 | 72 | assert_raise NimbleOptions.ValidationError, 73 | "invalid value for :max_iter option: expected positive integer, got: 200.0", 74 | fn -> 75 | GaussianMixture.fit(x, num_gaussians: 2, max_iter: 200.0) 76 | end 77 | end 78 | 79 | test "when :tol is not a non-negative number" do 80 | x = Nx.tensor([[1, 2], [3, 4]]) 81 | 82 | assert_raise NimbleOptions.ValidationError, 83 | "invalid value for :tol option: expected a non-negative number, got: -0.1", 84 | fn -> 85 | GaussianMixture.fit(x, num_gaussians: 2, tol: -0.1) 86 | end 87 | end 88 | end 89 | end 90 | -------------------------------------------------------------------------------- /test/scholar/covariance/ledoit_wolf_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Covariance.LedoitWolfTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Covariance.LedoitWolf 4 | doctest LedoitWolf 5 | 6 | defp key do 7 | Nx.Random.key(1) 8 | end 9 | 10 | test "fit test - all default options" do 11 | key = key() 12 | 13 | {x, _new_key} = 14 | Nx.Random.multivariate_normal( 15 | key, 16 | Nx.tensor([0.0, 0.0, 0.0]), 17 | Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), 18 | shape: {10}, 19 | type: :f32 20 | ) 21 | 22 | model = LedoitWolf.fit(x) 23 | 24 | assert_all_close( 25 | model.covariance, 26 | Nx.tensor([ 27 | [1.439786434173584, -0.0, 0.0], 28 | [-0.0, 1.439786434173584, 0.0], 29 | [0.0, 0.0, 1.439786434173584] 30 | ]), 31 | atol: 1.0e-3 32 | ) 33 | 34 | assert_all_close(model.shrinkage, Nx.tensor(1.0), atol: 1.0e-3) 35 | 36 | assert_all_close( 37 | model.location, 38 | Nx.tensor([-1.015519142150879, -0.4495307505130768, 0.06475571542978287]), 39 | atol: 1.0e-3 40 | ) 41 | end 42 | 43 | test "fit test - :assume_centered? is true" do 44 | key = key() 45 | 46 | {x, _new_key} = 47 | Nx.Random.multivariate_normal( 48 | key, 49 | Nx.tensor([0.0, 0.0, 0.0]), 50 | Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), 51 | shape: {10}, 52 | type: :f32 53 | ) 54 | 55 | model = LedoitWolf.fit(x, assume_centered?: true) 56 | 57 | assert_all_close( 58 | model.covariance, 59 | Nx.tensor([ 60 | [1.852303147315979, 0.0, 0.0], 61 | [0.0, 1.852303147315979, 0.0], 62 | [0.0, 0.0, 1.852303147315979] 63 | ]), 64 | atol: 1.0e-3 65 | ) 66 | 67 | assert_all_close(model.shrinkage, Nx.tensor(1.0), atol: 1.0e-3) 68 | 69 | assert_all_close(model.location, Nx.tensor([0, 0, 0]), atol: 1.0e-3) 70 | end 71 | 72 | test "fit test 2" do 73 | key = key() 74 | 75 | {x, _new_key} = 76 | Nx.Random.multivariate_normal( 77 | key, 78 | Nx.tensor([0.0, 0.0]), 79 | Nx.tensor([[2.2, 1.5], [0.7, 1.1]]), 80 | shape: {50}, 81 | type: :f32 82 | ) 83 | 84 | model = LedoitWolf.fit(x) 85 | 86 | assert_all_close( 87 | model.covariance, 88 | Nx.tensor([ 89 | [1.8378269672393799, 0.27215731143951416], 90 | [0.27215731143951416, 1.2268550395965576] 91 | ]), 92 | atol: 1.0e-3 93 | ) 94 | 95 | assert_all_close(model.shrinkage, Nx.tensor(0.38731059432029724), atol: 1.0e-3) 96 | 97 | assert_all_close(model.location, Nx.tensor([0.06882287561893463, 0.13750331103801727]), 98 | atol: 1.0e-3 99 | ) 100 | end 101 | 102 | test "fit test - 1 dim x" do 103 | key = key() 104 | 105 | {x, _new_key} = 106 | Nx.Random.multivariate_normal(key, Nx.tensor([0.0]), Nx.tensor([[0.4]]), 107 | shape: {15}, 108 | type: :f32 109 | ) 110 | 111 | x = Nx.flatten(x) 112 | 113 | model = LedoitWolf.fit(x) 114 | 115 | assert_all_close( 116 | model.covariance, 117 | Nx.tensor([ 118 | [0.5322133302688599] 119 | ]), 120 | atol: 1.0e-3 121 | ) 122 | 123 | assert_all_close(model.shrinkage, Nx.tensor(0.0), atol: 1.0e-3) 124 | 125 | assert_all_close(model.location, Nx.tensor([0.060818854719400406]), atol: 1.0e-3) 126 | end 127 | end 128 | -------------------------------------------------------------------------------- /test/scholar/covariance/shrunk_covariance_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Covariance.ShrunkCovarianceTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Covariance.ShrunkCovariance 4 | doctest ShrunkCovariance 5 | 6 | defp key do 7 | Nx.Random.key(1) 8 | end 9 | 10 | test "fit test - all default options" do 11 | key = key() 12 | 13 | {x, _new_key} = 14 | Nx.Random.multivariate_normal( 15 | key, 16 | Nx.tensor([0.0, 0.0, 0.0]), 17 | Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), 18 | shape: {10}, 19 | type: :f32 20 | ) 21 | 22 | model = ShrunkCovariance.fit(x) 23 | 24 | assert_all_close( 25 | model.covariance, 26 | Nx.tensor([ 27 | [2.0949244499206543, -0.13400490581989288, 0.5413897037506104], 28 | [-0.13400490581989288, 1.2940725088119507, 0.0621684193611145], 29 | [0.5413897037506104, 0.0621684193611145, 0.9303621053695679] 30 | ]), 31 | atol: 1.0e-3 32 | ) 33 | 34 | assert_all_close( 35 | model.location, 36 | Nx.tensor([-1.015519142150879, -0.4495307505130768, 0.06475571542978287]), 37 | atol: 1.0e-3 38 | ) 39 | end 40 | 41 | test "fit test - :assume_centered? is true" do 42 | key = key() 43 | 44 | {x, _new_key} = 45 | Nx.Random.multivariate_normal( 46 | key, 47 | Nx.tensor([0.0, 0.0, 0.0]), 48 | Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), 49 | shape: {10}, 50 | type: :f32 51 | ) 52 | 53 | model = ShrunkCovariance.fit(x, assume_centered?: true) 54 | 55 | assert_all_close( 56 | model.covariance, 57 | Nx.tensor([ 58 | [3.0643274784088135, 0.27685147523880005, 0.4822050631046295], 59 | [0.27685147523880005, 1.5171942710876465, 0.03596973791718483], 60 | [0.4822050631046295, 0.03596973791718483, 0.975387692451477] 61 | ]), 62 | atol: 1.0e-3 63 | ) 64 | 65 | assert_all_close(model.location, Nx.tensor(0), atol: 1.0e-3) 66 | end 67 | 68 | test "fit test - :shrinkage" do 69 | key = key() 70 | 71 | {x, _new_key} = 72 | Nx.Random.multivariate_normal( 73 | key, 74 | Nx.tensor([0.0, 0.0, 0.0]), 75 | Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), 76 | shape: {10}, 77 | type: :f32 78 | ) 79 | 80 | model = ShrunkCovariance.fit(x, shrinkage: 0.8) 81 | 82 | assert_all_close( 83 | model.covariance, 84 | Nx.tensor([ 85 | [1.5853726863861084, -0.029778867959976196, 0.12030883133411407], 86 | [-0.029778867959976196, 1.4074056148529053, 0.013815204612910748], 87 | [0.12030883133411407, 0.013815204612910748, 1.3265810012817383] 88 | ]), 89 | atol: 1.0e-3 90 | ) 91 | 92 | assert_all_close( 93 | model.location, 94 | Nx.tensor([-1.015519142150879, -0.4495307505130768, 0.06475571542978287]), 95 | atol: 1.0e-3 96 | ) 97 | end 98 | 99 | test "fit test 2" do 100 | key = key() 101 | 102 | {x, _new_key} = 103 | Nx.Random.multivariate_normal( 104 | key, 105 | Nx.tensor([0.0, 0.0]), 106 | Nx.tensor([[2.2, 1.5], [0.7, 1.1]]), 107 | shape: {50}, 108 | type: :f32 109 | ) 110 | 111 | model = ShrunkCovariance.fit(x) 112 | 113 | assert_all_close( 114 | model.covariance, 115 | Nx.tensor([ 116 | [1.9810796976089478, 0.3997809886932373], 117 | [0.3997809886932373, 1.0836023092269897] 118 | ]), 119 | atol: 1.0e-3 120 | ) 121 | 122 | assert_all_close(model.location, Nx.tensor([0.06882287561893463, 0.13750331103801727]), 123 | atol: 1.0e-3 124 | ) 125 | end 126 | 127 | test "fit test - 1 dim x" do 128 | key = key() 129 | 130 | {x, _new_key} = 131 | Nx.Random.multivariate_normal(key, Nx.tensor([0.0]), Nx.tensor([[0.4]]), 132 | shape: {15}, 133 | type: :f32 134 | ) 135 | 136 | x = Nx.flatten(x) 137 | 138 | model = ShrunkCovariance.fit(x) 139 | 140 | assert_all_close( 141 | model.covariance, 142 | Nx.tensor([ 143 | [0.5322133302688599] 144 | ]), 145 | atol: 1.0e-3 146 | ) 147 | 148 | assert_all_close(model.location, Nx.tensor([0.060818854719400406]), atol: 1.0e-3) 149 | end 150 | end 151 | -------------------------------------------------------------------------------- /test/scholar/feature_extraction/count_vectorizer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.CountVectorizer do 2 | use Scholar.Case, async: true 3 | alias Scholar.FeatureExtraction.CountVectorizer 4 | doctest CountVectorizer 5 | 6 | describe "fit_transform" do 7 | test "without padding" do 8 | tesnsor = Nx.tensor([[2, 3, 0], [1, 4, 4]]) 9 | 10 | counts = 11 | CountVectorizer.fit_transform(tesnsor, 12 | max_token_id: CountVectorizer.max_token_id(tesnsor) 13 | ) 14 | 15 | expected_counts = Nx.tensor([[1, 0, 1, 1, 0], [0, 1, 0, 0, 2]]) 16 | 17 | assert counts == expected_counts 18 | end 19 | 20 | test "with padding" do 21 | tensor = Nx.tensor([[2, 3, 0], [1, 4, -1]]) 22 | 23 | counts = 24 | CountVectorizer.fit_transform(tensor, max_token_id: CountVectorizer.max_token_id(tensor)) 25 | 26 | expected_counts = Nx.tensor([[1, 0, 1, 1, 0], [0, 1, 0, 0, 1]]) 27 | 28 | assert counts == expected_counts 29 | end 30 | end 31 | 32 | describe "max_token_id" do 33 | test "without padding" do 34 | tensor = Nx.tensor([[2, 3, 0], [1, 4, 4]]) 35 | assert CountVectorizer.max_token_id(tensor) == 4 36 | end 37 | 38 | test "with padding" do 39 | tensor = Nx.tensor([[2, 3, 0], [1, 4, -1]]) 40 | assert CountVectorizer.max_token_id(tensor) == 4 41 | end 42 | end 43 | 44 | describe "errors" do 45 | test "wrong input rank" do 46 | assert_raise ArgumentError, 47 | "expected tensor to have shape {num_documents, num_tokens}, got tensor with shape: {3}", 48 | fn -> 49 | CountVectorizer.fit_transform(Nx.tensor([1, 2, 3]), max_token_id: 3) 50 | end 51 | end 52 | end 53 | end 54 | -------------------------------------------------------------------------------- /test/scholar/impute/knn_imputter_test.exs: -------------------------------------------------------------------------------- 1 | defmodule KNNImputterTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Impute.KNNImputter 4 | doctest KNNImputter 5 | 6 | describe "general cases" do 7 | def generate_data() do 8 | x = Nx.iota({5, 4}) 9 | x = Nx.select(Nx.equal(Nx.quotient(x, 5), 2), Nx.Constants.nan(), x) 10 | Nx.indexed_put(x, Nx.tensor([[4, 2]]), Nx.tensor([6.0])) 11 | end 12 | 13 | test "general KNN imputer" do 14 | x = generate_data() 15 | jit_fit = Nx.Defn.jit(&KNNImputter.fit/2) 16 | jit_transform = Nx.Defn.jit(&KNNImputter.transform/2) 17 | 18 | knn_imputer = 19 | %KNNImputter{statistics: statistics, missing_values: missing_values} = 20 | jit_fit.(x, missing_values: :nan, num_neighbors: 2) 21 | 22 | assert missing_values == :nan 23 | 24 | assert statistics == 25 | Nx.tensor([ 26 | [:nan, :nan, :nan, :nan], 27 | [:nan, :nan, :nan, :nan], 28 | [:nan, :nan, 4.0, 5.0], 29 | [2.0, 3.0, 4.0, :nan], 30 | [:nan, :nan, :nan, :nan] 31 | ]) 32 | 33 | assert jit_transform.(knn_imputer, x) == 34 | Nx.tensor([ 35 | [0.0, 1.0, 2.0, 3.0], 36 | [4.0, 5.0, 6.0, 7.0], 37 | [8.0, 9.0, 4.0, 5.0], 38 | [2.0, 3.0, 4.0, 15.0], 39 | [16.0, 17.0, 6.0, 19.0] 40 | ]) 41 | end 42 | 43 | test "general KNN imputer with different number of neighbors" do 44 | x = generate_data() 45 | jit_fit = Nx.Defn.jit(&KNNImputter.fit/2) 46 | jit_transform = Nx.Defn.jit(&KNNImputter.transform/2) 47 | 48 | knn_imputter = 49 | %KNNImputter{statistics: statistics, missing_values: missing_values} = 50 | jit_fit.(x, missing_values: :nan, num_neighbors: 1) 51 | 52 | assert missing_values == :nan 53 | 54 | assert statistics == 55 | Nx.tensor([ 56 | [:nan, :nan, :nan, :nan], 57 | [:nan, :nan, :nan, :nan], 58 | [:nan, :nan, 2.0, 3.0], 59 | [0.0, 1.0, 2.0, :nan], 60 | [:nan, :nan, :nan, :nan] 61 | ]) 62 | 63 | assert jit_transform.(knn_imputter, x) == 64 | Nx.tensor([ 65 | [0.0, 1.0, 2.0, 3.0], 66 | [4.0, 5.0, 6.0, 7.0], 67 | [8.0, 9.0, 2.0, 3.0], 68 | [0.0, 1.0, 2.0, 15.0], 69 | [16.0, 17.0, 6.0, 19.0] 70 | ]) 71 | end 72 | 73 | test "missing values different than :nan" do 74 | x = generate_data() 75 | x = Nx.select(Nx.is_nan(x), 19.0, x) 76 | # x = Nx.select(Nx.equal(x,19), :nan, x) 77 | jit_fit = Nx.Defn.jit(&KNNImputter.fit/2) 78 | jit_transform = Nx.Defn.jit(&KNNImputter.transform/2) 79 | 80 | knn_imputter = 81 | %KNNImputter{statistics: statistics, missing_values: missing_values} = 82 | jit_fit.(x, missing_values: 19.0, num_neighbors: 2) 83 | 84 | assert missing_values == 19.0 85 | 86 | assert statistics == 87 | Nx.tensor([ 88 | [:nan, :nan, :nan, :nan], 89 | [:nan, :nan, :nan, :nan], 90 | [:nan, :nan, 4.0, 5.0], 91 | [2.0, 3.0, 4.0, :nan], 92 | [:nan, :nan, :nan, 5.0] 93 | ]) 94 | 95 | assert jit_transform.(knn_imputter, x) == 96 | Nx.tensor([ 97 | [0.0, 1.0, 2.0, 3.0], 98 | [4.0, 5.0, 6.0, 7.0], 99 | [8.0, 9.0, 4.0, 5.0], 100 | [2.0, 3.0, 4.0, 15.0], 101 | [16.0, 17.0, 6.0, 5.0] 102 | ]) 103 | end 104 | end 105 | 106 | describe "errors" do 107 | test "invalid impute rank" do 108 | x = Nx.tensor([1, 2, 2, 3]) 109 | 110 | assert_raise ArgumentError, 111 | "wrong input rank. Expected: 2, got: 1", 112 | fn -> 113 | KNNImputter.fit(x, missing_values: 1, num_neighbors: 2) 114 | end 115 | end 116 | 117 | test "invalid n_neighbors value" do 118 | x = generate_data() 119 | 120 | jit_fit = Nx.Defn.jit(&KNNImputter.fit/2) 121 | 122 | assert_raise NimbleOptions.ValidationError, 123 | "invalid value for :num_neighbors option: expected positive integer, got: -1", 124 | fn -> 125 | jit_fit.(x, missing_values: 1.0, num_neighbors: -1) 126 | end 127 | end 128 | end 129 | end 130 | -------------------------------------------------------------------------------- /test/scholar/impute/simple_imputer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule SimpleImputerTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Impute.SimpleImputer 4 | doctest SimpleImputer 5 | 6 | describe "general cases" do 7 | def generate_data() do 8 | x = Nx.iota({5, 4}) 9 | x = Nx.select(Nx.equal(Nx.quotient(x, 5), 2), Nx.Constants.nan(), x) 10 | Nx.indexed_put(x, Nx.tensor([[4, 2]]), Nx.tensor([6.0])) 11 | end 12 | 13 | test "general test mode" do 14 | x = generate_data() 15 | 16 | simple_imputer_mode = 17 | %SimpleImputer{statistics: statistics, missing_values: missing_values} = 18 | SimpleImputer.fit(x, missing_values: :nan, strategy: :mode) 19 | 20 | assert statistics == Nx.tensor([0.0, 1.0, 6.0, 3.0]) 21 | assert missing_values == :nan 22 | 23 | assert SimpleImputer.transform(simple_imputer_mode, x) == 24 | Nx.tensor([ 25 | [0.0, 1.0, 2.0, 3.0], 26 | [4.0, 5.0, 6.0, 7.0], 27 | [8.0, 9.0, 6.0, 3.0], 28 | [0.0, 1.0, 6.0, 15.0], 29 | [16.0, 17.0, 6.0, 19.0] 30 | ]) 31 | end 32 | 33 | test "general test median" do 34 | x = generate_data() 35 | 36 | simple_imputer_median = 37 | %SimpleImputer{statistics: statistics, missing_values: missing_values} = 38 | SimpleImputer.fit(x, missing_values: :nan, strategy: :median) 39 | 40 | assert statistics == Nx.tensor([6.0, 7.0, 6.0, 11.0]) 41 | assert missing_values == :nan 42 | 43 | assert SimpleImputer.transform(simple_imputer_median, x) == 44 | Nx.tensor([ 45 | [0.0, 1.0, 2.0, 3.0], 46 | [4.0, 5.0, 6.0, 7.0], 47 | [8.0, 9.0, 6.0, 11.0], 48 | [6.0, 7.0, 6.0, 15.0], 49 | [16.0, 17.0, 6.0, 19.0] 50 | ]) 51 | end 52 | 53 | test "general test mean" do 54 | x = generate_data() 55 | 56 | simple_imputer_mean = 57 | %SimpleImputer{statistics: statistics, missing_values: missing_values} = 58 | SimpleImputer.fit(x, missing_values: :nan, strategy: :mean) 59 | 60 | assert_all_close(statistics, Nx.tensor([7.0, 8.0, 4.666666507720947, 11.0])) 61 | assert missing_values == :nan 62 | 63 | assert_all_close( 64 | SimpleImputer.transform(simple_imputer_mean, x), 65 | Nx.tensor([ 66 | [0.0, 1.0, 2.0, 3.0], 67 | [4.0, 5.0, 6.0, 7.0], 68 | [8.0, 9.0, 4.666666507720947, 11.0], 69 | [7.0, 8.0, 4.666666507720947, 15.0], 70 | [16.0, 17.0, 6.0, 19.0] 71 | ]) 72 | ) 73 | end 74 | 75 | test "general test constant value" do 76 | x = generate_data() 77 | 78 | simple_imputer_constant_with_zeros = 79 | %SimpleImputer{statistics: statistics, missing_values: missing_values} = 80 | SimpleImputer.fit(x, missing_values: :nan, strategy: :constant) 81 | 82 | assert statistics == Nx.tensor([0.0, 0.0, 0.0, 0.0]) 83 | assert missing_values == :nan 84 | 85 | %SimpleImputer{statistics: statistics, missing_values: missing_values} = 86 | SimpleImputer.fit(x, 87 | missing_values: :nan, 88 | strategy: :constant, 89 | fill_value: 1.37 90 | ) 91 | 92 | assert statistics == Nx.tensor([1.37, 1.37, 1.37, 1.37]) 93 | assert missing_values == :nan 94 | 95 | assert SimpleImputer.transform(simple_imputer_constant_with_zeros, x) == 96 | Nx.tensor([ 97 | [0.0, 1.0, 2.0, 3.0], 98 | [4.0, 5.0, 6.0, 7.0], 99 | [8.0, 9.0, 0.0, 0.0], 100 | [0.0, 0.0, 0.0, 15.0], 101 | [16.0, 17.0, 6.0, 19.0] 102 | ]) 103 | end 104 | end 105 | 106 | test "mode with integer type" do 107 | x = Nx.tile(Nx.tensor([1, 2, 1, 2, 1, 2]), [5, 1]) |> Nx.reshape({6, 5}) 108 | 109 | simple_imputer_constant_with_zeros = 110 | %SimpleImputer{statistics: statistics, missing_values: missing_values} = 111 | SimpleImputer.fit(x, missing_values: 1, strategy: :mode) 112 | 113 | assert statistics == Nx.tensor([2, 2, 2, 2, 2]) 114 | assert missing_values == 1 115 | 116 | assert SimpleImputer.transform(simple_imputer_constant_with_zeros, x) == 117 | Nx.broadcast(2, {6, 5}) 118 | end 119 | 120 | describe "errors" do 121 | test "Wrong impute rank" do 122 | x = Nx.tensor([1, 2, 2, 3]) 123 | 124 | assert_raise ArgumentError, 125 | "wrong input rank. Expected: 2, got: 1", 126 | fn -> 127 | SimpleImputer.fit(x, missing_values: 1, strategy: :mode) 128 | end 129 | end 130 | 131 | test "Wrong :fill_value type" do 132 | x = Nx.tensor([[1.0, 2.0, 2.0, 3.0]]) 133 | 134 | assert_raise ArgumentError, 135 | "wrong type of `:fill_value` for the given data. Expected: :f or :bf, got: :s", 136 | fn -> 137 | SimpleImputer.fit(x, 138 | missing_values: 1.0, 139 | strategy: :constant, 140 | fill_value: 2 141 | ) 142 | end 143 | end 144 | end 145 | end 146 | -------------------------------------------------------------------------------- /test/scholar/interpolation/bezier_spline_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Interpolation.BezierSplineTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Interpolation.BezierSpline 4 | doctest BezierSpline 5 | 6 | describe "bezier spline" do 7 | test "fit/2" do 8 | x = Nx.iota({4}) 9 | y = Nx.tensor([0, 1, 8, 1]) 10 | 11 | model = BezierSpline.fit(x, y) 12 | 13 | assert_all_close( 14 | model.coefficients, 15 | Nx.tensor([ 16 | [ 17 | [0.0, 0.0], 18 | [0.3333335816860199, -0.5111109614372253], 19 | [0.6666669845581055, -1.0222218036651611], 20 | [1.0, 1.0] 21 | ], 22 | [ 23 | [1.0, 1.0], 24 | [1.3333330154418945, 3.022221803665161], 25 | [1.6666665077209473, 7.577777862548828], 26 | [2.0, 8.0] 27 | ], 28 | [ 29 | [2.0, 8.0], 30 | [2.3333334922790527, 8.422222137451172], 31 | [2.6666667461395264, 4.711111068725586], 32 | [3.0, 1.0] 33 | ] 34 | ]) 35 | ) 36 | 37 | assert model.k == Nx.stack([x, y], axis: 1) 38 | end 39 | 40 | test "input validation error cases" do 41 | assert_raise ArgumentError, 42 | "expected x to be a tensor with shape {n}, where n > 2, got: {1, 1, 1}", 43 | fn -> 44 | BezierSpline.fit(Nx.iota({1, 1, 1}), Nx.iota({1, 1, 1})) 45 | end 46 | 47 | assert_raise ArgumentError, 48 | "expected x to be a tensor with shape {n}, where n > 2, got: {}", 49 | fn -> 50 | BezierSpline.fit(Nx.iota({}), Nx.iota({})) 51 | end 52 | 53 | assert_raise ArgumentError, 54 | "expected x to be a tensor with shape {n}, where n > 2, got: {1}", 55 | fn -> 56 | BezierSpline.fit(Nx.iota({1}), Nx.iota({1})) 57 | end 58 | 59 | assert_raise ArgumentError, "expected y to have shape {4}, got: {3}", fn -> 60 | BezierSpline.fit(Nx.iota({4}), Nx.iota({3})) 61 | end 62 | end 63 | 64 | test "predict/2" do 65 | x = Nx.iota({4}) 66 | y = Nx.tensor([0, 1, 8, 1]) 67 | 68 | model = BezierSpline.fit(x, y) 69 | 70 | assert_all_close( 71 | BezierSpline.predict(model, Nx.tensor([0, 1, 2, 3, -0.5, 0.5, 1.5, 2.5, 3.5]), 72 | max_iter: 20, 73 | eps: 1.0e-3 74 | ), 75 | Nx.tensor([ 76 | 0.0, 77 | 0.9881857633590698, 78 | 7.997480392456055, 79 | 1.0217341184616089, 80 | 7.31151374111505e-7, 81 | -0.4500003159046173, 82 | 5.083068370819092, 83 | 6.065662860870361, 84 | -4.0382304191589355 85 | ]) 86 | ) 87 | end 88 | end 89 | end 90 | -------------------------------------------------------------------------------- /test/scholar/interpolation/linear_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Interpolation.LinearTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Interpolation.Linear 4 | doctest Linear 5 | 6 | describe "linear" do 7 | test "fit/2" do 8 | x = Nx.iota({4}) 9 | y = Nx.tensor([0, 0, 1, 5]) 10 | 11 | model = Linear.fit(x, y) 12 | 13 | assert model.coefficients == 14 | Nx.tensor([ 15 | [0.0, 0.0], 16 | [1.0, -1.0], 17 | [4.0, -7.0] 18 | ]) 19 | end 20 | 21 | test "input validation error cases" do 22 | assert_raise ArgumentError, 23 | "expected x to be a tensor with shape {n}, where n > 1, got: {1, 1, 1}", 24 | fn -> 25 | Linear.fit(Nx.iota({1, 1, 1}), Nx.iota({1, 1, 1})) 26 | end 27 | 28 | assert_raise ArgumentError, 29 | "expected x to be a tensor with shape {n}, where n > 1, got: {}", 30 | fn -> 31 | Linear.fit(Nx.iota({}), Nx.iota({})) 32 | end 33 | 34 | assert_raise ArgumentError, 35 | "expected x to be a tensor with shape {n}, where n > 1, got: {1}", 36 | fn -> 37 | Linear.fit(Nx.iota({1}), Nx.iota({1})) 38 | end 39 | 40 | assert_raise ArgumentError, "expected y to have shape {4}, got: {3}", fn -> 41 | Linear.fit(Nx.iota({4}), Nx.iota({3})) 42 | end 43 | end 44 | 45 | test "predict/2" do 46 | x = Nx.iota({4}) 47 | y = Nx.tensor([0, 0, 1, 5]) 48 | 49 | model = Linear.fit(x, y) 50 | 51 | assert Linear.predict(model, Nx.tensor([[[-0.5], [0.5], [1.5], [2.5], [3.5]]])) == 52 | Nx.tensor([[[0.0], [0.0], [0.5], [3], [7]]]) 53 | end 54 | 55 | test "with different types" do 56 | x_s = Nx.tensor([1, 2, 3], type: :u64) 57 | y_s = Nx.tensor([1.0, 2.0, 3.0], type: :f64) 58 | target = Nx.tensor([1, 2], type: :u64) 59 | 60 | assert x_s 61 | |> Scholar.Interpolation.Linear.fit(y_s) 62 | |> Scholar.Interpolation.Linear.predict(target) == 63 | Nx.tensor([1.0, 2.0], type: :f64) 64 | end 65 | end 66 | end 67 | -------------------------------------------------------------------------------- /test/scholar/linear/bayesian_ridge_regression_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Linear.BayesianRidgeRegressionTest do 2 | import Nx.Defn 3 | use Scholar.Case, async: true 4 | alias Scholar.Linear.BayesianRidgeRegression 5 | alias Scholar.Linear.RidgeRegression 6 | doctest BayesianRidgeRegression 7 | 8 | test "toy bayesian ridge" do 9 | x = Nx.tensor([[1], [2], [6], [8], [10]]) 10 | y = Nx.tensor([1, 2, 6, 8, 10]) 11 | clf = BayesianRidgeRegression.fit(x, y) 12 | test = Nx.tensor([[1], [3], [4]]) 13 | expected = Nx.tensor([1, 3, 4]) 14 | predicted = BayesianRidgeRegression.predict(clf, test) 15 | assert_all_close(expected, predicted, atol: 1.0e-1) 16 | end 17 | 18 | test "toy bayesian ridge with column target" do 19 | x = Nx.tensor([[1], [2], [6], [8], [10]]) 20 | y = Nx.tensor([1, 2, 6, 8, 10]) 21 | model = BayesianRidgeRegression.fit(x, y) 22 | pred = BayesianRidgeRegression.predict(model, x) 23 | col_model = BayesianRidgeRegression.fit(x, y |> Nx.new_axis(-1)) 24 | col_pred = BayesianRidgeRegression.predict(col_model, x) 25 | assert model == col_model 26 | assert pred == col_pred 27 | end 28 | 29 | test "2 column target raises" do 30 | x = Nx.tensor([[1], [2], [6], [8], [10]]) 31 | y = Nx.tensor([1, 2, 6, 8, 10]) 32 | y = Nx.new_axis(y, -1) 33 | y = Nx.concatenate([y, y], axis: 1) 34 | 35 | message = 36 | "Scholar.Linear.BayesianRidgeRegression expected y to have shape {n_samples}, got tensor with shape: #{inspect(Nx.shape(y))}" 37 | 38 | assert_raise ArgumentError, 39 | message, 40 | fn -> 41 | BayesianRidgeRegression.fit(x, y) 42 | end 43 | end 44 | 45 | test "ridge vs bayesian ridge: parameters" do 46 | x = Nx.tensor([[1, 1], [3, 4], [5, 7], [4, 1], [2, 6], [3, 10], [3, 2]]) 47 | y = Nx.tensor([1, 2, 3, 2, 0, 4, 5]) 48 | brr = BayesianRidgeRegression.fit(x, y) 49 | rr = RidgeRegression.fit(x, y, alpha: Nx.to_number(brr.lambda) / Nx.to_number(brr.alpha)) 50 | assert_all_close(brr.coefficients, rr.coefficients, atol: 1.0e-2) 51 | assert_all_close(brr.intercept, rr.intercept, atol: 1.0e-2) 52 | end 53 | 54 | test "ridge vs bayesian ridge: weights" do 55 | x = Nx.tensor([[1, 1], [3, 4], [5, 7], [4, 1], [2, 6], [3, 10], [3, 2]]) 56 | y = Nx.tensor([1, 2, 3, 2, 0, 4, 5]) 57 | w = Nx.tensor([4, 3, 3, 1, 1, 2, 3]) 58 | brr = BayesianRidgeRegression.fit(x, y, sample_weights: w) 59 | 60 | rr = 61 | RidgeRegression.fit(x, y, 62 | alpha: Nx.to_number(brr.lambda) / Nx.to_number(brr.alpha), 63 | sample_weights: w 64 | ) 65 | 66 | assert_all_close(brr.coefficients, rr.coefficients, atol: 1.0e-2) 67 | assert_all_close(brr.intercept, rr.intercept, atol: 1.0e-2) 68 | end 69 | 70 | test "compute scores" do 71 | {x, y} = diabetes_data() 72 | eps = Nx.Constants.smallest_positive_normal(:f64) 73 | alpha = Nx.divide(1, Nx.add(Nx.variance(x), eps)) 74 | lambda = 1.0 75 | alpha_1 = 0.1 76 | alpha_2 = 0.1 77 | lambda_1 = 0.1 78 | lambda_2 = 0.1 79 | # compute score 80 | score = compute_score(x, y, alpha, lambda, alpha_1, alpha_2, lambda_1, lambda_2) 81 | 82 | brr = 83 | BayesianRidgeRegression.fit(x, Nx.flatten(y), 84 | alpha_1: alpha_1, 85 | alpha_2: alpha_2, 86 | lambda_1: lambda_1, 87 | lambda_2: lambda_2, 88 | fit_intercept?: true, 89 | compute_scores?: true, 90 | iterations: 1 91 | ) 92 | 93 | first_score = brr.scores[0] 94 | assert_all_close(score, first_score, rtol: 0.05) 95 | end 96 | 97 | defnp compute_score(x, y, alpha, lambda, alpha_1, alpha_2, lambda_1, lambda_2) do 98 | {n_samples, _} = Nx.shape(x) 99 | lambda_score = lambda_1 * Nx.log(lambda) - lambda_2 * lambda 100 | alpha_score = alpha_1 * Nx.log(alpha) - alpha_2 * alpha 101 | m = 1.0 / alpha * Nx.eye(n_samples) + 1.0 / lambda * Nx.dot(x, [-1], x, [-1]) 102 | m_inv_dot_y = Nx.LinAlg.solve(m, y) 103 | logdet = m |> Nx.LinAlg.determinant() |> Nx.log() 104 | 105 | y_score = 106 | -0.5 * 107 | (logdet + Nx.dot(y, [0], m_inv_dot_y, [0]) + n_samples * Nx.log(2 * Nx.Constants.pi())) 108 | 109 | alpha_score + lambda_score + y_score 110 | end 111 | 112 | test "constant inputs: prediction. n_features > n_samples" do 113 | key = Nx.Random.key(42) 114 | n_samples = 4 115 | n_features = 5 116 | {constant_value, new_key} = Nx.Random.uniform(key) 117 | {x, _} = Nx.Random.uniform(new_key, shape: {n_samples, n_features}, type: :f64) 118 | y = Nx.broadcast(constant_value, {n_samples}) 119 | expected = Nx.broadcast(constant_value, {n_samples}) 120 | brr = BayesianRidgeRegression.fit(x, y) 121 | predicted = BayesianRidgeRegression.predict(brr, x) 122 | assert_all_close(expected, predicted, atol: 0.01) 123 | end 124 | 125 | test "constant inputs: variance is constant" do 126 | key = Nx.Random.key(42) 127 | n_samples = 15 128 | n_features = 10 129 | {constant_value, new_key} = Nx.Random.uniform(key) 130 | {x, _} = Nx.Random.uniform(new_key, shape: {n_samples, n_features}, type: :f64) 131 | y = Nx.broadcast(constant_value, {n_samples}) 132 | brr = BayesianRidgeRegression.fit(x, y) 133 | check = Nx.less_equal(brr.sigma, 0.01) 134 | assert Nx.all(check) == Nx.u8(1) 135 | end 136 | end 137 | -------------------------------------------------------------------------------- /test/scholar/linear/logistic_regression_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Linear.LogisticRegressionTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Linear.LogisticRegression 4 | doctest LogisticRegression 5 | 6 | test "Iris Data Set - multinomial logistic regression test" do 7 | {x_train, x_test, y_train, y_test} = iris_data() 8 | 9 | model = LogisticRegression.fit(x_train, y_train, num_classes: 3) 10 | res = LogisticRegression.predict(model, x_test) 11 | accuracy = Scholar.Metrics.Classification.accuracy(res, y_test) 12 | 13 | assert Nx.greater_equal(accuracy, 0.96) == Nx.u8(1) 14 | end 15 | 16 | describe "errors" do 17 | test "when :num_classes is invalid" do 18 | x = Nx.tensor([[1, 2], [3, 4]]) 19 | y = Nx.tensor([1, 2]) 20 | 21 | assert_raise NimbleOptions.ValidationError, 22 | "invalid value for :num_classes option: expected positive integer, got: -3", 23 | fn -> 24 | LogisticRegression.fit(x, y, num_classes: -3) 25 | end 26 | 27 | assert_raise NimbleOptions.ValidationError, 28 | "invalid value for :num_classes option: expected positive integer, got: 2.0", 29 | fn -> 30 | LogisticRegression.fit(x, y, num_classes: 2.0) 31 | end 32 | end 33 | 34 | test "when missing :num_classes option" do 35 | x = Nx.tensor([[1, 2], [3, 4]]) 36 | y = Nx.tensor([0, 1]) 37 | 38 | assert_raise NimbleOptions.ValidationError, 39 | "required :num_classes option not found, received options: []", 40 | fn -> LogisticRegression.fit(x, y) end 41 | end 42 | 43 | test "when :optimizer is invalid" do 44 | x = Nx.tensor([[1, 2], [3, 4]]) 45 | y = Nx.tensor([1, 2]) 46 | 47 | assert_raise NimbleOptions.ValidationError, 48 | "invalid value for :optimizer option: expected :optimizer to be either a valid 0-arity function in Polaris.Optimizers or a valid {init_fn, update_fn} tuple", 49 | fn -> 50 | LogisticRegression.fit(x, y, 51 | num_classes: 2, 52 | optimizer: :invalid_optimizer 53 | ) 54 | end 55 | end 56 | 57 | test "when :iterations is not a positive integer" do 58 | x = Nx.tensor([[1, 2], [3, 4]]) 59 | y = Nx.tensor([1, 2]) 60 | 61 | assert_raise NimbleOptions.ValidationError, 62 | "invalid value for :iterations option: expected positive integer, got: 0", 63 | fn -> 64 | LogisticRegression.fit(x, y, num_classes: 2, iterations: 0) 65 | end 66 | end 67 | 68 | test "when training vector size is invalid" do 69 | x = Nx.tensor([5, 6]) 70 | y = Nx.tensor([1, 2]) 71 | 72 | assert_raise ArgumentError, 73 | "expected x to have shape {n_samples, n_features}, got tensor with shape: {2}", 74 | fn -> LogisticRegression.fit(x, y, num_classes: 2) end 75 | end 76 | 77 | test "when target vector size is invalid" do 78 | x = Nx.tensor([[1, 2], [3, 4]]) 79 | y = Nx.tensor([[0, 1], [1, 0]]) 80 | 81 | assert_raise ArgumentError, 82 | "Scholar.Linear.LogisticRegression expected y to have shape {n_samples}, got tensor with shape: {2, 2}", 83 | fn -> LogisticRegression.fit(x, y, num_classes: 2) end 84 | end 85 | end 86 | 87 | describe "column target tests" do 88 | @tag :wip 89 | test "column target" do 90 | {x_train, _, y_train, _} = iris_data() 91 | 92 | model = LogisticRegression.fit(x_train, y_train, num_classes: 3) 93 | pred = LogisticRegression.predict(model, x_train) 94 | col_model = LogisticRegression.fit(x_train, y_train |> Nx.new_axis(-1), num_classes: 3) 95 | col_pred = LogisticRegression.predict(col_model, x_train) 96 | assert model == col_model 97 | assert pred == col_pred 98 | end 99 | end 100 | end 101 | -------------------------------------------------------------------------------- /test/scholar/linear/svm_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Linear.SVMTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Linear.SVM 4 | doctest SVM 5 | 6 | test "Iris Data Set - multinomial classification svm test" do 7 | {x_train, x_test, y_train, y_test} = iris_data() 8 | 9 | loss_fn = fn y_pred, y_true -> 10 | Scholar.Linear.SVM.hinge_loss(y_pred, y_true, c: 1.0, margin: 150) 11 | end 12 | 13 | model = SVM.fit(x_train, y_train, num_classes: 3, loss_fn: loss_fn) 14 | res = SVM.predict(model, x_test) 15 | 16 | accuracy = Scholar.Metrics.Classification.accuracy(res, y_test) 17 | 18 | assert Nx.greater_equal(accuracy, 0.96) == Nx.u8(1) 19 | end 20 | 21 | test "test column target" do 22 | {x_train, x_test, y_train, y_test} = iris_data() 23 | 24 | loss_fn = fn y_pred, y_true -> 25 | Scholar.Linear.SVM.hinge_loss(y_pred, y_true, c: 1.0, margin: 150) 26 | end 27 | 28 | col_model = SVM.fit(x_train, y_train |> Nx.new_axis(-1), num_classes: 3, loss_fn: loss_fn) 29 | res = SVM.predict(col_model, x_test) 30 | 31 | model = SVM.fit(x_train, y_train, num_classes: 3, loss_fn: loss_fn) 32 | 33 | accuracy = Scholar.Metrics.Classification.accuracy(res, y_test) 34 | 35 | assert Nx.greater_equal(accuracy, 0.96) == Nx.u8(1) 36 | assert model == col_model 37 | end 38 | 39 | test "test fit 2 columned y data" do 40 | {x_train, _, y_train, _} = iris_data() 41 | 42 | loss_fn = fn y_pred, y_true -> 43 | Scholar.Linear.SVM.hinge_loss(y_pred, y_true, c: 1.0, margin: 150) 44 | end 45 | 46 | y_train = Nx.new_axis(y_train, -1) 47 | 48 | y_train = 49 | Nx.concatenate([y_train, y_train], axis: 1) 50 | 51 | message = 52 | "Scholar.Linear.SVM expected y to have shape {n_samples}, got tensor with shape: #{inspect(Nx.shape(y_train))}" 53 | 54 | assert_raise ArgumentError, 55 | message, 56 | fn -> SVM.fit(x_train, y_train, num_classes: 3, loss_fn: loss_fn) end 57 | end 58 | end 59 | -------------------------------------------------------------------------------- /test/scholar/metrics/classification_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Metrics.ClassificationTest do 2 | use Scholar.Case, async: true 3 | 4 | alias Scholar.Metrics.Classification 5 | doctest Classification 6 | 7 | test "roc_curve - y_score with repeated elements" do 8 | y_score = Nx.tensor([0.1, 0.1, 0.2, 0.2, 0.3, 0.3]) 9 | y_true = Nx.tensor([0, 0, 1, 1, 1, 1]) 10 | distinct_value_indices = Classification.distinct_value_indices(y_score) 11 | 12 | {fpr, tpr, thresholds} = Classification.roc_curve(y_true, y_score, distinct_value_indices) 13 | assert_all_close(fpr, Nx.tensor([0.0, 0.0, 0.0, 1.0])) 14 | assert_all_close(tpr, Nx.tensor([0.0, 0.5, 1.0, 1.0])) 15 | assert_all_close(thresholds, Nx.tensor([1.3, 0.3, 0.2, 0.1])) 16 | end 17 | 18 | describe "fbeta_score" do 19 | test "equals recall when beta is infinity" do 20 | beta = Nx.tensor(:infinity) 21 | y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) 22 | y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :u32) 23 | fbeta_scores = Classification.fbeta_score(y_true, y_pred, beta, num_classes: 2) 24 | 25 | assert_all_close(fbeta_scores, Classification.recall(y_true, y_pred, num_classes: 2)) 26 | end 27 | 28 | test "equals precision when beta is 0" do 29 | beta = 0 30 | y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) 31 | y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :u32) 32 | fbeta_scores = Classification.fbeta_score(y_true, y_pred, beta, num_classes: 2) 33 | 34 | assert_all_close(fbeta_scores, Classification.precision(y_true, y_pred, num_classes: 2)) 35 | end 36 | end 37 | 38 | describe "mcc/2" do 39 | test "returns 1 for perfect predictions" do 40 | y_true = Nx.tensor([1, 0, 1, 0, 1]) 41 | y_pred = Nx.tensor([1, 0, 1, 0, 1]) 42 | assert Classification.mcc(y_true, y_pred) == Nx.tensor([1.0], type: :f32) 43 | end 44 | 45 | test "returns -1 for completely wrong predictions" do 46 | y_true = Nx.tensor([1, 0, 1, 0, 1]) 47 | y_pred = Nx.tensor([0, 1, 0, 1, 0]) 48 | assert Classification.mcc(y_true, y_pred) == Nx.tensor([-1.0], type: :f32) 49 | end 50 | 51 | test "returns 0 when all predictions are positive" do 52 | y_true = Nx.tensor([1, 0, 1, 0, 1]) 53 | y_pred = Nx.tensor([1, 1, 1, 1, 1]) 54 | assert Classification.mcc(y_true, y_pred) == Nx.tensor([0.0], type: :f32) 55 | end 56 | 57 | test "returns 0 when all predictions are negative" do 58 | y_true = Nx.tensor([1, 0, 1, 0, 1]) 59 | y_pred = Nx.tensor([0, 0, 0, 0, 0]) 60 | assert Classification.mcc(y_true, y_pred) == Nx.tensor([0.0], type: :f32) 61 | end 62 | 63 | test "computes MCC for generic case" do 64 | y_true = Nx.tensor([1, 0, 1, 0, 1]) 65 | y_pred = Nx.tensor([1, 0, 1, 1, 1]) 66 | assert Classification.mcc(y_true, y_pred) == Nx.tensor([0.6123723983764648], type: :f32) 67 | end 68 | 69 | test "returns 0 when TP, TN, FP, and FN are all 0" do 70 | y_true = Nx.tensor([0, 0, 0, 0, 0]) 71 | y_pred = Nx.tensor([0, 0, 0, 0, 0]) 72 | assert Classification.mcc(y_true, y_pred) == Nx.tensor([0.0], type: :f32) 73 | end 74 | end 75 | end 76 | -------------------------------------------------------------------------------- /test/scholar/metrics/clustering_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Metrics.ClusteringTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Metrics.Clustering 4 | doctest Clustering 5 | 6 | defp x do 7 | Nx.tensor([ 8 | [0.8157184486105233, -0.01635270455208079, -0.18865907479231095], 9 | [-0.23329832770253972, -1.0670578022770247, -1.4104723030890731], 10 | [-1.7011914621798259, -0.5470603319369067, 0.6322584168194634], 11 | [0.9425204058473541, -1.037556635542662, -0.21666787618406214], 12 | [-0.3690472295854992, 0.09486339114365348, 0.6961098873520952], 13 | [-1.6771195334192937, -1.2952988072907938, 0.08604298953131055], 14 | [-0.15949996798407284, 0.6000955647811537, 0.8663542127627641], 15 | [-1.7895457106619646, -0.1754378695049237, -0.365205321101012], 16 | [-1.394493075734099, 0.07372952442240266, 0.050092862319408904], 17 | [0.7307430281513498, -0.9238089730338562, -1.1574739042650035], 18 | [0.8915021818008122, 0.3179163863864429, -0.08149978809223538], 19 | [-3.168762672235742, -0.9546695365525117, -0.9908465332175033], 20 | [1.4631690777205948, -0.4670129058849814, 0.6461228097855395], 21 | [0.5622446765584268, -1.3052125865380872, 2.1498777623693526], 22 | [1.520536826767848, 0.13949842485542288, 0.6555116085852244], 23 | [0.07897976746203501, -0.6255090568811114, 0.5807450128292253], 24 | [0.9392107340484556, -0.3814452707719886, 0.30010654905670114], 25 | [2.2255470282884557, 0.6745085423783157, -1.6413094715372527], 26 | [-1.1522920682636968, 1.6005950153494766, -0.3385979516219887], 27 | [0.391502390567214, -0.25562956508471424, 0.12315775252511421] 28 | ]) 29 | end 30 | 31 | defp y do 32 | Nx.tensor([2, 2, 2, 2, 0, 2, 1, 1, 2, 2, 2, 2, 2, 1, 1, 0, 1, 2, 1, 2]) 33 | end 34 | 35 | defp y_one do 36 | Nx.tensor([2, 2, 2, 2, 0, 2, 1, 3, 2, 2, 2, 2, 2, 1, 4, 0, 1, 2, 1, 2]) 37 | end 38 | 39 | test "silhouette_score test" do 40 | expected = Nx.tensor(-0.17424359382143964) 41 | 42 | assert_all_close(Clustering.silhouette_score(x(), y(), num_clusters: 3), expected) 43 | end 44 | 45 | test "silhouette_samples test" do 46 | expected = 47 | Nx.tensor([ 48 | -0.2700417935848236, 49 | 0.014513458125293255, 50 | -0.36026713252067566, 51 | -0.2003783881664276, 52 | 0.4819222092628479, 53 | -0.22008605301380157, 54 | -0.526187002658844, 55 | -0.2967069745063782, 56 | -0.3657914698123932, 57 | 0.019084271043539047, 58 | -0.27950319647789, 59 | -0.03889584541320801, 60 | -0.3125722408294678, 61 | -0.3157929480075836, 62 | -0.23939572274684906, 63 | 0.5096476674079895, 64 | -0.4214971363544464, 65 | 0.03587301820516586, 66 | -0.16732777655124664, 67 | -0.5314690470695496 68 | ]) 69 | 70 | assert_all_close(Clustering.silhouette_samples(x(), y(), num_clusters: 3), expected) 71 | end 72 | 73 | test "silhouette_score test with one-element clusters" do 74 | expected = Nx.tensor(-0.298110549250206) 75 | 76 | assert_all_close(Clustering.silhouette_score(x(), y_one(), num_clusters: 5), expected) 77 | end 78 | 79 | test "silhouette_samples test with one-element clusters" do 80 | expected = 81 | Nx.tensor([ 82 | -0.4016384184360504, 83 | -0.05811696499586105, 84 | -0.5811521410942078, 85 | -0.2139229029417038, 86 | 0.45039641857147217, 87 | -0.5208510756492615, 88 | -0.5239397287368774, 89 | 0.0, 90 | -0.7291823625564575, 91 | 0.019084271043539047, 92 | -0.5117945671081543, 93 | -0.536014974117279, 94 | -0.7470640540122986, 95 | -0.3105389475822449, 96 | 0.0, 97 | 0.4759448170661926, 98 | -0.6119896173477173, 99 | -0.2591235637664795, 100 | -0.37083834409713745, 101 | -0.5314690470695496 102 | ]) 103 | 104 | assert_all_close(Clustering.silhouette_samples(x(), y_one(), num_clusters: 5), expected) 105 | end 106 | end 107 | -------------------------------------------------------------------------------- /test/scholar/metrics/distance_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Metrics.DistanceTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Metrics.Distance 4 | doctest Distance 5 | 6 | defp x do 7 | Nx.tensor([ 8 | -0.165435329079628, 9 | -1.0515050888061523, 10 | 1.8801462650299072, 11 | 0.2381746470928192, 12 | 0.6978269219398499, 13 | 0.025831177830696106, 14 | 0.11569870263338089, 15 | -0.6905220150947571, 16 | -0.9335482120513916, 17 | -0.025539811700582504 18 | ]) 19 | end 20 | 21 | defp y do 22 | Nx.tensor([ 23 | 0.5898482203483582, 24 | -0.5769372582435608, 25 | 1.43277108669281, 26 | -0.024414867162704468, 27 | -1.3458243608474731, 28 | 1.669877052307129, 29 | 0.6263275742530823, 30 | 0.8154261708259583, 31 | 0.06888432800769806, 32 | 0.022759810090065002 33 | ]) 34 | end 35 | 36 | test "euclidean matches scipy" do 37 | assert_all_close(Distance.euclidean(x(), y()), Nx.tensor(3.388213202573845)) 38 | end 39 | 40 | test "squared euclidean matches scipy" do 41 | assert_all_close(Distance.squared_euclidean(x(), y()), Nx.tensor(11.479988706095714)) 42 | end 43 | 44 | test "manhattan matches scipy" do 45 | assert_all_close(Distance.manhattan(x(), y()), Nx.tensor(8.694822449237108)) 46 | end 47 | 48 | test "chebyshev matches scipy" do 49 | assert_all_close(Distance.chebyshev(x(), y()), Nx.tensor(2.043651282787323)) 50 | end 51 | 52 | test "minkowski matches scipy" do 53 | assert_all_close(Distance.minkowski(x(), y()), Nx.tensor(3.388213202573845)) 54 | end 55 | 56 | test "minkowski with p set to :infinity matches chebyshev" do 57 | assert_all_close(Distance.minkowski(x(), y(), p: :infinity), Nx.tensor(2.043651282787323)) 58 | end 59 | 60 | test "cosine matches scipy" do 61 | assert_all_close(Distance.cosine(x(), y()), Nx.tensor(0.7650632810164779)) 62 | end 63 | 64 | test "hamming matches scipy" do 65 | assert Distance.hamming(Nx.tensor([1, 0, 0]), Nx.tensor([0, 1, 0])) == 66 | Nx.tensor(0.6666666865348816) 67 | 68 | assert Distance.hamming(Nx.tensor([1, 0, 0]), Nx.tensor([1, 1, 0])) == 69 | Nx.tensor(0.3333333432674408) 70 | 71 | assert Distance.hamming(Nx.tensor([1, 0, 0]), Nx.tensor([2, 0, 0])) == 72 | Nx.tensor(0.3333333432674408) 73 | 74 | assert Distance.hamming(Nx.tensor([1, 0, 0]), Nx.tensor([3, 0, 0])) == 75 | Nx.tensor(0.3333333432674408) 76 | end 77 | end 78 | -------------------------------------------------------------------------------- /test/scholar/metrics/neighbors_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Metrics.NeighborsTest do 2 | use ExUnit.Case, async: true 3 | alias Scholar.Metrics.Neighbors 4 | doctest Neighbors 5 | end 6 | -------------------------------------------------------------------------------- /test/scholar/metrics/ranking_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Metrics.RankingTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Metrics.Ranking 4 | 5 | describe "dcg/3" do 6 | test "computes DCG when there are no ties" do 7 | y_true = Nx.tensor([3, 2, 3, 0, 1, 2]) 8 | y_score = Nx.tensor([3.0, 2.2, 3.5, 0.5, 1.0, 2.1]) 9 | 10 | result = Ranking.dcg(y_true, y_score) 11 | 12 | x = Nx.tensor([7.140995025634766]) 13 | assert x == Nx.broadcast(result, {1}) 14 | end 15 | 16 | test "computes DCG with ties" do 17 | y_true = Nx.tensor([3, 3, 3]) 18 | y_score = Nx.tensor([2.0, 2.0, 3.5]) 19 | 20 | result = Ranking.dcg(y_true, y_score) 21 | 22 | x = Nx.tensor([6.3927892607143715]) 23 | assert x == Nx.broadcast(result, {1}) 24 | end 25 | 26 | test "raises error when shapes mismatch" do 27 | y_true = Nx.tensor([3, 2, 3]) 28 | y_score = Nx.tensor([3.0, 2.2, 3.5, 0.5]) 29 | 30 | assert_raise ArgumentError, 31 | "expected tensor to have shape {3}, got tensor with shape {4}", 32 | fn -> 33 | Ranking.dcg(y_true, y_score) 34 | end 35 | end 36 | 37 | test "computes DCG for top-k values" do 38 | y_true = Nx.tensor([3, 2, 3, 0, 1, 2]) 39 | y_score = Nx.tensor([3.0, 2.2, 3.5, 0.5, 1.0, 2.1]) 40 | 41 | result = Ranking.dcg(y_true, y_score, k: 3) 42 | 43 | x = Nx.tensor([5.892789363861084]) 44 | assert x == Nx.broadcast(result, {1}) 45 | end 46 | end 47 | end 48 | -------------------------------------------------------------------------------- /test/scholar/metrics/regression_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Metrics.RegressionTest do 2 | use Scholar.Case, async: true 3 | 4 | alias Scholar.Metrics.Regression 5 | doctest Regression 6 | 7 | describe "mean_tweedie_deviance!/3" do 8 | test "raise when y_pred <= 0 and power < 0" do 9 | power = -1 10 | y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) 11 | y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) 12 | 13 | assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn -> 14 | Regression.mean_tweedie_deviance!(y_true, y_pred, power) 15 | end 16 | end 17 | 18 | test "raise when y_pred <= 0 and 1 <= power < 2" do 19 | power = 1 20 | y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) 21 | y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) 22 | 23 | assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn -> 24 | Regression.mean_tweedie_deviance!(y_true, y_pred, power) 25 | end 26 | end 27 | 28 | test "raise when y_pred <= 0 and power >= 2" do 29 | power = 2 30 | y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) 31 | y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) 32 | 33 | assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn -> 34 | Regression.mean_tweedie_deviance!(y_true, y_pred, power) 35 | end 36 | end 37 | 38 | test "raise when y_true < 0 and 1 <= power < 2" do 39 | power = 1 40 | y_true = Nx.tensor([-1, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32) 41 | y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32) 42 | 43 | assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn -> 44 | Regression.mean_tweedie_deviance!(y_true, y_pred, power) 45 | end 46 | end 47 | 48 | test "raise when y_true <= 0 and power >= 2" do 49 | power = 2 50 | y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32) 51 | y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32) 52 | 53 | assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn -> 54 | Regression.mean_tweedie_deviance!(y_true, y_pred, power) 55 | end 56 | end 57 | end 58 | 59 | describe "d2_tweedie_score/3" do 60 | test "equal R^2 when power is 0" do 61 | y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32) 62 | y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32) 63 | d2 = Regression.d2_tweedie_score(y_true, y_pred, 0) 64 | r2 = Regression.r2_score(y_true, y_pred) 65 | 66 | assert Nx.equal(d2, r2) 67 | end 68 | end 69 | 70 | describe "mean_pinball_loss/3" do 71 | test "mean_pinball_loss cases from sklearn" do 72 | # Test cases copied from sklearn: 73 | # https://github.com/scikit-learn/scikit-learn/blob/128e40ed593c57e8b9e57a4109928d58fa8bf359/sklearn/metrics/tests/test_regression.py#L49 74 | 75 | y_true = Nx.linspace(1, 50, n: 50) 76 | y_pred = Nx.add(y_true, 1) 77 | y_pred_2 = Nx.add(y_true, -1) 78 | 79 | assert Regression.mean_pinball_loss(y_true, y_pred) == Nx.tensor(0.5) 80 | assert Regression.mean_pinball_loss(y_true, y_pred_2) == Nx.tensor(0.5) 81 | assert Regression.mean_pinball_loss(y_true, y_pred, alpha: 0.4) == Nx.tensor(0.6) 82 | assert Regression.mean_pinball_loss(y_true, y_pred_2, alpha: 0.4) == Nx.tensor(0.4) 83 | end 84 | 85 | test "mean_pinball_loss with axes" do 86 | y_true = Nx.tensor([[1, 0, 0, 1], [0, 1, 1, 1], [1, 1, 0, 1]]) 87 | y_pred = Nx.tensor([[0, 0, 0, 1], [1, 0, 1, 1], [0, 0, 0, 1]]) 88 | 89 | sample_weight = 90 | Nx.tensor([[0.5, 0.5, 0.5, 1.5], [1.5, 0.5, 1.5, 1.5], [1.5, 1.5, 1.5, 1.5]]) 91 | 92 | expected_error = Nx.tensor((1 + 2 / 3) / 8) 93 | expected_raw_values_tensor = Nx.tensor([0.5, 0.33333333, 0.0, 0.0]) 94 | expected_raw_values_weighted_tensor = Nx.tensor([0.5, 0.4, 0.0, 0.0]) 95 | 96 | mpbl = Regression.mean_pinball_loss(y_true, y_pred) 97 | assert_all_close(mpbl, expected_error) 98 | ## this assertion yields false due to precision error 99 | mpbl = 100 | Regression.mean_pinball_loss( 101 | y_true, 102 | y_pred, 103 | alpha: 0.5 104 | ) 105 | 106 | assert_all_close(mpbl, expected_error) 107 | mpbl = Regression.mean_pinball_loss(y_true, y_pred, alpha: 0.5, axes: [0]) 108 | assert_all_close(mpbl, expected_raw_values_tensor) 109 | 110 | mpbl = 111 | Regression.mean_pinball_loss(y_true, y_pred, 112 | alpha: 0.5, 113 | sample_weights: sample_weight, 114 | axes: [0] 115 | ) 116 | 117 | assert_all_close(mpbl, expected_raw_values_weighted_tensor) 118 | 119 | mpbl = 120 | Regression.mean_pinball_loss(y_true, y_pred, 121 | alpha: 0.5, 122 | sample_weights: sample_weight 123 | ) 124 | 125 | assert_all_close(mpbl, Nx.tensor(0.225)) 126 | end 127 | end 128 | end 129 | -------------------------------------------------------------------------------- /test/scholar/metrics/similarity_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Metrics.SimilarityTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Metrics.Similarity 4 | doctest Similarity 5 | 6 | describe "jaccard/2" do 7 | test "returns similarity according to sklearn jaccard_score function" do 8 | x = Nx.tensor([1, 2, 3, 5, 0]) 9 | y = Nx.tensor([1, 30, 4, 8, 9]) 10 | 11 | assert Similarity.jaccard(x, y) == Nx.tensor(1 / 9) 12 | end 13 | 14 | test "returns 100% of similarity" do 15 | x = Nx.tensor([1, 2, 3]) 16 | y = Nx.tensor([1, 2, 3]) 17 | 18 | assert Similarity.jaccard(x, y) == Nx.tensor(1.0) 19 | end 20 | 21 | test "returns 0% of similarity" do 22 | x = Nx.tensor([1, 2, 3]) 23 | y = Nx.tensor([4, 5, 6]) 24 | 25 | assert Similarity.jaccard(x, y) == Nx.tensor(0.0) 26 | end 27 | 28 | test "returns 20% of similarity" do 29 | x = Nx.tensor([1, 2, 3]) 30 | y = Nx.tensor([3, 4, 5]) 31 | 32 | assert Similarity.jaccard(x, y) == Nx.tensor(0.20) 33 | end 34 | 35 | test "returns similarity when tensors have a single element" do 36 | x = Nx.tensor([1]) 37 | y = Nx.tensor([2]) 38 | 39 | assert Similarity.jaccard(x, y) == Nx.tensor(0.0) 40 | end 41 | 42 | test "returns similarity when tensor has multiple dimensions" do 43 | x = Nx.tensor([[0, 1, 1], [1, 1, 0]]) 44 | y = Nx.tensor([[1, 1, 1], [1, 0, 0]]) 45 | 46 | assert Similarity.jaccard(x, y) == Nx.tensor(1.0) 47 | end 48 | end 49 | 50 | describe "binary_jaccard/2" do 51 | test "returns similarity according to sklearn jaccard_score function" do 52 | x = Nx.tensor([1, 0, 0, 1, 1, 1]) 53 | y = Nx.tensor([0, 0, 1, 1, 1, 0]) 54 | 55 | assert Similarity.binary_jaccard(x, y) == Nx.tensor(0.4) 56 | end 57 | 58 | test "returns 100% of similarity" do 59 | x = Nx.tensor([1, 0, 1]) 60 | y = Nx.tensor([1, 0, 1]) 61 | 62 | assert Similarity.binary_jaccard(x, y) == Nx.tensor(1.0) 63 | end 64 | 65 | test "returns 0% of similarity" do 66 | x = Nx.tensor([1, 1, 1]) 67 | y = Nx.tensor([0, 0, 0]) 68 | 69 | assert Similarity.binary_jaccard(x, y) == Nx.tensor(0.0) 70 | end 71 | 72 | test "returns 20% of similarity" do 73 | x = Nx.tensor([1, 0, 1, 0, 1]) 74 | y = Nx.tensor([0, 1, 1, 1, 0]) 75 | 76 | assert Similarity.binary_jaccard(x, y) == Nx.tensor(0.2) 77 | end 78 | 79 | test "returns similarity when tensors have a single element" do 80 | x = Nx.tensor([1]) 81 | y = Nx.tensor([1]) 82 | 83 | assert Similarity.binary_jaccard(x, y) == Nx.tensor(1.0) 84 | end 85 | 86 | test "returns similarity when tensors have scalars" do 87 | x = Nx.tensor(1) 88 | y = Nx.tensor(0) 89 | 90 | assert Similarity.binary_jaccard(x, y) == Nx.tensor(0.0) 91 | end 92 | 93 | test "returns similarity when tensor has multiple dimensions" do 94 | x = Nx.tensor([[0, 1, 1], [1, 1, 1]]) 95 | y = Nx.tensor([[1, 1, 1], [1, 1, 1]]) 96 | 97 | assert Similarity.binary_jaccard(x, y) == Nx.tensor(5 / 6) 98 | end 99 | 100 | test "raises exception when tensors have different shapes" do 101 | x = Nx.tensor([1, 1, 0]) 102 | y = Nx.tensor([1, 0]) 103 | 104 | assert_raise ArgumentError, 105 | "expected tensor to have shape {3}, got tensor with shape {2}", 106 | fn -> 107 | Similarity.binary_jaccard(x, y) 108 | end 109 | end 110 | end 111 | end 112 | -------------------------------------------------------------------------------- /test/scholar/model_selection_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.ModelSelectionTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.ModelSelection 4 | doctest ModelSelection 5 | end 6 | -------------------------------------------------------------------------------- /test/scholar/neighbors/brute_knn_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Neighbors.BruteKNNTest do 2 | use ExUnit.Case, async: true 3 | alias Scholar.Neighbors.BruteKNN 4 | doctest BruteKNN 5 | 6 | defp data do 7 | Nx.tensor([ 8 | [10, 15], 9 | [46, 63], 10 | [68, 21], 11 | [40, 33], 12 | [25, 54], 13 | [15, 43], 14 | [44, 58], 15 | [45, 40], 16 | [62, 69], 17 | [53, 67] 18 | ]) 19 | end 20 | 21 | defp query do 22 | Nx.tensor([ 23 | [12, 23], 24 | [55, 30], 25 | [41, 57], 26 | [64, 72], 27 | [26, 39] 28 | ]) 29 | end 30 | 31 | defp result do 32 | neighbor_indices = 33 | Nx.tensor( 34 | [ 35 | [0, 5, 3], 36 | [7, 3, 2], 37 | [6, 1, 9], 38 | [8, 9, 1], 39 | [5, 4, 3] 40 | ], 41 | type: :u64 42 | ) 43 | 44 | neighbor_distances = 45 | Nx.tensor([ 46 | [8.246211051940918, 20.2237491607666, 29.73213768005371], 47 | [14.142135620117188, 15.29705810546875, 15.81138801574707], 48 | [3.1622776985168457, 7.8102498054504395, 15.620499610900879], 49 | [3.605551242828369, 12.083045959472656, 20.124610900878906], 50 | [11.704699516296387, 15.033296585083008, 15.231546401977539] 51 | ]) 52 | 53 | {neighbor_indices, neighbor_distances} 54 | end 55 | 56 | describe "fit" do 57 | test "default" do 58 | data = data() 59 | k = 3 60 | model = BruteKNN.fit(data, num_neighbors: k) 61 | assert model.num_neighbors == 3 62 | assert model.data == data 63 | assert model.batch_size == nil 64 | end 65 | 66 | test "custom metric and batch_size" do 67 | data = data() 68 | k = 3 69 | metric = &Scholar.Metrics.Distance.minkowski/2 70 | batch_size = 2 71 | model = BruteKNN.fit(data, num_neighbors: k, metric: metric, batch_size: batch_size) 72 | assert model.num_neighbors == k 73 | assert model.metric == metric 74 | assert model.data == data 75 | assert model.batch_size == batch_size 76 | end 77 | end 78 | 79 | describe "predict" do 80 | test "batch_size = 1" do 81 | query = query() 82 | k = 3 83 | model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 1) 84 | {neighbors_true, distances_true} = result() 85 | {neighbors_pred, distances_pred} = BruteKNN.predict(model, query) 86 | assert neighbors_pred == neighbors_true 87 | assert distances_pred == distances_true 88 | end 89 | 90 | test "batch_size = 2" do 91 | query = query() 92 | k = 3 93 | model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 2) 94 | {neighbors_true, distances_true} = result() 95 | {neighbors_pred, distances_pred} = BruteKNN.predict(model, query) 96 | assert neighbors_pred == neighbors_true 97 | assert distances_pred == distances_true 98 | end 99 | 100 | test "batch_size = 5" do 101 | query = query() 102 | k = 3 103 | model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 5) 104 | {neighbors_true, distances_true} = result() 105 | {neighbors_pred, distances_pred} = BruteKNN.predict(model, query) 106 | assert neighbors_pred == neighbors_true 107 | assert distances_pred == distances_true 108 | end 109 | 110 | test "batch_size = 10" do 111 | query = query() 112 | k = 3 113 | model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 10) 114 | {neighbors_true, distances_true} = result() 115 | {neighbors_pred, distances_pred} = BruteKNN.predict(model, query) 116 | 117 | assert neighbors_pred == 118 | neighbors_true 119 | 120 | assert distances_pred == distances_true 121 | end 122 | 123 | test "custom metric" do 124 | model = BruteKNN.fit(data(), num_neighbors: 3, batch_size: 1, metric: :cosine) 125 | assert {_, _} = BruteKNN.predict(model, query()) 126 | end 127 | end 128 | end 129 | -------------------------------------------------------------------------------- /test/scholar/neighbors/kd_tree_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Neighbors.KDTreeTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Neighbors.KDTree 4 | doctest KDTree 5 | 6 | defp example do 7 | Nx.tensor([ 8 | [10, 15], 9 | [46, 63], 10 | [68, 21], 11 | [40, 33], 12 | [25, 54], 13 | [15, 43], 14 | [44, 58], 15 | [45, 40], 16 | [62, 69], 17 | [53, 67] 18 | ]) 19 | end 20 | 21 | describe "fit" do 22 | test "iota" do 23 | tree = KDTree.fit(Nx.iota({5, 2})) 24 | assert tree.levels == 3 25 | assert tree.indices == Nx.u32([3, 1, 4, 0, 2]) 26 | assert tree.num_neighbors == 3 27 | end 28 | 29 | test "float" do 30 | tree = KDTree.fit(Nx.as_type(example(), :f32)) 31 | assert tree.levels == 4 32 | assert Nx.to_flat_list(tree.indices) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] 33 | assert tree.num_neighbors == 3 34 | end 35 | 36 | test "sample" do 37 | tree = KDTree.fit(example()) 38 | assert tree.levels == 4 39 | assert Nx.to_flat_list(tree.indices) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] 40 | assert tree.num_neighbors == 3 41 | end 42 | end 43 | 44 | defp x do 45 | Nx.tensor([ 46 | [3, 6, 7, 5], 47 | [9, 8, 5, 4], 48 | [4, 4, 4, 1], 49 | [9, 4, 5, 6], 50 | [6, 4, 5, 7], 51 | [4, 5, 3, 3], 52 | [4, 5, 7, 8], 53 | [9, 4, 4, 5], 54 | [8, 4, 3, 9], 55 | [2, 8, 4, 4] 56 | ]) 57 | end 58 | 59 | defp x_pred do 60 | Nx.tensor([[4, 3, 8, 4], [1, 6, 1, 1], [3, 7, 9, 2], [5, 2, 1, 2]]) 61 | end 62 | 63 | describe "predict knn" do 64 | test "all defaults" do 65 | kdtree = KDTree.fit(x()) 66 | {indices, distances} = KDTree.predict(kdtree, x_pred()) 67 | 68 | assert indices == Nx.tensor([[0, 6, 4], [5, 2, 9], [0, 9, 2], [5, 2, 7]], type: :s64) 69 | 70 | assert_all_close( 71 | distances, 72 | Nx.tensor([ 73 | [3.464101552963257, 4.582575798034668, 4.795831680297852], 74 | [4.242640495300293, 4.690415859222412, 4.795831680297852], 75 | [3.7416574954986572, 5.5677642822265625, 6.0], 76 | [3.872983455657959, 3.872983455657959, 6.164413928985596] 77 | ]) 78 | ) 79 | end 80 | 81 | test "metric set to {:minkowski, 1.5}" do 82 | kdtree = KDTree.fit(x(), metric: {:minkowski, 1.5}) 83 | {indices, distances} = KDTree.predict(kdtree, x_pred()) 84 | 85 | assert indices == Nx.tensor([[0, 6, 2], [5, 2, 9], [0, 9, 2], [5, 2, 7]], type: :s64) 86 | 87 | assert_all_close( 88 | distances, 89 | Nx.tensor([ 90 | [4.065119743347168, 5.191402435302734, 5.862917423248291], 91 | [5.198591709136963, 5.591182708740234, 5.869683265686035], 92 | [4.334622859954834, 6.35192346572876, 6.9637274742126465], 93 | [4.649191856384277, 4.649191856384277, 7.664907932281494] 94 | ]) 95 | ) 96 | end 97 | 98 | test "k set to 4" do 99 | kdtree = KDTree.fit(x(), num_neighbors: 4) 100 | {indices, distances} = KDTree.predict(kdtree, x_pred()) 101 | 102 | assert indices == 103 | Nx.tensor([[0, 6, 4, 2], [5, 2, 9, 0], [0, 9, 2, 5], [5, 2, 7, 4]], type: :s64) 104 | 105 | assert_all_close( 106 | distances, 107 | Nx.tensor([ 108 | [3.464101552963257, 4.582575798034668, 4.795831680297852, 5.099019527435303], 109 | [4.242640495300293, 4.690415859222412, 4.795831680297852, 7.4833149909973145], 110 | [3.7416574954986572, 5.5677642822265625, 6.0, 6.480740547180176], 111 | [3.872983455657959, 3.872983455657959, 6.164413928985596, 6.78233003616333] 112 | ]) 113 | ) 114 | end 115 | 116 | test "float type data" do 117 | kdtree = KDTree.fit(x() |> Nx.as_type(:f64), num_neighbors: 4) 118 | {indices, distances} = KDTree.predict(kdtree, x_pred()) 119 | 120 | assert indices == 121 | Nx.tensor([[0, 6, 4, 2], [5, 2, 9, 0], [0, 9, 2, 5], [5, 2, 7, 4]], type: :s64) 122 | 123 | assert_all_close( 124 | distances, 125 | Nx.tensor([ 126 | [3.464101552963257, 4.582575798034668, 4.795831680297852, 5.099019527435303], 127 | [4.242640495300293, 4.690415859222412, 4.795831680297852, 7.4833149909973145], 128 | [3.7416574954986572, 5.5677642822265625, 6.0, 6.480740547180176], 129 | [3.872983455657959, 3.872983455657959, 6.164413928985596, 6.78233003616333] 130 | ]) 131 | ) 132 | end 133 | end 134 | end 135 | -------------------------------------------------------------------------------- /test/scholar/neighbors/knn_regressor_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Neighbors.KNNRegressorTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Neighbors.KNNRegressor 4 | doctest KNNRegressor 5 | 6 | defp x_train do 7 | Nx.tensor([ 8 | [3, 6, 7, 5], 9 | [9, 8, 5, 4], 10 | [4, 4, 4, 1], 11 | [9, 4, 5, 6], 12 | [6, 4, 5, 7], 13 | [4, 5, 3, 3], 14 | [4, 5, 7, 8], 15 | [9, 4, 4, 5], 16 | [8, 4, 3, 9], 17 | [2, 8, 4, 4] 18 | ]) 19 | end 20 | 21 | defp y_train do 22 | Nx.tensor([[0], [1], [1], [1], [1], [1], [1], [1], [0], [0]]) 23 | end 24 | 25 | defp x do 26 | Nx.tensor([[4, 3, 8, 4], [1, 6, 1, 1], [3, 7, 9, 2], [5, 2, 1, 2]]) 27 | end 28 | 29 | describe "fit" do 30 | test "fit with default parameters" do 31 | model = KNNRegressor.fit(x_train(), y_train(), num_neighbors: 3) 32 | 33 | assert model.algorithm == Scholar.Neighbors.BruteKNN.fit(x_train(), num_neighbors: 3) 34 | assert model.labels == y_train() 35 | assert model.weights == :uniform 36 | end 37 | 38 | test "fit with k-d tree" do 39 | model = KNNRegressor.fit(x_train(), y_train(), algorithm: :kd_tree, num_neighbors: 3) 40 | 41 | assert model.algorithm == Scholar.Neighbors.KDTree.fit(x_train(), num_neighbors: 3) 42 | assert model.labels == y_train() 43 | assert model.weights == :uniform 44 | end 45 | 46 | test "fit with random projection forest" do 47 | key = Nx.Random.key(12) 48 | 49 | model = 50 | KNNRegressor.fit(x_train(), y_train(), 51 | algorithm: :random_projection_forest, 52 | num_neighbors: 3, 53 | num_trees: 4, 54 | key: key 55 | ) 56 | 57 | assert model.algorithm == 58 | Scholar.Neighbors.RandomProjectionForest.fit(x_train(), 59 | num_neighbors: 3, 60 | num_trees: 4, 61 | key: key 62 | ) 63 | 64 | assert model.labels == y_train() 65 | assert model.weights == :uniform 66 | end 67 | end 68 | 69 | describe "predict" do 70 | test "predict with default parameters" do 71 | model = KNNRegressor.fit(x_train(), y_train(), num_neighbors: 3) 72 | y_pred = KNNRegressor.predict(model, x()) 73 | assert_all_close(y_pred, Nx.tensor([[0.66666667], [0.66666667], [0.33333333], [1.0]])) 74 | end 75 | 76 | test "predict with weights set to :distance" do 77 | model = KNNRegressor.fit(x_train(), y_train(), num_neighbors: 3, weights: :distance) 78 | y_pred = KNNRegressor.predict(model, x()) 79 | assert_all_close(y_pred, Nx.tensor([[0.59648849], [0.68282796], [0.2716506], [1.0]])) 80 | end 81 | 82 | test "predict with cosine metric and weights set to :distance" do 83 | model = 84 | KNNRegressor.fit(x_train(), y_train(), 85 | num_neighbors: 3, 86 | metric: :cosine, 87 | weights: :distance 88 | ) 89 | 90 | y_pred = KNNRegressor.predict(model, x()) 91 | assert_all_close(y_pred, Nx.tensor([[0.5736568], [0.427104], [0.33561941], [1.0]])) 92 | end 93 | 94 | test "predict with 2D labels" do 95 | y = 96 | Nx.tensor([ 97 | [1, 4], 98 | [0, 3], 99 | [2, 5], 100 | [0, 3], 101 | [0, 3], 102 | [1, 4], 103 | [2, 5], 104 | [0, 3], 105 | [1, 4], 106 | [2, 5] 107 | ]) 108 | 109 | model = KNNRegressor.fit(x_train(), y, num_neighbors: 3) 110 | y_pred = KNNRegressor.predict(model, x()) 111 | 112 | assert_all_close( 113 | y_pred, 114 | Nx.tensor([ 115 | [1.0, 4.0], 116 | [1.6666666269302368, 4.666666507720947], 117 | [1.6666666269302368, 4.666666507720947], 118 | [1.0, 4.0] 119 | ]) 120 | ) 121 | end 122 | 123 | test "predict with 2D labels, cosine metric and weights set to :distance" do 124 | y = 125 | Nx.tensor([ 126 | [1, 4], 127 | [0, 3], 128 | [2, 5], 129 | [0, 3], 130 | [0, 3], 131 | [1, 4], 132 | [2, 5], 133 | [0, 3], 134 | [1, 4], 135 | [2, 5] 136 | ]) 137 | 138 | model = 139 | KNNRegressor.fit(x_train(), y, num_neighbors: 3, metric: :cosine, weights: :distance) 140 | 141 | y_pred = KNNRegressor.predict(model, x()) 142 | 143 | assert_all_close( 144 | y_pred, 145 | Nx.tensor([ 146 | [1.11344606, 4.11344606], 147 | [1.3915288, 4.3915288], 148 | [1.53710155, 4.53710155], 149 | [0.0, 3.0] 150 | ]) 151 | ) 152 | end 153 | end 154 | end 155 | -------------------------------------------------------------------------------- /test/scholar/neighbors/large_vis_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Neighbors.LargeVisTest do 2 | use ExUnit.Case, async: true 3 | alias Scholar.Neighbors.LargeVis 4 | doctest LargeVis 5 | end 6 | -------------------------------------------------------------------------------- /test/scholar/neighbors/nn_descent_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Neighbors.NNDescentTest do 2 | use ExUnit.Case, async: true 3 | alias Scholar.Neighbors.NNDescent 4 | doctest NNDescent 5 | 6 | test "every point is its own neighbor when num_neighbors is 1" do 7 | key = Nx.Random.key(12) 8 | {tensor, key} = Nx.Random.uniform(key, shape: {10, 5}) 9 | size = Nx.axis_size(tensor, 0) 10 | 11 | %NNDescent{nearest_neighbors: nearest_neighbors, distances: distances} = 12 | NNDescent.fit(tensor, 13 | num_neighbors: 1, 14 | key: key 15 | ) 16 | 17 | assert Nx.flatten(nearest_neighbors) == Nx.iota({size}, type: :s64) 18 | assert Nx.flatten(distances) == Nx.broadcast(0.0, {size}) 19 | end 20 | end 21 | -------------------------------------------------------------------------------- /test/scholar/neighbors/random_projection_forest_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Neighbors.RandomProjectionForestTest do 2 | use ExUnit.Case, async: true 3 | alias Scholar.Neighbors.RandomProjectionForest 4 | doctest RandomProjectionForest 5 | 6 | defp example do 7 | Nx.tensor([ 8 | [10, 15], 9 | [46, 63], 10 | [68, 21], 11 | [40, 33], 12 | [25, 54], 13 | [15, 43], 14 | [44, 58], 15 | [45, 40], 16 | [62, 69], 17 | [53, 67] 18 | ]) 19 | end 20 | 21 | describe "fit" do 22 | test "shape" do 23 | tensor = example() 24 | 25 | forest = 26 | RandomProjectionForest.fit(tensor, num_neighbors: 2, num_trees: 4, min_leaf_size: 3) 27 | 28 | assert forest.num_neighbors == 2 29 | assert forest.depth == 1 30 | assert forest.leaf_size == 5 31 | assert forest.num_trees == 4 32 | assert forest.indices.shape == {4, 10} 33 | assert forest.data.shape == {10, 2} 34 | assert forest.hyperplanes.shape == {4, 1, 2} 35 | assert forest.medians.shape == {4, 1} 36 | end 37 | end 38 | 39 | describe "predict" do 40 | test "shape" do 41 | tensor = example() 42 | 43 | forest = 44 | RandomProjectionForest.fit(tensor, num_neighbors: 2, num_trees: 4, min_leaf_size: 3) 45 | 46 | {neighbor_indices, neighbor_distances} = 47 | RandomProjectionForest.predict(forest, Nx.tensor([[20, 30], [30, 50]])) 48 | 49 | assert Nx.shape(neighbor_indices) == {2, 2} 50 | assert Nx.shape(neighbor_distances) == {2, 2} 51 | end 52 | 53 | test "every point is its own neighbor when num_neighbors is 1" do 54 | key = Nx.Random.key(12) 55 | {tensor, key} = Nx.Random.uniform(key, shape: {1000, 10}) 56 | size = Nx.axis_size(tensor, 0) 57 | 58 | forest = 59 | RandomProjectionForest.fit(tensor, 60 | num_neighbors: 1, 61 | num_trees: 1, 62 | min_leaf_size: 1, 63 | key: key 64 | ) 65 | 66 | {neighbors, distances} = RandomProjectionForest.predict(forest, tensor) 67 | assert Nx.flatten(neighbors) == Nx.iota({size}, type: :u32) 68 | assert Nx.flatten(distances) == Nx.broadcast(0.0, {size}) 69 | end 70 | 71 | test "every point is its own neighbor when num_neighbors is 1 and size is power of two" do 72 | key = Nx.Random.key(12) 73 | {tensor, key} = Nx.Random.uniform(key, shape: {1024, 10}) 74 | size = Nx.axis_size(tensor, 0) 75 | 76 | forest = 77 | RandomProjectionForest.fit(tensor, 78 | num_neighbors: 1, 79 | num_trees: 1, 80 | min_leaf_size: 1, 81 | key: key 82 | ) 83 | 84 | {neighbors, distances} = RandomProjectionForest.predict(forest, tensor) 85 | assert Nx.flatten(neighbors) == Nx.iota({size}, type: :u32) 86 | assert Nx.flatten(distances) == Nx.broadcast(0.0, {size}) 87 | end 88 | end 89 | end 90 | -------------------------------------------------------------------------------- /test/scholar/neighbors/rnn_regressor_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Neighbors.RadiusNNRegressorTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Neighbors.RadiusNNRegressor 4 | doctest RadiusNNRegressor 5 | 6 | defp x do 7 | Nx.tensor([ 8 | [3, 6, 7, 5], 9 | [9, 8, 5, 4], 10 | [4, 4, 4, 1], 11 | [9, 4, 5, 6], 12 | [6, 4, 5, 7], 13 | [4, 5, 3, 3], 14 | [4, 5, 7, 8], 15 | [9, 4, 4, 5], 16 | [8, 4, 3, 9], 17 | [2, 8, 4, 4] 18 | ]) 19 | end 20 | 21 | defp y do 22 | Nx.tensor([0, 1, 1, 1, 1, 1, 1, 1, 0, 0]) 23 | end 24 | 25 | defp x_pred do 26 | Nx.tensor([[4, 3, 8, 4], [1, 6, 1, 1], [3, 7, 9, 2], [5, 2, 1, 2]]) 27 | end 28 | 29 | describe "predict" do 30 | test "predict with weights set to :distance" do 31 | model = 32 | RadiusNNRegressor.fit(x(), y(), 33 | num_classes: 2, 34 | radius: 10, 35 | weights: :distance 36 | ) 37 | 38 | predictions = RadiusNNRegressor.predict(model, x_pred()) 39 | assert_all_close(predictions, Nx.tensor([0.69033845, 0.71773642, 0.68217609, 0.75918273])) 40 | end 41 | 42 | test "predict with weights set to :distance and with specific metric" do 43 | model = 44 | RadiusNNRegressor.fit(x(), y(), 45 | num_classes: 2, 46 | radius: 10, 47 | weights: :distance, 48 | metric: :cosine 49 | ) 50 | 51 | predictions = RadiusNNRegressor.predict(model, x_pred()) 52 | assert_all_close(predictions, Nx.tensor([0.683947, 0.54694187, 0.59806132, 0.86398641])) 53 | end 54 | 55 | test "predict with weights set to :distance and with specific metric and 2d labels" do 56 | y = 57 | Nx.tensor([ 58 | [1, 4], 59 | [0, 3], 60 | [2, 5], 61 | [0, 3], 62 | [0, 3], 63 | [1, 4], 64 | [2, 5], 65 | [0, 3], 66 | [1, 4], 67 | [2, 5] 68 | ]) 69 | 70 | model = 71 | RadiusNNRegressor.fit(x(), y, 72 | num_classes: 3, 73 | radius: 10, 74 | weights: :distance, 75 | metric: :cosine 76 | ) 77 | 78 | predictions = RadiusNNRegressor.predict(model, x_pred()) 79 | 80 | assert_all_close( 81 | predictions, 82 | Nx.tensor([ 83 | [0.99475077, 3.99475077], 84 | [1.20828527, 4.20828527], 85 | [1.15227075, 4.15227075], 86 | [0.37743229, 3.37743229] 87 | ]) 88 | ) 89 | end 90 | end 91 | 92 | describe "errors" do 93 | test "wrong shape of x" do 94 | x = Nx.tensor([1, 2, 3, 4, 5]) 95 | y = Nx.tensor([1, 2, 3, 4, 5]) 96 | 97 | assert_raise ArgumentError, 98 | "expected input tensor to have shape {n_samples, n_features} or {num_samples, num_samples}, 99 | got tensor with shape: {5}", 100 | fn -> 101 | RadiusNNRegressor.fit(x, y, num_classes: 5) 102 | end 103 | end 104 | 105 | test "wrong shape of y" do 106 | x = Nx.tensor([[1], [2], [3], [4], [5], [6]]) 107 | y = Nx.tensor([[[1, 2, 3, 4, 5]]]) 108 | 109 | assert_raise ArgumentError, 110 | "expected labels to have shape {num_samples} or {num_samples, num_outputs}, 111 | got tensor with shape: {1, 1, 5}", 112 | fn -> 113 | RadiusNNRegressor.fit(x, y, num_classes: 5) 114 | end 115 | end 116 | 117 | test "incompatible shapes of x and y" do 118 | x = Nx.tensor([[1], [2], [3], [4], [5], [6]]) 119 | y = Nx.tensor([1, 2, 3, 4, 5]) 120 | 121 | assert_raise ArgumentError, 122 | "expected labels to have the same size of the first axis as data, 123 | got: 6 != 5", 124 | fn -> 125 | RadiusNNRegressor.fit(x, y, num_classes: 5) 126 | end 127 | end 128 | end 129 | end 130 | -------------------------------------------------------------------------------- /test/scholar/preprocessing/binarizer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.BinarizerTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Preprocessing.Binarizer 4 | doctest Binarizer 5 | 6 | describe "binarization" do 7 | test "binarize with positive threshold" do 8 | tensor = Nx.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [-2.0, -1.0, 0.0]]) 9 | 10 | jit_binarizer = Nx.Defn.jit(&Binarizer.fit_transform/2) 11 | 12 | result = jit_binarizer.(tensor, threshold: 2.0) 13 | 14 | assert Nx.to_flat_list(result) == [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0] 15 | end 16 | 17 | test "binarize values with default threshold" do 18 | tensor = Nx.tensor([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0], [-2.0, 1.0, 0.0]]) 19 | 20 | result = Binarizer.fit_transform(tensor) 21 | 22 | assert Nx.to_flat_list(result) == [0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0] 23 | end 24 | 25 | test "binarize with threshold less than 0" do 26 | tensor = Nx.tensor([[0.0, 0.5, -0.5], [-0.1, -0.2, -0.3]]) 27 | jit_binarizer = Nx.Defn.jit(&Binarizer.fit_transform/2) 28 | 29 | result = jit_binarizer.(tensor, threshold: -0.2) 30 | 31 | assert Nx.to_flat_list(result) == [1.0, 1.0, 0.0, 1.0, 0.0, 0.0] 32 | end 33 | end 34 | end 35 | -------------------------------------------------------------------------------- /test/scholar/preprocessing/max_abs_scaler_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.MaxAbsScalerTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Preprocessing.MaxAbsScaler 4 | 5 | doctest MaxAbsScaler 6 | 7 | describe "fit_transform/2" do 8 | test "set axes to [0]" do 9 | data = Nx.tensor([[1, -1, 2], [3, 0, 0], [0, 1, -1], [2, 3, 1]]) 10 | 11 | expected = 12 | Nx.tensor([ 13 | [0.3333333432674408, -0.3333333432674408, 1.0], 14 | [1.0, 0.0, 0.0], 15 | [0.0, 0.3333333432674408, -0.5], 16 | [0.6666666865348816, 1.0, 0.5] 17 | ]) 18 | 19 | assert_all_close(MaxAbsScaler.fit_transform(data, axes: [0]), expected) 20 | end 21 | 22 | test "Work in case where tensor contains only zeros" do 23 | data = Nx.broadcast(Nx.f32(0), {3, 3}) 24 | expected = data 25 | assert MaxAbsScaler.fit_transform(data) == expected 26 | end 27 | end 28 | end 29 | -------------------------------------------------------------------------------- /test/scholar/preprocessing/min_max_scaler_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.MinMaxScalerTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Preprocessing.MinMaxScaler 4 | 5 | doctest MinMaxScaler 6 | 7 | describe "fit_transform/2" do 8 | test "set axes to [0]" do 9 | data = Nx.tensor([[1, -1, 2], [3, 0, 0], [0, 1, -1], [2, 3, 1]]) 10 | 11 | expected = 12 | Nx.tensor([ 13 | [0.3333333432674408, 0.0, 1.0], 14 | [1.0, 0.25, 0.3333333432674408], 15 | [0.0, 0.5, 0.0], 16 | [0.6666666865348816, 1.0, 0.6666666865348816] 17 | ]) 18 | 19 | assert_all_close(MinMaxScaler.fit_transform(data, axes: [0]), expected) 20 | end 21 | 22 | test "set axes to [0], min_bound to 1, and max_bound to 3" do 23 | data = Nx.tensor([[1, -1, 2], [3, 0, 0], [0, 1, -1], [2, 3, 1]]) 24 | 25 | expected = 26 | Nx.tensor([ 27 | [1.6666667461395264, 1.0, 3.0], 28 | [3.0, 1.5, 1.6666667461395264], 29 | [1.0, 2.0, 1.0], 30 | [2.3333334922790527, 3.0, 2.3333334922790527] 31 | ]) 32 | 33 | assert_all_close( 34 | MinMaxScaler.fit_transform(data, axes: [0], min_bound: 1, max_bound: 3), 35 | expected 36 | ) 37 | end 38 | 39 | test "Work in case where tensor contains only zeros" do 40 | data = Nx.broadcast(Nx.f32(0), {3, 3}) 41 | expected = data 42 | assert MinMaxScaler.fit_transform(data) == expected 43 | end 44 | end 45 | end 46 | -------------------------------------------------------------------------------- /test/scholar/preprocessing/normalizer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.NormalizerTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Preprocessing.Normalizer 4 | 5 | doctest Normalizer 6 | end 7 | -------------------------------------------------------------------------------- /test/scholar/preprocessing/one_hot_encoder_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.OneHotEncoderTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Preprocessing.OneHotEncoder 4 | 5 | doctest OneHotEncoder 6 | end 7 | -------------------------------------------------------------------------------- /test/scholar/preprocessing/ordinal_encoder_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.OrdinalEncoderTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Preprocessing.OrdinalEncoder 4 | 5 | doctest OrdinalEncoder 6 | end 7 | -------------------------------------------------------------------------------- /test/scholar/preprocessing/robust_scaler_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.RobustScalerTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Preprocessing.RobustScaler 4 | doctest RobustScaler 5 | 6 | describe "fit_transform" do 7 | test "applies scaling to data" do 8 | data = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 9 | 10 | expected = 11 | Nx.tensor([ 12 | [0.0, -1.0, 1.3333333333333333], 13 | [1.0, 0.0, 0.0], 14 | [-1.0, 1.0, -0.6666666666666666] 15 | ]) 16 | 17 | assert_all_close(RobustScaler.fit_transform(data), expected) 18 | end 19 | 20 | test "applies scaling to data with custom quantile range" do 21 | data = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 22 | 23 | expected = 24 | Nx.tensor([ 25 | [0.0, -0.7142857142857142, 1.0], 26 | [0.7142857142857142, 0.0, 0.0], 27 | [-0.7142857142857142, 0.7142857142857142, -0.5] 28 | ]) 29 | 30 | assert_all_close( 31 | RobustScaler.fit_transform(data, quantile_range: {10, 80}), 32 | expected 33 | ) 34 | end 35 | 36 | test "handles constant data (all values the same)" do 37 | data = Nx.tensor([[5, 5, 5], [5, 5, 5], [5, 5, 5]]) 38 | expected = Nx.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) 39 | 40 | assert_all_close(RobustScaler.fit_transform(data), expected) 41 | end 42 | 43 | test "handles already scaled data" do 44 | data = Nx.tensor([[0, -1, 1], [1, 0, 0], [-1, 1, -1]]) 45 | expected = data 46 | 47 | assert_all_close(RobustScaler.fit_transform(data), expected) 48 | end 49 | 50 | test "handles single-row tensor" do 51 | data = Nx.tensor([[1, 2, 3]]) 52 | expected = Nx.tensor([[0.0, 0.0, 0.0]]) 53 | 54 | assert_all_close(RobustScaler.fit_transform(data), expected) 55 | end 56 | 57 | test "handles single-column tensor" do 58 | data = Nx.tensor([[1], [2], [3]]) 59 | expected = Nx.tensor([[-1.0], [0.0], [1.0]]) 60 | 61 | assert_all_close(RobustScaler.fit_transform(data), expected) 62 | end 63 | 64 | test "handles data with negative values only" do 65 | data = Nx.tensor([[-5, -10, -15], [-15, -5, -20], [-10, -15, -5]]) 66 | 67 | expected = 68 | Nx.tensor([ 69 | [1.0, 0.0, 0.0], 70 | [-1.0, 1.0, -0.6666666666666666], 71 | [0.0, -1.0, 1.3333333333333333] 72 | ]) 73 | 74 | assert_all_close(RobustScaler.fit_transform(data), expected) 75 | end 76 | 77 | test "handles data with extreme outliers" do 78 | data = Nx.tensor([[1, 2, 3], [1000, 2000, 3000], [-1000, -2000, -3000]]) 79 | 80 | expected = 81 | Nx.tensor([[0.0, 0.0, 0.0], [0.999, 0.999, 0.999], [-1.001, -1.001, -1.001]]) 82 | 83 | assert_all_close( 84 | RobustScaler.fit_transform(data), 85 | expected 86 | ) 87 | end 88 | end 89 | 90 | describe "errors" do 91 | test "wrong input rank for fit" do 92 | assert_raise ArgumentError, 93 | "expected tensor to have shape {num_samples, num_features}, got tensor with shape: {1, 1, 1}", 94 | fn -> 95 | RobustScaler.fit(Nx.tensor([[[1]]])) 96 | end 97 | end 98 | 99 | test "wrong input rank for transform" do 100 | assert_raise ArgumentError, 101 | "expected tensor to have shape {num_samples, num_features}, got tensor with shape: {1, 1, 1}", 102 | fn -> 103 | RobustScaler.fit(Nx.tensor([[1]])) 104 | |> RobustScaler.transform(Nx.tensor([[[1]]])) 105 | end 106 | end 107 | 108 | test "wrong quantile range" do 109 | assert_raise NimbleOptions.ValidationError, 110 | "invalid value for :quantile_range option: expected :quantile_range to be a tuple {q_min, q_max} such that 0.0 < q_min < q_max < 100.0, got: {10, 800}", 111 | fn -> 112 | RobustScaler.fit(Nx.tensor([[[1]]]), quantile_range: {10, 800}) 113 | end 114 | end 115 | end 116 | end 117 | -------------------------------------------------------------------------------- /test/scholar/preprocessing/standard_scaler_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.Preprocessing.StandardScalerTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Preprocessing.StandardScaler 4 | 5 | doctest StandardScaler 6 | 7 | describe "fit_transform/2" do 8 | test "applies standard scaling to data" do 9 | data = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 10 | 11 | expected = 12 | Nx.tensor([ 13 | [0.5212860703468323, -1.3553436994552612, 1.4596009254455566], 14 | [1.4596009254455566, -0.4170288145542145, -0.4170288145542145], 15 | [-0.4170288145542145, 0.5212860703468323, -1.3553436994552612] 16 | ]) 17 | 18 | assert_all_close(StandardScaler.fit_transform(data), expected) 19 | end 20 | 21 | test "leaves data as it is when variance is zero" do 22 | data = 42.0 23 | expected = Nx.tensor(data) 24 | assert StandardScaler.fit_transform(data) == expected 25 | end 26 | end 27 | end 28 | -------------------------------------------------------------------------------- /test/scholar/preprocessing_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.PreprocessingTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Preprocessing 4 | doctest Preprocessing 5 | 6 | describe "standard_scaler/1" do 7 | test "applies standard scaling to data" do 8 | data = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]) 9 | 10 | expected = 11 | Nx.tensor([ 12 | [0.5212860703468323, -1.3553436994552612, 1.4596009254455566], 13 | [1.4596009254455566, -0.4170288145542145, -0.4170288145542145], 14 | [-0.4170288145542145, 0.5212860703468323, -1.3553436994552612] 15 | ]) 16 | 17 | assert_all_close(Preprocessing.standard_scale(data), expected) 18 | end 19 | 20 | test "leaves data as it is when variance is zero" do 21 | data = 42.0 22 | expected = Nx.tensor(data) 23 | assert Preprocessing.standard_scale(data) == expected 24 | end 25 | end 26 | end 27 | -------------------------------------------------------------------------------- /test/scholar/stats_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Scholar.StatsTest do 2 | use Scholar.Case, async: true 3 | alias Scholar.Stats 4 | doctest Stats 5 | 6 | defp x do 7 | Nx.tensor([[3, 5, 3], [2, 6, 1], [9, 3, 2], [1, 6, 8]]) 8 | end 9 | 10 | describe "moment" do 11 | test "moment set to 3" do 12 | expected_moment = Nx.tensor([29.53125, -1.5, 18.0]) 13 | assert_all_close(Stats.moment(x(), 3), expected_moment) 14 | end 15 | 16 | test "moment set to 3 and axis set to 1" do 17 | expected_moment = Nx.tensor([0.5925924181938171, 6.0, 19.25926399230957, -12.0]) 18 | assert_all_close(Stats.moment(x(), 3, axes: [1]), expected_moment) 19 | end 20 | 21 | test "moment set to 3 and axis set to 1 and keep_axes set to true" do 22 | expected_moment = Nx.tensor([[0.5925924181938171], [6.0], [19.25926399230957], [-12.0]]) 23 | assert_all_close(Stats.moment(x(), 3, axes: [1], keep_axes: true), expected_moment) 24 | end 25 | 26 | test "moment set to 3 and axis set to nil" do 27 | expected_moment = Nx.tensor(9.438657407407414) 28 | assert_all_close(Stats.moment(x(), 3, axes: nil), expected_moment) 29 | end 30 | 31 | test "moment set to 3 and axis set to nil and keep_axes set to true" do 32 | expected_moment = Nx.tensor([[9.438657407407414]]) 33 | assert_all_close(Stats.moment(x(), 3, axes: nil, keep_axes: true), expected_moment) 34 | end 35 | 36 | test "moment set to 3 and axis set to [0, 1] and keep_axes set to true" do 37 | expected_moment = Nx.tensor([[9.438657407407414]]) 38 | assert_all_close(Stats.moment(x(), 3, axes: [0, 1], keep_axes: true), expected_moment) 39 | end 40 | end 41 | 42 | describe "skew" do 43 | test "all defaults" do 44 | expected_skew = Nx.tensor([0.97940938, -0.81649658, 0.9220734]) 45 | assert_all_close(Stats.skew(x()), expected_skew) 46 | end 47 | 48 | test "axis set to 1" do 49 | expected_skew = Nx.tensor([0.70710678, 0.59517006, 0.65201212, -0.47033046]) 50 | assert_all_close(Stats.skew(x(), axes: [1]), expected_skew) 51 | end 52 | 53 | test "axis set to 1 and keep_axes set to true" do 54 | expected_skew = Nx.tensor([[0.70710678], [0.59517006], [0.65201212], [-0.47033046]]) 55 | assert_all_close(Stats.skew(x(), axes: [1], keep_axes: true), expected_skew) 56 | end 57 | 58 | test "axis set to nil" do 59 | expected_skew = Nx.tensor(0.5596660882003394) 60 | assert_all_close(Stats.skew(x(), axes: nil), expected_skew) 61 | end 62 | 63 | test "axis set to nil and keep_axes set to true" do 64 | expected_skew = Nx.tensor([[0.5596660882003394]]) 65 | assert_all_close(Stats.skew(x(), axes: nil, keep_axes: true), expected_skew) 66 | end 67 | 68 | test "axis set to [0, 1] and keep_axes set to true" do 69 | expected_skew = Nx.tensor([[0.5596660882003394]]) 70 | assert_all_close(Stats.skew(x(), axes: [0, 1], keep_axes: true), expected_skew) 71 | end 72 | 73 | test "axis set to [1] and bias set to false" do 74 | expected_skew = Nx.tensor([1.73205081, 1.45786297, 1.59709699, -1.15206964]) 75 | assert_all_close(Stats.skew(x(), axes: [1], bias: false), expected_skew) 76 | end 77 | end 78 | 79 | describe "kurtosis" do 80 | test "all defaults" do 81 | expected_kurtosis = Nx.tensor([-0.79808533, -1.0, -0.83947681]) 82 | assert_all_close(Stats.kurtosis(x()), expected_kurtosis) 83 | end 84 | 85 | test "axis set to 1" do 86 | expected_kurtosis = Nx.tensor([-1.5, -1.5, -1.5, -1.5]) 87 | assert_all_close(Stats.kurtosis(x(), axes: [1]), expected_kurtosis) 88 | end 89 | 90 | test "axis set to 1 and keep_axes set to true" do 91 | expected_kurtosis = Nx.tensor([[-1.5], [-1.5], [-1.5], [-1.5]]) 92 | assert_all_close(Stats.kurtosis(x(), axes: [1], keep_axes: true), expected_kurtosis) 93 | end 94 | 95 | test "axis set to nil" do 96 | expected_kurtosis = Nx.tensor(-0.9383737228328437) 97 | assert_all_close(Stats.kurtosis(x(), axes: nil), expected_kurtosis) 98 | end 99 | 100 | test "axis set to nil and keep_axes set to true" do 101 | expected_kurtosis = Nx.tensor([[-0.9383737228328437]]) 102 | assert_all_close(Stats.kurtosis(x(), axes: nil, keep_axes: true), expected_kurtosis) 103 | end 104 | 105 | test "axis set to [0, 1] and keep_axes set to true" do 106 | expected_kurtosis = Nx.tensor([[-0.9383737228328437]]) 107 | assert_all_close(Stats.kurtosis(x(), axes: [0, 1], keep_axes: true), expected_kurtosis) 108 | end 109 | 110 | test "axis set to nil and bias set to false" do 111 | expected_kurtosis = Nx.tensor(-0.757638248501074) 112 | assert_all_close(Stats.kurtosis(x(), axes: nil, bias: false), expected_kurtosis) 113 | end 114 | 115 | test "axis set to nil and bias set to false and variant set to pearson" do 116 | expected_kurtosis = Nx.tensor(2.242361751498926) 117 | 118 | assert_all_close( 119 | Stats.kurtosis(x(), axes: nil, bias: false, variant: :pearson), 120 | expected_kurtosis 121 | ) 122 | end 123 | end 124 | end 125 | -------------------------------------------------------------------------------- /test/support/diabetes_data_raw.csv: -------------------------------------------------------------------------------- 1 | 59 2 32.1 101.0 157 93.2 38.0 4.0 4.8598 87 2 | 48 1 21.6 87.0 183 103.2 70.0 3.0 3.8918 69 3 | 72 2 30.5 93.0 156 93.6 41.0 4.0 4.6728 85 4 | 24 1 25.3 84.0 198 131.4 40.0 5.0 4.8903 89 5 | 50 1 23.0 101.0 192 125.4 52.0 4.0 4.2905 80 6 | 23 1 22.6 89.0 139 64.8 61.0 2.0 4.1897 68 7 | 36 2 22.0 90.0 160 99.6 50.0 3.0 3.9512 82 8 | 66 2 26.2 114.0 255 185.0 56.0 4.55 4.2485 92 9 | 60 2 32.1 83.0 179 119.4 42.0 4.0 4.4773 94 10 | 29 1 30.0 85.0 180 93.4 43.0 4.0 5.3845 88 11 | 22 1 18.6 97.0 114 57.6 46.0 2.0 3.9512 83 12 | 56 2 28.0 85.0 184 144.8 32.0 6.0 3.5835 77 13 | 53 1 23.7 92.0 186 109.2 62.0 3.0 4.3041 81 14 | 50 2 26.2 97.0 186 105.4 49.0 4.0 5.0626 88 15 | 61 1 24.0 91.0 202 115.4 72.0 3.0 4.2905 73 16 | 34 2 24.7 118.0 254 184.2 39.0 7.0 5.037 81 17 | 47 1 30.3 109.0 207 100.2 70.0 3.0 5.2149 98 18 | 68 2 27.5 111.0 214 147.0 39.0 5.0 4.9416 91 19 | 38 1 25.4 84.0 162 103.0 42.0 4.0 4.4427 87 20 | 41 1 24.7 83.0 187 108.2 60.0 3.0 4.5433 78 21 | 35 1 21.1 82.0 156 87.8 50.0 3.0 4.5109 95 22 | 25 2 24.3 95.0 162 98.6 54.0 3.0 3.8501 87 23 | 25 1 26.0 92.0 187 120.4 56.0 3.0 3.9703 88 24 | 61 2 32.0 103.67 210 85.2 35.0 6.0 6.107 124 25 | 31 1 29.7 88.0 167 103.4 48.0 4.0 4.3567 78 26 | 30 2 25.2 83.0 178 118.4 34.0 5.0 4.852 83 27 | 19 1 19.2 87.0 124 54.0 57.0 2.0 4.1744 90 28 | 42 1 31.9 83.0 158 87.6 53.0 3.0 4.4659 101 29 | 63 1 24.4 73.0 160 91.4 48.0 3.0 4.6347 78 30 | 67 2 25.8 113.0 158 54.2 64.0 2.0 5.2933 104 31 | 32 1 30.5 89.0 182 110.6 56.0 3.0 4.3438 89 32 | 42 1 20.3 71.0 161 81.2 66.0 2.0 4.2341 81 33 | 58 2 38.0 103.0 150 107.2 22.0 7.0 4.6444 98 34 | 57 1 21.7 94.0 157 58.0 82.0 2.0 4.4427 92 35 | 53 1 20.5 78.0 147 84.2 52.0 3.0 3.989 75 36 | 62 2 23.5 80.33 225 112.8 86.0 2.62 4.8752 96 37 | 52 1 28.5 110.0 195 97.2 60.0 3.0 5.2417 85 38 | 46 1 27.4 78.0 171 88.0 58.0 3.0 4.8283 90 39 | 48 2 33.0 123.0 253 163.6 44.0 6.0 5.425 97 40 | 48 2 27.7 73.0 191 119.4 46.0 4.0 4.852 92 41 | 50 2 25.6 101.0 229 162.2 43.0 5.0 4.7791 114 42 | 21 1 20.1 63.0 135 69.0 54.0 3.0 4.0943 89 43 | 32 2 25.4 90.33 153 100.4 34.0 4.5 4.5326 83 44 | 54 1 24.2 74.0 204 109.0 82.0 2.0 4.1744 109 45 | 61 2 32.7 97.0 177 118.4 29.0 6.0 4.9972 87 46 | 56 2 23.1 104.0 181 116.4 47.0 4.0 4.4773 79 47 | 33 1 25.3 85.0 155 85.0 51.0 3.0 4.5539 70 48 | 27 1 19.6 78.0 128 68.0 43.0 3.0 4.4427 71 49 | 67 2 22.5 98.0 191 119.2 61.0 3.0 3.989 86 50 | 37 2 27.7 93.0 180 119.4 30.0 6.0 5.0304 88 -------------------------------------------------------------------------------- /test/support/diabetes_target.csv: -------------------------------------------------------------------------------- 1 | 1.510000000000000000e+02 2 | 7.500000000000000000e+01 3 | 1.410000000000000000e+02 4 | 2.060000000000000000e+02 5 | 1.350000000000000000e+02 6 | 9.700000000000000000e+01 7 | 1.380000000000000000e+02 8 | 6.300000000000000000e+01 9 | 1.100000000000000000e+02 10 | 3.100000000000000000e+02 11 | 1.010000000000000000e+02 12 | 6.900000000000000000e+01 13 | 1.790000000000000000e+02 14 | 1.850000000000000000e+02 15 | 1.180000000000000000e+02 16 | 1.710000000000000000e+02 17 | 1.660000000000000000e+02 18 | 1.440000000000000000e+02 19 | 9.700000000000000000e+01 20 | 1.680000000000000000e+02 21 | 6.800000000000000000e+01 22 | 4.900000000000000000e+01 23 | 6.800000000000000000e+01 24 | 2.450000000000000000e+02 25 | 1.840000000000000000e+02 26 | 2.020000000000000000e+02 27 | 1.370000000000000000e+02 28 | 8.500000000000000000e+01 29 | 1.310000000000000000e+02 30 | 2.830000000000000000e+02 31 | 1.290000000000000000e+02 32 | 5.900000000000000000e+01 33 | 3.410000000000000000e+02 34 | 8.700000000000000000e+01 35 | 6.500000000000000000e+01 36 | 1.020000000000000000e+02 37 | 2.650000000000000000e+02 38 | 2.760000000000000000e+02 39 | 2.520000000000000000e+02 40 | 9.000000000000000000e+01 41 | 1.000000000000000000e+02 42 | 5.500000000000000000e+01 43 | 6.100000000000000000e+01 44 | 9.200000000000000000e+01 45 | 2.590000000000000000e+02 46 | 5.300000000000000000e+01 47 | 1.900000000000000000e+02 48 | 1.420000000000000000e+02 49 | 7.500000000000000000e+01 50 | 1.420000000000000000e+02 -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | Application.ensure_all_started(:exla) 2 | --------------------------------------------------------------------------------