├── .github ├── .gitignore └── workflows │ └── R-CMD-check.yaml ├── vignettes ├── .gitignore ├── vignette.bib └── use-tidytreatment-BART.Rmd ├── LICENSE ├── data ├── bartmodel1.rda ├── suhillsim1.rda ├── highDim_testdataset3.rda └── bartmodel1_modelmatrix.rda ├── tests ├── testthat.R └── testthat │ ├── test-fitted.R │ ├── test-residuals.R │ ├── test-counter-factuals.R │ ├── test-predict.R │ └── test-treatment-effects.R ├── R ├── generics-bartMachine.R ├── tidytreatment-package.R ├── helper.R ├── print.R ├── tidy-variance.R ├── data.R ├── common-support.R ├── average-treatment-effects-posterior.R ├── covariate-importance.R ├── tree-extract-BART.R ├── tidy-posterior-bartMachine.R ├── treatment-effects-posterior.R ├── simulate-su-hill.R └── tidy-posterior-BART.R ├── .gitignore ├── .Rbuildignore ├── NEWS.md ├── tidytreatment.Rproj ├── man ├── tidytreatment.Rd ├── has_tidytreatment_methods.Rd ├── bartmodel1_modelmatrix.Rd ├── bartmodel1.Rd ├── covariate_importance.Rd ├── variance_draws.Rd ├── covariate_with_treatment_importance.Rd ├── suhillsim1.Rd ├── residual_draws_BART.Rd ├── tidy_ate.Rd ├── tidy_att.Rd ├── fitted_draws.lbart.Rd ├── fitted_draws.mbart.Rd ├── fitted_draws.pbart.Rd ├── fitted_draws.wbart.Rd ├── fitted_draws.mbart2.Rd ├── fitted_draws.bartMachine.Rd ├── residual_draws.bartMachine.Rd ├── residual_draws.pbart.Rd ├── residual_draws.wbart.Rd ├── fitted_draws_BART.Rd ├── posterior_trees_BART.Rd ├── predicted_draws.bartMachine.Rd ├── predicted_draws.wbart.Rd ├── predicted_draws_BART.Rd ├── has_common_support.Rd ├── treatment_effects.Rd ├── avg_treatment_effects.Rd ├── highDim_testdataset3.Rd ├── treatment_effects.default.Rd └── simulate_su_hill_data.Rd ├── LICENSE.md ├── DESCRIPTION ├── README.md ├── data-raw └── suhillsim1_bartmodel1.R ├── NAMESPACE └── examples └── use-tidytreatment-bartMachine.Rmd /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /vignettes/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | *.R 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2019 2 | COPYRIGHT HOLDER: Joshua J Bon 3 | -------------------------------------------------------------------------------- /data/bartmodel1.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjskay/tidytreatment/master/data/bartmodel1.rda -------------------------------------------------------------------------------- /data/suhillsim1.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjskay/tidytreatment/master/data/suhillsim1.rda -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(tidytreatment) 3 | 4 | test_check("tidytreatment") 5 | -------------------------------------------------------------------------------- /R/generics-bartMachine.R: -------------------------------------------------------------------------------- 1 | #' @export 2 | model.matrix.bartMachine <- function(object, ...) { 3 | object$X 4 | } 5 | -------------------------------------------------------------------------------- /data/highDim_testdataset3.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjskay/tidytreatment/master/data/highDim_testdataset3.rda -------------------------------------------------------------------------------- /data/bartmodel1_modelmatrix.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjskay/tidytreatment/master/data/bartmodel1_modelmatrix.rda -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | Meta 2 | doc 3 | inst/doc 4 | .Rproj.user 5 | .Rhistory 6 | .RData 7 | .Ruserdata 8 | *_cache/ 9 | /doc/ 10 | /Meta/ 11 | -------------------------------------------------------------------------------- /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^Meta$ 2 | ^doc$ 3 | ^LICENSE\.md$ 4 | ^.*\.Rproj$ 5 | ^\.Rproj\.user$ 6 | ^data-raw$ 7 | ^examples$ 8 | ^\.github$ 9 | ^cran-comments\.md$ 10 | ^CRAN-RELEASE$ 11 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # tidytreatment 2 | 3 | ## tidytreatment 0.2.0 4 | 5 | * First CRAN submission 6 | 7 | * Implementation of (average) treatment effects, common support, tidy posterior draws, and (simple) covariate importance functions. 8 | 9 | * For fitted models from the BART and bartMachine packages. 10 | -------------------------------------------------------------------------------- /tidytreatment.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: Default 4 | SaveWorkspace: Default 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: Sweave 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | 18 | BuildType: Package 19 | PackageUseDevtools: Yes 20 | PackageInstallArgs: --no-multiarch --with-keep.source 21 | -------------------------------------------------------------------------------- /man/tidytreatment.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidytreatment-package.R 3 | \docType{package} 4 | \name{tidytreatment} 5 | \alias{tidytreatment} 6 | \title{tidytreatment: Tidy methods for Bayesian treatment effect models} 7 | \description{ 8 | tidytreatment provides functions for extracting tidy data from Bayesian treatment effect models, estimating treatment effects, and plotting useful summaries of these. 9 | } 10 | -------------------------------------------------------------------------------- /man/has_tidytreatment_methods.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/helper.R 3 | \name{has_tidytreatment_methods} 4 | \alias{has_tidytreatment_methods} 5 | \title{Check if a model class has required generic methods for tidytreatment functions.} 6 | \usage{ 7 | has_tidytreatment_methods(model) 8 | } 9 | \arguments{ 10 | \item{model}{Model to be checked.} 11 | } 12 | \value{ 13 | Boolean 14 | } 15 | \description{ 16 | Check if a model class has required generic methods for tidytreatment functions. 17 | } 18 | -------------------------------------------------------------------------------- /man/bartmodel1_modelmatrix.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{bartmodel1_modelmatrix} 5 | \alias{bartmodel1_modelmatrix} 6 | \title{Model matrix used for \code{bartmodel1}} 7 | \format{ 8 | Object of type \code{BART::wbart} 9 | } 10 | \source{ 11 | \url{https://github.com/bonStats/tidytreatment/tree/master/data-raw} 12 | } 13 | \usage{ 14 | bartmodel1_modelmatrix 15 | } 16 | \description{ 17 | Useful for testing tidytreatment package functions. 18 | } 19 | \keyword{datasets} 20 | -------------------------------------------------------------------------------- /R/tidytreatment-package.R: -------------------------------------------------------------------------------- 1 | #' tidytreatment: Tidy methods for Bayesian treatment effect models 2 | #' 3 | #' tidytreatment provides functions for extracting tidy data from Bayesian treatment effect models, estimating treatment effects, and plotting useful summaries of these. 4 | #' 5 | #' @docType package 6 | #' @name tidytreatment 7 | #' @importFrom tidybayes fitted_draws predicted_draws residual_draws add_predicted_draws add_fitted_draws add_residual_draws 8 | #' @importFrom stats rnorm predict terms 9 | #' @importFrom rlang := !! .data 10 | #' @importFrom utils methods 11 | #' 12 | NULL 13 | -------------------------------------------------------------------------------- /man/bartmodel1.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{bartmodel1} 5 | \alias{bartmodel1} 6 | \title{Example model 1} 7 | \format{ 8 | Object of type \code{BART::wbart} 9 | } 10 | \source{ 11 | \url{https://github.com/bonStats/tidytreatment/tree/master/data-raw} 12 | } 13 | \usage{ 14 | bartmodel1 15 | } 16 | \description{ 17 | Model fit with simulated data from simulated dataset \code{suhillsim1}. 18 | } 19 | \details{ 20 | Propensity score estimated and included \code{suhillsim1} for fitting the model. 21 | } 22 | \keyword{datasets} 23 | -------------------------------------------------------------------------------- /man/covariate_importance.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/covariate-importance.R 3 | \name{covariate_importance} 4 | \alias{covariate_importance} 5 | \title{Counts of variable overall inclusion} 6 | \usage{ 7 | covariate_importance(model, ...) 8 | } 9 | \arguments{ 10 | \item{model}{Model} 11 | 12 | \item{...}{Arguments to pass to particular methods.} 13 | } 14 | \value{ 15 | Tidy data with counts of variable inclusion, when interacting with treatment variable. 16 | } 17 | \description{ 18 | Inclusion metric for bartMachine and BART are scaled differently. 19 | bartMachine averaged over number of trees, in addition to number of MCMC draws. 20 | } 21 | -------------------------------------------------------------------------------- /man/variance_draws.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-variance.R 3 | \name{variance_draws} 4 | \alias{variance_draws} 5 | \title{Get variance draws from posterior of BART models} 6 | \usage{ 7 | variance_draws(model, value = ".sigma_sq", ...) 8 | } 9 | \arguments{ 10 | \item{model}{A model from a supported package.} 11 | 12 | \item{value}{The name of the output column for variance parameter; default \code{".sigma_sq"}.} 13 | 14 | \item{...}{Additional arguments.} 15 | } 16 | \value{ 17 | A tidy data frame (tibble) with draws of variance parameter 18 | } 19 | \description{ 20 | Models from \code{BART}-package include warm-up and skipped MCMC draws. 21 | } 22 | -------------------------------------------------------------------------------- /man/covariate_with_treatment_importance.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/covariate-importance.R 3 | \name{covariate_with_treatment_importance} 4 | \alias{covariate_with_treatment_importance} 5 | \title{Counts of variable inclusion when interacting with treatment} 6 | \usage{ 7 | covariate_with_treatment_importance(model, treatment, ...) 8 | } 9 | \arguments{ 10 | \item{model}{Model} 11 | 12 | \item{treatment}{A character string specifying the name of the treatment variable.} 13 | 14 | \item{...}{Arguments to pass to particular methods.} 15 | } 16 | \value{ 17 | Tidy data with counts of variable inclusion, when interacting with treatment variable. 18 | } 19 | \description{ 20 | Counts of variable inclusion when interacting with treatment 21 | } 22 | -------------------------------------------------------------------------------- /tests/testthat/test-fitted.R: -------------------------------------------------------------------------------- 1 | # context("Fitted") deprecated 2 | 3 | library(BART) 4 | library(dplyr) 5 | library(tidyr) 6 | 7 | # rows = MCMC samples, cols = observations 8 | check_matrix <- bartmodel1$yhat.train 9 | colnames(check_matrix) <- 1:ncol(check_matrix) 10 | check_df <- check_matrix %>% 11 | as_tibble() %>% 12 | mutate(.draw = 1:n()) %>% 13 | pivot_longer( 14 | cols = all_of(1:ncol(check_matrix)), 15 | names_to = ".row", 16 | values_to = "fitted_check" 17 | ) %>% 18 | mutate(.row = as.integer(.row)) 19 | 20 | test_that("Fitted values calculated correctly", { 21 | td_fd <- fitted_draws(bartmodel1, newdata = suhillsim1$data, include_newdata = FALSE, value = "fitted") 22 | comp_df <- td_fd %>% full_join(check_df, by = c(".row", ".draw")) 23 | expect_equal(comp_df$fitted, comp_df$fitted_check) 24 | }) 25 | -------------------------------------------------------------------------------- /man/suhillsim1.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{suhillsim1} 5 | \alias{suhillsim1} 6 | \title{Example simulated dataset 1} 7 | \format{ 8 | See \code{?simulate_su_hill_data} for output format. 9 | } 10 | \source{ 11 | \url{https://github.com/bonStats/tidytreatment/tree/master/data-raw} 12 | } 13 | \usage{ 14 | suhillsim1 15 | } 16 | \description{ 17 | Simulated with \code{simulate_su_hill_data(...)}, see details. 18 | Includes propensity score estimated using BART (\code{prop_score}), see source. 19 | } 20 | \details{ 21 | \preformatted{set.seed(101) 22 | suhillsim1 <- simulate_su_hill_data(n = 100, treatment_linear = FALSE, omega = 0, add_categorical = TRUE, 23 | coef_categorical_treatment = c(0,0,1), 24 | coef_categorical_nontreatment = c(-1,0,-1)) 25 | } 26 | } 27 | \keyword{datasets} 28 | -------------------------------------------------------------------------------- /tests/testthat/test-residuals.R: -------------------------------------------------------------------------------- 1 | # context("Residuals") deprecated 2 | 3 | library(BART) 4 | library(dplyr) 5 | library(tidyr) 6 | 7 | # rows = MCMC samples, cols = observations 8 | smpls <- nrow(bartmodel1$yhat.train) 9 | check_matrix <- matrix(rep(suhillsim1$data$y, smpls), nrow = smpls, byrow = TRUE) - bartmodel1$yhat.train 10 | colnames(check_matrix) <- 1:ncol(check_matrix) 11 | check_df <- check_matrix %>% 12 | as_tibble() %>% 13 | mutate(.draw = 1:n()) %>% 14 | pivot_longer( 15 | cols = all_of(1:ncol(check_matrix)), 16 | names_to = ".row", 17 | values_to = "resid_check" 18 | ) %>% 19 | mutate(.row = as.integer(.row)) 20 | 21 | test_that("Residual values calculated correctly", { 22 | td_fd <- residual_draws(bartmodel1, newdata = suhillsim1$data, response = suhillsim1$data$y, include_newdata = FALSE, value = "resid") 23 | comp_df <- td_fd %>% full_join(check_df, by = c(".row", ".draw")) 24 | expect_equal(comp_df$resid, comp_df$resid_check) 25 | }) 26 | -------------------------------------------------------------------------------- /R/helper.R: -------------------------------------------------------------------------------- 1 | #' Check if a model class has required generic methods for tidytreatment functions. 2 | #' 3 | #' @param model Model to be checked. 4 | #' 5 | #' @return Boolean 6 | #' @export 7 | #' 8 | has_tidytreatment_methods <- function(model) { 9 | all( 10 | c("fitted_draws", "model.matrix") %in% attr(utils::methods(class = class(model)), "info")$generic 11 | ) 12 | } 13 | 14 | 15 | is_01_integer_vector <- function(x) { 16 | class(x) == "integer" & all(x %in% c(0, 1)) 17 | } 18 | 19 | has_installed_package <- function(package) { 20 | length(find.package(package, quiet = TRUE)) >= 1 21 | } 22 | 23 | has_method_str <- function(cl, method) { 24 | mth <- methods(class = cl) 25 | method %in% attr(mth, "info")[, "generic"] 26 | } 27 | 28 | check_method <- function(x, method, helper = "") { 29 | x_cl <- class(x) 30 | if (!has_method_str(x_cl, method)) { 31 | stop("Object of class '", x_cl, "' does not have method '", method, "'.\n", helper, call. = FALSE) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /tests/testthat/test-counter-factuals.R: -------------------------------------------------------------------------------- 1 | # context("counter factuals") deprecated 2 | 3 | library(BART) 4 | library(dplyr) 5 | library(tidyr) 6 | 7 | # set up counter factual values 8 | md_cf <- bartmodel1_modelmatrix 9 | md_cf[, "z"] <- 1 - md_cf[, "z"] # 0 -> 1, 1 -> 0 10 | 11 | # rows = MCMC samples, cols = observations 12 | check_matrix <- predict(bartmodel1, newdata = md_cf) 13 | colnames(check_matrix) <- 1:ncol(check_matrix) 14 | check_df <- check_matrix %>% 15 | as_tibble() %>% 16 | mutate(.draw = 1:n()) %>% 17 | pivot_longer( 18 | cols = all_of(1:ncol(check_matrix)), 19 | names_to = ".row", 20 | values_to = "cf_check" 21 | ) %>% 22 | mutate(.row = as.integer(.row)) 23 | 24 | test_that("Counter factuals calculated correctly", { 25 | td_cf <- tidytreatment:::fitted_with_counter_factual_draws(bartmodel1, treatment = "z", newdata = suhillsim1$data, subset = "all") 26 | comp_df <- td_cf %>% full_join(check_df, by = c(".row", ".draw")) 27 | expect_equal(comp_df$cfactual, comp_df$cf_check) 28 | }) 29 | -------------------------------------------------------------------------------- /man/residual_draws_BART.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-posterior-BART.R 3 | \name{residual_draws_BART} 4 | \alias{residual_draws_BART} 5 | \title{Get residual draw for BART model} 6 | \usage{ 7 | residual_draws_BART( 8 | object, 9 | response, 10 | newdata = NULL, 11 | value = ".residual", 12 | include_newdata = TRUE, 13 | include_sigsqs = FALSE 14 | ) 15 | } 16 | \arguments{ 17 | \item{object}{model from \code{BART} package.} 18 | 19 | \item{response}{Original response vector.} 20 | 21 | \item{newdata}{Data frame to generate predictions from. If omitted, original data used to fit the model.} 22 | 23 | \item{value}{Name of the output column for residual_draws; default is \code{.residual}.} 24 | 25 | \item{include_newdata}{Should the newdata be included in the tibble?} 26 | 27 | \item{include_sigsqs}{Should the posterior sigma-squared draw be included?} 28 | } 29 | \value{ 30 | Tibble with residuals. 31 | } 32 | \description{ 33 | Classes from \code{BART}-package models 34 | } 35 | -------------------------------------------------------------------------------- /man/tidy_ate.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/average-treatment-effects-posterior.R 3 | \name{tidy_ate} 4 | \alias{tidy_ate} 5 | \title{Get average treatment effect draws from posterior} 6 | \usage{ 7 | tidy_ate(model, treatment, common_support_method, cutoff, ...) 8 | } 9 | \arguments{ 10 | \item{model}{A supported Bayesian model fit that can provide fits and predictions.} 11 | 12 | \item{treatment}{A character string specifying the name of the treatment variable.} 13 | 14 | \item{common_support_method}{Either "sd", or "chisq". Default is unspecified, and no common support calculation is done.} 15 | 16 | \item{cutoff}{Cutoff for common support (if in use).} 17 | 18 | \item{...}{Arguments to be passed to \code{tidybayes::fitted_draws} typically scale for \code{BART} models.} 19 | } 20 | \value{ 21 | A tidy data frame (tibble) with treatment effect values. 22 | } 23 | \description{ 24 | ATE = Average Treatment Effects 25 | Assumes treated column is either a integer column of 1's (treated) and 0's (nontreated) or logical indicating treatment if TRUE. 26 | } 27 | -------------------------------------------------------------------------------- /R/print.R: -------------------------------------------------------------------------------- 1 | #' @export 2 | print.wbart <- function(x, ...) { 3 | cmps <- paste0("\t$", names(x), collapse = "\n") 4 | cat("\nBART::wbart with", length(x$yhat.train.mean), "samples\n") 5 | cat("components:\n", cmps) 6 | } 7 | 8 | #' @export 9 | print.pbart <- function(x, ...) { 10 | cmps <- paste0("\t$", names(x), collapse = "\n") 11 | cat("\nBART::pbart with", length(x$yhat.train.mean), "samples\n") 12 | cat("components:\n", cmps) 13 | } 14 | 15 | #' @export 16 | print.lbart <- function(x, ...) { 17 | cmps <- paste0("\t$", names(x), collapse = "\n") 18 | cat("\nBART:lbart with", length(x$yhat.train.mean), "samples\n") 19 | cat("components:\n", cmps) 20 | } 21 | 22 | #' @export 23 | print.mbart <- function(x, ...) { 24 | cmps <- paste0("\t$", names(x), collapse = "\n") 25 | cat("\nBART::mbart with", length(x$yhat.train.mean), "samples\n") 26 | cat("components:\n", cmps) 27 | } 28 | 29 | #' @export 30 | print.mbart2 <- function(x, ...) { 31 | cmps <- paste0("\t$", names(x), collapse = "\n") 32 | cat("\nBART::mbart2 with", length(x$yhat.train.mean), "samples\n") 33 | cat("components:\n", cmps) 34 | } 35 | -------------------------------------------------------------------------------- /man/tidy_att.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/average-treatment-effects-posterior.R 3 | \name{tidy_att} 4 | \alias{tidy_att} 5 | \title{Get average treatment effect on treated draws from posterior} 6 | \usage{ 7 | tidy_att(model, treatment, common_support_method, cutoff, ...) 8 | } 9 | \arguments{ 10 | \item{model}{A supported Bayesian model fit that can provide fits and predictions.} 11 | 12 | \item{treatment}{A character string specifying the name of the treatment variable.} 13 | 14 | \item{common_support_method}{Either "sd", or "chisq". Default is unspecified, and no common support calculation is done.} 15 | 16 | \item{cutoff}{Cutoff for common support (if in use).} 17 | 18 | \item{...}{Arguments to be passed to \code{tidybayes::fitted_draws} typically scale for \code{BART} models.} 19 | } 20 | \value{ 21 | A tidy data frame (tibble) with treatment effect values. 22 | } 23 | \description{ 24 | ATT = average Treatment Effects on Treated 25 | Assumes treated column is either a integer column of 1's (treated) and 0's (nontreated) or logical indicating treatment if TRUE. 26 | } 27 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2019 Joshua J Bon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /man/fitted_draws.lbart.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-posterior-BART.R 3 | \name{fitted_draws.lbart} 4 | \alias{fitted_draws.lbart} 5 | \title{Get fitted draws from posterior of \code{lbart} model} 6 | \usage{ 7 | \method{fitted_draws}{lbart}( 8 | model, 9 | newdata, 10 | value = ".value", 11 | ..., 12 | n = NULL, 13 | include_newdata = TRUE, 14 | include_sigsqs = FALSE 15 | ) 16 | } 17 | \arguments{ 18 | \item{model}{A model from \code{BART} package.} 19 | 20 | \item{newdata}{Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model.} 21 | 22 | \item{value}{The name of the output column for \code{fitted_draws}; default \code{".value"}.} 23 | 24 | \item{...}{Not currently in use.} 25 | 26 | \item{n}{Not currently implemented.} 27 | 28 | \item{include_newdata}{Should the newdata be included in the tibble?} 29 | 30 | \item{include_sigsqs}{Should the posterior sigma-squared draw be included?} 31 | } 32 | \value{ 33 | A tidy data frame (tibble) with fitted values. 34 | } 35 | \description{ 36 | Get fitted draws from posterior of \code{lbart} model 37 | } 38 | -------------------------------------------------------------------------------- /man/fitted_draws.mbart.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-posterior-BART.R 3 | \name{fitted_draws.mbart} 4 | \alias{fitted_draws.mbart} 5 | \title{Get fitted draws from posterior of \code{mbart} model} 6 | \usage{ 7 | \method{fitted_draws}{mbart}( 8 | model, 9 | newdata, 10 | value = ".value", 11 | ..., 12 | n = NULL, 13 | include_newdata = TRUE, 14 | include_sigsqs = FALSE 15 | ) 16 | } 17 | \arguments{ 18 | \item{model}{A model from \code{BART} package.} 19 | 20 | \item{newdata}{Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model.} 21 | 22 | \item{value}{The name of the output column for \code{fitted_draws}; default \code{".value"}.} 23 | 24 | \item{...}{Not currently in use.} 25 | 26 | \item{n}{Not currently implemented.} 27 | 28 | \item{include_newdata}{Should the newdata be included in the tibble?} 29 | 30 | \item{include_sigsqs}{Should the posterior sigma-squared draw be included?} 31 | } 32 | \value{ 33 | A tidy data frame (tibble) with fitted values. 34 | } 35 | \description{ 36 | Get fitted draws from posterior of \code{mbart} model 37 | } 38 | -------------------------------------------------------------------------------- /man/fitted_draws.pbart.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-posterior-BART.R 3 | \name{fitted_draws.pbart} 4 | \alias{fitted_draws.pbart} 5 | \title{Get fitted draws from posterior of \code{pbart} model} 6 | \usage{ 7 | \method{fitted_draws}{pbart}( 8 | model, 9 | newdata, 10 | value = ".value", 11 | ..., 12 | n = NULL, 13 | include_newdata = TRUE, 14 | include_sigsqs = FALSE 15 | ) 16 | } 17 | \arguments{ 18 | \item{model}{A model from \code{BART} package.} 19 | 20 | \item{newdata}{Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model.} 21 | 22 | \item{value}{The name of the output column for \code{fitted_draws}; default \code{".value"}.} 23 | 24 | \item{...}{Not currently in use.} 25 | 26 | \item{n}{Not currently implemented.} 27 | 28 | \item{include_newdata}{Should the newdata be included in the tibble?} 29 | 30 | \item{include_sigsqs}{Should the posterior sigma-squared draw be included?} 31 | } 32 | \value{ 33 | A tidy data frame (tibble) with fitted values. 34 | } 35 | \description{ 36 | Get fitted draws from posterior of \code{pbart} model 37 | } 38 | -------------------------------------------------------------------------------- /man/fitted_draws.wbart.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-posterior-BART.R 3 | \name{fitted_draws.wbart} 4 | \alias{fitted_draws.wbart} 5 | \title{Get fitted draws from posterior of \code{wbart} model} 6 | \usage{ 7 | \method{fitted_draws}{wbart}( 8 | model, 9 | newdata, 10 | value = ".value", 11 | ..., 12 | n = NULL, 13 | include_newdata = TRUE, 14 | include_sigsqs = FALSE 15 | ) 16 | } 17 | \arguments{ 18 | \item{model}{A model from \code{BART} package.} 19 | 20 | \item{newdata}{Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model.} 21 | 22 | \item{value}{The name of the output column for \code{fitted_draws}; default \code{".value"}.} 23 | 24 | \item{...}{Not currently in use.} 25 | 26 | \item{n}{Not currently implemented.} 27 | 28 | \item{include_newdata}{Should the newdata be included in the tibble?} 29 | 30 | \item{include_sigsqs}{Should the posterior sigma-squared draw be included?} 31 | } 32 | \value{ 33 | A tidy data frame (tibble) with fitted values. 34 | } 35 | \description{ 36 | Get fitted draws from posterior of \code{wbart} model 37 | } 38 | -------------------------------------------------------------------------------- /man/fitted_draws.mbart2.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-posterior-BART.R 3 | \name{fitted_draws.mbart2} 4 | \alias{fitted_draws.mbart2} 5 | \title{Get fitted draws from posterior of \code{mbart2} model} 6 | \usage{ 7 | \method{fitted_draws}{mbart2}( 8 | model, 9 | newdata, 10 | value = ".value", 11 | ..., 12 | n = NULL, 13 | include_newdata = TRUE, 14 | include_sigsqs = FALSE 15 | ) 16 | } 17 | \arguments{ 18 | \item{model}{A model from \code{BART} package.} 19 | 20 | \item{newdata}{Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model.} 21 | 22 | \item{value}{The name of the output column for \code{fitted_draws}; default \code{".value"}.} 23 | 24 | \item{...}{Not currently in use.} 25 | 26 | \item{n}{Not currently implemented.} 27 | 28 | \item{include_newdata}{Should the newdata be included in the tibble?} 29 | 30 | \item{include_sigsqs}{Should the posterior sigma-squared draw be included?} 31 | } 32 | \value{ 33 | A tidy data frame (tibble) with fitted values. 34 | } 35 | \description{ 36 | Get fitted draws from posterior of \code{mbart2} model 37 | } 38 | -------------------------------------------------------------------------------- /man/fitted_draws.bartMachine.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-posterior-bartMachine.R 3 | \name{fitted_draws.bartMachine} 4 | \alias{fitted_draws.bartMachine} 5 | \title{Get fitted draws from posterior of \code{bartMachine} model} 6 | \usage{ 7 | \method{fitted_draws}{bartMachine}( 8 | model, 9 | newdata, 10 | value = ".value", 11 | ..., 12 | n = NULL, 13 | include_newdata = TRUE, 14 | include_sigsqs = FALSE 15 | ) 16 | } 17 | \arguments{ 18 | \item{model}{A \code{bartMachine} model.} 19 | 20 | \item{newdata}{Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model.} 21 | 22 | \item{value}{The name of the output column for \code{fitted_draws}; default \code{".value"}.} 23 | 24 | \item{...}{Not currently in use.} 25 | 26 | \item{n}{Not currently implemented.} 27 | 28 | \item{include_newdata}{Should the newdata be included in the tibble?} 29 | 30 | \item{include_sigsqs}{Should the posterior sigma-squared draw be included?} 31 | } 32 | \value{ 33 | A tidy data frame (tibble) with fitted values. 34 | } 35 | \description{ 36 | Get fitted draws from posterior of \code{bartMachine} model 37 | } 38 | -------------------------------------------------------------------------------- /R/tidy-variance.R: -------------------------------------------------------------------------------- 1 | #' Get variance draws from posterior of BART models 2 | #' 3 | #' Models from \code{BART}-package include warm-up and skipped MCMC draws. 4 | #' 5 | #' @param model A model from a supported package. 6 | #' @param value The name of the output column for variance parameter; default \code{".sigma_sq"}. 7 | #' @param ... Additional arguments. 8 | #' 9 | #' @return A tidy data frame (tibble) with draws of variance parameter 10 | #' 11 | #' @export 12 | variance_draws <- function(model, value = ".sigma_sq", ...) { 13 | UseMethod("variance_draws") 14 | } 15 | 16 | #' @export 17 | variance_draws.wbart <- function(model, value = ".sigma_sq", ...) { 18 | sigma_draws <- model$sigma 19 | 20 | dplyr::tibble( 21 | .chain = NA_integer_, 22 | .iteration = NA_integer_, 23 | .draw = 1:length(sigma_draws), 24 | !!value := sigma_draws^2 25 | ) 26 | } 27 | 28 | #' @export 29 | variance_draws.bartMachine <- function(model, value = ".sigma_sq", ...) { 30 | sigma2_draws <- bartMachine::get_sigsqs(model) 31 | 32 | dplyr::tibble( 33 | .chain = NA_integer_, 34 | .iteration = NA_integer_, 35 | .draw = 1:length(sigma2_draws), 36 | !!value := sigma2_draws 37 | ) 38 | } 39 | -------------------------------------------------------------------------------- /man/residual_draws.bartMachine.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-posterior-bartMachine.R 3 | \name{residual_draws.bartMachine} 4 | \alias{residual_draws.bartMachine} 5 | \title{Get residual draw for \code{bartMachine} model} 6 | \usage{ 7 | \method{residual_draws}{bartMachine}( 8 | object, 9 | newdata, 10 | value = ".residual", 11 | ..., 12 | ndraws = NULL, 13 | include_newdata = TRUE, 14 | include_sigsqs = FALSE 15 | ) 16 | } 17 | \arguments{ 18 | \item{object}{\code{bartMachine} model.} 19 | 20 | \item{newdata}{Data frame to generate predictions from. If omitted, original data used to fit the model.} 21 | 22 | \item{value}{Name of the output column for residual_draws; default is \code{.residual}.} 23 | 24 | \item{...}{Additional arguments passed to the underlying prediction method for the type of model given.} 25 | 26 | \item{ndraws}{Not currently implemented.} 27 | 28 | \item{include_newdata}{Should the newdata be included in the tibble?} 29 | 30 | \item{include_sigsqs}{Should the posterior sigma-squared draw be included?} 31 | } 32 | \value{ 33 | Tibble with residuals. 34 | } 35 | \description{ 36 | Get residual draw for \code{bartMachine} model 37 | } 38 | -------------------------------------------------------------------------------- /tests/testthat/test-predict.R: -------------------------------------------------------------------------------- 1 | # context("Fitted") deprecated 2 | 3 | library(BART) 4 | library(dplyr) 5 | library(tidyr) 6 | 7 | smpl_id <- c( 8 | 30L, 81L, 97L, 44L, 52L, 43L, 34L, 89L, 88L, 87L, 93L, 14L, 9 | 68L, 17L, 8L, 45L, 85L, 66L, 94L, 35L 10 | ) 11 | 12 | pdata <- suhillsim1$data[smpl_id, ] # randomly sample some 13 | pdata_mm <- bartModelMatrix(pdata[, -1]) # remove "y" variable 14 | 15 | # rows = MCMC samples, cols = observations 16 | check_matrix <- predict(bartmodel1, newdata = pdata_mm) 17 | colnames(check_matrix) <- 1:ncol(check_matrix) 18 | check_df <- check_matrix %>% 19 | as_tibble() %>% 20 | mutate(.draw = 1:n()) %>% 21 | pivot_longer( 22 | cols = all_of(1:ncol(check_matrix)), 23 | names_to = ".row", 24 | values_to = "pred_check" 25 | ) %>% 26 | mutate(.row = as.integer(.row)) 27 | 28 | test_that("Predicted values calculated correctly", { 29 | td_pd <- predicted_draws(bartmodel1, 30 | newdata = pdata, include_newdata = FALSE, value = "pred", 31 | rng = function(n, mean, ...) { 32 | mean + 0.1 33 | } 34 | ) # random noise fixed 35 | comp_df <- td_pd %>% full_join(check_df, by = c(".row", ".draw")) 36 | expect_equal(comp_df$pred, comp_df$pred_check + 0.1) 37 | }) 38 | -------------------------------------------------------------------------------- /man/residual_draws.pbart.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-posterior-BART.R 3 | \name{residual_draws.pbart} 4 | \alias{residual_draws.pbart} 5 | \title{Get residual draw for \code{pbart} model} 6 | \usage{ 7 | \method{residual_draws}{pbart}( 8 | object, 9 | newdata, 10 | value = ".residual", 11 | ..., 12 | ndraws = NULL, 13 | include_newdata = TRUE, 14 | include_sigsqs = FALSE 15 | ) 16 | } 17 | \arguments{ 18 | \item{object}{\code{wbart} model.} 19 | 20 | \item{newdata}{Data frame to generate predictions from. If omitted, original data used to fit the model.} 21 | 22 | \item{value}{Name of the output column for residual_draws; default is \code{.residual}.} 23 | 24 | \item{...}{Additional arguments passed to the underlying prediction method for the type of model given.} 25 | 26 | \item{ndraws}{Not currently implemented.} 27 | 28 | \item{include_newdata}{Should the newdata be included in the tibble?} 29 | 30 | \item{include_sigsqs}{Should the posterior sigma-squared draw be included?} 31 | } 32 | \value{ 33 | Tibble with residuals. 34 | } 35 | \description{ 36 | The original response variable must be passed as an argument to this function. 37 | e.g. `response = y` 38 | } 39 | -------------------------------------------------------------------------------- /man/residual_draws.wbart.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-posterior-BART.R 3 | \name{residual_draws.wbart} 4 | \alias{residual_draws.wbart} 5 | \title{Get residual draw for \code{wbart} model} 6 | \usage{ 7 | \method{residual_draws}{wbart}( 8 | object, 9 | newdata, 10 | value = ".residual", 11 | ..., 12 | ndraws = NULL, 13 | include_newdata = TRUE, 14 | include_sigsqs = FALSE 15 | ) 16 | } 17 | \arguments{ 18 | \item{object}{\code{wbart} model.} 19 | 20 | \item{newdata}{Data frame to generate predictions from. If omitted, original data used to fit the model.} 21 | 22 | \item{value}{Name of the output column for residual_draws; default is \code{.residual}.} 23 | 24 | \item{...}{Additional arguments passed to the underlying prediction method for the type of model given.} 25 | 26 | \item{ndraws}{Not currently implemented.} 27 | 28 | \item{include_newdata}{Should the newdata be included in the tibble?} 29 | 30 | \item{include_sigsqs}{Should the posterior sigma-squared draw be included?} 31 | } 32 | \value{ 33 | Tibble with residuals. 34 | } 35 | \description{ 36 | The original response variable must be passed as an argument to this function. 37 | e.g. `response = y` 38 | } 39 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: tidytreatment 2 | Type: Package 3 | Title: Tidy Methods for Bayesian Treatment Effect Models 4 | Version: 0.2.0.9000 5 | Authors@R: person("Joshua J", "Bon", email = "joshuajbon@gmail.com", 6 | role = c("aut", "cre"), 7 | comment = c(ORCID = "0000-0003-2313-2949")) 8 | Description: Functions for extracting tidy data from Bayesian treatment effect models, in particular BART, but extensions are possible. Functionality includes extracting tidy posterior summaries as in 'tidybayes' , estimating (average) treatment effects, common support calculations, and plotting useful summaries of these. 9 | Encoding: UTF-8 10 | LazyData: true 11 | License: MIT + file LICENSE 12 | URL: https://github.com/bonStats/tidytreatment 13 | BugReports: https://github.com/bonStats/tidytreatment/issues 14 | Depends: R (>= 3.1.0) 15 | Suggests: 16 | knitr, 17 | rmarkdown, 18 | BART, 19 | ggplot2, 20 | testthat (>= 3.0.0), 21 | withr 22 | VignetteBuilder: knitr 23 | RoxygenNote: 7.1.1 24 | Imports: 25 | tidybayes, 26 | purrr, 27 | tidyr, 28 | dplyr, 29 | readr, 30 | rlang 31 | Enhances: 32 | bartMachine 33 | Config/testthat/edition: 3 34 | -------------------------------------------------------------------------------- /man/fitted_draws_BART.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-posterior-BART.R 3 | \name{fitted_draws_BART} 4 | \alias{fitted_draws_BART} 5 | \title{Get fitted draws from posterior of \code{BART}-package models} 6 | \usage{ 7 | fitted_draws_BART( 8 | model, 9 | newdata = NULL, 10 | value = ".value", 11 | ..., 12 | include_newdata = TRUE, 13 | include_sigsqs = FALSE, 14 | scale = "real" 15 | ) 16 | } 17 | \arguments{ 18 | \item{model}{A model from \code{BART} package.} 19 | 20 | \item{newdata}{Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model.} 21 | 22 | \item{value}{The name of the output column for \code{fitted_draws}; default \code{".value"}.} 23 | 24 | \item{...}{Arguments to pass to \code{predict} (e.g. \code{BART:::predict.wbart}).} 25 | 26 | \item{include_newdata}{Should the newdata be included in the tibble?} 27 | 28 | \item{include_sigsqs}{Should the posterior sigma-squared draw be included?} 29 | 30 | \item{scale}{Should the fitted values be on the real, probit or logit scale?} 31 | } 32 | \value{ 33 | A tidy data frame (tibble) with fitted values. 34 | } 35 | \description{ 36 | Get fitted draws from posterior of \code{BART}-package models 37 | } 38 | -------------------------------------------------------------------------------- /man/posterior_trees_BART.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tree-extract-BART.R 3 | \name{posterior_trees_BART} 4 | \alias{posterior_trees_BART} 5 | \title{Get posterior tree draws into tibble format from BART model} 6 | \usage{ 7 | posterior_trees_BART(model, label_digits = 2) 8 | } 9 | \arguments{ 10 | \item{model}{BART model.} 11 | 12 | \item{label_digits}{Rounding for labels.} 13 | } 14 | \value{ 15 | A tibble with columns to \describe{ 16 | \item{iter}{Integer describing unique MCMC iteration.} 17 | \item{tree_id}{Integer. Unique tree id with each `iter`.} 18 | \item{node}{Integer describing node in tree. Unique to each `tree`-`iter`.} 19 | \item{parent}{Integer describing parent node in tree.} 20 | \item{label}{Label for the node.} 21 | \item{tier}{Position in tree hierarchy.} 22 | \item{var}{Variable for split.} 23 | \item{cut}{Numeric. Value of decision rule for `var`.} 24 | \item{is_leaf}{Logical. `TRUE` if leaf, `FALSE` if stem.} 25 | \item{leaf_value}{} 26 | \item{child_left}{Integer. Left child of node.} 27 | \item{child_right}{Integer. Right child of node.} 28 | } 29 | } 30 | \description{ 31 | Tibble grouped by iteration (`iter`) and tree id (`tree_id`). All information 32 | calculated by method is included in output. 33 | } 34 | -------------------------------------------------------------------------------- /man/predicted_draws.bartMachine.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-posterior-bartMachine.R 3 | \name{predicted_draws.bartMachine} 4 | \alias{predicted_draws.bartMachine} 5 | \title{Get predict draws from posterior of \code{bartMachine} model} 6 | \usage{ 7 | \method{predicted_draws}{bartMachine}( 8 | object, 9 | newdata, 10 | value = ".prediction", 11 | ..., 12 | ndraws = NULL, 13 | include_newdata = TRUE, 14 | include_fitted = FALSE, 15 | include_sigsqs = FALSE 16 | ) 17 | } 18 | \arguments{ 19 | \item{object}{A \code{bartMachine} model.} 20 | 21 | \item{newdata}{Data frame to generate predictions from. If omitted, most model types will generate predictions from the data used to fit the model.} 22 | 23 | \item{value}{The name of the output column for \code{predicted_draws}; default \code{".prediction"}.} 24 | 25 | \item{...}{Not currently in use.} 26 | 27 | \item{ndraws}{Not currently implemented.} 28 | 29 | \item{include_newdata}{Should the newdata be included in the tibble?} 30 | 31 | \item{include_fitted}{Should the posterior fitted values be included in the tibble?} 32 | 33 | \item{include_sigsqs}{Should the posterior sigma-squared draw be included?} 34 | } 35 | \value{ 36 | A tidy data frame (tibble) with predicted values. 37 | } 38 | \description{ 39 | Get predict draws from posterior of \code{bartMachine} model 40 | } 41 | -------------------------------------------------------------------------------- /man/predicted_draws.wbart.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-posterior-BART.R 3 | \name{predicted_draws.wbart} 4 | \alias{predicted_draws.wbart} 5 | \title{Get predict draws from posterior of \code{wbart} model} 6 | \usage{ 7 | \method{predicted_draws}{wbart}( 8 | object, 9 | newdata, 10 | value = ".prediction", 11 | ..., 12 | ndraws = NULL, 13 | include_newdata = TRUE, 14 | include_fitted = FALSE, 15 | include_sigsqs = FALSE 16 | ) 17 | } 18 | \arguments{ 19 | \item{object}{A \code{wbart} model.} 20 | 21 | \item{newdata}{Data frame to generate predictions from. If omitted, most model types will generate predictions from the data used to fit the model.} 22 | 23 | \item{value}{The name of the output column for \code{predicted_draws}; default \code{".prediction"}.} 24 | 25 | \item{...}{Use to specify random number generator, default is \code{rng=stats::rnorm}.} 26 | 27 | \item{ndraws}{Not currently implemented.} 28 | 29 | \item{include_newdata}{Should the newdata be included in the tibble?} 30 | 31 | \item{include_fitted}{Should the posterior fitted values be included in the tibble?} 32 | 33 | \item{include_sigsqs}{Should the posterior sigma-squared draw be included?} 34 | } 35 | \value{ 36 | A tidy data frame (tibble) with predicted values. 37 | } 38 | \description{ 39 | Get predict draws from posterior of \code{wbart} model 40 | } 41 | -------------------------------------------------------------------------------- /man/predicted_draws_BART.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidy-posterior-BART.R 3 | \name{predicted_draws_BART} 4 | \alias{predicted_draws_BART} 5 | \title{Get predict draws from posterior of \code{BART}-package models} 6 | \usage{ 7 | predicted_draws_BART( 8 | object, 9 | newdata = NULL, 10 | value = ".prediction", 11 | ..., 12 | rng = stats::rnorm, 13 | include_newdata = TRUE, 14 | include_fitted = FALSE, 15 | include_sigsqs = FALSE 16 | ) 17 | } 18 | \arguments{ 19 | \item{object}{A \code{BART}-package model.} 20 | 21 | \item{newdata}{Data frame to generate predictions from. If omitted, most model types will generate predictions from the data used to fit the model.} 22 | 23 | \item{value}{The name of the output column for \code{predicted_draws}; default \code{".prediction"}.} 24 | 25 | \item{...}{Arguments to pass to \code{predict} (e.g. \code{BART:::predict.wbart}).} 26 | 27 | \item{rng}{Random number generator function. Default is \code{rnorm} for models with Gaussian errors.} 28 | 29 | \item{include_newdata}{Should the newdata be included in the tibble?} 30 | 31 | \item{include_fitted}{Should the posterior fitted values be included in the tibble?} 32 | 33 | \item{include_sigsqs}{Should the posterior sigma-squared draw be included?} 34 | } 35 | \value{ 36 | A tidy data frame (tibble) with predicted values. 37 | } 38 | \description{ 39 | Get predict draws from posterior of \code{BART}-package models 40 | } 41 | -------------------------------------------------------------------------------- /man/has_common_support.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/common-support.R 3 | \name{has_common_support} 4 | \alias{has_common_support} 5 | \title{Evaluate if observations have common support.} 6 | \usage{ 7 | has_common_support(model, treatment, method, cutoff, modeldata = NULL) 8 | } 9 | \arguments{ 10 | \item{model}{A supported Bayesian model fit that can provide fits and predictions.} 11 | 12 | \item{treatment}{A character string specifying the name of the treatment variable.} 13 | 14 | \item{method}{Method to use in determining common support. 'chisq', or 'sd'.} 15 | 16 | \item{cutoff}{Cutoff point to use for method.} 17 | 18 | \item{modeldata}{Manually provide model data for some models (e.g. from BART package)} 19 | } 20 | \value{ 21 | Tibble with a row for each observation and a column indicating whether common support exists. 22 | } 23 | \description{ 24 | The common support identification methods are based on Hill and Su (2013). 25 | Loosely speaker, an individuals treatment effect estimate has common support if the counter factual 26 | estimate is not too uncertain. The estimates are uncertain when the prediction is 'far away' from 27 | other observations. Removing estimates without common support can be beneficial for treat effect 28 | estimates. 29 | } 30 | \details{ 31 | Hill, Jennifer; Su, Yu-Sung. Ann. Appl. Stat. 7 (2013), no. 3, 1386--1420. doi:10.1214/13-AOAS630. \url{https://projecteuclid.org/euclid.aoas/1380804800} 32 | } 33 | -------------------------------------------------------------------------------- /tests/testthat/test-treatment-effects.R: -------------------------------------------------------------------------------- 1 | # context("treatment effects") deprecated 2 | 3 | library(BART) 4 | library(dplyr) 5 | library(tidyr) 6 | 7 | # set up treatment effects values 8 | md_z1 <- md_z0 <- bartmodel1_modelmatrix 9 | md_z1[, "z"] <- 1 10 | md_z0[, "z"] <- 0 11 | 12 | # rows = MCMC samples, cols = observations 13 | check_matrix <- predict(bartmodel1, newdata = md_z1) - predict(bartmodel1, newdata = md_z0) 14 | colnames(check_matrix) <- 1:ncol(check_matrix) 15 | check_teff_df <- check_matrix %>% 16 | as_tibble() %>% 17 | mutate(.draw = 1:n()) %>% 18 | pivot_longer( 19 | cols = all_of(1:ncol(check_matrix)), 20 | names_to = ".row", 21 | values_to = "cte_check" 22 | ) %>% 23 | mutate(.row = as.integer(.row)) 24 | 25 | test_that("Treatment effects calculated correctly", { 26 | td_teff <- treatment_effects(bartmodel1, treatment = "z", newdata = suhillsim1$data) 27 | comp_df <- td_teff %>% full_join(check_teff_df, by = c(".row", ".draw")) 28 | expect_equal(comp_df$cte, comp_df$cte_check) 29 | }) 30 | 31 | test_that("ATE calculated correctly", { 32 | td_ate <- tidy_ate(bartmodel1, treatment = "z", newdata = suhillsim1$data) %>% 33 | arrange(.draw) 34 | expect_equal(td_ate$ate, rowMeans(check_matrix)) # average across obs 35 | }) 36 | 37 | test_that("ATT calculated correctly", { 38 | td_att <- tidy_att(bartmodel1, treatment = "z", newdata = suhillsim1$data) %>% 39 | arrange(.draw) 40 | expect_equal(td_att$att, rowMeans(check_matrix[, bartmodel1_modelmatrix[, "z"] == 1])) 41 | }) 42 | -------------------------------------------------------------------------------- /man/treatment_effects.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/treatment-effects-posterior.R 3 | \name{treatment_effects} 4 | \alias{treatment_effects} 5 | \title{Get (individual) treatment effect draws from posterior} 6 | \usage{ 7 | treatment_effects( 8 | model, 9 | treatment, 10 | newdata, 11 | subset = "all", 12 | common_support_method, 13 | cutoff, 14 | ... 15 | ) 16 | } 17 | \arguments{ 18 | \item{model}{A supported Bayesian model fit that can provide fits and predictions.} 19 | 20 | \item{treatment}{A character string specifying the name of the treatment variable.} 21 | 22 | \item{newdata}{Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model.} 23 | 24 | \item{subset}{Either "treated", "nontreated", or "all". Default is "all".} 25 | 26 | \item{common_support_method}{Either "sd", or "chisq". Default is unspecified, and no common support calculation is done.} 27 | 28 | \item{cutoff}{Cutoff for common support (if in use).} 29 | 30 | \item{...}{Arguments to be passed to \code{tidybayes::fitted_draws} typically scale for \code{BART} models.} 31 | } 32 | \value{ 33 | A tidy data frame (tibble) with treatment effect values. 34 | } 35 | \description{ 36 | CTE = Conditional Treatment Effects (usually used to generate (C)ATE or ATT) 37 | \code{newdata} specifies the conditions, if unspecified it defaults to the original data. 38 | Assumes treated column is either a integer column of 1's (treated) and 0's (nontreated) or logical indicating treatment if TRUE. 39 | } 40 | -------------------------------------------------------------------------------- /man/avg_treatment_effects.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/average-treatment-effects-posterior.R 3 | \name{avg_treatment_effects} 4 | \alias{avg_treatment_effects} 5 | \title{Get (conditional) average treatment effect draws from posterior} 6 | \usage{ 7 | avg_treatment_effects( 8 | model, 9 | treatment, 10 | newdata, 11 | subset = "all", 12 | common_support_method, 13 | cutoff, 14 | ... 15 | ) 16 | } 17 | \arguments{ 18 | \item{model}{A supported Bayesian model fit that can provide fits and predictions.} 19 | 20 | \item{treatment}{A character string specifying the name of the treatment variable.} 21 | 22 | \item{newdata}{Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model.} 23 | 24 | \item{subset}{Either "treated", "nontreated", or "all". Default is "all".} 25 | 26 | \item{common_support_method}{Either "sd", or "chisq". Default is unspecified, and no common support calculation is done.} 27 | 28 | \item{cutoff}{Cutoff for common support (if in use).} 29 | 30 | \item{...}{Arguments to be passed to \code{tidybayes::fitted_draws} typically scale for \code{BART} models.} 31 | } 32 | \value{ 33 | A tidy data frame (tibble) with treatment effect values. 34 | } 35 | \description{ 36 | (C)ATE = (Conditional) Average Treatment Effects 37 | \code{newdata} specifies the conditions, if unspecified it defaults to the original data. 38 | Assumes treated column is either a integer column of 1's (treated) and 0's (nontreated) or logical indicating treatment if TRUE. 39 | } 40 | -------------------------------------------------------------------------------- /man/highDim_testdataset3.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{highDim_testdataset3} 5 | \alias{highDim_testdataset3} 6 | \title{ACIC2019 High Dimensional Test Dataset} 7 | \format{ 8 | A data frame with 2000 observations, and 187 variables. 9 | \describe{ 10 | \item{Y}{Outcome variable} 11 | \item{A}{Treatment variable} 12 | \item{V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,V12,V13,V14,V15,V16,V17,V18,V19,V20,V21,V22,V23,V24,V25,V26,V27,V28,V29,V30,V31,V32,V33,V34,V35,V36,V37,V38,V39,V40,V41,V42,V43,V44,V45,V46,V47,V48,V49,V50,V51,V52,V53,V54,V55,V56,V57,V58,V59,V60,V61,V62,V63,V64,V65,V66,V67,V68,V69,V70,V71,V72,V73,V74,V75,V76,V77,V78,V79,V80,V81,V82,V83,V84,V85,V86,V87,V88,V89,V90,V91,V92,V93,V94,V95,V96,V97,V98,V99,V100,V101,V102,V103,V104,V105,V106,V107,V108,V109,V110,V111,V112,V113,V114,V115,V116,V117,V118,V119,V120,V121,V122,V123,V124,V125,V126,V127,V128,V129,V130,V131,V132,V133,V134,V135,V136,V137,V138,V139,V140,V141,V142,V143,V144,V145,V146,V147,V148,V149,V150,V151,V152,V153,V154,V155,V156,V157,V158,V159,V160,V161,V162,V163,V164,V165,V166,V167,V168,V169,V170,V171,V172,V173,V174,V175,V176,V177,V178,V179,V180,V181,V182,V183,V184,V185}{Other covariates} 13 | ... 14 | } 15 | } 16 | \source{ 17 | \url{https://www.mcgill.ca/epi-biostat-occh/seminars-events/atlantic-causal-inference-conference-2019/data-challenge} 18 | } 19 | \usage{ 20 | highDim_testdataset3 21 | } 22 | \description{ 23 | Dataset from the "Data Challenge" for the Atlantic Causal Inference Conference 2019. 24 | } 25 | \keyword{datasets} 26 | -------------------------------------------------------------------------------- /man/treatment_effects.default.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/treatment-effects-posterior.R 3 | \name{treatment_effects.default} 4 | \alias{treatment_effects.default} 5 | \title{Get treatment effect draws from posterior} 6 | \usage{ 7 | \method{treatment_effects}{default}( 8 | model, 9 | treatment, 10 | newdata, 11 | subset = "all", 12 | common_support_method, 13 | cutoff, 14 | ... 15 | ) 16 | } 17 | \arguments{ 18 | \item{model}{A supported Bayesian model fit that can provide fits and predictions.} 19 | 20 | \item{treatment}{A character string specifying the name of the treatment variable.} 21 | 22 | \item{newdata}{Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model.} 23 | 24 | \item{subset}{Either "treated", "nontreated", or "all". Default is "all".} 25 | 26 | \item{common_support_method}{Either "sd", or "chisq". Default is unspecified, and no common support calculation is done.} 27 | 28 | \item{cutoff}{Cutoff for common support (if in use).} 29 | 30 | \item{...}{Arguments to be passed to \code{tidybayes::fitted_draws} typically scale for \code{BART} models.} 31 | } 32 | \value{ 33 | A tidy data frame (tibble) with treatment effect values. 34 | } 35 | \description{ 36 | CTE = Conditional Treatment Effects (or CATE, the average effects) 37 | \code{newdata} specifies the conditions, if unspecified it defaults to the original data. 38 | Assumes treated column is either a integer column of 1's (treated) and 0's (nontreated) or logical indicating treatment if TRUE. 39 | } 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tidy methods for Bayesian treatment effect models 2 | 3 | 4 | [![R-CMD-check](https://github.com/bonStats/tidytreatment/workflows/R-CMD-check/badge.svg)](https://github.com/bonStats/tidytreatment/actions) 5 | [![CRAN status](https://www.r-pkg.org/badges/version/tidytreatment)](https://CRAN.R-project.org/package=tidytreatment) 6 | [![CRAN downloads](https://cranlogs.r-pkg.org/badges/tidytreatment)](https://cran.r-project.org/package=tidytreatment) 7 | 8 | 9 | `tidytreatment` is an `R` package that provides functions for extracting tidy data from Bayesian treatment effect models, estimating treatment effects, and plotting useful summaries of these. This package closely follows the output style from the [tidybayes](https://github.com/mjskay/tidybayes) `R` package in order to use some functions provided by `tidybayes`. 10 | 11 | The package currently supports the following models: 12 | 13 | - `BART`: see [CRAN](https://cran.r-project.org/package=BART) 14 | - `bartMachine`: see [CRAN](https://cran.r-project.org/package=bartMachine). 15 | - `bcf`: see [CRAN](https://cran.r-project.org/package=bcf) (in development, see branch `bcf-hold` on github). 16 | 17 | See this [HTML vignette](https://cloud.r-project.org/web/packages/tidytreatment/vignettes/use-tidytreatment-BART.html) or `vignette("use-tidytreatment-BART")` for examples of usage. 18 | 19 | ## How to install 20 | 21 | ### CRAN 22 | 23 | Install the release version from CRAN with `install.packages("tidytreament")`. 24 | 25 | ### Github 26 | 27 | To install the latest development version: 28 | 29 | 1. Make sure at least one of the above model fitting packages is installed. 30 | 2. In `R` make sure `remotes` is installed. Install with `install.packages("remotes")`. 31 | - For help: see the Rtools (windows) and Xcode (macOS) links on [this page](https://support.rstudio.com/hc/en-us/articles/200486498-Package-Development-Prerequisites). 32 | 3. Run `remotes::install_github("bonStats/tidytreatment")` 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /data-raw/suhillsim1_bartmodel1.R: -------------------------------------------------------------------------------- 1 | ## Code to prepare `suhillsim1` and `cached_bart_model1` 2 | library(tidytreatment) 3 | library(BART) 4 | library(dplyr) 5 | withr::with_seed(101, { 6 | sim <- simulate_su_hill_data( 7 | n = 100, treatment_linear = F, omega = 0, add_categorical = T, 8 | coef_categorical_treatment = c(0, 0, 1), 9 | coef_categorical_nontreatment = c(-1, 0, -1) 10 | ) 11 | 12 | # regress y ~ covariates 13 | var_select_bart <- wbart( 14 | x.train = select(sim$data, -y, -z), 15 | y.train = pull(sim$data, y), 16 | sparse = T, 17 | nskip = 2000, 18 | ndpost = 5000, 19 | printevery = 1000L 20 | ) 21 | 22 | # select most important vars from y ~ covariates model 23 | # very simple selection mechanism. Should use cross-validation in practice 24 | var_select <- var_select_bart$varprob.mean %>% 25 | { 26 | which(. > 1 / length(.)) 27 | } %>% 28 | names() 29 | # change categoricals to just one variable 30 | var_select <- gsub("c1[1-3]$", "c1", var_select) 31 | 32 | # regress z ~ most important covariates to get propensity score 33 | # BART::pbart is for probit regression 34 | prop_bart <- pbart( 35 | x.train = select(sim$data, all_of(var_select)), 36 | y.train = pull(sim$data, z), 37 | nskip = 2000, 38 | ndpost = 5000, 39 | printevery = 1000L 40 | ) 41 | 42 | sim$data$prop_score <- prop_bart$prob.train.mean 43 | 44 | x.train <- select(sim$data, -y) 45 | y.train <- pull(sim$data, y) 46 | 47 | bmodel <- wbart( 48 | x.train = x.train, 49 | y.train = y.train, 50 | nskip = 10000L, 51 | ndpost = 200L, # keep small to manage size on CRAN 52 | keepevery = 100L, 53 | printevery = 3000L 54 | ) 55 | 56 | datamatrix1 <- bartModelMatrix(X = x.train) 57 | }) 58 | 59 | suhillsim1 <- sim 60 | bartmodel1 <- bmodel 61 | bartmodel1_modelmatrix <- datamatrix1 62 | 63 | usethis::use_data(suhillsim1, overwrite = T) 64 | usethis::use_data(bartmodel1, overwrite = T) 65 | usethis::use_data(bartmodel1_modelmatrix, overwrite = T) 66 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method(covariate_importance,bartMachine) 4 | S3method(covariate_importance,lbart) 5 | S3method(covariate_importance,mbart) 6 | S3method(covariate_importance,mbart2) 7 | S3method(covariate_importance,pbart) 8 | S3method(covariate_importance,wbart) 9 | S3method(covariate_with_treatment_importance,bartMachine) 10 | S3method(covariate_with_treatment_importance,lbart) 11 | S3method(covariate_with_treatment_importance,mbart) 12 | S3method(covariate_with_treatment_importance,mbart2) 13 | S3method(covariate_with_treatment_importance,pbart) 14 | S3method(covariate_with_treatment_importance,wbart) 15 | S3method(fitted_draws,bartMachine) 16 | S3method(fitted_draws,lbart) 17 | S3method(fitted_draws,mbart) 18 | S3method(fitted_draws,mbart2) 19 | S3method(fitted_draws,pbart) 20 | S3method(fitted_draws,wbart) 21 | S3method(model.matrix,bartMachine) 22 | S3method(predicted_draws,bartMachine) 23 | S3method(predicted_draws,wbart) 24 | S3method(print,lbart) 25 | S3method(print,mbart) 26 | S3method(print,mbart2) 27 | S3method(print,pbart) 28 | S3method(print,suhillsim) 29 | S3method(print,wbart) 30 | S3method(residual_draws,bartMachine) 31 | S3method(residual_draws,pbart) 32 | S3method(residual_draws,wbart) 33 | S3method(treatment_effects,default) 34 | S3method(variance_draws,bartMachine) 35 | S3method(variance_draws,wbart) 36 | export(avg_treatment_effects) 37 | export(covariate_importance) 38 | export(covariate_with_treatment_importance) 39 | export(has_common_support) 40 | export(has_tidytreatment_methods) 41 | export(simulate_su_hill_data) 42 | export(tidy_ate) 43 | export(tidy_att) 44 | export(treatment_effects) 45 | export(variance_draws) 46 | importFrom(rlang,"!!") 47 | importFrom(rlang,":=") 48 | importFrom(rlang,.data) 49 | importFrom(stats,predict) 50 | importFrom(stats,rnorm) 51 | importFrom(stats,terms) 52 | importFrom(tidybayes,add_fitted_draws) 53 | importFrom(tidybayes,add_predicted_draws) 54 | importFrom(tidybayes,add_residual_draws) 55 | importFrom(tidybayes,fitted_draws) 56 | importFrom(tidybayes,predicted_draws) 57 | importFrom(tidybayes,residual_draws) 58 | importFrom(utils,methods) 59 | -------------------------------------------------------------------------------- /R/data.R: -------------------------------------------------------------------------------- 1 | #' ACIC2019 High Dimensional Test Dataset 2 | #' 3 | #' Dataset from the "Data Challenge" for the Atlantic Causal Inference Conference 2019. 4 | #' 5 | #' @format A data frame with 2000 observations, and 187 variables. 6 | #' \describe{ 7 | #' \item{Y}{Outcome variable} 8 | #' \item{A}{Treatment variable} 9 | #' \item{V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,V12,V13,V14,V15,V16,V17,V18,V19,V20,V21,V22,V23,V24,V25,V26,V27,V28,V29,V30,V31,V32,V33,V34,V35,V36,V37,V38,V39,V40,V41,V42,V43,V44,V45,V46,V47,V48,V49,V50,V51,V52,V53,V54,V55,V56,V57,V58,V59,V60,V61,V62,V63,V64,V65,V66,V67,V68,V69,V70,V71,V72,V73,V74,V75,V76,V77,V78,V79,V80,V81,V82,V83,V84,V85,V86,V87,V88,V89,V90,V91,V92,V93,V94,V95,V96,V97,V98,V99,V100,V101,V102,V103,V104,V105,V106,V107,V108,V109,V110,V111,V112,V113,V114,V115,V116,V117,V118,V119,V120,V121,V122,V123,V124,V125,V126,V127,V128,V129,V130,V131,V132,V133,V134,V135,V136,V137,V138,V139,V140,V141,V142,V143,V144,V145,V146,V147,V148,V149,V150,V151,V152,V153,V154,V155,V156,V157,V158,V159,V160,V161,V162,V163,V164,V165,V166,V167,V168,V169,V170,V171,V172,V173,V174,V175,V176,V177,V178,V179,V180,V181,V182,V183,V184,V185}{Other covariates} 10 | #' ... 11 | #' } 12 | #' @source \url{https://www.mcgill.ca/epi-biostat-occh/seminars-events/atlantic-causal-inference-conference-2019/data-challenge} 13 | "highDim_testdataset3" 14 | 15 | #' Example simulated dataset 1 16 | #' 17 | #' Simulated with \code{simulate_su_hill_data(...)}, see details. 18 | #' Includes propensity score estimated using BART (\code{prop_score}), see source. 19 | #' 20 | #' \preformatted{set.seed(101) 21 | #' suhillsim1 <- simulate_su_hill_data(n = 100, treatment_linear = FALSE, omega = 0, add_categorical = TRUE, 22 | #' coef_categorical_treatment = c(0,0,1), 23 | #' coef_categorical_nontreatment = c(-1,0,-1)) 24 | #' } 25 | #' 26 | #' @format See \code{?simulate_su_hill_data} for output format. 27 | #' @source \url{https://github.com/bonStats/tidytreatment/tree/master/data-raw} 28 | "suhillsim1" 29 | 30 | #' Example model 1 31 | #' 32 | #' Model fit with simulated data from simulated dataset \code{suhillsim1}. 33 | #' 34 | #' Propensity score estimated and included \code{suhillsim1} for fitting the model. 35 | #' 36 | #' @format Object of type \code{BART::wbart} 37 | #' @source \url{https://github.com/bonStats/tidytreatment/tree/master/data-raw} 38 | "bartmodel1" 39 | 40 | #' Model matrix used for \code{bartmodel1} 41 | #' 42 | #' Useful for testing tidytreatment package functions. 43 | #' 44 | #' @format Object of type \code{BART::wbart} 45 | #' @source \url{https://github.com/bonStats/tidytreatment/tree/master/data-raw} 46 | "bartmodel1_modelmatrix" 47 | -------------------------------------------------------------------------------- /vignettes/vignette.bib: -------------------------------------------------------------------------------- 1 | @Article{Kapelner2016, 2 | title = {{bartMachine}: Machine Learning with {B}ayesian Additive Regression Trees}, 3 | author = {Adam Kapelner and Justin Bleich}, 4 | journal = {Journal of Statistical Software}, 5 | year = {2016}, 6 | volume = {70}, 7 | number = {4}, 8 | pages = {1--40}, 9 | url = {https://doi.org/10.18637/jss.v070.i04}, 10 | } 11 | 12 | @Article{Hill2013, 13 | title={Assessing lack of common support in causal inference using {B}ayesian nonparametrics: Implications for evaluating the effect of breastfeeding on children's cognitive outcomes}, 14 | author={Hill, Jennifer and Su, Yu-Sung}, 15 | journal={The Annals of Applied Statistics}, 16 | volume = {7}, 17 | number = {3}, 18 | pages={1386--1420}, 19 | year={2013}, 20 | publisher={JSTOR}, 21 | url={https://doi.org/10.1214/13-AOAS630} 22 | } 23 | 24 | 25 | @article{Hahn2020, 26 | author = {Hahn, P. Richard and Murray, Jared S. and Carvalho, Carlos M.}, 27 | doi = {10.1214/19-BA1195}, 28 | journal = {Bayesian Analysis}, 29 | month = {09}, 30 | number = {3}, 31 | pages = {965--1056}, 32 | publisher = {International Society for Bayesian Analysis}, 33 | title = {Bayesian Regression Tree Models for Causal Inference: Regularization, Confounding, and Heterogeneous Effects (with Discussion)}, 34 | url = {https://doi.org/10.1214/19-BA1195}, 35 | volume = {15}, 36 | year = {2020} 37 | } 38 | 39 | 40 | @Article{Chipman2010, 41 | title = {BART: Bayesian additive regression trees}, 42 | author = {Hugh A. Chipman and Edward I. George and Robert {E. McCulloch}}, 43 | doi = {10.1214/09-AOAS285}, 44 | month = {3}, 45 | number = {1}, 46 | pages = {266--298}, 47 | journal = {The Annals of Applied Statistics}, 48 | publisher = {The Institute of Mathematical Statistics}, 49 | url = {http://dx.doi.org/10.1214/09-AOAS285}, 50 | volume = {4}, 51 | year = {2010}, 52 | } 53 | 54 | @article{sparapani2016, 55 | title={Nonparametric survival analysis using Bayesian additive regression trees (BART)}, 56 | author={Sparapani, Rodney A and Logan, Brent R and McCulloch, Robert E and Laud, Purushottam W}, 57 | journal={Statistics in Medicine}, 58 | volume={35}, 59 | number={16}, 60 | pages={2741--2753}, 61 | year={2016}, 62 | publisher={Wiley Online Library}, 63 | url={https://doi.org/10.1002/sim.6893} 64 | } 65 | 66 | @article{bleich2014variable, 67 | title={Variable selection for BART: an application to gene regulation}, 68 | author={Bleich, Justin and Kapelner, Adam and George, Edward I and Jensen, Shane T}, 69 | journal={The Annals of Applied Statistics}, 70 | volume={8}, 71 | number={3}, 72 | pages={1750--1781}, 73 | year={2014}, 74 | publisher={JSTOR}, 75 | url={https://www.jstor.org/stable/24522283} 76 | } 77 | 78 | 79 | -------------------------------------------------------------------------------- /man/simulate_su_hill_data.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/simulate-su-hill.R 3 | \name{simulate_su_hill_data} 4 | \alias{simulate_su_hill_data} 5 | \title{Simulate data with scenarios from Hill and Su (2013)} 6 | \usage{ 7 | simulate_su_hill_data( 8 | n, 9 | treatment_linear = TRUE, 10 | response_parallel = TRUE, 11 | response_aligned = TRUE, 12 | y_sd = 1, 13 | tau = 4, 14 | omega = 0, 15 | add_categorical = FALSE, 16 | coef_categorical_treatment = NULL, 17 | coef_categorical_nontreatment = NULL 18 | ) 19 | } 20 | \arguments{ 21 | \item{n}{Size of simulated sample.} 22 | 23 | \item{treatment_linear}{Treatment assignment mechanism is linear?} 24 | 25 | \item{response_parallel}{Response surface is parallel?} 26 | 27 | \item{response_aligned}{Response surface is aligned?} 28 | 29 | \item{y_sd}{Observation noise.} 30 | 31 | \item{tau}{Treatment effect for parallel response surfaces. Not applicable if surface is nonparallel.} 32 | 33 | \item{omega}{Offset to control treatment assignment ratios.} 34 | 35 | \item{add_categorical}{Should a categorical variable be added? (Not in Hill and Su)} 36 | 37 | \item{coef_categorical_treatment}{What are the coefficients of the categorical variable under treatment? (Not in Hill and Su)} 38 | 39 | \item{coef_categorical_nontreatment}{What are the coefficients of the categorical variable under nontreatment? (Not in Hill and Su)} 40 | } 41 | \value{ 42 | An object of class \code{suhillsim} that is a list with elements 43 | \item{data}{Simulated data in data.frame} 44 | \item{mean_y}{The mean y values for each individual (row)} 45 | \item{args}{List of arguments passed to function} 46 | \item{formulas}{Response formulas used to generate data} 47 | \item{coefs}{Coefficients for the formulas} 48 | } 49 | \description{ 50 | Sample \eqn{n} observations with the following scheme: 51 | \enumerate{ 52 | \item Covariates: \eqn{X_j ~ N(0,1)}. 53 | \item Assignment: \eqn{Z ~ Bin(n, p)} with \eqn{p = logit^{-1}(a + X \gamma^L + Q \gamma^N)} where \eqn{a = \omega - mean(X \gamma^L + Q \gamma^N)}. 54 | \item Mean response: \eqn{E(Y(0)|X) = X \beta_0^L + Q \beta_0^N } and \eqn{E(Y(1)|X) = X \beta_1^L + Q \beta_1^N}. 55 | \item Observation: \eqn{Y ~ N(\mu,\sigma_y^2))}. 56 | } 57 | Superscript \eqn{L} denotes the linear components, whilst \eqn{N} denotes the non-linear 58 | components. 59 | } 60 | \details{ 61 | Coefficients used are returned in the list this function creates. See Table 1 in Su and Hill (2013) for the table of coefficients. 62 | The \eqn{X_j} are in a data.frame named \code{data} in the returned list. 63 | The formula for the model matrix \eqn{[X,Q]} is named \code{su_hill_formula} in the returned list. 64 | The coefficients used for the model matrix are contained in \code{coefs}. 65 | The Su and Hill (2013) simulations did not include categorical variables, but you can add them here using arguments: \code{add_categorical}, \code{coef_categorical_treatment}, \code{coef_categorical_nontreatment}. 66 | 67 | Hill, Jennifer; Su, Yu-Sung. Ann. Appl. Stat. 7 (2013), no. 3, 1386--1420. doi:10.1214/13-AOAS630. \url{https://projecteuclid.org/euclid.aoas/1380804800} 68 | } 69 | -------------------------------------------------------------------------------- /R/common-support.R: -------------------------------------------------------------------------------- 1 | #' Evaluate if observations have common support. 2 | #' 3 | #' The common support identification methods are based on Hill and Su (2013). 4 | #' Loosely speaker, an individuals treatment effect estimate has common support if the counter factual 5 | #' estimate is not too uncertain. The estimates are uncertain when the prediction is 'far away' from 6 | #' other observations. Removing estimates without common support can be beneficial for treat effect 7 | #' estimates. 8 | #' 9 | #' Hill, Jennifer; Su, Yu-Sung. Ann. Appl. Stat. 7 (2013), no. 3, 1386--1420. doi:10.1214/13-AOAS630. \url{https://projecteuclid.org/euclid.aoas/1380804800} 10 | #' 11 | #' @param model A supported Bayesian model fit that can provide fits and predictions. 12 | #' @param treatment A character string specifying the name of the treatment variable. 13 | #' @param method Method to use in determining common support. 'chisq', or 'sd'. 14 | #' @param cutoff Cutoff point to use for method. 15 | #' @param modeldata Manually provide model data for some models (e.g. from BART package) 16 | #' 17 | #' @return Tibble with a row for each observation and a column indicating whether common support exists. 18 | #' @export 19 | #' 20 | has_common_support <- function(model, treatment, method, cutoff, modeldata = NULL) { 21 | if (is.null(modeldata)) { 22 | modeldata <- stats::model.matrix(model) 23 | } 24 | 25 | stopifnot( 26 | treatment %in% colnames(modeldata), 27 | is.data.frame(modeldata), 28 | !missing(cutoff) 29 | ) 30 | 31 | stopifnot( 32 | is_01_integer_vector(modeldata[, treatment]) | is.logical(modeldata[, treatment]) 33 | ) 34 | 35 | treatment_class <- class(modeldata[, treatment]) 36 | 37 | if (treatment_class == "integer") { 38 | counter_factual <- function(x) { 39 | 1L - x 40 | } 41 | } else if (treatment_class == "logical") { 42 | counter_factual <- function(x) { 43 | !x 44 | } 45 | } 46 | 47 | calc_common_support_from_fitted_and_cf( 48 | fitted_and_cf = fitted_with_counter_factual_draws( 49 | model = model, 50 | newdata = modeldata, 51 | treatment = treatment, 52 | subset = "all" 53 | ), 54 | modeldata = modeldata, 55 | treatment = treatment, 56 | method = method, 57 | cutoff = cutoff 58 | ) 59 | } 60 | 61 | calc_common_support_from_fitted_and_cf <- function(fitted_and_cf, modeldata, treatment, method, cutoff) { 62 | posterior_obs_cf_sd <- dplyr::summarise( 63 | fitted_and_cf, 64 | sd_observed = stats::sd(.data$observed), 65 | sd_cfactual = stats::sd(.data$cfactual) 66 | ) 67 | 68 | common_support_cutoff <- switch(method, 69 | sd = common_support_sd_method, 70 | chisq = common_support_chisq_method, 71 | common_support_default 72 | ) 73 | 74 | dplyr::mutate(posterior_obs_cf_sd, 75 | common_support = 76 | common_support_cutoff( 77 | sd_obs = .data$sd_observed, 78 | sd_cf = .data$sd_cfactual, 79 | cutoff = cutoff, 80 | treatment = modeldata[posterior_obs_cf_sd$.row, treatment] 81 | ) 82 | ) 83 | } 84 | 85 | 86 | common_support_chisq_method <- function(sd_obs, sd_cf, cutoff, ...) { 87 | 88 | # the sd of the counterfactual divided by the sd of 89 | # the actual observation is approx Chi^2. 90 | (sd_cf / sd_obs)^2 < stats::qchisq(1 - cutoff, 1) 91 | } 92 | 93 | common_support_sd_method <- function(sd_obs, sd_cf, treatment, ...) { 94 | sd_obs_treated <- sd_obs[treatment == 1L] 95 | 96 | m_a <- max(sd_obs_treated) 97 | 98 | sd_cf < m_a + stats::sd(sd_obs_treated) 99 | } 100 | 101 | common_support_default <- function(sd_obs, sd_cf, cutoff) { 102 | warning("Please specify common support 'method'.") 103 | rep(NA, times = length(sd_obs)) 104 | } 105 | -------------------------------------------------------------------------------- /R/average-treatment-effects-posterior.R: -------------------------------------------------------------------------------- 1 | #' Get (conditional) average treatment effect draws from posterior 2 | #' 3 | #' (C)ATE = (Conditional) Average Treatment Effects 4 | #' \code{newdata} specifies the conditions, if unspecified it defaults to the original data. 5 | #' Assumes treated column is either a integer column of 1's (treated) and 0's (nontreated) or logical indicating treatment if TRUE. 6 | #' 7 | #' @param model A supported Bayesian model fit that can provide fits and predictions. 8 | #' @param treatment A character string specifying the name of the treatment variable. 9 | #' @param newdata Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model. 10 | #' @param subset Either "treated", "nontreated", or "all". Default is "all". 11 | #' @param common_support_method Either "sd", or "chisq". Default is unspecified, and no common support calculation is done. 12 | #' @param cutoff Cutoff for common support (if in use). 13 | #' @param ... Arguments to be passed to \code{tidybayes::fitted_draws} typically scale for \code{BART} models. 14 | #' 15 | #' @return A tidy data frame (tibble) with treatment effect values. 16 | #' @export 17 | #' 18 | #' 19 | avg_treatment_effects <- function(model, treatment, newdata, subset = "all", common_support_method, cutoff, ...) { 20 | te <- dplyr::group_by( 21 | .data = treatment_effects( 22 | model = model, treatment = treatment, 23 | newdata = newdata, subset = subset, 24 | common_support_method = common_support_method, 25 | cutoff = cutoff, ... 26 | ), 27 | .data$.chain, .data$.iteration, .data$.draw 28 | ) 29 | 30 | dplyr::summarise(te, ate = mean(.data$cte), .groups = "drop") 31 | } 32 | 33 | #' Get average treatment effect draws from posterior 34 | #' 35 | #' ATE = Average Treatment Effects 36 | #' Assumes treated column is either a integer column of 1's (treated) and 0's (nontreated) or logical indicating treatment if TRUE. 37 | #' 38 | #' @param model A supported Bayesian model fit that can provide fits and predictions. 39 | #' @param treatment A character string specifying the name of the treatment variable. 40 | #' @param common_support_method Either "sd", or "chisq". Default is unspecified, and no common support calculation is done. 41 | #' @param cutoff Cutoff for common support (if in use). 42 | #' @param ... Arguments to be passed to \code{tidybayes::fitted_draws} typically scale for \code{BART} models. 43 | #' 44 | #' @return A tidy data frame (tibble) with treatment effect values. 45 | #' @export 46 | #' 47 | #' 48 | tidy_ate <- function(model, treatment, common_support_method, cutoff, ...) { 49 | .dots <- list(...) 50 | if (!"newdata" %in% names(.dots)) check_method(model, method = "model.matrix", helper = "Please use 'avg_treatment_effects' function with 'newdata'.") 51 | 52 | te <- dplyr::group_by( 53 | .data = treatment_effects( 54 | model = model, treatment = treatment, 55 | subset = "all", 56 | common_support_method = common_support_method, 57 | cutoff = cutoff, ... 58 | ), 59 | .data$.chain, .data$.iteration, .data$.draw 60 | ) 61 | 62 | dplyr::summarise(te, ate = mean(.data$cte), .groups = "drop") 63 | } 64 | 65 | #' Get average treatment effect on treated draws from posterior 66 | #' 67 | #' ATT = average Treatment Effects on Treated 68 | #' Assumes treated column is either a integer column of 1's (treated) and 0's (nontreated) or logical indicating treatment if TRUE. 69 | #' 70 | #' @param model A supported Bayesian model fit that can provide fits and predictions. 71 | #' @param treatment A character string specifying the name of the treatment variable. 72 | #' @param common_support_method Either "sd", or "chisq". Default is unspecified, and no common support calculation is done. 73 | #' @param cutoff Cutoff for common support (if in use). 74 | #' @param ... Arguments to be passed to \code{tidybayes::fitted_draws} typically scale for \code{BART} models. 75 | #' 76 | #' @return A tidy data frame (tibble) with treatment effect values. 77 | #' @export 78 | #' 79 | #' 80 | tidy_att <- function(model, treatment, common_support_method, cutoff, ...) { 81 | .dots <- list(...) 82 | if (!"newdata" %in% names(.dots)) check_method(model, method = "model.matrix", helper = "Please use 'avg_treatment_effects' function with 'newdata'.") 83 | 84 | te <- dplyr::group_by( 85 | .data = treatment_effects( 86 | model = model, treatment = treatment, 87 | subset = "treated", 88 | common_support_method = common_support_method, 89 | cutoff = cutoff, ... 90 | ), 91 | .data$.chain, .data$.iteration, .data$.draw 92 | ) 93 | 94 | dplyr::summarise(te, att = mean(.data$cte), .groups = "drop") 95 | } 96 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # NOTE: This workflow is overkill for most R packages 2 | # check-standard.yaml is likely a better choice 3 | # usethis::use_github_action("check-standard") will install it. 4 | # 5 | # For help debugging build failures open an issue on the RStudio community with the 'github-actions' tag. 6 | # https://community.rstudio.com/new-topic?category=Package%20development&tags=github-actions 7 | on: 8 | push: 9 | branches: 10 | - main 11 | - master 12 | pull_request: 13 | branches: 14 | - main 15 | - master 16 | 17 | name: R-CMD-check 18 | 19 | jobs: 20 | R-CMD-check: 21 | runs-on: ${{ matrix.config.os }} 22 | 23 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 24 | 25 | strategy: 26 | fail-fast: false 27 | matrix: 28 | config: 29 | - {os: macOS-latest, r: 'release'} 30 | - {os: windows-latest, r: 'release'} 31 | - {os: windows-latest, r: '3.6', rspm: "https://packagemanager.rstudio.com/cran/latest"} 32 | - {os: ubuntu-18.04, r: 'devel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest", http-user-agent: "R/4.1.0 (ubuntu-18.04) R (4.1.0 x86_64-pc-linux-gnu x86_64 linux-gnu) on GitHub Actions" } 33 | - {os: ubuntu-18.04, r: 'release', rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest"} 34 | - {os: ubuntu-18.04, r: 'oldrel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest"} 35 | - {os: ubuntu-18.04, r: '3.6', rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest"} 36 | - {os: ubuntu-18.04, r: '3.5', rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest"} 37 | 38 | env: 39 | RSPM: ${{ matrix.config.rspm }} 40 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 41 | 42 | steps: 43 | - uses: actions/checkout@v2 44 | 45 | - uses: r-lib/actions/setup-r@v1 46 | id: install-r 47 | with: 48 | r-version: ${{ matrix.config.r }} 49 | http-user-agent: ${{ matrix.config.http-user-agent }} 50 | 51 | - uses: r-lib/actions/setup-pandoc@v1 52 | 53 | - name: Install pak and query dependencies 54 | run: | 55 | install.packages("pak", repos = "https://r-lib.github.io/p/pak/dev/") 56 | saveRDS(pak::pkg_deps("local::.", dependencies = TRUE), ".github/r-depends.rds") 57 | shell: Rscript {0} 58 | 59 | - name: Restore R package cache 60 | uses: actions/cache@v2 61 | with: 62 | path: | 63 | ${{ env.R_LIBS_USER }}/* 64 | !${{ env.R_LIBS_USER }}/pak 65 | key: ${{ matrix.config.os }}-${{ steps.install-r.outputs.installed-r-version }}-1-${{ hashFiles('.github/r-depends.rds') }} 66 | restore-keys: ${{ matrix.config.os }}-${{ steps.install-r.outputs.installed-r-version }}-1- 67 | 68 | - name: Install system dependencies 69 | if: runner.os == 'Linux' 70 | run: | 71 | pak::local_system_requirements(execute = TRUE) 72 | pak::pkg_system_requirements("rcmdcheck", execute = TRUE) 73 | shell: Rscript {0} 74 | 75 | - name: Install dependencies 76 | run: | 77 | pak::local_install_dev_deps(upgrade = TRUE) 78 | pak::pkg_install("rcmdcheck") 79 | shell: Rscript {0} 80 | 81 | - name: Session info 82 | run: | 83 | options(width = 100) 84 | pkgs <- installed.packages()[, "Package"] 85 | sessioninfo::session_info(pkgs, include_base = TRUE) 86 | shell: Rscript {0} 87 | 88 | - name: Check 89 | env: 90 | _R_CHECK_CRAN_INCOMING_: false 91 | run: | 92 | options(crayon.enabled = TRUE) 93 | rcmdcheck::rcmdcheck(args = c("--no-manual", "--as-cran"), error_on = "warning", check_dir = "check") 94 | shell: Rscript {0} 95 | 96 | - name: Show testthat output 97 | if: always() 98 | run: find check -name 'testthat.Rout*' -exec cat '{}' \; || true 99 | shell: bash 100 | 101 | - name: Upload check results 102 | if: failure() 103 | uses: actions/upload-artifact@main 104 | with: 105 | name: ${{ matrix.config.os }}-r${{ matrix.config.r }}-results 106 | path: check 107 | 108 | - name: Don't use tar from old Rtools to store the cache 109 | if: ${{ runner.os == 'Windows' && startsWith(steps.install-r.outputs.installed-r-version, '3.6' ) }} 110 | shell: bash 111 | run: echo "C:/Program Files/Git/usr/bin" >> $GITHUB_PATH 112 | -------------------------------------------------------------------------------- /R/covariate-importance.R: -------------------------------------------------------------------------------- 1 | #' Counts of variable inclusion when interacting with treatment 2 | #' 3 | #' @param model Model 4 | #' @param treatment A character string specifying the name of the treatment variable. 5 | #' @param ... Arguments to pass to particular methods. 6 | #' 7 | #' @return Tidy data with counts of variable inclusion, when interacting with treatment variable. 8 | #' @export 9 | #' 10 | covariate_with_treatment_importance <- function(model, treatment, ...) { 11 | UseMethod("covariate_with_treatment_importance") 12 | } 13 | 14 | #' @export 15 | covariate_with_treatment_importance.bartMachine <- function(model, treatment, ...) { 16 | ii <- bartMachine::interaction_investigator(model, plot = FALSE) 17 | 18 | treatment_col <- colnames(ii$interaction_counts_avg) %in% treatment 19 | 20 | stopifnot( 21 | sum(treatment_col) == 1 22 | ) 23 | 24 | res <- dplyr::tibble( 25 | variable = colnames(ii$interaction_counts_avg), 26 | avg_inclusion = ii$interaction_counts_avg[, treatment_col], 27 | sd = ii$interaction_counts_sd[, treatment_col] 28 | ) 29 | 30 | dplyr::filter(res, .data$variable != treatment) 31 | } 32 | 33 | #' Counts of variable overall inclusion 34 | #' 35 | #' Inclusion metric for bartMachine and BART are scaled differently. 36 | #' bartMachine averaged over number of trees, in addition to number of MCMC draws. 37 | #' 38 | #' @param model Model 39 | #' @param ... Arguments to pass to particular methods. 40 | #' 41 | #' @return Tidy data with counts of variable inclusion, when interacting with treatment variable. 42 | #' @export 43 | #' 44 | covariate_importance <- function(model, ...) { 45 | UseMethod("covariate_importance") 46 | } 47 | 48 | #' @export 49 | covariate_importance.bartMachine <- function(model, ...) { 50 | vv <- bartMachine::get_var_props_over_chain(model, ...) 51 | 52 | res <- dplyr::tibble( 53 | variable = names(vv), 54 | avg_inclusion = vv 55 | ) 56 | 57 | res 58 | } 59 | 60 | covariate_with_treatment_importance_BART <- function(model, treatment, ...) { 61 | # currently only use the (single) fitted BART model. 62 | # Whereas bartMachine uses average over replicates (default 5) 63 | ttree <- posterior_trees_BART(model) 64 | 65 | ttree_treat <- dplyr::select( 66 | dplyr::filter(ttree$trees, .data$var == treatment), 67 | .data$iter, 68 | .data$tree_id 69 | ) 70 | 71 | # filtered to trees with treatment 72 | var_counts <- table( 73 | dplyr::left_join(ttree_treat, ttree$trees, by = c("iter", "tree_id"))$var, 74 | useNA = "no" 75 | ) 76 | 77 | res <- dplyr::tibble( 78 | variable = names(var_counts), 79 | avg_inclusion = as.numeric(var_counts), 80 | sd = NA 81 | ) 82 | 83 | # add vars if missing from table 84 | var_names <- names(model$varprob.mean) 85 | missing_vars <- !var_names %in% res$variable 86 | 87 | if (any(missing_vars)) { 88 | add_res <- dplyr::tibble( 89 | variable = var_names[missing_vars], 90 | avg_inclusion = 0, 91 | sd = NA 92 | ) 93 | res <- dplyr::bind_rows(res, add_res) 94 | } 95 | 96 | dplyr::filter(res, .data$variable != treatment) 97 | } 98 | 99 | covariate_importance_BART <- function(model, ...) { 100 | 101 | # mean over mcmc draws 102 | vv <- model$varcount.mean 103 | 104 | res <- dplyr::tibble( 105 | variable = names(vv), 106 | avg_inclusion = vv 107 | ) 108 | 109 | res 110 | } 111 | 112 | #' @export 113 | covariate_importance.wbart <- function(model, ...) { 114 | covariate_importance_BART(model, ...) 115 | } 116 | #' @export 117 | covariate_importance.pbart <- function(model, ...) { 118 | covariate_importance_BART(model, ...) 119 | } 120 | #' @export 121 | covariate_importance.lbart <- function(model, ...) { 122 | covariate_importance_BART(model, ...) 123 | } 124 | 125 | #' @export 126 | covariate_importance.mbart <- function(model, ...) { 127 | covariate_importance_BART(model, ...) 128 | } 129 | #' @export 130 | covariate_importance.mbart2 <- function(model, ...) { 131 | covariate_importance_BART(model, ...) 132 | } 133 | 134 | #' @export 135 | covariate_with_treatment_importance.wbart <- function(model, treatment, ...) { 136 | covariate_with_treatment_importance_BART(model, treatment, ...) 137 | } 138 | 139 | #' @export 140 | covariate_with_treatment_importance.pbart <- function(model, treatment, ...) { 141 | covariate_with_treatment_importance_BART(model, treatment, ...) 142 | } 143 | 144 | #' @export 145 | covariate_with_treatment_importance.lbart <- function(model, treatment, ...) { 146 | covariate_with_treatment_importance_BART(model, treatment, ...) 147 | } 148 | 149 | #' @export 150 | covariate_with_treatment_importance.mbart2 <- function(model, treatment, ...) { 151 | covariate_with_treatment_importance_BART(model, treatment, ...) 152 | } 153 | 154 | #' @export 155 | covariate_with_treatment_importance.mbart <- function(model, treatment, ...) { 156 | covariate_with_treatment_importance_BART(model, treatment, ...) 157 | } 158 | -------------------------------------------------------------------------------- /R/tree-extract-BART.R: -------------------------------------------------------------------------------- 1 | #' Get posterior tree draws into tibble format from BART model 2 | #' 3 | #' Tibble grouped by iteration (`iter`) and tree id (`tree_id`). All information 4 | #' calculated by method is included in output. 5 | #' 6 | #' @param model BART model. 7 | #' @param label_digits Rounding for labels. 8 | #' 9 | #' @return A tibble with columns to \describe{ 10 | #' \item{iter}{Integer describing unique MCMC iteration.} 11 | #' \item{tree_id}{Integer. Unique tree id with each `iter`.} 12 | #' \item{node}{Integer describing node in tree. Unique to each `tree`-`iter`.} 13 | #' \item{parent}{Integer describing parent node in tree.} 14 | #' \item{label}{Label for the node.} 15 | #' \item{tier}{Position in tree hierarchy.} 16 | #' \item{var}{Variable for split.} 17 | #' \item{cut}{Numeric. Value of decision rule for `var`.} 18 | #' \item{is_leaf}{Logical. `TRUE` if leaf, `FALSE` if stem.} 19 | #' \item{leaf_value}{} 20 | #' \item{child_left}{Integer. Left child of node.} 21 | #' \item{child_right}{Integer. Right child of node.} 22 | #' } 23 | #' 24 | posterior_trees_BART <- function(model, label_digits = 2) { 25 | var_names <- names(model$treedraws$cutpoints) 26 | 27 | cut_points_tb <- purrr::map_df( 28 | .x = model$treedraws$cutpoints, 29 | .f = ~ dplyr::tibble(cut = ., cut_id = 1:length(.)), 30 | .id = "var" 31 | ) 32 | 33 | out <- list() 34 | 35 | # first line contains mcmc draws 36 | fline <- strsplit( 37 | readr::read_lines( 38 | file = model$treedraws$trees, 39 | n_max = 1 40 | ), 41 | " " 42 | )[[1]] 43 | out$n_mcmc <- as.integer(fline[1]) 44 | out$n_tree <- as.integer(fline[2]) 45 | out$n_var <- as.integer(fline[3]) 46 | 47 | out$trees <- suppressWarnings( 48 | readr::read_table2( 49 | file = model$treedraws$trees, 50 | col_names = c("node", "var", "cut", "leaf"), 51 | col_types = 52 | readr::cols( 53 | node = readr::col_integer(), 54 | var = readr::col_integer(), 55 | cut = readr::col_integer(), 56 | leaf = readr::col_double() 57 | ), 58 | skip = 1, 59 | na = c(""), 60 | progress = F 61 | ) 62 | ) 63 | 64 | # indexing and tier 65 | out$trees <- dplyr::mutate( 66 | out$trees, 67 | tier = as.integer(floor(log2(.data$node))), 68 | cut_id = .data$cut + 1L, # R indexing at 1 69 | var = var_names[.data$var + 1L] # R indexing at 1 70 | ) 71 | 72 | # define tree id and mcmc iteration number 73 | out$trees <- dplyr::mutate( 74 | out$trees, 75 | unique_tree_id = cumsum(is.na(.data$var) & is.na(.data$cut) & is.na(.data$leaf)), 76 | iter = (.data$unique_tree_id - 1L) %/% out$n_tree + 1L, 77 | tree_id = (.data$unique_tree_id - 1L) %% out$n_tree + 1L, 78 | unique_tree_id = NULL 79 | ) 80 | 81 | # remove information about tree groups (was stored as missing lines) 82 | out$trees <- dplyr::filter(out$trees, stats::complete.cases(out$trees)) 83 | 84 | # add cut information 85 | out$trees <- dplyr::left_join( 86 | dplyr::select(out$trees, -cut), 87 | cut_points_tb, 88 | by = c("var", "cut_id") 89 | ) 90 | 91 | # add children information 92 | out$trees <- dplyr::group_by(out$trees, .data$iter, .data$tree_id) 93 | out$trees <- dplyr::mutate( 94 | out$trees, 95 | child_left = child_left(.data$node), 96 | child_right = child_right(.data$node) 97 | ) 98 | 99 | # remove leaf info if no children 100 | out$trees <- dplyr::mutate( 101 | dplyr::ungroup(out$trees), 102 | is_leaf = is.na(child_left) & is.na(child_right), 103 | leaf_value = ifelse(.data$is_leaf, .data$leaf, NA_real_), 104 | cut = ifelse(.data$is_leaf, NA_real_, .data$cut), # is leaf, then no cut for stem 105 | var = ifelse(.data$is_leaf, NA_character_, .data$var), # is leaf, then no var for cut 106 | label = ifelse( 107 | .data$is_leaf, 108 | as.character(round(.data$leaf_value, digits = label_digits)), 109 | paste(.data$var, "<", round(.data$cut, digits = label_digits)) 110 | ), 111 | parent = parent(.data$node) 112 | ) 113 | 114 | # regroup 115 | out$trees <- dplyr::select( 116 | dplyr::group_by(out$trees, .data$iter, .data$tree_id), 117 | .data$iter, 118 | .data$tree_id, 119 | .data$node, 120 | .data$parent, 121 | .data$label, 122 | .data$tier, 123 | .data$var, 124 | .data$cut, 125 | .data$is_leaf, 126 | .data$leaf_value, 127 | .data$child_left, 128 | .data$child_right 129 | ) 130 | 131 | return(out) 132 | } 133 | 134 | child_left <- function(nodes) { 135 | 136 | # must be grouped by iter and tree to apply 137 | pot_child <- nodes * 2L 138 | pot_child[!pot_child %in% nodes] <- NA_integer_ 139 | 140 | return(pot_child) 141 | } 142 | 143 | child_right <- function(nodes) { 144 | 145 | # must be grouped by iter and tree to apply 146 | pot_child <- nodes * 2L + 1L 147 | pot_child[!pot_child %in% nodes] <- NA_integer_ 148 | 149 | return(pot_child) 150 | } 151 | 152 | parent <- function(nodes) { 153 | parents <- nodes %/% 2L 154 | parents[parents == 0L] <- NA_integer_ 155 | 156 | return(parents) 157 | } 158 | -------------------------------------------------------------------------------- /R/tidy-posterior-bartMachine.R: -------------------------------------------------------------------------------- 1 | #' Get fitted draws from posterior of \code{bartMachine} model 2 | #' 3 | #' @param model A \code{bartMachine} model. 4 | #' @param newdata Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model. 5 | #' @param value The name of the output column for \code{fitted_draws}; default \code{".value"}. 6 | #' @param n Not currently implemented. 7 | #' @param include_newdata Should the newdata be included in the tibble? 8 | #' @param include_sigsqs Should the posterior sigma-squared draw be included? 9 | #' @param ... Not currently in use. 10 | #' 11 | #' @return A tidy data frame (tibble) with fitted values. 12 | #' @export 13 | #' 14 | fitted_draws.bartMachine <- function(model, newdata, value = ".value", ..., n = NULL, include_newdata = TRUE, include_sigsqs = FALSE) { 15 | if (missing(newdata)) newdata <- stats::model.matrix(model) 16 | 17 | stopifnot( 18 | is.data.frame(newdata), 19 | is.character(value), 20 | is.null(n) | (is.integer(n) & n > 0), 21 | is.logical(include_newdata), 22 | is.logical(include_sigsqs) 23 | ) 24 | 25 | # order for columns in output 26 | col_order <- c(".row", ".chain", ".iteration", ".draw", value) 27 | 28 | posterior <- bartMachine::bart_machine_get_posterior(bart_machine = model, new_data = newdata) 29 | 30 | # bind newdata with fitted, wide format 31 | out <- dplyr::bind_cols( 32 | if (include_newdata) dplyr::as_tibble(newdata) else NULL, 33 | dplyr::as_tibble(posterior$y_hat_posterior_samples, .name_repair = function(names) { 34 | paste0(".col_iter", as.character(1:length(names))) 35 | }), 36 | .row = 1:nrow(newdata) 37 | ) 38 | 39 | # convert to long format 40 | out <- tidyr::gather(out, key = ".draw", value = !!value, dplyr::starts_with(".col_iter")) 41 | 42 | # add variables to keep to generic standard, remove string in 43 | out <- dplyr::mutate(out, .chain = NA_integer_, .iteration = NA_integer_, .draw = as.integer(gsub(pattern = ".col_iter", replacement = "", x = .data$.draw))) 44 | 45 | # include sigma^2 if needed 46 | if (include_sigsqs) { 47 | sigsq <- dplyr::bind_cols( 48 | .draw = 1:model$num_iterations_after_burn_in, 49 | sigsq = bartMachine::get_sigsqs(model) 50 | ) 51 | 52 | out <- dplyr::left_join(out, sigsq, by = ".draw") 53 | 54 | col_order <- c(col_order, "sigsq") 55 | } 56 | 57 | # rearrange 58 | out <- dplyr::select(out, -!!col_order, !!col_order) 59 | 60 | # group 61 | row_groups <- names(out)[!names(out) %in% col_order[col_order != ".row"]] 62 | 63 | out <- dplyr::group_by(out, dplyr::across(row_groups)) 64 | 65 | return(out) 66 | } 67 | 68 | 69 | #' Get predict draws from posterior of \code{bartMachine} model 70 | #' 71 | #' @param object A \code{bartMachine} model. 72 | #' @param newdata Data frame to generate predictions from. If omitted, most model types will generate predictions from the data used to fit the model. 73 | #' @param value The name of the output column for \code{predicted_draws}; default \code{".prediction"}. 74 | #' @param ndraws Not currently implemented. 75 | #' @param include_newdata Should the newdata be included in the tibble? 76 | #' @param include_fitted Should the posterior fitted values be included in the tibble? 77 | #' @param include_sigsqs Should the posterior sigma-squared draw be included? 78 | #' @param ... Not currently in use. 79 | #' 80 | #' @return A tidy data frame (tibble) with predicted values. 81 | #' @export 82 | #' 83 | predicted_draws.bartMachine <- function(object, newdata, value = ".prediction", ..., ndraws = NULL, include_newdata = TRUE, include_fitted = FALSE, include_sigsqs = FALSE) { 84 | stopifnot( 85 | is.character(value), 86 | is.logical(include_fitted), 87 | is.logical(include_sigsqs) 88 | ) 89 | 90 | # get fitted values (need sigsq to start with) 91 | out <- fitted_draws.bartMachine(object = object, newdata = newdata, value = ".fit", include_newdata = include_newdata, include_sigsqs = TRUE) 92 | 93 | # draw prediction from estimated variance 94 | out <- dplyr::mutate(out, !!value := stats::rnorm(n = dplyr::n(), mean = .data$.fit, sd = sqrt(.data$sigsq))) 95 | 96 | # remove sigma^2 value if necessary 97 | if (!include_sigsqs) out <- dplyr::select(out, -.data$sigsq) 98 | 99 | # remove fitted value if necessary 100 | if (!include_fitted) out <- dplyr::select(out, -.data$.fit) 101 | 102 | return(out) 103 | } 104 | 105 | #' Get residual draw for \code{bartMachine} model 106 | #' 107 | #' @param object \code{bartMachine} model. 108 | #' @param newdata Data frame to generate predictions from. If omitted, original data used to fit the model. 109 | #' @param value Name of the output column for residual_draws; default is \code{.residual}. 110 | #' @param ... Additional arguments passed to the underlying prediction method for the type of model given. 111 | #' @param include_newdata Should the newdata be included in the tibble? 112 | #' @param include_sigsqs Should the posterior sigma-squared draw be included? 113 | #' @param ndraws Not currently implemented. 114 | #' 115 | #' @return Tibble with residuals. 116 | #' @export 117 | #' 118 | residual_draws.bartMachine <- function(object, newdata, value = ".residual", ..., ndraws = NULL, include_newdata = TRUE, include_sigsqs = FALSE) { 119 | obs <- dplyr::tibble(y = object$y, .row = 1:object$n) 120 | 121 | fitted <- fitted_draws(object, newdata, 122 | value = ".fitted", n = NULL, 123 | include_newdata = include_newdata, 124 | include_sigsqs = include_sigsqs 125 | ) 126 | 127 | 128 | out <- dplyr::mutate( 129 | dplyr::left_join(fitted, obs, by = ".row"), 130 | !!value := .data$y - .data$.fitted 131 | ) 132 | 133 | dplyr::group_by(out, .data$.row) 134 | } 135 | -------------------------------------------------------------------------------- /R/treatment-effects-posterior.R: -------------------------------------------------------------------------------- 1 | #' Get (individual) treatment effect draws from posterior 2 | #' 3 | #' CTE = Conditional Treatment Effects (usually used to generate (C)ATE or ATT) 4 | #' \code{newdata} specifies the conditions, if unspecified it defaults to the original data. 5 | #' Assumes treated column is either a integer column of 1's (treated) and 0's (nontreated) or logical indicating treatment if TRUE. 6 | #' 7 | #' @param model A supported Bayesian model fit that can provide fits and predictions. 8 | #' @param treatment A character string specifying the name of the treatment variable. 9 | #' @param newdata Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model. 10 | #' @param subset Either "treated", "nontreated", or "all". Default is "all". 11 | #' @param common_support_method Either "sd", or "chisq". Default is unspecified, and no common support calculation is done. 12 | #' @param cutoff Cutoff for common support (if in use). 13 | #' @param ... Arguments to be passed to \code{tidybayes::fitted_draws} typically scale for \code{BART} models. 14 | #' 15 | #' @return A tidy data frame (tibble) with treatment effect values. 16 | #' @export 17 | #' 18 | 19 | treatment_effects <- function(model, treatment, newdata, subset = "all", common_support_method, cutoff, ...) { 20 | UseMethod("treatment_effects") 21 | } 22 | 23 | #' Get treatment effect draws from posterior 24 | #' 25 | #' CTE = Conditional Treatment Effects (or CATE, the average effects) 26 | #' \code{newdata} specifies the conditions, if unspecified it defaults to the original data. 27 | #' Assumes treated column is either a integer column of 1's (treated) and 0's (nontreated) or logical indicating treatment if TRUE. 28 | #' 29 | #' @inheritParams treatment_effects 30 | #' 31 | #' @return A tidy data frame (tibble) with treatment effect values. 32 | #' @export 33 | #' 34 | treatment_effects.default <- function(model, treatment, newdata, subset = "all", common_support_method, cutoff, ...) { 35 | stopifnot( 36 | !missing(treatment), 37 | is.character(treatment), 38 | length(treatment) == 1 39 | ) 40 | 41 | if (missing(newdata)) { 42 | check_method(model, 43 | method = "model.matrix", 44 | helper = "Please specify 'newdata' argument = data from model fitting." 45 | ) 46 | 47 | modeldata <- stats::model.matrix(model) 48 | } else { 49 | modeldata <- newdata 50 | } 51 | 52 | posterior_fit_with_cf <- fitted_with_counter_factual_draws(model, modeldata, treatment, subset, ...) 53 | 54 | posterior_treatment <- dplyr::select( 55 | dplyr::mutate(posterior_fit_with_cf, cte = (2L * as.integer(!!rlang::sym(treatment)) - 1L) * (.data$observed - .data$cfactual)), # equivalent to treatment - non_treatment 56 | -.data$observed, -.data$cfactual 57 | ) 58 | 59 | # add boolean for common support 60 | if (!missing(common_support_method)) { 61 | stopifnot( 62 | !missing(cutoff), 63 | # should use model data only, unless needs to be specified (e.g. for BART models). 64 | missing(newdata) | !has_tidytreatment_methods(model) 65 | ) 66 | 67 | if (!missing(newdata)) message("Note: Argument 'newdata' must be original dataset when calculating common support.") 68 | 69 | common_supp <- 70 | calc_common_support_from_fitted_and_cf( 71 | fitted_and_cf = posterior_fit_with_cf, 72 | modeldata = modeldata, 73 | treatment = treatment, 74 | method = common_support_method, 75 | cutoff = cutoff 76 | ) 77 | 78 | posterior_treatment <- dplyr::left_join(posterior_treatment, common_supp, by = ".row") 79 | } 80 | 81 | 82 | return(posterior_treatment) 83 | } 84 | 85 | fitted_with_counter_factual_draws <- function(model, newdata, treatment, subset, ...) { 86 | stopifnot( 87 | has_tidytreatment_methods(model) | !missing(newdata) 88 | ) 89 | 90 | if (missing(newdata)) { 91 | newdata <- stats::model.matrix(model) 92 | } 93 | 94 | use_subset <- match.arg(subset, c("all", "treated", "nontreated")) 95 | 96 | stopifnot( 97 | treatment %in% colnames(newdata), 98 | is.data.frame(newdata) 99 | ) 100 | 101 | stopifnot( 102 | is_01_integer_vector(newdata[, treatment]) | is.logical(newdata[, treatment]) 103 | ) 104 | 105 | obs_fitted <- tidybayes::fitted_draws( 106 | model = model, value = "observed", 107 | newdata = newdata, 108 | include_newdata = FALSE, 109 | ... 110 | ) 111 | 112 | cfactual_fitted <- tidybayes::fitted_draws( 113 | model = model, value = "cfactual", 114 | newdata = dplyr::mutate(newdata, !!treatment := counter_factual(!!rlang::sym(treatment))), 115 | include_newdata = FALSE, 116 | ... 117 | ) 118 | 119 | obs_fitted <- dplyr::left_join( 120 | obs_fitted, 121 | dplyr::mutate(dplyr::select(newdata, !!treatment), .row = 1:dplyr::n()), 122 | by = c(".row") 123 | ) 124 | 125 | out <- dplyr::left_join( 126 | obs_fitted, 127 | cfactual_fitted, 128 | by = c(".row", ".chain", ".iteration", ".draw") 129 | ) 130 | 131 | if (use_subset == "treated") { 132 | out <- dplyr::filter(out, is_treated(!!rlang::sym(treatment))) 133 | } else if (use_subset == "nontreated") { 134 | out <- dplyr::filter(out, !is_treated(!!rlang::sym(treatment))) 135 | } 136 | 137 | return(out) 138 | } 139 | 140 | counter_factual <- function(x) { 141 | if (is.integer(x)) { 142 | return(1L - x) 143 | } else if (is.logical(x)) { 144 | return(!x) 145 | } else { 146 | return(rep(NA, times = length(x))) 147 | } 148 | } 149 | 150 | is_treated <- function(x) { 151 | if (is.integer(x)) { 152 | return(x == 1L) 153 | } else if (is.logical(x)) { 154 | return(x) 155 | } else { 156 | return(rep(NA, times = length(x))) 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /R/simulate-su-hill.R: -------------------------------------------------------------------------------- 1 | #' Simulate data with scenarios from Hill and Su (2013) 2 | #' 3 | #' Sample \eqn{n} observations with the following scheme: 4 | #' \enumerate{ 5 | #' \item Covariates: \eqn{X_j ~ N(0,1)}. 6 | #' \item Assignment: \eqn{Z ~ Bin(n, p)} with \eqn{p = logit^{-1}(a + X \gamma^L + Q \gamma^N)} where \eqn{a = \omega - mean(X \gamma^L + Q \gamma^N)}. 7 | #' \item Mean response: \eqn{E(Y(0)|X) = X \beta_0^L + Q \beta_0^N } and \eqn{E(Y(1)|X) = X \beta_1^L + Q \beta_1^N}. 8 | #' \item Observation: \eqn{Y ~ N(\mu,\sigma_y^2))}. 9 | #' } 10 | #' Superscript \eqn{L} denotes the linear components, whilst \eqn{N} denotes the non-linear 11 | #' components. 12 | #' 13 | #' Coefficients used are returned in the list this function creates. See Table 1 in Su and Hill (2013) for the table of coefficients. 14 | #' The \eqn{X_j} are in a data.frame named \code{data} in the returned list. 15 | #' The formula for the model matrix \eqn{[X,Q]} is named \code{su_hill_formula} in the returned list. 16 | #' The coefficients used for the model matrix are contained in \code{coefs}. 17 | #' The Su and Hill (2013) simulations did not include categorical variables, but you can add them here using arguments: \code{add_categorical}, \code{coef_categorical_treatment}, \code{coef_categorical_nontreatment}. 18 | #' 19 | #' Hill, Jennifer; Su, Yu-Sung. Ann. Appl. Stat. 7 (2013), no. 3, 1386--1420. doi:10.1214/13-AOAS630. \url{https://projecteuclid.org/euclid.aoas/1380804800} 20 | #' 21 | #' @param n Size of simulated sample. 22 | #' @param tau Treatment effect for parallel response surfaces. Not applicable if surface is nonparallel. 23 | #' @param omega Offset to control treatment assignment ratios. 24 | #' @param treatment_linear Treatment assignment mechanism is linear? 25 | #' @param response_parallel Response surface is parallel? 26 | #' @param response_aligned Response surface is aligned? 27 | #' @param y_sd Observation noise. 28 | #' @param add_categorical Should a categorical variable be added? (Not in Hill and Su) 29 | #' @param coef_categorical_treatment What are the coefficients of the categorical variable under treatment? (Not in Hill and Su) 30 | #' @param coef_categorical_nontreatment What are the coefficients of the categorical variable under nontreatment? (Not in Hill and Su) 31 | #' @return An object of class \code{suhillsim} that is a list with elements 32 | #' \item{data}{Simulated data in data.frame} 33 | #' \item{mean_y}{The mean y values for each individual (row)} 34 | #' \item{args}{List of arguments passed to function} 35 | #' \item{formulas}{Response formulas used to generate data} 36 | #' \item{coefs}{Coefficients for the formulas} 37 | #' @export 38 | simulate_su_hill_data <- function(n, treatment_linear = TRUE, response_parallel = TRUE, response_aligned = TRUE, y_sd = 1, tau = 4, omega = 0, add_categorical = FALSE, coef_categorical_treatment = NULL, coef_categorical_nontreatment = NULL) { 39 | fargs <- as.list(match.call()) 40 | 41 | coefs <- dplyr::tribble( 42 | ~"class", ~"linear", ~"parallel", ~"aligned", ~"treatment", ~"values", 43 | # treatment assignment: linear, nonlinear 44 | "treatment-assignment", TRUE, NA, NA, NA, c(0.0, 0.0, 0.0, 0.0, 0.0, 0.4, 0.2, 0.4, 0.2, 0.4, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), 45 | "treatment-assignment", FALSE, NA, NA, NA, c(0.0, 0.0, 0.0, 0.0, 0.0, 0.4, 0.2, 0.4, 0.2, 0.4, 0.2, 0.8, 0.8, 0.5, 0.3, 0.8, 0.2, 0.4, 0.3, 0.8, 0.5), 46 | # response surface nonlinear and not parallel, aligned: treatment, nontreatment 47 | "response", FALSE, FALSE, TRUE, FALSE, c(0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 2.0, 0.0, 0.5, 2.0, 0.4, 0.8, 0.0, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.7), 48 | "response", FALSE, FALSE, TRUE, TRUE, c(0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 1.0, 0.5, 0.0, 0.8, 0.0, 0.0, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), 49 | # response surface nonlinear and not parallel, not as aligned: treatment, nontreatment 50 | "response", FALSE, FALSE, FALSE, FALSE, c(0.5, 2.0, 0.4, 0.5, 1.0, 0.5, 2.0, 0.0, 0.0, 0.0, 0.0, 0.5, 1.5, 0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), 51 | "response", FALSE, FALSE, FALSE, TRUE, c(0.5, 0.5, 0.0, 0.0, 0.0, 0.5, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) 52 | ) 53 | 54 | # add parallel response surfaces: 55 | add_parallel_nontreatment <- 56 | dplyr::mutate( 57 | dplyr::filter(coefs, class == "response", !.data$treatment), 58 | parallel = TRUE 59 | ) 60 | 61 | add_parallel_treatment <- dplyr::mutate( 62 | add_parallel_nontreatment, 63 | treatment = TRUE 64 | ) 65 | 66 | coefs <- dplyr::bind_rows(coefs, add_parallel_nontreatment, add_parallel_treatment) 67 | 68 | coefs_treatment_assignment <- 69 | dplyr::filter( 70 | coefs, 71 | class == "treatment-assignment", 72 | .data$linear == treatment_linear 73 | ) 74 | 75 | coefs_response <- 76 | dplyr::filter( 77 | coefs, 78 | class == "response", 79 | .data$parallel == response_parallel, 80 | .data$aligned == response_aligned 81 | ) 82 | 83 | stopifnot( 84 | nrow(coefs_treatment_assignment) == 1, 85 | nrow(coefs_response) == 2 86 | ) 87 | 88 | coef_assign <- coefs_treatment_assignment$values[[1]] 89 | coef_y_0 <- dplyr::filter(coefs_response, .data$treatment == FALSE)$values[[1]] 90 | coef_y_1 <- dplyr::filter(coefs_response, .data$treatment == TRUE)$values[[1]] 91 | 92 | invlogit <- function(x) { 93 | exp(x) / (1 + exp(x)) 94 | } 95 | 96 | # simulate data 97 | X <- as.data.frame(matrix(rnorm(n * 10, mean = 0, sd = 1), ncol = 10)) 98 | colnames(X) <- paste0("x", 1:10) 99 | 100 | su_hill_formula <- 101 | ~ x1 + x2 + I(x1^2) + I(x2^2) + I(x2 * x6) + 102 | x5 + x6 + x7 + x8 + x9 + x10 + 103 | I(x5^2) + I(x6^2) + I(x5 * x6) + I(x5 * x6 * x7) + I(x7^2) + 104 | I(x7^3) + I(x8^2) + I(x7 * x8) + I(x9^2) + I(x9 * x10) 105 | 106 | model_matrix <- as.matrix(stats::model.frame(su_hill_formula, data = X)) 107 | 108 | 109 | # assign to treatment 110 | logit_mean_treat_assignment <- model_matrix %*% coef_assign 111 | 112 | allocation_offset <- omega - mean(logit_mean_treat_assignment) 113 | 114 | p <- invlogit(allocation_offset + logit_mean_treat_assignment) 115 | z <- stats::rbinom(n = n, size = 1, prob = p) 116 | 117 | # Add categorical variable. Note: not included in Hill and Su 118 | 119 | if (add_categorical) { 120 | stopifnot( 121 | is.numeric(coef_categorical_treatment), 122 | is.numeric(coef_categorical_nontreatment), 123 | length(coef_categorical_treatment) == 124 | length(coef_categorical_nontreatment) 125 | ) 126 | 127 | ss <- length(coef_categorical_treatment) 128 | 129 | # sample categories with equal probability 130 | c1 <- sample.int(n = ss, replace = TRUE, size = n) 131 | 132 | cat_y_0 <- coef_categorical_nontreatment[c1] 133 | cat_y_1 <- coef_categorical_treatment[c1] 134 | } else { 135 | cat_y_0 <- 0 136 | cat_y_1 <- 0 137 | } 138 | 139 | # mean response 140 | mean_y <- ifelse(z == 0, model_matrix %*% coef_y_0 + cat_y_0, model_matrix %*% coef_y_1 + response_parallel * tau + cat_y_1) 141 | 142 | # add noise 143 | y <- rnorm(n = n, mean = mean_y, sd = y_sd) 144 | 145 | if (add_categorical) { 146 | rdata <- cbind(data.frame(y = y, z = z, c1 = factor(c1)), X) 147 | } else { 148 | rdata <- cbind(data.frame(y = y, z = z), X) 149 | } 150 | 151 | # prepare formula's to describe simulation truth 152 | formula_terms <- attributes(terms(su_hill_formula))$term.labels 153 | 154 | frmls <- list() 155 | 156 | # formula for treatment assignment 157 | which_treatment_assignment <- coefs_treatment_assignment$values[[1]] 158 | chr_treatment_assignment_frm <- paste(paste0(which_treatment_assignment, "*", formula_terms)[which_treatment_assignment != 0], collapse = " + ") 159 | frmls$treatment_assignment <- parse(text = chr_treatment_assignment_frm) 160 | 161 | # formula for response from treatment group 162 | which_response_treatment <- dplyr::filter(coefs_response, .data$treatment == TRUE)$values[[1]] 163 | chr_response_treatment_frm <- paste(paste0(which_response_treatment, "*", formula_terms)[which_response_treatment != 0], collapse = " + ") 164 | if (add_categorical) { 165 | add_chr_response_treatment_frm_cat <- paste0(coef_categorical_treatment, "*I(c1==", paste0("'", 1:length(coef_categorical_treatment), "'"), ")")[coef_categorical_treatment != 0] 166 | chr_response_treatment_frm <- paste(chr_response_treatment_frm, "+", paste(add_chr_response_treatment_frm_cat, collapse = " + ")) 167 | } 168 | frmls$response_treatment <- parse(text = chr_response_treatment_frm) 169 | 170 | # formula for response from non-treatment group 171 | which_response_nontreatment <- dplyr::filter(coefs_response, .data$treatment == FALSE)$values[[1]] 172 | chr_response_nontreatment_frm <- paste(paste0(which_response_nontreatment, "*", formula_terms)[which_response_nontreatment != 0], collapse = " + ") 173 | if (add_categorical) { 174 | add_chr_response_nontreatment_frm_cat <- paste0(coef_categorical_nontreatment, "*I(c1==", paste0("'", 1:length(coef_categorical_nontreatment), "'"), ")")[coef_categorical_nontreatment != 0] 175 | chr_response_nontreatment_frm <- paste(chr_response_nontreatment_frm, "+", paste(add_chr_response_nontreatment_frm_cat, collapse = " + ")) 176 | } 177 | frmls$response_nontreatment <- parse(text = chr_response_nontreatment_frm) 178 | 179 | frmls$generic <- su_hill_formula 180 | 181 | 182 | return( 183 | structure(list( 184 | data = rdata, 185 | mean_y = mean_y, 186 | args = fargs, 187 | formulas = frmls, 188 | coefs = dplyr::bind_rows(coefs_treatment_assignment, coefs_response) 189 | ), class = "suhillsim") 190 | ) 191 | } 192 | 193 | #' @export 194 | print.suhillsim <- function(x, ...) { 195 | st <- "Su-Hill Simulation" 196 | ta <- paste0("\n Treatment assignment: ", as.character(x$formulas$treatment_assignment)) 197 | rt <- paste0("\n Response (treated): ", as.character(x$formulas$response_treatment)) 198 | rnt <- paste0("\n Response (nontreated): ", as.character(x$formulas$response_nontreatment)) 199 | en <- paste0("\nList with elements\n", paste0("\t$", names(x), collapse = "\n")) 200 | 201 | cat(st, ta, rt, rnt, en) 202 | } 203 | -------------------------------------------------------------------------------- /examples/use-tidytreatment-bartMachine.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Using the tidytreatment package" 3 | author: "Joshua J Bon" 4 | date: "`r Sys.Date()`" 5 | bibliography: ../vignettes/vignette.bib 6 | output: rmarkdown::html_vignette 7 | vignette: > 8 | %\VignetteIndexEntry{Using tidytreatment with bartMachine} 9 | %\VignetteEngine{knitr::rmarkdown} 10 | %\VignetteEncoding{UTF-8} 11 | --- 12 | 13 | ```{r setup, include = FALSE} 14 | knitr::opts_chunk$set( 15 | collapse = TRUE, 16 | comment = "#>" 17 | ) 18 | ``` 19 | 20 | This vignette simulates data using the scheme described by @Hill2013 with the additional of 1 categorical variable. It it implemented in the function `simulate_hill_su_data()`: 21 | 22 | ```{r load-data} 23 | 24 | # before running library(bartMachine), set memory 25 | options(java.parameters = "-Xmx2000m") # restart R to take effect 26 | # check memory allocated to Java VM 27 | options("java.parameters") 28 | 29 | suppressPackageStartupMessages({ 30 | library(bartMachine) 31 | library(tidytreatment) 32 | library(dplyr) 33 | library(tidybayes) 34 | library(ggplot2) 35 | }) 36 | 37 | sim <- simulate_su_hill_data(n = 200, treatment_linear = FALSE, omega = 0, add_categorical = TRUE, 38 | coef_categorical_treatment = c(0,0,1), 39 | coef_categorical_nontreatment = c(-1,0,-1) 40 | ) 41 | # non-treated vs treated counts: 42 | table(sim$data$z) 43 | 44 | dat <- sim$data 45 | dat$c1 <- as.integer(dat$c1) 46 | # a selection of data 47 | dat %>% select(y, z, c1, x1:x3) %>% head() 48 | 49 | ``` 50 | 51 | ## Run the `bartMachine` 52 | 53 | Run the model to be used to assess treatment effects. Here we will use `bartMachine`, which is one implementation of Bayesian Additive Regression Trees in `R` [@Kapelner2016]. The package can be found on [CRAN](https://cran.r-project.org/package=bartMachine). 54 | 55 | ```{r run-bart, echo=TRUE, results='hide', cache=FALSE} 56 | 57 | # if you increase the number of cores, the memory needs to be increased, 58 | # this requires restarting R, setting the 'java.parameters' option then 59 | # loading the bartMachine package. 60 | set_bart_machine_num_cores(2) 61 | 62 | # set serialize = TRUE if using the fit over multiple sessions 63 | # The first bart model will be for the propensity score... 64 | # i.e. propensity for selection of treatment? 65 | 66 | # regress y ~ covariates 67 | var_select_bart <- bartMachine( 68 | X = select(dat,-y,-z), 69 | y = select(dat, y)[[1]], 70 | num_burn_in = 2000, 71 | num_iterations_after_burn_in = 5000, 72 | serialize = TRUE, 73 | verbose = FALSE 74 | ) 75 | 76 | # select most important vars from y ~ covariates model 77 | var_select <- bartMachine::var_selection_by_permute_cv(var_select_bart, k_folds = 5) 78 | 79 | # regress z ~ most important covariates to get propensity score 80 | prop_bart <- bartMachine( 81 | X = select(dat,var_select$important_vars_cv), 82 | y = as.factor(select(dat, z)[[1]]), 83 | num_burn_in = 2000, 84 | num_iterations_after_burn_in = 5000, 85 | serialize = TRUE, 86 | verbose = FALSE 87 | ) 88 | 89 | dat$prop_score <- prop_bart$p_hat_train 90 | 91 | destroy_bart_machine(var_select_bart) 92 | destroy_bart_machine(prop_bart) 93 | 94 | # Give z double prior inclusion probability 95 | prior_incl_prob <- setNames(rep(1, times = ncol(dat) - 1), colnames(dat)[colnames(dat) != "y"]) 96 | prior_incl_prob["z"] <- 2 97 | 98 | bartM <- bartMachine( 99 | X = select(dat,-y), 100 | y = select(dat, y)[[1]], 101 | num_burn_in = 2000, 102 | num_iterations_after_burn_in = 5000, 103 | serialize = TRUE, 104 | verbose = FALSE, 105 | cov_prior_vec = prior_incl_prob 106 | ) 107 | 108 | 109 | ``` 110 | 111 | ## Model checking and convergence 112 | 113 | Here are some examples of model checking we can do. 114 | 115 | ```{r convergence-bart, echo=TRUE, cache=FALSE} 116 | 117 | print(bartM) 118 | 119 | res <- residual_draws(bartM, include_newdata = FALSE) 120 | res %>% 121 | point_interval(.residual, y, .width = c(0.95) ) %>% 122 | select(-y.lower, -y.upper) %>% 123 | ggplot() + 124 | geom_pointinterval(aes(x = y, y = .residual, ymin = .residual.lower, ymax = .residual.upper), alpha = 0.2) + 125 | scale_fill_brewer() + 126 | theme_bw() + ggtitle("Residuals vs observations") 127 | 128 | res %>% summarise(.fitted = mean(.fitted), y = first(y)) %>% 129 | ggplot(aes(x = y, y = .fitted)) + 130 | geom_point() + 131 | geom_smooth(method = "lm") + 132 | theme_bw() + ggtitle("Observations vs fitted") 133 | 134 | res %>% summarise(.residual = mean(.residual)) %>% 135 | ggplot(aes(sample = .residual)) + 136 | geom_qq() + 137 | geom_qq_line() + 138 | theme_bw() + ggtitle("Q-Q plot of residuals") 139 | 140 | bartMachine:::plot_sigsqs_convergence_diagnostics(bartM) 141 | 142 | bartMachine:::plot_mh_acceptance_reject(bartM) 143 | 144 | bartMachine:::plot_tree_num_nodes(bartM) 145 | 146 | bartMachine:::plot_tree_depths(bartM) 147 | 148 | ``` 149 | 150 | ## Extract the posterior (tidy style) 151 | 152 | Methods for extracting the posterior in a tidy format is included in the `tidytreatment`. 153 | 154 | ```{r tidy-bart-fit, echo=TRUE, cache=FALSE} 155 | 156 | posterior_fitted <- fitted_draws(bartM, value = "fit", include_newdata = FALSE) 157 | # The newdata argument (omitted) defaults to the data from the model. 158 | # include_newdata = FALSE, avoids returning the newdata with the fitted values 159 | # as it is so large. 160 | # The `.row` variable makes sure we know which row in the newdata the fitted 161 | # value came from (if we dont include the data in the result). 162 | 163 | posterior_fitted 164 | 165 | ``` 166 | 167 | ```{r tidy-bart-pred, eval=FALSE, echo=TRUE, cache=FALSE} 168 | 169 | # Function to tidy predicted draws also... 170 | posterior_pred <- predicted_draws(bartM, include_newdata = FALSE) 171 | 172 | ``` 173 | 174 | ## Use some plotting functions from the `tidybayes` package 175 | 176 | Since `tidytreatment` follows the `tidybayes` output specifications, functions from `tidybayes` should work. 177 | 178 | ```{r plot-tidy-bart, echo=TRUE, cache=FALSE} 179 | 180 | treatment_var_and_c1 <- 181 | dat %>% 182 | select(z,c1) %>% 183 | mutate(.row = 1:n(), z = as.factor(z)) 184 | 185 | posterior_fitted %>% 186 | left_join(treatment_var_and_c1, by = ".row") %>% 187 | ggplot() + 188 | geom_eye(aes(x = z, y = fit)) + 189 | facet_wrap(~c1, labeller = as_labeller( function(x) paste("c1 =",x) ) ) + 190 | xlab("Treatment (z)") + ylab("Posterior predicted value") + 191 | theme_bw() + ggtitle("Effect of treatment with 'c1' on posterior fitted values") 192 | 193 | ``` 194 | 195 | ## Calculate Treatment Effects 196 | 197 | Posterior conditional (average) treatment effects can be calculated using the `treatment_effects` function. This function finds the posterior values of 198 | $$ 199 | \text{E}(y ~ \vert~ T = 1, X = x_{i}) - \text{E}(y ~ \vert~ T = 0, X = x_{i}) 200 | $$ 201 | for each unit of measurement, $i$, (e.g. subject) in the data sample. 202 | 203 | Some histogram summaries are presented below. 204 | 205 | ```{r cates-hist, echo=TRUE, cache=FALSE} 206 | 207 | # sample based (using data from fit) conditional treatment effects, posterior draws 208 | posterior_treat_eff <- 209 | treatment_effects(bartM, treatment = "z") 210 | 211 | posterior_treat_eff %>% 212 | ggplot() + 213 | geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") + 214 | theme_bw() + ggtitle("Histogram of treatment effect (all draws)") 215 | 216 | 217 | posterior_treat_eff %>% summarise(cte_hat = median(cte)) %>% 218 | ggplot() + 219 | geom_histogram(aes(x = cte_hat), binwidth = 0.1, colour = "white") + 220 | theme_bw() + ggtitle("Histogram of treatment effect (median for each subject)") 221 | 222 | ``` 223 | 224 | We can also focus on the treatment effects for just those that are treated. 225 | 226 | ```{r cates-hist-treated, echo=TRUE, cache=FALSE} 227 | 228 | # sample based (using data from fit) conditional treatment effects, posterior draws 229 | posterior_treat_eff_on_treated <- 230 | treatment_effects(bartM, treatment = "z", subset = "treated") 231 | 232 | posterior_treat_eff_on_treated %>% 233 | ggplot() + 234 | geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") + 235 | theme_bw() + ggtitle("Histogram of treatment effect (all draws from treated subjects)") 236 | 237 | ``` 238 | 239 | Plots can be made that stack each subjects posterior CIs of the CATEs. 240 | 241 | ```{r cates-stack-plot, echo=TRUE, cache=FALSE} 242 | 243 | posterior_treat_eff %>% select(-z) %>% point_interval() %>% 244 | arrange(cte) %>% mutate(.orow = 1:n()) %>% 245 | ggplot() + 246 | geom_interval(aes(x = .orow, y= cte), size = 0.5) + 247 | geom_point(aes(x = .orow, y = cte), shape = "circle open", alpha = 0.1) + 248 | ylab("Median posterior CATE for each subject (95% CI)") + 249 | theme_bw() + coord_flip() + scale_colour_brewer() + 250 | theme(axis.title.y = element_blank(), 251 | axis.text.y = element_blank(), 252 | axis.ticks.y = element_blank(), 253 | legend.position = "none") 254 | 255 | ``` 256 | 257 | We can also plot the CATEs varying over particular covariates. In this example, instead of grouping by subject, we group by the variable of interest, and calculate the posterior summaries over this variable. 258 | 259 | ```{r cates-line-plot, echo=TRUE, cache=FALSE} 260 | 261 | posterior_treat_eff %>% 262 | left_join(dplyr::tibble(c1 = dat$c1, .row = 1:length(dat$c1) ), by = ".row") %>% 263 | group_by(c1) %>% 264 | ggplot() + 265 | geom_eye(aes(x = c1, y = cte), alpha = 0.2) + 266 | scale_fill_brewer() + 267 | theme_bw() + ggtitle("Treatment effect by `c1`") 268 | 269 | 270 | ``` 271 | 272 | ## Common support 273 | 274 | Common support testing [@hill] can be tested directly, or a Boolean can be included when calculating the treatment effects. 275 | 276 | ```{r common-support, echo=TRUE, cache=FALSE} 277 | 278 | csupp1 <- has_common_support(bartM, treatment = "z", 279 | method = "chisq", cutoff = 0.05) 280 | csupp1 %>% filter(!common_support) 281 | 282 | csupp2 <- has_common_support(bartM, treatment = "z", 283 | method = "sd", cutoff = 1) 284 | csupp2 %>% filter(!common_support) 285 | 286 | posterior_treat_eff_on_treated <- 287 | treatment_effects(bartM, treatment = "z", 288 | subset = "treated", 289 | common_support_method = "sd", cutoff = 1) 290 | 291 | ``` 292 | 293 | ## Investigating variable importance 294 | 295 | We can count how many times a variables was included in the BART in conjunction with the treatment effect, or overall. 296 | 297 | ```{r interaction-investigator, echo=TRUE, cache=FALSE} 298 | 299 | treatment_interactions <- 300 | covariate_with_treatment_importance(bartM, treatment = "z") 301 | 302 | treatment_interactions %>% 303 | ggplot() + 304 | geom_bar(aes(x = variable, y = avg_inclusion), stat = "identity") + 305 | theme_bw() + ggtitle("Important variables interacting with treatment ('z')") + 306 | ylab("Inclusion counts") 307 | 308 | variable_importance <- 309 | covariate_importance(bartM) 310 | 311 | variable_importance %>% 312 | ggplot() + 313 | geom_bar(aes(x = variable, y = avg_inclusion), stat = "identity") + 314 | theme_bw() + ggtitle("Important variables overall") + 315 | ylab("Inclusion counts") 316 | 317 | 318 | ``` 319 | 320 | ## References 321 | -------------------------------------------------------------------------------- /R/tidy-posterior-BART.R: -------------------------------------------------------------------------------- 1 | #' Get fitted draws from posterior of \code{BART}-package models 2 | #' 3 | #' @param model A model from \code{BART} package. 4 | #' @param newdata Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model. 5 | #' @param value The name of the output column for \code{fitted_draws}; default \code{".value"}. 6 | #' @param include_newdata Should the newdata be included in the tibble? 7 | #' @param include_sigsqs Should the posterior sigma-squared draw be included? 8 | #' @param scale Should the fitted values be on the real, probit or logit scale? 9 | #' @param ... Arguments to pass to \code{predict} (e.g. \code{BART:::predict.wbart}). 10 | #' 11 | #' @return A tidy data frame (tibble) with fitted values. 12 | #' 13 | fitted_draws_BART <- function(model, newdata = NULL, value = ".value", ..., include_newdata = TRUE, include_sigsqs = FALSE, scale = "real") { 14 | stopifnot(has_installed_package("BART")) 15 | 16 | if (is.null(newdata) & include_newdata) { 17 | stop("For models from BART package 'newdata' 18 | must be specified if 'include_newdata = TRUE'.") 19 | } 20 | 21 | stopifnot( 22 | is.character(value), 23 | is.logical(include_newdata), 24 | is.logical(include_sigsqs), 25 | class(model) %in% c("wbart", "pbart", "lbart", "mbart", "mbart2") 26 | ) 27 | 28 | use_scale <- match.arg(scale, 29 | c("real", "prob"), 30 | several.ok = F 31 | ) 32 | 33 | # order for columns in output 34 | col_order <- c(".row", ".chain", ".iteration", ".draw", value) 35 | 36 | if (!(missing(newdata) | is.null(newdata))) { 37 | # S3 predict methods in BART get yhat values. 38 | xvars <- names(model$treedraws$cutpoints) 39 | bartdata <- BART::bartModelMatrix(newdata)[, xvars] 40 | # dodraws=TRUE => all draws (not just mean) 41 | posterior <- predict(object = model, newdata = bartdata, dodraws = TRUE, ...) 42 | if (!is.matrix(posterior)) posterior <- posterior$yhat.test 43 | } else { 44 | posterior <- model$yhat.train 45 | } 46 | 47 | if (use_scale == "prob" & "lbart" %in% class(model)) posterior <- stats::plogis(posterior) 48 | if (use_scale == "prob" & "pbart" %in% class(model)) posterior <- stats::pnorm(posterior) 49 | 50 | # bind newdata with fitted, wide format 51 | out <- dplyr::bind_cols( 52 | if (include_newdata) dplyr::as_tibble(newdata) else NULL, 53 | dplyr::as_tibble(t(posterior), .name_repair = function(names) { 54 | paste0(".col_iter", as.character(1:length(names))) 55 | }), 56 | .row = 1:ncol(posterior) 57 | ) 58 | 59 | # convert to long format 60 | out <- tidyr::gather(out, key = ".draw", value = !!value, dplyr::starts_with(".col_iter")) 61 | 62 | # add variables to keep to generic standard, remove string in 63 | out <- dplyr::mutate(out, .chain = NA_integer_, .iteration = NA_integer_, .draw = as.integer(gsub(pattern = ".col_iter", replacement = "", x = .data$.draw))) 64 | 65 | # include sigma^2 if needed 66 | if (include_sigsqs) { 67 | sigsq <- dplyr::bind_cols( 68 | .draw = 1:length(model$sigma), 69 | sigsq = model$sigma^2 70 | ) 71 | 72 | out <- dplyr::left_join(out, sigsq, by = ".draw") 73 | 74 | col_order <- c(col_order, "sigsq") 75 | } 76 | 77 | # rearrange 78 | out <- dplyr::select(out, -!!col_order, !!col_order) 79 | 80 | # group 81 | row_groups <- names(out)[!names(out) %in% col_order[col_order != ".row"]] 82 | 83 | out <- dplyr::group_by(out, dplyr::across(row_groups)) 84 | 85 | return(out) 86 | } 87 | 88 | #' Get predict draws from posterior of \code{BART}-package models 89 | #' 90 | #' @param object A \code{BART}-package model. 91 | #' @param newdata Data frame to generate predictions from. If omitted, most model types will generate predictions from the data used to fit the model. 92 | #' @param value The name of the output column for \code{predicted_draws}; default \code{".prediction"}. 93 | #' @param rng Random number generator function. Default is \code{rnorm} for models with Gaussian errors. 94 | #' @param include_newdata Should the newdata be included in the tibble? 95 | #' @param include_fitted Should the posterior fitted values be included in the tibble? 96 | #' @param include_sigsqs Should the posterior sigma-squared draw be included? 97 | #' @param ... Arguments to pass to \code{predict} (e.g. \code{BART:::predict.wbart}). 98 | #' 99 | #' @return A tidy data frame (tibble) with predicted values. 100 | #' 101 | predicted_draws_BART <- function(object, newdata = NULL, value = ".prediction", ..., rng = stats::rnorm, include_newdata = TRUE, include_fitted = FALSE, include_sigsqs = FALSE) { 102 | stopifnot( 103 | is.character(value), 104 | is.logical(include_fitted), 105 | is.logical(include_sigsqs) 106 | ) 107 | 108 | # get fitted values (need sigsq to start with) 109 | out <- fitted_draws(object, newdata = newdata, value = ".fit", include_newdata = include_newdata, include_sigsqs = TRUE) 110 | 111 | # draw prediction from estimated variance 112 | out <- dplyr::mutate(out, !!value := rng(n = dplyr::n(), mean = .data$.fit, sd = sqrt(.data$sigsq))) 113 | 114 | # remove sigma^2 value if necessary 115 | if (!include_sigsqs) out <- dplyr::select(out, -.data$sigsq) 116 | 117 | # remove fitted value if necessary 118 | if (!include_fitted) out <- dplyr::select(out, -.data$.fit) 119 | 120 | return(out) 121 | } 122 | 123 | 124 | #' Get residual draw for BART model 125 | #' 126 | #' Classes from \code{BART}-package models 127 | #' 128 | #' @param object model from \code{BART} package. 129 | #' @param response Original response vector. 130 | #' @param newdata Data frame to generate predictions from. If omitted, original data used to fit the model. 131 | #' @param value Name of the output column for residual_draws; default is \code{.residual}. 132 | #' @param include_newdata Should the newdata be included in the tibble? 133 | #' @param include_sigsqs Should the posterior sigma-squared draw be included? 134 | #' 135 | #' @return Tibble with residuals. 136 | #' 137 | residual_draws_BART <- function(object, response, newdata = NULL, value = ".residual", include_newdata = TRUE, include_sigsqs = FALSE) { 138 | if (missing(response)) stop("Models from BART pacakge require response (y) as argument. Specify 'response = ' as argument.") 139 | 140 | stopifnot(is.numeric(response)) 141 | 142 | obs <- dplyr::tibble(y = response, .row = 1:length(response)) 143 | 144 | fitted <- fitted_draws(object, newdata, 145 | value = ".fitted", n = NULL, 146 | include_newdata = include_newdata, 147 | include_sigsqs = include_sigsqs 148 | ) 149 | 150 | out <- dplyr::mutate( 151 | dplyr::left_join(fitted, obs, by = ".row"), 152 | !!value := .data$y - .data$.fitted 153 | ) 154 | 155 | dplyr::group_by(out, .row) 156 | } 157 | 158 | #' Get fitted draws from posterior of \code{wbart} model 159 | #' 160 | #' @param model A model from \code{BART} package. 161 | #' @param newdata Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model. 162 | #' @param value The name of the output column for \code{fitted_draws}; default \code{".value"}. 163 | #' @param n Not currently implemented. 164 | #' @param include_newdata Should the newdata be included in the tibble? 165 | #' @param include_sigsqs Should the posterior sigma-squared draw be included? 166 | #' @param ... Not currently in use. 167 | #' 168 | #' @return A tidy data frame (tibble) with fitted values. 169 | #' @export 170 | #' 171 | fitted_draws.wbart <- function(model, newdata, value = ".value", ..., n = NULL, include_newdata = TRUE, include_sigsqs = FALSE) { 172 | if (missing(newdata)) { 173 | newdata <- NULL 174 | } 175 | 176 | fitted_draws_BART( 177 | model = model, newdata = newdata, value = value, 178 | ..., 179 | include_newdata = include_newdata, 180 | include_sigsqs = include_sigsqs 181 | ) 182 | } 183 | 184 | #' Get fitted draws from posterior of \code{pbart} model 185 | #' 186 | #' @inheritParams fitted_draws.wbart 187 | #' 188 | #' @return A tidy data frame (tibble) with fitted values. 189 | #' @export 190 | #' 191 | fitted_draws.pbart <- function(model, newdata, value = ".value", ..., n = NULL, include_newdata = TRUE, include_sigsqs = FALSE) { 192 | if (missing(newdata)) { 193 | newdata <- NULL 194 | } 195 | 196 | fitted_draws_BART( 197 | model = model, newdata = newdata, value = value, 198 | ..., 199 | include_newdata = include_newdata, 200 | include_sigsqs = include_sigsqs 201 | ) 202 | } 203 | 204 | #' Get fitted draws from posterior of \code{lbart} model 205 | #' 206 | #' @inheritParams fitted_draws.wbart 207 | #' 208 | #' @return A tidy data frame (tibble) with fitted values. 209 | #' @export 210 | #' 211 | fitted_draws.lbart <- function(model, newdata, value = ".value", ..., n = NULL, include_newdata = TRUE, include_sigsqs = FALSE) { 212 | if (missing(newdata)) { 213 | newdata <- NULL 214 | } 215 | 216 | fitted_draws_BART( 217 | model = model, newdata = newdata, value = value, 218 | ..., 219 | include_newdata = include_newdata, 220 | include_sigsqs = include_sigsqs 221 | ) 222 | } 223 | 224 | #' Get fitted draws from posterior of \code{mbart} model 225 | #' 226 | #' @inheritParams fitted_draws.wbart 227 | #' 228 | #' @return A tidy data frame (tibble) with fitted values. 229 | #' @export 230 | #' 231 | fitted_draws.mbart <- function(model, newdata, value = ".value", ..., n = NULL, include_newdata = TRUE, include_sigsqs = FALSE) { 232 | if (missing(newdata)) { 233 | newdata <- NULL 234 | } 235 | 236 | fitted_draws_BART( 237 | model = model, newdata = newdata, value = value, 238 | ..., 239 | include_newdata = include_newdata, 240 | include_sigsqs = include_sigsqs 241 | ) 242 | } 243 | 244 | #' Get fitted draws from posterior of \code{mbart2} model 245 | #' 246 | #' @inheritParams fitted_draws.wbart 247 | #' 248 | #' @return A tidy data frame (tibble) with fitted values. 249 | #' @export 250 | #' 251 | fitted_draws.mbart2 <- function(model, newdata, value = ".value", ..., n = NULL, include_newdata = TRUE, include_sigsqs = FALSE) { 252 | if (missing(newdata)) { 253 | newdata <- NULL 254 | } 255 | 256 | fitted_draws_BART( 257 | model = model, newdata = newdata, value = value, 258 | ..., 259 | include_newdata = include_newdata, 260 | include_sigsqs = include_sigsqs 261 | ) 262 | } 263 | 264 | #' Get predict draws from posterior of \code{wbart} model 265 | #' 266 | #' @param object A \code{wbart} model. 267 | #' @param newdata Data frame to generate predictions from. If omitted, most model types will generate predictions from the data used to fit the model. 268 | #' @param value The name of the output column for \code{predicted_draws}; default \code{".prediction"}. 269 | #' @param ndraws Not currently implemented. 270 | #' @param include_newdata Should the newdata be included in the tibble? 271 | #' @param include_fitted Should the posterior fitted values be included in the tibble? 272 | #' @param include_sigsqs Should the posterior sigma-squared draw be included? 273 | #' @param ... Use to specify random number generator, default is \code{rng=stats::rnorm}. 274 | #' 275 | #' @return A tidy data frame (tibble) with predicted values. 276 | #' @export 277 | #' 278 | predicted_draws.wbart <- function(object, newdata, value = ".prediction", ..., ndraws = NULL, include_newdata = TRUE, include_fitted = FALSE, include_sigsqs = FALSE) { 279 | if (missing(newdata)) { 280 | newdata <- NULL 281 | } 282 | 283 | predicted_draws_BART( 284 | object = object, newdata = newdata, 285 | value = value, 286 | include_newdata = include_newdata, 287 | include_sigsqs = include_sigsqs, ... 288 | ) 289 | } 290 | 291 | #' Get residual draw for \code{wbart} model 292 | #' 293 | #' The original response variable must be passed as an argument to this function. 294 | #' e.g. `response = y` 295 | #' 296 | #' @param object \code{wbart} model. 297 | #' @param newdata Data frame to generate predictions from. If omitted, original data used to fit the model. 298 | #' @param value Name of the output column for residual_draws; default is \code{.residual}. 299 | #' @param ... Additional arguments passed to the underlying prediction method for the type of model given. 300 | #' @param include_newdata Should the newdata be included in the tibble? 301 | #' @param include_sigsqs Should the posterior sigma-squared draw be included? 302 | #' @param ndraws Not currently implemented. 303 | #' 304 | #' @return Tibble with residuals. 305 | #' @export 306 | #' 307 | residual_draws.wbart <- function(object, newdata, value = ".residual", ..., ndraws = NULL, include_newdata = TRUE, include_sigsqs = FALSE) { 308 | if (missing(newdata)) { 309 | newdata <- NULL 310 | } 311 | 312 | residual_draws_BART( 313 | object = object, newdata = newdata, value = value, 314 | include_newdata = include_newdata, 315 | include_sigsqs = include_sigsqs, ... 316 | ) 317 | } 318 | 319 | #' Get residual draw for \code{pbart} model 320 | #' 321 | #' The original response variable must be passed as an argument to this function. 322 | #' e.g. `response = y` 323 | #' 324 | #' @inheritParams residual_draws.wbart 325 | #' 326 | #' @return Tibble with residuals. 327 | #' @export 328 | #' 329 | residual_draws.pbart <- function(object, newdata, value = ".residual", ..., ndraws = NULL, include_newdata = TRUE, include_sigsqs = FALSE) { 330 | if (missing(newdata)) { 331 | newdata <- NULL 332 | } 333 | 334 | residual_draws_BART( 335 | object = object, newdata = newdata, value = value, 336 | include_newdata = include_newdata, 337 | include_sigsqs = include_sigsqs, ... 338 | ) 339 | } 340 | -------------------------------------------------------------------------------- /vignettes/use-tidytreatment-BART.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Using the tidytreatment package with BART" 3 | author: "Joshua J Bon" 4 | date: "`r Sys.Date()`" 5 | bibliography: vignette.bib 6 | output: rmarkdown::html_vignette 7 | vignette: > 8 | %\VignetteIndexEntry{Using the tidytreatment package with BART} 9 | %\VignetteEngine{knitr::rmarkdown} 10 | %\VignetteEncoding{UTF-8} 11 | --- 12 | 13 | ```{r setup, include = FALSE} 14 | knitr::opts_chunk$set( 15 | collapse = TRUE, 16 | comment = "#>", 17 | fig.dim = c(6, 4) 18 | ) 19 | 20 | suppressPackageStartupMessages({ 21 | library(BART) 22 | library(tidytreatment) 23 | library(dplyr) 24 | library(tidybayes) 25 | library(ggplot2) 26 | }) 27 | 28 | # load pre-computed data and model 29 | sim <- suhillsim1 30 | te_model <- bartmodel1 31 | 32 | # pre compute 33 | posterior_treat_eff <- treatment_effects(te_model, treatment = "z", newdata = sim$data) 34 | posterior_treat_eff_on_treated <- treatment_effects(te_model, treatment = "z", newdata = sim$dat, subset = "treated") 35 | 36 | ``` 37 | 38 | This vignette demonstrates an example workflow for heterogeneous treatment effect models using the `BART` package for fitting Bayesian Additive Regression Trees and `tidytreatment` for investigating the output of such models. The `tidytreatment` package can also be used with `bartMachine` models, support for `bcf` is coming soon (see branch `bcf-hold` on github). 39 | 40 | ## Simulate data 41 | 42 | Below we load packages and simulate data using the scheme described by @Hill2013 with the additional of 1 categorical variable. It it implemented in the function `simulate_hill_su_data()`: 43 | 44 | ```{r load-data-print, echo = TRUE, eval = FALSE} 45 | 46 | # load packages 47 | library(BART) 48 | library(tidytreatment) 49 | library(dplyr) 50 | library(tidybayes) 51 | library(ggplot2) 52 | 53 | # set seed so vignette is reproducible 54 | set.seed(101) 55 | 56 | # simulate data 57 | sim <- simulate_su_hill_data(n = 100, treatment_linear = FALSE, omega = 0, add_categorical = TRUE, 58 | coef_categorical_treatment = c(0,0,1), 59 | coef_categorical_nontreatment = c(-1,0,-1) 60 | ) 61 | 62 | ``` 63 | 64 | Now we can take a look at some data summaries. 65 | 66 | ```{r data-summary, echo = TRUE, eval = TRUE} 67 | 68 | # non-treated vs treated counts: 69 | table(sim$data$z) 70 | 71 | dat <- sim$data 72 | # a selection of data 73 | dat %>% select(y, z, c1, x1:x3) %>% head() 74 | 75 | ``` 76 | 77 | ## Fit the regression model 78 | 79 | Run the model to be used to assess treatment effects. Here we will use `BART`, which is one implementation of Bayesian Additive Regression Trees in `R` [@Chipman2010; @sparapani2016]. The package can be found on [CRAN](https://cran.r-project.org/package=BART). 80 | 81 | We are following the procedure in @Hahn2020 (albeit without their more sophisticated model) where we estimate a propensity score for being assigned to the treatment regime, which improves estimation properties. The procedure is roughly as follows: 82 | 83 | 1. Fit 'variable selection' model (VS): Regress outcome against covariates (excluding treatment variable) 84 | 2. Select a subset of covariate from the VS model which are most associated with the outcome 85 | 3. Fit a 'propensity score' model (PS): A probit/logit model estimating the propensity score using only the covariates selected in step 2 86 | 4. Fit the treatment effect model (TE): Using the original covariates and propensity score from step 3 87 | 88 | ```{r run-bart, echo = TRUE, eval = FALSE} 89 | 90 | # STEP 1 VS Model: Regress y ~ covariates 91 | var_select_bart <- wbart(x.train = select(dat,-y,-z), 92 | y.train = pull(dat, y), 93 | sparse = TRUE, 94 | nskip = 2000, 95 | ndpost = 5000) 96 | 97 | # STEP 2: Variable selection 98 | # select most important vars from y ~ covariates model 99 | # very simple selection mechanism. Should use cross-validation in practice 100 | covar_ranking <- covariate_importance(var_select_bart) 101 | var_select <- covar_ranking %>% 102 | filter(avg_inclusion >= quantile(avg_inclusion, 0.5)) %>% 103 | pull(variable) 104 | 105 | # change categorical variables to just one variable 106 | var_select <- unique(gsub("c1[1-3]$","c1", var_select)) 107 | 108 | var_select 109 | 110 | # STEP 3 PS Model: Regress z ~ selected covariates 111 | # BART::pbart is for probit regression 112 | prop_bart <- pbart( 113 | x.train = select(dat, all_of(var_select)), 114 | y.train = pull(dat, z), 115 | nskip = 2000, 116 | ndpost = 5000 117 | ) 118 | 119 | # store propensity score in data 120 | dat$prop_score <- prop_bart$prob.train.mean 121 | 122 | # Step 4 TE Model: Regress y ~ z + covariates + propensity score 123 | te_model <- wbart( 124 | x.train = select(dat,-y), 125 | y.train = pull(dat, y), 126 | nskip = 10000L, 127 | ndpost = 200L, #* 128 | keepevery = 100L #* 129 | ) 130 | 131 | #* The posterior samples are kept small to manage size on CRAN 132 | 133 | ``` 134 | 135 | ## Extract the posterior (tidy style) 136 | 137 | Methods for extracting the posterior in a tidy format is included in the `tidytreatment`. 138 | 139 | ```{r tidy-bart-fit, echo=TRUE, cache=FALSE} 140 | 141 | posterior_fitted <- fitted_draws(te_model, value = "fit", include_newdata = FALSE) 142 | # include_newdata = FALSE, avoids returning the newdata with the fitted values 143 | # as it is so large. newdata argument must be specified for this option in BART models. 144 | # The `.row` variable makes sure we know which row in the newdata the fitted 145 | # value came from (if we dont include the data in the result). 146 | 147 | posterior_fitted 148 | 149 | ``` 150 | 151 | ```{r tidy-bart-pred, eval=FALSE, echo=TRUE, cache=FALSE} 152 | 153 | # Function to tidy predicted draws also, this adds random normal noise by default 154 | posterior_pred <- predicted_draws(te_model, include_newdata = FALSE) 155 | 156 | ``` 157 | 158 | ## Use some plotting functions from the `tidybayes` package 159 | 160 | Since `tidytreatment` follows the `tidybayes` output specifications, functions from `tidybayes` should work. 161 | 162 | ```{r plot-tidy-bart, echo=TRUE, cache=FALSE} 163 | 164 | treatment_var_and_c1 <- 165 | dat %>% 166 | select(z,c1) %>% 167 | mutate(.row = 1:n(), z = as.factor(z)) 168 | 169 | posterior_fitted %>% 170 | left_join(treatment_var_and_c1, by = ".row") %>% 171 | ggplot() + 172 | stat_halfeye(aes(x = z, y = fit)) + 173 | facet_wrap(~c1, labeller = as_labeller( function(x) paste("c1 =",x) ) ) + 174 | xlab("Treatment (z)") + ylab("Posterior predicted value") + 175 | theme_bw() + ggtitle("Effect of treatment with 'c1' on posterior fitted values") 176 | 177 | ``` 178 | 179 | ## Calculate Treatment Effects 180 | 181 | Posterior conditional (average) treatment effects can be calculated using the `treatment_effects` function. This function finds the posterior values of 182 | $$ 183 | \tau(x) = \text{E}(y ~ \vert~ T = 1, X = x) - \text{E}(y ~ \vert~ T = 0, X = x) 184 | $$ 185 | for each unit of measurement, $i$, (e.g. subject) in the data sample. 186 | 187 | Some histogram summaries are presented below. 188 | 189 | ```{r post-treatment, eval = FALSE} 190 | 191 | # sample based (using data from fit) conditional treatment effects, posterior draws 192 | posterior_treat_eff <- 193 | treatment_effects(te_model, treatment = "z", newdata = dat) 194 | 195 | ``` 196 | ```{r cates-hist, echo=TRUE, cache=FALSE} 197 | 198 | # Histogram of treatment effect (all draws) 199 | posterior_treat_eff %>% 200 | ggplot() + 201 | geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") + 202 | theme_bw() + ggtitle("Histogram of treatment effect (all draws)") 203 | 204 | # Histogram of treatment effect (median for each subject) 205 | posterior_treat_eff %>% summarise(cte_hat = median(cte)) %>% 206 | ggplot() + 207 | geom_histogram(aes(x = cte_hat), binwidth = 0.1, colour = "white") + 208 | theme_bw() + ggtitle("Histogram of treatment effect (median for each subject)") 209 | 210 | ``` 211 | ```{r att-ate, eval=FALSE} 212 | # get the ATE and ATT directly: 213 | 214 | posterior_ate <- tidy_ate(te_model, treatment = "z", newdata = dat) 215 | posterior_att <- tidy_att(te_model, treatment = "z", newdata = dat) 216 | 217 | ``` 218 | 219 | ```{r ate-trace-setup, eval = TRUE, echo = FALSE} 220 | 221 | posterior_ate <- posterior_treat_eff %>% group_by(.chain, .iteration, .draw) %>% 222 | summarise(ate = mean(cte), .groups = "drop") 223 | 224 | ``` 225 | 226 | We can create a trace plot for the treatment effect summaries easily too: 227 | 228 | ```{r ate-trace, eval=TRUE, echo=TRUE} 229 | 230 | posterior_ate %>% ggplot(aes(x = .draw, y = ate)) + 231 | geom_line() + 232 | theme_bw() + 233 | ggtitle("Trace plot of ATE") 234 | 235 | ``` 236 | 237 | We can also focus on the treatment effects for just those that are treated. 238 | 239 | ```{r post-te-treated, echo=TRUE, eval=FALSE} 240 | 241 | # sample based (using data from fit) conditional treatment effects, posterior draws 242 | posterior_treat_eff_on_treated <- 243 | treatment_effects(te_model, treatment = "z", newdata = dat, subset = "treated") 244 | 245 | ``` 246 | 247 | ```{r cates-hist-treated, echo=TRUE, cache=FALSE} 248 | 249 | posterior_treat_eff_on_treated %>% 250 | ggplot() + 251 | geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") + 252 | theme_bw() + ggtitle("Histogram of treatment effect (all draws from treated subjects)") 253 | 254 | ``` 255 | 256 | Plots can be made that stack each subjects posterior CIs of the CATEs. 257 | 258 | ```{r cates-stack-plot, echo=TRUE, cache=FALSE} 259 | 260 | posterior_treat_eff %>% select(-z) %>% point_interval() %>% 261 | arrange(cte) %>% mutate(.orow = 1:n()) %>% 262 | ggplot() + 263 | geom_interval(aes(x = .orow, y= cte), size = 0.5) + 264 | geom_point(aes(x = .orow, y = cte), shape = "circle open", alpha = 0.5) + 265 | ylab("Median posterior CATE for each subject (95% CI)") + 266 | theme_bw() + coord_flip() + scale_colour_brewer() + 267 | theme(axis.title.y = element_blank(), 268 | axis.text.y = element_blank(), 269 | axis.ticks.y = element_blank(), 270 | legend.position = "none") 271 | 272 | ``` 273 | 274 | We can also plot the CATEs varying over particular covariates. In this example, instead of grouping by subject, we group by the variable of interest, and calculate the posterior summaries over this variable. 275 | 276 | ```{r cates-line-plot, echo=TRUE, cache=FALSE} 277 | 278 | posterior_treat_eff %>% 279 | left_join(tibble(c1 = dat$c1, .row = 1:length(dat$c1) ), by = ".row") %>% 280 | group_by(c1) %>% 281 | ggplot() + 282 | stat_halfeye(aes(x = c1, y = cte), alpha = 0.7) + 283 | scale_fill_brewer() + 284 | theme_bw() + ggtitle("Treatment effect by `c1`") 285 | 286 | 287 | ``` 288 | 289 | ## Common support calculations 290 | 291 | Common support testing [@Hill2013] can be tested directly, or a Boolean can be included when calculating the treatment effects. 292 | 293 | ```{r common-support, echo=TRUE, results='hide', cache=FALSE} 294 | 295 | # calculate common support directly 296 | # argument 'modeldata' must be specified for BART models 297 | csupp_chisq <- has_common_support(te_model, treatment = "z", modeldata = dat, 298 | method = "chisq", cutoff = 0.05) 299 | 300 | csupp_chisq %>% filter(!common_support) 301 | 302 | csupp_sd <- has_common_support(te_model, treatment = "z", modeldata = dat, 303 | method = "sd", cutoff = 1) 304 | csupp_sd %>% filter(!common_support) 305 | 306 | # calculate treatment effects (on those who were treated) 307 | # and include only those estimates with common support 308 | posterior_treat_eff_on_treated <- 309 | treatment_effects(te_model, treatment = "z", subset = "treated", newdata = dat, 310 | common_support_method = "sd", cutoff = 1) 311 | 312 | ``` 313 | 314 | ## Investigating variable importance 315 | 316 | We can count how many times a variable was included in the BART (on average) in conjunction with the treatment effect, or overall. This method uses a simple average of occurrences, see @bleich2014variable for more sophisticated methods. 317 | 318 | ```{r interaction-investigator, echo=TRUE, cache=FALSE} 319 | 320 | treatment_interactions <- 321 | covariate_with_treatment_importance(te_model, treatment = "z") 322 | 323 | treatment_interactions %>% 324 | ggplot() + 325 | geom_bar(aes(x = variable, y = avg_inclusion), stat = "identity") + 326 | theme_bw() + ggtitle("Important variables interacting with treatment ('z')") + ylab("Inclusion counts") + 327 | theme(axis.text.x = element_text(angle = 45, hjust=1)) 328 | 329 | variable_importance <- 330 | covariate_importance(te_model) 331 | 332 | variable_importance %>% 333 | ggplot() + 334 | geom_bar(aes(x = variable, y = avg_inclusion), stat = "identity") + 335 | theme_bw() + ggtitle("Important variables overall") + 336 | ylab("Inclusion counts") + 337 | theme(axis.text.x = element_text(angle = 45, hjust=1)) 338 | 339 | ``` 340 | 341 | ## Model checking and convergence 342 | 343 | Here are some examples of model checking we can do. 344 | 345 | Code for trace plot of model variance ($\sigma^2$). 346 | 347 | ```{r sigma-trace, echo=TRUE, cache=FALSE} 348 | 349 | # includes skipped MCMC samples 350 | variance_draws(te_model, value = "siqsq") %>% 351 | filter(.draw > 10000) %>% 352 | ggplot(aes(x = .draw, y = siqsq)) + 353 | geom_line() + 354 | theme_bw() + 355 | ggtitle("Trace plot of model variance post warm-up") 356 | 357 | ``` 358 | Code for examining model residuals. 359 | 360 | ```{r convergence-bart, echo=TRUE, cache=FALSE} 361 | 362 | res <- residual_draws(te_model, response = pull(dat, y), include_newdata = FALSE) 363 | res %>% 364 | point_interval(.residual, y, .width = c(0.95) ) %>% 365 | select(-y.lower, -y.upper) %>% 366 | ggplot() + 367 | geom_pointinterval(aes(x = y, y = .residual, ymin = .residual.lower, ymax = .residual.upper), alpha = 0.2) + 368 | scale_fill_brewer() + 369 | theme_bw() + ggtitle("Residuals vs observations") 370 | 371 | res %>% summarise(.fitted = mean(.fitted), y = first(y)) %>% 372 | ggplot(aes(x = y, y = .fitted)) + 373 | geom_point() + 374 | geom_smooth(method = "lm") + 375 | theme_bw() + ggtitle("Observations vs fitted") 376 | 377 | res %>% summarise(.residual = mean(.residual)) %>% 378 | ggplot(aes(sample = .residual)) + 379 | geom_qq() + 380 | geom_qq_line() + 381 | theme_bw() + ggtitle("Q-Q plot of residuals") 382 | 383 | ``` 384 | 385 | ## References 386 | --------------------------------------------------------------------------------