├── .Rbuildignore ├── .github ├── .gitignore └── workflows │ ├── R-CMD-check.yaml │ ├── pkgdown.yaml │ └── test-coverage.yaml ├── .gitignore ├── CRAN-SUBMISSION ├── DESCRIPTION ├── LICENSE.md ├── NAMESPACE ├── NEWS.md ├── R ├── aa_deprecated.R ├── flashlight.R ├── is_flashlight.R ├── light_breakdown.R ├── light_check.R ├── light_combine.R ├── light_effects.R ├── light_global_surrogate.R ├── light_ice.R ├── light_importance.R ├── light_interaction.R ├── light_performance.R ├── light_profile.R ├── light_profile2d.R ├── light_recode.R ├── light_scatter.R ├── methods.R ├── multiflashlight.R ├── utils.R ├── utils_cut.R ├── utils_grouped.R ├── utils_plot.R └── zzz.R ├── README.md ├── cran-comments.md ├── flashlight.Rproj ├── logo.png ├── man ├── add_shap.Rd ├── figures │ ├── ale.svg │ ├── breakdown.svg │ ├── cice.svg │ ├── effects.svg │ ├── ice.svg │ ├── imp.svg │ ├── logo.png │ ├── pdp.svg │ ├── pdp2d.svg │ ├── pdp_grouped.svg │ ├── pdp_grouped_multi.svg │ ├── perf.svg │ ├── perf_grouped.svg │ ├── perf_grouped_multi.svg │ └── surrogate.svg ├── flashlight.Rd ├── is.flashlight.Rd ├── light_breakdown.Rd ├── light_check.Rd ├── light_combine.Rd ├── light_effects.Rd ├── light_global_surrogate.Rd ├── light_ice.Rd ├── light_importance.Rd ├── light_interaction.Rd ├── light_performance.Rd ├── light_profile.Rd ├── light_profile2d.Rd ├── light_recode.Rd ├── light_scatter.Rd ├── most_important.Rd ├── multiflashlight.Rd ├── plot.light_breakdown.Rd ├── plot.light_effects.Rd ├── plot.light_global_surrogate.Rd ├── plot.light_ice.Rd ├── plot.light_importance.Rd ├── plot.light_performance.Rd ├── plot.light_profile.Rd ├── plot.light_profile2d.Rd ├── plot.light_scatter.Rd ├── plot_counts.Rd ├── predict.flashlight.Rd ├── predict.multiflashlight.Rd ├── print.flashlight.Rd ├── print.light.Rd ├── print.multiflashlight.Rd ├── residuals.flashlight.Rd ├── residuals.multiflashlight.Rd └── response.Rd ├── packaging.R ├── revdep ├── .gitignore ├── README.md ├── cran.md ├── email.yml ├── failures.md └── problems.md ├── tests ├── testthat.R └── testthat │ ├── tests-breakdown.R │ ├── tests-cut.R │ ├── tests-eff.R │ ├── tests-globaltree.R │ ├── tests-grouped.R │ ├── tests-ice.R │ ├── tests-importance.R │ ├── tests-interaction.R │ ├── tests-methods.R │ ├── tests-perf.R │ └── tests-profile2d.R └── vignettes ├── .gitignore ├── biblio.bib └── flashlight.Rmd /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^LICENSE\.md$ 2 | ^packaging.R$ 3 | [.]Rproj$ 4 | ^backlog$ 5 | ^cran-comments.md$ 6 | ^logo.png$ 7 | ^.*\.Rproj$ 8 | ^\.Rproj\.user$ 9 | ^cran-comments\.md$ 10 | ^doc$ 11 | ^Meta$ 12 | ^\.github$ 13 | ^revdep$ 14 | ^CRAN-SUBMISSION$ 15 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | branches: [main, master] 8 | 9 | name: R-CMD-check 10 | 11 | jobs: 12 | R-CMD-check: 13 | runs-on: ${{ matrix.config.os }} 14 | 15 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 16 | 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | config: 21 | - {os: macos-latest, r: 'release'} 22 | - {os: windows-latest, r: 'release'} 23 | - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} 24 | - {os: ubuntu-latest, r: 'release'} 25 | - {os: ubuntu-latest, r: 'oldrel-1'} 26 | 27 | env: 28 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 29 | R_KEEP_PKG_SOURCE: yes 30 | 31 | steps: 32 | - uses: actions/checkout@v3 33 | 34 | - uses: r-lib/actions/setup-pandoc@v2 35 | 36 | - uses: r-lib/actions/setup-r@v2 37 | with: 38 | r-version: ${{ matrix.config.r }} 39 | http-user-agent: ${{ matrix.config.http-user-agent }} 40 | use-public-rspm: true 41 | 42 | - uses: r-lib/actions/setup-r-dependencies@v2 43 | with: 44 | extra-packages: any::rcmdcheck 45 | needs: check 46 | 47 | - uses: r-lib/actions/check-r-package@v2 48 | with: 49 | upload-snapshots: true 50 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | branches: [main, master] 8 | release: 9 | types: [published] 10 | workflow_dispatch: 11 | 12 | name: pkgdown 13 | 14 | jobs: 15 | pkgdown: 16 | runs-on: ubuntu-latest 17 | # Only restrict concurrency for non-PR jobs 18 | concurrency: 19 | group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }} 20 | env: 21 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 22 | permissions: 23 | contents: write 24 | steps: 25 | - uses: actions/checkout@v3 26 | 27 | - uses: r-lib/actions/setup-pandoc@v2 28 | 29 | - uses: r-lib/actions/setup-r@v2 30 | with: 31 | use-public-rspm: true 32 | 33 | - uses: r-lib/actions/setup-r-dependencies@v2 34 | with: 35 | extra-packages: any::pkgdown, local::. 36 | needs: website 37 | 38 | - name: Build site 39 | run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE) 40 | shell: Rscript {0} 41 | 42 | - name: Deploy to GitHub pages 🚀 43 | if: github.event_name != 'pull_request' 44 | uses: JamesIves/github-pages-deploy-action@v4.4.1 45 | with: 46 | clean: false 47 | branch: gh-pages 48 | folder: docs 49 | -------------------------------------------------------------------------------- /.github/workflows/test-coverage.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | 8 | name: test-coverage.yaml 9 | 10 | permissions: read-all 11 | 12 | jobs: 13 | test-coverage: 14 | runs-on: ubuntu-latest 15 | env: 16 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | 21 | - uses: r-lib/actions/setup-r@v2 22 | with: 23 | use-public-rspm: true 24 | 25 | - uses: r-lib/actions/setup-r-dependencies@v2 26 | with: 27 | extra-packages: any::covr, any::xml2 28 | needs: coverage 29 | 30 | - name: Test coverage 31 | run: | 32 | cov <- covr::package_coverage( 33 | quiet = FALSE, 34 | clean = FALSE, 35 | install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package") 36 | ) 37 | print(cov) 38 | covr::to_cobertura(cov) 39 | shell: Rscript {0} 40 | 41 | - uses: codecov/codecov-action@v5 42 | with: 43 | # Fail if error if not on PR, or if on PR and token is given 44 | fail_ci_if_error: ${{ github.event_name != 'pull_request' || secrets.CODECOV_TOKEN }} 45 | files: ./cobertura.xml 46 | plugins: noop 47 | disable_search: true 48 | token: ${{ secrets.CODECOV_TOKEN }} 49 | 50 | - name: Show testthat output 51 | if: always() 52 | run: | 53 | ## -------------------------------------------------------------------- 54 | find '${{ runner.temp }}/package' -name 'testthat.Rout*' -exec cat '{}' \; || true 55 | shell: bash 56 | 57 | - name: Upload test results 58 | if: failure() 59 | uses: actions/upload-artifact@v4 60 | with: 61 | name: coverage-test-failures 62 | path: ${{ runner.temp }}/package 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .RData 4 | .Ruserdata 5 | doc 6 | Meta 7 | inst/doc 8 | /doc/ 9 | /Meta/ 10 | -------------------------------------------------------------------------------- /CRAN-SUBMISSION: -------------------------------------------------------------------------------- 1 | Version: 0.9.0 2 | Date: 2023-05-09 19:40:00 UTC 3 | SHA: a3279404b0fa89f5fb1256cc2a370ed92c3a5716 4 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: flashlight 2 | Title: Shed Light on Black Box Machine Learning Models 3 | Version: 0.9.0.9000 4 | Authors@R: 5 | person("Michael", "Mayer", , "mayermichael79@gmail.com", role = c("aut", "cre", "cph")) 6 | Description: Shed light on black box machine learning models by the help 7 | of model performance, variable importance, global surrogate models, 8 | ICE profiles, partial dependence (Friedman J. H. (2001) 9 | ), accumulated local effects (Apley D. W. 10 | (2016) ), further effects plots, interaction 11 | strength, and variable contribution breakdown (Gosiewska and Biecek 12 | (2019) ). All tools are implemented to work with 13 | case weights and allow for stratified analysis. Furthermore, multiple 14 | flashlights can be combined and analyzed together. 15 | License: GPL (>= 2) 16 | Depends: 17 | R (>= 3.2.0) 18 | Encoding: UTF-8 19 | Roxygen: list(markdown = TRUE) 20 | RoxygenNote: 7.3.1 21 | Imports: 22 | dplyr (>= 1.1.0), 23 | ggplot2, 24 | MetricsWeighted (>= 0.3.0), 25 | rlang (>= 0.3.0), 26 | rpart, 27 | rpart.plot, 28 | stats, 29 | tibble, 30 | tidyr (>= 1.0.0), 31 | tidyselect, 32 | utils 33 | URL: https://github.com/mayer79/flashlight 34 | BugReports: https://github.com/mayer79/flashlight/issues 35 | Suggests: 36 | knitr, 37 | rmarkdown, 38 | testthat (>= 3.0.0) 39 | VignetteBuilder: knitr 40 | Config/testthat/edition: 3 41 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method(flashlight,default) 4 | S3method(flashlight,flashlight) 5 | S3method(light_breakdown,default) 6 | S3method(light_breakdown,flashlight) 7 | S3method(light_breakdown,multiflashlight) 8 | S3method(light_check,default) 9 | S3method(light_check,flashlight) 10 | S3method(light_check,multiflashlight) 11 | S3method(light_combine,default) 12 | S3method(light_combine,light) 13 | S3method(light_combine,list) 14 | S3method(light_effects,default) 15 | S3method(light_effects,flashlight) 16 | S3method(light_effects,multiflashlight) 17 | S3method(light_global_surrogate,default) 18 | S3method(light_global_surrogate,flashlight) 19 | S3method(light_global_surrogate,multiflashlight) 20 | S3method(light_ice,default) 21 | S3method(light_ice,flashlight) 22 | S3method(light_ice,multiflashlight) 23 | S3method(light_importance,default) 24 | S3method(light_importance,flashlight) 25 | S3method(light_importance,multiflashlight) 26 | S3method(light_interaction,default) 27 | S3method(light_interaction,flashlight) 28 | S3method(light_interaction,multiflashlight) 29 | S3method(light_performance,default) 30 | S3method(light_performance,flashlight) 31 | S3method(light_performance,multiflashlight) 32 | S3method(light_profile,default) 33 | S3method(light_profile,flashlight) 34 | S3method(light_profile,multiflashlight) 35 | S3method(light_profile2d,default) 36 | S3method(light_profile2d,flashlight) 37 | S3method(light_profile2d,multiflashlight) 38 | S3method(light_scatter,default) 39 | S3method(light_scatter,flashlight) 40 | S3method(light_scatter,multiflashlight) 41 | S3method(multiflashlight,default) 42 | S3method(multiflashlight,flashlight) 43 | S3method(multiflashlight,list) 44 | S3method(multiflashlight,multiflashlight) 45 | S3method(plot,light_breakdown) 46 | S3method(plot,light_effects) 47 | S3method(plot,light_global_surrogate) 48 | S3method(plot,light_ice) 49 | S3method(plot,light_importance) 50 | S3method(plot,light_performance) 51 | S3method(plot,light_profile) 52 | S3method(plot,light_profile2d) 53 | S3method(plot,light_scatter) 54 | S3method(predict,flashlight) 55 | S3method(predict,multiflashlight) 56 | S3method(print,flashlight) 57 | S3method(print,light) 58 | S3method(print,multiflashlight) 59 | S3method(residuals,flashlight) 60 | S3method(residuals,multiflashlight) 61 | S3method(response,default) 62 | S3method(response,flashlight) 63 | S3method(response,multiflashlight) 64 | export(add_shap) 65 | export(flashlight) 66 | export(is.flashlight) 67 | export(is.light) 68 | export(is.light_breakdown) 69 | export(is.light_breakdown_multi) 70 | export(is.light_effects) 71 | export(is.light_effects_multi) 72 | export(is.light_global_surrogate) 73 | export(is.light_global_surrogate_multi) 74 | export(is.light_ice) 75 | export(is.light_ice_multi) 76 | export(is.light_importance) 77 | export(is.light_importance_multi) 78 | export(is.light_performance) 79 | export(is.light_performance_multi) 80 | export(is.light_profile) 81 | export(is.light_profile2d) 82 | export(is.light_profile2d_multi) 83 | export(is.light_profile_multi) 84 | export(is.light_scatter) 85 | export(is.light_scatter_multi) 86 | export(is.multiflashlight) 87 | export(is.shap) 88 | export(light_breakdown) 89 | export(light_check) 90 | export(light_combine) 91 | export(light_effects) 92 | export(light_global_surrogate) 93 | export(light_ice) 94 | export(light_importance) 95 | export(light_interaction) 96 | export(light_performance) 97 | export(light_profile) 98 | export(light_profile2d) 99 | export(light_recode) 100 | export(light_scatter) 101 | export(most_important) 102 | export(multiflashlight) 103 | export(plot_counts) 104 | export(response) 105 | importFrom(rlang,.data) 106 | -------------------------------------------------------------------------------- /R/aa_deprecated.R: -------------------------------------------------------------------------------- 1 | #' DEPRECATED 2 | #' 3 | #' Deprecated in favor of {kernelshap}/{fastshap}. 4 | #' 5 | #' @export 6 | #' @param ... Deprecated 7 | #' @returns Error message. 8 | add_shap <- function(...) { 9 | stop("Deprecated in favor of {kernelshap} or {fastshap}.") 10 | } 11 | 12 | #' DEPRECATED 13 | #' 14 | #' @param ... Any input. 15 | #' @returns Error message. 16 | #' @export 17 | plot_counts <- function(...) { 18 | stop("'plot_counts()' has been deprecated.") 19 | } 20 | 21 | #' DEPRECATED 22 | #' 23 | #' @param ... Any input. 24 | #' @returns Error message. 25 | #' @export 26 | light_recode <- function(...) { 27 | stop("'light_recode()' is deprecated.") 28 | } 29 | -------------------------------------------------------------------------------- /R/flashlight.R: -------------------------------------------------------------------------------- 1 | #' Create or Update a flashlight 2 | #' 3 | #' Creates or updates a "flashlight" object. If a flashlight is to be created, 4 | #' all arguments are optional except `label`. If a flashlight is to be updated, 5 | #' all arguments are optional up to `x` (the flashlight to be updated). 6 | #' 7 | #' @param x An object of class "flashlight". If not provided, a new flashlight is 8 | #' created based on further input. Otherwise, `x` is updated based on further input. 9 | #' @param model A fitted model of any type. Most models require a customized 10 | #' `predict_function`. 11 | #' @param data A `data.frame` or `tibble` used as basis for calculations. 12 | #' @param y Variable name of response. 13 | #' @param predict_function A real valued function with two arguments: 14 | #' A model and a data of the same structure as `data`. 15 | #' Only the order of the two arguments matter, not their names. 16 | #' @param linkinv An inverse transformation function applied after `predict_function`. 17 | #' @param w A variable name of case weights. 18 | #' @param by A character vector with names of grouping variables. 19 | #' @param metrics A named list of metrics. Here, a metric is a function with exactly 20 | #' four arguments: actual, predicted, w (case weights) and `...` 21 | #' like those in package {MetricsWeighted}. 22 | #' @param label Name of the flashlight. Required. 23 | #' @param shap An optional shap object. Typically added by calling [add_shap()]. 24 | #' @param check When updating the flashlight: Should internal checks be performed? 25 | #' Default is `TRUE`. 26 | #' @param ... Arguments passed from or to other functions. 27 | #' @returns An object of class "flashlight" (and `list`) containing each 28 | #' input (except `x`) as element. 29 | #' @export 30 | #' @examples 31 | #' fit <- lm(Sepal.Length ~ ., data = iris) 32 | #' (fl <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols")) 33 | #' (fl_updated <- flashlight(fl, linkinv = exp)) 34 | #' @seealso [multiflashlight()] 35 | flashlight <- function(x, ...) { 36 | UseMethod("flashlight") 37 | } 38 | 39 | #' @describeIn flashlight Used to create a flashlight object. 40 | #' No `x` has to be passed in this case. 41 | #' @export 42 | flashlight.default <- function(x, model = NULL, data = NULL, y = NULL, 43 | predict_function = stats::predict, 44 | linkinv = function(z) z, 45 | w = NULL, by = NULL, 46 | metrics = list(rmse = MetricsWeighted::rmse), 47 | label = NULL, shap = NULL, ...) { 48 | x <- c( 49 | list( 50 | model = model, 51 | data = data, 52 | y = y, 53 | predict_function = predict_function, 54 | linkinv = linkinv, 55 | w = w, 56 | by = by, 57 | metrics = metrics, 58 | label = label, 59 | shap = shap 60 | ), 61 | list(...) 62 | ) 63 | class(x) <- c("flashlight", "list") 64 | light_check(x) 65 | } 66 | 67 | #' @describeIn flashlight Used to update an existing flashlight object. 68 | #' @export 69 | flashlight.flashlight <- function(x, check = TRUE, ...) { 70 | args <- list(...) 71 | x[names(args)] <- args 72 | if (check) light_check(x) else invisible(x) 73 | } 74 | -------------------------------------------------------------------------------- /R/is_flashlight.R: -------------------------------------------------------------------------------- 1 | #' Check functions for flashlight Classes 2 | #' 3 | #' Checks if an object inherits specific class relevant for the flashlight package. 4 | #' 5 | #' @param x Any object. 6 | #' @returns A logical vector of length one. 7 | #' @export 8 | #' @examples 9 | #' a <- flashlight(label = "a") 10 | #' is.flashlight(a) 11 | #' is.flashlight("a") 12 | is.flashlight <- function(x) { 13 | inherits(x, "flashlight") 14 | } 15 | 16 | #' @describeIn is.flashlight Check for multiflashlight object. 17 | #' @export 18 | is.multiflashlight <- function(x) { 19 | inherits(x, "multiflashlight") 20 | } 21 | 22 | #' @describeIn is.flashlight Check for light object. 23 | #' @export 24 | is.light <- function(x) { 25 | inherits(x, "light") 26 | } 27 | 28 | #' @describeIn is.flashlight Check for light_performance object. 29 | #' @export 30 | is.light_performance <- function(x) { 31 | inherits(x, "light_performance") 32 | } 33 | 34 | #' @describeIn is.flashlight Check for light_performance_multi object. 35 | #' @export 36 | is.light_performance_multi <- function(x) { 37 | inherits(x, "light_performance_multi") 38 | } 39 | 40 | #' @describeIn is.flashlight Check for light_importance object. 41 | #' @export 42 | is.light_importance <- function(x) { 43 | inherits(x, "light_importance") 44 | } 45 | 46 | #' @describeIn is.flashlight Check for light_importance_multi object. 47 | #' @export 48 | is.light_importance_multi <- function(x) { 49 | inherits(x, "light_importance_multi") 50 | } 51 | 52 | #' @describeIn is.flashlight Check for light_breakdown object. 53 | #' @export 54 | is.light_breakdown <- function(x) { 55 | inherits(x, "light_breakdown") 56 | } 57 | 58 | #' @describeIn is.flashlight Check for light_breakdown_multi object. 59 | #' @export 60 | is.light_breakdown_multi <- function(x) { 61 | inherits(x, "light_breakdown_multi") 62 | } 63 | 64 | #' @describeIn is.flashlight Check for light_ice object. 65 | #' @export 66 | is.light_ice <- function(x) { 67 | inherits(x, "light_ice") 68 | } 69 | 70 | #' @describeIn is.flashlight Check for light_ice_multi object. 71 | #' @export 72 | is.light_ice_multi <- function(x) { 73 | inherits(x, "light_ice_multi") 74 | } 75 | 76 | #' @describeIn is.flashlight Check for light_profile object. 77 | #' @export 78 | is.light_profile <- function(x) { 79 | inherits(x, "light_profile") 80 | } 81 | 82 | #' @describeIn is.flashlight Check for light_profile_multi object. 83 | #' @export 84 | is.light_profile_multi <- function(x) { 85 | inherits(x, "light_profile_multi") 86 | } 87 | 88 | #' @describeIn is.flashlight Check for light_profile2d object. 89 | #' @export 90 | is.light_profile2d <- function(x) { 91 | inherits(x, "light_profile2d") 92 | } 93 | 94 | #' @describeIn is.flashlight Check for light_profile2d_multi object. 95 | #' @export 96 | is.light_profile2d_multi <- function(x) { 97 | inherits(x, "light_profile2d_multi") 98 | } 99 | 100 | #' @describeIn is.flashlight Check for light_effects object. 101 | #' @export 102 | is.light_effects <- function(x) { 103 | inherits(x, "light_effects") 104 | } 105 | 106 | #' @describeIn is.flashlight Check for light_effects_multi object. 107 | #' @export 108 | is.light_effects_multi <- function(x) { 109 | inherits(x, "light_effects_multi") 110 | } 111 | 112 | #' @describeIn is.flashlight Check for shap object. 113 | #' @export 114 | is.shap <- function(x) { 115 | inherits(x, "shap") 116 | } 117 | 118 | #' @describeIn is.flashlight Check for light_scatter object. 119 | #' @export 120 | is.light_scatter <- function(x) { 121 | inherits(x, "light_scatter") 122 | } 123 | 124 | #' @describeIn is.flashlight Check for light_scatter_multi object. 125 | #' @export 126 | is.light_scatter_multi <- function(x) { 127 | inherits(x, "light_scatter_multi") 128 | } 129 | 130 | #' @describeIn is.flashlight Check for light_global_surrogate object. 131 | #' @export 132 | is.light_global_surrogate <- function(x) { 133 | inherits(x, "light_global_surrogate") 134 | } 135 | 136 | #' @describeIn is.flashlight Check for light_global_surrogate_multi object. 137 | #' @export 138 | is.light_global_surrogate_multi <- function(x) { 139 | inherits(x, "light_global_surrogate_multi") 140 | } 141 | -------------------------------------------------------------------------------- /R/light_check.R: -------------------------------------------------------------------------------- 1 | #' Check flashlight 2 | #' 3 | #' Checks if an object of class "flashlight" or "multiflashlight" 4 | #' is consistently defined. 5 | #' 6 | #' @param x An object of class "flashlight" or "multiflashlight". 7 | #' @param ... Further arguments passed from or to other methods. 8 | #' @returns The input `x` or an error message. 9 | #' @export 10 | #' @examples 11 | #' fit <- lm(Sepal.Length ~ ., data = iris) 12 | #' fit_log <- lm(log(Sepal.Length) ~ ., data = iris) 13 | #' fl <- flashlight(fit, data = iris, y = "Sepal.Length", label = "ols") 14 | #' fl_log <- flashlight(fit_log, y = "Sepal.Length", label = "ols", linkinv = exp) 15 | #' light_check(fl) 16 | #' light_check(fl_log) 17 | light_check <- function(x, ...) { 18 | UseMethod("light_check") 19 | } 20 | 21 | #' @describeIn light_check Default check method not implemented yet. 22 | #' @export 23 | light_check.default <- function(x, ...) { 24 | stop("No default method available yet.") 25 | } 26 | 27 | #' @describeIn light_check Checks if a flashlight object is consistently defined. 28 | #' @export 29 | light_check.flashlight <- function(x, ...) { 30 | if (is.null(x$label)) { 31 | stop("label should not be NULL.") 32 | } 33 | nms <- names(x) 34 | is_function <- function(nm) { 35 | if (nm %in% nms && !is.function(x[[nm]])) { 36 | stop(paste(nm, "needs to be a function.")) 37 | } 38 | } 39 | is_char <- function(nm, max_len = 1L) { 40 | if (nm %in% nms && !is.null(x[[nm]]) && 41 | !(is.character(x[[nm]]) && length(x[[nm]]) <= max_len)) { 42 | stop(paste(nm, "needs to be a character of length one.")) 43 | } 44 | } 45 | in_colnames <- function(nm) { 46 | if (nm %in% nms && !is.null(x[[nm]]) && !all(x[[nm]] %in% colnames(x$data))) { 47 | stop(paste(nm, "needs to be a column in 'data'.")) 48 | } 49 | } 50 | lapply(c("predict_function", "linkinv"), is_function) 51 | lapply(c("y", "w"), is_char) 52 | is_char("by", max_len = Inf) 53 | if ("metrics" %in% nms && !is.null(x[["metrics"]]) && 54 | !(is.list(x[["metrics"]]))) { 55 | stop("metrics needs to be a named list.") 56 | } 57 | if ("data" %in% nms && !is.null(x[["data"]])) { 58 | if (!inherits(x$data, "data.frame")) { 59 | stop("data should be a data.frame.") 60 | } 61 | lapply(c("y", "w", "by"), in_colnames) 62 | } 63 | invisible(x) 64 | } 65 | 66 | #' @describeIn light_check Checks if a multiflashlight object is consistently defined. 67 | #' @export 68 | light_check.multiflashlight <- function(x, ...) { 69 | # by 70 | if (!all_identical(x, `[[`, "by")) { 71 | warning("Inconsistent 'by' variables specified. 72 | Please pass 'by' in subsequent calls to 'light_*' functions.") 73 | } 74 | # metrics 75 | if (!all_identical(x, `[[`, "metrics")) { 76 | warning("metrics differ across flashlights. This might produce wrong result in variable importance. 77 | Please pass 'metric(s)' in subsequent calls to 'light_performance' or 'light_importance'.") 78 | } 79 | # colnames(data) 80 | if (!all_identical(x, function(z) colnames(z$data))) { 81 | warning("Column names differ across data in flashlights. This is rarely a good idea and can be 82 | avoided by specifying individual 'predict_function'.") 83 | } 84 | # dim(data) 85 | if (!all_identical(x, function(z) dim(z$data))){ 86 | warning("Data dimensions differ across data in flashlights. This might lead to unfair comparisons. 87 | Please pass 'data' in subsequent calls to 'light_*' functions.") 88 | } 89 | invisible(x) 90 | } 91 | -------------------------------------------------------------------------------- /R/light_combine.R: -------------------------------------------------------------------------------- 1 | #' Combine Objects 2 | #' 3 | #' Combines a list of similar objects each of class "light" by row binding 4 | #' `data.frame` slots and retaining the other slots from the first list element. 5 | #' 6 | #' @param x A list of objects of the same class. 7 | #' @param new_class An optional vector with additional class names to be added 8 | #' to the output. 9 | #' @param ... Further arguments passed from or to other methods. 10 | #' @returns If `x` is a list, an object like each element but with unioned rows 11 | #' in data slots. 12 | #' @export 13 | #' @examples 14 | #' fit_lm <- lm(Sepal.Length ~ ., data = iris) 15 | #' fit_glm <- glm(Sepal.Length ~ ., family = Gamma(link = "log"), data = iris) 16 | #' mod_lm <- flashlight(model = fit_lm, label = "lm", data = iris, y = "Sepal.Length") 17 | #' mod_glm <- flashlight( 18 | #' model = fit_glm, 19 | #' label = "glm", 20 | #' data = iris, 21 | #' y = "Sepal.Length", 22 | #' predict_function = function(object, newdata) 23 | #' predict(object, newdata, type = "response") 24 | #' ) 25 | #' mods <- multiflashlight(list(mod_lm, mod_glm)) 26 | #' perf_lm <- light_performance(mod_lm) 27 | #' perf_glm <- light_performance(mod_glm) 28 | #' manual_comb <- light_combine( 29 | #' list(perf_lm, perf_glm), 30 | #' new_class = "light_performance_multi" 31 | #' ) 32 | #' auto_comb <- light_performance(mods) 33 | #' all.equal(manual_comb, auto_comb) 34 | light_combine <- function(x, ...) { 35 | UseMethod("light_combine") 36 | } 37 | 38 | #' @describeIn light_combine Default method not implemented yet. 39 | #' @export 40 | light_combine.default <- function(x, ...) { 41 | stop("No default method available yet.") 42 | } 43 | 44 | #' @describeIn light_combine Since there is nothing to combine, the input is returned 45 | #' except for additional classes. 46 | #' @export 47 | light_combine.light <- function(x, new_class = NULL, ...) { 48 | add_classes(x, new_class) 49 | } 50 | 51 | #' @describeIn light_combine Combine a list of similar light objects. 52 | #' @export 53 | light_combine.list <- function(x, new_class = NULL, ...) { 54 | stopifnot( 55 | all(sapply(x, inherits, "light")), 56 | all_identical(x, class), 57 | all_identical(x, length), 58 | all_identical(x, names) 59 | ) 60 | 61 | out <- x[[1L]] 62 | data_slots <- names(out)[vapply(out, FUN = is.data.frame, FUN.VALUE = TRUE)] 63 | other_slots <- setdiff(names(out), data_slots) 64 | 65 | # Compare non-data slots for identity 66 | if (length(other_slots)) { 67 | stopifnot( 68 | vapply(other_slots, FUN = function(s) all_identical(x, `[[`, s), FUN.VALUE = TRUE) 69 | ) 70 | } 71 | 72 | # Row bind data elements 73 | if (length(data_slots)) { 74 | for (d in data_slots) { 75 | out[[d]] <- dplyr::bind_rows(lapply(x, `[[`, d)) 76 | out[[d]]$label_ <- factor(out[[d]]$label_, levels = unique(out[[d]]$label_)) 77 | } 78 | } 79 | class(out) <- union(new_class, class(x[[1L]])) 80 | out 81 | } 82 | 83 | -------------------------------------------------------------------------------- /R/light_global_surrogate.R: -------------------------------------------------------------------------------- 1 | #' Global Surrogate Tree 2 | #' 3 | #' Model predictions are modelled by a single decision tree, serving as an easy 4 | #' to interprete surrogate to the original model. 5 | #' As suggested in Molnar (see reference below), the quality of the surrogate 6 | #' tree can be measured by its R-squared. The size of the tree can be modified 7 | #' by passing `...` arguments to [rpart::rpart()]. 8 | #' 9 | #' @param x An object of class "flashlight" or "multiflashlight". 10 | #' @param data An optional `data.frame`. 11 | #' @param by An optional vector of column names used to additionally group the results. 12 | #' For each group, a separate tree is grown. 13 | #' @param v Vector of variables used in the surrogate model. 14 | #' Defaults to all variables in `data` except "by", "w" and "y". 15 | #' @param use_linkinv Should retransformation function be applied? Default is `TRUE`. 16 | #' @param n_max Maximum number of data rows to consider to build the tree. 17 | #' @param seed An integer random seed used to select data rows if `n_max` is lower than 18 | #' the number of data rows. 19 | #' @param keep_max_levels Number of levels of categorical and factor variables to keep. 20 | #' Other levels are combined to a level "Other". This prevents [rpart::rpart()] to 21 | #' take too long to split non-numeric variables with many levels. 22 | #' @param ... Arguments passed to [rpart::rpart()], such as `maxdepth`. 23 | #' @returns 24 | #' An object of class "light_global_surrogate" with the following elements: 25 | #' - `data` A tibble with results. 26 | #' - `by` Same as input `by`. 27 | #' @export 28 | #' @references Molnar C. (2019). Interpretable Machine Learning. 29 | #' @examples 30 | #' fit <- lm(Sepal.Length ~ ., data = iris) 31 | #' x <- flashlight(model = fit, label = "lm", data = iris) 32 | #' sur <- light_global_surrogate(x) 33 | #' sur$data$r_squared 34 | #' plot(sur) 35 | #' @seealso [plot.light_global_surrogate()] 36 | light_global_surrogate <- function(x, ...) { 37 | UseMethod("light_global_surrogate") 38 | } 39 | 40 | #' @describeIn light_global_surrogate Default method not implemented yet. 41 | #' @export 42 | light_global_surrogate.default <- function(x, ...) { 43 | stop("light_global_surrogate method is only available for objects of class flashlight or multiflashlight.") 44 | } 45 | 46 | #' @describeIn light_global_surrogate Surrogate model for a flashlight. 47 | #' @export 48 | light_global_surrogate.flashlight <- function(x, data = x$data, by = x$by, 49 | v = NULL, use_linkinv = TRUE, 50 | n_max = Inf, seed = NULL, 51 | keep_max_levels = 4L, ...) { 52 | stopifnot( 53 | "No data!" = is.data.frame(data) && nrow(data) >= 1L, 54 | "'by' not in 'data'!" = by %in% colnames(data), 55 | "Not all 'v' in 'data'" = v %in% colnames(data), 56 | !any(c("label_", "r_squared", "tree_") %in% by) 57 | ) 58 | 59 | # Set v and remove 'by' from it 60 | if (is.null(v)) { 61 | v <- setdiff(colnames(data), c(x$y, by, x$w)) 62 | } else if (!is.null(by)) { 63 | v <- setdiff(v, by) 64 | } 65 | 66 | # Subsample if data is very large 67 | n <- nrow(data) 68 | if (n > n_max) { 69 | if (!is.null(seed)) { 70 | set.seed(seed) 71 | } 72 | data <- data[sample(n, n_max), , drop = FALSE] 73 | } 74 | 75 | x <- flashlight( 76 | x, data = data, by = by, linkinv = if (use_linkinv) x$linkinv else function(z) z 77 | ) 78 | 79 | # Add response of tree model 80 | stopifnot(!("pred_" %in% colnames(data))) 81 | data$pred_ <- stats::predict(x) 82 | 83 | # Lump factors with many levels for tree fit 84 | for (vv in v) { 85 | data[[vv]] <- .fct_lump(data[[vv]], keep_max_levels = keep_max_levels) 86 | } 87 | 88 | # Fit tree within by group 89 | core_func <- function(X) { 90 | fit <- rpart::rpart( 91 | stats::reformulate(v, "pred_"), data = X, 92 | weights = if (!is.null(x$w)) X[[x$w]], 93 | model = FALSE, 94 | y = FALSE, 95 | xval = 0, 96 | ... 97 | ) 98 | r2 <- MetricsWeighted::r_squared(X$pred_, stats::predict(fit, X)) 99 | stats::setNames(data.frame(r2, I(list(fit))), c("r_squared", "tree_")) 100 | } 101 | res <- Reframe(data, FUN = core_func, .by = by) 102 | 103 | # Organize output 104 | res$label_ <- x$label 105 | out <- list(data = res[, c("label_", by, "r_squared", "tree_")], by = by) 106 | add_classes(out, c("light_global_surrogate", "light")) 107 | } 108 | 109 | #' @describeIn light_global_surrogate Surrogate model for a multiflashlight. 110 | #' @export 111 | light_global_surrogate.multiflashlight <- function(x, ...) { 112 | light_combine( 113 | lapply(x, light_global_surrogate, ...), 114 | new_class = "light_global_surrogate_multi" 115 | ) 116 | } 117 | 118 | #' Plot Global Surrogate Trees 119 | #' 120 | #' Use [rpart.plot::rpart.plot()] to visualize trees fitted by 121 | #' [light_global_surrogate()]. 122 | #' 123 | #' @param x An object of class "light_global_surrogate". 124 | #' @param type Plot type, see help of [rpart.plot::rpart.plot()]. Default is 5. 125 | #' @param auto_main Automatic plot titles (only if multiple trees are shown). 126 | #' @param mfrow If multiple trees are shown in the same figure: 127 | #' what value of `mfrow` to use in [graphics::par()]? 128 | #' @param ... Further arguments passed to [rpart.plot::rpart.plot()]. 129 | #' @returns An object of class "ggplot". 130 | #' @export 131 | #' @seealso [light_global_surrogate()] 132 | plot.light_global_surrogate <- function(x, type = 5, auto_main = TRUE, 133 | mfrow = NULL, ...) { 134 | data <- x$data 135 | multi <- is.light_global_surrogate_multi(x) 136 | ndim <- length(x$by) + multi 137 | if (ndim == 0L) { 138 | rpart.plot::rpart.plot(data$tree_[[1L]], roundint = FALSE, type = type, ...) 139 | } else if (ndim == 1L) { 140 | dim_col <- data[[if (multi) "label_" else x$by[1L]]] 141 | m <- length(dim_col) 142 | if (is.null(mfrow)) { 143 | nr <- floor(sqrt(m)) 144 | mfrow <- c(nr, ceiling(m / nr)) 145 | } 146 | old_params <- graphics::par(mfrow = mfrow) 147 | on.exit(graphics::par(old_params)) 148 | 149 | for (i in seq_len(m)) { 150 | rpart.plot::rpart.plot( 151 | data$tree_[[i]], 152 | roundint = FALSE, 153 | type = type, 154 | main = if (auto_main) dim_col[i], 155 | ... 156 | ) 157 | } 158 | } else { 159 | stop("Either one 'by' variable or a multiflashlight is supported.") 160 | } 161 | } 162 | 163 | # Helper function to lump too many levels of a factor to category "Other" 164 | .fct_lump <- function(x, keep_max_levels = 4L, other_name = "Other") { 165 | if (!is.character(x) && !is.factor(x)) { 166 | return(x) 167 | } 168 | if (!is.factor(x)) { 169 | x <- factor(x) 170 | } 171 | m <- nlevels(x) 172 | if (m > keep_max_levels + 1L) { 173 | drop_levels <- names(sort(table(x), decreasing = TRUE))[(keep_max_levels + 1L):m] 174 | levels(x)[match(drop_levels, levels(x))] <- other_name 175 | } 176 | x 177 | } 178 | -------------------------------------------------------------------------------- /R/light_performance.R: -------------------------------------------------------------------------------- 1 | #' Model Performance of Flashlight 2 | #' 3 | #' Calculates performance of a flashlight with respect to one or more 4 | #' performance measure. 5 | #' 6 | #' The minimal required elements in the (multi-) flashlight are "y", "predict_function", 7 | #' "model", "data" and "metrics". The latter two can also directly be passed to 8 | #' [light_performance()]. Note that by default, no retransformation function is applied. 9 | #' 10 | #' @param x An object of class "flashlight" or "multiflashlight". 11 | #' @param data An optional `data.frame`. 12 | #' @param by An optional vector of column names used to additionally group the results. 13 | #' Will overwrite `x$by`. 14 | #' @param metrics An optional named list with metrics. Each metric takes at least 15 | #' four arguments: actual, predicted, case weights w and `...`. 16 | #' @param use_linkinv Should retransformation function be applied? Default is `FALSE`. 17 | #' @param ... Arguments passed from or to other functions. 18 | #' @returns 19 | #' An object of class "light_performance" with the following elements: 20 | #' - `data`: A tibble containing the results. 21 | #' - `by` Same as input `by`. 22 | #' @export 23 | #' @examples 24 | #' fit_part <- lm(Sepal.Length ~ Species + Petal.Length, data = iris) 25 | #' fl_part <- flashlight( 26 | #' model = fit_part, label = "part", data = iris, y = "Sepal.Length" 27 | #' ) 28 | #' plot(light_performance(fl_part, by = "Species"), fill = "chartreuse4") 29 | #' 30 | #' # Second model 31 | #' fit_full <- lm(Sepal.Length ~ ., data = iris) 32 | #' fl_full <- flashlight( 33 | #' model = fit_full, label = "full", data = iris, y = "Sepal.Length" 34 | #' ) 35 | #' fls <- multiflashlight(list(fl_part, fl_full)) 36 | #' 37 | #' plot(light_performance(fls, by = "Species")) 38 | #' plot(light_performance(fls, by = "Species"), swap_dim = TRUE) 39 | #' @seealso [plot.light_performance()] 40 | light_performance <- function(x, ...) { 41 | UseMethod("light_performance") 42 | } 43 | 44 | #' @describeIn light_performance Default method not implemented yet. 45 | #' @export 46 | light_performance.default <- function(x, ...) { 47 | stop("light_performance method is only available for objects of class flashlight or multiflashlight.") 48 | } 49 | 50 | #' @describeIn light_performance Model performance of flashlight object. 51 | #' @export 52 | light_performance.flashlight <- function(x, data = x$data, by = x$by, 53 | metrics = x$metrics, 54 | use_linkinv = FALSE, ...) { 55 | stopifnot( 56 | "No data!" = is.data.frame(data) && nrow(data) >= 1L, 57 | "'by' not in 'data'!" = by %in% colnames(data), 58 | "No metric!" = !is.null(metrics), 59 | "No 'y' defined in flashlight!" = !is.null(x$y), 60 | !any(c("metric_", "value_", "label_", "pred_") %in% by) 61 | ) 62 | 63 | # Update flashlight 64 | x <- flashlight( 65 | x, data = data, by = by, linkinv = if (use_linkinv) x$linkinv else function(z) z 66 | ) 67 | 68 | # Calculate predictions 69 | data$pred_ <- stats::predict(x) 70 | data[[x$y]] <- response(x) # Applies linkinv 71 | 72 | # Aggregate the result within by groups 73 | core_fun <- function(X) { 74 | MetricsWeighted::performance( 75 | X, 76 | actual = x$y, 77 | predicted = "pred_", 78 | w = x$w, 79 | metrics = metrics, 80 | key = "metric_", 81 | value = "value_", 82 | ... 83 | ) 84 | } 85 | agg <- Reframe(data, FUN = core_fun, .by = by) 86 | agg$label_ <- x$label 87 | 88 | # Organize output 89 | add_classes(list(data = agg, by = by), classes = c("light_performance", "light")) 90 | } 91 | 92 | #' @describeIn light_performance Model performance of multiflashlight object. 93 | #' @export 94 | light_performance.multiflashlight <- function(x, ...) { 95 | light_combine( 96 | lapply(x, light_performance, ...), new_class = "light_performance_multi" 97 | ) 98 | } 99 | 100 | #' Visualize Model Performance 101 | #' 102 | #' Minimal visualization of an object of class "light_performance" as 103 | #' [ggplot2::geom_bar()]. The object returned has class "ggplot", 104 | #' and can be further customized. 105 | #' 106 | #' The plot is organized as a bar plot as follows: 107 | #' For flashlights without "by" variable specified, a single bar is drawn. 108 | #' Otherwise, the "by" variable (or the flashlight label if there is no "by" variable) 109 | #' is represented by the "x" aesthetic. 110 | #' 111 | #' The flashlight label (in case of one "by" variable) is represented by dodged bars. 112 | #' This strategy makes sure that performance of different flashlights can 113 | #' be compared easiest. Set "swap_dim = TRUE" to revert the role of dodging and x 114 | #' aesthetic. Different metrics are always represented by facets. 115 | #' 116 | #' @importFrom rlang .data 117 | #' @param x An object of class "light_performance". 118 | #' @param swap_dim Should representation of dimensions 119 | #' (either two "by" variables or one "by" variable and multiflashlight) 120 | #' of x aesthetic and dodge fill aesthetic be swapped? Default is `FALSE`. 121 | #' @param geom Geometry of plot (either "bar" or "point") 122 | #' @param facet_scales Scales argument passed to [ggplot2::facet_wrap()]. 123 | #' @param rotate_x Should x axis labels be rotated by 45 degrees? 124 | #' @param ... Further arguments passed to [ggplot2::geom_bar()] or 125 | #' [ggplot2::geom_point()]. 126 | #' @returns An object of class "ggplot". 127 | #' @export 128 | #' @seealso [light_performance()] 129 | plot.light_performance <- function(x, swap_dim = FALSE, geom = c("bar", "point"), 130 | facet_scales = "free_y", rotate_x = FALSE, ...) { 131 | geom <- match.arg(geom) 132 | data <- x$data 133 | nby <- length(x$by) 134 | multi <- is.light_performance_multi(x) 135 | ndim <- nby + multi 136 | if (ndim > 2L) { 137 | stop("Plot method not defined for single flashlights with more than two by variables or 138 | a multiflashlight with more than one by variable.") 139 | } 140 | 141 | # Differentiate some plot cases 142 | if (ndim <= 1L) { 143 | xvar <- if (nby) x$by[1L] else "label_" 144 | p <- ggplot2::ggplot(data, ggplot2::aes(y = value_, x = .data[[xvar]])) 145 | if (geom == "bar") { 146 | p <- p + ggplot2::geom_bar(stat = "identity", ...) 147 | } else if (geom == "point") { 148 | p <- p + ggplot2::geom_point(...) 149 | } 150 | } else { 151 | second_dim <- if (multi) "label_" else x$by[2L] 152 | x_var <- if (!swap_dim) x$by[1L] else second_dim 153 | dodge_var <- if (!swap_dim) second_dim else x$by[1L] 154 | 155 | p <- ggplot2::ggplot(data, ggplot2::aes(y = value_, x = .data[[x_var]])) 156 | if (geom == "bar") { 157 | p <- p + ggplot2::geom_bar( 158 | ggplot2::aes(fill = .data[[dodge_var]]), stat = "identity", position = "dodge", 159 | ... 160 | ) 161 | } else if (geom == "point") { 162 | p <- p + ggplot2::geom_point( 163 | ggplot2::aes(group = .data[[dodge_var]], color = .data[[dodge_var]]), ... 164 | ) 165 | } 166 | } 167 | 168 | # Multiple metrics always go into a facet due to different y scales 169 | if (length(unique(data$metric_)) >= 2L) { 170 | p <- p + ggplot2::facet_wrap(~ metric_, scales = facet_scales) 171 | } 172 | if (rotate_x) { 173 | p <- p + rotate_x() 174 | } 175 | p + ggplot2::ylab("Value") 176 | } 177 | -------------------------------------------------------------------------------- /R/light_recode.R: -------------------------------------------------------------------------------- 1 | #' Recode Factor Columns - DEPRECATED 2 | #' 3 | #' @param ... Deprecated. 4 | #' @returns Deprecated. 5 | #' @export 6 | light_recode <- function(...) { 7 | stop("'light_recode()' is deprecated.") 8 | } 9 | -------------------------------------------------------------------------------- /R/light_scatter.R: -------------------------------------------------------------------------------- 1 | #' Scatter Plot Data 2 | #' 3 | #' This function prepares values for drawing a scatter plot of predicted values, 4 | #' responses, or residuals against a selected variable. 5 | #' 6 | #' @param x An object of class "flashlight" or "multiflashlight". 7 | #' @param v The variable name to be shown on the x-axis. 8 | #' @param data An optional `data.frame`. 9 | #' @param by An optional vector of column names used to additionally group the results. 10 | #' @param type Type of the profile: Either "predicted", "response", or "residual". 11 | #' @param use_linkinv Should retransformation function be applied? Default is `TRUE`. 12 | #' @param n_max Maximum number of data rows to select. Will be randomly picked. 13 | #' @param seed An integer random seed used for subsampling. 14 | #' @param ... Further arguments passed from or to other methods. 15 | #' @returns 16 | #' An object of class "light_scatter" with the following elements: 17 | #' - `data`: A tibble with results. 18 | #' - `by`: Same as input `by`. 19 | #' - `v`: The variable name evaluated. 20 | #' - `type`: Same as input `type`. For information only. 21 | #' @export 22 | #' @examples 23 | #' fit_a <- lm(Sepal.Length ~ . -Petal.Length, data = iris) 24 | #' fit_b <- lm(Sepal.Length ~ ., data = iris) 25 | #' 26 | #' fl_a <- flashlight(model = fit_a, label = "no Petal.Length") 27 | #' fl_b <- flashlight(model = fit_b, label = "all") 28 | #' fls <- multiflashlight(list(fl_a, fl_b), data = iris, y = "Sepal.Length") 29 | #' 30 | #' plot(light_scatter(fls, v = "Petal.Width"), color = "darkred") 31 | #' 32 | #' sc <- light_scatter(fls, "Petal.Length", by = "Species", type = "residual") 33 | #' plot(sc) 34 | #' @seealso [plot.light_scatter()] 35 | light_scatter <- function(x, ...) { 36 | UseMethod("light_scatter") 37 | } 38 | 39 | #' @describeIn light_scatter Default method not implemented yet. 40 | #' @export 41 | light_scatter.default <- function(x, ...) { 42 | stop("light_scatter method is only available for objects of class flashlight or multiflashlight.") 43 | } 44 | 45 | #' @describeIn light_scatter Variable profile for a flashlight. 46 | #' @export 47 | light_scatter.flashlight <- function(x, v, data = x$data, by = x$by, 48 | type = c("predicted", "response", 49 | "residual", "shap"), 50 | use_linkinv = TRUE, n_max = 400, 51 | seed = NULL, ...) { 52 | type <- match.arg(type) 53 | 54 | if (type == "shap") { 55 | stop("type = 'shap' is deprecated.") 56 | } 57 | 58 | stopifnot( 59 | "No data!" = is.data.frame(data) && nrow(data) >= 1L, 60 | "'by' not in 'data'!" = by %in% colnames(data), 61 | "'v' not in 'data'!" = v %in% colnames(data), 62 | !any(c("value_", "label_") %in% by) 63 | ) 64 | if (type %in% c("response", "residual") && is.null(x$y)) { 65 | stop("You need to specify 'y' in flashlight.") 66 | } 67 | n <- nrow(data) 68 | 69 | # Subsample rows if data too large 70 | if (n > n_max) { 71 | if (!is.null(seed)) { 72 | set.seed(seed) 73 | } 74 | data <- data[sample(n, n_max), , drop = FALSE] 75 | } 76 | 77 | # Update flashlight 78 | x <- flashlight( 79 | x, data = data, by = by, linkinv = if (use_linkinv) x$linkinv else function(z) z 80 | ) 81 | 82 | # Calculate values 83 | data$value_ <- switch( 84 | type, 85 | response = response(x), 86 | predicted = stats::predict(x), 87 | residual = stats::residuals(x) 88 | ) 89 | 90 | # Organize output 91 | data$label_ <- x$label 92 | add_classes( 93 | list( 94 | data = tibble::as_tibble(data[, c("label_", by, v, "value_")]), 95 | by = by, 96 | v = v, 97 | type = type 98 | ), 99 | c("light_scatter", "light") 100 | ) 101 | } 102 | 103 | #' @describeIn light_scatter light_scatter for a multiflashlight. 104 | #' @export 105 | light_scatter.multiflashlight <- function(x, ...) { 106 | light_combine(lapply(x, light_scatter, ...), new_class = "light_scatter_multi") 107 | } 108 | 109 | #' Scatter Plot 110 | #' 111 | #' Values are plotted against a variable. The object returned is of class "ggplot" 112 | #' and can be further customized. To avoid overplotting, try `alpha = 0.2` or 113 | #' `position = "jitter"`. 114 | #' 115 | #' @importFrom rlang .data 116 | #' 117 | #' @inheritParams plot.light_performance 118 | #' @param x An object of class "light_scatter". 119 | #' @param swap_dim If multiflashlight and one "by" variable, or single flashlight 120 | #' with two "by" variables, swap the role of color variable and facet variable. 121 | #' If multiflashlight or one "by" variable, use colors instead of facets. 122 | #' @param ... Further arguments passed to [ggplot2::geom_point()]. Typical arguments 123 | #' would be `alpha = 0.2` or `position = "jitter"` to avoid overplotting. 124 | #' @returns An object of class "ggplot". 125 | #' @export 126 | #' @seealso [light_scatter()] 127 | plot.light_scatter <- function(x, swap_dim = FALSE, facet_scales = "free_x", 128 | rotate_x = FALSE, ...) { 129 | data <- x$data 130 | nby <- length(x$by) 131 | multi <- is.light_scatter_multi(x) 132 | ndim <- nby + multi 133 | if (ndim > 2L) { 134 | stop("Plot method not defined for more than two by variables or 135 | multiflashlight with more than one by variable.") 136 | } 137 | # Distinguish some cases 138 | p <- ggplot2::ggplot( 139 | x$data, ggplot2::aes(x = .data[[x$v]], y = value_) 140 | ) 141 | if (ndim == 0L) { 142 | p <- p + ggplot2::geom_point(...) 143 | } else if (ndim == 1L) { 144 | first_dim <- if (multi) "label_" else x$by[1L] 145 | if (swap_dim) { 146 | p <- p + 147 | ggplot2::geom_point(ggplot2::aes(color = .data[[first_dim]]), ...) + 148 | ggplot2::guides(color = ggplot2::guide_legend(override.aes = list(alpha = 1))) 149 | } else { 150 | p <- p + 151 | ggplot2::geom_point(...) + 152 | ggplot2::facet_wrap(first_dim, scales = facet_scales) 153 | } 154 | } else { 155 | second_dim <- if (multi) "label_" else x$by[2L] 156 | wrap_var <- if (swap_dim) x$by[1L] else second_dim 157 | col_var <- if (swap_dim) second_dim else x$by[1L] 158 | p <- p + 159 | ggplot2::geom_point(ggplot2::aes(color = .data[[col_var]]), ...) + 160 | ggplot2::facet_wrap(wrap_var, scales = facet_scales) + 161 | override_alpha() 162 | } 163 | if (rotate_x) { 164 | p <- p + rotate_x() 165 | } 166 | p + ggplot2::ylab(x$type) 167 | } 168 | -------------------------------------------------------------------------------- /R/methods.R: -------------------------------------------------------------------------------- 1 | #' Prints a flashlight 2 | #' 3 | #' Print method for an object of class "flashlight". 4 | #' 5 | #' @param x A on object of class "flashlight". 6 | #' @param ... Further arguments passed from other methods. 7 | #' @returns Invisibly, the input is returned. 8 | #' @export 9 | #' @examples 10 | #' fit <- lm(Sepal.Length ~ ., data = iris) 11 | #' x <- flashlight(model = fit, label = "lm", y = "Sepal.Length", data = iris) 12 | #' x 13 | #' @seealso [flashlight()] 14 | print.flashlight <- function(x, ...) { 15 | cat("\nFlashlight", x$label, "\n") 16 | cat("\nModel:\t\t\t", .yn(x$model, "Yes")) 17 | cat("\ny:\t\t\t", .yn(x$y)) 18 | cat("\nw:\t\t\t", .yn(x$w)) 19 | cat("\nby:\t\t\t", .yn(x$by)) 20 | cat("\ndata dim:\t\t", .yn(dim(x$data))) 21 | cat("\nmetrics:\t\t", .yn(x[["metrics"]], names(x$metrics))) 22 | cat("\n") 23 | invisible(x) 24 | } 25 | 26 | #' Prints a multiflashlight 27 | #' 28 | #' Print method for an object of class "multiflashlight". 29 | #' 30 | #' @param x An object of class "multiflashlight". 31 | #' @param ... Further arguments passed to [print.flashlight()]. 32 | #' @returns Invisibly, the input is returned. 33 | #' @export 34 | #' @examples 35 | #' fit_lm <- lm(Sepal.Length ~ ., data = iris) 36 | #' fit_glm <- glm(Sepal.Length ~ ., family = Gamma(link = log), data = iris) 37 | #' fl_lm <- flashlight(model = fit_lm, label = "lm") 38 | #' fl_glm <- flashlight(model = fit_glm, label = "glm") 39 | #' multiflashlight(list(fl_lm, fl_glm), data = iris) 40 | #' @seealso [multiflashlight()] 41 | print.multiflashlight <- function(x, ...) { 42 | lapply(x, print.flashlight, ...) 43 | invisible(x) 44 | } 45 | 46 | #' Prints light Object 47 | #' 48 | #' Print method for an object of class "light". 49 | #' 50 | #' @param x A on object of class "light". 51 | #' @param ... Further arguments passed from other methods. 52 | #' @returns Invisibly, the input is returned. 53 | #' @method print light 54 | #' @export 55 | #' @examples 56 | #' fit <- lm(Sepal.Length ~ ., data = iris) 57 | #' fl <- flashlight(model = fit, label = "lm", y = "Sepal.Length", data = iris) 58 | #' light_performance(fl, v = "Species") 59 | print.light <- function(x, ...) { 60 | cat("\nI am an object of class", class(x)[1L], "\n") 61 | x_cs <- x[vapply(x, FUN = is.data.frame, FUN.VALUE = TRUE)] 62 | if (length(x_cs)) { 63 | cat("\ndata.frames:\n") 64 | for (nm in names(x_cs)) { 65 | cat("\n", nm, "\n") 66 | print(x_cs[[nm]]) 67 | } 68 | } 69 | invisible(x) 70 | } 71 | 72 | #' Predictions for flashlight 73 | #' 74 | #' Predict method for an object of class "flashlight". 75 | #' Pass additional elements to update the flashlight, typically `data`. 76 | #' 77 | #' @param object An object of class "flashlight". 78 | #' @param ... Arguments used to update the flashlight. 79 | #' @returns A vector with predictions. 80 | #' @export 81 | #' @examples 82 | #' fit <- lm(Sepal.Length ~ ., data = iris) 83 | #' fl <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols") 84 | #' predict(fl)[1:5] 85 | #' predict(fl, data = iris[1:5, ]) 86 | predict.flashlight <- function(object, ...) { 87 | object <- flashlight(object, check = FALSE, ...) 88 | if (is.null(object[["data"]])) { 89 | stop("No 'data' to predict.") 90 | } 91 | if (!is.data.frame(object[["data"]])) { 92 | stop("'data' needs to be a data.frame.") 93 | } 94 | pred <- with(object, linkinv(predict_function(model, data))) 95 | if (!is.numeric(pred) && !is.logical(pred)) { 96 | stop("Non-numeric/non-logical predictions detected. Please modify 'predict_function' accordingly.") 97 | } 98 | pred 99 | } 100 | 101 | #' Predictions for multiflashlight 102 | #' 103 | #' Predict method for an object of class "multiflashlight". 104 | #' Pass additional elements to update the flashlight, typically `data`. 105 | #' 106 | #' @param object An object of class "multiflashlight". 107 | #' @param ... Arguments used to update the multiflashlight. 108 | #' @returns A named list of prediction vectors. 109 | #' @export 110 | #' @examples 111 | #' fit_part <- lm(Sepal.Length ~ Petal.Length, data = iris) 112 | #' fit_full <- lm(Sepal.Length ~ ., data = iris) 113 | #' mod_full <- flashlight(model = fit_full, label = "full") 114 | #' mod_part <- flashlight(model = fit_part, label = "part") 115 | #' mods <- multiflashlight(list(mod_full, mod_part), data = iris, y = "Sepal.Length") 116 | #' predict(mods, data = iris[1:5, ]) 117 | predict.multiflashlight <- function(object, ...) { 118 | lapply(object, stats::predict, ...) 119 | } 120 | 121 | #' Residuals for flashlight 122 | #' 123 | #' Residuals method for an object of class "flashlight". 124 | #' Pass additional elements to update the flashlight before calculation of residuals. 125 | #' 126 | #' @param object An object of class "flashlight". 127 | #' @param ... Arguments used to update the flashlight before calculating the residuals. 128 | #' @returns A numeric vector with residuals. 129 | #' @export 130 | #' @examples 131 | #' fit <- lm(Sepal.Length ~ ., data = iris) 132 | #' x <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols") 133 | #' residuals(x)[1:5] 134 | residuals.flashlight <- function(object, ...) { 135 | object <- flashlight(object, check = FALSE, ...) 136 | response(object) - stats::predict(object) 137 | } 138 | 139 | #' Residuals for multiflashlight 140 | #' 141 | #' Residuals method for an object of class "multiflashlight". 142 | #' Pass additional elements to update the multiflashlight before calculation of 143 | #' residuals. 144 | #' 145 | #' @param object An object of class "multiflashlight". 146 | #' @param ... Arguments used to update the multiflashlight before 147 | #' calculating the residuals. 148 | #' @returns A named list with residuals per flashlight. 149 | #' @export 150 | #' @examples 151 | #' fit_part <- lm(Sepal.Length ~ Petal.Length, data = iris) 152 | #' fit_full <- lm(Sepal.Length ~ ., data = iris) 153 | #' mod_full <- flashlight(model = fit_full, label = "full") 154 | #' mod_part <- flashlight(model = fit_part, label = "part") 155 | #' mods <- multiflashlight(list(mod_full, mod_part), data = iris, y = "Sepal.Length") 156 | #' residuals(mods, data = head(iris)) 157 | residuals.multiflashlight <- function(object, ...) { 158 | lapply(object, stats::residuals, ...) 159 | } 160 | 161 | #' Response of multi/-flashlight 162 | #' 163 | #' Extracts response from object of class "flashlight". 164 | #' 165 | #' @param object An object of class "flashlight". 166 | #' @param ... Arguments used to update the flashlight before extracting the response. 167 | #' @returns A numeric vector of responses. 168 | #' @export 169 | #' @examples 170 | #' fit <- lm(Sepal.Length ~ ., data = iris) 171 | #' (fl <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols")) 172 | #' response(fl)[1:5] 173 | #' response(fl, data = iris[1:5, ]) 174 | #' response(fl, data = iris[1:5, ], linkinv = exp) 175 | response <- function(object, ...) { 176 | UseMethod("response") 177 | } 178 | 179 | #' @describeIn response Default method not implemented yet. 180 | #' @export 181 | response.default <- function(object, ...) { 182 | stop("No default method available yet.") 183 | } 184 | 185 | #' @describeIn response Extract response from flashlight object. 186 | #' @export 187 | response.flashlight <- function(object, ...) { 188 | object <- flashlight(object, check = FALSE, ...) 189 | required <- c("y", "linkinv", "data") 190 | stopifnot(sapply(with(object, required), Negate(is.null))) 191 | with(object, linkinv(data[[y]])) 192 | } 193 | 194 | #' @describeIn response Extract responses from multiflashlight object. 195 | #' @export 196 | response.multiflashlight <- function(object, ...) { 197 | lapply(object, response, ...) 198 | } 199 | 200 | 201 | 202 | # Helper functions 203 | .yn <- function(z, ret = z) { 204 | if (!is.null(z)) ret else "No" 205 | } 206 | -------------------------------------------------------------------------------- /R/multiflashlight.R: -------------------------------------------------------------------------------- 1 | #' Create or Update a multiflashlight 2 | #' 3 | #' Combines a list of flashlights to an object of class "multiflashlight" 4 | #' and/or updates a multiflashlight. 5 | #' 6 | #' @param x An object of class "multiflashlight", "flashlight" or a list of flashlights. 7 | #' @param ... Optional arguments in the flashlights to update, see examples. 8 | #' @returns An object of class "multiflashlight" (a named list of flashlight objects). 9 | #' @export 10 | #' @examples 11 | #' fit_lm <- lm(Sepal.Length ~ ., data = iris) 12 | #' fit_glm <- glm(Sepal.Length ~ ., family = Gamma(link = log), data = iris) 13 | #' mod_lm <- flashlight(model = fit_lm, label = "lm") 14 | #' mod_glm <- flashlight(model = fit_glm, label = "glm") 15 | #' (mods <- multiflashlight(list(mod_lm, mod_glm))) 16 | #' @seealso [flashlight()] 17 | multiflashlight <- function(x, ...) { 18 | UseMethod("multiflashlight") 19 | } 20 | 21 | #' @describeIn multiflashlight Used to create a flashlight object. 22 | #' No \code{x} has to be passed in this case. 23 | #' @export 24 | multiflashlight.default <- function(x, ...) { 25 | stop("No default method available yet.") 26 | } 27 | 28 | #' @describeIn multiflashlight Updates an existing flashlight object and turns 29 | #' into a multiflashlight. 30 | #' @export 31 | multiflashlight.flashlight <- function(x, ...) { 32 | multiflashlight(list(x), ...) 33 | } 34 | 35 | #' @describeIn multiflashlight Creates (and updates) a multiflashlight from a list 36 | #' of flashlights. 37 | #' @export 38 | multiflashlight.list <- function(x, ...) { 39 | stopifnot( 40 | "x must be a list of flashlight objects" = is.list(x), 41 | "x must be a list of flashlight objects" = 42 | vapply(x, is.flashlight, FUN.VALUE = TRUE) 43 | ) 44 | 45 | # Update single flashlights 46 | out <- lapply(x, flashlight, ...) 47 | 48 | # Set names 49 | lab <- sapply(x, `[[`, "label") 50 | if (anyDuplicated(lab)) { 51 | stop("flashlights must have different 'label'.") 52 | } 53 | names(out) <- lab 54 | 55 | # Organize output 56 | class(out) <- c("multiflashlight", "list") 57 | light_check(out) 58 | } 59 | 60 | #' @describeIn multiflashlight Updates an object of class "multiflashlight". 61 | #' @export 62 | multiflashlight.multiflashlight <- function(x, ...) { 63 | multiflashlight(lapply(x, flashlight, ...)) 64 | } 65 | -------------------------------------------------------------------------------- /R/utils.R: -------------------------------------------------------------------------------- 1 | # Helper functions 2 | 3 | # Add vector of classes upfront existing ones 4 | add_classes <- function(x, classes) { 5 | class(x) <- union(classes, class(x)) 6 | x 7 | } 8 | 9 | # Renames one column of a data.frame by passing character strings for old, new 10 | rename_one <- function(x, old, new) { 11 | colnames(x)[colnames(x) == old] <- new 12 | x 13 | } 14 | 15 | # Organize binning strategy per variable for 2d partial dependence 16 | fix_strategy <- function(v, n_bins, cut_type, breaks, pd_evaluate_at) { 17 | stopifnot( 18 | "breaks must be NULL or a named list" = 19 | is.null(breaks) || is.list(breaks), 20 | "pd_evaluate_at must be NULL or a named list" = 21 | is.null(pd_evaluate_at) || is.list(pd_evaluate_at), 22 | "n_bins should be a numeric vector of length <=2" = 23 | length(n_bins) <= 2L && is.numeric(n_bins), 24 | "cut_type should be a character vector of length <=2" = 25 | length(cut_type) <= 2L && all(cut_type %in% c("equal", "quantile")) 26 | ) 27 | strategy <- list() 28 | for (i in 1:2) { 29 | vv <- v[i] 30 | strategy[[vv]] <- list( 31 | breaks = if (vv %in% names(breaks)) breaks[[vv]], 32 | pd_evaluate_at = if (vv %in% names(pd_evaluate_at)) pd_evaluate_at[[vv]], 33 | n_bins = n_bins[min(i, length(n_bins))], 34 | cut_type = cut_type[min(i, length(cut_type))] 35 | ) 36 | } 37 | return(strategy) 38 | } 39 | 40 | # Applies df-valued FUN to X grouped by BY 41 | Reframe <- function(X, FUN, .by = NULL, as_tib = TRUE) { 42 | if (is.null(.by)) { 43 | out <- FUN(X) 44 | } else { 45 | X_grouped <- dplyr::group_by(X, dplyr::across(tidyselect::all_of(.by))) 46 | out <- dplyr::reframe(X_grouped, FUN(dplyr::pick(dplyr::everything()))) 47 | } 48 | if (as_tib && !tibble::is_tibble(out)) { 49 | out <- tibble::as_tibble(out) 50 | } 51 | if (!as_tib && tibble::is_tibble(out)) { 52 | out <- as.data.frame(out) 53 | } 54 | out 55 | } 56 | 57 | #' all_identical 58 | #' 59 | #' Checks if an aspect is identical for all elements in a nested list. 60 | #' The aspect is specified by `fun`, e.g., `[[`, followed by the element 61 | #' name to compare. 62 | #' 63 | #' @noRd 64 | #' @param x A nested list of objects. 65 | #' @param fun Function used to extract information of each element of `x`. 66 | #' @param ... Further arguments passed to `fun()`. 67 | #' @returns A logical vector of length one. 68 | #' @examples 69 | #' x <- list(a = 1, b = 2) 70 | #' y <- list(a = 1, b = 3) 71 | #' all_identical(list(x, y), `[[`, "a") 72 | #' all_identical(list(x, y), `[[`, "b") 73 | all_identical <- function(x, fun, ...) { 74 | if ((m <- length(x)) <= 1L) { 75 | return(TRUE) 76 | } 77 | subs <- lapply(x, fun, ...) 78 | all(vapply(subs[2:m], FUN = identical, FUN.VALUE = TRUE, subs[[1L]])) 79 | } 80 | -------------------------------------------------------------------------------- /R/utils_cut.R: -------------------------------------------------------------------------------- 1 | #' Discretizes a Vector 2 | #' 3 | #' This function takes a vector `x` and returns a list with information on 4 | #' disretized version of `x`. The construction of level names can be controlled 5 | #' by passing `...` arguments to [formatC()]. 6 | #' 7 | #' @noRd 8 | #' @param x A vector. 9 | #' @param breaks An optional vector of breaks. Only relevant for numeric `x`. 10 | #' @param n_bins If `x` is numeric and no breaks are provided, 11 | #' this is the maximum number of bins allowed or to be created (approximately). 12 | #' @param cut_type For the default type "equal", bins of equal width are created 13 | #' by [pretty()]. Choose "quantile" to create quantile bins. 14 | #' @param x_name Column name with the values of `x` in the output. 15 | #' @param level_name Column name with the bin labels of `x` in the output. 16 | #' @param ... Further arguments passed to [cut3()]. 17 | #' @returns 18 | #' A list with the following elements: 19 | #' - `data`: A `data.frame` with colums `x_name` and 20 | #' `level_name` each with the same length as `x`. 21 | #' The column `x_name` has values in output `bin_means` 22 | #' while the column `level_name` has values in `bin_labels`. 23 | #' - `breaks`: A vector of increasing and unique breaks used to cut a 24 | #' numeric `x` with too many distinct levels. `NULL` otherwise. 25 | #' - `bin_means`: The midpoints of subsequent breaks, or if there are no 26 | #' `breaks` in the output, factor levels or distinct values of `x`. 27 | #' - `bin_labels`: Break labels of the form "(low, high]" if there are `breaks` 28 | #' in the output, otherwise the same as `bin_means`. Same order as `bin_means`. 29 | #' @examples 30 | #' auto_cut(1:10, n_bins = 3) 31 | #' auto_cut(c(NA, 1:10), n_bins = 3) 32 | #' auto_cut(1:10, breaks = 3:4, n_bins = 3) 33 | #' auto_cut(1:10, n_bins = 3, cut_type = "quantile") 34 | #' auto_cut(LETTERS[4:1], n_bins = 2) 35 | #' auto_cut(factor(LETTERS[1:4], LETTERS[4:1]), n_bins = 2) 36 | #' auto_cut(990:1100, n_bins = 3, big.mark = "'", format = "fg") 37 | #' auto_cut(c(0.0001, 0.0002, 0.0003, 0.005), n_bins = 3, format = "fg") 38 | auto_cut <- function(x, breaks = NULL, n_bins = 27L, 39 | cut_type = c("equal", "quantile"), 40 | x_name = "value", level_name = "level", ...) { 41 | cut_type <- match.arg(cut_type) 42 | bin_means <- if (is.factor(x)) levels(x) else sort(unique(x)) 43 | if (!is.numeric(x) || (is.null(breaks) && length(bin_means) <= n_bins)) { 44 | if (!is.numeric(x)) { 45 | breaks <- NULL # ignored for non-numeric 46 | } 47 | data <- data.frame(x, x) 48 | if (anyNA(x)) { 49 | bin_means <- c(bin_means, NA) 50 | } 51 | if (is.factor(x)) { 52 | bin_means <- factor(bin_means, bin_means) 53 | } 54 | bin_labels <- bin_means 55 | } else { 56 | if (is.null(breaks)) { 57 | if (cut_type == "equal") { 58 | breaks <- pretty(x, n = n_bins) 59 | } else { 60 | breaks <- stats::quantile( 61 | x, 62 | probs = seq(0, 1, length.out = n_bins + 1L), 63 | na.rm = TRUE, 64 | names = FALSE, 65 | type = 1 66 | ) 67 | } 68 | } 69 | breaks <- sort(unique(breaks)) 70 | bin_means <- midpoints(breaks) 71 | cuts <- cut3(x, breaks = breaks, include.lowest = TRUE, ...) 72 | bin_labels <- levels(cuts) 73 | if (anyNA(cuts)) { 74 | bin_labels <- c(bin_labels, NA) 75 | bin_means <- c(bin_means, NA) 76 | } 77 | bin_labels <- factor(bin_labels, levels(cuts)) 78 | stopifnot(length(bin_labels) == length(bin_means)) 79 | int_cuts <- as.integer(cuts) 80 | data <- data.frame((breaks[int_cuts] + breaks[int_cuts + 1L]) / 2, cuts) 81 | } 82 | list( 83 | data = stats::setNames(data, c(x_name, level_name)), 84 | breaks = breaks, 85 | bin_means = bin_means, 86 | bin_labels = bin_labels 87 | ) 88 | } 89 | 90 | #' Modified cut 91 | #' 92 | #' Slightly modified version of [cut.default()]. Both modifications refer 93 | #' to the construction of break labels. Firstly, `...` arguments are passed to 94 | #' [formatC()] in formatting the numbers in the labels. 95 | #' Secondly, a separator between the two numbers can be specified with default ", ". 96 | #' 97 | #' @noRd 98 | #' @param x Numeric vector. 99 | #' @param breaks Numeric vector of cut points or a single number 100 | #' specifying the number of intervals desired. 101 | #' @param labels Labels for the levels of the final categories. 102 | #' @param include.lowest Flag if minimum value should be added to intervals 103 | #' of type "(,]" (or maximum for "[,)"). 104 | #' @param right Flag if intervals should be closed to the right or left. 105 | #' @param dig.lab Number of significant digits passed to [formatC()]. 106 | #' @param ordered_result Flag if resulting output vector should be ordered. 107 | #' @param sep Separater between from-to labels. 108 | #' @param ... Arguments passed to [formatC()]. 109 | #' @returns Vector of the same length as x. 110 | #' @examples 111 | #' x <- 998:1001 112 | #' cut3(x, breaks = 2) 113 | #' cut3(x, breaks = 2, big.mark = "'", sep = ":") 114 | cut3 <- function(x, breaks, labels = NULL, include.lowest = FALSE, right = TRUE, 115 | dig.lab = 3L, ordered_result = FALSE, sep = ", ", ...) { 116 | # Modified version of base::cut.default() 117 | if (!is.numeric(x)) 118 | stop("'x' must be numeric") 119 | if (length(breaks) == 1L) { 120 | if (is.na(breaks) || breaks < 2L) 121 | stop("invalid number of intervals") 122 | nb <- as.integer(breaks + 1) 123 | dx <- diff(rx <- range(x, na.rm = TRUE)) 124 | if (dx == 0) { 125 | dx <- if (rx[1L] != 0) 126 | abs(rx[1L]) 127 | else 1 128 | breaks <- seq.int(rx[1L] - dx / 1000, rx[2L] + dx / 1000, length.out = nb) 129 | } 130 | else { 131 | breaks <- seq.int(rx[1L], rx[2L], length.out = nb) 132 | breaks[c(1L, nb)] <- c(rx[1L] - dx / 1000, rx[2L] + dx / 1000) 133 | } 134 | } 135 | else nb <- length(breaks <- sort.int(as.double(breaks))) 136 | if (anyDuplicated(breaks)) 137 | stop("'breaks' are not unique") 138 | codes.only <- FALSE 139 | if (is.null(labels)) { 140 | for (dig in dig.lab:max(12L, dig.lab)) { 141 | ch.br <- formatC(0L + breaks, digits = dig, width = 1L, ...) 142 | if (ok <- all(ch.br[-1L] != ch.br[-nb])) 143 | break 144 | } 145 | labels <- if (ok) 146 | paste0( 147 | if (right) "(" else "[", ch.br[-nb], sep, ch.br[-1L], if (right) "]" else ")" 148 | ) 149 | else paste0("Range_", seq_len(nb - 1L)) 150 | if (ok && include.lowest) { 151 | if (right) 152 | substr(labels[1L], 1L, 1L) <- "[" 153 | else substring(labels[nb - 1L], nchar(labels[nb - 1L], "c")) <- "]" 154 | } 155 | } 156 | else if (is.logical(labels) && !labels) 157 | codes.only <- TRUE 158 | else if (length(labels) != nb - 1L) 159 | stop("lengths of 'breaks' and 'labels' differ") 160 | code <- .bincode(x, breaks, right, include.lowest) 161 | if (codes.only) 162 | code 163 | else factor(code, seq_along(labels), labels, ordered = ordered_result) 164 | } 165 | 166 | #' Common Breaks for multiflashlight 167 | #' 168 | #' Internal function used to find common breaks from different flashlights. 169 | #' 170 | #' @noRd 171 | #' @param x An object of class "multiflashlight". 172 | #' @param v The variable to be profiled. 173 | #' @param data A `data.frame`. 174 | #' @param n_bins Maxmium number of unique values to evaluate for numeric `v`. 175 | #' @param cut_type Cut type 176 | #' @returns A vector of breaks 177 | common_breaks <- function(x, v, data = NULL, n_bins, cut_type) { 178 | if (is.null(data)) { 179 | # Stack v from all data in flashlights 180 | stopifnot( 181 | all(vapply(x, function(z) nrow(z$data) >= 1L, FUN.VALUE = TRUE)), 182 | all(vapply(x, function(z) v %in% colnames(z$data), FUN.VALUE = TRUE)) 183 | ) 184 | v_vec <- unlist(lapply(x, function(z) z$data[[v]]), use.names = FALSE) 185 | } else { 186 | stopifnot(nrow(data) >= 1L, v %in% colnames(data)) 187 | v_vec <- data[[v]] 188 | } 189 | auto_cut(v_vec, n_bins = n_bins, cut_type = cut_type)$breaks 190 | } 191 | 192 | # Calculates midpoints of subsequent unique breaks 193 | midpoints <- function(breaks) { 194 | # to do: deal with missings 195 | stopifnot(is.numeric(breaks)) 196 | breaks <- sort(unique(breaks)) 197 | stopifnot((m <- length(breaks)) >= 2L) 198 | (breaks[-m] + breaks[-1L]) / 2 199 | } 200 | -------------------------------------------------------------------------------- /R/utils_plot.R: -------------------------------------------------------------------------------- 1 | rotate_x <- function() { 2 | ggplot2::theme( 3 | axis.text.x = ggplot2::element_text(angle = 45, hjust = 1, vjust = 1) 4 | ) 5 | } 6 | 7 | override_alpha <- function() { 8 | ggplot2::guides(color = ggplot2::guide_legend(override.aes = list(alpha = 1))) 9 | } 10 | -------------------------------------------------------------------------------- /R/zzz.R: -------------------------------------------------------------------------------- 1 | # .onLoad <- function(libname, pkgname) { 2 | # op <- options() 3 | # op.flashlight <- list( 4 | # flashlight.after_name = "after", 5 | # ) 6 | # toset <- !(names(op.flashlight) %in% names(op)) 7 | # if (any(toset)) { 8 | # options(op.flashlight[toset]) 9 | # } 10 | # invisible() 11 | # } 12 | 13 | utils::globalVariables( 14 | c( 15 | "after_", 16 | "before_", 17 | "cal_xx", 18 | "counts_", 19 | "description_", 20 | "error_", 21 | "fill_", 22 | "high_", 23 | "id_", 24 | "lab_", 25 | "label_", 26 | "low_", 27 | "metric_", 28 | "shift_xx", 29 | "step_", 30 | "temp_", 31 | "tree_", 32 | "type_", 33 | "value_", 34 | "value_i_", 35 | "value_j_", 36 | "value_2_", 37 | "variable_", 38 | "ymax_", 39 | "ymin_", 40 | "x_") 41 | ) 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # {flashlight} 2 | 3 | 4 | 5 | [![R-CMD-check](https://github.com/mayer79/flashlight/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/mayer79/flashlight/actions/workflows/R-CMD-check.yaml) 6 | [![Codecov test coverage](https://codecov.io/gh/mayer79/flashlight/graph/badge.svg)](https://app.codecov.io/gh/mayer79/flashlight) 7 | [![CRAN_Status_Badge](https://www.r-pkg.org/badges/version/flashlight)](https://cran.r-project.org/package=flashlight) 8 | 9 | [![](https://cranlogs.r-pkg.org/badges/flashlight)](https://cran.r-project.org/package=flashlight) 10 | [![](https://cranlogs.r-pkg.org/badges/grand-total/flashlight?color=orange)](https://cran.r-project.org/package=flashlight) 11 | 12 | 13 | 14 | ## Overview 15 | 16 | The goal of this package is shed light on black box machine learning models. 17 | 18 | The main props of {flashlight}: 19 | 20 | 1. It is simple, yet flexible. 21 | 2. It offers model agnostic tools like model performance, variable importance, global surrogate models, ICE profiles, partial dependence, ALE, and further effects plots, scatter plots, interaction strength, and variable contribution breakdown/SHAP for single observations. 22 | 3. It allows to assess multiple models side-by-side. 23 | 4. It supports "group by" operations. 24 | 5. It works with case weights. 25 | 26 | Currently, models with numeric or binary response are supported. 27 | 28 | ## Installation 29 | 30 | ```r 31 | # From CRAN 32 | install.packages("flashlight") 33 | 34 | # Development version 35 | devtools::install_github("mayer79/flashlight") 36 | ``` 37 | 38 | ## Usage 39 | 40 | Let's start with an iris example. For simplicity, we do not split the data into training and testing/validation sets. 41 | 42 | ```r 43 | library(ggplot2) 44 | library(MetricsWeighted) 45 | library(flashlight) 46 | 47 | fit_lm <- lm(Sepal.Length ~ ., data = iris) 48 | 49 | # Make explainer object 50 | fl_lm <- flashlight( 51 | model = fit_lm, 52 | data = iris, 53 | y = "Sepal.Length", 54 | label = "lm", 55 | metrics = list(RMSE = rmse, `R-squared` = r_squared) 56 | ) 57 | ``` 58 | 59 | ### Performance 60 | 61 | ```r 62 | fl_lm |> 63 | light_performance() |> 64 | plot(fill = "darkred") + 65 | labs(x = element_blank(), title = "Performance on training data") 66 | 67 | fl_lm |> 68 | light_performance(by = "Species") |> 69 | plot(fill = "darkred") + 70 | ggtitle("Performance split by Species") 71 | ``` 72 | 73 |

