├── .DS_Store ├── .Rbuildignore ├── .gitignore ├── DESCRIPTION ├── DRsims.Rmd ├── NAMESPACE ├── R ├── .DS_Store ├── hal_cate_partially_linear.R ├── hal_cate_plugin.R ├── inference.R ├── isoreg.R ├── lasso_cate_partially_linear.R ├── truncate_pscore_adaptive.R └── utils.R ├── README.md ├── causalHAL.Rproj ├── man ├── fit_cate_hal_partially_linear.Rd ├── fit_cate_lasso_partially_linear.Rd ├── fit_hal_cate_plugin.Rd ├── hello.Rd ├── inference_ate.Rd ├── inference_cate.Rd └── isoreg_with_xgboost.Rd ├── simResults └── .DS_Store ├── simulationScripts ├── .DS_Store ├── R_setup.R ├── Rout │ ├── par-24844814.err │ ├── par-24844814.out │ ├── par-24844815.err │ ├── par-24844815.out │ ├── par-24844816.err │ ├── par-24844816.out │ ├── par-24844817.err │ ├── par-24844817.out │ ├── par-24844818.err │ ├── par-24844818.out │ ├── par-24844819.err │ ├── par-24844819.out │ ├── par-24844820.err │ ├── par-24844820.out │ ├── par-24844821.err │ ├── par-24844821.out │ ├── par-24844822.err │ ├── par-24844822.out │ ├── par-24844823.err │ ├── par-24844823.out │ ├── par-24844824.err │ ├── par-24844824.out │ ├── par-24844825.err │ ├── par-24844825.out │ ├── par-24844826.err │ ├── par-24844826.out │ ├── par-24844827.err │ ├── par-24844827.out │ ├── par-24844828.err │ ├── par-24844828.out │ ├── par-24844829.err │ ├── par-24844829.out │ ├── par-24844830.err │ ├── par-24844830.out │ ├── par-24844831.err │ ├── par-24844831.out │ ├── par-24844832.err │ ├── par-24844832.out │ ├── par-24844833.err │ ├── par-24844833.out │ ├── par-24844834.err │ ├── par-24844834.out │ ├── par-24844835.err │ ├── par-24844835.out │ ├── par-24844836.err │ ├── par-24844836.out │ ├── par-24844837.err │ └── par-24844837.out ├── install_packages.R ├── install_packages.sbatch ├── simScriptAdapt.R ├── simScriptAdapt.sbatch ├── simScriptAdaptLocal.sbatch ├── simsAdapt.sh └── simsAdaptLocal.sh └── vignette.Rmd /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Larsvanderlaan/AdaptiveDML/3d35b5d2c0444511e53d4c6e554d2c5123945e27/.DS_Store -------------------------------------------------------------------------------- /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^.*\.Rproj$ 2 | ^\.Rproj\.user$ 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .RData 4 | .Ruserdata 5 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: causalHAL 2 | Type: Package 3 | Title: Automated nonparametric causal inference using the highly adaptive lasso. 4 | Version: 0.1.0 5 | Author: Who wrote it 6 | Maintainer: Lars van der Laan 7 | Description: More about what it does (maybe more than one line) 8 | Use four spaces when indenting paragraphs within the Description. 9 | License: What license is it under? 10 | Encoding: UTF-8 11 | LazyData: true 12 | RoxygenNote: 7.3.2 13 | Imports: 14 | hal9001, 15 | sl3, 16 | stats, 17 | data.table, 18 | R6, 19 | stringr, 20 | origami, 21 | delayed, 22 | future, 23 | earth, 24 | fastglm, 25 | doMC, 26 | glmnet 27 | Depends: 28 | hal9001, 29 | sl3, 30 | stats, 31 | data.table, 32 | R6, 33 | stringr, 34 | origami, 35 | delayed, 36 | future, 37 | earth, 38 | fastglm, 39 | doMC, 40 | glmnet 41 | Remotes: github::tlverse/hal9001@screeningHAL, github::tlverse/sl3@develVersionChangeLars 42 | -------------------------------------------------------------------------------- /DRsims.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "simsDR" 3 | output: html_document 4 | date: '2023-07-05' 5 | --- 6 | 7 | ```{r setup, include=FALSE} 8 | knitr::opts_chunk$set(echo = TRUE) 9 | ``` 10 | 11 | 12 | 13 | 14 | 15 | ```{r} 16 | 17 | results <- out 18 | results[, coverage := CI_left <= ATE & CI_right >= ATE] 19 | results[, sd := abs(CI_left -CI_right)/1.96/2] 20 | 21 | results 22 | results <- unique(results[, .(sd = mean(sd), se = sd(estimate-ATE), bias = abs(mean(estimate - ATE)), coverage = mean(coverage)), by = c( "misp", "estimator", "lrnr")]) 23 | results[, rmse := sqrt(bias^2 + se^2)] 24 | 25 | results 26 | ``` 27 | 28 | Do table 29 | 30 | For each lrnr. 31 | n Treatment Outcome Both 32 | mse (coverage) 33 | 34 | 35 | ```{r} 36 | results[results$lrnr=="xgboost",] 37 | 38 | ``` 39 | ```{r} 40 | library(data.table) 41 | 42 | n_list <- c( 1000, 2000, 3000, 4000, 5000 ) 43 | pos_list <- c(2 ) 44 | 45 | results <- 46 | rbindlist(unlist(lapply(n_list, function(n){ 47 | unlist(lapply(pos_list, function(pos) { 48 | lapply(c(0), function(misp) { 49 | try({ 50 | 51 | key <- paste0("DR_iter=", "1000", "_n=", n, "_pos=", pos ) 52 | data <- fread(paste0("./simResultsDR/sim_results_", key, ".csv")) 53 | 54 | return(data) 55 | }) 56 | return(data.table()) 57 | }) 58 | }), recursive = F) 59 | }), recursive = F)) 60 | 61 | 62 | results[, coverage := CI_left <= ATE & CI_right >= ATE] 63 | results[, sd := abs(CI_left -CI_right)/1.96/2] 64 | 65 | results <- unique(results[, .(sd = mean(sd), se = sd(estimate-ATE), bias = abs(mean(estimate - ATE)), coverage = mean(coverage)), by = c("pos_const", "n", "misp", "estimator", "lrnr")]) 66 | results[, rmse := sqrt(bias^2 + se^2)] 67 | 68 | library(ggplot2) 69 | w <- 7 70 | h <- 5 71 | results$misp <- c("Both", "Treatment", "Outcome", "Neither")[match(results$misp, c("1", "2", "3", "4"))] 72 | 73 | for(pos_const in unique(results$pos_const)) { 74 | for(lrnr in c("pooled")) { 75 | for(misp in c("Both", "Treatment", "Outcome", "Neither")) { 76 | tmp <- as.data.frame(results)[results$pos_const == pos_const,] 77 | #tmp <- as.data.frame(tmp)[tmp$lrnr == lrnr,] 78 | tmp <- as.data.frame(tmp)[tmp$misp == misp,] 79 | 80 | 81 | #maxval <- 3*min(tmp$rmse[tmp$n ==500]) 82 | maxval <- max(tmp$rmse ) 83 | tmp$bias <- pmin(tmp$bias, maxval) 84 | tmp$se <- pmin(tmp$se, maxval) 85 | tmp$mse <- pmin(tmp$rmse, maxval) 86 | 87 | limits <- c(0, maxval + .01) 88 | p <- ggplot(tmp, aes(x= n, y = bias, color = lrnr, shape= lrnr)) + geom_point(size = 4) + geom_line( color="grey", linetype = "dashed") + facet_wrap(~estimator, ncol =2) + theme_bw() + theme( text = element_text(size=18), axis.text.x = element_text(size = 14 , hjust = 1, vjust = 0.5), legend.position = "bottom", legend.box = "horizontal" 89 | ) + labs(x = "Sample Size (n)", y = "Bias", color = "Estimator", group = "Estimator", shape= "Estimator") + scale_y_continuous( limits = limits) 90 | 91 | ggsave( filename = paste0("DR=", pos_const,lrnr,misp, "Bias.pdf") , width = w, height = h) 92 | 93 | 94 | 95 | 96 | 97 | 98 | ggplot(tmp, aes(x= n, y = se, color = lrnr, shape= lrnr)) + geom_point(size = 4) + geom_line( color="grey", linetype = "dashed") + facet_wrap(~estimator, ncol =2) + theme_bw() + theme( text = element_text(size=18), axis.text.x = element_text(size = 14 , hjust = 1, vjust = 0.5), legend.position = "bottom", legend.box = "horizontal" 99 | ) + labs(x = "Sample Size (n)", y = "Standard Error", color = "Estimator", group = "Estimator", shape= "Estimator") + scale_y_continuous( limits = limits) 100 | 101 | ggsave(filename = paste0("DR=", pos_const,lrnr,misp, "SE.pdf"), width = w, height = h) 102 | 103 | 104 | ggplot(tmp, aes(x= n, y = mse, color = lrnr, shape= lrnr)) + geom_point(size = 4) + geom_line( color="grey", linetype = "dashed") + facet_wrap(~estimator, ncol =2) + theme_bw() + theme( text = element_text(size=18), axis.text.x = element_text(size = 14 , hjust = 1, vjust = 0.5), legend.position = "bottom", legend.box = "horizontal" 105 | ) + labs(x = "Sample Size (n)", y = "Root Mean Square Error", color = "Estimator", group = "Estimator", shape= "Estimator") + scale_y_continuous( limits = limits) 106 | 107 | ggsave(filename = paste0("DR=", pos_const,lrnr,misp, "MSE.pdf"), width = w, height = h) 108 | 109 | p <- ggplot(tmp, aes(x= n, y = coverage, color = lrnr, shape= lrnr)) + geom_point(size = 4) + geom_line( color="grey", linetype = "dashed") + facet_wrap(~estimator, ncol =2)+ scale_y_continuous(limits = c(min(tmp$coverage), 0.97)) + geom_hline(yintercept = 0.95, color = "grey") + theme_bw() + theme( text = element_text(size=18), axis.text.x = element_text(size = 14 , hjust = 1, vjust = 0.5), legend.position = "bottom", legend.box = "horizontal" 110 | ) + labs(x = "Sample Size (n)", y = "CI Coverage", color = "Estimator", group = "Estimator", shape= "Estimator") 111 | 112 | 113 | ggsave(filename = paste0("DR=", pos_const,lrnr,misp, "CI.pdf"), width = w, height = h) 114 | } 115 | }} 116 | 117 | 118 | p <- ggplot(tmp, aes(x= n, y = coverage, color = lrnr, shape= lrnr)) + geom_point(size = 4) + facet_wrap(~estimator, ncol =2)+ scale_y_log10() + geom_hline(yintercept = 0.95, color = "grey") + theme_bw() + theme( text = element_text(size=18), axis.text.x = element_text(size = 14 , hjust = 1, vjust = 0.5), legend.position = "bottom", legend.box = "horizontal" , legend.title = element_blank() 119 | ) + labs(x = "Sample Size (n)", y = "CI Coverage", color = "estimator", group = "estimator", shape= "estimator") 120 | 121 | 122 | ``` 123 | 124 | 125 | 126 | ```{r} 127 | 128 | results[n >= 1000 & pos == 4e-02] 129 | 130 | ``` 131 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | export(fit_cate_hal_partially_linear) 4 | export(fit_cate_lasso_partially_linear) 5 | export(fit_hal_cate_plugin) 6 | export(inference_ate) 7 | export(isoreg_with_xgboost) 8 | export(truncate_pscore_adaptive) 9 | import(glmnet) 10 | import(hal9001) 11 | -------------------------------------------------------------------------------- /R/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Larsvanderlaan/AdaptiveDML/3d35b5d2c0444511e53d4c6e554d2c5123945e27/R/.DS_Store -------------------------------------------------------------------------------- /R/hal_cate_partially_linear.R: -------------------------------------------------------------------------------- 1 | #' Doubly-robust nonparametric superefficient estimation of the partial conditional average treatment effect 2 | #' using the highly adaptive lasso R-learner 3 | #' 4 | #' This function estimates the Partial Conditional Average Treatment Effect (CATE) function `w -> tau(w) := E[Y | A=1, W=w] - E[Y | A=0, W=w]` 5 | #' within the regression model `E[Y | A, W] = E[Y | A=0, W] + A * tau(W)`. 6 | #' 7 | #' This method implements a doubly-robust and superefficient estimation technique for the Partial CATE using the highly adaptive lasso R-learner. 8 | #' It can data-adaptivelly learn complex relationships in the CATE function, while benefiting from simpler structure and parsimony when present. 9 | #' By using the highly adaptive lasso, the method aims to provide robust and precise nonparametric inference of the Partial CATE. 10 | #' 11 | #' @param W A numeric matrix of covariate values. 12 | #' @param A A numeric vector of treatment values. Can be binary or continuous. 13 | #' @param Y A numeric vector of outcome values. 14 | #' @param weights (Optional) A numeric vector of observation weights. 15 | #' @param pi.hat A numeric vector containing estimated propensity scores `pi(W) := P(A=1 | W)`. 16 | #' @param m.hat A numeric vector containing estimates of treatment-marginalized outcome regression `m(W) := E[Y | W]`. 17 | #' @param sl3_Lrnr_pi.hat If `pi.hat` is not provided, a `Lrnr_base` object for estimation of `pi(W) := P(A=1 | W)`. 18 | #' @param sl3_Lrnr_m.hat If `m.hat` is not provided, a `Lrnr_base` object for estimation of `m(W) := E[Y | W)`. 19 | #' @param formula_cate (Optional) A `hal9001`-formatted formula object for the CATE to be passed to `formula_hal`. 20 | #' By default, the CATE model is learned data-adaptively using MARS-based screening and HAL. See documentation for `fit_hal`. 21 | #' @param max_degree_cate (Optional) Same as `max_degree` but for CATE model. See documentation for `fit_hal`. 22 | #' @param num_knots_cate (Optional) Same as `num_knots` but for CATE model. See documentation for `fit_hal`. 23 | #' @param smoothness_orders_cate (Optional) Same as `smoothness_orders` but for CATE model. See documentation for `fit_hal`. 24 | #' @param ... Other arguments to be passed to `fit_hal`. 25 | #' 26 | #' @import hal9001 27 | #' @export 28 | fit_cate_hal_partially_linear <- function(W, A, Y, weights = NULL, pi.hat = NULL, m.hat = NULL, formula_cate = NULL, max_degree_cate = 1, smoothness_orders_cate = 1, num_knots_cate = c(50), sl3_Lrnr_pi.hat= NULL, sl3_Lrnr_m.hat = NULL, verbose = TRUE,...) { 29 | if(!is.matrix(W)) W <- as.matrix(W) 30 | if(is.null(weights)) weights <- rep(1, length(Y)) 31 | 32 | X <- cbind(W,A) 33 | 34 | ############################################### 35 | ### Fit outcome regression m(W) := E[Y | W] 36 | ############################################### 37 | if(verbose) print("Fitting m = E[Y|W]") 38 | # If given use it. 39 | if(!is.null(m.hat)) { 40 | m <- m.hat 41 | fit_m <- NULL 42 | } else { 43 | # use sl3 learner if provided 44 | task_Y <- sl3_Task$new(data.table(W, Y= Y), covariates = c(colnames(W)), outcome = "Y", outcome_type = "continuous") 45 | fit_m <- sl3_Lrnr_m.hat$train(task_Y) 46 | m <- fit_m$predict(task_Y) 47 | } 48 | 49 | ############################################### 50 | ### Fit propensity score pi(W) := E[A | W] 51 | ############################################### 52 | if(verbose) print("Fitting pi = E[A|W]") 53 | if(!is.null(pi.hat)) { 54 | pi <- pi.hat 55 | fit_pi <- NULL 56 | } else { 57 | task_A <- sl3_Task$new(data.table(W, A = A), covariates = colnames(W), outcome = "A", outcome_type = "continuous") 58 | fit_pi <- sl3_Lrnr_pi.hat$train(task_A) 59 | pi <-fit_pi$predict(task_A) 60 | } 61 | #### truncate propensity score to lie in (c_n,1-c_n) with data-adaptive truncation level c_n. 62 | pi <- truncate_pscore_adaptive(A, pi) 63 | 64 | ############################################### 65 | ### Fit CATE tau(W) := E[Y_1 - Y_0 | W] using HAL-based R-learner 66 | ############################################### 67 | # R learner is implemented as least-squares regression with the following pseudo-outcomes and pseudo-weights: 68 | pseudo_outcome <- ifelse(abs(A-pi)<1e-10, 0, (Y - m)/(A-pi)) 69 | pseudo_weights <- (A - pi)^2 * weights 70 | 71 | 72 | if(verbose) print("Fitting tau = E[Y|A=1,W] - E[Y|A=0,W]") 73 | 74 | 75 | # Observations with near zero weights are dropped. 76 | keep <- which( abs(A-pi) > 1e-10) 77 | fit_cate <- fit_hal(W[keep,,drop = F], pseudo_outcome[keep], formula = formula_cate, weights = pseudo_weights[keep], max_degree = max_degree_cate, num_knots = num_knots_cate, smoothness_orders = smoothness_orders_cate, family = "gaussian", screen_variables = FALSE, ... ) 78 | tau <- predict(fit_cate, new_data = W) 79 | # fit relaxed HAL to allow for plug-in inference 80 | basis_list_reduced_tau <- fit_cate$basis_list[fit_cate$coefs[-1] != 0] 81 | x_basis_tau <- cbind(1,as.matrix(hal9001::make_design_matrix(W, basis_list_reduced_tau))) 82 | print(mean(tau)) 83 | if(verbose) print(fit_cate$formula) 84 | 85 | fit_cate_relaxed <- glm.fit(x_basis_tau[keep,,drop = F], pseudo_outcome[keep], family = gaussian(), weights = pseudo_weights[keep], intercept = FALSE) 86 | beta <- coef(fit_cate_relaxed) 87 | beta[is.na(beta)] <- 0 88 | tau_relaxed <- x_basis_tau %*% beta 89 | 90 | 91 | 92 | 93 | # IF_conditional <- (A - pi) * (Y - m - (A-pi)*tau) 94 | 95 | # Compute IF map for projection of CATE 96 | IF_Y_map <- function(x_proj) { 97 | n <- length(Y) 98 | gamma <- x_basis_tau %*% solve(t(x_basis_tau) %*% diag((A-pi)^2) %*% x_basis_tau / n) %*% colMeans(x_basis_tau) 99 | IF_cate <- gamma * (A-pi) * (Y - m - (A-pi)*tau) 100 | IF_cate <- IF_cate + tau - mean(tau) 101 | } 102 | 103 | fit_cate$internal <- list(IF_Y_map = IF_Y_map, fit_EAW = fit_pi, fit_m.hat = fit_m, data = list(tau_relaxed = tau_relaxed, W = W, A = A, Y = Y, pi = pi, m = m, tau = tau, pseudo_outcome = pseudo_outcome, pseudo_weights = pseudo_weights)) 104 | class(fit_cate) <- c("hal9001", "hal_cate") 105 | return(fit_cate) 106 | } 107 | -------------------------------------------------------------------------------- /R/hal_cate_plugin.R: -------------------------------------------------------------------------------- 1 | 2 | # turn to plug-in 3 | 4 | 5 | #' Doubly-robust nonparametric superefficient estimation of the conditional average treatment effect 6 | #' using the highly adaptive lasso plug-in estimator. 7 | #' 8 | #' This method estimates the conditional average treatment effect function `w - > tau(w)` 9 | #' under the regression model `E[Y | A, W] = E[Y | A=0, W] + A * tau(W)`. 10 | #' @param W A \code{matrix} of covariate values. 11 | #' @param A A \code{numeric} binary vector of treatment values. 12 | #' @param Y A \code{numeric} vector of outcome values. 13 | #' @param Delta (Not used) 14 | #' @param weights (Optional) A \code{numeric} vector of observation weights. 15 | #' @param max_degree For estimation of nuisance functions `E[Y|W]` and `E[X|W]`. 16 | #' The maximum interaction degree of basis functions generated. 17 | #' Passed to \code{\link[hal9001]{fit_hal}} function of \code{hal9001} package. 18 | #' @param num_knots For estimation of nuisance functions `E[Y|W]` and `E[X|W]`. 19 | #' Passed to \code{\link[hal9001]{fit_hal}} function of \code{hal9001} package. 20 | #' A \code{numeric} vector of length \code{max_degree} where 21 | #' the `d`-th entry specifies the number of univariable spline knot points to use 22 | #' when generating the tensor-product basis functions of interaction degree `d`. 23 | #' @param smoothness_orders For estimation of nuisance functions `E[Y|W]` and `E[X|W]`. 24 | #' An integer taking values in (0,1,2,...) 25 | #' specifying the smoothness order of the basis functions. See documentation for \code{\link[hal9001]{fit_hal}}. 26 | #' @param family_Y A \code{\link[stats]{family}} object specifying the outcome type of the outcome \code{Y}. 27 | #' This is passed internally to \code{\link[hal9001]{fit_hal}} when estimating `E[Y | W]`. 28 | #' @param family_A A \code{\link[stats]{family}} object specifying the outcome type of the treatment \code{A}. 29 | #' This is passed internally to \code{\link[hal9001]{fit_hal}} when estimating `E[A | W]`. 30 | #' @param formula_cate (Optional) A \code{hal9001}-formatted \code{formula} object for the CATE/tau to be passed to \code{\link[hal9001]{formula_hal}}. 31 | #' By default the CATE model is learned data-adaptivelly using MARS-based screening and HAL. 32 | #' @param max_degree_cate (Optional) Same as \code{max_degree} but for CATE model. 33 | #' @param num_knots_cate (Optional) Same as \code{num_knots} but for CATE model. 34 | #' @param smoothness_orders_cate (Optional) Same as \code{smoothness_orders} but for CATE model. 35 | #' @param screen_variables Highly recommended. See documentation for \code{\link[hal9001]{fit_hal}}. 36 | #' @param screen_interactions Highly recommended. See documentation for \code{\link[hal9001]{fit_hal}}. 37 | #' @param ... Other arguments to be passed to \code{\link[hal9001]{fit_hal}}. 38 | #' @import hal9001 39 | #' @export 40 | 41 | fit_hal_cate_plugin <- function(W, A, Y,weights = NULL, formula_cate = NULL, max_degree_cate = 3, num_knots_cate = c(sqrt(length(Y)), length(Y)^(1/3), length(Y)^(1/5)), smoothness_orders_cate = 1, screen_variable_cate = TRUE, params_EY0W = list(max_degree = 3, num_knots = c(sqrt(length(Y)), length(Y)^(1/3), length(Y)^(1/5)), smoothness_orders = 1, screen_variables = TRUE), include_propensity_score = FALSE, verbose = TRUE,...) { 42 | if(!all(A %in% c(0,1))) { 43 | stop("The treatment `A` should be binary. For continuous treatments consider using `fit_hal_pcate()`.") 44 | } 45 | 46 | if(!is.matrix(W)) W <- as.matrix(W) 47 | 48 | if(is.null(weights)) weights <- rep(1, length(Y)) 49 | 50 | X <- cbind(W,A) 51 | 52 | # mu = E[Y | W] 53 | if(include_propensity_score) { 54 | if(verbose) print("Fitting pi = E[A|W]") 55 | fit_pi <- fit_hal(W, A, weights = weights, max_degree = max_degree, num_knots = num_knots, smoothness_orders = smoothness_orders, family = family_A, screen_variables = screen_variables,... ) 56 | cols_pi <- unique(unlist(lapply(fit_pi$basis_list, function(basis) {basis$cols}))) 57 | pi <- predict(fit_pi, new_data = W) 58 | if(verbose) print(fit_pi$formula) 59 | } else { 60 | pi <- NULL 61 | fit_pi <- NULL 62 | } 63 | 64 | 65 | if(verbose) print("Fitting mu0 = E[Y|A=0,W]") 66 | if(include_propensity_score){ 67 | W_pi <- cbind(W,pi) 68 | } else { 69 | W_pi <- W 70 | } 71 | subset <- A==0 72 | params_EY0W$X <- W_pi[subset, , drop = F] 73 | params_EY0W$Y <- Y[subset] 74 | params_EY0W$weights <- weights[subset] 75 | params_EY0W$family <- "gaussian" 76 | 77 | fit_mu0 <- sl3:::call_with_args(fit_hal, params_EY0W) 78 | #fit_mu0 <- fit_hal(W_pi[subset, , drop = F], Y[subset], weights = weights[subset], max_degree = max_degree, num_knots = num_knots, smoothness_orders = smoothness_orders, family = "gaussian", screen_variables = screen_variables,... ) 79 | mu0 <- predict(fit_mu0, new_data = W_pi) 80 | basis_list_reduced_mu0 <- fit_mu0$basis_list[fit_mu0$coefs[-1] != 0] 81 | x_basis_mu0 <- cbind(1,as.matrix(hal9001::make_design_matrix(W_pi, basis_list_reduced_mu0))) 82 | if(verbose) print(fit_mu0$formula) 83 | 84 | 85 | if(verbose) print("Fitting tau = E[Y|A=1,W] - E[Y|A=0,W]") 86 | subset <- A==1 87 | # Offset not supported by MARS screener so added to outcome since we use least-squares 88 | fit_cate <- fit_hal(W[subset, , drop = F], Y[subset] - mu0[subset], weights = weights[subset], max_degree = max_degree_cate, formula = formula_cate, num_knots = num_knots_cate, smoothness_orders = smoothness_orders_cate, family = "gaussian", screen_variables = screen_variable_cate,... ) 89 | tau <- predict(fit_cate, new_data = W) 90 | basis_list_reduced_tau <- fit_mu0$basis_list[fit_cate$coefs[-1] != 0] 91 | x_basis_tau <- cbind(1,as.matrix(hal9001::make_design_matrix(W, basis_list_reduced_tau))) 92 | if(verbose) print(fit_cate$formula) 93 | 94 | 95 | mu <- mu0 + A * tau 96 | mu1 <-mu0 + tau 97 | 98 | x_basis <- cbind(x_basis_mu0, A * x_basis_tau) 99 | x_basis0 <- cbind(x_basis_mu0, 0 * x_basis_tau) 100 | x_basis1 <- cbind(x_basis_mu0, 1 * x_basis_tau) 101 | 102 | 103 | fit_mu_relaxed <- glm.fit(x_basis, Y, family = gaussian(), weights = weights, intercept = FALSE) 104 | beta_mu <- coef(fit_mu_relaxed) 105 | mu1_relaxed <- x_basis1 %*% beta_mu 106 | mu0_relaxed <- x_basis0 %*% beta_mu 107 | tau_relaxed <- mu1_relaxed - mu0_relaxed 108 | 109 | 110 | # Computes Y component of IF for projection of CATE onto x_proj 111 | IF_Y_map <- function(x_proj) { 112 | n <- length(A) 113 | #M <- x_proj %*% solve((t(x_proj) %*% x_proj) / n, t(x_proj)) 114 | alpha <- x_basis %*% solve(t(x_basis) %*% x_basis/n, colMeans(x_basis1 - x_basis0) ) 115 | IF <- alpha * (Y - x_basis %*% beta_mu) + (x_basis1 - x_basis0) %*% beta_mu 116 | } 117 | 118 | # For inference function 119 | pseudo_outcome <- tau_relaxed 120 | pseudo_weights <- weights 121 | 122 | fit_cate$internal <- list(IF_Y_map = IF_Y_map, fit_EAW = fit_pi, fit_EY0W = fit_mu0, fit_cate = fit_cate, data = list(tau_relaxed = tau_relaxed, W = W, A = A, Y = Y, pi = pi, mu1 = mu1, mu0 = mu0, tau = tau, pseudo_outcome = pseudo_outcome, pseudo_weights = pseudo_weights)) 123 | class(fit_cate) <- c("hal9001", "hal_cate") 124 | return(fit_cate) 125 | } 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /R/inference.R: -------------------------------------------------------------------------------- 1 | 2 | #' Estimates and confidence intervals for the ATE. 3 | #' @param fit_cate A \code{hal_cate} object obtained from the function \code{fit_hal_cate}. 4 | #' @param alpha Significant level for confidence intervals 5 | #' @param return_cov_mat A \code{logical} for whether to return the asymptotic covariance matrix of the coefficient estimates. 6 | #' @export 7 | inference_ate <- function(fit_cate, alpha = 0.05, return_cov_mat = FALSE) { 8 | return(inference_cate(fit_cate, formula = ~ 1, alpha, return_cov_mat)) 9 | } 10 | 11 | #' Estimates and confidence intervals for the projection of the CATE onto a user-specified parametric working model. 12 | #' @param fit_cate A \code{hal_cate} object obtained from the function \code{fit_hal_cate}. 13 | #' @param formula A \code{formula} object specifying a working parametric model for the conditional average treatment effect. 14 | #' For instance, `formula = ~ 1` specifies the marginal average treatment effect `E[CATE(W)]`. 15 | #' More complex formula like `formula = ~ W1` specifies the best `W1`-linear approximation of the true CATE. 16 | #' @param alpha Significant level for confidence intervals 17 | #' @param return_cov_mat A \code{logical} for whether to return the asymptotic covariance matrix of the coefficient estimates. 18 | inference_cate <- function(fit_cate, formula = ~ 1, alpha = 0.05, return_cov_mat = FALSE) { 19 | internal <- fit_cate$internal 20 | data <- internal$data 21 | pseudo_weights <- data$pseudo_weights 22 | pseudo_outcome <- data$pseudo_outcome 23 | A <- data$A 24 | sandwich_weights <- data$sandwich_weights 25 | if(is.null(sandwich_weights)) sandwich_weights <- pseudo_weights 26 | 27 | coefs <- fit_cate$coefs 28 | basis_list <- fit_cate$basis_list[coefs[-1]!=0] 29 | coefs <- coefs[coefs!=0] 30 | x_basis <- cbind(1,as.matrix(hal9001::make_design_matrix(as.matrix(data$W), basis_list))) 31 | tau <- x_basis %*% coef(glm.fit(x_basis, pseudo_outcome, weights = pseudo_weights)) 32 | 33 | x_basis_proj <- model.matrix(formula, data = as.data.frame(data$W )) 34 | coef_proj <- coef(glm.fit(x_basis_proj, tau)) 35 | tau_proj <- x_basis_proj %*% coef_proj 36 | 37 | n <- nrow(x_basis) 38 | 39 | scale <- solve(t(x_basis_proj) %*% x_basis_proj / n) # 1 if intercept model 40 | IF_proj <- x_basis_proj * as.vector(tau - tau_proj) # residual if intercept model 41 | IF_proj <- IF_proj %*% scale 42 | 43 | IF_cate <- internal$IF_Y_map(x_basis) 44 | # x_basis_proj = x_basis then we recover nonprojection case 45 | #scale2 <- solve(t(x_basis_proj) %*% diag(pseudo_weights) %*% x_basis_proj / n) 46 | #IF_cate <- (x_basis_proj) * pseudo_weights* as.vector(pseudo_outcome - tau) 47 | #IF_cate <- IF_cate %*% scale2 48 | 49 | 50 | IF_full <- IF_cate #+ IF_proj 51 | 52 | 53 | var_mat <- var(IF_full) 54 | se <- sqrt(diag(var_mat)) / sqrt(n) 55 | 56 | CI <- matrix(coef_proj + abs(qnorm(alpha/2)) * c(-1,1) * se, ncol =2) 57 | 58 | summary <- data.table(variable = colnames(x_basis_proj), coef = coef_proj, se = se, CI_left = CI[,1], CI_right = CI[,2]) 59 | if(!return_cov_mat) return(summary) 60 | return(list(summary = summary, cov_mat = var_mat)) 61 | 62 | } 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /R/isoreg.R: -------------------------------------------------------------------------------- 1 | #' Isotonic Regression with XGBoost 2 | #' 3 | #' Fits an isotonic regression model using XGBoost with monotonic constraints. 4 | #' 5 | #' @param x A vector or matrix of predictor variables. 6 | #' @param y A vector of response variables. 7 | #' @param max_depth Integer. Maximum depth of the trees in XGBoost (default is 15). 8 | #' @param min_child_weight Numeric. Minimum sum of instance weights (Hessian) needed in a child node (default is 20). 9 | #' @param weights A vector of weights to apply to each instance during training (default is NULL, meaning equal weights). 10 | #' 11 | #' @return A function that takes a new predictor variable \code{x} and returns the model's predicted values. 12 | #' 13 | #' @details 14 | #' This function uses XGBoost to fit a monotonic increasing model to the data, enforcing isotonic regression 15 | #' through the use of monotonic constraints. The model is trained with one boosting round to achieve a fit 16 | #' that is interpretable as an isotonic regression. 17 | #' 18 | #' @examples 19 | #' \dontrun{ 20 | #' # Example data 21 | #' x <- matrix(rnorm(100), ncol = 1) 22 | #' y <- sort(rnorm(100)) 23 | #' 24 | #' # Fit the model 25 | #' iso_model <- isoreg_with_xgboost(x, y) 26 | #' 27 | #' # Predict on new data 28 | #' x_new <- matrix(rnorm(10), ncol = 1) 29 | #' predictions <- iso_model(x_new) 30 | #' } 31 | #' 32 | #' @export 33 | isoreg_with_xgboost <- function(x, y, max_depth = 15, min_child_weight = 20, weights = NULL) { 34 | if(is.null(weights)) { 35 | weights <- rep(1, length(y)) 36 | } 37 | # Create an XGBoost DMatrix object from the data, including weights 38 | data <- xgboost::xgb.DMatrix(data = as.matrix(x), label = as.vector(y), weight = weights) 39 | 40 | # Set parameters for the monotonic XGBoost model 41 | params <- list( 42 | max_depth = max_depth, 43 | min_child_weight = min_child_weight, 44 | monotone_constraints = 1, # Enforce monotonic increase 45 | eta = 1, 46 | gamma = 0, 47 | lambda = 0 48 | ) 49 | 50 | # Train the model with one boosting round 51 | iso_fit <- xgboost::xgb.train(params = params, data = data, nrounds = 1) 52 | 53 | # Prediction function for new data 54 | fun <- function(x) { 55 | data_pred <- xgboost::xgb.DMatrix(data = as.matrix(x)) 56 | pred <- predict(iso_fit, data_pred) 57 | return(pred) 58 | } 59 | 60 | return(fun) 61 | } 62 | -------------------------------------------------------------------------------- /R/lasso_cate_partially_linear.R: -------------------------------------------------------------------------------- 1 | #' Doubly-robust nonparametric superefficient estimation of the conditional average treatment effect 2 | #' using the lasso-based R-learner 3 | #' 4 | #' This function estimates the Conditional Average Treatment Effect (CATE) function `w -> tau(w) := E[Y | A=1, W=w] - E[Y | A=0, W=w]` 5 | #' within the regression model `E[Y | A, W] = E[Y | A=0, W] + A * tau(W)`. 6 | #' 7 | #' This method implements adaptive debiased machine learning of the Average Treatment Effect (ATE) through data-driven partially linear 8 | #' model selection based on the LASSO (Least Absolute Shrinkage and Selection Operator). By incorporating the LASSO technique, 9 | #' the method leverages learned structural information and promotes parsimony in the CATE function estimation. The approach 10 | #' offers adaptivity and super-efficiency in nonparametric ATE inference. 11 | #' 12 | #' @param W A numeric matrix of covariate values. 13 | #' @param A A numeric vector of treatment values. Can be binary or continuous. 14 | #' @param Y A numeric vector of outcome values. 15 | #' @param pi.hat A numeric vector containing estimated propensity scores `pi(W) := P(A=1 | W)`. 16 | #' @param m.hat A numeric vector containing estimates of treatment-marginalized outcome regression `m(W) := E[Y | W]`. 17 | #' @param ... Additional arguments to be passed to \code{\link[glmnet]{cv.glmnet}}. 18 | #' 19 | #' @import glmnet 20 | #' @export 21 | #' 22 | fit_cate_lasso_partially_linear <- function(W, A, Y, pi.hat = NULL, m.hat = NULL, verbose = TRUE,...) { 23 | if(!is.matrix(W)) W <- as.matrix(W) 24 | weights <- rep(1, length(Y)) 25 | 26 | 27 | 28 | 29 | ############################################### 30 | ### Fit CATE tau.hat(W) := E[Y_1 - Y_0 | W] using HAL-based R-learner 31 | ############################################### 32 | # R learner is implemented as least-squares regression with the following pseudo-outcomes and pseudo-weights: 33 | pseudo_outcome <- ifelse(abs(A-pi.hat)<1e-10, 0, (Y - m.hat)/(A-pi.hat)) 34 | pseudo_weights <- (A - pi.hat)^2 * weights 35 | 36 | 37 | if(verbose) print("Fitting tau = E[Y|A=1,W] - E[Y|A=0,W]") 38 | 39 | 40 | # Observations with near zero weights are dropped for stability. 41 | keep <- which( abs(A-pi.hat) > 1e-10) 42 | library(glmnet) 43 | fit_cate <- cv.glmnet(W[keep,,drop = F], pseudo_outcome[keep], weights = pseudo_weights[keep], family = "gaussian", ,... ) 44 | tau.hat <- predict(fit_cate, newx = W) 45 | # fit relaxed HAL to allow for plug-in inference 46 | nonzero <- as.vector(coef(fit_cate, s = "lambda.min"))!=0 47 | W_post <- cbind(1, W)[,nonzero] 48 | fit_cate_relaxed <- glm.fit(W_post[keep,,drop = F], pseudo_outcome[keep], family = gaussian(), weights = pseudo_weights[keep], intercept = FALSE) 49 | beta <- coef(fit_cate_relaxed) 50 | beta[is.na(beta)] <- 0 51 | tau.hat_relaxed <- W_post %*% beta 52 | 53 | 54 | # Compute IF map for projection of CATE 55 | IF_Y_map <- function(x_proj) { 56 | n <- length(Y) 57 | gamma <- W_post %*% solve(t(W_post) %*% diag((A-pi.hat)^2) %*% W_post / n) %*% colMeans(W_post) 58 | IF_cate <- gamma * (A-pi.hat) * (Y - m.hat - (A-pi.hat)*tau.hat) 59 | IF_cate <- IF_cate + tau.hat - mean(tau.hat) 60 | } 61 | 62 | fit_cate$internal <- list(IF_Y_map = IF_Y_map, data = list(tau.hat_relaxed = tau.hat_relaxed, W = W, A = A, Y = Y, pi.hat = pi.hat, m.hat = m.hat, tau.hat = tau.hat, pseudo_outcome = pseudo_outcome, pseudo_weights = pseudo_weights)) 63 | 64 | 65 | return(fit_cate) 66 | } 67 | -------------------------------------------------------------------------------- /R/truncate_pscore_adaptive.R: -------------------------------------------------------------------------------- 1 | 2 | 3 | #' @export 4 | #' 5 | truncate_pscore_adaptive <- function(A, pi, min_trunc_level = 1e-8) { 6 | risk_function <- function(cutoff, level) { 7 | pi <- pmax(pi, cutoff) 8 | pi <- pmin(pi, 1 - cutoff) 9 | alpha <- A/pi - (1-A)/(1-pi) #Riesz-representor 10 | alpha1 <- 1/pi 11 | alpha0 <- - 1/(1-pi) 12 | mean(alpha^2 - 2*(alpha1 - alpha0)) 13 | } 14 | cutoff <- optim(1e-5, fn = risk_function, method = "Brent", lower = min_trunc_level, upper = 0.5, level = 1)$par 15 | pi <- pmin(pi, 1 - cutoff) 16 | pi <- pmax(pi, cutoff) 17 | pi 18 | } 19 | 20 | 21 | -------------------------------------------------------------------------------- /R/utils.R: -------------------------------------------------------------------------------- 1 | detect_family <- function(Z, family = NULL){ 2 | if( inherits(family, "family") || (!is.null(family) && family != "auto") ) return(family) 3 | if(all(Z %in% c(0,1))) return("binomial") 4 | if(all(Z >=0)) return("poisson") 5 | return("gaussian") 6 | } 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # causalHAL: Adaptive Debiased Machine Learning with HAL 2 | 3 | This (in-development) package implements adaptive debiased machine learning estimators for the ATE in data-driven linear and partially linear regression models using the highly adaptive lasso. The theory for these methods is provided in the working paper: `https://arxiv.org/abs/2307.12544`. 4 | 5 | `vignette.Rmd` contains example code for running the partially linear ADMLE of the ATE using the highly adaptive lasso (HAL) or lasso (via glmnet). 6 | 7 | The R, sh, and sbatch scripts used to run the simulations in the paper can be found in the folder `simulationScripts`. 8 | Note, at this point in time, the code documentation is fairly poor. 9 | 10 | # Motivation and framework 11 | 12 | Debiased machine learning estimators for nonparametric inference of smooth functionals of the data-generating distribution can suffer from excessive variability and instability. For this reason, practitioners may resort to simpler models based on parametric or semiparametric assumptions. However, such simplifying assumptions may fail to hold, and estimates may then be biased due to model misspecification. To address this problem, we propose Adaptive Debiased Machine Learning (ADML), a nonparametric framework that combines data-driven model selection and debiased machine learning techniques to construct asymptotically linear, adaptive, and superefficient estimators for pathwise differentiable functionals. 13 | 14 | By learning model structure directly from data, ADML avoids the bias introduced by model misspecification and remains free from the restrictions of parametric and semiparametric models. While they may exhibit irregular behavior for the target parameter in a nonparametric statistical model, we demonstrate that ADML estimators provides regular and locally uniformly valid inference for a projection-based oracle parameter. Importantly, this oracle parameter agrees with the original target parameter for distributions within an unknown but correctly specified oracle statistical submodel that is learned from the data. This finding implies that there is no penalty, in a local asymptotic sense, for conducting data-driven model selection compared to having prior knowledge of the oracle submodel and oracle parameter. 15 | 16 | 17 | 18 | 19 | ## Install 20 | 21 | ```{r} 22 | devtools::install_github("tlverse/hal9001") 23 | devtools::install_github("Larsvanderlaan/causalHAL") 24 | ``` 25 | 26 | ## Run ADMLE using HAL and glmnet 27 | 28 | ```{r} 29 | library(causalHAL) 30 | library(hal9001) 31 | seed <- rnorm(1) 32 | n <- 1000 33 | d <- 4 34 | pos_const <- 1 35 | W <- replicate(d, runif(n, -1, 1)) 36 | colnames(W) <- paste0("W", 1:d) 37 | pi0 <- plogis(pos_const * ( W[,1] + sin(4*W[,1]) + W[,2] + cos(4*W[,2]) + W[,3] + sin(4*W[,3]) + W[,4] + cos(4*W[,4]) )) 38 | A <- rbinom(n, 1, pi0) 39 | mu0 <- sin(4*W[,1]) + sin(4*W[,2]) + sin(4*W[,3])+ sin(4*W[,4]) + cos(4*W[,2]) 40 | tau <- 1 + W[,1] + abs(W[,2]) + cos(4*W[,3]) + W[,4] 41 | Y <- rnorm(n, mu0 + A * tau, 0.5) 42 | 43 | 44 | # User-supplied estimate of propensity score pi = P(A=1|W) 45 | pi.hat <- pi0 46 | # User-supplied estimate of treatment-marginalized outcome regression m = E(Y|W) 47 | m.hat <- mu0 * pi.hat + (mu0 + tau) * (1-pi.hat) 48 | 49 | # ADMLE for ATE using partially linear model with HAL. 50 | # Fits additive piece-wise linear spline model for CATE with 50 knot points per covariate using highly adaptive lasso (see tlverse/hal9001 github R package) 51 | set.seed(seed) 52 | ADMLE_fit <- fit_cate_hal_partially_linear(W, A, Y, 53 | m.hat = m.hat, 54 | pi.hat = pi.hat, 55 | smoothness_orders_cate = 1, num_knots_cate = c(50), max_degree_cate = 1) 56 | # Provides estimates and CI for ATE 57 | inference_ate(ADMLE_fit) 58 | 59 | # Same analysis but using glmnet implementation with hal9001-basis design matrix. 60 | # May not reproduce estimates exactly but should be close. 61 | # For those not familiar with hal9001 package, the below code may be easier to play around with. 62 | basis_list <- hal9001::enumerate_basis(W, smoothness_orders = 1, num_knots = 50, max_degree = 1) 63 | tau_basis <- hal9001::make_design_matrix(W, basis_list) 64 | set.seed(seed) 65 | ADMLE_fit <- fit_cate_lasso_partially_linear(tau_basis, A, Y, 66 | m.hat = m.hat, 67 | pi.hat = pi.hat, standardize = FALSE) 68 | 69 | # Provides estimates and CI for ATE 70 | inference_ate(ADMLE_fit) 71 | ``` 72 | 73 | -------------------------------------------------------------------------------- /causalHAL.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: Default 4 | SaveWorkspace: Default 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: knitr 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | 18 | BuildType: Package 19 | PackageUseDevtools: Yes 20 | PackageInstallArgs: --no-multiarch --with-keep.source 21 | -------------------------------------------------------------------------------- /man/fit_cate_hal_partially_linear.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/hal_cate_partially_linear.R 3 | \name{fit_cate_hal_partially_linear} 4 | \alias{fit_cate_hal_partially_linear} 5 | \title{Doubly-robust nonparametric superefficient estimation of the partial conditional average treatment effect 6 | using the highly adaptive lasso R-learner} 7 | \usage{ 8 | fit_cate_hal_partially_linear( 9 | W, 10 | A, 11 | Y, 12 | weights = NULL, 13 | pi.hat = NULL, 14 | m.hat = NULL, 15 | formula_cate = NULL, 16 | max_degree_cate = 1, 17 | smoothness_orders_cate = 1, 18 | num_knots_cate = c(50), 19 | sl3_Lrnr_pi.hat = NULL, 20 | sl3_Lrnr_m.hat = NULL, 21 | verbose = TRUE, 22 | ... 23 | ) 24 | } 25 | \arguments{ 26 | \item{W}{A numeric matrix of covariate values.} 27 | 28 | \item{A}{A numeric vector of treatment values. Can be binary or continuous.} 29 | 30 | \item{Y}{A numeric vector of outcome values.} 31 | 32 | \item{weights}{(Optional) A numeric vector of observation weights.} 33 | 34 | \item{pi.hat}{A numeric vector containing estimated propensity scores `pi(W) := P(A=1 | W)`.} 35 | 36 | \item{m.hat}{A numeric vector containing estimates of treatment-marginalized outcome regression `m(W) := E[Y | W]`.} 37 | 38 | \item{formula_cate}{(Optional) A `hal9001`-formatted formula object for the CATE to be passed to `formula_hal`. 39 | By default, the CATE model is learned data-adaptively using MARS-based screening and HAL. See documentation for `fit_hal`.} 40 | 41 | \item{max_degree_cate}{(Optional) Same as `max_degree` but for CATE model. See documentation for `fit_hal`.} 42 | 43 | \item{smoothness_orders_cate}{(Optional) Same as `smoothness_orders` but for CATE model. See documentation for `fit_hal`.} 44 | 45 | \item{num_knots_cate}{(Optional) Same as `num_knots` but for CATE model. See documentation for `fit_hal`.} 46 | 47 | \item{sl3_Lrnr_pi.hat}{If `pi.hat` is not provided, a `Lrnr_base` object for estimation of `pi(W) := P(A=1 | W)`.} 48 | 49 | \item{sl3_Lrnr_m.hat}{If `m.hat` is not provided, a `Lrnr_base` object for estimation of `m(W) := E[Y | W)`.} 50 | 51 | \item{...}{Other arguments to be passed to `fit_hal`.} 52 | } 53 | \description{ 54 | This function estimates the Partial Conditional Average Treatment Effect (CATE) function `w -> tau(w) := E[Y | A=1, W=w] - E[Y | A=0, W=w]` 55 | within the regression model `E[Y | A, W] = E[Y | A=0, W] + A * tau(W)`. 56 | } 57 | \details{ 58 | This method implements a doubly-robust and superefficient estimation technique for the Partial CATE using the highly adaptive lasso R-learner. 59 | It can data-adaptivelly learn complex relationships in the CATE function, while benefiting from simpler structure and parsimony when present. 60 | By using the highly adaptive lasso, the method aims to provide robust and precise nonparametric inference of the Partial CATE. 61 | } 62 | -------------------------------------------------------------------------------- /man/fit_cate_lasso_partially_linear.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/lasso_cate_partially_linear.R 3 | \name{fit_cate_lasso_partially_linear} 4 | \alias{fit_cate_lasso_partially_linear} 5 | \title{Doubly-robust nonparametric superefficient estimation of the conditional average treatment effect 6 | using the lasso-based R-learner} 7 | \usage{ 8 | fit_cate_lasso_partially_linear( 9 | W, 10 | A, 11 | Y, 12 | pi.hat = NULL, 13 | m.hat = NULL, 14 | verbose = TRUE, 15 | ... 16 | ) 17 | } 18 | \arguments{ 19 | \item{W}{A numeric matrix of covariate values.} 20 | 21 | \item{A}{A numeric vector of treatment values. Can be binary or continuous.} 22 | 23 | \item{Y}{A numeric vector of outcome values.} 24 | 25 | \item{pi.hat}{A numeric vector containing estimated propensity scores `pi(W) := P(A=1 | W)`.} 26 | 27 | \item{m.hat}{A numeric vector containing estimates of treatment-marginalized outcome regression `m(W) := E[Y | W]`.} 28 | 29 | \item{...}{Additional arguments to be passed to \code{\link[glmnet]{cv.glmnet}}.} 30 | } 31 | \description{ 32 | This function estimates the Conditional Average Treatment Effect (CATE) function `w -> tau(w) := E[Y | A=1, W=w] - E[Y | A=0, W=w]` 33 | within the regression model `E[Y | A, W] = E[Y | A=0, W] + A * tau(W)`. 34 | } 35 | \details{ 36 | This method implements adaptive debiased machine learning of the Average Treatment Effect (ATE) through data-driven partially linear 37 | model selection based on the LASSO (Least Absolute Shrinkage and Selection Operator). By incorporating the LASSO technique, 38 | the method leverages learned structural information and promotes parsimony in the CATE function estimation. The approach 39 | offers adaptivity and super-efficiency in nonparametric ATE inference. 40 | } 41 | -------------------------------------------------------------------------------- /man/fit_hal_cate_plugin.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/hal_cate_plugin.R 3 | \name{fit_hal_cate_plugin} 4 | \alias{fit_hal_cate_plugin} 5 | \title{Doubly-robust nonparametric superefficient estimation of the conditional average treatment effect 6 | using the highly adaptive lasso plug-in estimator.} 7 | \usage{ 8 | fit_hal_cate_plugin( 9 | W, 10 | A, 11 | Y, 12 | weights = NULL, 13 | formula_cate = NULL, 14 | max_degree_cate = 3, 15 | num_knots_cate = c(sqrt(length(Y)), length(Y)^(1/3), length(Y)^(1/5)), 16 | smoothness_orders_cate = 1, 17 | screen_variable_cate = TRUE, 18 | params_EY0W = list(max_degree = 3, num_knots = c(sqrt(length(Y)), length(Y)^(1/3), 19 | length(Y)^(1/5)), smoothness_orders = 1, screen_variables = TRUE), 20 | include_propensity_score = FALSE, 21 | verbose = TRUE, 22 | ... 23 | ) 24 | } 25 | \arguments{ 26 | \item{W}{A \code{matrix} of covariate values.} 27 | 28 | \item{A}{A \code{numeric} binary vector of treatment values.} 29 | 30 | \item{Y}{A \code{numeric} vector of outcome values.} 31 | 32 | \item{weights}{(Optional) A \code{numeric} vector of observation weights.} 33 | 34 | \item{formula_cate}{(Optional) A \code{hal9001}-formatted \code{formula} object for the CATE/tau to be passed to \code{\link[hal9001]{formula_hal}}. 35 | By default the CATE model is learned data-adaptivelly using MARS-based screening and HAL.} 36 | 37 | \item{max_degree_cate}{(Optional) Same as \code{max_degree} but for CATE model.} 38 | 39 | \item{num_knots_cate}{(Optional) Same as \code{num_knots} but for CATE model.} 40 | 41 | \item{smoothness_orders_cate}{(Optional) Same as \code{smoothness_orders} but for CATE model.} 42 | 43 | \item{...}{Other arguments to be passed to \code{\link[hal9001]{fit_hal}}.} 44 | 45 | \item{Delta}{(Not used)} 46 | 47 | \item{max_degree}{For estimation of nuisance functions `E[Y|W]` and `E[X|W]`. 48 | The maximum interaction degree of basis functions generated. 49 | Passed to \code{\link[hal9001]{fit_hal}} function of \code{hal9001} package.} 50 | 51 | \item{num_knots}{For estimation of nuisance functions `E[Y|W]` and `E[X|W]`. 52 | Passed to \code{\link[hal9001]{fit_hal}} function of \code{hal9001} package. 53 | A \code{numeric} vector of length \code{max_degree} where 54 | the `d`-th entry specifies the number of univariable spline knot points to use 55 | when generating the tensor-product basis functions of interaction degree `d`.} 56 | 57 | \item{smoothness_orders}{For estimation of nuisance functions `E[Y|W]` and `E[X|W]`. 58 | An integer taking values in (0,1,2,...) 59 | specifying the smoothness order of the basis functions. See documentation for \code{\link[hal9001]{fit_hal}}.} 60 | 61 | \item{family_Y}{A \code{\link[stats]{family}} object specifying the outcome type of the outcome \code{Y}. 62 | This is passed internally to \code{\link[hal9001]{fit_hal}} when estimating `E[Y | W]`.} 63 | 64 | \item{family_A}{A \code{\link[stats]{family}} object specifying the outcome type of the treatment \code{A}. 65 | This is passed internally to \code{\link[hal9001]{fit_hal}} when estimating `E[A | W]`.} 66 | 67 | \item{screen_variables}{Highly recommended. See documentation for \code{\link[hal9001]{fit_hal}}.} 68 | 69 | \item{screen_interactions}{Highly recommended. See documentation for \code{\link[hal9001]{fit_hal}}.} 70 | } 71 | \description{ 72 | This method estimates the conditional average treatment effect function `w - > tau(w)` 73 | under the regression model `E[Y | A, W] = E[Y | A=0, W] + A * tau(W)`. 74 | } 75 | -------------------------------------------------------------------------------- /man/hello.Rd: -------------------------------------------------------------------------------- 1 | \name{hello} 2 | \alias{hello} 3 | \title{Hello, World!} 4 | \usage{ 5 | hello() 6 | } 7 | \description{ 8 | Prints 'Hello, world!'. 9 | } 10 | \examples{ 11 | hello() 12 | } 13 | -------------------------------------------------------------------------------- /man/inference_ate.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/inference.R 3 | \name{inference_ate} 4 | \alias{inference_ate} 5 | \title{Estimates and confidence intervals for the ATE.} 6 | \usage{ 7 | inference_ate(fit_cate, alpha = 0.05, return_cov_mat = FALSE) 8 | } 9 | \arguments{ 10 | \item{fit_cate}{A \code{hal_cate} object obtained from the function \code{fit_hal_cate}.} 11 | 12 | \item{alpha}{Significant level for confidence intervals} 13 | 14 | \item{return_cov_mat}{A \code{logical} for whether to return the asymptotic covariance matrix of the coefficient estimates.} 15 | } 16 | \description{ 17 | Estimates and confidence intervals for the ATE. 18 | } 19 | -------------------------------------------------------------------------------- /man/inference_cate.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/inference.R 3 | \name{inference_cate} 4 | \alias{inference_cate} 5 | \title{Estimates and confidence intervals for the projection of the CATE onto a user-specified parametric working model.} 6 | \usage{ 7 | inference_cate(fit_cate, formula = ~1, alpha = 0.05, return_cov_mat = FALSE) 8 | } 9 | \arguments{ 10 | \item{fit_cate}{A \code{hal_cate} object obtained from the function \code{fit_hal_cate}.} 11 | 12 | \item{formula}{A \code{formula} object specifying a working parametric model for the conditional average treatment effect. 13 | For instance, `formula = ~ 1` specifies the marginal average treatment effect `E[CATE(W)]`. 14 | More complex formula like `formula = ~ W1` specifies the best `W1`-linear approximation of the true CATE.} 15 | 16 | \item{alpha}{Significant level for confidence intervals} 17 | 18 | \item{return_cov_mat}{A \code{logical} for whether to return the asymptotic covariance matrix of the coefficient estimates.} 19 | } 20 | \description{ 21 | Estimates and confidence intervals for the projection of the CATE onto a user-specified parametric working model. 22 | } 23 | -------------------------------------------------------------------------------- /man/isoreg_with_xgboost.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/isoreg.R 3 | \name{isoreg_with_xgboost} 4 | \alias{isoreg_with_xgboost} 5 | \title{Isotonic Regression with XGBoost} 6 | \usage{ 7 | isoreg_with_xgboost( 8 | x, 9 | y, 10 | max_depth = 15, 11 | min_child_weight = 20, 12 | weights = NULL 13 | ) 14 | } 15 | \arguments{ 16 | \item{x}{A vector or matrix of predictor variables.} 17 | 18 | \item{y}{A vector of response variables.} 19 | 20 | \item{max_depth}{Integer. Maximum depth of the trees in XGBoost (default is 15).} 21 | 22 | \item{min_child_weight}{Numeric. Minimum sum of instance weights (Hessian) needed in a child node (default is 20).} 23 | 24 | \item{weights}{A vector of weights to apply to each instance during training (default is NULL, meaning equal weights).} 25 | } 26 | \value{ 27 | A function that takes a new predictor variable \code{x} and returns the model's predicted values. 28 | } 29 | \description{ 30 | Fits an isotonic regression model using XGBoost with monotonic constraints. 31 | } 32 | \details{ 33 | This function uses XGBoost to fit a monotonic increasing model to the data, enforcing isotonic regression 34 | through the use of monotonic constraints. The model is trained with one boosting round to achieve a fit 35 | that is interpretable as an isotonic regression. 36 | } 37 | \examples{ 38 | \dontrun{ 39 | # Example data 40 | x <- matrix(rnorm(100), ncol = 1) 41 | y <- sort(rnorm(100)) 42 | 43 | # Fit the model 44 | iso_model <- isoreg_with_xgboost(x, y) 45 | 46 | # Predict on new data 47 | x_new <- matrix(rnorm(10), ncol = 1) 48 | predictions <- iso_model(x_new) 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /simResults/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Larsvanderlaan/AdaptiveDML/3d35b5d2c0444511e53d4c6e554d2c5123945e27/simResults/.DS_Store -------------------------------------------------------------------------------- /simulationScripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Larsvanderlaan/AdaptiveDML/3d35b5d2c0444511e53d4c6e554d2c5123945e27/simulationScripts/.DS_Store -------------------------------------------------------------------------------- /simulationScripts/R_setup.R: -------------------------------------------------------------------------------- 1 | .libPaths( c( "~/Rlibs2", .libPaths()) ) 2 | print(.libPaths()) 3 | setwd("~/sieveSims") 4 | print(getwd()) 5 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844814.err: -------------------------------------------------------------------------------- 1 | Failed to query server: Connection timed out 2 | Loading required package: foreach 3 | Loading required package: future 4 | Warning message: 5 | In system("timedatectl", intern = TRUE) : 6 | running command 'timedatectl' had status 1 7 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844815.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844816.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844817.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844818.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844819.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844820.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844821.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844822.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844823.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844824.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844825.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844826.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | There were 50 or more warnings (use warnings() to see the first 50) 4 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844827.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | There were 50 or more warnings (use warnings() to see the first 50) 4 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844828.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | There were 50 or more warnings (use warnings() to see the first 50) 4 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844829.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | There were 50 or more warnings (use warnings() to see the first 50) 4 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844830.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844831.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844832.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844833.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844834.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844835.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844836.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/Rout/par-24844837.err: -------------------------------------------------------------------------------- 1 | Loading required package: foreach 2 | Loading required package: future 3 | -------------------------------------------------------------------------------- /simulationScripts/install_packages.R: -------------------------------------------------------------------------------- 1 | #devtools::install_github("tlverse/hal9001", ref = "screeningHAL") 2 | devtools::install_github("tlverse/sl3", ref = "develVersionChangeLars") 3 | devtools::install_github("tlverse/origami") 4 | devtools::install_cran("future") 5 | devtools::install_cran("doFuture") 6 | #devtools::install_github("Larsvanderlaan/causalHAL") 7 | 8 | -------------------------------------------------------------------------------- /simulationScripts/install_packages.sbatch: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name LarsJob%jls # Set a name for your job. This is especially useful if you have multiple jobs queued. 3 | #SBATCH --partition short # Slurm partition to use 4 | #SBATCH --ntasks 1 # Number of tasks to run. By default, one CPU core will be allocated per task 5 | #SBATCH --time 0-00:20 # Wall time limit in D-HH:MM 6 | #SBATCH --mem-per-cpu=1000 # Memory limit for each tasks (in MB) 7 | #SBATCH -o ./%j.out # File to which STDOUT will be written 8 | #SBATCH -e ./%j.err # File to which STDERR will be written 9 | #SBATCH --mail-type=NONE # Type of email notification- NONE,BEGIN,END,FAIL,ALL 10 | #SBATCH --mail-user=lvdlaan@uw.edu # Email to which notifications will be sent 11 | export R_LIBS=~/Rlibs2 12 | export R_LIBS_USER=~/Rlibs2 13 | module load R 14 | Rscript -e 'source("~/causalHAL/simScripts/install_packages.R")' 15 | -------------------------------------------------------------------------------- /simulationScripts/simScriptAdapt.R: -------------------------------------------------------------------------------- 1 | library(data.table) 2 | library(hal9001) 3 | library(sl3) 4 | library(causalHAL) 5 | library(doFuture) 6 | library(future) 7 | 8 | #out <- do_sims(10, 3000, 2, TRUE, do_local_alt = FALSE) 9 | 10 | 11 | 12 | do_sims <- function(niter, n, pos_const, muIsHard, do_local_alt = FALSE) { 13 | seed_init <- 12345 14 | sim_results <- rbindlist(lapply(1:niter, function(iter) { 15 | set.seed(seed_init*iter) 16 | print(paste0("Iteration number: ", iter)) 17 | try({ 18 | if(!do_local_alt) { 19 | data_list <- get_data(n, pos_const, muIsHard) 20 | } else if(do_local_alt) { 21 | data_list <- get_data_local_alt(n, pos_const, muIsHard) 22 | } 23 | return(as.data.table(get_estimates(data_list$W, data_list$A, data_list$Y,iter, NULL))) 24 | }) 25 | return(data.table()) 26 | })) 27 | key <- paste0("iter=", niter, "_n=", n, "_pos=", pos_const, "_hard=", muIsHard, "_local_",do_local_alt ) 28 | try({fwrite(sim_results, paste0("~/causalHAL/simResults/sim_results_", key, ".csv"))}) 29 | return(sim_results) 30 | } 31 | 32 | 33 | #' generates dataset of size n. 34 | #' constant in propensity score can be used to vary overlap. 35 | #' two settings for outcome regression: easy form and hard form 36 | get_data <- function(n, pos_const, muIsHard = TRUE) { 37 | d <- 4 38 | W <- replicate(d, runif(n, -1, 1)) 39 | colnames(W) <- paste0("W", 1:d) 40 | pi0 <- plogis(pos_const * ( W[,1] + sin(4*W[,1]) + W[,2] + cos(4*W[,2]) + W[,3] + sin(4*W[,3]) + W[,4] + cos(4*W[,4]) )) 41 | print("pos") 42 | print(range(pi0)) 43 | A <- rbinom(n, 1, pi0) 44 | if(muIsHard) { 45 | mu0 <- sin(4*W[,1]) + sin(4*W[,2]) + sin(4*W[,3])+ sin(4*W[,4]) + cos(4*W[,2]) 46 | } else { 47 | mu0 <- W[,1] + abs(W[,2]) + W[,3] + abs(W[,4]) 48 | } 49 | tau <- 1 + W[,1] + abs(W[,2]) + cos(4*W[,3]) + W[,4] 50 | Y <- rnorm(n, mu0 + A * tau, 0.5) 51 | return(list(W=W, A = A, Y = Y, ATE = 1.31, pi = pi0)) 52 | } 53 | 54 | get_data_local_alt <- function(n, pos_const, muIsHard = TRUE) { 55 | ates <- list("0" = 4, "0.5" = 5.219036, "1"= 14.44244, "2"= 1048.994) 56 | 57 | d <- 4 58 | W <- replicate(d, runif(n, -1, 1)) 59 | colnames(W) <- paste0("W", 1:d) 60 | pi0 <- plogis(pos_const * ( W[,1] + sin(4*W[,1]) + W[,2] + cos(4*W[,2]) + W[,3] + sin(4*W[,3]) + W[,4] + cos(4*W[,4]) )) 61 | 62 | print("pos") 63 | print(range(pi0)) 64 | A <- rbinom(n, 1, pi0) 65 | if(muIsHard) { 66 | mu0 <- sin(4*W[,1]) + sin(4*W[,2]) + sin(4*W[,3])+ sin(4*W[,4]) + cos(4*W[,2]) 67 | } else { 68 | mu0 <- W[,1] + abs(W[,2]) + W[,3] + abs(W[,4]) 69 | } 70 | 71 | mu0 <- mu0 - (pi0/(pi0*(1-pi0)))/sqrt(n) 72 | tau <- 1 + ( 1 / (pi0*(1-pi0)) )/sqrt(n) 73 | Y <- rnorm(n, mu0 + A * tau, 0.5) 74 | return(list(W=W, A = A, Y = Y, ATE = 1 + ates[[as.character(pos_const)]]/sqrt(n), pi = pi)) 75 | } 76 | 77 | #' Given simulated data (W,A,Y) and simulation iteration number `iter`, 78 | #' computes ATE estimates, se, and CI for plug-in T-learner HAL, plug-in R-learner HAL, partially linear intercept model, AIPW. 79 | get_estimates <- function(W, A, Y,iter, pi_true) { 80 | n <- length(Y) 81 | if(n <= 500) { 82 | num_knots <- c(10, 10, 1, 0) 83 | } else if(n <= 1000) { 84 | num_knots <- c(50, 15, 15, 15) 85 | } else if(n <= 3000) { 86 | num_knots <- c(75, 25,30,30) 87 | } else{ 88 | num_knots <- c(100, 50, 50,50) 89 | } 90 | fit_T <- fit_hal_cate_plugin (W, A, Y, max_degree_cate = 1, num_knots_cate = num_knots , smoothness_orders_cate = 1, screen_variable_cate = FALSE, params_EY0W = list(max_degree = 1, num_knots = num_knots , smoothness_orders = 1, screen_variables = FALSE, fit_control = list(parallel = TRUE)), fit_control = list(parallel = TRUE), include_propensity_score = FALSE, verbose = TRUE ) 91 | ate_T <- unlist(inference_ate(fit_T)) 92 | ate_T[1] <- "Tlearner" 93 | 94 | mu1 <- fit_T$internal$data$mu1 95 | mu0 <- fit_T$internal$data$mu0 96 | mu <- ifelse(A==1, mu1, mu0) 97 | 98 | lrnr_stack <- Stack$new(list( Lrnr_earth$new(degree = 2, family = "gaussian"),Lrnr_gam$new(family = "gaussian"), Lrnr_ranger$new(), Lrnr_xgboost$new(max_depth = 4, nrounds = 20), Lrnr_xgboost$new(max_depth = 5, nrounds = 20) )) 99 | lrnr_A<- make_learner(Pipeline, Lrnr_cv$new(lrnr_stack), Lrnr_cv_selector$new(loss_squared_error) ) 100 | task_A <- sl3_Task$new(data.table(W, A = A), covariates = colnames(W), outcome = "A", outcome_type = "continuous") 101 | 102 | fit_pi <- lrnr_A$train(task_A) 103 | 104 | pi <- fit_pi$predict(task_A) 105 | pi <- truncate_pscore_adaptive(A, pi) 106 | 107 | 108 | 109 | m <- mu0 * (1-pi) + mu1 * (pi) 110 | fit_R <- fit_hal_cate_partially_linear(W, A, Y, fit_control = list(parallel = TRUE), pi.hat = pi, m.hat = m, formula_cate = NULL, max_degree_cate = 1, num_knots_cate = num_knots, smoothness_orders_cate = 1, verbose = TRUE) 111 | ate_R<- unlist(inference_ate(fit_R)) 112 | ate_R[1] <- "Rlearner" 113 | 114 | # 115 | cate.hat <- fit_R$data$tau_relaxed 116 | calibrator <- isoreg_with_xgboost(cate.hat, fit_R$data$pseudo_outcome, weights = fit_R$data$pseudo_weights) 117 | cate_cal <- calibrator(cate.hat) 118 | # Create a data.table 119 | dt <- data.table(tau_cal, cond_var, weight = (A - pi)^2) 120 | gamma_dt <- dt[, .(gamma = weighted.mean(cond_var, w = weight)), by = tau_cal] 121 | dt <- merge(dt, gamma_dt, by = "tau_cal", all.x = TRUE, sort = FALSE) 122 | gamma_n <- dt$gamma 123 | IF <- (A - pi) * gamma_n * (Y - m - (A-pi)*cate_cal) 124 | CI <- mean(cate_cal) + 1.96*c(-1,1)*sd(IF)/sqrt(n) 125 | ate_cal <- c("cal", mean(cate_cal), sd(IF)/sqrt(n), CI) 126 | 127 | tau_int <- mean((A-pi) * (Y - m)) / mean((A-pi)^2) 128 | IF <- (A - pi) / mean((A-pi)^2) * (Y - m - (A-pi)*tau_int) 129 | CI <- tau_int + 1.96*c(-1,1)*sd(IF)/sqrt(n) 130 | ate_intercept <- c("intercept", tau_int, sd(IF)/sqrt(n), CI) 131 | names(ate_intercept) <- c("method", "coef","se", "CI_left", "CI_right") 132 | 133 | 134 | # 135 | 136 | IF <- mu1 - mu0 + (A/pi - (1-A)/(1-pi)) * (Y - mu) 137 | est_AIPW <- mean(IF) 138 | CI <- est_AIPW + 1.96*c(-1,1)*sd(IF)/sqrt(n) 139 | ate_aipw <- c("AIPW", est_AIPW, sd(IF)/sqrt(n), CI) 140 | names(ate_aipw) <- c("method", "coef","se", "CI_left", "CI_right") 141 | 142 | mat <- cbind(iter,rbind(ate_T, ate_R, ate_intercept, ate_aipw, ate_cal)) 143 | colnames(mat) <- c("iter", "method", "coef","se", "CI_left", "CI_right") 144 | 145 | 146 | return(mat) 147 | } 148 | -------------------------------------------------------------------------------- /simulationScripts/simScriptAdapt.sbatch: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --array=1 4 | #SBATCH --nodes=1 5 | #SBATCH --cpus-per-task=11 6 | #SBATCH --output=Rout/par-%J.out 7 | #SBATCH --error=Rout/par-%J.err 8 | echo "LOADING R" 9 | module load R 10 | echo "R LOADED" 11 | Rscript -e 'source("~/causalHAL/simulationScripts/R_setup.R"); source("~/causalHAL/simulationScripts/simScriptAdapt.R"); n = as.numeric(Sys.getenv("n")); pos_const = as.numeric(Sys.getenv("const")); muIsHard = as.logical(Sys.getenv("hard")); do_sims(5000, n, pos_const, muIsHard)' 12 | 13 | -------------------------------------------------------------------------------- /simulationScripts/simScriptAdaptLocal.sbatch: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --array=1 4 | #SBATCH --nodes=1 5 | #SBATCH --cpus-per-task=11 6 | #SBATCH --output=Rout/par-%J.out 7 | #SBATCH --error=Rout/par-%J.err 8 | echo "LOADING R" 9 | module load R 10 | echo "R LOADED" 11 | Rscript -e 'source("~/causalHAL/simScripts/R_setup.R"); source("~/causalHAL/simScripts/simScriptAdapt.R"); n = as.numeric(Sys.getenv("n")); pos_const = as.numeric(Sys.getenv("const")); muIsHard = as.logical(Sys.getenv("hard")); do_sims(5000, n, pos_const, muIsHard, TRUE)' 12 | -------------------------------------------------------------------------------- /simulationScripts/simsAdapt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/usr/env bash 2 | nsims=2500 3 | export R_LIBS=~/Rlibs2 4 | export R_LIBS_USER=~/Rlibs2 5 | for n in 500 1000 2000 3000 4000 5000 6 | do 7 | for const in 0.5 1 2 3 8 | do 9 | for hard in "TRUE" "FALSE" 10 | do 11 | sbatch --export=n=$n,const=$const,hard=$hard ~/causalHAL/simulationScripts/simScriptAdapt.sbatch 12 | done 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /simulationScripts/simsAdaptLocal.sh: -------------------------------------------------------------------------------- 1 | #!/bin/usr/env bash 2 | nsims=2500 3 | export R_LIBS=~/Rlibs2 4 | export R_LIBS_USER=~/Rlibs2 5 | for n in 500 1000 2000 3000 4000 5000 6 | do 7 | for const in 0 0.5 1 2 8 | do 9 | for hard in "TRUE" "FALSE" 10 | do 11 | sbatch --export=n=$n,const=$const,hard=$hard ~/causalHAL/simScripts/simScriptAdaptLocal.sbatch 12 | done 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /vignette.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Vignette" 3 | output: html_document 4 | date: '2023-07-26' 5 | --- 6 | 7 | ```{r setup, include=FALSE} 8 | knitr::opts_chunk$set(echo = TRUE) 9 | ``` 10 | 11 | 12 | ```{r} 13 | # install package if needed 14 | if(!require(causalHAL)) { 15 | devtools::install_github("Larsvanderlaan/causalHAL") 16 | } 17 | 18 | ``` 19 | 20 | # ADMLE for ATE using adaptive partially linear regression models 21 | 22 | ## Generate example dataset 23 | 24 | ```{r} 25 | # Dataset used for simulations 26 | get_data <- function(n, pos_const, muIsHard = TRUE) { 27 | # n: sample size 28 | # pos_const: used to control treatment overlap 29 | # Whether outcome regression is hard or simple. 30 | 31 | # covariate dimension 32 | d <- 4 33 | W <- replicate(d, runif(n, -1, 1)) 34 | colnames(W) <- paste0("W", 1:d) 35 | # propensity score 36 | pi0 <- plogis(pos_const * ( W[,1] + sin(4*W[,1]) + W[,2] + cos(4*W[,2]) + W[,3] + sin(4*W[,3]) + W[,4] + cos(4*W[,4]) )) 37 | 38 | # treatment 39 | A <- rbinom(n, 1, pi0) 40 | 41 | # control outcome regression 42 | if(muIsHard) { 43 | mu0 <- sin(4*W[,1]) + sin(4*W[,2]) + sin(4*W[,3])+ sin(4*W[,4]) + cos(4*W[,2]) 44 | } else { 45 | mu0 <- W[,1] + abs(W[,2]) + W[,3] + abs(W[,4]) 46 | } 47 | # CATE 48 | tau <- 1 + W[,1] + abs(W[,2]) + cos(4*W[,3]) + W[,4] 49 | # outcome 50 | Y <- rnorm(n, mu0 + A * tau, 0.5) 51 | return(list(W=W, A = A, Y = Y, ATE = 1.31, pi = pi0, mu0 = mu0, tau = tau )) 52 | } 53 | 54 | ``` 55 | 56 | 57 | ## Run ADMLE using HAL and glmnet 58 | 59 | ```{r} 60 | library(causalHAL) 61 | seed <- rnorm(1) 62 | 63 | data <- get_data(1000, 1, TRUE) 64 | print(paste0("True ATE: ", data$ATE)) 65 | # get nuisance functions for R-learner 66 | 67 | # User-supplied estimate of propensity score pi = P(A=1|W) 68 | pi.hat <- data$pi 69 | # User-supplied estimate of treatment-marginalized outcome regression m = E(Y|W) 70 | m.hat <- data$mu0 * pi.hat + (data$mu0 + data$tau) * (1-pi.hat) 71 | 72 | # ADMLE for ATE using partially linear model with HAL. 73 | # Fits additive piece-wise linear spline model for CATE with 50 knot points per covariate using highly adaptive lasso (see tlverse/hal9001 github R package) 74 | set.seed(seed) 75 | ADMLE_fit <- fit_cate_hal_partially_linear(data$W, data$A, data$Y, 76 | m.hat = m.hat, 77 | pi.hat = pi.hat, 78 | smoothness_orders_cate = 1, num_knots_cate = c(50), max_degree_cate = 1) 79 | # Provides estimates and CI for ATE 80 | inference_ate(ADMLE_fit) 81 | 82 | # Same analysis but using glmnet implementation with hal9001-basis design matrix. 83 | # May not reproduce estimates exactly but should be close. 84 | # For those not familiar with hal9001 package, the below code may be easier to play around with. 85 | basis_list <- hal9001::enumerate_basis(data$W, smoothness_orders = 1, num_knots = 50, max_degree = 1) 86 | tau_basis <- hal9001::make_design_matrix(data$W, basis_list) 87 | set.seed(seed) 88 | ADMLE_fit <- fit_cate_lasso_partially_linear(tau_basis, data$A, data$Y, 89 | m.hat = m.hat, 90 | pi.hat = pi.hat, standardize = FALSE) 91 | 92 | # Provides estimates and CI for ATE 93 | inference_ate(ADMLE_fit) 94 | ``` 95 | --------------------------------------------------------------------------------