├── .github ├── .gitignore └── workflows │ ├── pkgdown.yaml │ ├── test-coverage.yaml │ ├── rhub.yaml │ └── R-CMD-check.yaml ├── _pkgdown.yml ├── tests ├── testthat │ ├── .gitignore │ ├── _snaps │ │ ├── pretraining.md │ │ └── hardhat_interfaces.md │ ├── test_translations.R │ ├── test-dials.R │ ├── helper-tensor.R │ ├── setup.R │ ├── test-loss.R │ ├── test-plot.R │ ├── test-hardhat_scenarios.R │ ├── test-missing_values.R │ ├── test-hardhat_interfaces.R │ ├── test-explain.R │ ├── test-mask-type.R │ ├── test-hardhat_multi-outcome.R │ ├── test-hardhat_hierarchical.R │ ├── test-pretraining.R │ ├── test-model.R │ └── test-parsnip.R ├── testthat.R └── spelling.R ├── vignettes ├── .gitignore ├── ames_fit.png ├── ames_fit_vip.png ├── ames_fit_vip_.png ├── ames_pretrain.png ├── vanillia_model.png ├── vis_miss_ames.png ├── ames_mas_vnr_hist.png ├── ames_missing_fit.png ├── ames_pretrain_vip.png ├── pretrained_model.png ├── pretraining_loss.png ├── ames_missing_fit_vip.png ├── ames_pretrain_vip_.png ├── ames_pretrain_vip__.png ├── ames_missing_pretrain.png ├── ames_missing_pretrain_vip.png ├── ames_missing_pretrain_vip_.png ├── tidymodels-interface.Rmd ├── interpretation.Rmd └── aum_loss.Rmd ├── LICENSE ├── inst ├── po │ └── fr │ │ └── LC_MESSAGES │ │ └── R-tabnet.mo └── WORDLIST ├── man ├── figures │ ├── README-model-fit-1.png │ ├── README-model-explain-1.png │ ├── README-step-explain-1.png │ └── README-step-pretrain-1.png ├── pipe.Rd ├── get_tau.Rd ├── nn_aum_loss.Rd ├── check_compliant_node.Rd ├── node_to_df.Rd ├── sparsemax.Rd ├── nn_prune_head.Rd ├── min_grid.tabnet.Rd ├── tabnet_non_tunable.Rd ├── autoplot.tabnet_fit.Rd ├── entmax15.Rd ├── tabnet_explain.Rd ├── tabnet_params.Rd ├── autoplot.tabnet_explain.Rd ├── tabnet_nn.Rd ├── tabnet_pretrain.Rd ├── tabnet_fit.Rd ├── tabnet_config.Rd └── tabnet.Rd ├── cran-comments.md ├── .gitignore ├── .Rbuildignore ├── R ├── utils-pipe.R ├── package.R ├── explain.R ├── loss.R ├── dials.R ├── plot.R └── utils.R ├── LICENSE.md ├── DESCRIPTION ├── NAMESPACE ├── NEWS.md ├── po └── R-tabnet.pot ├── README.Rmd └── README.md /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /_pkgdown.yml: -------------------------------------------------------------------------------- 1 | template: 2 | bootstrap: 5 -------------------------------------------------------------------------------- /tests/testthat/.gitignore: -------------------------------------------------------------------------------- 1 | Rplots.pdf 2 | -------------------------------------------------------------------------------- /vignettes/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | *.R 3 | *_files 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2020 2 | COPYRIGHT HOLDER: RStudio, PBC 3 | -------------------------------------------------------------------------------- /vignettes/ames_fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/ames_fit.png -------------------------------------------------------------------------------- /vignettes/ames_fit_vip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/ames_fit_vip.png -------------------------------------------------------------------------------- /vignettes/ames_fit_vip_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/ames_fit_vip_.png -------------------------------------------------------------------------------- /vignettes/ames_pretrain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/ames_pretrain.png -------------------------------------------------------------------------------- /vignettes/vanillia_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/vanillia_model.png -------------------------------------------------------------------------------- /vignettes/vis_miss_ames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/vis_miss_ames.png -------------------------------------------------------------------------------- /vignettes/ames_mas_vnr_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/ames_mas_vnr_hist.png -------------------------------------------------------------------------------- /vignettes/ames_missing_fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/ames_missing_fit.png -------------------------------------------------------------------------------- /vignettes/ames_pretrain_vip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/ames_pretrain_vip.png -------------------------------------------------------------------------------- /vignettes/pretrained_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/pretrained_model.png -------------------------------------------------------------------------------- /vignettes/pretraining_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/pretraining_loss.png -------------------------------------------------------------------------------- /inst/po/fr/LC_MESSAGES/R-tabnet.mo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/inst/po/fr/LC_MESSAGES/R-tabnet.mo -------------------------------------------------------------------------------- /man/figures/README-model-fit-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/man/figures/README-model-fit-1.png -------------------------------------------------------------------------------- /vignettes/ames_missing_fit_vip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/ames_missing_fit_vip.png -------------------------------------------------------------------------------- /vignettes/ames_pretrain_vip_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/ames_pretrain_vip_.png -------------------------------------------------------------------------------- /vignettes/ames_pretrain_vip__.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/ames_pretrain_vip__.png -------------------------------------------------------------------------------- /vignettes/ames_missing_pretrain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/ames_missing_pretrain.png -------------------------------------------------------------------------------- /man/figures/README-model-explain-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/man/figures/README-model-explain-1.png -------------------------------------------------------------------------------- /man/figures/README-step-explain-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/man/figures/README-step-explain-1.png -------------------------------------------------------------------------------- /man/figures/README-step-pretrain-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/man/figures/README-step-pretrain-1.png -------------------------------------------------------------------------------- /vignettes/ames_missing_pretrain_vip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/ames_missing_pretrain_vip.png -------------------------------------------------------------------------------- /vignettes/ames_missing_pretrain_vip_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/tabnet/HEAD/vignettes/ames_missing_pretrain_vip_.png -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(tabnet) 3 | 4 | if (Sys.getenv("TORCH_TEST", unset = 0) == 1) 5 | test_check("tabnet") 6 | -------------------------------------------------------------------------------- /tests/spelling.R: -------------------------------------------------------------------------------- 1 | if(requireNamespace('spelling', quietly = TRUE)) 2 | spelling::spell_check_test(vignettes = TRUE, error = FALSE, 3 | skip_on_cran = TRUE) 4 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- 1 | ## R CMD check results 2 | 3 | 0 errors | 0 warnings | 1 note 4 | 5 | * This is a new release. 6 | 7 | Note 1 : Example duration is inherent to fitting a model with the underlying torch framework. 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | docs 3 | inst/doc 4 | .Rhistory 5 | .venv 6 | activate 7 | .V8history 8 | /doc/ 9 | /Meta/ 10 | revdep 11 | revdep/ 12 | tabnet.Rcheck 13 | ..Rcheck 14 | tabnet_*.tar.gz 15 | tabnet.Rproj 16 | po/glossary.csv 17 | inst/IMPORTLIST 18 | -------------------------------------------------------------------------------- /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^tabnet\.Rproj$ 2 | ^\.Rproj\.user$ 3 | ^README\.Rmd$ 4 | ^\.github$ 5 | ^LICENSE\.md$ 6 | ^_pkgdown\.yml$ 7 | ^docs$ 8 | ^pkgdown$ 9 | ^cran-comments\.md$ 10 | ^CRAN-RELEASE$ 11 | ^.V8* 12 | ^doc$ 13 | ^Meta$ 14 | ^CRAN-SUBMISSION$ 15 | ^revdep$ 16 | ^vignettes/*_files$ 17 | -------------------------------------------------------------------------------- /R/utils-pipe.R: -------------------------------------------------------------------------------- 1 | #' Pipe operator 2 | #' 3 | #' See \code{magrittr::\link[magrittr:pipe]{\%>\%}} for details. 4 | #' 5 | #' @name %>% 6 | #' @rdname pipe 7 | #' @keywords internal 8 | #' @export 9 | #' @importFrom magrittr %>% 10 | #' @importFrom zeallot %<-% 11 | #' @usage lhs \%>\% rhs 12 | #' 13 | #' @return Returns `rhs(lhs)`. 14 | NULL 15 | -------------------------------------------------------------------------------- /man/pipe.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/utils-pipe.R 3 | \name{\%>\%} 4 | \alias{\%>\%} 5 | \title{Pipe operator} 6 | \usage{ 7 | lhs \%>\% rhs 8 | } 9 | \value{ 10 | Returns \code{rhs(lhs)}. 11 | } 12 | \description{ 13 | See \code{magrittr::\link[magrittr:pipe]{\%>\%}} for details. 14 | } 15 | \keyword{internal} 16 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/pretraining.md: -------------------------------------------------------------------------------- 1 | # print module works 2 | 3 | An `nn_module` containing 13,190 parameters. 4 | 5 | -- Modules --------------------------------------- 6 | * initial_bn: #146 parameters 7 | * embedder: #283 parameters 8 | * embedder_na: #0 parameters 9 | * masker: #0 parameters 10 | * encoder: #10,304 parameters 11 | * decoder: #2,456 parameters 12 | 13 | -- Parameters ------------------------------------ 14 | * .check: Float [1:1] 15 | 16 | -------------------------------------------------------------------------------- /inst/WORDLIST: -------------------------------------------------------------------------------- 1 | AUM 2 | Ames 3 | Arik 4 | Bugfixes 5 | Eleonora 6 | Explicitely 7 | FNR 8 | FPR 9 | GLU 10 | Giunchiglia 11 | Hillman 12 | Interpretability 13 | Interpretable 14 | Lifecycle 15 | MNAR 16 | OOM 17 | Pfister 18 | Pretrain 19 | Sercan 20 | Sparsemax 21 | TabNet 22 | TabNet's 23 | Tsallis 24 | XGBoost 25 | ai 26 | al 27 | ames 28 | arXiv 29 | autoassociative 30 | autograd 31 | beeing 32 | callout 33 | classif 34 | cli 35 | config 36 | cpu 37 | cuda 38 | dataloading 39 | detailled 40 | doi 41 | dreamquark 42 | entmax 43 | et 44 | explainability 45 | ggplot 46 | interpretable 47 | mse 48 | nn 49 | num 50 | orginal 51 | overfit 52 | overfits 53 | pre 54 | pretrain 55 | pretrained 56 | pretraining 57 | reusage 58 | softmax 59 | sparsemax 60 | subprocesses 61 | th 62 | tibble 63 | tidymodels 64 | tunable 65 | zeallot 66 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/hardhat_interfaces.md: -------------------------------------------------------------------------------- 1 | # print module works even after a reload from disk 2 | 3 | An `nn_module` containing 10,742 parameters. 4 | 5 | -- Modules --------------------------------------- 6 | * embedder: #283 parameters 7 | * embedder_na: #0 parameters 8 | * tabnet: #10,458 parameters 9 | 10 | -- Parameters ------------------------------------ 11 | * .check: Float [1:1] 12 | 13 | --- 14 | 15 | An `nn_module` containing 10,742 parameters. 16 | 17 | -- Modules --------------------------------------- 18 | * embedder: #283 parameters 19 | * embedder_na: #0 parameters 20 | * tabnet: #10,458 parameters 21 | 22 | -- Parameters ------------------------------------ 23 | * .check: Float [1:1] 24 | 25 | -------------------------------------------------------------------------------- /man/get_tau.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/mask-type.R 3 | \name{get_tau} 4 | \alias{get_tau} 5 | \title{Optimal threshold (tau) computation for 1.5-entmax} 6 | \usage{ 7 | get_tau(input, dim = -1L, k = NULL) 8 | } 9 | \arguments{ 10 | \item{input}{The input tensor to compute thresholds over.} 11 | 12 | \item{dim}{The dimension along which to apply 1.5-entmax. Default is -1.} 13 | 14 | \item{k}{The number of largest elements to partial-sort over. For optimal 15 | performance, should be slightly bigger than the expected number of 16 | non-zeros in the solution. If the solution is more than k-sparse, 17 | this function is recursively called with a 2*k schedule. If \code{NULL}, 18 | full sorting is performed from the beginning. Default is NULL.} 19 | } 20 | \value{ 21 | The threshold value for each vector, with all but the \code{dim} 22 | dimension intact. 23 | } 24 | \description{ 25 | Optimal threshold (tau) computation for 1.5-entmax 26 | } 27 | -------------------------------------------------------------------------------- /man/nn_aum_loss.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/loss.R 3 | \name{nn_aum_loss} 4 | \alias{nn_aum_loss} 5 | \title{AUM loss} 6 | \usage{ 7 | nn_aum_loss() 8 | } 9 | \description{ 10 | Creates a criterion that measures the Area under the \eqn{Min(FPR, FNR)} (AUM) between each 11 | element in the input \eqn{pred_tensor} and target \eqn{label_tensor}. 12 | } 13 | \details{ 14 | This is used for measuring the error of a binary reconstruction within highly unbalanced dataset, 15 | where the goal is optimizing the ROC curve. Note that the targets \eqn{label_tensor} should be factor 16 | level of the binary outcome, i.e. with values \code{1L} and \code{2L}. 17 | } 18 | \examples{ 19 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 20 | loss <- nn_aum_loss() 21 | input <- torch::torch_randn(4, 6, requires_grad = TRUE) 22 | target <- input > 1.5 23 | output <- loss(input, target) 24 | output$backward() 25 | \dontshow{\}) # examplesIf} 26 | } 27 | -------------------------------------------------------------------------------- /R/package.R: -------------------------------------------------------------------------------- 1 | .onLoad <- function(...) { 2 | vctrs::s3_register("parsnip::multi_predict", "_tabnet_fit") 3 | vctrs::s3_register("vip::vi_model", "tabnet_fit") 4 | vctrs::s3_register("vip::vi_model", "tabnet_pretrain") 5 | vctrs::s3_register("ggplot2::autoplot", "tabnet_fit") 6 | vctrs::s3_register("ggplot2::autoplot", "tabnet_pretrain") 7 | vctrs::s3_register("ggplot2::autoplot", "tabnet_explain") 8 | vctrs::s3_register("torch::nn_prune_head", "tabnet_fit") 9 | vctrs::s3_register("torch::nn_prune_head", "tabnet_pretrain") 10 | vctrs::s3_register("tune::min_grid", "tabnet") 11 | } 12 | 13 | 14 | globalVariables(c("batch_size", 15 | "checkpoint", 16 | "dataset", 17 | "epoch", 18 | "has_checkpoint", 19 | "loss", 20 | "mask_agg", 21 | "mean_loss", 22 | "row_number", 23 | "rowname", 24 | "step", 25 | "value", 26 | "variable", 27 | "..")) 28 | -------------------------------------------------------------------------------- /man/check_compliant_node.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/utils.R 3 | \name{check_compliant_node} 4 | \alias{check_compliant_node} 5 | \title{Check that Node object names are compliant} 6 | \usage{ 7 | check_compliant_node(node) 8 | } 9 | \arguments{ 10 | \item{node}{the Node object, or a dataframe ready to be parsed by \code{data.tree::as.Node()}} 11 | } 12 | \value{ 13 | node if it is compliant, else an Error with the column names to fix 14 | } 15 | \description{ 16 | Check that Node object names are compliant 17 | } 18 | \examples{ 19 | \dontshow{if ((require("data.tree") || require("dplyr"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 20 | library(dplyr) 21 | library(data.tree) 22 | data(starwars) 23 | starwars_tree <- starwars \%>\% 24 | mutate(pathString = paste("tree", species, homeworld, `name`, sep = "/")) 25 | 26 | # pre as.Node() check 27 | try(check_compliant_node(starwars_tree)) 28 | 29 | # post as.Node() check 30 | check_compliant_node(as.Node(starwars_tree)) 31 | \dontshow{\}) # examplesIf} 32 | } 33 | -------------------------------------------------------------------------------- /man/node_to_df.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/utils.R 3 | \name{node_to_df} 4 | \alias{node_to_df} 5 | \title{Turn a Node object into predictor and outcome.} 6 | \usage{ 7 | node_to_df(x, drop_last_level = TRUE) 8 | } 9 | \arguments{ 10 | \item{x}{Node object} 11 | 12 | \item{drop_last_level}{TRUE unused} 13 | } 14 | \value{ 15 | a named list of x and y, being respectively the predictor data-frame and the outcomes data-frame, 16 | as expected inputs for \code{hardhat::mold()} function. 17 | } 18 | \description{ 19 | Turn a Node object into predictor and outcome. 20 | } 21 | \examples{ 22 | \dontshow{if ((require("data.tree") || require("dplyr"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 23 | library(dplyr) 24 | library(data.tree) 25 | data(starwars) 26 | starwars_tree <- starwars \%>\% 27 | mutate(pathString = paste("tree", species, homeworld, `name`, sep = "/")) \%>\% 28 | as.Node() 29 | node_to_df(starwars_tree)$x \%>\% head() 30 | node_to_df(starwars_tree)$y \%>\% head() 31 | \dontshow{\}) # examplesIf} 32 | } 33 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2020 RStudio, PBC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /man/sparsemax.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/mask-type.R 3 | \name{sparsemax} 4 | \alias{sparsemax} 5 | \alias{sparsemax15} 6 | \title{Sparsemax} 7 | \usage{ 8 | sparsemax(dim = -1L) 9 | 10 | sparsemax15(dim = -1L, k = NULL) 11 | } 12 | \arguments{ 13 | \item{dim}{The dimension along which to apply sparsemax.} 14 | 15 | \item{k}{The number of largest elements to partial-sort input over. For optimal 16 | performance, \code{k} should be slightly bigger than the expected number of 17 | non-zeros in the solution. If the solution is more than k-sparse, 18 | this function is recursively called with a 2*k schedule. If \code{NULL}, full 19 | sorting is performed from the beginning.} 20 | } 21 | \value{ 22 | The projection result, such that \eqn{\sum_{dim} P = 1 \forall dim} elementwise. 23 | } 24 | \description{ 25 | Normalizing sparse transform (a la softmax). 26 | } 27 | \details{ 28 | Solves the projection: 29 | 30 | \eqn{\min_P ||input - P||_2 \text{ s.t. } P \geq0, \sum(P) ==1} 31 | } 32 | \examples{ 33 | input <- torch::torch_randn(10, 5, requires_grad = TRUE) 34 | # create a top3 alpha=1.5 sparsemax on last input dimension 35 | nn_sparsemax <- sparsemax15(dim=1, k=3) 36 | result <- nn_sparsemax(input) 37 | print(result) 38 | } 39 | -------------------------------------------------------------------------------- /man/nn_prune_head.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/hardhat.R 3 | \name{nn_prune_head.tabnet_fit} 4 | \alias{nn_prune_head.tabnet_fit} 5 | \alias{nn_prune_head.tabnet_pretrain} 6 | \title{Prune top layer(s) of a tabnet network} 7 | \usage{ 8 | \method{nn_prune_head}{tabnet_fit}(x, head_size) 9 | 10 | \method{nn_prune_head}{tabnet_pretrain}(x, head_size) 11 | } 12 | \arguments{ 13 | \item{x}{nn_network to prune} 14 | 15 | \item{head_size}{number of nn_layers to prune, should be less than 2} 16 | } 17 | \value{ 18 | a tabnet network with the top nn_layer removed 19 | } 20 | \description{ 21 | Prune \code{head_size} last layers of a tabnet network in order to 22 | use the pruned module as a sequential embedding module. 23 | } 24 | \examples{ 25 | \dontshow{if ((torch::torch_is_installed())) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 26 | data("ames", package = "modeldata") 27 | x <- ames[,-which(names(ames) == "Sale_Price")] 28 | y <- ames$Sale_Price 29 | # pretrain a tabnet model on ames dataset 30 | ames_pretrain <- tabnet_pretrain(x, y, epoch = 2, checkpoint_epochs = 1) 31 | # prune classification head to get an embedding model 32 | pruned_pretrain <- torch::nn_prune_head(ames_pretrain, 1) 33 | \dontshow{\}) # examplesIf} 34 | } 35 | -------------------------------------------------------------------------------- /tests/testthat/test_translations.R: -------------------------------------------------------------------------------- 1 | test_that("early stopping message get translated in french", { 2 | # skip on linux on ci due to missing language in image 3 | testthat::skip_if((testthat:::on_ci() && testthat:::system_os() == "linux")) 4 | testthat::skip_on_cran() 5 | withr::with_language(lang = "fr", 6 | expect_error( 7 | tabnet_fit(attrix, attriy, epochs = 200, verbose=TRUE, 8 | early_stopping_monitor="cross_validation_loss", 9 | early_stopping_tolerance=1e-7, early_stopping_patience=3, learn_rate = 0.2), 10 | regexp = "n'est pas une m" 11 | ) 12 | ) 13 | }) 14 | 15 | test_that("scheduler message translated in french", { 16 | # skip on linux on ci due to missing language in image 17 | testthat::skip_if((testthat:::on_ci() && testthat:::system_os() == "linux")) 18 | testthat::skip_on_cran() 19 | withr::with_language(lang = "fr", 20 | expect_error( 21 | fit <- tabnet_pretrain(x, y, epochs = 3, lr_scheduler = "multiplicative", 22 | lr_decay = 0.1, step_size = 1), 23 | regexp = "Seule les planifications \"step\" et" 24 | ) 25 | ) 26 | }) 27 | -------------------------------------------------------------------------------- /man/min_grid.tabnet.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/parsnip.R 3 | \name{min_grid.tabnet} 4 | \alias{min_grid.tabnet} 5 | \title{Determine the minimum set of model fits} 6 | \usage{ 7 | \method{min_grid}{tabnet}(x, grid, ...) 8 | } 9 | \arguments{ 10 | \item{x}{A model specification.} 11 | 12 | \item{grid}{A tibble with tuning parameter combinations.} 13 | 14 | \item{...}{Not currently used.} 15 | } 16 | \value{ 17 | A tibble with the minimum tuning parameters to fit and an additional 18 | list column with the parameter combinations used for prediction. 19 | } 20 | \description{ 21 | \code{min_grid()} determines exactly what models should be fit in order to 22 | evaluate the entire set of tuning parameter combinations. This is for 23 | internal use only and the API may change in the near future. 24 | } 25 | \details{ 26 | \code{fit_max_value()} can be used in other packages to implement a \code{min_grid()} 27 | method. 28 | } 29 | \examples{ 30 | library(dials) 31 | library(tune) 32 | library(parsnip) 33 | 34 | tabnet_spec <- tabnet(decision_width = tune(), attention_width = tune()) \%>\% 35 | set_mode("regression") \%>\% 36 | set_engine("torch") 37 | 38 | tabnet_grid <- 39 | tabnet_spec \%>\% 40 | extract_parameter_set_dials() \%>\% 41 | grid_regular(levels = 3) 42 | 43 | min_grid(tabnet_spec, tabnet_grid) 44 | 45 | } 46 | \keyword{internal} 47 | -------------------------------------------------------------------------------- /man/tabnet_non_tunable.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/dials.R 3 | \name{cat_emb_dim} 4 | \alias{cat_emb_dim} 5 | \alias{checkpoint_epochs} 6 | \alias{drop_last} 7 | \alias{encoder_activation} 8 | \alias{lr_scheduler} 9 | \alias{mlp_activation} 10 | \alias{mlp_hidden_multiplier} 11 | \alias{num_independent_decoder} 12 | \alias{num_shared_decoder} 13 | \alias{optimizer} 14 | \alias{penalty} 15 | \alias{verbose} 16 | \alias{virtual_batch_size} 17 | \title{Non-tunable parameters for the tabnet model} 18 | \usage{ 19 | cat_emb_dim(range = NULL, trans = NULL) 20 | 21 | checkpoint_epochs(range = NULL, trans = NULL) 22 | 23 | drop_last(range = NULL, trans = NULL) 24 | 25 | encoder_activation(range = NULL, trans = NULL) 26 | 27 | lr_scheduler(range = NULL, trans = NULL) 28 | 29 | mlp_activation(range = NULL, trans = NULL) 30 | 31 | mlp_hidden_multiplier(range = NULL, trans = NULL) 32 | 33 | num_independent_decoder(range = NULL, trans = NULL) 34 | 35 | num_shared_decoder(range = NULL, trans = NULL) 36 | 37 | optimizer(range = NULL, trans = NULL) 38 | 39 | penalty(range = NULL, trans = NULL) 40 | 41 | verbose(range = NULL, trans = NULL) 42 | 43 | virtual_batch_size(range = NULL, trans = NULL) 44 | } 45 | \arguments{ 46 | \item{range}{unused} 47 | 48 | \item{trans}{unused} 49 | } 50 | \description{ 51 | Non-tunable parameters for the tabnet model 52 | } 53 | -------------------------------------------------------------------------------- /man/autoplot.tabnet_fit.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/plot.R 3 | \name{autoplot.tabnet_fit} 4 | \alias{autoplot.tabnet_fit} 5 | \alias{autoplot.tabnet_pretrain} 6 | \title{Plot tabnet_fit model loss along epochs} 7 | \usage{ 8 | autoplot.tabnet_fit(object, ...) 9 | 10 | autoplot.tabnet_pretrain(object, ...) 11 | } 12 | \arguments{ 13 | \item{object}{A \code{tabnet_fit} or \code{tabnet_pretrain} object as a result of 14 | \code{\link[=tabnet_fit]{tabnet_fit()}} or \code{\link[=tabnet_pretrain]{tabnet_pretrain()}}.} 15 | 16 | \item{...}{not used.} 17 | } 18 | \value{ 19 | A \code{ggplot} object. 20 | } 21 | \description{ 22 | Plot tabnet_fit model loss along epochs 23 | } 24 | \details{ 25 | Plot the training loss along epochs, and validation loss along epochs if any. 26 | A dot is added on epochs where model snapshot is available, helping 27 | the choice of \code{from_epoch} value for later model training resume. 28 | } 29 | \examples{ 30 | \dontshow{if ((torch::torch_is_installed() && require("modeldata"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 31 | \dontrun{ 32 | library(ggplot2) 33 | data("attrition", package = "modeldata") 34 | attrition_fit <- tabnet_fit(Attrition ~. , data=attrition, valid_split=0.2, epoch=11) 35 | 36 | # Plot the model loss over epochs 37 | autoplot(attrition_fit) 38 | } 39 | \dontshow{\}) # examplesIf} 40 | } 41 | -------------------------------------------------------------------------------- /man/entmax15.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/mask-type.R 3 | \name{entmax} 4 | \alias{entmax} 5 | \alias{entmax15} 6 | \title{Alpha-entmax} 7 | \usage{ 8 | entmax(dim = -1) 9 | 10 | entmax15(dim = -1L, k = NULL) 11 | } 12 | \arguments{ 13 | \item{dim}{The dimension along which to apply 1.5-entmax.} 14 | 15 | \item{k}{The number of largest elements to partial-sort input over. For optimal 16 | performance, should be slightly bigger than the expected number of 17 | non-zeros in the solution. If the solution is more than k-sparse, 18 | this function is recursively called with a 2*k schedule. If \code{NULL}, full 19 | sorting is performed from the beginning.} 20 | } 21 | \value{ 22 | The projection result P of the same shape as input, such that 23 | \eqn{\sum_{dim} P = 1 \forall dim} elementwise. 24 | } 25 | \description{ 26 | With alpha = 1.5 and normalizing sparse transform (a la softmax). 27 | } 28 | \details{ 29 | Solves the optimization problem: 30 | \eqn{\max_p - H_{1.5}(P) \text{ s.t. } P \geq 0, \sum(P) == 1} 31 | where \eqn{H_{1.5}(P)} is the Tsallis alpha-entropy with \eqn{\alpha=1.5}. 32 | } 33 | \examples{ 34 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 35 | \dontrun{ 36 | input <- torch::torch_randn(10,5, requires_grad = TRUE) 37 | # create a top3 alpha=1.5 entmax on last input dimension 38 | nn_entmax <- entmax15(dim=-1L, k = 3) 39 | result <- nn_entmax(input) 40 | } 41 | \dontshow{\}) # examplesIf} 42 | } 43 | -------------------------------------------------------------------------------- /tests/testthat/test-dials.R: -------------------------------------------------------------------------------- 1 | test_that("Check we can use hardhat:::extract_parameter_set_dials() with {dial} tune()ed parameter", { 2 | 3 | model <- tabnet(batch_size = tune(), learn_rate = tune(), epochs = tune(), 4 | momentum = tune(), penalty = tune(), rate_step_size = tune()) %>% 5 | parsnip::set_mode("regression") %>% 6 | parsnip::set_engine("torch") 7 | 8 | wf <- workflows::workflow() %>% 9 | workflows::add_model(model) %>% 10 | workflows::add_formula(Sale_Price ~ .) 11 | 12 | expect_no_error( 13 | wf %>% hardhat::extract_parameter_set_dials() 14 | ) 15 | }) 16 | 17 | test_that("Check we can use hardhat:::extract_parameter_set_dials() with {tabnet} tune()ed parameter", { 18 | 19 | model <- tabnet(num_steps = tune(), num_shared = tune(), mask_type = tune(), 20 | feature_reusage = tune(), attention_width = tune()) %>% 21 | parsnip::set_mode("regression") %>% 22 | parsnip::set_engine("torch") 23 | 24 | wf <- workflows::workflow() %>% 25 | workflows::add_model(model) %>% 26 | workflows::add_formula(Sale_Price ~ .) 27 | 28 | expect_no_error( 29 | wf %>% hardhat::extract_parameter_set_dials() 30 | ) 31 | }) 32 | 33 | test_that("Check non supported tune()ed parameter raise an explicit error", { 34 | 35 | model <- tabnet(cat_emb_dim = tune(), checkpoint_epochs = 0) %>% 36 | parsnip::set_mode("regression") %>% 37 | parsnip::set_engine("torch") 38 | 39 | wf <- workflows::workflow() %>% 40 | workflows::add_model(model) %>% 41 | workflows::add_formula(Sale_Price ~ .) 42 | 43 | expect_error( 44 | wf %>% hardhat::extract_parameter_set_dials(), 45 | regexp = "cannot be used as a .* parameter yet" 46 | ) 47 | }) 48 | 49 | -------------------------------------------------------------------------------- /man/tabnet_explain.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/explain.R 3 | \name{tabnet_explain} 4 | \alias{tabnet_explain} 5 | \alias{tabnet_explain.default} 6 | \alias{tabnet_explain.tabnet_fit} 7 | \alias{tabnet_explain.tabnet_pretrain} 8 | \alias{tabnet_explain.model_fit} 9 | \title{Interpretation metrics from a TabNet model} 10 | \usage{ 11 | tabnet_explain(object, new_data) 12 | 13 | \method{tabnet_explain}{default}(object, new_data) 14 | 15 | \method{tabnet_explain}{tabnet_fit}(object, new_data) 16 | 17 | \method{tabnet_explain}{tabnet_pretrain}(object, new_data) 18 | 19 | \method{tabnet_explain}{model_fit}(object, new_data) 20 | } 21 | \arguments{ 22 | \item{object}{a TabNet fit object} 23 | 24 | \item{new_data}{a data.frame to obtain interpretation metrics.} 25 | } 26 | \value{ 27 | Returns a list with 28 | \itemize{ 29 | \item \code{M_explain}: the aggregated feature importance masks as detailed in 30 | TabNet's paper. 31 | \item \code{masks} a list containing the masks for each step. 32 | } 33 | } 34 | \description{ 35 | Interpretation metrics from a TabNet model 36 | } 37 | \examples{ 38 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 39 | 40 | set.seed(2021) 41 | 42 | n <- 256 43 | x <- data.frame( 44 | x = rnorm(n), 45 | y = rnorm(n), 46 | z = rnorm(n) 47 | ) 48 | 49 | y <- x$x 50 | 51 | fit <- tabnet_fit(x, y, epochs = 10, 52 | num_steps = 1, 53 | batch_size = 512, 54 | attention_width = 1, 55 | num_shared = 1, 56 | num_independent = 1) 57 | 58 | 59 | ex <- tabnet_explain(fit, x) 60 | 61 | \dontshow{\}) # examplesIf} 62 | } 63 | -------------------------------------------------------------------------------- /man/tabnet_params.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/dials.R 3 | \name{attention_width} 4 | \alias{attention_width} 5 | \alias{decision_width} 6 | \alias{feature_reusage} 7 | \alias{momentum} 8 | \alias{mask_type} 9 | \alias{num_independent} 10 | \alias{num_shared} 11 | \alias{num_steps} 12 | \title{Parameters for the tabnet model} 13 | \usage{ 14 | attention_width(range = c(8L, 64L), trans = NULL) 15 | 16 | decision_width(range = c(8L, 64L), trans = NULL) 17 | 18 | feature_reusage(range = c(1, 2), trans = NULL) 19 | 20 | momentum(range = c(0.01, 0.4), trans = NULL) 21 | 22 | mask_type(values = c("sparsemax", "entmax")) 23 | 24 | num_independent(range = c(1L, 5L), trans = NULL) 25 | 26 | num_shared(range = c(1L, 5L), trans = NULL) 27 | 28 | num_steps(range = c(3L, 10L), trans = NULL) 29 | } 30 | \arguments{ 31 | \item{range}{the default range for the parameter value} 32 | 33 | \item{trans}{whether to apply a transformation to the parameter} 34 | 35 | \item{values}{possible values for factor parameters 36 | 37 | These functions are used with \code{tune} grid functions to generate 38 | candidates.} 39 | } 40 | \value{ 41 | A \code{dials} parameter to be used when tuning TabNet models. 42 | } 43 | \description{ 44 | Parameters for the tabnet model 45 | } 46 | \examples{ 47 | \dontshow{if ((require("dials") && require("parsnip") && torch::torch_is_installed())) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 48 | model <- tabnet(attention_width = tune(), feature_reusage = tune(), 49 | momentum = tune(), penalty = tune(), rate_step_size = tune()) \%>\% 50 | parsnip::set_mode("regression") \%>\% 51 | parsnip::set_engine("torch") 52 | \dontshow{\}) # examplesIf} 53 | } 54 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | branches: [main, master] 8 | release: 9 | types: [published] 10 | workflow_dispatch: 11 | 12 | name: pkgdown 13 | 14 | jobs: 15 | pkgdown: 16 | runs-on: ubuntu-latest 17 | # Only restrict concurrency for non-PR jobs 18 | concurrency: 19 | group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }} 20 | env: 21 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 22 | TORCH_INSTALL: 1 23 | steps: 24 | - name: Checkout 25 | uses: actions/checkout@v4 26 | 27 | - name: Set up Quarto 28 | uses: quarto-dev/quarto-actions/setup@v2 29 | with: 30 | tinytex: true 31 | env: 32 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 33 | 34 | - name: Setup Pandoc 35 | uses: r-lib/actions/setup-pandoc@v2 36 | 37 | - name: Setup R 38 | uses: r-lib/actions/setup-r@v2 39 | with: 40 | use-public-rspm: true 41 | 42 | - name: Setup R dependencies 43 | uses: r-lib/actions/setup-r-dependencies@v2 44 | with: 45 | extra-packages: any::pkgdown, local::. 46 | needs: website 47 | install-quarto: true 48 | 49 | - name: Build site 50 | run: pkgdown::build_site_github_pages(new_process = FALSE, install = TRUE) 51 | shell: Rscript {0} 52 | 53 | - name: Deploy to GitHub pages 🚀 54 | if: github.event_name != 'pull_request' 55 | uses: JamesIves/github-pages-deploy-action@v4 56 | with: 57 | clean: false 58 | branch: gh-pages 59 | folder: docs 60 | -------------------------------------------------------------------------------- /tests/testthat/helper-tensor.R: -------------------------------------------------------------------------------- 1 | # as per https://github.com/mlverse/torch/blob/main/tests/testthat/helper-tensor.R 2 | Sys.setenv(KMP_DUPLICATE_LIB_OK = TRUE) 3 | # torch_zeros(1, names="hello") # trigger warning about named tensors 4 | 5 | skip_if_not_test_examples <- function() { 6 | if (Sys.getenv("TEST_EXAMPLES", unset = "0") != "1") { 7 | skip("Not testing examples/readme. Set the env var TEST_EXAMPLES = 1.") 8 | } 9 | } 10 | 11 | skip_if_cuda_not_available <- function() { 12 | if (!cuda_is_available()) { 13 | skip("A GPU is not available for testing.") 14 | } 15 | } 16 | 17 | skip_if_not_m1_mac <- function() { 18 | if (!grepl("darwin", R.version$os)) { 19 | skip("Not on MacOS") 20 | } 21 | 22 | if (R.version$arch != "aarch64") { 23 | skip("Not an M1 Mac") 24 | } 25 | } 26 | 27 | expect_equal_to_tensor <- function(object, expected, ...) { 28 | expect_equal(torch::as_array(object), torch::as_array(expected), ...) 29 | } 30 | 31 | expect_not_equal_to_tensor <- function(object, expected) { 32 | expect_false(isTRUE(all.equal(torch::as_array(object), torch::as_array(expected)))) 33 | } 34 | 35 | expect_no_error <- function(object, ...) { 36 | expect_error(object, NA, ...) 37 | } 38 | 39 | expect_tensor <- function(object) { 40 | expect_true(torch:::is_torch_tensor(object)) 41 | expect_no_error(torch::as_array(object$to(device = "cpu"))) 42 | } 43 | 44 | expect_equal_to_r <- function(object, expected, ...) { 45 | expect_equal(torch::as_array(object$cpu()), expected, ...) 46 | } 47 | 48 | expect_tensor_shape <- function(object, expected) { 49 | expect_tensor(object) 50 | expect_equal(object$shape, expected) 51 | } 52 | 53 | expect_undefined_tensor <- function(object) { 54 | # TODO 55 | } 56 | 57 | expect_identical_modules <- function(object, expected) { 58 | expect_identical( 59 | attr(object, "module"), 60 | attr(expected, "module") 61 | ) 62 | } 63 | -------------------------------------------------------------------------------- /.github/workflows/test-coverage.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/master/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | branches: [main, master] 8 | 9 | name: test-coverage 10 | 11 | jobs: 12 | test-coverage: 13 | runs-on: ['self-hosted', 'gce', 'gpu'] 14 | 15 | container: 16 | image: 'nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04' 17 | options: '--gpus all --runtime=nvidia' 18 | 19 | timeout-minutes: 120 20 | 21 | env: 22 | RSPM: https://packagemanager.rstudio.com/cran/__linux__/focal/latest 23 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 24 | TORCH_INSTALL: 1 25 | TORCH_TEST: 1 26 | DEBIAN_FRONTEND: 'noninteractive' 27 | 28 | steps: 29 | - uses: actions/checkout@v3 30 | 31 | - run: | 32 | apt-get update -y 33 | apt-get install -y sudo software-properties-common dialog apt-utils tzdata 34 | 35 | - uses: r-lib/actions/setup-r@v2 36 | 37 | - uses: r-lib/actions/setup-r-dependencies@v2 38 | with: 39 | extra-packages: | 40 | any::covr 41 | 42 | - name: Test coverage 43 | run: | 44 | covr::codecov( 45 | quiet = FALSE, 46 | clean = FALSE, 47 | install_path = file.path(Sys.getenv("RUNNER_TEMP"), "package") 48 | ) 49 | shell: Rscript {0} 50 | 51 | - name: Show testthat output 52 | if: always() 53 | run: | 54 | ## -------------------------------------------------------------------- 55 | find ${{ runner.temp }}/package -name 'testthat.Rout*' -exec cat '{}' \; || true 56 | shell: bash 57 | 58 | - name: Upload test results 59 | if: failure() 60 | uses: actions/upload-artifact@v3 61 | with: 62 | name: coverage-test-failures 63 | path: ${{ runner.temp }}/package 64 | 65 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: tabnet 2 | Title: Fit 'TabNet' Models for Classification and Regression 3 | Version: 0.8.0 4 | Authors@R: c( 5 | person("Daniel", "Falbel", , "daniel@rstudio.com", role = "aut"), 6 | person(, "RStudio", role = "cph"), 7 | person("Christophe", "Regouby", , "christophe.regouby@free.fr", role = c("cre", "ctb")), 8 | person("Egill", "Fridgeirsson", role = "ctb"), 9 | person("Philipp", "Haarmeyer", role = "ctb"), 10 | person("Sven", "Verweij", role = "ctb", 11 | comment = c(ORCID = "0000-0002-5573-3952")) 12 | ) 13 | Description: Implements the 'TabNet' model by Sercan O. Arik et al. (2019) 14 | with 'Coherent Hierarchical Multi-label 15 | Classification Networks' by Giunchiglia et al. and 16 | provides a consistent interface for fitting and creating predictions. 17 | It's also fully compatible with the 'tidymodels' ecosystem. 18 | License: MIT + file LICENSE 19 | URL: https://mlverse.github.io/tabnet/, https://github.com/mlverse/tabnet 20 | BugReports: https://github.com/mlverse/tabnet/issues 21 | Depends: 22 | R (>= 3.6) 23 | Imports: 24 | coro, 25 | data.tree, 26 | dials, 27 | dplyr, 28 | ggplot2, 29 | hardhat (>= 1.3.0), 30 | magrittr, 31 | Matrix, 32 | methods, 33 | parsnip, 34 | progress, 35 | purrr, 36 | rlang, 37 | stats, 38 | stringr, 39 | tibble, 40 | tidyr, 41 | torch (>= 0.4.0), 42 | tune, 43 | utils, 44 | vctrs, 45 | withr, 46 | zeallot 47 | Suggests: 48 | cli, 49 | knitr, 50 | modeldata, 51 | patchwork, 52 | quarto, 53 | recipes, 54 | rmarkdown, 55 | rsample, 56 | spelling, 57 | testthat (>= 3.0.0), 58 | tidymodels, 59 | tidyverse, 60 | vip, 61 | visdat, 62 | workflows, 63 | xgboost, 64 | yardstick 65 | VignetteBuilder: knitr 66 | Config/testthat/edition: 3 67 | Config/testthat/parallel: false 68 | Config/testthat/start-first: interface, explain, params 69 | Encoding: UTF-8 70 | Roxygen: list(markdown = TRUE) 71 | RoxygenNote: 7.3.2 72 | Language: en-US 73 | -------------------------------------------------------------------------------- /tests/testthat/setup.R: -------------------------------------------------------------------------------- 1 | # Run before any test 2 | suppressPackageStartupMessages(library(recipes)) 3 | suppressPackageStartupMessages(library(ggplot2)) 4 | suppressPackageStartupMessages(library(data.tree)) 5 | 6 | 7 | # ames small data 8 | utils::data("ames", package = "modeldata") 9 | ids <- sample(nrow(ames), 256) 10 | small_ames <- ames[ids,] 11 | x <- ames[ids,-which(names(ames) == "Sale_Price")] 12 | y <- ames[ids,]$Sale_Price 13 | 14 | # ames common models 15 | ames_pretrain <- tabnet_pretrain(x, y, epoch = 2, checkpoint_epochs = 1) 16 | ames_pretrain_vsplit <- tabnet_pretrain(x, y, epochs = 3, valid_split=.2, 17 | num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1) 18 | ames_fit <- tabnet_fit(x, y, epochs = 5 , checkpoint_epochs = 2) 19 | ames_fit_vsplit <- tabnet_fit(x, y, tabnet_model=ames_pretrain_vsplit, epochs = 3, 20 | num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1) 21 | 22 | # attrition small data 23 | utils::data("attrition", package = "modeldata") 24 | ids <- sample(nrow(attrition), 256) 25 | 26 | # attrition common models 27 | attrix <- attrition[ids,-which(names(attrition) == "Attrition")] 28 | attri_mult_x <- attrix[-which(names(attrix) == "JobSatisfaction")] 29 | 30 | attriy <- attrition[ids,]$Attrition 31 | 32 | attr_pretrained <- tabnet_pretrain(attrix, attriy, epochs = 12) 33 | attr_pretrained_vsplit <- tabnet_pretrain(attrix, attriy, epochs = 12, valid_split=0.3) 34 | attr_fitted <- tabnet_fit(attrix, attriy, epochs = 12) 35 | attr_fitted_vsplit <- tabnet_fit(attrix, attriy, epochs = 12, valid_split=0.3) 36 | 37 | # data.tree Node dataset 38 | utils::data("acme", package = "data.tree") 39 | acme_df <- data.tree::ToDataFrameTypeCol(acme, acme$attributesAll) %>% 40 | select(-starts_with("level_")) 41 | 42 | attrition_tree <- attrition %>% 43 | tibble::rowid_to_column() %>% 44 | mutate(pathString = paste("attrition", Department, JobRole, rowid, sep = "/")) %>% 45 | select(-Department, -JobRole, -rowid) %>% 46 | data.tree::as.Node() 47 | 48 | # Run after all tests 49 | withr::defer(testthat::teardown_env()) 50 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method(min_grid,tabnet) 4 | S3method(multi_predict,"_tabnet_fit") 5 | S3method(nn_prune_head,tabnet_fit) 6 | S3method(nn_prune_head,tabnet_pretrain) 7 | S3method(predict,tabnet_fit) 8 | S3method(print,tabnet_fit) 9 | S3method(print,tabnet_pretrain) 10 | S3method(tabnet_explain,default) 11 | S3method(tabnet_explain,model_fit) 12 | S3method(tabnet_explain,tabnet_fit) 13 | S3method(tabnet_explain,tabnet_pretrain) 14 | S3method(tabnet_fit,Node) 15 | S3method(tabnet_fit,data.frame) 16 | S3method(tabnet_fit,default) 17 | S3method(tabnet_fit,formula) 18 | S3method(tabnet_fit,recipe) 19 | S3method(tabnet_pretrain,Node) 20 | S3method(tabnet_pretrain,data.frame) 21 | S3method(tabnet_pretrain,default) 22 | S3method(tabnet_pretrain,formula) 23 | S3method(tabnet_pretrain,recipe) 24 | S3method(update,tabnet) 25 | export("%>%") 26 | export(attention_width) 27 | export(cat_emb_dim) 28 | export(check_compliant_node) 29 | export(checkpoint_epochs) 30 | export(decision_width) 31 | export(drop_last) 32 | export(encoder_activation) 33 | export(entmax) 34 | export(entmax15) 35 | export(feature_reusage) 36 | export(lr_scheduler) 37 | export(mask_type) 38 | export(mlp_activation) 39 | export(mlp_hidden_multiplier) 40 | export(momentum) 41 | export(nn_aum_loss) 42 | export(node_to_df) 43 | export(num_independent) 44 | export(num_independent_decoder) 45 | export(num_shared) 46 | export(num_shared_decoder) 47 | export(num_steps) 48 | export(optimizer) 49 | export(penalty) 50 | export(sparsemax) 51 | export(sparsemax15) 52 | export(tabnet) 53 | export(tabnet_config) 54 | export(tabnet_explain) 55 | export(tabnet_fit) 56 | export(tabnet_nn) 57 | export(tabnet_pretrain) 58 | export(verbose) 59 | export(virtual_batch_size) 60 | importFrom(dplyr,filter) 61 | importFrom(dplyr,last_col) 62 | importFrom(dplyr,mutate) 63 | importFrom(dplyr,mutate_all) 64 | importFrom(dplyr,mutate_if) 65 | importFrom(dplyr,select) 66 | importFrom(dplyr,starts_with) 67 | importFrom(dplyr,where) 68 | importFrom(magrittr,"%>%") 69 | importFrom(parsnip,multi_predict) 70 | importFrom(rlang,.data) 71 | importFrom(stats,predict) 72 | importFrom(stats,update) 73 | importFrom(tidyr,replace_na) 74 | importFrom(torch,nn_prune_head) 75 | importFrom(tune,min_grid) 76 | importFrom(zeallot,"%<-%") 77 | -------------------------------------------------------------------------------- /man/autoplot.tabnet_explain.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/plot.R 3 | \name{autoplot.tabnet_explain} 4 | \alias{autoplot.tabnet_explain} 5 | \title{Plot tabnet_explain mask importance heatmap} 6 | \usage{ 7 | autoplot.tabnet_explain( 8 | object, 9 | type = c("mask_agg", "steps"), 10 | quantile = 1, 11 | ... 12 | ) 13 | } 14 | \arguments{ 15 | \item{object}{A \code{tabnet_explain} object as a result of \code{\link[=tabnet_explain]{tabnet_explain()}}.} 16 | 17 | \item{type}{a character value. Either \code{"mask_agg"} the default, for a single 18 | heatmap of aggregated mask importance per predictor along the dataset, 19 | or \code{"steps"} for one heatmap at each mask step.} 20 | 21 | \item{quantile}{numerical value between 0 and 1. Provides quantile clipping of the 22 | mask values} 23 | 24 | \item{...}{not used.} 25 | } 26 | \value{ 27 | A \code{ggplot} object. 28 | } 29 | \description{ 30 | Plot tabnet_explain mask importance heatmap 31 | } 32 | \details{ 33 | Plot the \code{tabnet_explain} object mask importance per variable along the predicted dataset. 34 | \code{type="mask_agg"} output a single heatmap of mask aggregated values, 35 | \code{type="steps"} provides a plot faceted along the \code{n_steps} mask present in the model. 36 | \code{quantile=.995} may be used for strong outlier clipping, in order to better highlight 37 | low values. \code{quantile=1}, the default, do not clip any values. 38 | } 39 | \examples{ 40 | \dontshow{if ((torch::torch_is_installed() && require("modeldata"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 41 | \dontrun{ 42 | library(ggplot2) 43 | data("attrition", package = "modeldata") 44 | 45 | ## Single-outcome binary classification of `Attrition` in `attrition` dataset 46 | attrition_fit <- tabnet_fit(Attrition ~. , data=attrition, epoch=11) 47 | attrition_explain <- tabnet_explain(attrition_fit, attrition) 48 | # Plot the model aggregated mask interpretation heatmap 49 | autoplot(attrition_explain) 50 | 51 | ## Multi-outcome regression on `Sale_Price` and `Pool_Area` in `ames` dataset, 52 | data("ames", package = "modeldata") 53 | x <- ames[,-which(names(ames) \%in\% c("Sale_Price", "Pool_Area"))] 54 | y <- ames[, c("Sale_Price", "Pool_Area")] 55 | ames_fit <- tabnet_fit(x, y, epochs = 1, verbose=TRUE) 56 | ames_explain <- tabnet_explain(ames_fit, x) 57 | autoplot(ames_explain, quantile = 0.99) 58 | } 59 | \dontshow{\}) # examplesIf} 60 | } 61 | -------------------------------------------------------------------------------- /man/tabnet_nn.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tab-network.R 3 | \name{tabnet_nn} 4 | \alias{tabnet_nn} 5 | \title{TabNet Model Architecture} 6 | \usage{ 7 | tabnet_nn( 8 | input_dim, 9 | output_dim, 10 | n_d = 8, 11 | n_a = 8, 12 | n_steps = 3, 13 | gamma = 1.3, 14 | cat_idxs = c(), 15 | cat_dims = c(), 16 | cat_emb_dim = 1, 17 | n_independent = 2, 18 | n_shared = 2, 19 | epsilon = 1e-15, 20 | virtual_batch_size = 128, 21 | momentum = 0.02, 22 | mask_type = "sparsemax", 23 | mask_topk = NULL 24 | ) 25 | } 26 | \arguments{ 27 | \item{input_dim}{Initial number of features.} 28 | 29 | \item{output_dim}{Dimension of network output. Examples : one for regression, 2 for 30 | binary classification etc.. Vector of those dimensions in case of multi-output.} 31 | 32 | \item{n_d}{Dimension of the prediction layer (usually between 4 and 64).} 33 | 34 | \item{n_a}{Dimension of the attention layer (usually between 4 and 64).} 35 | 36 | \item{n_steps}{Number of successive steps in the network (usually between 3 and 10).} 37 | 38 | \item{gamma}{Scaling factor for attention updates (usually between 1 and 2).} 39 | 40 | \item{cat_idxs}{Index of each categorical column in the dataset.} 41 | 42 | \item{cat_dims}{Number of categories in each categorical column.} 43 | 44 | \item{cat_emb_dim}{Size of the embedding of categorical features if int, all categorical 45 | features will have same embedding size if list of int, every corresponding feature will have 46 | specific size.} 47 | 48 | \item{n_independent}{Number of independent GLU layer in each GLU block of the encoder.} 49 | 50 | \item{n_shared}{Number of shared GLU layer in each GLU block of the encoder.} 51 | 52 | \item{epsilon}{Avoid log(0), this should be kept very low.} 53 | 54 | \item{virtual_batch_size}{Batch size for Ghost Batch Normalization.} 55 | 56 | \item{momentum}{Numerical value between 0 and 1 which will be used for momentum in all batch norm.} 57 | 58 | \item{mask_type}{Either "sparsemax", "entmax" or "entmax15": the sparse masking function to use.} 59 | 60 | \item{mask_topk}{the mask top-k value for k-sparsity selection in the mask for \code{sparsemax} and \code{entmax15}. 61 | defaults to 1/4 of last \code{input_dim} if \code{NULL}. See \link{entmax15} for details.} 62 | } 63 | \description{ 64 | This is a \code{nn_module} representing the TabNet architecture from 65 | \href{https://arxiv.org/abs/1908.07442}{Attentive Interpretable Tabular Deep Learning}. 66 | } 67 | -------------------------------------------------------------------------------- /tests/testthat/test-loss.R: -------------------------------------------------------------------------------- 1 | test_that("nn_unsupervised_loss is working as expected", { 2 | 3 | unsup_loss <- tabnet:::nn_unsupervised_loss() 4 | 5 | # the poor-guy expect_r6_class(x, class) 6 | expect_true(all(c("nn_weighted_loss","nn_loss","nn_module") %in% class(unsup_loss))) 7 | 8 | y_pred <- torch::torch_rand(3,5, requires_grad = TRUE) 9 | embedded_x <- torch::torch_rand(3,5) 10 | obfuscation_mask <- torch::torch_bernoulli(embedded_x, p = 0.5) 11 | output <- unsup_loss(y_pred, embedded_x, obfuscation_mask) 12 | output$backward() 13 | 14 | expect_tensor(output) 15 | expect_equal_to_r(output >= 0, TRUE) 16 | expect_false(rlang::is_null(output$grad_fn)) 17 | expect_equal(output$dim(), 0) 18 | }) 19 | 20 | 21 | test_that("nn_aum_loss works as expected with 1-dim label", { 22 | 23 | aum_loss <- tabnet::nn_aum_loss() 24 | 25 | # the poor-guy expect_r6_class(x, class) 26 | expect_true(all(c("nn_mse_loss","nn_loss","nn_module") %in% class(aum_loss))) 27 | 28 | # 1-dim label 29 | label_tensor <- torch::torch_tensor(attrition$Attrition) 30 | pred_tensor <- torch::torch_rand(label_tensor$shape, requires_grad = TRUE) 31 | output <- aum_loss(pred_tensor, label_tensor) 32 | output$backward() 33 | 34 | expect_tensor(output) 35 | expect_equal_to_r(output >= 0, TRUE) 36 | expect_false(rlang::is_null(output$grad_fn)) 37 | expect_equal(output$dim(), 0) 38 | 39 | }) 40 | 41 | 42 | test_that("nn_aum_loss works as expected with 2-dim label", { 43 | 44 | aum_loss <- tabnet::nn_aum_loss() 45 | label_tensor <- torch::torch_tensor(attrition$Attrition)$unsqueeze(-1) 46 | pred_tensor <- torch::torch_rand(label_tensor$shape, requires_grad = TRUE) 47 | output <- aum_loss(pred_tensor, label_tensor) 48 | output$backward() 49 | 50 | expect_tensor(output) 51 | expect_equal_to_r(output >= 0, TRUE) 52 | expect_false(rlang::is_null(output$grad_fn)) 53 | expect_equal(output$dim(), 0) 54 | }) 55 | 56 | 57 | test_that("nn_aum_loss works as expected with {n, 2} shape prediction", { 58 | 59 | aum_loss <- tabnet::nn_aum_loss() 60 | label_tensor <- torch::torch_tensor(attrition$Attrition) 61 | pred_tensor <- torch::torch_rand(c(label_tensor$shape, 2), requires_grad = TRUE) 62 | output <- aum_loss(pred_tensor, label_tensor) 63 | output$backward() 64 | 65 | 66 | expect_tensor(output) 67 | expect_equal_to_r(output >= 0, TRUE) 68 | expect_false(rlang::is_null(output$grad_fn)) 69 | expect_equal(output$dim(), 0) 70 | 71 | }) 72 | 73 | 74 | -------------------------------------------------------------------------------- /vignettes/tidymodels-interface.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Fitting tabnet with tidymodels" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Fitting tabnet with tidymodels} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | ```{r, include = FALSE} 11 | knitr::opts_chunk$set( 12 | collapse = TRUE, 13 | comment = "#>", 14 | eval = FALSE 15 | ) 16 | ``` 17 | 18 | ```{r setup} 19 | library(tabnet) 20 | library(tidymodels) 21 | library(modeldata) 22 | ``` 23 | 24 | In this vignette, we show how to create a TabNet model using the tidymodels interface. 25 | 26 | We are going to use the `lending_club` dataset available 27 | in the `modeldata` package. 28 | 29 | First let's split our dataset into training and testing so we can later access performance of our model: 30 | 31 | ```{r} 32 | set.seed(123) 33 | data("lending_club", package = "modeldata") 34 | split <- initial_split(lending_club, strata = Class) 35 | train <- training(split) 36 | test <- testing(split) 37 | ``` 38 | 39 | We now define our pre-processing steps. Note that TabNet handles categorical variables, so we don't need to do any kind of transformation to them. Normalizing the numeric variables is a good idea though. 40 | 41 | ```{r} 42 | rec <- recipe(Class ~ ., train) %>% 43 | step_normalize(all_numeric()) 44 | ``` 45 | 46 | Next, we define our model. We are going to train for 50 epochs with a batch size of 128. There are other hyperparameters but, we are going to use the defaults. 47 | 48 | ```{r} 49 | mod <- tabnet(epochs = 50) %>% 50 | set_engine("torch", verbose = TRUE) %>% 51 | set_mode("classification") 52 | ``` 53 | 54 | We also define our `workflow` object: 55 | 56 | ```{r} 57 | wf <- workflow() %>% 58 | add_model(mod) %>% 59 | add_recipe(rec) 60 | ``` 61 | 62 | We can now define our cross-validation strategy: 63 | 64 | ```{r} 65 | folds <- vfold_cv(train, v = 5) 66 | ``` 67 | 68 | And finally, fit the model: 69 | 70 | ```{r} 71 | fit_rs <- wf %>% fit_resamples(folds) 72 | ``` 73 | 74 | After a few minutes we can get the results: 75 | 76 | ```{r} 77 | collect_metrics(fit_rs) 78 | ``` 79 | 80 | ``` 81 | # A tibble: 3 × 6 82 | .metric .estimator mean n std_err .config 83 | 84 | 1 accuracy binary 0.945 5 0.000869 Preprocessor1_Model1 85 | 2 brier_class binary 0.0535 5 0.00122 Preprocessor1_Model1 86 | 3 roc_auc binary 0.611 5 0.0153 Preprocessor1_Model1 87 | ``` 88 | 89 | And finally, we can verify the results in our test set: 90 | 91 | ```{r} 92 | model <- wf %>% fit(train) 93 | model %>% 94 | augment( test) %>% 95 | roc_auc(Class, .pred_good, event_level = "second") 96 | ``` 97 | 98 | ``` 99 | # A tibble: 1 x 3 100 | .metric .estimator .estimate 101 | 102 | 1 roc_auc binary 0.710 103 | ``` 104 | 105 | 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /.github/workflows/rhub.yaml: -------------------------------------------------------------------------------- 1 | # R-hub's generic GitHub Actions workflow file. It's canonical location is at 2 | # https://github.com/r-hub/actions/blob/v1/workflows/rhub.yaml 3 | # You can update this file to a newer version using the rhub2 package: 4 | # 5 | # rhub::rhub_setup() 6 | # 7 | # It is unlikely that you need to modify this file manually. 8 | 9 | name: R-hub 10 | run-name: "${{ github.event.inputs.id }}: ${{ github.event.inputs.name || format('Manually run by {0}', github.triggering_actor) }}" 11 | 12 | on: 13 | workflow_dispatch: 14 | inputs: 15 | config: 16 | description: 'A comma separated list of R-hub platforms to use.' 17 | type: string 18 | default: 'linux,windows,macos' 19 | name: 20 | description: 'Run name. You can leave this empty now.' 21 | type: string 22 | id: 23 | description: 'Unique ID. You can leave this empty now.' 24 | type: string 25 | 26 | jobs: 27 | 28 | setup: 29 | runs-on: ubuntu-latest 30 | outputs: 31 | containers: ${{ steps.rhub-setup.outputs.containers }} 32 | platforms: ${{ steps.rhub-setup.outputs.platforms }} 33 | 34 | steps: 35 | # NO NEED TO CHECKOUT HERE 36 | - uses: r-hub/actions/setup@v1 37 | with: 38 | config: ${{ github.event.inputs.config }} 39 | id: rhub-setup 40 | 41 | linux-containers: 42 | needs: setup 43 | if: ${{ needs.setup.outputs.containers != '[]' }} 44 | runs-on: ubuntu-latest 45 | name: ${{ matrix.config.label }} 46 | strategy: 47 | fail-fast: false 48 | matrix: 49 | config: ${{ fromJson(needs.setup.outputs.containers) }} 50 | container: 51 | image: ${{ matrix.config.container }} 52 | 53 | steps: 54 | - uses: r-hub/actions/checkout@v1 55 | - uses: r-hub/actions/platform-info@v1 56 | with: 57 | token: ${{ secrets.RHUB_TOKEN }} 58 | job-config: ${{ matrix.config.job-config }} 59 | - uses: r-hub/actions/setup-deps@v1 60 | with: 61 | token: ${{ secrets.RHUB_TOKEN }} 62 | job-config: ${{ matrix.config.job-config }} 63 | - uses: r-hub/actions/run-check@v1 64 | with: 65 | token: ${{ secrets.RHUB_TOKEN }} 66 | job-config: ${{ matrix.config.job-config }} 67 | 68 | other-platforms: 69 | needs: setup 70 | if: ${{ needs.setup.outputs.platforms != '[]' }} 71 | runs-on: ${{ matrix.config.os }} 72 | name: ${{ matrix.config.label }} 73 | strategy: 74 | fail-fast: false 75 | matrix: 76 | config: ${{ fromJson(needs.setup.outputs.platforms) }} 77 | 78 | steps: 79 | - uses: r-hub/actions/checkout@v1 80 | - uses: r-hub/actions/setup-r@v1 81 | with: 82 | job-config: ${{ matrix.config.job-config }} 83 | token: ${{ secrets.RHUB_TOKEN }} 84 | - uses: r-hub/actions/platform-info@v1 85 | with: 86 | token: ${{ secrets.RHUB_TOKEN }} 87 | job-config: ${{ matrix.config.job-config }} 88 | - uses: r-hub/actions/setup-deps@v1 89 | with: 90 | job-config: ${{ matrix.config.job-config }} 91 | token: ${{ secrets.RHUB_TOKEN }} 92 | - uses: r-hub/actions/run-check@v1 93 | with: 94 | job-config: ${{ matrix.config.job-config }} 95 | token: ${{ secrets.RHUB_TOKEN }} 96 | -------------------------------------------------------------------------------- /tests/testthat/test-plot.R: -------------------------------------------------------------------------------- 1 | 2 | test_that("Autoplot with unsupervised training, w and wo valid_split", { 3 | 4 | expect_no_error( 5 | print(autoplot(attr_pretrained)) 6 | ) 7 | 8 | expect_no_error( 9 | print(autoplot(attr_pretrained_vsplit)) 10 | ) 11 | 12 | }) 13 | 14 | test_that("Autoplot with supervised training, w and wo valid_split", { 15 | 16 | expect_no_error( 17 | print(autoplot(attr_fitted)) 18 | ) 19 | 20 | expect_no_error( 21 | print(autoplot(attr_fitted_vsplit)) 22 | ) 23 | 24 | }) 25 | 26 | test_that("Autoplot a model without checkpoint", { 27 | 28 | tabnet_pretrain <- tabnet_pretrain(attrix, attriy, epochs = 3) 29 | expect_no_error( 30 | print(autoplot(tabnet_pretrain)) 31 | ) 32 | 33 | tabnet_pretrain <- tabnet_pretrain(attrix, attriy, epochs = 3, valid_split=0.3) 34 | expect_no_error( 35 | print(autoplot(tabnet_pretrain)) 36 | ) 37 | 38 | tabnet_fit <- tabnet_fit(attrix, attriy, epochs = 3) 39 | expect_no_error( 40 | print(autoplot(tabnet_fit)) 41 | ) 42 | 43 | tabnet_fit <- tabnet_fit(attrix, attriy, epochs = 3, valid_split = 0.3) 44 | expect_no_error( 45 | print(autoplot(tabnet_fit)) 46 | ) 47 | 48 | }) 49 | 50 | test_that("Autoplot of pretrain then fit scenario, pretrain without checkpoints, fit without valid", { 51 | 52 | tabnet_fit <- tabnet_fit(attrix, attriy, tabnet_model = attr_pretrained_vsplit, epochs = 12) 53 | 54 | expect_no_error( 55 | print(autoplot(tabnet_fit)) 56 | ) 57 | 58 | fit_no_checkpoint <- tabnet_fit(Sale_Price ~., data = small_ames, epochs = 2, valid_split = 0.2, checkpoint_epoch = 3, batch_size = 64) 59 | expect_no_error( 60 | print(autoplot(fit_no_checkpoint)) 61 | ) 62 | fit_with_checkpoint <- tabnet_fit(Sale_Price ~., data = small_ames, tabnet_model = fit_no_checkpoint, epochs = 2, checkpoint_epoch = 1) 63 | expect_warning( 64 | print(autoplot(fit_with_checkpoint)), 65 | "Removed 2 rows containing missing values" 66 | ) 67 | 68 | }) 69 | 70 | test_that("Autoplot of tabnet_explain works for pretrain and fitted model", { 71 | 72 | explain_pretrain <- tabnet_explain(attr_pretrained_vsplit, attrix) 73 | explain_fit <- tabnet_explain(attr_fitted_vsplit, attrix) 74 | 75 | expect_no_error( 76 | print(autoplot(explain_pretrain)) 77 | ) 78 | 79 | expect_no_error( 80 | print(autoplot(explain_pretrain, type = "steps")) 81 | ) 82 | 83 | expect_no_error( 84 | print(autoplot(explain_pretrain, type = "steps", quantile = 0.99)), 85 | 86 | ) 87 | 88 | expect_no_error( 89 | print(autoplot(explain_fit)) 90 | ) 91 | 92 | expect_no_error( 93 | print(autoplot(explain_fit, type = "steps")) 94 | ) 95 | 96 | expect_no_error( 97 | print(autoplot(explain_fit, type = "steps", quantile = 0.99)) 98 | ) 99 | 100 | }) 101 | 102 | test_that("Autoplot of multi-outcome regression explainer", { 103 | 104 | x <- small_ames[,-which(names(ames) %in% c("Sale_Price", "Pool_Area"))] 105 | y <- small_ames[, c("Sale_Price", "Pool_Area")] 106 | ames_fit <- tabnet_fit(x, y, epochs = 5, verbose=TRUE) 107 | ames_explain <- tabnet_explain(ames_fit, ames) 108 | 109 | expect_no_error( 110 | print(autoplot(ames_explain)) 111 | ) 112 | 113 | expect_no_error( 114 | print(autoplot(ames_explain, type = "steps", quantile = 0.99)) 115 | ) 116 | 117 | }) 118 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # For help debugging build failures open an issue on the RStudio community with the 'github-actions' tag. 2 | # https://community.rstudio.com/new-topic?category=Package%20development&tags=github-actions 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | pull_request: 9 | branches: 10 | - main 11 | - master 12 | 13 | name: R-CMD-check 14 | 15 | jobs: 16 | R-CMD-check: 17 | runs-on: ${{ matrix.config.os }} 18 | 19 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 20 | 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | config: 25 | - {os: windows-latest, r: 'release'} 26 | # on m1 the R version is whicherver is installed in the runner machine. 27 | - {os: macOS, r: 'release', version: cpu-m1, runner: [self-hosted, macOS, ARM64]} 28 | - {os: ubuntu-22.04, r: 'release', rspm: "https://packagemanager.rstudio.com/cran/__linux__/jammy/latest"} 29 | - {os: ubuntu-22.04, r: 'devel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/jammy/latest"} 30 | 31 | env: 32 | R_REMOTES_NO_ERRORS_FROM_WARNINGS: true 33 | RSPM: ${{ matrix.config.rspm }} 34 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 35 | TORCH_INSTALL: 1 36 | TORCH_TEST: 1 37 | PYTORCH_MPS_HIGH_WATERMARK_RATIO: 0.0 38 | 39 | steps: 40 | - uses: actions/checkout@v4 41 | 42 | - uses: r-lib/actions/setup-r@v2 43 | with: 44 | r-version: ${{ matrix.config.r }} 45 | 46 | - uses: r-lib/actions/setup-pandoc@v2 47 | 48 | - name: Install wget on macOS for quarto download 49 | if: runner.os == 'macOS' 50 | run: 51 | brew install wget 52 | 53 | - uses: quarto-dev/quarto-actions/setup@v2 54 | with: 55 | version: 1.7.30 56 | 57 | - uses: r-lib/actions/setup-r-dependencies@v2 58 | with: 59 | extra-packages: any::rcmdcheck, local::. 60 | needs: check 61 | 62 | - uses: r-lib/actions/check-r-package@v2 63 | with: 64 | error-on: '"error"' 65 | args: 'c("--no-multiarch", "--no-manual", "--as-cran")' 66 | 67 | GPU: 68 | runs-on: ['self-hosted', 'gce', 'gpu'] 69 | name: 'gpu' 70 | 71 | container: 72 | image: 'nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04' 73 | options: '--gpus all --runtime=nvidia' 74 | 75 | timeout-minutes: 120 76 | 77 | env: 78 | R_REMOTES_NO_ERRORS_FROM_WARNINGS: true 79 | RSPM: 'https://packagemanager.rstudio.com/cran/__linux__/focal/latest' 80 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 81 | TORCH_INSTALL: 1 82 | TORCH_TEST: 1 83 | DEBIAN_FRONTEND: 'noninteractive' 84 | 85 | steps: 86 | - uses: actions/checkout@v4 87 | 88 | - run: | 89 | apt-get update -y 90 | apt-get install -y sudo software-properties-common dialog apt-utils tzdata libpng-dev 91 | 92 | - uses: r-lib/actions/setup-r@v2 93 | 94 | - uses: r-lib/actions/setup-pandoc@v2 95 | 96 | - uses: quarto-dev/quarto-actions/setup@v2 97 | with: 98 | version: 1.7.30 99 | 100 | - uses: r-lib/actions/setup-r-dependencies@v2 101 | with: 102 | extra-packages: any::rcmdcheck, local::. 103 | needs: check 104 | 105 | - uses: r-lib/actions/check-r-package@v2 106 | with: 107 | error-on: '"error"' 108 | args: 'c("--no-multiarch", "--no-manual", "--as-cran")' 109 | -------------------------------------------------------------------------------- /tests/testthat/test-hardhat_scenarios.R: -------------------------------------------------------------------------------- 1 | test_that("Supervised training can continue with a additional fit, with or wo from_epoch=", { 2 | 3 | fit_2 <- tabnet_fit(x, y, tabnet_model = ames_fit, epochs = 1) 4 | 5 | expect_equal(fit_2$fit$config$epoch, 1) 6 | expect_length(fit_2$fit$metrics, 6) 7 | expect_identical(ames_fit$fit$metrics[[1]]$train, fit_2$fit$metrics[[1]]$train) 8 | expect_identical(ames_fit$fit$metrics[[5]]$train, fit_2$fit$metrics[[5]]$train) 9 | 10 | expect_no_error( 11 | fit_3 <- tabnet_fit(x, y, tabnet_model = ames_fit, from_epoch = 2, epoch = 1 ) 12 | ) 13 | expect_equal(fit_3$fit$config$epoch, 1) 14 | expect_length(fit_3$fit$metrics, 3) 15 | expect_identical(ames_fit$fit$metrics[[1]]$train, fit_2$fit$metrics[[1]]$train) 16 | expect_identical(ames_fit$fit$metrics[[2]]$train, fit_2$fit$metrics[[2]]$train) 17 | 18 | }) 19 | 20 | test_that("we can change the tabnet_options between training epoch", { 21 | 22 | fit_2 <- tabnet_fit(x, y, ames_fit, epochs = 1, penalty = 0.003, learn_rate = 0.002) 23 | 24 | expect_equal(fit_2$fit$config$epoch, 1) 25 | expect_length(fit_2$fit$metrics, 6) 26 | expect_equal(fit_2$fit$config$learn_rate, 0.002) 27 | 28 | }) 29 | 30 | test_that("epoch counter is valid for retraining from a checkpoint", { 31 | 32 | tmp <- tempfile("model", fileext = "rds") 33 | withr::local_file(saveRDS(ames_fit, tmp)) 34 | 35 | fit1 <- readRDS(tmp) 36 | fit_2 <- tabnet_fit(x, y, ames_fit, epochs = 12, verbose=T) 37 | 38 | expect_equal(fit_2$fit$config$epoch, 12) 39 | expect_length(fit_2$fit$metrics, 17) 40 | expect_lte(mean(fit_2$fit$metrics[[17]]$train), mean(fit_2$fit$metrics[[1]]$train)) 41 | 42 | }) 43 | 44 | test_that("trying to continue training with different dataset raise error", { 45 | 46 | pretrain_1 <- tabnet_pretrain(x, y, epochs = 1) 47 | 48 | expect_error( 49 | pretrain_2 <- tabnet_fit(attrix, y, tabnet_model=pretrain_1, epochs = 1), 50 | regexp = "Model dimensions" 51 | ) 52 | 53 | fit_1 <- tabnet_fit(x, y, epochs = 1) 54 | 55 | expect_error( 56 | fit_2 <- tabnet_fit(attrix, y, tabnet_model=fit_1, epochs = 1), 57 | regexp = "Model dimensions" 58 | ) 59 | 60 | expect_error( 61 | fit_2 <- tabnet_fit(x, attriy, tabnet_model=fit_1, epochs = 1), 62 | regexp = "Model dimensions" 63 | ) 64 | 65 | }) 66 | 67 | test_that("Supervised training can continue unsupervised training, with or wo from_epoch=", { 68 | 69 | expect_no_error( 70 | tabnet_fit(x, y, tabnet_model = ames_pretrain, epoch = 1) 71 | ) 72 | 73 | expect_no_error( 74 | tabnet_fit(Attrition ~ ., data = attrition, tabnet_model = attr_pretrained, epochs = 1) 75 | ) 76 | 77 | expect_no_error( 78 | tabnet_fit(x, y, tabnet_model = ames_pretrain, from_epoch = 1, epoch = 1 ) 79 | ) 80 | 81 | }) 82 | 83 | test_that("Supervised training can continue unsupervised training, with a Libtorch optimizer", { 84 | testthat::skip_if(!torch_has_optim_ignite()) 85 | 86 | expect_no_error( 87 | tabnet_pretrain(x, y, epoch = 1, config = tabnet_config( 88 | optimizer = torch::optim_ignite_adamw) 89 | ) 90 | ) 91 | 92 | expect_no_error( 93 | tabnet_fit(Attrition ~ ., data = attrition, tabnet_model = attr_pretrained, epochs = 1, 94 | optimizer = torch::optim_ignite_adamw 95 | ) 96 | ) 97 | }) 98 | 99 | 100 | test_that("serialization of tabnet_pretrain with saveRDS just works", { 101 | 102 | fit <- tabnet_fit(x, y, ames_pretrain, epoch = 1, learn_rate = 1e-12) 103 | 104 | tmp <- tempfile("model", fileext = "rds") 105 | withr::local_file(saveRDS(ames_pretrain, tmp)) 106 | 107 | pretrain2 <- readRDS(tmp) 108 | fit2 <- tabnet_fit(x, y, pretrain2, epoch = 1, learn_rate = 1e-12) 109 | 110 | expect_equal( 111 | predict(fit, ames), 112 | predict(fit2, ames), 113 | tolerance = 20 114 | ) 115 | 116 | expect_equal(as.numeric(fit2$fit$network$.check), 1) 117 | 118 | }) 119 | -------------------------------------------------------------------------------- /R/explain.R: -------------------------------------------------------------------------------- 1 | 2 | #' Interpretation metrics from a TabNet model 3 | #' 4 | #' @param object a TabNet fit object 5 | #' @param new_data a data.frame to obtain interpretation metrics. 6 | #' 7 | #' @return 8 | #' 9 | #' Returns a list with 10 | #' 11 | #' * `M_explain`: the aggregated feature importance masks as detailed in 12 | #' TabNet's paper. 13 | #' * `masks` a list containing the masks for each step. 14 | #' 15 | #' @examplesIf torch::torch_is_installed() 16 | #' 17 | #' set.seed(2021) 18 | #' 19 | #' n <- 256 20 | #' x <- data.frame( 21 | #' x = rnorm(n), 22 | #' y = rnorm(n), 23 | #' z = rnorm(n) 24 | #' ) 25 | #' 26 | #' y <- x$x 27 | #' 28 | #' fit <- tabnet_fit(x, y, epochs = 10, 29 | #' num_steps = 1, 30 | #' batch_size = 512, 31 | #' attention_width = 1, 32 | #' num_shared = 1, 33 | #' num_independent = 1) 34 | #' 35 | #' 36 | #' ex <- tabnet_explain(fit, x) 37 | #' 38 | #' 39 | #' @export 40 | tabnet_explain <- function(object, new_data) { 41 | UseMethod("tabnet_explain") 42 | } 43 | 44 | #' @export 45 | #' @rdname tabnet_explain 46 | tabnet_explain.default <- function(object, new_data) { 47 | type_error("{.fn tabnet_explain} is not defined for a {.type {class(object)[1]}}.") 48 | } 49 | 50 | #' @export 51 | #' @rdname tabnet_explain 52 | tabnet_explain.tabnet_fit <- function(object, new_data) { 53 | if (inherits(new_data, "Node")) { 54 | new_data_df <- node_to_df(new_data)$x 55 | } else { 56 | new_data_df <- new_data 57 | } 58 | # Enforces column order, type, column names, etc 59 | processed <- hardhat::forge(new_data_df, object$blueprint, outcomes = FALSE) 60 | data <- resolve_data(processed$predictors, y = rep(1, nrow(processed$predictors))) 61 | device <- get_device_from_config(object$fit$config) 62 | data <- to_device(data, device) 63 | output <- explain_impl(object$fit$network, data$x, data$x_na_mask) 64 | 65 | # convert stuff to matrix with colnames 66 | nms <- colnames(processed$predictors) 67 | output$M_explain <- convert_to_df(output$M_explain, nms) 68 | output$masks <- lapply(output$masks, convert_to_df, nms = nms) 69 | class(output) <- "tabnet_explain" 70 | output 71 | } 72 | 73 | #' @export 74 | #' @rdname tabnet_explain 75 | tabnet_explain.tabnet_pretrain <- tabnet_explain.tabnet_fit 76 | 77 | #' @export 78 | #' @rdname tabnet_explain 79 | tabnet_explain.model_fit <- function(object, new_data) { 80 | tabnet_explain(parsnip::extract_fit_engine(object), new_data) 81 | } 82 | 83 | convert_to_df <- function(x, nms) { 84 | x <- as.data.frame(as.matrix(x$to(device = "cpu")$detach())) 85 | colnames(x) <- nms 86 | tibble::as_tibble(x) 87 | } 88 | 89 | explain_impl <- function(network, x, x_na_mask) { 90 | curr_device <- network$.check$device 91 | withr::defer({ 92 | network$to(device = curr_device) 93 | }) 94 | network$to(device=x$device) 95 | # NULLing values to avoid a R-CMD Check Note 96 | M_explain_emb_dim <- masks_emb_dim <- NULL 97 | c(M_explain_emb_dim, masks_emb_dim) %<-% network$forward_masks(x, x_na_mask) 98 | 99 | # summarize the categorical embeddedings into 1 column 100 | # per variable 101 | M_explain <- sum_embedding_masks( 102 | mask = M_explain_emb_dim, 103 | input_dim = network$input_dim, 104 | cat_idx = network$cat_idxs, 105 | cat_emb_dim = network$cat_emb_dim 106 | ) 107 | 108 | masks <- lapply( 109 | masks_emb_dim, 110 | FUN = sum_embedding_masks, 111 | input_dim = network$input_dim, 112 | cat_idx = network$cat_idxs, 113 | cat_emb_dim = network$cat_emb_dim 114 | ) 115 | 116 | list(M_explain = M_explain$to(device="cpu"), masks = to_device(masks, "cpu")) 117 | } 118 | 119 | compute_feature_importance <- function(network, x, x_na_mask) { 120 | out <- explain_impl(network, x, x_na_mask) 121 | m <- as.numeric(out$M_explain$sum(dim = 1)$detach()$to(device = "cpu")) 122 | m/sum(m) 123 | } 124 | 125 | # sum embeddings taking their sizes into account. 126 | sum_embedding_masks <- function(mask, input_dim, cat_idx, cat_emb_dim) { 127 | sizes <- rep(1, input_dim) 128 | sizes[cat_idx] <- cat_emb_dim 129 | 130 | splits <- mask$split_with_sizes(sizes, dim = 2) 131 | splits <- lapply(splits, torch::torch_sum, dim = 2, keepdim = TRUE) 132 | 133 | torch::torch_cat(splits, dim = 2) 134 | } 135 | 136 | vi_model.tabnet_fit <- function(object, ...) { 137 | tib <- object$fit$importances 138 | names(tib) <- c("Variable", "Importance") 139 | tib 140 | } 141 | 142 | vi_model.tabnet_pretrain <- vi_model.tabnet_fit 143 | -------------------------------------------------------------------------------- /man/tabnet_pretrain.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/hardhat.R 3 | \name{tabnet_pretrain} 4 | \alias{tabnet_pretrain} 5 | \alias{tabnet_pretrain.default} 6 | \alias{tabnet_pretrain.data.frame} 7 | \alias{tabnet_pretrain.formula} 8 | \alias{tabnet_pretrain.recipe} 9 | \alias{tabnet_pretrain.Node} 10 | \title{Tabnet model} 11 | \usage{ 12 | tabnet_pretrain(x, ...) 13 | 14 | \method{tabnet_pretrain}{default}(x, ...) 15 | 16 | \method{tabnet_pretrain}{data.frame}( 17 | x, 18 | y, 19 | tabnet_model = NULL, 20 | config = tabnet_config(), 21 | ..., 22 | from_epoch = NULL 23 | ) 24 | 25 | \method{tabnet_pretrain}{formula}( 26 | formula, 27 | data, 28 | tabnet_model = NULL, 29 | config = tabnet_config(), 30 | ..., 31 | from_epoch = NULL 32 | ) 33 | 34 | \method{tabnet_pretrain}{recipe}( 35 | x, 36 | data, 37 | tabnet_model = NULL, 38 | config = tabnet_config(), 39 | ..., 40 | from_epoch = NULL 41 | ) 42 | 43 | \method{tabnet_pretrain}{Node}( 44 | x, 45 | tabnet_model = NULL, 46 | config = tabnet_config(), 47 | ..., 48 | from_epoch = NULL 49 | ) 50 | } 51 | \arguments{ 52 | \item{x}{Depending on the context: 53 | \itemize{ 54 | \item A \strong{data frame} of predictors. 55 | \item A \strong{matrix} of predictors. 56 | \item A \strong{recipe} specifying a set of preprocessing steps 57 | created from \code{\link[recipes:recipe]{recipes::recipe()}}. 58 | \item A \strong{Node} where tree leaves will be left out, 59 | and attributes will be used as predictors. 60 | } 61 | 62 | The predictor data should be standardized (e.g. centered or scaled). 63 | The model treats categorical predictors internally thus, you don't need to 64 | make any treatment. 65 | The model treats missing values internally thus, you don't need to make any 66 | treatment.} 67 | 68 | \item{...}{Model hyperparameters. 69 | Any hyperparameters set here will update those set by the config argument. 70 | See \code{\link[=tabnet_config]{tabnet_config()}} for a list of all possible hyperparameters.} 71 | 72 | \item{y}{(optional) When \code{x} is a \strong{data frame} or \strong{matrix}, \code{y} is the outcome} 73 | 74 | \item{tabnet_model}{A pretrained \code{tabnet_model} object to continue the fitting on. 75 | if \code{NULL} (the default) a brand new model is initialized.} 76 | 77 | \item{config}{A set of hyperparameters created using the \code{tabnet_config} function. 78 | If no argument is supplied, this will use the default values in \code{\link[=tabnet_config]{tabnet_config()}}.} 79 | 80 | \item{from_epoch}{When a \code{tabnet_model} is provided, restore the network weights from a specific epoch. 81 | Default is last available checkpoint for restored model, or last epoch for in-memory model.} 82 | 83 | \item{formula}{A formula specifying the outcome terms on the left-hand side, 84 | and the predictor terms on the right-hand side.} 85 | 86 | \item{data}{When a \strong{recipe} or \strong{formula} is used, \code{data} is specified as: 87 | \itemize{ 88 | \item A \strong{data frame} containing both the predictors and the outcome. 89 | }} 90 | } 91 | \value{ 92 | A TabNet model object. It can be used for serialization, predictions, or further fitting. 93 | } 94 | \description{ 95 | Pretrain the \href{https://arxiv.org/abs/1908.07442}{TabNet: Attentive Interpretable Tabular Learning} model 96 | on the predictor data exclusively (unsupervised training). 97 | } 98 | \section{outcome}{ 99 | 100 | 101 | Outcome value are accepted here only for consistent syntax with \code{tabnet_fit}, but 102 | by design the outcome, if present, is ignored during pre-training. 103 | } 104 | 105 | \section{pre-training from a previous model}{ 106 | 107 | 108 | When providing a parent \code{tabnet_model} parameter, the model pretraining resumes from that model weights 109 | at the following epoch: 110 | \itemize{ 111 | \item last pretrained epoch for a model already in torch context 112 | \item Last model checkpoint epoch for a model loaded from file 113 | \item the epoch related to a checkpoint matching or preceding the \code{from_epoch} value if provided 114 | The model pretraining metrics append on top of the parent metrics in the returned TabNet model. 115 | } 116 | } 117 | 118 | \section{Threading}{ 119 | 120 | 121 | TabNet uses \code{torch} as its backend for computation and \code{torch} uses all 122 | available threads by default. 123 | 124 | You can control the number of threads used by \code{torch} with: 125 | 126 | \if{html}{\out{
}}\preformatted{torch::torch_set_num_threads(1) 127 | torch::torch_set_num_interop_threads(1) 128 | }\if{html}{\out{
}} 129 | } 130 | 131 | \examples{ 132 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 133 | data("ames", package = "modeldata") 134 | pretrained <- tabnet_pretrain(Sale_Price ~ ., data = ames, epochs = 1) 135 | \dontshow{\}) # examplesIf} 136 | } 137 | -------------------------------------------------------------------------------- /tests/testthat/test-missing_values.R: -------------------------------------------------------------------------------- 1 | test_that("pretrain accepts missing value in predictors and (unused) outcome", { 2 | 3 | data("attrition", package = "modeldata") 4 | ids <- sample(nrow(attrition), 256) 5 | 6 | x <- attrition[ids,-which(names(attrition) == "Attrition")] 7 | y <- attrition[ids,]$Attrition 8 | y_missing <- y 9 | y_missing[1] <- NA 10 | 11 | # numerical missing 12 | x_missing <- x 13 | x_missing[1,"Age"] <- NA 14 | 15 | expect_no_error( 16 | miss_pretrain <- tabnet_pretrain(x_missing, y, epochs = 1) 17 | ) 18 | 19 | # categorical missing 20 | x_missing <- x 21 | x_missing[1,"BusinessTravel"] <- NA 22 | 23 | expect_no_error( 24 | miss_pretrain <- tabnet_pretrain(x_missing, y, epochs = 1) 25 | ) 26 | 27 | # no error when missing in outcome 28 | expect_no_error( 29 | miss_pretrain <- tabnet_pretrain(x, y_missing, epochs = 1) 30 | ) 31 | 32 | }) 33 | 34 | 35 | test_that("fit accept missing value in predictor, not in outcome", { 36 | 37 | data("attrition", package = "modeldata") 38 | ids <- sample(nrow(attrition), 256) 39 | 40 | x <- attrition[ids,-which(names(attrition) == "Attrition")] 41 | y <- attrition[ids,]$Attrition 42 | y_missing <- y 43 | y_missing[1] <- NA 44 | 45 | # numerical missing 46 | x_missing <- x 47 | x_missing[1,"Age"] <- NA 48 | 49 | expect_no_error( 50 | miss_fit <- tabnet_fit(x_missing, y, epochs = 1) 51 | ) 52 | 53 | # categorical missing 54 | x_missing <- x 55 | x_missing[1,"BusinessTravel"] <- NA 56 | 57 | expect_no_error( 58 | miss_fit <- tabnet_fit(x_missing, y, epochs = 1) 59 | ) 60 | 61 | # missing in outcome 62 | expect_error( 63 | miss_fit <- tabnet_fit(x, y_missing, epochs = 1), 64 | regexp = "missing" 65 | ) 66 | 67 | }) 68 | 69 | test_that("fit accept missing value in `Node` predictor", { 70 | # fix to https://github.com/mlverse/tabnet/issues/125 71 | library(data.tree) 72 | data(starwars, package = "dplyr") 73 | 74 | starwars_tree <- starwars %>% 75 | rename(`_name` = "name", `_height` = "height") %>% 76 | # iconv translation ensure compatibility on macintosh 77 | mutate(pathString = paste("StarWars_characters", species, homeworld, `_name`, sep = "/") %>% 78 | iconv(from = "UTF-8", to = "ASCII//TRANSLIT")) %>% 79 | as.Node() 80 | 81 | expect_no_error( 82 | miss_fit <- tabnet_fit(starwars_tree, epochs = 1, cat_emb_dim = 2) 83 | ) 84 | 85 | expect_no_error( 86 | miss_pred <- predict(miss_fit, starwars_tree) 87 | ) 88 | 89 | }) 90 | 91 | test_that("predict data-frame accept missing value in predictor", { 92 | 93 | data("attrition", package = "modeldata") 94 | ids <- sample(nrow(attrition), 256) 95 | 96 | x <- attrition[ids,-which(names(attrition) == "Attrition")] 97 | y <- attrition[ids,]$Attrition 98 | # 99 | fit <- tabnet_fit(x, y, epochs = 1) 100 | 101 | # numerical missing 102 | x_missing <- x 103 | x_missing[1,"Age"] <- NA 104 | 105 | # predict with numerical missing 106 | expect_no_error( 107 | predict(fit, x_missing), 108 | ) 109 | # categorical missing 110 | x_missing <- x 111 | x_missing[1,"BusinessTravel"] <- NA 112 | 113 | # predict with categorical missing 114 | expect_no_error( 115 | predict(fit, x_missing) 116 | ) 117 | 118 | }) 119 | 120 | test_that("inference works with missings in the response vector", { 121 | 122 | data("attrition", package = "modeldata") 123 | ids <- sample(nrow(attrition), 256) 124 | 125 | rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>% 126 | step_normalize(all_numeric(), -all_outcomes()) 127 | fit <- tabnet_fit(rec, attrition, epochs = 1, valid_split = 0.25, 128 | verbose = TRUE) 129 | # predict with empty vector 130 | attrition[["EnvironmentSatisfaction"]] <-NA 131 | expect_no_error( 132 | predict(fit, attrition) 133 | ) 134 | 135 | # predict with wrong class 136 | attrition[["EnvironmentSatisfaction"]] <-NA_character_ 137 | expect_no_error( 138 | predict(fit, attrition) 139 | ) 140 | 141 | # predict with list column 142 | attrition[["EnvironmentSatisfaction"]] <- list(NA) 143 | expect_no_error( 144 | predict(fit, attrition) 145 | ) 146 | 147 | }) 148 | 149 | test_that("explain works with missings in predictors", { 150 | 151 | data("attrition", package = "modeldata") 152 | ids <- sample(nrow(attrition), 256) 153 | 154 | x <- attrition[ids,-which(names(attrition) == "Attrition")] 155 | y <- attrition[ids,]$Attrition 156 | # 157 | fit <- tabnet_fit(x, y, epochs = 1) 158 | 159 | # numerical missing 160 | x_missing <- x 161 | x_missing[1,"Age"] <- NA 162 | 163 | # explain with numerical missing 164 | expect_no_error( 165 | tabnet_explain(fit, x_missing) 166 | ) 167 | # categorical missing 168 | x_missing <- x 169 | x_missing[1,"BusinessTravel"] <- NA 170 | 171 | # explain with categorical missing 172 | expect_no_error( 173 | tabnet_explain(fit, x_missing) 174 | ) 175 | }) 176 | -------------------------------------------------------------------------------- /R/loss.R: -------------------------------------------------------------------------------- 1 | #' Self-supervised learning loss 2 | #' 3 | #' Creates a criterion that measures the Autoassociative self-supervised learning loss between each 4 | #' element in the input \eqn{y_pred} and target \eqn{embedded_x} on the values masked by \eqn{obfuscation_mask}. 5 | #' 6 | #' @noRd 7 | nn_unsupervised_loss <- torch::nn_module( 8 | "nn_unsupervised_loss", 9 | inherit = torch::nn_cross_entropy_loss, 10 | 11 | initialize = function(eps = 1e-9){ 12 | super$initialize() 13 | self$eps = eps 14 | }, 15 | 16 | forward = function(y_pred, embedded_x, obfuscation_mask){ 17 | errors <- y_pred - embedded_x 18 | reconstruction_errors <- torch::torch_mul(errors, obfuscation_mask) ^ 2 19 | batch_stds <- torch::torch_std(embedded_x, dim = 1) ^ 2 + self$eps 20 | 21 | # compute the number of obfuscated variables to reconstruct 22 | nb_reconstructed_variables <- torch::torch_sum(obfuscation_mask, dim = 2) 23 | 24 | # take the mean of the reconstructed variable errors 25 | features_loss <- torch::torch_matmul(reconstruction_errors, 1 / batch_stds) / (nb_reconstructed_variables + self$eps) 26 | loss <- torch::torch_mean(features_loss, dim = 1) 27 | loss 28 | } 29 | ) 30 | 31 | 32 | #' AUM loss 33 | #' 34 | #' Creates a criterion that measures the Area under the \eqn{Min(FPR, FNR)} (AUM) between each 35 | #' element in the input \eqn{pred_tensor} and target \eqn{label_tensor}. 36 | #' 37 | #' This is used for measuring the error of a binary reconstruction within highly unbalanced dataset, 38 | #' where the goal is optimizing the ROC curve. Note that the targets \eqn{label_tensor} should be factor 39 | #' level of the binary outcome, i.e. with values `1L` and `2L`. 40 | #' 41 | #' @examplesIf torch::torch_is_installed() 42 | #' loss <- nn_aum_loss() 43 | #' input <- torch::torch_randn(4, 6, requires_grad = TRUE) 44 | #' target <- input > 1.5 45 | #' output <- loss(input, target) 46 | #' output$backward() 47 | #' @export 48 | nn_aum_loss <- torch::nn_module( 49 | "nn_aum_loss", 50 | inherit = torch::nn_mse_loss, 51 | initialize = function(){ 52 | super$initialize() 53 | self$roc_aum <- tibble::tibble() 54 | }, 55 | forward = function(pred_tensor, label_tensor){ 56 | # thanks to https://tdhock.github.io/blog/2024/auto-grad-overhead/ 57 | is_positive <- label_tensor == label_tensor$max() 58 | is_negative <- is_positive$bitwise_not() 59 | # manage case when prediction error is null (prevent division by 0) 60 | if(as.logical(torch::torch_sum(is_positive) == 0) || as.logical(torch::torch_sum(is_negative) == 0)){ 61 | return(torch::torch_sum(pred_tensor*0)) 62 | } 63 | 64 | # pred tensor may be [prediction, case_wts] when add_case_weight() is used. We keep only prediction 65 | if (pred_tensor$ndim > label_tensor$ndim) { 66 | pred_tensor <- pred_tensor$slice(dim = 2, 0, 1)$squeeze(2) 67 | } 68 | 69 | # nominal case 70 | fn_diff <- -1L * is_positive 71 | fp_diff <- is_negative$to(dtype = torch::torch_long()) 72 | fp_denom <- torch::torch_sum(is_negative) # or 1 for AUM based on count instead of rate 73 | fn_denom <- torch::torch_sum(is_positive) # or 1 for AUM based on count instead of rate 74 | sorted_pred_ids <- torch::torch_argsort(pred_tensor, dim = 1, descending = TRUE)$squeeze(-1) 75 | 76 | sorted_fp_cum <- fp_diff[sorted_pred_ids]$cumsum(dim = 1) / fp_denom 77 | sorted_fn_cum <- -fn_diff[sorted_pred_ids]$flip(1)$cumsum(dim = 1)$flip(1) / fn_denom 78 | sorted_thresh_gr <- -pred_tensor[sorted_pred_ids] 79 | sorted_dedup <- sorted_thresh_gr$diff(dim = 1) != 0 80 | # pad to replace removed last element 81 | padding <- sorted_dedup$slice(dim = 1, 0, 1) # torch_tensor 1 w same dtype, same shape, same device 82 | sorted_fp_end <- torch::torch_cat(c(sorted_dedup, padding)) 83 | sorted_fn_end <- torch::torch_cat(c(padding, sorted_dedup)) 84 | uniq_thresh_gr <- sorted_thresh_gr[sorted_fp_end] 85 | uniq_fp_after <- sorted_fp_cum[sorted_fp_end] 86 | uniq_fn_before <- sorted_fn_cum[sorted_fn_end] 87 | if (pred_tensor$ndim == 1) { 88 | FPR <- torch::torch_cat(c(padding$logical_not(), uniq_fp_after)) # FPR with trailing 0 89 | FNR <- torch::torch_cat(c(uniq_fn_before, padding$logical_not())) # FNR with leading 0 90 | self$roc_aum <- list( 91 | FPR = FPR, 92 | FNR = FNR, 93 | TPR = 1 - FNR, 94 | "min(FPR,FNR)" = torch::torch_minimum(FNR, FPR), # full-range min(FNR, FPR) 95 | constant_range_low = torch::torch_cat(c(torch::torch_tensor(-Inf), uniq_thresh_gr)), 96 | constant_range_high = torch::torch_cat(c(uniq_thresh_gr, torch::torch_tensor(Inf))) 97 | ) %>% purrr::map_dfc(torch::as_array) 98 | } 99 | min_FPR_FNR <- torch::torch_minimum(uniq_fp_after[1:-2], uniq_fn_before[2:N]) 100 | constant_range_gr <- uniq_thresh_gr$diff() # range splits leading to {FPR, FNR } errors (see roc_aum row) 101 | torch::torch_sum(min_FPR_FNR * constant_range_gr, dim = 1) 102 | 103 | } 104 | ) 105 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # tabnet 0.8.0 2 | 3 | ## New features 4 | 5 | * messaging is now improved with {cli} 6 | * add optimal threshold and support size into new 1.5 alpha `entmax15()` and `sparsemax15()` 7 | `mask_types`. Add an optional `mask_topk` config parameter. (#180) 8 | * `optimizer`now default to the `torch_ignite_adam` when available. 9 | Result is 30% faster pretraining and fitting tasks (#178). 10 | * add `nn_aum_loss()` function for area under the $Min(FPR,FNR)$ optimization for cases of 11 | unbalanced binary classification (#178). 12 | * add a vignette on imbalanced binary classification with `nn_aum_loss()` (#178). 13 | 14 | ## Bugfixes 15 | 16 | * config parameter now merge correctly for torch loss or torch optimizer generator. 17 | * `nn_unsupervised_loss()` is now a proper loss function. 18 | 19 | # tabnet 0.7.0 20 | 21 | ## Bugfixes 22 | 23 | * Remove long-run example raising a Note. 24 | * fix `tabet_pretrain` failing with `value_error("Can't convert data of class: 'NULL'")` in R 4.5 25 | * fix `tabet_pretrain` wrongly used instead of `tabnet_fit` in Missing data predictor vignette 26 | * improve message related to case_weights not being used as predictors. 27 | * improve function documentation consistency before translation. 28 | * fix "..." is not an exported object from 'namespace:dials'" error when using tune() on tabnet parameters. (#160 @cphaarmeyer) 29 | 30 | # tabnet 0.6.0 31 | 32 | ## New features 33 | 34 | * parsnip models now allow transparently passing case weights through `workflows::add_case_weights()` parameters (#151) 35 | * parsnip models now support `tabnet_model` and `from_epoch` parameters (#143) 36 | 37 | ## Bugfixes 38 | 39 | * Adapt `tune::finalize_workflow()` test to {parsnip} v1.2 breaking change. (#155) 40 | * `autoplot()` now position the "has_checkpoint" points correctly when a `tabnet_fit()` is continuing a previous training using `tabnet_model =`. (#150) 41 | * Explicitely warn that `tabnet_model` option will not be used in `tabnet_pretrain()` tasks. (#150) 42 | 43 | # tabnet 0.5.0 44 | 45 | ## New features 46 | 47 | * {tabnet} now allows hierarchical multi-label classification through {data.tree} hierarchical `Node` dataset. (#126) 48 | * `tabnet_pretrain()` now allows different GLU blocks in GLU layers in encoder and in decoder through the `config()` parameters `num_idependant_decoder` and `num_shared_decoder` (#129) 49 | * Add `reduce_on_plateau` as option for `lr_scheduler` at `tabnet_config()` (@SvenVw, #120) 50 | * use zeallot internally with %<-% for code readability (#133) 51 | * add FR translation (#131) 52 | 53 | # tabnet 0.4.0 54 | 55 | ## New features 56 | 57 | * Add explicit legend in `autoplot.tabnet_fit()` (#67) 58 | * Improve unsupervised vignette content. (#67) 59 | * `tabnet_pretrain()` now allows missing values in predictors. (#68) 60 | * `tabnet_explain()` now works for `tabnet_pretrain` models. (#68) 61 | * Allow missing-values values in predictor for unsupervised training. (#68) 62 | * Improve performance of `random_obfuscator()` torch_nn module. (#68) 63 | * Add support for early stopping (#69) 64 | * `tabnet_fit()` and `predict()` now allow **missing values** in predictors. (#76) 65 | * `tabnet_config()` now supports a `num_workers=` parameters to control parallel dataloading (#83) 66 | * Add a vignette on missing data (#83) 67 | * `tabnet_config()` now has a flag `skip_importance` to skip calculating feature importance (@egillax, #91) 68 | * Export and document `tabnet_nn` 69 | * Added `min_grid.tabnet` method for `tune` (@cphaarmeyer, #107) 70 | * Added `tabnet_explain()` method for parsnip models (@cphaarmeyer, #108) 71 | * `tabnet_fit()` and `predict()` now allow **multi-outcome**, all numeric or all factors but not mixed. (#118) 72 | 73 | ## Bugfixes 74 | 75 | * `tabnet_explain()` is now correctly handling missing values in predictors. (#77) 76 | * `dataloader` can now use `num_workers>0` (#83) 77 | * new default values for `batch_size` and `virtual_batch_size` improves performance on mid-range devices. 78 | * add default `engine="torch"` to tabnet parsnip model (#114) 79 | * fix `autoplot()` warnings turned into errors with {ggplot2} v3.4 (#113) 80 | 81 | 82 | # tabnet 0.3.0 83 | 84 | * Added an `update` method for tabnet models to allow the correct usage of `finalize_workflow` (#60). 85 | 86 | # tabnet 0.2.0 87 | 88 | ## New features 89 | 90 | * Allow model fine-tuning through passing a pre-trained model to `tabnet_fit()` (@cregouby, #26) 91 | * Explicit error in case of missing values (@cregouby, #24) 92 | * Better handling of larger datasets when running `tabnet_explain()`. 93 | * Add `tabnet_pretrain()` for unsupervised pretraining (@cregouby, #29) 94 | * Add `autoplot()` of model loss among epochs (@cregouby, #36) 95 | * Added a `config` argument to `fit() / pretrain()` so one can pass a pre-made config list. (#42) 96 | * In `tabnet_config()`, new `mask_type` option with `entmax` additional to default `sparsemax` (@cmcmaster1, #48) 97 | * In `tabnet_config()`, `loss` now also takes function (@cregouby, #55) 98 | 99 | ## Bugfixes 100 | 101 | * Fixed bug in GPU training. (#22) 102 | * Fixed memory leaks when using custom autograd function. 103 | * Batch predictions to avoid OOM error. 104 | 105 | ## Internal improvements 106 | 107 | * Added GPU CI. (#22) 108 | 109 | # tabnet 0.1.0 110 | 111 | * Added a `NEWS.md` file to track changes to the package. 112 | -------------------------------------------------------------------------------- /tests/testthat/test-hardhat_interfaces.R: -------------------------------------------------------------------------------- 1 | test_that("Training regression for data.frame and formula", { 2 | 3 | expect_no_error( 4 | fit <- tabnet_fit(x, y, epochs = 1) 5 | ) 6 | 7 | expect_no_error( 8 | fit <- tabnet_fit(Sale_Price ~ ., data = ames, epochs = 1) 9 | ) 10 | 11 | expect_no_error( 12 | predict(fit, x) 13 | ) 14 | 15 | expect_no_error( 16 | fit <- tabnet_fit(x, y, epochs = 2, verbose = TRUE) 17 | ) 18 | }) 19 | 20 | test_that("Training classification for data.frame", { 21 | 22 | expect_no_error( 23 | fit <- tabnet_fit(attrix, attriy, epochs = 1) 24 | ) 25 | 26 | expect_no_error( 27 | predict(fit, attrix, type = "prob") 28 | ) 29 | 30 | expect_no_error( 31 | predict(fit, attrix) 32 | ) 33 | 34 | }) 35 | 36 | test_that("works with validation split", { 37 | 38 | expect_no_error( 39 | fit <- tabnet_fit(attrix, attriy, epochs = 1, valid_split = 0.5) 40 | ) 41 | 42 | expect_no_error( 43 | fit <- tabnet_fit(attrix, attriy, epochs = 1, valid_split = 0.5, verbose = TRUE) 44 | ) 45 | 46 | }) 47 | 48 | 49 | test_that("can train from a recipe", { 50 | 51 | rec <- recipe(Attrition ~ ., data = attrition) %>% 52 | step_normalize(all_numeric(), -all_outcomes()) 53 | 54 | expect_no_error( 55 | fit <- tabnet_fit(rec, attrition[1:256,], epochs = 1, valid_split = 0.25, 56 | verbose = TRUE) 57 | ) 58 | 59 | expect_no_error( 60 | predict(fit, attrition) 61 | ) 62 | 63 | }) 64 | 65 | test_that("serialization with saveRDS just works", { 66 | 67 | predictions <- predict(ames_fit, ames) 68 | 69 | tmp <- tempfile("model", fileext = "rds") 70 | withr::local_file(saveRDS(ames_fit, tmp)) 71 | 72 | # rm(fit) 73 | gc() 74 | 75 | fit2 <- readRDS(tmp) 76 | 77 | expect_equal( 78 | predictions, 79 | predict(fit2, ames) 80 | ) 81 | 82 | expect_equal(as.numeric(fit2$fit$network$.check), 1) 83 | 84 | }) 85 | 86 | test_that("checkpoints works for inference", { 87 | 88 | expect_no_error( 89 | fit <- tabnet_fit(x, y, epochs = 3, checkpoint_epochs = 1) 90 | ) 91 | 92 | expect_no_error( 93 | p1 <- predict(fit, x, epoch = 1) 94 | ) 95 | 96 | expect_no_error( 97 | p2 <- predict(fit, x, epoch = 2) 98 | ) 99 | 100 | expect_no_error( 101 | p3 <- predict(fit, x, epoch = 3) 102 | ) 103 | 104 | expect_equal(p3, predict(fit, x)) 105 | 106 | }) 107 | 108 | test_that("print module works even after a reload from disk", { 109 | 110 | testthat::skip_on_os("linux") 111 | testthat::skip_on_os("windows") 112 | 113 | withr::with_options(new = c(cli.width = 50), 114 | expect_snapshot_output(ames_fit)) 115 | 116 | tmp <- tempfile("model", fileext = "rds") 117 | withr::local_file(saveRDS(ames_fit, tmp)) 118 | fit2 <- readRDS(tmp) 119 | 120 | withr::with_options(new = c(cli.width = 50), 121 | expect_snapshot_output(fit2)) 122 | 123 | }) 124 | 125 | 126 | test_that("num_workers works for pretrain, fit an predict", { 127 | 128 | expect_no_error( 129 | tabnet_pretrain(x, y, epochs = 1, num_workers=1L, 130 | batch_size=128, virtual_batch_size=64) 131 | ) 132 | expect_no_error( 133 | tabnet_pretrain(x, y, epochs = 1, num_workers=1L, valid_split=0.2, 134 | batch_size=128, virtual_batch_size=64) 135 | ) 136 | 137 | expect_no_error( 138 | tabnet_fit(x, y, epochs = 1, num_workers=1L, 139 | batch_size=128, virtual_batch_size=64) 140 | ) 141 | 142 | expect_no_error( 143 | tabnet_fit(x, y, epochs = 1, num_workers=1L, valid_split=0.2, 144 | batch_size=128, virtual_batch_size=64) 145 | ) 146 | 147 | expect_no_error( 148 | predict(ames_fit, x, num_workers=1L, 149 | batch_size=128, virtual_batch_size=64) 150 | ) 151 | 152 | }) 153 | 154 | 155 | test_that("we can prune head of tabnet pretrain and tabnet fit models", { 156 | 157 | expect_no_error(pruned_pretrain <- torch::nn_prune_head(ames_pretrain, 1)) 158 | test_that("decoder has been removed from the list of modules", { 159 | expect_equal(all(stringr::str_detect("decoder", names(pruned_pretrain$children))),FALSE) 160 | }) 161 | 162 | 163 | expect_no_error(pruned_fit <- torch::nn_prune_head(ames_fit, 1)) 164 | test_that("decoder has been removed from the list of modules", { 165 | expect_equal(all(stringr::str_detect("final_mapping", names(pruned_pretrain$children))),FALSE) 166 | }) 167 | 168 | 169 | }) 170 | 171 | test_that("we can prune head of restored models from disk", { 172 | testthat::skip_on_os("linux") 173 | testthat::skip_on_os("windows") 174 | 175 | tmp <- tempfile("model", fileext = "rds") 176 | withr::local_file(saveRDS(ames_pretrain, tmp)) 177 | ames_pretrain2 <- readRDS(tmp) 178 | expect_no_error(pruned_pretrain <- torch::nn_prune_head(ames_pretrain2, 1)) 179 | test_that("decoder has been removed from the list of modules", { 180 | expect_equal(all(stringr::str_detect("decoder", names(pruned_pretrain$children))),FALSE) 181 | }) 182 | 183 | 184 | tmp <- tempfile("model", fileext = "rds") 185 | withr::local_file(saveRDS(ames_fit, tmp)) 186 | ames_fit2 <- readRDS(tmp) 187 | expect_no_error(pruned_fit <- torch::nn_prune_head(ames_fit2, 1)) 188 | test_that("decoder has been removed from the list of modules", { 189 | expect_equal(all(stringr::str_detect("final_mapping", names(pruned_pretrain$children))),FALSE) 190 | }) 191 | 192 | }) 193 | -------------------------------------------------------------------------------- /tests/testthat/test-explain.R: -------------------------------------------------------------------------------- 1 | test_that("explain provides correct result with data.frame", { 2 | 3 | set.seed(2022) 4 | torch::torch_manual_seed(2022) 5 | 6 | n <- 2000 7 | x <- data.frame( 8 | x = rnorm(n), 9 | y = rnorm(n), 10 | z = rnorm(n) 11 | ) 12 | 13 | y <- x$x 14 | 15 | fit <- tabnet_fit(x, y, epochs = 15, 16 | num_steps = 1, 17 | batch_size = 512, 18 | attention_width = 1, 19 | num_shared = 1, 20 | num_independent = 1) 21 | 22 | expect_equal(which.max(fit$fit$importances$importance), 1) 23 | expect_equal(fit$fit$importances$variables, colnames(x)) 24 | 25 | ex <- tabnet_explain(fit, x) 26 | 27 | expect_length(ex, 2) 28 | expect_length(ex[[2]], 1) 29 | expect_equal(nrow(ex[[1]]), nrow(x)) 30 | expect_equal(nrow(ex[[2]][[1]]), nrow(x)) 31 | 32 | }) 33 | 34 | test_that("explain works for dataframe, formula and recipe", { 35 | 36 | # data.frame, regression 37 | expect_no_error( 38 | tabnet_explain(ames_pretrain_vsplit, new_data=small_ames) 39 | ) 40 | 41 | expect_no_error( 42 | tabnet_explain(ames_fit_vsplit, new_data=small_ames) 43 | ) 44 | 45 | # data.frame, classification 46 | expect_no_error( 47 | tabnet_explain(attr_pretrained_vsplit, attrix) 48 | ) 49 | expect_no_error( 50 | tabnet_explain(attr_fitted_vsplit, attrix) 51 | ) 52 | 53 | 54 | # formula 55 | tabnet_pretrain <- tabnet_pretrain(Sale_Price ~., data=small_ames, epochs = 3, valid_split=.2, 56 | num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1) 57 | expect_no_error( 58 | tabnet_explain(tabnet_pretrain, new_data=small_ames) 59 | ) 60 | 61 | tabnet_fit <- tabnet_fit(Sale_Price ~., data=small_ames, tabnet_model=tabnet_pretrain, epochs = 3, 62 | num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1) 63 | expect_no_error( 64 | tabnet_explain(tabnet_fit, new_data=small_ames) 65 | ) 66 | 67 | # recipe 68 | rec <- recipe(Sale_Price ~., data = small_ames) %>% 69 | step_zv(all_predictors()) %>% 70 | step_normalize(all_numeric_predictors()) 71 | 72 | tabnet_pretrain <- tabnet_pretrain(rec, data=small_ames, epochs = 3, valid_split=.2, 73 | num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1) 74 | expect_no_error( 75 | tabnet_explain(tabnet_pretrain, new_data=small_ames) 76 | ) 77 | 78 | tabnet_fit <- tabnet_fit(rec, data=small_ames, tabnet_model=tabnet_pretrain, epochs = 3, 79 | num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1) 80 | expect_no_error( 81 | tabnet_explain(tabnet_fit, new_data=small_ames) 82 | ) 83 | }) 84 | 85 | test_that("support for vip on tabnet_fit and tabnet_pretrain", { 86 | 87 | skip_if_not_installed("vip") 88 | 89 | n <- 1000 90 | x <- data.frame( 91 | x = runif(n), 92 | y = runif(n), 93 | z = runif(n) 94 | ) 95 | 96 | y <- x$x 97 | 98 | pretrain <- tabnet_pretrain(x, y, epochs = 1, 99 | num_steps = 1, 100 | batch_size = 512, 101 | attention_width = 1, 102 | num_shared = 1, 103 | num_independent = 1) 104 | 105 | fit <- tabnet_fit(x, y, epochs = 1, 106 | num_steps = 1, 107 | batch_size = 512, 108 | attention_width = 1, 109 | num_shared = 1, 110 | num_independent = 1) 111 | 112 | expect_no_error(vip::vip(pretrain)) 113 | expect_no_error(vip::vip(fit)) 114 | 115 | }) 116 | 117 | 118 | test_that("Importance is skipped if skip_importance flag is used", { 119 | 120 | set.seed(2022) 121 | torch::torch_manual_seed(2022) 122 | 123 | n <- 1000 124 | x <- data.frame( 125 | x = rnorm(n), 126 | y = rnorm(n), 127 | z = rnorm(n) 128 | ) 129 | 130 | y <- x$x 131 | 132 | fit <- tabnet_fit(x, y, epochs = 15, 133 | num_steps = 1, 134 | batch_size = 512, 135 | attention_width = 1, 136 | num_shared = 1, 137 | num_independent = 1, 138 | skip_importance = TRUE) 139 | 140 | expect_equal(fit$fit$importances, NULL) 141 | 142 | fit <- tabnet_fit(x, y, epochs = 15, 143 | num_steps = 1, 144 | batch_size = 512, 145 | attention_width = 1, 146 | num_shared = 1, 147 | num_independent = 1, 148 | skip_importance = FALSE) 149 | 150 | 151 | expect_equal(which.max(fit$fit$importances$importance), 1) 152 | expect_equal(fit$fit$importances$variables, colnames(x)) 153 | 154 | }) 155 | 156 | test_that("explain works for parsnip model", { 157 | 158 | model <- tabnet() %>% 159 | parsnip::set_mode("regression") %>% 160 | parsnip::set_engine("torch") 161 | fit <- model %>% 162 | parsnip::fit(Sale_Price ~ ., data = small_ames) 163 | 164 | expect_no_error( 165 | tabnet_explain(fit, new_data = small_ames), 166 | ) 167 | 168 | }) 169 | 170 | test_that("explain works for multi-outcome classification model", { 171 | 172 | fit <- tabnet_fit(x, data.frame(y = y, z = y + 1), epochs = 1) 173 | 174 | expect_no_error(tabnet_explain(fit, new_data = x)) 175 | 176 | }) 177 | -------------------------------------------------------------------------------- /po/R-tabnet.pot: -------------------------------------------------------------------------------- 1 | msgid "" 2 | msgstr "" 3 | "Project-Id-Version: tabnet 0.7.0.9000\n" 4 | "POT-Creation-Date: 2025-05-12 00:22+0200\n" 5 | "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" 6 | "Last-Translator: FULL NAME \n" 7 | "Language-Team: LANGUAGE \n" 8 | "Language: \n" 9 | "MIME-Version: 1.0\n" 10 | "Content-Type: text/plain; charset=UTF-8\n" 11 | "Content-Transfer-Encoding: 8bit\n" 12 | 13 | #: dials.R:3 14 | msgid "" 15 | "Package {.pkg dials} is needed for this function to work. Please install it." 16 | msgstr "" 17 | 18 | #: dials.R:142 19 | msgid "{.var cat_emb_dim} cannot be used as a {.fun tune} parameter yet." 20 | msgstr "" 21 | 22 | #: explain.R:47 23 | msgid "{.fn tabnet_explain} is not defined for a {.type {class(object)[1]}}." 24 | msgstr "" 25 | 26 | #: hardhat.R:111 27 | msgid "{.fn tabnet_fit} is not defined for a {.type {class(x)[1])}}." 28 | msgstr "" 29 | 30 | #: hardhat.R:268 31 | msgid "{.fn tabnet_pretrain} is not defined for a {.type {class(x)[1])}}." 32 | msgstr "" 33 | 34 | #: hardhat.R:347 35 | msgid "{.var {tabnet_model}} is not recognised as a proper TabNet model" 36 | msgstr "" 37 | 38 | #: hardhat.R:353 39 | msgid "The model was trained for less than {.val {from_epoch}} epochs" 40 | msgstr "" 41 | 42 | #: hardhat.R:365 43 | msgid "Found missing values in the {.var {names(outcomes)}} outcome column." 44 | msgstr "" 45 | 46 | #: hardhat.R:375 hardhat.R:484 47 | msgid "Model dimensions don't match." 48 | msgstr "" 49 | 50 | #: hardhat.R:400 51 | msgid "" 52 | "No model serialized weight can be found in {.var {tabnet_model}}, check the " 53 | "model history" 54 | msgstr "" 55 | 56 | #: hardhat.R:408 57 | msgid "Using {.fn tabnet_pretrain} from a model is not currently supported." 58 | msgstr "" 59 | 60 | #: hardhat.R:409 61 | msgid "Pretraining will start from a new network initialization" 62 | msgstr "" 63 | 64 | #: hardhat.R:446 65 | msgid "The model was trained for less than {.val {epoch}} epochs" 66 | msgstr "" 67 | 68 | #: hardhat.R:522 69 | msgid "" 70 | "Mixed multi-outcome type {.type {unique(purrr::map_chr(outcome_ptype, " 71 | "~class(.x)[[1]]))}} is not supported" 72 | msgstr "" 73 | 74 | #: hardhat.R:530 75 | msgid "Unknown outcome type {.type {class(outcome_ptype)}}" 76 | msgstr "" 77 | 78 | #: hardhat.R:537 79 | msgid "Outcome is factor and the prediction type is {.type {type}}." 80 | msgstr "" 81 | 82 | #: hardhat.R:540 83 | msgid "Outcome is numeric and the prediction type is {.type {type}}." 84 | msgstr "" 85 | 86 | #: model.R:256 87 | msgid "{.val {loss}} is not a valid loss for outcome of type {.type {dtype}}" 88 | msgstr "" 89 | 90 | #: model.R:267 91 | msgid "" 92 | "{.val {early_stopping_monitor}} is not a valid early-stopping metric to " 93 | "monitor with {.val valid_split = {valid_split}}" 94 | msgstr "" 95 | 96 | #: model.R:523 pretraining.R:121 97 | msgid "{.var optimizer} must be resolved into a torch optimizer generator." 98 | msgstr "" 99 | 100 | #: model.R:536 pretraining.R:135 101 | msgid "" 102 | "Currently only the {.str step} and {.str reduce_on_plateau} scheduler are " 103 | "supported." 104 | msgstr "" 105 | 106 | #: model.R:583 pretraining.R:182 107 | #, c-format 108 | msgid "[Epoch %03d] Loss: %3f" 109 | msgstr "" 110 | 111 | #: model.R:585 pretraining.R:184 112 | #, c-format 113 | msgid "[Epoch %03d] Loss: %3f, Valid loss: %3f" 114 | msgstr "" 115 | 116 | #: model.R:601 pretraining.R:199 117 | #, c-format 118 | msgid "Early-stopping at epoch {.val epoch}" 119 | msgstr "" 120 | 121 | #: model.R:626 pretraining.R:224 122 | msgid "" 123 | "Computing importances for a dataset with size {.val {train_ds$.length()}}. \n" 124 | " This can consume too much memory. We are going to use a sample of " 125 | "size 1e5. \n" 126 | " You can disable this message by using the " 127 | "`importance_sample_size` argument." 128 | msgstr "" 129 | 130 | #: parsnip.R:474 131 | msgid "" 132 | "Package {.pkg parsnip} is needed for this function to work. Please install " 133 | "it." 134 | msgstr "" 135 | 136 | #: parsnip.R:559 137 | msgid "parsnip" 138 | msgstr "" 139 | 140 | #: tab-network.R:238 tab-network.R:405 141 | msgid "{.var n_steps} should be a positive integer." 142 | msgstr "" 143 | 144 | #: tab-network.R:240 tab-network.R:407 145 | msgid "{.var n_shared} and {.var n_independant} can't be both zero." 146 | msgstr "" 147 | 148 | #: tab-network.R:463 149 | msgid "" 150 | "Please choose either {.val sparsemax}, {.val sparsemax15}, {.val entmax} or " 151 | "{.val entmax15} as {.var mask_type}" 152 | msgstr "" 153 | 154 | #: tab-network.R:627 155 | msgid "" 156 | "{.var cat_emb_dim} length must be 1 or the number of categorical predictors, \n" 157 | " got length {.val {length(self$cat_emb_dims)}} for {.val " 158 | "{length(cat_dims)}} \n" 159 | " categorical predictors" 160 | msgstr "" 161 | 162 | #: utils.R:68 163 | msgid "" 164 | "The provided hierarchical object is not recognized with a valid format that " 165 | "can be checked" 166 | msgstr "" 167 | 168 | #: utils.R:72 169 | msgid "" 170 | "The attributes or colnames in the provided hierarchical object use the " 171 | "following reserved names:\n" 172 | " {.vars {actual_names[actual_names %in% reserved_names]}}. \n" 173 | " Please change those names as they will lead to unexpected " 174 | "tabnet behavior." 175 | msgstr "" 176 | 177 | #: utils.R:154 utils.R:157 178 | msgid "" 179 | "Currently only {.val adam} is supported as character for {.var optimizer}." 180 | msgstr "" 181 | -------------------------------------------------------------------------------- /R/dials.R: -------------------------------------------------------------------------------- 1 | check_dials <- function() { 2 | if (!requireNamespace("dials", quietly = TRUE)) 3 | runtime_error("Package {.pkg dials} is needed for this function to work. Please install it.") 4 | } 5 | 6 | 7 | #' Parameters for the tabnet model 8 | #' 9 | #' @param range the default range for the parameter value 10 | #' @param trans whether to apply a transformation to the parameter 11 | #' @param values possible values for factor parameters 12 | #' 13 | #' These functions are used with `tune` grid functions to generate 14 | #' candidates. 15 | #' 16 | #' @rdname tabnet_params 17 | #' @return A `dials` parameter to be used when tuning TabNet models. 18 | #' @export 19 | #' @examplesIf (require("dials") && require("parsnip") && torch::torch_is_installed()) 20 | #' model <- tabnet(attention_width = tune(), feature_reusage = tune(), 21 | #' momentum = tune(), penalty = tune(), rate_step_size = tune()) %>% 22 | #' parsnip::set_mode("regression") %>% 23 | #' parsnip::set_engine("torch") 24 | #' 25 | attention_width <- function(range = c(8L, 64L), trans = NULL) { 26 | check_dials() 27 | dials::new_quant_param( 28 | type = "integer", 29 | range = range, 30 | inclusive = c(TRUE, TRUE), 31 | trans = trans, 32 | label = c(attention_width = "Width of the attention embedding for each mask"), 33 | finalize = NULL 34 | ) 35 | } 36 | 37 | #' @rdname tabnet_params 38 | #' @export 39 | decision_width <- function(range = c(8L, 64L), trans = NULL) { 40 | check_dials() 41 | dials::new_quant_param( 42 | type = "integer", 43 | range = range, 44 | inclusive = c(TRUE, TRUE), 45 | trans = trans, 46 | label = c(decision_width = "Width of the decision prediction layer"), 47 | finalize = NULL 48 | ) 49 | } 50 | 51 | 52 | #' @rdname tabnet_params 53 | #' @export 54 | feature_reusage <- function(range = c(1, 2), trans = NULL) { 55 | check_dials() 56 | dials::new_quant_param( 57 | type = "double", 58 | range = range, 59 | inclusive = c(TRUE, TRUE), 60 | trans = trans, 61 | label = c(feature_reusage = "Coefficient for feature reusage in the masks"), 62 | finalize = NULL 63 | ) 64 | } 65 | 66 | #' @rdname tabnet_params 67 | #' @export 68 | momentum <- function(range = c(0.01, 0.4), trans = NULL) { 69 | check_dials() 70 | dials::new_quant_param( 71 | type = "double", 72 | range = range, 73 | inclusive = c(TRUE, TRUE), 74 | trans = trans, 75 | label = c(momentum = "Momentum for batch normalization"), 76 | finalize = NULL 77 | ) 78 | } 79 | 80 | 81 | #' @rdname tabnet_params 82 | #' @export 83 | mask_type <- function(values = c("sparsemax", "entmax")) { 84 | check_dials() 85 | dials::new_qual_param( 86 | type = "character", 87 | values = values, 88 | label = c(mask_type = "Final layer of feature selector, either 'sparsemax' or 'entmax'"), 89 | finalize = NULL 90 | ) 91 | } 92 | 93 | #' @rdname tabnet_params 94 | #' @export 95 | num_independent <- function(range = c(1L, 5L), trans = NULL) { 96 | check_dials() 97 | dials::new_quant_param( 98 | type = "integer", 99 | range = range, 100 | inclusive = c(TRUE, TRUE), 101 | trans = trans, 102 | label = c(num_independent = "Number of independent Gated Linear Units layers at each step"), 103 | finalize = NULL 104 | ) 105 | } 106 | 107 | #' @rdname tabnet_params 108 | #' @export 109 | num_shared <- function(range = c(1L, 5L), trans = NULL) { 110 | check_dials() 111 | dials::new_quant_param( 112 | type = "integer", 113 | range = range, 114 | inclusive = c(TRUE, TRUE), 115 | trans = trans, 116 | label = c(num_shared = "Number of shared Gated Linear Units at each step"), 117 | finalize = NULL 118 | ) 119 | } 120 | 121 | #' @rdname tabnet_params 122 | #' @export 123 | num_steps <- function(range = c(3L, 10L), trans = NULL) { 124 | check_dials() 125 | dials::new_quant_param( 126 | type = "integer", 127 | range = range, 128 | inclusive = c(TRUE, TRUE), 129 | trans = trans, 130 | label = c(num_steps = "Number of steps in the architecture"), 131 | finalize = NULL 132 | ) 133 | } 134 | 135 | #' Non-tunable parameters for the tabnet model 136 | #' 137 | #' @param range unused 138 | #' @param trans unused 139 | #' @rdname tabnet_non_tunable 140 | #' @export 141 | cat_emb_dim <- function(range = NULL, trans = NULL) { 142 | cli::cli_abort("{.var cat_emb_dim} cannot be used as a {.fun tune} parameter yet.") 143 | } 144 | 145 | #' @rdname tabnet_non_tunable 146 | #' @export 147 | checkpoint_epochs <- cat_emb_dim 148 | 149 | #' @rdname tabnet_non_tunable 150 | #' @export 151 | drop_last <- cat_emb_dim 152 | 153 | #' @rdname tabnet_non_tunable 154 | #' @export 155 | encoder_activation <- cat_emb_dim 156 | 157 | #' @rdname tabnet_non_tunable 158 | #' @export 159 | lr_scheduler <- cat_emb_dim 160 | 161 | #' @rdname tabnet_non_tunable 162 | #' @export 163 | mlp_activation <- cat_emb_dim 164 | 165 | #' @rdname tabnet_non_tunable 166 | #' @export 167 | mlp_hidden_multiplier <- cat_emb_dim 168 | 169 | #' @rdname tabnet_non_tunable 170 | #' @export 171 | num_independent_decoder <- cat_emb_dim 172 | 173 | #' @rdname tabnet_non_tunable 174 | #' @export 175 | num_shared_decoder <- cat_emb_dim 176 | 177 | #' @rdname tabnet_non_tunable 178 | #' @export 179 | optimizer <- cat_emb_dim 180 | 181 | #' @rdname tabnet_non_tunable 182 | #' @export 183 | penalty <- cat_emb_dim 184 | 185 | #' @rdname tabnet_non_tunable 186 | #' @export 187 | verbose <- cat_emb_dim 188 | 189 | #' @rdname tabnet_non_tunable 190 | #' @export 191 | virtual_batch_size <- cat_emb_dim 192 | -------------------------------------------------------------------------------- /tests/testthat/test-mask-type.R: -------------------------------------------------------------------------------- 1 | 2 | test_that(".sparsemax_threshold_and_support works as expected with default values", { 3 | input <- torch::torch_randn(10, 5) 4 | expect_no_error( 5 | result <- tabnet:::.sparsemax_threshold_and_support(input) 6 | ) 7 | expect_type(result, "list") 8 | expect_length(result, 2) 9 | 10 | tau <- result[[1]] 11 | support_size <- result[[2]] 12 | expect_tensor(tau) 13 | expect_tensor_shape(tau, c(input$shape[1], 1)) 14 | expect_tensor(support_size) 15 | }) 16 | 17 | 18 | test_that(".sparsemax_threshold_and_support works as expected with k < input$size(dim)", { 19 | input <- torch::torch_randn(10, 5) 20 | dim <- 1L 21 | k <- 3 22 | expect_no_error( 23 | result <- tabnet:::.sparsemax_threshold_and_support(input, dim, k) 24 | ) 25 | expect_type(result, "list") 26 | expect_length(result, 2) 27 | 28 | tau <- result[[1]] 29 | support_size <- result[[2]] 30 | expect_tensor(tau) 31 | expect_tensor_shape(tau, c(1, input$shape[2])) 32 | expect_tensor(support_size) 33 | expect_tensor_shape(support_size, c(1, input$shape[2])) 34 | 35 | }) 36 | 37 | test_that(".sparsemax_threshold_and_support works as expected with k >= input$size(dim)", { 38 | input <- torch::torch_randn(10, 5) 39 | dim <- -2L 40 | k <- 7 41 | expect_no_error( 42 | result <- tabnet:::.sparsemax_threshold_and_support(input, dim, k) 43 | ) 44 | expect_type(result, "list") 45 | expect_length(result, 2) 46 | 47 | tau <- result[[1]] 48 | support_size <- result[[2]] 49 | expect_tensor(tau) 50 | expect_tensor_shape(tau, c(1, input$shape[2])) 51 | expect_tensor(support_size) 52 | expect_tensor_shape(support_size, c(1, input$shape[2])) 53 | }) 54 | 55 | 56 | 57 | test_that(".entmax_threshold_and_support works as expected with default values", { 58 | input <- torch::torch_randn(10, 5) 59 | expect_no_error( 60 | result <- tabnet:::.entmax_threshold_and_support(input) 61 | ) 62 | expect_type(result, "list") 63 | expect_length(result, 2) 64 | 65 | tau_star <- result[[1]] 66 | support_size <- result[[2]] 67 | expect_tensor(tau_star) 68 | expect_tensor_shape(tau_star, c(input$shape[1], 1)) 69 | expect_tensor(support_size) 70 | expect_tensor_shape(support_size, c(input$shape[1], 1)) 71 | }) 72 | 73 | 74 | test_that(".entmax_threshold_and_support works as expected with k < input$size(dim)", { 75 | input <- torch::torch_randn(10, 5) 76 | dim <- 1L 77 | k <- 3 78 | expect_no_error( 79 | result <- tabnet:::.entmax_threshold_and_support(input, dim, k) 80 | ) 81 | expect_type(result, "list") 82 | expect_length(result, 2) 83 | 84 | tau_star <- result[[1]] 85 | support_size <- result[[2]] 86 | expect_tensor(tau_star) 87 | expect_tensor_shape(tau_star, c(1, input$shape[2])) 88 | expect_tensor(support_size) 89 | expect_tensor_shape(support_size, c(1, input$shape[2])) 90 | }) 91 | 92 | 93 | test_that(".entmax_threshold_and_support works as expected with k >= input$size(dim)", { 94 | input <- torch::torch_randn(10, 5) 95 | dim <- 2L 96 | k <- 12 97 | expect_no_error( 98 | result <- tabnet:::.entmax_threshold_and_support(input, dim, k) 99 | ) 100 | expect_type(result, "list") 101 | expect_length(result, 2) 102 | 103 | tau_star <- result[[1]] 104 | support_size <- result[[2]] 105 | expect_tensor(tau_star) 106 | expect_tensor_shape(tau_star, c(input$shape[1], 1)) 107 | expect_tensor(support_size) 108 | expect_tensor_shape(support_size, c(input$shape[1], 1)) 109 | }) 110 | 111 | 112 | test_that("fit works with entmax mask-type", { 113 | 114 | rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>% 115 | step_normalize(all_numeric(), -all_outcomes()) 116 | 117 | expect_no_error( 118 | tabnet_fit(rec, attrition, epochs = 1, valid_split = 0.25, verbose = TRUE, 119 | config = tabnet_config( mask_type = "entmax")) 120 | ) 121 | }) 122 | 123 | 124 | test_that("sparsemax_15_function works as a proper autograd", { 125 | 126 | input = torch::torch_rand(10,2, requires_grad = TRUE) 127 | 128 | expect_no_error( 129 | output <- tabnet:::sparsemax_function(input, 2L, 3) 130 | ) 131 | expect_no_error( 132 | output$backward 133 | ) 134 | }) 135 | 136 | test_that("entmax_15_function works as a proper autograd", { 137 | 138 | input = torch::torch_rand(10,2, requires_grad = TRUE) 139 | 140 | expect_no_error( 141 | output <- tabnet:::entmax_15_function(input, 2L, 3) 142 | ) 143 | expect_no_error( 144 | output$backward 145 | ) 146 | }) 147 | 148 | 149 | test_that("fit works with sparsemax15 mask-type", { 150 | 151 | rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>% 152 | step_normalize(all_numeric(), -all_outcomes()) 153 | 154 | expect_no_error( 155 | tabnet_fit(rec, attrition, epochs = 1, valid_split = 0.25, verbose = TRUE, 156 | config = tabnet_config( mask_type = "sparsemax15")) 157 | ) 158 | expect_no_error( 159 | tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, 160 | config = tabnet_config( mask_type = "sparsemax15", mask_topk = 12)) 161 | ) 162 | }) 163 | 164 | test_that("fit works with entmax15 mask-type", { 165 | 166 | rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>% 167 | step_normalize(all_numeric(), -all_outcomes()) 168 | 169 | expect_no_error( 170 | tabnet_fit(rec, attrition, epochs = 1, valid_split = 0.25, verbose = TRUE, 171 | config = tabnet_config( mask_type = "entmax15")) 172 | ) 173 | expect_no_error( 174 | tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, 175 | config = tabnet_config( mask_type = "entmax15", mask_topk = 12)) 176 | ) 177 | }) 178 | -------------------------------------------------------------------------------- /R/plot.R: -------------------------------------------------------------------------------- 1 | #' Plot tabnet_fit model loss along epochs 2 | #' 3 | #' @param object A `tabnet_fit` or `tabnet_pretrain` object as a result of 4 | #' [tabnet_fit()] or [tabnet_pretrain()]. 5 | #' @param ... not used. 6 | #' @return A `ggplot` object. 7 | #' @details 8 | #' Plot the training loss along epochs, and validation loss along epochs if any. 9 | #' A dot is added on epochs where model snapshot is available, helping 10 | #' the choice of `from_epoch` value for later model training resume. 11 | #' 12 | #' @examplesIf (torch::torch_is_installed() && require("modeldata")) 13 | #' \dontrun{ 14 | #' library(ggplot2) 15 | #' data("attrition", package = "modeldata") 16 | #' attrition_fit <- tabnet_fit(Attrition ~. , data=attrition, valid_split=0.2, epoch=11) 17 | #' 18 | #' # Plot the model loss over epochs 19 | #' autoplot(attrition_fit) 20 | #'} 21 | #' @importFrom rlang .data 22 | #' 23 | autoplot.tabnet_fit <- function(object, ...) { 24 | 25 | collect_metrics <- tibble::enframe(object$fit$metrics, name = "epoch") %>% 26 | tidyr::unnest_wider(value) %>% 27 | dplyr::mutate_if(is.list, ~purrr::map_dbl(.x, mean)) %>% 28 | dplyr::select_if(function(x) {!all(is.na(x))} ) %>% 29 | tidyr::pivot_longer(cols = !dplyr::matches("epoch|checkpoint"), 30 | names_to = "dataset", values_to = "loss") 31 | 32 | 33 | p <- ggplot2::ggplot(collect_metrics, ggplot2::aes(x = epoch, y = loss, color = dataset)) + 34 | ggplot2::geom_line() + 35 | ggplot2::scale_y_log10() + 36 | ggplot2::guides(colour = ggplot2::guide_legend("Dataset", order=1, override.aes = list(size = 1.7, shape = " ")), 37 | size = ggplot2::guide_legend("has checkpoint", order = 2, override.aes = list(size = 3, color = "#F8766D"), 38 | label.theme = ggplot2::element_text(colour = "#FFFFFF"))) + 39 | ggplot2::theme(legend.position = "bottom") + 40 | ggplot2::labs(y="Mean loss (log scale)") 41 | 42 | if ("checkpoint" %in% names(collect_metrics)) { 43 | checkpoints <- collect_metrics %>% 44 | dplyr::filter(checkpoint == TRUE, dataset == "train") %>% 45 | dplyr::select(-checkpoint) %>% 46 | dplyr::mutate(size = 2) 47 | p + 48 | ggplot2::geom_point(data = checkpoints, ggplot2::aes(x = epoch, y = loss, color = dataset, size = .data$size )) 49 | } else { 50 | p 51 | } 52 | } 53 | 54 | #' @rdname autoplot.tabnet_fit 55 | autoplot.tabnet_pretrain <- autoplot.tabnet_fit 56 | 57 | #' Plot tabnet_explain mask importance heatmap 58 | #' 59 | #' @param object A `tabnet_explain` object as a result of [tabnet_explain()]. 60 | #' @param type a character value. Either `"mask_agg"` the default, for a single 61 | #' heatmap of aggregated mask importance per predictor along the dataset, 62 | #' or `"steps"` for one heatmap at each mask step. 63 | #' @param quantile numerical value between 0 and 1. Provides quantile clipping of the 64 | #' mask values 65 | #' @param ... not used. 66 | #' @return A `ggplot` object. 67 | #' @details 68 | #' Plot the `tabnet_explain` object mask importance per variable along the predicted dataset. 69 | #' `type="mask_agg"` output a single heatmap of mask aggregated values, 70 | #' `type="steps"` provides a plot faceted along the `n_steps` mask present in the model. 71 | #' `quantile=.995` may be used for strong outlier clipping, in order to better highlight 72 | #' low values. `quantile=1`, the default, do not clip any values. 73 | #' 74 | #' @examplesIf (torch::torch_is_installed() && require("modeldata")) 75 | #' \dontrun{ 76 | #' library(ggplot2) 77 | #' data("attrition", package = "modeldata") 78 | #' 79 | #' ## Single-outcome binary classification of `Attrition` in `attrition` dataset 80 | #' attrition_fit <- tabnet_fit(Attrition ~. , data=attrition, epoch=11) 81 | #' attrition_explain <- tabnet_explain(attrition_fit, attrition) 82 | #' # Plot the model aggregated mask interpretation heatmap 83 | #' autoplot(attrition_explain) 84 | #' 85 | #' ## Multi-outcome regression on `Sale_Price` and `Pool_Area` in `ames` dataset, 86 | #' data("ames", package = "modeldata") 87 | #' x <- ames[,-which(names(ames) %in% c("Sale_Price", "Pool_Area"))] 88 | #' y <- ames[, c("Sale_Price", "Pool_Area")] 89 | #' ames_fit <- tabnet_fit(x, y, epochs = 1, verbose=TRUE) 90 | #' ames_explain <- tabnet_explain(ames_fit, x) 91 | #' autoplot(ames_explain, quantile = 0.99) 92 | #' } 93 | autoplot.tabnet_explain <- function(object, type = c("mask_agg", "steps"), quantile = 1, ...) { 94 | type <- match.arg(type) 95 | 96 | if (type == "steps") { 97 | .data <- object$masks %>% 98 | purrr::imap_dfr(~dplyr::mutate( 99 | .x, 100 | step = sprintf("Step %d", .y), 101 | rowname = dplyr::row_number() 102 | )) %>% 103 | tidyr::pivot_longer(-c(rowname, step), names_to = "variable", values_to = "mask_agg") %>% 104 | dplyr::group_by(step) %>% 105 | dplyr::mutate(mask_agg = quantile_clip(mask_agg, probs=quantile)) %>% 106 | dplyr::ungroup() 107 | } else { 108 | 109 | .data <- object$M_explain %>% 110 | dplyr::mutate(rowname = dplyr::row_number()) %>% 111 | tidyr::pivot_longer(-rowname, names_to = "variable", values_to = "mask_agg") %>% 112 | dplyr::mutate(mask_agg = quantile_clip(mask_agg, probs=quantile), 113 | step = "mask_aggregate") 114 | } 115 | 116 | p <- ggplot2::ggplot(.data, ggplot2::aes(x = rowname, y = variable, fill = mask_agg)) + 117 | ggplot2::geom_tile() + 118 | ggplot2::scale_fill_viridis_c() + 119 | ggplot2::facet_wrap(~step) + 120 | ggplot2::theme_minimal() 121 | p 122 | } 123 | 124 | quantile_clip <- function(x, probs) { 125 | quantile <- quantile(x, probs = probs) 126 | purrr::map_dbl(x, ~min(.x, quantile)) 127 | } 128 | -------------------------------------------------------------------------------- /tests/testthat/test-hardhat_multi-outcome.R: -------------------------------------------------------------------------------- 1 | 2 | test_that("Training multi-output regression from data.frame", { 3 | 4 | expect_no_error( 5 | fit <- tabnet_fit(x, data.frame(y = y, z = y + 1), epochs = 1) 6 | ) 7 | 8 | expect_no_error( 9 | result <- predict(fit, x) 10 | ) 11 | expect_equal(ncol(result), 2) 12 | 13 | }) 14 | 15 | test_that("Training multi-output regression from formula", { 16 | 17 | expect_no_error( 18 | fit <- tabnet_fit(Sale_Price + Latitude + Longitude ~ ., small_ames, epochs = 1) 19 | ) 20 | 21 | expect_no_error( 22 | result <- predict(fit, ames) 23 | ) 24 | expect_equal(ncol(result), 3) 25 | 26 | }) 27 | 28 | test_that("Training multi-output regression from recipe", { 29 | 30 | rec <- recipe(Sale_Price + Latitude + Longitude ~ ., data = small_ames) %>% 31 | step_zv(all_predictors()) %>% 32 | step_normalize(all_numeric(), -all_outcomes()) 33 | 34 | expect_no_error( 35 | fit <- tabnet_fit(rec, small_ames, epochs = 1) 36 | ) 37 | 38 | expect_no_error( 39 | result <- predict(fit, ames) 40 | ) 41 | expect_equal(ncol(result), 3) 42 | 43 | }) 44 | 45 | test_that("Training multilabel classification from data.frame", { 46 | 47 | expect_no_error( 48 | fit <- tabnet_fit(attri_mult_x, data.frame(y = attriy, z = attriy, sat = attrix$JobSatisfaction), 49 | epochs = 1) 50 | ) 51 | 52 | expect_no_error( 53 | result <- predict(fit, attri_mult_x, type = "prob") 54 | ) 55 | 56 | expect_equal(ncol(result), 3) 57 | outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) 58 | # we get back outcomes vars with a `.pred_` prefix 59 | expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) 60 | 61 | # result columns are tibbles of resp 2, 2, 4 columns 62 | expect_true(all(purrr::map_lgl(result, tibble::is_tibble))) 63 | expect_equal(purrr::map_dbl(result, ncol), outcome_nlevels, ignore_attr = TRUE) 64 | 65 | expect_no_error( 66 | result <- predict(fit, attri_mult_x) 67 | ) 68 | expect_equal(ncol(result), 3) 69 | expect_equal(stringr::str_remove(names(result), ".pred_class_"), names(outcome_nlevels)) 70 | 71 | }) 72 | 73 | test_that("Training multilabel classification from formula", { 74 | 75 | expect_no_error( 76 | fit <- tabnet_fit(Attrition + JobSatisfaction ~ ., data = attrition[ids,], 77 | epochs = 1) 78 | ) 79 | 80 | expect_no_error( 81 | result <- predict(fit, attri_mult_x, type = "prob") 82 | ) 83 | 84 | expect_equal(ncol(result), 2) 85 | outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) 86 | # we get back outcomes vars with a `.pred_` prefix 87 | expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) 88 | 89 | # result columns are tibbles of resp 2, 2, 4 columns 90 | expect_true(all(purrr::map_lgl(result, tibble::is_tibble))) 91 | expect_equal(purrr::map_dbl(result, ncol), outcome_nlevels, ignore_attr = TRUE) 92 | 93 | }) 94 | 95 | test_that("Training multilabel classification from recipe", { 96 | 97 | rec <- recipe(Attrition + JobSatisfaction ~ ., data = attrition[ids,]) %>% 98 | step_zv(all_predictors()) %>% 99 | step_normalize(all_numeric(), -all_outcomes()) 100 | 101 | expect_no_error( 102 | fit <- tabnet_fit(rec, attrition[ids,], epochs = 1, valid_split = 0.25, 103 | verbose = TRUE) 104 | ) 105 | 106 | expect_no_error( 107 | result <- predict(fit, attrition) 108 | ) 109 | expect_equal(ncol(result), 2) 110 | 111 | outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) 112 | expect_equal(stringr::str_remove(names(result), ".pred_class_"), names(outcome_nlevels)) 113 | 114 | }) 115 | 116 | test_that("Training multilabel classification from data.frame with validation split", { 117 | 118 | expect_no_error( 119 | fit <- tabnet_fit(attri_mult_x, data.frame(y=attriy, z=attriy, sat=attrix$JobSatisfaction), 120 | valid_split = 0.2, epochs = 1) 121 | ) 122 | 123 | expect_no_error( 124 | result <- predict(fit, attri_mult_x, type = "prob") 125 | ) 126 | 127 | expect_equal(ncol(result), 3) 128 | 129 | outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) 130 | # we get back outcomes vars with a `.pred_` prefix 131 | expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) 132 | 133 | # result columns are tibbles of resp 2, 2, 4 columns 134 | expect_true(all(purrr::map_lgl(result, tibble::is_tibble))) 135 | expect_equal(purrr::map_dbl(result, ncol), outcome_nlevels, ignore_attr = TRUE) 136 | 137 | expect_no_error( 138 | result <- predict(fit, attri_mult_x) 139 | ) 140 | expect_equal(ncol(result), 3) 141 | 142 | # we get back outcomes vars with a `.pred_class_` prefix 143 | expect_equal(stringr::str_remove(names(result), ".pred_class_"), names(fit$blueprint$ptypes$outcomes)) 144 | }) 145 | 146 | 147 | test_that("Training multilabel mixed output fails with explicit error", { 148 | 149 | attri_multi_x <- attrix[-which(names(attrix) == "PercentSalaryHike")] 150 | expect_error( 151 | fit <- tabnet_fit(attri_multi_x, data.frame(y = attriy, hik = attrix$PercentSalaryHike), epochs = 1), 152 | "Mixed multi-outcome type" 153 | ) 154 | expect_error( 155 | fit <- tabnet_fit(Attrition + PercentSalaryHike ~ ., data = attrition[ids,], epochs = 1), 156 | "Mixed multi-outcome type" 157 | ) 158 | rec <- recipe(Attrition + PercentSalaryHike ~ ., data = attrition[ids,]) %>% 159 | step_normalize(all_numeric(), -all_outcomes()) 160 | expect_error( 161 | fit <- tabnet_fit(rec, data = attrition[ids,], epochs = 1), 162 | "Mixed multi-outcome type" 163 | ) 164 | }) 165 | 166 | test_that("Training multi-output regression fails for matrix", { 167 | 168 | expect_error( 169 | fit <- tabnet_fit(x, matrix(rnorm( 2 * length(y)), ncol = 2), epochs = 1), 170 | "All columns of `y` must have unique names" 171 | ) 172 | 173 | expect_error( 174 | fit <- tabnet_fit(x, matrix(factor(runif( 2 * length(y)) < 0.5) , ncol = 2), epochs = 1), 175 | "All columns of `y` must have unique names" 176 | ) 177 | 178 | }) 179 | -------------------------------------------------------------------------------- /man/tabnet_fit.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/hardhat.R 3 | \name{tabnet_fit} 4 | \alias{tabnet_fit} 5 | \alias{tabnet_fit.default} 6 | \alias{tabnet_fit.data.frame} 7 | \alias{tabnet_fit.formula} 8 | \alias{tabnet_fit.recipe} 9 | \alias{tabnet_fit.Node} 10 | \title{Tabnet model} 11 | \usage{ 12 | tabnet_fit(x, ...) 13 | 14 | \method{tabnet_fit}{default}(x, ...) 15 | 16 | \method{tabnet_fit}{data.frame}( 17 | x, 18 | y, 19 | tabnet_model = NULL, 20 | config = tabnet_config(), 21 | ..., 22 | from_epoch = NULL, 23 | weights = NULL 24 | ) 25 | 26 | \method{tabnet_fit}{formula}( 27 | formula, 28 | data, 29 | tabnet_model = NULL, 30 | config = tabnet_config(), 31 | ..., 32 | from_epoch = NULL, 33 | weights = NULL 34 | ) 35 | 36 | \method{tabnet_fit}{recipe}( 37 | x, 38 | data, 39 | tabnet_model = NULL, 40 | config = tabnet_config(), 41 | ..., 42 | from_epoch = NULL, 43 | weights = NULL 44 | ) 45 | 46 | \method{tabnet_fit}{Node}( 47 | x, 48 | tabnet_model = NULL, 49 | config = tabnet_config(), 50 | ..., 51 | from_epoch = NULL 52 | ) 53 | } 54 | \arguments{ 55 | \item{x}{Depending on the context: 56 | \itemize{ 57 | \item A \strong{data frame} of predictors. 58 | \item A \strong{matrix} of predictors. 59 | \item A \strong{recipe} specifying a set of preprocessing steps 60 | created from \code{\link[recipes:recipe]{recipes::recipe()}}. 61 | \item A \strong{Node} where tree will be used as hierarchical outcome, 62 | and attributes will be used as predictors. 63 | } 64 | 65 | The predictor data should be standardized (e.g. centered or scaled). 66 | The model treats categorical predictors internally thus, you don't need to 67 | make any treatment. 68 | The model treats missing values internally thus, you don't need to make any 69 | treatment.} 70 | 71 | \item{...}{Model hyperparameters. 72 | Any hyperparameters set here will update those set by the config argument. 73 | See \code{\link[=tabnet_config]{tabnet_config()}} for a list of all possible hyperparameters.} 74 | 75 | \item{y}{When \code{x} is a \strong{data frame} or \strong{matrix}, \code{y} is the outcome 76 | specified as: 77 | \itemize{ 78 | \item A \strong{data frame} with 1 or many numeric column (regression) or 1 or many categorical columns (classification) . 79 | \item A \strong{matrix} with 1 column. 80 | \item A \strong{vector}, either numeric or categorical. 81 | }} 82 | 83 | \item{tabnet_model}{A previously fitted \code{tabnet_model} object to continue the fitting on. 84 | if \code{NULL} (the default) a brand new model is initialized.} 85 | 86 | \item{config}{A set of hyperparameters created using the \code{tabnet_config} function. 87 | If no argument is supplied, this will use the default values in \code{\link[=tabnet_config]{tabnet_config()}}.} 88 | 89 | \item{from_epoch}{When a \code{tabnet_model} is provided, restore the network weights from a specific epoch. 90 | Default is last available checkpoint for restored model, or last epoch for in-memory model.} 91 | 92 | \item{weights}{Unused. Placeholder for hardhat::importance_weight() variables.} 93 | 94 | \item{formula}{A formula specifying the outcome terms on the left-hand side, 95 | and the predictor terms on the right-hand side.} 96 | 97 | \item{data}{When a \strong{recipe} or \strong{formula} is used, \code{data} is specified as: 98 | \itemize{ 99 | \item A \strong{data frame} containing both the predictors and the outcome. 100 | }} 101 | } 102 | \value{ 103 | A TabNet model object. It can be used for serialization, predictions, or further fitting. 104 | } 105 | \description{ 106 | Fits the \href{https://arxiv.org/abs/1908.07442}{TabNet: Attentive Interpretable Tabular Learning} model 107 | } 108 | \section{Fitting a pre-trained model}{ 109 | 110 | 111 | When providing a parent \code{tabnet_model} parameter, the model fitting resumes from that model weights 112 | at the following epoch: 113 | \itemize{ 114 | \item last fitted epoch for a model already in torch context 115 | \item Last model checkpoint epoch for a model loaded from file 116 | \item the epoch related to a checkpoint matching or preceding the \code{from_epoch} value if provided 117 | The model fitting metrics append on top of the parent metrics in the returned TabNet model. 118 | } 119 | } 120 | 121 | \section{Multi-outcome}{ 122 | 123 | 124 | TabNet allows multi-outcome prediction, which is usually named \href{https://en.wikipedia.org/wiki/Multi-label_classification}{multi-label classification} 125 | or multi-output regression when outcomes are numerical. 126 | Multi-outcome currently expect outcomes to be either all numeric or all categorical. 127 | } 128 | 129 | \section{Threading}{ 130 | 131 | 132 | TabNet uses \code{torch} as its backend for computation and \code{torch} uses all 133 | available threads by default. 134 | 135 | You can control the number of threads used by \code{torch} with: 136 | 137 | \if{html}{\out{
}}\preformatted{torch::torch_set_num_threads(1) 138 | torch::torch_set_num_interop_threads(1) 139 | }\if{html}{\out{
}} 140 | } 141 | 142 | \examples{ 143 | \dontshow{if ((torch::torch_is_installed() && require("modeldata"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 144 | \dontrun{ 145 | data("ames", package = "modeldata") 146 | data("attrition", package = "modeldata") 147 | 148 | ## Single-outcome regression using formula specification 149 | fit <- tabnet_fit(Sale_Price ~ ., data = ames, epochs = 4) 150 | 151 | ## Single-outcome classification using data-frame specification 152 | attrition_x <- attrition[ids,-which(names(attrition) == "Attrition")] 153 | fit <- tabnet_fit(attrition_x, attrition$Attrition, epochs = 4, verbose = TRUE) 154 | 155 | ## Multi-outcome regression on `Sale_Price` and `Pool_Area` in `ames` dataset using formula, 156 | ames_fit <- tabnet_fit(Sale_Price + Pool_Area ~ ., data = ames, epochs = 4, valid_split = 0.2) 157 | 158 | ## Multi-label classification on `Attrition` and `JobSatisfaction` in 159 | ## `attrition` dataset using recipe 160 | library(recipes) 161 | rec <- recipe(Attrition + JobSatisfaction ~ ., data = attrition) \%>\% 162 | step_normalize(all_numeric(), -all_outcomes()) 163 | 164 | attrition_fit <- tabnet_fit(rec, data = attrition, epochs = 4, valid_split = 0.2) 165 | 166 | ## Hierarchical classification on `acme` 167 | data(acme, package = "data.tree") 168 | 169 | acme_fit <- tabnet_fit(acme, epochs = 4, verbose = TRUE) 170 | 171 | # Note: Model's number of epochs should be increased for publication-level results. 172 | } 173 | \dontshow{\}) # examplesIf} 174 | } 175 | -------------------------------------------------------------------------------- /tests/testthat/test-hardhat_hierarchical.R: -------------------------------------------------------------------------------- 1 | test_that("C-HMCNN get_constr_output works ", { 2 | x <- torch::torch_rand(c(2,4)) 3 | R <- torch::torch_tril(torch::torch_zeros(c(4,4))$bernoulli(p = 0.2) + torch::torch_diag(rep(1,4)))$to(dtype = torch::torch_bool()) 4 | expect_no_error( 5 | constr_output <- get_constr_output(x, R) 6 | ) 7 | expect_tensor_shape( 8 | constr_output, x$shape 9 | ) 10 | # expect_equal( 11 | # constr_output$dtype, torch_tensor(0.1)$dtype 12 | # ) 13 | 14 | R <- torch::torch_zeros(c(4,4))$to(dtype = torch::torch_bool()) 15 | expect_equal_to_tensor( 16 | get_constr_output(x, R), torch::torch_zeros_like(x) 17 | ) 18 | }) 19 | 20 | test_that("C-HMCNN max_constraint_output works ", { 21 | output <- torch::torch_rand(c(3, 5)) 22 | labels <- torch::torch_diag(rep(1,5))[1:3, ]$to(dtype = torch::torch_bool()) 23 | ancestor <- torch::torch_tril(torch::torch_zeros(c(5, 5))$bernoulli(p = 0.2) )$to(dtype = torch::torch_bool()) 24 | 25 | expect_no_error( 26 | MC_output <- max_constraint_output(output, labels, ancestor) 27 | ) 28 | expect_tensor_shape( 29 | MC_output, output$shape 30 | ) 31 | # max_constraint_output is not identity 32 | expect_not_equal_to_tensor( 33 | MC_output, output 34 | ) 35 | # max_constraint_output provides more than 35% null values 36 | expect_gte( 37 | as.matrix(torch::torch_sum(MC_output == 0), device="cpu"), .30 * output$shape[1] * output$shape[2] 38 | ) 39 | }) 40 | 41 | test_that("node_to_df works ", { 42 | expect_no_error( 43 | node_to_df(acme) 44 | ) 45 | expect_no_error( 46 | attrition_df <- node_to_df(attrition_tree) 47 | ) 48 | # node_to_df removes first and last level of the hierarchy 49 | outcome_levels <- paste0("level_", seq(2, attrition_tree$height - 1)) 50 | expect_equal(names(attrition_df$y), outcome_levels) 51 | 52 | # node_to_df do not shuffle outcome rows 53 | df <- tibble(pred_1 = seq(1,26), pred_2 = seq(26,1), 54 | level_2 = factor(LETTERS[1:26]), level_3 = factor(letters[26:1])) 55 | df_node_df <- df %>% 56 | mutate(pathString = paste("synth", level_2, level_3, level_3, sep = "/")) %>% 57 | select(-level_2, -level_3) %>% 58 | as.Node() %>% 59 | node_to_df() 60 | 61 | expect_equal(df_node_df$y %>% as_tibble(), df %>% select(starts_with("level_"))) 62 | expect_equal(df_node_df$x %>% as_tibble(), df %>% select(starts_with("pred_"))) 63 | 64 | }) 65 | 66 | 67 | test_that("Training hierarchical classification for {data.tree} Node", { 68 | 69 | expect_no_error( 70 | fit <- tabnet_fit(acme, epochs = 1) 71 | ) 72 | expect_no_error( 73 | result <- predict(fit, acme_df, type = "prob") 74 | ) 75 | 76 | expect_equal(ncol(result), 3) 77 | outcome_levels <-levels(fit$blueprint$ptypes$outcomes[[1]]) 78 | # we get back outcomes vars with a `.pred_` prefix 79 | expect_equal(stringr::str_remove(names(result), ".pred_"), outcome_levels) 80 | expect_no_error( 81 | result <- predict(fit, acme_df) 82 | ) 83 | expect_equal(ncol(result), 1) 84 | 85 | expect_no_error( 86 | fit <- tabnet_fit(attrition_tree, epochs = 1) 87 | ) 88 | expect_no_error( 89 | result <- predict(fit, attrition_tree, type = "prob") 90 | ) 91 | 92 | expect_equal(ncol(result), 2) # 2 outcomes levels_ 93 | 94 | outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) 95 | # we get back outcomes vars with a `.pred_` prefix 96 | expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) 97 | 98 | # result columns are tibbles of resp 2, 2, 4 columns 99 | expect_true(all(purrr::map_lgl(result, tibble::is_tibble))) 100 | expect_equal(unname(purrr::map_dbl(result, ncol)), unname(outcome_nlevels), ignore_attr = TRUE) 101 | 102 | }) 103 | 104 | test_that("Training hierarchical classification for {data.tree} Node with validation split", { 105 | 106 | expect_no_error( 107 | fit <- tabnet_fit(attrition_tree, valid_split = 0.2, epochs = 1) 108 | ) 109 | 110 | expect_no_error( 111 | result <- predict(fit, attrition_tree, type = "prob") 112 | ) 113 | 114 | expect_equal(ncol(result), 2) # 2 outcomes levels_ 115 | 116 | outcome_nlevels <- purrr::map_dbl(fit$blueprint$ptypes$outcomes, ~length(levels(.x))) 117 | # we get back outcomes vars with a `.pred_` prefix 118 | expect_equal(stringr::str_remove(names(result), ".pred_"), names(outcome_nlevels)) 119 | 120 | # result columns are tibbles of resp 2, 2, 4 columns 121 | expect_true(all(purrr::map_lgl(result, tibble::is_tibble))) 122 | expect_equal(unname(purrr::map_dbl(result, ncol)), unname(outcome_nlevels), ignore_attr = TRUE) 123 | 124 | expect_no_error( 125 | result <- predict(fit, attrition_tree) 126 | ) 127 | expect_equal(ncol(result), 2) # 2 outcomes levels_ 128 | 129 | # we get back outcomes vars with a `.pred_class_` prefix 130 | expect_equal(stringr::str_remove(names(result), ".pred_class_"), names(fit$blueprint$ptypes$outcomes)) 131 | }) 132 | 133 | test_that("hierarchical classification for {data.tree} Node is explainable", { 134 | 135 | fit <- tabnet_fit(attrition_tree, epochs = 1) 136 | 137 | expect_no_error( 138 | explain <- tabnet_explain(fit, attrition_tree) 139 | ) 140 | 141 | expect_no_error( 142 | autoplot(explain) 143 | ) 144 | 145 | }) 146 | 147 | test_that("we properly check non-compliant colnames", { 148 | 149 | # try to use starwars dataset with two forbidden column name 150 | starwars_tree <- starwars %>% 151 | mutate(pathString = paste("tree", species, homeworld, `name`, sep = "/")) 152 | expect_error( 153 | check_compliant_node(starwars_tree) 154 | ,"reserved names") 155 | 156 | # augment acme dataset with a forbidden column name with no impact on predictor is ok 157 | acme$Do(function(x) { 158 | x$level_4 <- as.character(data.tree::Aggregate(node = x, 159 | attribute = "p", 160 | aggFun = sum)) 161 | }, 162 | traversal = "post-order") 163 | expect_no_error(check_compliant_node(acme)) 164 | 165 | expect_no_error(tabnet_fit(acme, epochs = 1)) 166 | 167 | # augment acme dataset with a used forbidden column name raise error 168 | acme$Do(function(x) { 169 | x$level_3 <- data.tree::Aggregate(node = x, 170 | attribute = "p", 171 | aggFun = sum) 172 | }, 173 | traversal = "post-order") 174 | expect_error( 175 | check_compliant_node(acme) 176 | ,"reserved names") 177 | 178 | expect_error( 179 | tabnet_fit(acme, epochs = 1) 180 | ,"reserved names") 181 | 182 | }) 183 | 184 | -------------------------------------------------------------------------------- /tests/testthat/test-pretraining.R: -------------------------------------------------------------------------------- 1 | test_that("transpose_metrics is not adding an unnamed entry on top of the list", { 2 | 3 | metrics <- list(loss = 1, loss = 2, loss = 3, loss = 4) 4 | 5 | expect_no_error( 6 | tabnet:::transpose_metrics(metrics) 7 | ) 8 | 9 | expect_equal( 10 | tabnet:::transpose_metrics(metrics), 11 | list(loss = c(1, 2, 3, 4)) 12 | ) 13 | 14 | }) 15 | 16 | test_that("Unsupervised training with default config, data.frame and formula", { 17 | 18 | expect_no_error( 19 | fit <- tabnet_pretrain(x, y, epochs = 1) 20 | ) 21 | expect_s3_class( fit, "tabnet_pretrain") 22 | expect_equal(length(fit), 3) 23 | expect_equal(names(fit), c("fit", "serialized_net", "blueprint")) 24 | expect_equal(length(fit$fit), 5) 25 | expect_equal(names(fit$fit), c("network", "metrics", "config", "checkpoints", "importances")) 26 | expect_equal(length(fit$fit$metrics), 1) 27 | 28 | expect_no_error( 29 | fit <- tabnet_pretrain(Sale_Price ~ ., data = ames, epochs = 1) 30 | ) 31 | expect_s3_class( fit, "tabnet_pretrain") 32 | expect_equal(length(fit), 3) 33 | expect_equal(names(fit), c("fit", "serialized_net", "blueprint")) 34 | expect_equal(length(fit$fit), 5) 35 | expect_equal(names(fit$fit), c("network", "metrics", "config", "checkpoints", "importances")) 36 | expect_equal(length(fit$fit$metrics), 1) 37 | 38 | }) 39 | 40 | test_that("Unsupervised training with pretraining_ratio", { 41 | 42 | expect_no_error( 43 | pretrain <- tabnet_pretrain(attrix, attriy, epochs = 1, pretraining_ratio=0.2) 44 | ) 45 | 46 | }) 47 | 48 | test_that("Unsupervised training prevent predict with an explicit message", { 49 | 50 | pretrain <- tabnet_pretrain(attrix, attriy, epochs = 1, pretraining_ratio=0.2) 51 | 52 | expect_error( 53 | predict(pretrain, attrix, type = "prob"), 54 | regexp = "tabnet_pretrain" 55 | ) 56 | 57 | expect_error( 58 | predict(pretrain, attrix), 59 | regexp = "tabnet_pretrain" 60 | ) 61 | 62 | }) 63 | 64 | test_that("pretraining with `tabnet_model= ` parameter raise a warning", { 65 | 66 | expect_warning( 67 | fit <- tabnet_pretrain(x, y, epochs = 1, tabnet_model = ames_pretrain) 68 | ) 69 | expect_s3_class( fit, "tabnet_pretrain") 70 | expect_equal( length(fit), length(ames_pretrain)) 71 | expect_equal( length(fit$fit$metrics), 1) 72 | }) 73 | 74 | test_that("errors when using an argument that do not exist", { 75 | 76 | expect_error( 77 | pretrain <- tabnet_pretrain(x, y, pretraining_ratiosas = 1-1e5), 78 | regexp = "unused argument" 79 | ) 80 | 81 | }) 82 | 83 | test_that("works with validation split", { 84 | 85 | expect_no_error( 86 | pretrain <- tabnet_pretrain(attrix, attriy, epochs = 1, valid_split = 0.2) 87 | ) 88 | 89 | expect_no_error( 90 | pretrain <- tabnet_pretrain(attrix, attriy, epochs = 1, valid_split = 0.2, verbose = TRUE) 91 | ) 92 | 93 | }) 94 | 95 | test_that("works with categorical embedding dimension as list", { 96 | 97 | config <- tabnet_config(cat_emb_dim=c(1,1,2,2,1,1,1,2,1,1,1,2,2,2)) 98 | 99 | expect_no_error( 100 | pretrain <- tabnet_pretrain(attrix, attriy, epochs = 1, valid_split = 0.2, config=config) 101 | ) 102 | }) 103 | 104 | test_that("explicit error message when categorical embedding dimension vector has wrong size", { 105 | 106 | config <- tabnet_config(cat_emb_dim=c(1,1,2,2)) 107 | 108 | expect_error( 109 | pretrain <- tabnet_pretrain(attrix, attriy, epochs = 1, valid_split = 0.2, config=config), 110 | regexp = "number of categorical predictors" 111 | ) 112 | }) 113 | 114 | test_that("can train from a recipe", { 115 | 116 | rec <- recipe(Attrition ~ ., data = attrition) %>% 117 | step_normalize(all_numeric(), -all_outcomes()) 118 | 119 | expect_no_error( 120 | pretrain <- tabnet_pretrain(rec, attrition, epochs = 1, verbose = TRUE) 121 | ) 122 | 123 | }) 124 | 125 | test_that("lr scheduler step works", { 126 | 127 | expect_no_error( 128 | fit <- tabnet_pretrain(x, y, epochs = 3, lr_scheduler = "step", 129 | lr_decay = 0.1, step_size = 1) 130 | ) 131 | 132 | sc_fn <- function(optimizer) { 133 | torch::lr_step(optimizer, step_size = 1, gamma = 0.1) 134 | } 135 | 136 | expect_no_error( 137 | fit <- tabnet_pretrain(x, y, epochs = 3, lr_scheduler = sc_fn, 138 | lr_decay = 0.1, step_size = 1) 139 | ) 140 | 141 | }) 142 | 143 | test_that("lr scheduler reduce_on_plateau works", { 144 | 145 | expect_no_error( 146 | fit <- tabnet_pretrain(x, y, epochs = 3, lr_scheduler = "reduce_on_plateau", 147 | lr_decay = 0.1, step_size = 1) 148 | ) 149 | 150 | sc_fn <- function(optimizer) { 151 | torch::lr_reduce_on_plateau(optimizer, factor = 0.1, patience = 10) 152 | } 153 | 154 | expect_no_error( 155 | fit <- tabnet_pretrain(x, y, epochs = 3, lr_scheduler = sc_fn, 156 | lr_decay = 0.1, step_size = 1) 157 | ) 158 | 159 | }) 160 | 161 | test_that("checkpoints works", { 162 | 163 | expect_no_error( 164 | pretrain <- tabnet_pretrain(x, y, epochs = 3, checkpoint_epochs = 1) 165 | ) 166 | 167 | expect_length( pretrain$fit$checkpoints, 3 ) 168 | 169 | # expect_equal( pretrain$fit$checkpoints[[3]], pretrain$serialized_net ) 170 | 171 | }) 172 | 173 | test_that("print module works", { 174 | 175 | testthat::local_edition(3) 176 | testthat::skip_on_os("linux") 177 | testthat::skip_on_os("windows") 178 | 179 | expect_no_error( 180 | fit <- tabnet_pretrain(x, y, epochs = 1) 181 | ) 182 | 183 | withr::with_options(new = c(cli.width = 50), 184 | expect_snapshot_output(fit)) 185 | 186 | }) 187 | 188 | test_that("num_independent_decoder and num_shared_decoder change the network number of parameters", { 189 | 190 | expect_no_error( 191 | pretrain <- tabnet_pretrain(attrix, attriy, epochs = 1, 192 | num_independent_decoder = 3, num_shared_decoder = 2) 193 | ) 194 | expect_gt( torch:::get_parameter_count(pretrain$fit$network), 195 | torch:::get_parameter_count(attr_pretrained$fit$network) 196 | ) 197 | }) 198 | 199 | test_that("num_independent_decoder and num_shared_decoder do not change the network number of parameters for fit", { 200 | 201 | expect_no_error( 202 | config <- tabnet_config(epochs = 1, 203 | num_independent_decoder = 3, num_shared_decoder = 2) 204 | ) 205 | expect_no_error( 206 | attr_fit <- tabnet_fit(attrix, attriy, config = config) 207 | ) 208 | expect_equal( torch:::get_parameter_count(attr_fit$fit$network), 209 | torch:::get_parameter_count(attr_fitted$fit$network) 210 | ) 211 | }) 212 | 213 | -------------------------------------------------------------------------------- /tests/testthat/test-model.R: -------------------------------------------------------------------------------- 1 | if (torch::cuda_is_available()) { 2 | device <- "cuda" 3 | } else { 4 | device <- "cpu" 5 | } 6 | 7 | 8 | test_that("resolve_data works through a dataloader", { 9 | data("ames", package = "modeldata") 10 | 11 | x <- ames[-which(names(ames) == "Sale_Price")] 12 | y <- ames[,"Sale_Price"] 13 | # dataset are R6 class and shall be instantiated 14 | train_ds <- torch::dataset( 15 | initialize = function() {}, 16 | .getbatch = function(batch) {tabnet:::resolve_data(x[batch,], y[batch,])}, 17 | .length = function() {nrow(x)} 18 | )() 19 | expect_no_error( 20 | train_ds$.getbatch(batch = 1:2) 21 | ) 22 | # dataloader 23 | train_dl <- torch::dataloader( 24 | train_ds, 25 | batch_size = 2000 , 26 | drop_last = TRUE, 27 | shuffle = FALSE #, 28 | # num_workers = 0L 29 | ) 30 | expect_no_error( 31 | coro::loop(for (batch in train_dl) { 32 | expect_tensor_shape(batch$x, c(2000, 73)) 33 | expect_true(batch$x$dtype == torch::torch_float()) 34 | expect_tensor_shape(batch$x_na_mask, c(2000, 73)) 35 | expect_true(batch$x_na_mask$dtype == torch::torch_bool()) 36 | expect_tensor_shape(batch$y, c(2000, 1)) 37 | expect_true(batch$y$dtype == torch::torch_float()) 38 | expect_tensor_shape(batch$cat_idx, 40) 39 | expect_true(batch$cat_idx$dtype == torch::torch_long()) 40 | expect_equal_to_r(batch$output_dim, 1L) 41 | expect_true(batch$cat_idx$dtype == torch::torch_long()) 42 | expect_tensor_shape(batch$input_dim, 1) 43 | expect_true(batch$input_dim$dtype == torch::torch_long()) 44 | expect_tensor_shape(batch$cat_dims, 40) 45 | expect_true(batch$cat_dims$dtype == torch::torch_long()) 46 | 47 | }) 48 | ) 49 | 50 | }) 51 | 52 | test_that("resolve_data works through a dataloader without nominal variables", { 53 | n <- 1000 54 | x <- data.frame( 55 | x = rnorm(n), 56 | y = rnorm(n), 57 | z = rnorm(n) 58 | ) 59 | 60 | y <- x[,"x", drop = FALSE] 61 | # dataset are R6 class and shall be instanciated 62 | train_ds <- torch::dataset( 63 | initialize = function() {}, 64 | .getbatch = function(batch) {tabnet:::resolve_data(x[batch,], y[batch,])}, 65 | .length = function() {nrow(x)} 66 | )() 67 | expect_no_error( 68 | train_ds$.getbatch(batch = 1:2) 69 | ) 70 | 71 | # dataloader 72 | train_dl <- torch::dataloader( 73 | train_ds, 74 | batch_size = 2000 , 75 | drop_last = TRUE, 76 | shuffle = FALSE #, 77 | # num_workers = 0L 78 | ) 79 | expect_no_error( 80 | coro::loop(for (batch in train_dl) { 81 | expect_tensor_shape(batch$x, c(2000, 3)) 82 | expect_true(batch$x$dtype == torch::torch_float()) 83 | expect_tensor_shape(batch$x_na_mask, c(2000, 3)) 84 | expect_true(batch$x_na_mask$dtype == torch::torch_bool()) 85 | expect_tensor_shape(batch$y, c(2000, 1)) 86 | expect_true(batch$y$dtype == torch::torch_float()) 87 | expect_tensor_shape(batch$cat_idx, 0) 88 | expect_true(batch$cat_idx$dtype == torch::torch_long()) 89 | expect_equal_to_r(batch$output_dim, 1L) 90 | expect_true(batch$cat_idx$dtype == torch::torch_long()) 91 | expect_tensor_shape(batch$input_dim, 1) 92 | expect_true(batch$input_dim$dtype == torch::torch_long()) 93 | expect_tensor_shape(batch$cat_dims, 0) 94 | expect_true(batch$cat_dims$dtype == torch::torch_long()) 95 | 96 | }) 97 | ) 98 | 99 | }) 100 | 101 | test_that("resolve_data works for multioutput regression", { 102 | data("ames", package = "modeldata") 103 | 104 | x <- ames[-which(names(ames) %in% c("Sale_Price", "Lot_Area"))] 105 | y <- ames[,c("Sale_Price", "Lot_Area")] 106 | # dataset are R6 class and shall be instantiated 107 | train_ds <- torch::dataset( 108 | initialize = function() {}, 109 | .getbatch = function(batch) {tabnet:::resolve_data(x[batch,], y[batch,])}, 110 | .length = function() {nrow(x)} 111 | )() 112 | expect_error( 113 | train_ds$.getbatch(batch = 1:2), 114 | NA 115 | ) 116 | # dataloader 117 | train_dl <- torch::dataloader( 118 | train_ds, 119 | batch_size = 2000 , 120 | drop_last = TRUE, 121 | shuffle = FALSE #, 122 | # num_workers = 0L 123 | ) 124 | expect_no_error( 125 | coro::loop(for (batch in train_dl) { 126 | expect_tensor_shape(batch$x, c(2000, 72)) 127 | expect_true(batch$x$dtype == torch::torch_float()) 128 | expect_tensor_shape(batch$x_na_mask, c(2000, 72)) 129 | expect_true(batch$x_na_mask$dtype == torch::torch_bool()) 130 | expect_tensor_shape(batch$y, c(2000, 2)) 131 | expect_true(batch$y$dtype == torch::torch_float()) 132 | expect_tensor_shape(batch$cat_idx, 40) 133 | expect_true(batch$cat_idx$dtype == torch::torch_long()) 134 | expect_equal_to_r(batch$output_dim, 2L) 135 | expect_true(batch$cat_idx$dtype == torch::torch_long()) 136 | expect_tensor_shape(batch$input_dim, 1) 137 | expect_true(batch$input_dim$dtype == torch::torch_long()) 138 | expect_tensor_shape(batch$cat_dims, 40) 139 | expect_true(batch$cat_dims$dtype == torch::torch_long()) 140 | 141 | }) 142 | ) 143 | 144 | }) 145 | test_that("resolve_data works for multioutput classification", { 146 | 147 | x <- attrix[-which(names(attrix) == "JobSatisfaction")] 148 | y <- data.frame(y = attriy, z = attriy, sat = attrix$JobSatisfaction) 149 | # dataset are R6 class and shall be instantiated 150 | train_ds <- torch::dataset( 151 | initialize = function() {}, 152 | .getbatch = function(batch) {tabnet:::resolve_data(x[batch,], y[batch,])}, 153 | .length = function() {nrow(x)} 154 | )() 155 | expect_error( 156 | train_ds$.getbatch(batch = 1:2), 157 | NA 158 | ) 159 | # dataloader 160 | train_dl <- torch::dataloader( 161 | train_ds, 162 | batch_size = 2000 , 163 | drop_last = TRUE, 164 | shuffle = FALSE #, 165 | # num_workers = 0L 166 | ) 167 | expect_no_error( 168 | coro::loop(for (batch in train_dl) { 169 | expect_tensor_shape(batch$x, c(2000, 72)) 170 | expect_true(batch$x$dtype == torch::torch_float()) 171 | expect_tensor_shape(batch$x_na_mask, c(2000, 72)) 172 | expect_true(batch$x_na_mask$dtype == torch::torch_bool()) 173 | expect_tensor_shape(batch$y, c(2000, 2)) 174 | expect_true(batch$y$dtype == torch::torch_float()) 175 | expect_tensor_shape(batch$cat_idx, 40) 176 | expect_true(batch$cat_idx$dtype == torch::torch_long()) 177 | expect_equal_to_r(batch$output_dim, 2L) 178 | expect_true(batch$cat_idx$dtype == torch::torch_long()) 179 | expect_tensor_shape(batch$input_dim, 1) 180 | expect_true(batch$input_dim$dtype == torch::torch_long()) 181 | expect_tensor_shape(batch$cat_dims, 40) 182 | expect_true(batch$cat_dims$dtype == torch::torch_long()) 183 | 184 | }) 185 | ) 186 | 187 | }) 188 | -------------------------------------------------------------------------------- /vignettes/interpretation.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Interpretation tools" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Interpretation tools} 6 | %\VignetteEncoding{UTF-8} 7 | %\VignetteEngine{knitr::rmarkdown} 8 | editor_options: 9 | markdown: 10 | wrap: 72 11 | --- 12 | 13 | ```{r, include = FALSE} 14 | knitr::opts_chunk$set( 15 | collapse = TRUE, 16 | comment = "#>", 17 | eval = torch::torch_is_installed(), 18 | out.width = "100%", 19 | out.height = "300px", 20 | fig.width = 14 21 | ) 22 | ``` 23 | 24 | ```{r setup} 25 | library(tabnet) 26 | library(tidyverse, warn.conflicts = FALSE) 27 | set.seed(1) 28 | # You may want to add this for complete reproducibility 29 | # torch::torch_manual_seed(1) 30 | ``` 31 | 32 | TabNet is one of the first deep-learning interpretable model. Thanks to the underlying 33 | neural-network architecture, TabNet uses feature-selection masks that can help identify 34 | which predictor is used at each step. 35 | 36 | The paper also defines an aggregate measure that combines the masks of 37 | each step into a single measure. 38 | 39 | Note that the selection of important features in the masks is done 40 | instance-wise. Thus, you can identify, for each observation, which 41 | predictors were considered relevant. 42 | 43 | ## Experiments 44 | 45 | To show how to use the interpretation tools of `tabnet`, we are going to 46 | perform 2 experiments using synthetic datasets very similar to those 47 | that were used in the paper. 48 | 49 | ### Datasets 50 | 51 | First, let's define the functions that we will use to generate data: 52 | 53 | - `make_syn2` will generate a dataset with 10 columns, but only 54 | columns 3-6 are used to calculate the `y` response vector. This is 55 | similar to *Syn2* in the paper. 56 | 57 | - `make_syn4` will generate a dataset with 10 columns too. The 58 | response vector depends on column 10: if the value is greater than 59 | 0, we use columns 1-2 to compute the logits, otherwise we use 60 | columns 5-6. 61 | 62 | ```{r} 63 | logit_to_y <- function(logits) { 64 | p <- exp(logits)/(1 + exp(logits)) 65 | y <- factor(ifelse(p > 0.5, "yes", "no"), levels = c("yes", "no")) 66 | y 67 | } 68 | 69 | make_random_x <- function(n) { 70 | x <- as.data.frame(lapply(1:10, function(x) rnorm(n))) 71 | names(x) <- sprintf("V%02d", 1:10) 72 | x 73 | } 74 | 75 | make_syn2 <- function(n = 5000) { 76 | x <- make_random_x(n) 77 | logits <- rowSums(x[,3:6]) 78 | x$y <- logit_to_y(logits) 79 | x 80 | } 81 | 82 | make_syn4 <- function(n = 5000) { 83 | x <- make_random_x(n) 84 | logits <- ifelse( 85 | x[,10] > 0, 86 | rowSums(x[,1:2]), 87 | rowSums(x[,5:6]) 88 | ) 89 | 90 | x$y <- logit_to_y(logits) 91 | x 92 | } 93 | ``` 94 | 95 | Now let's generate the datasets: 96 | 97 | ```{r} 98 | syn2 <- make_syn2() 99 | syn4 <- make_syn4() 100 | ``` 101 | 102 | ### Syn2 103 | 104 | Let's fit a TabNet model to the `syn2` dataset and analyze the 105 | interpretation metrics. 106 | 107 | ```{r} 108 | fit_syn2 <- tabnet_fit(y ~ ., syn2, epochs = 45, learn_rate = 0.06, device = "cpu") 109 | ``` 110 | 111 | In the feature importance plot we can see that, as expected, features 112 | `V03-V06` are by far the most important ones. 113 | 114 | ```{r} 115 | #| fig.alt: "A variable importance plot of the fitted model on syn2 dataset showing V03 then V06, V04, v10 and V5 as the 5 most important features, in that order." 116 | vip::vip(fit_syn2) 117 | ``` 118 | 119 | Now let's visualize the aggregated-masks plot. In this figure we see 120 | each observation on the x axis and each variable on the y axis. The 121 | colors represent the importance of the feature in predicting the value 122 | for each observation. 123 | 124 | ```{r} 125 | #| fig.alt: "A tabnet explaination plot of the fitted model on syn2 dataset. The plot shows numerous important observations in V03 then V06 and V04. the other variables are shown with low importance points or sparse observations with importance" 126 | library(tidyverse) 127 | ex_syn2 <- tabnet_explain(fit_syn2, syn2) 128 | 129 | autoplot(ex_syn2, quantile = 0.99) 130 | ``` 131 | 132 | We can see that the region between V03 and V06 concentrates most of the 133 | higher intensity colors, and the other variables are close to 0. This is 134 | expected because those are the variables that we considered when 135 | building the dataset. 136 | 137 | Next, we can visualize the attention masks for each step in the 138 | architecture. 139 | 140 | ```{r} 141 | #| fig.alt: "3 tabnet explaination plots, one for each step of the fitted model on syn2 dataset. The Step 1 plot shows numerous important observations with V02 and V03 having high importance, step 2 plot highlight the importance of V03 and V06. Third step plot highlight V03 and V10 as important variables" 142 | autoplot(ex_syn2, type="steps") 143 | ``` 144 | 145 | We see that the first step captures a lot of noise, but the other 2 146 | steps focus specifically on the important features. 147 | 148 | ## Syn 4 149 | 150 | Now let's analyze the results for the Syn4 dataset. This dataset is a 151 | little more complicated for TabNet because there's a strong interaction 152 | between the variables. Depending on V10, different variables are used to 153 | create the response variable and we expect to see this in the masks. 154 | 155 | First we fit the model for 10 epochs. 156 | 157 | ```{r} 158 | fit_syn4 <- tabnet_fit(y ~ ., syn4, epochs = 50, device = "cpu", learn_rate = 0.08) 159 | ``` 160 | 161 | In the feature importance plot we have, as expected, strong importance 162 | for `V10`, and the other features that are used conditionally - either 163 | `V01-V02` or `V05-V06`. 164 | 165 | ```{r} 166 | #| fig.alt: "A variable importance plot of the fitted model on syn4 dataset" 167 | vip::vip(fit_syn4) 168 | ``` 169 | 170 | Now let's visualize the attention masks. Notice that we arranged the 171 | dataset by `V10` so we can easily visualize the interaction effects. 172 | 173 | We also trimmed to the 98th percentile so the colors shows the 174 | importance even if there are strong outliers. 175 | 176 | ```{r} 177 | #| fig.alt: "A tabnet explaination plot of the fitted model on syn4 dataset. The plot shows numerous important observations in V06 for low values of V10, and importance of V01 and V02 for high values of V10." 178 | ex_syn4 <- tabnet_explain(fit_syn4, arrange(syn4, V10)) 179 | 180 | autoplot(ex_syn4, quantile=.98) 181 | ``` 182 | 183 | From the figure we see that V10 is important for all observations. We 184 | also see that for the first half of the dataset `V05` and `V06` is the 185 | most important feature, while for the other half, `V01` and `V02` are 186 | the important ones. 187 | 188 | We can also visualize the masks at each step in the architecture. 189 | 190 | ```{r} 191 | #| fig.alt: "3 tabnet explaination plots, one for each step of the fitted model on syn4 dataset." 192 | autoplot(ex_syn4, type="steps", quantile=.995) 193 | ``` 194 | 195 | We see that step 1 and 3 both focus on `V10`, but on different 196 | additional features, depending on `V10`. Step 2 seems to have found some 197 | noise in `V08`, but also focuses strongly on `V01-V02` and `V05-V06`. 198 | -------------------------------------------------------------------------------- /man/tabnet_config.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/model.R 3 | \name{tabnet_config} 4 | \alias{tabnet_config} 5 | \title{Configuration for TabNet models} 6 | \usage{ 7 | tabnet_config( 8 | batch_size = 1024^2, 9 | penalty = 0.001, 10 | clip_value = NULL, 11 | loss = "auto", 12 | epochs = 5, 13 | drop_last = FALSE, 14 | decision_width = NULL, 15 | attention_width = NULL, 16 | num_steps = 3, 17 | feature_reusage = 1.3, 18 | mask_type = "sparsemax", 19 | mask_topk = NULL, 20 | virtual_batch_size = 256^2, 21 | valid_split = 0, 22 | learn_rate = 0.02, 23 | optimizer = "adam", 24 | lr_scheduler = NULL, 25 | lr_decay = 0.1, 26 | step_size = 30, 27 | checkpoint_epochs = 10, 28 | cat_emb_dim = 1, 29 | num_independent = 2, 30 | num_shared = 2, 31 | num_independent_decoder = 1, 32 | num_shared_decoder = 1, 33 | momentum = 0.02, 34 | pretraining_ratio = 0.5, 35 | verbose = FALSE, 36 | device = "auto", 37 | importance_sample_size = NULL, 38 | early_stopping_monitor = "auto", 39 | early_stopping_tolerance = 0, 40 | early_stopping_patience = 0L, 41 | num_workers = 0L, 42 | skip_importance = FALSE 43 | ) 44 | } 45 | \arguments{ 46 | \item{batch_size}{(int) Number of examples per batch, large batch sizes are 47 | recommended. (default: 1024^2)} 48 | 49 | \item{penalty}{This is the extra sparsity loss coefficient as proposed 50 | in the original paper. The bigger this coefficient is, the sparser your model 51 | will be in terms of feature selection. Depending on the difficulty of your 52 | problem, reducing this value could help (default 1e-3).} 53 | 54 | \item{clip_value}{If a num is given this will clip the gradient at 55 | clip_value. Pass \code{NULL} to not clip.} 56 | 57 | \item{loss}{(character or function) Loss function for training (default to mse 58 | for regression and cross entropy for classification)} 59 | 60 | \item{epochs}{(int) Number of training epochs.} 61 | 62 | \item{drop_last}{(logical) Whether to drop last batch if not complete during 63 | training} 64 | 65 | \item{decision_width}{(int) Width of the decision prediction layer. Bigger values gives 66 | more capacity to the model with the risk of overfitting. Values typically 67 | range from 8 to 64.} 68 | 69 | \item{attention_width}{(int) Width of the attention embedding for each mask. According to 70 | the paper n_d = n_a is usually a good choice. (default=8)} 71 | 72 | \item{num_steps}{(int) Number of steps in the architecture 73 | (usually between 3 and 10)} 74 | 75 | \item{feature_reusage}{(num) This is the coefficient for feature reusage in the masks. 76 | A value close to 1 will make mask selection least correlated between layers. 77 | Values range from 1 to 2.} 78 | 79 | \item{mask_type}{(character) Final layer of feature selector in the attentive_transformer 80 | block, either \code{"sparsemax"}, \code{"entmax"} or \code{"entmax15"}.Defaults to \code{"sparsemax"}.} 81 | 82 | \item{mask_topk}{(int) mask sparsity top-k for \code{sparsemax15} and \code{entmax15.} See \code{\link[=entmax15]{entmax15()}} for detail.} 83 | 84 | \item{virtual_batch_size}{(int) Size of the mini batches used for 85 | "Ghost Batch Normalization" (default=256^2)} 86 | 87 | \item{valid_split}{In [0, 1). The fraction of the dataset used for validation. 88 | (default = 0 means no split)} 89 | 90 | \item{learn_rate}{initial learning rate for the optimizer.} 91 | 92 | \item{optimizer}{the optimization method. currently only \code{"adam"} is supported, 93 | you can also pass any torch optimizer function.} 94 | 95 | \item{lr_scheduler}{if \code{NULL}, no learning rate decay is used. If "step" 96 | decays the learning rate by \code{lr_decay} every \code{step_size} epochs. If "reduce_on_plateau" 97 | decays the learning rate by \code{lr_decay} when no improvement after \code{step_size} epochs. 98 | It can also be a \link[torch:lr_scheduler]{torch::lr_scheduler} function that only takes the optimizer 99 | as parameter. The \code{step} method is called once per epoch.} 100 | 101 | \item{lr_decay}{multiplies the initial learning rate by \code{lr_decay} every 102 | \code{step_size} epochs. Unused if \code{lr_scheduler} is a \code{torch::lr_scheduler} 103 | or \code{NULL}.} 104 | 105 | \item{step_size}{the learning rate scheduler step size. Unused if 106 | \code{lr_scheduler} is a \code{torch::lr_scheduler} or \code{NULL}.} 107 | 108 | \item{checkpoint_epochs}{checkpoint model weights and architecture every 109 | \code{checkpoint_epochs}. (default is 10). This may cause large memory usage. 110 | Use \code{0} to disable checkpoints.} 111 | 112 | \item{cat_emb_dim}{Size of the embedding of categorical features. If int, all categorical 113 | features will have same embedding size, if list of int, every corresponding feature will have 114 | specific embedding size.} 115 | 116 | \item{num_independent}{Number of independent Gated Linear Units layers at each step of the encoder. 117 | Usual values range from 1 to 5.} 118 | 119 | \item{num_shared}{Number of shared Gated Linear Units at each step of the encoder. Usual values 120 | at each step of the decoder. range from 1 to 5} 121 | 122 | \item{num_independent_decoder}{For pretraining, number of independent Gated Linear Units layers 123 | Usual values range from 1 to 5.} 124 | 125 | \item{num_shared_decoder}{For pretraining, number of shared Gated Linear Units at each step of the 126 | decoder. Usual values range from 1 to 5.} 127 | 128 | \item{momentum}{Momentum for batch normalization, typically ranges from 0.01 129 | to 0.4 (default=0.02)} 130 | 131 | \item{pretraining_ratio}{Ratio of features to mask for reconstruction during 132 | pretraining. Ranges from 0 to 1 (default=0.5)} 133 | 134 | \item{verbose}{(logical) Whether to print progress and loss values during 135 | training.} 136 | 137 | \item{device}{the device to use for training. "cpu" or "cuda". The default ("auto") 138 | uses to "cuda" if it's available, otherwise uses "cpu".} 139 | 140 | \item{importance_sample_size}{sample of the dataset to compute importance metrics. 141 | If the dataset is larger than 1e5 obs we will use a sample of size 1e5 and 142 | display a warning.} 143 | 144 | \item{early_stopping_monitor}{Metric to monitor for early_stopping. One of "valid_loss", "train_loss" or "auto" (defaults to "auto").} 145 | 146 | \item{early_stopping_tolerance}{Minimum relative improvement to reset the patience counter. 147 | 0.01 for 1\% tolerance (default 0)} 148 | 149 | \item{early_stopping_patience}{Number of epochs without improving until stopping training. (default=5)} 150 | 151 | \item{num_workers}{(int, optional): how many subprocesses to use for data 152 | loading. 0 means that the data will be loaded in the main process. 153 | (default: \code{0})} 154 | 155 | \item{skip_importance}{if feature importance calculation should be skipped (default: \code{FALSE})} 156 | } 157 | \value{ 158 | A named list with all hyperparameters of the TabNet implementation. 159 | } 160 | \description{ 161 | Configuration for TabNet models 162 | } 163 | \examples{ 164 | \dontshow{if ((torch::torch_is_installed() && require("modeldata"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 165 | data("ames", package = "modeldata") 166 | 167 | # change the model config for an faster ignite optimizer 168 | config <- tabnet_config(optimizer = torch::optim_ignite_adamw) 169 | 170 | ## Single-outcome regression using formula specification 171 | fit <- tabnet_fit(Sale_Price ~ ., data = ames, epochs = 1, config = config) 172 | \dontshow{\}) # examplesIf} 173 | } 174 | -------------------------------------------------------------------------------- /R/utils.R: -------------------------------------------------------------------------------- 1 | merge_config_and_dots <- function(config, ...) { 2 | default_config <- tabnet_config() 3 | new_config <- do.call(tabnet_config, list(...)) 4 | # TODO currently we cannot not compare two nn_optimizer nor nn_loss values 5 | new_config <- new_config[ 6 | mapply( 7 | function(x, y) ifelse(is_null_or_optim_generator_or_loss(x), 8 | !is_null_or_optim_generator_or_loss(y), # TRUE 9 | ifelse(is_optim_generator_or_loss(y), TRUE, x != y)), # FALSE 10 | default_config, 11 | new_config) 12 | ] 13 | merged_config <- utils::modifyList(config, as.list(new_config)) 14 | merged_config$optimizer <- resolve_optimizer(merged_config$optimizer) 15 | merged_config 16 | } 17 | # 18 | # is_different_param <- function(x, y) { 19 | # if (rlang::inherits_any(x, c("nn_loss", "nn_optim_generatorclass"))) { 20 | # 21 | # } 22 | # } 23 | 24 | check_net_is_empty_ptr <- function(object) { 25 | is_null_external_pointer(object$fit$network$.check$ptr) 26 | } 27 | 28 | # https://stackoverflow.com/a/27350487/3297472 29 | is_null_external_pointer <- function(pointer) { 30 | a <- attributes(pointer) 31 | attributes(pointer) <- NULL 32 | out <- identical(pointer, methods::new("externalptr")) 33 | attributes(pointer) <- a 34 | out 35 | } 36 | 37 | #' Check that Node object names are compliant 38 | #' 39 | #' @param node the Node object, or a dataframe ready to be parsed by `data.tree::as.Node()` 40 | #' 41 | #' @return node if it is compliant, else an Error with the column names to fix 42 | #' @export 43 | #' 44 | #' @examplesIf (require("data.tree") || require("dplyr")) 45 | #' library(dplyr) 46 | #' library(data.tree) 47 | #' data(starwars) 48 | #' starwars_tree <- starwars %>% 49 | #' mutate(pathString = paste("tree", species, homeworld, `name`, sep = "/")) 50 | #' 51 | #' # pre as.Node() check 52 | #' try(check_compliant_node(starwars_tree)) 53 | #' 54 | #' # post as.Node() check 55 | #' check_compliant_node(as.Node(starwars_tree)) 56 | #' 57 | check_compliant_node <- function(node) { 58 | # prevent reserved data.tree Node colnames and the level_1 ... level_n names used for coercion 59 | if (inherits(node, "Node")) { 60 | # Node has already lost its reserved colnames 61 | reserved_names <- paste0("level_", c(1:node$height)) 62 | actual_names <- node$attributesAll 63 | } else if (inherits(node, "data.frame") && "pathString" %in% colnames(node)) { 64 | node_height <- max(stringr::str_count(node$pathString, "/")) 65 | reserved_names <- c(paste0("level_", c(1:node_height)), data.tree::NODE_RESERVED_NAMES_CONST) 66 | actual_names <- colnames(node)[!colnames(node) %in% "pathString"] 67 | } else { 68 | type_error("The provided hierarchical object is not recognized with a valid format that can be checked") 69 | } 70 | 71 | if (any(actual_names %in% reserved_names)) { 72 | value_error("The attributes or colnames in the provided hierarchical object use the following reserved names: 73 | {.vars {actual_names[actual_names %in% reserved_names]}}. 74 | Please change those names as they will lead to unexpected tabnet behavior.") 75 | } 76 | 77 | invisible(node) 78 | } 79 | 80 | #' Turn a Node object into predictor and outcome. 81 | #' 82 | #' @param x Node object 83 | #' @param drop_last_level TRUE unused 84 | #' 85 | #' @return a named list of x and y, being respectively the predictor data-frame and the outcomes data-frame, 86 | #' as expected inputs for `hardhat::mold()` function. 87 | #' @export 88 | #' 89 | #' @examplesIf (require("data.tree") || require("dplyr")) 90 | #' library(dplyr) 91 | #' library(data.tree) 92 | #' data(starwars) 93 | #' starwars_tree <- starwars %>% 94 | #' mutate(pathString = paste("tree", species, homeworld, `name`, sep = "/")) %>% 95 | #' as.Node() 96 | #' node_to_df(starwars_tree)$x %>% head() 97 | #' node_to_df(starwars_tree)$y %>% head() 98 | #' @importFrom dplyr last_col mutate mutate_if select starts_with where 99 | node_to_df <- function(x, drop_last_level = TRUE) { 100 | # TODO get rid of all those import through base R equivalent 101 | xy_df <- data.tree::ToDataFrameTypeCol(x, x$attributesAll) 102 | x_df <- xy_df %>% 103 | select(-starts_with("level_")) %>% 104 | mutate_if(is.character, as.factor) 105 | y_df <- xy_df %>% 106 | select(starts_with("level_")) %>% 107 | # drop first (and all zero-variance) column 108 | select(where(~ nlevels(as.factor(.x)) > 1 )) %>% 109 | # TODO take the drop_last_level param into account 110 | # drop last level column 111 | select(-last_col()) %>% 112 | # TODO impute "NA" with parent through coalesce() via an option 113 | mutate_if(is.character, as.factor) 114 | return(list(x = x_df, y = y_df)) 115 | } 116 | 117 | 118 | model_to_raw <- function(model) { 119 | con <- rawConnection(raw(), open = "wr") 120 | torch::torch_save(model, con) 121 | on.exit({close(con)}, add = TRUE) 122 | r <- rawConnectionValue(con) 123 | r 124 | } 125 | 126 | # generalize torch to_device to nested list of tensors 127 | to_device <- function(x, device) { 128 | lapply(x, function(x) { 129 | if (inherits(x, "torch_tensor")) { 130 | x$to(device=device) 131 | } else if (is.list(x)) { 132 | lapply(x, to_device) 133 | } else { 134 | x 135 | } 136 | }) 137 | } 138 | 139 | # `optim_ignite_*` requires a minimum torch version 140 | torch_has_optim_ignite <- function() { 141 | utils::compareVersion(as.character(utils::packageVersion("torch")), "0.14.0") >= 0 142 | } 143 | 144 | # turn "adam" or a torch_optim_* generator into a proper torch_optim_ generator 145 | resolve_optimizer <- function(optimizer) { 146 | if (is_optim_generator(optimizer)) { 147 | torch_optimizer <- optimizer 148 | } else if (is.character(optimizer)) { 149 | if (optimizer == "adam" && torch_has_optim_ignite()) { 150 | torch_optimizer <- torch::optim_ignite_adam 151 | } else if (optimizer == "adam") { 152 | torch_optimizer <- torch::optim_adam 153 | } else { 154 | value_error("Currently only {.val adam} is supported as character for {.var optimizer}.") 155 | } 156 | } else { 157 | value_error("Currently only {.val adam} is supported as character for {.var optimizer}.") 158 | } 159 | torch_optimizer 160 | 161 | } 162 | 163 | 164 | is_optim_generator <- function(x) { 165 | inherits(x, "torch_optimizer_generator") 166 | } 167 | 168 | is_loss_generator <- function(x) { 169 | rlang::inherits_all(x, c("nn_loss", "nn_module_generator")) 170 | } 171 | 172 | is_null_or_optim_generator_or_loss <- function(x) { 173 | is.null(x) || is_optim_generator(x) || inherits(x, "nn_loss") 174 | } 175 | 176 | is_optim_generator_or_loss <- function(x) { 177 | is_optim_generator(x) || inherits(x, "nn_loss") 178 | } 179 | 180 | 181 | value_error <- function(..., env = rlang::caller_env()) { 182 | cli::cli_abort(gettext(..., domain = "R-tabnet")[[1]], .envir = env) 183 | } 184 | 185 | type_error <- function(..., env = rlang::caller_env()) { 186 | cli::cli_abort(gettext(..., domain = "R-tabnet")[[1]], .envir = env) 187 | } 188 | 189 | runtime_error <- function(..., env = rlang::caller_env()) { 190 | cli::cli_abort(gettext(..., domain = "R-tabnet")[[1]], .envir = env) 191 | } 192 | 193 | not_implemented_error <- function(..., env = rlang::caller_env()) { 194 | cli::cli_abort(gettext(..., domain = "R-tabnet")[[1]], .envir = env) 195 | } 196 | 197 | warn <- function(..., env = rlang::caller_env()) { 198 | cli::cli_warn(gettext(..., domain = "R-tabnet")[[1]], .envir = env) 199 | } 200 | -------------------------------------------------------------------------------- /tests/testthat/test-parsnip.R: -------------------------------------------------------------------------------- 1 | test_that("parsnip fit model works", { 2 | 3 | # default params 4 | expect_no_error( 5 | model <- tabnet() %>% 6 | parsnip::set_mode("regression") %>% 7 | parsnip::set_engine("torch") 8 | ) 9 | 10 | expect_no_error( 11 | fit <- model %>% 12 | parsnip::fit(Sale_Price ~ ., data = small_ames) 13 | ) 14 | 15 | # some setup params 16 | expect_no_error( 17 | model <- tabnet(epochs = 2, learn_rate = 1e-5) %>% 18 | parsnip::set_mode("regression") %>% 19 | parsnip::set_engine("torch") 20 | ) 21 | 22 | expect_no_error( 23 | fit <- model %>% 24 | parsnip::fit(Sale_Price ~ ., data = small_ames) 25 | ) 26 | 27 | # new batch of setup params 28 | expect_no_error( 29 | model <- tabnet(penalty = 0.2, verbose = FALSE, early_stopping_tolerance = 1e-3) %>% 30 | parsnip::set_mode("classification") %>% 31 | parsnip::set_engine("torch") 32 | ) 33 | 34 | expect_no_error( 35 | fit <- model %>% 36 | parsnip::fit(Overall_Cond ~ ., data = small_ames) 37 | ) 38 | 39 | }) 40 | 41 | test_that("parsnip fit model works from a pretrained model", { 42 | 43 | # default params 44 | expect_no_error( 45 | model <- tabnet(tabnet_model = ames_pretrain, from_epoch = 1, epoch = 1) %>% 46 | parsnip::set_mode("regression") %>% 47 | parsnip::set_engine("torch") 48 | ) 49 | 50 | expect_no_error( 51 | fit <- model %>% 52 | parsnip::fit(Sale_Price ~ ., data = small_ames) 53 | ) 54 | 55 | 56 | 57 | 58 | }) 59 | 60 | test_that("multi_predict works as expected", { 61 | 62 | model <- tabnet(checkpoint_epoch = 1) %>% 63 | parsnip::set_mode("regression") %>% 64 | parsnip::set_engine("torch") 65 | 66 | expect_no_error( 67 | fit <- model %>% 68 | parsnip::fit(Sale_Price ~ ., data = small_ames) 69 | ) 70 | 71 | preds <- parsnip::multi_predict(fit, small_ames, epochs = c(1,2,3,4,5)) 72 | 73 | expect_equal(nrow(preds), nrow(small_ames)) 74 | expect_equal(nrow(preds$.pred[[1]]), 5) 75 | }) 76 | 77 | test_that("Check we can finalize a workflow", { 78 | 79 | model <- tabnet(penalty = tune(), epochs = tune()) %>% 80 | parsnip::set_mode("regression") %>% 81 | parsnip::set_engine("torch") 82 | 83 | wf <- workflows::workflow() %>% 84 | workflows::add_model(model) %>% 85 | workflows::add_formula(Sale_Price ~ .) 86 | 87 | wf <- tune::finalize_workflow(wf, tibble::tibble(penalty = 0.01, epochs = 1)) 88 | 89 | expect_no_error( 90 | fit <- wf %>% parsnip::fit(data = small_ames) 91 | ) 92 | 93 | expect_equal(rlang::eval_tidy(wf$fit$actions$model$spec$args$penalty), 0.01) 94 | expect_equal(rlang::eval_tidy(wf$fit$actions$model$spec$args$epochs), 1) 95 | }) 96 | 97 | test_that("Check we can finalize a workflow from a tune_grid", { 98 | 99 | model <- tabnet(epochs = tune(), checkpoint_epochs = 1) %>% 100 | parsnip::set_mode("regression") %>% 101 | parsnip::set_engine("torch") 102 | 103 | wf <- workflows::workflow() %>% 104 | workflows::add_model(model) %>% 105 | workflows::add_formula(Sale_Price ~ .) 106 | 107 | custom_grid <- tidyr::crossing(epochs = c(1,2,3)) 108 | cv_folds <- small_ames %>% 109 | rsample::vfold_cv(v = 2, repeats = 1) 110 | 111 | at <- tune::tune_grid( 112 | object = wf, 113 | resamples = cv_folds, 114 | grid = custom_grid, 115 | metrics = yardstick::metric_set(yardstick::rmse), 116 | control = tune::control_grid(verbose = F) 117 | ) 118 | 119 | best_rmse <- tune::select_best(at, metric = "rmse") 120 | 121 | expect_no_error( 122 | final_wf <- tune::finalize_workflow(wf, best_rmse) 123 | ) 124 | }) 125 | 126 | test_that("tabnet grid reduction - torch", { 127 | 128 | mod <- tabnet() %>% 129 | parsnip::set_engine("torch") 130 | 131 | # A typical grid 132 | reg_grid <- expand.grid(epochs = 1:3, penalty = 1:2) 133 | reg_grid_smol <- tune::min_grid(mod, reg_grid) 134 | 135 | expect_equal(reg_grid_smol$epochs, rep(3, 2)) 136 | expect_equal(reg_grid_smol$penalty, 1:2) 137 | for (i in 1:nrow(reg_grid_smol)) { 138 | expect_equal(reg_grid_smol$.submodels[[i]], list(epochs = 1:2)) 139 | } 140 | 141 | # Unbalanced grid 142 | reg_ish_grid <- expand.grid(epochs = 1:3, penalty = 1:2)[-3, ] 143 | reg_ish_grid_smol <- tune::min_grid(mod, reg_ish_grid) 144 | 145 | expect_equal(reg_ish_grid_smol$epochs, 2:3) 146 | expect_equal(reg_ish_grid_smol$penalty, 1:2) 147 | for (i in 2:nrow(reg_ish_grid_smol)) { 148 | expect_equal(reg_ish_grid_smol$.submodels[[i]], list(epochs = 1:2)) 149 | } 150 | 151 | # Grid with a third parameter 152 | reg_grid_extra <- expand.grid(epochs = 1:3, penalty = 1:2, batch_size = 10:12) 153 | reg_grid_extra_smol <- tune::min_grid(mod, reg_grid_extra) 154 | 155 | expect_equal(reg_grid_extra_smol$epochs, rep(3, 6)) 156 | expect_equal(reg_grid_extra_smol$penalty, rep(1:2, each = 3)) 157 | expect_equal(reg_grid_extra_smol$batch_size, rep(10:12, 2)) 158 | for (i in 1:nrow(reg_grid_extra_smol)) { 159 | expect_equal(reg_grid_extra_smol$.submodels[[i]], list(epochs = 1:2)) 160 | } 161 | 162 | # Only epochs 163 | only_epochs <- expand.grid(epochs = 1:3) 164 | only_epochs_smol <- tune::min_grid(mod, only_epochs) 165 | 166 | expect_equal(only_epochs_smol$epochs, 3) 167 | expect_equal(only_epochs_smol$.submodels, list(list(epochs = 1:2))) 168 | 169 | # No submodels 170 | no_sub <- tibble::tibble(epochs = 1, penalty = 1:2) 171 | no_sub_smol <- tune::min_grid(mod, no_sub) 172 | 173 | expect_equal(no_sub_smol$epochs, rep(1, 2)) 174 | expect_equal(no_sub_smol$penalty, 1:2) 175 | for (i in 1:nrow(no_sub_smol)) { 176 | expect_length(no_sub_smol$.submodels[[i]], 0) 177 | } 178 | 179 | # different id names 180 | mod_1 <- tabnet(epochs = tune("Amos")) %>% 181 | parsnip::set_engine("torch") 182 | reg_grid <- expand.grid(Amos = 1:3, penalty = 1:2) 183 | reg_grid_smol <- tune::min_grid(mod_1, reg_grid) 184 | 185 | expect_equal(reg_grid_smol$Amos, rep(3, 2)) 186 | expect_equal(reg_grid_smol$penalty, 1:2) 187 | for (i in 1:nrow(reg_grid_smol)) { 188 | expect_equal(reg_grid_smol$.submodels[[i]], list(Amos = 1:2)) 189 | } 190 | 191 | all_sub <- expand.grid(Amos = 1:3) 192 | all_sub_smol <- tune::min_grid(mod_1, all_sub) 193 | 194 | expect_equal(all_sub_smol$Amos, 3) 195 | expect_equal(all_sub_smol$.submodels[[1]], list(Amos = 1:2)) 196 | 197 | mod_2 <- tabnet(epochs = tune("Ade Tukunbo")) %>% 198 | parsnip::set_engine("torch") 199 | reg_grid <- expand.grid(`Ade Tukunbo` = 1:3, penalty = 1:2, ` \t123` = 10:11) 200 | reg_grid_smol <- tune::min_grid(mod_2, reg_grid) 201 | 202 | expect_equal(reg_grid_smol$`Ade Tukunbo`, rep(3, 4)) 203 | expect_equal(reg_grid_smol$penalty, rep(1:2, each = 2)) 204 | expect_equal(reg_grid_smol$` \t123`, rep(10:11, 2)) 205 | for (i in 1:nrow(reg_grid_smol)) { 206 | expect_equal(reg_grid_smol$.submodels[[i]], list(`Ade Tukunbo` = 1:2)) 207 | } 208 | }) 209 | 210 | test_that("Check workflow can use case_weight", { 211 | 212 | small_ames_cw <- small_ames %>% dplyr::mutate(case_weight = hardhat::frequency_weights(Year_Sold)) 213 | model <- tabnet(epochs = 2) %>% 214 | parsnip::set_mode("regression") %>% 215 | parsnip::set_engine("torch") 216 | 217 | wf <- workflows::workflow() %>% 218 | workflows::add_model(model) %>% 219 | workflows::add_formula(Sale_Price ~ .) %>% 220 | workflows::add_case_weights(case_weight) 221 | 222 | expect_no_error( 223 | fit <- wf %>% parsnip::fit(data = small_ames_cw) 224 | ) 225 | expect_no_error( 226 | predict(fit, small_ames) 227 | ) 228 | 229 | 230 | }) 231 | -------------------------------------------------------------------------------- /man/tabnet.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/parsnip.R 3 | \name{tabnet} 4 | \alias{tabnet} 5 | \title{Parsnip compatible tabnet model} 6 | \usage{ 7 | tabnet( 8 | mode = "unknown", 9 | cat_emb_dim = NULL, 10 | decision_width = NULL, 11 | attention_width = NULL, 12 | num_steps = NULL, 13 | mask_type = NULL, 14 | mask_topk = NULL, 15 | num_independent = NULL, 16 | num_shared = NULL, 17 | num_independent_decoder = NULL, 18 | num_shared_decoder = NULL, 19 | penalty = NULL, 20 | feature_reusage = NULL, 21 | momentum = NULL, 22 | epochs = NULL, 23 | batch_size = NULL, 24 | virtual_batch_size = NULL, 25 | learn_rate = NULL, 26 | optimizer = NULL, 27 | loss = NULL, 28 | clip_value = NULL, 29 | drop_last = NULL, 30 | lr_scheduler = NULL, 31 | rate_decay = NULL, 32 | rate_step_size = NULL, 33 | checkpoint_epochs = NULL, 34 | verbose = NULL, 35 | importance_sample_size = NULL, 36 | early_stopping_monitor = NULL, 37 | early_stopping_tolerance = NULL, 38 | early_stopping_patience = NULL, 39 | skip_importance = NULL, 40 | tabnet_model = NULL, 41 | from_epoch = NULL 42 | ) 43 | } 44 | \arguments{ 45 | \item{mode}{A single character string for the type of model. Possible values 46 | for this model are "unknown", "regression", or "classification".} 47 | 48 | \item{cat_emb_dim}{Size of the embedding of categorical features. If int, all categorical 49 | features will have same embedding size, if list of int, every corresponding feature will have 50 | specific embedding size.} 51 | 52 | \item{decision_width}{(int) Width of the decision prediction layer. Bigger values gives 53 | more capacity to the model with the risk of overfitting. Values typically 54 | range from 8 to 64.} 55 | 56 | \item{attention_width}{(int) Width of the attention embedding for each mask. According to 57 | the paper n_d = n_a is usually a good choice. (default=8)} 58 | 59 | \item{num_steps}{(int) Number of steps in the architecture 60 | (usually between 3 and 10)} 61 | 62 | \item{mask_type}{(character) Final layer of feature selector in the attentive_transformer 63 | block, either \code{"sparsemax"}, \code{"entmax"} or \code{"entmax15"}.Defaults to \code{"sparsemax"}.} 64 | 65 | \item{mask_topk}{(int) mask sparsity top-k for \code{sparsemax15} and \code{entmax15.} See \code{\link[=entmax15]{entmax15()}} for detail.} 66 | 67 | \item{num_independent}{Number of independent Gated Linear Units layers at each step of the encoder. 68 | Usual values range from 1 to 5.} 69 | 70 | \item{num_shared}{Number of shared Gated Linear Units at each step of the encoder. Usual values 71 | at each step of the decoder. range from 1 to 5} 72 | 73 | \item{num_independent_decoder}{For pretraining, number of independent Gated Linear Units layers 74 | Usual values range from 1 to 5.} 75 | 76 | \item{num_shared_decoder}{For pretraining, number of shared Gated Linear Units at each step of the 77 | decoder. Usual values range from 1 to 5.} 78 | 79 | \item{penalty}{This is the extra sparsity loss coefficient as proposed 80 | in the original paper. The bigger this coefficient is, the sparser your model 81 | will be in terms of feature selection. Depending on the difficulty of your 82 | problem, reducing this value could help (default 1e-3).} 83 | 84 | \item{feature_reusage}{(num) This is the coefficient for feature reusage in the masks. 85 | A value close to 1 will make mask selection least correlated between layers. 86 | Values range from 1 to 2.} 87 | 88 | \item{momentum}{Momentum for batch normalization, typically ranges from 0.01 89 | to 0.4 (default=0.02)} 90 | 91 | \item{epochs}{(int) Number of training epochs.} 92 | 93 | \item{batch_size}{(int) Number of examples per batch, large batch sizes are 94 | recommended. (default: 1024^2)} 95 | 96 | \item{virtual_batch_size}{(int) Size of the mini batches used for 97 | "Ghost Batch Normalization" (default=256^2)} 98 | 99 | \item{learn_rate}{initial learning rate for the optimizer.} 100 | 101 | \item{optimizer}{the optimization method. currently only \code{"adam"} is supported, 102 | you can also pass any torch optimizer function.} 103 | 104 | \item{loss}{(character or function) Loss function for training (default to mse 105 | for regression and cross entropy for classification)} 106 | 107 | \item{clip_value}{If a num is given this will clip the gradient at 108 | clip_value. Pass \code{NULL} to not clip.} 109 | 110 | \item{drop_last}{(logical) Whether to drop last batch if not complete during 111 | training} 112 | 113 | \item{lr_scheduler}{if \code{NULL}, no learning rate decay is used. If "step" 114 | decays the learning rate by \code{lr_decay} every \code{step_size} epochs. If "reduce_on_plateau" 115 | decays the learning rate by \code{lr_decay} when no improvement after \code{step_size} epochs. 116 | It can also be a \link[torch:lr_scheduler]{torch::lr_scheduler} function that only takes the optimizer 117 | as parameter. The \code{step} method is called once per epoch.} 118 | 119 | \item{rate_decay}{multiplies the initial learning rate by \code{rate_decay} every 120 | \code{rate_step_size} epochs. Unused if \code{lr_scheduler} is a \code{torch::lr_scheduler} 121 | or \code{NULL}.} 122 | 123 | \item{rate_step_size}{the learning rate scheduler step size. Unused if 124 | \code{lr_scheduler} is a \code{torch::lr_scheduler} or \code{NULL}.} 125 | 126 | \item{checkpoint_epochs}{checkpoint model weights and architecture every 127 | \code{checkpoint_epochs}. (default is 10). This may cause large memory usage. 128 | Use \code{0} to disable checkpoints.} 129 | 130 | \item{verbose}{(logical) Whether to print progress and loss values during 131 | training.} 132 | 133 | \item{importance_sample_size}{sample of the dataset to compute importance metrics. 134 | If the dataset is larger than 1e5 obs we will use a sample of size 1e5 and 135 | display a warning.} 136 | 137 | \item{early_stopping_monitor}{Metric to monitor for early_stopping. One of "valid_loss", "train_loss" or "auto" (defaults to "auto").} 138 | 139 | \item{early_stopping_tolerance}{Minimum relative improvement to reset the patience counter. 140 | 0.01 for 1\% tolerance (default 0)} 141 | 142 | \item{early_stopping_patience}{Number of epochs without improving until stopping training. (default=5)} 143 | 144 | \item{skip_importance}{if feature importance calculation should be skipped (default: \code{FALSE})} 145 | 146 | \item{tabnet_model}{A previously fitted \code{tabnet_model} object to continue the fitting on. 147 | if \code{NULL} (the default) a brand new model is initialized.} 148 | 149 | \item{from_epoch}{When a \code{tabnet_model} is provided, restore the network weights from a specific epoch. 150 | Default is last available checkpoint for restored model, or last epoch for in-memory model.} 151 | } 152 | \value{ 153 | A TabNet \code{parsnip} instance. It can be used to fit tabnet models using 154 | \code{parsnip} machinery. 155 | } 156 | \description{ 157 | Parsnip compatible tabnet model 158 | } 159 | \section{Threading}{ 160 | 161 | 162 | TabNet uses \code{torch} as its backend for computation and \code{torch} uses all 163 | available threads by default. 164 | 165 | You can control the number of threads used by \code{torch} with: 166 | 167 | \if{html}{\out{
}}\preformatted{torch::torch_set_num_threads(1) 168 | torch::torch_set_num_interop_threads(1) 169 | }\if{html}{\out{
}} 170 | } 171 | 172 | \examples{ 173 | \dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 174 | library(parsnip) 175 | data("ames", package = "modeldata") 176 | model <- tabnet() \%>\% 177 | set_mode("regression") \%>\% 178 | set_engine("torch") 179 | model \%>\% 180 | fit(Sale_Price ~ ., data = ames) 181 | \dontshow{\}) # examplesIf} 182 | } 183 | \seealso{ 184 | tabnet_fit 185 | } 186 | -------------------------------------------------------------------------------- /README.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | output: github_document 3 | --- 4 | 5 | 6 | 7 | ```{r, include = FALSE} 8 | knitr::opts_chunk$set( 9 | collapse = TRUE, 10 | comment = "#>", 11 | fig.path = "man/figures/README-", 12 | out.width = "100%" 13 | ) 14 | ``` 15 | 16 | # tabnet 17 | 18 | 19 | 20 | [![R build status](https://github.com/mlverse/tabnet/workflows/R-CMD-check/badge.svg)](https://github.com/mlverse/tabnet/actions) [![Lifecycle: experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html) [![CRAN status](https://www.r-pkg.org/badges/version/tabnet)](https://CRAN.R-project.org/package=tabnet) [![](https://cranlogs.r-pkg.org/badges/tabnet)](https://cran.r-project.org/package=tabnet) [![Discord](https://img.shields.io/discord/837019024499277855?logo=discord)](https://discord.com/invite/s3D5cKhBkx) 21 | 22 | 23 | 24 | An R implementation of: [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442) [(Sercan O. Arik, Tomas Pfister)](https://doi.org/10.48550/arXiv.1908.07442).\ 25 | 26 | The code in this repository started by an R port using the [torch](https://github.com/mlverse/torch) package of [dreamquark-ai/tabnet](https://github.com/dreamquark-ai/tabnet) implementation. 27 | 28 | TabNet is now augmented with 29 | 30 | - [Coherent Hierarchical Multi-label Classification Networks](https://proceedings.neurips.cc//paper/2020/file/6dd4e10e3296fa63738371ec0d5df818-Paper.pdf) [(Eleonora Giunchiglia et Al.)](https://doi.org/10.48550/arXiv.2010.10151) for hierarchical outcomes 31 | 32 | - [Optimizing ROC Curves with a Sort-Based Surrogate Loss for Binary Classification and Changepoint Detection (J Hillman, TD Hocking)](https://jmlr.org/papers/v24/21-0751.html) for imbalanced binary classification. 33 | 34 | ## Installation 35 | 36 | Install [{tabnet} from CRAN](https://CRAN.R-project.org/package=tabnet) with: 37 | 38 | ``` r 39 | install.packages('tabnet') 40 | ``` 41 | 42 | The development version can be installed from [GitHub](https://github.com/mlverse/tabnet) with: 43 | 44 | ``` r 45 | # install.packages("pak") 46 | pak::pak("mlverse/tabnet") 47 | ``` 48 | 49 | ## Basic Binary Classification Example 50 | 51 | Here we show a **binary classification** example of the `attrition` dataset, using a **recipe** for dataset input specification. 52 | 53 | ```{r model-fit} 54 | #| fig.alt: "A training loss line-plot along training epochs. Both validation loss and training loss are shown. Training loss line includes regular dots at epochs where a checkpoint is recorded." 55 | library(tabnet) 56 | suppressPackageStartupMessages(library(recipes)) 57 | library(yardstick) 58 | library(ggplot2) 59 | set.seed(1) 60 | 61 | data("attrition", package = "modeldata") 62 | test_idx <- sample.int(nrow(attrition), size = 0.2 * nrow(attrition)) 63 | 64 | train <- attrition[-test_idx,] 65 | test <- attrition[test_idx,] 66 | 67 | rec <- recipe(Attrition ~ ., data = train) %>% 68 | step_normalize(all_numeric(), -all_outcomes()) 69 | 70 | fit <- tabnet_fit(rec, train, epochs = 30, valid_split=0.1, learn_rate = 5e-3) 71 | autoplot(fit) 72 | ``` 73 | 74 | The plots gives you an immediate insight about model over-fitting, and if any, the available model checkpoints available before the over-fitting 75 | 76 | Keep in mind that **regression** as well as **multi-class classification** are also available, and that you can specify dataset through **data.frame** and **formula** as well. You will find them in the package vignettes. 77 | 78 | ## Model performance results 79 | 80 | As the standard method `predict()` is used, you can rely on your usual metric functions for model performance results. Here we use {yardstick} : 81 | 82 | ```{r} 83 | metrics <- metric_set(accuracy, precision, recall) 84 | cbind(test, predict(fit, test)) %>% 85 | metrics(Attrition, estimate = .pred_class) 86 | 87 | cbind(test, predict(fit, test, type = "prob")) %>% 88 | roc_auc(Attrition, .pred_No) 89 | ``` 90 | 91 | ## Explain model on test-set with attention map 92 | 93 | TabNet has intrinsic explainability feature through the visualization of attention map, either **aggregated**: 94 | 95 | ```{r model-explain} 96 | #| fig.alt: "An expainability plot showing for each variable of the test-set on the y axis the importance along each observation on the x axis. The value is a mask agggregate." 97 | explain <- tabnet_explain(fit, test) 98 | autoplot(explain) 99 | ``` 100 | 101 | or at **each layer** through the `type = "steps"` option: 102 | 103 | ```{r step-explain} 104 | #| fig.alt: "An small-multiple expainability plot for each step of the Tabnet network. Each plot shows for each variable of the test-set on the y axis the importance along each observation on the x axis." 105 | autoplot(explain, type = "steps") 106 | ``` 107 | 108 | ## Self-supervised pretraining 109 | 110 | For cases when a consistent part of your dataset has no outcome, TabNet offers a self-supervised training step allowing to model to capture predictors intrinsic features and predictors interactions, upfront the supervised task. 111 | 112 | ```{r step-pretrain} 113 | #| fig.alt: "A training loss line-plot along pre-training epochs. Both validation loss and training loss are shown. Training loss line includes regular dots at epochs where a checkpoint is recorded." 114 | pretrain <- tabnet_pretrain(rec, train, epochs = 50, valid_split=0.1, learn_rate = 1e-2) 115 | autoplot(pretrain) 116 | ``` 117 | 118 | The example here is a toy example as the `train` dataset does actually contain outcomes. The vignette [`vignette("selfsupervised_training")`](articles/selfsupervised_training.html) will gives you the complete correct workflow step-by-step. 119 | 120 | ## {tidymodels} integration 121 | 122 | The integration within tidymodels workflows offers you unlimited opportunity to compare {tabnet} models with challengers. 123 | 124 | Don't miss the [`vignette("tidymodels-interface")`](articles/tidymodels-interface.html) for that. 125 | 126 | ## Missing data in predictors 127 | 128 | {tabnet} leverage the masking mechanism to deal with missing data, so you don't have to remove the entries in your dataset with some missing values in the predictors variables. 129 | 130 | See [`vignette("Missing_data_predictors")`](articles/Missing_data_predictors.html) 131 | 132 | ## Imbalanced binary classification 133 | 134 | {tabnet} includes a Area under the $Min(FPR,FNR)$ (AUM) loss function `nn_aum_loss()` dedicated to your imbalanced binary classification tasks. 135 | 136 | Try it out in [`vignette("aum_loss")`](articles/aum_loss.html) 137 | 138 | # Comparison with other implementations 139 | 140 | | Group | Feature | {tabnet} | dreamquark-ai | fast-tabnet | 141 | |---------------|---------------|:-------------:|:-------------:|:-------------:| 142 | | Input format | data-frame | ✅ | ✅ | ✅ | 143 | | | formula | ✅ | | | 144 | | | recipe | ✅ | | | 145 | | | Node | ✅ | | | 146 | | | missings in predictor | ✅ | | | 147 | | Output format | data-frame | ✅ | ✅ | ✅ | 148 | | | workflow | ✅ | | | 149 | | ML Tasks | self-supervised learning | ✅ | ✅ | | 150 | | | classification (binary, multi-class) | ✅ | ✅ | ✅ | 151 | | | unbalanced binary classification | ✅ | | | 152 | | | regression | ✅ | ✅ | ✅ | 153 | | | multi-outcome | ✅ | ✅ | | 154 | | | hierarchical multi-label classif. | ✅ | | | 155 | | Model management | from / to file | ✅ | ✅ | v | 156 | | | resume from snapshot | ✅ | | | 157 | | | training diagnostic | ✅ | | | 158 | | Interpretability | | ✅ | ✅ | ✅ | 159 | | Performance | | 1 x | 2 - 4 x | | 160 | | Code quality | test coverage | 85% | | | 161 | | | continuous integration | 4 OS including GPU | | | 162 | 163 | : Alternative TabNet implementation features 164 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # tabnet 5 | 6 | 7 | 8 | [![R build 9 | status](https://github.com/mlverse/tabnet/workflows/R-CMD-check/badge.svg)](https://github.com/mlverse/tabnet/actions) 10 | [![Lifecycle: 11 | experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html) 12 | [![CRAN 13 | status](https://www.r-pkg.org/badges/version/tabnet)](https://CRAN.R-project.org/package=tabnet) 14 | [![](https://cranlogs.r-pkg.org/badges/tabnet)](https://cran.r-project.org/package=tabnet) 15 | [![Discord](https://img.shields.io/discord/837019024499277855?logo=discord)](https://discord.com/invite/s3D5cKhBkx) 16 | 17 | 18 | 19 | An R implementation of: [TabNet: Attentive Interpretable Tabular 20 | Learning](https://arxiv.org/abs/1908.07442) [(Sercan O. Arik, Tomas 21 | Pfister)](https://doi.org/10.48550/arXiv.1908.07442). 22 | 23 | The code in this repository started by an R port using the 24 | [torch](https://github.com/mlverse/torch) package of 25 | [dreamquark-ai/tabnet](https://github.com/dreamquark-ai/tabnet) 26 | implementation. 27 | 28 | TabNet is now augmented with 29 | 30 | - [Coherent Hierarchical Multi-label Classification 31 | Networks](https://proceedings.neurips.cc//paper/2020/file/6dd4e10e3296fa63738371ec0d5df818-Paper.pdf) 32 | [(Eleonora Giunchiglia et 33 | Al.)](https://doi.org/10.48550/arXiv.2010.10151) for hierarchical 34 | outcomes 35 | 36 | - [Optimizing ROC Curves with a Sort-Based Surrogate Loss for Binary 37 | Classification and Changepoint Detection (J Hillman, TD 38 | Hocking)](https://jmlr.org/papers/v24/21-0751.html) for imbalanced 39 | binary classification. 40 | 41 | ## Installation 42 | 43 | Install [{tabnet} from CRAN](https://CRAN.R-project.org/package=tabnet) 44 | with: 45 | 46 | ``` r 47 | install.packages('tabnet') 48 | ``` 49 | 50 | The development version can be installed from 51 | [GitHub](https://github.com/mlverse/tabnet) with: 52 | 53 | ``` r 54 | # install.packages("pak") 55 | pak::pak("mlverse/tabnet") 56 | ``` 57 | 58 | ## Basic Binary Classification Example 59 | 60 | Here we show a **binary classification** example of the `attrition` 61 | dataset, using a **recipe** for dataset input specification. 62 | 63 | ``` r 64 | library(tabnet) 65 | suppressPackageStartupMessages(library(recipes)) 66 | library(yardstick) 67 | library(ggplot2) 68 | set.seed(1) 69 | 70 | data("attrition", package = "modeldata") 71 | test_idx <- sample.int(nrow(attrition), size = 0.2 * nrow(attrition)) 72 | 73 | train <- attrition[-test_idx,] 74 | test <- attrition[test_idx,] 75 | 76 | rec <- recipe(Attrition ~ ., data = train) %>% 77 | step_normalize(all_numeric(), -all_outcomes()) 78 | 79 | fit <- tabnet_fit(rec, train, epochs = 30, valid_split=0.1, learn_rate = 5e-3) 80 | autoplot(fit) 81 | ``` 82 | 83 | A training loss line-plot along training epochs. Both validation loss and training loss are shown. Training loss line includes regular dots at epochs where a checkpoint is recorded. 84 | 85 | The plots gives you an immediate insight about model over-fitting, and 86 | if any, the available model checkpoints available before the 87 | over-fitting 88 | 89 | Keep in mind that **regression** as well as **multi-class 90 | classification** are also available, and that you can specify dataset 91 | through **data.frame** and **formula** as well. You will find them in 92 | the package vignettes. 93 | 94 | ## Model performance results 95 | 96 | As the standard method `predict()` is used, you can rely on your usual 97 | metric functions for model performance results. Here we use {yardstick} 98 | : 99 | 100 | ``` r 101 | metrics <- metric_set(accuracy, precision, recall) 102 | cbind(test, predict(fit, test)) %>% 103 | metrics(Attrition, estimate = .pred_class) 104 | #> # A tibble: 3 × 3 105 | #> .metric .estimator .estimate 106 | #> 107 | #> 1 accuracy binary 0.840 108 | #> 2 precision binary 0.840 109 | #> 3 recall binary 1 110 | 111 | cbind(test, predict(fit, test, type = "prob")) %>% 112 | roc_auc(Attrition, .pred_No) 113 | #> # A tibble: 1 × 3 114 | #> .metric .estimator .estimate 115 | #> 116 | #> 1 roc_auc binary 0.466 117 | ``` 118 | 119 | ## Explain model on test-set with attention map 120 | 121 | TabNet has intrinsic explainability feature through the visualization of 122 | attention map, either **aggregated**: 123 | 124 | ``` r 125 | explain <- tabnet_explain(fit, test) 126 | autoplot(explain) 127 | ``` 128 | 129 | An expainability plot showing for each variable of the test-set on the y axis the importance along each observation on the x axis. The value is a mask agggregate. 130 | 131 | or at **each layer** through the `type = "steps"` option: 132 | 133 | ``` r 134 | autoplot(explain, type = "steps") 135 | ``` 136 | 137 | An small-multiple expainability plot for each step of the Tabnet network. Each plot shows for each variable of the test-set on the y axis the importance along each observation on the x axis. 138 | 139 | ## Self-supervised pretraining 140 | 141 | For cases when a consistent part of your dataset has no outcome, TabNet 142 | offers a self-supervised training step allowing to model to capture 143 | predictors intrinsic features and predictors interactions, upfront the 144 | supervised task. 145 | 146 | ``` r 147 | pretrain <- tabnet_pretrain(rec, train, epochs = 50, valid_split=0.1, learn_rate = 1e-2) 148 | autoplot(pretrain) 149 | ``` 150 | 151 | A training loss line-plot along pre-training epochs. Both validation loss and training loss are shown. Training loss line includes regular dots at epochs where a checkpoint is recorded. 152 | 153 | The example here is a toy example as the `train` dataset does actually 154 | contain outcomes. The vignette 155 | [`vignette("selfsupervised_training")`](articles/selfsupervised_training.html) 156 | will gives you the complete correct workflow step-by-step. 157 | 158 | ## {tidymodels} integration 159 | 160 | The integration within tidymodels workflows offers you unlimited 161 | opportunity to compare {tabnet} models with challengers. 162 | 163 | Don’t miss the 164 | [`vignette("tidymodels-interface")`](articles/tidymodels-interface.html) 165 | for that. 166 | 167 | ## Missing data in predictors 168 | 169 | {tabnet} leverage the masking mechanism to deal with missing data, so 170 | you don’t have to remove the entries in your dataset with some missing 171 | values in the predictors variables. 172 | 173 | See 174 | [`vignette("Missing_data_predictors")`](articles/Missing_data_predictors.html) 175 | 176 | ## Imbalanced binary classification 177 | 178 | {tabnet} includes a Area under the $Min(FPR,FNR)$ (AUM) loss function 179 | `nn_aum_loss()` dedicated to your imbalanced binary classification 180 | tasks. 181 | 182 | Try it out in [`vignette("aum_loss")`](articles/aum_loss.html) 183 | 184 | # Comparison with other implementations 185 | 186 | | Group | Feature | {tabnet} | dreamquark-ai | fast-tabnet | 187 | |----|----|:--:|:--:|:--:| 188 | | Input format | data-frame | ✅ | ✅ | ✅ | 189 | | | formula | ✅ | | | 190 | | | recipe | ✅ | | | 191 | | | Node | ✅ | | | 192 | | | missings in predictor | ✅ | | | 193 | | Output format | data-frame | ✅ | ✅ | ✅ | 194 | | | workflow | ✅ | | | 195 | | ML Tasks | self-supervised learning | ✅ | ✅ | | 196 | | | classification (binary, multi-class) | ✅ | ✅ | ✅ | 197 | | | unbalanced binary classification | ✅ | | | 198 | | | regression | ✅ | ✅ | ✅ | 199 | | | multi-outcome | ✅ | ✅ | | 200 | | | hierarchical multi-label classif. | ✅ | | | 201 | | Model management | from / to file | ✅ | ✅ | v | 202 | | | resume from snapshot | ✅ | | | 203 | | | training diagnostic | ✅ | | | 204 | | Interpretability | | ✅ | ✅ | ✅ | 205 | | Performance | | 1 x | 2 - 4 x | | 206 | | Code quality | test coverage | 85% | | | 207 | | | continuous integration | 4 OS including GPU | | | 208 | 209 | Alternative TabNet implementation features 210 | -------------------------------------------------------------------------------- /vignettes/aum_loss.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Using ROC AUM loss for imbalanced binary classification" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Using ROC AUM loss for imbalanced binary classification} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | editor_options: 9 | markdown: 10 | fig-width: 9 11 | fig-height: 6 12 | fig-cap-location: "top" 13 | --- 14 | 15 | ```{r, include = FALSE} 16 | knitr::opts_chunk$set( 17 | collapse = TRUE, 18 | comment = "#>", 19 | eval = torch::torch_is_installed() 20 | ) 21 | ``` 22 | 23 | ```{r setup} 24 | library(tabnet) 25 | suppressPackageStartupMessages(library(tidymodels)) 26 | library(modeldata) 27 | data("lending_club", package = "modeldata") 28 | set.seed(20250809) 29 | ``` 30 | 31 | ::: callout-note 32 | This vignette is a continuation of `vignette("tidymodels-interface")`. So we highly encourage you to start with it to be up to speed with this vignette. 33 | ::: 34 | 35 | ## Introduction 36 | 37 | The previously used `lending_club` dataset is highly imbalanced, leading to challenging result in the binary classification task. Despite we got fairly good accuracy with default model design, the `roc_auc()` metric was poor, mainly due to this imbalanced problem. 38 | 39 | Here, we will see how tabnet features allow improved performance on such family of classification problems. 40 | 41 | ## How imbalance is my problem ? 42 | 43 | The target variable `Class` imbalance can be evaluated through the class imbalance Ratio : 44 | 45 | ```{r} 46 | class_ratio <- lending_club |> 47 | summarise(sum( Class == "good") / sum( Class == "bad")) |> 48 | pull() 49 | 50 | class_ratio 51 | ``` 52 | 53 | With a class_ratio of 18.1, the target variable is seriously imbalanced, making the minority class much harder to model. 54 | 55 | ## Solutions to improve imbalanced classification models 56 | 57 | First, usual solution to such problem is over-sampling of the minority class, and/or down-sampling the majority class in the training data. We won't cover this here. 58 | 59 | The second solution is case weighting. As {tidymodels} offers the framework to manage such case weighting, we'll first use it to compare two model families - XGBoost and Tabnet - with that feature. 60 | 61 | Last, we would like to also **optimize** the model according to the metric we are looking at. As the metric of choice for imbalanced dataset are `roc_auc()` or `roc_pr()`, we definitively want a **loss function** that is a proxy of those. This loss is available in {tabnet} with the `nn_aum_loss()` from [Optimizing ROC Curves with a Sort-Based Surrogate Loss for Binary Classification and Changepoint Detection (J Hillman, TD Hocking)](https://jmlr.org/papers/v24/21-0751.html). 62 | 63 | ## Using the AUC metric and `pr_curve()` plots 64 | 65 | Measuring the ROC_AUC or AUC_PR can't be separated from plotting the `pr_curve()`. 66 | 67 | Let's baseline our models on two different workflows, one for tabnet, the other for XGBoost. This is a big chunk of code, but it is mainly a copy of the previous vignette. 68 | 69 | ```{r} 70 | lending_club <- lending_club |> 71 | mutate( 72 | case_wts = if_else(Class == "bad", class_ratio, 1), 73 | case_wts = importance_weights(case_wts) 74 | ) 75 | 76 | split <- initial_split(lending_club, strata = Class) 77 | train <- training(split) 78 | test <- testing(split) 79 | 80 | tab_rec <- train |> 81 | recipe() |> 82 | update_role(Class, new_role = "outcome") |> 83 | update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor") 84 | 85 | xgb_rec <- tab_rec |> 86 | step_dummy(term, sub_grade, addr_state, verification_status, emp_length) 87 | 88 | tab_mod <- tabnet(epochs = 100) |> 89 | set_engine("torch", device = "cpu") |> 90 | set_mode("classification") 91 | 92 | xgb_mod <- boost_tree(trees = 100) |> 93 | set_engine("xgboost") |> 94 | set_mode("classification") 95 | 96 | tab_wf <- workflow() |> 97 | add_model(tab_mod) |> 98 | add_recipe(tab_rec) |> 99 | add_case_weights(case_wts) 100 | 101 | xgb_wf <- workflow() |> 102 | add_model(xgb_mod) |> 103 | add_recipe(xgb_rec) |> 104 | add_case_weights(case_wts) 105 | ``` 106 | 107 | Few details to be noticed 108 | 109 | - we compute and tag as `importance_weight()` a new column `case_wts` in the `lending_club` dataset. 110 | 111 | - this column is excluded from the recipe() predictors role. 112 | 113 | - we explicitly mention this column role in each `workflow()` via `add_case_weights()`. 114 | 115 | We can now `fit()` each model and plot the precision-recall curve on the test-set : 116 | 117 | ```{r} 118 | #| label: "vanilia_models_fitting" 119 | #| layout.ncol: 2 120 | #| fig.cap: 121 | #| - "Tabnet, no case-weight, default loss" 122 | #| - "XGBoost, no case-weight" 123 | #| 124 | tab_fit <- tab_wf |> fit(train) 125 | xgb_fit <- xgb_wf |> fit(train) 126 | 127 | tab_test <- tab_fit |> augment(test) 128 | xgb_test <- xgb_fit |> augment(test) 129 | 130 | tab_test |> 131 | pr_curve(Class, .pred_good) |> 132 | autoplot() 133 | 134 | xgb_test |> 135 | pr_curve(Class, .pred_good) |> 136 | autoplot() 137 | 138 | ``` 139 | 140 | Both models are returning poor results. 141 | 142 | ## Case-weight 143 | 144 | Weighting each observation by the importance weight of the class is made available in {tabnet} through 145 | 146 | - marking one variable as importance weight variable via `workflow::add_case_weights()` 147 | 148 | - using the case_weight variable as such at inference time through the `case_weights =` parameter in functions that allows it. 149 | 150 | Let's proceed 151 | 152 | ```{r} 153 | #| label: "case-weights_prediction" 154 | #| layout.ncol: 2 155 | #| fig.cap: 156 | #| - "Tabnet, with case-weight, default loss" 157 | #| - "XGBoost, with case-weight" 158 | #| 159 | tab_test |> 160 | pr_curve(Class, .pred_good, case_weights = case_wts) |> 161 | autoplot() 162 | 163 | xgb_test |> 164 | pr_curve(Class, .pred_good, case_weights = case_wts) |> 165 | autoplot() 166 | ``` 167 | 168 | The boost on the `pr_curve()` is impressive for both models, Tabnet remains behind XGBoost here[^1]. 169 | 170 | [^1]: Or may become leader if you change the initial random seed. 171 | 172 | ## ROC_AUM loss 173 | 174 | {tabnet} implement the ROC AUM loss that will drive the torch optimizer to the best possible AUC. Let's use it to compare to previous models : 175 | 176 | ```{r} 177 | #| label: "AUM_based_model_fit" 178 | # configure the AUM loss 179 | tab_aum_mod <- tabnet(epochs = 100, loss = tabnet::nn_aum_loss, learn_rate = 0.02) |> 180 | set_engine("torch", device = "cpu") |> 181 | set_mode("classification") 182 | 183 | # derive a workflow 184 | tab_aum_wf <- workflow() |> 185 | add_model(tab_aum_mod) |> 186 | add_recipe(tab_rec) |> 187 | add_case_weights(case_wts) 188 | 189 | # fit and augment the test dataset with prediction 190 | tab_aum_fit <- tab_aum_wf |> fit(train) 191 | tab_aum_test <- tab_aum_fit |> augment(test) 192 | ``` 193 | 194 | Now let's compare the result on the PR curve with the default loss side by side: 195 | 196 | ```{r} 197 | #| label: "AUM_model_pr_curve" 198 | #| layout.ncol: 2 199 | #| fig.cap: 200 | #| - "Tabnet, no case-weight, default loss" 201 | #| - "Tabnet, no case-weight, ROC_AUM loss" 202 | #| 203 | tab_test |> 204 | pr_curve(Class, .pred_good) |> 205 | autoplot() 206 | 207 | tab_aum_test |> 208 | pr_curve(Class, .pred_good) |> 209 | autoplot() 210 | ``` 211 | 212 | We can see a real[^2] improvement with the AUM loss, compared to the default `nn_bce_loss()` but globally still a poor recall. 213 | 214 | [^2]: With improvement level being sensitive to the random seed. 215 | 216 | ## All together 217 | 218 | Nothing prevent us to use both features, as they are independent. That is what we do here. Moreover, it is here without additional computation, as it is done post inference. 219 | 220 | ```{r} 221 | #| label: "AUM_and_case-weights_prediction" 222 | #| layout.ncol: 2 223 | #| fig.cap: 224 | #| - "Tabnet, with case-weight, default loss" 225 | #| - "Tabnet, with case-weight, ROC_AUM loss" 226 | #| 227 | tab_test |> 228 | pr_curve(Class, .pred_good, case_weights = case_wts) |> 229 | autoplot() 230 | 231 | 232 | tab_aum_test |> 233 | pr_curve(Class, .pred_good, case_weights = case_wts) |> 234 | autoplot() 235 | ``` 236 | 237 | Here the boost in recall is impressive, making Tabnet model far above any experimented challenger model[^3]. 238 | 239 | [^3]: Within an educational and reproducible intend only. 240 | --------------------------------------------------------------------------------