74 | Performance 75 | Grouped 76 |

77 | 78 | 79 | ### Permutation importance regarding first metric 80 | 81 | Error bars represent standard errors, i.e., the uncertainty of the estimated importance. 82 | 83 | ```r 84 | fl_lm |> 85 | light_importance(m_repetitions = 4) |> 86 | plot(fill = "darkred") + 87 | labs(title = "Permutation importance", y = "Increase in RMSE") 88 | ``` 89 | 90 | ![](man/figures/imp.svg) 91 | 92 | ### ICE curves for `Petal.Width` 93 | 94 | ```r 95 | fl_lm |> 96 | light_ice("Sepal.Width", n_max = 200) |> 97 | plot(alpha = 0.3, color = "chartreuse4") + 98 | labs(title = "ICE curves for 'Sepal.Width'", y = "Prediction") 99 | 100 | fl_lm |> 101 | light_ice("Sepal.Width", n_max = 200, center = "middle") |> 102 | plot(alpha = 0.3, color = "chartreuse4") + 103 | labs(title = "c-ICE curves for 'Sepal.Width'", y = "Prediction (centered)") 104 | ``` 105 | 106 |

107 | Performance 108 | Grouped 109 |

110 | 111 | ### PDPs 112 | 113 | ```r 114 | fl_lm |> 115 | light_profile("Sepal.Width", n_bins = 40) |> 116 | plot() + 117 | ggtitle("PDP for 'Sepal.Width'") 118 | 119 | fl_lm |> 120 | light_profile("Sepal.Width", n_bins = 40, by = "Species") |> 121 | plot() + 122 | ggtitle("Same grouped by 'Species'") 123 | ``` 124 | 125 |

