├── .Rbuildignore ├── .github ├── .gitignore └── workflows │ ├── R-CMD-check.yaml │ └── pkgdown.yaml ├── .gitignore ├── CRAN-SUBMISSION ├── DESCRIPTION ├── LICENSE ├── LICENSE.md ├── NAMESPACE ├── NEWS.md ├── R ├── assertions.R ├── globals.R ├── predict_boots.R ├── summarise_boots.R ├── utils.R └── vi_boots.R ├── README.Rmd ├── README.md ├── _pkgdown.yml ├── cran-comments.md ├── man ├── figures │ ├── logo.png │ └── logo.svg ├── predict_boots.Rd ├── summarise_importance.Rd ├── summarise_predictions.Rd └── vi_boots.Rd ├── pkgdown └── favicon │ ├── apple-touch-icon-120x120.png │ ├── apple-touch-icon-152x152.png │ ├── apple-touch-icon-180x180.png │ ├── apple-touch-icon-60x60.png │ ├── apple-touch-icon-76x76.png │ ├── apple-touch-icon.png │ ├── favicon-16x16.png │ ├── favicon-32x32.png │ └── favicon.ico ├── revdep ├── README.md ├── checks │ └── libraries.csv ├── cran.md ├── data.sqlite ├── failures.md └── problems.md ├── tests ├── testthat.R └── testthat │ ├── data │ ├── test_data_bad.csv │ ├── test_importances.rds │ ├── test_preds.rds │ ├── test_test.csv │ ├── test_train.csv │ ├── test_wf.rds │ └── test_wf_bad.rds │ ├── test-predict-boots.R │ ├── test-summarise-boots.R │ └── test-vi-boots.R ├── vignettes ├── .gitignore ├── Estimating-Linear-Intervals.Rmd ├── Getting-Started-with-workboots.Rmd └── The-Math-Behind-workboots.Rmd └── workboots.Rproj /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^workboots\.Rproj$ 2 | ^\.Rproj\.user$ 3 | ^README\.Rmd$ 4 | ^LICENSE\.md$ 5 | ^cran-comments\.md$ 6 | ^data-raw$ 7 | ^pkgdown$ 8 | ^revdep$ 9 | ^_pkgdown\.yml$ 10 | ^\.github$ 11 | 12 | ^CRAN-SUBMISSION$ 13 | ^codecov\.yml$ 14 | ^docs$ 15 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/master/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | branches: [main, master] 8 | 9 | name: R-CMD-check 10 | 11 | jobs: 12 | R-CMD-check: 13 | runs-on: ${{ matrix.config.os }} 14 | 15 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 16 | 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | config: 21 | - {os: macOS-latest, r: 'release'} 22 | - {os: windows-latest, r: 'release'} 23 | - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} 24 | - {os: ubuntu-latest, r: 'release'} 25 | - {os: ubuntu-latest, r: 'oldrel-1'} 26 | 27 | env: 28 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 29 | R_KEEP_PKG_SOURCE: yes 30 | 31 | steps: 32 | - uses: actions/checkout@v3 33 | 34 | - uses: r-lib/actions/setup-pandoc@v2 35 | 36 | - uses: r-lib/actions/setup-r@v2 37 | with: 38 | r-version: ${{ matrix.config.r }} 39 | http-user-agent: ${{ matrix.config.http-user-agent }} 40 | use-public-rspm: true 41 | 42 | - uses: r-lib/actions/setup-r-dependencies@v2 43 | with: 44 | extra-packages: rcmdcheck 45 | 46 | - uses: r-lib/actions/check-r-package@v2 47 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | branches: [main, master] 8 | release: 9 | types: [published] 10 | workflow_dispatch: 11 | 12 | name: pkgdown 13 | 14 | jobs: 15 | pkgdown: 16 | runs-on: ubuntu-latest 17 | # Only restrict concurrency for non-PR jobs 18 | concurrency: 19 | group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }} 20 | env: 21 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 22 | steps: 23 | - uses: actions/checkout@v3 24 | 25 | - uses: r-lib/actions/setup-pandoc@v2 26 | 27 | - uses: r-lib/actions/setup-r@v2 28 | with: 29 | use-public-rspm: true 30 | 31 | - uses: r-lib/actions/setup-r-dependencies@v2 32 | with: 33 | extra-packages: any::pkgdown, local::. 34 | needs: website 35 | 36 | - name: Build site 37 | run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE) 38 | shell: Rscript {0} 39 | 40 | - name: Deploy to GitHub pages 🚀 41 | if: github.event_name != 'pull_request' 42 | uses: JamesIves/github-pages-deploy-action@v4.4.1 43 | with: 44 | clean: false 45 | branch: gh-pages 46 | folder: docs 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # History files 2 | .Rhistory 3 | .Rapp.history 4 | 5 | # Session Data files 6 | .RData 7 | 8 | # User-specific files 9 | .Ruserdata 10 | 11 | # Example code in package build process 12 | *-Ex.R 13 | 14 | # Output files from R CMD build 15 | /*.tar.gz 16 | 17 | # Output files from R CMD check 18 | /*.Rcheck/ 19 | 20 | # RStudio files 21 | .Rproj.user/ 22 | 23 | # produced vignettes 24 | vignettes/*.html 25 | vignettes/*.pdf 26 | 27 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 28 | .httr-oauth 29 | 30 | # knitr and R markdown default cache directories 31 | *_cache/ 32 | /cache/ 33 | 34 | # Temporary files created by R markdown 35 | *.utf8.md 36 | *.knit.md 37 | 38 | # R Environment Variables 39 | .Renviron 40 | inst/doc 41 | docs 42 | -------------------------------------------------------------------------------- /CRAN-SUBMISSION: -------------------------------------------------------------------------------- 1 | Version: 0.2.1 2 | Date: 2023-08-23 20:28:01 UTC 3 | SHA: ad4ba3c9142f7e010986efa1484bcfc91262fd89 4 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: workboots 2 | Title: Generate Bootstrap Prediction Intervals from a 'tidymodels' Workflow 3 | Version: 0.2.1 4 | Authors@R: 5 | person(given = "Mark", 6 | family = "Rieke", 7 | role = c("aut", "cre"), 8 | email = "markjrieke@gmail.com") 9 | Description: Provides functions for generating bootstrap prediction 10 | intervals from a 'tidymodels' workflow. 'tidymodels' 11 | is a collection of packages for modeling 12 | and machine learning using 'tidyverse' 13 | principles. This package is not affiliated with or maintained by 14 | 'RStudio' or the 'tidymodels' maintainers. 15 | License: MIT + file LICENSE 16 | URL: https://github.com/markjrieke/workboots, 17 | https://markjrieke.github.io/workboots/ 18 | BugReports: https://github.com/markjrieke/workboots/issues 19 | Imports: 20 | assertthat, 21 | dplyr, 22 | generics, 23 | lifecycle, 24 | Metrics, 25 | purrr, 26 | rlang, 27 | rsample, 28 | stats, 29 | tibble, 30 | tidyr, 31 | vip (>= 0.4.1), 32 | workflows 33 | Encoding: UTF-8 34 | LazyData: true 35 | Roxygen: list(markdown = TRUE) 36 | RoxygenNote: 7.2.3 37 | Suggests: 38 | forcats, 39 | ggplot2, 40 | kknn, 41 | knitr, 42 | readr, 43 | recipes, 44 | rmarkdown, 45 | scales, 46 | testthat (>= 3.0.0), 47 | tidymodels, 48 | tune, 49 | xgboost 50 | VignetteBuilder: knitr 51 | Config/testthat/edition: 3 52 | Depends: 53 | R (>= 2.10) 54 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2022 2 | COPYRIGHT HOLDER: Mark Rieke 3 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2022 Mark Rieke 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | export(predict_boots) 4 | export(summarise_importance) 5 | export(summarise_predictions) 6 | export(summarize_importance) 7 | export(summarize_predictions) 8 | export(vi_boots) 9 | importFrom(Metrics,rmse) 10 | importFrom(dplyr,filter) 11 | importFrom(dplyr,mutate) 12 | importFrom(dplyr,pull) 13 | importFrom(dplyr,relocate) 14 | importFrom(dplyr,rename) 15 | importFrom(dplyr,rename_with) 16 | importFrom(dplyr,select) 17 | importFrom(generics,fit) 18 | importFrom(lifecycle,deprecate_soft) 19 | importFrom(purrr,map) 20 | importFrom(purrr,map_dfc) 21 | importFrom(purrr,map_dfr) 22 | importFrom(rlang,":=") 23 | importFrom(rlang,arg_match) 24 | importFrom(rlang,sym) 25 | importFrom(rlang,warn) 26 | importFrom(rsample,bootstraps) 27 | importFrom(rsample,testing) 28 | importFrom(rsample,training) 29 | importFrom(stats,predict) 30 | importFrom(stats,quantile) 31 | importFrom(stats,rnorm) 32 | importFrom(stats,sd) 33 | importFrom(tibble,add_column) 34 | importFrom(tibble,rowid_to_column) 35 | importFrom(tidyr,crossing) 36 | importFrom(tidyr,nest) 37 | importFrom(tidyr,pivot_longer) 38 | importFrom(tidyr,pivot_wider) 39 | importFrom(tidyr,unnest) 40 | importFrom(vip,vi) 41 | importFrom(workflows,extract_fit_engine) 42 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # workboots 0.2.1 2 | 3 | * Workboots now requires that [`{vip}`](https://koalaverse.github.io/vip/) is at least version 0.4.1 (addresses [`{vip}`'s removal from CRAN](https://github.com/koalaverse/vip/issues/153)). 4 | 5 | # workboots 0.2.0 6 | 7 | ### Function updates 8 | 9 | * Rearranged column order output of `summarise_*` functions from `*_lower`, `*`, `*_upper` to `*`, `*_lower`, `*_upper` 10 | * Deprecated `conf` parameter in `summarise_*` functions in favor of `interval_width`. 11 | * Added support for generating confidence intervals from `predict_boots()`. 12 | 13 | # workboots 0.1.1 14 | 15 | ### Function updates 16 | 17 | * Updates to `predict_boots()` 18 | + updated function to generate prediction interval (previously was generating a prediction's confidence interval) 19 | + updated default setting to assume residuals are normally distributed 20 | + updated default number of resamples 21 | + updated function to draw residuals based on the [632+ rule](https://stats.stackexchange.com/questions/96739/what-is-the-632-rule-in-bootstrapping) (previously was using training residuals) 22 | * updated default number of resamples in `vi_boots()` 23 | * added param `verbose` to both `predict_boots()` and `vi_boots()` to display progress in the console. 24 | * added new function `summarise_importance()`, as well alias `summarize_*` for `summarise_*` 25 | 26 | ### Bug fixes 27 | 28 | * Fixed bug in `assert_pred_data()` that caused some workflows to be rejected by `predict_boots()` 29 | 30 | # workboots 0.1.0 31 | 32 | * Initial release 33 | * Core functions: 34 | + `predict_boots()` for generating bootstrap prediction intervals from tidymodel workflows. 35 | + `summarise_predictions()` for summarizing bootstrap predictions with expected, lower bound, and upper bound values. 36 | + `vi_boots()` for generating bootstrap feature importance from tidymodel workflows. 37 | * Vignette for Getting Started with workboots. 38 | * Unit testing for core functions (via testthat). 39 | -------------------------------------------------------------------------------- /R/assertions.R: -------------------------------------------------------------------------------- 1 | # Util function for checking workflow passed to predict/vi boots 2 | assert_workflow <- function(workflow) { 3 | 4 | # assert that object is actually a workflow 5 | assertthat::assert_that( 6 | assertthat::are_equal(class(workflow), "workflow"), 7 | msg = "argument `workflow` must be of class \"workflow\"." 8 | ) 9 | 10 | # assert that there are no remaining tuning parameters 11 | n_params <- length(workflow$fit$actions$model$spec$args) 12 | 13 | purrr::walk( 14 | seq(1, n_params, 1), 15 | ~assert_tune(workflow, .x) 16 | ) 17 | 18 | } 19 | 20 | # Util function for checking that workflow argument is not set to tune() 21 | assert_tune <- function(workflow, index) { 22 | 23 | # get model arg (need to know if it's NULL or not) 24 | model_arg <- as.character(rlang::eval_tidy(workflow$fit$actions$model$spec$args[[index]])) 25 | 26 | # don't check if NULL 27 | if (length(model_arg) != 0) { 28 | 29 | assertthat::assert_that( 30 | model_arg[1] != "tune", 31 | msg = paste0("all tuning parameters must be final before passing workflow to `predict_boots()`.") 32 | ) 33 | 34 | } 35 | 36 | } 37 | 38 | # Util function for checking n param 39 | assert_n <- function(n) { 40 | 41 | # >= 1 42 | assertthat::assert_that( 43 | n >= 1, 44 | msg = "argument `n` must be >= 1." 45 | ) 46 | 47 | # don't pass double 48 | assertthat::assert_that( 49 | n == as.integer(n), 50 | msg = "argmuent `n` must be an integer." 51 | ) 52 | 53 | } 54 | 55 | # Util function for checking training_data and new_data params 56 | assert_pred_data <- function(workflow, data, type) { 57 | 58 | # get colnames from workflow 59 | var_info <- workflow$pre$actions$recipe$recipe$var_info 60 | 61 | if (type == "training") { 62 | 63 | # check that colnames include all predictors and outcomes 64 | var_info <- dplyr::filter(var_info, role %in% c("predictor", "outcome")) 65 | 66 | } else { # type == "new" 67 | 68 | # check that colnames include all predictors 69 | var_info <- dplyr::filter(var_info, role == "predictor") 70 | 71 | } 72 | 73 | # get colnames for comparison 74 | cols_wf <- var_info$variable 75 | cols_dat <- colnames(data) 76 | 77 | # check that all cols in wf appear in data 78 | # message displays any names that appear in cols_wf that don't appear in cols_dat 79 | assertthat::assert_that( 80 | sum(cols_wf %in% cols_dat) == length(cols_wf), 81 | msg = paste0("missing cols in ", type, "_data:\n", 82 | paste(cols_wf[which(!cols_wf %in% cols_dat)], 83 | collapse = ", ")) 84 | ) 85 | 86 | } 87 | 88 | # Util function for checking .data passed to summary functions 89 | assert_pred_summary <- function(data) { 90 | 91 | # check for col .preds 92 | assertthat::assert_that( 93 | ".preds" %in% names(data), 94 | msg = "col `.preds` missing." 95 | 96 | ) 97 | 98 | # check that .preds is list 99 | assertthat::assert_that( 100 | class(data$.preds) == "list", 101 | msg = "col `.preds` must be a list-col." 102 | ) 103 | 104 | # check that model.pred is numeric 105 | assertthat::assert_that( 106 | class(data$.preds[[1]]$model.pred) == "numeric", 107 | msg = "col `model.pred` must be numeric." 108 | ) 109 | 110 | } 111 | 112 | # Util function for checking .data passed to importance summary function 113 | assert_importance_summary <- function(data) { 114 | 115 | # check for col .importances 116 | assertthat::assert_that( 117 | ".importances" %in% names(data), 118 | msg = "col `.importances` missing." 119 | ) 120 | 121 | # check that .importances is a list 122 | assertthat::assert_that( 123 | class(data$.importances) == "list", 124 | msg = "col `.importances` must be a list-col." 125 | ) 126 | 127 | # check that model.importance is numeric 128 | assertthat::assert_that( 129 | class(data$.importances[[1]]$model.importance) == "numeric", 130 | msg = "col `model.importance` must be numeric." 131 | ) 132 | 133 | } 134 | 135 | # Util function for checkinginterval 136 | assert_interval <- function(interval_width) { 137 | 138 | # numeric 139 | assertthat::assert_that( 140 | is.numeric(interval_width), 141 | msg = "argument `interval_width` must be numeric." 142 | ) 143 | 144 | # must be between [0, 1] 145 | assertthat::assert_that( 146 | interval_width >= 0 && interval_width <= 1, 147 | msg = "argument `interval_width` must be between [0, 1]." 148 | ) 149 | 150 | } 151 | -------------------------------------------------------------------------------- /R/globals.R: -------------------------------------------------------------------------------- 1 | globalVariables(c( 2 | ".pred", 3 | ".pred_lower", 4 | ".importance", 5 | ".importance_lower", 6 | ".importances", 7 | "importance", 8 | "int_level", 9 | "interval", 10 | "model", 11 | "model.importance", 12 | "model.pred", 13 | "pred_level", 14 | "resid_add", 15 | "resid_oob", 16 | "resid_train", 17 | "role", 18 | "variable" 19 | )) 20 | -------------------------------------------------------------------------------- /R/predict_boots.R: -------------------------------------------------------------------------------- 1 | #' Fit and predict from a workflow using many bootstrap resamples. 2 | #' 3 | #' Generate a prediction interval from arbitrary model types using bootstrap 4 | #' resampling. `predict_boots()` generates `n` bootstrap resamples, fits a model 5 | #' to each resample (creating `n` models), then creates `n` predictions for each 6 | #' observation in `new_data`. 7 | #' 8 | #' @details Since `predict_boots()` fits a new model to each resample, the 9 | #' argument `workflow` must not yet be fit. Any tuned hyperparameters must be 10 | #' finalized prior to calling `predict_boots()`. 11 | #' 12 | #' @return A tibble with a column indicating the row index of each observation in 13 | #' `new_data` and a nested list of the model predictions for each observation. 14 | #' 15 | #' @param new_data A tibble or dataframe used to make predictions. 16 | #' @param interval One of `prediction`, `confidence`. Specifies the interval type to be generated. 17 | #' @inheritParams vi_boots 18 | #' 19 | #' @export 20 | #' 21 | #' @importFrom rlang arg_match 22 | #' @importFrom rlang warn 23 | #' @importFrom rsample bootstraps 24 | #' @importFrom purrr map_dfc 25 | #' @importFrom tibble rowid_to_column 26 | #' @importFrom tidyr pivot_longer 27 | #' @importFrom tidyr nest 28 | #' 29 | #' @examples 30 | #' \dontrun{ 31 | #' library(tidymodels) 32 | #' 33 | #' # setup a workflow without fitting 34 | #' wf <- 35 | #' workflow() %>% 36 | #' add_recipe(recipe(qsec ~ wt, data = mtcars)) %>% 37 | #' add_model(linear_reg()) 38 | #' 39 | #' # fit and predict 2000 bootstrap resampled models to mtcars 40 | #' set.seed(123) 41 | #' wf %>% 42 | #' predict_boots(n = 2000, training_data = mtcars, new_data = mtcars) 43 | #' } 44 | predict_boots <- function(workflow, 45 | n = 2000, 46 | training_data, 47 | new_data, 48 | interval = c("prediction", "confidence"), 49 | verbose = FALSE, 50 | ...) { 51 | 52 | # convert interval type 53 | interval <- rlang::arg_match(interval) 54 | 55 | # check arguments 56 | assert_workflow(workflow) 57 | assert_n(n) 58 | assert_pred_data(workflow, training_data, "training") 59 | assert_pred_data(workflow, new_data, "new") 60 | 61 | # warn if low n 62 | if (n < 2000) { 63 | 64 | rlang::warn( 65 | paste0("At least 2000 resamples recommended for stable results.") 66 | ) 67 | 68 | } 69 | 70 | # create resamples from training set 71 | training_boots <- 72 | rsample::bootstraps( 73 | training_data, 74 | times = n, 75 | ... 76 | ) 77 | 78 | # map sequence of indices to `predict_single_boot()` 79 | # returns a column of predictions for each model 80 | preds <- 81 | purrr::map_dfc( 82 | seq(1, n), 83 | ~predict_single_boot( 84 | workflow = workflow, 85 | boot_splits = training_boots, 86 | new_data = new_data, 87 | interval = interval, 88 | verbose = verbose, 89 | index = .x 90 | ) 91 | ) 92 | 93 | # nest & return predictions in long format 94 | preds <- tibble::rowid_to_column(preds) 95 | 96 | preds <- 97 | tidyr::pivot_longer( 98 | preds, 99 | dplyr::starts_with(".pred_"), 100 | names_to = "model", 101 | values_to = "model.pred" 102 | ) 103 | 104 | preds <- 105 | tidyr::nest( 106 | preds, 107 | .preds = c(model, model.pred) 108 | ) 109 | 110 | return(preds) 111 | 112 | } 113 | 114 | # --------------------------------internals------------------------------------- 115 | 116 | #' (Internal) Generate a column of predictions on new data based on a model fit 117 | #' to a single training bootstrap. 118 | #' 119 | #' @param workflow passed from `predict_boots()` 120 | #' @param boot_splits passed from `predict_boots()` 121 | #' @param new_data passed from `predict_boots()` 122 | #' @param verbose passed from `predict_boots()` 123 | #' @param index passed from `predict_boots()` 124 | #' 125 | #' @importFrom rsample training 126 | #' @importFrom rsample testing 127 | #' @importFrom generics fit 128 | #' @importFrom dplyr filter 129 | #' @importFrom dplyr pull 130 | #' @importFrom stats predict 131 | #' @importFrom rlang sym 132 | #' @importFrom tidyr crossing 133 | #' @importFrom Metrics rmse 134 | #' @importFrom stats sd 135 | #' @importFrom tibble add_column 136 | #' @importFrom stats rnorm 137 | #' @importFrom dplyr mutate 138 | #' @importFrom dplyr rename 139 | #' @importFrom rlang := 140 | #' 141 | #' @noRd 142 | #' 143 | predict_single_boot <- function(workflow, 144 | boot_splits, 145 | new_data, 146 | interval, 147 | verbose, 148 | index) { 149 | 150 | # get training data from bootstrap resample split 151 | boot_train <- 152 | rsample::training( 153 | boot_splits$splits[[index]] 154 | ) 155 | 156 | # get oob sample 157 | boot_oob <- 158 | rsample::testing( 159 | boot_splits$splits[[index]] 160 | ) 161 | 162 | # fit workflow to training data 163 | model <- generics::fit(workflow, boot_train) 164 | 165 | # predict given model and new data 166 | preds <- stats::predict(model, new_data) 167 | 168 | # get predicted var name 169 | pred_name <- dplyr::filter(workflow$pre$actions$recipe$recipe$var_info, role == "outcome") 170 | pred_name <- dplyr::pull(pred_name, variable) 171 | 172 | # apply prediction interval using bootstrap 632+ estimate 173 | # if not, just returns absolute prediction (when summarised, this generates a confidence interval) 174 | if (interval == "prediction") { 175 | 176 | # get training residuals 177 | preds_train <- dplyr::pull(stats::predict(model, boot_train), .pred) 178 | actuals_train <- dplyr::pull(boot_train, rlang::sym(pred_name)) 179 | resids_train <- actuals_train - preds_train 180 | resids_train <- resids_train - mean(resids_train) 181 | 182 | # get oob residuals 183 | preds_oob <- dplyr::pull(stats::predict(model, boot_oob), .pred) 184 | actuals_oob <- dplyr::pull(boot_oob, rlang::sym(pred_name)) 185 | resids_oob <- actuals_oob - preds_oob 186 | resids_oob <- resids_oob - mean(resids_oob) 187 | 188 | # calculate no-information error rate (rmse_ni) with RMSE as loss function 189 | combos <- tidyr::crossing(actuals_train, preds_train) 190 | rmse_ni <- Metrics::rmse(combos$actuals_train, combos$preds_train) 191 | 192 | # calculate overfit rate 193 | rmse_oob <- Metrics::rmse(actuals_oob, preds_oob) 194 | rmse_train <- Metrics::rmse(actuals_train, preds_train) 195 | overfit <- (rmse_oob - rmse_train)/(rmse_ni - rmse_train) 196 | 197 | # calculate weight (if overfit = 0, weight = .632 & residual used will just be .632) 198 | # uses the actual proportion of distinct training/oob samples, rather than average of 0.632/0.368 199 | prop_368 <- nrow(boot_oob)/nrow(boot_train) 200 | prop_632 <- 1 - prop_368 201 | weight <- prop_632/(1 - (prop_368 * overfit)) 202 | 203 | # determine residual std.dev based on weight 204 | sd_oob <- stats::sd(resids_oob) 205 | sd_train <- stats::sd(resids_train) 206 | sd_resid <- weight * sd_oob + (1 - weight) * sd_train 207 | 208 | # add residuals to fit 209 | preds <- tibble::add_column(preds, resid_add = stats::rnorm(nrow(new_data), 0, sd_resid)) 210 | preds <- dplyr::mutate(preds, .pred = .pred + resid_add) 211 | preds <- preds[, 1] 212 | 213 | } 214 | 215 | # rename .pred col based on index number 216 | preds <- dplyr::rename(preds, !!rlang::sym(paste0(".pred_", index)) := .pred) 217 | 218 | # print progress when verbose is set to TRUE 219 | verbose_print(verbose, index, nrow(boot_splits)) 220 | 221 | return(preds) 222 | 223 | } 224 | 225 | 226 | -------------------------------------------------------------------------------- /R/summarise_boots.R: -------------------------------------------------------------------------------- 1 | #' Append a tibble of predictions returned by `predict_boots()` with upper and 2 | #' lower bounds. 3 | #' 4 | #' @details Generates a summary of predictions with a upper and lower interval 5 | #' range. Presently, the `quantile()` function from the `{stats}` package is 6 | #' used to determine the lower, 50th percentile, and upper interval ranges. 7 | #' 8 | #' @return Appends the tibble of predictions returned by `predict_boots()` with 9 | #' three new columns: `.pred_lower`, `.pred`, and `.pred_upper`. 10 | #' 11 | #' @aliases `summarize_predictions()` 12 | #' 13 | #' @param .data a tibble of predictions returned by `predict_boots()`. 14 | #' @param interval_width a value between (0, 1) specifying the interval range. 15 | #' @param conf deprecated - please use `interval_width` instead. 16 | #' 17 | #' @export 18 | #' 19 | #' @importFrom lifecycle deprecate_soft 20 | #' 21 | #' @examples 22 | #' \dontrun{ 23 | #' library(tidymodels) 24 | #' 25 | #' # setup a workflow without fitting 26 | #' wf <- 27 | #' workflow() %>% 28 | #' add_recipe(recipe(qsec ~ wt, data = mtcars)) %>% 29 | #' add_model(linear_reg()) 30 | #' 31 | #' # fit and predict 2000 bootstrap resampled models to mtcars 32 | #' set.seed(123) 33 | #' preds <- 34 | #' wf %>% 35 | #' predict_boots(n = 2000, training_data = mtcars, new_data = mtcars) 36 | #' 37 | #' # append with prediction interval summary columns 38 | #' preds %>% 39 | #' summarise_predictions(conf = 0.95) 40 | #' } 41 | summarise_predictions <- function(.data, 42 | interval_width = 0.95, 43 | conf = NULL) { 44 | 45 | # warn about parameter deprecation 46 | if (!is.null(conf)) { 47 | 48 | lifecycle::deprecate_soft( 49 | when = "0.2.0", 50 | what = "summarise_predictions(conf)", 51 | with = "summarise_predictions(interval_width)" 52 | ) 53 | 54 | # reassign to interval width 55 | interval_width <- conf 56 | 57 | } 58 | 59 | # check arguments 60 | assert_pred_summary(.data) 61 | assert_interval(interval_width) 62 | 63 | # pass to summarise_generic 64 | summarise_generic( 65 | .data = .data, 66 | nest_col = ".preds", 67 | interval_width = interval_width 68 | ) 69 | 70 | } 71 | #' @rdname summarise_predictions 72 | #' @export 73 | summarize_predictions <- summarise_predictions 74 | 75 | #' Append a tibble of variable importances returned by `vi_boots()` with upper 76 | #' and lower bounds. 77 | #' 78 | #' @details Generates a summary of variable importances with an upper and lower 79 | #' interval range. Uses the `vi()` function from the `{vip}` package to compute 80 | #' variable importances (not all model types are supported by `vip::vi()`; please 81 | #' refer to `{vip}` package documentation for supported model types). Presently, 82 | #' the `quantile()` function from the `{stats}` package is used to determine 83 | #' the lower, 50th percentile, and upper interval ranges. 84 | #' 85 | #' @param .data a tibble of variable importances returned by `vi_boots()`. 86 | #' @param interval_width a value between (0, 1) specifying the interval range. 87 | #' @param conf deprecated - please use `interval_width` instead. 88 | #' 89 | #' @export 90 | #' 91 | #' @importFrom lifecycle deprecate_soft 92 | #' 93 | #' @examples 94 | #' \dontrun{ 95 | #' library(tidymodels) 96 | #' 97 | #' # setup a workflow without fitting 98 | #' wf <- 99 | #' workflow() %>% 100 | #' add_recipe(recipe(qsec ~ wt, data = mtcars)) %>% 101 | #' add_model(linear_reg()) 102 | #' 103 | #' # evaluate variable importance from 2000 models fit to mtcars 104 | #' set.seed(123) 105 | #' importances <- 106 | #' wf %>% 107 | #' vi_boots(n = 2000, training_data = mtcars, new_data = mtcars) 108 | #' 109 | #' # append with lower and upper bound importance summary columns 110 | #' importances %>% 111 | #' summarise_importance(interval_width = 0.95) 112 | #' } 113 | summarise_importance <- function(.data, 114 | interval_width = 0.95, 115 | conf = NULL) { 116 | 117 | # warn about parameter deprecation 118 | if (!is.null(conf)) { 119 | 120 | lifecycle::deprecate_soft( 121 | when = "0.2.0", 122 | what = "summarise_importance(conf)", 123 | with = "summarise_importance(interval_width)" 124 | ) 125 | 126 | # reassign to interval width 127 | interval_width <- conf 128 | 129 | } 130 | 131 | # check arguments 132 | assert_importance_summary(.data) 133 | assert_interval(interval_width) 134 | 135 | # pass arguments to summarise_generic 136 | summarise_generic( 137 | .data = .data, 138 | nest_col = "importance", 139 | interval_width = interval_width 140 | ) 141 | 142 | } 143 | #' @rdname summarise_importance 144 | #' @export 145 | summarize_importance <- summarise_importance 146 | 147 | # ------------------------------internals--------------------------------------- 148 | 149 | #' (Internal) Function for generating generic summaries from either `predict_boots()` 150 | #' or `vi_boots()`. 151 | #' 152 | #' @param .data passed from one of the summarise_* functions 153 | #' @param nest_col passed from one of the summarise_* functions 154 | #' @param interval_width passed from one of the summarise_* functions 155 | #' @param conf passed from one of the summarise_* functions 156 | #' 157 | #' @importFrom dplyr mutate 158 | #' @importFrom purrr map 159 | #' @importFrom stats quantile 160 | #' @importFrom tidyr unnest 161 | #' @importFrom dplyr select 162 | #' @importFrom tidyr nest 163 | #' @importFrom tidyr pivot_wider 164 | #' @importFrom dplyr relocate 165 | #' 166 | #' @noRd 167 | #' 168 | summarise_generic <- function(.data, 169 | nest_col, 170 | interval_width) { 171 | 172 | # internal renaming 173 | summary <- .data 174 | 175 | # determine ci_lower & ci_upper values from conf 176 | int_lower <- (1 - interval_width)/2 177 | int_upper <- int_lower + interval_width 178 | 179 | # return max row 180 | n_rows <- nrow(summary) 181 | 182 | # map variable importances to quantile fn 183 | if (nest_col == ".preds") { 184 | 185 | # summarise predictions 186 | summary <- 187 | dplyr::mutate( 188 | summary, 189 | interval = purrr::map(summary$.preds, ~stats::quantile(.x$model.pred, probs = c(int_lower, 0.5, int_upper))) 190 | ) 191 | 192 | # vector of lower/med/upper column names 193 | col_vec <- rep(c(".pred_lower", ".pred", ".pred_upper"), n_rows) 194 | 195 | } else { 196 | 197 | # adjust the importance cols if needed 198 | if (ncol(summary$.importances[[1]]) > 2) { 199 | 200 | summary <- tidyr::unnest(summary, .importances) 201 | summary <- dplyr::mutate(summary, model.importance = ifelse(sign == "NEG", -model.importance, model.importance)) 202 | summary <- dplyr::select(summary, -sign) 203 | summary <- tidyr::nest(summary, .importances = -variable) 204 | 205 | } 206 | 207 | # summarise variable importances 208 | summary <- 209 | dplyr::mutate( 210 | summary, 211 | interval = purrr::map(summary$.importances, ~stats::quantile(.x$model.importance, probs = c(int_lower, 0.5, int_upper))) 212 | ) 213 | 214 | # vector of lower/med/upper column names 215 | col_vec <- rep(c(".importance_lower", ".importance", ".importance_upper"), n_rows) 216 | 217 | } 218 | 219 | # add interval labels & display as their own cols 220 | summary <- tidyr::unnest(summary, interval) 221 | summary <- 222 | dplyr::mutate( 223 | summary, 224 | int_level = col_vec 225 | ) 226 | 227 | summary <- 228 | tidyr::pivot_wider( 229 | summary, 230 | names_from = int_level, 231 | values_from = interval 232 | ) 233 | 234 | # rearrange cols to be .mid .lower .upper 235 | if (nest_col == ".preds") { 236 | 237 | summary <- dplyr::relocate(summary, .pred_lower, .after = .pred) 238 | 239 | } else { 240 | 241 | summary <- dplyr::relocate(summary, .importance_lower, .after = .importance) 242 | 243 | } 244 | 245 | return(summary) 246 | 247 | } 248 | -------------------------------------------------------------------------------- /R/utils.R: -------------------------------------------------------------------------------- 1 | # util function for printing training index to screen when verbose = TRUE 2 | verbose_print <- function(verbose, index, total) { 3 | 4 | if (verbose == TRUE) { 5 | 6 | message(paste0("Trained ", 7 | index, 8 | "/", 9 | total, 10 | " models.")) 11 | 12 | } 13 | 14 | } 15 | -------------------------------------------------------------------------------- /R/vi_boots.R: -------------------------------------------------------------------------------- 1 | #' Fit and estimate variable importance from a workflow using many bootstrap resamples. 2 | #' 3 | #' Generate variable importances from a tidymodel workflow using bootstrap resampling. 4 | #' `vi_boots()` generates `n` bootstrap resamples, fits a model to each (creating 5 | #' `n` models), then creates `n` estimates of variable importance for each variable 6 | #' in the model. 7 | #' 8 | #' @details Since `vi_boots()` fits a new model to each resample, the 9 | #' argument `workflow` must not yet be fit. Any tuned hyperparameters must be 10 | #' finalized prior to calling `vi_boots()`. 11 | #' 12 | #' @return A tibble with a column indicating each variable in the model and a 13 | #' nested list of variable importances for each variable. The shape of the list 14 | #' may vary by model type. For example, linear models return two nested columns: 15 | #' the absolute value of each variable's importance and the sign (POS/NEG), 16 | #' whereas tree-based models return a single nested column of variable importance. 17 | #' Similarly, the number of nested rows may vary by model type as some models 18 | #' may not utilize every possible predictor. 19 | #' 20 | #' @param workflow An un-fitted workflow object. 21 | #' @param training_data A tibble or dataframe of data to be resampled and used for training. 22 | #' @param n An integer for the number of bootstrap resampled models that will be created. 23 | #' @param verbose A logical. Defaults to `FALSE`. If set to `TRUE`, prints progress 24 | #' of training to console. 25 | #' @param ... Additional params passed to `rsample::bootstraps()`. 26 | #' 27 | #' @export 28 | #' 29 | #' @importFrom rlang warn 30 | #' @importFrom rsample bootstraps 31 | #' @importFrom purrr map_dfr 32 | #' @importFrom dplyr rename_with 33 | #' @importFrom tidyr nest 34 | #' 35 | #' @examples 36 | #' \dontrun{ 37 | #' library(tidymodels) 38 | #' 39 | #' # setup a workflow without fitting 40 | #' wf <- 41 | #' workflow() %>% 42 | #' add_recipe(recipe(qsec ~ wt, data = mtcars)) %>% 43 | #' add_model(linear_reg()) 44 | #' 45 | #' # fit and estimate variable importance from 125 bootstrap resampled models 46 | #' set.seed(123) 47 | #' wf %>% 48 | #' vi_boots(n = 2000, training_data = mtcars) 49 | #' } 50 | vi_boots <- function(workflow, 51 | n = 2000, 52 | training_data, 53 | verbose = FALSE, 54 | ...) { 55 | 56 | # check arguments 57 | assert_workflow(workflow) 58 | assert_n(n) 59 | assert_pred_data(workflow, training_data, "training") 60 | 61 | # warn if low n 62 | if (n < 2000) { 63 | 64 | rlang::warn( 65 | paste0("At least 2000 resamples recommended for stable results.") 66 | ) 67 | 68 | } 69 | 70 | # create resamples from training set 71 | training_boots <- 72 | rsample::bootstraps( 73 | training_data, 74 | times = n, 75 | ... 76 | ) 77 | 78 | # map sequence of indices to `vi_single_boot()` 79 | # returns a variable + importance for each model (number of cols may vary by model type) 80 | bootstrap_vi <- 81 | purrr::map_dfr( 82 | seq(1, n), 83 | ~vi_single_boot( 84 | workflow = workflow, 85 | boot_splits = training_boots, 86 | verbose = verbose, 87 | index = .x 88 | ) 89 | ) 90 | 91 | # rename cols 92 | bootstrap_vi <- dplyr::rename_with(bootstrap_vi, tolower) 93 | bootstrap_vi <- dplyr::rename(bootstrap_vi, model.importance = importance) 94 | 95 | # return a nested tibble 96 | bootstrap_vi <- tidyr::nest(bootstrap_vi, .importances = -variable) 97 | 98 | return(bootstrap_vi) 99 | 100 | } 101 | 102 | # -------------------------------internals-------------------------------------- 103 | 104 | #' Fit a model and get the variable importance based on a single bootstrap resample 105 | #' 106 | #' @param workflow passed from `vi_boots()` 107 | #' @param boot_splits passed from `vi_boots()` 108 | #' @param verbose passed from `vi_boots()` 109 | #' @param index passed from `vi_boots()` 110 | #' 111 | #' @importFrom rsample training 112 | #' @importFrom generics fit 113 | #' @importFrom vip vi 114 | #' @importFrom workflows extract_fit_engine 115 | #' 116 | #' @noRd 117 | #' 118 | vi_single_boot <- function(workflow, 119 | boot_splits, 120 | verbose, 121 | index) { 122 | 123 | # get training data from bootstrap resample split 124 | boot_train <- 125 | rsample::training( 126 | boot_splits$splits[[index]] 127 | ) 128 | 129 | # fit workflow to to the training data 130 | model <- generics::fit(workflow, boot_train) 131 | 132 | # get the variable importance from the model 133 | vi_boot <- vip::vi(workflows::extract_fit_engine(model)) 134 | 135 | # add model name 136 | vi_boot <- dplyr::mutate(vi_boot, model = paste0(".importance_", index)) 137 | 138 | # print progress when verbose is set to TRUE 139 | verbose_print( 140 | verbose = verbose, 141 | index = index, 142 | total = nrow(boot_splits) 143 | ) 144 | 145 | return(vi_boot) 146 | 147 | } 148 | -------------------------------------------------------------------------------- /README.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | output: github_document 3 | --- 4 | 5 | ```{r setup, include=FALSE} 6 | knitr::opts_chunk$set( 7 | echo = TRUE, 8 | comment = "#>", 9 | message = FALSE, 10 | warning = FALSE 11 | ) 12 | ``` 13 | 14 | # workboots 15 | 16 | **Author:** [Mark Rieke](https://www.thedatadiary.net/about/about/)
17 | **License:** [MIT](https://github.com/markjrieke/workboots/blob/main/LICENSE) 18 | 19 | 20 | [![R-CMD-check](https://github.com/markjrieke/workboots/workflows/R-CMD-check/badge.svg)](https://github.com/markjrieke/workboots/actions) 21 | [![Lifecycle: experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html#experimental) 22 | [![CRAN status](https://www.r-pkg.org/badges/version/workboots)](https://CRAN.R-project.org/package=workboots) 23 | [![](https://cranlogs.r-pkg.org/badges/grand-total/workboots)](https://cran.r-project.org/package=workboots) 24 | 25 | 26 | ## Overview 27 | 28 | `{workboots}` is a tidy method of generating bootstrap prediction intervals for arbitrary model types from a tidymodel workflow. 29 | 30 | By using [bootstrap resampling](https://en.wikipedia.org/wiki/Bootstrapping_(statistics)), we can create many models --- one for each resample. Each model will be slightly different based on the resample it was trained on. Each model will also generate slightly different predictions for new data, allowing us to generate a prediction distribution for models that otherwise just return point predictions. 31 | 32 | ## Installation 33 | 34 | You can install the released version of workboots from CRAN or the development version from github with the [devtools](https://cran.r-project.org/package=devtools) or [remotes](https://cran.r-project.org/package=remotes) package: 35 | 36 | ```{r, eval=FALSE} 37 | # install from CRAN 38 | install.packages("workboots") 39 | 40 | # or the development version 41 | devtools::install_github("markjrieke/workboots") 42 | ``` 43 | 44 | ## Usage 45 | 46 | workboots builds on top of the `{tidymodels}` suite of packages. Teaching how to use tidymodels is beyond the scope of this package, but some helpful resources are linked at the bottom of this README. 47 | 48 | To get started, we'll need to create a workflow. 49 | 50 | ```{r} 51 | library(tidymodels) 52 | 53 | # load our dataset 54 | data("penguins") 55 | penguins <- penguins %>% drop_na() 56 | 57 | # split data into testing & training sets 58 | set.seed(123) 59 | penguins_split <- initial_split(penguins) 60 | penguins_test <- testing(penguins_split) 61 | penguins_train <- training(penguins_split) 62 | 63 | # create a workflow 64 | penguins_wf <- 65 | workflow() %>% 66 | add_recipe(recipe(body_mass_g ~ ., data = penguins_train) %>% step_dummy(all_nominal())) %>% 67 | add_model(boost_tree("regression")) 68 | ``` 69 | 70 | Boosted tree models can only generate point predictions, but with workboots we can generate a prediction interval for each observation in `penguins_test` by passing the workflow to `predict_boots()`: 71 | 72 | ```{r, eval=FALSE} 73 | library(workboots) 74 | 75 | # generate predictions from 2000 bootstrap models 76 | set.seed(345) 77 | penguins_pred_int <- 78 | penguins_wf %>% 79 | predict_boots( 80 | n = 2000, 81 | training_data = penguins_train, 82 | new_data = penguins_test 83 | ) 84 | 85 | # summarise predictions with a 95% prediction interval 86 | pengins_pred_int %>% 87 | summarise_predictions() 88 | ``` 89 | 90 | ```{r,echo=FALSE} 91 | library(workboots) 92 | 93 | # load data from workboots_support (avoid re-fitting on knit) 94 | penguins_pred_int <-readr::read_rds("https://github.com/markjrieke/workboots_support/blob/main/data/penguins_pred_int.rds?raw=true") 95 | 96 | penguins_pred_int %>% 97 | summarise_predictions() 98 | ``` 99 | 100 | Alternatively, we can generate a confidence interval around each prediction by setting the parameter `interval` to `"confidence"`: 101 | 102 | ```{r, eval=FALSE} 103 | # generate predictions from 2000 bootstrap models 104 | set.seed(456) 105 | penguins_conf_int <- 106 | penguins_wf %>% 107 | predict_boots( 108 | n = 2000, 109 | training_data = penguins_train, 110 | new_data = penguins_test, 111 | interval = "confidence" 112 | ) 113 | 114 | # summarise with a 95% confidence interval 115 | penguins_conf_int %>% 116 | summarise_predictions() 117 | ``` 118 | 119 | ```{r, echo=FALSE} 120 | # load data from workboots_support (avoid re-fitting on knit) 121 | penguins_conf_int <- readr::read_rds("https://github.com/markjrieke/workboots_support/blob/main/data/penguins_conf_int.rds?raw=true") 122 | 123 | penguins_conf_int %>% 124 | summarise_predictions() 125 | ``` 126 | 127 | ## Bug reports/feature requests 128 | 129 | If you notice a bug, want to request a new feature, or have recommendations on improving documentation, please [open an issue](https://github.com/markjrieke/workboots/issues) in this repository. 130 | 131 | ### Tidymodels Resources 132 | 133 | * [Getting started with Tidymodels](https://www.tidymodels.org/start/) 134 | * [Tidy Modeling with R](https://www.tmwr.org/) 135 | * [Julia Silge's Blog](https://juliasilge.com/blog/) provides use cases of tidymodels with weekly [#tidytuesday](https://github.com/rfordatascience/tidytuesday) datasets. 136 | 137 |

The hex logo for workboots was designed by Sarah Power.

138 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # workboots 3 | 4 | **Author:** [Mark Rieke](https://www.thedatadiary.net/about/about/) 5 |
**License:** 6 | [MIT](https://github.com/markjrieke/workboots/blob/main/LICENSE) 7 | 8 | 9 | 10 | [![R-CMD-check](https://github.com/markjrieke/workboots/workflows/R-CMD-check/badge.svg)](https://github.com/markjrieke/workboots/actions) 11 | [![Lifecycle: 12 | experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html#experimental) 13 | [![CRAN 14 | status](https://www.r-pkg.org/badges/version/workboots)](https://CRAN.R-project.org/package=workboots) 15 | [![](https://cranlogs.r-pkg.org/badges/grand-total/workboots)](https://cran.r-project.org/package=workboots) 16 | 17 | 18 | ## Overview 19 | 20 | `{workboots}` is a tidy method of generating bootstrap prediction 21 | intervals for arbitrary model types from a tidymodel workflow. 22 | 23 | By using [bootstrap 24 | resampling](https://en.wikipedia.org/wiki/Bootstrapping_(statistics)), 25 | we can create many models — one for each resample. Each model will be 26 | slightly different based on the resample it was trained on. Each model 27 | will also generate slightly different predictions for new data, allowing 28 | us to generate a prediction distribution for models that otherwise just 29 | return point predictions. 30 | 31 | ## Installation 32 | 33 | You can install the released version of workboots from CRAN or the 34 | development version from github with the 35 | [devtools](https://cran.r-project.org/package=devtools) or 36 | [remotes](https://cran.r-project.org/package=remotes) package: 37 | 38 | ``` r 39 | # install from CRAN 40 | install.packages("workboots") 41 | 42 | # or the development version 43 | devtools::install_github("markjrieke/workboots") 44 | ``` 45 | 46 | ## Usage 47 | 48 | workboots builds on top of the `{tidymodels}` suite of packages. 49 | Teaching how to use tidymodels is beyond the scope of this package, but 50 | some helpful resources are linked at the bottom of this README. 51 | 52 | To get started, we’ll need to create a workflow. 53 | 54 | ``` r 55 | library(tidymodels) 56 | 57 | # load our dataset 58 | data("penguins") 59 | penguins <- penguins %>% drop_na() 60 | 61 | # split data into testing & training sets 62 | set.seed(123) 63 | penguins_split <- initial_split(penguins) 64 | penguins_test <- testing(penguins_split) 65 | penguins_train <- training(penguins_split) 66 | 67 | # create a workflow 68 | penguins_wf <- 69 | workflow() %>% 70 | add_recipe(recipe(body_mass_g ~ ., data = penguins_train) %>% step_dummy(all_nominal())) %>% 71 | add_model(boost_tree("regression")) 72 | ``` 73 | 74 | Boosted tree models can only generate point predictions, but with 75 | workboots we can generate a prediction interval for each observation in 76 | `penguins_test` by passing the workflow to `predict_boots()`: 77 | 78 | ``` r 79 | library(workboots) 80 | 81 | # generate predictions from 2000 bootstrap models 82 | set.seed(345) 83 | penguins_pred_int <- 84 | penguins_wf %>% 85 | predict_boots( 86 | n = 2000, 87 | training_data = penguins_train, 88 | new_data = penguins_test 89 | ) 90 | 91 | # summarise predictions with a 95% prediction interval 92 | pengins_pred_int %>% 93 | summarise_predictions() 94 | ``` 95 | 96 | #> # A tibble: 84 × 5 97 | #> rowid .preds .pred .pred_lower .pred_upper 98 | #> 99 | #> 1 1 3465. 2913. 3994. 100 | #> 2 2 3535. 2982. 4100. 101 | #> 3 3 3604. 3050. 4187. 102 | #> 4 4 4157. 3477. 4764. 103 | #> 5 5 3868. 3305. 4372. 104 | #> 6 6 3519. 2996. 4078. 105 | #> 7 7 3435. 2914. 3954. 106 | #> 8 8 4072. 3483. 4653. 107 | #> 9 9 3445. 2926. 3966. 108 | #> 10 10 3405. 2876. 3938. 109 | #> # ℹ 74 more rows 110 | 111 | Alternatively, we can generate a confidence interval around each 112 | prediction by setting the parameter `interval` to `"confidence"`: 113 | 114 | ``` r 115 | # generate predictions from 2000 bootstrap models 116 | set.seed(456) 117 | penguins_conf_int <- 118 | penguins_wf %>% 119 | predict_boots( 120 | n = 2000, 121 | training_data = penguins_train, 122 | new_data = penguins_test, 123 | interval = "confidence" 124 | ) 125 | 126 | # summarise with a 95% confidence interval 127 | penguins_conf_int %>% 128 | summarise_predictions() 129 | ``` 130 | 131 | #> # A tibble: 84 × 5 132 | #> rowid .preds .pred .pred_lower .pred_upper 133 | #> 134 | #> 1 1 3466. 3257. 3635. 135 | #> 2 2 3534. 3291. 3811. 136 | #> 3 3 3623. 3306. 3921. 137 | #> 4 4 4155. 3722. 4504. 138 | #> 5 5 3868. 3644. 4086. 139 | #> 6 6 3509. 3286. 3768. 140 | #> 7 7 3439. 3249. 3624. 141 | #> 8 8 4064. 3737. 4369. 142 | #> 9 9 3450. 3253. 3635. 143 | #> 10 10 3405. 3222. 3651. 144 | #> # ℹ 74 more rows 145 | 146 | ## Bug reports/feature requests 147 | 148 | If you notice a bug, want to request a new feature, or have 149 | recommendations on improving documentation, please [open an 150 | issue](https://github.com/markjrieke/workboots/issues) in this 151 | repository. 152 | 153 | ### Tidymodels Resources 154 | 155 | - [Getting started with Tidymodels](https://www.tidymodels.org/start/) 156 | - [Tidy Modeling with R](https://www.tmwr.org/) 157 | - [Julia Silge’s Blog](https://juliasilge.com/blog/) provides use cases 158 | of tidymodels with weekly 159 | [\#tidytuesday](https://github.com/rfordatascience/tidytuesday) 160 | datasets. 161 | 162 |

163 | The hex logo for workboots was designed by 164 | Sarah 165 | Power. 166 |

167 | -------------------------------------------------------------------------------- /_pkgdown.yml: -------------------------------------------------------------------------------- 1 | url: https://markjrieke.github.io/workboots/ 2 | template: 3 | bootstrap: 5 4 | bootswatch: yeti 5 | 6 | figures: 7 | fig.width: 8 8 | fig.height: 5.75 9 | 10 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- 1 | ## Release summary 2 | 3 | This is a reupload of workboots 0.2.1. In this version I have: 4 | 5 | * Updated package dependencies to depend on vip version 0.4.1. vip 0.4.1 fixes issues that caused vip 0.4.0 to be removed from CRAN. 6 | * Updated links in vignettes that caused 403 errors on submission. 7 | 8 | ## R CMD check results 9 | There was 1 NOTE: 10 | 11 | CRAN repository db overrides: 12 | X-CRAN-Comment: Archived on 2023-08-14 as it requires archived 13 | package 'vip'. 14 | 15 | 'vip' 0.4.0 was archived from CRAN and has been re-instated with version 0.4.1. This release of workboots 0.2.1 addresses the dependency issue by requiring 'vip' 0.4.1 or greater. 16 | 17 | 18 | ## revdepcheck results 19 | 20 | We checked 0 reverse dependencies, comparing R CMD check results across CRAN and dev versions of this package. 21 | 22 | * We saw 0 new problems 23 | * We failed to check 0 packages 24 | 25 | -------------------------------------------------------------------------------- /man/figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/man/figures/logo.png -------------------------------------------------------------------------------- /man/figures/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 9 | 10 | 11 | 12 | 16 | 17 | 18 | 22 | 25 | 28 | 30 | 32 | 35 | 38 | 41 | 44 | 48 | 53 | 56 | 59 | 60 | 61 | 65 | 68 | 71 | 73 | 75 | 77 | 78 | 80 | 81 | 84 | 85 | 86 | 89 | 91 | 94 | 97 | 100 | 103 | 106 | 109 | 115 | 118 | 119 | 120 | 123 | 125 | 129 | 132 | 133 | 136 | 139 | 141 | 144 | 147 | 150 | 153 | 156 | 161 | 165 | 168 | 169 | 170 | 173 | 176 | 179 | 182 | 183 | 186 | 187 | 190 | 192 | 193 | 195 | 197 | 200 | 202 | 205 | 208 | 211 | 214 | 217 | 220 | 227 | 228 | 231 | 234 | 235 | 236 | 240 | 242 | 245 | 248 | 250 | 255 | 258 | 261 | 264 | 267 | 270 | 274 | 275 | 278 | 279 | 282 | 285 | 288 | 291 | 293 | 295 | 297 | 299 | 302 | 304 | 307 | 310 | 313 | 315 | 317 | 318 | 319 | 331 | 348 | 360 | 369 | 376 | 410 | 411 | 416 | 417 | 418 | 419 | 420 | 421 | -------------------------------------------------------------------------------- /man/predict_boots.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/predict_boots.R 3 | \name{predict_boots} 4 | \alias{predict_boots} 5 | \title{Fit and predict from a workflow using many bootstrap resamples.} 6 | \usage{ 7 | predict_boots( 8 | workflow, 9 | n = 2000, 10 | training_data, 11 | new_data, 12 | interval = c("prediction", "confidence"), 13 | verbose = FALSE, 14 | ... 15 | ) 16 | } 17 | \arguments{ 18 | \item{workflow}{An un-fitted workflow object.} 19 | 20 | \item{n}{An integer for the number of bootstrap resampled models that will be created.} 21 | 22 | \item{training_data}{A tibble or dataframe of data to be resampled and used for training.} 23 | 24 | \item{new_data}{A tibble or dataframe used to make predictions.} 25 | 26 | \item{interval}{One of \code{prediction}, \code{confidence}. Specifies the interval type to be generated.} 27 | 28 | \item{verbose}{A logical. Defaults to \code{FALSE}. If set to \code{TRUE}, prints progress 29 | of training to console.} 30 | 31 | \item{...}{Additional params passed to \code{rsample::bootstraps()}.} 32 | } 33 | \value{ 34 | A tibble with a column indicating the row index of each observation in 35 | \code{new_data} and a nested list of the model predictions for each observation. 36 | } 37 | \description{ 38 | Generate a prediction interval from arbitrary model types using bootstrap 39 | resampling. \code{predict_boots()} generates \code{n} bootstrap resamples, fits a model 40 | to each resample (creating \code{n} models), then creates \code{n} predictions for each 41 | observation in \code{new_data}. 42 | } 43 | \details{ 44 | Since \code{predict_boots()} fits a new model to each resample, the 45 | argument \code{workflow} must not yet be fit. Any tuned hyperparameters must be 46 | finalized prior to calling \code{predict_boots()}. 47 | } 48 | \examples{ 49 | \dontrun{ 50 | library(tidymodels) 51 | 52 | # setup a workflow without fitting 53 | wf <- 54 | workflow() \%>\% 55 | add_recipe(recipe(qsec ~ wt, data = mtcars)) \%>\% 56 | add_model(linear_reg()) 57 | 58 | # fit and predict 2000 bootstrap resampled models to mtcars 59 | set.seed(123) 60 | wf \%>\% 61 | predict_boots(n = 2000, training_data = mtcars, new_data = mtcars) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /man/summarise_importance.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/summarise_boots.R 3 | \name{summarise_importance} 4 | \alias{summarise_importance} 5 | \alias{summarize_importance} 6 | \title{Append a tibble of variable importances returned by \code{vi_boots()} with upper 7 | and lower bounds.} 8 | \usage{ 9 | summarise_importance(.data, interval_width = 0.95, conf = NULL) 10 | 11 | summarize_importance(.data, interval_width = 0.95, conf = NULL) 12 | } 13 | \arguments{ 14 | \item{.data}{a tibble of variable importances returned by \code{vi_boots()}.} 15 | 16 | \item{interval_width}{a value between (0, 1) specifying the interval range.} 17 | 18 | \item{conf}{deprecated - please use \code{interval_width} instead.} 19 | } 20 | \description{ 21 | Append a tibble of variable importances returned by \code{vi_boots()} with upper 22 | and lower bounds. 23 | } 24 | \details{ 25 | Generates a summary of variable importances with an upper and lower 26 | interval range. Uses the \code{vi()} function from the \code{{vip}} package to compute 27 | variable importances (not all model types are supported by \code{vip::vi()}; please 28 | refer to \code{{vip}} package documentation for supported model types). Presently, 29 | the \code{quantile()} function from the \code{{stats}} package is used to determine 30 | the lower, 50th percentile, and upper interval ranges. 31 | } 32 | \examples{ 33 | \dontrun{ 34 | library(tidymodels) 35 | 36 | # setup a workflow without fitting 37 | wf <- 38 | workflow() \%>\% 39 | add_recipe(recipe(qsec ~ wt, data = mtcars)) \%>\% 40 | add_model(linear_reg()) 41 | 42 | # evaluate variable importance from 2000 models fit to mtcars 43 | set.seed(123) 44 | importances <- 45 | wf \%>\% 46 | vi_boots(n = 2000, training_data = mtcars, new_data = mtcars) 47 | 48 | # append with lower and upper bound importance summary columns 49 | importances \%>\% 50 | summarise_importance(interval_width = 0.95) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /man/summarise_predictions.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/summarise_boots.R 3 | \name{summarise_predictions} 4 | \alias{summarise_predictions} 5 | \alias{`summarize_predictions()`} 6 | \alias{summarize_predictions} 7 | \title{Append a tibble of predictions returned by \code{predict_boots()} with upper and 8 | lower bounds.} 9 | \usage{ 10 | summarise_predictions(.data, interval_width = 0.95, conf = NULL) 11 | 12 | summarize_predictions(.data, interval_width = 0.95, conf = NULL) 13 | } 14 | \arguments{ 15 | \item{.data}{a tibble of predictions returned by \code{predict_boots()}.} 16 | 17 | \item{interval_width}{a value between (0, 1) specifying the interval range.} 18 | 19 | \item{conf}{deprecated - please use \code{interval_width} instead.} 20 | } 21 | \value{ 22 | Appends the tibble of predictions returned by \code{predict_boots()} with 23 | three new columns: \code{.pred_lower}, \code{.pred}, and \code{.pred_upper}. 24 | } 25 | \description{ 26 | Append a tibble of predictions returned by \code{predict_boots()} with upper and 27 | lower bounds. 28 | } 29 | \details{ 30 | Generates a summary of predictions with a upper and lower interval 31 | range. Presently, the \code{quantile()} function from the \code{{stats}} package is 32 | used to determine the lower, 50th percentile, and upper interval ranges. 33 | } 34 | \examples{ 35 | \dontrun{ 36 | library(tidymodels) 37 | 38 | # setup a workflow without fitting 39 | wf <- 40 | workflow() \%>\% 41 | add_recipe(recipe(qsec ~ wt, data = mtcars)) \%>\% 42 | add_model(linear_reg()) 43 | 44 | # fit and predict 2000 bootstrap resampled models to mtcars 45 | set.seed(123) 46 | preds <- 47 | wf \%>\% 48 | predict_boots(n = 2000, training_data = mtcars, new_data = mtcars) 49 | 50 | # append with prediction interval summary columns 51 | preds \%>\% 52 | summarise_predictions(conf = 0.95) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /man/vi_boots.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/vi_boots.R 3 | \name{vi_boots} 4 | \alias{vi_boots} 5 | \title{Fit and estimate variable importance from a workflow using many bootstrap resamples.} 6 | \usage{ 7 | vi_boots(workflow, n = 2000, training_data, verbose = FALSE, ...) 8 | } 9 | \arguments{ 10 | \item{workflow}{An un-fitted workflow object.} 11 | 12 | \item{n}{An integer for the number of bootstrap resampled models that will be created.} 13 | 14 | \item{training_data}{A tibble or dataframe of data to be resampled and used for training.} 15 | 16 | \item{verbose}{A logical. Defaults to \code{FALSE}. If set to \code{TRUE}, prints progress 17 | of training to console.} 18 | 19 | \item{...}{Additional params passed to \code{rsample::bootstraps()}.} 20 | } 21 | \value{ 22 | A tibble with a column indicating each variable in the model and a 23 | nested list of variable importances for each variable. The shape of the list 24 | may vary by model type. For example, linear models return two nested columns: 25 | the absolute value of each variable's importance and the sign (POS/NEG), 26 | whereas tree-based models return a single nested column of variable importance. 27 | Similarly, the number of nested rows may vary by model type as some models 28 | may not utilize every possible predictor. 29 | } 30 | \description{ 31 | Generate variable importances from a tidymodel workflow using bootstrap resampling. 32 | \code{vi_boots()} generates \code{n} bootstrap resamples, fits a model to each (creating 33 | \code{n} models), then creates \code{n} estimates of variable importance for each variable 34 | in the model. 35 | } 36 | \details{ 37 | Since \code{vi_boots()} fits a new model to each resample, the 38 | argument \code{workflow} must not yet be fit. Any tuned hyperparameters must be 39 | finalized prior to calling \code{vi_boots()}. 40 | } 41 | \examples{ 42 | \dontrun{ 43 | library(tidymodels) 44 | 45 | # setup a workflow without fitting 46 | wf <- 47 | workflow() \%>\% 48 | add_recipe(recipe(qsec ~ wt, data = mtcars)) \%>\% 49 | add_model(linear_reg()) 50 | 51 | # fit and estimate variable importance from 125 bootstrap resampled models 52 | set.seed(123) 53 | wf \%>\% 54 | vi_boots(n = 2000, training_data = mtcars) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-120x120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/pkgdown/favicon/apple-touch-icon-120x120.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-152x152.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/pkgdown/favicon/apple-touch-icon-152x152.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-180x180.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/pkgdown/favicon/apple-touch-icon-180x180.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-60x60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/pkgdown/favicon/apple-touch-icon-60x60.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-76x76.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/pkgdown/favicon/apple-touch-icon-76x76.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/pkgdown/favicon/apple-touch-icon.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/pkgdown/favicon/favicon-16x16.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/pkgdown/favicon/favicon-32x32.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/pkgdown/favicon/favicon.ico -------------------------------------------------------------------------------- /revdep/README.md: -------------------------------------------------------------------------------- 1 | # Platform 2 | 3 | |field |value | 4 | |:--------|:------------------------------| 5 | |version |R version 4.1.0 (2021-05-18) | 6 | |os |Windows 10 x64 (build 18363) | 7 | |system |x86_64, mingw32 | 8 | |ui |RStudio | 9 | |language |(EN) | 10 | |collate |English_United States.1252 | 11 | |ctype |English_United States.1252 | 12 | |tz |America/Chicago | 13 | |date |2022-04-11 | 14 | |rstudio |1.4.1717 Juliet Rose (desktop) | 15 | |pandoc |NA | 16 | 17 | # Dependencies 18 | 19 | |package |old |new | | 20 | |:------------|:-----|:------|:--| 21 | |workboots |0.1.0 |0.1.1 |* | 22 | |magrittr |NA |2.0.3 |* | 23 | |parallelly |NA |1.31.0 |* | 24 | |RColorBrewer |NA |1.1-3 |* | 25 | |vctrs |NA |0.4.0 |* | 26 | 27 | # Revdeps 28 | 29 | -------------------------------------------------------------------------------- /revdep/checks/libraries.csv: -------------------------------------------------------------------------------- 1 | package,old,new,delta 2 | workboots,0.1.0,0.1.1,* 3 | magrittr,NA,2.0.3,* 4 | parallelly,NA,1.31.0,* 5 | RColorBrewer,NA,1.1-3,* 6 | vctrs,NA,0.4.0,* 7 | -------------------------------------------------------------------------------- /revdep/cran.md: -------------------------------------------------------------------------------- 1 | ## revdepcheck results 2 | 3 | We checked 0 reverse dependencies, comparing R CMD check results across CRAN and dev versions of this package. 4 | 5 | * We saw 0 new problems 6 | * We failed to check 0 packages 7 | 8 | -------------------------------------------------------------------------------- /revdep/data.sqlite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/revdep/data.sqlite -------------------------------------------------------------------------------- /revdep/failures.md: -------------------------------------------------------------------------------- 1 | *Wow, no problems at all. :)* -------------------------------------------------------------------------------- /revdep/problems.md: -------------------------------------------------------------------------------- 1 | *Wow, no problems at all. :)* -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(workboots) 3 | 4 | test_check("workboots") 5 | -------------------------------------------------------------------------------- /tests/testthat/data/test_data_bad.csv: -------------------------------------------------------------------------------- 1 | mpg,cyl,disp,hp,drat,wt,qsec,vs,am,gear,carb 2 | 21,6,160,110,3.9,2.62,16.46,0,1,4,4 3 | 21,6,160,110,3.9,2.875,17.02,0,1,4,4 4 | 22.8,4,108,93,3.85,2.32,18.61,1,1,4,1 5 | 21.4,6,258,110,3.08,3.215,19.44,1,0,3,1 6 | 18.7,8,360,175,3.15,3.44,17.02,0,0,3,2 7 | 18.1,6,225,105,2.76,3.46,20.22,1,0,3,1 8 | 14.3,8,360,245,3.21,3.57,15.84,0,0,3,4 9 | 24.4,4,146.7,62,3.69,3.19,20,1,0,4,2 10 | 22.8,4,140.8,95,3.92,3.15,22.9,1,0,4,2 11 | 19.2,6,167.6,123,3.92,3.44,18.3,1,0,4,4 12 | 17.8,6,167.6,123,3.92,3.44,18.9,1,0,4,4 13 | 16.4,8,275.8,180,3.07,4.07,17.4,0,0,3,3 14 | 17.3,8,275.8,180,3.07,3.73,17.6,0,0,3,3 15 | 15.2,8,275.8,180,3.07,3.78,18,0,0,3,3 16 | 10.4,8,472,205,2.93,5.25,17.98,0,0,3,4 17 | 10.4,8,460,215,3,5.424,17.82,0,0,3,4 18 | 14.7,8,440,230,3.23,5.345,17.42,0,0,3,4 19 | 32.4,4,78.7,66,4.08,2.2,19.47,1,1,4,1 20 | 30.4,4,75.7,52,4.93,1.615,18.52,1,1,4,2 21 | 33.9,4,71.1,65,4.22,1.835,19.9,1,1,4,1 22 | 21.5,4,120.1,97,3.7,2.465,20.01,1,0,3,1 23 | 15.5,8,318,150,2.76,3.52,16.87,0,0,3,2 24 | 15.2,8,304,150,3.15,3.435,17.3,0,0,3,2 25 | 13.3,8,350,245,3.73,3.84,15.41,0,0,3,4 26 | 19.2,8,400,175,3.08,3.845,17.05,0,0,3,2 27 | 27.3,4,79,66,4.08,1.935,18.9,1,1,4,1 28 | 26,4,120.3,91,4.43,2.14,16.7,0,1,5,2 29 | 30.4,4,95.1,113,3.77,1.513,16.9,1,1,5,2 30 | 15.8,8,351,264,4.22,3.17,14.5,0,1,5,4 31 | 19.7,6,145,175,3.62,2.77,15.5,0,1,5,6 32 | 15,8,301,335,3.54,3.57,14.6,0,1,5,8 33 | 21.4,4,121,109,4.11,2.78,18.6,1,1,4,2 34 | -------------------------------------------------------------------------------- /tests/testthat/data/test_importances.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/tests/testthat/data/test_importances.rds -------------------------------------------------------------------------------- /tests/testthat/data/test_preds.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/tests/testthat/data/test_preds.rds -------------------------------------------------------------------------------- /tests/testthat/data/test_test.csv: -------------------------------------------------------------------------------- 1 | species,island,bill_length_mm,bill_depth_mm,flipper_length_mm,body_mass_g,sex 2 | Adelie,Torgersen,39.5,17.4,186,3800,female 3 | Adelie,Torgersen,40.3,18,195,3250,female 4 | Adelie,Torgersen,38.7,19,195,3450,female 5 | Adelie,Torgersen,46,21.5,194,4200,male 6 | Adelie,Biscoe,38.2,18.1,185,3950,male 7 | Adelie,Dream,39.5,17.8,188,3300,female 8 | Adelie,Dream,36,18.5,186,3100,female 9 | Adelie,Dream,42.3,21.2,191,4150,male 10 | Adelie,Biscoe,35,17.9,190,3450,female 11 | Adelie,Biscoe,34.5,18.1,187,2900,female 12 | Adelie,Biscoe,41.4,18.6,191,3700,male 13 | Adelie,Biscoe,41.3,21.1,195,4400,male 14 | Adelie,Biscoe,41.1,18.2,192,4050,male 15 | Adelie,Biscoe,41.1,19.1,188,4100,male 16 | Adelie,Torgersen,33.5,19,190,3600,female 17 | Adelie,Torgersen,45.8,18.9,197,4150,male 18 | Adelie,Torgersen,36.2,16.1,187,3550,female 19 | Adelie,Dream,41.3,20.3,194,3550,male 20 | Adelie,Dream,36.9,18.6,189,3500,female 21 | Adelie,Dream,34,17.1,185,3400,female 22 | Adelie,Dream,40.3,18.5,196,4350,male 23 | Adelie,Biscoe,35,17.9,192,3725,female 24 | Adelie,Biscoe,41,20,203,4725,male 25 | Adelie,Biscoe,39.7,18.9,184,3550,male 26 | Adelie,Biscoe,38.1,17,181,3175,female 27 | Adelie,Torgersen,35.2,15.9,186,3050,female 28 | Adelie,Torgersen,41.5,18.3,195,4300,male 29 | Adelie,Torgersen,44.1,18,210,4000,male 30 | Adelie,Torgersen,43.1,19.2,197,3500,male 31 | Adelie,Dream,37.5,18.5,199,4475,male 32 | Adelie,Dream,41.1,17.5,190,3900,male 33 | Adelie,Dream,40.2,20.1,200,3975,male 34 | Adelie,Dream,37,16.5,185,3400,female 35 | Adelie,Dream,40.7,17,190,3725,male 36 | Adelie,Dream,39,18.7,185,3650,male 37 | Adelie,Dream,36,17.1,187,3700,female 38 | Adelie,Dream,41.5,18.5,201,4000,male 39 | Gentoo,Biscoe,50,16.3,230,5700,male 40 | Gentoo,Biscoe,50,15.2,218,5700,male 41 | Gentoo,Biscoe,45.8,14.6,210,4200,female 42 | Gentoo,Biscoe,42.9,13.1,215,5000,female 43 | Gentoo,Biscoe,47.8,15,215,5650,male 44 | Gentoo,Biscoe,48.4,16.3,220,5400,male 45 | Gentoo,Biscoe,42.6,13.7,213,4950,female 46 | Gentoo,Biscoe,49.6,16,225,5700,male 47 | Gentoo,Biscoe,49.6,15,216,4750,male 48 | Gentoo,Biscoe,43.6,13.9,217,4900,female 49 | Gentoo,Biscoe,45.5,15,220,5000,male 50 | Gentoo,Biscoe,50.4,15.3,224,5550,male 51 | Gentoo,Biscoe,45.3,13.8,208,4200,female 52 | Gentoo,Biscoe,45.7,13.9,214,4400,female 53 | Gentoo,Biscoe,46.4,15,216,4700,female 54 | Gentoo,Biscoe,51.1,16.3,220,6000,male 55 | Gentoo,Biscoe,49.1,14.5,212,4625,female 56 | Gentoo,Biscoe,50,15.9,224,5350,male 57 | Gentoo,Biscoe,43.4,14.4,218,4600,female 58 | Gentoo,Biscoe,52.1,17,230,5550,male 59 | Gentoo,Biscoe,49.4,15.8,216,4925,male 60 | Gentoo,Biscoe,49.1,15,228,5500,male 61 | Gentoo,Biscoe,43.3,14,208,4575,female 62 | Gentoo,Biscoe,48.1,15.1,209,5500,male 63 | Gentoo,Biscoe,48.8,16.2,222,6000,male 64 | Gentoo,Biscoe,49.9,16.1,213,5400,male 65 | Chinstrap,Dream,51.3,19.2,193,3650,male 66 | Chinstrap,Dream,51.3,19.9,198,3700,male 67 | Chinstrap,Dream,46.6,17.8,193,3800,female 68 | Chinstrap,Dream,47,17.3,185,3700,female 69 | Chinstrap,Dream,58,17.8,181,3700,female 70 | Chinstrap,Dream,46.4,18.6,190,3450,female 71 | Chinstrap,Dream,49.2,18.2,195,4400,male 72 | Chinstrap,Dream,48.5,17.5,191,3400,male 73 | Chinstrap,Dream,50.6,19.4,193,3800,male 74 | Chinstrap,Dream,46.4,17.8,191,3700,female 75 | Chinstrap,Dream,52.8,20,205,4550,male 76 | Chinstrap,Dream,54.2,20.8,201,4300,male 77 | Chinstrap,Dream,42.5,16.7,187,3350,female 78 | Chinstrap,Dream,47.6,18.3,195,3850,female 79 | Chinstrap,Dream,45.5,17,196,3500,female 80 | Chinstrap,Dream,50.9,17.9,196,3675,female 81 | Chinstrap,Dream,50.1,17.9,190,3400,female 82 | Chinstrap,Dream,51.5,18.7,187,3250,male 83 | Chinstrap,Dream,52.2,18.8,197,3450,male 84 | Chinstrap,Dream,51.9,19.5,206,3950,male 85 | Chinstrap,Dream,50.2,18.7,198,3775,female 86 | -------------------------------------------------------------------------------- /tests/testthat/data/test_train.csv: -------------------------------------------------------------------------------- 1 | species,island,bill_length_mm,bill_depth_mm,flipper_length_mm,body_mass_g,sex 2 | Gentoo,Biscoe,59.6,17,230,6050,male 3 | Adelie,Torgersen,34.4,18.4,184,3325,female 4 | Gentoo,Biscoe,45.2,15.8,215,5300,male 5 | Chinstrap,Dream,49,19.5,210,3950,male 6 | Adelie,Torgersen,41.4,18.5,202,3875,male 7 | Chinstrap,Dream,51,18.8,203,4100,male 8 | Gentoo,Biscoe,44.9,13.8,212,4750,female 9 | Gentoo,Biscoe,51.1,16.5,225,5250,male 10 | Chinstrap,Dream,50.8,19,210,4100,male 11 | Gentoo,Biscoe,45.4,14.6,211,4800,female 12 | Adelie,Dream,40.8,18.9,208,4300,male 13 | Adelie,Dream,38.1,18.6,190,3700,female 14 | Gentoo,Biscoe,43.5,15.2,213,4650,female 15 | Gentoo,Biscoe,48.5,14.1,220,5300,male 16 | Chinstrap,Dream,45.2,16.6,191,3250,female 17 | Adelie,Dream,32.1,15.5,188,3050,female 18 | Adelie,Dream,39.5,16.7,178,3250,female 19 | Adelie,Torgersen,39.2,19.6,195,4675,male 20 | Chinstrap,Dream,45.7,17.3,193,3600,female 21 | Gentoo,Biscoe,50.5,15.2,216,5000,female 22 | Gentoo,Biscoe,49.8,16.8,230,5700,male 23 | Adelie,Torgersen,35.1,19.4,193,4200,male 24 | Adelie,Dream,36.3,19.5,190,3800,male 25 | Adelie,Dream,36,17.9,190,3450,female 26 | Adelie,Dream,36,17.8,195,3450,female 27 | Adelie,Dream,38.8,20,190,3950,male 28 | Adelie,Biscoe,39.6,20.7,191,3900,female 29 | Gentoo,Biscoe,50.4,15.7,222,5750,male 30 | Adelie,Biscoe,40.5,17.9,187,3200,female 31 | Adelie,Dream,40.2,17.1,193,3400,female 32 | Gentoo,Biscoe,45.2,16.4,223,5950,male 33 | Gentoo,Biscoe,48.7,15.1,222,5350,male 34 | Gentoo,Biscoe,48.2,15.6,221,5100,male 35 | Chinstrap,Dream,46.7,17.9,195,3300,female 36 | Adelie,Torgersen,35.5,17.5,190,3700,female 37 | Adelie,Torgersen,37.2,19.4,184,3900,male 38 | Adelie,Torgersen,42.9,17.6,196,4700,male 39 | Adelie,Torgersen,35.9,16.6,190,3050,female 40 | Adelie,Dream,39.2,18.6,190,4250,male 41 | Gentoo,Biscoe,45.8,14.2,219,4700,female 42 | Chinstrap,Dream,50.2,18.8,202,3800,male 43 | Chinstrap,Dream,51.7,20.3,194,3775,male 44 | Adelie,Dream,39.6,18.8,190,4600,male 45 | Chinstrap,Dream,49.3,19.9,203,4050,male 46 | Gentoo,Biscoe,45.2,13.8,215,4750,female 47 | Adelie,Biscoe,37.8,18.3,174,3400,female 48 | Adelie,Torgersen,37.7,19.8,198,3500,male 49 | Adelie,Dream,43.2,18.5,192,4100,male 50 | Gentoo,Biscoe,46.8,14.3,215,4850,female 51 | Gentoo,Biscoe,47.5,15,218,4950,female 52 | Adelie,Dream,41.1,18.1,205,4300,male 53 | Adelie,Dream,44.1,19.7,196,4400,male 54 | Gentoo,Biscoe,45.5,13.7,214,4650,female 55 | Gentoo,Biscoe,50.8,15.7,226,5200,male 56 | Gentoo,Biscoe,54.3,15.7,231,5650,male 57 | Adelie,Dream,37.6,19.3,181,3300,female 58 | Adelie,Torgersen,36.7,19.3,193,3450,female 59 | Adelie,Torgersen,42.5,20.7,197,4500,male 60 | Chinstrap,Dream,45.7,17,195,3650,female 61 | Gentoo,Biscoe,48.4,14.4,203,4625,female 62 | Adelie,Dream,36.2,17.3,187,3300,female 63 | Adelie,Biscoe,40.5,18.9,180,3950,male 64 | Chinstrap,Dream,50.5,19.6,201,4050,male 65 | Adelie,Torgersen,38.8,17.6,191,3275,female 66 | Adelie,Biscoe,42.7,18.3,196,4075,male 67 | Gentoo,Biscoe,49,16.1,216,5550,male 68 | Adelie,Torgersen,41.8,19.4,198,4450,male 69 | Gentoo,Biscoe,50.1,15,225,5000,male 70 | Adelie,Torgersen,39.6,17.2,196,3550,female 71 | Gentoo,Biscoe,47.6,14.5,215,5400,male 72 | Adelie,Dream,35.7,18,202,3550,female 73 | Gentoo,Biscoe,46.2,14.5,209,4800,female 74 | Adelie,Dream,40.6,17.2,187,3475,male 75 | Adelie,Biscoe,39,17.5,186,3550,female 76 | Adelie,Torgersen,42.1,19.1,195,4000,male 77 | Gentoo,Biscoe,45.1,14.5,207,5050,female 78 | Gentoo,Biscoe,52.2,17.1,228,5400,male 79 | Adelie,Biscoe,37.8,20,190,4250,male 80 | Gentoo,Biscoe,55.1,16,230,5850,male 81 | Gentoo,Biscoe,50.7,15,223,5550,male 82 | Adelie,Dream,36.8,18.5,193,3500,female 83 | Gentoo,Biscoe,43.2,14.5,208,4450,female 84 | Gentoo,Biscoe,49.5,16.2,229,5800,male 85 | Gentoo,Biscoe,48.2,14.3,210,4600,female 86 | Adelie,Biscoe,37.7,18.7,180,3600,male 87 | Gentoo,Biscoe,50.5,15.9,222,5550,male 88 | Adelie,Biscoe,40.1,18.9,188,4300,male 89 | Adelie,Biscoe,37.6,19.1,194,3750,male 90 | Gentoo,Biscoe,42,13.5,210,4150,female 91 | Gentoo,Biscoe,51.3,14.2,218,5300,male 92 | Chinstrap,Dream,45.4,18.7,188,3525,female 93 | Chinstrap,Dream,42.4,17.3,181,3600,female 94 | Adelie,Dream,37.3,17.8,191,3350,female 95 | Adelie,Biscoe,37.9,18.6,172,3150,female 96 | Adelie,Torgersen,35.7,17,189,3350,female 97 | Gentoo,Biscoe,47.7,15,216,4750,female 98 | Adelie,Biscoe,39.7,17.7,193,3200,female 99 | Chinstrap,Dream,46.9,16.6,192,2700,female 100 | Gentoo,Biscoe,47.5,14.2,209,4600,female 101 | Gentoo,Biscoe,46.2,14.9,221,5300,male 102 | Gentoo,Biscoe,45.2,14.8,212,5200,female 103 | Gentoo,Biscoe,46.7,15.3,219,5200,male 104 | Adelie,Biscoe,38.2,20,190,3900,male 105 | Gentoo,Biscoe,48.4,14.6,213,5850,male 106 | Gentoo,Biscoe,51.5,16.3,230,5500,male 107 | Gentoo,Biscoe,43.3,13.4,209,4400,female 108 | Gentoo,Biscoe,46.3,15.8,215,5050,male 109 | Adelie,Torgersen,39.3,20.6,190,3650,male 110 | Adelie,Torgersen,42.8,18.5,195,4250,male 111 | Chinstrap,Dream,43.2,16.6,187,2900,female 112 | Chinstrap,Dream,45.9,17.1,190,3575,female 113 | Chinstrap,Dream,50.9,19.1,196,3550,male 114 | Adelie,Biscoe,35.3,18.9,187,3800,female 115 | Adelie,Biscoe,35.7,16.9,185,3150,female 116 | Adelie,Torgersen,34.6,17.2,189,3200,female 117 | Gentoo,Biscoe,49.2,15.2,221,6300,male 118 | Adelie,Dream,38.3,19.2,189,3950,male 119 | Chinstrap,Dream,50.3,20,197,3300,male 120 | Gentoo,Biscoe,55.9,17,228,5600,male 121 | Chinstrap,Dream,49.8,17.3,198,3675,female 122 | Chinstrap,Dream,46.5,17.9,192,3500,female 123 | Adelie,Biscoe,42,19.5,200,4050,male 124 | Adelie,Torgersen,36.7,18.8,187,3800,female 125 | Chinstrap,Dream,47.5,16.8,199,3900,female 126 | Gentoo,Biscoe,44,13.6,208,4350,female 127 | Adelie,Torgersen,38.6,17,188,2900,female 128 | Gentoo,Biscoe,46.4,15.6,221,5000,male 129 | Gentoo,Biscoe,46.6,14.2,210,4850,female 130 | Adelie,Torgersen,39.1,18.7,181,3750,male 131 | Gentoo,Biscoe,45.1,14.5,215,5000,female 132 | Adelie,Dream,36.4,17,195,3325,female 133 | Gentoo,Biscoe,46.9,14.6,222,4875,female 134 | Chinstrap,Dream,52,20.7,210,4800,male 135 | Adelie,Dream,39.6,18.1,186,4450,male 136 | Gentoo,Biscoe,49.5,16.1,224,5650,male 137 | Gentoo,Biscoe,47.3,15.3,222,5250,male 138 | Adelie,Biscoe,38.8,17.2,180,3800,male 139 | Adelie,Biscoe,43.2,19,197,4775,male 140 | Chinstrap,Dream,49.6,18.2,193,3775,male 141 | Adelie,Dream,33.1,16.1,178,2900,female 142 | Gentoo,Biscoe,46.5,14.8,217,5200,female 143 | Adelie,Biscoe,40.6,18.8,193,3800,male 144 | Adelie,Biscoe,42.2,19.5,197,4275,male 145 | Adelie,Biscoe,40.6,18.6,183,3550,male 146 | Gentoo,Biscoe,42.8,14.2,209,4700,female 147 | Adelie,Dream,41.1,19,182,3425,male 148 | Adelie,Biscoe,36.4,17.1,184,2850,female 149 | Adelie,Dream,38.9,18.8,190,3600,female 150 | Adelie,Torgersen,36.6,17.8,185,3700,female 151 | Chinstrap,Dream,52.7,19.8,197,3725,male 152 | Gentoo,Biscoe,44.4,17.3,219,5250,male 153 | Gentoo,Biscoe,47.2,13.7,214,4925,female 154 | Gentoo,Biscoe,47.2,15.5,215,4975,female 155 | Chinstrap,Dream,43.5,18.1,202,3400,female 156 | Chinstrap,Dream,46.2,17.5,187,3650,female 157 | Gentoo,Biscoe,45.5,14.5,212,4750,female 158 | Chinstrap,Dream,49,19.6,212,4300,male 159 | Gentoo,Biscoe,41.7,14.7,210,4700,female 160 | Gentoo,Biscoe,46.1,15.1,215,5100,male 161 | Gentoo,Biscoe,50.8,17.3,228,5600,male 162 | Adelie,Dream,36.5,18,182,3150,female 163 | Adelie,Dream,36.6,18.4,184,3475,female 164 | Gentoo,Biscoe,45,15.4,220,5050,male 165 | Adelie,Torgersen,38.5,17.9,190,3325,female 166 | Adelie,Dream,42.2,18.5,180,3550,female 167 | Adelie,Dream,37,16.9,185,3000,female 168 | Adelie,Torgersen,34.6,21.1,198,4400,male 169 | Gentoo,Biscoe,47.4,14.6,212,4725,female 170 | Adelie,Torgersen,38.6,21.2,191,3800,male 171 | Chinstrap,Dream,48.1,16.4,199,3325,female 172 | Chinstrap,Dream,46.8,16.5,189,3650,female 173 | Adelie,Biscoe,35.5,16.2,195,3350,female 174 | Gentoo,Biscoe,46.5,13.5,210,4550,female 175 | Chinstrap,Dream,40.9,16.6,187,3200,female 176 | Gentoo,Biscoe,43.5,14.2,220,4700,female 177 | Gentoo,Biscoe,47.5,14,212,4875,female 178 | Chinstrap,Dream,51.4,19,201,3950,male 179 | Gentoo,Biscoe,46.5,14.5,213,4400,female 180 | Adelie,Torgersen,36.2,17.2,187,3150,female 181 | Gentoo,Biscoe,50.2,14.3,218,5700,male 182 | Adelie,Biscoe,36.5,16.6,181,2850,female 183 | Chinstrap,Dream,45.2,17.8,198,3950,female 184 | Gentoo,Biscoe,46.2,14.1,217,4375,female 185 | Gentoo,Biscoe,48.7,15.7,208,5350,male 186 | Chinstrap,Dream,46.1,18.2,178,3250,female 187 | Gentoo,Biscoe,45.5,13.9,210,4200,female 188 | Gentoo,Biscoe,40.9,13.7,214,4650,female 189 | Gentoo,Biscoe,48.5,15,219,4850,female 190 | Chinstrap,Dream,49.7,18.6,195,3600,male 191 | Adelie,Biscoe,37.6,17,185,3600,female 192 | Adelie,Biscoe,38.1,16.5,198,3825,female 193 | Adelie,Biscoe,38.6,17.2,199,3750,female 194 | Chinstrap,Dream,53.5,19.9,205,4500,male 195 | Chinstrap,Dream,45.6,19.4,194,3525,female 196 | Chinstrap,Dream,46,18.9,195,4150,female 197 | Adelie,Dream,39.7,17.9,193,4250,male 198 | Chinstrap,Dream,50.7,19.7,203,4050,male 199 | Adelie,Dream,38.1,17.6,187,3425,female 200 | Adelie,Biscoe,45.6,20.3,191,4600,male 201 | Gentoo,Biscoe,44.5,14.7,214,4850,female 202 | Chinstrap,Dream,50.8,18.5,201,4450,male 203 | Adelie,Dream,40.9,18.9,184,3900,male 204 | Gentoo,Biscoe,46.1,13.2,211,4500,female 205 | Adelie,Dream,37.2,18.1,178,3900,male 206 | Gentoo,Biscoe,50.5,15.9,225,5400,male 207 | Adelie,Biscoe,41.6,18,192,3950,male 208 | Adelie,Torgersen,40.2,17,176,3450,female 209 | Adelie,Dream,35.6,17.5,191,3175,female 210 | Chinstrap,Dream,52,19,197,4150,male 211 | Gentoo,Biscoe,42.7,13.7,208,3950,female 212 | Adelie,Torgersen,38.9,17.8,181,3625,female 213 | Adelie,Torgersen,41.1,18.6,189,3325,male 214 | Gentoo,Biscoe,43.8,13.9,208,4300,female 215 | Adelie,Dream,39.2,21.1,196,4150,male 216 | Gentoo,Biscoe,44.9,13.3,213,5100,female 217 | Gentoo,Biscoe,46.8,16.1,215,5500,male 218 | Chinstrap,Dream,50,19.5,196,3900,male 219 | Adelie,Dream,40.8,18.4,195,3900,male 220 | Adelie,Dream,37.8,18.1,193,3750,male 221 | Chinstrap,Dream,42.5,17.3,187,3350,female 222 | Adelie,Dream,39.8,19.1,184,4650,male 223 | Chinstrap,Dream,51.3,18.2,197,3750,male 224 | Adelie,Dream,37.3,16.8,192,3000,female 225 | Gentoo,Biscoe,46.8,15.4,215,5150,male 226 | Adelie,Torgersen,39.7,18.4,190,3900,male 227 | Chinstrap,Dream,50.5,18.4,200,3400,female 228 | Gentoo,Biscoe,49.8,15.9,229,5950,male 229 | Adelie,Biscoe,37.7,16,183,3075,female 230 | Gentoo,Biscoe,52.5,15.6,221,5450,male 231 | Gentoo,Biscoe,48.7,14.1,210,4450,female 232 | Adelie,Torgersen,41.1,17.6,182,3200,female 233 | Gentoo,Biscoe,48.6,16,230,5800,male 234 | Adelie,Torgersen,37.3,20.5,199,3775,male 235 | Gentoo,Biscoe,49.1,14.8,220,5150,female 236 | Adelie,Biscoe,39.6,17.7,186,3500,female 237 | Adelie,Biscoe,35.9,19.2,189,3800,female 238 | Gentoo,Biscoe,45.3,13.7,210,4300,female 239 | Adelie,Torgersen,40.6,19,199,4000,male 240 | Adelie,Torgersen,39,17.1,191,3050,female 241 | Adelie,Torgersen,40.9,16.8,191,3700,female 242 | Gentoo,Biscoe,49.3,15.7,217,5850,male 243 | Gentoo,Biscoe,53.4,15.8,219,5500,male 244 | Chinstrap,Dream,55.8,19.8,207,4000,male 245 | Adelie,Biscoe,37.9,18.6,193,2925,female 246 | Gentoo,Biscoe,45.1,14.4,210,4400,female 247 | Chinstrap,Dream,49.5,19,200,3800,male 248 | Gentoo,Biscoe,46.5,14.4,217,4900,female 249 | Gentoo,Biscoe,50,15.3,220,5550,male 250 | Chinstrap,Dream,52,18.1,201,4050,male 251 | -------------------------------------------------------------------------------- /tests/testthat/data/test_wf.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/tests/testthat/data/test_wf.rds -------------------------------------------------------------------------------- /tests/testthat/data/test_wf_bad.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markjrieke/workboots/be558ce5f3971bdb4fdde92e94c450f44121c3c0/tests/testthat/data/test_wf_bad.rds -------------------------------------------------------------------------------- /tests/testthat/test-predict-boots.R: -------------------------------------------------------------------------------- 1 | # read in data to use in tests 2 | # test_wf: wf using xgboost to predict body_mass_g from all predictors in the 3 | # palmer penguins dataset. one recipe step - step_dummy(all_nominal()) 4 | # test_train: training df of palmer penguins 5 | # test_test: testing df of palmer penguins 6 | test_wf <- readRDS("data/test_wf.rds") 7 | test_train <- read.csv("data/test_train.csv") 8 | test_test <- read.csv("data/test_test.csv") 9 | 10 | test_that("predict_boots() returns prediction interval in expected format", { 11 | 12 | # generate predictions 13 | expect_warning( 14 | x <- 15 | predict_boots( 16 | workflow = test_wf, 17 | n = 5, 18 | training_data = test_train, 19 | new_data = test_test 20 | ), 21 | 22 | "At least 2000 resamples recommended for stable results." 23 | ) 24 | 25 | # tests 26 | expect_s3_class(x, c("tbl_df", "tbl", "data.frame")) 27 | expect_named(x, c("rowid", ".preds")) 28 | expect_named(x$.preds[[1]], c("model", "model.pred")) 29 | expect_type(x$rowid, "integer") 30 | expect_type(x$.preds, "list") 31 | expect_type(x$.preds[[1]]$model, "character") 32 | expect_type(x$.preds[[1]]$model.pred, "double") 33 | expect_equal(nrow(x), nrow(test_test)) 34 | 35 | }) 36 | 37 | test_that("predict_boots() returns confidence interval in expected format", { 38 | 39 | # generate predictions 40 | expect_warning( 41 | x <- 42 | predict_boots( 43 | workflow = test_wf, 44 | n = 5, 45 | training_data = test_train, 46 | new_data = test_test, 47 | interval = "confidence" 48 | ), 49 | 50 | "At least 2000 resamples recommended for stable results." 51 | ) 52 | 53 | # tests 54 | expect_s3_class(x, c("tbl_df", "tbl", "data.frame")) 55 | expect_named(x, c("rowid", ".preds")) 56 | expect_named(x$.preds[[1]], c("model", "model.pred")) 57 | expect_type(x$rowid, "integer") 58 | expect_type(x$.preds, "list") 59 | expect_type(x$.preds[[1]]$model, "character") 60 | expect_type(x$.preds[[1]]$model.pred, "double") 61 | expect_equal(nrow(x), nrow(test_test)) 62 | 63 | }) 64 | 65 | test_that("predict_boots() throws an error when not passed a workflow", { 66 | 67 | expect_error( 68 | predict_boots( 69 | workflow = test_train, 70 | n = 1, 71 | training_data = test_train, 72 | new_data = test_test 73 | ), 74 | 75 | "argument `workflow` must be of class \"workflow\"." 76 | ) 77 | 78 | }) 79 | 80 | test_that("predict_boots() throws an error when workflow is not final", { 81 | 82 | # load bad wf - same as test_wf but has 1 non-final tuning param 83 | test_wf_bad <- readRDS("data/test_wf_bad.rds") 84 | 85 | expect_error( 86 | predict_boots( 87 | workflow = test_wf_bad, 88 | n = 1, 89 | training_data = test_train, 90 | new_data = test_test 91 | ), 92 | 93 | "all tuning parameters must be final." 94 | ) 95 | 96 | }) 97 | 98 | test_that("predict_boots() throws an error when bad n is specified", { 99 | 100 | expect_error( 101 | predict_boots( 102 | workflow = test_wf, 103 | n = 0, 104 | training_data = test_train, 105 | new_data = test_test 106 | ), 107 | 108 | "argument `n` must be >= 1." 109 | ) 110 | 111 | expect_error( 112 | predict_boots( 113 | workflow = test_wf, 114 | n = 1.5, 115 | training_data = test_train, 116 | new_data = test_test 117 | ), 118 | 119 | "argmuent `n` must be an integer." 120 | ) 121 | 122 | }) 123 | 124 | test_that("predict_boots() throws an error when training_data/new_data doesn't match expected format", { 125 | 126 | # predictors & outcome missing from training_data 127 | expect_error( 128 | predict_boots( 129 | workflow = test_wf, 130 | n = 1, 131 | training_data = test_train[, 3], 132 | new_data = test_test 133 | ), 134 | 135 | paste0("missing cols in training_data:\n", 136 | "species, island, bill_length_mm, bill_depth_mm, flipper_length_mm, sex, body_mass_g") 137 | ) 138 | 139 | # predictors missing from new_data 140 | expect_error( 141 | predict_boots( 142 | workflow = test_wf, 143 | n = 1, 144 | training_data = test_train, 145 | new_data = test_test[, 3] 146 | ), 147 | 148 | paste0("missing cols in new_data:\n", 149 | "species, island, bill_length_mm, bill_depth_mm, flipper_length_mm, sex") 150 | ) 151 | 152 | }) 153 | 154 | -------------------------------------------------------------------------------- /tests/testthat/test-summarise-boots.R: -------------------------------------------------------------------------------- 1 | # read in test data 2 | test_preds <- readRDS("data/test_preds.rds") 3 | test_importances <- readRDS("data/test_importances.rds") 4 | 5 | test_that("summarise_predictions() returns predictions in expected format", { 6 | 7 | # read in data used to make predictions (new_data in predict_boots()) 8 | test_test <- read.csv("data/test_test.csv") 9 | 10 | # generate summary 11 | x <- summarise_predictions(test_preds) 12 | 13 | # tests 14 | expect_s3_class(x, c("tbl_df", "tbl", "data.frame")) 15 | expect_named(x, c("rowid", ".preds", ".pred", ".pred_lower", ".pred_upper")) 16 | expect_type(x$.preds, "list") 17 | expect_type(x$.pred_lower, "double") 18 | expect_type(x$.pred, "double") 19 | expect_type(x$.pred_upper, "double") 20 | expect_type(x$.preds[[1]]$model, "character") 21 | expect_type(x$.preds[[1]]$model.pred, "double") 22 | expect_equal(nrow(x), nrow(test_test)) 23 | 24 | }) 25 | 26 | test_that("summarise_importances() returns importances in expected format", { 27 | 28 | # generate summary 29 | x <- summarise_importance(test_importances) 30 | 31 | # tests 32 | expect_s3_class(x, c("tbl_df", "tbl", "data.frame")) 33 | expect_named(x, c("variable", ".importances", ".importance", ".importance_lower", ".importance_upper")) 34 | expect_type(x$.importances, "list") 35 | expect_type(x$.importance_lower, "double") 36 | expect_type(x$.importance, "double") 37 | expect_type(x$.importance_upper, "double") 38 | expect_type(x$.importances[[1]]$model, "character") 39 | expect_type(x$.importances[[1]]$model.importance, "double") 40 | 41 | }) 42 | -------------------------------------------------------------------------------- /tests/testthat/test-vi-boots.R: -------------------------------------------------------------------------------- 1 | # read in data to use in tests 2 | # test_wf: wf using xgboost to predict body_mass_g from all predictors in the 3 | # palmer penguins dataset. one recipe step - step_dummy(all_nominal()) 4 | # test_train: training df of palmer penguins 5 | # test_test: testing df of palmer penguins 6 | test_wf <- readRDS("data/test_wf.rds") 7 | test_train <- read.csv("data/test_train.csv") 8 | 9 | test_that("vi_boots() returns importances in expected format", { 10 | 11 | # generate predictions 12 | expect_warning( 13 | x <- 14 | vi_boots( 15 | workflow = test_wf, 16 | n = 5, 17 | training_data = test_train 18 | ), 19 | 20 | "At least 2000 resamples recommended for stable results." 21 | 22 | ) 23 | 24 | 25 | # tests 26 | expect_s3_class(x, c("tbl_df", "tbl", "data.frame")) 27 | expect_named(x, c("variable", ".importances")) 28 | expect_type(x$variable, "character") 29 | expect_type(x$.importances, "list") 30 | 31 | }) 32 | 33 | test_that("vi_boots() throws an error when not passed a workflow", { 34 | 35 | expect_error( 36 | vi_boots( 37 | workflow = test_train, 38 | n = 1, 39 | training_data = test_train 40 | ), 41 | 42 | "argument `workflow` must be of class \"workflow\"." 43 | ) 44 | 45 | }) 46 | 47 | test_that("vi_boots() throws an error when workflow is not final", { 48 | 49 | # load bad wf - same as test_wf but has 1 non-final tuning param 50 | test_wf_bad <- readRDS("data/test_wf_bad.rds") 51 | 52 | expect_error( 53 | vi_boots( 54 | workflow = test_wf_bad, 55 | n = 1, 56 | training_data = test_train 57 | ), 58 | 59 | "all tuning parameters must be final." 60 | ) 61 | 62 | }) 63 | 64 | test_that("vi_boots() throws an error when bad n is specified", { 65 | 66 | expect_error( 67 | vi_boots( 68 | workflow = test_wf, 69 | n = 0, 70 | training_data = test_train 71 | ), 72 | 73 | "argument `n` must be >= 1." 74 | ) 75 | 76 | expect_error( 77 | vi_boots( 78 | workflow = test_wf, 79 | n = 1.5, 80 | training_data = test_train 81 | ), 82 | 83 | "argmuent `n` must be an integer." 84 | ) 85 | 86 | }) 87 | 88 | test_that("vi_boots() throws an error when training_data doesn't match expected format", { 89 | 90 | # predictors & outcome missing from training_data 91 | expect_error( 92 | vi_boots( 93 | workflow = test_wf, 94 | n = 1, 95 | training_data = test_train[, 3] 96 | ), 97 | 98 | paste0("missing cols in training_data:\n", 99 | "species, island, bill_length_mm, bill_depth_mm, flipper_length_mm, sex, body_mass_g") 100 | ) 101 | 102 | }) 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /vignettes/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | *.R 3 | -------------------------------------------------------------------------------- /vignettes/Estimating-Linear-Intervals.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Estimating Linear Intervals" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Estimating Linear Intervals} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | ```{r, include = FALSE} 11 | knitr::opts_chunk$set( 12 | collapse = TRUE, 13 | comment = "#>", 14 | warning = FALSE, 15 | message = FALSE 16 | ) 17 | 18 | ggplot2::theme_set( 19 | ggplot2::theme_minimal() + 20 | ggplot2::theme(plot.title.position = "plot", 21 | plot.background = ggplot2::element_rect(fill = "white", color = "white")) 22 | ) 23 | ``` 24 | 25 | When using a linear model, we can generate prediction and confidence intervals without much effort. By way of example, we can use workboots to approximate linear model intervals. Let's start by building a baseline model. In this example, we'll predict a home's sale price based on the first floor's square footage with data from the [Ames housing dataset](https://modeldata.tidymodels.org/reference/ames.html). 26 | 27 | ```{r, warning=FALSE, message=FALSE} 28 | library(tidymodels) 29 | 30 | # setup our data 31 | data("ames") 32 | ames_mod <- ames %>% select(First_Flr_SF, Sale_Price) 33 | 34 | # baseline plot 35 | ames_mod %>% 36 | ggplot(aes(x = First_Flr_SF, y = Sale_Price)) + 37 | geom_point(alpha = 0.25) + 38 | scale_x_log10(labels = scales::comma_format()) + 39 | scale_y_log10(labels = scales::label_number(scale_cut = scales::cut_short_scale())) 40 | ``` 41 | 42 | We can use a linear model to predict the log transform of `Sale_Price` based on the log transform of `First_Flr_SF`and plot our predictions against a holdout set with a prediction interval. 43 | 44 | ```{r} 45 | # log transform 46 | ames_mod <- 47 | ames_mod %>% 48 | mutate(across(everything(), log10)) 49 | 50 | # split into train/test data 51 | set.seed(918) 52 | ames_split <- initial_split(ames_mod) 53 | ames_train <- training(ames_split) 54 | ames_test <- testing(ames_split) 55 | 56 | # train a linear model 57 | set.seed(314) 58 | ames_lm <- lm(Sale_Price ~ First_Flr_SF, data = ames_train) 59 | 60 | # predict on new data with a prediction interval 61 | ames_lm_pred_int <- 62 | ames_lm %>% 63 | predict(ames_test, interval = "predict") %>% 64 | as_tibble() 65 | 66 | ames_lm_pred_int %>% 67 | 68 | # rescale predictions to match the original dataset's scale 69 | bind_cols(ames_test) %>% 70 | mutate(across(everything(), ~10^.x)) %>% 71 | 72 | # plot! 73 | ggplot(aes(x = First_Flr_SF)) + 74 | geom_point(aes(y = Sale_Price), 75 | alpha = 0.25) + 76 | geom_line(aes(y = fit), 77 | size = 1) + 78 | geom_ribbon(aes(ymin = lwr, 79 | ymax = upr), 80 | alpha = 0.25) + 81 | scale_x_log10(labels = scales::comma_format()) + 82 | scale_y_log10(labels = scales::label_number(scale_cut = scales::cut_short_scale())) 83 | ``` 84 | 85 | We can use workboots to approximate the linear model's prediction interval by passing a workflow built on a linear model to `predict_boots()`. 86 | 87 | ```{r, eval=FALSE} 88 | library(workboots) 89 | 90 | # setup a workflow with a linear model 91 | ames_wf <- 92 | workflow() %>% 93 | add_recipe(recipe(Sale_Price ~ First_Flr_SF, data = ames_train)) %>% 94 | add_model(linear_reg()) 95 | 96 | # generate bootstrap predictions on ames test 97 | set.seed(713) 98 | ames_boot_pred_int <- 99 | ames_wf %>% 100 | predict_boots( 101 | n = 2000, 102 | training_data = ames_train, 103 | new_data = ames_test 104 | ) 105 | ``` 106 | 107 | ```{r, echo=FALSE} 108 | library(workboots) 109 | 110 | options(timeout = 120) 111 | 112 | # load data from workboots_support (avoid re-fitting on knit) 113 | ames_boot_pred_int <- readr::read_rds("https://github.com/markjrieke/workboots_support/blob/main/data/ames_boot_pred_int.rds?raw=true") 114 | ``` 115 | 116 | By overlaying the intervals on top of one another, we can see that the prediction interval generated by `predict_boots()` (in blue) is a good approximation of the theoretical interval from `lm()`. 117 | 118 | ```{r} 119 | ames_boot_pred_int %>% 120 | summarise_predictions() %>% 121 | 122 | # rescale predictions to match original dataset's scale 123 | bind_cols(ames_lm_pred_int) %>% 124 | bind_cols(ames_test) %>% 125 | mutate(across(.pred:Sale_Price, ~10^.x)) %>% 126 | 127 | # plot! 128 | ggplot(aes(x = First_Flr_SF)) + 129 | geom_point(aes(y = Sale_Price), 130 | alpha = 0.25) + 131 | scale_x_log10(labels = scales::comma_format()) + 132 | scale_y_log10(labels = scales::label_number(scale_cut = scales::cut_short_scale())) + 133 | 134 | # add prediction interval created by lm() 135 | geom_line(aes(y = fit), 136 | size = 1) + 137 | geom_ribbon(aes(ymin = lwr, 138 | ymax = upr), 139 | alpha = 0.25) + 140 | 141 | # add prediction interval created by workboots 142 | geom_point(aes(y = .pred), 143 | color = "blue", 144 | alpha = 0.25) + 145 | geom_errorbar(aes(ymin = .pred_lower, 146 | ymax = .pred_upper), 147 | color = "blue", 148 | alpha = 0.25, 149 | width = 0.0125) 150 | ``` 151 | 152 | Both `lm()` and `summarise_predictions()` use a 95% prediction interval by default but we can generate other intervals by passing different values to the parameter `interval_width`: 153 | 154 | ```{r} 155 | ames_boot_pred_int %>% 156 | 157 | # generate a 95% prediction interval 158 | summarise_predictions(interval_width = 0.95) %>% 159 | rename(.pred_lower_95 = .pred_lower, 160 | .pred_upper_95 = .pred_upper) %>% 161 | select(-.pred) %>% 162 | 163 | # generate 80% prediction interval 164 | summarise_predictions(interval_width = 0.80) %>% 165 | rename(.pred_lower_80 = .pred_lower, 166 | .pred_upper_80 = .pred_upper) %>% 167 | 168 | # rescale predictions to match original dataset's scale 169 | bind_cols(ames_test) %>% 170 | mutate(across(.pred_lower_95:Sale_Price, ~10^.x)) %>% 171 | 172 | # plot! 173 | ggplot(aes(x = First_Flr_SF)) + 174 | geom_point(aes(y = Sale_Price), 175 | alpha = 0.25) + 176 | geom_line(aes(y = .pred), 177 | size = 1, 178 | color = "blue") + 179 | geom_ribbon(aes(ymin = .pred_lower_95, 180 | ymax = .pred_upper_95), 181 | alpha = 0.25, 182 | fill = "blue") + 183 | geom_ribbon(aes(ymin = .pred_lower_80, 184 | ymax = .pred_upper_80), 185 | alpha = 0.25, 186 | fill = "blue") + 187 | scale_x_log10(labels = scales::comma_format()) + 188 | scale_y_log10(labels = scales::label_number(scale_cut = scales::cut_short_scale())) 189 | ``` 190 | 191 | Alternatively, we can estimate the confidence interval around each prediction by passing the argument `"confidence"` to the `interval` parameter of `predict_boots()`. 192 | 193 | ```{r, eval=FALSE} 194 | # generate linear model confidence interval for reference 195 | ames_lm_conf_int <- 196 | ames_lm %>% 197 | predict(ames_test, interval = "confidence") %>% 198 | as_tibble() 199 | 200 | # generate bootstrap predictions on test set 201 | set.seed(867) 202 | ames_boot_conf_int <- 203 | ames_wf %>% 204 | predict_boots( 205 | n = 2000, 206 | training_data = ames_train, 207 | new_data = ames_test, 208 | interval = "confidence" 209 | ) 210 | ``` 211 | 212 | ```{r, echo=FALSE} 213 | # load data from workboots_support (avoid re-fitting on knit) 214 | ames_boot_conf_int <- readr::read_rds("https://github.com/markjrieke/workboots_support/blob/main/data/ames_boot_conf_int.rds?raw=true") 215 | 216 | ames_lm_conf_int <- 217 | ames_lm %>% 218 | predict(ames_test, interval = "confidence") %>% 219 | as_tibble() 220 | ``` 221 | 222 | Again, by overlaying the intervals on the same plot, we can see that the confidence interval generated by `predict_boots()` is a good approximation of the theoretical interval. 223 | 224 | ```{r} 225 | ames_boot_conf_int %>% 226 | summarise_predictions() %>% 227 | 228 | # rescale predictions to match original dataset's scale 229 | bind_cols(ames_lm_conf_int) %>% 230 | bind_cols(ames_test) %>% 231 | mutate(across(.pred:Sale_Price, ~10^.x)) %>% 232 | 233 | # plot! 234 | ggplot(aes(x = First_Flr_SF)) + 235 | geom_point(aes(y = Sale_Price), 236 | alpha = 0.25) + 237 | scale_x_log10(labels = scales::comma_format()) + 238 | scale_y_log10(labels = scales::label_number(scale_cut = scales::cut_short_scale())) + 239 | 240 | # add prediction interval created by lm() 241 | geom_line(aes(y = fit), 242 | size = 1) + 243 | geom_ribbon(aes(ymin = lwr, 244 | ymax = upr), 245 | alpha = 0.25) + 246 | 247 | # add prediction interval created by workboots 248 | geom_point(aes(y = .pred), 249 | color = "blue", 250 | alpha = 0.25) + 251 | geom_ribbon(aes(ymin = .pred_lower, 252 | ymax = .pred_upper), 253 | fill = "blue", 254 | alpha = 0.25) 255 | ``` 256 | 257 | 258 | 259 | -------------------------------------------------------------------------------- /vignettes/Getting-Started-with-workboots.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Getting Started with workboots" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Getting Started with workboots} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | ```{r, include = FALSE} 11 | knitr::opts_chunk$set( 12 | collapse = TRUE, 13 | comment = "#>", 14 | warning = FALSE, 15 | message = FALSE 16 | ) 17 | 18 | ggplot2::theme_set( 19 | ggplot2::theme_minimal() + 20 | ggplot2::theme(plot.title.position = "plot", 21 | plot.background = ggplot2::element_rect(fill = "white", color = "white")) 22 | ) 23 | ``` 24 | 25 | Sometimes, we want a model that generates a range of possible outcomes around each prediction. Other times, we just care about point predictions and may be able to use a powerful model. workboots allows us to get the best of both worlds --- getting a range of predictions while still using powerful model! 26 | 27 | In this vignette, we'll walk through the entire process of building a boosted tree model to predict the range of possible car prices from the [modeldata::car_prices](https://modeldata.tidymodels.org/reference/car_prices.html) dataset. Prior to estimating ranges with workboots, we'll need to build and tune a workflow. This vignette will walk through several steps: 28 | 29 | 1. Building a baseline model with default parameters. 30 | 2. Tuning and finalizing model parameters. 31 | 3. Predicting price ranges with a tuned workflow. 32 | 4. Estimating variable importance ranges with a tuned workflow. 33 | 34 | ## Building a baseline model 35 | 36 | What's included in the `car_prices` dataset? 37 | 38 | ```{r} 39 | library(tidymodels) 40 | 41 | # setup data 42 | data("car_prices") 43 | car_prices 44 | ``` 45 | 46 | The `car_prices` dataset is already well set-up for modeling --- we'll apply a bit of light preprocessing before training a boosted tree model to predict the price. 47 | 48 | ```{r} 49 | # apply global transfomations 50 | car_prices <- 51 | car_prices %>% 52 | mutate(Price = log10(Price), 53 | Cylinder = as.character(Cylinder), 54 | Doors = as.character(Doors)) 55 | 56 | # split into testing and training 57 | set.seed(999) 58 | car_split <- initial_split(car_prices) 59 | car_train <- training(car_split) 60 | car_test <- testing(car_split) 61 | ``` 62 | 63 | We'll save the test data until the very end and use a validation split to evaluate our first model. 64 | 65 | ```{r} 66 | set.seed(888) 67 | car_val_split <- initial_split(car_train) 68 | car_val_train <- training(car_val_split) 69 | car_val_test <- testing(car_val_split) 70 | ``` 71 | 72 | How does an XGBoost model with default parameters perform on this dataset? 73 | 74 | ```{r} 75 | car_val_rec <- 76 | recipe(Price ~ ., data = car_val_train) %>% 77 | step_BoxCox(Mileage) %>% 78 | step_dummy(all_nominal()) 79 | 80 | # fit and predict on our validation set 81 | set.seed(777) 82 | car_val_preds <- 83 | workflow() %>% 84 | add_recipe(car_val_rec) %>% 85 | add_model(boost_tree("regression", engine = "xgboost")) %>% 86 | fit(car_val_train) %>% 87 | predict(car_val_test) %>% 88 | bind_cols(car_val_test) 89 | 90 | car_val_preds %>% 91 | rmse(truth = Price, estimate = .pred) 92 | ``` 93 | 94 | We can also plot our predictions against the actual prices to see how the baseline model performs. 95 | 96 | ```{r} 97 | car_val_preds %>% 98 | ggplot(aes(x = Price, y = .pred)) + 99 | geom_point(size = 2, alpha = 0.25) + 100 | geom_abline(linetype = "dashed") 101 | ``` 102 | 103 | We can extract a bit of extra performance by tuning the model parameters --- this is also needed if we want to stray from the default parameters when predicting ranges with the workboots package. 104 | 105 | ## Tuning model parameters 106 | 107 | Boosted tree models have a lot of available tuning parameters --- given our relatively small dataset, we'll just focus on the `mtry` and `trees` parameters. 108 | 109 | ```{r, eval=FALSE} 110 | # re-setup recipe with training dataset 111 | car_rec <- 112 | recipe(Price ~ ., data = car_train) %>% 113 | step_BoxCox(Mileage) %>% 114 | step_dummy(all_nominal()) 115 | 116 | # setup model spec 117 | car_spec <- 118 | boost_tree( 119 | mode = "regression", 120 | engine = "xgboost", 121 | mtry = tune(), 122 | trees = tune() 123 | ) 124 | 125 | # combine into workflow 126 | car_wf <- 127 | workflow() %>% 128 | add_recipe(car_rec) %>% 129 | add_model(car_spec) 130 | 131 | # setup cross-validation folds 132 | set.seed(666) 133 | car_folds <- vfold_cv(car_train) 134 | 135 | # tune model 136 | set.seed(555) 137 | car_tune <- 138 | tune_grid( 139 | car_wf, 140 | car_folds, 141 | grid = 5 142 | ) 143 | ``` 144 | 145 | Tuning gives us *slightly* better performance than the baseline model: 146 | 147 | ```{r, echo=FALSE} 148 | # re-setup recipe with training dataset 149 | car_rec <- 150 | recipe(Price ~ ., data = car_train) %>% 151 | step_BoxCox(Mileage) %>% 152 | step_dummy(all_nominal()) 153 | 154 | # setup model spec 155 | car_spec <- 156 | boost_tree( 157 | mode = "regression", 158 | engine = "xgboost", 159 | mtry = tune(), 160 | trees = tune() 161 | ) 162 | 163 | # combine into workflow 164 | car_wf <- 165 | workflow() %>% 166 | add_recipe(car_rec) %>% 167 | add_model(car_spec) 168 | 169 | car_tune <- readr::read_rds("https://github.com/markjrieke/workboots_support/raw/main/data/car_tune.rds") 170 | ``` 171 | 172 | ```{r} 173 | car_tune %>% 174 | show_best("rmse") 175 | ``` 176 | 177 | Now we can finalize the workflow with the best tuning parameters. With this finalized workflow, we can start predicting intervals with workboots! 178 | 179 | ```{r} 180 | car_wf_final <- 181 | car_wf %>% 182 | finalize_workflow(car_tune %>% select_best("rmse")) 183 | 184 | car_wf_final 185 | ``` 186 | 187 | ## Predicting price ranges 188 | 189 | To generate a prediction interval for each car's price, we can pass the finalized workflow to `predict_boots()`. 190 | 191 | ```{r, eval=FALSE} 192 | library(workboots) 193 | 194 | set.seed(444) 195 | car_preds <- 196 | car_wf_final %>% 197 | predict_boots( 198 | n = 2000, 199 | training_data = car_train, 200 | new_data = car_test 201 | ) 202 | ``` 203 | 204 | ```{r, echo=FALSE} 205 | library(workboots) 206 | 207 | car_preds <- readr::read_rds("https://github.com/markjrieke/workboots_support/raw/main/data/car_preds.rds") 208 | ``` 209 | 210 | We can summarize the predictions with upper and lower bounds of a prediction interval by passing `car_preds` to `summarise_predictions()`. 211 | 212 | ```{r} 213 | car_preds %>% 214 | summarise_predictions() 215 | ``` 216 | 217 | How do our predictions compare against the actual values? 218 | 219 | ```{r} 220 | car_preds %>% 221 | summarise_predictions() %>% 222 | bind_cols(car_test) %>% 223 | ggplot(aes(x = Price, 224 | y = .pred, 225 | ymin = .pred_lower, 226 | ymax = .pred_upper)) + 227 | geom_point(size = 2, 228 | alpha = 0.25) + 229 | geom_errorbar(alpha = 0.25, 230 | width = 0.0125) + 231 | geom_abline(linetype = "dashed", 232 | color = "gray") 233 | ``` 234 | 235 | ## Estimating variable importance 236 | 237 | With workboots, we can also estimate variable importance by passing the finalized workflow to `vi_boots()`. This uses [`vip::vi()`](https://koalaverse.github.io/vip/reference/vi.html) under the hood, which doesn't support all the model types that are available in tidymodels --- please refer to [vip's package documentation](https://koalaverse.github.io/vip/articles/vip.html) for a full list of supported models. 238 | 239 | ```{r, eval=FALSE} 240 | set.seed(333) 241 | car_importance <- 242 | car_wf_final %>% 243 | vi_boots( 244 | n = 2000, 245 | trainng_data = car_train 246 | ) 247 | ``` 248 | 249 | ```{r, echo=FALSE} 250 | car_importance <- readr::read_rds("https://github.com/markjrieke/workboots_support/raw/main/data/car_importance.rds") 251 | ``` 252 | 253 | Similar to predictions, we can summarise each variable's importance by passing `car_importance` to the function `summarise_importance()` and plot the results. 254 | 255 | ```{r} 256 | car_importance %>% 257 | summarise_importance() %>% 258 | mutate(variable = forcats::fct_reorder(variable, .importance)) %>% 259 | ggplot(aes(x = variable, 260 | y = .importance, 261 | ymin = .importance_lower, 262 | ymax = .importance_upper)) + 263 | geom_point(size = 2) + 264 | geom_errorbar() + 265 | coord_flip() 266 | ``` 267 | 268 | -------------------------------------------------------------------------------- /vignettes/The-Math-Behind-workboots.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "The Math Behind workboots" 3 | output: 4 | html_document: 5 | code_folding: hide 6 | vignette: > 7 | %\VignetteIndexEntry{The Math Behind workboots} 8 | %\VignetteEngine{knitr::rmarkdown} 9 | %\VignetteEncoding{UTF-8} 10 | --- 11 | 12 | ```{r, include = FALSE} 13 | knitr::opts_chunk$set( 14 | collapse = TRUE, 15 | comment = "#>", 16 | warning = FALSE, 17 | message = FALSE 18 | ) 19 | 20 | ggplot2::theme_set( 21 | ggplot2::theme_minimal() + 22 | ggplot2::theme(plot.title.position = "plot", 23 | plot.background = ggplot2::element_rect(fill = "white", color = "white")) 24 | ) 25 | ``` 26 | 27 | Generating prediction intervals with workboots hinges on a few core concepts: bootstrap resampling, estimating prediction error for each resample, and aggregating the resampled prediction errors for each observation. The [`bootstraps()` documentation from {rsample}](https://rsample.tidymodels.org/reference/bootstraps.html) gives a concise definition of bootstrap resampling: 28 | 29 | > A bootstrap sample is a sample that is the same size as the original data set that is made using replacement. This results in analysis samples that have multiple replicates of some of the original rows of the data. The assessment set is defined as the rows of the original data that were not included in the bootstrap sample. This is often referred to as the "out-of-bag" (OOB) sample. 30 | 31 | This vignette will walk through the details of estimating and aggregating prediction errors --- additional resources can be found in Davison and Hinkley's book, [*Bootstrap Methods and their Application*](https://www.cambridge.org/core/books/bootstrap-methods-and-their-application/ED2FD043579F27952363566DC09CBD6A), or Efron and Tibshirani's paper, *Improvements on Cross-Validation: The Bootstrap .632+ Method* (available on JSTOR). 32 | 33 | ### The Bootstrap .632+ Method 34 | 35 | *What follows here is largely a summary of [this explanation](https://stats.stackexchange.com/questions/96739/what-is-the-632-rule-in-bootstrapping/96750#96750) of the .632+ error rate by Benjamin Deonovic.* 36 | 37 | When working with bootstrap resamples of a dataset, there are two error estimates we can work with: the bootstrap training error and the out-of-bag (oob) error. Using the [Sacramento housing dataset](https://modeldata.tidymodels.org/reference/Sacramento.html), we can estimate the training and oob error for a single bootstrap. 38 | 39 | ```{r, echo=FALSE} 40 | library(tidymodels) 41 | 42 | # setup our data 43 | data("Sacramento") 44 | Sacramento <- 45 | Sacramento %>% 46 | select(sqft, type, price) %>% 47 | mutate(across(c(sqft, price), log10)) 48 | 49 | set.seed(987) 50 | sacramento_split <- initial_split(Sacramento) 51 | sacramento_train <- training(sacramento_split) 52 | sacramento_test <- testing(sacramento_split) 53 | 54 | # setup bootstrapped dataset for .632+ example 55 | sacramento_boots <- bootstraps(sacramento_train, times = 1) 56 | ``` 57 | ```{r class.source = 'fold-show'} 58 | sacramento_boots 59 | ``` 60 | 61 | Using a [k-nearest-neighbor regression model](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm#k-NN_regression) and [rmse](https://en.wikipedia.org/wiki/Root-mean-square_deviation#:~:text=The%20root%2Dmean%2Dsquare%20deviation,estimator%20and%20the%20values%20observed.) as our error metric, we find that the training and oob error differ, with the training error lesser than the oob error. 62 | 63 | ```{r, echo=FALSE} 64 | # setup a workflow to predict price using a knn regressor 65 | sacramento_recipe <- 66 | recipe(price ~ ., data = sacramento_train) %>% 67 | step_dummy(all_nominal()) 68 | 69 | sacramento_spec <- 70 | nearest_neighbor() %>% 71 | set_mode("regression") 72 | 73 | sacramento_wf <- 74 | workflow() %>% 75 | add_recipe(sacramento_recipe) %>% 76 | add_model(sacramento_spec) 77 | 78 | set.seed(876) 79 | sacramento_fit <- 80 | sacramento_wf %>% 81 | fit(training(sacramento_boots$splits[[1]])) 82 | 83 | # get bootstrap training error 84 | sacramento_train_err <- 85 | Metrics::rmse( 86 | training(sacramento_boots$splits[[1]])$price, 87 | sacramento_fit %>% predict(training(sacramento_boots$splits[[1]])) %>% pull(.pred) 88 | ) 89 | 90 | # get oob error 91 | sacramento_oob_err <- 92 | Metrics::rmse( 93 | testing(sacramento_boots$splits[[1]])$price, 94 | sacramento_fit %>% predict(testing(sacramento_boots$splits[[1]])) %>% pull(.pred) 95 | ) 96 | ``` 97 | 98 | ```{r, class.source = 'fold-show'} 99 | sacramento_train_err 100 | sacramento_oob_err 101 | ``` 102 | 103 | The training error is overly optimistic in the model's performance and likely to under-estimate the prediction error. We are interested in the model's performance on new data. The oob error, on the other hand, is likely to over-estimate the prediction error! This is due to non-distinct observations in the bootstrap sample that results from sampling with replacement. Given that [the average number of distinct observations in a bootstrap training set is about `0.632 * total_observations`](https://stats.stackexchange.com/questions/88980/why-on-average-does-each-bootstrap-sample-contain-roughly-two-thirds-of-observat?lq=1), Efron and Tibshirani proposed a blend of the training and oob error with the 0.632 estimate: 104 | 105 | \begin{align*} 106 | Err_{.632} & = 0.368 Err_{train} + 0.632 Err_{oob} 107 | \end{align*} 108 | 109 | ```{r, class.source = 'fold-show'} 110 | sacramento_632 <- 0.368 * sacramento_train_err + 0.632 * sacramento_oob_err 111 | sacramento_632 112 | ``` 113 | 114 | If, however, the model is highly overfit to the bootstrap training set, the training error will approach 0 and the 0.632 estimate will *under estimate* the prediction error. 115 | 116 | An example from [*Applied Predictive Modeling*](http://appliedpredictivemodeling.com/) shows that as model complexity increases, the reported resample accuracy by the 0.632 estimate continues to increase whereas other resampling strategies report diminishing returns: 117 | 118 | ![](https://user-images.githubusercontent.com/5731043/157986232-9c32c1c2-a7ed-4f9f-b28e-7d8ccb7ac41c.png) 119 | 120 | As an alternative to the 0.632 estimate, Efron & Tibshirani also propose the 0.632+ estimate, which re-weights the blend of training and oob error based on the model overfit rate: 121 | 122 | \begin{align*} 123 | Err_{0.632+} & = (1 - w) Err_{train} + w Err_{oob} \\ 124 | \\ 125 | w & = \frac{0.632}{1 - 0.368 R} \\ 126 | \\ 127 | R & = \frac{Err_{oob} - Err_{train}}{\gamma - Err_{train}} 128 | \end{align*} 129 | 130 | Here, $R$ represents the overfit rate and $\gamma$ is the no-information error rate, estimated by evaulating all combinations of predictions and actual values in the bootstrap training set. 131 | 132 | ```{r, echo=FALSE} 133 | # estimate the no-information error rate 134 | preds_train <- predict(sacramento_fit, training(sacramento_boots$splits[[1]])) %>% pull(.pred) 135 | actuals_train <- training(sacramento_boots$splits[[1]]) %>% pull(price) 136 | all_combinations <- crossing(actuals_train, preds_train) 137 | rmse_ni <- Metrics::rmse(all_combinations$actuals_train, all_combinations$preds_train) 138 | 139 | # estimate the overfit rate 140 | overfit <- (sacramento_oob_err - sacramento_train_err)/(rmse_ni - sacramento_train_err) 141 | 142 | # estimate weight 143 | w <- 0.632 / (1 - 0.368 * overfit) 144 | ``` 145 | 146 | ```{r, class.source = 'fold-show'} 147 | sacramento_632_plus <- (1 - w) * sacramento_train_err + w * sacramento_oob_err 148 | sacramento_632_plus 149 | ``` 150 | 151 | When there is no overfitting (i.e., $R = 0$) the 0.632+ estimate will equal the 0.632 estimate. In this case, however, the model is overfitting the training set and the 0.632+ error estimate is pushed a bit closer to the oob error. 152 | 153 | ### Prediction intervals with many bootstraps 154 | 155 | [For an unbiased estimator, rmse is the standard deviation of the residuals](https://en.wikipedia.org/wiki/Root-mean-square_deviation#Formula). With this in mind, we can modify our predictions to include a sample from the residual distribution (for more information, see Algorithm 6.4 from Davison and Hinkley's *Bootstrap Methods and their Application*): 156 | 157 | ```{r, class.source = 'fold-show'} 158 | set.seed(999) 159 | resid_train_add <- rnorm(length(preds_train), 0, sacramento_632_plus) 160 | 161 | preds_train_mod <- preds_train + resid_train_add 162 | ``` 163 | 164 | Thus far, we've been working with a single bootstrap resample. When working with a single bootstrap resample, adding this residual term gives a pretty poor estimate for each observation: 165 | 166 | ```{r, echo=FALSE} 167 | library(ggplot2) 168 | 169 | tibble(.pred = preds_train_mod) %>% 170 | bind_cols(training(sacramento_boots$splits[[1]])) %>% 171 | mutate(across(c(.pred, price), ~10^.x)) %>% 172 | ggplot(aes(x = .pred, y = price)) + 173 | geom_point(alpha = 0.25, 174 | size = 2.5, 175 | color = "midnightblue") + 176 | geom_abline(linetype = "dashed", 177 | size = 1, 178 | color = "gray") + 179 | labs(title = "Predicted sale price of home in Sacramento", 180 | subtitle = "Adding a single error estimate produces poor predictions of price", 181 | x = "Predicted price", 182 | y = "Actual price") + 183 | scale_x_log10(labels = scales::label_dollar(scale_cut = cut_short_scale())) + 184 | scale_y_log10(labels = scales::label_dollar(scale_cut = cut_short_scale())) 185 | 186 | ``` 187 | 188 | With workboots, however, we can repeat this process over many bootstrap datasets to generate a prediction distribution for each observation: 189 | 190 | ```{r, class.source = 'fold-show'} 191 | library(workboots) 192 | 193 | # fit and predict price in sacramento_test from 100 models 194 | # the default number of resamples is 2000 - dropping here to speed up knitting 195 | set.seed(555) 196 | sacramento_pred_int <- 197 | sacramento_wf %>% 198 | predict_boots( 199 | n = 100, 200 | training_data = sacramento_train, 201 | new_data = sacramento_test 202 | ) 203 | ``` 204 | 205 | ```{r, echo=FALSE} 206 | sacramento_pred_int %>% 207 | summarise_predictions() %>% 208 | bind_cols(sacramento_test) %>% 209 | mutate(across(c(.pred:.pred_upper, price), ~ 10^.x)) %>% 210 | ggplot(aes(x = .pred, 211 | y = price, 212 | ymin = .pred_lower, 213 | ymax = .pred_upper)) + 214 | geom_point(alpha = 0.25, 215 | size = 2.5, 216 | color = "midnightblue") + 217 | geom_errorbar(alpha = 0.25, 218 | color = "midnightblue", 219 | width = 0.0125) + 220 | scale_x_log10(labels = scales::label_dollar(scale_cut = cut_short_scale())) + 221 | scale_y_log10(labels = scales::label_dollar(scale_cut = cut_short_scale())) + 222 | geom_abline(linetype = "dashed", 223 | size = 1, 224 | color = "gray") + 225 | labs(title = "Predicted sale price of home in Sacramento", 226 | subtitle = "Using many resamples allows us to generate prediction intervals", 227 | x = "Predicted price", 228 | y = "Actual price") 229 | ``` 230 | 231 | This methodology produces prediction distributions that are [consistent with what we might expect from linear models](https://markjrieke.github.io/workboots/articles/Estimating-Linear-Intervals.html) while making no assumptions about model type (i.e., we can use a non-parametric model; in this case, a k-nearest neighbors regression). 232 | -------------------------------------------------------------------------------- /workboots.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: Default 4 | SaveWorkspace: Default 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: Sweave 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | 18 | BuildType: Package 19 | PackageUseDevtools: Yes 20 | PackageInstallArgs: --no-multiarch --with-keep.source 21 | PackageBuildArgs: --compact-vignettes=qpdf 22 | PackageCheckArgs: --compact-vignettes=qpdf 23 | --------------------------------------------------------------------------------