├── .Rbuildignore ├── .github ├── .gitignore └── workflows │ ├── CRAN-R-CMD-check.yaml │ ├── pkgdown.yaml │ ├── pr-commands.yaml │ └── test-coverage.yaml ├── .gitignore ├── DESCRIPTION ├── LICENSE.md ├── NAMESPACE ├── NEWS.md ├── R ├── RcppExports.R ├── fifa20.R ├── globals.R ├── model_unified.R ├── plot_contribution.R ├── plot_feature_dependence.R ├── plot_feature_importance.R ├── plot_interaction.R ├── predict.R ├── set_reference_dataset.R ├── theme_drwhy.R ├── treeshap.R ├── unify.R ├── unify_gbm.R ├── unify_lightgbm.R ├── unify_randomForest.R ├── unify_ranger.R ├── unify_ranger_surv.R └── unify_xgboost.R ├── README.Rmd ├── README.md ├── _pkgdown.yml ├── cran-comments.md ├── data-raw └── fifa20.R ├── data └── fifa20.rda ├── man ├── colors_drwhy.Rd ├── fifa20.Rd ├── figures │ ├── README-plot_contribution_example-1.png │ ├── README-plot_dependence_example-1.png │ ├── README-plot_importance_example-1.png │ └── README-plot_interaction-1.png ├── gbm.unify.Rd ├── is.model_unified.Rd ├── is.treeshap.Rd ├── lightgbm.unify.Rd ├── model_unified.object.Rd ├── model_unified_multioutput.object.Rd ├── plot_contribution.Rd ├── plot_feature_dependence.Rd ├── plot_feature_importance.Rd ├── plot_interaction.Rd ├── predict.model_unified.Rd ├── print.model_unified.Rd ├── print.model_unified_multioutput.Rd ├── print.treeshap.Rd ├── print.treeshap_multioutput.Rd ├── randomForest.unify.Rd ├── ranger.unify.Rd ├── ranger_surv.unify.Rd ├── set_reference_dataset.Rd ├── theme_drwhy.Rd ├── treeshap.Rd ├── treeshap.object.Rd ├── treeshap_multioutput.object.Rd ├── unify.Rd └── xgboost.unify.Rd ├── src ├── RcppExports.cpp ├── RcppExports.o ├── predict.cpp ├── set_reference_dataset.cpp ├── treeshap.cpp └── treeshap.o ├── tests ├── testthat.R └── testthat │ ├── test_gbm_unify.R │ ├── test_lightgbm_unify.R │ ├── test_randomForest.R │ ├── test_ranger.R │ ├── test_ranger_surv.R │ ├── test_set_reference_dataset.R │ ├── test_treeshap_correctness.R │ └── test_xgboost_unify.R └── treeshap.Rproj /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^LICENSE\.md$ 2 | ^README\.Rmd$ 3 | ^data-raw$ 4 | ^README.md$ 5 | ^.*\.Rproj$ 6 | ^\.Rproj\.user$ 7 | ^misc$ 8 | ^python$ 9 | ^docs$ 10 | ^_pkgdown\.yml$ 11 | .travis.yml 12 | vignettes/ 13 | pkgdown/ 14 | ^LICENSE$ 15 | ^CONTRIBUTING.md 16 | ^NEWS.md 17 | DALEXpiramide 18 | ^codecov\.yml$ 19 | ^\.github$ 20 | ^tox.ini$ 21 | ^pkgdown$ 22 | ^cran-comments\.md$ 23 | ^CRAN-SUBMISSION$ 24 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.github/workflows/CRAN-R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master, 'dev*', 'fix*', 'issue*', 'doc*', 'gh-actions', 'githubactions'] 6 | 7 | pull_request: 8 | branches: [main, master] 9 | 10 | name: R-CMD-check 11 | 12 | jobs: 13 | CRAN-R-check: 14 | runs-on: ${{ matrix.config.os }} 15 | 16 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 17 | 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | config: 22 | - {os: macOS-latest, r: 'release'} 23 | - {os: windows-latest, r: 'release'} 24 | - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} 25 | - {os: ubuntu-latest, r: 'release'} 26 | - {os: ubuntu-latest, r: 'oldrel-1'} 27 | env: 28 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 29 | R_KEEP_PKG_SOURCE: yes 30 | steps: 31 | - uses: actions/checkout@v3 32 | 33 | - uses: r-lib/actions/setup-r@v2 34 | with: 35 | use-public-rspm: true 36 | 37 | - uses: r-lib/actions/setup-r-dependencies@v2 38 | with: 39 | extra-packages: any::rcmdcheck 40 | needs: check 41 | 42 | - uses: r-lib/actions/check-r-package@v2 43 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | branches: [main, master] 8 | release: 9 | types: [published] 10 | workflow_dispatch: 11 | 12 | name: pkgdown 13 | 14 | jobs: 15 | pkgdown: 16 | runs-on: ubuntu-latest 17 | # Only restrict concurrency for non-PR jobs 18 | concurrency: 19 | group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }} 20 | env: 21 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 22 | steps: 23 | - uses: actions/checkout@v2 24 | 25 | - uses: r-lib/actions/setup-pandoc@v2 26 | 27 | - uses: r-lib/actions/setup-r@v2 28 | with: 29 | use-public-rspm: true 30 | 31 | - uses: r-lib/actions/setup-r-dependencies@v2 32 | with: 33 | extra-packages: any::pkgdown, local::. 34 | needs: website 35 | 36 | - name: Install DrWhy theme 37 | run: | 38 | install.packages("remotes") 39 | remotes::install_deps(dependencies = TRUE) 40 | install.packages("future.apply") 41 | remotes::install_github("ModelOriented/DrWhyTemplate") 42 | shell: Rscript {0} 43 | 44 | - name: Build site 45 | run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE) 46 | shell: Rscript {0} 47 | 48 | - name: Deploy to GitHub pages 🚀 49 | if: github.event_name != 'pull_request' 50 | uses: JamesIves/github-pages-deploy-action@4.1.4 51 | with: 52 | clean: false 53 | branch: gh-pages 54 | folder: docs 55 | -------------------------------------------------------------------------------- /.github/workflows/pr-commands.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | issue_comment: 3 | types: [created] 4 | name: Commands 5 | jobs: 6 | document: 7 | if: startsWith(github.event.comment.body, '/document') 8 | name: document 9 | runs-on: macOS-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: r-lib/actions/pr-fetch@v2 13 | with: 14 | repo-token: ${{ secrets.GITHUB_TOKEN }} 15 | - uses: r-lib/actions/setup-r@v2 16 | - name: Install dependencies 17 | run: Rscript -e 'install.packages(c("remotes", "roxygen2"))' -e 'remotes::install_deps(dependencies = TRUE)' 18 | - name: Document 19 | run: Rscript -e 'roxygen2::roxygenise()' 20 | - name: commit 21 | run: | 22 | git add man/\* NAMESPACE 23 | git commit -m 'Document' 24 | - uses: r-lib/actions/pr-push@v2 25 | with: 26 | repo-token: ${{ secrets.GITHUB_TOKEN }} 27 | style: 28 | if: startsWith(github.event.comment.body, '/style') 29 | name: style 30 | runs-on: macOS-latest 31 | steps: 32 | - uses: actions/checkout@v2 33 | - uses: r-lib/actions/pr-fetch@v2 34 | with: 35 | repo-token: ${{ secrets.GITHUB_TOKEN }} 36 | - uses: r-lib/actions/setup-r@v2 37 | - name: Install dependencies 38 | run: Rscript -e 'install.packages("styler")' 39 | - name: Style 40 | run: Rscript -e 'styler::style_pkg()' 41 | - name: commit 42 | run: | 43 | git add \*.R 44 | git commit -m 'Style' 45 | - uses: r-lib/actions/pr-push@v2 46 | with: 47 | repo-token: ${{ secrets.GITHUB_TOKEN }} 48 | # A mock job just to ensure we have a successful build status 49 | finish: 50 | runs-on: ubuntu-latest 51 | steps: 52 | - run: true 53 | -------------------------------------------------------------------------------- /.github/workflows/test-coverage.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - master 5 | - 'issue*' 6 | - 'gh-actions' 7 | - 'githubactions' 8 | pull_request: 9 | branches: 10 | - master 11 | 12 | name: test-coverage 13 | 14 | jobs: 15 | test-coverage: 16 | runs-on: macOS-latest 17 | env: 18 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 19 | steps: 20 | - uses: actions/checkout@v2 21 | 22 | - uses: r-lib/actions/setup-r@v2 23 | 24 | - uses: r-lib/actions/setup-pandoc@v2 25 | 26 | - name: Query dependencies 27 | run: | 28 | install.packages('remotes') 29 | saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2) 30 | shell: Rscript {0} 31 | 32 | - uses: r-lib/actions/setup-r-dependencies@v2 33 | with: 34 | extra-packages: any::covr 35 | needs: coverage 36 | 37 | - name: Test coverage 38 | run: covr::codecov(quiet = FALSE) 39 | shell: Rscript {0} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .RData 4 | .Ruserdata 5 | src/*.o 6 | src/*.so 7 | src/*.dll 8 | CRAN-SUBMISSION 9 | .DS_Store 10 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: treeshap 2 | Title: Compute SHAP Values for Your Tree-Based Models Using the 'TreeSHAP' 3 | Algorithm 4 | Version: 0.3.1.9000 5 | Authors@R: c( 6 | person("Konrad", "Komisarczyk", email = "komisarczykkonrad@gmail.com", role = "aut"), 7 | person("Pawel", "Kozminski", email = "pkozminski99@gmail.com", role = "aut"), 8 | person("Szymon", "Maksymiuk", role = "aut", comment = c(ORCID = "0000-0002-3120-1601")), 9 | person("Lorenz A.", "Kapsner", role = "ctb", comment = c(ORCID = "0000-0003-1866-860X")), 10 | person("Mikolaj", "Spytek", role = "ctb", comment = c(ORCID = "0000-0001-7111-2286")), 11 | person("Mateusz", "Krzyzinski", email = "krzyzinskimateusz23@gmail.com", role = c("ctb", "cre"), comment = c(ORCID = "0000-0001-6143-488X")), 12 | person("Przemyslaw", "Biecek", email = "przemyslaw.biecek@gmail.com", role = c("aut", "cph"), comment = c(ORCID = "0000-0001-8423-1823")) 13 | ) 14 | Description: An efficient implementation of the 'TreeSHAP' algorithm 15 | introduced by Lundberg et al., (2020) . 16 | It is capable of calculating SHAP (SHapley Additive exPlanations) values 17 | for tree-based models in polynomial time. Currently supported models include 18 | 'gbm', 'randomForest', 'ranger', 'xgboost', 'lightgbm'. 19 | License: GPL-3 20 | URL: https://modeloriented.github.io/treeshap/, 21 | https://github.com/ModelOriented/treeshap 22 | BugReports: https://github.com/ModelOriented/treeshap/issues 23 | Depends: 24 | R (>= 2.10) 25 | Imports: 26 | data.table, 27 | ggplot2, 28 | Rcpp 29 | Suggests: 30 | gbm, 31 | jsonlite, 32 | lightgbm, 33 | randomForest, 34 | ranger, 35 | scales, 36 | survival, 37 | testthat, 38 | xgboost 39 | LinkingTo: 40 | Rcpp 41 | Encoding: UTF-8 42 | LazyData: true 43 | Roxygen: list(markdown = TRUE) 44 | RoxygenNote: 7.2.3 45 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method(predict,model_unified) 4 | S3method(print,model_unified) 5 | S3method(print,model_unified_multioutput) 6 | S3method(print,treeshap) 7 | S3method(print,treeshap_multioutput) 8 | S3method(treeshap,model_unified) 9 | S3method(treeshap,model_unified_multioutput) 10 | S3method(unify,default) 11 | S3method(unify,gbm) 12 | S3method(unify,lgb.Booster) 13 | S3method(unify,randomForest) 14 | S3method(unify,ranger) 15 | S3method(unify,xgb.Booster) 16 | export(colors_breakdown_drwhy) 17 | export(colors_discrete_drwhy) 18 | export(gbm.unify) 19 | export(is.model_unified) 20 | export(is.treeshap) 21 | export(lightgbm.unify) 22 | export(plot_contribution) 23 | export(plot_feature_dependence) 24 | export(plot_feature_importance) 25 | export(plot_interaction) 26 | export(randomForest.unify) 27 | export(ranger.unify) 28 | export(ranger_surv.unify) 29 | export(set_reference_dataset) 30 | export(theme_drwhy) 31 | export(theme_drwhy_vertical) 32 | export(treeshap) 33 | export(unify) 34 | export(xgboost.unify) 35 | import(data.table) 36 | import(ggplot2) 37 | importFrom(Rcpp,sourceCpp) 38 | importFrom(graphics,text) 39 | importFrom(stats,reorder) 40 | importFrom(stats,stepfun) 41 | importFrom(utils,setTxtProgressBar) 42 | importFrom(utils,txtProgressBar) 43 | useDynLib(treeshap) 44 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # treeshap (development version) 2 | 3 | # treeshap 0.3.1 4 | * Fixed code examples in `lightgbm.unify`. 5 | 6 | # treeshap 0.3.0 7 | * Fixed `ranger_surv.unify` operation for predictions in form of survival and cumulative hazard functions. 8 | * Added `model_unified_multioutput` and `treeshap_multioutput` classes for multi-output models and their explanations. 9 | * Improved documentation of `ranger_surv.unify`. 10 | 11 | # treeshap 0.2.5 12 | * Removed `catboost.unify` function (as the `catboost` package is not on CRAN); it is available on a separate branch 13 | * Fixed `randomForest.unify` for classifiers ([#12](https://github.com/ModelOriented/treeshap/issues/12), [#23](https://github.com/ModelOriented/treeshap/issues/23)) 14 | * Implemented consolidated (generic) `unify` function ([#18](https://github.com/ModelOriented/treeshap/issues/18)) 15 | * An error is thrown when the data passed to the `unify` or `treeshap` functions contain variables that are not used by the model ([#14](https://github.com/ModelOriented/treeshap/issues/14)) 16 | * Added implementation for random survival forests created using `ranger` ([#22](https://github.com/ModelOriented/treeshap/pull/22), [#26](https://github.com/ModelOriented/treeshap/pull/26)) 17 | * Fixed GitHub Actions, check and test issues ([#25](https://github.com/ModelOriented/treeshap/pull/25), [#29](https://github.com/ModelOriented/treeshap/pull/29)) 18 | * Fixed issues with documentation and examples 19 | * Changed use of bitwise '|' to logical '||' with boolean operands in C++ files 20 | 21 | # treeshap 0.1.1 22 | * Fixed `plot_contribution` when `max_vars` is larger than the number of variables ([#16](https://github.com/ModelOriented/treeshap/issues/16)) 23 | 24 | # treeshap 0.1.0 25 | * Rebuilded treeshap function so it now stores observations and whole dataset 26 | * Rebuilded all unifiers so they require passing data. 27 | 28 | # treeshap 0.0.1 29 | * Made package pass all checks 30 | * Fixed infinite recursion issue in ranger ([see commit](https://github.com/ModelOriented/treeshap/commit/eff70d8095932128151fb4c015fd61b89635aa9e)) 31 | * If there is no missing value in the model, unifiers return `NA` for `Missing` column ([see commit](https://github.com/ModelOriented/treeshap/commit/eff70d8095932128151fb4c015fd61b89635aa9e)) 32 | 33 | # treeshap 0.0.0.9000 34 | * treeshap is now public 35 | * Implemented fast computations of tree ensemble shap values in C++ 36 | * Implemented unifiers for catboost, lightgbm, xgboost, gbm, ranger and randomForest 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /R/RcppExports.R: -------------------------------------------------------------------------------- 1 | # Generated by using Rcpp::compileAttributes() -> do not edit by hand 2 | # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 3 | 4 | predict_cpp <- function(x, is_na, roots, yes, no, missing, is_leaf, feature, split, decision_type, value) { 5 | .Call('_treeshap_predict_cpp', PACKAGE = 'treeshap', x, is_na, roots, yes, no, missing, is_leaf, feature, split, decision_type, value) 6 | } 7 | 8 | new_covers <- function(x, is_na, roots, yes, no, missing, is_leaf, feature, split, decision_type) { 9 | .Call('_treeshap_new_covers', PACKAGE = 'treeshap', x, is_na, roots, yes, no, missing, is_leaf, feature, split, decision_type) 10 | } 11 | 12 | treeshap_cpp <- function(x, is_na, roots, yes, no, missing, feature, split, decision_type, is_leaf, value, cover, verbose) { 13 | .Call('_treeshap_treeshap_cpp', PACKAGE = 'treeshap', x, is_na, roots, yes, no, missing, feature, split, decision_type, is_leaf, value, cover, verbose) 14 | } 15 | 16 | treeshap_interactions_cpp <- function(x, is_na, roots, yes, no, missing, feature, split, decision_type, is_leaf, value, cover, verbose) { 17 | .Call('_treeshap_treeshap_interactions_cpp', PACKAGE = 'treeshap', x, is_na, roots, yes, no, missing, feature, split, decision_type, is_leaf, value, cover, verbose) 18 | } 19 | 20 | -------------------------------------------------------------------------------- /R/fifa20.R: -------------------------------------------------------------------------------- 1 | #' Attributes of all players in FIFA 20 2 | #' 3 | #' Dataset consists of 56 columns, 55 numeric and one of type factor \code{'work_rate'}. 4 | #' \code{value_eur} is a potential target feature. 5 | #' 6 | #' @format A data frame with 18278 rows and 56 columns. 7 | #' Most of variables representing skills are in range from 0 to 100 and will not be described here. 8 | #' To list non obvious features: 9 | #' \describe{ 10 | #' \item{overall}{Overall score of player's skills} 11 | #' \item{potential}{Potential of a player, younger players tend to have higher level of potential} 12 | #' \item{value_eur}{Market value of a player (in mln EUR)} 13 | #' \item{international_reputation}{Range 1 to 5} 14 | #' \item{weak_foot}{Range 1 to 5} 15 | #' \item{skill_moves}{Range 1 to 5} 16 | #' \item{work_rate}{Divided by slash levels of willingness to work in offense and defense respectively} 17 | #' } 18 | #' 19 | #'@source 20 | #'"Data has been scraped from the publicly available website \url{https://sofifa.com}" 21 | #'\url{https://www.kaggle.com/stefanoleone992/fifa-20-complete-player-dataset} 22 | #' 23 | "fifa20" 24 | -------------------------------------------------------------------------------- /R/globals.R: -------------------------------------------------------------------------------- 1 | globalVariables(c("Feature", "split_index", "tree_index", ".", "node_parent", "default_left", "decision_type", "position", "cumulative", 2 | "prev", "text", "contribution", "var_value", "shap_value", "reorder", "variable", "importance", "Tree", "Node", "Missing", 3 | "Cover", "Yes", "No", 'Prediction', 'Decision.type')) 4 | -------------------------------------------------------------------------------- /R/model_unified.R: -------------------------------------------------------------------------------- 1 | #' Unified model representation 2 | #' 3 | #' \code{model_unified} object produced by \code{*.unify} or \code{unify} function. 4 | #' 5 | #' @return List consisting of two elements: 6 | #' 7 | #' 8 | #' \strong{model} - A \code{data.frame} representing model with following columns: 9 | #' 10 | #' \item{Tree}{0-indexed ID of a tree} 11 | #' \item{Node}{0-indexed ID of a node in a tree. In a tree the root always has ID 0} 12 | #' \item{Feature}{In case of an internal node - name of a feature to split on. Otherwise - NA} 13 | #' \item{Decision.type}{A factor with two levels: "<" and "<=". In case of an internal node - predicate used for splitting observations. Otherwise - NA} 14 | #' \item{Split}{For internal nodes threshold used for splitting observations. All observations that satisfy the predicate Decision.type(Split) ('< Split' / '<= Split') are proceeded to the node marked as 'Yes'. Otherwise to the 'No' node. For leaves - NA} 15 | #' \item{Yes}{Index of a row containing a child Node. Thanks to explicit indicating the row it is much faster to move between nodes} 16 | #' \item{No}{Index of a row containing a child Node} 17 | #' \item{Missing}{Index of a row containing a child Node where are proceeded all observations with no value of the dividing feature} 18 | #' \item{Prediction}{For leaves: Value of prediction in the leaf. For internal nodes: NA} 19 | #' \item{Cover}{Number of observations seen by the internal node or collected by the leaf for the reference dataset} 20 | #' 21 | #' \strong{data} - Dataset used as a reference for calculating SHAP values. A dataset passed to the \code{*.unify}, \code{unify} or \code{\link{set_reference_dataset}} function with \code{data} argument. A \code{data.frame}. 22 | #' 23 | #' 24 | #' Object has two also attributes set: 25 | #' \item{\code{model}}{A string. By what package the model was produced.} 26 | #' \item{\code{missing_support}}{A boolean. Whether the model allows missing values to be present in explained dataset.} 27 | #' 28 | #' 29 | #' @seealso 30 | #' \code{\link{unify}} 31 | #' 32 | #' 33 | #' @name model_unified.object 34 | #' 35 | NULL 36 | 37 | 38 | #' Unified model representations for multi-output model 39 | #' 40 | #' \code{model_unified_multioutput} object produced by \code{*.unify} or \code{unify} function. 41 | #' 42 | #' @return List consisting of \code{model_unified} objects, one for each individual output of a model. For survival models, the list is named using the time points, for which predictions are calculated. 43 | #' 44 | #' @seealso 45 | #' \code{\link{unify}} 46 | #' 47 | #' 48 | #' @name model_unified_multioutput.object 49 | #' 50 | NULL 51 | 52 | 53 | #' Prints model_unified objects 54 | #' 55 | #' @param x a model_unified object 56 | #' @param ... other arguments 57 | #' 58 | #' @return No return value, called for printing 59 | #' 60 | #' @export 61 | #' 62 | print.model_unified <- function(x, ...){ 63 | print(x$model) 64 | return(invisible(NULL)) 65 | } 66 | 67 | 68 | #' Prints model_unified_multioutput objects 69 | #' 70 | #' @param x a model_unified_multioutput object 71 | #' @param ... other arguments 72 | #' 73 | #' @return No return value, called for printing 74 | #' 75 | #' @export 76 | #' 77 | print.model_unified_multioutput <- function(x, ...){ 78 | output_names <- names(x) 79 | lapply(output_names, function(output_name){ 80 | cat(paste("-> for output:", output_name, "\n")) 81 | print(x[[output_name]]) 82 | cat("\n") 83 | }) 84 | return(invisible(NULL)) 85 | } 86 | 87 | 88 | #' Check whether object is a valid model_unified object 89 | #' 90 | #' Does not check correctness of representation, only basic checks 91 | #' 92 | #' @param x an object to check 93 | #' 94 | #' @return boolean 95 | #' 96 | #' @export 97 | #' 98 | is.model_unified <- function(x) { 99 | # class checks 100 | ("model_unified" %in% class(x)) & 101 | is.data.frame(x$data) & 102 | is.data.frame(x$model) & 103 | # attributes check 104 | is.character(attr(x, "model")) & 105 | is.logical(attr(x, "missing_support")) & 106 | # colnames check 107 | all(c("Tree", "Node", "Feature", "Decision.type", "Split", "Yes", "No", "Missing", "Prediction", "Cover") %in% colnames(x$model)) & 108 | # column types check 109 | is.numeric(x$model$Tree) & 110 | is.numeric(x$model$Node) & 111 | is.character(x$model$Feature) & 112 | is.factor(x$model$Decision.type) & 113 | all(levels(x$model$Decision.type) == c("<=", "<")) & 114 | all(unclass(x$model$Decision.type) %in% c(1, 2, NA)) & 115 | is.numeric(x$model$Split) & 116 | is.numeric(x$model$Yes) & 117 | is.numeric(x$model$No) & 118 | (!attr(x, "missing_support") | is.numeric(x$model$Missing)) & 119 | is.numeric(x$model$Prediction) & 120 | is.numeric(x$model$Cover) 121 | } 122 | -------------------------------------------------------------------------------- /R/plot_contribution.R: -------------------------------------------------------------------------------- 1 | #' SHAP value based Break-Down plot 2 | #' 3 | #' This function plots contributions of features into the prediction for a single observation. 4 | #' 5 | #' @param treeshap A treeshap object produced with the \code{\link{treeshap}} function. \code{\link{treeshap.object}}. 6 | #' @param obs A numeric indicating which observation should be plotted. Be default it's first observation. 7 | #' @param max_vars maximum number of variables that shall be presented. Variables with the highest importance will be presented. 8 | #' Remaining variables will be summed into one additional contribution. By default \code{5}. 9 | #' @param min_max a range of OX axis. By default \code{NA}, therefore it will be extracted from the contributions of \code{x}. 10 | #' But it can be set to some constants, useful if these plots are to be used for comparisons. 11 | #' @param digits number of decimal places (\code{\link{round}}) to be used. 12 | #' @param explain_deviation if \code{TRUE} then instead of explaining prediction and plotting intercept bar, only deviation from mean prediction of the reference dataset will be explained. By default \code{FALSE}. 13 | #' @param title the plot's title, by default \code{'SHAP Break-Down'}. 14 | #' @param subtitle the plot's subtitle. By default no subtitle. 15 | #' 16 | #' @return a \code{ggplot2} object 17 | #' 18 | #' @export 19 | #' 20 | #' @import ggplot2 21 | #' 22 | #' @seealso 23 | #' \code{\link{treeshap}} for calculation of SHAP values 24 | #' 25 | #' \code{\link{plot_feature_importance}}, \code{\link{plot_feature_dependence}}, \code{\link{plot_interaction}} 26 | #' 27 | #' 28 | #' @examples 29 | #' \donttest{ 30 | #' library(xgboost) 31 | #' data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 32 | #' target <- fifa20$target 33 | #' param <- list(objective = "reg:squarederror", max_depth = 3) 34 | #' xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, 35 | #' nrounds = 20, verbose = FALSE) 36 | #' unified_model <- xgboost.unify(xgb_model, as.matrix(data)) 37 | #' x <- head(data, 1) 38 | #' shap <- treeshap(unified_model, x) 39 | #' plot_contribution(shap, 1, min_max = c(0, 120000000)) 40 | #' } 41 | plot_contribution <- function(treeshap, 42 | obs = 1, 43 | max_vars = 5, 44 | min_max = NA, 45 | digits = 3, 46 | explain_deviation = FALSE, 47 | title = "SHAP Break-Down", 48 | subtitle = "") { 49 | 50 | shap <- treeshap$shaps[obs, ] 51 | model <- treeshap$unified_model$model 52 | x <- treeshap$observations[obs, ] 53 | 54 | # argument check 55 | if (!is.treeshap(treeshap)) { 56 | stop("treeshap parameter has to be correct object of class treeshap. Produce it using treeshap function.") 57 | } 58 | 59 | if (max_vars > ncol(shap)) { 60 | warning("max_vars exceeds number of variables. All variables will be shown.") 61 | max_vars <- ncol(shap) 62 | } 63 | if (nrow(shap) != 1) { 64 | warning("Only 1 observation can be plotted. Plotting 1st one.") 65 | shap <- shap[1, ] 66 | } 67 | 68 | # setting intercept 69 | mean_prediction <- mean(predict.model_unified(treeshap$unified_model, treeshap$unified_model$data)) 70 | if (explain_deviation) { 71 | mean_prediction <- 0 72 | } 73 | 74 | df <- data.frame(variable = colnames(shap), contribution = as.numeric(shap)) 75 | 76 | # setting variable names to showing their value 77 | df$variable <- paste0(df$variable, " = ", as.character(x)) 78 | 79 | # selecting max_vars most important variables 80 | is_important <- order(abs(df$contribution), decreasing = TRUE)[1:max_vars] 81 | other_variables_contribution_sum <- sum(df$contribution[-is_important]) 82 | df <- df[is_important, ] 83 | df$position <- 2:(max_vars + 1) 84 | if (max_vars < ncol(shap)) { 85 | df <- rbind(df, data.frame(variable = "+ all other variables", 86 | contribution = other_variables_contribution_sum, 87 | position = max(df$position) + 1)) 88 | } 89 | 90 | # adding "prediction" bar 91 | df <- rbind(df, data.frame(variable = ifelse(explain_deviation, "prediction deviation", "prediction"), 92 | contribution = mean_prediction + sum(df$contribution), 93 | position = max(df$position) + 1)) 94 | 95 | df$sign <- ifelse(df$contribution >= 0, "1", "-1") 96 | 97 | # adding "intercept" bar 98 | df <- rbind(df, data.frame(variable = "intercept", 99 | contribution = mean_prediction, 100 | position = 1, 101 | sign = "X")) 102 | 103 | # ordering 104 | df <- df[order(df$position), ] 105 | 106 | # adding columns needed by plot 107 | df$cumulative <- cumsum(df$contribution) 108 | df$prev <- df$cumulative - df$contribution 109 | df$text <- as.character(round(df$contribution, digits)) 110 | df$text[df$contribution > 0] <- paste0("+", df$text[df$contribution > 0]) 111 | 112 | # intercept bar corrections: 113 | df$prev[1] <- df$contribution[1] 114 | df$text[1] <- as.character(round(df$contribution[1], digits)) 115 | 116 | # prediction bar corrections: 117 | df$prev[nrow(df)] <- df$contribution[1] 118 | df$cumulative[nrow(df)] <- df$cumulative[max_vars + 2] 119 | if (!explain_deviation) { # assuring it doesn't differ from prediction because of some numeric errors 120 | df$cumulative[nrow(df)] <- predict.model_unified(treeshap$unified_model, x) 121 | } 122 | df$sign[nrow(df)] <- "X" 123 | df$text[nrow(df)] <- as.character(round(df$contribution[nrow(df)], digits)) 124 | 125 | # removing intercept bar if requested by explain_deviation argument 126 | if (explain_deviation) { 127 | df <- df[-1, ] 128 | } 129 | 130 | # reversing postions to sort bars decreasing 131 | df$position <- rev(df$position) 132 | 133 | # base plot 134 | p <- ggplot(df, aes(x = position + 0.5, 135 | y = pmax(cumulative, prev), 136 | xmin = position + 0.15, xmax = position + 0.85, 137 | ymin = cumulative, ymax = prev, 138 | fill = sign, 139 | label = text)) 140 | 141 | # add rectangles and hline 142 | p <- p + 143 | geom_errorbarh(data = df[-c(nrow(df), if (explain_deviation) nrow(df) - 1), ], 144 | aes(xmax = position - 0.85, 145 | xmin = position + 0.85, 146 | y = cumulative), height = 0, 147 | color = "#371ea3") + 148 | geom_rect(alpha = 0.9) + 149 | if (!explain_deviation) (geom_hline(data = df[df$variable == "intercept", ], 150 | aes(yintercept = contribution), 151 | lty = 3, alpha = 0.5, color = "#371ea3")) 152 | 153 | 154 | # add adnotations 155 | drange <- diff(range(df$cumulative)) 156 | p <- p + geom_text(aes(y = pmax(cumulative, cumulative - contribution)), 157 | vjust = 0.5, 158 | nudge_y = drange * 0.05, 159 | hjust = 0, 160 | color = "#371ea3") 161 | 162 | # set limits for contributions 163 | if (any(is.na(min_max))) { 164 | x_limits <- scale_y_continuous(expand = c(0.05, 0.15), name = "", labels = scales::comma) 165 | } else { 166 | x_limits <- scale_y_continuous(expand = c(0.05, 0.15), name = "", limits = min_max, labels = scales::comma) 167 | } 168 | 169 | p <- p + x_limits + 170 | scale_x_continuous(labels = df$variable, breaks = df$position + 0.5, name = "") + 171 | scale_fill_manual(values = colors_breakdown_drwhy()) 172 | 173 | # add theme 174 | p + coord_flip() + theme_drwhy_vertical() + 175 | theme(legend.position = "none") + 176 | labs(title = title, subtitle = subtitle) 177 | } 178 | -------------------------------------------------------------------------------- /R/plot_feature_dependence.R: -------------------------------------------------------------------------------- 1 | #' SHAP value based Feature Dependence plot 2 | #' 3 | #' Depending on the value of a variable: how does it contribute into the prediction? 4 | #' 5 | #' @param treeshap A treeshap object produced with the \code{\link{treeshap}} function. \code{\link{treeshap.object}}. 6 | #' @param variable name or index of variable for which feature dependence will be plotted. 7 | #' @param title the plot's title, by default \code{'Feature Dependence'}. 8 | #' @param subtitle the plot's subtitle. By default no subtitle. 9 | #' 10 | #' @return a \code{ggplot2} object 11 | #' 12 | #' @export 13 | #' 14 | #' @import ggplot2 15 | #' 16 | #' @seealso 17 | #' \code{\link{treeshap}} for calculation of SHAP values 18 | #' 19 | #' \code{\link{plot_contribution}}, \code{\link{plot_feature_importance}}, \code{\link{plot_interaction}} 20 | #' 21 | #' 22 | #' @examples 23 | #' \donttest{ 24 | #' library(xgboost) 25 | #' data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 26 | #' target <- fifa20$target 27 | #' param <- list(objective = "reg:squarederror", max_depth = 3) 28 | #' xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, 29 | #' nrounds = 20, verbose = FALSE) 30 | #' unified_model <- xgboost.unify(xgb_model, as.matrix(data)) 31 | #' x <- head(data, 100) 32 | #' shaps <- treeshap(unified_model, x) 33 | #' plot_feature_dependence(shaps, variable = "overall") 34 | #' } 35 | plot_feature_dependence <- function(treeshap, variable, 36 | title = "Feature Dependence", subtitle = NULL) { 37 | 38 | shaps <- treeshap$shaps 39 | x <- treeshap$observations 40 | 41 | # argument check 42 | if (!is.treeshap(treeshap)) { 43 | stop("treeshap parameter has to be correct object of class treeshap. Produce it using treeshap function.") 44 | } 45 | 46 | if (is.character(variable)) { 47 | if (!(variable %in% colnames(shaps))) { 48 | stop("Incorrect variable or shaps dataframe, variable should be one of variables in the shaps dataframe.") 49 | } 50 | if (!(variable %in% colnames(shaps))) { 51 | stop("Incorrect variable or x dataframe, varaible should be one of variables in the shaps dataframe.") 52 | } 53 | } else if (is.numeric(variable) && (length(variable) == 1)) { 54 | if (!all(colnames(shaps) == colnames(x))) { 55 | stop("shaps and x should have the same column names.") 56 | } 57 | if (!(variable %in% 1:ncol(shaps))) { 58 | stop("variable is an incorrect number.") 59 | } 60 | variable <- colnames(shaps)[variable] 61 | } else { 62 | stop("variable is of incorrect type.") 63 | } 64 | 65 | df <- data.frame(var_value = x[, variable], shap_value = shaps[, variable ]) 66 | p <- ggplot(df, aes(x = var_value, y = shap_value)) + 67 | geom_point() 68 | 69 | p + 70 | theme_drwhy() + 71 | xlab(variable) + ylab(paste0("SHAP value for ", variable)) + 72 | labs(title = title, subtitle = subtitle) + 73 | scale_y_continuous(labels = scales::comma) 74 | } 75 | -------------------------------------------------------------------------------- /R/plot_feature_importance.R: -------------------------------------------------------------------------------- 1 | ## plotting functions for treeshap package 2 | 3 | #' SHAP value based Feature Importance plot 4 | #' 5 | #' This function plots feature importance calculated as means of absolute values of SHAP values of variables (average impact on model output magnitude). 6 | #' 7 | #' @param treeshap A treeshap object produced with the \code{\link{treeshap}} function. \code{\link{treeshap.object}}. 8 | #' @param desc_sorting logical. Should the bars be sorted descending? By default TRUE. 9 | #' @param max_vars maximum number of variables that shall be presented. By default all are presented. 10 | #' @param title the plot's title, by default \code{'Feature Importance'}. 11 | #' @param subtitle the plot's subtitle. By default no subtitle. 12 | #' 13 | #' @return a \code{ggplot2} object 14 | #' 15 | #' @export 16 | #' @import ggplot2 17 | #' @importFrom stats reorder 18 | #' @importFrom graphics text 19 | #' 20 | #' @seealso 21 | #' \code{\link{treeshap}} for calculation of SHAP values 22 | #' 23 | #' \code{\link{plot_contribution}}, \code{\link{plot_feature_dependence}}, \code{\link{plot_interaction}} 24 | #' 25 | #' 26 | #' @examples 27 | #' \donttest{ 28 | #' library(xgboost) 29 | #' data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 30 | #' target <- fifa20$target 31 | #' param <- list(objective = "reg:squarederror", max_depth = 3) 32 | #' xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, 33 | #' nrounds = 20, verbose = FALSE) 34 | #' unified_model <- xgboost.unify(xgb_model, as.matrix(data)) 35 | #' shaps <- treeshap(unified_model, as.matrix(head(data, 3))) 36 | #' plot_feature_importance(shaps, max_vars = 4) 37 | #' } 38 | plot_feature_importance <- function(treeshap, 39 | desc_sorting = TRUE, 40 | max_vars = ncol(shaps), 41 | title = "Feature Importance", 42 | subtitle = NULL) { 43 | shaps <- treeshap$shaps 44 | 45 | # argument check 46 | if (!is.treeshap(treeshap)) { 47 | stop("treeshap parameter has to be correct object of class treeshap. Produce it using treeshap function.") 48 | } 49 | 50 | if (!is.logical(desc_sorting)) { 51 | stop("desc_sorting is not logical.") 52 | } 53 | 54 | if (!is.numeric(max_vars)) { 55 | stop("max_vars is not numeric.") 56 | } 57 | 58 | if (max_vars > ncol(shaps)) { 59 | warning("max_vars exceeded number of explained variables. All variables will be shown.") 60 | max_vars <- ncol(shaps) 61 | } 62 | 63 | mean <- colMeans(abs(shaps)) 64 | df <- data.frame(variable = factor(names(mean)), importance = as.vector(mean)) 65 | df$variable <- reorder(df$variable, df$importance * ifelse(desc_sorting, 1, -1)) 66 | df <- df[order(df$importance, decreasing = TRUE)[1:max_vars], ] 67 | 68 | p <- ggplot(df, aes(x = variable, y = importance)) + 69 | geom_bar(stat = "identity", fill = colors_discrete_drwhy(1)) 70 | 71 | p + coord_flip() + 72 | theme_drwhy_vertical() + 73 | ylab("mean(|SHAP value|)") + xlab("") + 74 | labs(title = title, subtitle = subtitle) + 75 | scale_y_continuous(labels = scales::comma) + 76 | theme(legend.position = "none") 77 | } 78 | -------------------------------------------------------------------------------- /R/plot_interaction.R: -------------------------------------------------------------------------------- 1 | #' SHAP Interaction value plot 2 | #' 3 | #' This function plots SHAP Interaction value for two variables depending on the value of the first variable. 4 | #' Value of the second variable is marked with the color. 5 | #' 6 | #' @param treeshap A treeshap object produced with \code{\link{treeshap}(interactions = TRUE)} function. \code{\link{treeshap.object}}. 7 | #' @param var1 name or index of the first variable - plotted on x axis. 8 | #' @param var2 name or index of the second variable - marked with color. 9 | #' @param title the plot's title, by default \code{'SHAP Interaction Value Plot'}. 10 | #' @param subtitle the plot's subtitle. By default no subtitle. 11 | #' 12 | #' @return a \code{ggplot2} object 13 | #' 14 | #' @export 15 | #' 16 | #' @import ggplot2 17 | #' 18 | #' @seealso 19 | #' \code{\link{treeshap}} for calculation of SHAP Interaction values 20 | #' 21 | #' \code{\link{plot_contribution}}, \code{\link{plot_feature_importance}}, \code{\link{plot_feature_dependence}} 22 | #' 23 | #' 24 | #' @examples 25 | #' \donttest{ 26 | #' data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 27 | #' target <- fifa20$target 28 | #' param2 <- list(objective = "reg:squarederror", max_depth = 5) 29 | #' xgb_model2 <- xgboost::xgboost(as.matrix(data), params = param2, label = target, nrounds = 10) 30 | #' unified_model2 <- xgboost.unify(xgb_model2, data) 31 | #' inters <- treeshap(unified_model2, as.matrix(data[1:50, ]), interactions = TRUE) 32 | #' plot_interaction(inters, "dribbling", "defending") 33 | #' } 34 | plot_interaction <- function(treeshap, var1, var2, 35 | title = "SHAP Interaction Value Plot", 36 | subtitle = "") { 37 | 38 | interactions <- treeshap$interactions 39 | x <- treeshap$observations 40 | 41 | # argument check 42 | if (!is.treeshap(treeshap)) { 43 | stop("treeshap parameter has to be correct object of class treeshap. Produce it using treeshap function.") 44 | } 45 | 46 | if (is.null(interactions)) { 47 | stop("SHAP Interaction values were not calculated in treeshap object. You need to use treeshap(interactions = TRUE).") 48 | } 49 | 50 | if (is.character(var1)) { 51 | if (!(var1 %in% colnames(x))) stop("var1 is not a correct variable name. It does not occur in the dataset.") 52 | if (!(var1 %in% colnames(interactions))) stop("var1 is not a correct variable name. It does not occur in interactions object.") 53 | } else if (is.numeric(var1)) { 54 | if (var1 > ncol(x) || var1 < 1) stop("var1 is not a correct number.") 55 | } 56 | 57 | if (is.character(var2)) { 58 | if (!(var2 %in% colnames(x))) stop("var2 is not a correct variable name. It does not occur in the dataset.") 59 | if (!(var2 %in% colnames(interactions))) stop("var2 is not a correct variable name. It does not occur in interactions object.") 60 | } else if (is.numeric(var2)) { 61 | if (var2 > ncol(x) || var2 < 1) stop("var2 is not a correct number.") 62 | } 63 | 64 | 65 | interaction <- interactions[var1, var2, ] 66 | var1_value <- x[,var1] 67 | var2_value <- x[,var2] 68 | plot_data <- data.frame(var1_value = var1_value, var2_value = var2_value, interaction = interaction) 69 | 70 | x_lab <- ifelse(is.character(var1), var1, colnames(x)[var1]) 71 | col_lab <- ifelse(is.character(var2), var2, colnames(x)[var2]) 72 | y_lab <- paste0("SHAP Interaction value for ", x_lab, " and ", col_lab) 73 | 74 | p <- ggplot(plot_data, aes(x = var1_value, y = interaction, color = var2_value)) + 75 | geom_point() + 76 | labs(x = x_lab, color = col_lab, y = y_lab, title = title, subtitle = subtitle) + 77 | scale_y_continuous(labels = scales::comma) + 78 | theme_drwhy() 79 | p 80 | } 81 | -------------------------------------------------------------------------------- /R/predict.R: -------------------------------------------------------------------------------- 1 | #' Predict 2 | #' 3 | #' Predict using unified_model representation. 4 | #' 5 | #' @param object Unified model representation of the model created with a (model).unify function. \code{\link{model_unified.object}} 6 | #' @param x Observations to predict. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. 7 | #' @param ... other parameters 8 | #' 9 | #' @return a vector of predictions. 10 | #' 11 | #' @export 12 | #' 13 | #' @examples 14 | #' \donttest{ 15 | #' library(gbm) 16 | #' data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 17 | #' data['value_eur'] <- fifa20$target 18 | #' gbm_model <- gbm::gbm( 19 | #' formula = value_eur ~ ., 20 | #' data = data, 21 | #' distribution = "laplace", 22 | #' n.trees = 20, 23 | #' interaction.depth = 4, 24 | #' n.cores = 1) 25 | #' unified <- gbm.unify(gbm_model, data) 26 | #' predict(unified, data[2001:2005, ]) 27 | #' } 28 | predict.model_unified <- function(object, x, ...) { 29 | unified_model <- object 30 | model <- unified_model$model 31 | x <- as.data.frame(x) 32 | 33 | # argument check 34 | if (!is.model_unified(unified_model)) { 35 | stop("unified_model parameter has to of class model_unified. Produce it using *.unify function.") 36 | } 37 | 38 | if (!("matrix" %in% class(x) | "data.frame" %in% class(x))) { 39 | stop("x parameter has to be data.frame or matrix.") 40 | } 41 | 42 | if (!attr(unified_model, "missing_support") && any(is.na(x))) { 43 | stop("Given model does not work with missing values. Dataset x should not contain missing values.") 44 | } 45 | 46 | x <- x[,colnames(x) %in% unified_model$feature_names] 47 | 48 | if (!all(model$Feature %in% c(NA, colnames(x)))) { 49 | stop("Dataset does not contain all features occurring in the model.") 50 | } 51 | 52 | # adapting model representation to C++ and extracting from dataframe to vectors 53 | roots <- which(model$Node == 0) - 1 54 | yes <- model$Yes - 1 55 | no <- model$No - 1 56 | missing <- model$Missing - 1 57 | is_leaf <- is.na(model$Feature) 58 | feature <- match(model$Feature, colnames(x)) - 1 59 | split <- model$Split 60 | decision_type <- unclass(model$Decision.type) 61 | #stopifnot(levels(decision_type) == c("<=", "<")) 62 | #stopifnot(all(decision_type %in% c(1, 2, NA))) 63 | value <- model$Prediction 64 | 65 | n <- nrow(x) 66 | x <- as.data.frame(sapply(x, as.numeric)) 67 | if (n > 1) x <- t(x) 68 | 69 | is_na <- is.na(x) # needed, because dataframe passed to cpp somehow replaces missing values with random values 70 | 71 | predict_cpp(x, is_na, roots, yes, no, missing, is_leaf, feature, split, decision_type, value) 72 | } 73 | -------------------------------------------------------------------------------- /R/set_reference_dataset.R: -------------------------------------------------------------------------------- 1 | #' Set reference dataset 2 | #' 3 | #' Change a dataset used as reference for calculating SHAP values. 4 | #' Reference dataset is initially set with \code{data} argument in unifying function. 5 | #' Usually reference dataset is dataset used to train the model. 6 | #' Important property of reference dataset is that SHAP values for each observation add up to its deviation from mean prediction for a reference dataset. 7 | #' 8 | #' 9 | #' @param unified_model Unified model representation of the model created with a (model).unify function. (\code{\link{model_unified.object}}). 10 | #' @param x Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. 11 | #' 12 | #' @return \code{\link{model_unified.object}}. Unified representation of the model as created with a (model).unify function, 13 | #' but with changed reference dataset (Cover column containing updated values). 14 | #' 15 | #' @export 16 | #' 17 | #' @seealso 18 | #' \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} 19 | #' 20 | #' \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} 21 | #' 22 | #' \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} 23 | #' 24 | #' \code{\link{ranger.unify}} for \code{\link[ranger:ranger]{ranger models}} 25 | #' 26 | #' \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} 27 | #' 28 | #' @examples 29 | #' \donttest{ 30 | #' library(gbm) 31 | #' data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 32 | #' data['value_eur'] <- fifa20$target 33 | #' gbm_model <- gbm::gbm( 34 | #' formula = value_eur ~ ., 35 | #' data = data, 36 | #' distribution = "laplace", 37 | #' n.trees = 20, 38 | #' interaction.depth = 4, 39 | #' n.cores = 1) 40 | #' unified <- gbm.unify(gbm_model, data) 41 | #' set_reference_dataset(unified, data[200:700, ]) 42 | #' } 43 | set_reference_dataset <- function(unified_model, x) { 44 | model <- unified_model$model 45 | data <- x 46 | 47 | # argument check 48 | if (!("matrix" %in% class(x) | "data.frame" %in% class(x))) { 49 | stop("x parameter has to be data.frame or matrix.") 50 | } 51 | 52 | if (!("model_unified" %in% class(unified_model))) { 53 | stop("unified_model parameter has to of class model_unified. Produce it using *.unify function.") 54 | } 55 | 56 | if (!all(c("Tree", "Node", "Feature", "Decision.type", "Split", "Yes", "No", "Missing", "Prediction") %in% colnames(model))) { 57 | stop("Given model dataframe is not a correct unified dataframe representation. Use (model).unify function.") 58 | } 59 | 60 | if (!attr(unified_model, "missing_support") && any(is.na(x))) { 61 | stop("Given model does not work with missing values. Dataset x should not contain missing values.") 62 | } 63 | 64 | x <- x[,colnames(x) %in% unified_model$feature_names] 65 | 66 | if (!all(model$Feature %in% c(NA, colnames(x)))) { 67 | stop("Dataset does not contain all features occurring in the model.") 68 | } 69 | 70 | 71 | # adapting model representation to C++ and extracting from dataframe to vectors 72 | roots <- which(model$Node == 0) - 1 73 | yes <- model$Yes - 1 74 | no <- model$No - 1 75 | missing <- model$Missing - 1 76 | is_leaf <- is.na(model$Feature) 77 | feature <- match(model$Feature, colnames(x)) - 1 78 | split <- model$Split 79 | decision_type <- unclass(model$Decision.type) 80 | #stopifnot(levels(decision_type) == c("<=", "<")) 81 | #stopifnot(all(decision_type %in% c(1, 2, NA))) 82 | 83 | n <- nrow(x) 84 | x <- as.data.frame(sapply(x, as.numeric)) 85 | if (n > 1) x <- t(x) 86 | is_na <- is.na(x) # needed, because dataframe passed to cpp somehow replaces missing values with random values 87 | 88 | model$Cover <- new_covers(x, is_na, roots, yes, no, missing, is_leaf, feature, split, decision_type) 89 | 90 | ret <- list(model = as.data.frame(model), data = as.data.frame(data), feature_names = unified_model$feature_names) 91 | #attributes(ret) <- attributes(model_unified) 92 | class(ret) <- "model_unified" 93 | attr(ret, "missing_support") <- attr(unified_model, "missing_support") 94 | attr(ret, 'model') <- attr(unified_model, "model") 95 | 96 | return(ret) 97 | } 98 | -------------------------------------------------------------------------------- /R/theme_drwhy.R: -------------------------------------------------------------------------------- 1 | #' DrWhy Theme for ggplot objects 2 | #' 3 | #' @return theme for ggplot2 objects 4 | #' @export 5 | #' @rdname theme_drwhy 6 | theme_drwhy <- function() { 7 | theme_bw(base_line_size = 0) %+replace% 8 | theme(axis.ticks = element_blank(), legend.background = element_blank(), 9 | legend.key = element_blank(), panel.background = element_blank(), 10 | panel.border = element_blank(), strip.background = element_blank(), 11 | plot.background = element_blank(), complete = TRUE, 12 | legend.direction = "horizontal", legend.position = "top", 13 | axis.line.y = element_line(color = "white"), 14 | axis.ticks.y = element_line(color = "white"), 15 | #axis.line = element_line(color = "#371ea3", size = 0.5, linetype = 1), 16 | axis.title = element_text(color = "#371ea3"), 17 | plot.title = element_text(color = "#371ea3", size = 16, hjust = 0), 18 | plot.subtitle = element_text(color = "#371ea3", hjust = 0), 19 | axis.text = element_text(color = "#371ea3", size = 10), 20 | strip.text = element_text(color = "#371ea3", size = 12, hjust = 0), 21 | panel.grid.major.y = element_line(color = "grey90", size = 0.5, linetype = 1), 22 | panel.grid.minor.y = element_line(color = "grey90", size = 0.5, linetype = 1), 23 | panel.grid.minor.x = element_blank(), 24 | panel.grid.major.x = element_blank()) 25 | 26 | } 27 | 28 | #' @export 29 | #' @rdname theme_drwhy 30 | theme_drwhy_vertical <- function() { 31 | theme_bw(base_line_size = 0) %+replace% 32 | theme(axis.ticks = element_blank(), legend.background = element_blank(), 33 | legend.key = element_blank(), panel.background = element_blank(), 34 | panel.border = element_blank(), strip.background = element_blank(), 35 | plot.background = element_blank(), complete = TRUE, 36 | legend.direction = "horizontal", legend.position = "top", 37 | axis.line.x = element_line(color = "white"), 38 | axis.ticks.x = element_line(color = "white"), 39 | plot.title = element_text(color = "#371ea3", size = 16, hjust = 0), 40 | plot.subtitle = element_text(color = "#371ea3", hjust = 0), 41 | #axis.line = element_line(color = "#371ea3", size = 0.5, linetype = 1), 42 | axis.title = element_text(color = "#371ea3"), 43 | axis.text = element_text(color = "#371ea3", size = 10), 44 | strip.text = element_text(color = "#371ea3", size = 12, hjust = 0), 45 | panel.grid.major.x = element_line(color = "grey90", size = 0.5, linetype = 1), 46 | panel.grid.minor.x = element_line(color = "grey90", size = 0.5, linetype = 1), 47 | panel.grid.minor.y = element_blank(), 48 | panel.grid.major.y = element_blank()) 49 | } 50 | 51 | #' DrWhy color palettes for ggplot objects 52 | #' 53 | #' @param n number of colors for color palette 54 | #' 55 | #' @return color palette as vector of characters 56 | #' @export 57 | #' @rdname colors_drwhy 58 | colors_discrete_drwhy <- function(n = 2) { 59 | if (n == 1) return("#4378bf") 60 | if (n == 2) return(c( "#4378bf", "#8bdcbe")) 61 | if (n == 3) return(c( "#4378bf", "#f05a71", "#8bdcbe")) 62 | if (n == 4) return(c( "#4378bf", "#f05a71", "#8bdcbe", "#ffa58c")) 63 | if (n == 5) return(c( "#4378bf", "#f05a71", "#8bdcbe", "#ae2c87", "#ffa58c")) 64 | if (n == 6) return(c( "#4378bf", "#46bac2", "#8bdcbe", "#ae2c87", "#ffa58c", "#f05a71")) 65 | c( "#4378bf", "#46bac2", "#371ea3", "#8bdcbe", "#ae2c87", "#ffa58c", "#f05a71")[((0:(n-1)) %% 7) + 1] 66 | } 67 | 68 | #' @export 69 | #' @rdname colors_drwhy 70 | colors_breakdown_drwhy <- function() { 71 | c(`-1` = "#f05a71", `0` = "#371ea3", `1` = "#8bdcbe", X = "#371ea3") 72 | } 73 | -------------------------------------------------------------------------------- /R/treeshap.R: -------------------------------------------------------------------------------- 1 | #' Calculate SHAP values of a tree ensemble model. 2 | #' 3 | #' Calculate SHAP values and optionally SHAP Interaction values. 4 | #' 5 | #' 6 | #' @param unified_model Unified data.frame representation of the model created with a (model).unify function. A \code{\link{model_unified.object}} object. 7 | #' @param x Observations to be explained. A \code{data.frame} or \code{matrix} object with the same columns as in the training set of the model. Keep in mind that objects different than \code{data.frame} or plain \code{matrix} will cause an error or unpredictable behavior. 8 | #' @param interactions Whether to calculate SHAP interaction values. By default is \code{FALSE}. Basic SHAP values are always calculated. 9 | #' @param verbose Whether to print progress bar to the console. Should be logical. Progress bar will not be displayed on Windows. 10 | #' 11 | #' @return A \code{\link{treeshap.object}} object (for single-output models) or \code{\link{treeshap_multioutput.object}}, which is a list of \code{\link{treeshap.object}} objects (for multi-output models). SHAP values can be accessed from \code{\link{treeshap.object}} with \code{$shaps}, and interaction values can be accessed with \code{$interactions}. 12 | #' 13 | #' 14 | #' @export 15 | #' 16 | #' @importFrom Rcpp sourceCpp 17 | #' @importFrom utils setTxtProgressBar txtProgressBar 18 | #' @useDynLib treeshap 19 | #' 20 | #' @seealso 21 | #' \code{\link{xgboost.unify}} for \code{XGBoost models} 22 | #' \code{\link{lightgbm.unify}} for \code{LightGBM models} 23 | #' \code{\link{gbm.unify}} for \code{GBM models} 24 | #' \code{\link{randomForest.unify}} for \code{randomForest models} 25 | #' \code{\link{ranger.unify}} for \code{ranger models} 26 | #' \code{\link{ranger_surv.unify}} for \code{ranger survival models} 27 | #' 28 | #' @examples 29 | #' \donttest{ 30 | #' library(xgboost) 31 | #' data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 32 | #' target <- fifa20$target 33 | #' 34 | #' # calculating simple SHAP values 35 | #' param <- list(objective = "reg:squarederror", max_depth = 3) 36 | #' xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, 37 | #' nrounds = 20, verbose = FALSE) 38 | #' unified_model <- xgboost.unify(xgb_model, as.matrix(data)) 39 | #' treeshap1 <- treeshap(unified_model, head(data, 3)) 40 | #' plot_contribution(treeshap1, obs = 1) 41 | #' treeshap1$shaps 42 | #' 43 | #' # It's possible to calcualte explanation over different part of the data set 44 | #' 45 | #' unified_model_rec <- set_reference_dataset(unified_model, data[1:1000, ]) 46 | #' treeshap_rec <- treeshap(unified_model, head(data, 3)) 47 | #' plot_contribution(treeshap_rec, obs = 1) 48 | #' 49 | #' # calculating SHAP interaction values 50 | #' param2 <- list(objective = "reg:squarederror", max_depth = 7) 51 | #' xgb_model2 <- xgboost::xgboost(as.matrix(data), params = param2, label = target, nrounds = 10) 52 | #' unified_model2 <- xgboost.unify(xgb_model2, as.matrix(data)) 53 | #' treeshap2 <- treeshap(unified_model2, head(data, 3), interactions = TRUE) 54 | #' treeshap2$interactions 55 | #' } 56 | treeshap <- function(unified_model, x, interactions = FALSE, verbose = TRUE) { 57 | UseMethod("treeshap", unified_model) 58 | } 59 | 60 | #' @export 61 | treeshap.model_unified <- function(unified_model, x, interactions = FALSE, verbose = TRUE){ 62 | model <- unified_model$model 63 | # argument check 64 | if (!("matrix" %in% class(x) | "data.frame" %in% class(x))) { 65 | stop("x parameter has to be data.frame or matrix.") 66 | } 67 | 68 | if (!is.model_unified(unified_model)) { 69 | stop("unified_model parameter has to of class model_unified. Produce it using *.unify function.") 70 | } 71 | 72 | if (!attr(unified_model, "missing_support") & any(is.na(x))) { 73 | stop("Given model does not work with missing values. Dataset x should not contain missing values.") 74 | } 75 | 76 | x <- x[,colnames(x) %in% unified_model$feature_names] 77 | 78 | if (!all(model$Feature %in% c(NA, colnames(x)))) { 79 | stop("Dataset x does not contain all features occurring in the model.") 80 | } 81 | 82 | if (attr(unified_model, "model") == "LightGBM" & !is.data.frame(x)) { 83 | stop("For LightGBM models data.frame object is required as x parameter. Please convert.") 84 | } 85 | 86 | if ((!is.numeric(verbose) & !is.logical(verbose)) | is.null(verbose)) { 87 | warning("Incorrect verbose argument, setting verbose = FALSE (progress will not be printed).") 88 | verbose <- FALSE 89 | } 90 | verbose <- verbose[1] > 0 # so verbose = numeric will work too 91 | x <- as.data.frame(x) 92 | 93 | # adapting model representation to C++ and extracting from dataframe to vectors 94 | roots <- which(model$Node == 0) - 1 95 | yes <- model$Yes - 1 96 | no <- model$No - 1 97 | missing <- model$Missing - 1 98 | feature <- match(model$Feature, colnames(x)) - 1 99 | split <- model$Split 100 | decision_type <- unclass(model$Decision.type) 101 | is_leaf <- is.na(model$Feature) 102 | value <- model$Prediction 103 | cover <- model$Cover 104 | 105 | x2 <- as.data.frame(sapply(x, as.numeric)) 106 | if (nrow(x) > 1) x2 <- t(x2) # transposed to be able to pick a observation with [] operator in Rcpp 107 | is_na <- is.na(x2) # needed, because dataframe passed to cpp somehow replaces missing values with random values 108 | 109 | # calculating SHAP values 110 | if (interactions) { 111 | result <- treeshap_interactions_cpp(x2, is_na, 112 | roots, yes, no, missing, feature, split, decision_type, is_leaf, value, cover, 113 | verbose) 114 | shaps <- result$shaps 115 | interactions_array <- array(result$interactions, 116 | dim = c(ncol(x), ncol(x), nrow(x)), 117 | dimnames = list(colnames(x), colnames(x), rownames(x))) 118 | } else { 119 | shaps <- treeshap_cpp(x2, is_na, 120 | roots, yes, no, missing, feature, split, decision_type, is_leaf, value, cover, 121 | verbose) 122 | interactions_array <- NULL 123 | } 124 | 125 | dimnames(shaps) <- list(rownames(x), colnames(x)) 126 | treeshap_obj <- list(shaps = as.data.frame(shaps), interactions = interactions_array, 127 | unified_model = unified_model, observations = x) 128 | class(treeshap_obj) <- "treeshap" 129 | return(treeshap_obj) 130 | } 131 | 132 | 133 | #' @export 134 | treeshap.model_unified_multioutput <- function(unified_model, x, interactions = FALSE, verbose = TRUE){ 135 | treeshaps_objects <- lapply(unified_model, 136 | treeshap.model_unified, 137 | x = x, 138 | interactions = interactions, 139 | verbose = verbose) 140 | class(treeshaps_objects) <- "treeshap_multioutput" 141 | return(treeshaps_objects) 142 | } 143 | 144 | 145 | #' treeshap results 146 | #' 147 | #' \code{treeshap} object produced by \code{treeshap} function. 148 | #' 149 | #' @return List consisting of four elements: 150 | #' \describe{ 151 | #' \item{shaps}{A \code{data.frame} with M columns, X rows (M - number of features, X - number of explained observations). Every row corresponds to SHAP values for a observation. } 152 | #' \item{interactions}{An \code{array} with dimensions (M, M, X) (M - number of features, X - number of explained observations). Every \code{[, , i]} slice is a symmetric matrix - SHAP Interaction values for a observation. \code{[a, b, i]} element is SHAP Interaction value of features \code{a} and \code{b} for observation \code{i}. Is \code{NULL} if interactions where not calculated (parameter \code{interactions} set \code{FALSE}.) } 153 | #' \item{unified_model}{An object of type \code{\link{model_unified.object}}. Unified representation of a model for which SHAP values were calculated. It is used by some of the plotting functions.} 154 | #' \item{observations}{Explained dataset. \code{data.frame} or \code{matrix}. It is used by some of the plotting functions.} 155 | #' } 156 | #' 157 | #' 158 | #' @seealso 159 | #' \code{\link{treeshap}}, 160 | #' 161 | #' \code{\link{plot_contribution}}, \code{\link{plot_feature_importance}}, \code{\link{plot_feature_dependence}}, \code{\link{plot_interaction}} 162 | #' 163 | #' 164 | #' @name treeshap.object 165 | NULL 166 | 167 | 168 | #' treeshap results for multi-output model 169 | #' 170 | #' \code{treeshap_multioutput} object produced by \code{treeshap} function. 171 | #' 172 | #' @return List consisting of \code{treeshap} objects, one for each individual output of a model. For survival models, the list is named using the time points, for which TreeSHAP values are calculated. 173 | #' 174 | #' 175 | #' @seealso 176 | #' \code{\link{treeshap}}, 177 | #' 178 | #' \code{\link{treeshap.object}} 179 | #' 180 | #' 181 | #' @name treeshap_multioutput.object 182 | NULL 183 | 184 | 185 | #' Prints treeshap objects 186 | #' 187 | #' @param x a treeshap object 188 | #' @param ... other arguments 189 | #' 190 | #' @return No return value, called for printing 191 | #' 192 | #' @export 193 | #' 194 | print.treeshap <- function(x, ...){ 195 | print(x$shaps) 196 | if (!is.null(x$interactions)) { 197 | print(x$interactions) 198 | } 199 | return(invisible(NULL)) 200 | } 201 | 202 | 203 | #' Prints treeshap_multioutput objects 204 | #' 205 | #' @param x a treeshap_multioutput object 206 | #' @param ... other arguments 207 | #' 208 | #' @return No return value, called for printing 209 | #' 210 | #' @export 211 | #' 212 | print.treeshap_multioutput <- function(x, ...){ 213 | output_names <- names(x) 214 | lapply(output_names, function(output_name){ 215 | cat(paste("-> for output:", output_name, "\n")) 216 | print(x[[output_name]]) 217 | cat("\n") 218 | }) 219 | return(invisible(NULL)) 220 | } 221 | 222 | 223 | #' Check whether object is a valid treeshap object 224 | #' 225 | #' Does not check correctness of result, only basic checks 226 | #' 227 | #' @param x an object to check 228 | #' 229 | #' @return boolean 230 | #' 231 | #' @export 232 | #' 233 | is.treeshap <- function(x) { 234 | # class checks 235 | ("treeshap" %in% class(x)) & 236 | (is.data.frame(x$shaps)) & 237 | (is.null(x$interactions) | is.array(x$interactions)) & 238 | (is.model_unified(x$unified_model)) & 239 | (is.data.frame(x$observations) | is.matrix(x$observations)) & 240 | # dim checks 241 | (all(nrow(x$observations) == nrow(x$shaps)) & all(ncol(x$observations) == ncol(x$shaps))) & 242 | (is.null(x$interactions) | all(dim(x$interactions) == c(ncol(x$shaps), ncol(x$shaps), nrow(x$shaps)))) & 243 | # names check 244 | #all(rownames(x$observations) == rownames(x$shaps)) & 245 | all(colnames(x$observations) == colnames(x$shaps)) & 246 | (is.null(x$interactions) | all(dimnames(x$interactions)[[1]] == colnames(x$shaps) 247 | & dimnames(x$interactions)[[2]] == colnames(x$shaps))) & 248 | #(is.null(x$interactions) | all(dimnames(x$interactions)[[3]] == rownames(x$shaps))) & 249 | # type check 250 | (is.null(x$interactions) | is.numeric(x$interactions)) & 251 | (is.numeric(as.matrix(x$shaps))) 252 | } 253 | 254 | -------------------------------------------------------------------------------- /R/unify.R: -------------------------------------------------------------------------------- 1 | #' Unify tree-based model 2 | #' 3 | #' Convert your tree-based model into a standardized representation. 4 | #' The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. 5 | #' 6 | #' @param model A tree-based model object of any supported class (\code{gbm}, \code{lgb.Booster}, \code{randomForest}, \code{ranger}, or \code{xgb.Booster}). 7 | #' @param data Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model. 8 | #' @param ... Additional parameters passed to the model-specific unification functions. 9 | #' 10 | #' @return A unified model representation - a \code{\link{model_unified.object}} object (for single-output models) or \code{\link{model_unified_multioutput.object}}, which is a list of \code{\link{model_unified.object}} objects (for multi-output models). 11 | #' 12 | #' 13 | #' @seealso 14 | #' \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} 15 | #' 16 | #' \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} 17 | #' 18 | #' \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} 19 | #' 20 | #' \code{\link{ranger.unify}} for \code{\link[ranger:ranger]{ranger models}} 21 | #' 22 | #' \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} 23 | #' 24 | #' @export 25 | #' 26 | #' @examples 27 | #' 28 | #' library(ranger) 29 | #' data_fifa <- fifa20$data[!colnames(fifa20$data) %in% 30 | #' c('work_rate', 'value_eur', 'gk_diving', 'gk_handling', 31 | #' 'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')] 32 | #' data <- na.omit(cbind(data_fifa, target = fifa20$target)) 33 | #' 34 | #' rf1 <- ranger::ranger(target~., data = data, max.depth = 10, num.trees = 10) 35 | #' unified_model1 <- unify(rf1, data) 36 | #' shaps1 <- treeshap(unified_model1, data[1:2,]) 37 | #' plot_contribution(shaps1, obs = 1) 38 | #' 39 | #' rf2 <- randomForest::randomForest(target~., data = data, maxnodes = 10, ntree = 10) 40 | #' unified_model2 <- unify(rf2, data) 41 | #' shaps2 <- treeshap(unified_model2, data[1:2,]) 42 | #' plot_contribution(shaps2, obs = 1) 43 | unify <- function(model, data, ...){ 44 | UseMethod("unify", model) 45 | } 46 | 47 | #' @export 48 | unify.gbm <- function(model, data, ...){ 49 | gbm.unify(model, data) 50 | } 51 | 52 | #' @export 53 | unify.lgb.Booster <- function(model, data, recalculate = FALSE, ...){ 54 | lightgbm.unify(model, data, recalculate) 55 | } 56 | 57 | #' @export 58 | unify.randomForest <- function(model, data, ...){ 59 | randomForest.unify(model, data) 60 | } 61 | 62 | #' @export 63 | unify.ranger <- function(model, data, ...){ 64 | if (model$treetype == "Survival"){ 65 | return(ranger_surv.unify(model, data, ...)) 66 | } 67 | ranger.unify(model, data) 68 | } 69 | 70 | #' @export 71 | unify.xgb.Booster <- function(model, data, recalculate = FALSE, ...){ 72 | xgboost.unify(model, data, recalculate) 73 | } 74 | 75 | #' @export 76 | unify.default <- function(model, data, ...){ 77 | stop("Provided model is not of type supported by treeshap.") 78 | } 79 | 80 | -------------------------------------------------------------------------------- /R/unify_gbm.R: -------------------------------------------------------------------------------- 1 | #' Unify GBM model 2 | #' 3 | #' Convert your GBM model into a standardized representation. 4 | #' The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. 5 | #' 6 | #' @param gbm_model An object of \code{gbm} class. At the moment, models built on data with categorical features 7 | #' are not supported - please encode them before training. 8 | #' @param data Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model. 9 | #' 10 | #' @return a unified model representation - a \code{\link{model_unified.object}} object 11 | #' 12 | #' @export 13 | #' 14 | #' @seealso 15 | #' \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} 16 | #' 17 | #' \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} 18 | #' 19 | #' \code{\link{ranger.unify}} for \code{\link[ranger:ranger]{ranger models}} 20 | #' 21 | #' \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} 22 | #' 23 | #' @examples 24 | #' \donttest{ 25 | #' library(gbm) 26 | #' data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 27 | #' data['value_eur'] <- fifa20$target 28 | #' gbm_model <- gbm::gbm( 29 | #' formula = value_eur ~ ., 30 | #' data = data, 31 | #' distribution = "gaussian", 32 | #' n.trees = 20, 33 | #' interaction.depth = 4, 34 | #' n.cores = 1) 35 | #' unified_model <- gbm.unify(gbm_model, data) 36 | #' shaps <- treeshap(unified_model, data[1:2,]) 37 | #' plot_contribution(shaps, obs = 1) 38 | #' } 39 | gbm.unify <- function(gbm_model, data) { 40 | if(!inherits(gbm_model,'gbm')) { 41 | stop('Object gbm_model was not of class "gbm"') 42 | } 43 | if(any(gbm_model$var.type > 0)) { 44 | stop('Models built on data with categorical features are not supported - please encode them before training.') 45 | } 46 | x <- lapply(gbm_model$trees, data.table::as.data.table) 47 | times_vec <- sapply(x, nrow) 48 | y <- data.table::rbindlist(x) 49 | data.table::setnames(y, c("Feature", "Split", "Yes", 50 | "No", "Missing", "ErrorReduction", "Cover", 51 | "Prediction")) 52 | y[["Tree"]] <- rep(0:(length(gbm_model$trees) - 1), times = times_vec) 53 | y[["Node"]] <- unlist(lapply(times_vec, function(x) 0:(x - 1))) 54 | y <- y[, Feature := as.character(Feature)] 55 | y[y$Feature < 0, "Feature"] <- NA 56 | y[!is.na(y$Feature), "Feature"] <- attr(gbm_model$Terms, "term.labels")[as.integer(y[["Feature"]][!is.na(y$Feature)]) + 1] 57 | y[is.na(y$Feature), "ErrorReduction"] <- y[is.na(y$Feature), "Split"] 58 | y[is.na(y$Feature), "Split"] <- NA 59 | y[y$Yes < 0, "Yes"] <- NA 60 | y[y$No < 0, "No"] <- NA 61 | y[y$Missing < 0, "Missing"] <- NA 62 | y$Decision.type <- factor(x = rep("<=", times = nrow(y)), levels = c("<=", "<")) 63 | y[is.na(Feature), Decision.type := NA] 64 | y <- y[, c("Tree", "Node", "Feature", "Decision.type", "Split", "Yes", "No", "Missing", "ErrorReduction", "Cover")] 65 | colnames(y) <- c("Tree", "Node", "Feature", "Decision.type", "Split", "Yes", "No", "Missing", "Prediction", "Cover") 66 | 67 | ID <- paste0(y$Node, "-", y$Tree) 68 | y$Yes <- match(paste0(y$Yes, "-", y$Tree), ID) 69 | y$No <- match(paste0(y$No, "-", y$Tree), ID) 70 | y$Missing <- match(paste0(y$Missing, "-", y$Tree), ID) 71 | 72 | # Here we lose "Quality" information 73 | y[!is.na(Feature), Prediction := NA] 74 | 75 | # GBM calculates prediction as [initF + sum of predictions of trees] 76 | # treeSHAP assumes prediction are calculated as [sum of predictions of trees] 77 | # so here we adjust it 78 | ntrees <- sum(y$Node == 0) 79 | y[is.na(Feature), Prediction := Prediction + gbm_model$initF / ntrees] 80 | 81 | feature_names <- gbm_model$var.names 82 | data <- data[,colnames(data) %in% feature_names] 83 | 84 | ret <- list(model = as.data.frame(y), data = as.data.frame(data), feature_names = feature_names) 85 | class(ret) <- "model_unified" 86 | attr(ret, "missing_support") <- TRUE 87 | attr(ret, "model") <- "gbm" 88 | 89 | # Original covers in gbm_model are not correct 90 | ret <- set_reference_dataset(ret, as.data.frame(data)) 91 | 92 | return(ret) 93 | } 94 | -------------------------------------------------------------------------------- /R/unify_lightgbm.R: -------------------------------------------------------------------------------- 1 | # should be preceded with lgb.model.dt.tree 2 | #' Unify LightGBM model 3 | #' 4 | #' Convert your LightGBM model into a standardized representation. 5 | #' The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. 6 | #' 7 | #' @param lgb_model A lightgbm model - object of class \code{lgb.Booster} 8 | #' @param data Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model. 9 | #' @param recalculate logical indicating if covers should be recalculated according to the dataset given in data. Keep it \code{FALSE} if training data are used. 10 | #' 11 | #' @return a unified model representation - a \code{\link{model_unified.object}} object 12 | #' 13 | #' @export 14 | #' 15 | #' @import data.table 16 | #' 17 | #' @seealso 18 | #' 19 | #' \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} 20 | #' 21 | #' \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} 22 | #' 23 | #' \code{\link{ranger.unify}} for \code{\link[ranger:ranger]{ranger models}} 24 | #' 25 | #' \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} 26 | #' 27 | #' @examples 28 | #' \donttest{ 29 | #' library(lightgbm) 30 | #' param_lgbm <- list(objective = "regression", max_depth = 2, 31 | #' force_row_wise = TRUE, num_iterations = 20) 32 | #' data_fifa <- fifa20$data[!colnames(fifa20$data) %in% 33 | #' c('work_rate', 'value_eur', 'gk_diving', 'gk_handling', 34 | #' 'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')] 35 | #' data <- na.omit(cbind(data_fifa, fifa20$target)) 36 | #' sparse_data <- as.matrix(data[,-ncol(data)]) 37 | #' x <- lightgbm::lgb.Dataset(sparse_data, label = as.matrix(data[,ncol(data)])) 38 | #' lgb_data <- lightgbm::lgb.Dataset.construct(x) 39 | #' lgb_model <- lightgbm::lightgbm(data = lgb_data, params = param_lgbm, 40 | #' verbose = -1, num_threads = 0) 41 | #' unified_model <- lightgbm.unify(lgb_model, sparse_data) 42 | #' shaps <- treeshap(unified_model, data[1:2, ]) 43 | #' plot_contribution(shaps, obs = 1) 44 | #' } 45 | lightgbm.unify <- function(lgb_model, data, recalculate = FALSE) { 46 | if (!requireNamespace("lightgbm", quietly = TRUE)) { 47 | stop("Package \"lightgbm\" needed for this function to work. Please install it.", 48 | call. = FALSE) 49 | } 50 | df <- lightgbm::lgb.model.dt.tree(lgb_model) 51 | stopifnot(c("split_index", "split_feature", "node_parent", "leaf_index", "leaf_parent", "internal_value", 52 | "internal_count", "leaf_value", "leaf_count", "decision_type") %in% colnames(df)) 53 | df <- data.table::as.data.table(df) 54 | #convert node_parent and leaf_parent into one parent column 55 | df[is.na(df$node_parent), "node_parent"] <- df[is.na(df$node_parent), "leaf_parent"] 56 | #convert values into one column... 57 | df[is.na(df$internal_value), "internal_value"] <- df[!is.na(df$leaf_value), "leaf_value"] 58 | #...and counts 59 | df[is.na(df$internal_count), "internal_count"] <- df[!is.na(df$leaf_count), "leaf_count"] 60 | df[["internal_count"]] <- as.numeric(df[["internal_count"]]) 61 | #convert split_index and leaf_index into one column: 62 | max_split_index <- df[, max(split_index, na.rm = TRUE), tree_index] 63 | rep_max_split <- rep(max_split_index$V1, times = as.numeric(table(df$tree_index))) 64 | new_leaf_index <- rep_max_split + df[, "leaf_index"] + 1 65 | df[is.na(df$split_index), "split_index"] <- new_leaf_index[!is.na(new_leaf_index[["leaf_index"]]), 'leaf_index'] 66 | df[is.na(df$split_gain), "split_gain"] <- df[is.na(df$split_gain), "leaf_value"] 67 | # On the basis of column 'Parent', create columns with childs: 'Yes', 'No' and 'Missing' like in the xgboost df: 68 | ret.first <- function(x) x[1] 69 | ret.second <- function(x) x[2] 70 | tmp <- data.table::merge.data.table(df[, .(node_parent, tree_index, split_index)], df[, .(tree_index, split_index, default_left, decision_type)], 71 | by.x = c("tree_index", "node_parent"), by.y = c("tree_index", "split_index")) 72 | y_n_m <- unique(tmp[, .(Yes = ifelse(decision_type %in% c("<=", "<"), ret.first(split_index), 73 | ifelse(decision_type %in% c(">=", ">"), ret.second(split_index), stop("Unknown decision_type"))), 74 | No = ifelse(decision_type %in% c(">=", ">"), ret.first(split_index), 75 | ifelse(decision_type %in% c("<=", "<"), ret.second(split_index), stop("Unknown decision_type"))), 76 | Missing = ifelse(default_left, ret.first(split_index),ret.second(split_index)), 77 | decision_type = decision_type), 78 | .(tree_index, node_parent)]) 79 | df <- data.table::merge.data.table(df[, c("tree_index", "depth", "split_index", "split_feature", "node_parent", "split_gain", 80 | "threshold", "internal_value", "internal_count")], 81 | y_n_m, by.x = c("tree_index", "split_index"), 82 | by.y = c("tree_index", "node_parent"), all.x = TRUE) 83 | df[decision_type == ">=", decision_type := "<"] 84 | df[decision_type == ">", decision_type := "<="] 85 | df$Decision.type <- factor(x = df$decision_type, levels = c("<=", "<")) 86 | df[is.na(split_index), Decision.type := NA] 87 | df <- df[, c("tree_index", "split_index", "split_feature", "Decision.type", "threshold", "Yes", "No", "Missing", "split_gain", "internal_count")] 88 | colnames(df) <- c("Tree", "Node", "Feature", "Decision.type", "Split", "Yes", "No", "Missing", "Prediction", "Cover") 89 | attr(df, "sorted") <- NULL 90 | 91 | ID <- paste0(df$Node, "-", df$Tree) 92 | df$Yes <- match(paste0(df$Yes, "-", df$Tree), ID) 93 | df$No <- match(paste0(df$No, "-", df$Tree), ID) 94 | df$Missing <- match(paste0(df$Missing, "-", df$Tree), ID) 95 | 96 | # Here we lose "Quality" information 97 | df$Prediction[!is.na(df$Feature)] <- NA 98 | 99 | feature_names <- jsonlite::fromJSON(lgb_model$dump_model())$feature_names 100 | data <- data[,colnames(data) %in% feature_names] 101 | 102 | ret <- list(model = as.data.frame(df), data = as.data.frame(data), feature_names = feature_names) 103 | class(ret) <- "model_unified" 104 | attr(ret, "missing_support") <- TRUE 105 | attr(ret, "model") <- "LightGBM" 106 | 107 | if (recalculate) { 108 | ret <- set_reference_dataset(ret, as.data.frame(data)) 109 | } 110 | 111 | return(ret) 112 | } 113 | -------------------------------------------------------------------------------- /R/unify_randomForest.R: -------------------------------------------------------------------------------- 1 | #' Unify randomForest model 2 | #' 3 | #' Convert your randomForest model into a standardized representation. 4 | #' The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. 5 | #' 6 | #' Binary classification models with a target variable that is a factor with two levels, 0 and 1, are supported 7 | #' 8 | #' @param rf_model An object of \code{randomForest} class. At the moment, models built on data with categorical features 9 | #' are not supported - please encode them before training. 10 | #' @param data Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model. 11 | #' 12 | #' @return a unified model representation - a \code{\link{model_unified.object}} object 13 | #' 14 | #' @import data.table 15 | #' 16 | #' @export 17 | #' 18 | #' @seealso 19 | #' \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} 20 | #' 21 | #' \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} 22 | #' 23 | #' \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} 24 | #' 25 | #' \code{\link{ranger.unify}} for \code{\link[ranger:ranger]{ranger models}} 26 | #' 27 | #' @examples 28 | #' 29 | #' library(randomForest) 30 | #' data_fifa <- fifa20$data[!colnames(fifa20$data) %in% 31 | #' c('work_rate', 'value_eur', 'gk_diving', 'gk_handling', 32 | #' 'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')] 33 | #' data <- na.omit(cbind(data_fifa, target = fifa20$target)) 34 | #' 35 | #' rf <- randomForest::randomForest(target~., data = data, maxnodes = 10, ntree = 10) 36 | #' unified_model <- randomForest.unify(rf, data) 37 | #' shaps <- treeshap(unified_model, data[1:2,]) 38 | #' # plot_contribution(shaps, obs = 1) 39 | #' 40 | randomForest.unify <- function(rf_model, data) { 41 | if(!inherits(rf_model,'randomForest')){stop('Object rf_model was not of class "randomForest"')} 42 | if(any(attr(rf_model$terms, "dataClasses")[-1] != "numeric")) { 43 | stop('Models built on data with categorical features are not supported - please encode them before training.') 44 | } 45 | n <- rf_model$ntree 46 | ret <- data.table() 47 | prediction <- NULL 48 | x <- lapply(1:n, function(tree){ 49 | tree_data <- as.data.table(randomForest::getTree(rf_model, k = tree, labelVar = TRUE)) 50 | tree_data <- tree_data[ , prediction:=as.numeric(prediction)] 51 | tree_data[, c("left daughter", "right daughter", "split var", "split point", "prediction")] 52 | }) 53 | times_vec <- sapply(x, nrow) 54 | y <- rbindlist(x) 55 | y[, Tree := rep(0:(n - 1), times = times_vec)] 56 | y[, Node := unlist(lapply(times_vec, function(x) 0:(x - 1)))] 57 | setnames(y, c("Yes", "No", "Feature", "Split", "Prediction", "Tree", "Node")) 58 | y[, Feature := as.character(Feature)] 59 | y[, Yes := Yes - 1] 60 | y[, No := No - 1] 61 | y[y$Yes < 0, "Yes"] <- NA 62 | y[y$No < 0, "No"] <- NA 63 | y[, Missing := NA] 64 | y[, Missing := as.integer(Missing)] # seems not, but needed 65 | 66 | ID <- paste0(y$Node, "-", y$Tree) 67 | y$Yes <- match(paste0(y$Yes, "-", y$Tree), ID) 68 | y$No <- match(paste0(y$No, "-", y$Tree), ID) 69 | 70 | y$Cover <- 0 71 | 72 | y$Decision.type <- factor(x = rep("<=", times = nrow(y)), levels = c("<=", "<")) 73 | y[is.na(Feature), Decision.type := NA] 74 | 75 | # Here we lose "Quality" information 76 | y[!is.na(Feature), Prediction := NA] 77 | 78 | # treeSHAP assumes, that [prediction = sum of predictions of the trees] 79 | # in random forest [prediction = mean of predictions of the trees] 80 | # so here we correct it by adjusting leaf prediction values 81 | y[is.na(Feature), Prediction := Prediction / n] 82 | 83 | 84 | setcolorder(y, c("Tree", "Node", "Feature", "Decision.type", "Split", "Yes", "No", "Missing", "Prediction", "Cover")) 85 | feature_names <- rownames(rf_model$importance) 86 | data <- data[,colnames(data) %in% feature_names] 87 | 88 | ret <- list(model = as.data.frame(y), data = as.data.frame(data), feature_names = feature_names) 89 | class(ret) <- "model_unified" 90 | attr(ret, "missing_support") <- FALSE 91 | attr(ret, "model") <- "randomForest" 92 | return(set_reference_dataset(ret, as.data.frame(data))) 93 | } 94 | -------------------------------------------------------------------------------- /R/unify_ranger.R: -------------------------------------------------------------------------------- 1 | #' Unify ranger model 2 | #' 3 | #' Convert your ranger model into a standardized representation. 4 | #' The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. 5 | #' 6 | #' @param rf_model An object of \code{ranger} class. At the moment, models built on data with categorical features 7 | #' are not supported - please encode them before training. 8 | #' @param data Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model. 9 | #' 10 | #' @return a unified model representation - a \code{\link{model_unified.object}} object 11 | #' 12 | #' @import data.table 13 | #' 14 | #' @export 15 | #' 16 | #' @seealso 17 | #' \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} 18 | #' 19 | #' \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} 20 | #' 21 | #' \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} 22 | #' 23 | #' \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} 24 | #' 25 | #' @examples 26 | #' 27 | #' library(ranger) 28 | #' data_fifa <- fifa20$data[!colnames(fifa20$data) %in% 29 | #' c('work_rate', 'value_eur', 'gk_diving', 'gk_handling', 30 | #' 'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')] 31 | #' data <- na.omit(cbind(data_fifa, target = fifa20$target)) 32 | #' 33 | #' rf <- ranger::ranger(target~., data = data, max.depth = 10, num.trees = 10) 34 | #' unified_model <- ranger.unify(rf, data) 35 | #' shaps <- treeshap(unified_model, data[1:2,]) 36 | #' plot_contribution(shaps, obs = 1) 37 | ranger.unify <- function(rf_model, data) { 38 | if(!'ranger' %in% class(rf_model)) { 39 | stop('Object rf_model was not of class "ranger"') 40 | } 41 | n <- rf_model$num.trees 42 | x <- lapply(1:n, function(tree) { 43 | tree_data <- data.table::as.data.table(ranger::treeInfo(rf_model, tree = tree)) 44 | tree_data[, c("nodeID", "leftChild", "rightChild", "splitvarName", "splitval", "prediction")] 45 | }) 46 | return(ranger_unify.common(x = x, n = n, data = data, feature_names = rf_model$forest$independent.variable.names)) 47 | } 48 | 49 | 50 | ranger_unify.common <- function(x, n, data, feature_names) { 51 | times_vec <- sapply(x, nrow) 52 | y <- data.table::rbindlist(x) 53 | y[, ("Tree") := rep(0:(n - 1), times = times_vec)] 54 | data.table::setnames(y, c("Node", "Yes", "No", "Feature", "Split", "Prediction", "Tree")) 55 | y[, ("Feature") := as.character(get("Feature"))] 56 | y[y$Yes < 0, "Yes"] <- NA 57 | y[y$No < 0, "No"] <- NA 58 | y[, ("Missing") := NA] 59 | y$Cover <- 0 60 | y$Decision.type <- factor(x = rep("<=", times = nrow(y)), levels = c("<=", "<")) 61 | y[is.na(get("Feature")), ("Decision.type") := NA] 62 | 63 | ID <- paste0(y$Node, "-", y$Tree) 64 | y$Yes <- match(paste0(y$Yes, "-", y$Tree), ID) 65 | y$No <- match(paste0(y$No, "-", y$Tree), ID) 66 | 67 | # Here we lose "Quality" information 68 | y[!is.na(get("Feature")), ("Prediction") := NA] 69 | 70 | # treeSHAP assumes, that [prediction = sum of predictions of the trees] 71 | # in random forest [prediction = mean of predictions of the trees] 72 | # so here we correct it by adjusting leaf prediction values 73 | y[is.na(get("Feature")), ("Prediction") := I(get("Prediction") / n)] 74 | 75 | 76 | data.table::setcolorder( 77 | y, c("Tree", "Node", "Feature", "Decision.type", "Split", 78 | "Yes", "No", "Missing", "Prediction", "Cover")) 79 | 80 | data <- data[,colnames(data) %in% feature_names] 81 | 82 | ret <- list(model = as.data.frame(y), data = as.data.frame(data), feature_names = feature_names) 83 | class(ret) <- "model_unified" 84 | attr(ret, "missing_support") <- FALSE 85 | attr(ret, "model") <- "ranger" 86 | return(set_reference_dataset(ret, as.data.frame(data))) 87 | } 88 | -------------------------------------------------------------------------------- /R/unify_ranger_surv.R: -------------------------------------------------------------------------------- 1 | #' Unify ranger survival model 2 | #' 3 | #' Convert your ranger model into a standardized representation. 4 | #' The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. 5 | #' 6 | #' @details 7 | #' The survival forest implemented in the \code{ranger} package stores cumulative hazard 8 | #' functions (CHFs) in the leaves of survival trees, as proposed for Random Survival Forests 9 | #' (Ishwaran et al. 2008). The final model prediction is made by averaging these CHFs 10 | #' from all the trees. To provide explanations in the form of a survival function, 11 | #' the CHFs from the leaves are converted into survival functions (SFs) using 12 | #' the formula SF(t) = exp(-CHF(t)). 13 | #' However, it is important to note that averaging these SFs does not yield the correct 14 | #' model prediction as the model prediction is the average of CHFs transformed in the same way. 15 | #' Therefore, when you obtain explanations based on the survival function, 16 | #' they are only proxies and may not be fully consistent with the model predictions 17 | #' obtained using for example \code{predict} function. 18 | #' 19 | # 20 | #' @param rf_model An object of \code{ranger} class. At the moment, models built on data with categorical features 21 | #' are not supported - please encode them before training. 22 | #' @param data Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model. 23 | #' @param type A character to define the type of model prediction to use. Either `"risk"` (default), which uses the risk score calculated as a sum of cumulative hazard function values, `"survival"`, which uses the survival probability at certain time-points for each observation, or `"chf"`, which used the cumulative hazard values at certain time-points for each observation. 24 | #' @param times A numeric vector of unique death times at which the prediction should be evaluated. By default `unique.death.times` from model are used. 25 | #' 26 | #' @return For `type = "risk"` a unified model representation is returned - a \code{\link{model_unified.object}} object. For `type = "survival"` or `type = "chf"` - a \code{\link{model_unified_multioutput.object}} object is returned, which is a list that contains unified model representation (\code{\link{model_unified.object}} object) for each time point. In this case, the list names are time points at which the survival function was evaluated. 27 | #' 28 | #' @import data.table 29 | #' @importFrom stats stepfun 30 | #' 31 | #' @export 32 | #' 33 | #' @seealso 34 | #' \code{\link{ranger.unify}} for regression and classification \code{\link[ranger:ranger]{ranger models}} 35 | #' 36 | #' \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} 37 | #' 38 | #' \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} 39 | #' 40 | #' \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} 41 | #' 42 | #' \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} 43 | #' 44 | #' @examples 45 | #' 46 | #' library(ranger) 47 | #' data_colon <- data.table::data.table(survival::colon) 48 | #' data_colon <- na.omit(data_colon[get("etype") == 2, ]) 49 | #' surv_cols <- c("status", "time", "rx") 50 | #' 51 | #' feature_cols <- colnames(data_colon)[3:(ncol(data_colon) - 1)] 52 | #' 53 | #' train_x <- model.matrix( 54 | #' ~ -1 + ., 55 | #' data_colon[, .SD, .SDcols = setdiff(feature_cols, surv_cols[1:2])] 56 | #' ) 57 | #' train_y <- survival::Surv( 58 | #' event = (data_colon[, get("status")] |> 59 | #' as.character() |> 60 | #' as.integer()), 61 | #' time = data_colon[, get("time")], 62 | #' type = "right" 63 | #' ) 64 | #' 65 | #' rf <- ranger::ranger( 66 | #' x = train_x, 67 | #' y = train_y, 68 | #' data = data_colon, 69 | #' max.depth = 10, 70 | #' num.trees = 10 71 | #' ) 72 | #' unified_model_risk <- ranger_surv.unify(rf, train_x, type = "risk") 73 | #' shaps <- treeshap(unified_model_risk, train_x[1:2,]) 74 | #' 75 | #' # compute shaps for 3 selected time points 76 | #' unified_model_surv <- ranger_surv.unify(rf, train_x, type = "survival", times = c(23, 50, 73)) 77 | #' shaps_surv <- treeshap(unified_model_surv, train_x[1:2,]) 78 | #' 79 | ranger_surv.unify <- function(rf_model, data, type = c("risk", "survival", "chf"), times = NULL) { 80 | type <- match.arg(type) 81 | 82 | stopifnot( 83 | "`times` must be a numeric vector and argument \ 84 | `type = 'survival'` or `type = 'chf'` must be set." = 85 | ifelse(!is.null(times), is.numeric(times) && (type == "survival" || type == "chf"), TRUE) 86 | ) 87 | 88 | surv_common <- ranger_surv.common(rf_model, data) 89 | n <- surv_common$n 90 | chf_table_list <- surv_common$chf_table_list 91 | 92 | if (type == "risk") { 93 | 94 | x <- lapply(chf_table_list, function(tree) { 95 | tree_data <- tree$tree_data 96 | nodes_chf <- tree$table 97 | tree_data$prediction <- rowSums(nodes_chf) 98 | tree_data[, c("nodeID", "leftChild", "rightChild", "splitvarName", 99 | "splitval", "prediction")] 100 | }) 101 | unified_return <- ranger_unify.common(x = x, n = n, data = data, feature_names = rf_model$forest$independent.variable.names) 102 | 103 | } else if (type == "survival" || type == "chf") { 104 | 105 | unique_death_times <- rf_model$unique.death.times 106 | 107 | if (is.null(times)) { 108 | compute_at_times <- unique_death_times 109 | # eval_times is required for list names (mainly when eval-times are 110 | # differing to the unique death times from the model as in the next case) 111 | eval_times <- as.character(compute_at_times) 112 | } else { 113 | stepfunction <- stepfun(unique_death_times, c(unique_death_times[1], unique_death_times)) 114 | compute_at_times <- stepfunction(times) 115 | eval_times <- as.character(times) 116 | } 117 | 118 | # iterate over time-points 119 | unified_return <- lapply(compute_at_times, function(t) { 120 | time_index <- which(unique_death_times == t) 121 | x <- lapply(chf_table_list, function(tree) { 122 | tree_data <- tree$tree_data 123 | nodes_chf <- tree$table[, time_index] 124 | 125 | # transform cumulative hazards to survival function (if needed) 126 | # H(t) = -ln(S(t)) 127 | # S(t) = exp(-H(t)) 128 | tree_data$prediction <- if(type == "survival") exp(-nodes_chf) else nodes_chf 129 | tree_data[, c("nodeID", "leftChild", "rightChild", "splitvarName", 130 | "splitval", "prediction")] 131 | }) 132 | ranger_unify.common(x = x, n = n, data = data, feature_names = rf_model$forest$independent.variable.names) 133 | }) 134 | names(unified_return) <- eval_times 135 | class(unified_return) <- "model_unified_multioutput" 136 | } 137 | return(unified_return) 138 | } 139 | 140 | ranger_surv.common <- function(rf_model, data) { 141 | if (!"ranger" %in% class(rf_model)) { 142 | stop("Object rf_model was not of class \"ranger\"") 143 | } 144 | if (!rf_model$treetype == "Survival") { 145 | stop("Object rf_model is not a random survival forest.") 146 | } 147 | n <- rf_model$num.trees 148 | chf_table_list <- lapply(1:n, function(tree) { 149 | tree_data <- data.table::as.data.table(ranger::treeInfo(rf_model, 150 | tree = tree)) 151 | 152 | # first get number of columns 153 | chf_node <- rf_model$forest$chf[[tree]] 154 | nodes_chf_n <- ncol(do.call(rbind, chf_node)) 155 | nodes_prepare_chf_list <- lapply( 156 | X = chf_node, 157 | FUN = function(node) { 158 | if (identical(node, numeric(0L))) { 159 | rep(NA, nodes_chf_n) 160 | } else { 161 | node 162 | } 163 | } 164 | ) 165 | list(table = do.call(rbind, nodes_prepare_chf_list), tree_data = tree_data) 166 | }) 167 | return(list(chf_table_list = chf_table_list, n = n)) 168 | } 169 | 170 | -------------------------------------------------------------------------------- /R/unify_xgboost.R: -------------------------------------------------------------------------------- 1 | #' Unify XGBoost model 2 | #' 3 | #' Convert your XGBoost model into a standardized representation. 4 | #' The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. 5 | #' 6 | #' @param xgb_model A XGBoost model - object of class \code{xgb.Booster} 7 | #' @param data Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model. 8 | #' @param recalculate logical indicating if covers should be recalculated according to the dataset given in data. Keep it \code{FALSE} if training data are used. 9 | #' 10 | #' @return a unified model representation - a \code{\link{model_unified.object}} object 11 | #' 12 | #' @export 13 | #' 14 | #' @seealso 15 | #' \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} 16 | #' 17 | #' \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} 18 | #' 19 | #' \code{\link{ranger.unify}} for \code{\link[ranger:ranger]{ranger models}} 20 | #' 21 | #' \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} 22 | #' 23 | #' @examples 24 | #' \donttest{ 25 | #' library(xgboost) 26 | #' data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 27 | #' target <- fifa20$target 28 | #' param <- list(objective = "reg:squarederror", max_depth = 3) 29 | #' xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, 30 | #' nrounds = 20, verbose = 0) 31 | #' unified_model <- xgboost.unify(xgb_model, as.matrix(data)) 32 | #' shaps <- treeshap(unified_model, data[1:2,]) 33 | #' plot_contribution(shaps, obs = 1) 34 | #' } 35 | #' 36 | xgboost.unify <- function(xgb_model, data, recalculate = FALSE) { 37 | if (!requireNamespace("xgboost", quietly = TRUE)) { 38 | stop("Package \"xgboost\" needed for this function to work. Please install it.", 39 | call. = FALSE) 40 | } 41 | xgbtree <- xgboost::xgb.model.dt.tree(model = xgb_model) 42 | stopifnot(c("Tree", "Node", "ID", "Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover") %in% colnames(xgbtree)) 43 | xgbtree$Yes <- match(xgbtree$Yes, xgbtree$ID) 44 | xgbtree$No <- match(xgbtree$No, xgbtree$ID) 45 | xgbtree$Missing <- match(xgbtree$Missing, xgbtree$ID) 46 | xgbtree[is.na(xgbtree$Split), 'Feature'] <- NA 47 | xgbtree$Decision.type <- factor(x = rep("<=", times = nrow(xgbtree)), levels = c("<=", "<")) 48 | xgbtree$Decision.type[is.na(xgbtree$Feature)] <- NA 49 | xgbtree <- xgbtree[, c("Tree", "Node", "Feature", "Decision.type", "Split", "Yes", "No", "Missing", "Quality", "Cover")] 50 | colnames(xgbtree) <- c("Tree", "Node", "Feature", "Decision.type", "Split", "Yes", "No", "Missing", "Prediction", "Cover") 51 | 52 | # Here we lose "Quality" information 53 | xgbtree$Prediction[!is.na(xgbtree$Feature)] <- NA 54 | 55 | feature_names <- xgb_model$feature_names 56 | data <- data[,colnames(data) %in% feature_names] 57 | 58 | ret <- list(model = as.data.frame(xgbtree), data = as.data.frame(data), feature_names = feature_names) 59 | class(ret) <- "model_unified" 60 | attr(ret, "missing_support") <- TRUE 61 | attr(ret, "model") <- "xgboost" 62 | 63 | if (recalculate) { 64 | ret <- set_reference_dataset(ret, as.data.frame(data)) 65 | } 66 | 67 | return(ret) 68 | } 69 | -------------------------------------------------------------------------------- /README.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | output: github_document 3 | --- 4 | 5 | 6 | 7 | ```{r, include = FALSE} 8 | knitr::opts_chunk$set( 9 | collapse = TRUE, 10 | comment = "#>", 11 | fig.path = "man/figures/README-", 12 | out.width = "100%" 13 | ) 14 | 15 | set.seed(21) 16 | ``` 17 | 18 | # treeshap 19 | 20 | 21 | [![R-CMD-check](https://github.com/ModelOriented/treeshap/actions/workflows/CRAN-R-CMD-check.yaml/badge.svg)](https://github.com/ModelOriented/treeshap/actions/workflows/CRAN-R-CMD-check.yaml) 22 | [![CRAN status](https://www.r-pkg.org/badges/version/treeshap)](https://CRAN.R-project.org/package=treeshap) 23 | 24 | 25 | In the era of complicated classifiers conquering their market, sometimes even the authors of algorithms do not know the exact manner of building a tree ensemble model. The difficulties in models' structures are one of the reasons why most users use them simply like black-boxes. But, how can they know whether the prediction made by the model is reasonable? `treeshap` is an efficient answer for this question. Due to implementing an optimized algorithm for tree ensemble models (called TreeSHAP), it calculates the SHAP values in polynomial (instead of exponential) time. Currently, `treeshap` supports models produced with `xgboost`, `lightgbm`, `gbm`, `ranger`, and `randomForest` packages. Support for `catboost` is available only in [`catboost` branch](https://github.com/ModelOriented/treeshap/tree/catboost) (see why [here](#catboost)). 26 | 27 | ## Installation 28 | 29 | The package is available on CRAN: 30 | ``` r 31 | install.packages('treeshap') 32 | ``` 33 | 34 | You can install the latest development version from GitHub using `devtools` with: 35 | ``` r 36 | devtools::install_github('ModelOriented/treeshap') 37 | ``` 38 | 39 | ## Example 40 | 41 | First of all, let's focus on an example how to represent a `xgboost` model as a unified model object: 42 | 43 | ```{r unifier-example, warning=FALSE, message=FALSE} 44 | library(treeshap) 45 | library(xgboost) 46 | data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 47 | target <- fifa20$target 48 | param <- list(objective = "reg:squarederror", max_depth = 6) 49 | xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, nrounds = 200, verbose = 0) 50 | unified <- unify(xgb_model, data) 51 | head(unified$model) 52 | ``` 53 | 54 | Having the object of unified structure, it is a piece of cake to produce SHAP values for a specific observation. The `treeshap()` function requires passing two data arguments: one representing an ensemble model unified representation and one with the observations about which we want to get the explanations. Obviously, the latter one should contain the same columns as data used during building the model. 55 | 56 | ```{r treeshap-example} 57 | treeshap1 <- treeshap(unified, data[700:800, ], verbose = 0) 58 | treeshap1$shaps[1:3, 1:6] 59 | ``` 60 | 61 | We can also compute SHAP values for interactions. As an example we will calculate them for a model built with simpler (only 5 columns) data and first 100 observations. 62 | 63 | ```{r interactions-example} 64 | data2 <- fifa20$data[, 1:5] 65 | xgb_model2 <- xgboost::xgboost(as.matrix(data2), params = param, label = target, nrounds = 200, verbose = 0) 66 | unified2 <- unify(xgb_model2, data2) 67 | 68 | treeshap_interactions <- treeshap(unified2, data2[1:100, ], interactions = TRUE, verbose = 0) 69 | treeshap_interactions$interactions[, , 1:2] 70 | ``` 71 | 72 | ## Plotting results 73 | 74 | The explanation results can be visualized using [`shapviz`](https://github.com/ModelOriented/shapviz/) package, see [here](https://modeloriented.github.io/shapviz/articles/basic_use.html#treeshap). 75 | 76 | However, `treeshap` also provides 4 plotting functions: 77 | 78 | ### Feature Contribution (Break-Down) 79 | 80 | On this plot we can see how features contribute into the prediction for a single observation. It is similar to the Break Down plot from [iBreakDown](https://github.com/ModelOriented/iBreakDown) package, which uses different method to approximate SHAP values. 81 | 82 | ```{r plot_contribution_example} 83 | plot_contribution(treeshap1, obs = 1, min_max = c(0, 16000000)) 84 | ``` 85 | 86 | ### Feature Importance 87 | 88 | This plot shows us average absolute impact of features on the prediction of the model. 89 | 90 | ```{r plot_importance_example} 91 | plot_feature_importance(treeshap1, max_vars = 6) 92 | ``` 93 | 94 | ### Feature Dependence 95 | 96 | Using this plot we can see, how a single feature contributes into the prediction depending on its value. 97 | 98 | ```{r plot_dependence_example} 99 | plot_feature_dependence(treeshap1, "height_cm") 100 | ``` 101 | 102 | ### Interaction Plot 103 | 104 | Simple plot to visualize an SHAP Interaction value of two features depending on their values. 105 | 106 | ```{r plot_interaction} 107 | plot_interaction(treeshap_interactions, "height_cm", "overall") 108 | ``` 109 | 110 | ## How to use the unifying functions? 111 | 112 | For your convenience, you can now simply use the `unify()` function by specifying your model and reference dataset. Behind the scenes, it uses one of the six functions from the `.unify()` family (`xgboost.unify()`, `lightgbm.unify()`, `gbm.unify()`, `catboost.unify()`, `randomForest.unify()`, `ranger.unify()`). Even though the objects produced by these functions are identical when it comes to the structure, due to different possibilities of saving and representing the trees among the packages, the usage of these model-specific functions may be slightly different. Therefore, you can use them independently or pass some additional parameters to `unify()`. 113 | 114 | ```{r gbm, eval=FALSE} 115 | library(treeshap) 116 | library(gbm) 117 | x <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 118 | x['value_eur'] <- fifa20$target 119 | gbm_model <- gbm::gbm( 120 | formula = value_eur ~ ., 121 | data = x, 122 | distribution = "laplace", 123 | n.trees = 200, 124 | cv.folds = 2, 125 | interaction.depth = 2 126 | ) 127 | unified_gbm <- unify(gbm_model, x) 128 | unified_gbm2 <- gbm.unify(gbm_model, x) # legacy API 129 | ``` 130 | 131 | 132 | ## Setting reference dataset 133 | 134 | Dataset used as a reference for calculating SHAP values is stored in unified model representation object. It can be set any time using `set_reference_dataset()` function. 135 | 136 | ```{r set_reference_dataset, eval=FALSE} 137 | library(treeshap) 138 | library(ranger) 139 | data_fifa <- fifa20$data[!colnames(fifa20$data) %in% 140 | c('work_rate', 'value_eur', 'gk_diving', 'gk_handling', 141 | 'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')] 142 | data <- na.omit(cbind(data_fifa, target = fifa20$target)) 143 | rf <- ranger::ranger(target~., data = data, max.depth = 10, num.trees = 10) 144 | 145 | unified_ranger_model <- unify(rf, data) 146 | unified_ranger_model2 <- set_reference_dataset(unified_ranger_model, data[c(1000:2000), ]) 147 | ``` 148 | 149 | ## Other functionalities 150 | 151 | Package also implements `predict()` function for calculating model's predictions using unified representation. 152 | 153 | ## How fast does it work? 154 | 155 | The complexity of TreeSHAP is $\mathcal{O}(TLD^2)$, where $T$ is the number of trees, $L$ is the number of leaves in a tree, and $D$ is the depth of a tree. 156 | 157 | Our implementation works at a speed comparable to the original Lundberg's Python package `shap` implementation using C and Python. 158 | 159 | The complexity of SHAP interaction values computation is $\mathcal{O}(MTLD^2)$, where $M$ is the number of explanatory variables used by the explained model, $T$ is the number of trees, $L$ is the number of leaves in a tree, and $D$ is the depth of a tree. 160 | 161 | ## CatBoost 162 | Originally, `treeshap` also supported the CatBoost models from the `catboost` package but due to the lack of this package on CRAN or R-universe (see `catboost` issues issues [#439](https://github.com/catboost/catboost/issues/439), [#1846](https://github.com/catboost/catboost/issues/1846)), we decided to remove support from the main version of our package. 163 | 164 | However, you can still use the `treeshap` implementation for `catboost` by installing our package from [`catboost` branch](https://github.com/ModelOriented/treeshap/tree/catboost). 165 | 166 | This branch can be installed with: 167 | 168 | ``` r 169 | devtools::install_github('ModelOriented/treeshap@catboost') 170 | ``` 171 | 172 | ## References 173 | - Lundberg, S.M., Erion, G., Chen, H. et al. "From local explanations to global understanding with explainable AI for trees", Nature Machine Intelligence 2, 56–67 (2020). 174 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # treeshap 5 | 6 | 7 | 8 | [![R-CMD-check](https://github.com/ModelOriented/treeshap/actions/workflows/CRAN-R-CMD-check.yaml/badge.svg)](https://github.com/ModelOriented/treeshap/actions/workflows/CRAN-R-CMD-check.yaml) 9 | [![CRAN 10 | status](https://www.r-pkg.org/badges/version/treeshap)](https://CRAN.R-project.org/package=treeshap) 11 | 12 | 13 | In the era of complicated classifiers conquering their market, sometimes 14 | even the authors of algorithms do not know the exact manner of building 15 | a tree ensemble model. The difficulties in models’ structures are one of 16 | the reasons why most users use them simply like black-boxes. But, how 17 | can they know whether the prediction made by the model is reasonable? 18 | `treeshap` is an efficient answer for this question. Due to implementing 19 | an optimized algorithm for tree ensemble models (called TreeSHAP), it 20 | calculates the SHAP values in polynomial (instead of exponential) time. 21 | Currently, `treeshap` supports models produced with `xgboost`, 22 | `lightgbm`, `gbm`, `ranger`, and `randomForest` packages. Support for 23 | `catboost` is available only in [`catboost` 24 | branch](https://github.com/ModelOriented/treeshap/tree/catboost) (see 25 | why [here](#catboost)). 26 | 27 | ## Installation 28 | 29 | The package is available on CRAN: 30 | 31 | ``` r 32 | install.packages('treeshap') 33 | ``` 34 | 35 | You can install the latest development version from GitHub using 36 | `devtools` with: 37 | 38 | ``` r 39 | devtools::install_github('ModelOriented/treeshap') 40 | ``` 41 | 42 | ## Example 43 | 44 | First of all, let’s focus on an example how to represent a `xgboost` 45 | model as a unified model object: 46 | 47 | ``` r 48 | library(treeshap) 49 | library(xgboost) 50 | data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 51 | target <- fifa20$target 52 | param <- list(objective = "reg:squarederror", max_depth = 6) 53 | xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, nrounds = 200, verbose = 0) 54 | unified <- unify(xgb_model, data) 55 | head(unified$model) 56 | #> Tree Node Feature Decision.type Split Yes No Missing Prediction Cover 57 | #> 1 0 0 overall <= 81.5 2 3 2 NA 18278 58 | #> 2 0 1 overall <= 73.5 4 5 4 NA 17949 59 | #> 3 0 2 overall <= 84.5 6 7 6 NA 329 60 | #> 4 0 3 overall <= 69.5 8 9 8 NA 15628 61 | #> 5 0 4 potential <= 79.5 10 11 10 NA 2321 62 | #> 6 0 5 potential <= 83.5 12 13 12 NA 221 63 | ``` 64 | 65 | Having the object of unified structure, it is a piece of cake to produce 66 | SHAP values for a specific observation. The `treeshap()` function 67 | requires passing two data arguments: one representing an ensemble model 68 | unified representation and one with the observations about which we want 69 | to get the explanations. Obviously, the latter one should contain the 70 | same columns as data used during building the model. 71 | 72 | ``` r 73 | treeshap1 <- treeshap(unified, data[700:800, ], verbose = 0) 74 | treeshap1$shaps[1:3, 1:6] 75 | #> age height_cm weight_kg overall potential international_reputation 76 | #> 700 297154.4 5769.186 12136.316 8739757 212428.8 -50855.738 77 | #> 701 -2550066.6 16011.136 3134.526 6525123 244814.2 22784.430 78 | #> 702 300830.3 -9023.299 15374.550 8585145 479118.8 2374.351 79 | ``` 80 | 81 | We can also compute SHAP values for interactions. As an example we will 82 | calculate them for a model built with simpler (only 5 columns) data and 83 | first 100 observations. 84 | 85 | ``` r 86 | data2 <- fifa20$data[, 1:5] 87 | xgb_model2 <- xgboost::xgboost(as.matrix(data2), params = param, label = target, nrounds = 200, verbose = 0) 88 | unified2 <- unify(xgb_model2, data2) 89 | 90 | treeshap_interactions <- treeshap(unified2, data2[1:100, ], interactions = TRUE, verbose = 0) 91 | treeshap_interactions$interactions[, , 1:2] 92 | #> , , 1 93 | #> 94 | #> age height_cm weight_kg overall potential 95 | #> age -1886241.70 -3984.09 -96765.97 -47245.92 1034657.6 96 | #> height_cm -3984.09 -628797.41 -35476.11 1871689.75 685472.2 97 | #> weight_kg -96765.97 -35476.11 -983162.25 2546930.16 1559453.5 98 | #> overall -47245.92 1871689.75 2546930.16 55289985.16 12683135.3 99 | #> potential 1034657.61 685472.23 1559453.46 12683135.27 868268.7 100 | #> 101 | #> , , 2 102 | #> 103 | #> age height_cm weight_kg overall potential 104 | #> age -2349987.9 306165.41 120483.91 -9871270.0 960198.02 105 | #> height_cm 306165.4 -78810.31 -48271.61 -991020.7 -44632.74 106 | #> weight_kg 120483.9 -48271.61 -21657.14 -615688.2 -380810.70 107 | #> overall -9871270.0 -991020.68 -615688.21 57384425.2 9603937.05 108 | #> potential 960198.0 -44632.74 -380810.70 9603937.1 2994190.74 109 | ``` 110 | 111 | ## Plotting results 112 | 113 | The explanation results can be visualized using 114 | [`shapviz`](https://github.com/ModelOriented/shapviz/) package, see 115 | [here](https://modeloriented.github.io/shapviz/articles/basic_use.html#treeshap). 116 | 117 | However, `treeshap` also provides 4 plotting functions: 118 | 119 | ### Feature Contribution (Break-Down) 120 | 121 | On this plot we can see how features contribute into the prediction for 122 | a single observation. It is similar to the Break Down plot from 123 | [iBreakDown](https://github.com/ModelOriented/iBreakDown) package, which 124 | uses different method to approximate SHAP values. 125 | 126 | ``` r 127 | plot_contribution(treeshap1, obs = 1, min_max = c(0, 16000000)) 128 | ``` 129 | 130 | 131 | 132 | ### Feature Importance 133 | 134 | This plot shows us average absolute impact of features on the prediction 135 | of the model. 136 | 137 | ``` r 138 | plot_feature_importance(treeshap1, max_vars = 6) 139 | ``` 140 | 141 | 142 | 143 | ### Feature Dependence 144 | 145 | Using this plot we can see, how a single feature contributes into the 146 | prediction depending on its value. 147 | 148 | ``` r 149 | plot_feature_dependence(treeshap1, "height_cm") 150 | ``` 151 | 152 | 153 | 154 | ### Interaction Plot 155 | 156 | Simple plot to visualize an SHAP Interaction value of two features 157 | depending on their values. 158 | 159 | ``` r 160 | plot_interaction(treeshap_interactions, "height_cm", "overall") 161 | ``` 162 | 163 | 164 | 165 | ## How to use the unifying functions? 166 | 167 | For your convenience, you can now simply use the `unify()` function by 168 | specifying your model and reference dataset. Behind the scenes, it uses 169 | one of the six functions from the `.unify()` family (`xgboost.unify()`, 170 | `lightgbm.unify()`, `gbm.unify()`, `catboost.unify()`, 171 | `randomForest.unify()`, `ranger.unify()`). Even though the objects 172 | produced by these functions are identical when it comes to the 173 | structure, due to different possibilities of saving and representing the 174 | trees among the packages, the usage of these model-specific functions 175 | may be slightly different. Therefore, you can use them independently or 176 | pass some additional parameters to `unify()`. 177 | 178 | ``` r 179 | library(treeshap) 180 | library(gbm) 181 | x <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 182 | x['value_eur'] <- fifa20$target 183 | gbm_model <- gbm::gbm( 184 | formula = value_eur ~ ., 185 | data = x, 186 | distribution = "laplace", 187 | n.trees = 200, 188 | cv.folds = 2, 189 | interaction.depth = 2 190 | ) 191 | unified_gbm <- unify(gbm_model, x) 192 | unified_gbm2 <- gbm.unify(gbm_model, x) # legacy API 193 | ``` 194 | 195 | ## Setting reference dataset 196 | 197 | Dataset used as a reference for calculating SHAP values is stored in 198 | unified model representation object. It can be set any time using 199 | `set_reference_dataset()` function. 200 | 201 | ``` r 202 | library(treeshap) 203 | library(ranger) 204 | data_fifa <- fifa20$data[!colnames(fifa20$data) %in% 205 | c('work_rate', 'value_eur', 'gk_diving', 'gk_handling', 206 | 'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')] 207 | data <- na.omit(cbind(data_fifa, target = fifa20$target)) 208 | rf <- ranger::ranger(target~., data = data, max.depth = 10, num.trees = 10) 209 | 210 | unified_ranger_model <- unify(rf, data) 211 | unified_ranger_model2 <- set_reference_dataset(unified_ranger_model, data[c(1000:2000), ]) 212 | ``` 213 | 214 | ## Other functionalities 215 | 216 | Package also implements `predict()` function for calculating model’s 217 | predictions using unified representation. 218 | 219 | ## How fast does it work? 220 | 221 | The complexity of TreeSHAP is $\mathcal{O}(TLD^2)$, where $T$ is the 222 | number of trees, $L$ is the number of leaves in a tree, and $D$ is the 223 | depth of a tree. 224 | 225 | Our implementation works at a speed comparable to the original 226 | Lundberg’s Python package `shap` implementation using C and Python. 227 | 228 | The complexity of SHAP interaction values computation is 229 | $\mathcal{O}(MTLD^2)$, where $M$ is the number of explanatory variables 230 | used by the explained model, $T$ is the number of trees, $L$ is the 231 | number of leaves in a tree, and $D$ is the depth of a tree. 232 | 233 | ## CatBoost 234 | 235 | Originally, `treeshap` also supported the CatBoost models from the 236 | `catboost` package but due to the lack of this package on CRAN or 237 | R-universe (see `catboost` issues issues 238 | [\#439](https://github.com/catboost/catboost/issues/439), 239 | [\#1846](https://github.com/catboost/catboost/issues/1846)), we decided 240 | to remove support from the main version of our package. 241 | 242 | However, you can still use the `treeshap` implementation for `catboost` 243 | by installing our package from [`catboost` 244 | branch](https://github.com/ModelOriented/treeshap/tree/catboost). 245 | 246 | This branch can be installed with: 247 | 248 | ``` r 249 | devtools::install_github('ModelOriented/treeshap@catboost') 250 | ``` 251 | 252 | ## References 253 | 254 | - Lundberg, S.M., Erion, G., Chen, H. et al. “From local explanations to 255 | global understanding with explainable AI for trees”, Nature Machine 256 | Intelligence 2, 56–67 (2020). 257 | -------------------------------------------------------------------------------- /_pkgdown.yml: -------------------------------------------------------------------------------- 1 | template: 2 | package: DrWhyTemplate 3 | default_assets: false 4 | reference: 5 | - title: TreeSHAP 6 | desc: Calculate SHAP values for your model 7 | - contents: 8 | - treeshap 9 | - treeshap.object 10 | - treeshap_multioutput.object 11 | - title: Unifiers 12 | desc: Convert your model into a standardized representation 13 | - contents: 14 | - unify 15 | - ends_with(".unify") 16 | - model_unified.object 17 | - model_unified_multioutput.object 18 | - title: Plotting functions 19 | desc: Plot explanation results 20 | - contents: 21 | - starts_with("plot") 22 | - colors_discrete_drwhy 23 | - colors_breakdown_drwhy 24 | - theme_drwhy 25 | - theme_drwhy_vertical 26 | - title: Printing functions 27 | - contents: 28 | - starts_with("print") 29 | - title: Utility functions 30 | - contents: 31 | - set_reference_dataset 32 | - predict.model_unified 33 | - is.model_unified 34 | - is.treeshap 35 | - title: Data 36 | - contents: fifa20 37 | 38 | 39 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- 1 | ## R CMD check results 2 | 3 | 0 errors | 0 warnings | 1 note 4 | 5 | Note about the package being archived on CRAN as issues were not corrected 6 | in time. There were errors in examples related to changed syntax in 'lightgbm' package. I have resolved these issues in the current version. 7 | 8 | ## revdepcheck results 9 | 10 | We checked 1 reverse dependencies, comparing R CMD check results across CRAN and dev versions of this package. 11 | 12 | * We saw 0 new problems 13 | * We failed to check 0 packages 14 | -------------------------------------------------------------------------------- /data-raw/fifa20.R: -------------------------------------------------------------------------------- 1 | options(stringsAsFactors = FALSE) 2 | fifa20_raw <- read.csv('~/Documents/players_20.csv') 3 | fifa20_num <- fifa20_raw[,sapply(fifa20_raw, is.numeric)] 4 | fifa20_all <- fifa20_num[!(colnames(fifa20_num) %in% c('sofifa_id', 5 | 'wage_eur', 'release_clause_eur', 'team_jersey_number', 'contract_valid_until', 'nation_jersey_number'))] 6 | fifa20_all[['work_rate']] <- as.factor(fifa20_raw[['work_rate']]) 7 | fifa_target <- fifa20_all[['value_eur']] 8 | fifa20_data <- fifa20_all[colnames(fifa20_all) != 'value_eur'] 9 | fifa20 <- list(data = fifa20_data, target = fifa_target) 10 | 11 | 12 | usethis::use_data(fifa20, overwrite = TRUE) 13 | -------------------------------------------------------------------------------- /data/fifa20.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/treeshap/a1a5472fe5cb5039079b20eace513c4cf82dac51/data/fifa20.rda -------------------------------------------------------------------------------- /man/colors_drwhy.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/theme_drwhy.R 3 | \name{colors_discrete_drwhy} 4 | \alias{colors_discrete_drwhy} 5 | \alias{colors_breakdown_drwhy} 6 | \title{DrWhy color palettes for ggplot objects} 7 | \usage{ 8 | colors_discrete_drwhy(n = 2) 9 | 10 | colors_breakdown_drwhy() 11 | } 12 | \arguments{ 13 | \item{n}{number of colors for color palette} 14 | } 15 | \value{ 16 | color palette as vector of characters 17 | } 18 | \description{ 19 | DrWhy color palettes for ggplot objects 20 | } 21 | -------------------------------------------------------------------------------- /man/fifa20.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/fifa20.R 3 | \docType{data} 4 | \name{fifa20} 5 | \alias{fifa20} 6 | \title{Attributes of all players in FIFA 20} 7 | \format{ 8 | A data frame with 18278 rows and 56 columns. 9 | Most of variables representing skills are in range from 0 to 100 and will not be described here. 10 | To list non obvious features: 11 | \describe{ 12 | \item{overall}{Overall score of player's skills} 13 | \item{potential}{Potential of a player, younger players tend to have higher level of potential} 14 | \item{value_eur}{Market value of a player (in mln EUR)} 15 | \item{international_reputation}{Range 1 to 5} 16 | \item{weak_foot}{Range 1 to 5} 17 | \item{skill_moves}{Range 1 to 5} 18 | \item{work_rate}{Divided by slash levels of willingness to work in offense and defense respectively} 19 | } 20 | } 21 | \source{ 22 | "Data has been scraped from the publicly available website \url{https://sofifa.com}" 23 | \url{https://www.kaggle.com/stefanoleone992/fifa-20-complete-player-dataset} 24 | } 25 | \usage{ 26 | fifa20 27 | } 28 | \description{ 29 | Dataset consists of 56 columns, 55 numeric and one of type factor \code{'work_rate'}. 30 | \code{value_eur} is a potential target feature. 31 | } 32 | \keyword{datasets} 33 | -------------------------------------------------------------------------------- /man/figures/README-plot_contribution_example-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/treeshap/a1a5472fe5cb5039079b20eace513c4cf82dac51/man/figures/README-plot_contribution_example-1.png -------------------------------------------------------------------------------- /man/figures/README-plot_dependence_example-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/treeshap/a1a5472fe5cb5039079b20eace513c4cf82dac51/man/figures/README-plot_dependence_example-1.png -------------------------------------------------------------------------------- /man/figures/README-plot_importance_example-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/treeshap/a1a5472fe5cb5039079b20eace513c4cf82dac51/man/figures/README-plot_importance_example-1.png -------------------------------------------------------------------------------- /man/figures/README-plot_interaction-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/treeshap/a1a5472fe5cb5039079b20eace513c4cf82dac51/man/figures/README-plot_interaction-1.png -------------------------------------------------------------------------------- /man/gbm.unify.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/unify_gbm.R 3 | \name{gbm.unify} 4 | \alias{gbm.unify} 5 | \title{Unify GBM model} 6 | \usage{ 7 | gbm.unify(gbm_model, data) 8 | } 9 | \arguments{ 10 | \item{gbm_model}{An object of \code{gbm} class. At the moment, models built on data with categorical features 11 | are not supported - please encode them before training.} 12 | 13 | \item{data}{Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model.} 14 | } 15 | \value{ 16 | a unified model representation - a \code{\link{model_unified.object}} object 17 | } 18 | \description{ 19 | Convert your GBM model into a standardized representation. 20 | The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. 21 | } 22 | \examples{ 23 | \donttest{ 24 | library(gbm) 25 | data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 26 | data['value_eur'] <- fifa20$target 27 | gbm_model <- gbm::gbm( 28 | formula = value_eur ~ ., 29 | data = data, 30 | distribution = "gaussian", 31 | n.trees = 20, 32 | interaction.depth = 4, 33 | n.cores = 1) 34 | unified_model <- gbm.unify(gbm_model, data) 35 | shaps <- treeshap(unified_model, data[1:2,]) 36 | plot_contribution(shaps, obs = 1) 37 | } 38 | } 39 | \seealso{ 40 | \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} 41 | 42 | \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} 43 | 44 | \code{\link{ranger.unify}} for \code{\link[ranger:ranger]{ranger models}} 45 | 46 | \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} 47 | } 48 | -------------------------------------------------------------------------------- /man/is.model_unified.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/model_unified.R 3 | \name{is.model_unified} 4 | \alias{is.model_unified} 5 | \title{Check whether object is a valid model_unified object} 6 | \usage{ 7 | is.model_unified(x) 8 | } 9 | \arguments{ 10 | \item{x}{an object to check} 11 | } 12 | \value{ 13 | boolean 14 | } 15 | \description{ 16 | Does not check correctness of representation, only basic checks 17 | } 18 | -------------------------------------------------------------------------------- /man/is.treeshap.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/treeshap.R 3 | \name{is.treeshap} 4 | \alias{is.treeshap} 5 | \title{Check whether object is a valid treeshap object} 6 | \usage{ 7 | is.treeshap(x) 8 | } 9 | \arguments{ 10 | \item{x}{an object to check} 11 | } 12 | \value{ 13 | boolean 14 | } 15 | \description{ 16 | Does not check correctness of result, only basic checks 17 | } 18 | -------------------------------------------------------------------------------- /man/lightgbm.unify.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/unify_lightgbm.R 3 | \name{lightgbm.unify} 4 | \alias{lightgbm.unify} 5 | \title{Unify LightGBM model} 6 | \usage{ 7 | lightgbm.unify(lgb_model, data, recalculate = FALSE) 8 | } 9 | \arguments{ 10 | \item{lgb_model}{A lightgbm model - object of class \code{lgb.Booster}} 11 | 12 | \item{data}{Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model.} 13 | 14 | \item{recalculate}{logical indicating if covers should be recalculated according to the dataset given in data. Keep it \code{FALSE} if training data are used.} 15 | } 16 | \value{ 17 | a unified model representation - a \code{\link{model_unified.object}} object 18 | } 19 | \description{ 20 | Convert your LightGBM model into a standardized representation. 21 | The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. 22 | } 23 | \examples{ 24 | \donttest{ 25 | library(lightgbm) 26 | param_lgbm <- list(objective = "regression", max_depth = 2, 27 | force_row_wise = TRUE, num_iterations = 20) 28 | data_fifa <- fifa20$data[!colnames(fifa20$data) \%in\% 29 | c('work_rate', 'value_eur', 'gk_diving', 'gk_handling', 30 | 'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')] 31 | data <- na.omit(cbind(data_fifa, fifa20$target)) 32 | sparse_data <- as.matrix(data[,-ncol(data)]) 33 | x <- lightgbm::lgb.Dataset(sparse_data, label = as.matrix(data[,ncol(data)])) 34 | lgb_data <- lightgbm::lgb.Dataset.construct(x) 35 | lgb_model <- lightgbm::lightgbm(data = lgb_data, params = param_lgbm, 36 | verbose = -1, num_threads = 0) 37 | unified_model <- lightgbm.unify(lgb_model, sparse_data) 38 | shaps <- treeshap(unified_model, data[1:2, ]) 39 | plot_contribution(shaps, obs = 1) 40 | } 41 | } 42 | \seealso{ 43 | \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} 44 | 45 | \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} 46 | 47 | \code{\link{ranger.unify}} for \code{\link[ranger:ranger]{ranger models}} 48 | 49 | \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} 50 | } 51 | -------------------------------------------------------------------------------- /man/model_unified.object.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/model_unified.R 3 | \name{model_unified.object} 4 | \alias{model_unified.object} 5 | \title{Unified model representation} 6 | \value{ 7 | List consisting of two elements: 8 | 9 | \strong{model} - A \code{data.frame} representing model with following columns: 10 | 11 | \item{Tree}{0-indexed ID of a tree} 12 | \item{Node}{0-indexed ID of a node in a tree. In a tree the root always has ID 0} 13 | \item{Feature}{In case of an internal node - name of a feature to split on. Otherwise - NA} 14 | \item{Decision.type}{A factor with two levels: "<" and "<=". In case of an internal node - predicate used for splitting observations. Otherwise - NA} 15 | \item{Split}{For internal nodes threshold used for splitting observations. All observations that satisfy the predicate Decision.type(Split) ('< Split' / '<= Split') are proceeded to the node marked as 'Yes'. Otherwise to the 'No' node. For leaves - NA} 16 | \item{Yes}{Index of a row containing a child Node. Thanks to explicit indicating the row it is much faster to move between nodes} 17 | \item{No}{Index of a row containing a child Node} 18 | \item{Missing}{Index of a row containing a child Node where are proceeded all observations with no value of the dividing feature} 19 | \item{Prediction}{For leaves: Value of prediction in the leaf. For internal nodes: NA} 20 | \item{Cover}{Number of observations seen by the internal node or collected by the leaf for the reference dataset} 21 | 22 | \strong{data} - Dataset used as a reference for calculating SHAP values. A dataset passed to the \code{*.unify}, \code{unify} or \code{\link{set_reference_dataset}} function with \code{data} argument. A \code{data.frame}. 23 | 24 | Object has two also attributes set: 25 | \item{\code{model}}{A string. By what package the model was produced.} 26 | \item{\code{missing_support}}{A boolean. Whether the model allows missing values to be present in explained dataset.} 27 | } 28 | \description{ 29 | \code{model_unified} object produced by \code{*.unify} or \code{unify} function. 30 | } 31 | \seealso{ 32 | \code{\link{unify}} 33 | } 34 | -------------------------------------------------------------------------------- /man/model_unified_multioutput.object.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/model_unified.R 3 | \name{model_unified_multioutput.object} 4 | \alias{model_unified_multioutput.object} 5 | \title{Unified model representations for multi-output model} 6 | \value{ 7 | List consisting of \code{model_unified} objects, one for each individual output of a model. For survival models, the list is named using the time points, for which predictions are calculated. 8 | } 9 | \description{ 10 | \code{model_unified_multioutput} object produced by \code{*.unify} or \code{unify} function. 11 | } 12 | \seealso{ 13 | \code{\link{unify}} 14 | } 15 | -------------------------------------------------------------------------------- /man/plot_contribution.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/plot_contribution.R 3 | \name{plot_contribution} 4 | \alias{plot_contribution} 5 | \title{SHAP value based Break-Down plot} 6 | \usage{ 7 | plot_contribution( 8 | treeshap, 9 | obs = 1, 10 | max_vars = 5, 11 | min_max = NA, 12 | digits = 3, 13 | explain_deviation = FALSE, 14 | title = "SHAP Break-Down", 15 | subtitle = "" 16 | ) 17 | } 18 | \arguments{ 19 | \item{treeshap}{A treeshap object produced with the \code{\link{treeshap}} function. \code{\link{treeshap.object}}.} 20 | 21 | \item{obs}{A numeric indicating which observation should be plotted. Be default it's first observation.} 22 | 23 | \item{max_vars}{maximum number of variables that shall be presented. Variables with the highest importance will be presented. 24 | Remaining variables will be summed into one additional contribution. By default \code{5}.} 25 | 26 | \item{min_max}{a range of OX axis. By default \code{NA}, therefore it will be extracted from the contributions of \code{x}. 27 | But it can be set to some constants, useful if these plots are to be used for comparisons.} 28 | 29 | \item{digits}{number of decimal places (\code{\link{round}}) to be used.} 30 | 31 | \item{explain_deviation}{if \code{TRUE} then instead of explaining prediction and plotting intercept bar, only deviation from mean prediction of the reference dataset will be explained. By default \code{FALSE}.} 32 | 33 | \item{title}{the plot's title, by default \code{'SHAP Break-Down'}.} 34 | 35 | \item{subtitle}{the plot's subtitle. By default no subtitle.} 36 | } 37 | \value{ 38 | a \code{ggplot2} object 39 | } 40 | \description{ 41 | This function plots contributions of features into the prediction for a single observation. 42 | } 43 | \examples{ 44 | \donttest{ 45 | library(xgboost) 46 | data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 47 | target <- fifa20$target 48 | param <- list(objective = "reg:squarederror", max_depth = 3) 49 | xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, 50 | nrounds = 20, verbose = FALSE) 51 | unified_model <- xgboost.unify(xgb_model, as.matrix(data)) 52 | x <- head(data, 1) 53 | shap <- treeshap(unified_model, x) 54 | plot_contribution(shap, 1, min_max = c(0, 120000000)) 55 | } 56 | } 57 | \seealso{ 58 | \code{\link{treeshap}} for calculation of SHAP values 59 | 60 | \code{\link{plot_feature_importance}}, \code{\link{plot_feature_dependence}}, \code{\link{plot_interaction}} 61 | } 62 | -------------------------------------------------------------------------------- /man/plot_feature_dependence.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/plot_feature_dependence.R 3 | \name{plot_feature_dependence} 4 | \alias{plot_feature_dependence} 5 | \title{SHAP value based Feature Dependence plot} 6 | \usage{ 7 | plot_feature_dependence( 8 | treeshap, 9 | variable, 10 | title = "Feature Dependence", 11 | subtitle = NULL 12 | ) 13 | } 14 | \arguments{ 15 | \item{treeshap}{A treeshap object produced with the \code{\link{treeshap}} function. \code{\link{treeshap.object}}.} 16 | 17 | \item{variable}{name or index of variable for which feature dependence will be plotted.} 18 | 19 | \item{title}{the plot's title, by default \code{'Feature Dependence'}.} 20 | 21 | \item{subtitle}{the plot's subtitle. By default no subtitle.} 22 | } 23 | \value{ 24 | a \code{ggplot2} object 25 | } 26 | \description{ 27 | Depending on the value of a variable: how does it contribute into the prediction? 28 | } 29 | \examples{ 30 | \donttest{ 31 | library(xgboost) 32 | data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 33 | target <- fifa20$target 34 | param <- list(objective = "reg:squarederror", max_depth = 3) 35 | xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, 36 | nrounds = 20, verbose = FALSE) 37 | unified_model <- xgboost.unify(xgb_model, as.matrix(data)) 38 | x <- head(data, 100) 39 | shaps <- treeshap(unified_model, x) 40 | plot_feature_dependence(shaps, variable = "overall") 41 | } 42 | } 43 | \seealso{ 44 | \code{\link{treeshap}} for calculation of SHAP values 45 | 46 | \code{\link{plot_contribution}}, \code{\link{plot_feature_importance}}, \code{\link{plot_interaction}} 47 | } 48 | -------------------------------------------------------------------------------- /man/plot_feature_importance.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/plot_feature_importance.R 3 | \name{plot_feature_importance} 4 | \alias{plot_feature_importance} 5 | \title{SHAP value based Feature Importance plot} 6 | \usage{ 7 | plot_feature_importance( 8 | treeshap, 9 | desc_sorting = TRUE, 10 | max_vars = ncol(shaps), 11 | title = "Feature Importance", 12 | subtitle = NULL 13 | ) 14 | } 15 | \arguments{ 16 | \item{treeshap}{A treeshap object produced with the \code{\link{treeshap}} function. \code{\link{treeshap.object}}.} 17 | 18 | \item{desc_sorting}{logical. Should the bars be sorted descending? By default TRUE.} 19 | 20 | \item{max_vars}{maximum number of variables that shall be presented. By default all are presented.} 21 | 22 | \item{title}{the plot's title, by default \code{'Feature Importance'}.} 23 | 24 | \item{subtitle}{the plot's subtitle. By default no subtitle.} 25 | } 26 | \value{ 27 | a \code{ggplot2} object 28 | } 29 | \description{ 30 | This function plots feature importance calculated as means of absolute values of SHAP values of variables (average impact on model output magnitude). 31 | } 32 | \examples{ 33 | \donttest{ 34 | library(xgboost) 35 | data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 36 | target <- fifa20$target 37 | param <- list(objective = "reg:squarederror", max_depth = 3) 38 | xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, 39 | nrounds = 20, verbose = FALSE) 40 | unified_model <- xgboost.unify(xgb_model, as.matrix(data)) 41 | shaps <- treeshap(unified_model, as.matrix(head(data, 3))) 42 | plot_feature_importance(shaps, max_vars = 4) 43 | } 44 | } 45 | \seealso{ 46 | \code{\link{treeshap}} for calculation of SHAP values 47 | 48 | \code{\link{plot_contribution}}, \code{\link{plot_feature_dependence}}, \code{\link{plot_interaction}} 49 | } 50 | -------------------------------------------------------------------------------- /man/plot_interaction.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/plot_interaction.R 3 | \name{plot_interaction} 4 | \alias{plot_interaction} 5 | \title{SHAP Interaction value plot} 6 | \usage{ 7 | plot_interaction( 8 | treeshap, 9 | var1, 10 | var2, 11 | title = "SHAP Interaction Value Plot", 12 | subtitle = "" 13 | ) 14 | } 15 | \arguments{ 16 | \item{treeshap}{A treeshap object produced with \code{\link{treeshap}(interactions = TRUE)} function. \code{\link{treeshap.object}}.} 17 | 18 | \item{var1}{name or index of the first variable - plotted on x axis.} 19 | 20 | \item{var2}{name or index of the second variable - marked with color.} 21 | 22 | \item{title}{the plot's title, by default \code{'SHAP Interaction Value Plot'}.} 23 | 24 | \item{subtitle}{the plot's subtitle. By default no subtitle.} 25 | } 26 | \value{ 27 | a \code{ggplot2} object 28 | } 29 | \description{ 30 | This function plots SHAP Interaction value for two variables depending on the value of the first variable. 31 | Value of the second variable is marked with the color. 32 | } 33 | \examples{ 34 | \donttest{ 35 | data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 36 | target <- fifa20$target 37 | param2 <- list(objective = "reg:squarederror", max_depth = 5) 38 | xgb_model2 <- xgboost::xgboost(as.matrix(data), params = param2, label = target, nrounds = 10) 39 | unified_model2 <- xgboost.unify(xgb_model2, data) 40 | inters <- treeshap(unified_model2, as.matrix(data[1:50, ]), interactions = TRUE) 41 | plot_interaction(inters, "dribbling", "defending") 42 | } 43 | } 44 | \seealso{ 45 | \code{\link{treeshap}} for calculation of SHAP Interaction values 46 | 47 | \code{\link{plot_contribution}}, \code{\link{plot_feature_importance}}, \code{\link{plot_feature_dependence}} 48 | } 49 | -------------------------------------------------------------------------------- /man/predict.model_unified.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/predict.R 3 | \name{predict.model_unified} 4 | \alias{predict.model_unified} 5 | \title{Predict} 6 | \usage{ 7 | \method{predict}{model_unified}(object, x, ...) 8 | } 9 | \arguments{ 10 | \item{object}{Unified model representation of the model created with a (model).unify function. \code{\link{model_unified.object}}} 11 | 12 | \item{x}{Observations to predict. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model.} 13 | 14 | \item{...}{other parameters} 15 | } 16 | \value{ 17 | a vector of predictions. 18 | } 19 | \description{ 20 | Predict using unified_model representation. 21 | } 22 | \examples{ 23 | \donttest{ 24 | library(gbm) 25 | data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 26 | data['value_eur'] <- fifa20$target 27 | gbm_model <- gbm::gbm( 28 | formula = value_eur ~ ., 29 | data = data, 30 | distribution = "laplace", 31 | n.trees = 20, 32 | interaction.depth = 4, 33 | n.cores = 1) 34 | unified <- gbm.unify(gbm_model, data) 35 | predict(unified, data[2001:2005, ]) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /man/print.model_unified.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/model_unified.R 3 | \name{print.model_unified} 4 | \alias{print.model_unified} 5 | \title{Prints model_unified objects} 6 | \usage{ 7 | \method{print}{model_unified}(x, ...) 8 | } 9 | \arguments{ 10 | \item{x}{a model_unified object} 11 | 12 | \item{...}{other arguments} 13 | } 14 | \value{ 15 | No return value, called for printing 16 | } 17 | \description{ 18 | Prints model_unified objects 19 | } 20 | -------------------------------------------------------------------------------- /man/print.model_unified_multioutput.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/model_unified.R 3 | \name{print.model_unified_multioutput} 4 | \alias{print.model_unified_multioutput} 5 | \title{Prints model_unified_multioutput objects} 6 | \usage{ 7 | \method{print}{model_unified_multioutput}(x, ...) 8 | } 9 | \arguments{ 10 | \item{x}{a model_unified_multioutput object} 11 | 12 | \item{...}{other arguments} 13 | } 14 | \value{ 15 | No return value, called for printing 16 | } 17 | \description{ 18 | Prints model_unified_multioutput objects 19 | } 20 | -------------------------------------------------------------------------------- /man/print.treeshap.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/treeshap.R 3 | \name{print.treeshap} 4 | \alias{print.treeshap} 5 | \title{Prints treeshap objects} 6 | \usage{ 7 | \method{print}{treeshap}(x, ...) 8 | } 9 | \arguments{ 10 | \item{x}{a treeshap object} 11 | 12 | \item{...}{other arguments} 13 | } 14 | \value{ 15 | No return value, called for printing 16 | } 17 | \description{ 18 | Prints treeshap objects 19 | } 20 | -------------------------------------------------------------------------------- /man/print.treeshap_multioutput.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/treeshap.R 3 | \name{print.treeshap_multioutput} 4 | \alias{print.treeshap_multioutput} 5 | \title{Prints treeshap_multioutput objects} 6 | \usage{ 7 | \method{print}{treeshap_multioutput}(x, ...) 8 | } 9 | \arguments{ 10 | \item{x}{a treeshap_multioutput object} 11 | 12 | \item{...}{other arguments} 13 | } 14 | \value{ 15 | No return value, called for printing 16 | } 17 | \description{ 18 | Prints treeshap_multioutput objects 19 | } 20 | -------------------------------------------------------------------------------- /man/randomForest.unify.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/unify_randomForest.R 3 | \name{randomForest.unify} 4 | \alias{randomForest.unify} 5 | \title{Unify randomForest model} 6 | \usage{ 7 | randomForest.unify(rf_model, data) 8 | } 9 | \arguments{ 10 | \item{rf_model}{An object of \code{randomForest} class. At the moment, models built on data with categorical features 11 | are not supported - please encode them before training.} 12 | 13 | \item{data}{Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model.} 14 | } 15 | \value{ 16 | a unified model representation - a \code{\link{model_unified.object}} object 17 | } 18 | \description{ 19 | Convert your randomForest model into a standardized representation. 20 | The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. 21 | } 22 | \details{ 23 | Binary classification models with a target variable that is a factor with two levels, 0 and 1, are supported 24 | } 25 | \examples{ 26 | 27 | library(randomForest) 28 | data_fifa <- fifa20$data[!colnames(fifa20$data) \%in\% 29 | c('work_rate', 'value_eur', 'gk_diving', 'gk_handling', 30 | 'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')] 31 | data <- na.omit(cbind(data_fifa, target = fifa20$target)) 32 | 33 | rf <- randomForest::randomForest(target~., data = data, maxnodes = 10, ntree = 10) 34 | unified_model <- randomForest.unify(rf, data) 35 | shaps <- treeshap(unified_model, data[1:2,]) 36 | # plot_contribution(shaps, obs = 1) 37 | 38 | } 39 | \seealso{ 40 | \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} 41 | 42 | \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} 43 | 44 | \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} 45 | 46 | \code{\link{ranger.unify}} for \code{\link[ranger:ranger]{ranger models}} 47 | } 48 | -------------------------------------------------------------------------------- /man/ranger.unify.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/unify_ranger.R 3 | \name{ranger.unify} 4 | \alias{ranger.unify} 5 | \title{Unify ranger model} 6 | \usage{ 7 | ranger.unify(rf_model, data) 8 | } 9 | \arguments{ 10 | \item{rf_model}{An object of \code{ranger} class. At the moment, models built on data with categorical features 11 | are not supported - please encode them before training.} 12 | 13 | \item{data}{Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model.} 14 | } 15 | \value{ 16 | a unified model representation - a \code{\link{model_unified.object}} object 17 | } 18 | \description{ 19 | Convert your ranger model into a standardized representation. 20 | The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. 21 | } 22 | \examples{ 23 | 24 | library(ranger) 25 | data_fifa <- fifa20$data[!colnames(fifa20$data) \%in\% 26 | c('work_rate', 'value_eur', 'gk_diving', 'gk_handling', 27 | 'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')] 28 | data <- na.omit(cbind(data_fifa, target = fifa20$target)) 29 | 30 | rf <- ranger::ranger(target~., data = data, max.depth = 10, num.trees = 10) 31 | unified_model <- ranger.unify(rf, data) 32 | shaps <- treeshap(unified_model, data[1:2,]) 33 | plot_contribution(shaps, obs = 1) 34 | } 35 | \seealso{ 36 | \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} 37 | 38 | \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} 39 | 40 | \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} 41 | 42 | \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} 43 | } 44 | -------------------------------------------------------------------------------- /man/ranger_surv.unify.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/unify_ranger_surv.R 3 | \name{ranger_surv.unify} 4 | \alias{ranger_surv.unify} 5 | \title{Unify ranger survival model} 6 | \usage{ 7 | ranger_surv.unify( 8 | rf_model, 9 | data, 10 | type = c("risk", "survival", "chf"), 11 | times = NULL 12 | ) 13 | } 14 | \arguments{ 15 | \item{rf_model}{An object of \code{ranger} class. At the moment, models built on data with categorical features 16 | are not supported - please encode them before training.} 17 | 18 | \item{data}{Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model.} 19 | 20 | \item{type}{A character to define the type of model prediction to use. Either \code{"risk"} (default), which uses the risk score calculated as a sum of cumulative hazard function values, \code{"survival"}, which uses the survival probability at certain time-points for each observation, or \code{"chf"}, which used the cumulative hazard values at certain time-points for each observation.} 21 | 22 | \item{times}{A numeric vector of unique death times at which the prediction should be evaluated. By default \code{unique.death.times} from model are used.} 23 | } 24 | \value{ 25 | For \code{type = "risk"} a unified model representation is returned - a \code{\link{model_unified.object}} object. For \code{type = "survival"} or \code{type = "chf"} - a \code{\link{model_unified_multioutput.object}} object is returned, which is a list that contains unified model representation (\code{\link{model_unified.object}} object) for each time point. In this case, the list names are time points at which the survival function was evaluated. 26 | } 27 | \description{ 28 | Convert your ranger model into a standardized representation. 29 | The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. 30 | } 31 | \details{ 32 | The survival forest implemented in the \code{ranger} package stores cumulative hazard 33 | functions (CHFs) in the leaves of survival trees, as proposed for Random Survival Forests 34 | (Ishwaran et al. 2008). The final model prediction is made by averaging these CHFs 35 | from all the trees. To provide explanations in the form of a survival function, 36 | the CHFs from the leaves are converted into survival functions (SFs) using 37 | the formula SF(t) = exp(-CHF(t)). 38 | However, it is important to note that averaging these SFs does not yield the correct 39 | model prediction as the model prediction is the average of CHFs transformed in the same way. 40 | Therefore, when you obtain explanations based on the survival function, 41 | they are only proxies and may not be fully consistent with the model predictions 42 | obtained using for example \code{predict} function. 43 | } 44 | \examples{ 45 | 46 | library(ranger) 47 | data_colon <- data.table::data.table(survival::colon) 48 | data_colon <- na.omit(data_colon[get("etype") == 2, ]) 49 | surv_cols <- c("status", "time", "rx") 50 | 51 | feature_cols <- colnames(data_colon)[3:(ncol(data_colon) - 1)] 52 | 53 | train_x <- model.matrix( 54 | ~ -1 + ., 55 | data_colon[, .SD, .SDcols = setdiff(feature_cols, surv_cols[1:2])] 56 | ) 57 | train_y <- survival::Surv( 58 | event = (data_colon[, get("status")] |> 59 | as.character() |> 60 | as.integer()), 61 | time = data_colon[, get("time")], 62 | type = "right" 63 | ) 64 | 65 | rf <- ranger::ranger( 66 | x = train_x, 67 | y = train_y, 68 | data = data_colon, 69 | max.depth = 10, 70 | num.trees = 10 71 | ) 72 | unified_model_risk <- ranger_surv.unify(rf, train_x, type = "risk") 73 | shaps <- treeshap(unified_model_risk, train_x[1:2,]) 74 | 75 | # compute shaps for 3 selected time points 76 | unified_model_surv <- ranger_surv.unify(rf, train_x, type = "survival", times = c(23, 50, 73)) 77 | shaps_surv <- treeshap(unified_model_surv, train_x[1:2,]) 78 | 79 | } 80 | \seealso{ 81 | \code{\link{ranger.unify}} for regression and classification \code{\link[ranger:ranger]{ranger models}} 82 | 83 | \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} 84 | 85 | \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} 86 | 87 | \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} 88 | 89 | \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} 90 | } 91 | -------------------------------------------------------------------------------- /man/set_reference_dataset.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/set_reference_dataset.R 3 | \name{set_reference_dataset} 4 | \alias{set_reference_dataset} 5 | \title{Set reference dataset} 6 | \usage{ 7 | set_reference_dataset(unified_model, x) 8 | } 9 | \arguments{ 10 | \item{unified_model}{Unified model representation of the model created with a (model).unify function. (\code{\link{model_unified.object}}).} 11 | 12 | \item{x}{Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model.} 13 | } 14 | \value{ 15 | \code{\link{model_unified.object}}. Unified representation of the model as created with a (model).unify function, 16 | but with changed reference dataset (Cover column containing updated values). 17 | } 18 | \description{ 19 | Change a dataset used as reference for calculating SHAP values. 20 | Reference dataset is initially set with \code{data} argument in unifying function. 21 | Usually reference dataset is dataset used to train the model. 22 | Important property of reference dataset is that SHAP values for each observation add up to its deviation from mean prediction for a reference dataset. 23 | } 24 | \examples{ 25 | \donttest{ 26 | library(gbm) 27 | data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 28 | data['value_eur'] <- fifa20$target 29 | gbm_model <- gbm::gbm( 30 | formula = value_eur ~ ., 31 | data = data, 32 | distribution = "laplace", 33 | n.trees = 20, 34 | interaction.depth = 4, 35 | n.cores = 1) 36 | unified <- gbm.unify(gbm_model, data) 37 | set_reference_dataset(unified, data[200:700, ]) 38 | } 39 | } 40 | \seealso{ 41 | \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} 42 | 43 | \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} 44 | 45 | \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} 46 | 47 | \code{\link{ranger.unify}} for \code{\link[ranger:ranger]{ranger models}} 48 | 49 | \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} 50 | } 51 | -------------------------------------------------------------------------------- /man/theme_drwhy.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/theme_drwhy.R 3 | \name{theme_drwhy} 4 | \alias{theme_drwhy} 5 | \alias{theme_drwhy_vertical} 6 | \title{DrWhy Theme for ggplot objects} 7 | \usage{ 8 | theme_drwhy() 9 | 10 | theme_drwhy_vertical() 11 | } 12 | \value{ 13 | theme for ggplot2 objects 14 | } 15 | \description{ 16 | DrWhy Theme for ggplot objects 17 | } 18 | -------------------------------------------------------------------------------- /man/treeshap.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/treeshap.R 3 | \name{treeshap} 4 | \alias{treeshap} 5 | \title{Calculate SHAP values of a tree ensemble model.} 6 | \usage{ 7 | treeshap(unified_model, x, interactions = FALSE, verbose = TRUE) 8 | } 9 | \arguments{ 10 | \item{unified_model}{Unified data.frame representation of the model created with a (model).unify function. A \code{\link{model_unified.object}} object.} 11 | 12 | \item{x}{Observations to be explained. A \code{data.frame} or \code{matrix} object with the same columns as in the training set of the model. Keep in mind that objects different than \code{data.frame} or plain \code{matrix} will cause an error or unpredictable behavior.} 13 | 14 | \item{interactions}{Whether to calculate SHAP interaction values. By default is \code{FALSE}. Basic SHAP values are always calculated.} 15 | 16 | \item{verbose}{Whether to print progress bar to the console. Should be logical. Progress bar will not be displayed on Windows.} 17 | } 18 | \value{ 19 | A \code{\link{treeshap.object}} object (for single-output models) or \code{\link{treeshap_multioutput.object}}, which is a list of \code{\link{treeshap.object}} objects (for multi-output models). SHAP values can be accessed from \code{\link{treeshap.object}} with \code{$shaps}, and interaction values can be accessed with \code{$interactions}. 20 | } 21 | \description{ 22 | Calculate SHAP values and optionally SHAP Interaction values. 23 | } 24 | \examples{ 25 | \donttest{ 26 | library(xgboost) 27 | data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 28 | target <- fifa20$target 29 | 30 | # calculating simple SHAP values 31 | param <- list(objective = "reg:squarederror", max_depth = 3) 32 | xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, 33 | nrounds = 20, verbose = FALSE) 34 | unified_model <- xgboost.unify(xgb_model, as.matrix(data)) 35 | treeshap1 <- treeshap(unified_model, head(data, 3)) 36 | plot_contribution(treeshap1, obs = 1) 37 | treeshap1$shaps 38 | 39 | # It's possible to calcualte explanation over different part of the data set 40 | 41 | unified_model_rec <- set_reference_dataset(unified_model, data[1:1000, ]) 42 | treeshap_rec <- treeshap(unified_model, head(data, 3)) 43 | plot_contribution(treeshap_rec, obs = 1) 44 | 45 | # calculating SHAP interaction values 46 | param2 <- list(objective = "reg:squarederror", max_depth = 7) 47 | xgb_model2 <- xgboost::xgboost(as.matrix(data), params = param2, label = target, nrounds = 10) 48 | unified_model2 <- xgboost.unify(xgb_model2, as.matrix(data)) 49 | treeshap2 <- treeshap(unified_model2, head(data, 3), interactions = TRUE) 50 | treeshap2$interactions 51 | } 52 | } 53 | \seealso{ 54 | \code{\link{xgboost.unify}} for \code{XGBoost models} 55 | \code{\link{lightgbm.unify}} for \code{LightGBM models} 56 | \code{\link{gbm.unify}} for \code{GBM models} 57 | \code{\link{randomForest.unify}} for \code{randomForest models} 58 | \code{\link{ranger.unify}} for \code{ranger models} 59 | \code{\link{ranger_surv.unify}} for \code{ranger survival models} 60 | } 61 | -------------------------------------------------------------------------------- /man/treeshap.object.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/treeshap.R 3 | \name{treeshap.object} 4 | \alias{treeshap.object} 5 | \title{treeshap results} 6 | \value{ 7 | List consisting of four elements: 8 | \describe{ 9 | \item{shaps}{A \code{data.frame} with M columns, X rows (M - number of features, X - number of explained observations). Every row corresponds to SHAP values for a observation. } 10 | \item{interactions}{An \code{array} with dimensions (M, M, X) (M - number of features, X - number of explained observations). Every \code{[, , i]} slice is a symmetric matrix - SHAP Interaction values for a observation. \code{[a, b, i]} element is SHAP Interaction value of features \code{a} and \code{b} for observation \code{i}. Is \code{NULL} if interactions where not calculated (parameter \code{interactions} set \code{FALSE}.) } 11 | \item{unified_model}{An object of type \code{\link{model_unified.object}}. Unified representation of a model for which SHAP values were calculated. It is used by some of the plotting functions.} 12 | \item{observations}{Explained dataset. \code{data.frame} or \code{matrix}. It is used by some of the plotting functions.} 13 | } 14 | } 15 | \description{ 16 | \code{treeshap} object produced by \code{treeshap} function. 17 | } 18 | \seealso{ 19 | \code{\link{treeshap}}, 20 | 21 | \code{\link{plot_contribution}}, \code{\link{plot_feature_importance}}, \code{\link{plot_feature_dependence}}, \code{\link{plot_interaction}} 22 | } 23 | -------------------------------------------------------------------------------- /man/treeshap_multioutput.object.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/treeshap.R 3 | \name{treeshap_multioutput.object} 4 | \alias{treeshap_multioutput.object} 5 | \title{treeshap results for multi-output model} 6 | \value{ 7 | List consisting of \code{treeshap} objects, one for each individual output of a model. For survival models, the list is named using the time points, for which TreeSHAP values are calculated. 8 | } 9 | \description{ 10 | \code{treeshap_multioutput} object produced by \code{treeshap} function. 11 | } 12 | \seealso{ 13 | \code{\link{treeshap}}, 14 | 15 | \code{\link{treeshap.object}} 16 | } 17 | -------------------------------------------------------------------------------- /man/unify.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/unify.R 3 | \name{unify} 4 | \alias{unify} 5 | \title{Unify tree-based model} 6 | \usage{ 7 | unify(model, data, ...) 8 | } 9 | \arguments{ 10 | \item{model}{A tree-based model object of any supported class (\code{gbm}, \code{lgb.Booster}, \code{randomForest}, \code{ranger}, or \code{xgb.Booster}).} 11 | 12 | \item{data}{Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model.} 13 | 14 | \item{...}{Additional parameters passed to the model-specific unification functions.} 15 | } 16 | \value{ 17 | A unified model representation - a \code{\link{model_unified.object}} object (for single-output models) or \code{\link{model_unified_multioutput.object}}, which is a list of \code{\link{model_unified.object}} objects (for multi-output models). 18 | } 19 | \description{ 20 | Convert your tree-based model into a standardized representation. 21 | The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. 22 | } 23 | \examples{ 24 | 25 | library(ranger) 26 | data_fifa <- fifa20$data[!colnames(fifa20$data) \%in\% 27 | c('work_rate', 'value_eur', 'gk_diving', 'gk_handling', 28 | 'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')] 29 | data <- na.omit(cbind(data_fifa, target = fifa20$target)) 30 | 31 | rf1 <- ranger::ranger(target~., data = data, max.depth = 10, num.trees = 10) 32 | unified_model1 <- unify(rf1, data) 33 | shaps1 <- treeshap(unified_model1, data[1:2,]) 34 | plot_contribution(shaps1, obs = 1) 35 | 36 | rf2 <- randomForest::randomForest(target~., data = data, maxnodes = 10, ntree = 10) 37 | unified_model2 <- unify(rf2, data) 38 | shaps2 <- treeshap(unified_model2, data[1:2,]) 39 | plot_contribution(shaps2, obs = 1) 40 | } 41 | \seealso{ 42 | \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} 43 | 44 | \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} 45 | 46 | \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} 47 | 48 | \code{\link{ranger.unify}} for \code{\link[ranger:ranger]{ranger models}} 49 | 50 | \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} 51 | } 52 | -------------------------------------------------------------------------------- /man/xgboost.unify.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/unify_xgboost.R 3 | \name{xgboost.unify} 4 | \alias{xgboost.unify} 5 | \title{Unify XGBoost model} 6 | \usage{ 7 | xgboost.unify(xgb_model, data, recalculate = FALSE) 8 | } 9 | \arguments{ 10 | \item{xgb_model}{A XGBoost model - object of class \code{xgb.Booster}} 11 | 12 | \item{data}{Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model.} 13 | 14 | \item{recalculate}{logical indicating if covers should be recalculated according to the dataset given in data. Keep it \code{FALSE} if training data are used.} 15 | } 16 | \value{ 17 | a unified model representation - a \code{\link{model_unified.object}} object 18 | } 19 | \description{ 20 | Convert your XGBoost model into a standardized representation. 21 | The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. 22 | } 23 | \examples{ 24 | \donttest{ 25 | library(xgboost) 26 | data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 27 | target <- fifa20$target 28 | param <- list(objective = "reg:squarederror", max_depth = 3) 29 | xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, 30 | nrounds = 20, verbose = 0) 31 | unified_model <- xgboost.unify(xgb_model, as.matrix(data)) 32 | shaps <- treeshap(unified_model, data[1:2,]) 33 | plot_contribution(shaps, obs = 1) 34 | } 35 | 36 | } 37 | \seealso{ 38 | \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} 39 | 40 | \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} 41 | 42 | \code{\link{ranger.unify}} for \code{\link[ranger:ranger]{ranger models}} 43 | 44 | \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} 45 | } 46 | -------------------------------------------------------------------------------- /src/RcppExports.cpp: -------------------------------------------------------------------------------- 1 | // Generated by using Rcpp::compileAttributes() -> do not edit by hand 2 | // Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 3 | 4 | #include 5 | 6 | using namespace Rcpp; 7 | 8 | #ifdef RCPP_USE_GLOBAL_ROSTREAM 9 | Rcpp::Rostream& Rcpp::Rcout = Rcpp::Rcpp_cout_get(); 10 | Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); 11 | #endif 12 | 13 | // predict_cpp 14 | NumericVector predict_cpp(DataFrame x, DataFrame is_na, IntegerVector roots, IntegerVector yes, IntegerVector no, IntegerVector missing, LogicalVector is_leaf, IntegerVector feature, NumericVector split, IntegerVector decision_type, NumericVector value); 15 | RcppExport SEXP _treeshap_predict_cpp(SEXP xSEXP, SEXP is_naSEXP, SEXP rootsSEXP, SEXP yesSEXP, SEXP noSEXP, SEXP missingSEXP, SEXP is_leafSEXP, SEXP featureSEXP, SEXP splitSEXP, SEXP decision_typeSEXP, SEXP valueSEXP) { 16 | BEGIN_RCPP 17 | Rcpp::RObject rcpp_result_gen; 18 | Rcpp::RNGScope rcpp_rngScope_gen; 19 | Rcpp::traits::input_parameter< DataFrame >::type x(xSEXP); 20 | Rcpp::traits::input_parameter< DataFrame >::type is_na(is_naSEXP); 21 | Rcpp::traits::input_parameter< IntegerVector >::type roots(rootsSEXP); 22 | Rcpp::traits::input_parameter< IntegerVector >::type yes(yesSEXP); 23 | Rcpp::traits::input_parameter< IntegerVector >::type no(noSEXP); 24 | Rcpp::traits::input_parameter< IntegerVector >::type missing(missingSEXP); 25 | Rcpp::traits::input_parameter< LogicalVector >::type is_leaf(is_leafSEXP); 26 | Rcpp::traits::input_parameter< IntegerVector >::type feature(featureSEXP); 27 | Rcpp::traits::input_parameter< NumericVector >::type split(splitSEXP); 28 | Rcpp::traits::input_parameter< IntegerVector >::type decision_type(decision_typeSEXP); 29 | Rcpp::traits::input_parameter< NumericVector >::type value(valueSEXP); 30 | rcpp_result_gen = Rcpp::wrap(predict_cpp(x, is_na, roots, yes, no, missing, is_leaf, feature, split, decision_type, value)); 31 | return rcpp_result_gen; 32 | END_RCPP 33 | } 34 | // new_covers 35 | IntegerVector new_covers(DataFrame x, DataFrame is_na, IntegerVector roots, IntegerVector yes, IntegerVector no, IntegerVector missing, LogicalVector is_leaf, IntegerVector feature, NumericVector split, IntegerVector decision_type); 36 | RcppExport SEXP _treeshap_new_covers(SEXP xSEXP, SEXP is_naSEXP, SEXP rootsSEXP, SEXP yesSEXP, SEXP noSEXP, SEXP missingSEXP, SEXP is_leafSEXP, SEXP featureSEXP, SEXP splitSEXP, SEXP decision_typeSEXP) { 37 | BEGIN_RCPP 38 | Rcpp::RObject rcpp_result_gen; 39 | Rcpp::RNGScope rcpp_rngScope_gen; 40 | Rcpp::traits::input_parameter< DataFrame >::type x(xSEXP); 41 | Rcpp::traits::input_parameter< DataFrame >::type is_na(is_naSEXP); 42 | Rcpp::traits::input_parameter< IntegerVector >::type roots(rootsSEXP); 43 | Rcpp::traits::input_parameter< IntegerVector >::type yes(yesSEXP); 44 | Rcpp::traits::input_parameter< IntegerVector >::type no(noSEXP); 45 | Rcpp::traits::input_parameter< IntegerVector >::type missing(missingSEXP); 46 | Rcpp::traits::input_parameter< LogicalVector >::type is_leaf(is_leafSEXP); 47 | Rcpp::traits::input_parameter< IntegerVector >::type feature(featureSEXP); 48 | Rcpp::traits::input_parameter< NumericVector >::type split(splitSEXP); 49 | Rcpp::traits::input_parameter< IntegerVector >::type decision_type(decision_typeSEXP); 50 | rcpp_result_gen = Rcpp::wrap(new_covers(x, is_na, roots, yes, no, missing, is_leaf, feature, split, decision_type)); 51 | return rcpp_result_gen; 52 | END_RCPP 53 | } 54 | // treeshap_cpp 55 | NumericVector treeshap_cpp(DataFrame x, DataFrame is_na, IntegerVector roots, IntegerVector yes, IntegerVector no, IntegerVector missing, IntegerVector feature, NumericVector split, IntegerVector decision_type, LogicalVector is_leaf, NumericVector value, NumericVector cover, bool verbose); 56 | RcppExport SEXP _treeshap_treeshap_cpp(SEXP xSEXP, SEXP is_naSEXP, SEXP rootsSEXP, SEXP yesSEXP, SEXP noSEXP, SEXP missingSEXP, SEXP featureSEXP, SEXP splitSEXP, SEXP decision_typeSEXP, SEXP is_leafSEXP, SEXP valueSEXP, SEXP coverSEXP, SEXP verboseSEXP) { 57 | BEGIN_RCPP 58 | Rcpp::RObject rcpp_result_gen; 59 | Rcpp::RNGScope rcpp_rngScope_gen; 60 | Rcpp::traits::input_parameter< DataFrame >::type x(xSEXP); 61 | Rcpp::traits::input_parameter< DataFrame >::type is_na(is_naSEXP); 62 | Rcpp::traits::input_parameter< IntegerVector >::type roots(rootsSEXP); 63 | Rcpp::traits::input_parameter< IntegerVector >::type yes(yesSEXP); 64 | Rcpp::traits::input_parameter< IntegerVector >::type no(noSEXP); 65 | Rcpp::traits::input_parameter< IntegerVector >::type missing(missingSEXP); 66 | Rcpp::traits::input_parameter< IntegerVector >::type feature(featureSEXP); 67 | Rcpp::traits::input_parameter< NumericVector >::type split(splitSEXP); 68 | Rcpp::traits::input_parameter< IntegerVector >::type decision_type(decision_typeSEXP); 69 | Rcpp::traits::input_parameter< LogicalVector >::type is_leaf(is_leafSEXP); 70 | Rcpp::traits::input_parameter< NumericVector >::type value(valueSEXP); 71 | Rcpp::traits::input_parameter< NumericVector >::type cover(coverSEXP); 72 | Rcpp::traits::input_parameter< bool >::type verbose(verboseSEXP); 73 | rcpp_result_gen = Rcpp::wrap(treeshap_cpp(x, is_na, roots, yes, no, missing, feature, split, decision_type, is_leaf, value, cover, verbose)); 74 | return rcpp_result_gen; 75 | END_RCPP 76 | } 77 | // treeshap_interactions_cpp 78 | List treeshap_interactions_cpp(DataFrame x, DataFrame is_na, IntegerVector roots, IntegerVector yes, IntegerVector no, IntegerVector missing, IntegerVector feature, NumericVector split, IntegerVector decision_type, LogicalVector is_leaf, NumericVector value, NumericVector cover, bool verbose); 79 | RcppExport SEXP _treeshap_treeshap_interactions_cpp(SEXP xSEXP, SEXP is_naSEXP, SEXP rootsSEXP, SEXP yesSEXP, SEXP noSEXP, SEXP missingSEXP, SEXP featureSEXP, SEXP splitSEXP, SEXP decision_typeSEXP, SEXP is_leafSEXP, SEXP valueSEXP, SEXP coverSEXP, SEXP verboseSEXP) { 80 | BEGIN_RCPP 81 | Rcpp::RObject rcpp_result_gen; 82 | Rcpp::RNGScope rcpp_rngScope_gen; 83 | Rcpp::traits::input_parameter< DataFrame >::type x(xSEXP); 84 | Rcpp::traits::input_parameter< DataFrame >::type is_na(is_naSEXP); 85 | Rcpp::traits::input_parameter< IntegerVector >::type roots(rootsSEXP); 86 | Rcpp::traits::input_parameter< IntegerVector >::type yes(yesSEXP); 87 | Rcpp::traits::input_parameter< IntegerVector >::type no(noSEXP); 88 | Rcpp::traits::input_parameter< IntegerVector >::type missing(missingSEXP); 89 | Rcpp::traits::input_parameter< IntegerVector >::type feature(featureSEXP); 90 | Rcpp::traits::input_parameter< NumericVector >::type split(splitSEXP); 91 | Rcpp::traits::input_parameter< IntegerVector >::type decision_type(decision_typeSEXP); 92 | Rcpp::traits::input_parameter< LogicalVector >::type is_leaf(is_leafSEXP); 93 | Rcpp::traits::input_parameter< NumericVector >::type value(valueSEXP); 94 | Rcpp::traits::input_parameter< NumericVector >::type cover(coverSEXP); 95 | Rcpp::traits::input_parameter< bool >::type verbose(verboseSEXP); 96 | rcpp_result_gen = Rcpp::wrap(treeshap_interactions_cpp(x, is_na, roots, yes, no, missing, feature, split, decision_type, is_leaf, value, cover, verbose)); 97 | return rcpp_result_gen; 98 | END_RCPP 99 | } 100 | 101 | static const R_CallMethodDef CallEntries[] = { 102 | {"_treeshap_predict_cpp", (DL_FUNC) &_treeshap_predict_cpp, 11}, 103 | {"_treeshap_new_covers", (DL_FUNC) &_treeshap_new_covers, 10}, 104 | {"_treeshap_treeshap_cpp", (DL_FUNC) &_treeshap_treeshap_cpp, 13}, 105 | {"_treeshap_treeshap_interactions_cpp", (DL_FUNC) &_treeshap_treeshap_interactions_cpp, 13}, 106 | {NULL, NULL, 0} 107 | }; 108 | 109 | RcppExport void R_init_treeshap(DllInfo *dll) { 110 | R_registerRoutines(dll, NULL, CallEntries, NULL, NULL); 111 | R_useDynamicSymbols(dll, FALSE); 112 | } 113 | -------------------------------------------------------------------------------- /src/RcppExports.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/treeshap/a1a5472fe5cb5039079b20eace513c4cf82dac51/src/RcppExports.o -------------------------------------------------------------------------------- /src/predict.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | using namespace Rcpp; 3 | 4 | // [[Rcpp::plugins(cpp11)]] 5 | 6 | // [[Rcpp::export]] 7 | NumericVector predict_cpp(DataFrame x, DataFrame is_na, IntegerVector roots, IntegerVector yes, IntegerVector no, 8 | IntegerVector missing, LogicalVector is_leaf, IntegerVector feature, NumericVector split, 9 | IntegerVector decision_type, NumericVector value) { 10 | NumericVector prediction(x.ncol()); 11 | for (int i = 0; i < x.ncol(); ++i) { 12 | NumericVector observation = x[i]; 13 | LogicalVector observation_is_na = is_na[i]; 14 | 15 | for (int node: roots) { 16 | while (!is_leaf[node]) { 17 | if (observation_is_na[feature[node]]) { 18 | node = missing[node]; 19 | } else if (((decision_type[node] == 1) && (observation[feature[node]] <= split[node])) 20 | || ((decision_type[node] == 2) && (observation[feature[node]] < split[node]))) { 21 | node = yes[node]; 22 | } else { 23 | node = no[node]; 24 | } 25 | } 26 | prediction[i] += value[node]; 27 | } 28 | } 29 | 30 | return prediction; 31 | } 32 | 33 | -------------------------------------------------------------------------------- /src/set_reference_dataset.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | using namespace Rcpp; 3 | 4 | // [[Rcpp::plugins(cpp11)]] 5 | 6 | // [[Rcpp::export]] 7 | IntegerVector new_covers(DataFrame x, DataFrame is_na, IntegerVector roots, IntegerVector yes, 8 | IntegerVector no, IntegerVector missing, LogicalVector is_leaf, IntegerVector feature, 9 | NumericVector split, IntegerVector decision_type) { 10 | IntegerVector cover(is_leaf.size()); 11 | for (int i = 0; i < x.ncol(); ++i) { 12 | NumericVector observation = x[i]; 13 | LogicalVector observation_is_na = is_na[i]; 14 | 15 | for (int node: roots) { 16 | while (!is_leaf[node]) { 17 | cover[node]++; 18 | 19 | if (observation_is_na[feature[node]]) { 20 | node = missing[node]; 21 | } else if (((decision_type[node] == 1) && (observation[feature[node]] <= split[node])) 22 | || ((decision_type[node] == 2) && (observation[feature[node]] < split[node]))) { 23 | node = yes[node]; 24 | } else { 25 | node = no[node]; 26 | } 27 | } 28 | cover[node]++; 29 | } 30 | } 31 | 32 | return cover; 33 | } 34 | 35 | -------------------------------------------------------------------------------- /src/treeshap.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #if !defined(WIN32) && !defined(__WIN32) && !defined(__WIN32__) 4 | #include 5 | #include 6 | #endif 7 | 8 | using namespace Rcpp; 9 | 10 | // [[Rcpp::plugins(cpp11)]] 11 | 12 | #if !defined(WIN32) && !defined(__WIN32) && !defined(__WIN32__) 13 | void initProgressBar() { 14 | std::stringstream strs; 15 | strs <<"|0%----|------|20%---|------|40%---|------|60%---|------|80%---|------|100%\n" << 16 | "=---------------------------------------------------------------------- (0%)"; 17 | std::string temp_str = strs.str(); 18 | char const* char_type = temp_str.c_str(); 19 | Rprintf("\r"); 20 | Rprintf("%s", char_type); 21 | Rprintf("\r"); 22 | R_FlushConsole(); 23 | R_CheckUserInterrupt(); 24 | } 25 | 26 | void updateProgressBar(int steps_done, int steps_all) { 27 | std::stringstream strs; 28 | int progress_signs = int(.5 + 70 * steps_done / steps_all); 29 | int progress_percent = int(.5 + 100 * steps_done / steps_all); 30 | strs << std::string(progress_signs + 1, '=') << std::string(70 - progress_signs, '-') << " (" << progress_percent << "%)"; 31 | std::string temp_str = strs.str(); 32 | char const* char_type = temp_str.c_str(); 33 | Rprintf("\r"); 34 | Rprintf("%s", char_type); 35 | Rprintf("\r"); 36 | R_FlushConsole(); 37 | R_CheckUserInterrupt(); 38 | } 39 | #endif 40 | 41 | typedef double tnumeric; 42 | 43 | struct PathElem { 44 | PathElem(int d, bool o, tnumeric z, tnumeric w) : d(d), o(o), z(z), w(w) {} 45 | int d; 46 | bool o; 47 | tnumeric z, w; 48 | }; 49 | 50 | typedef std::vector Path; 51 | 52 | void extend(Path &m, tnumeric p_z, bool p_o, int p_i) { 53 | int depth = m.size(); 54 | 55 | PathElem tmp(p_i, p_o, p_z, (depth == 0) ? 1.0 : 0.0); 56 | 57 | m.push_back(tmp); 58 | 59 | for (int i = depth - 1; i >= 0; i--) { 60 | m[i + 1].w += p_o * m[i].w * (i + 1) / static_cast(depth + 1); 61 | m[i].w = p_z * m[i].w * (depth - i) / static_cast(depth + 1); 62 | } 63 | } 64 | 65 | void unwind(Path &m, int i) { 66 | int depth = m.size() - 1; 67 | tnumeric n = m[depth].w; 68 | 69 | if (m[i].o != 0) { 70 | for (int j = depth - 1; j >= 0; --j) { 71 | tnumeric tmp = m[j].w; 72 | m[j].w = n * (depth + 1) / static_cast(j + 1); 73 | n = tmp - m[j].w * m[i].z * (depth - j) / static_cast(depth + 1); 74 | } 75 | } else { 76 | for (int j = depth - 1; j >= 0; --j) { 77 | m[j].w = (m[j].w * (depth + 1)) / static_cast(m[i].z * (depth - j)); 78 | } 79 | } 80 | 81 | for (int j = i; j < depth; ++j) { 82 | m[j].d = m[j + 1].d; 83 | m[j].z = m[j + 1].z; 84 | m[j].o = m[j + 1].o; 85 | } 86 | 87 | m.pop_back(); 88 | } 89 | 90 | tnumeric unwound_sum(const Path &m, int i) { 91 | int depth = m.size() - 1; 92 | tnumeric total = 0; 93 | 94 | if (m[i].o != 0) { 95 | tnumeric n = m[depth].w; 96 | for (int j = depth - 1; j >= 0; --j) { 97 | tnumeric tmp = n / static_cast((j + 1) * m[i].o); 98 | total += tmp; 99 | n = m[j].w - tmp * m[i].z * (depth - j); 100 | } 101 | } else { 102 | for (int j = depth - 1; j >= 0; --j) { 103 | total += m[j].w / static_cast((depth - j) * m[i].z); 104 | } 105 | } 106 | 107 | return total * (depth + 1); 108 | } 109 | 110 | // SHAP computation for a single decision tree 111 | void recurse(const IntegerVector &yes, const IntegerVector &no, const IntegerVector &missing, const IntegerVector &feature, 112 | const LogicalVector &is_leaf, const NumericVector &value, const NumericVector &cover, 113 | const NumericVector &split, const IntegerVector &decision_type, const NumericVector &observation, const LogicalVector &observation_is_na, 114 | NumericVector &shaps, Path &m, int j, tnumeric p_z, bool p_o, int p_i, 115 | int condition, int condition_feature, tnumeric condition_fraction) { 116 | 117 | if (condition_fraction == 0) { 118 | return; 119 | } 120 | 121 | if (p_z == 0) { // entering a node with Cover = 0 122 | return; 123 | } 124 | 125 | if (condition == 0 || // not calculating interactions 126 | condition_feature != p_i) { 127 | extend(m, p_z, p_o, p_i); 128 | } 129 | 130 | if (is_leaf[j]) { 131 | for (int i = 1; i < m.size(); ++i) { 132 | shaps[m[i].d] += unwound_sum(m, i) * (m[i].o - m[i].z) * condition_fraction * value[j]; 133 | } 134 | } else { 135 | tnumeric i_z = 1.0; 136 | bool i_o = 1; 137 | 138 | // undo previous extension if we have already seen this feature 139 | for (int k = 1; k < m.size(); ++k) { 140 | if (m[k].d == feature[j]) { 141 | i_z = m[k].z; 142 | i_o = m[k].o; 143 | unwind(m, k); 144 | break; 145 | } 146 | } 147 | 148 | if ((missing[j] == NA_INTEGER) || (missing[j] == no[j]) || (missing[j] == yes[j])) { //'missing' is one of ['yes', 'no'] nodes, or is NA 149 | int hot = no[j]; 150 | 151 | if (observation_is_na[feature[j]]) { 152 | hot = missing[j]; 153 | } else if (((decision_type[j] == 1) && (observation[feature[j]] <= split[j])) 154 | || ((decision_type[j] == 2) && (observation[feature[j]] < split[j]))) { 155 | hot = yes[j]; 156 | } 157 | int cold = (hot == yes[j]) ? no[j] : yes[j]; 158 | 159 | // divide up the condition_fraction among the recursive calls 160 | // if we are not calculating interactions then condition fraction is always 1 161 | tnumeric hot_condition_fraction = condition_fraction; 162 | tnumeric cold_condition_fraction = condition_fraction; 163 | if (feature[j] == condition_feature) { 164 | if (condition > 0) { 165 | cold_condition_fraction = 0; 166 | } else if (condition < 0) { 167 | hot_condition_fraction *= cover[hot] / static_cast(cover[j]); 168 | cold_condition_fraction *= cover[cold] / static_cast(cover[j]); 169 | } 170 | } 171 | 172 | Path m_copy = Path(m); 173 | recurse(yes, no, missing, feature, is_leaf, value, cover, split, decision_type, observation, observation_is_na, shaps, 174 | m, hot, 175 | i_z * cover[hot] / static_cast(cover[j]), 176 | i_o, 177 | feature[j], 178 | condition, condition_feature, hot_condition_fraction); 179 | recurse(yes, no, missing, feature, is_leaf, value, cover, split, decision_type, observation, observation_is_na, shaps, 180 | m_copy, cold, 181 | i_z * cover[cold] / static_cast(cover[j]), 182 | 0, 183 | feature[j], 184 | condition, condition_feature, cold_condition_fraction); 185 | } else { // 'missing' node is a third son = not one of ['yes', 'no'] nodes 186 | int hot = no[j]; 187 | int cold1 = yes[j]; 188 | int cold2 = missing[j]; 189 | if (observation_is_na[feature[j]]) { 190 | hot = missing[j]; 191 | cold1 = yes[j]; 192 | cold2 = no[j]; 193 | } else if(((decision_type[j] == 1) && (observation[feature[j]] <= split[j])) 194 | || ((decision_type[j] == 2) && (observation[feature[j]] < split[j]))) { 195 | hot = yes[j]; 196 | cold1 = missing[j]; 197 | cold2 = no[j]; 198 | } 199 | 200 | // divide up the condition_fraction among the recursive calls 201 | // if we are not calculating interactions condition fraction is always 1 202 | tnumeric hot_condition_fraction = condition_fraction; 203 | tnumeric cold1_condition_fraction = condition_fraction; 204 | tnumeric cold2_condition_fraction = condition_fraction; 205 | if (feature[j] == condition_feature) { 206 | if (condition > 0) { 207 | cold1_condition_fraction = 0; 208 | cold2_condition_fraction = 0; 209 | } else if (condition < 0) { 210 | hot_condition_fraction *= cover[hot] / static_cast(cover[j]); 211 | cold1_condition_fraction *= cover[cold1] / static_cast(cover[j]); 212 | cold2_condition_fraction *= cover[cold2] / static_cast(cover[j]); 213 | } 214 | } 215 | 216 | Path m_copy1 = Path(m); 217 | Path m_copy2 = Path(m); 218 | recurse(yes, no, missing, feature, is_leaf, value, cover, split, decision_type, observation, observation_is_na, shaps, 219 | m, hot, 220 | i_z * cover[hot] / static_cast(cover[j]), 221 | i_o, 222 | feature[j], 223 | condition, condition_feature, hot_condition_fraction); 224 | recurse(yes, no, missing, feature, is_leaf, value, cover, split, decision_type, observation, observation_is_na, shaps, 225 | m_copy1, cold1, 226 | i_z * cover[cold1] / static_cast(cover[j]), 227 | 0, 228 | feature[j], 229 | condition, condition_feature, cold1_condition_fraction); 230 | recurse(yes, no, missing, feature, is_leaf, value, cover, split, decision_type, observation, observation_is_na, shaps, 231 | m_copy2, cold2, 232 | i_z * cover[cold2] / static_cast(cover[j]), 233 | 0, 234 | feature[j], 235 | condition, condition_feature, cold2_condition_fraction); 236 | } 237 | } 238 | } 239 | 240 | // [[Rcpp::export]] 241 | NumericVector treeshap_cpp(DataFrame x, DataFrame is_na, IntegerVector roots, 242 | IntegerVector yes, IntegerVector no, IntegerVector missing, IntegerVector feature, 243 | NumericVector split, IntegerVector decision_type, 244 | LogicalVector is_leaf, NumericVector value, NumericVector cover, 245 | bool verbose) { 246 | NumericMatrix shaps(x.ncol(), x.nrow()); 247 | 248 | #if !defined(WIN32) && !defined(__WIN32) && !defined(__WIN32__) 249 | if (verbose) { 250 | initProgressBar(); 251 | } 252 | #endif 253 | 254 | for (int obs = 0; obs < x.ncol(); obs++) { 255 | NumericVector observation = x[obs]; 256 | LogicalVector observation_is_na = is_na[obs]; 257 | 258 | NumericVector shaps_row(x.nrow()); 259 | 260 | for (int i = 0; i < roots.size(); ++i) { 261 | Path m; 262 | recurse(yes, no, missing, feature, is_leaf, value, cover, split, decision_type, observation, observation_is_na, shaps_row, 263 | m, roots[i], 1, 1, -1, 264 | 0, 0, 1); 265 | } 266 | 267 | shaps(obs, _) = shaps_row; 268 | 269 | #if !defined(WIN32) && !defined(__WIN32) && !defined(__WIN32__) 270 | if (verbose) { 271 | updateProgressBar(obs + 1, x.ncol()); 272 | } 273 | #endif 274 | } 275 | 276 | return shaps; 277 | } 278 | 279 | 280 | // recursive tree traversal listing all features in the tree 281 | void unique_features_tree_traversal(int node, const IntegerVector &yes, const IntegerVector &no, 282 | const IntegerVector &missing, const IntegerVector &feature, const LogicalVector &is_leaf, 283 | std::vector &tree_features) { 284 | if (!is_leaf[node]) { 285 | tree_features.push_back(feature[node]); 286 | unique_features_tree_traversal(yes[node], yes, no, missing, feature, is_leaf, tree_features); 287 | unique_features_tree_traversal(no[node], yes, no, missing, feature, is_leaf, tree_features); 288 | if (missing[node] != NA_INTEGER && missing[node] != yes[node] && missing[node] != no[node]) { 289 | unique_features_tree_traversal(missing[node], yes, no, missing, feature, is_leaf, tree_features); 290 | } 291 | } 292 | } 293 | 294 | // function listing all unique features inside the tree 295 | std::vector unique_features(int root, const IntegerVector &yes, const IntegerVector &no, 296 | const IntegerVector &missing, const IntegerVector &feature, const LogicalVector &is_leaf) { 297 | std::vector tree_features; 298 | unique_features_tree_traversal(root, yes, no, missing, feature, is_leaf, tree_features); 299 | 300 | // removing duplicates 301 | std::sort(tree_features.begin(), tree_features.end()); 302 | auto last = std::unique(tree_features.begin(), tree_features.end()); 303 | tree_features.erase(last, tree_features.end()); 304 | 305 | return tree_features; 306 | } 307 | 308 | // [[Rcpp::export]] 309 | List treeshap_interactions_cpp(DataFrame x, DataFrame is_na, IntegerVector roots, 310 | IntegerVector yes, IntegerVector no, IntegerVector missing, IntegerVector feature, 311 | NumericVector split, IntegerVector decision_type, 312 | LogicalVector is_leaf, NumericVector value, NumericVector cover, 313 | bool verbose) { 314 | NumericMatrix shaps(x.ncol(), x.nrow()); 315 | NumericVector interactions(x.ncol() * x.nrow() * x.nrow()); 316 | 317 | 318 | #if !defined(WIN32) && !defined(__WIN32) && !defined(__WIN32__) 319 | if (verbose) { 320 | initProgressBar(); 321 | } 322 | #endif 323 | 324 | for (int obs = 0; obs < x.ncol(); obs++) { 325 | NumericVector observation = x[obs]; 326 | LogicalVector observation_is_na = is_na[obs]; 327 | 328 | NumericMatrix interactions_slice(x.nrow(), x.nrow()); 329 | NumericVector shaps_row(x.nrow()); 330 | NumericVector diagonal(x.nrow()); 331 | 332 | for (int i = 0; i < roots.size(); ++i) { 333 | Path m; 334 | recurse(yes, no, missing, feature, is_leaf, value, cover, split, decision_type, observation, observation_is_na, shaps_row, 335 | m, roots[i], 1, 1, -1, 336 | 0, 0, 1); // standard shaps computation 337 | 338 | std::vector tree_features = unique_features(roots[i], yes, no, missing, feature, is_leaf); 339 | for (auto tree_feature : tree_features) { 340 | NumericVector with(x.nrow()); 341 | NumericVector without(x.nrow()); 342 | 343 | Path m_with; 344 | recurse(yes, no, missing, feature, is_leaf, value, cover, split, decision_type, observation, observation_is_na, with, 345 | m_with, roots[i], 1, 1, -1, 346 | 1, tree_feature, 1); 347 | Path m_without; 348 | recurse(yes, no, missing, feature, is_leaf, value, cover, split, decision_type, observation, observation_is_na, without, 349 | m_without, roots[i], 1, 1, -1, 350 | -1, tree_feature, 1); 351 | 352 | NumericVector v = (with - without) / 2; 353 | interactions_slice(tree_feature, _) = interactions_slice(tree_feature, _) + v; 354 | diagonal = diagonal - v; 355 | } 356 | } 357 | 358 | // filling diagonal 359 | diagonal = shaps_row + diagonal; 360 | for (int k = 0; k < x.nrow(); ++k) { 361 | interactions_slice(k, k) = diagonal[k]; 362 | } 363 | 364 | // prescribing results from observation's vector and matrix to result's matrix and "array" 365 | shaps(obs, _) = shaps_row; 366 | for (int i = 0; i < x.nrow(); i++) { 367 | for (int j = 0; j < x.nrow(); j++) { 368 | interactions[obs * x.nrow() * x.nrow() + i * x.nrow() + j] = interactions_slice(i, j); 369 | } 370 | } 371 | 372 | #if !defined(WIN32) && !defined(__WIN32) && !defined(__WIN32__) 373 | if (verbose) { 374 | updateProgressBar(obs + 1, x.ncol()); 375 | } 376 | #endif 377 | } 378 | 379 | List ret = List::create(Named("shaps") = shaps, _["interactions"] = interactions); 380 | return ret; 381 | } 382 | -------------------------------------------------------------------------------- /src/treeshap.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/treeshap/a1a5472fe5cb5039079b20eace513c4cf82dac51/src/treeshap.o -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(treeshap) 3 | 4 | test_check("treeshap") 5 | -------------------------------------------------------------------------------- /tests/testthat/test_gbm_unify.R: -------------------------------------------------------------------------------- 1 | library(treeshap) 2 | suppressWarnings(library(gbm, quietly = TRUE)) 3 | 4 | x <- fifa20$data 5 | x['value_eur'] <- fifa20$target 6 | 7 | gbm_with_cat_model <- gbm::gbm( 8 | formula = value_eur ~ ., 9 | data = x, 10 | distribution = "laplace", 11 | n.trees = 10, 12 | interaction.depth = 2, 13 | n.cores = 1 14 | ) 15 | 16 | x <- x[colnames(fifa20$data) != 'work_rate'] 17 | 18 | gbm_num_model <- gbm::gbm( 19 | formula = value_eur ~ ., 20 | data = x, 21 | n.trees = 50, 22 | distribution = 'gaussian', 23 | n.cores = 1 24 | ) 25 | 26 | test_that('gbm.unify returns an object of appropriate class', { 27 | expect_true(is.model_unified(gbm.unify(gbm_num_model, x))) 28 | expect_true(is.model_unified(unify(gbm_num_model, x))) 29 | }) 30 | 31 | 32 | test_that('gbm.unify returns an object with correct attributes', { 33 | unified_model <- gbm.unify(gbm_num_model, x) 34 | 35 | expect_equal(attr(unified_model, "missing_support"), TRUE) 36 | expect_equal(attr(unified_model, "model"), "gbm") 37 | }) 38 | 39 | test_that('the gbm.unify function does not support models with categorical features', { 40 | expect_error(gbm.unify(gbm_with_cat_model), "Models built on data with categorical features are not supported - please encode them before training.") 41 | }) 42 | 43 | test_that('the gbm.unify function returns data frame with columns of appropriate column', { 44 | unifier <- gbm.unify(gbm_num_model, x)$model 45 | expect_true(is.integer(unifier$Tree)) 46 | expect_true(is.integer(unifier$Node)) 47 | expect_true(is.character(unifier$Feature)) 48 | expect_true(is.factor(unifier$Decision.type)) 49 | expect_true(is.numeric(unifier$Split)) 50 | expect_true(is.integer(unifier$Yes)) 51 | expect_true(is.integer(unifier$No)) 52 | expect_true(is.integer(unifier$Missing)) 53 | expect_true(is.numeric(unifier$Prediction)) 54 | expect_true(is.numeric(unifier$Cover)) 55 | }) 56 | 57 | test_that("shap calculates without an error", { 58 | unifier <- gbm.unify(gbm_num_model, x) 59 | expect_error(treeshap(unifier, x[1:3,], verbose = FALSE), NA) 60 | }) 61 | 62 | test_that("gbm: mean prediction calculated using predict == using covers", { 63 | unifier <- gbm.unify(gbm_num_model, x) 64 | 65 | intercept_predict <- mean(predict(unifier, x)) 66 | 67 | ntrees <- sum(unifier$model$Node == 0) 68 | leaves <- unifier$model[is.na(unifier$model$Feature), ] 69 | intercept_covers <- sum(leaves$Prediction * leaves$Cover) / sum(leaves$Cover) * ntrees 70 | 71 | #expect_true(all(abs((intercept_predict - intercept_covers) / intercept_predict) < 10**(-14))) 72 | expect_equal(intercept_predict, intercept_covers) 73 | }) 74 | 75 | test_that("gbm: covers correctness", { 76 | unifier <- gbm.unify(gbm_num_model, x) 77 | 78 | roots <- unifier$model[unifier$model$Node == 0, ] 79 | expect_true(all(roots$Cover == nrow(x))) 80 | 81 | internals <- unifier$model[!is.na(unifier$model$Feature), ] 82 | yes_child_cover <- unifier$model[internals$Yes, ]$Cover 83 | no_child_cover <- unifier$model[internals$No, ]$Cover 84 | if (all(is.na(internals$Missing))) { 85 | children_cover <- yes_child_cover + no_child_cover 86 | } else { 87 | missing_child_cover <- unifier$model[internals$Missing, ]$Cover 88 | missing_child_cover[is.na(missing_child_cover)] <- 0 89 | missing_child_cover[internals$Missing == internals$Yes | internals$Missing == internals$No] <- 0 90 | children_cover <- yes_child_cover + no_child_cover + missing_child_cover 91 | } 92 | expect_true(all(internals$Cover == children_cover)) 93 | }) 94 | 95 | -------------------------------------------------------------------------------- /tests/testthat/test_lightgbm_unify.R: -------------------------------------------------------------------------------- 1 | library(treeshap) 2 | param_lightgbm <- list(objective = "regression", 3 | max_depth = 3, 4 | force_row_wise = TRUE, 5 | learning.rate = 0.1) 6 | 7 | data_fifa <- fifa20$data[!colnames(fifa20$data)%in%c('work_rate', 'value_eur', 'gk_diving', 'gk_handling', 'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')] 8 | data <- as.matrix(na.omit(data.table::as.data.table(cbind(data_fifa, fifa20$target)))) 9 | sparse_data <- data[,-ncol(data)] 10 | x <- lightgbm::lgb.Dataset(sparse_data, label = data[,ncol(data)]) 11 | lgb_data <- lightgbm::lgb.Dataset.construct(x) 12 | lgbm_fifa <- lightgbm::lightgbm(data = lgb_data, 13 | params = param_lightgbm, 14 | verbose = -1, 15 | num_threads = 0) 16 | lgbmtree <- lightgbm::lgb.model.dt.tree(lgbm_fifa) 17 | 18 | test_that('lightgbm.unify returns an object with correct attributes', { 19 | unified_model <- lightgbm.unify(lgbm_fifa, sparse_data) 20 | 21 | expect_equal(attr(unified_model, "missing_support"), TRUE) 22 | expect_equal(attr(unified_model, "model"), "LightGBM") 23 | }) 24 | 25 | test_that('Columns after lightgbm.unify are of appropriate type', { 26 | unified_model <- lightgbm.unify(lgbm_fifa, sparse_data)$model 27 | expect_true(is.integer(unified_model$Tree)) 28 | expect_true(is.integer(unified_model$Node)) 29 | expect_true(is.character(unified_model$Feature)) 30 | expect_true(is.factor(unified_model$Decision.type)) 31 | expect_true(is.numeric(unified_model$Split)) 32 | expect_true(is.integer(unified_model$Yes)) 33 | expect_true(is.integer(unified_model$No)) 34 | expect_true(is.integer(unified_model$Missing)) 35 | expect_true(is.numeric(unified_model$Prediction)) 36 | expect_true(is.numeric(unified_model$Cover)) 37 | }) 38 | 39 | test_that('lightgbm.unify creates an object of the appropriate class', { 40 | expect_true(is.model_unified(lightgbm.unify(lgbm_fifa, sparse_data))) 41 | expect_true(is.model_unified(unify(lgbm_fifa, sparse_data))) 42 | }) 43 | 44 | test_that('basic columns after lightgbm.unify are correct', { 45 | unified_model <- lightgbm.unify(lgbm_fifa, sparse_data)$model 46 | expect_equal(lgbmtree$tree_index, unified_model$Tree) 47 | to_test_features <- lgbmtree[order(lgbmtree$split_index), .(split_feature,split_index, threshold, leaf_count, internal_count),tree_index] 48 | expect_equal(to_test_features[!is.na(to_test_features$split_index),][['split_index']], unified_model[!is.na(unified_model$Feature),][['Node']]) 49 | expect_equal(to_test_features[['split_feature']], unified_model[['Feature']]) 50 | expect_equal(to_test_features[['threshold']], unified_model[['Split']]) 51 | expect_equal(to_test_features[!is.na(internal_count),][['internal_count']], unified_model[!is.na(unified_model$Feature),][['Cover']]) 52 | }) 53 | 54 | test_that('connections between nodes and leaves after lightgbm.unify are correct', { 55 | test_object <- as.data.table(lightgbm.unify(lgbm_fifa, sparse_data)$model) 56 | #Check if the sums of children's covers are correct 57 | expect_equal(test_object[test_object[!is.na(test_object$Yes)][['Yes']]][['Cover']] + 58 | test_object[test_object[!is.na(test_object$No)][['No']]][['Cover']], test_object[!is.na(Feature)][['Cover']]) 59 | #check if default_left information is correctly used 60 | df_default_left <- lgbmtree[default_left == "TRUE", c('tree_index', 'split_index')] 61 | test_object_actual_default_left <- test_object[Yes == Missing, c('Tree', 'Node')] 62 | colnames(test_object_actual_default_left) <- c('tree_index', 'split_index') 63 | attr(test_object_actual_default_left, 'model') <- NULL 64 | expect_equal(test_object_actual_default_left[order(tree_index, split_index)], df_default_left[order(tree_index, split_index)]) 65 | #and default_left = FALSE analogically: 66 | df_default_right <- lgbmtree[default_left != 'TRUE', c('tree_index', 'split_index')] 67 | test_object_actual_default_right <- test_object[No == Missing, c('Tree', 'Node')] 68 | colnames(test_object_actual_default_right) <- c('tree_index', 'split_index') 69 | attr(test_object_actual_default_right, 'model') <- NULL 70 | expect_equal(test_object_actual_default_right[order(tree_index, split_index)], df_default_right[order(tree_index, split_index)]) 71 | #One more test with checking the usage of 'decision_type' column needed 72 | }) 73 | 74 | # Function that return the predictions for sample observations indicated by vector contatining values -1, 0, 1, where -1 means 75 | # going to the 'Yes' Node, 1 - to the 'No' node and 0 - to the missing node. The vectors are randomly produced during executing 76 | # the function and should be passed to prepare_original_preds_ to save the conscistence. Later we can compare the 'predicted' values 77 | prepare_test_preds <- function(unify_out){ 78 | stopifnot(all(c("Tree", "Node", "Feature", "Split", "Yes", "No", "Missing", "Prediction", "Cover") %in% colnames(unify_out))) 79 | test_tree <- unify_out[unify_out$Tree %in% 0:9,] 80 | test_tree[['node_row_id']] <- seq_len(nrow(test_tree)) 81 | test_obs <- lapply(table(test_tree$Tree), function(y) sample(c(-1, 0, 1), y, replace = T)) 82 | test_tree <- split(test_tree, test_tree$Tree) 83 | determine_val <- function(obs, tree){ 84 | root_id <- tree[['node_row_id']][1] 85 | tree[,c('Yes', 'No', 'Missing')] <- tree[,c('Yes', 'No', 'Missing')] - root_id + 1 86 | i <- 1 87 | indx <- 1 88 | while(!is.na(tree$Feature[indx])) { 89 | indx <- ifelse(obs[i] == 0, tree$Missing[indx], ifelse(obs[i] < 0, tree$Yes[indx], tree$No[indx])) 90 | #if(length(is.na(tree$Feature[indx]))>1) {print(paste(indx, i)); print(tree); print(obs)} 91 | i <- i + 1 92 | } 93 | return(tree[['Prediction']][indx]) 94 | } 95 | x = numeric() 96 | for(i in seq_along(test_obs)) { 97 | x[i] <- determine_val(test_obs[[i]], test_tree[[i]]) 98 | 99 | } 100 | return(list(preds = x, test_obs = test_obs)) 101 | } 102 | 103 | prepare_original_preds_lgbm <- function(orig_tree, test_obs){ 104 | test_tree <- orig_tree[orig_tree$tree_index %in% 0:9,] 105 | test_tree <- split(test_tree, test_tree$tree_index) 106 | stopifnot(length(test_tree) == length(test_obs)) 107 | determine_val <- function(obs, tree){ 108 | i <- 1 109 | indx <- 1 110 | while(!is.na(tree$split_feature[indx])) { 111 | children <- ifelse(is.na(tree$node_parent), tree$leaf_parent, tree$node_parent) 112 | if((obs[i] < 0) | (tree$default_left[indx] == 'TRUE' & obs[i] == 0)){ 113 | indx <- which(tree$split_index[indx] == children)[1] 114 | } 115 | else if((obs[i] > 0) | (tree$default_left[indx] == 'FALSE' & obs[i] == 0)) { 116 | indx <- which(tree$split_index[indx] == children)[2] 117 | } 118 | else{ 119 | stop('Error in the connections') 120 | indx <- 0 121 | } 122 | i <- i + 1 123 | } 124 | return(tree[['leaf_value']][indx]) 125 | } 126 | y = numeric() 127 | for(i in seq_along(test_obs)) { 128 | y[i] <- determine_val(test_obs[[i]], test_tree[[i]]) 129 | } 130 | return(y) 131 | } 132 | 133 | test_that('the connections between the nodes are correct', { 134 | # The test is passed only if the predictions for sample observations are equal in the first 10 trees of the ensemble 135 | x <- prepare_test_preds(lightgbm.unify(lgbm_fifa, sparse_data)$model) 136 | preds <- x[['preds']] 137 | test_obs <- x[['test_obs']] 138 | original_preds <- prepare_original_preds_lgbm(lgbmtree, test_obs) 139 | expect_equal(preds, original_preds) 140 | }) 141 | 142 | test_that("LightGBM: predictions from unified == original predictions", { 143 | unifier <- lightgbm.unify(lgbm_fifa, sparse_data) 144 | obs <- c(1:16000) 145 | original <- stats::predict(lgbm_fifa, sparse_data[obs, ]) 146 | from_unified <- predict(unifier, sparse_data[obs, ]) 147 | expect_equal(from_unified, original) 148 | #expect_true(all(abs((from_unified - original) / original) < 10**(-14))) #not needed 149 | }) 150 | 151 | test_that("LightGBM: mean prediction calculated using predict == using covers", { 152 | unifier <- lightgbm.unify(lgbm_fifa, sparse_data) 153 | 154 | intercept_predict <- mean(predict(unifier, sparse_data)) 155 | 156 | ntrees <- sum(unifier$model$Node == 0) 157 | leaves <- unifier$model[is.na(unifier$model$Feature), ] 158 | intercept_covers <- sum(leaves$Prediction * leaves$Cover) / sum(leaves$Cover) * ntrees 159 | 160 | #expect_true(all(abs((intercept_predict - intercept_covers) / intercept_predict) < 10**(-14))) 161 | expect_equal(intercept_predict, intercept_covers) 162 | }) 163 | 164 | test_that("LightGBM: covers correctness", { 165 | unifier <- lightgbm.unify(lgbm_fifa, sparse_data) 166 | 167 | roots <- unifier$model[unifier$model$Node == 0, ] 168 | expect_true(all(roots$Cover == nrow(sparse_data))) 169 | 170 | internals <- unifier$model[!is.na(unifier$model$Feature), ] 171 | yes_child_cover <- unifier$model[internals$Yes, ]$Cover 172 | no_child_cover <- unifier$model[internals$No, ]$Cover 173 | if (all(is.na(internals$Missing))) { 174 | children_cover <- yes_child_cover + no_child_cover 175 | } else { 176 | missing_child_cover <- unifier$model[internals$Missing, ]$Cover 177 | missing_child_cover[is.na(missing_child_cover)] <- 0 178 | missing_child_cover[internals$Missing == internals$Yes | internals$Missing == internals$No] <- 0 179 | children_cover <- yes_child_cover + no_child_cover + missing_child_cover 180 | } 181 | expect_true(all(internals$Cover == children_cover)) 182 | }) 183 | -------------------------------------------------------------------------------- /tests/testthat/test_randomForest.R: -------------------------------------------------------------------------------- 1 | library(treeshap) 2 | 3 | data_fifa <- fifa20$data[!colnames(fifa20$data) %in% 4 | c('value_eur', 'gk_diving', 'gk_handling', 5 | 'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')] 6 | x <- na.omit(cbind(data_fifa, target = fifa20$target)) 7 | 8 | rf_with_cat_model <- randomForest::randomForest(target~., data = x, maxnodes = 10, ntree = 10) 9 | 10 | x <- x[colnames(x) != 'work_rate'] 11 | 12 | 13 | rf_num_model <- randomForest::randomForest(target~., data = x, maxnodes = 10, ntree = 10) 14 | 15 | 16 | test_that('randomForest.unify creates an object of the appropriate class', { 17 | expect_true(is.model_unified(randomForest.unify(rf_num_model, x))) 18 | expect_true(is.model_unified(unify(rf_num_model, x))) 19 | }) 20 | 21 | test_that('randomForest.unify returns an object with correct attributes', { 22 | unified_model <- randomForest.unify(rf_num_model, x) 23 | 24 | expect_equal(attr(unified_model, "missing_support"), FALSE) 25 | expect_equal(attr(unified_model, "model"), "randomForest") 26 | }) 27 | 28 | test_that('the randomForest.unify function returns data frame with columns of appropriate column', { 29 | unifier <- randomForest.unify(rf_num_model, x)$model 30 | expect_true(is.integer(unifier$Tree)) 31 | expect_true(is.integer(unifier$Node)) 32 | expect_true(is.character(unifier$Feature)) 33 | expect_true(is.factor(unifier$Decision.type)) 34 | expect_true(is.numeric(unifier$Split)) 35 | expect_true(is.integer(unifier$Yes)) 36 | expect_true(is.integer(unifier$No)) 37 | expect_true(is.integer(unifier$Missing)) 38 | expect_true(is.numeric(unifier$Prediction)) 39 | expect_true(is.numeric(unifier$Cover)) 40 | }) 41 | 42 | test_that("shap calculates without an error", { 43 | unifier <- randomForest.unify(rf_num_model, x) 44 | expect_error(treeshap(unifier, x[1:3,], verbose = FALSE), NA) 45 | }) 46 | 47 | test_that("randomForest: predictions from unified == original predictions", { 48 | unifier <- randomForest.unify(rf_num_model, x) 49 | obs <- x[1:16000, ] 50 | original <- stats::predict(rf_num_model, obs) 51 | names(original) <- NULL 52 | from_unified <- predict(unifier, obs) 53 | expect_true(all(abs((from_unified - original) / original) < 10**(-14))) 54 | }) 55 | 56 | test_that("randomForest: mean prediction calculated using predict == using covers", { 57 | unifier <- randomForest.unify(rf_num_model, x) 58 | 59 | intercept_predict <- mean(predict(unifier, x)) 60 | 61 | ntrees <- sum(unifier$model$Node == 0) 62 | leaves <- unifier$model[is.na(unifier$model$Feature), ] 63 | intercept_covers <- sum(leaves$Prediction * leaves$Cover) / sum(leaves$Cover) * ntrees 64 | 65 | #expect_true(all(abs((intercept_predict - intercept_covers) / intercept_predict) < 10**(-14))) 66 | expect_equal(intercept_predict, intercept_covers) 67 | }) 68 | 69 | test_that("randomForest: covers correctness", { 70 | unifier <- randomForest.unify(rf_num_model, x) 71 | 72 | roots <- unifier$model[unifier$model$Node == 0, ] 73 | expect_true(all(roots$Cover == nrow(x))) 74 | 75 | internals <- unifier$model[!is.na(unifier$model$Feature), ] 76 | yes_child_cover <- unifier$model[internals$Yes, ]$Cover 77 | no_child_cover <- unifier$model[internals$No, ]$Cover 78 | if (all(is.na(internals$Missing))) { 79 | children_cover <- yes_child_cover + no_child_cover 80 | } else { 81 | missing_child_cover <- unifier$model[internals$Missing, ]$Cover 82 | missing_child_cover[is.na(missing_child_cover)] <- 0 83 | missing_child_cover[internals$Missing == internals$Yes | internals$Missing == internals$No] <- 0 84 | children_cover <- yes_child_cover + no_child_cover + missing_child_cover 85 | } 86 | expect_true(all(internals$Cover == children_cover)) 87 | }) 88 | -------------------------------------------------------------------------------- /tests/testthat/test_ranger.R: -------------------------------------------------------------------------------- 1 | library(treeshap) 2 | 3 | data_fifa <- fifa20$data[!colnames(fifa20$data) %in% 4 | c('value_eur', 'gk_diving', 'gk_handling', 5 | 'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')] 6 | x <- na.omit(cbind(data_fifa, target = fifa20$target)) 7 | 8 | ranger_with_cat_model <- ranger::ranger(target ~ ., data = x, max.depth = 10, num.trees = 10) 9 | 10 | x <- x[colnames(x) != 'work_rate'] 11 | 12 | 13 | ranger_num_model <- ranger::ranger(target ~ ., data = x, max.depth = 10, num.trees = 10) 14 | 15 | 16 | test_that('ranger.unify creates an object of the appropriate class', { 17 | expect_true(is.model_unified(ranger.unify(ranger_num_model, x))) 18 | expect_true(is.model_unified(unify(ranger_num_model, x))) 19 | }) 20 | 21 | test_that('ranger.unify returns an object with correct attributes', { 22 | unified_model <- ranger.unify(ranger_num_model, x) 23 | 24 | expect_equal(attr(unified_model, "missing_support"), FALSE) 25 | expect_equal(attr(unified_model, "model"), "ranger") 26 | }) 27 | 28 | test_that('the ranger.unify function returns data frame with columns of appropriate column', { 29 | unifier <- ranger.unify(ranger_num_model, x)$model 30 | expect_true(is.integer(unifier$Tree)) 31 | expect_true(is.integer(unifier$Node)) 32 | expect_true(is.character(unifier$Feature)) 33 | expect_true(is.factor(unifier$Decision.type)) 34 | expect_true(is.numeric(unifier$Split)) 35 | expect_true(is.integer(unifier$Yes)) 36 | expect_true(is.integer(unifier$No)) 37 | expect_true(all(is.na(unifier$Missing))) 38 | expect_true(is.numeric(unifier$Prediction)) 39 | expect_true(is.numeric(unifier$Cover)) 40 | }) 41 | 42 | test_that("shap calculates without an error", { 43 | unifier <- ranger.unify(ranger_num_model, x) 44 | expect_error(treeshap(unifier, x[1:3,], verbose = FALSE), NA) 45 | }) 46 | 47 | test_that("ranger: predictions from unified == original predictions", { 48 | unifier <- ranger.unify(ranger_num_model, x) 49 | obs <- x[1:16000, ] 50 | original <- ranger::predictions(stats::predict(ranger_num_model, obs)) 51 | from_unified <- predict(unifier, obs) 52 | expect_true(all(abs((from_unified - original) / original) < 10**(-14))) 53 | }) 54 | 55 | test_that("ranger: mean prediction calculated using predict == using covers", { 56 | unifier <- ranger.unify(ranger_num_model, x) 57 | 58 | intercept_predict <- mean(predict(unifier, x)) 59 | 60 | ntrees <- sum(unifier$model$Node == 0) 61 | leaves <- unifier$model[is.na(unifier$model$Feature), ] 62 | intercept_covers <- sum(leaves$Prediction * leaves$Cover) / sum(leaves$Cover) * ntrees 63 | 64 | #expect_true(all(abs((intercept_predict - intercept_covers) / intercept_predict) < 10**(-14))) 65 | expect_equal(intercept_predict, intercept_covers) 66 | }) 67 | 68 | test_that("ranger: covers correctness", { 69 | unifier <- ranger.unify(ranger_num_model, x) 70 | 71 | roots <- unifier$model[unifier$model$Node == 0, ] 72 | expect_true(all(roots$Cover == nrow(x))) 73 | 74 | internals <- unifier$model[!is.na(unifier$model$Feature), ] 75 | yes_child_cover <- unifier$model[internals$Yes, ]$Cover 76 | no_child_cover <- unifier$model[internals$No, ]$Cover 77 | if (all(is.na(internals$Missing))) { 78 | children_cover <- yes_child_cover + no_child_cover 79 | } else { 80 | missing_child_cover <- unifier$model[internals$Missing, ]$Cover 81 | missing_child_cover[is.na(missing_child_cover)] <- 0 82 | missing_child_cover[internals$Missing == internals$Yes | internals$Missing == internals$No] <- 0 83 | children_cover <- yes_child_cover + no_child_cover + missing_child_cover 84 | } 85 | expect_true(all(internals$Cover == children_cover)) 86 | }) 87 | -------------------------------------------------------------------------------- /tests/testthat/test_ranger_surv.R: -------------------------------------------------------------------------------- 1 | library(treeshap) 2 | 3 | data_colon <- data.table::data.table(survival::colon) 4 | data_colon <- na.omit(data_colon[get("etype") == 2, ]) 5 | surv_cols <- c("status", "time", "rx") 6 | 7 | feature_cols <- colnames(data_colon)[3:(ncol(data_colon) - 1)] 8 | 9 | x <- model.matrix( 10 | ~ -1 + ., 11 | data_colon[, .SD, .SDcols = setdiff(feature_cols, surv_cols[1:2])] 12 | ) 13 | y <- survival::Surv( 14 | event = (data_colon[, get("status")] |> 15 | as.character() |> 16 | as.integer()), 17 | time = data_colon[, get("time")], 18 | type = "right" 19 | ) 20 | 21 | set.seed(123) 22 | ranger_num_model <- ranger::ranger( 23 | x = x, 24 | y = y, 25 | data = data_colon, 26 | max.depth = 10, 27 | num.trees = 10 28 | ) 29 | 30 | 31 | 32 | # to save some time for these tests, compute model here once: 33 | unified_model <- ranger_surv.unify(ranger_num_model, x) 34 | unified_model2 <- unify(ranger_num_model, x) 35 | 36 | 37 | test_that('ranger_surv.unify creates an object of the appropriate class', { 38 | expect_true(is.model_unified(unified_model)) 39 | expect_true(is.model_unified(unified_model2)) 40 | }) 41 | 42 | test_that('ranger_surv.unify returns an object with correct attributes', { 43 | expect_equal(attr(unified_model, "missing_support"), FALSE) 44 | expect_equal(attr(unified_model, "model"), "ranger") 45 | }) 46 | 47 | test_that('the ranger_surv.unify function returns data frame with columns of appropriate column', { 48 | unifier <- unified_model$model 49 | expect_true(is.integer(unifier$Tree)) 50 | expect_true(is.integer(unifier$Node)) 51 | expect_true(is.character(unifier$Feature)) 52 | expect_true(is.factor(unifier$Decision.type)) 53 | expect_true(is.numeric(unifier$Split)) 54 | expect_true(is.integer(unifier$Yes)) 55 | expect_true(is.integer(unifier$No)) 56 | expect_true(all(is.na(unifier$Missing))) 57 | expect_true(is.numeric(unifier$Prediction)) 58 | expect_true(is.numeric(unifier$Cover)) 59 | }) 60 | 61 | test_that("ranger_surv: shap calculates without an error", { 62 | expect_error(treeshap(unified_model, x[1:3,], verbose = FALSE), NA) 63 | }) 64 | 65 | test_that("ranger_surv: predictions from unified == original predictions", { 66 | obs <- x[1:800, ] 67 | surv_preds <- stats::predict(ranger_num_model, obs) 68 | original <- rowSums(surv_preds$chf) 69 | from_unified <- predict(unified_model, obs) 70 | expect_true(all(abs((from_unified - original) / original) < 10**(-13))) 71 | }) 72 | 73 | test_that("ranger_surv: mean prediction calculated using predict == using covers", { 74 | 75 | intercept_predict <- mean(predict(unified_model, x)) 76 | 77 | ntrees <- sum(unified_model$model$Node == 0) 78 | leaves <- unified_model$model[is.na(unified_model$model$Feature), ] 79 | intercept_covers <- sum(leaves$Prediction * leaves$Cover) / sum(leaves$Cover) * ntrees 80 | 81 | #expect_true(all(abs((intercept_predict - intercept_covers) / intercept_predict) < 10**(-14))) 82 | expect_equal(intercept_predict, intercept_covers) 83 | }) 84 | 85 | test_that("ranger_surv: covers correctness", { 86 | 87 | roots <- unified_model$model[unified_model$model$Node == 0, ] 88 | expect_true(all(roots$Cover == nrow(x))) 89 | 90 | internals <- unified_model$model[!is.na(unified_model$model$Feature), ] 91 | yes_child_cover <- unified_model$model[internals$Yes, ]$Cover 92 | no_child_cover <- unified_model$model[internals$No, ]$Cover 93 | if (all(is.na(internals$Missing))) { 94 | children_cover <- yes_child_cover + no_child_cover 95 | } else { 96 | missing_child_cover <- unified_model$model[internals$Missing, ]$Cover 97 | missing_child_cover[is.na(missing_child_cover)] <- 0 98 | missing_child_cover[internals$Missing == internals$Yes | internals$Missing == internals$No] <- 0 99 | children_cover <- yes_child_cover + no_child_cover + missing_child_cover 100 | } 101 | expect_true(all(internals$Cover == children_cover)) 102 | }) 103 | 104 | 105 | # tests for ranger_surv.unify (type = "survival") 106 | # to save some time for these tests, compute model here once: 107 | unified_model <- ranger_surv.unify(ranger_num_model, x, type = "survival", times = c(10, 50, 100)) 108 | unified_model2 <- unify(ranger_num_model, x, type = "survival", times = c(10, 50, 100)) 109 | 110 | 111 | test_that('ranger_surv.unify (type = "survival") list names == unique.death.times', { 112 | expect_equal(names(unified_model), as.character(c(10, 50, 100))) 113 | expect_equal(names(unified_model2), as.character(c(10, 50, 100))) 114 | }) 115 | 116 | test_that('ranger_surv.unify (type = "survival") creates an object of the appropriate class', { 117 | expect_s3_class(unified_model, "model_unified_multioutput") 118 | expect_s3_class(unified_model2, "model_unified_multioutput") 119 | lapply(unified_model, function(m) expect_true(is.model_unified(m))) 120 | lapply(unified_model2, function(m) expect_true(is.model_unified(m))) 121 | }) 122 | 123 | test_that('ranger_surv.unify (type = "survival") returns an object with correct attributes', { 124 | m <- unified_model[[1]] 125 | expect_equal(attr(m, "missing_support"), FALSE) 126 | expect_equal(attr(m, "model"), "ranger") 127 | }) 128 | 129 | test_that('the ranger_surv.unify (type = "survival") function returns data frame with columns of appropriate column', { 130 | unifier <- unified_model[[1]]$model 131 | expect_true(is.integer(unifier$Tree)) 132 | expect_true(is.integer(unifier$Node)) 133 | expect_true(is.character(unifier$Feature)) 134 | expect_true(is.factor(unifier$Decision.type)) 135 | expect_true(is.numeric(unifier$Split)) 136 | expect_true(is.integer(unifier$Yes)) 137 | expect_true(is.integer(unifier$No)) 138 | expect_true(all(is.na(unifier$Missing))) 139 | expect_true(is.numeric(unifier$Prediction)) 140 | expect_true(is.numeric(unifier$Cover)) 141 | }) 142 | 143 | test_that('ranger_surv.unify (type = "survival"): shap calculates without an error', { 144 | expect_error(treeshap(unified_model[[1]], x[1:3,], verbose = FALSE), NA) 145 | }) 146 | 147 | test_that('ranger_surv.unify (type = "survival"): predictions from unified == original predictions', { 148 | for (t in names(unified_model)) { 149 | m <- unified_model[[t]] 150 | death_time <- as.integer(t) 151 | obs <- x[1:800, ] 152 | surv_preds <- stats::predict(ranger_num_model, obs) 153 | original <- surv_preds$survival[, which(surv_preds$unique.death.times == death_time)] 154 | from_unified <- predict(m, obs) 155 | # this is yet kind of strange that values differ so much 156 | expect_true(all(abs((from_unified - original) / original) < 8e-1)) 157 | } 158 | }) 159 | 160 | test_that('ranger_surv.unify (type = "survival"): mean prediction calculated using predict == using covers', { 161 | m <- unified_model[[1]] 162 | intercept_predict <- mean(predict(m, x)) 163 | 164 | ntrees <- sum(m$model$Node == 0) 165 | leaves <- m$model[is.na(m$model$Feature), ] 166 | intercept_covers <- sum(leaves$Prediction * leaves$Cover) / sum(leaves$Cover) * ntrees 167 | 168 | #expect_true(all(abs((intercept_predict - intercept_covers) / intercept_predict) < 10**(-14))) 169 | expect_equal(intercept_predict, intercept_covers) 170 | }) 171 | 172 | test_that('ranger_surv.unify (type = "survival"): covers correctness', { 173 | for (m in unified_model) { 174 | roots <- m$model[m$model$Node == 0, ] 175 | expect_true(all(roots$Cover == nrow(x))) 176 | 177 | internals <- m$model[!is.na(m$model$Feature), ] 178 | yes_child_cover <- m$model[internals$Yes, ]$Cover 179 | no_child_cover <- m$model[internals$No, ]$Cover 180 | if (all(is.na(internals$Missing))) { 181 | children_cover <- yes_child_cover + no_child_cover 182 | } else { 183 | missing_child_cover <- m$model[internals$Missing, ]$Cover 184 | missing_child_cover[is.na(missing_child_cover)] <- 0 185 | missing_child_cover[internals$Missing == internals$Yes | internals$Missing == internals$No] <- 0 186 | children_cover <- yes_child_cover + no_child_cover + missing_child_cover 187 | } 188 | expect_true(all(internals$Cover == children_cover)) 189 | } 190 | }) 191 | -------------------------------------------------------------------------------- /tests/testthat/test_set_reference_dataset.R: -------------------------------------------------------------------------------- 1 | library(treeshap) 2 | library(xgboost) 3 | data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 4 | target <- fifa20$target 5 | 6 | test_that('recalculate covers works correctly for xgboost model', { 7 | param <- list(objective = "reg:squarederror", max_depth = 5) 8 | xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, nrounds = 100, verbose = FALSE) 9 | unified <- xgboost.unify(xgb_model, as.matrix(data)) 10 | a <- set_reference_dataset(unified, data)$Cover 11 | b <- unified$Cover 12 | expect_true(all(a == b)) 13 | }) 14 | -------------------------------------------------------------------------------- /tests/testthat/test_xgboost_unify.R: -------------------------------------------------------------------------------- 1 | library(treeshap) 2 | data <- fifa20$data[colnames(fifa20$data) != 'work_rate'] 3 | target <- fifa20$target 4 | param <- list(objective = "reg:squarederror", max_depth = 3) 5 | xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, nrounds = 200, verbose = 0) 6 | xgb_tree <- xgboost::xgb.model.dt.tree(model = xgb_model) 7 | 8 | 9 | test_that('xgboost.unify returns an object of appropriate class', { 10 | expect_true(is.model_unified(xgboost.unify(xgb_model, as.matrix(data)))) 11 | expect_true(is.model_unified(unify(xgb_model, as.matrix(data)))) 12 | }) 13 | 14 | test_that('xgboost.unify returns an object with correct attributes', { 15 | unified_model <- xgboost.unify(xgb_model, as.matrix(data)) 16 | 17 | expect_equal(attr(unified_model, "missing_support"), TRUE) 18 | expect_equal(attr(unified_model, "model"), "xgboost") 19 | }) 20 | 21 | test_that('columns after xgboost.unify are of appropriate type', { 22 | unified_model <- xgboost.unify(xgb_model, as.matrix(data))$model 23 | 24 | expect_true(is.integer(unified_model$Tree)) 25 | expect_true(is.integer(unified_model$Node)) 26 | expect_true(is.character(unified_model$Feature)) 27 | expect_true(is.factor(unified_model$Decision.type)) 28 | expect_true(is.numeric(unified_model$Split)) 29 | expect_true(is.integer(unified_model$Yes)) 30 | expect_true(is.integer(unified_model$No)) 31 | expect_true(is.integer(unified_model$Missing)) 32 | expect_true(is.numeric(unified_model$Prediction)) 33 | expect_true(is.numeric(unified_model$Cover)) 34 | }) 35 | 36 | test_that('values in the columns after xgboost.unify are correct', { 37 | unified_model <- xgboost.unify(xgb_model, as.matrix(data))$model 38 | 39 | expect_equal(xgb_tree$Tree, unified_model$Tree) 40 | expect_equal(xgb_tree$Node, unified_model$Node) 41 | expect_equal(xgb_tree$Cover, unified_model$Cover) 42 | expect_equal(xgb_tree$Quality[xgb_tree$Feature == 'Leaf'], unified_model$Prediction[is.na(unified_model$Feature)]) 43 | expect_equal(xgb_tree$Node, unified_model$Node) 44 | expect_equal(xgb_tree$Split, unified_model$Split) 45 | expect_equal(match(xgb_tree$Yes, xgb_tree$ID), unified_model$Yes) 46 | expect_equal(match(xgb_tree$No, xgb_tree$ID), unified_model$No) 47 | expect_equal(match(xgb_tree$Missing, xgb_tree$ID), unified_model$Missing) 48 | expect_equal(xgb_tree[xgb_tree[['Feature']] != 'Leaf',][['Feature']], 49 | unified_model[!is.na(unified_model$Feature),][['Feature']]) 50 | expect_equal(nrow(xgb_tree[xgb_tree[['Feature']] == 'Leaf',]), 51 | nrow(unified_model[is.na(unified_model$Feature),])) 52 | 53 | }) 54 | 55 | 56 | 57 | test_that('xgboost.unify() does not work for objects produced with other packages', { 58 | param_lightgbm <- list(objective = "regression", 59 | max_depth = 2, 60 | num_leaves = 4L, 61 | force_row_wise = TRUE, 62 | learning.rate = 0.1) 63 | expect_warning({lgbm_fifa <- lightgbm::lightgbm(data = as.matrix(fifa20$data[colnames(fifa20$data) != 'value_eur']), 64 | label = fifa20$target, 65 | params = param_lightgbm, 66 | verbose = -1, 67 | num_threads = 0) 68 | lgbmtree <- lightgbm::lgb.model.dt.tree(lgbm_fifa)}) 69 | expect_error(xgboost.unify(lgbmtree)) 70 | }) 71 | 72 | # Function that return the predictions for sample observations indicated by vector contatining values -1, 0, 1, where -1 means 73 | # going to the 'Yes' Node, 1 - to the 'No' node and 0 - to the missing node. The vectors are randomly produced during executing 74 | # the function and should be passed to prepare_original_preds_ to save the conscistence. Later we can compare the 'predicted' values 75 | prepare_test_preds <- function(unify_out){ 76 | stopifnot(all(c("Tree", "Node", "Feature", "Split", "Yes", "No", "Missing", "Prediction", "Cover") %in% colnames(unify_out))) 77 | test_tree <- unify_out[unify_out$Tree %in% 0:9, ] 78 | test_tree[['node_row_id']] <- seq_len(nrow(test_tree)) 79 | test_obs <- lapply(table(test_tree$Tree), function(y) sample(c(-1, 0, 1), y, replace = T)) 80 | test_tree <- split(test_tree, test_tree$Tree) 81 | determine_val <- function(obs, tree){ 82 | root_id <- tree[['node_row_id']][1] 83 | tree[,c('Yes', 'No', 'Missing')] <- tree[,c('Yes', 'No', 'Missing')] - root_id + 1 84 | i <- 1 85 | indx <- 1 86 | while(!is.na(tree$Feature[indx])) { 87 | indx <- ifelse(obs[i] == 0, tree$Missing[indx], ifelse(obs[i] < 0, tree$Yes[indx], tree$No[indx])) 88 | i <- i + 1 89 | } 90 | return(tree[['Prediction']][indx]) 91 | } 92 | x = numeric() 93 | for(i in seq_along(test_obs)) { 94 | x[i] <- determine_val(test_obs[[i]], test_tree[[i]]) 95 | 96 | } 97 | return(list(preds = x, test_obs = test_obs)) 98 | } 99 | 100 | prepare_original_preds_xgb <- function(orig_tree, test_obs){ 101 | test_tree <- orig_tree[orig_tree$Tree %in% 0:9, ] 102 | test_tree <- split(test_tree, test_tree$Tree) 103 | stopifnot(length(test_tree) == length(test_obs)) 104 | determine_val <- function(obs, tree){ 105 | i <- 1 106 | indx <- 1 107 | while(!is.na(tree$Split[indx])) { 108 | indx <- ifelse(obs[i] == 0, match(tree$Missing[indx], tree$ID), ifelse(obs[i] < 0, match(tree$Yes[indx], tree$ID), 109 | match(tree$No[indx], tree$ID))) 110 | 111 | i <- i + 1 112 | } 113 | return(tree[['Quality']][indx]) 114 | } 115 | y = numeric() 116 | for(i in seq_along(test_obs)) { 117 | y[i] <- determine_val(test_obs[[i]], test_tree[[i]]) 118 | } 119 | return(y) 120 | } 121 | 122 | test_that('the connections between the nodes are correct', { 123 | # The test is passed only if the predictions for sample observations are equal in the first 10 trees of the ensemble 124 | x <- prepare_test_preds(xgboost.unify(xgb_model, as.matrix(data))$model) 125 | preds <- x[['preds']] 126 | test_obs <- x[['test_obs']] 127 | original_preds <- prepare_original_preds_xgb(xgb_tree, test_obs) 128 | expect_equal(preds, original_preds) 129 | }) 130 | 131 | test_that("xgboost: predictions from unified == original predictions", { 132 | unifier <- xgboost.unify(xgb_model, data) 133 | obs <- data[1:16000, ] 134 | original <- stats::predict(xgb_model, as.matrix(obs)) 135 | from_unified <- predict(unifier, obs) 136 | # expect_equal(from_unified, original) #there are small differences 137 | expect_true(all(abs((from_unified - original) / original) < 5 * 10**(-3))) 138 | }) 139 | 140 | test_that("xgboost: mean prediction calculated using predict == using covers", { 141 | unifier <- xgboost.unify(xgb_model, data) 142 | 143 | intercept_predict <- mean(predict(unifier, data)) 144 | 145 | ntrees <- sum(unifier$model$Node == 0) 146 | leaves <- unifier$model[is.na(unifier$model$Feature), ] 147 | intercept_covers <- sum(leaves$Prediction * leaves$Cover) / sum(leaves$Cover) * ntrees 148 | 149 | #expect_true(all(abs((intercept_predict - intercept_covers) / intercept_predict) < 10**(-14))) 150 | expect_equal(intercept_predict, intercept_covers) 151 | }) 152 | 153 | test_that("xgboost: covers correctness", { 154 | unifier <- xgboost.unify(xgb_model, data) 155 | 156 | roots <- unifier$model[unifier$model$Node == 0, ] 157 | expect_true(all(roots$Cover == nrow(data))) 158 | 159 | internals <- unifier$model[!is.na(unifier$model$Feature), ] 160 | yes_child_cover <- unifier$model[internals$Yes, ]$Cover 161 | no_child_cover <- unifier$model[internals$No, ]$Cover 162 | if (all(is.na(internals$Missing))) { 163 | children_cover <- yes_child_cover + no_child_cover 164 | } else { 165 | missing_child_cover <- unifier$model[internals$Missing, ]$Cover 166 | missing_child_cover[is.na(missing_child_cover)] <- 0 167 | missing_child_cover[internals$Missing == internals$Yes | internals$Missing == internals$No] <- 0 168 | children_cover <- yes_child_cover + no_child_cover + missing_child_cover 169 | } 170 | expect_true(all(internals$Cover == children_cover)) 171 | }) 172 | -------------------------------------------------------------------------------- /treeshap.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: No 4 | SaveWorkspace: No 5 | AlwaysSaveHistory: No 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: Sweave 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | LineEndingConversion: Posix 18 | 19 | BuildType: Package 20 | PackageUseDevtools: Yes 21 | PackageInstallArgs: --no-multiarch --with-keep.source 22 | PackageRoxygenize: rd,collate,namespace 23 | --------------------------------------------------------------------------------