126 | Performance 127 | Grouped 128 |

129 | 130 | ### 2D PDP 131 | 132 | ```r 133 | fl_lm |> 134 | light_profile2d(c("Petal.Width", "Petal.Length")) |> 135 | plot() 136 | ``` 137 | 138 | ![](man/figures/pdp2d.svg) 139 | 140 | ### ALE 141 | 142 | ```r 143 | fl_lm |> 144 | light_profile("Sepal.Width", type = "ale") |> 145 | plot() + 146 | ggtitle("ALE plot for 'Sepal.Width'") 147 | ``` 148 | 149 | ![](man/figures/ale.svg) 150 | 151 | ### Different profile plots in one 152 | 153 | ```r 154 | fl_lm |> 155 | light_effects("Sepal.Width") |> 156 | plot(use = "all") + 157 | ggtitle("Different types of profiles for 'Sepal.Width'") 158 | ``` 159 | 160 | ![](man/figures/effects.svg) 161 | 162 | ### Variable contribution breakdown for single observation 163 | 164 | ```r 165 | fl_lm |> 166 | light_breakdown(new_obs = iris[1, ]) |> 167 | plot() 168 | ``` 169 | 170 | ![](man/figures/breakdown.svg) 171 | 172 | ### Global surrogate tree 173 | 174 | ```r 175 | fl_lm |> 176 | light_global_surrogate() |> 177 | plot() 178 | ``` 179 | 180 | ![](man/figures/surrogate.svg) 181 | 182 | ### Multiple models 183 | 184 | Multiple flashlights can be combined to a multiflashlight. 185 | 186 | ```r 187 | library(rpart) 188 | 189 | fit_tree <- rpart( 190 | Sepal.Length ~ ., 191 | data = iris, 192 | control = list(cp = 0, xval = 0, maxdepth = 5) 193 | ) 194 | 195 | # Make explainer object 196 | fl_tree <- flashlight( 197 | model = fit_tree, 198 | data = iris, 199 | y = "Sepal.Length", 200 | label = "tree", 201 | metrics = list(RMSE = rmse, `R-squared` = r_squared) 202 | ) 203 | 204 | # Combine with other explainer 205 | fls <- multiflashlight(list(fl_tree, fl_lm)) 206 | 207 | fls |> 208 | light_performance() |> 209 | plot(fill = "chartreuse4") + 210 | labs(x = "Model", title = "Performance") 211 | 212 | fls |> 213 | light_profile("Petal.Length", n_bins = 40, by = "Species") |> 214 | plot() + 215 | ggtitle("PDP by Species") 216 | ``` 217 | 218 |

219 | Performance 220 | Grouped 221 |

