├── .Rbuildignore ├── .gitignore ├── .travis.yml ├── DESCRIPTION ├── NAMESPACE ├── R ├── buildExplainer.R ├── buildExplainerFromTree.R ├── explainPredictions.R ├── findPath.R ├── getStatsForTrees.R ├── getTreeBreakdown.R └── showWaterfall.R ├── README.md ├── example └── example.R ├── lightgbmExplainer.Rproj └── man ├── buildExplainer.Rd ├── explainPredictions.Rd └── showWaterfall.Rd /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^.*\.Rproj$ 2 | ^\.Rproj\.user$ 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # History files 2 | .Rhistory 3 | .Rapp.history 4 | 5 | # Session Data files 6 | .RData 7 | 8 | # Example code in package build process 9 | *-Ex.R 10 | 11 | # Output files from R CMD build 12 | /*.tar.gz 13 | 14 | # Output files from R CMD check 15 | /*.Rcheck/ 16 | 17 | # RStudio files 18 | .Rproj.user/ 19 | 20 | # produced vignettes 21 | vignettes/*.html 22 | vignettes/*.pdf 23 | 24 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 25 | .httr-oauth 26 | 27 | # knitr and R markdown default cache directories 28 | /*_cache/ 29 | /cache/ 30 | 31 | # Temporary files created by R markdown 32 | *.utf8.md 33 | *.knit.md 34 | .Rproj.user 35 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: r 2 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: lightgbmExplainer 2 | Type: Package 3 | Title: An R package that makes lightgbm models fully interpretable 4 | Version: 0.2.0 5 | Author: Albert cheng 6 | Maintainer: Albert cheng 7 | Description: An R package that makes lightgbm models fully interpretable 8 | Depends: R (>= 3.4.0) 9 | Imports: data.table, waterfalls, scales, ggplot2, purrr 10 | Suggests: lightgbm (>= 2.1.0) 11 | License: GPL-3 12 | Encoding: UTF-8 13 | LazyData: true 14 | RoxygenNote: 6.0.1 15 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | export(buildExplainer) 4 | export(explainPredictions) 5 | export(showWaterfall) 6 | import(data.table) 7 | import(ggplot2) 8 | import(lightgbm) 9 | import(scales) 10 | import(waterfalls) 11 | importFrom(purrr,map) 12 | importFrom(purrr,walk) 13 | importFrom(purrr,walk2) 14 | -------------------------------------------------------------------------------- /R/buildExplainer.R: -------------------------------------------------------------------------------- 1 | #' Step 1: Build an lightgbmExplainer 2 | #' 3 | #' This function outputs an lightgbmExplainer (a data table that stores the feature impact breakdown for each leaf of each tree in an lightgbm model). It is required as input into the explainPredictions and showWaterfall functions. 4 | #' @param lgb_tree A lightgbm.dt.tree 5 | #' @return The lightgbm Explainer for the model. This is a data table where each row is a leaf of a tree in the lightgbm model 6 | #' and each column is the impact of each feature on the prediction at the leaf. 7 | #' 8 | #' The leaf and tree columns uniquely identify the node. 9 | #' 10 | #' The sum of the other columns equals the prediction at the leaf (log-odds if binary response). 11 | #' 12 | #' The 'intercept' column is identical for all rows and is analogous to the intercept term in a linear / logistic regression. 13 | #' 14 | #' @export 15 | #' @import data.table 16 | #' @import lightgbm 17 | #' @examples 18 | #' library(lightgbm) # v2.1.0 or above 19 | #' library(lightgbmExplainer) 20 | #' 21 | #' # Load Data 22 | #' data(agaricus.train, package = "lightgbm") 23 | #' # Train a model 24 | #' lgb.dtrain <- lgb.Dataset(agaricus.train$data, label = agaricus.train$label) 25 | #' lgb.params <- list(objective = "binary") 26 | #' lgb.model <- lgb.train(lgb.params, lgb.dtrain, 5) 27 | #' # Build Explainer 28 | #' lgb.trees <- lgb.model.dt.tree(lgb.model) # First get a lgb tree 29 | #' explainer <- buildExplainer(lgb.trees) 30 | #' # compute contribution for each data point 31 | #' pred.breakdown <- explainPredictions(lgb.model, explainer, agaricus.train$data) 32 | #' # Show waterfall for the 8th observation 33 | #' showWaterfall(lgb.model, explainer, lgb.dtrain, agaricus.train$data, 8, type = "binary") 34 | 35 | 36 | buildExplainer = function(lgb_tree){ 37 | 38 | # TODO - Add test case for lgb.dt.tree order 39 | 40 | cat('\nBuilding the Explainer...') 41 | cat('\nSTEP 1 of 2') 42 | lgb_tree_with_stat = getStatsForTrees(lgb_tree) 43 | cat('\n\nSTEP 2 of 2') 44 | explainer = buildExplainerFromTree(lgb_tree_with_stat) 45 | 46 | cat('\n\nDONE!\n') 47 | 48 | return (explainer) 49 | } 50 | -------------------------------------------------------------------------------- /R/buildExplainerFromTree.R: -------------------------------------------------------------------------------- 1 | 2 | #' @import data.table 3 | #' @import lightgbm 4 | 5 | buildExplainerFromTree = function(lgb_tree_with_stat){ 6 | 7 | ####accepts a list of trees and column names 8 | ####outputs a data table, of the impact of each variable + intercept, for each leaf 9 | col_names <- purrr::discard(unique(lgb_tree_with_stat$split_feature), is.na) 10 | 11 | lgb_tree_with_stat_breakdown <- 12 | setNames(data.table(matrix(nrow = 0, ncol = length(col_names) + 3)), 13 | c(col_names,'intercept', 'leaf','tree')) 14 | 15 | num_trees = length(unique(lgb_tree_with_stat$tree_index)) 16 | 17 | cat('\n\nGetting breakdown for each leaf of each tree...\n') 18 | pb <- txtProgressBar(style=3) 19 | 20 | for (x in 0:(num_trees-1)){ 21 | tree = lgb_tree_with_stat[tree_index == x] 22 | tree_breakdown = getTreeBreakdown(tree, col_names) 23 | tree_breakdown$tree = x 24 | lgb_tree_with_stat_breakdown = rbindlist(append(list(lgb_tree_with_stat_breakdown),list(tree_breakdown))) 25 | setTxtProgressBar(pb, (x+1) / num_trees) 26 | } 27 | # replace NA with 0 28 | lgb_tree_with_stat_breakdown[is.na(lgb_tree_with_stat_breakdown)] <- 0 29 | return (lgb_tree_with_stat_breakdown) 30 | 31 | } 32 | -------------------------------------------------------------------------------- /R/explainPredictions.R: -------------------------------------------------------------------------------- 1 | #' Step 2: Get multiple prediction breakdowns from a trained lightgbm model 2 | #' 3 | #' This function outputs the feature impact breakdown of a set of predictions made using an lightgbm model. 4 | #' @param lgb.model A trained lightgbm model 5 | #' @param explainer The output from the buildExplainer function, for this model 6 | #' @param data A DMatrix of data to be explained 7 | #' @return A data table where each row is an observation in the data and each column is the impact of each feature on the prediction. 8 | #' 9 | #' The sum of the row equals the prediction of the lightgbm model for this observation (log-odds if binary response). 10 | #' 11 | #' @export 12 | #' @import data.table 13 | #' @import lightgbm 14 | #' @examples 15 | #' library(lightgbm) # v2.1.0 or above 16 | #' library(lightgbmExplainer) 17 | #' 18 | #' # Load Data 19 | #' data(agaricus.train, package = "lightgbm") 20 | #' # Train a model 21 | #' lgb.dtrain <- lgb.Dataset(agaricus.train$data, label = agaricus.train$label) 22 | #' lgb.params <- list(objective = "binary") 23 | #' lgb.model <- lgb.train(lgb.params, lgb.dtrain, 5) 24 | #' # Build Explainer 25 | #' lgb.trees <- lgb.model.dt.tree(lgb.model) # First get a lgb tree 26 | #' explainer <- buildExplainer(lgb.trees) 27 | #' # compute contribution for each data point 28 | #' pred.breakdown <- explainPredictions(lgb.model, explainer, agaricus.train$data) 29 | #' # Show waterfall for the 8th observation 30 | #' showWaterfall(lgb.model, explainer, lgb.dtrain, agaricus.train$data, 8, type = "binary") 31 | 32 | explainPredictions = function(lgb.model, explainer ,data){ 33 | 34 | #Accepts data table of the breakdown for each leaf of each tree and the node matrix 35 | #Returns the breakdown for each prediction as a data table 36 | 37 | nodes = predict(lgb.model,data,predleaf =TRUE) 38 | 39 | colnames = names(explainer)[1:(ncol(explainer)-2)] 40 | 41 | preds_breakdown = data.table(matrix(0,nrow = nrow(nodes), ncol = length(colnames))) 42 | setnames(preds_breakdown, colnames) 43 | 44 | num_trees = ncol(nodes) 45 | 46 | cat('\n\nExtracting the breakdown of each prediction...\n') 47 | pb <- txtProgressBar(style=3) 48 | for (x in 1:num_trees){ 49 | nodes_for_tree = nodes[,x] 50 | tree_breakdown = explainer[tree==x-1] 51 | 52 | preds_breakdown_for_tree = tree_breakdown[match(nodes_for_tree, tree_breakdown$leaf),] 53 | preds_breakdown = preds_breakdown + preds_breakdown_for_tree[,colnames,with=FALSE] 54 | 55 | setTxtProgressBar(pb, x / num_trees) 56 | } 57 | 58 | cat('\n\nDONE!\n') 59 | 60 | return (preds_breakdown) 61 | 62 | } 63 | -------------------------------------------------------------------------------- /R/findPath.R: -------------------------------------------------------------------------------- 1 | 2 | #' @import data.table 3 | #' @import lightgbm 4 | findPath = function(currentnode, index, parent, path = c()){ 5 | # print(currentnode) 6 | # print(index) 7 | # print(parent) 8 | #accepts a tree data table, and the node to reach 9 | #path is used in the recursive function - do not set this 10 | 11 | while(currentnode!=0){ 12 | path = c(currentnode, path) 13 | currentnode = parent[index==currentnode] 14 | } 15 | # print(c(0,path)) 16 | return(c(0,path)) 17 | 18 | } 19 | 20 | -------------------------------------------------------------------------------- /R/getStatsForTrees.R: -------------------------------------------------------------------------------- 1 | 2 | #' @import data.table 3 | #' @import lightgbm 4 | #' @importFrom purrr walk map walk2 5 | getStatsForTrees = function(lgb_tree){ 6 | #Accepts data table of tree (the output of lgb.model.dt.tree) 7 | #Returns a list of tree, with the stats filled in 8 | #weight equal internal_value in lgb tree 9 | #Assumption: the trees are ordered from left to right 10 | 11 | lgb_tree_with_stat = copy(lgb_tree) 12 | # using native index representation 13 | lgb_tree_with_stat[, index := ifelse(is.na(leaf_index), split_index, -leaf_index-1)] 14 | lgb_tree_with_stat[, parent := ifelse(is.na(node_parent), leaf_parent, node_parent)] 15 | lgb_tree_with_stat[, weight := ifelse(is.na(internal_value), leaf_value, internal_value)] 16 | 17 | # 0 - left child, 1 - right child 18 | # Must not sort at this step 19 | # TODO: handle default_left 20 | lgb_tree_with_stat[,child_type := 0:(.N-1L), by = .(tree_index, parent)] 21 | 22 | lgb_tree_with_stat <- merge(lgb_tree_with_stat, 23 | lgb_tree_with_stat[!is.na(split_index)][,.(tree_index, 24 | parent = split_index, 25 | previous_weight = weight, 26 | previous_feature = split_feature, 27 | previous_threshold = threshold, 28 | previous_decision_type = decision_type)], 29 | all.x = T, sort = F) 30 | lgb_tree_with_stat[,previous_decision_type := ifelse(child_type == 0, 31 | previous_decision_type, 32 | ifelse(previous_decision_type == "<=", 33 | ">", "!="))] 34 | lgb_tree_with_stat[,uplift_weight := weight - previous_weight] 35 | 36 | return (lgb_tree_with_stat) 37 | } 38 | -------------------------------------------------------------------------------- /R/getTreeBreakdown.R: -------------------------------------------------------------------------------- 1 | 2 | #' @import data.table 3 | #' @import lightgbm 4 | getTreeBreakdown = function(tree, col_names){ 5 | 6 | ####accepts a tree (data table), and column names 7 | ####outputs a data table, of the impact of each variable + intercept, for each leaf 8 | 9 | tree_breakdown <- 10 | setNames(data.table(matrix(nrow = 0, ncol = length(col_names) + 2)), 11 | c(col_names,'intercept', 'leaf')) 12 | 13 | temp <- copy(tree) 14 | temp[,path:=purrr::map(index, findPath, index, parent)] 15 | temp <- data.table(merge(tidyr::unnest(temp[index <0, .(leaf = -index-1, path)]), 16 | temp[, .(path = index, previous_feature,uplift_weight)], 17 | all.x = T, sort = F)) 18 | temp <- temp[!is.na(previous_feature), 19 | .(uplift_weight = sum(uplift_weight)), 20 | by =.(leaf, previous_feature)] 21 | temp <- dcast(temp, formula = leaf ~ previous_feature, 22 | value.var = "uplift_weight", fill = 0) 23 | 24 | tree_breakdown = rbindlist(list(tree_breakdown, temp), use.names = T, fill = TRUE) 25 | 26 | return (tree_breakdown) 27 | } 28 | -------------------------------------------------------------------------------- /R/showWaterfall.R: -------------------------------------------------------------------------------- 1 | #' Step 3: Get prediction breakdown and waterfall chart for a single row of data 2 | #' 3 | #' This function prints the feature impact breakdown for a single data row, and plots an accompanying waterfall chart. 4 | #' @param lgb.model A trained lightgbm model 5 | #' @param explainer The output from the buildExplainer function, for this model 6 | #' @param lgb.dtrain The lgb.dtrain in which the row to be predicted is stored 7 | #' @param lgb.train.data The matrix of data from which the lgb.dtrain was built 8 | #' @param idx The row number of the data to be explained 9 | #' @param type The objective function of the model - either "binary" (for binary:logistic) or "regression" (for reg:linear) 10 | #' @param threshold Default = 0.0001. The waterfall chart will group all variables with absolute impact less than the threshold into a variable called 'Other' 11 | #' @return None 12 | #' @export 13 | #' @import data.table 14 | #' @import lightgbm 15 | #' @import waterfalls 16 | #' @import scales 17 | #' @import ggplot2 18 | 19 | showWaterfall = function(lgb.model, explainer, lgb.dtrain, lgb.train.data, id, type = "binary", threshold = 0.0001){ 20 | 21 | 22 | breakdown = explainPredictions(lgb.model, explainer, lgb.train.data[id,,drop=FALSE]) 23 | 24 | weight = rowSums(breakdown) 25 | if (type == 'regression'){ 26 | pred = weight 27 | }else{ 28 | pred = 1/(1+exp(-weight)) 29 | } 30 | 31 | 32 | breakdown_summary = as.matrix(breakdown)[1,] 33 | data_for_label = lgb.train.data[id,] 34 | 35 | idx = order(abs(breakdown_summary),decreasing=TRUE) 36 | breakdown_summary = breakdown_summary[idx] 37 | data_for_label = data_for_label[idx] 38 | 39 | intercept = breakdown_summary[names(breakdown_summary)=='intercept'] 40 | data_for_label = data_for_label[names(breakdown_summary)!='intercept'] 41 | breakdown_summary = breakdown_summary[names(breakdown_summary)!='intercept'] 42 | 43 | idx_other =which(abs(breakdown_summary) 0)){ 47 | other_impact = sum(breakdown_summary[idx_other]) 48 | names(other_impact) = 'other' 49 | breakdown_summary = breakdown_summary[-idx_other] 50 | data_for_label = data_for_label[-idx_other] 51 | } 52 | 53 | if (abs(other_impact) > 0){ 54 | breakdown_summary = c(intercept, breakdown_summary, other_impact) 55 | data_for_label = c("", data_for_label,"") 56 | labels = paste0(names(breakdown_summary)," = ", data_for_label) 57 | labels[1] = 'intercept' 58 | labels[length(labels)] = 'other' 59 | }else{ 60 | breakdown_summary = c(intercept, breakdown_summary) 61 | data_for_label = c("", data_for_label) 62 | labels = paste0(names(breakdown_summary)," = ", data_for_label) 63 | labels[1] = 'intercept' 64 | } 65 | 66 | 67 | if (!is.null(lgb.dtrain)){ 68 | if (!is.null(lgb.dtrain$getinfo("label")[id])){ 69 | cat("\nActual: ", lgb.dtrain$getinfo("label")[id]) 70 | } 71 | } 72 | cat("\nPrediction: ", pred) 73 | cat("\nWeight: ", weight) 74 | cat("\nBreakdown") 75 | cat('\n') 76 | print(breakdown_summary) 77 | 78 | if (type == 'regression'){ 79 | 80 | waterfall(values = round(breakdown_summary,2), labels = labels 81 | , calc_total = TRUE 82 | , total_axis_text = "Prediction") + theme(axis.text.x = element_text(angle = 45, hjust = 1)) 83 | }else{ 84 | 85 | inverse_logit_trans <- trans_new("inverse logit", 86 | transform = plogis, 87 | inverse = qlogis) 88 | 89 | inverse_logit_labels = function(x){return (1/(1+exp(-x)))} 90 | logit = function(x){return(log(x/(1-x)))} 91 | 92 | ybreaks<-logit(seq(2,98,2)/100) 93 | 94 | waterfall(values = round(breakdown_summary,2), labels = labels 95 | , calc_total = TRUE 96 | , total_axis_text = "Prediction") + scale_y_continuous(labels = inverse_logit_labels, breaks = ybreaks) + theme(axis.text.x = element_text(angle = 45, hjust = 1)) 97 | 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/lantanacamara/lightgbmExplainer.svg?branch=master)](https://travis-ci.org/lantanacamara/lightgbmExplainer) 2 | 3 | ## lightgbmExplainer 4 | An R package that makes LightGBM models fully interpretable 5 | 6 | ### Example 7 | ``` 8 | library(lightgbm) # v2.1.0 or above 9 | library(lightgbmExplainer) 10 | 11 | # Load Data 12 | data(agaricus.train, package = "lightgbm") 13 | # Train a model 14 | lgb.dtrain <- lgb.Dataset(agaricus.train$data, label = agaricus.train$label) 15 | lgb.params <- list(objective = "binary") 16 | lgb.model <- lgb.train(lgb.params, lgb.dtrain, 5) 17 | # Build Explainer 18 | lgb.trees <- lgb.model.dt.tree(lgb.model) # First get a lgb tree 19 | explainer <- buildExplainer(lgb.trees) 20 | # compute contribution for each data point 21 | pred.breakdown <- explainPredictions(lgb.model, explainer, agaricus.train$data) 22 | # Show waterfall for the 8th observation 23 | showWaterfall(lgb.model, explainer, lgb.dtrain, agaricus.train$data, 8, type = "binary") 24 | ``` 25 | 26 | Take reference from [xgboostExplainer](https://github.com/AppliedDataSciencePartners/xgboostExplainer) and credit to David Foster. 27 | 28 | Note: LightGBM provides similar function *lgb.interprete* and *lgb.plot.interpretation*. *lgb.interprete* could be faster if you only want to interprete a few data point, but it could be much slower if you want to interprete many data point. 29 | -------------------------------------------------------------------------------- /example/example.R: -------------------------------------------------------------------------------- 1 | library(data.table) 2 | 3 | ################## lightgbmExplainer Exaxmple ################## 4 | library(lightgbm) # v2.1.0 or above 5 | library(lightgbmExplainer) 6 | 7 | # Load Data 8 | data(agaricus.train, package = "lightgbm") 9 | # Train a model 10 | lgb.dtrain <- lgb.Dataset(agaricus.train$data, label = agaricus.train$label) 11 | lgb.params <- list(objective = "binary") 12 | lgb.model <- lgb.train(lgb.params, lgb.dtrain, 5) 13 | # Build Explainer 14 | lgb.trees <- lgb.model.dt.tree(lgb.model) # First get a lgb tree 15 | explainer <- buildExplainer(lgb.trees) 16 | # compute contribution for each data point 17 | pred.breakdown <- explainPredictions(lgb.model, explainer, agaricus.train$data) 18 | # Show waterfall for the 8th observation 19 | showWaterfall(lgb.model, explainer, lgb.dtrain, agaricus.train$data, 8, type = "binary") 20 | # Check if the result is correct 21 | compare <- merge( 22 | melt(pred.breakdown[8], measure.vars = colnames(pred.breakdown), 23 | variable.name = "Feature", value.name="Explainer"), 24 | lgb.interprete(lgb.model,agaricus.train$data,8), 25 | sort = F) 26 | all(with(compare, Explainer == Contribution)) 27 | 28 | ################## XGBoost Explainer Exaxmple ################## 29 | 30 | library(xgboost) 31 | library(xgboostExplainer) 32 | 33 | xgb.train <- agaricus.train 34 | xgb.train.data <- xgb.DMatrix(as.matrix(xgb.train$data), label = xgb.train$label) 35 | xgb.param <- list(objective = "binary:logistic") 36 | xgb.model <- xgboost(param =xgb.param, data = xgb.train.data, nrounds=5) 37 | xgb.col_names = colnames(xgb.train$data) 38 | xgb.trees = xgb.model.dt.tree(xgb.col_names, model = xgb.model) 39 | #### The XGBoost Explainer example 40 | xgb.explainer = buildExplainer(xgb.model,xgb.train.data, type="binary", base_score = 0.5) 41 | xgb.pred.breakdown = explainPredictions(xgb.model, xgb.explainer, xgb.train.data) 42 | showWaterfall(xgb.model, xgb.explainer, xgb.train.data, as.matrix(xgb.train$data), 8, type = "binary") 43 | 44 | ################## Profiling with custom data (Don't Run) ################## 45 | # Profiling 46 | library(readr) 47 | target_data <- read_rds("../../temp/lightgbmexplainer/2018-04-25/target_data.rds") 48 | lgb.model <- lgb.load("../../temp/lightgbmexplainer/2018-04-25/lightgbm_v2_4_2_2.txt") 49 | library(profvis) 50 | # 8040ms/654.7MB 51 | profvis(lgb.trees <- lgb.model.dt.tree(lgb.model)) 52 | # 23380ms/3136.3MB (v0.1 benchmark - 142150ms/8134.5MB (lgb.model.dt.tree excluded)) 53 | profvis(explainer <- buildExplainer(lgb.trees)) 54 | # 14160ms / 388.0MB 55 | profvis(pred.breakdown <- explainPredictions(lgb.model, explainer, as.matrix(target_data$data))) 56 | # 16580ms/1048.0MB 57 | profvis(temp <- lgb.interprete(lgb.model,as.matrix(target_data$data),2)) 58 | # Extremely Long 59 | profvis(temp2 <- lgb.interprete(lgb.model,as.matrix(target_data$data),1:94)) 60 | 61 | -------------------------------------------------------------------------------- /lightgbmExplainer.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: No 4 | SaveWorkspace: No 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: Sweave 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | 18 | BuildType: Package 19 | PackageUseDevtools: Yes 20 | PackageInstallArgs: --no-multiarch --with-keep.source 21 | PackageRoxygenize: rd,collate,namespace 22 | -------------------------------------------------------------------------------- /man/buildExplainer.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/buildExplainer.R 3 | \name{buildExplainer} 4 | \alias{buildExplainer} 5 | \title{Step 1: Build an lightgbmExplainer} 6 | \usage{ 7 | buildExplainer(lgb_tree) 8 | } 9 | \arguments{ 10 | \item{lgb_tree}{A lightgbm.dt.tree} 11 | } 12 | \value{ 13 | The lightgbm Explainer for the model. This is a data table where each row is a leaf of a tree in the lightgbm model 14 | and each column is the impact of each feature on the prediction at the leaf. 15 | 16 | The leaf and tree columns uniquely identify the node. 17 | 18 | The sum of the other columns equals the prediction at the leaf (log-odds if binary response). 19 | 20 | The 'intercept' column is identical for all rows and is analogous to the intercept term in a linear / logistic regression. 21 | } 22 | \description{ 23 | This function outputs an lightgbmExplainer (a data table that stores the feature impact breakdown for each leaf of each tree in an lightgbm model). It is required as input into the explainPredictions and showWaterfall functions. 24 | } 25 | \examples{ 26 | library(lightgbm) # v2.1.0 or above 27 | library(lightgbmExplainer) 28 | 29 | # Load Data 30 | data(agaricus.train, package = "lightgbm") 31 | # Train a model 32 | lgb.dtrain <- lgb.Dataset(agaricus.train$data, label = agaricus.train$label) 33 | lgb.params <- list(objective = "binary") 34 | lgb.model <- lgb.train(lgb.params, lgb.dtrain, 5) 35 | # Build Explainer 36 | lgb.trees <- lgb.model.dt.tree(lgb.model) # First get a lgb tree 37 | explainer <- buildExplainer(lgb.trees) 38 | # compute contribution for each data point 39 | pred.breakdown <- explainPredictions(lgb.model, explainer, agaricus.train$data) 40 | # Show waterfall for the 8th observation 41 | showWaterfall(lgb.model, explainer, lgb.dtrain, agaricus.train$data, 8, type = "binary") 42 | } 43 | -------------------------------------------------------------------------------- /man/explainPredictions.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/explainPredictions.R 3 | \name{explainPredictions} 4 | \alias{explainPredictions} 5 | \title{Step 2: Get multiple prediction breakdowns from a trained lightgbm model} 6 | \usage{ 7 | explainPredictions(lgb.model, explainer, data) 8 | } 9 | \arguments{ 10 | \item{lgb.model}{A trained lightgbm model} 11 | 12 | \item{explainer}{The output from the buildExplainer function, for this model} 13 | 14 | \item{data}{A DMatrix of data to be explained} 15 | } 16 | \value{ 17 | A data table where each row is an observation in the data and each column is the impact of each feature on the prediction. 18 | 19 | The sum of the row equals the prediction of the lightgbm model for this observation (log-odds if binary response). 20 | } 21 | \description{ 22 | This function outputs the feature impact breakdown of a set of predictions made using an lightgbm model. 23 | } 24 | \examples{ 25 | library(lightgbm) # v2.1.0 or above 26 | library(lightgbmExplainer) 27 | 28 | # Load Data 29 | data(agaricus.train, package = "lightgbm") 30 | # Train a model 31 | lgb.dtrain <- lgb.Dataset(agaricus.train$data, label = agaricus.train$label) 32 | lgb.params <- list(objective = "binary") 33 | lgb.model <- lgb.train(lgb.params, lgb.dtrain, 5) 34 | # Build Explainer 35 | lgb.trees <- lgb.model.dt.tree(lgb.model) # First get a lgb tree 36 | explainer <- buildExplainer(lgb.trees) 37 | # compute contribution for each data point 38 | pred.breakdown <- explainPredictions(lgb.model, explainer, agaricus.train$data) 39 | # Show waterfall for the 8th observation 40 | showWaterfall(lgb.model, explainer, lgb.dtrain, agaricus.train$data, 8, type = "binary") 41 | } 42 | -------------------------------------------------------------------------------- /man/showWaterfall.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/showWaterfall.R 3 | \name{showWaterfall} 4 | \alias{showWaterfall} 5 | \title{Step 3: Get prediction breakdown and waterfall chart for a single row of data} 6 | \usage{ 7 | showWaterfall(lgb.model, explainer, lgb.dtrain, lgb.train.data, id, 8 | type = "binary", threshold = 1e-04) 9 | } 10 | \arguments{ 11 | \item{lgb.model}{A trained lightgbm model} 12 | 13 | \item{explainer}{The output from the buildExplainer function, for this model} 14 | 15 | \item{lgb.dtrain}{The lgb.dtrain in which the row to be predicted is stored} 16 | 17 | \item{lgb.train.data}{The matrix of data from which the lgb.dtrain was built} 18 | 19 | \item{type}{The objective function of the model - either "binary" (for binary:logistic) or "regression" (for reg:linear)} 20 | 21 | \item{threshold}{Default = 0.0001. The waterfall chart will group all variables with absolute impact less than the threshold into a variable called 'Other'} 22 | 23 | \item{idx}{The row number of the data to be explained} 24 | } 25 | \value{ 26 | None 27 | } 28 | \description{ 29 | This function prints the feature impact breakdown for a single data row, and plots an accompanying waterfall chart. 30 | } 31 | --------------------------------------------------------------------------------