222 | 223 | ## More 224 | 225 | Check out the vignette for more information and important references. 226 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- 1 | # flashlight 0.9.0 2 | 3 | Hello CRAN 4 | 5 | This is a relatively large maintenance update: 6 | 7 | - dealing with dplyr/ggplot2 depreciation cycles 8 | - more modern help files 9 | - smaller vignettes 10 | - smaller dependency footprint 11 | - announcing upcoming depreciation changes 12 | 13 | Thanks for keeping CRAN clean. 14 | 15 | ## R CMD check 16 | 17 | Warning: 'qpdf' is needed for checks on size reduction of PDFs 18 | 19 | Note: Skipping checking HTML validation: no command 'tidy' found 20 | 21 | ## Windevel 22 | 23 | Status: OK 24 | 25 | ## Rhub NOTES 26 | 27 | * checking HTML version of manual ... NOTE 28 | Skipping checking HTML validation: no command 'tidy' found 29 | Skipping checking math rendering: package 'V8' unavailable 30 | 31 | 32 | ## Reverse dependencies 33 | 34 | OK: 1 35 | BROKEN: 0 36 | 37 | -------------------------------------------------------------------------------- /flashlight.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | ProjectId: bfd7b514-c956-4808-a6fb-3bb235b6f135 3 | 4 | RestoreWorkspace: No 5 | SaveWorkspace: No 6 | AlwaysSaveHistory: Default 7 | 8 | EnableCodeIndexing: Yes 9 | UseSpacesForTab: Yes 10 | NumSpacesForTab: 2 11 | Encoding: UTF-8 12 | 13 | RnwWeave: Sweave 14 | LaTeX: pdfLaTeX 15 | 16 | AutoAppendNewline: Yes 17 | StripTrailingWhitespace: Yes 18 | 19 | BuildType: Package 20 | PackageUseDevtools: Yes 21 | PackageInstallArgs: --no-multiarch --with-keep.source 22 | PackageRoxygenize: rd,collate,namespace 23 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mayer79/flashlight/bea46c8bed2587a45a5b82ec718ec5d691d136c6/logo.png -------------------------------------------------------------------------------- /man/add_shap.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/aa_deprecated.R 3 | \name{add_shap} 4 | \alias{add_shap} 5 | \title{DEPRECATED} 6 | \usage{ 7 | add_shap(...) 8 | } 9 | \arguments{ 10 | \item{...}{Deprecated} 11 | } 12 | \value{ 13 | Error message. 14 | } 15 | \description{ 16 | Deprecated in favor of {kernelshap}/{fastshap}. 17 | } 18 | -------------------------------------------------------------------------------- /man/figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mayer79/flashlight/bea46c8bed2587a45a5b82ec718ec5d691d136c6/man/figures/logo.png -------------------------------------------------------------------------------- /man/flashlight.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/flashlight.R 3 | \name{flashlight} 4 | \alias{flashlight} 5 | \alias{flashlight.default} 6 | \alias{flashlight.flashlight} 7 | \title{Create or Update a flashlight} 8 | \usage{ 9 | flashlight(x, ...) 10 | 11 | \method{flashlight}{default}( 12 | x, 13 | model = NULL, 14 | data = NULL, 15 | y = NULL, 16 | predict_function = stats::predict, 17 | linkinv = function(z) z, 18 | w = NULL, 19 | by = NULL, 20 | metrics = list(rmse = MetricsWeighted::rmse), 21 | label = NULL, 22 | shap = NULL, 23 | ... 24 | ) 25 | 26 | \method{flashlight}{flashlight}(x, check = TRUE, ...) 27 | } 28 | \arguments{ 29 | \item{x}{An object of class "flashlight". If not provided, a new flashlight is 30 | created based on further input. Otherwise, \code{x} is updated based on further input.} 31 | 32 | \item{...}{Arguments passed from or to other functions.} 33 | 34 | \item{model}{A fitted model of any type. Most models require a customized 35 | \code{predict_function}.} 36 | 37 | \item{data}{A \code{data.frame} or \code{tibble} used as basis for calculations.} 38 | 39 | \item{y}{Variable name of response.} 40 | 41 | \item{predict_function}{A real valued function with two arguments: 42 | A model and a data of the same structure as \code{data}. 43 | Only the order of the two arguments matter, not their names.} 44 | 45 | \item{linkinv}{An inverse transformation function applied after \code{predict_function}.} 46 | 47 | \item{w}{A variable name of case weights.} 48 | 49 | \item{by}{A character vector with names of grouping variables.} 50 | 51 | \item{metrics}{A named list of metrics. Here, a metric is a function with exactly 52 | four arguments: actual, predicted, w (case weights) and \code{...} 53 | like those in package {MetricsWeighted}.} 54 | 55 | \item{label}{Name of the flashlight. Required.} 56 | 57 | \item{shap}{An optional shap object. Typically added by calling \code{\link[=add_shap]{add_shap()}}.} 58 | 59 | \item{check}{When updating the flashlight: Should internal checks be performed? 60 | Default is \code{TRUE}.} 61 | } 62 | \value{ 63 | An object of class "flashlight" (and \code{list}) containing each 64 | input (except \code{x}) as element. 65 | } 66 | \description{ 67 | Creates or updates a "flashlight" object. If a flashlight is to be created, 68 | all arguments are optional except \code{label}. If a flashlight is to be updated, 69 | all arguments are optional up to \code{x} (the flashlight to be updated). 70 | } 71 | \section{Methods (by class)}{ 72 | \itemize{ 73 | \item \code{flashlight(default)}: Used to create a flashlight object. 74 | No \code{x} has to be passed in this case. 75 | 76 | \item \code{flashlight(flashlight)}: Used to update an existing flashlight object. 77 | 78 | }} 79 | \examples{ 80 | fit <- lm(Sepal.Length ~ ., data = iris) 81 | (fl <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols")) 82 | (fl_updated <- flashlight(fl, linkinv = exp)) 83 | } 84 | \seealso{ 85 | \code{\link[=multiflashlight]{multiflashlight()}} 86 | } 87 | -------------------------------------------------------------------------------- /man/is.flashlight.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/is_flashlight.R 3 | \name{is.flashlight} 4 | \alias{is.flashlight} 5 | \alias{is.multiflashlight} 6 | \alias{is.light} 7 | \alias{is.light_performance} 8 | \alias{is.light_performance_multi} 9 | \alias{is.light_importance} 10 | \alias{is.light_importance_multi} 11 | \alias{is.light_breakdown} 12 | \alias{is.light_breakdown_multi} 13 | \alias{is.light_ice} 14 | \alias{is.light_ice_multi} 15 | \alias{is.light_profile} 16 | \alias{is.light_profile_multi} 17 | \alias{is.light_profile2d} 18 | \alias{is.light_profile2d_multi} 19 | \alias{is.light_effects} 20 | \alias{is.light_effects_multi} 21 | \alias{is.shap} 22 | \alias{is.light_scatter} 23 | \alias{is.light_scatter_multi} 24 | \alias{is.light_global_surrogate} 25 | \alias{is.light_global_surrogate_multi} 26 | \title{Check functions for flashlight Classes} 27 | \usage{ 28 | is.flashlight(x) 29 | 30 | is.multiflashlight(x) 31 | 32 | is.light(x) 33 | 34 | is.light_performance(x) 35 | 36 | is.light_performance_multi(x) 37 | 38 | is.light_importance(x) 39 | 40 | is.light_importance_multi(x) 41 | 42 | is.light_breakdown(x) 43 | 44 | is.light_breakdown_multi(x) 45 | 46 | is.light_ice(x) 47 | 48 | is.light_ice_multi(x) 49 | 50 | is.light_profile(x) 51 | 52 | is.light_profile_multi(x) 53 | 54 | is.light_profile2d(x) 55 | 56 | is.light_profile2d_multi(x) 57 | 58 | is.light_effects(x) 59 | 60 | is.light_effects_multi(x) 61 | 62 | is.shap(x) 63 | 64 | is.light_scatter(x) 65 | 66 | is.light_scatter_multi(x) 67 | 68 | is.light_global_surrogate(x) 69 | 70 | is.light_global_surrogate_multi(x) 71 | } 72 | \arguments{ 73 | \item{x}{Any object.} 74 | } 75 | \value{ 76 | A logical vector of length one. 77 | } 78 | \description{ 79 | Checks if an object inherits specific class relevant for the flashlight package. 80 | } 81 | \section{Functions}{ 82 | \itemize{ 83 | \item \code{is.multiflashlight()}: Check for multiflashlight object. 84 | 85 | \item \code{is.light()}: Check for light object. 86 | 87 | \item \code{is.light_performance()}: Check for light_performance object. 88 | 89 | \item \code{is.light_performance_multi()}: Check for light_performance_multi object. 90 | 91 | \item \code{is.light_importance()}: Check for light_importance object. 92 | 93 | \item \code{is.light_importance_multi()}: Check for light_importance_multi object. 94 | 95 | \item \code{is.light_breakdown()}: Check for light_breakdown object. 96 | 97 | \item \code{is.light_breakdown_multi()}: Check for light_breakdown_multi object. 98 | 99 | \item \code{is.light_ice()}: Check for light_ice object. 100 | 101 | \item \code{is.light_ice_multi()}: Check for light_ice_multi object. 102 | 103 | \item \code{is.light_profile()}: Check for light_profile object. 104 | 105 | \item \code{is.light_profile_multi()}: Check for light_profile_multi object. 106 | 107 | \item \code{is.light_profile2d()}: Check for light_profile2d object. 108 | 109 | \item \code{is.light_profile2d_multi()}: Check for light_profile2d_multi object. 110 | 111 | \item \code{is.light_effects()}: Check for light_effects object. 112 | 113 | \item \code{is.light_effects_multi()}: Check for light_effects_multi object. 114 | 115 | \item \code{is.shap()}: Check for shap object. 116 | 117 | \item \code{is.light_scatter()}: Check for light_scatter object. 118 | 119 | \item \code{is.light_scatter_multi()}: Check for light_scatter_multi object. 120 | 121 | \item \code{is.light_global_surrogate()}: Check for light_global_surrogate object. 122 | 123 | \item \code{is.light_global_surrogate_multi()}: Check for light_global_surrogate_multi object. 124 | 125 | }} 126 | \examples{ 127 | a <- flashlight(label = "a") 128 | is.flashlight(a) 129 | is.flashlight("a") 130 | } 131 | -------------------------------------------------------------------------------- /man/light_breakdown.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_breakdown.R 3 | \name{light_breakdown} 4 | \alias{light_breakdown} 5 | \alias{light_breakdown.default} 6 | \alias{light_breakdown.flashlight} 7 | \alias{light_breakdown.multiflashlight} 8 | \title{Variable Contribution Breakdown for Single Observation} 9 | \usage{ 10 | light_breakdown(x, ...) 11 | 12 | \method{light_breakdown}{default}(x, ...) 13 | 14 | \method{light_breakdown}{flashlight}( 15 | x, 16 | new_obs, 17 | data = x$data, 18 | by = x$by, 19 | v = NULL, 20 | visit_strategy = c("importance", "permutation", "v"), 21 | n_max = Inf, 22 | n_perm = 20, 23 | seed = NULL, 24 | use_linkinv = FALSE, 25 | description = TRUE, 26 | digits = 2, 27 | ... 28 | ) 29 | 30 | \method{light_breakdown}{multiflashlight}(x, ...) 31 | } 32 | \arguments{ 33 | \item{x}{An object of class "flashlight" or "multiflashlight".} 34 | 35 | \item{...}{Further arguments passed to \code{\link[=prettyNum]{prettyNum()}} to format numbers 36 | in description text.} 37 | 38 | \item{new_obs}{One single new observation to calculate variable attribution for. 39 | Needs to be a \code{data.frame} of same structure as \code{data}.} 40 | 41 | \item{data}{An optional \code{data.frame}.} 42 | 43 | \item{by}{An optional vector of column names used to filter \code{data} 44 | for rows with equal values in "by" variables as \code{new_obs}.} 45 | 46 | \item{v}{Vector of variable names to assess contribution for. 47 | Defaults to all except those specified by "y", "w" and "by".} 48 | 49 | \item{visit_strategy}{In what sequence should variables be visited? 50 | By "importance", by \code{n_perm} "permutation" or as "v" (see Details).} 51 | 52 | \item{n_max}{Maximum number of rows in \code{data} to consider in the reference data. 53 | Set to lower value if \code{data} is large.} 54 | 55 | \item{n_perm}{Number of permutations of random visit sequences. 56 | Only used if \code{visit_strategy = "permutation"}.} 57 | 58 | \item{seed}{An integer random seed used to shuffle rows if \code{n_max} 59 | is smaller than the number of rows in \code{data}.} 60 | 61 | \item{use_linkinv}{Should retransformation function be applied? Default is \code{FALSE}.} 62 | 63 | \item{description}{Should descriptions be added? Default is \code{TRUE}.} 64 | 65 | \item{digits}{Passed to \code{\link[=prettyNum]{prettyNum()}} to format numbers in description text.} 66 | } 67 | \value{ 68 | An object of class "light_breakdown" with the following elements: 69 | \itemize{ 70 | \item \code{data} A tibble with results. 71 | \item \code{by} Same as input \code{by}. 72 | } 73 | } 74 | \description{ 75 | Calculates sequential additive variable contributions (approximate SHAP) to 76 | the prediction of a single observation, see Gosiewska and Biecek (see reference) 77 | and the details below. 78 | } 79 | \details{ 80 | The breakdown algorithm works as follows: First, the visit order 81 | \eqn{(x_1, ..., x_m)} of the variables \code{v} is specified. 82 | Then, in the query \code{data}, the column \eqn{x_1} is set to the value of \eqn{x_1} 83 | of the single observation \code{new_obs} to be explained. 84 | The change in the (weighted) average prediction on \code{data} measures the 85 | contribution of \eqn{x_1} on the prediction of \code{new_obs}. 86 | This procedure is iterated over all \eqn{x_i} until eventually, all rows 87 | in \code{data} are identical to \code{new_obs}. 88 | 89 | A complication with this approach is that the visit order is relevant, 90 | at least for non-additive models. Ideally, the algorithm could be repeated 91 | for all possible permutations of \code{v} and its results averaged per variable. 92 | This is basically what SHAP values do, see the reference below for an explanation. 93 | Unfortunately, there is no efficient way to do this in a model agnostic way. 94 | 95 | We offer two visit strategies to approximate SHAP: 96 | \enumerate{ 97 | \item "importance": Using the short-cut described in the reference below: 98 | The variables are sorted by the size of their contribution in the same way as the 99 | breakdown algorithm but without iteration, i.e., starting from the original query 100 | data for each variable \eqn{x_i}. 101 | \item "permutation": Averages contributions from a small number of random permutations 102 | of \code{v}. 103 | } 104 | 105 | Note that the minimum required elements in the (multi-)flashlight are a 106 | "predict_function", "model", and "data". The latter can also directly be passed to 107 | \code{\link[=light_breakdown]{light_breakdown()}}. Note that by default, no retransformation function is applied. 108 | } 109 | \section{Methods (by class)}{ 110 | \itemize{ 111 | \item \code{light_breakdown(default)}: Default method not implemented yet. 112 | 113 | \item \code{light_breakdown(flashlight)}: Variable attribution to single observation 114 | for a flashlight. 115 | 116 | \item \code{light_breakdown(multiflashlight)}: Variable attribution to single observation 117 | for a multiflashlight. 118 | 119 | }} 120 | \examples{ 121 | fit_part <- lm(Sepal.Length ~ Species + Petal.Length, data = iris) 122 | fl_part <- flashlight( 123 | model = fit_part, label = "part", data = iris, y = "Sepal.Length" 124 | ) 125 | plot(light_breakdown(fl_part, new_obs = iris[1, ])) 126 | 127 | # Second model 128 | fit_full <- lm(Sepal.Length ~ ., data = iris) 129 | fl_full <- flashlight( 130 | model = fit_full, label = "full", data = iris, y = "Sepal.Length" 131 | ) 132 | fls <- multiflashlight(list(fl_part, fl_full)) 133 | plot(light_breakdown(fls, new_obs = iris[1, ])) 134 | } 135 | \references{ 136 | A. Gosiewska and P. Biecek (2019). IBREAKDOWN: Uncertainty of model explanations 137 | for non-additive predictive models. ArXiv. 138 | } 139 | \seealso{ 140 | \code{\link[=plot.light_breakdown]{plot.light_breakdown()}} 141 | } 142 | -------------------------------------------------------------------------------- /man/light_check.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_check.R 3 | \name{light_check} 4 | \alias{light_check} 5 | \alias{light_check.default} 6 | \alias{light_check.flashlight} 7 | \alias{light_check.multiflashlight} 8 | \title{Check flashlight} 9 | \usage{ 10 | light_check(x, ...) 11 | 12 | \method{light_check}{default}(x, ...) 13 | 14 | \method{light_check}{flashlight}(x, ...) 15 | 16 | \method{light_check}{multiflashlight}(x, ...) 17 | } 18 | \arguments{ 19 | \item{x}{An object of class "flashlight" or "multiflashlight".} 20 | 21 | \item{...}{Further arguments passed from or to other methods.} 22 | } 23 | \value{ 24 | The input \code{x} or an error message. 25 | } 26 | \description{ 27 | Checks if an object of class "flashlight" or "multiflashlight" 28 | is consistently defined. 29 | } 30 | \section{Methods (by class)}{ 31 | \itemize{ 32 | \item \code{light_check(default)}: Default check method not implemented yet. 33 | 34 | \item \code{light_check(flashlight)}: Checks if a flashlight object is consistently defined. 35 | 36 | \item \code{light_check(multiflashlight)}: Checks if a multiflashlight object is consistently defined. 37 | 38 | }} 39 | \examples{ 40 | fit <- lm(Sepal.Length ~ ., data = iris) 41 | fit_log <- lm(log(Sepal.Length) ~ ., data = iris) 42 | fl <- flashlight(fit, data = iris, y = "Sepal.Length", label = "ols") 43 | fl_log <- flashlight(fit_log, y = "Sepal.Length", label = "ols", linkinv = exp) 44 | light_check(fl) 45 | light_check(fl_log) 46 | } 47 | -------------------------------------------------------------------------------- /man/light_combine.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_combine.R 3 | \name{light_combine} 4 | \alias{light_combine} 5 | \alias{light_combine.default} 6 | \alias{light_combine.light} 7 | \alias{light_combine.list} 8 | \title{Combine Objects} 9 | \usage{ 10 | light_combine(x, ...) 11 | 12 | \method{light_combine}{default}(x, ...) 13 | 14 | \method{light_combine}{light}(x, new_class = NULL, ...) 15 | 16 | \method{light_combine}{list}(x, new_class = NULL, ...) 17 | } 18 | \arguments{ 19 | \item{x}{A list of objects of the same class.} 20 | 21 | \item{...}{Further arguments passed from or to other methods.} 22 | 23 | \item{new_class}{An optional vector with additional class names to be added 24 | to the output.} 25 | } 26 | \value{ 27 | If \code{x} is a list, an object like each element but with unioned rows 28 | in data slots. 29 | } 30 | \description{ 31 | Combines a list of similar objects each of class "light" by row binding 32 | \code{data.frame} slots and retaining the other slots from the first list element. 33 | } 34 | \section{Methods (by class)}{ 35 | \itemize{ 36 | \item \code{light_combine(default)}: Default method not implemented yet. 37 | 38 | \item \code{light_combine(light)}: Since there is nothing to combine, the input is returned 39 | except for additional classes. 40 | 41 | \item \code{light_combine(list)}: Combine a list of similar light objects. 42 | 43 | }} 44 | \examples{ 45 | fit_lm <- lm(Sepal.Length ~ ., data = iris) 46 | fit_glm <- glm(Sepal.Length ~ ., family = Gamma(link = "log"), data = iris) 47 | mod_lm <- flashlight(model = fit_lm, label = "lm", data = iris, y = "Sepal.Length") 48 | mod_glm <- flashlight( 49 | model = fit_glm, 50 | label = "glm", 51 | data = iris, 52 | y = "Sepal.Length", 53 | predict_function = function(object, newdata) 54 | predict(object, newdata, type = "response") 55 | ) 56 | mods <- multiflashlight(list(mod_lm, mod_glm)) 57 | perf_lm <- light_performance(mod_lm) 58 | perf_glm <- light_performance(mod_glm) 59 | manual_comb <- light_combine( 60 | list(perf_lm, perf_glm), 61 | new_class = "light_performance_multi" 62 | ) 63 | auto_comb <- light_performance(mods) 64 | all.equal(manual_comb, auto_comb) 65 | } 66 | -------------------------------------------------------------------------------- /man/light_effects.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_effects.R 3 | \name{light_effects} 4 | \alias{light_effects} 5 | \alias{light_effects.default} 6 | \alias{light_effects.flashlight} 7 | \alias{light_effects.multiflashlight} 8 | \title{Combination of Response, Predicted, Partial Dependence, and ALE profiles.} 9 | \usage{ 10 | light_effects(x, ...) 11 | 12 | \method{light_effects}{default}(x, ...) 13 | 14 | \method{light_effects}{flashlight}( 15 | x, 16 | v, 17 | data = NULL, 18 | by = x$by, 19 | stats = "mean", 20 | breaks = NULL, 21 | n_bins = 11L, 22 | cut_type = c("equal", "quantile"), 23 | use_linkinv = TRUE, 24 | counts_weighted = FALSE, 25 | v_labels = TRUE, 26 | pred = NULL, 27 | pd_indices = NULL, 28 | pd_n_max = 1000L, 29 | pd_seed = NULL, 30 | ale_two_sided = TRUE, 31 | ... 32 | ) 33 | 34 | \method{light_effects}{multiflashlight}( 35 | x, 36 | v, 37 | data = NULL, 38 | breaks = NULL, 39 | n_bins = 11L, 40 | cut_type = c("equal", "quantile"), 41 | ... 42 | ) 43 | } 44 | \arguments{ 45 | \item{x}{An object of class "flashlight" or "multiflashlight".} 46 | 47 | \item{...}{Further arguments passed to \code{\link[=formatC]{formatC()}} in forming the 48 | cut breaks of the \code{v} variable.} 49 | 50 | \item{v}{The variable name to be profiled.} 51 | 52 | \item{data}{An optional \code{data.frame}.} 53 | 54 | \item{by}{An optional vector of column names used to additionally group the results.} 55 | 56 | \item{stats}{Deprecated. Will be removed in version 1.1.0.} 57 | 58 | \item{breaks}{Cut breaks for a numeric \code{v}. Used to overwrite automatic binning via 59 | \code{n_bins} and \code{cut_type}. Ignored if \code{v} is not numeric.} 60 | 61 | \item{n_bins}{Approximate number of unique values to evaluate for numeric \code{v}. 62 | Ignored if \code{v} is not numeric or if \code{breaks} is specified.} 63 | 64 | \item{cut_type}{Should a numeric \code{v} be cut into "equal" or "quantile" bins? 65 | Ignored if \code{v} is not numeric or if \code{breaks} is specified.} 66 | 67 | \item{use_linkinv}{Should retransformation function be applied? Default is \code{TRUE}.} 68 | 69 | \item{counts_weighted}{Should counts be weighted by the case weights? 70 | If \code{TRUE}, the sum of \code{w} is returned by group.} 71 | 72 | \item{v_labels}{If \code{FALSE}, return group centers of \code{v} instead of labels. 73 | Only relevant if \code{v} is numeric with many distinct values. 74 | In that case useful for instance when different flashlights use different data sets.} 75 | 76 | \item{pred}{Optional vector with predictions (after application of inverse link). 77 | Can be used to avoid recalculation of predictions over and over if the functions 78 | is to be repeatedly called for different \code{v} and predictions are computationally 79 | expensive to make. Not implemented for multiflashlight.} 80 | 81 | \item{pd_indices}{A vector of row numbers to consider in calculating 82 | partial dependence profiles and "ale".} 83 | 84 | \item{pd_n_max}{Maximum number of ICE profiles to calculate (will be randomly 85 | picked from \code{data}) for partial dependence and ALE.} 86 | 87 | \item{pd_seed}{Integer random seed used to select ICE profiles for partial dependence 88 | and ALE.} 89 | 90 | \item{ale_two_sided}{If \code{TRUE}, \code{v} is continuous and \code{breaks} 91 | are passed or being calculated, then two-sided derivatives are calculated 92 | for ALE instead of left derivatives. More specifically: Usually, local effects 93 | at value x are calculated using points in \eqn{[x-e, x]}. 94 | Set \code{ale_two_sided = TRUE} to use points in \eqn{[x-e/2, x+e/2]}.} 95 | } 96 | \value{ 97 | An object of class "light_effects" with the following elements: 98 | \itemize{ 99 | \item \code{response}: A tibble containing the response profiles. 100 | Column names can be controlled by \code{options(flashlight.column_name)}. 101 | \item \code{predicted}: A tibble containing the prediction profiles. 102 | \item \code{pd}: A tibble containing the partial dependence profiles. 103 | \item \code{ale}: A tibble containing the ALE profiles. 104 | \item \code{by}: Same as input \code{by}. 105 | \item \code{v}: The variable(s) evaluated. 106 | } 107 | } 108 | \description{ 109 | Calculates response- prediction-, partial dependence, and ALE profiles of a 110 | (multi-)flashlight with respect to a covariable \code{v}. 111 | } 112 | \details{ 113 | Note that ALE profiles are being calibrated by (weighted) average predictions. 114 | The resulting level might be quite different from the one of the partial 115 | dependence profiles. 116 | } 117 | \section{Methods (by class)}{ 118 | \itemize{ 119 | \item \code{light_effects(default)}: Default method. 120 | 121 | \item \code{light_effects(flashlight)}: Profiles for a flashlight object. 122 | 123 | \item \code{light_effects(multiflashlight)}: Effect profiles for a multiflashlight object. 124 | 125 | }} 126 | \examples{ 127 | fit_lin <- lm(Sepal.Length ~ ., data = iris) 128 | fl_lin <- flashlight(model = fit_lin, label = "lin", data = iris, y = "Sepal.Length") 129 | 130 | # PDP, average response, average predicted by Species 131 | eff <- light_effects(fl_lin, v = "Petal.Length") 132 | plot(eff) 133 | 134 | # PDP and ALE 135 | plot(eff, use = c("pd", "ale"), recode_labels = c(ale = "ALE")) 136 | 137 | # Second model with non-linear Petal.Length effect 138 | fit_nonlin <- lm(Sepal.Length ~ . + I(Petal.Length^2), data = iris) 139 | fl_nonlin <- flashlight( 140 | model = fit_nonlin, label = "nonlin", data = iris, y = "Sepal.Length" 141 | ) 142 | fls <- multiflashlight(list(fl_lin, fl_nonlin)) 143 | 144 | # PDP and ALE 145 | plot(light_effects(fls, v = "Petal.Length"), use = c("pd", "ale")) 146 | } 147 | \seealso{ 148 | \code{\link[=light_profile]{light_profile()}}, \code{\link[=plot.light_effects]{plot.light_effects()}} 149 | } 150 | -------------------------------------------------------------------------------- /man/light_global_surrogate.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_global_surrogate.R 3 | \name{light_global_surrogate} 4 | \alias{light_global_surrogate} 5 | \alias{light_global_surrogate.default} 6 | \alias{light_global_surrogate.flashlight} 7 | \alias{light_global_surrogate.multiflashlight} 8 | \title{Global Surrogate Tree} 9 | \usage{ 10 | light_global_surrogate(x, ...) 11 | 12 | \method{light_global_surrogate}{default}(x, ...) 13 | 14 | \method{light_global_surrogate}{flashlight}( 15 | x, 16 | data = x$data, 17 | by = x$by, 18 | v = NULL, 19 | use_linkinv = TRUE, 20 | n_max = Inf, 21 | seed = NULL, 22 | keep_max_levels = 4L, 23 | ... 24 | ) 25 | 26 | \method{light_global_surrogate}{multiflashlight}(x, ...) 27 | } 28 | \arguments{ 29 | \item{x}{An object of class "flashlight" or "multiflashlight".} 30 | 31 | \item{...}{Arguments passed to \code{\link[rpart:rpart]{rpart::rpart()}}, such as \code{maxdepth}.} 32 | 33 | \item{data}{An optional \code{data.frame}.} 34 | 35 | \item{by}{An optional vector of column names used to additionally group the results. 36 | For each group, a separate tree is grown.} 37 | 38 | \item{v}{Vector of variables used in the surrogate model. 39 | Defaults to all variables in \code{data} except "by", "w" and "y".} 40 | 41 | \item{use_linkinv}{Should retransformation function be applied? Default is \code{TRUE}.} 42 | 43 | \item{n_max}{Maximum number of data rows to consider to build the tree.} 44 | 45 | \item{seed}{An integer random seed used to select data rows if \code{n_max} is lower than 46 | the number of data rows.} 47 | 48 | \item{keep_max_levels}{Number of levels of categorical and factor variables to keep. 49 | Other levels are combined to a level "Other". This prevents \code{\link[rpart:rpart]{rpart::rpart()}} to 50 | take too long to split non-numeric variables with many levels.} 51 | } 52 | \value{ 53 | An object of class "light_global_surrogate" with the following elements: 54 | \itemize{ 55 | \item \code{data} A tibble with results. 56 | \item \code{by} Same as input \code{by}. 57 | } 58 | } 59 | \description{ 60 | Model predictions are modelled by a single decision tree, serving as an easy 61 | to interprete surrogate to the original model. 62 | As suggested in Molnar (see reference below), the quality of the surrogate 63 | tree can be measured by its R-squared. The size of the tree can be modified 64 | by passing \code{...} arguments to \code{\link[rpart:rpart]{rpart::rpart()}}. 65 | } 66 | \section{Methods (by class)}{ 67 | \itemize{ 68 | \item \code{light_global_surrogate(default)}: Default method not implemented yet. 69 | 70 | \item \code{light_global_surrogate(flashlight)}: Surrogate model for a flashlight. 71 | 72 | \item \code{light_global_surrogate(multiflashlight)}: Surrogate model for a multiflashlight. 73 | 74 | }} 75 | \examples{ 76 | fit <- lm(Sepal.Length ~ ., data = iris) 77 | x <- flashlight(model = fit, label = "lm", data = iris) 78 | sur <- light_global_surrogate(x) 79 | sur$data$r_squared 80 | plot(sur) 81 | } 82 | \references{ 83 | Molnar C. (2019). Interpretable Machine Learning. 84 | } 85 | \seealso{ 86 | \code{\link[=plot.light_global_surrogate]{plot.light_global_surrogate()}} 87 | } 88 | -------------------------------------------------------------------------------- /man/light_ice.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_ice.R 3 | \name{light_ice} 4 | \alias{light_ice} 5 | \alias{light_ice.default} 6 | \alias{light_ice.flashlight} 7 | \alias{light_ice.multiflashlight} 8 | \title{Individual Conditional Expectation (ICE)} 9 | \usage{ 10 | light_ice(x, ...) 11 | 12 | \method{light_ice}{default}(x, ...) 13 | 14 | \method{light_ice}{flashlight}( 15 | x, 16 | v = NULL, 17 | data = x$data, 18 | by = x$by, 19 | evaluate_at = NULL, 20 | breaks = NULL, 21 | grid = NULL, 22 | n_bins = 27L, 23 | cut_type = c("equal", "quantile"), 24 | indices = NULL, 25 | n_max = 20L, 26 | seed = NULL, 27 | use_linkinv = TRUE, 28 | center = c("no", "first", "middle", "last", "mean", "0"), 29 | ... 30 | ) 31 | 32 | \method{light_ice}{multiflashlight}(x, ...) 33 | } 34 | \arguments{ 35 | \item{x}{An object of class "flashlight" or "multiflashlight".} 36 | 37 | \item{...}{Further arguments passed to or from other methods.} 38 | 39 | \item{v}{The variable name to be profiled.} 40 | 41 | \item{data}{An optional \code{data.frame}.} 42 | 43 | \item{by}{An optional vector of column names used to additionally group the results.} 44 | 45 | \item{evaluate_at}{Vector with values of \code{v} used to evaluate the profile.} 46 | 47 | \item{breaks}{Cut breaks for a numeric \code{v}. Used to overwrite automatic 48 | binning via \code{n_bins} and \code{cut_type}. Ignored if \code{v} is not numeric or if \code{grid} 49 | or \code{evaluate_at} are specified.} 50 | 51 | \item{grid}{A \code{data.frame} with evaluation grid. For instance, can be generated by 52 | \code{\link[=expand.grid]{expand.grid()}}.} 53 | 54 | \item{n_bins}{Approximate number of unique values to evaluate for numeric \code{v}. 55 | Ignored if \code{v} is not numeric or if \code{breaks}, \code{grid} or \code{evaluate_at} are specified.} 56 | 57 | \item{cut_type}{Should a numeric \code{v} be cut into "equal" or "quantile" bins? 58 | Ignored if \code{v} is not numeric or if \code{breaks}, \code{grid} or \code{evaluate_at} are specified.} 59 | 60 | \item{indices}{A vector of row numbers to consider.} 61 | 62 | \item{n_max}{If \code{indices} is not given, maximum number of rows to consider. 63 | Will be randomly picked from \code{data} if necessary.} 64 | 65 | \item{seed}{An integer random seed.} 66 | 67 | \item{use_linkinv}{Should retransformation function be applied? Default is \code{TRUE}.} 68 | 69 | \item{center}{How should curves be centered? 70 | \itemize{ 71 | \item Default is "no". 72 | \item Choose "first", "middle", or "last" to 0-center at specific evaluation points. 73 | \item Choose "mean" to center all profiles at the within-group means. 74 | \item Choose "0" to mean-center curves at 0. 75 | }} 76 | } 77 | \value{ 78 | An object of class "light_ice" with the following elements: 79 | \itemize{ 80 | \item \code{data} A tibble containing the results. 81 | \item \code{by} Same as input \code{by}. 82 | \item \code{v} The variable(s) evaluated. 83 | \item \code{center} How centering was done. 84 | } 85 | } 86 | \description{ 87 | Generates Individual Conditional Expectation (ICE) profiles. 88 | An ICE profile shows how the prediction of an observation changes if 89 | one or multiple variables are systematically changed across its ranges, 90 | holding all other values fixed (see the reference below for details). 91 | The curves can be centered in order to increase visibility of interaction effects. 92 | } 93 | \details{ 94 | There are two ways to specify the variable(s) to be profiled. 95 | \enumerate{ 96 | \item Pass the variable name via \code{v} and an optional vector with evaluation points 97 | \code{evaluate_at} (or \code{breaks}). This works for dependence on a single variable. 98 | \item More general: Specify any \code{grid} as a \code{data.frame} with one or 99 | more columns. For instance, it can be generated by a call to \code{\link[=expand.grid]{expand.grid()}}. 100 | } 101 | 102 | The minimum required elements in the (multi-)flashlight are "predict_function", 103 | "model", "linkinv" and "data", where the latest can be passed on the fly. 104 | 105 | Which rows in \code{data} are profiled? This is specified by \code{indices}. 106 | If not given and \code{n_max} is smaller than the number of rows in \code{data}, 107 | then row indices will be sampled randomly from \code{data}. 108 | If the same rows should be used for all flashlights in a multiflashlight, 109 | there are two options: Either pass a \code{seed} or a vector of indices used to select rows. 110 | In both cases, \code{data} should be the same for all flashlights considered. 111 | } 112 | \section{Methods (by class)}{ 113 | \itemize{ 114 | \item \code{light_ice(default)}: Default method not implemented yet. 115 | 116 | \item \code{light_ice(flashlight)}: ICE profiles for a flashlight object. 117 | 118 | \item \code{light_ice(multiflashlight)}: ICE profiles for a multiflashlight object. 119 | 120 | }} 121 | \examples{ 122 | fit_add <- lm(Sepal.Length ~ ., data = iris) 123 | fl_add <- flashlight(model = fit_add, label = "additive", data = iris) 124 | 125 | plot(light_ice(fl_add, v = "Sepal.Width", n_max = 200), alpha = 0.2) 126 | plot(light_ice(fl_add, v = "Sepal.Width", n_max = 200, center = "first")) 127 | 128 | # Second model with interactions 129 | fit_nonadd <- lm(Sepal.Length ~ . + Sepal.Width:Species, data = iris) 130 | fl_nonadd <- flashlight(model = fit_nonadd, label = "nonadditive", data = iris) 131 | fls <- multiflashlight(list(fl_add, fl_nonadd)) 132 | 133 | plot(light_ice(fls, v = "Sepal.Width", by = "Species", n_max = 200), alpha = 0.2) 134 | plot(light_ice(fls, v = "Sepal.Width", by = "Species", n_max = 200, center = "mid")) 135 | } 136 | \references{ 137 | Goldstein, A. et al. (2015). Peeking inside the black box: Visualizing statistical 138 | learning with plots of individual conditional expectation. 139 | Journal of Computational and Graphical Statistics, 24:1 140 | . 141 | } 142 | \seealso{ 143 | \code{\link[=light_profile]{light_profile()}}, \code{\link[=plot.light_ice]{plot.light_ice()}} 144 | } 145 | -------------------------------------------------------------------------------- /man/light_importance.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_importance.R 3 | \name{light_importance} 4 | \alias{light_importance} 5 | \alias{light_importance.default} 6 | \alias{light_importance.flashlight} 7 | \alias{light_importance.multiflashlight} 8 | \title{Permutation Variable Importance} 9 | \usage{ 10 | light_importance(x, ...) 11 | 12 | \method{light_importance}{default}(x, ...) 13 | 14 | \method{light_importance}{flashlight}( 15 | x, 16 | data = x$data, 17 | by = x$by, 18 | type = c("permutation", "shap"), 19 | v = NULL, 20 | n_max = Inf, 21 | seed = NULL, 22 | m_repetitions = 1L, 23 | metric = x$metrics[1L], 24 | lower_is_better = TRUE, 25 | use_linkinv = FALSE, 26 | ... 27 | ) 28 | 29 | \method{light_importance}{multiflashlight}(x, ...) 30 | } 31 | \arguments{ 32 | \item{x}{An object of class "flashlight" or "multiflashlight".} 33 | 34 | \item{...}{Further arguments passed to \code{\link[=light_performance]{light_performance()}}.} 35 | 36 | \item{data}{An optional \code{data.frame}.} 37 | 38 | \item{by}{An optional vector of column names used to additionally group the results.} 39 | 40 | \item{type}{Type of importance: "permutation" (currently the only option).} 41 | 42 | \item{v}{Vector of variable names to assess importance for. 43 | Defaults to all variables in \code{data} except "by" and "y".} 44 | 45 | \item{n_max}{Maximum number of rows to consider.} 46 | 47 | \item{seed}{An integer random seed used to select and shuffle rows.} 48 | 49 | \item{m_repetitions}{Number of permutations. Defaults to 1. 50 | A value above 1 provides more stable estimates of variable importance and 51 | allows the calculation of standard errors measuring the uncertainty from permuting.} 52 | 53 | \item{metric}{An optional named list of length one with a metric as element. 54 | Defaults to the first metric in the flashlight. The metric needs to be a function 55 | with at least four arguments: actual, predicted, case weights w and \code{...}.} 56 | 57 | \item{lower_is_better}{Logical flag indicating if lower values in the metric 58 | are better or not. If set to \code{FALSE}, the increase in metric is multiplied by -1.} 59 | 60 | \item{use_linkinv}{Should retransformation function be applied? Default is \code{FALSE}.} 61 | } 62 | \value{ 63 | An object of class "light_importance" with the following elements: 64 | \itemize{ 65 | \item \code{data} A tibble with results. 66 | \item \code{by} Same as input \code{by}. 67 | \item \code{type} Same as input \code{type}. For information only. 68 | } 69 | } 70 | \description{ 71 | Importance of variable \code{v} is measured as drop in performance 72 | by permuting the values of \code{v}, see Fisher et al. 2018 (reference below). 73 | } 74 | \details{ 75 | The minimum required elements in the (multi-)flashlight are "y", "predict_function", 76 | "model", "data" and "metrics". 77 | } 78 | \section{Methods (by class)}{ 79 | \itemize{ 80 | \item \code{light_importance(default)}: Default method not implemented yet. 81 | 82 | \item \code{light_importance(flashlight)}: Variable importance for a flashlight. 83 | 84 | \item \code{light_importance(multiflashlight)}: Variable importance for a multiflashlight. 85 | 86 | }} 87 | \examples{ 88 | fit_part <- lm(Sepal.Length ~ Species + Petal.Length, data = iris) 89 | fl_part <- flashlight( 90 | model = fit_part, label = "part", data = iris, y = "Sepal.Length" 91 | ) 92 | 93 | # No effect of some variables (incl. standard errors) 94 | plot(light_importance(fl_part, m_repetitions = 4), fill = "chartreuse4") 95 | 96 | # Second model includes all variables 97 | fit_full <- lm(Sepal.Length ~ ., data = iris) 98 | fl_full <- flashlight( 99 | model = fit_full, label = "full", data = iris, y = "Sepal.Length" 100 | ) 101 | fls <- multiflashlight(list(fl_part, fl_full)) 102 | 103 | plot(light_importance(fls), fill = "chartreuse4") 104 | plot(light_importance(fls, by = "Species")) 105 | } 106 | \references{ 107 | Fisher A., Rudin C., Dominici F. (2018). All Models are Wrong but many are Useful: 108 | Variable Importance for Black-Box, Proprietary, or Misspecified Prediction 109 | Models, using Model Class Reliance. Arxiv. 110 | } 111 | \seealso{ 112 | \code{\link[=most_important]{most_important()}}, \code{\link[=plot.light_importance]{plot.light_importance()}} 113 | } 114 | -------------------------------------------------------------------------------- /man/light_interaction.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_interaction.R 3 | \name{light_interaction} 4 | \alias{light_interaction} 5 | \alias{light_interaction.default} 6 | \alias{light_interaction.flashlight} 7 | \alias{light_interaction.multiflashlight} 8 | \title{Interaction Strength} 9 | \usage{ 10 | light_interaction(x, ...) 11 | 12 | \method{light_interaction}{default}(x, ...) 13 | 14 | \method{light_interaction}{flashlight}( 15 | x, 16 | data = x$data, 17 | by = x$by, 18 | v = NULL, 19 | pairwise = FALSE, 20 | type = c("H", "ice"), 21 | normalize = TRUE, 22 | take_sqrt = TRUE, 23 | grid_size = 200L, 24 | n_max = 1000L, 25 | seed = NULL, 26 | use_linkinv = FALSE, 27 | ... 28 | ) 29 | 30 | \method{light_interaction}{multiflashlight}(x, ...) 31 | } 32 | \arguments{ 33 | \item{x}{An object of class "flashlight" or "multiflashlight".} 34 | 35 | \item{...}{Further arguments passed to or from other methods.} 36 | 37 | \item{data}{An optional \code{data.frame}.} 38 | 39 | \item{by}{An optional vector of column names used to additionally group the results.} 40 | 41 | \item{v}{Vector of variable names to be assessed.} 42 | 43 | \item{pairwise}{Should overall interaction strength per variable be shown or 44 | pairwise interactions? Defaults to \code{FALSE}.} 45 | 46 | \item{type}{Are measures based on Friedman's H statistic ("H") or on "ice" curves? 47 | Option "ice" is available only if \code{pairwise = FALSE}.} 48 | 49 | \item{normalize}{Should the variances explained be normalized? 50 | Default is \code{TRUE} in order to reproduce Friedman's H statistic.} 51 | 52 | \item{take_sqrt}{In order to reproduce Friedman's H statistic, 53 | resulting values are root transformed. Set to \code{FALSE} if squared values 54 | should be returned.} 55 | 56 | \item{grid_size}{Grid size used to form the outer product. Will be randomly 57 | picked from data (after limiting to \code{n_max}).} 58 | 59 | \item{n_max}{Maximum number of data rows to consider. Will be randomly picked 60 | from \code{data} if necessary.} 61 | 62 | \item{seed}{An integer random seed used for subsampling.} 63 | 64 | \item{use_linkinv}{Should retransformation function be applied? Default is \code{FALSE}.} 65 | } 66 | \value{ 67 | An object of class "light_importance" with the following elements: 68 | \itemize{ 69 | \item \code{data} A tibble containing the results. Can be used to build fully customized 70 | visualizations. Column names can be controlled by 71 | \code{options(flashlight.column_name)}. 72 | \item \code{by} Same as input \code{by}. 73 | \item \code{type} Same as input \code{type}. For information only. 74 | } 75 | } 76 | \description{ 77 | This function provides Friedman's H statistic for overall interaction strength per 78 | covariable as well as its version for pairwise interactions, see the reference below. 79 | } 80 | \details{ 81 | As a fast alternative to assess overall interaction strength, with \code{type = "ice"}, 82 | the function offers a method based on centered ICE curves: 83 | The corresponding H* statistic measures how much of the variability of a c-ICE curve 84 | is unexplained by the main effect. As for Friedman's H statistic, it can be useful 85 | to consider unnormalized or squared values (see Details below). 86 | 87 | Friedman's H statistic relates the interaction strength of a variable (pair) 88 | to the total effect strength of that variable (pair) based on partial dependence 89 | curves. Due to this normalization step, even variables with low importance can 90 | have high values for H. The function \code{\link[=light_interaction]{light_interaction()}} offers the option 91 | to skip normalization in order to have a more direct comparison of the interaction 92 | effects across variable (pairs). The values of such unnormalized H statistics are 93 | on the scale of the response variable. Use \code{take_sqrt = FALSE} to return 94 | squared values of H. Note that in general, for each variable (pair), predictions 95 | are done on a data set with \code{grid_size * n_max}, so be cautious with 96 | increasing the defaults too much. Still, even with larger \code{grid_size} 97 | and \code{n_max}, there might be considerable variation across different runs, 98 | thus, setting a seed is recommended. 99 | 100 | The minimum required elements in the (multi-) flashlight are a "predict_function", 101 | "model", and "data". 102 | } 103 | \section{Methods (by class)}{ 104 | \itemize{ 105 | \item \code{light_interaction(default)}: Default method not implemented yet. 106 | 107 | \item \code{light_interaction(flashlight)}: Interaction strengths for a flashlight object. 108 | 109 | \item \code{light_interaction(multiflashlight)}: for a multiflashlight object. 110 | 111 | }} 112 | \examples{ 113 | # First model with interactions 114 | fit_nonadd <- lm( 115 | Sepal.Length ~ . + Sepal.Width:Species + Petal.Width:Species, data = iris 116 | ) 117 | fl_nonadd <- flashlight( 118 | model = fit_nonadd, label = "nonadditive", data = iris, y = "Sepal.Length" 119 | ) 120 | 121 | # Friedman's H per feature 122 | plot(light_interaction(fl_nonadd), fill = "chartreuse4") 123 | 124 | # Unnormalized H^2 measures proportion of bivariate effect explained by interaction 125 | plot( 126 | light_interaction(fl_nonadd, normalize = TRUE, take_sqrt = TRUE), 127 | fill = "chartreuse4" 128 | ) 129 | 130 | # Pairwise H 131 | plot(light_interaction(fl_nonadd, pairwise = TRUE), fill = "chartreuse4") 132 | 133 | # Second model without interactions 134 | fit_add <- lm(Sepal.Length ~ ., data = iris) 135 | fl_add <- flashlight( 136 | model = fit_add, label = "additive", data = iris, y = "Sepal.Length" 137 | ) 138 | fls <- multiflashlight(list(fl_add, fl_nonadd)) 139 | 140 | plot(light_interaction(fls), fill = "chartreuse4") 141 | } 142 | \references{ 143 | Friedman, J. H. and Popescu, B. E. (2008). "Predictive learning via rule 144 | ensembles." The Annals of Applied Statistics. JSTOR, 916–54. 145 | } 146 | \seealso{ 147 | \code{\link[=light_ice]{light_ice()}} 148 | } 149 | -------------------------------------------------------------------------------- /man/light_performance.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_performance.R 3 | \name{light_performance} 4 | \alias{light_performance} 5 | \alias{light_performance.default} 6 | \alias{light_performance.flashlight} 7 | \alias{light_performance.multiflashlight} 8 | \title{Model Performance of Flashlight} 9 | \usage{ 10 | light_performance(x, ...) 11 | 12 | \method{light_performance}{default}(x, ...) 13 | 14 | \method{light_performance}{flashlight}( 15 | x, 16 | data = x$data, 17 | by = x$by, 18 | metrics = x$metrics, 19 | use_linkinv = FALSE, 20 | ... 21 | ) 22 | 23 | \method{light_performance}{multiflashlight}(x, ...) 24 | } 25 | \arguments{ 26 | \item{x}{An object of class "flashlight" or "multiflashlight".} 27 | 28 | \item{...}{Arguments passed from or to other functions.} 29 | 30 | \item{data}{An optional \code{data.frame}.} 31 | 32 | \item{by}{An optional vector of column names used to additionally group the results. 33 | Will overwrite \code{x$by}.} 34 | 35 | \item{metrics}{An optional named list with metrics. Each metric takes at least 36 | four arguments: actual, predicted, case weights w and \code{...}.} 37 | 38 | \item{use_linkinv}{Should retransformation function be applied? Default is \code{FALSE}.} 39 | } 40 | \value{ 41 | An object of class "light_performance" with the following elements: 42 | \itemize{ 43 | \item \code{data}: A tibble containing the results. 44 | \item \code{by} Same as input \code{by}. 45 | } 46 | } 47 | \description{ 48 | Calculates performance of a flashlight with respect to one or more 49 | performance measure. 50 | } 51 | \details{ 52 | The minimal required elements in the (multi-) flashlight are "y", "predict_function", 53 | "model", "data" and "metrics". The latter two can also directly be passed to 54 | \code{\link[=light_performance]{light_performance()}}. Note that by default, no retransformation function is applied. 55 | } 56 | \section{Methods (by class)}{ 57 | \itemize{ 58 | \item \code{light_performance(default)}: Default method not implemented yet. 59 | 60 | \item \code{light_performance(flashlight)}: Model performance of flashlight object. 61 | 62 | \item \code{light_performance(multiflashlight)}: Model performance of multiflashlight object. 63 | 64 | }} 65 | \examples{ 66 | fit_part <- lm(Sepal.Length ~ Species + Petal.Length, data = iris) 67 | fl_part <- flashlight( 68 | model = fit_part, label = "part", data = iris, y = "Sepal.Length" 69 | ) 70 | plot(light_performance(fl_part, by = "Species"), fill = "chartreuse4") 71 | 72 | # Second model 73 | fit_full <- lm(Sepal.Length ~ ., data = iris) 74 | fl_full <- flashlight( 75 | model = fit_full, label = "full", data = iris, y = "Sepal.Length" 76 | ) 77 | fls <- multiflashlight(list(fl_part, fl_full)) 78 | 79 | plot(light_performance(fls, by = "Species")) 80 | plot(light_performance(fls, by = "Species"), swap_dim = TRUE) 81 | } 82 | \seealso{ 83 | \code{\link[=plot.light_performance]{plot.light_performance()}} 84 | } 85 | -------------------------------------------------------------------------------- /man/light_profile.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_profile.R 3 | \name{light_profile} 4 | \alias{light_profile} 5 | \alias{light_profile.default} 6 | \alias{light_profile.flashlight} 7 | \alias{light_profile.multiflashlight} 8 | \title{Partial Dependence and other Profiles} 9 | \usage{ 10 | light_profile(x, ...) 11 | 12 | \method{light_profile}{default}(x, ...) 13 | 14 | \method{light_profile}{flashlight}( 15 | x, 16 | v = NULL, 17 | data = NULL, 18 | by = x$by, 19 | type = c("partial dependence", "ale", "predicted", "response", "residual", "shap"), 20 | stats = "mean", 21 | breaks = NULL, 22 | n_bins = 11L, 23 | cut_type = c("equal", "quantile"), 24 | use_linkinv = TRUE, 25 | counts = TRUE, 26 | counts_weighted = FALSE, 27 | v_labels = TRUE, 28 | pred = NULL, 29 | pd_evaluate_at = NULL, 30 | pd_grid = NULL, 31 | pd_indices = NULL, 32 | pd_n_max = 1000L, 33 | pd_seed = NULL, 34 | pd_center = c("no", "first", "middle", "last", "mean", "0"), 35 | ale_two_sided = FALSE, 36 | ... 37 | ) 38 | 39 | \method{light_profile}{multiflashlight}( 40 | x, 41 | v = NULL, 42 | data = NULL, 43 | type = c("partial dependence", "ale", "predicted", "response", "residual", "shap"), 44 | breaks = NULL, 45 | n_bins = 11L, 46 | cut_type = c("equal", "quantile"), 47 | pd_evaluate_at = NULL, 48 | pd_grid = NULL, 49 | ... 50 | ) 51 | } 52 | \arguments{ 53 | \item{x}{An object of class "flashlight" or "multiflashlight".} 54 | 55 | \item{...}{Further arguments passed to \code{\link[=formatC]{formatC()}} in forming the 56 | cut breaks of the \code{v} variable.} 57 | 58 | \item{v}{The variable name to be profiled.} 59 | 60 | \item{data}{An optional \code{data.frame}.} 61 | 62 | \item{by}{An optional vector of column names used to additionally group the results.} 63 | 64 | \item{type}{Type of the profile: Either "partial dependence", "ale", "predicted", 65 | "response", or "residual".} 66 | 67 | \item{stats}{Deprecated. Will be removed in version 1.1.0.} 68 | 69 | \item{breaks}{Cut breaks for a numeric \code{v}. Used to overwrite automatic binning via 70 | \code{n_bins} and \code{cut_type}. Ignored if \code{v} is not numeric.} 71 | 72 | \item{n_bins}{Approximate number of unique values to evaluate for numeric \code{v}. 73 | Ignored if \code{v} is not numeric or if \code{breaks} is specified.} 74 | 75 | \item{cut_type}{Should a numeric \code{v} be cut into "equal" or "quantile" bins? 76 | Ignored if \code{v} is not numeric or if \code{breaks} is specified.} 77 | 78 | \item{use_linkinv}{Should retransformation function be applied? Default is \code{TRUE}.} 79 | 80 | \item{counts}{Should observation counts be added?} 81 | 82 | \item{counts_weighted}{If \code{counts = TRUE}: Should counts be weighted by the 83 | case weights? If \code{TRUE}, the sum of \code{w} is returned by group.} 84 | 85 | \item{v_labels}{If \code{FALSE}, return group centers of \code{v} instead of labels. 86 | Only relevant for types "response", "predicted" or "residual" and if \code{v} 87 | is being binned. In that case useful, for instance, if different flashlights 88 | use different data sets and bin labels would not match.} 89 | 90 | \item{pred}{Optional vector with predictions (after application of inverse link). 91 | Can be used to avoid recalculation of predictions over and over if the functions 92 | is to be repeatedly called for different \code{v} and predictions are computationally 93 | expensive to make. Not implemented for multiflashlight.} 94 | 95 | \item{pd_evaluate_at}{Vector with values of \code{v} used to evaluate the profile. 96 | Only relevant for type = "partial dependence" and "ale".} 97 | 98 | \item{pd_grid}{A \code{data.frame} with grid values, e.g., generated by \code{\link[=expand.grid]{expand.grid()}}. 99 | Only used for type = "partial dependence".} 100 | 101 | \item{pd_indices}{A vector of row numbers to consider in calculating 102 | partial dependence profiles and "ale".} 103 | 104 | \item{pd_n_max}{Maximum number of ICE profiles to calculate (will be randomly 105 | picked from \code{data}) for partial dependence and ALE.} 106 | 107 | \item{pd_seed}{Integer random seed used to select ICE profiles for partial dependence 108 | and ALE.} 109 | 110 | \item{pd_center}{How should ICE curves be centered? 111 | \itemize{ 112 | \item Default is "no". 113 | \item Choose "first", "middle", or "last" to 0-center at specific evaluation points. 114 | \item Choose "mean" to center all profiles at the within-group means. 115 | \item Choose "0" to mean-center curves at 0. Only relevant for partial dependence. 116 | }} 117 | 118 | \item{ale_two_sided}{If \code{TRUE}, \code{v} is continuous and \code{breaks} 119 | are passed or being calculated, then two-sided derivatives are calculated 120 | for ALE instead of left derivatives. More specifically: Usually, local effects 121 | at value x are calculated using points in \eqn{[x-e, x]}. 122 | Set \code{ale_two_sided = TRUE} to use points in \eqn{[x-e/2, x+e/2]}.} 123 | } 124 | \value{ 125 | An object of class "light_profile" with the following elements: 126 | \itemize{ 127 | \item \code{data} A tibble containing results. 128 | \item \code{by} Names of group by variable. 129 | \item \code{v} The variable(s) evaluated. 130 | \item \code{type} Same as input \code{type}. For information only. 131 | } 132 | } 133 | \description{ 134 | Calculates different types of profiles across covariable values. 135 | By default, partial dependence profiles are calculated (see Friedman). 136 | Other options are profiles of ALE (accumulated local effects, see Apley), 137 | response, predicted values ("M plots" or "marginal plots", see Apley), and residuals. 138 | The results are aggregated either by (weighted) means or by (weighted) quartiles. 139 | 140 | Note that ALE profiles are calibrated by (weighted) average predictions. 141 | In contrast to the suggestions in Apley, we calculate ALE profiles of factors 142 | in the same order as the factor levels. 143 | They are not being reordered based on similiarity of other variables. 144 | } 145 | \details{ 146 | Numeric covariables \code{v} with more than \code{n_bins} disjoint values 147 | are binned into \code{n_bins} bins. Alternatively, \code{breaks} can be provided 148 | to specify the binning. For partial dependence profiles 149 | (and partly also ALE profiles), this behaviour can be overwritten either 150 | by providing a vector of evaluation points (\code{pd_evaluate_at}) or an 151 | evaluation \code{pd_grid}. By the latter we mean a data frame with column name(s) 152 | with a (multi-)variate evaluation grid. 153 | 154 | For partial dependence, ALE, and prediction profiles, "model", "predict_function", 155 | "linkinv" and "data" are required. For response profiles its "y", "linkinv" and 156 | "data". "data" can also be passed on the fly. 157 | } 158 | \section{Methods (by class)}{ 159 | \itemize{ 160 | \item \code{light_profile(default)}: Default method not implemented yet. 161 | 162 | \item \code{light_profile(flashlight)}: Profiles for flashlight. 163 | 164 | \item \code{light_profile(multiflashlight)}: Profiles for multiflashlight. 165 | 166 | }} 167 | \examples{ 168 | fit_lin <- lm(Sepal.Length ~ ., data = iris) 169 | fl_lin <- flashlight(model = fit_lin, label = "lin", data = iris, y = "Sepal.Length") 170 | 171 | # PDP by Species 172 | plot(light_profile(fl_lin, v = "Petal.Length", by = "Species")) 173 | 174 | # Average predicted 175 | plot(light_profile(fl_lin, v = "Petal.Length", type = "pred")) 176 | 177 | # Second model with non-linear Petal.Length effect 178 | fit_nonlin <- lm(Sepal.Length ~ . + I(Petal.Length^2), data = iris) 179 | fl_nonlin <- flashlight( 180 | model = fit_nonlin, label = "nonlin", data = iris, y = "Sepal.Length" 181 | ) 182 | fls <- multiflashlight(list(fl_lin, fl_nonlin)) 183 | 184 | # PDP by Species 185 | plot(light_profile(fls, v = "Petal.Length", by = "Species")) 186 | plot(light_profile(fls, v = "Petal.Length", by = "Species"), swap_dim = TRUE) 187 | 188 | # Average residuals (calibration) 189 | plot(light_profile(fls, v = "Petal.Length", type = "residual")) 190 | } 191 | \references{ 192 | \itemize{ 193 | \item Friedman J. H. (2001). Greedy function approximation: A gradient boosting machine. 194 | The Annals of Statistics, 29:1189–1232. 195 | \item Apley D. W. (2016). Visualizing the effects of predictor variables in black box 196 | supervised learning models. 197 | } 198 | } 199 | \seealso{ 200 | \code{\link[=light_effects]{light_effects()}}, \code{\link[=plot.light_profile]{plot.light_profile()}} 201 | } 202 | -------------------------------------------------------------------------------- /man/light_profile2d.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_profile2d.R 3 | \name{light_profile2d} 4 | \alias{light_profile2d} 5 | \alias{light_profile2d.default} 6 | \alias{light_profile2d.flashlight} 7 | \alias{light_profile2d.multiflashlight} 8 | \title{2D Partial Dependence and other 2D Profiles} 9 | \usage{ 10 | light_profile2d(x, ...) 11 | 12 | \method{light_profile2d}{default}(x, ...) 13 | 14 | \method{light_profile2d}{flashlight}( 15 | x, 16 | v = NULL, 17 | data = NULL, 18 | by = x$by, 19 | type = c("partial dependence", "predicted", "response", "residual", "shap"), 20 | breaks = NULL, 21 | n_bins = 11L, 22 | cut_type = "equal", 23 | use_linkinv = TRUE, 24 | counts = TRUE, 25 | counts_weighted = FALSE, 26 | pd_evaluate_at = NULL, 27 | pd_grid = NULL, 28 | pd_indices = NULL, 29 | pd_n_max = 1000L, 30 | pd_seed = NULL, 31 | ... 32 | ) 33 | 34 | \method{light_profile2d}{multiflashlight}( 35 | x, 36 | v = NULL, 37 | data = NULL, 38 | type = c("partial dependence", "predicted", "response", "residual", "shap"), 39 | breaks = NULL, 40 | n_bins = 11L, 41 | cut_type = "equal", 42 | pd_evaluate_at = NULL, 43 | pd_grid = NULL, 44 | ... 45 | ) 46 | } 47 | \arguments{ 48 | \item{x}{An object of class "flashlight" or "multiflashlight".} 49 | 50 | \item{...}{Further arguments passed to \code{\link[=formatC]{formatC()}} in forming 51 | the cut breaks of the \code{v} variables. Not relevant for partial dependence profiles.} 52 | 53 | \item{v}{A vector of exactly two variable names to be profiled.} 54 | 55 | \item{data}{An optional \code{data.frame}.} 56 | 57 | \item{by}{An optional vector of column names used to additionally group the results.} 58 | 59 | \item{type}{Type of the profile: Either "partial dependence", "predicted", 60 | "response", or "residual".} 61 | 62 | \item{breaks}{Named list of cut breaks specifying how to bin one or more numeric 63 | variables. Used to overwrite automatic binning via \code{n_bins} and \code{cut_type}. 64 | Ignored for non-numeric \code{v}.} 65 | 66 | \item{n_bins}{Approximate number of unique values to evaluate for numeric \code{v}. 67 | Can be an unnamed vector of length 2 to distinguish between v.} 68 | 69 | \item{cut_type}{Should numeric \code{v} be cut into "equal" or "quantile" bins? 70 | Can be an unnamed vector of length 2 to distinguish between v.} 71 | 72 | \item{use_linkinv}{Should retransformation function be applied? Default is \code{TRUE}.} 73 | 74 | \item{counts}{Should observation counts be added?} 75 | 76 | \item{counts_weighted}{If \code{counts} is TRUE: Should counts be weighted by the 77 | case weights? If \code{TRUE}, the sum of \code{w} is returned by group.} 78 | 79 | \item{pd_evaluate_at}{An named list of evaluation points for one or more variables. 80 | Only relevant for type = "partial dependence".} 81 | 82 | \item{pd_grid}{An evaluation \code{data.frame} with exactly two columns, 83 | e.g., generated by \code{\link[=expand.grid]{expand.grid()}}. Only used for type = "partial dependence". 84 | Offers maximal flexibility.} 85 | 86 | \item{pd_indices}{A vector of row numbers to consider in calculating partial 87 | dependence profiles. Only used for type = "partial dependence".} 88 | 89 | \item{pd_n_max}{Maximum number of ICE profiles to calculate 90 | (will be randomly picked from \code{data}). Only used for type = "partial dependence".} 91 | 92 | \item{pd_seed}{Integer random seed used to select ICE profiles. 93 | Only used for type = "partial dependence".} 94 | } 95 | \value{ 96 | An object of class "light_profile2d" with the following elements: 97 | \itemize{ 98 | \item \code{data} A tibble containing results. 99 | \item \code{by} Names of group by variables. 100 | \item \code{v} The two variable names evaluated. 101 | \item \code{type} Same as input \code{type}. For information only. 102 | } 103 | } 104 | \description{ 105 | Calculates different types of 2D-profiles across two variables. 106 | By default, partial dependence profiles are calculated (see Friedman). 107 | Other options are response, predicted values, and residuals. 108 | The results are aggregated by (weighted) means. 109 | } 110 | \details{ 111 | Different binning options are available, see arguments below. 112 | For high resolution partial dependence plots, it might be necessary to specify 113 | \code{breaks}, \code{pd_evaluate_at} or \code{pd_grid} in order to avoid empty parts 114 | in the plot. A high value of \code{n_bins} might not have the desired effect as it 115 | internally capped at the number of distinct values of a variable. 116 | 117 | For partial dependence and prediction profiles, "model", "predict_function", 118 | "linkinv" and "data" are required. For response profiles it is "y", "linkinv" 119 | and "data". "data" can also be passed on the fly. 120 | } 121 | \section{Methods (by class)}{ 122 | \itemize{ 123 | \item \code{light_profile2d(default)}: Default method not implemented yet. 124 | 125 | \item \code{light_profile2d(flashlight)}: 2D profiles for flashlight. 126 | 127 | \item \code{light_profile2d(multiflashlight)}: 2D profiles for multiflashlight. 128 | 129 | }} 130 | \examples{ 131 | fit_part <- lm(Sepal.Length ~ Species + Petal.Length, data = iris) 132 | fl_part <- flashlight( 133 | model = fit_part, label = "part", data = iris, y = "Sepal.Length" 134 | ) 135 | 136 | # No effect of Petal.Width 137 | plot(light_profile2d(fl_part, v = c("Petal.Length", "Petal.Width"))) 138 | 139 | # Second model includes Petal.Width 140 | fit_full <- lm(Sepal.Length ~ ., data = iris) 141 | fl_full <- flashlight( 142 | model = fit_full, label = "full", data = iris, y = "Sepal.Length" 143 | ) 144 | fls <- multiflashlight(list(fl_part, fl_full)) 145 | 146 | plot(light_profile2d(fls, v = c("Petal.Length", "Petal.Width"))) 147 | } 148 | \references{ 149 | Friedman J. H. (2001). Greedy function approximation: A gradient boosting machine. 150 | The Annals of Statistics, 29:1189–1232. 151 | } 152 | \seealso{ 153 | \code{\link[=light_profile]{light_profile()}}, \code{\link[=plot.light_profile2d]{plot.light_profile2d()}} 154 | } 155 | -------------------------------------------------------------------------------- /man/light_recode.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/aa_deprecated.R, R/light_recode.R 3 | \name{light_recode} 4 | \alias{light_recode} 5 | \title{DEPRECATED} 6 | \usage{ 7 | light_recode(...) 8 | 9 | light_recode(...) 10 | } 11 | \arguments{ 12 | \item{...}{Deprecated.} 13 | } 14 | \value{ 15 | Error message. 16 | 17 | Deprecated. 18 | } 19 | \description{ 20 | DEPRECATED 21 | 22 | Recode Factor Columns - DEPRECATED 23 | } 24 | -------------------------------------------------------------------------------- /man/light_scatter.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_scatter.R 3 | \name{light_scatter} 4 | \alias{light_scatter} 5 | \alias{light_scatter.default} 6 | \alias{light_scatter.flashlight} 7 | \alias{light_scatter.multiflashlight} 8 | \title{Scatter Plot Data} 9 | \usage{ 10 | light_scatter(x, ...) 11 | 12 | \method{light_scatter}{default}(x, ...) 13 | 14 | \method{light_scatter}{flashlight}( 15 | x, 16 | v, 17 | data = x$data, 18 | by = x$by, 19 | type = c("predicted", "response", "residual", "shap"), 20 | use_linkinv = TRUE, 21 | n_max = 400, 22 | seed = NULL, 23 | ... 24 | ) 25 | 26 | \method{light_scatter}{multiflashlight}(x, ...) 27 | } 28 | \arguments{ 29 | \item{x}{An object of class "flashlight" or "multiflashlight".} 30 | 31 | \item{...}{Further arguments passed from or to other methods.} 32 | 33 | \item{v}{The variable name to be shown on the x-axis.} 34 | 35 | \item{data}{An optional \code{data.frame}.} 36 | 37 | \item{by}{An optional vector of column names used to additionally group the results.} 38 | 39 | \item{type}{Type of the profile: Either "predicted", "response", or "residual".} 40 | 41 | \item{use_linkinv}{Should retransformation function be applied? Default is \code{TRUE}.} 42 | 43 | \item{n_max}{Maximum number of data rows to select. Will be randomly picked.} 44 | 45 | \item{seed}{An integer random seed used for subsampling.} 46 | } 47 | \value{ 48 | An object of class "light_scatter" with the following elements: 49 | \itemize{ 50 | \item \code{data}: A tibble with results. 51 | \item \code{by}: Same as input \code{by}. 52 | \item \code{v}: The variable name evaluated. 53 | \item \code{type}: Same as input \code{type}. For information only. 54 | } 55 | } 56 | \description{ 57 | This function prepares values for drawing a scatter plot of predicted values, 58 | responses, or residuals against a selected variable. 59 | } 60 | \section{Methods (by class)}{ 61 | \itemize{ 62 | \item \code{light_scatter(default)}: Default method not implemented yet. 63 | 64 | \item \code{light_scatter(flashlight)}: Variable profile for a flashlight. 65 | 66 | \item \code{light_scatter(multiflashlight)}: light_scatter for a multiflashlight. 67 | 68 | }} 69 | \examples{ 70 | fit_a <- lm(Sepal.Length ~ . -Petal.Length, data = iris) 71 | fit_b <- lm(Sepal.Length ~ ., data = iris) 72 | 73 | fl_a <- flashlight(model = fit_a, label = "no Petal.Length") 74 | fl_b <- flashlight(model = fit_b, label = "all") 75 | fls <- multiflashlight(list(fl_a, fl_b), data = iris, y = "Sepal.Length") 76 | 77 | plot(light_scatter(fls, v = "Petal.Width"), color = "darkred") 78 | 79 | sc <- light_scatter(fls, "Petal.Length", by = "Species", type = "residual") 80 | plot(sc) 81 | } 82 | \seealso{ 83 | \code{\link[=plot.light_scatter]{plot.light_scatter()}} 84 | } 85 | -------------------------------------------------------------------------------- /man/most_important.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_importance.R 3 | \name{most_important} 4 | \alias{most_important} 5 | \title{Most Important Variables.} 6 | \usage{ 7 | most_important(x, top_m = Inf) 8 | } 9 | \arguments{ 10 | \item{x}{An object of class "light_importance".} 11 | 12 | \item{top_m}{Maximum number of important variables to be returned.} 13 | } 14 | \value{ 15 | A character vector of variable names sorted in descending importance. 16 | } 17 | \description{ 18 | Returns the most important variable names sorted descendingly. 19 | } 20 | \examples{ 21 | fit <- lm(Sepal.Length ~ ., data = iris) 22 | fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") 23 | imp <- light_importance(fl) 24 | most_important(imp) 25 | most_important(imp, top_m = 2) 26 | } 27 | \seealso{ 28 | \code{\link[=light_importance]{light_importance()}} 29 | } 30 | -------------------------------------------------------------------------------- /man/multiflashlight.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/multiflashlight.R 3 | \name{multiflashlight} 4 | \alias{multiflashlight} 5 | \alias{multiflashlight.default} 6 | \alias{multiflashlight.flashlight} 7 | \alias{multiflashlight.list} 8 | \alias{multiflashlight.multiflashlight} 9 | \title{Create or Update a multiflashlight} 10 | \usage{ 11 | multiflashlight(x, ...) 12 | 13 | \method{multiflashlight}{default}(x, ...) 14 | 15 | \method{multiflashlight}{flashlight}(x, ...) 16 | 17 | \method{multiflashlight}{list}(x, ...) 18 | 19 | \method{multiflashlight}{multiflashlight}(x, ...) 20 | } 21 | \arguments{ 22 | \item{x}{An object of class "multiflashlight", "flashlight" or a list of flashlights.} 23 | 24 | \item{...}{Optional arguments in the flashlights to update, see examples.} 25 | } 26 | \value{ 27 | An object of class "multiflashlight" (a named list of flashlight objects). 28 | } 29 | \description{ 30 | Combines a list of flashlights to an object of class "multiflashlight" 31 | and/or updates a multiflashlight. 32 | } 33 | \section{Methods (by class)}{ 34 | \itemize{ 35 | \item \code{multiflashlight(default)}: Used to create a flashlight object. 36 | No \code{x} has to be passed in this case. 37 | 38 | \item \code{multiflashlight(flashlight)}: Updates an existing flashlight object and turns 39 | into a multiflashlight. 40 | 41 | \item \code{multiflashlight(list)}: Creates (and updates) a multiflashlight from a list 42 | of flashlights. 43 | 44 | \item \code{multiflashlight(multiflashlight)}: Updates an object of class "multiflashlight". 45 | 46 | }} 47 | \examples{ 48 | fit_lm <- lm(Sepal.Length ~ ., data = iris) 49 | fit_glm <- glm(Sepal.Length ~ ., family = Gamma(link = log), data = iris) 50 | mod_lm <- flashlight(model = fit_lm, label = "lm") 51 | mod_glm <- flashlight(model = fit_glm, label = "glm") 52 | (mods <- multiflashlight(list(mod_lm, mod_glm))) 53 | } 54 | \seealso{ 55 | \code{\link[=flashlight]{flashlight()}} 56 | } 57 | -------------------------------------------------------------------------------- /man/plot.light_breakdown.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_breakdown.R 3 | \name{plot.light_breakdown} 4 | \alias{plot.light_breakdown} 5 | \title{Visualize Variable Contribution Breakdown for Single Observation} 6 | \usage{ 7 | \method{plot}{light_breakdown}(x, facet_scales = "free", facet_ncol = 1, rotate_x = FALSE, ...) 8 | } 9 | \arguments{ 10 | \item{x}{An object of class "light_breakdown".} 11 | 12 | \item{facet_scales}{Scales argument passed to \code{\link[ggplot2:facet_wrap]{ggplot2::facet_wrap()}}.} 13 | 14 | \item{facet_ncol}{\code{ncol} argument passed to \code{\link[ggplot2:facet_wrap]{ggplot2::facet_wrap()}}.} 15 | 16 | \item{rotate_x}{Should x axis labels be rotated by 45 degrees?} 17 | 18 | \item{...}{Further arguments passed to \code{\link[ggplot2:geom_text]{ggplot2::geom_label()}}.} 19 | } 20 | \value{ 21 | An object of class "ggplot". 22 | } 23 | \description{ 24 | Minimal visualization of an object of class "light_breakdown" as waterfall plot. 25 | The object returned is of class "ggplot" and can be further customized. 26 | } 27 | \details{ 28 | The waterfall plot is to be read from top to bottom. 29 | The first line describes the (weighted) average prediction in the query data 30 | used to start with. Then, each additional line shows how the prediction changes 31 | due to the impact of the corresponding variable. 32 | The last line finally shows the original prediction of the selected observation. 33 | Multiple flashlights are shown in different facets. 34 | Positive and negative impacts are visualized with different colors. 35 | } 36 | \seealso{ 37 | \code{\link[=light_breakdown]{light_breakdown()}} 38 | } 39 | -------------------------------------------------------------------------------- /man/plot.light_effects.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_effects.R 3 | \name{plot.light_effects} 4 | \alias{plot.light_effects} 5 | \title{Visualize Multiple Types of Profiles Together} 6 | \usage{ 7 | \method{plot}{light_effects}( 8 | x, 9 | use = c("response", "predicted", "pd"), 10 | zero_counts = TRUE, 11 | size_factor = 1, 12 | facet_scales = "free_x", 13 | facet_nrow = 1L, 14 | rotate_x = TRUE, 15 | show_points = TRUE, 16 | recode_labels = NULL, 17 | ... 18 | ) 19 | } 20 | \arguments{ 21 | \item{x}{An object of class "light_effects".} 22 | 23 | \item{use}{A vector of elements to show. Any subset of ("response", "predicted", 24 | "pd", "ale") or "all". Defaults to all except "ale"} 25 | 26 | \item{zero_counts}{Logical flag if 0 count levels should be shown on the x axis.} 27 | 28 | \item{size_factor}{Factor used to enlarge default \code{size/linewidth} in 29 | \code{\link[ggplot2:geom_point]{ggplot2::geom_point()}} and \code{\link[ggplot2:geom_path]{ggplot2::geom_line()}}.} 30 | 31 | \item{facet_scales}{Scales argument passed to \code{\link[ggplot2:facet_wrap]{ggplot2::facet_wrap()}}.} 32 | 33 | \item{facet_nrow}{Number of rows in \code{\link[ggplot2:facet_wrap]{ggplot2::facet_wrap()}}.} 34 | 35 | \item{rotate_x}{Should x axis labels be rotated by 45 degrees?} 36 | 37 | \item{show_points}{Should points be added to the line (default is \code{TRUE}).} 38 | 39 | \item{recode_labels}{Named vector of curve labels. The names refer to the usual 40 | labels, while the values are the desired labels, e.g., 41 | `c("partial dependence" = PDP", "ale" = "ALE").} 42 | 43 | \item{...}{Further arguments passed to geoms.} 44 | } 45 | \value{ 46 | An object of class "ggplot". 47 | } 48 | \description{ 49 | Visualizes response-, prediction-, partial dependence, and/or ALE profiles 50 | of a (multi-)flashlight with respect to a covariable \code{v}. 51 | Different flashlights or a single flashlight with one "by" variable are separated 52 | by a facet wrap. 53 | } 54 | \seealso{ 55 | \code{\link[=light_effects]{light_effects()}}, \code{\link[=plot_counts]{plot_counts()}} 56 | } 57 | -------------------------------------------------------------------------------- /man/plot.light_global_surrogate.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_global_surrogate.R 3 | \name{plot.light_global_surrogate} 4 | \alias{plot.light_global_surrogate} 5 | \title{Plot Global Surrogate Trees} 6 | \usage{ 7 | \method{plot}{light_global_surrogate}(x, type = 5, auto_main = TRUE, mfrow = NULL, ...) 8 | } 9 | \arguments{ 10 | \item{x}{An object of class "light_global_surrogate".} 11 | 12 | \item{type}{Plot type, see help of \code{\link[rpart.plot:rpart.plot]{rpart.plot::rpart.plot()}}. Default is 5.} 13 | 14 | \item{auto_main}{Automatic plot titles (only if multiple trees are shown).} 15 | 16 | \item{mfrow}{If multiple trees are shown in the same figure: 17 | what value of \code{mfrow} to use in \code{\link[graphics:par]{graphics::par()}}?} 18 | 19 | \item{...}{Further arguments passed to \code{\link[rpart.plot:rpart.plot]{rpart.plot::rpart.plot()}}.} 20 | } 21 | \value{ 22 | An object of class "ggplot". 23 | } 24 | \description{ 25 | Use \code{\link[rpart.plot:rpart.plot]{rpart.plot::rpart.plot()}} to visualize trees fitted by 26 | \code{\link[=light_global_surrogate]{light_global_surrogate()}}. 27 | } 28 | \seealso{ 29 | \code{\link[=light_global_surrogate]{light_global_surrogate()}} 30 | } 31 | -------------------------------------------------------------------------------- /man/plot.light_ice.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_ice.R 3 | \name{plot.light_ice} 4 | \alias{plot.light_ice} 5 | \title{Visualize ICE profiles} 6 | \usage{ 7 | \method{plot}{light_ice}(x, facet_scales = "fixed", rotate_x = FALSE, ...) 8 | } 9 | \arguments{ 10 | \item{x}{An object of class "light_ice".} 11 | 12 | \item{facet_scales}{Scales argument passed to \code{\link[ggplot2:facet_wrap]{ggplot2::facet_wrap()}}.} 13 | 14 | \item{rotate_x}{Should x axis labels be rotated by 45 degrees?} 15 | 16 | \item{...}{Further arguments passed to \code{\link[ggplot2:geom_path]{ggplot2::geom_line()}}.} 17 | } 18 | \value{ 19 | An object of class "ggplot". 20 | } 21 | \description{ 22 | Minimal visualization of an object of class "light_ice" as \code{\link[ggplot2:geom_path]{ggplot2::geom_line()}}. 23 | The object returned is of class "ggplot" and can be further customized. 24 | } 25 | \details{ 26 | Each observation is visualized by a line. The first "by" variable is represented 27 | by the color, a second "by" variable or a multiflashlight by facets. 28 | } 29 | \seealso{ 30 | \code{\link[=light_ice]{light_ice()}} 31 | } 32 | -------------------------------------------------------------------------------- /man/plot.light_importance.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_importance.R 3 | \name{plot.light_importance} 4 | \alias{plot.light_importance} 5 | \title{Visualize Variable Importance} 6 | \usage{ 7 | \method{plot}{light_importance}( 8 | x, 9 | top_m = Inf, 10 | swap_dim = FALSE, 11 | facet_scales = "fixed", 12 | rotate_x = FALSE, 13 | error_bars = TRUE, 14 | ... 15 | ) 16 | } 17 | \arguments{ 18 | \item{x}{An object of class "light_importance".} 19 | 20 | \item{top_m}{Maximum number of important variables to be returned.} 21 | 22 | \item{swap_dim}{If multiflashlight and one "by" variable or single flashlight with 23 | two "by" variables, swap the role of dodge/fill variable and facet variable. 24 | If multiflashlight or one "by" variable, use facets instead of colors.} 25 | 26 | \item{facet_scales}{Scales argument passed to \code{\link[ggplot2:facet_wrap]{ggplot2::facet_wrap()}}.} 27 | 28 | \item{rotate_x}{Should x axis labels be rotated by 45 degrees?} 29 | 30 | \item{error_bars}{Should error bars be added? Defaults to \code{TRUE}. 31 | Only available if \code{\link[=light_importance]{light_importance()}} was run with multiple permutations 32 | by setting \code{m_repetitions} > 1.} 33 | 34 | \item{...}{Further arguments passed to \code{\link[ggplot2:geom_bar]{ggplot2::geom_bar()}}.} 35 | } 36 | \value{ 37 | An object of class "ggplot". 38 | } 39 | \description{ 40 | Visualization of an object of class "light_importance" via \code{\link[ggplot2:geom_bar]{ggplot2::geom_bar()}}. 41 | If available, standard errors are added by \code{\link[ggplot2:geom_linerange]{ggplot2::geom_errorbar()}}. 42 | The object returned is of class "ggplot" and can be further customized. 43 | } 44 | \details{ 45 | The plot is organized as a bar plot with variable names as x-aesthetic. 46 | Up to two additional dimensions (multiflashlight and one "by" variable or single 47 | flashlight with two "by" variables) can be visualized by facetting and dodge/fill. 48 | Set \code{swap_dim = FALSE} to revert the role of these two dimensions. 49 | One single additional dimension is visualized by a facet wrap, 50 | or - if \code{swap_dim = FALSE} - by dodge/fill. 51 | } 52 | \seealso{ 53 | \code{\link[=light_importance]{light_importance()}} 54 | } 55 | -------------------------------------------------------------------------------- /man/plot.light_performance.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_performance.R 3 | \name{plot.light_performance} 4 | \alias{plot.light_performance} 5 | \title{Visualize Model Performance} 6 | \usage{ 7 | \method{plot}{light_performance}( 8 | x, 9 | swap_dim = FALSE, 10 | geom = c("bar", "point"), 11 | facet_scales = "free_y", 12 | rotate_x = FALSE, 13 | ... 14 | ) 15 | } 16 | \arguments{ 17 | \item{x}{An object of class "light_performance".} 18 | 19 | \item{swap_dim}{Should representation of dimensions 20 | (either two "by" variables or one "by" variable and multiflashlight) 21 | of x aesthetic and dodge fill aesthetic be swapped? Default is \code{FALSE}.} 22 | 23 | \item{geom}{Geometry of plot (either "bar" or "point")} 24 | 25 | \item{facet_scales}{Scales argument passed to \code{\link[ggplot2:facet_wrap]{ggplot2::facet_wrap()}}.} 26 | 27 | \item{rotate_x}{Should x axis labels be rotated by 45 degrees?} 28 | 29 | \item{...}{Further arguments passed to \code{\link[ggplot2:geom_bar]{ggplot2::geom_bar()}} or 30 | \code{\link[ggplot2:geom_point]{ggplot2::geom_point()}}.} 31 | } 32 | \value{ 33 | An object of class "ggplot". 34 | } 35 | \description{ 36 | Minimal visualization of an object of class "light_performance" as 37 | \code{\link[ggplot2:geom_bar]{ggplot2::geom_bar()}}. The object returned has class "ggplot", 38 | and can be further customized. 39 | } 40 | \details{ 41 | The plot is organized as a bar plot as follows: 42 | For flashlights without "by" variable specified, a single bar is drawn. 43 | Otherwise, the "by" variable (or the flashlight label if there is no "by" variable) 44 | is represented by the "x" aesthetic. 45 | 46 | The flashlight label (in case of one "by" variable) is represented by dodged bars. 47 | This strategy makes sure that performance of different flashlights can 48 | be compared easiest. Set "swap_dim = TRUE" to revert the role of dodging and x 49 | aesthetic. Different metrics are always represented by facets. 50 | } 51 | \seealso{ 52 | \code{\link[=light_performance]{light_performance()}} 53 | } 54 | -------------------------------------------------------------------------------- /man/plot.light_profile.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_profile.R 3 | \name{plot.light_profile} 4 | \alias{plot.light_profile} 5 | \title{Visualize Profiles, e.g. Partial Dependence} 6 | \usage{ 7 | \method{plot}{light_profile}( 8 | x, 9 | swap_dim = FALSE, 10 | facet_scales = "free_x", 11 | rotate_x = x$type != "partial dependence", 12 | show_points = TRUE, 13 | ... 14 | ) 15 | } 16 | \arguments{ 17 | \item{x}{An object of class "light_profile".} 18 | 19 | \item{swap_dim}{If multiflashlight and one "by" variable or 20 | single flashlight with two "by" variables, swap the role of dodge/fill variable 21 | and facet variable. If multiflashlight or one "by" variable, 22 | use facets instead of colors.} 23 | 24 | \item{facet_scales}{Scales argument passed to \code{\link[ggplot2:facet_wrap]{ggplot2::facet_wrap()}}.} 25 | 26 | \item{rotate_x}{Should x axis labels be rotated by 45 degrees?} 27 | 28 | \item{show_points}{Should points be added to the line (default is \code{TRUE}).} 29 | 30 | \item{...}{Further arguments passed to \code{\link[ggplot2:geom_point]{ggplot2::geom_point()}} or 31 | \code{\link[ggplot2:geom_path]{ggplot2::geom_line()}}.} 32 | } 33 | \value{ 34 | An object of class "ggplot". 35 | } 36 | \description{ 37 | Minimal visualization of an object of class "light_profile". 38 | The object returned is of class "ggplot" and can be further customized. 39 | } 40 | \details{ 41 | Either lines and points are plotted (if stats = "mean") or quartile boxes. 42 | If there is a "by" variable or a multiflashlight, this first dimension 43 | is represented by color (or if \code{swap_dim = TRUE} by facets). 44 | If there are two "by" variables or a multiflashlight with one "by" variable, 45 | the first "by" variable is visualized as color, while the second one 46 | or the multiflashlight is shown via facet (change with \code{swap_dim}). 47 | } 48 | \seealso{ 49 | \code{\link[=light_profile]{light_profile()}}, \code{\link[=plot.light_effects]{plot.light_effects()}} 50 | } 51 | -------------------------------------------------------------------------------- /man/plot.light_profile2d.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_profile2d.R 3 | \name{plot.light_profile2d} 4 | \alias{plot.light_profile2d} 5 | \title{Visualize 2D-Profiles, e.g., of Partial Dependence} 6 | \usage{ 7 | \method{plot}{light_profile2d}(x, swap_dim = FALSE, rotate_x = TRUE, numeric_as_factor = FALSE, ...) 8 | } 9 | \arguments{ 10 | \item{x}{An object of class "light_profile2d".} 11 | 12 | \item{swap_dim}{Swap the \code{\link[ggplot2:facet_grid]{ggplot2::facet_grid()}} dimensions.} 13 | 14 | \item{rotate_x}{Should the x axis labels be rotated by 45 degrees? Default is \code{TRUE}.} 15 | 16 | \item{numeric_as_factor}{Should numeric x and y values be converted to factors first? 17 | Default is \code{FALSE}. Useful if \code{cut_type} was not set to "equal".} 18 | 19 | \item{...}{Further arguments passed to \code{\link[ggplot2:geom_tile]{ggplot2::geom_tile()}}.} 20 | } 21 | \value{ 22 | An object of class "ggplot". 23 | } 24 | \description{ 25 | Minimal visualization of an object of class "light_profile2d". 26 | The object returned is of class "ggplot" and can be further customized. 27 | } 28 | \details{ 29 | The main geometry is \code{\link[ggplot2:geom_tile]{ggplot2::geom_tile()}}. Additional dimensions 30 | ("by" variable(s) and/or multiflashlight) are represented by \code{facet_wrap/grid}. 31 | For all types of profiles except "partial dependence", it is natural to see 32 | empty parts in the plot. These are combinations of the \code{v} variables that 33 | do not appear in the data. Even for type "partial dependence", such gaps can occur, 34 | e.g. for \code{cut_type = "quantile"} or if \code{n_bins} are larger than the number 35 | of distinct values of a \code{v} variable. 36 | Such gaps can be suppressed by setting \code{numeric_as_factor = TRUE} 37 | or by using the arguments \code{breaks}, \code{pd_evaluate_at} or \code{pd_grid} in 38 | \code{\link[=light_profile2d]{light_profile2d()}}. 39 | } 40 | \seealso{ 41 | \code{\link[=light_profile2d]{light_profile2d()}} 42 | } 43 | -------------------------------------------------------------------------------- /man/plot.light_scatter.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/light_scatter.R 3 | \name{plot.light_scatter} 4 | \alias{plot.light_scatter} 5 | \title{Scatter Plot} 6 | \usage{ 7 | \method{plot}{light_scatter}(x, swap_dim = FALSE, facet_scales = "free_x", rotate_x = FALSE, ...) 8 | } 9 | \arguments{ 10 | \item{x}{An object of class "light_scatter".} 11 | 12 | \item{swap_dim}{If multiflashlight and one "by" variable, or single flashlight 13 | with two "by" variables, swap the role of color variable and facet variable. 14 | If multiflashlight or one "by" variable, use colors instead of facets.} 15 | 16 | \item{facet_scales}{Scales argument passed to \code{\link[ggplot2:facet_wrap]{ggplot2::facet_wrap()}}.} 17 | 18 | \item{rotate_x}{Should x axis labels be rotated by 45 degrees?} 19 | 20 | \item{...}{Further arguments passed to \code{\link[ggplot2:geom_point]{ggplot2::geom_point()}}. Typical arguments 21 | would be \code{alpha = 0.2} or \code{position = "jitter"} to avoid overplotting.} 22 | } 23 | \value{ 24 | An object of class "ggplot". 25 | } 26 | \description{ 27 | Values are plotted against a variable. The object returned is of class "ggplot" 28 | and can be further customized. To avoid overplotting, try \code{alpha = 0.2} or 29 | \code{position = "jitter"}. 30 | } 31 | \seealso{ 32 | \code{\link[=light_scatter]{light_scatter()}} 33 | } 34 | -------------------------------------------------------------------------------- /man/plot_counts.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/aa_deprecated.R 3 | \name{plot_counts} 4 | \alias{plot_counts} 5 | \title{DEPRECATED} 6 | \usage{ 7 | plot_counts(...) 8 | } 9 | \arguments{ 10 | \item{...}{Any input.} 11 | } 12 | \value{ 13 | Error message. 14 | } 15 | \description{ 16 | DEPRECATED 17 | } 18 | -------------------------------------------------------------------------------- /man/predict.flashlight.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/methods.R 3 | \name{predict.flashlight} 4 | \alias{predict.flashlight} 5 | \title{Predictions for flashlight} 6 | \usage{ 7 | \method{predict}{flashlight}(object, ...) 8 | } 9 | \arguments{ 10 | \item{object}{An object of class "flashlight".} 11 | 12 | \item{...}{Arguments used to update the flashlight.} 13 | } 14 | \value{ 15 | A vector with predictions. 16 | } 17 | \description{ 18 | Predict method for an object of class "flashlight". 19 | Pass additional elements to update the flashlight, typically \code{data}. 20 | } 21 | \examples{ 22 | fit <- lm(Sepal.Length ~ ., data = iris) 23 | fl <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols") 24 | predict(fl)[1:5] 25 | predict(fl, data = iris[1:5, ]) 26 | } 27 | -------------------------------------------------------------------------------- /man/predict.multiflashlight.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/methods.R 3 | \name{predict.multiflashlight} 4 | \alias{predict.multiflashlight} 5 | \title{Predictions for multiflashlight} 6 | \usage{ 7 | \method{predict}{multiflashlight}(object, ...) 8 | } 9 | \arguments{ 10 | \item{object}{An object of class "multiflashlight".} 11 | 12 | \item{...}{Arguments used to update the multiflashlight.} 13 | } 14 | \value{ 15 | A named list of prediction vectors. 16 | } 17 | \description{ 18 | Predict method for an object of class "multiflashlight". 19 | Pass additional elements to update the flashlight, typically \code{data}. 20 | } 21 | \examples{ 22 | fit_part <- lm(Sepal.Length ~ Petal.Length, data = iris) 23 | fit_full <- lm(Sepal.Length ~ ., data = iris) 24 | mod_full <- flashlight(model = fit_full, label = "full") 25 | mod_part <- flashlight(model = fit_part, label = "part") 26 | mods <- multiflashlight(list(mod_full, mod_part), data = iris, y = "Sepal.Length") 27 | predict(mods, data = iris[1:5, ]) 28 | } 29 | -------------------------------------------------------------------------------- /man/print.flashlight.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/methods.R 3 | \name{print.flashlight} 4 | \alias{print.flashlight} 5 | \title{Prints a flashlight} 6 | \usage{ 7 | \method{print}{flashlight}(x, ...) 8 | } 9 | \arguments{ 10 | \item{x}{A on object of class "flashlight".} 11 | 12 | \item{...}{Further arguments passed from other methods.} 13 | } 14 | \value{ 15 | Invisibly, the input is returned. 16 | } 17 | \description{ 18 | Print method for an object of class "flashlight". 19 | } 20 | \examples{ 21 | fit <- lm(Sepal.Length ~ ., data = iris) 22 | x <- flashlight(model = fit, label = "lm", y = "Sepal.Length", data = iris) 23 | x 24 | } 25 | \seealso{ 26 | \code{\link[=flashlight]{flashlight()}} 27 | } 28 | -------------------------------------------------------------------------------- /man/print.light.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/methods.R 3 | \name{print.light} 4 | \alias{print.light} 5 | \title{Prints light Object} 6 | \usage{ 7 | \method{print}{light}(x, ...) 8 | } 9 | \arguments{ 10 | \item{x}{A on object of class "light".} 11 | 12 | \item{...}{Further arguments passed from other methods.} 13 | } 14 | \value{ 15 | Invisibly, the input is returned. 16 | } 17 | \description{ 18 | Print method for an object of class "light". 19 | } 20 | \examples{ 21 | fit <- lm(Sepal.Length ~ ., data = iris) 22 | fl <- flashlight(model = fit, label = "lm", y = "Sepal.Length", data = iris) 23 | light_performance(fl, v = "Species") 24 | } 25 | -------------------------------------------------------------------------------- /man/print.multiflashlight.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/methods.R 3 | \name{print.multiflashlight} 4 | \alias{print.multiflashlight} 5 | \title{Prints a multiflashlight} 6 | \usage{ 7 | \method{print}{multiflashlight}(x, ...) 8 | } 9 | \arguments{ 10 | \item{x}{An object of class "multiflashlight".} 11 | 12 | \item{...}{Further arguments passed to \code{\link[=print.flashlight]{print.flashlight()}}.} 13 | } 14 | \value{ 15 | Invisibly, the input is returned. 16 | } 17 | \description{ 18 | Print method for an object of class "multiflashlight". 19 | } 20 | \examples{ 21 | fit_lm <- lm(Sepal.Length ~ ., data = iris) 22 | fit_glm <- glm(Sepal.Length ~ ., family = Gamma(link = log), data = iris) 23 | fl_lm <- flashlight(model = fit_lm, label = "lm") 24 | fl_glm <- flashlight(model = fit_glm, label = "glm") 25 | multiflashlight(list(fl_lm, fl_glm), data = iris) 26 | } 27 | \seealso{ 28 | \code{\link[=multiflashlight]{multiflashlight()}} 29 | } 30 | -------------------------------------------------------------------------------- /man/residuals.flashlight.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/methods.R 3 | \name{residuals.flashlight} 4 | \alias{residuals.flashlight} 5 | \title{Residuals for flashlight} 6 | \usage{ 7 | \method{residuals}{flashlight}(object, ...) 8 | } 9 | \arguments{ 10 | \item{object}{An object of class "flashlight".} 11 | 12 | \item{...}{Arguments used to update the flashlight before calculating the residuals.} 13 | } 14 | \value{ 15 | A numeric vector with residuals. 16 | } 17 | \description{ 18 | Residuals method for an object of class "flashlight". 19 | Pass additional elements to update the flashlight before calculation of residuals. 20 | } 21 | \examples{ 22 | fit <- lm(Sepal.Length ~ ., data = iris) 23 | x <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols") 24 | residuals(x)[1:5] 25 | } 26 | -------------------------------------------------------------------------------- /man/residuals.multiflashlight.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/methods.R 3 | \name{residuals.multiflashlight} 4 | \alias{residuals.multiflashlight} 5 | \title{Residuals for multiflashlight} 6 | \usage{ 7 | \method{residuals}{multiflashlight}(object, ...) 8 | } 9 | \arguments{ 10 | \item{object}{An object of class "multiflashlight".} 11 | 12 | \item{...}{Arguments used to update the multiflashlight before 13 | calculating the residuals.} 14 | } 15 | \value{ 16 | A named list with residuals per flashlight. 17 | } 18 | \description{ 19 | Residuals method for an object of class "multiflashlight". 20 | Pass additional elements to update the multiflashlight before calculation of 21 | residuals. 22 | } 23 | \examples{ 24 | fit_part <- lm(Sepal.Length ~ Petal.Length, data = iris) 25 | fit_full <- lm(Sepal.Length ~ ., data = iris) 26 | mod_full <- flashlight(model = fit_full, label = "full") 27 | mod_part <- flashlight(model = fit_part, label = "part") 28 | mods <- multiflashlight(list(mod_full, mod_part), data = iris, y = "Sepal.Length") 29 | residuals(mods, data = head(iris)) 30 | } 31 | -------------------------------------------------------------------------------- /man/response.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/methods.R 3 | \name{response} 4 | \alias{response} 5 | \alias{response.default} 6 | \alias{response.flashlight} 7 | \alias{response.multiflashlight} 8 | \title{Response of multi/-flashlight} 9 | \usage{ 10 | response(object, ...) 11 | 12 | \method{response}{default}(object, ...) 13 | 14 | \method{response}{flashlight}(object, ...) 15 | 16 | \method{response}{multiflashlight}(object, ...) 17 | } 18 | \arguments{ 19 | \item{object}{An object of class "flashlight".} 20 | 21 | \item{...}{Arguments used to update the flashlight before extracting the response.} 22 | } 23 | \value{ 24 | A numeric vector of responses. 25 | } 26 | \description{ 27 | Extracts response from object of class "flashlight". 28 | } 29 | \section{Methods (by class)}{ 30 | \itemize{ 31 | \item \code{response(default)}: Default method not implemented yet. 32 | 33 | \item \code{response(flashlight)}: Extract response from flashlight object. 34 | 35 | \item \code{response(multiflashlight)}: Extract responses from multiflashlight object. 36 | 37 | }} 38 | \examples{ 39 | fit <- lm(Sepal.Length ~ ., data = iris) 40 | (fl <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols")) 41 | response(fl)[1:5] 42 | response(fl, data = iris[1:5, ]) 43 | response(fl, data = iris[1:5, ], linkinv = exp) 44 | } 45 | -------------------------------------------------------------------------------- /packaging.R: -------------------------------------------------------------------------------- 1 | #============================================================================= 2 | # Put together the package 3 | #============================================================================= 4 | 5 | # WORKFLOW: UPDATE EXISTING PACKAGE 6 | # 1) Modify package content and documentation. 7 | # 2) Increase package number in "use_description" below. 8 | # 3) Go through this script and carefully answer "no" if a "use_*" function 9 | # asks to overwrite the existing files. Don't skip that function call. 10 | # devtools::load_all() 11 | 12 | library(usethis) 13 | 14 | # Sketch of description file 15 | use_description( 16 | fields = list( 17 | Title = "Shed Light on Black Box Machine Learning Models", 18 | Version = "0.9.0.9000", 19 | Description = "Shed light on black box machine learning models by the help of model 20 | performance, variable importance, global surrogate models, ICE profiles, 21 | partial dependence (Friedman J. H. (2001) ), 22 | accumulated local effects (Apley D. W. (2016) ), 23 | further effects plots, interaction strength, and variable contribution breakdown 24 | (Gosiewska and Biecek (2019) ). 25 | All tools are implemented to work with case weights and allow for stratified analysis. 26 | Furthermore, multiple flashlights can be combined and analyzed together.", 27 | `Authors@R` = "person('Michael', 'Mayer', email = 'mayermichael79@gmail.com', role = c('aut', 'cre', 'cph'))", 28 | Depends = "R (>= 3.2.0)", 29 | LazyData = NULL 30 | ), 31 | roxygen = TRUE 32 | ) 33 | 34 | # Imports 35 | use_package("dplyr", "Imports", min_version = "1.1.0") 36 | use_package("ggplot2", "Imports") 37 | use_package("MetricsWeighted", "Imports", min_version = "0.3.0") 38 | use_package("rlang", "Imports", min_version = "0.3.0") # dplyr 39 | use_package("rpart", "Imports") 40 | use_package("rpart.plot", "Imports") 41 | use_package("stats", "Imports") 42 | use_package("tibble", "Imports") # dplyr 43 | use_package("tidyr", "Imports", min_version = "1.0.0") 44 | use_package("tidyselect", "Imports") # dplyr 45 | use_package("utils", "Imports") 46 | 47 | use_gpl_license(2) 48 | 49 | use_github_links() # use this if this project is on github 50 | 51 | # Your files that do not belong to the package itself (others are added by "use_* function") 52 | use_build_ignore(c("^packaging.R$", "[.]Rproj$", "^logo.png$"), escape = FALSE) 53 | 54 | # If your code uses the pipe operator %>% 55 | # use_pipe() 56 | 57 | # If your package contains data. Google how to document 58 | # use_data() 59 | 60 | # Add short docu in Markdown (without running R code) 61 | use_readme_md() 62 | 63 | # Longer docu in RMarkdown (with running R code). Often quite similar to readme. 64 | use_vignette("flashlight") 65 | 66 | # If you want to add unit tests 67 | use_testthat() 68 | # use_test("test-eff.R") 69 | 70 | # On top of NEWS.md, describe changes made to the package 71 | use_news_md() 72 | 73 | # Add logo 74 | use_logo("logo.png") 75 | 76 | # If package goes to CRAN: infos (check results etc.) for CRAN 77 | use_cran_comments() 78 | 79 | # Github actions 80 | use_github_action("check-standard") 81 | use_github_action("test-coverage") 82 | use_github_action("pkgdown") 83 | 84 | # Revdep 85 | use_revdep() 86 | 87 | #============================================================================= 88 | # Finish package building (can use fresh session) 89 | #============================================================================= 90 | 91 | library(devtools) 92 | 93 | document() 94 | test() 95 | check(manual = TRUE, cran = TRUE) 96 | build(vignettes = FALSE) 97 | # build(binary = TRUE) 98 | install() 99 | 100 | # Run only if package is public(!) and should go to CRAN 101 | if (FALSE) { 102 | check_win_devel() 103 | check_rhub() 104 | 105 | # Takes long 106 | revdepcheck::revdep_check(num_workers = 4L) 107 | 108 | # Wait until above checks are passed without relevant notes/warnings 109 | # then submit to CRAN 110 | release() 111 | } 112 | -------------------------------------------------------------------------------- /revdep/.gitignore: -------------------------------------------------------------------------------- 1 | checks 2 | library 3 | checks.noindex 4 | library.noindex 5 | cloud.noindex 6 | data.sqlite 7 | *.html 8 | -------------------------------------------------------------------------------- /revdep/README.md: -------------------------------------------------------------------------------- 1 | # Platform 2 | 3 | |field |value | 4 | |:--------|:----------------------------------------------------| 5 | |version |R version 4.3.0 (2023-04-21 ucrt) | 6 | |os |Windows 11 x64 (build 22621) | 7 | |system |x86_64, mingw32 | 8 | |ui |RStudio | 9 | |language |(EN) | 10 | |collate |German_Switzerland.utf8 | 11 | |ctype |German_Switzerland.utf8 | 12 | |tz |Europe/Zurich | 13 | |date |2023-05-07 | 14 | |rstudio |2023.03.0+386 Cherry Blossom (desktop) | 15 | |pandoc |2.12 @ C:\Users\Michael\anaconda3\Scripts\pandoc.exe | 16 | 17 | # Dependencies 18 | 19 | |package |old |new |Δ | 20 | |:---------------|:------|:------|:--| 21 | |flashlight |0.8.0 |0.9.0 |* | 22 | |cli |3.6.1 |3.6.1 | | 23 | |colorspace |2.1-0 |2.1-0 | | 24 | |cowplot |1.1.1 |1.1.1 | | 25 | |cpp11 |0.4.3 |0.4.3 | | 26 | |dplyr |1.1.2 |1.1.2 | | 27 | |fansi |1.0.4 |1.0.4 | | 28 | |farver |2.1.1 |2.1.1 | | 29 | |generics |0.1.3 |0.1.3 | | 30 | |ggplot2 |3.4.2 |3.4.2 | | 31 | |glue |1.6.2 |1.6.2 | | 32 | |gtable |0.3.3 |0.3.3 | | 33 | |isoband |0.2.7 |0.2.7 | | 34 | |labeling |0.4.2 |0.4.2 | | 35 | |lifecycle |1.0.3 |1.0.3 | | 36 | |magrittr |2.0.3 |2.0.3 | | 37 | |MetricsWeighted |1.0.0 |1.0.0 | | 38 | |munsell |0.5.0 |0.5.0 | | 39 | |pillar |1.9.0 |1.9.0 | | 40 | |pkgconfig |2.0.3 |2.0.3 | | 41 | |purrr |1.0.1 |1.0.1 | | 42 | |R6 |2.5.1 |2.5.1 | | 43 | |RColorBrewer |1.1-3 |1.1-3 | | 44 | |rlang |1.1.1 |1.1.1 | | 45 | |rpart.plot |3.1.1 |3.1.1 | | 46 | |scales |1.2.1 |1.2.1 | | 47 | |stringi |1.7.12 |1.7.12 | | 48 | |stringr |1.5.0 |1.5.0 | | 49 | |tibble |3.2.1 |3.2.1 | | 50 | |tidyr |1.3.0 |1.3.0 | | 51 | |tidyselect |1.2.0 |1.2.0 | | 52 | |utf8 |1.2.3 |1.2.3 | | 53 | |vctrs |0.6.2 |0.6.2 | | 54 | |viridisLite |0.4.2 |0.4.2 | | 55 | |withr |2.5.0 |2.5.0 | | 56 | 57 | # Revdeps 58 | 59 | -------------------------------------------------------------------------------- /revdep/cran.md: -------------------------------------------------------------------------------- 1 | ## revdepcheck results 2 | 3 | We checked 1 reverse dependencies, comparing R CMD check results across CRAN and dev versions of this package. 4 | 5 | * We saw 0 new problems 6 | * We failed to check 0 packages 7 | 8 | -------------------------------------------------------------------------------- /revdep/email.yml: -------------------------------------------------------------------------------- 1 | release_date: ??? 2 | rel_release_date: ??? 3 | my_news_url: ??? 4 | release_version: ??? 5 | release_details: ??? 6 | -------------------------------------------------------------------------------- /revdep/failures.md: -------------------------------------------------------------------------------- 1 | *Wow, no problems at all. :)* -------------------------------------------------------------------------------- /revdep/problems.md: -------------------------------------------------------------------------------- 1 | *Wow, no problems at all. :)* -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(flashlight) 3 | 4 | test_check("flashlight") 5 | -------------------------------------------------------------------------------- /tests/testthat/tests-breakdown.R: -------------------------------------------------------------------------------- 1 | fit1 <- stats::lm(Sepal.Length ~ Petal.Width, data = iris) 2 | fit2 <- stats::lm(Sepal.Length ~ Petal.Width + Species + Sepal.Width, data = iris) 3 | fl1 <- flashlight(model = fit1, label = "small", data = iris, y = "Sepal.Length") 4 | fl2 <- flashlight(model = fit2, label = "large", data = iris, y = "Sepal.Length") 5 | fls <- multiflashlight(list(fl1, fl2)) 6 | 7 | test_that("basic functionality works", { 8 | br <- light_breakdown(fl1, iris[1L, ]) 9 | dat <- br$data 10 | expect_equal(dat$before_[-2L], dat$after_[-2L]) 11 | expect_equal(dat$after_[2L] - dat$before_[2L], -0.8879879, tolerance = 0.001) 12 | expect_s3_class(plot(br), "ggplot") 13 | }) 14 | 15 | test_that("light_breakdown reacts on v", { 16 | dat <- light_breakdown(fl1, iris[1L, ], v = "Petal.Width")$data 17 | expect_equal(nrow(dat), 3L) 18 | }) 19 | 20 | test_that("light_breakdown reacts on visit_strategy v", { 21 | br <- light_breakdown(fl1, iris[1L, ], visit_strategy = "v") 22 | dat <- br$data 23 | expect_equal(dat$before_[-4L], dat$after_[-4L]) 24 | expect_equal(dat$after_[4L] - dat$before_[4L], -0.8879879, tolerance = 0.001) 25 | expect_s3_class(plot(br), "ggplot") 26 | }) 27 | 28 | test_that("light_breakdown reacts on visit_strategy shap", { 29 | br <- light_breakdown(fl1, iris[1L, ], visit_strategy = "permutation", seed = 1L) 30 | dat <- br$data 31 | expect_equal(dat$before_[-4L], dat$after_[-4L]) 32 | expect_equal(dat$after_[4L] - dat$before_[4L], -0.8879879, tolerance = 0.001) 33 | expect_s3_class(plot(br), "ggplot") 34 | }) 35 | 36 | test_that("light_breakdown reacts on weights", { 37 | br <- light_breakdown(flashlight(fl1, w = "Petal.Length"), iris[1L, ]) 38 | dat <- br$data 39 | expect_equal(dat$before_[-2L], dat$after_[-2L]) 40 | expect_equal(dat$after_[2L] - dat$before_[2L], -1.192293, tolerance = 0.001) 41 | expect_s3_class(plot(br), "ggplot") 42 | }) 43 | 44 | test_that("light_breakdown reacts on by", { 45 | br <- light_breakdown(flashlight(fl1, by = "Species"), iris[1L, ]) 46 | dat <- br$data 47 | expect_equal(dat$before_[-2L], dat$after_[-2L]) 48 | expect_equal(dat$after_[2L] - dat$before_[2L], -0.04087469, tolerance = 0.001) 49 | expect_s3_class(plot(br), "ggplot") 50 | }) 51 | 52 | test_that("light_breakdown reacts on multiflashlight", { 53 | br <- light_breakdown(fls, iris[1L, ]) 54 | dat <- br$data 55 | expect_equal(nrow(dat), 6L * 2L) 56 | expect_false(all(dat[1:6, ] == dat[7:12, ])) 57 | expect_s3_class(plot(br), "ggplot") 58 | }) 59 | -------------------------------------------------------------------------------- /tests/testthat/tests-cut.R: -------------------------------------------------------------------------------- 1 | test_that("midpoints are working", { 2 | expect_equal(midpoints(1:2), 1.5) 3 | expect_error(midpoints(1)) 4 | expect_error(midpoints(c(1, NA))) 5 | }) 6 | 7 | test_that("cut3 is working", { 8 | expect_equal(levels(cut3(1:3, breaks = c(0, 1.5, 2.5, 4))), 9 | c("(0, 1.5]", "(1.5, 2.5]", "(2.5, 4]")) 10 | }) 11 | 12 | test_that("auto_cut is working", { 13 | x <- 1:10 14 | expect_equal(auto_cut(x, n_bins = 3)$breaks, c(0, 5, 10)) 15 | ac <- auto_cut(c(NA, x), n_bins = 3) 16 | expect_equal(dim(ac$data), c(11L, 2L)) 17 | expect_equal(ac$breaks, c(0, 5, 10)) 18 | expect_equal(ac$bin_means, c(2.5, 7.5, NA)) 19 | expect_equal(length(ac$bin_labels), 3L) 20 | expect_equal(auto_cut(x, cut_type = "quantile", n_bins = 3)$bin_means, 21 | c(2.5, 5.5, 8.5)) 22 | }) 23 | 24 | test_that("common_breaks give same results for split data as for combined data", { 25 | fit1 <- lm(Sepal.Length ~ ., data = iris) 26 | fit2 <- lm(Sepal.Length ~ Petal.Length, data = iris) 27 | fl1 <- flashlight(model = fit1, label = "full") 28 | fl2 <- flashlight(model = fit2, label = "single") 29 | fls <- multiflashlight(list(fl1, fl2), data = iris, y = "Sepal.Length") 30 | expect_equal(common_breaks(fls, v = "Petal.Length", data = NULL, 31 | cut_type = "quantile", n_bins = 2), 32 | c(1.0, 4.3, 6.9)) 33 | 34 | # Same result for multiflashlight with data distributed across flashlights 35 | fl1 <- flashlight(fls$full, data = iris[1:75, ]) 36 | fl2 <- flashlight(fls$single, data = iris[76:150, ]) 37 | fls2 <- multiflashlight(list(fl1, fl2)) 38 | expect_equal(common_breaks(fls, v = "Petal.Length", data = NULL, 39 | cut_type = "quantile", n_bins = 2), 40 | c(1.0, 4.3, 6.9)) 41 | }) 42 | 43 | 44 | -------------------------------------------------------------------------------- /tests/testthat/tests-eff.R: -------------------------------------------------------------------------------- 1 | fit <- stats::lm(Sepal.Length ~ Species + 0, data = iris) 2 | fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") 3 | 4 | test_that("light_profile works correctly for type response", { 5 | pr <- light_profile(fl, v = "Species", type = "response") 6 | expect_equal( 7 | pr$data$value_, 8 | stats::aggregate(Sepal.Length ~ Species, data = iris, FUN = mean)$Sepal.Length 9 | ) 10 | expect_s3_class(plot(pr), "ggplot") 11 | }) 12 | 13 | test_that("breaks work", { 14 | pr <- light_profile(fl, v = "Petal.Length", type = "response", breaks = c(1, 4, 7)) 15 | expect_equal( 16 | pr$data$value_, 17 | stats::aggregate( 18 | Sepal.Length ~ Petal.Length > 4, data = iris, FUN = mean)$Sepal.Length 19 | ) 20 | expect_s3_class(plot(pr), "ggplot") 21 | }) 22 | 23 | test_that("n_bins work", { 24 | pr <- light_profile(fl, v = "Petal.Length", type = "response", n_bins = 2) 25 | expect_equal(dim(pr$data), c(2L, 5L)) 26 | }) 27 | 28 | test_that("v_labels work", { 29 | pr <- light_profile( 30 | fl, v = "Petal.Length", type = "response", n_bins = 2, v_labels = FALSE 31 | ) 32 | expect_true(is.numeric(pr$data$Petal.Length)) 33 | }) 34 | 35 | test_that("light_profile works correctly for type predicted", { 36 | pr <- light_profile(fl, v = "Species", type = "predicted") 37 | expect_equal( 38 | pr$data$value_, 39 | stats::aggregate(Sepal.Length ~ Species, data = iris, FUN = mean)$Sepal.Length 40 | ) 41 | expect_s3_class(plot(pr), "ggplot") 42 | }) 43 | 44 | test_that("light_profile uses pred", { 45 | pr <- light_profile(fl, v = "Species", type = "predicted", pred = rep(1:3, each = 50)) 46 | expect_equal(pr$data$value_, 1:3) 47 | fls <- multiflashlight(list(fl, flashlight(fl, label = "lm2"))) 48 | expect_error( 49 | light_profile(fls, v = "Species", type = "predicted", pred = rep(1:3, each = 50)) 50 | ) 51 | }) 52 | 53 | test_that("light_profile works correctly for type residual", { 54 | pr <- light_profile(fl, v = "Species", type = "residual") 55 | expect_equal(pr$data$value_, c(0, 0, 0)) 56 | expect_s3_class(plot(pr), "ggplot") 57 | }) 58 | 59 | test_that("partial dependence is the same as ice", { 60 | pr <- light_profile(fl, v = "Species", pd_indices = 1) 61 | ice <- light_ice(fl, v = "Species", indices = 1) 62 | expect_equal(pr$data$value_, unname(ice$data$value_)) 63 | }) 64 | 65 | test_that("partial dependence gives correct output for model with one covariable", { 66 | pr <- light_profile(fl, v = "Species") 67 | expect_equal(as.numeric(pr$data$value_), as.numeric(coef(fit))) 68 | }) 69 | 70 | test_that("partial dependence is constant if covariable not in model", { 71 | pr <- light_profile(fl, v = "Petal.Length") 72 | expect_true(var(pr$data$value_) == 0) 73 | }) 74 | 75 | test_that("ale gives correct output for model with one covariable", { 76 | pr <- light_profile(fl, v = "Species", type = "ale") 77 | expect_equal(pr$data$value_, as.numeric(coef(fit))) 78 | }) 79 | 80 | test_that("ale is constant if covariable not in model", { 81 | pr <- light_profile(fl, v = "Petal.Length", type = "ale") 82 | expect_true(var(pr$data$value_) == 0) 83 | }) 84 | 85 | test_that("light_profile works correctly for type 'response' with by variable", { 86 | fit <- stats::lm(Sepal.Length ~ Species + Petal.Length, data = iris) 87 | fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") 88 | pr <- light_profile( 89 | fl, 90 | v = "Petal.Length", 91 | type = "response", 92 | by = "Species", 93 | breaks = c(1, 4, 7) 94 | ) 95 | agg <- stats::aggregate( 96 | Sepal.Length ~ Species + (Petal.Length > 4), data = iris, FUN = mean 97 | ) 98 | pr_data <- pr$data # reframe() has different sort order... 99 | pr_data <- pr$data[order(pr$data$Species, pr$data$Petal.Length), ] 100 | expect_equal(pr_data$value_, agg$Sepal.Length) 101 | expect_s3_class(plot(pr), "ggplot") 102 | 103 | setosa1 <- light_profile( 104 | fl, v = "Petal.Length", breaks = c(1, 4, 7), data = iris[1:50, ] 105 | ) 106 | setosa2 <- light_profile(fl, v = "Petal.Length", breaks = c(1, 4, 7), by = "Species") 107 | expect_equal(setosa1$data$value_, setosa2$data$value_[1:2]) 108 | }) 109 | 110 | test_that("light_profile works correctly for type 'ale' with by variable", { 111 | fit <- stats::lm(Sepal.Length ~ Species + Petal.Width, data = iris) 112 | fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") 113 | ale <- light_profile(fl, v = "Petal.Width", type = "ale", by = "Species") 114 | expect_s3_class(plot(ale), "ggplot") 115 | expect_true(is.light(ale)) 116 | }) 117 | 118 | test_that("basic functionality works for multiflashlight", { 119 | fit1 <- stats::lm(Sepal.Length ~ Species + 0, data = iris) 120 | fl1 <- flashlight(model = fit1, label = "Species", data = iris, y = "Sepal.Length") 121 | fit2 <- stats::lm(Sepal.Length ~ 1, data = iris) 122 | fl2 <- flashlight(model = fit2, label = "Empty", data = iris, y = "Sepal.Length") 123 | fls <- multiflashlight(list(fl1, fl2)) 124 | 125 | pr <- light_profile(fls, v = "Species") 126 | expect_true(all(dim(pr$data) == c(6L, 5L))) 127 | expect_true(all(pr$data$value_[4:6] != pr$data$value_[1:3])) 128 | expect_equal(pr$data$value_[4:6], rep(mean(iris$Sepal.Length), 3)) 129 | expect_s3_class(plot(pr), "ggplot") 130 | 131 | pr <- light_profile(fls, v = "Species", type = "predicted") 132 | expect_true(all(dim(pr$data) == c(6L, 5L))) 133 | expect_true(all(pr$data$value_[4:6] != pr$data$value_[1:3])) 134 | expect_equal(pr$data$value_[4:6], rep(mean(iris$Sepal.Length), 3)) 135 | expect_s3_class(plot(pr), "ggplot") 136 | 137 | pr <- light_profile(fls, v = "Species", type = "ale") 138 | expect_true(all(dim(pr$data) == c(6L, 5L))) 139 | expect_true(all(pr$data$value_[4:6] != pr$data$value_[1:3])) 140 | expect_equal(pr$data$value_[4:6], rep(mean(iris$Sepal.Length), 3)) 141 | expect_s3_class(plot(pr), "ggplot") 142 | }) 143 | 144 | test_that("light_profile reacts on weights", { 145 | fit <- stats::lm(Sepal.Length ~ 1, data = iris) 146 | fl <- flashlight( 147 | model = fit, label = "empty", data = iris, y = "Sepal.Length", w = "Sepal.Width" 148 | ) 149 | pr <- light_profile(fl, v = "Species", type = "response") 150 | expect_equal( 151 | pr$data$value_[1L], with(iris[1:50, ], weighted.mean(Sepal.Length, Sepal.Width)) 152 | ) 153 | }) 154 | 155 | fit1 <- stats::lm(Sepal.Length ~ ., data = iris) 156 | fl1 <- flashlight(model = fit1, label = "full", data = iris, y = "Sepal.Length") 157 | fit2 <- stats::lm(Sepal.Length ~ 1, data = iris) 158 | fl2 <- flashlight(model = fit2, label = "Empty", data = iris, y = "Sepal.Length") 159 | fls <- multiflashlight(list(fl1, fl2)) 160 | 161 | test_that("light_effects works and produces same as individual profiles", { 162 | eff <- light_effects(fls, v = "Petal.Length") 163 | 164 | pd <- light_profile(fls, v = "Petal.Length", counts = FALSE) 165 | ale <- light_profile(fls, v = "Petal.Length", type = "ale", counts = FALSE) 166 | resp <- light_profile(fls, v = "Petal.Length", type = "response") 167 | predicted <- light_profile( 168 | fls, v = "Petal.Length", type = "predicted", counts = FALSE 169 | ) 170 | 171 | expect_equal(eff$pd[-1L], pd$data[-1L]) 172 | expect_equal(eff$ale[-1L], ale$data[-1L]) 173 | expect_equal(eff$response, resp$data) 174 | expect_equal(eff$predicted, predicted$data) 175 | 176 | expect_s3_class(plot(eff), "ggplot") 177 | }) 178 | 179 | test_that("light_effects works with 'by'", { 180 | fls2 <- flashlight(fls$full, by = "Species") 181 | eff <- light_effects(fls2, v = "Petal.Length") 182 | 183 | pd <- light_profile(fls2, v = "Petal.Length", counts = FALSE) 184 | ale <- light_profile(fls2, v = "Petal.Length", type = "ale", counts = FALSE) 185 | resp <- light_profile(fls2, v = "Petal.Length", type = "response") 186 | predicted <- light_profile( 187 | fls2, v = "Petal.Length", type = "predicted", counts = FALSE 188 | ) 189 | 190 | expect_equal(eff$pd[-2L], pd$data[-2L]) 191 | expect_equal(eff$ale[-2L], ale$data[-2L]) 192 | expect_equal(eff$response, resp$data) 193 | expect_equal(eff$predicted, predicted$data) 194 | 195 | expect_s3_class(plot(eff, use = "all"), "ggplot") 196 | }) 197 | 198 | test_that("light_scatter works for type response", { 199 | sc <- light_scatter( 200 | fls, v = "Petal.Length", type = "response", data = iris[1:5, ] 201 | ) 202 | expect_equal(sc$data$value_, rep(iris$Sepal.Length[1:5], 2)) 203 | expect_s3_class(plot(sc), "ggplot") 204 | }) 205 | 206 | test_that("light_scatter works for type predicted", { 207 | sc <- light_scatter(fl2, v = "Petal.Length", type = "predicted") 208 | expect_equal(sc$data$value_, rep(mean(iris$Sepal.Length), 150)) 209 | expect_s3_class(plot(sc), "ggplot") 210 | }) 211 | 212 | test_that("light_scatter works for type residual", { 213 | sc <- light_scatter(fl1, v = "Species", type = "residual", data = iris[1:50, ]) 214 | expect_equal(mean(sc$data$value_), 0) 215 | expect_s3_class(plot(sc), "ggplot") 216 | }) 217 | -------------------------------------------------------------------------------- /tests/testthat/tests-globaltree.R: -------------------------------------------------------------------------------- 1 | fit <- stats::lm(Sepal.Length ~ ., data = iris) 2 | x <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") 3 | 4 | test_that("basic functionality works", { 5 | surr <- light_global_surrogate(x) 6 | expect_equal(surr$data$r_squared, 0.923, tolerance = 0.001) 7 | expect_true(inherits(plot(surr), "list")) 8 | }) 9 | 10 | test_that("by variable work", { 11 | surr <- light_global_surrogate(x, by = "Species", v = c("Species", "Petal.Length")) 12 | expect_equal(dim(surr$data), c(3L, 4L)) 13 | }) 14 | 15 | test_that("argument 'v' works", { 16 | surr <- light_global_surrogate(x, v = "Petal.Length") 17 | expect_equal(dim(surr$data), c(1L, 3L)) 18 | }) 19 | 20 | test_that("multiflashlights work", { 21 | fit1 <- stats::lm(Sepal.Length ~ Species, data = iris) 22 | fl1 <- flashlight(model = fit1, label = "Species", data = iris, y = "Sepal.Length") 23 | fit2 <- stats::lm(Sepal.Length ~ Petal.Length, data = iris) 24 | fl2 <- flashlight( 25 | model = fit2, label = "Petal.Length", data = iris, y = "Sepal.Length" 26 | ) 27 | fls <- multiflashlight(list(fl1, fl2)) 28 | surr <- light_global_surrogate(fls) 29 | expect_equal(dim(surr$data), c(2L, 3L)) 30 | expect_equal(surr$data$r_squared, c(1, 0.978), tolerance = 0.001) 31 | }) 32 | -------------------------------------------------------------------------------- /tests/testthat/tests-grouped.R: -------------------------------------------------------------------------------- 1 | test_that("grouped_counts works", { 2 | expect_equal(grouped_counts(iris), data.frame(n = 150)) 3 | expect_equal(grouped_counts(iris, by = "Species")$n, c(50, 50, 50)) 4 | expect_equal(grouped_counts(iris, w = "Petal.Length"), data.frame(n = 563.7)) 5 | expect_equal(grouped_counts(iris, by = "Species", w = "Petal.Length")$n, 6 | c(73.1, 213, 278), tolerance = 1) 7 | }) 8 | 9 | test_that("grouped_weighted_mean works", { 10 | n <- 100 11 | set.seed(1) 12 | data <- data.frame(x = rnorm(n), w = runif(n), group = factor(sample(1:3, n, TRUE))) 13 | data <- data[order(data$group), ] # .by does not sort 14 | expect_equal(grouped_weighted_mean(data, x = "x")$x, 15 | grouped_stats(data, x = "x")$x) 16 | expect_equal(grouped_weighted_mean(data, x = "x", w = "w")$x, 17 | grouped_stats(data, x = "x", w = "w")$x) 18 | expect_equal(grouped_weighted_mean(data, x = "x", by = "group")$x, 19 | grouped_stats(data, x = "x", by = "group")$x) 20 | expect_equal(grouped_weighted_mean(data, x = "x", w = "w", by = "group")$x, 21 | grouped_stats(data, x = "x", w = "w", by = "group")$x) 22 | }) 23 | 24 | test_that("grouped_center works", { 25 | data <- data.frame(x = c(1, 1, 2), w = c(2, 2, 1)) 26 | res <- c(-1, -1, 2) / 3 27 | resw <- c(-1, -1, 4) / 5 28 | expect_equal(grouped_center(data, x = "x"), res) 29 | expect_equal(grouped_center(data, x = "x", w = "w"), resw) 30 | 31 | data2 <- data * 2 32 | data$g <- "A" 33 | data2$g <- "B" 34 | data3 <- rbind(data, data2) 35 | expect_equal(grouped_center(data3, x = "x", by = "g"), c(res, 2 * res)) 36 | expect_equal(grouped_center(data3, x = "x", w = "w", by = "g"), 37 | c(resw, 2 * resw)) 38 | }) 39 | 40 | test_that("grouped_stats works", { 41 | data <- data.frame(x = 1:10, w = 1:10, g = rep(1:2, each = 5)) 42 | 43 | expect_equal(grouped_stats(data, "x"), data.frame(counts_ = 10, x = 5.5)) 44 | expect_equal(grouped_stats(data, "x", stats = "variance"), 45 | data.frame(counts_ = 10, x = var(1:10))) 46 | expect_equal(grouped_stats(data, "x", stats = "quartiles")[c("q1", "q3")], 47 | data.frame(q1 = 3.25, q3 = 7.75)) 48 | 49 | expect_equal(grouped_stats(data, "x", by = "g")$x, c(3, 8)) 50 | expect_equal(grouped_stats(data, "x", by = "g", stats = "variance")$x, 51 | c(2.5, 2.5)) 52 | expect_equal(grouped_stats(data, "x", by = "g", stats = "quartiles")[["q1"]], 53 | c(2, 7)) 54 | 55 | expect_equal(grouped_stats(data, "x", w = "w")$x, weighted.mean(data$x, data$w)) 56 | expect_equal(grouped_stats(data, "x", w = "w", stats = "variance")$x, 57 | 6.875) 58 | expect_equal(grouped_stats(data, "x", w = "w", stats = "quartiles")[c("q1", "q3")], 59 | data.frame(q1 = 5, q3 = 9)) 60 | 61 | expect_equal(grouped_stats(data, "x", w = "w", by = "g")$x, 62 | c(weighted.mean(data$x[1:5], data$w[1:5]), 63 | weighted.mean(data$x[6:10], data$w[6:10]))) 64 | 65 | expect_equal(colnames(grouped_stats(data, "x", by = "g", counts_name = "n", 66 | value_name = "median", stats = "quartiles", 67 | q1_name = "p25", q3_name = "p75", )), 68 | c("g", "n", "p25", "median", "p75")) 69 | }) 70 | 71 | -------------------------------------------------------------------------------- /tests/testthat/tests-ice.R: -------------------------------------------------------------------------------- 1 | test_that("basic functionality and n_max work", { 2 | fit <- stats::lm(Sepal.Length ~ Species + 0, data = iris) 3 | fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") 4 | ice <- light_ice(fl, v = "Species", n_max = 1) 5 | expect_equal(as.numeric(ice$data$value_), as.numeric(coef(fit))) 6 | expect_s3_class(plot(ice), "ggplot") 7 | }) 8 | 9 | test_that("by functionality, indices and evaluate_at work", { 10 | fit <- stats::lm(Sepal.Length ~ Species * Petal.Width, data = iris) 11 | pred <- predict( 12 | fit, expand.grid(Petal.Width = c(0.1, 0.3, 0.5), Species = levels(iris$Species)) 13 | ) 14 | fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") 15 | ice <- light_ice( 16 | fl, 17 | v = "Petal.Width", 18 | indices = c(1, 51, 101), 19 | by = "Species", 20 | evaluate_at = c(0.1, 0.3, 0.5) 21 | ) 22 | expect_equal(as.numeric(ice$data$value_), as.numeric(pred)) 23 | expect_s3_class(plot(ice), "ggplot") 24 | }) 25 | 26 | fit <- stats::lm(Sepal.Length ~ Species, data = iris) 27 | fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") 28 | 29 | test_that("center first work", { 30 | ice <- light_ice(fl, v = "Species", n_max = 1L, center = "first") 31 | expect_equal(as.numeric(ice$data$value_)[-1L], as.numeric(coef(fit))[-1L]) 32 | expect_s3_class(plot(ice), "ggplot") 33 | }) 34 | 35 | test_that("center 0 and n_max work", { 36 | ice <- light_ice(fl, v = "Species", n_max = 10L, center = "0") 37 | expect_equal(dim(ice$data), c(30L, 4L)) 38 | expect_equal(mean(ice$data$value_), 0) 39 | expect_s3_class(plot(ice), "ggplot") 40 | }) 41 | 42 | test_that("basic functionality works for multiflashlight", { 43 | fit1 <- stats::lm(Sepal.Length ~ Species + 0, data = iris) 44 | fl1 <- flashlight(model = fit1, label = "Species", data = iris, y = "Sepal.Length") 45 | fit2 <- stats::lm(Sepal.Length ~ 1, data = iris) 46 | fl2 <- flashlight(model = fit2, label = "Empty", data = iris, y = "Sepal.Length") 47 | fls <- multiflashlight(list(fl1, fl2)) 48 | 49 | ice <- light_ice(fls, v = "Species", n_max = 1) 50 | expect_equal(as.numeric(ice$data$value_)[1:3], as.numeric(coef(fit1))) 51 | expect_equal(as.numeric(ice$data$value_)[4:6], rep(mean(iris$Sepal.Length), 3)) 52 | expect_s3_class(plot(ice), "ggplot") 53 | 54 | ice <- light_ice(fls, v = "Petal.Length", indices = 1L, n_bins = 2L) 55 | expect_equal(ice$data$id_[1L], 1) 56 | expect_equal(as.numeric(ice$data$value_[1:2]), rep(mean(iris$Sepal.Length[1:50]), 2)) 57 | expect_equal(as.numeric(ice$data$value_[3:4]), rep(mean(iris$Sepal.Length), 2)) 58 | expect_s3_class(plot(ice), "ggplot") 59 | }) 60 | 61 | test_that("weights have no impact on results", { 62 | fit <- stats::lm(Sepal.Length ~ Species + 0, data = iris) 63 | fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") 64 | ice <- light_ice(fl, v = "Species", indices = 1:3) 65 | 66 | fl_weighted <- flashlight( 67 | model = fit, 68 | label = "weighted by Petal.Length", 69 | data = iris, 70 | y = "Sepal.Length", 71 | w = "Petal.Length" 72 | ) 73 | ice_weighted <- light_ice(fl_weighted, v = "Species", indices = 1:3) 74 | 75 | expect_equal(ice$data$value_, ice_weighted$data$value_) 76 | }) 77 | -------------------------------------------------------------------------------- /tests/testthat/tests-importance.R: -------------------------------------------------------------------------------- 1 | fit1 <- stats::lm(Sepal.Length ~ Petal.Width + Species + Sepal.Width, data = iris) 2 | fit2 <- stats::lm(Sepal.Length ~ Petal.Width, data = iris) 3 | fl1 <- flashlight(model = fit1, label = "1", data = iris, y = "Sepal.Length") 4 | fl2 <- flashlight(model = fit2, label = "2", data = iris, y = "Sepal.Length") 5 | fls <- multiflashlight(list(fl1, fl2)) 6 | 7 | test_that("most_important works ", { 8 | imp <- light_importance(fl1, seed = 1L) 9 | expect_equal(most_important(imp, 1L), "Species") 10 | expect_equal( 11 | most_important(imp), 12 | c("Species", "Sepal.Width", "Petal.Width", "Petal.Length") 13 | ) 14 | 15 | imp <- light_importance(fls, seed = 1L) 16 | expect_equal(most_important(imp, 1L), "Petal.Width") 17 | expect_equal( 18 | most_important(imp), 19 | c("Petal.Width", "Species", "Sepal.Width", "Petal.Length")) 20 | }) 21 | 22 | test_that("light_importance works", { 23 | imp <- light_importance(fl2, seed = 1L) 24 | expect_equal(imp$data$value_[imp$data$variable_ != "Petal.Width"], rep(0, 3)) 25 | expect_true(imp$data$value_[imp$data$variable_ == "Petal.Width"] > 0) 26 | expect_equal(imp$data$value_[imp$data$variable_ == "Petal.Width"], 0.623468, tolerance = 0.001) 27 | expect_s3_class(plot(imp), "ggplot") 28 | }) 29 | 30 | test_that("light_importance reacts on metric", { 31 | imp <- light_importance( 32 | fl2, seed = 1L, metric = list(r_squared = MetricsWeighted::r_squared) 33 | ) 34 | expect_equal(imp$data$value_[imp$data$variable_ != "Petal.Width"], rep(0, 3)) 35 | expect_true(imp$data$value_[imp$data$variable_ == "Petal.Width"] < 0) 36 | expect_s3_class(plot(imp), "ggplot") 37 | }) 38 | 39 | test_that("'lower_is_better' works", { 40 | imp <- light_importance( 41 | fl2, 42 | seed = 1L, 43 | metric = list(r_squared = MetricsWeighted::r_squared), 44 | lower_is_better = FALSE 45 | ) 46 | expect_true(imp$data$value_[imp$data$variable_ == "Petal.Width"] > 0) 47 | }) 48 | 49 | test_that("'by' and 'v' works", { 50 | fit3 <- stats::lm(Sepal.Length ~ Petal.Width + Species, data = iris) 51 | fl3 <- flashlight(model = fit3, label = "3", data = iris, y = "Sepal.Length") 52 | imp <- light_importance( 53 | fl3, seed = 1L, by = "Species", v = c("Petal.Width", "Species") 54 | ) 55 | 56 | expect_equal(imp$data$value_[imp$data$variable_ == "Species"], rep(0, 3)) 57 | expect_true(all(imp$data$value_[imp$data$variable_ == "Petal.Width"] > 0)) 58 | expect_equal( 59 | imp$data$value_[imp$data$variable_ == "Petal.Width"], 60 | c(0.02224045, 0.11677754, 0.11914359), 61 | tolerance = 0.001 62 | ) 63 | expect_s3_class(plot(imp), "ggplot") 64 | }) 65 | 66 | test_that("'w' reacts", { 67 | fit3 <- stats::lm(Sepal.Length ~ Petal.Width + Species, data = iris) 68 | fl3 <- flashlight( 69 | model = fit3, label = "3", data = iris, y = "Sepal.Length", w = "Petal.Length" 70 | ) 71 | imp <- light_importance( 72 | fl3, seed = 1L, by = "Species", v = c("Petal.Width", "Species") 73 | ) 74 | 75 | expect_equal( 76 | imp$data$value_[imp$data$variable_ == "Petal.Width"], 77 | c(0.02220016, 0.10929391, 0.11598350), 78 | tolerance = 0.001 79 | ) 80 | expect_s3_class(plot(imp), "ggplot") 81 | }) 82 | 83 | test_that("m_repetitions react", { 84 | imp <- light_importance( 85 | fl1, seed = 1L, v = c("Petal.Width", "Species"), m_repetitions = 4L 86 | ) 87 | expect_equal(imp$data$error_, c(0.005633942, 0.017129828), tolerance = 0.001) 88 | expect_true(all(!is.na(imp$data$error_))) 89 | expect_s3_class(plot(imp), "ggplot") 90 | }) 91 | 92 | test_that("m_repetitions react with 'by'", { 93 | imp <- light_importance( 94 | fl1, seed = 1L, v = c("Petal.Width", "Species"), m_repetitions = 4L, by = "Species" 95 | ) 96 | expect_true(is.light(imp)) 97 | expect_true(all(!is.na(imp$data$error_))) 98 | expect_s3_class(plot(imp), "ggplot") 99 | }) 100 | 101 | test_that("multiflashlight works", { 102 | imp <- light_importance(fls, seed = 1L, v = c("Petal.Width", "Species")) 103 | expect_equal(dim(imp$data), c(4L, 5L)) 104 | expect_equal( 105 | imp$data$value_, 106 | c(0.1425708, 0.4369238, 0.5537231, 0.0000000), 107 | tolerance = 0.001 108 | ) 109 | expect_s3_class(plot(imp), "ggplot") 110 | }) 111 | -------------------------------------------------------------------------------- /tests/testthat/tests-interaction.R: -------------------------------------------------------------------------------- 1 | fit_additive <- stats::lm( 2 | Sepal.Length ~ Petal.Length + Petal.Width + Species, data = iris 3 | ) 4 | fit_nonadditive <- stats::lm( 5 | Sepal.Length ~ Petal.Length * Petal.Width + Species, data = iris 6 | ) 7 | fl_additive <- flashlight(model = fit_additive, label = "additive") 8 | fl_nonadditive <- flashlight(model = fit_nonadditive, label = "nonadditive") 9 | fls <- multiflashlight( 10 | list(fl_additive, fl_nonadditive), data = iris, y = "Sepal.Length" 11 | ) 12 | 13 | test_that("basic functionality works for light_interaction", { 14 | inter <- light_interaction(fls$additive) 15 | expect_equal(inter$data$value_, rep(0, 4)) 16 | expect_s3_class(plot(inter), "ggplot") 17 | 18 | inter <- light_interaction(fls$nonadditive) 19 | expect_equal(inter$data$value_, c(0, 0.0421815, 0.0421815, 0), tolerance = 0.001) 20 | expect_s3_class(plot(inter), "ggplot") 21 | }) 22 | 23 | test_that("light_interaction reacts on weights", { 24 | inter <- light_interaction(flashlight(fls$nonadditive, w = "Sepal.Width")) 25 | expect_equal(inter$data$value_, c(0.03917691, 0.03917691, 0), tolerance = 0.001) 26 | expect_s3_class(plot(inter), "ggplot") 27 | }) 28 | 29 | test_that("basic functionality works for light_interaction (ICE approach)", { 30 | inter <- light_interaction(fls$additive, type = "ice") 31 | expect_equal(inter$data$value, rep(0, 4)) 32 | 33 | inter <- light_interaction(fls$nonadditive, type = "ice") 34 | expect_equal(inter$data$value_, c(0, 0.0274, 0.843, 0), tolerance = 0.001) 35 | expect_s3_class(plot(inter), "ggplot") 36 | }) 37 | 38 | test_that("basic functionality works for light_interaction with pairwise interactions", { 39 | inter <- light_interaction(fls$additive, pairwise = TRUE) 40 | expect_equal(inter$data$value, rep(0, 6)) 41 | expect_s3_class(plot(inter), "ggplot") 42 | 43 | inter <- light_interaction(fls$nonadditive, pairwise = TRUE) 44 | expect_equal(inter$data$value[4], 0.0207711, tolerance = 0.001) 45 | expect_equal(inter$data$value[-4], rep(0, 5)) 46 | expect_s3_class(plot(inter), "ggplot") 47 | expect_error(light_interaction(fls$additive, pairwise = TRUE, type = "ice")) 48 | }) 49 | 50 | test_that("light_interaction reacts on 'normalize'", { 51 | inter <- light_interaction(fls$nonadditive, normalize = FALSE) 52 | expect_equal(inter$data$value, c(0, 0.031848, 0.031848, 0), tolerance = 0.001) 53 | 54 | inter <- light_interaction(fls$nonadditive, normalize = FALSE, pairwise = TRUE) 55 | expect_equal(inter$data$value[4], 0.03184801, tolerance = 0.001) 56 | 57 | inter <- light_interaction(fls$nonadditive, normalize = FALSE, type = "ice") 58 | expect_equal(inter$data$value, c(0, 0.043085, 0.043085, 0), tolerance = 0.001) 59 | }) 60 | 61 | test_that("light_interaction reacts on 'take_sqrt'", { 62 | inter <- light_interaction(fls$nonadditive, take_sqrt = FALSE) 63 | expect_equal(inter$data$value, c(0, 0.00178, 0.00178, 0), tolerance = 0.001) 64 | 65 | inter <- light_interaction(fls$nonadditive, take_sqrt = FALSE, pairwise = TRUE) 66 | expect_equal(inter$data$value[4L], 0.00043143, tolerance = 0.001) 67 | 68 | inter <- light_interaction(fls$nonadditive, take_sqrt = FALSE, type = "ice") 69 | expect_equal(inter$data$value, c(0, 0.00075, 0.71004, 0), tolerance = 0.001) 70 | }) 71 | 72 | test_that("light_interaction reacts on multiflashlight", { 73 | inter <- light_interaction(fls, type = "ice") 74 | expect_equal(inter$data$value, c(rep(0, 5), 0.0274, 0.843, 0), tolerance = 0.001) 75 | expect_s3_class(plot(inter), "ggplot") 76 | 77 | inter <- light_interaction(fls, pairwise = TRUE) 78 | expect_equal(inter$data$value[inter$data$label == "additive"], rep(0, 6)) 79 | expect_equal( 80 | inter$data$value[inter$data$label == "nonadditive"][4L], 81 | 0.0207711, 82 | tolerance = 0.001 83 | ) 84 | }) 85 | 86 | test_that("light_interaction reacts on 'by'", { 87 | inter <- light_interaction(fls$nonadditive, by = "Species") 88 | dat <- inter$data 89 | expect_equal(dat$value[dat$variable == "Sepal.Width"], rep(0, 3)) 90 | expect_equal( 91 | dat$value[dat$variable %in% c("Petal.Width", "Petal.Length")], 92 | rep(c(0.00404, 0.007588311, 0.0077), each = 2L), 93 | tolerance = 0.001 94 | ) 95 | expect_s3_class(plot(inter), "ggplot") 96 | }) 97 | -------------------------------------------------------------------------------- /tests/testthat/tests-methods.R: -------------------------------------------------------------------------------- 1 | test_that("all_identical works", { 2 | x <- list(a = 1, b = 2) 3 | y <- list(a = 1, b = 3) 4 | expect_true(all_identical(list(x, y), `[[`, "a")) 5 | expect_false(all_identical(list(x, y), `[[`, "b")) 6 | }) 7 | 8 | test_that("light_check, flashlight and multiflashlight work", { 9 | fit <- stats::lm(Sepal.Length ~ Species + 0, data = iris) 10 | fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") 11 | expect_true(is.flashlight(light_check(fl))) 12 | expect_error(light_check(1)) 13 | expect_error(flashlight(fl, metrics = "no metric")) 14 | expect_error(flashlight(fl, linkinv = "no metric")) 15 | expect_error(flashlight(fl, data = "no metric")) 16 | expect_error(flashlight(fl, data = "no metric")) 17 | expect_error(multiflashlight(list(fl, 1))) 18 | expect_error(flashlight(1)) 19 | expect_error(flashlight(fl, data = "bad data")) 20 | }) 21 | 22 | test_that("light_combine works", { 23 | fit <- stats::lm(Sepal.Length ~ Species + 0, data = iris) 24 | fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") 25 | ell1 <- light_performance(fl) 26 | ell2 <- light_performance(fl) 27 | expect_equal(nrow(light_combine(list(ell1, ell2))$data), 2 * nrow(ell1$data)) 28 | expect_equal(light_combine(ell1), ell1) 29 | }) 30 | 31 | test_that("selected 'is' functions work", { 32 | fit <- stats::lm(Sepal.Length ~ Species + 0, data = iris) 33 | fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") 34 | fls <- multiflashlight(list(fl, flashlight(fl, label = "lm2"))) 35 | 36 | expect_true(is.flashlight(fl)) 37 | expect_false(is.flashlight(1)) 38 | 39 | expect_true(is.multiflashlight(fls)) 40 | expect_false(is.flashlight(fls)) 41 | expect_false(is.multiflashlight(1)) 42 | expect_false(is.light(1)) 43 | 44 | expect_true(is.light(light_performance(fl))) 45 | expect_true(is.light_performance(light_performance(fl))) 46 | expect_true(is.light_performance_multi(light_performance(fls))) 47 | expect_false(is.light_performance_multi(light_performance(fl))) 48 | 49 | expect_true(is.light(light_importance(fl))) 50 | expect_true(is.light_importance(light_importance(fl))) 51 | expect_true(is.light_importance_multi(light_importance(fls))) 52 | expect_false(is.light_importance_multi(light_importance(fl))) 53 | 54 | expect_true(is.light(light_importance(fl, v = "Species"))) 55 | expect_true(is.light_importance(light_importance(fl, v = "Species"))) 56 | expect_true(is.light_importance_multi(light_importance(fls, v = "Species"))) 57 | expect_false(is.light_importance_multi(light_importance(fl, v = "Species"))) 58 | 59 | expect_true(is.light(light_scatter(fl, v = "Species"))) 60 | expect_true(is.light_scatter(light_scatter(fl, v = "Species"))) 61 | expect_true(is.light_scatter_multi(light_scatter(fls, v = "Species"))) 62 | expect_false(is.light_scatter_multi(light_scatter(fl, v = "Species"))) 63 | 64 | expect_true(is.light(light_profile(fl, v = "Species"))) 65 | expect_true(is.light_profile(light_profile(fl, v = "Species"))) 66 | expect_true(is.light_profile_multi(light_profile(fls, v = "Species"))) 67 | expect_false(is.light_profile_multi(light_profile(fl, v = "Species"))) 68 | 69 | expect_true(is.light(light_effects(fl, v = "Species"))) 70 | expect_true(is.light_effects(light_effects(fl, v = "Species"))) 71 | expect_true(is.light_effects_multi(light_effects(fls, v = "Species"))) 72 | expect_false(is.light_effects_multi(light_effects(fl, v = "Species"))) 73 | 74 | expect_true(is.light(light_global_surrogate(fl))) 75 | expect_true(is.light_global_surrogate(light_global_surrogate(fl))) 76 | expect_true(is.light_global_surrogate_multi(light_global_surrogate(fls))) 77 | expect_false(is.light_global_surrogate_multi(light_global_surrogate(fl))) 78 | 79 | expect_true(is.light(light_breakdown(fl, new_obs = iris[1, ]))) 80 | expect_true(is.light_breakdown(light_breakdown(fl, new_obs = iris[1, ]))) 81 | expect_true( 82 | is.light_breakdown_multi(light_breakdown(fls, new_obs = iris[1, ])) 83 | ) 84 | expect_false( 85 | is.light_breakdown_multi(light_breakdown(fl, new_obs = iris[1, ])) 86 | ) 87 | }) 88 | 89 | fit <- stats::lm(Sepal.Length ~ Species + 0, data = iris) 90 | fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") 91 | fls <- multiflashlight(list(fl, flashlight(fl, label = "lm2"))) 92 | 93 | test_that("response method works for (multi-)flashlights", { 94 | expect_equal(response(fl), iris$Sepal.Length) 95 | expect_equal(response(fls)[[2L]], iris$Sepal.Length) 96 | expect_equal(response(flashlight(fl, linkinv = log)), log(iris$Sepal.Length)) 97 | }) 98 | 99 | test_that("predict method works for (multi-)flashlights", { 100 | expect_equal(predict(fl, data = head(iris)), predict(fit, head(iris))) 101 | expect_equal(predict(fls)[[1L]], predict(fls)[[2L]]) 102 | expect_equal(predict(flashlight(fl, linkinv = log)), log(predict(fl))) 103 | }) 104 | 105 | test_that("residuals method works for (multi-)flashlights", { 106 | expect_equal(resid(fl, data = head(iris)), head(resid(fit))) 107 | expect_equal(resid(fls)[[1L]], resid(fls)[[2L]]) 108 | }) 109 | 110 | -------------------------------------------------------------------------------- /tests/testthat/tests-perf.R: -------------------------------------------------------------------------------- 1 | fit <- stats::lm(Sepal.Length ~ ., data = iris) 2 | fl <- flashlight( 3 | model = fit, 4 | label = "full", 5 | data = iris, 6 | y = "Sepal.Length", 7 | metrics = list(rmse = MetricsWeighted::rmse, r2 = MetricsWeighted::r_squared) 8 | ) 9 | perf <- light_performance(fl) 10 | perf_by <- light_performance(fl, by = "Species") 11 | RMSE <- function(r) sqrt(mean(r^2)) 12 | get_rmse_r2 <- function(fit) c(RMSE(stats::resid(fit)), summary(fit)$r.squared) 13 | 14 | test_that("basic functionality works for two metrics", { 15 | expect_equal(dim(perf$data), c(2L, 3L)) 16 | expect_equal(perf$data$value_, get_rmse_r2(fit)) 17 | expect_s3_class(plot(perf), "ggplot") 18 | }) 19 | 20 | test_that("light_performance works with by variable", { 21 | expect_equal(dim(perf_by$data), c(3 * nrow(perf$data), ncol(perf$data) + 1L)) 22 | expect_equal( 23 | perf_by$data$value_[c(1L, 3L, 5L)], 24 | as.numeric(tapply(stats::resid(fit), iris$Species, FUN = RMSE)) 25 | ) 26 | expect_s3_class(plot(perf_by), "ggplot") 27 | }) 28 | 29 | test_that("light_performance works for weighted flashlight", { 30 | iris_w <- transform(iris, w1 = 1, w2 = 2, w3 = 1:nrow(iris)) 31 | perf_w1 <- light_performance(flashlight(fl, w = "w1", data = iris_w)) 32 | perf_w2 <- light_performance(flashlight(fl, w = "w2", data = iris_w), by = "Species") 33 | perf_w3 <- light_performance(flashlight(fl, w = "w3", data = iris_w)) 34 | 35 | expect_equal(perf, perf_w1) 36 | expect_equal(perf_by, perf_w2) 37 | expect_false(identical(perf, perf_w3)) 38 | 39 | expect_s3_class(plot(perf_w1), "ggplot") 40 | expect_s3_class(plot(perf_w2), "ggplot") 41 | expect_s3_class(plot(perf_w3), "ggplot") 42 | }) 43 | 44 | test_that("R-squared for weighted flashlight is the same as the one from summary.lm", { 45 | fit_w <- stats::lm(Sepal.Length ~ ., data = iris, weights = iris$Sepal.Width) 46 | fl_w <- flashlight( 47 | model = fit, 48 | label = "wlm", 49 | data = iris, 50 | w = "Sepal.Width", 51 | y = "Sepal.Length", 52 | metrics = list(r2 = MetricsWeighted::r_squared) 53 | ) 54 | perf_w <- light_performance(fl_w) 55 | expect_equal(perf_w$data$value_, summary(fit_w)$r.squared, tolerance = 0.01) 56 | expect_s3_class(plot(perf_w), "ggplot") 57 | }) 58 | -------------------------------------------------------------------------------- /vignettes/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | *.R 3 | -------------------------------------------------------------------------------- /vignettes/biblio.bib: -------------------------------------------------------------------------------- 1 | % Encoding: UTF-8 2 | 3 | @book{molnar, 4 | title = {Interpretable Machine Learning}, 5 | author = {Christoph Molnar}, 6 | url = {https://christophm.github.io/interpretable-ml-book/}, 7 | year = {2019}, 8 | subtitle = {A Guide for Making Black Box Models Explainable} 9 | } 10 | 11 | @misc{fisher, 12 | title={All Models are Wrong, but Many are Useful: Learning a Variable's Importance by Studying an Entire Class of Prediction Models Simultaneously}, 13 | author={Aaron Fisher and Cynthia Rudin and Francesca Dominici}, 14 | year={2018}, 15 | url = {https://arxiv.org/abs/1801.01489}, 16 | primaryClass={stat.ME} 17 | } 18 | 19 | @article{goldstein, 20 | author = {Alex Goldstein and Adam Kapelner and Justin Bleich and Emil Pitkin}, 21 | title = {Peeking Inside the Black Box: Visualizing Statistical Learning With Plots of Individual Conditional Expectation}, 22 | journal = {Journal of Computational and Graphical Statistics}, 23 | volume = {24}, 24 | number = {1}, 25 | pages = {44-65}, 26 | year = {2015}, 27 | publisher = {Taylor & Francis}, 28 | doi = {10.1080/10618600.2014.907095}, 29 | URL = {https://doi.org/10.1080/10618600.2014.907095}, 30 | eprint = {https://doi.org/10.1080/10618600.2014.907095} 31 | } 32 | 33 | @article{friedman2001, 34 | author = "Friedman, Jerome H.", 35 | doi = "10.1214/aos/1013203451", 36 | fjournal = "The Annals of Statistics", 37 | journal = "Ann. Statist.", 38 | month = "10", 39 | number = "5", 40 | pages = "1189--1232", 41 | publisher = "The Institute of Mathematical Statistics", 42 | title = "Greedy function approximation: A gradient boosting machine.", 43 | url = "https://doi.org/10.1214/aos/1013203451", 44 | volume = "29", 45 | year = "2001" 46 | } 47 | 48 | @misc{apley, 49 | title={Visualizing the Effects of Predictor Variables in Black Box Supervised Learning Models}, 50 | author={Daniel W. Apley and Jingyu Zhu}, 51 | year={2016}, 52 | url = {https://arxiv.org/abs/1612.08468}, 53 | primaryClass={stat.ME} 54 | } 55 | 56 | @misc{gosiewska, 57 | title={Do Not Trust Additive Explanations}, 58 | author={Alicja Gosiewska and Przemyslaw Biecek}, 59 | year={2019}, 60 | url = {https://arxiv.org/abs/1903.11420}, 61 | primaryClass={cs.LG} 62 | } 63 | 64 | @article{friedman2008, 65 | ISSN = {19326157}, 66 | author = {Jerome H. Friedman and Bogdan E. Popescu}, 67 | journal = {The Annals of Applied Statistics}, 68 | number = {3}, 69 | pages = {916--954}, 70 | publisher = {Institute of Mathematical Statistics}, 71 | title = {Predictive Learning via Rule Ensembles}, 72 | volume = {2}, 73 | year = {2008} 74 | } 75 | 76 | @incollection{Lundberg2017, 77 | title = {A Unified Approach to Interpreting Model Predictions}, 78 | author = {Lundberg, Scott M and Lee, Su-In}, 79 | booktitle = {Advances in Neural Information Processing Systems 30}, 80 | editor = {I. Guyon and U. V. Luxburg and S. Bengio and H. Wallach and R. Fergus and S. Vishwanathan and R. Garnett}, 81 | pages = {4765--4774}, 82 | year = {2017}, 83 | publisher = {Curran Associates, Inc.}, 84 | url = {https://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf} 85 | } 86 | 87 | @article{Lundberg2020, 88 | title={From local explanations to global understanding with explainable AI for trees}, 89 | author={Lundberg, Scott M. and Erion, Gabriel and Chen, Hugh and DeGrave, Alex and Prutkin, Jordan M. and Nair, Bala and Katz, Ronit and Himmelfarb, Jonathan and Bansal, Nisha and Lee, Su-In}, 90 | journal={Nature Machine Intelligence}, 91 | volume={2}, 92 | number={1}, 93 | pages={2522-5839}, 94 | year={2020}, 95 | publisher={Nature Publishing Group} 96 | } 97 | 98 | 99 | -------------------------------------------------------------------------------- /vignettes/flashlight.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Using flashlight" 3 | bibliography: "biblio.bib" 4 | link-citations: true 5 | output: 6 | rmarkdown::html_vignette 7 | vignette: > 8 | %\VignetteIndexEntry{Using flashlight} 9 | %\VignetteEngine{knitr::rmarkdown} 10 | %\VignetteEncoding{UTF-8} 11 | --- 12 | 13 | ```{r, include = FALSE} 14 | knitr::opts_chunk$set( 15 | collapse = TRUE, 16 | comment = "#>", 17 | warning = FALSE, 18 | message = FALSE, 19 | fig.width = 5.5, 20 | fig.height = 4.5 21 | ) 22 | ``` 23 | 24 | ## Overview 25 | 26 | **No black-box model without XAI.** This is where packages like 27 | 28 | - [{DALEX}](https://CRAN.R-project.org/package=DALEX), 29 | - [{iml}](https://CRAN.R-project.org/package=iml), and 30 | - [{flashlight}](https://CRAN.R-project.org/package=flashlight) enter the stage. 31 | 32 | {flashlight} offers the following XAI methods: 33 | 34 | - `light_performance()`: Performance metrics like RMSE and/or $R^2$ 35 | - `light_importance()`: Permutation variable importance [@fisher] 36 | - `light_ice()`: Individual conditional expectation (ICE) profiles [@goldstein] (centered or uncentered) 37 | - `light_profile()`: Partial dependence [@friedman2001], accumulated local effects (ALE) [@apley], average predicted/observed/residual 38 | - `light_profile2d()`: Two-dimensional version of `light_profile()` 39 | - `light_effects()`: Combines partial dependence, ALE, response and prediction profiles 40 | - `light_interaction()`: Different variants of Friedman's H statistics [@friedman2008] 41 | - `light_breakdown()`: Variable contribution breakdown (approximate SHAP) for single observations [@gosiewska] 42 | - `light_global_surrogate()`: Global surrogate trees [@molnar] 43 | 44 | Good to know: 45 | 46 | - Each method acts on an explainer object called `flashlight` (see examples and Section "flashlights"). 47 | - Multiple models can be compared via `multiflashlight()`. 48 | - Calling `plot()` visualizes the results via {ggplot2}. 49 | - Methods support case weights. 50 | - Methods support a grouping variable. 51 | 52 | ## Installation 53 | 54 | ```r 55 | # From CRAN 56 | install.packages("flashlight") 57 | 58 | # Development version 59 | devtools::install_github("mayer79/flashlight") 60 | ``` 61 | 62 | ## Usage 63 | 64 | Let's start with an iris example. For simplicity, we do not split the data into training and testing/validation sets. 65 | 66 | ```{r} 67 | library(ggplot2) 68 | library(MetricsWeighted) 69 | library(flashlight) 70 | 71 | fit_lm <- lm(Sepal.Length ~ ., data = iris) 72 | 73 | # Make explainer object 74 | fl_lm <- flashlight( 75 | model = fit_lm, 76 | data = iris, 77 | y = "Sepal.Length", 78 | label = "lm", 79 | metrics = list(RMSE = rmse, `R-squared` = r_squared) 80 | ) 81 | ``` 82 | 83 | ### Performance 84 | 85 | ```{r} 86 | fl_lm |> 87 | light_performance() |> 88 | plot(fill = "darkred") + 89 | labs(x = element_blank(), title = "Performance on training data") 90 | 91 | fl_lm |> 92 | light_performance(by = "Species") |> 93 | plot(fill = "darkred") + 94 | ggtitle("Performance split by Species") 95 | ``` 96 | 97 | ### Permutation importance regarding first metric 98 | 99 | Error bars represent standard errors, i.e., the uncertainty of the estimated importance. 100 | 101 | ```{r} 102 | fl_lm |> 103 | light_importance(m_repetitions = 4) |> 104 | plot(fill = "darkred") + 105 | labs(title = "Permutation importance", y = "Increase in RMSE") 106 | ``` 107 | 108 | ### ICE curves for `Petal.Width` 109 | 110 | ```{r} 111 | fl_lm |> 112 | light_ice("Sepal.Width", n_max = 200) |> 113 | plot(alpha = 0.3, color = "chartreuse4") + 114 | labs(title = "ICE curves for 'Sepal.Width'", y = "Prediction") 115 | 116 | fl_lm |> 117 | light_ice("Sepal.Width", n_max = 200, center = "middle") |> 118 | plot(alpha = 0.3, color = "chartreuse4") + 119 | labs(title = "c-ICE curves for 'Sepal.Width'", y = "Prediction (centered)") 120 | ``` 121 | ### PDPs 122 | 123 | ```{r} 124 | fl_lm |> 125 | light_profile("Sepal.Width", n_bins = 40) |> 126 | plot() + 127 | ggtitle("PDP for 'Sepal.Width'") 128 | 129 | fl_lm |> 130 | light_profile("Sepal.Width", n_bins = 40, by = "Species") |> 131 | plot() + 132 | ggtitle("Same grouped by 'Species'") 133 | ``` 134 | 135 | ### 2D PDP 136 | 137 | ```{r} 138 | fl_lm |> 139 | light_profile2d(c("Petal.Width", "Petal.Length")) |> 140 | plot() 141 | ``` 142 | 143 | ### ALE 144 | 145 | ```{r} 146 | fl_lm |> 147 | light_profile("Sepal.Width", type = "ale") |> 148 | plot() + 149 | ggtitle("ALE plot for 'Sepal.Width'") 150 | ``` 151 | 152 | ### Different profile plots in one 153 | 154 | ```{r} 155 | fl_lm |> 156 | light_effects("Sepal.Width") |> 157 | plot(use = "all") + 158 | ggtitle("Different types of profiles for 'Sepal.Width'") 159 | ``` 160 | 161 | ### Variable contribution breakdown for single observation 162 | 163 | ```{r} 164 | fl_lm |> 165 | light_breakdown(new_obs = iris[1, ]) |> 166 | plot() 167 | ``` 168 | 169 | ### Global surrogate tree 170 | 171 | ```{r} 172 | fl_lm |> 173 | light_global_surrogate() |> 174 | plot() 175 | ``` 176 | ### Multiple models 177 | 178 | Multiple flashlights can be combined to a multiflashlight. 179 | 180 | ```{r} 181 | library(rpart) 182 | 183 | fit_tree <- rpart( 184 | Sepal.Length ~ ., 185 | data = iris, 186 | control = list(cp = 0, xval = 0, maxdepth = 5) 187 | ) 188 | 189 | # Make explainer object 190 | fl_tree <- flashlight( 191 | model = fit_tree, 192 | data = iris, 193 | y = "Sepal.Length", 194 | label = "tree", 195 | metrics = list(RMSE = rmse, `R-squared` = r_squared) 196 | ) 197 | 198 | # Combine with other explainer 199 | fls <- multiflashlight(list(fl_tree, fl_lm)) 200 | 201 | fls |> 202 | light_performance() |> 203 | plot(fill = "chartreuse4") + 204 | labs(x = "Model", title = "Performance") 205 | 206 | fls |> 207 | light_importance() |> 208 | plot(fill = "chartreuse4") + 209 | labs(y = "Increase in RMSE", title = "Permutation importance") 210 | 211 | fls |> 212 | light_profile("Petal.Length", n_bins = 40) |> 213 | plot() + 214 | ggtitle("PDP") 215 | 216 | fls |> 217 | light_profile("Petal.Length", n_bins = 40, by = "Species") |> 218 | plot() + 219 | ggtitle("PDP by Species") 220 | ``` 221 | 222 | ### flashlights 223 | 224 | The "flashlight" explainer expects the following information: 225 | 226 | - `model`: Fitted model. Currently, this argument must be named. 227 | - `data`: Reference data used to calculate things, often part of the validation data. 228 | - `y`: Column name in `data` corresponding to the **numeric** response. 229 | - `predict_function`: function of the same signature as `stats::predict()`. It takes a `model` and a data.frame `data`, and provides numeric predictions, see below for more details. 230 | - `linkinv`: Optional function applied to the output of `predict_function()`. *Should actually be called "trafo".* 231 | - `w`: Optional column name in `data` corresponding to case weights. 232 | - `by`: Optional column name in `data` used to group the results. Must be discrete. 233 | - `metrics`: List of metrics, by default `list(rmse = MetricsWeighted::rmse)`. For binary (probabilistic) classification, good candidate metrics would be `MetricsWeighted::logLoss`. 234 | - `label`: Mandatory name of the model. 235 | 236 | #### Typical `predict_function`s (a selection) 237 | 238 | The default `stats::predict()` works for models of class 239 | 240 | - `lm()`, 241 | - `glm()` (for predictions on link scale), and 242 | - `rpart()`. 243 | 244 | It also works for meta-learner models like 245 | 246 | - {caret}, and 247 | - {mlr3}. 248 | 249 | Manual prediction functions are, e.g., required for 250 | 251 | - {ranger}: Use `function(m, X) predict(m, X)$predictions` for regression, and 252 | `function(m, X) predict(m, X)$predictions[, 2]` for probabilistic binary classification 253 | - `glm()`: Use `function(m, X) predict(m, X, type = "response")` to get GLM predictions at the response scale 254 | 255 | A bit more complicated are models whose native predict function do not work on data.frames: 256 | 257 | - {xgboost} and {lightgbm}: They digest numeric matrices only, so the prediction function also needs to deal with the mapping from data.frame to matrix. 258 | - {keras}: It might accept data.frame inputs, but we need to take care of scalings. 259 | 260 | **Example (XGBoost):** 261 | 262 | This works when non-numeric features are all factors (not categoricals): 263 | 264 | ```r 265 | x <- vector of features 266 | predict_function = function(m, df) predict(m, data.matrix(df[x])) 267 | ``` 268 | 269 | ## References 270 | --------------------------------------------------------------------------------