├── .github ├── .gitignore ├── workflows │ ├── draft-pdf.yml │ ├── lint.yaml │ ├── coverage.yaml │ ├── R-CMD-check.yaml │ ├── pkgdown.yaml │ └── rhub.yaml └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .covrignore ├── vignettes └── .gitignore ├── LICENSE ├── joss ├── paper.pdf └── paper.md ├── .gitattributes ├── .lintr ├── tests ├── testthat.R └── testthat │ ├── test-rroc.R │ ├── test-diagnostics.R │ ├── test-recipe-api.R │ ├── test-data.R │ ├── test-predictor.R │ ├── test-utils.R │ ├── test-vimp.R │ ├── test-pseudooutcomes.R │ ├── test-config-construction.R │ ├── test-api-mcate.R │ ├── test-api-repeats.R │ ├── test-api-weights.R │ ├── test-api-pcate.R │ ├── test-splitting.R │ └── test-api-missing.R ├── codecov.yml ├── .Rbuildignore ├── man ├── abort_data.Rd ├── abort_model.Rd ├── abort_config.Rd ├── abort_package.Rd ├── abort_not_implemented.Rd ├── tidyhte_abort.Rd ├── remove_vimp.Rd ├── calculate_pcate_quantities.Rd ├── check_splits.Rd ├── check_data_has_hte_cfg.Rd ├── check_weights.Rd ├── check_identifier.Rd ├── check_nuisance_models.Rd ├── add_known_propensity_score.Rd ├── add_effect_diagnostic.Rd ├── add_outcome_diagnostic.Rd ├── add_propensity_diagnostic.Rd ├── split_data.Rd ├── listwise_deletion.Rd ├── figures │ ├── lifecycle-stable.svg │ ├── lifecycle-defunct.svg │ ├── lifecycle-archived.svg │ ├── lifecycle-maturing.svg │ ├── lifecycle-deprecated.svg │ ├── lifecycle-superseded.svg │ ├── lifecycle-experimental.svg │ └── lifecycle-questioning.svg ├── add_outcome_model.Rd ├── add_effect_model.Rd ├── predict.SL.glmnet.interaction.Rd ├── add_propensity_score_model.Rd ├── fit_plugin.Rd ├── fit_effect.Rd ├── fit_plugin_A.Rd ├── estimate_diagnostic.Rd ├── fit_plugin_Y.Rd ├── calculate_rroc.Rd ├── calculate_ate.Rd ├── calculate_diagnostics.Rd ├── fit_fx_predictor.Rd ├── add_moderator.Rd ├── construct_pseudo_outcomes.Rd ├── add_vimp.Rd ├── tidyhte-errors.Rd ├── calculate_vimp.Rd ├── tidyhte-package.Rd ├── Model_cfg.Rd ├── calculate_linear_vimp.Rd ├── estimate_QoI.Rd ├── attach_config.Rd ├── SL.glmnet.interaction.Rd ├── make_splits.Rd ├── Constant_cfg.Rd ├── HTEFold.Rd ├── basic_config.Rd ├── produce_plugin_estimates.Rd ├── Stratified_cfg.Rd ├── Known_cfg.Rd ├── SLLearner_cfg.Rd ├── VIMP_cfg.Rd ├── KernelSmooth_cfg.Rd ├── FX.Predictor.Rd ├── Model_data.Rd ├── Diagnostics_cfg.Rd ├── MCATE_cfg.Rd ├── SLEnsemble_cfg.Rd ├── QoI_cfg.Rd └── HTE_cfg.Rd ├── cran-comments.md ├── R ├── tidyhte-package.R ├── utils.R ├── rroc.R ├── errors.R └── plugin-estimates.R ├── LICENSE.md ├── .gitignore ├── CITATION.cff ├── DESCRIPTION ├── NAMESPACE └── _pkgdown.yml /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.covrignore: -------------------------------------------------------------------------------- 1 | R/SL.glmnet.interaction.R 2 | -------------------------------------------------------------------------------- /vignettes/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | *.R 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2021 2 | COPYRIGHT HOLDER: tidyhte authors 3 | -------------------------------------------------------------------------------- /joss/paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddimmery/tidyhte/HEAD/joss/paper.pdf -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.lintr: -------------------------------------------------------------------------------- 1 | linters: linters_with_defaults( 2 | line_length_linter(100), 3 | object_name_linter = NULL, 4 | object_usage_linter = NULL) 5 | -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(checkmate) # for testthat extensions 3 | library(tidyhte) 4 | 5 | test_check("tidyhte") 6 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: auto 8 | threshold: 1% 9 | informational: true 10 | patch: 11 | default: 12 | target: auto 13 | threshold: 1% 14 | informational: true 15 | -------------------------------------------------------------------------------- /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^LICENSE\.md$ 2 | .lintr 3 | .covrignore 4 | ^\.github$ 5 | ^README\.Rmd$ 6 | ^codecov\.yml$ 7 | ^_pkgdown\.yml$ 8 | ^docs$ 9 | ^pkgdown$ 10 | ^doc$ 11 | ^Meta$ 12 | figure$ 13 | cache$ 14 | figure$ 15 | \.dvi$ 16 | \.bbl$ 17 | \.blg$ 18 | cran-comments.md 19 | ^CRAN-SUBMISSION$ 20 | ^joss$ 21 | ^CODE_OF_CONDUCT\.md$ 22 | ^CONTRIBUTING\.md$ 23 | ^CITATION.cff$ -------------------------------------------------------------------------------- /man/abort_data.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/errors.R 3 | \name{abort_data} 4 | \alias{abort_data} 5 | \title{Data validation errors} 6 | \usage{ 7 | abort_data(message, ...) 8 | } 9 | \arguments{ 10 | \item{message}{Error message} 11 | 12 | \item{...}{Additional arguments passed to \code{tidyhte_abort()}} 13 | } 14 | \description{ 15 | Data validation errors 16 | } 17 | \keyword{internal} 18 | -------------------------------------------------------------------------------- /man/abort_model.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/errors.R 3 | \name{abort_model} 4 | \alias{abort_model} 5 | \title{Model-related errors} 6 | \usage{ 7 | abort_model(message, ...) 8 | } 9 | \arguments{ 10 | \item{message}{Error message} 11 | 12 | \item{...}{Additional arguments passed to \code{tidyhte_abort()}} 13 | } 14 | \description{ 15 | Model-related errors 16 | } 17 | \keyword{internal} 18 | -------------------------------------------------------------------------------- /man/abort_config.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/errors.R 3 | \name{abort_config} 4 | \alias{abort_config} 5 | \title{Configuration-related errors} 6 | \usage{ 7 | abort_config(message, ...) 8 | } 9 | \arguments{ 10 | \item{message}{Error message} 11 | 12 | \item{...}{Additional arguments passed to \code{tidyhte_abort()}} 13 | } 14 | \description{ 15 | Configuration-related errors 16 | } 17 | \keyword{internal} 18 | -------------------------------------------------------------------------------- /man/abort_package.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/errors.R 3 | \name{abort_package} 4 | \alias{abort_package} 5 | \title{Package dependency errors} 6 | \usage{ 7 | abort_package(message, ...) 8 | } 9 | \arguments{ 10 | \item{message}{Error message} 11 | 12 | \item{...}{Additional arguments passed to \code{tidyhte_abort()}} 13 | } 14 | \description{ 15 | Package dependency errors 16 | } 17 | \keyword{internal} 18 | -------------------------------------------------------------------------------- /man/abort_not_implemented.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/errors.R 3 | \name{abort_not_implemented} 4 | \alias{abort_not_implemented} 5 | \title{Not implemented errors} 6 | \usage{ 7 | abort_not_implemented(message = "Not implemented", ...) 8 | } 9 | \arguments{ 10 | \item{message}{Error message} 11 | 12 | \item{...}{Additional arguments passed to \code{tidyhte_abort()}} 13 | } 14 | \description{ 15 | Not implemented errors 16 | } 17 | \keyword{internal} 18 | -------------------------------------------------------------------------------- /man/tidyhte_abort.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/errors.R 3 | \name{tidyhte_abort} 4 | \alias{tidyhte_abort} 5 | \title{Throw a tidyhte error} 6 | \usage{ 7 | tidyhte_abort(message, class = "general", ...) 8 | } 9 | \arguments{ 10 | \item{message}{Error message} 11 | 12 | \item{class}{Error class suffix (will be prefixed with "tidyhte_error_")} 13 | 14 | \item{...}{Additional arguments passed to \code{rlang::abort()}} 15 | } 16 | \description{ 17 | Throw a tidyhte error 18 | } 19 | \keyword{internal} 20 | -------------------------------------------------------------------------------- /man/remove_vimp.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/recipe-api.R 3 | \name{remove_vimp} 4 | \alias{remove_vimp} 5 | \title{Removes variable importance information} 6 | \usage{ 7 | remove_vimp(hte_cfg) 8 | } 9 | \arguments{ 10 | \item{hte_cfg}{\code{HTE_cfg} object to update.} 11 | } 12 | \value{ 13 | Updated \code{HTE_cfg} object 14 | } 15 | \description{ 16 | This removes the variable importance quantity of interest 17 | from an \code{HTE_cfg}. 18 | } 19 | \examples{ 20 | library("dplyr") 21 | basic_config() \%>\% 22 | remove_vimp() -> hte_cfg 23 | } 24 | -------------------------------------------------------------------------------- /man/calculate_pcate_quantities.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/qoi.R 3 | \name{calculate_pcate_quantities} 4 | \alias{calculate_pcate_quantities} 5 | \title{Calculate "partial" CATE estimates} 6 | \usage{ 7 | calculate_pcate_quantities( 8 | full_data, 9 | .weights, 10 | .outcome, 11 | fx_model, 12 | ..., 13 | .MCATE_cfg 14 | ) 15 | } 16 | \description{ 17 | \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} 18 | } 19 | \keyword{internal} 20 | -------------------------------------------------------------------------------- /man/check_splits.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data-utils.R 3 | \name{check_splits} 4 | \alias{check_splits} 5 | \title{Checks that splits have been properly created.} 6 | \usage{ 7 | check_splits(data) 8 | } 9 | \arguments{ 10 | \item{data}{Dataframe which should have appropriate \code{.split_id} column.} 11 | } 12 | \value{ 13 | Returns NULL. Errors if a problem is discovered. 14 | } 15 | \description{ 16 | This helper function makes a few simple checks to identify obvious 17 | issues with the way that splits have been made in the supplied data. 18 | } 19 | \keyword{internal} 20 | -------------------------------------------------------------------------------- /man/check_data_has_hte_cfg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data-utils.R 3 | \name{check_data_has_hte_cfg} 4 | \alias{check_data_has_hte_cfg} 5 | \title{Checks that a dataframe has an attached configuration for HTEs} 6 | \usage{ 7 | check_data_has_hte_cfg(data) 8 | } 9 | \arguments{ 10 | \item{data}{Dataframe of interest.} 11 | } 12 | \value{ 13 | Returns NULL. Errors if a problem is discovered. 14 | } 15 | \description{ 16 | This helper function ensures that the provided dataframe has 17 | the necessary auxilliary configuration information for HTE 18 | estimation. 19 | } 20 | \keyword{internal} 21 | -------------------------------------------------------------------------------- /man/check_weights.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data-utils.R 3 | \name{check_weights} 4 | \alias{check_weights} 5 | \title{Checks that an appropriate weighting variable has been provided} 6 | \usage{ 7 | check_weights(data, weight_col) 8 | } 9 | \arguments{ 10 | \item{data}{Dataframe of interest.} 11 | 12 | \item{weight_col}{Quoted name of weights column.} 13 | } 14 | \value{ 15 | Returns NULL. Errors if a problem is discovered. 16 | } 17 | \description{ 18 | This helper function makes a few simple checks to identify obvious 19 | issues with the weights provided. 20 | } 21 | \keyword{internal} 22 | -------------------------------------------------------------------------------- /tests/testthat/test-rroc.R: -------------------------------------------------------------------------------- 1 | set.seed(20051920) # 20051920 is derived from 'test' 2 | 3 | n <- 1000 4 | mu <- 1:n 5 | mu_hat <- rnorm(n, mu) 6 | 7 | test_that("rroc runs", { 8 | expect_error(calculate_rroc(mu, mu_hat), NA) 9 | }) 10 | 11 | test_that("rroc is approximately correct", { 12 | rroc <- calculate_rroc(mu, mu_hat, n) 13 | pos <- rroc$value 14 | neg <- rroc$estimate 15 | ord <- order(pos) 16 | pos <- pos[ord] 17 | neg <- neg[ord] 18 | dpos <- pos[2:n] - pos[1:(n - 1)] 19 | avgneg <- -(neg[2:n] + neg[1:(n - 1)]) / 2 20 | aoc <- sum(avgneg * dpos) 21 | rsd <- sqrt(2 * aoc) 22 | expect_true(rsd > 0.975 & rsd < 1.025) 23 | }) 24 | -------------------------------------------------------------------------------- /man/check_identifier.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data-utils.R 3 | \name{check_identifier} 4 | \alias{check_identifier} 5 | \title{Checks that an appropriate identifier has been provided} 6 | \usage{ 7 | check_identifier(data, id_col) 8 | } 9 | \arguments{ 10 | \item{data}{Dataframe of interest.} 11 | 12 | \item{id_col}{Quoted name of identifier column.} 13 | } 14 | \value{ 15 | Returns NULL. Errors if a problem is discovered. 16 | } 17 | \description{ 18 | This helper function makes a few simple checks to identify obvious 19 | issues with the way provided column of unit identifiers. 20 | } 21 | \keyword{internal} 22 | -------------------------------------------------------------------------------- /.github/workflows/draft-pdf.yml: -------------------------------------------------------------------------------- 1 | name: Draft PDF 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | pull_request: 9 | branches: 10 | - main 11 | - master 12 | 13 | jobs: 14 | paper: 15 | runs-on: ubuntu-latest 16 | name: Paper Draft 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v4 20 | 21 | - name: Build draft PDF 22 | uses: openjournals/openjournals-draft-action@master 23 | with: 24 | journal: joss 25 | paper-path: joss/paper.md 26 | 27 | - name: Upload 28 | uses: actions/upload-artifact@v4 29 | with: 30 | name: paper 31 | path: joss/paper.pdf 32 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- 1 | ## Resubmission 2 | This is a resubmission. In this version I have: 3 | 4 | * Noted the (intentional) change of contact address. 5 | 6 | * Reduced the distributed file size of vignettes. 7 | 8 | ## R CMD check results 9 | 10 | 0 errors | 0 warnings | 1 note 11 | 12 | * checking CRAN incoming feasibility ... [14s] NOTE 13 | Maintainer: 'Drew Dimmery ' 14 | 15 | New maintainer: 16 | Drew Dimmery 17 | Old maintainer(s): 18 | Drew Dimmery 19 | 20 | I am updating my contact email to no longer depend on 21 | institutional affiliation (I have moved). 22 | 23 | ## revdepcheck results 24 | 25 | There are currently no downstream dependencies for this package. -------------------------------------------------------------------------------- /man/check_nuisance_models.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data-utils.R 3 | \name{check_nuisance_models} 4 | \alias{check_nuisance_models} 5 | \title{Checks that nuisance models have been estimated and exist in the supplied dataset.} 6 | \usage{ 7 | check_nuisance_models(data) 8 | } 9 | \arguments{ 10 | \item{data}{Dataframe which should have appropriate columns of nuisance function 11 | predictions: \code{.pi_hat}, \code{.mu0_hat}, and \code{.mu1_hat}} 12 | } 13 | \value{ 14 | Returns NULL. Errors if a problem is discovered. 15 | } 16 | \description{ 17 | This helper function makes a few simple checks to identify obvious 18 | issues with the way that nuisance functions are created and prepared. 19 | } 20 | \keyword{internal} 21 | -------------------------------------------------------------------------------- /man/add_known_propensity_score.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/recipe-api.R 3 | \name{add_known_propensity_score} 4 | \alias{add_known_propensity_score} 5 | \title{Uses a known propensity score} 6 | \usage{ 7 | add_known_propensity_score(hte_cfg, covariate_name) 8 | } 9 | \arguments{ 10 | \item{hte_cfg}{\code{HTE_cfg} object to update.} 11 | 12 | \item{covariate_name}{Character indicating the name of the covariate 13 | name in the dataframe corresponding to the known propensity score.} 14 | } 15 | \value{ 16 | Updated \code{HTE_cfg} object 17 | } 18 | \description{ 19 | This replaces the propensity score model with a known value 20 | of the propensity score. 21 | } 22 | \examples{ 23 | library("dplyr") 24 | basic_config() \%>\% 25 | add_known_propensity_score("ps") -> hte_cfg 26 | } 27 | -------------------------------------------------------------------------------- /man/add_effect_diagnostic.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/recipe-api.R 3 | \name{add_effect_diagnostic} 4 | \alias{add_effect_diagnostic} 5 | \title{Add an additional diagnostic to the effect model} 6 | \usage{ 7 | add_effect_diagnostic(hte_cfg, diag) 8 | } 9 | \arguments{ 10 | \item{hte_cfg}{\code{HTE_cfg} object to update.} 11 | 12 | \item{diag}{Character indicating the name of the diagnostic 13 | to include. Possible values are \code{"MSE"}, \code{"RROC"} and, for 14 | \code{SuperLearner} ensembles, \code{"SL_risk"} and \code{"SL_coefs"}.} 15 | } 16 | \value{ 17 | Updated \code{HTE_cfg} object 18 | } 19 | \description{ 20 | This adds a diagnostic to the effect model. 21 | } 22 | \examples{ 23 | library("dplyr") 24 | basic_config() \%>\% 25 | add_effect_diagnostic("RROC") -> hte_cfg 26 | } 27 | -------------------------------------------------------------------------------- /R/tidyhte-package.R: -------------------------------------------------------------------------------- 1 | #' @details 2 | #' The best place to get started with `tidyhte` is `vignette("experimental_analysis")` which 3 | #' walks through a full analysis of HTE on simulated data, or `vignette("methodological_details")` 4 | #' which gets into more of the details underlying the method. 5 | #' @seealso The core public-facing functions are `make_splits`, `produce_plugin_estimates`, 6 | #' `construct_pseudo_outcomes` and `estimate_QoI`. Configuration is accomplished through `HTE_cfg` 7 | #' in addition to a variety of related classes (see `basic_config`). 8 | #' @references Kennedy, E. H. (2023). Towards optimal doubly robust estimation of heterogeneous 9 | #' causal effects. *Electronic Journal of Statistics*, 17(2), 3008-3049. 10 | #' @keywords internal 11 | "_PACKAGE" 12 | 13 | if (getRversion() >= "2.15.1") utils::globalVariables(".data") 14 | -------------------------------------------------------------------------------- /man/add_outcome_diagnostic.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/recipe-api.R 3 | \name{add_outcome_diagnostic} 4 | \alias{add_outcome_diagnostic} 5 | \title{Add an additional diagnostic to the outcome model} 6 | \usage{ 7 | add_outcome_diagnostic(hte_cfg, diag) 8 | } 9 | \arguments{ 10 | \item{hte_cfg}{\code{HTE_cfg} object to update.} 11 | 12 | \item{diag}{Character indicating the name of the diagnostic 13 | to include. Possible values are \code{"MSE"}, \code{"RROC"} and, for 14 | \code{SuperLearner} ensembles, \code{"SL_risk"} and \code{"SL_coefs"}.} 15 | } 16 | \value{ 17 | Updated \code{HTE_cfg} object 18 | } 19 | \description{ 20 | This adds a diagnostic to the outcome model. 21 | } 22 | \examples{ 23 | library("dplyr") 24 | basic_config() \%>\% 25 | add_outcome_diagnostic("RROC") -> hte_cfg 26 | } 27 | -------------------------------------------------------------------------------- /man/add_propensity_diagnostic.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/recipe-api.R 3 | \name{add_propensity_diagnostic} 4 | \alias{add_propensity_diagnostic} 5 | \title{Add an additional diagnostic to the propensity score} 6 | \usage{ 7 | add_propensity_diagnostic(hte_cfg, diag) 8 | } 9 | \arguments{ 10 | \item{hte_cfg}{\code{HTE_cfg} object to update.} 11 | 12 | \item{diag}{Character indicating the name of the diagnostic 13 | to include. Possible values are \code{"MSE"}, \code{"AUC"} and, for 14 | \code{SuperLearner} ensembles, \code{"SL_risk"} and \code{"SL_coefs"}.} 15 | } 16 | \value{ 17 | Updated \code{HTE_cfg} object 18 | } 19 | \description{ 20 | This adds a diagnostic to the propensity score. 21 | } 22 | \examples{ 23 | library("dplyr") 24 | basic_config() \%>\% 25 | add_propensity_diagnostic(c("AUC", "MSE")) -> hte_cfg 26 | } 27 | -------------------------------------------------------------------------------- /man/split_data.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data-utils.R 3 | \name{split_data} 4 | \alias{split_data} 5 | \title{Partition the data into folds} 6 | \usage{ 7 | split_data(data, split_id) 8 | } 9 | \arguments{ 10 | \item{data}{dataframe} 11 | 12 | \item{split_id}{integer representing the split to construct} 13 | } 14 | \value{ 15 | Returns an R6 object \code{HTEFold} with three public fields: 16 | \itemize{ 17 | \item \code{train} - The split to be used for training the plugin estimates 18 | \item \code{holdout} - The split not used for training 19 | \item \code{in_holdout} - A logical vector indicating for each unit whether they lie in the holdout. 20 | } 21 | } 22 | \description{ 23 | This takes a dataset and a split ID and generates two subsets of the 24 | data corresponding to a training set and a holdout. 25 | } 26 | \keyword{internal} 27 | -------------------------------------------------------------------------------- /man/listwise_deletion.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data-utils.R 3 | \name{listwise_deletion} 4 | \alias{listwise_deletion} 5 | \title{Removes rows which have missing data on any of the supplied columns.} 6 | \usage{ 7 | listwise_deletion(data, ...) 8 | } 9 | \arguments{ 10 | \item{data}{The dataset from which to drop cases which are not fully observed.} 11 | 12 | \item{...}{Unquoted column names which must be non-missing. Missingness in these 13 | columns will result in dropped observations. Missingness in other columns will not.} 14 | } 15 | \value{ 16 | The original data with all observations which are fully observed. 17 | } 18 | \description{ 19 | This function removes rows with missingness based on the columns provided. 20 | If rows are dropped, a message is displayed to the user to inform them of this 21 | fact. 22 | } 23 | \keyword{internal} 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Report a problem with tidyhte 4 | title: '[BUG] ' 5 | labels: bug 6 | assignees: '' 7 | --- 8 | 9 | ## Describe the bug 10 | 11 | A clear and concise description of what the bug is. 12 | 13 | ## Reproducible example 14 | 15 | Please provide a minimal reproducible example (reprex): 16 | 17 | ```r 18 | library(tidyhte) 19 | 20 | # Your code here that demonstrates the bug 21 | ``` 22 | 23 | ## Expected behavior 24 | 25 | What you expected to happen. 26 | 27 | ## Actual behavior 28 | 29 | What actually happened instead. 30 | 31 | ## Session info 32 | 33 | Please run `sessionInfo()` or `devtools::session_info()` and paste the output below: 34 | 35 | ```r 36 | # Paste session info here 37 | ``` 38 | 39 | ## Additional context 40 | 41 | Add any other context about the problem here (e.g., screenshots, related issues, potential causes). 42 | -------------------------------------------------------------------------------- /man/figures/lifecycle-stable.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclestablestable -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an enhancement or new feature for tidyhte 4 | title: '[FEATURE] ' 5 | labels: enhancement 6 | assignees: '' 7 | --- 8 | 9 | ## Feature description 10 | 11 | A clear and concise description of the feature you'd like to see. 12 | 13 | ## Motivation and use case 14 | 15 | Explain why this feature would be useful. What problem does it solve? What research questions would it help answer? 16 | 17 | ## Proposed solution 18 | 19 | If you have ideas about how this feature could be implemented, describe them here. 20 | 21 | ## Alternative approaches 22 | 23 | Are there alternative ways to achieve the same goal? Have you considered workarounds? 24 | 25 | ## Additional context 26 | 27 | Add any other context, examples, or references about the feature request here. 28 | 29 | ## Related work 30 | 31 | Are there other packages or implementations that have similar functionality? 32 | -------------------------------------------------------------------------------- /man/figures/lifecycle-defunct.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecycledefunctdefunct -------------------------------------------------------------------------------- /man/figures/lifecycle-archived.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclearchivedarchived -------------------------------------------------------------------------------- /man/figures/lifecycle-maturing.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclematuringmaturing -------------------------------------------------------------------------------- /man/figures/lifecycle-deprecated.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecycledeprecateddeprecated -------------------------------------------------------------------------------- /man/figures/lifecycle-superseded.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclesupersededsuperseded -------------------------------------------------------------------------------- /tests/testthat/test-diagnostics.R: -------------------------------------------------------------------------------- 1 | test_that("SL model slot", { 2 | expect_error(SL_model_slot("test"), "Unknown model slot.") 3 | 4 | checkmate::expect_character(SL_model_slot(".pi_hat"), pattern = "^pi$") 5 | 6 | checkmate::expect_character(SL_model_slot(".mu1_hat"), pattern = "^mu1$") 7 | 8 | checkmate::expect_character(SL_model_slot(".mu0_hat"), pattern = "^mu0$") 9 | }) 10 | 11 | 12 | test_that("estimate_diagnostic", { 13 | df <- dplyr::tibble(y = rnorm(100), p = rep(0.5, 100)) 14 | expect_message( 15 | estimate_diagnostic(df, "y", "p", "AUC"), 16 | "Cannot calculate AUC because labels are not binary." 17 | ) 18 | 19 | expect_message( 20 | estimate_diagnostic(df, "y", "p", "SL_risk"), 21 | "Cannot calculate SL_risk because the model is not SuperLearner." 22 | ) 23 | 24 | expect_message( 25 | estimate_diagnostic(df, "y", "p", "SL_coefs"), 26 | "Cannot calculate SL_coefs because the model is not SuperLearner." 27 | ) 28 | }) 29 | -------------------------------------------------------------------------------- /man/figures/lifecycle-experimental.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecycleexperimentalexperimental -------------------------------------------------------------------------------- /man/figures/lifecycle-questioning.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclequestioningquestioning -------------------------------------------------------------------------------- /man/add_outcome_model.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/recipe-api.R 3 | \name{add_outcome_model} 4 | \alias{add_outcome_model} 5 | \title{Add an additional model to the outcome ensemble} 6 | \usage{ 7 | add_outcome_model(hte_cfg, model_name, ...) 8 | } 9 | \arguments{ 10 | \item{hte_cfg}{\code{HTE_cfg} object to update.} 11 | 12 | \item{model_name}{Character indicating the name of the model to 13 | incorporate into the outcome ensemble. Possible values 14 | use \code{SuperLearner} naming conventions. A full list is available 15 | with \code{SuperLearner::listWrappers("SL")}} 16 | 17 | \item{...}{Parameters over which to grid-search for this model class.} 18 | } 19 | \value{ 20 | Updated \code{HTE_cfg} object 21 | } 22 | \description{ 23 | This adds a learner to the ensemble used for estimating a model 24 | of the conditional expectation of the outcome. 25 | } 26 | \examples{ 27 | library("dplyr") 28 | basic_config() \%>\% 29 | add_outcome_model("SL.glm.interaction") -> hte_cfg 30 | } 31 | -------------------------------------------------------------------------------- /man/add_effect_model.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/recipe-api.R 3 | \name{add_effect_model} 4 | \alias{add_effect_model} 5 | \title{Add an additional model to the joint effect ensemble} 6 | \usage{ 7 | add_effect_model(hte_cfg, model_name, ...) 8 | } 9 | \arguments{ 10 | \item{hte_cfg}{\code{HTE_cfg} object to update.} 11 | 12 | \item{model_name}{Character indicating the name of the model to 13 | incorporate into the joint effect ensemble. Possible values 14 | use \code{SuperLearner} naming conventions. A full list is available 15 | with \code{SuperLearner::listWrappers("SL")}} 16 | 17 | \item{...}{Parameters over which to grid-search for this model class.} 18 | } 19 | \value{ 20 | Updated \code{HTE_cfg} object 21 | } 22 | \description{ 23 | This adds a learner to the ensemble used for estimating a model 24 | of the conditional expectation of the pseudo-outcome. 25 | } 26 | \examples{ 27 | library("dplyr") 28 | basic_config() \%>\% 29 | add_effect_model("SL.glm.interaction") -> hte_cfg 30 | } 31 | -------------------------------------------------------------------------------- /man/predict.SL.glmnet.interaction.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/SL.glmnet.interaction.R 3 | \name{predict.SL.glmnet.interaction} 4 | \alias{predict.SL.glmnet.interaction} 5 | \title{Prediction for an SL.glmnet object} 6 | \usage{ 7 | \method{predict}{SL.glmnet.interaction}( 8 | object, 9 | newdata, 10 | remove_extra_cols = TRUE, 11 | add_missing_cols = TRUE, 12 | ... 13 | ) 14 | } 15 | \arguments{ 16 | \item{object}{Result object from SL.glmnet} 17 | 18 | \item{newdata}{Dataframe or matrix that will generate predictions.} 19 | 20 | \item{remove_extra_cols}{Remove any extra columns in the new data that were 21 | not part of the original model.} 22 | 23 | \item{add_missing_cols}{Add any columns from original data that do not exist 24 | in the new data, and set values to 0.} 25 | 26 | \item{...}{Any additional arguments (not used).} 27 | } 28 | \description{ 29 | Prediction for the glmnet wrapper. 30 | } 31 | \seealso{ 32 | \link[SuperLearner:SL.glmnet]{SuperLearner::SL.glmnet} 33 | } 34 | -------------------------------------------------------------------------------- /man/add_propensity_score_model.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/recipe-api.R 3 | \name{add_propensity_score_model} 4 | \alias{add_propensity_score_model} 5 | \title{Add an additional model to the propensity score ensemble} 6 | \usage{ 7 | add_propensity_score_model(hte_cfg, model_name, ...) 8 | } 9 | \arguments{ 10 | \item{hte_cfg}{\code{HTE_cfg} object to update.} 11 | 12 | \item{model_name}{Character indicating the name of the model to 13 | incorporate into the propensity score ensemble. Possible values 14 | use \code{SuperLearner} naming conventions. A full list is available 15 | with \code{SuperLearner::listWrappers("SL")}} 16 | 17 | \item{...}{Parameters over which to grid-search for this model class.} 18 | } 19 | \value{ 20 | Updated \code{HTE_cfg} object 21 | } 22 | \description{ 23 | This adds a learner to the ensemble used for estimating propensity 24 | scores. 25 | } 26 | \examples{ 27 | library("dplyr") 28 | basic_config() \%>\% 29 | add_propensity_score_model("SL.glmnet", alpha = c(0, 0.5, 1)) -> hte_cfg 30 | } 31 | -------------------------------------------------------------------------------- /man/fit_plugin.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/plugin-estimates.R 3 | \name{fit_plugin} 4 | \alias{fit_plugin} 5 | \title{Fits a plugin model using the appropriate settings} 6 | \usage{ 7 | fit_plugin(full_data, weight_col, outcome_col, ..., .Model_cfg) 8 | } 9 | \arguments{ 10 | \item{full_data}{The full dataset of interest for the modelling problem.} 11 | 12 | \item{weight_col}{The unquoted weighting variable name to use in model fitting.} 13 | 14 | \item{outcome_col}{The unquoted column name to use as a label for the supervised 15 | learning problem.} 16 | 17 | \item{...}{The unquoted names of covariates to use in the model.} 18 | 19 | \item{.Model_cfg}{A \code{Model_cfg} object configuring the appropriate model type to use.} 20 | } 21 | \value{ 22 | A new \code{Predictor} object of the appropriate subclass corresponding to the 23 | \code{Model_cfg} fit to the data. 24 | } 25 | \description{ 26 | This function prepares data, fits the appropriate models and returns the 27 | resulting estimates in a standardized format. 28 | } 29 | \keyword{internal} 30 | -------------------------------------------------------------------------------- /man/fit_effect.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/plugin-estimates.R 3 | \name{fit_effect} 4 | \alias{fit_effect} 5 | \title{Fits a treatment effect model using the appropriate settings} 6 | \usage{ 7 | fit_effect(full_data, weight_col, fx_col, ..., .Model_cfg) 8 | } 9 | \arguments{ 10 | \item{full_data}{The full dataset of interest for the modelling problem.} 11 | 12 | \item{weight_col}{The unquoted weighting variable name to use in model fitting.} 13 | 14 | \item{fx_col}{The unquoted column name of the pseudo-outcome.} 15 | 16 | \item{...}{The unquoted names of covariates to use in the model.} 17 | 18 | \item{.Model_cfg}{A \code{Model_cfg} object configuring the appropriate model type to use.} 19 | } 20 | \value{ 21 | A list with one element, \code{fx}. This element contains a \code{Predictor} object of 22 | the appropriate subclass corresponding to the \code{Model_cfg} fit to the data. 23 | } 24 | \description{ 25 | This function prepares data, fits the appropriate model and returns the 26 | resulting estimates in a standardized format. 27 | } 28 | \keyword{internal} 29 | -------------------------------------------------------------------------------- /man/fit_plugin_A.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/plugin-estimates.R 3 | \name{fit_plugin_A} 4 | \alias{fit_plugin_A} 5 | \title{Fits a propensity score model using the appropriate settings} 6 | \usage{ 7 | fit_plugin_A(full_data, weight_col, a_col, ..., .Model_cfg) 8 | } 9 | \arguments{ 10 | \item{full_data}{The full dataset of interest for the modelling problem.} 11 | 12 | \item{weight_col}{The unquoted weighting variable name to use in model fitting.} 13 | 14 | \item{a_col}{The unquoted column name of the treatment.} 15 | 16 | \item{...}{The unquoted names of covariates to use in the model.} 17 | 18 | \item{.Model_cfg}{A \code{Model_cfg} object configuring the appropriate model type to use.} 19 | } 20 | \value{ 21 | A list with one element, \code{ps}. This element contains a \code{Predictor} object of 22 | the appropriate subclass corresponding to the \code{Model_cfg} fit to the data. 23 | } 24 | \description{ 25 | This function prepares data, fits the appropriate model and returns the 26 | resulting estimates in a standardized format. 27 | } 28 | \keyword{internal} 29 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2021 tidyhte authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # History files 2 | .Rhistory 3 | .Rapp.history 4 | 5 | # Session Data files 6 | .RData 7 | 8 | # Example code in package build process 9 | *-Ex.R 10 | 11 | # Output files from R CMD build 12 | /*.tar.gz 13 | 14 | # Output files from R CMD check 15 | /*.Rcheck/ 16 | 17 | # RStudio files 18 | .Rproj.user/ 19 | 20 | # produced vignettes 21 | vignettes/*.html 22 | vignettes/*.pdf 23 | 24 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 25 | .httr-oauth 26 | 27 | # knitr and R markdown default cache directories 28 | /*_cache/ 29 | /cache/ 30 | 31 | # Temporary files created by R markdown 32 | *.utf8.md 33 | *.knit.md 34 | 35 | # Shiny token, see https://shiny.rstudio.com/articles/shinyapps.html 36 | rsconnect/ 37 | inst/doc 38 | .DS_Store 39 | docs 40 | /doc/ 41 | /Meta/ 42 | *.fls 43 | *.aux 44 | *.log 45 | *.fdb_latexmk 46 | *.out 47 | *.synctex.gz 48 | .build.timestamp 49 | vignettes/*.tex 50 | vignettes/methodological-details.dvi 51 | vignettes/methodological-details.blg 52 | vignettes/methodological-details.bbl 53 | vignettes/figure/mcate-1.pdf 54 | vignettes/figure/discretemcate-1.pdf 55 | vignettes/figure/ctsmcate-1.pdf 56 | CRAN-SUBMISSION 57 | -------------------------------------------------------------------------------- /man/estimate_diagnostic.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/diagnostics.R 3 | \name{estimate_diagnostic} 4 | \alias{estimate_diagnostic} 5 | \title{Function to calculate diagnostics based on model outputs} 6 | \usage{ 7 | estimate_diagnostic(data, label, prediction, diag_name, params) 8 | } 9 | \arguments{ 10 | \item{data}{The full data frame with all auxilliary columns.} 11 | 12 | \item{label}{The (string) column name for the labels to evaluate against.} 13 | 14 | \item{prediction}{The (string) column name of predictions from the model to diagnose.} 15 | 16 | \item{diag_name}{The (string) name of the diagnostic to calculate. Currently 17 | available are "AUC", "MSE", "SL_coefs", "SL_risk", "RROC"} 18 | 19 | \item{params}{Any other necessary options to pass to the given diagnostic.} 20 | } 21 | \description{ 22 | This function defines the calculations of common model diagnostics 23 | which are available. 24 | } 25 | \examples{ 26 | df <- dplyr::tibble(y = rbinom(100, 1, 0.5), p = rep(0.5, 100), w = rexp(100), u = 1:100) 27 | attr(df, "weights") <- "w" 28 | attr(df, "identifier") <- "u" 29 | estimate_diagnostic(df, "y", "p", "AUC") 30 | } 31 | \keyword{internal} 32 | -------------------------------------------------------------------------------- /man/fit_plugin_Y.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/plugin-estimates.R 3 | \name{fit_plugin_Y} 4 | \alias{fit_plugin_Y} 5 | \title{Fits a T-learner using the appropriate settings} 6 | \usage{ 7 | fit_plugin_Y(full_data, weight_col, y_col, a_col, ..., .Model_cfg) 8 | } 9 | \arguments{ 10 | \item{full_data}{The full dataset of interest for the modelling problem.} 11 | 12 | \item{weight_col}{The unquoted weighting variable name to use in model fitting.} 13 | 14 | \item{y_col}{The unquoted column name of the outcome.} 15 | 16 | \item{a_col}{The unquoted column name of the treatment.} 17 | 18 | \item{...}{The unquoted names of covariates to use in the model.} 19 | 20 | \item{.Model_cfg}{A \code{Model_cfg} object configuring the appropriate model type to use.} 21 | } 22 | \value{ 23 | A list with two elements, \code{mu1} and \code{mu0} corresponding to the models fit to 24 | the treatment and control potential outcomes, respectively. Each is a new \code{Predictor} 25 | object of the appropriate subclass corresponding to the the \code{Model_cfg} fit to the data. 26 | } 27 | \description{ 28 | This function prepares data, fits the appropriate model and returns the 29 | resulting estimates in a standardized format. 30 | } 31 | \keyword{internal} 32 | -------------------------------------------------------------------------------- /man/calculate_rroc.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/rroc.R 3 | \name{calculate_rroc} 4 | \alias{calculate_rroc} 5 | \title{Regression ROC Curve calculation} 6 | \usage{ 7 | calculate_rroc(label, prediction, nbins = 100) 8 | } 9 | \arguments{ 10 | \item{label}{True label} 11 | 12 | \item{prediction}{Model prediction of the label (out of sample)} 13 | 14 | \item{nbins}{Number of shift values to sweep over} 15 | } 16 | \value{ 17 | A tibble with \code{nbins} rows. 18 | } 19 | \description{ 20 | This function calculates the RegressionROC Curve of 21 | of Hernández-Orallo 22 | \doi{doi:10.1016/j.patcog.2013.06.014}. 23 | It provides estimates for the positive and negative 24 | errors when predictions are shifted by a variety 25 | of constants (which range across the domain of observed 26 | residuals). Curves closer to the axes are, in general, to be 27 | preferred. In general, this curve provides a simple way to 28 | visualize the error properties of a regression model. 29 | } 30 | \details{ 31 | The dot shows the errors when no shift is applied, corresponding 32 | to the base model predictions. 33 | } 34 | \references{ 35 | Hernández-Orallo, J. (2013). ROC curves for regression. 36 | Pattern Recognition, 46(12), 3395-3411. 37 | } 38 | \keyword{internal} 39 | -------------------------------------------------------------------------------- /man/calculate_ate.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/qoi.R 3 | \name{calculate_ate} 4 | \alias{calculate_ate} 5 | \title{Calculates a SATE and a PATE using AIPW} 6 | \usage{ 7 | calculate_ate(data) 8 | } 9 | \arguments{ 10 | \item{data}{The dataset of interest after it has been prepared fully.} 11 | } 12 | \description{ 13 | This function takes fully prepared data (with all auxilliary columns from the 14 | necessary models) and estimates average treatment effects using AIPW. 15 | } 16 | \references{ 17 | \itemize{ 18 | \item Kennedy, E. H. (2023). Towards optimal doubly robust estimation of heterogeneous 19 | causal effects. \emph{Electronic Journal of Statistics}, 17(2), 3008-3049. 20 | \item Tsiatis, A. A., Davidian, M., Zhang, M., & Lu, X. (2008). Covariate adjustment 21 | for two‐sample treatment comparisons in randomized clinical trials: a principled 22 | yet flexible approach. \emph{Statistics in medicine}, 27(23), 4658-4677. 23 | } 24 | } 25 | \seealso{ 26 | \code{\link[=basic_config]{basic_config()}}, \code{\link[=attach_config]{attach_config()}}, \code{\link[=make_splits]{make_splits()}}, \code{\link[=produce_plugin_estimates]{produce_plugin_estimates()}}, 27 | \code{\link[=construct_pseudo_outcomes]{construct_pseudo_outcomes()}}, \code{\link[=estimate_QoI]{estimate_QoI()}} 28 | } 29 | \keyword{internal} 30 | -------------------------------------------------------------------------------- /man/calculate_diagnostics.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/diagnostics.R 3 | \name{calculate_diagnostics} 4 | \alias{calculate_diagnostics} 5 | \title{Calculate diagnostics} 6 | \usage{ 7 | calculate_diagnostics(data, treatment, outcome, .diag.cfg) 8 | } 9 | \arguments{ 10 | \item{data}{Data frame with all additional columns (such as model predictions) included.} 11 | 12 | \item{treatment}{Unquoted treatment variable name} 13 | 14 | \item{outcome}{Unquoted outcome variable name} 15 | 16 | \item{.diag.cfg}{\code{Diagnostics_cfg} object} 17 | } 18 | \value{ 19 | Returns a tibble with columns: 20 | \itemize{ 21 | \item \code{estimand} - Character indicating the diagnostic that was calculated 22 | \item \code{level} - Indicates the scope of this diagnostic (e.g. does it apply 23 | only to the model of the outcome under treatment). 24 | \item \code{term} - Indicates a more granular descriptor of what the value is for, 25 | such as the specific model within the SuperLearner ensemble. 26 | \item \code{estimate} - Point estimate of the diagnostic. 27 | \item \code{std_error} - Standard error of the diagnostic. 28 | } 29 | } 30 | \description{ 31 | This function calculates the diagnostics requested by the \code{Diagnostics_cfg} object. 32 | } 33 | \seealso{ 34 | \link{Diagnostics_cfg} 35 | } 36 | \keyword{internal} 37 | -------------------------------------------------------------------------------- /man/fit_fx_predictor.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/joint-effect-estimation.R 3 | \name{fit_fx_predictor} 4 | \alias{fit_fx_predictor} 5 | \title{Fit a predictor for treatment effects} 6 | \usage{ 7 | fit_fx_predictor(full_data, weights, psi_col, ..., .pcate.cfg, .Model_cfg) 8 | } 9 | \arguments{ 10 | \item{full_data}{The full original data with all auxilliary columns.} 11 | 12 | \item{weights}{Weights to be used in the analysis.} 13 | 14 | \item{psi_col}{The unquoted column name of the calculated pseudo-outcome.} 15 | 16 | \item{...}{Covariate data, passed in as the unquoted names of columns in \code{full_data}} 17 | 18 | \item{.pcate.cfg}{A \code{PCATE_cfg} object describing what PCATEs to calculate (and how)} 19 | 20 | \item{.Model_cfg}{A \code{Model_cfg} object describing how the effect model should be estimated.} 21 | } 22 | \value{ 23 | A list with two items: 24 | \itemize{ 25 | \item \code{model} - The \code{FX.Predictor} model object used internally for PCATE estimation. 26 | \item \code{data} - The data augmented with column \code{.pseudo_outcome_hat} for the cross-fit predictions 27 | of the HTE for each unit. 28 | } 29 | } 30 | \description{ 31 | This function predicts treatment effects in a second stage model. 32 | } 33 | \seealso{ 34 | \link{Model_cfg}, \link{PCATE_cfg} 35 | } 36 | \keyword{internal} 37 | -------------------------------------------------------------------------------- /man/add_moderator.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/recipe-api.R 3 | \name{add_moderator} 4 | \alias{add_moderator} 5 | \title{Adds moderators to the configuration} 6 | \usage{ 7 | add_moderator(hte_cfg, model_type, ..., .model_arguments = NULL) 8 | } 9 | \arguments{ 10 | \item{hte_cfg}{\code{HTE_cfg} object to update.} 11 | 12 | \item{model_type}{Character indicating the model type for these moderators. 13 | Currently two model types are supported: \code{"Stratified"} for discrete moderators 14 | and \code{"KernelSmooth"} for continuous ones.} 15 | 16 | \item{...}{The (unquoted) names of the moderator variables.} 17 | 18 | \item{.model_arguments}{A named list from argument name to value to pass into the 19 | constructor for the model. See \code{Stratified_cfg} and \code{KernelSmooth_cfg} for more details.} 20 | } 21 | \value{ 22 | Updated \code{HTE_cfg} object 23 | } 24 | \description{ 25 | This adds a definition about how to display a moderators to 26 | the MCATE config. A moderator is any variable that you want to view information 27 | about CATEs with respect to. 28 | } 29 | \note{ 30 | For moderators with many levels and limited sample per level, estimates may be noisy. 31 | Consider whether other encodings would be more appropriate. 32 | } 33 | \examples{ 34 | library("dplyr") 35 | basic_config() \%>\% 36 | add_moderator("Stratified", x2, x3) \%>\% 37 | add_moderator("KernelSmooth", x1, x4, x5) -> hte_cfg 38 | } 39 | -------------------------------------------------------------------------------- /man/construct_pseudo_outcomes.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/qoi.R 3 | \name{construct_pseudo_outcomes} 4 | \alias{construct_pseudo_outcomes} 5 | \title{Construct Pseudo-outcomes} 6 | \usage{ 7 | construct_pseudo_outcomes(data, outcome, treatment, type = "dr") 8 | } 9 | \arguments{ 10 | \item{data}{dataframe (already prepared with \code{attach_config}, \code{make_splits}, 11 | and \code{produce_plugin_estimates})} 12 | 13 | \item{outcome}{Unquoted name of outcome variable.} 14 | 15 | \item{treatment}{Unquoted name of treatment variable.} 16 | 17 | \item{type}{String representing how to construct the pseudo-outcome. Valid 18 | values are "dr" (the default), "ipw" and "plugin". See "Details" for more 19 | discussion of these options.} 20 | } 21 | \description{ 22 | \code{construct_pseudo_outcomes} takes a dataset which has been prepared 23 | with plugin estimators of nuisance parameters and transforms these into 24 | a "pseudo-outcome": an unbiased estimator of the conditional average 25 | treatment effect under exogeneity. 26 | } 27 | \details{ 28 | Taking averages of these pseudo-outcomes (or fitting a model to them) 29 | will approximate averages (or models) of the underlying treatment effect. 30 | } 31 | \seealso{ 32 | \code{\link[=attach_config]{attach_config()}}, \code{\link[=make_splits]{make_splits()}}, \code{\link[=produce_plugin_estimates]{produce_plugin_estimates()}}, \code{\link[=estimate_QoI]{estimate_QoI()}} 33 | } 34 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | - master 6 | pull_request: 7 | branches: 8 | - main 9 | - master 10 | 11 | name: lint 12 | 13 | jobs: 14 | lint: 15 | runs-on: macOS-latest 16 | env: 17 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 18 | steps: 19 | - uses: actions/checkout@v3 20 | 21 | - uses: r-lib/actions/setup-r@v2 22 | 23 | - name: Query dependencies 24 | run: | 25 | install.packages('remotes') 26 | saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2) 27 | writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version") 28 | shell: Rscript {0} 29 | 30 | - name: Restore R package cache 31 | uses: actions/cache@v3 32 | with: 33 | path: ${{ env.R_LIBS_USER }} 34 | key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }} 35 | restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1- 36 | 37 | - name: Install dependencies 38 | run: | 39 | install.packages(c("remotes")) 40 | remotes::install_deps(dependencies = TRUE) 41 | remotes::install_cran("lintr") 42 | shell: Rscript {0} 43 | 44 | - name: Install package 45 | run: R CMD INSTALL . 46 | 47 | - name: Lint 48 | run: lintr::lint_package() 49 | shell: Rscript {0} 50 | -------------------------------------------------------------------------------- /.github/workflows/coverage.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | - master 6 | pull_request: 7 | branches: 8 | - main 9 | - master 10 | 11 | name: test-coverage 12 | 13 | jobs: 14 | test-coverage: 15 | runs-on: macOS-latest 16 | env: 17 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 18 | steps: 19 | - uses: actions/checkout@v3 20 | 21 | - uses: r-lib/actions/setup-r@v2 22 | 23 | - uses: r-lib/actions/setup-pandoc@v2 24 | 25 | - name: Query dependencies 26 | run: | 27 | install.packages('remotes') 28 | saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2) 29 | writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version") 30 | shell: Rscript {0} 31 | 32 | - name: Restore R package cache 33 | uses: actions/cache@v3 34 | with: 35 | path: ${{ env.R_LIBS_USER }} 36 | key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }} 37 | restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1- 38 | 39 | - name: Install dependencies 40 | run: | 41 | install.packages(c("remotes")) 42 | remotes::install_deps(dependencies = TRUE) 43 | remotes::install_cran("covr") 44 | shell: Rscript {0} 45 | 46 | - name: Test coverage 47 | run: covr::codecov(token = "${{ secrets.CODECOV_TOKEN }}") 48 | shell: Rscript {0} -------------------------------------------------------------------------------- /man/add_vimp.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/recipe-api.R 3 | \name{add_vimp} 4 | \alias{add_vimp} 5 | \title{Adds variable importance information} 6 | \usage{ 7 | add_vimp(hte_cfg, sample_splitting = TRUE, linear_only = FALSE) 8 | } 9 | \arguments{ 10 | \item{hte_cfg}{\code{HTE_cfg} object to update.} 11 | 12 | \item{sample_splitting}{Logical indicating whether to use sample splitting or not. 13 | Choosing not to use sample splitting means that inference will only be valid for 14 | moderators with non-null importance.} 15 | 16 | \item{linear_only}{Logical indicating whether the variable importance should use only a single 17 | linear-only model. Variable importance measure will only be consistent for the population 18 | quantity if the true model of pseudo-outcomes is linear.} 19 | } 20 | \value{ 21 | Updated \code{HTE_cfg} object 22 | } 23 | \description{ 24 | This adds a variable importance quantity of interest to the outputs. 25 | } 26 | \examples{ 27 | library("dplyr") 28 | basic_config() \%>\% 29 | add_vimp(sample_splitting = FALSE) -> hte_cfg 30 | } 31 | \references{ 32 | \itemize{ 33 | \item Williamson, B. D., Gilbert, P. B., Carone, M., & Simon, N. (2021). 34 | Nonparametric variable importance assessment using machine learning techniques. 35 | Biometrics, 77(1), 9-22. 36 | \item Williamson, B. D., Gilbert, P. B., Simon, N. R., & Carone, M. (2021). 37 | A general framework for inference on algorithm-agnostic variable importance. 38 | Journal of the American Statistical Association, 1-14. 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /man/tidyhte-errors.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/errors.R 3 | \name{tidyhte-errors} 4 | \alias{tidyhte-errors} 5 | \title{Error utilities for tidyhte package} 6 | \description{ 7 | These functions provide a standardized way to throw classed errors throughout 8 | the tidyhte package, replacing basic \code{stop()} calls with structured error 9 | conditions that can be caught and handled programmatically. 10 | } 11 | \details{ 12 | The tidyhte package uses a hierarchical error class system based on rlang's 13 | structured condition handling. All tidyhte errors inherit from the base class 14 | \code{tidyhte_error_general} and include more specific subclasses: 15 | \itemize{ 16 | \item \code{tidyhte_error_config} - Configuration-related errors 17 | \item \code{tidyhte_error_model} - Model-related errors 18 | \item \code{tidyhte_error_data} - Data validation errors 19 | \item \code{tidyhte_error_not_implemented} - Not implemented functionality 20 | \item \code{tidyhte_error_package} - Package dependency errors 21 | } 22 | 23 | These classed errors can be caught using \code{tryCatch()} or \code{rlang::catch_cnd()} 24 | with the specific error class for more precise error handling. 25 | } 26 | \examples{ 27 | \dontrun{ 28 | # Catching specific error types 29 | tryCatch( 30 | some_tidyhte_function(), 31 | tidyhte_error_config = function(e) cat("Configuration error:", conditionMessage(e)), 32 | tidyhte_error_data = function(e) cat("Data error:", conditionMessage(e)) 33 | ) 34 | } 35 | 36 | } 37 | \keyword{internal} 38 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | name: CRAN-like-check 2 | on: 3 | push: 4 | branches: [main, master] 5 | pull_request: 6 | branches: [main, master] 7 | schedule: 8 | - cron: "0 4 * * *" 9 | 10 | jobs: 11 | R-CMD-check: 12 | runs-on: ${{ matrix.config.os }} 13 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 14 | 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | config: 19 | - {os: windows-latest, r: 'devel'} 20 | - {os: macOS-latest, r: 'release'} 21 | - {os: ubuntu-latest, r: 'devel'} 22 | - {os: ubuntu-latest, r: 'release'} 23 | 24 | env: 25 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 26 | R_KEEP_PKG_SOURCE: yes 27 | # CRAN-specific environment variables 28 | _R_CHECK_CRAN_INCOMING_: true 29 | _R_CHECK_CRAN_INCOMING_REMOTE_: false 30 | _R_CHECK_FORCE_SUGGESTS_: false 31 | _R_CHECK_PACKAGES_USED_IN_TESTS_USE_SUBDIRS_: true 32 | _R_CHECK_CRAN_INCOMING_USE_ASPELL_: true 33 | 34 | steps: 35 | - uses: actions/checkout@v4 36 | 37 | - uses: r-lib/actions/setup-pandoc@v2 38 | 39 | - uses: r-lib/actions/setup-r@v2 40 | with: 41 | r-version: ${{ matrix.config.r }} 42 | http-user-agent: ${{ matrix.config.http-user-agent }} 43 | use-public-rspm: true 44 | 45 | - uses: r-lib/actions/setup-r-dependencies@v2 46 | with: 47 | extra-packages: any::rcmdcheck 48 | needs: check 49 | 50 | - uses: r-lib/actions/check-r-package@v2 51 | with: 52 | upload-snapshots: true 53 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | type: software 4 | title: 'tidyhte: Tidy Estimation of Heterogeneous Treatment Effects' 5 | version: 1.0.3 6 | doi: 10.5281/zenodo.6325442 7 | date-released: 2024-09-01 8 | url: "https://github.com/ddimmery/tidyhte" 9 | repository-code: "https://github.com/ddimmery/tidyhte" 10 | authors: 11 | - family-names: "Dimmery" 12 | given-names: "Drew" 13 | orcid: "https://orcid.org/0000-0001-9602-6325" 14 | email: "cran@ddimmery.com" 15 | abstract: >- 16 | Estimates heterogeneous treatment effects using tidy semantics 17 | on experimental or observational data. Methods are based on the 18 | doubly-robust learner of Kennedy (2023). You provide a simple 19 | recipe for what machine learning algorithms to use in estimating 20 | the nuisance functions and 'tidyhte' will take care of 21 | cross-validation, estimation, model selection, diagnostics and 22 | construction of relevant quantities of interest about the 23 | variability of treatment effects. 24 | keywords: 25 | - R 26 | - causal inference 27 | - heterogeneous treatment effects 28 | - machine learning 29 | - experimental design 30 | - doubly-robust estimation 31 | license: MIT 32 | references: 33 | - type: article 34 | title: "Towards optimal doubly robust estimation of heterogeneous causal effects" 35 | authors: 36 | - family-names: "Kennedy" 37 | given-names: "Edward H." 38 | journal: "Electronic Journal of Statistics" 39 | volume: 17 40 | issue: 2 41 | year: 2023 42 | start: 3008 43 | end: 3049 44 | doi: "10.1214/23-EJS2157" 45 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | - master 6 | tags: 7 | -'*' 8 | 9 | name: pkgdown 10 | 11 | jobs: 12 | pkgdown: 13 | runs-on: macOS-latest 14 | env: 15 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - uses: r-lib/actions/setup-r@v2 20 | 21 | - uses: r-lib/actions/setup-pandoc@v2 22 | 23 | - name: Query dependencies 24 | run: | 25 | install.packages('remotes') 26 | saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2) 27 | writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version") 28 | shell: Rscript {0} 29 | 30 | - name: Restore R package cache 31 | uses: actions/cache@v3 32 | with: 33 | path: ${{ env.R_LIBS_USER }} 34 | key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }} 35 | restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1- 36 | 37 | - name: Install dependencies 38 | run: | 39 | remotes::install_deps(dependencies = TRUE) 40 | install.packages("pkgdown", type = "binary") 41 | shell: Rscript {0} 42 | 43 | - name: Install package 44 | run: R CMD INSTALL . 45 | 46 | - name: Deploy package 47 | run: | 48 | git config --local user.email "actions@github.com" 49 | git config --local user.name "GitHub Actions" 50 | Rscript -e 'pkgdown::deploy_to_branch(new_process = FALSE)' 51 | -------------------------------------------------------------------------------- /man/calculate_vimp.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/variable_importance.R 3 | \name{calculate_vimp} 4 | \alias{calculate_vimp} 5 | \title{Calculate Variable Importance of HTEs} 6 | \usage{ 7 | calculate_vimp( 8 | full_data, 9 | weight_col, 10 | pseudo_outcome, 11 | ..., 12 | .VIMP_cfg, 13 | .Model_cfg 14 | ) 15 | } 16 | \arguments{ 17 | \item{full_data}{dataframe} 18 | 19 | \item{weight_col}{Unquoted name of the weight column.} 20 | 21 | \item{pseudo_outcome}{Unquoted name of the pseudo-outcome.} 22 | 23 | \item{...}{Unquoted names of covariates to include in the joint effect model. 24 | The variable importance will be calculated for each of these covariates.} 25 | 26 | \item{.VIMP_cfg}{A \code{VIMP_cfg} object defining how VIMP should be estimated.} 27 | 28 | \item{.Model_cfg}{A \code{Model_cfg} object defining how the joint effect model should be estimated.} 29 | } 30 | \description{ 31 | \code{calculate_vimp} estimates the reduction in (population) $R^2$ from 32 | removing a particular moderator from a model containing all moderators. 33 | } 34 | \references{ 35 | \itemize{ 36 | \item Williamson, B. D., Gilbert, P. B., Carone, M., & Simon, N. (2021). 37 | Nonparametric variable importance assessment using machine learning techniques. 38 | Biometrics, 77(1), 9-22. 39 | \item Williamson, B. D., Gilbert, P. B., Simon, N. R., & Carone, M. (2021). 40 | A general framework for inference on algorithm-agnostic variable importance. 41 | Journal of the American Statistical Association, 1-14. 42 | } 43 | } 44 | \seealso{ 45 | \code{\link[=calculate_linear_vimp]{calculate_linear_vimp()}} 46 | } 47 | \keyword{internal} 48 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: tidyhte 2 | Title: Tidy Estimation of Heterogeneous Treatment Effects 3 | Version: 1.0.4 4 | Authors@R: 5 | person(given = "Drew", 6 | family = "Dimmery", 7 | role = c("aut", "cre", "cph"), 8 | email = "cran@ddimmery.com", 9 | comment = c(ORCID = "0000-0001-9602-6325")) 10 | Description: Estimates heterogeneous treatment effects using tidy semantics 11 | on experimental or observational data. Methods are based on the doubly-robust 12 | learner of Kennedy (2023) . You provide a simple 13 | recipe for what machine learning algorithms to use in estimating the nuisance 14 | functions and 'tidyhte' will take care of cross-validation, estimation, model 15 | selection, diagnostics and construction of relevant quantities of interest about 16 | the variability of treatment effects. 17 | URL: https://github.com/ddimmery/tidyhte https://ddimmery.github.io/tidyhte/index.html 18 | BugReports: https://github.com/ddimmery/tidyhte/issues 19 | License: MIT + file LICENSE 20 | Encoding: UTF-8 21 | LazyData: true 22 | Roxygen: list(markdown = TRUE) 23 | RoxygenNote: 7.3.2 24 | Suggests: 25 | covr, 26 | devtools, 27 | estimatr, 28 | ggplot2, 29 | glmnet, 30 | knitr, 31 | mockr, 32 | nprobust, 33 | palmerpenguins, 34 | quadprog, 35 | quickblock, 36 | rmarkdown, 37 | testthat (>= 3.0.0), 38 | vimp, 39 | WeightedROC 40 | Config/testthat/edition: 3 41 | Imports: 42 | checkmate, 43 | dplyr, 44 | lifecycle, 45 | magrittr, 46 | progress, 47 | purrr, 48 | R6, 49 | rlang, 50 | SuperLearner, 51 | tibble 52 | VignetteBuilder: knitr 53 | -------------------------------------------------------------------------------- /man/tidyhte-package.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tidyhte-package.R 3 | \docType{package} 4 | \name{tidyhte-package} 5 | \alias{tidyhte} 6 | \alias{tidyhte-package} 7 | \title{tidyhte: Tidy Estimation of Heterogeneous Treatment Effects} 8 | \description{ 9 | Estimates heterogeneous treatment effects using tidy semantics on experimental or observational data. Methods are based on the doubly-robust learner of Kennedy (2023) \doi{10.1214/23-EJS2157}. You provide a simple recipe for what machine learning algorithms to use in estimating the nuisance functions and 'tidyhte' will take care of cross-validation, estimation, model selection, diagnostics and construction of relevant quantities of interest about the variability of treatment effects. 10 | } 11 | \details{ 12 | The best place to get started with \code{tidyhte} is \code{vignette("experimental_analysis")} which 13 | walks through a full analysis of HTE on simulated data, or \code{vignette("methodological_details")} 14 | which gets into more of the details underlying the method. 15 | } 16 | \references{ 17 | Kennedy, E. H. (2023). Towards optimal doubly robust estimation of heterogeneous 18 | causal effects. \emph{Electronic Journal of Statistics}, 17(2), 3008-3049. 19 | } 20 | \seealso{ 21 | The core public-facing functions are \code{make_splits}, \code{produce_plugin_estimates}, 22 | \code{construct_pseudo_outcomes} and \code{estimate_QoI}. Configuration is accomplished through \code{HTE_cfg} 23 | in addition to a variety of related classes (see \code{basic_config}). 24 | } 25 | \author{ 26 | \strong{Maintainer}: Drew Dimmery \email{cran@ddimmery.com} (\href{https://orcid.org/0000-0001-9602-6325}{ORCID}) [copyright holder] 27 | 28 | } 29 | \keyword{internal} 30 | -------------------------------------------------------------------------------- /man/Model_cfg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/model_cfg.R 3 | \name{Model_cfg} 4 | \alias{Model_cfg} 5 | \title{Base Class of Model Configurations} 6 | \description{ 7 | \code{Model_cfg} is the base class from which all other model configurations 8 | inherit. 9 | } 10 | \section{Public fields}{ 11 | \if{html}{\out{
}} 12 | \describe{ 13 | \item{\code{model_class}}{The class of the model, required for all classes 14 | which inherit from \code{Model_cfg}.} 15 | } 16 | \if{html}{\out{
}} 17 | } 18 | \section{Methods}{ 19 | \subsection{Public methods}{ 20 | \itemize{ 21 | \item \href{#method-Model_cfg-new}{\code{Model_cfg$new()}} 22 | \item \href{#method-Model_cfg-clone}{\code{Model_cfg$clone()}} 23 | } 24 | } 25 | \if{html}{\out{
}} 26 | \if{html}{\out{}} 27 | \if{latex}{\out{\hypertarget{method-Model_cfg-new}{}}} 28 | \subsection{Method \code{new()}}{ 29 | Create a new \code{Model_cfg} object with any necessary parameters. 30 | \subsection{Usage}{ 31 | \if{html}{\out{
}}\preformatted{Model_cfg$new()}\if{html}{\out{
}} 32 | } 33 | 34 | \subsection{Returns}{ 35 | A new \code{Model_cfg} object. 36 | } 37 | } 38 | \if{html}{\out{
}} 39 | \if{html}{\out{}} 40 | \if{latex}{\out{\hypertarget{method-Model_cfg-clone}{}}} 41 | \subsection{Method \code{clone()}}{ 42 | The objects of this class are cloneable with this method. 43 | \subsection{Usage}{ 44 | \if{html}{\out{
}}\preformatted{Model_cfg$clone(deep = FALSE)}\if{html}{\out{
}} 45 | } 46 | 47 | \subsection{Arguments}{ 48 | \if{html}{\out{
}} 49 | \describe{ 50 | \item{\code{deep}}{Whether to make a deep clone.} 51 | } 52 | \if{html}{\out{
}} 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /tests/testthat/test-recipe-api.R: -------------------------------------------------------------------------------- 1 | test_that("basic config shortcut", { 2 | cfg <- basic_config() 3 | checkmate::expect_r6(cfg, "HTE_cfg") 4 | }) 5 | 6 | cfg <- basic_config() 7 | 8 | test_that("recipe manipulations on ps", { 9 | checkmate::expect_r6(add_propensity_score_model(cfg, "SL.glmnet"), "HTE_cfg") 10 | 11 | checkmate::expect_r6(add_known_propensity_score(cfg, "pscore"), "HTE_cfg") 12 | 13 | checkmate::expect_r6(add_propensity_diagnostic(cfg, "MSE"), "HTE_cfg") 14 | }) 15 | 16 | test_that("recipe manipulations on outcome", { 17 | checkmate::expect_r6(add_outcome_model(cfg, "SL.glmnet"), "HTE_cfg") 18 | 19 | checkmate::expect_r6(add_outcome_diagnostic(cfg, "MSE"), "HTE_cfg") 20 | }) 21 | 22 | test_that("recipe manipulations on effect", { 23 | checkmate::expect_r6(add_effect_model(cfg, "SL.glmnet"), "HTE_cfg") 24 | 25 | checkmate::expect_r6(add_effect_diagnostic(cfg, "MSE"), "HTE_cfg") 26 | }) 27 | 28 | test_that("recipe manipulations on moderators", { 29 | checkmate::expect_r6(add_moderator(cfg, "Stratified", x1), "HTE_cfg") 30 | 31 | expect_error(add_moderator(cfg, "unknown", x1), "Unknown `model_type`.") 32 | 33 | checkmate::expect_r6( 34 | add_moderator(cfg, "KernelSmooth", x2, .model_arguments = rlang::list2(neval = 50)), 35 | "HTE_cfg" 36 | ) 37 | 38 | checkmate::expect_r6(add_moderator(cfg, "KernelSmooth", x2), "HTE_cfg") 39 | }) 40 | 41 | test_that("recipe manipulations on vimp", { 42 | checkmate::expect_r6(add_vimp(cfg, sample_splitting = FALSE), "HTE_cfg") 43 | }) 44 | 45 | test_that("init treatment", { 46 | cfg$treatment <- NULL 47 | checkmate::expect_r6(add_propensity_score_model(cfg, "SL.glmnet"), "HTE_cfg") 48 | }) 49 | 50 | test_that("init outcome", { 51 | cfg$outcome <- NULL 52 | checkmate::expect_r6(add_outcome_model(cfg, "SL.glmnet"), "HTE_cfg") 53 | }) 54 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method(predict,SL.glmnet.interaction) 4 | export(Constant_cfg) 5 | export(Diagnostics_cfg) 6 | export(HTE_cfg) 7 | export(KernelSmooth_cfg) 8 | export(Known_cfg) 9 | export(MCATE_cfg) 10 | export(Model_cfg) 11 | export(Model_data) 12 | export(PCATE_cfg) 13 | export(QoI_cfg) 14 | export(SL.glmnet.interaction) 15 | export(SLEnsemble_cfg) 16 | export(SLLearner_cfg) 17 | export(Stratified_cfg) 18 | export(VIMP_cfg) 19 | export(add_effect_diagnostic) 20 | export(add_effect_model) 21 | export(add_known_propensity_score) 22 | export(add_moderator) 23 | export(add_outcome_diagnostic) 24 | export(add_outcome_model) 25 | export(add_propensity_diagnostic) 26 | export(add_propensity_score_model) 27 | export(add_vimp) 28 | export(attach_config) 29 | export(basic_config) 30 | export(construct_pseudo_outcomes) 31 | export(estimate_QoI) 32 | export(estimate_diagnostic) 33 | export(make_splits) 34 | export(produce_plugin_estimates) 35 | export(remove_vimp) 36 | import(SuperLearner) 37 | importFrom(R6,R6Class) 38 | importFrom(dplyr,"%>%") 39 | importFrom(dplyr,bind_rows) 40 | importFrom(dplyr,left_join) 41 | importFrom(dplyr,matches) 42 | importFrom(dplyr,select) 43 | importFrom(dplyr,summarize) 44 | importFrom(dplyr,tibble) 45 | importFrom(magrittr,"%>%") 46 | importFrom(progress,progress_bar) 47 | importFrom(purrr,map) 48 | importFrom(rlang,.env) 49 | importFrom(rlang,check_installed) 50 | importFrom(rlang,enexpr) 51 | importFrom(rlang,enexprs) 52 | importFrom(rlang,list2) 53 | importFrom(stats,complete.cases) 54 | importFrom(stats,lm) 55 | importFrom(stats,model.frame) 56 | importFrom(stats,model.matrix) 57 | importFrom(stats,predict) 58 | importFrom(stats,quantile) 59 | importFrom(stats,residuals) 60 | importFrom(stats,sd) 61 | importFrom(stats,weighted.mean) 62 | importFrom(tibble,as_tibble) 63 | -------------------------------------------------------------------------------- /man/calculate_linear_vimp.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/variable_importance.R 3 | \name{calculate_linear_vimp} 4 | \alias{calculate_linear_vimp} 5 | \title{Calculate Linear Variable Importance of HTEs} 6 | \usage{ 7 | calculate_linear_vimp( 8 | full_data, 9 | weight_col, 10 | pseudo_outcome, 11 | ..., 12 | .VIMP_cfg, 13 | .Model_cfg 14 | ) 15 | } 16 | \arguments{ 17 | \item{full_data}{dataframe} 18 | 19 | \item{weight_col}{Unquoted name of the weight column.} 20 | 21 | \item{pseudo_outcome}{Unquoted name of the pseudo-outcome.} 22 | 23 | \item{...}{Unquoted names of covariates to include in the joint effect model. 24 | The variable importance will be calculated for each of these covariates.} 25 | 26 | \item{.VIMP_cfg}{A \code{VIMP_cfg} object defining how VIMP should be estimated.} 27 | 28 | \item{.Model_cfg}{A \code{Model_cfg} object defining how the joint effect model should be estimated.} 29 | } 30 | \description{ 31 | \code{calculate_linear_vimp} estimates the linear hypothesis test of removing a particular moderator 32 | from a linear model containing all moderators. Unlike \code{calculate_vimp}, this will only be 33 | unbiased and have correct asymptotic coverage rates if the true model is linear. This linear 34 | approach is also substantially faster, so may be useful when prototyping an analysis. 35 | } 36 | \references{ 37 | \itemize{ 38 | \item Williamson, B. D., Gilbert, P. B., Carone, M., & Simon, N. (2021). 39 | Nonparametric variable importance assessment using machine learning techniques. 40 | Biometrics, 77(1), 9-22. 41 | \item Williamson, B. D., Gilbert, P. B., Simon, N. R., & Carone, M. (2021). 42 | A general framework for inference on algorithm-agnostic variable importance. 43 | Journal of the American Statistical Association, 1-14. 44 | } 45 | } 46 | \seealso{ 47 | \code{\link[=calculate_vimp]{calculate_vimp()}} 48 | } 49 | \keyword{internal} 50 | -------------------------------------------------------------------------------- /man/estimate_QoI.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/public_api.R 3 | \name{estimate_QoI} 4 | \alias{estimate_QoI} 5 | \title{Estimate Quantities of Interest} 6 | \usage{ 7 | estimate_QoI(data, ...) 8 | } 9 | \arguments{ 10 | \item{data}{data frame (already prepared with \code{attach_config}, \code{make_splits}, 11 | \code{produce_plugin_estimates} and \code{construct_pseudo_outcomes})} 12 | 13 | \item{...}{Unquoted names of moderators to calculate QoIs for.} 14 | } 15 | \description{ 16 | \code{estimate_QoI} takes a dataframe already prepared with split IDs, 17 | plugin estimates and pseudo-outcomes and calculates the requested 18 | quantities of interest (QoIs). 19 | } 20 | \details{ 21 | To see an example analysis, read \code{vignette("experimental_analysis")} in the context 22 | of an experiment, \code{vignette("experimental_analysis")} for an observational study, or 23 | \code{vignette("methodological_details")} for a deeper dive under the hood. 24 | } 25 | \examples{ 26 | library("dplyr") 27 | if(require("palmerpenguins")) { 28 | data(package = 'palmerpenguins') 29 | penguins$unitid = seq_len(nrow(penguins)) 30 | penguins$propensity = rep(0.5, nrow(penguins)) 31 | penguins$treatment = rbinom(nrow(penguins), 1, penguins$propensity) 32 | cfg <- basic_config() \%>\% 33 | add_known_propensity_score("propensity") \%>\% 34 | add_outcome_model("SL.glm.interaction") \%>\% 35 | remove_vimp() 36 | attach_config(penguins, cfg) \%>\% 37 | make_splits(unitid, .num_splits = 4) \%>\% 38 | produce_plugin_estimates(outcome = body_mass_g, treatment = treatment, species, sex) \%>\% 39 | construct_pseudo_outcomes(body_mass_g, treatment) \%>\% 40 | estimate_QoI(species, sex) 41 | } 42 | } 43 | \seealso{ 44 | \code{\link[=attach_config]{attach_config()}}, \code{\link[=make_splits]{make_splits()}}, \code{\link[=produce_plugin_estimates]{produce_plugin_estimates()}}, 45 | \code{\link[=construct_pseudo_outcomes]{construct_pseudo_outcomes()}}, 46 | } 47 | -------------------------------------------------------------------------------- /tests/testthat/test-data.R: -------------------------------------------------------------------------------- 1 | test_that("listwise_deletion", { 2 | df <- dplyr::tibble( 3 | a = 1:100, 4 | b = rnorm(100), 5 | c = c(NA, NA, rnorm(98)), 6 | d = c(rnorm(98), NA, NA) 7 | ) 8 | expect_error(df1 <- listwise_deletion(df), NA) 9 | expect_true(nrow(df) == nrow(df1)) 10 | expect_message( 11 | listwise_deletion(df, c), 12 | "Dropped 2 of 100 rows (2%) through listwise deletion.", 13 | fixed = TRUE 14 | ) 15 | expect_message( 16 | listwise_deletion(df, c, d), 17 | "Dropped 4 of 100 rows (4%) through listwise deletion.", 18 | fixed = TRUE 19 | ) 20 | }) 21 | 22 | test_that("check_identifier", { 23 | df <- dplyr::tibble( 24 | uid1 = 1:10, 25 | uid2 = c(1:9, NA), 26 | uid3 = rnorm(10), 27 | uid4 = paste0("uid", 1:10) 28 | ) 29 | msg <- "Invalid identifier. Each unit / cluster must have its own unique ID." 30 | expect_error(check_identifier(df, "uid1"), NA) 31 | expect_error(check_identifier(df, "uid2"), msg) 32 | expect_error(check_identifier(df, "uid3"), msg) 33 | expect_error(check_identifier(df, "uid4"), NA) 34 | expect_error(check_identifier(df, "uid5"), msg) 35 | }) 36 | 37 | test_that("check_weights", { 38 | df <- dplyr::tibble( 39 | uid = 1:10, 40 | w1 = rexp(10), 41 | w2 = rnorm(10), 42 | w3 = paste0("weight is ", rexp(10)) 43 | ) 44 | expect_error(check_weights(df, "w1"), NA) 45 | expect_error(check_weights(df, "test"), "Invalid weight column. Must exist in dataframe.") 46 | expect_error(check_weights(df, "w2"), "Invalid weight column. Must be non-negative.") 47 | expect_error(check_weights(df, "w3"), "Invalid weight column. Must be numeric.") 48 | }) 49 | 50 | test_that("check_data_has_hte_cfg", { 51 | df <- dplyr::tibble( 52 | uid = 1:10, 53 | x1 = rnorm(10), 54 | y = rnorm(10), 55 | a = sample(2, 10, replace = TRUE) 56 | ) 57 | expect_error( 58 | check_data_has_hte_cfg(df), 59 | "Must attach HTE_cfg with `attach_config`." 60 | ) 61 | }) 62 | -------------------------------------------------------------------------------- /man/attach_config.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/public_api.R 3 | \name{attach_config} 4 | \alias{attach_config} 5 | \title{Attach an \code{HTE_cfg} to a dataframe} 6 | \usage{ 7 | attach_config(data, .HTE_cfg) 8 | } 9 | \arguments{ 10 | \item{data}{dataframe} 11 | 12 | \item{.HTE_cfg}{\code{HTE_cfg} object representing the full configuration of the HTE analysis.} 13 | } 14 | \description{ 15 | This adds a configuration attribute to a dataframe for HTE estimation. 16 | This configuration details the full analysis of HTE that should be performed. 17 | } 18 | \details{ 19 | For information about how to set up an \code{HTE_cfg} object, see the Recipe API 20 | documentation \code{\link[=basic_config]{basic_config()}}. 21 | 22 | To see an example analysis, read \code{vignette("experimental_analysis")} in the context 23 | of an experiment, \code{vignette("experimental_analysis")} for an observational study, or 24 | \code{vignette("methodological_details")} for a deeper dive under the hood. 25 | } 26 | \examples{ 27 | library("dplyr") 28 | if(require("palmerpenguins")) { 29 | data(package = 'palmerpenguins') 30 | penguins$unitid = seq_len(nrow(penguins)) 31 | penguins$propensity = rep(0.5, nrow(penguins)) 32 | penguins$treatment = rbinom(nrow(penguins), 1, penguins$propensity) 33 | cfg <- basic_config() \%>\% 34 | add_known_propensity_score("propensity") \%>\% 35 | add_outcome_model("SL.glm.interaction") \%>\% 36 | remove_vimp() 37 | attach_config(penguins, cfg) \%>\% 38 | make_splits(unitid, .num_splits = 4) \%>\% 39 | produce_plugin_estimates(outcome = body_mass_g, treatment = treatment, species, sex) \%>\% 40 | construct_pseudo_outcomes(body_mass_g, treatment) \%>\% 41 | estimate_QoI(species, sex) 42 | } 43 | } 44 | \seealso{ 45 | \code{\link[=basic_config]{basic_config()}}, \code{\link[=make_splits]{make_splits()}}, \code{\link[=produce_plugin_estimates]{produce_plugin_estimates()}}, 46 | \code{\link[=construct_pseudo_outcomes]{construct_pseudo_outcomes()}}, \code{\link[=estimate_QoI]{estimate_QoI()}} 47 | } 48 | -------------------------------------------------------------------------------- /R/utils.R: -------------------------------------------------------------------------------- 1 | #' @noRd 2 | #' @keywords internal 3 | muffle_warnings <- function(expr, ...) { 4 | regex <- paste(..., sep = "|") 5 | withCallingHandlers( 6 | expr, 7 | warning = function(w) { 8 | if (grepl(regex, conditionMessage(w))) { 9 | invokeRestart("muffleWarning") 10 | } 11 | } 12 | ) 13 | } 14 | 15 | #' @noRd 16 | #' @keywords internal 17 | muffle_messages <- function(expr, ...) { 18 | regex <- paste(..., sep = "|") 19 | withCallingHandlers( 20 | expr, 21 | message = function(m) { 22 | if (grepl(regex, conditionMessage(m))) { 23 | invokeRestart("muffleMessage") 24 | } 25 | } 26 | ) 27 | } 28 | 29 | #' @noRd 30 | #' @keywords internal 31 | #' @importFrom rlang check_installed 32 | soft_require <- function(package, load = FALSE) { 33 | rlang::check_installed(package) 34 | if (load) { 35 | try(attachNamespace(package), silent = TRUE) 36 | } 37 | } 38 | 39 | #' @noRd 40 | #' @keywords internal 41 | package_present <- function(package) { 42 | rlang::is_installed(package) 43 | } 44 | 45 | #' @noRd 46 | #' @keywords internal 47 | check_hte_cfg <- function(cfg) { 48 | checkmate::check_r6(cfg, classes = "HTE_cfg") 49 | } 50 | 51 | #' @noRd 52 | #' @keywords internal 53 | zero_range <- function(x, tol = .Machine$double.eps ^ 0.5) { 54 | if (length(x) == 1) return(TRUE) 55 | x <- range(x) / mean(x) 56 | isTRUE(all.equal(x[1], x[2], tolerance = tol)) 57 | } 58 | 59 | #' @noRd 60 | #' @keywords internal 61 | #' @importFrom stats weighted.mean 62 | clustered_se_of_mean <- function(y, cluster, weights = rep(1, length(y))) { 63 | n <- length(y) 64 | weights <- weights / sum(weights) * n 65 | H <- length(unique(cluster)) 66 | yhat <- stats::weighted.mean(y, weights) 67 | if (H < n) { 68 | dplyr::tibble(r = y - yhat, w = weights, cl = cluster) %>% 69 | dplyr::group_by(.data$cl) %>% 70 | dplyr::summarize(r = sum(tcrossprod(.data$w) * tcrossprod(.data$r))) %>% 71 | dplyr::select("r") %>% 72 | unlist() -> cl_resids 73 | } else { 74 | cl_resids <- weights ^ 2 * (y - yhat) ^ 2 75 | } 76 | sqrt(sum(cl_resids) / n ^ 2 * H / (H - 1)) 77 | } 78 | -------------------------------------------------------------------------------- /man/SL.glmnet.interaction.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/SL.glmnet.interaction.R 3 | \name{SL.glmnet.interaction} 4 | \alias{SL.glmnet.interaction} 5 | \title{Elastic net regression with pairwise interactions} 6 | \usage{ 7 | SL.glmnet.interaction( 8 | Y, 9 | X, 10 | newX, 11 | family, 12 | obsWeights, 13 | id, 14 | alpha = 1, 15 | nfolds = 10, 16 | nlambda = 100, 17 | useMin = TRUE, 18 | loss = "deviance", 19 | ... 20 | ) 21 | } 22 | \arguments{ 23 | \item{Y}{Outcome variable} 24 | 25 | \item{X}{Covariate dataframe} 26 | 27 | \item{newX}{Dataframe to predict the outcome} 28 | 29 | \item{family}{"gaussian" for regression, "binomial" for binary 30 | classification. Untested options: "multinomial" for multiple classification 31 | or "mgaussian" for multiple response, "poisson" for non-negative outcome 32 | with proportional mean and variance, "cox".} 33 | 34 | \item{obsWeights}{Optional observation-level weights} 35 | 36 | \item{id}{Optional id to group observations from the same unit (not used 37 | currently).} 38 | 39 | \item{alpha}{Elastic net mixing parameter, range [0, 1]. 0 = ridge regression 40 | and 1 = lasso.} 41 | 42 | \item{nfolds}{Number of folds for internal cross-validation to optimize lambda.} 43 | 44 | \item{nlambda}{Number of lambda values to check, recommended to be 100 or more.} 45 | 46 | \item{useMin}{If TRUE use lambda that minimizes risk, otherwise use 1 47 | standard-error rule which chooses a higher penalty with performance within 48 | one standard error of the minimum (see Breiman et al. 1984 on CART for 49 | background).} 50 | 51 | \item{loss}{Loss function, can be "deviance", "mse", or "mae". If family = 52 | binomial can also be "auc" or "class" (misclassification error).} 53 | 54 | \item{...}{Any additional arguments are passed through to cv.glmnet.} 55 | } 56 | \description{ 57 | Penalized regression using elastic net. Alpha = 0 corresponds to ridge 58 | regression and alpha = 1 corresponds to Lasso. Included in the model 59 | are pairwise interactions between covariates. 60 | 61 | See \code{vignette("glmnet_beta", package = "glmnet")} for a nice tutorial on 62 | glmnet. 63 | } 64 | -------------------------------------------------------------------------------- /man/make_splits.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/public_api.R 3 | \name{make_splits} 4 | \alias{make_splits} 5 | \title{Define splits for cross-fitting} 6 | \usage{ 7 | make_splits(data, identifier, ..., .num_splits) 8 | } 9 | \arguments{ 10 | \item{data}{dataframe} 11 | 12 | \item{identifier}{Unquoted name of unique identifier column} 13 | 14 | \item{...}{variables on which to stratify (requires that \code{quickblock} be installed.)} 15 | 16 | \item{.num_splits}{number of splits to create. If VIMP is requested in \code{QoI_cfg}, this 17 | must be an even number.} 18 | } 19 | \value{ 20 | original dataframe with additional \code{.split_id} column 21 | } 22 | \description{ 23 | This takes a dataset, a column with a unique identifier and an 24 | arbitrary number of covariates on which to stratify the splits. 25 | It returns the original dataset with an additional column \code{.split_id} 26 | corresponding to an identifier for the split. 27 | } 28 | \details{ 29 | To see an example analysis, read \code{vignette("experimental_analysis")} in the context 30 | of an experiment, \code{vignette("experimental_analysis")} for an observational study, or 31 | \code{vignette("methodological_details")} for a deeper dive under the hood. 32 | } 33 | \examples{ 34 | library("dplyr") 35 | if(require("palmerpenguins")) { 36 | data(package = 'palmerpenguins') 37 | penguins$unitid = seq_len(nrow(penguins)) 38 | penguins$propensity = rep(0.5, nrow(penguins)) 39 | penguins$treatment = rbinom(nrow(penguins), 1, penguins$propensity) 40 | cfg <- basic_config() \%>\% 41 | add_known_propensity_score("propensity") \%>\% 42 | add_outcome_model("SL.glm.interaction") \%>\% 43 | remove_vimp() 44 | attach_config(penguins, cfg) \%>\% 45 | make_splits(unitid, .num_splits = 4) \%>\% 46 | produce_plugin_estimates(outcome = body_mass_g, treatment = treatment, species, sex) \%>\% 47 | construct_pseudo_outcomes(body_mass_g, treatment) \%>\% 48 | estimate_QoI(species, sex) 49 | } 50 | } 51 | \seealso{ 52 | \code{\link[=attach_config]{attach_config()}}, \code{\link[=produce_plugin_estimates]{produce_plugin_estimates()}}, \code{\link[=construct_pseudo_outcomes]{construct_pseudo_outcomes()}}, 53 | \code{\link[=estimate_QoI]{estimate_QoI()}} 54 | } 55 | -------------------------------------------------------------------------------- /tests/testthat/test-predictor.R: -------------------------------------------------------------------------------- 1 | 2 | test_that("predictor factory", { 3 | model_cfg <- SLEnsemble_cfg$new() 4 | checkmate::expect_r6(predictor_factory(model_cfg), classes = "Predictor") 5 | 6 | model_cfg <- Known_cfg$new("test") 7 | checkmate::expect_r6(predictor_factory(model_cfg), classes = "Predictor") 8 | 9 | model_cfg <- Model_cfg$new() 10 | expect_error(predictor_factory(model_cfg), class = "tidyhte_error_model") 11 | 12 | expect_error(predictor_factory(list(model_class = "test")), class = "tidyhte_error_model") 13 | 14 | expect_error(predictor_factory("test"), class = "tidyhte_error_model") 15 | }) 16 | 17 | test_that("base Predictor class", { 18 | expect_warning(pred <- Predictor$new(), "Not Implemented") 19 | expect_error(pred$fit(y, x), class = "tidyhte_error_not_implemented") 20 | expect_error(pred$predict(data), class = "tidyhte_error_not_implemented") 21 | expect_error(pred$predict_se(data), class = "tidyhte_error_not_implemented") 22 | }) 23 | 24 | test_that("check check_nuisance_models", { 25 | df <- dplyr::tibble(a = 1:10, b = letters[1:10]) 26 | expect_error( 27 | check_nuisance_models(df), 28 | class = "tidyhte_error_data" 29 | ) 30 | }) 31 | 32 | test_that("SL_predictor gives expected output", { 33 | slpred <- predictor_factory( 34 | SLEnsemble_cfg$new(learner_cfgs = list( 35 | SLLearner_cfg$new("SL.glm"), SLLearner_cfg$new("SL.gam") 36 | )) 37 | ) 38 | df <- dplyr::tibble( 39 | uid = 1:100, 40 | x1 = rnorm(100), 41 | x2 = rnorm(100), 42 | x3 = sample(4, 100, replace = TRUE) 43 | ) %>% dplyr::mutate( 44 | y = x1 + x2 + x3 + rnorm(100), 45 | x3 = factor(x3) 46 | ) 47 | df <- make_splits(df, uid, .num_splits = 5) 48 | data <- Model_data$new(df, y, x1, x2, x3) 49 | slpred$fit(data) 50 | expect_error(o <- slpred$predict(data), NA) 51 | checkmate::expect_data_frame(o, nrows = 100) 52 | expect_true(all(is.na(o$x))) 53 | 54 | data <- Model_data$new(df, y, x1) 55 | slpred$fit(data) 56 | expect_error(o <- slpred$predict(data), NA) 57 | checkmate::expect_data_frame(o, nrows = 100) 58 | expect_true(cor.test(o$x, o$estimate)$p.value < 0.05) 59 | expect_true(all(!is.na(o$x))) 60 | expect_equal(o$x, df$x1, ignore_attr = TRUE) 61 | }) 62 | -------------------------------------------------------------------------------- /R/rroc.R: -------------------------------------------------------------------------------- 1 | #' Regression ROC Curve calculation 2 | #' 3 | #' This function calculates the RegressionROC Curve of 4 | #' of Hernández-Orallo 5 | #' \doi{doi:10.1016/j.patcog.2013.06.014}. 6 | #' It provides estimates for the positive and negative 7 | #' errors when predictions are shifted by a variety 8 | #' of constants (which range across the domain of observed 9 | #' residuals). Curves closer to the axes are, in general, to be 10 | #' preferred. In general, this curve provides a simple way to 11 | #' visualize the error properties of a regression model. 12 | #' 13 | #' The dot shows the errors when no shift is applied, corresponding 14 | #' to the base model predictions. 15 | #' @param label True label 16 | #' @param prediction Model prediction of the label (out of sample) 17 | #' @param nbins Number of shift values to sweep over 18 | #' @references Hernández-Orallo, J. (2013). ROC curves for regression. 19 | #' Pattern Recognition, 46(12), 3395-3411. 20 | #' @return A tibble with `nbins` rows. 21 | #' @importFrom dplyr tibble bind_rows 22 | #' @importFrom stats quantile 23 | #' @keywords internal 24 | calculate_rroc <- function(label, prediction, nbins = 100) { 25 | residuals <- label - prediction 26 | n <- length(residuals) 27 | shifts <- stats::quantile(residuals, probs = seq(0, 1, length.out = nbins - 1)) 28 | result <- calculate_pos_and_neg(residuals, 0) 29 | results <- dplyr::tibble( 30 | estimand = "RROC", 31 | value = result$pos / n, 32 | level = "observed", 33 | estimate = result$neg / n, 34 | std_error = NA_real_ 35 | ) 36 | for (shift in shifts) { 37 | result <- calculate_pos_and_neg(residuals, -shift) 38 | results <- dplyr::bind_rows( 39 | results, 40 | dplyr::tibble( 41 | estimand = "RROC", 42 | value = result$pos / n, 43 | level = "shifted", 44 | estimate = result$neg / n, 45 | std_error = NA_real_ 46 | ) 47 | ) 48 | } 49 | results 50 | } 51 | 52 | #' @noRd 53 | #' @keywords internal 54 | #' @importFrom rlang list2 55 | calculate_pos_and_neg <- function(residuals, shift = 0.0) { 56 | shifted_residuals <- residuals + shift 57 | rlang::list2( 58 | pos = sum(shifted_residuals[shifted_residuals > 0]), 59 | neg = sum(shifted_residuals[shifted_residuals <= 0]) 60 | ) 61 | } 62 | -------------------------------------------------------------------------------- /man/Constant_cfg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/model_cfg.R 3 | \name{Constant_cfg} 4 | \alias{Constant_cfg} 5 | \title{Configuration of a Constant Estimator} 6 | \description{ 7 | \code{Constant_cfg} is a configuration class for estimating a constant model. 8 | That is, the model is a simple, one-parameter mean model. 9 | } 10 | \examples{ 11 | 12 | ## ------------------------------------------------ 13 | ## Method `Constant_cfg$new` 14 | ## ------------------------------------------------ 15 | 16 | Constant_cfg$new() 17 | } 18 | \section{Super class}{ 19 | \code{\link[tidyhte:Model_cfg]{tidyhte::Model_cfg}} -> \code{Constant_cfg} 20 | } 21 | \section{Public fields}{ 22 | \if{html}{\out{
}} 23 | \describe{ 24 | \item{\code{model_class}}{The class of the model, required for all classes 25 | which inherit from \code{Model_cfg}.} 26 | } 27 | \if{html}{\out{
}} 28 | } 29 | \section{Methods}{ 30 | \subsection{Public methods}{ 31 | \itemize{ 32 | \item \href{#method-Constant_cfg-new}{\code{Constant_cfg$new()}} 33 | \item \href{#method-Constant_cfg-clone}{\code{Constant_cfg$clone()}} 34 | } 35 | } 36 | \if{html}{\out{
}} 37 | \if{html}{\out{}} 38 | \if{latex}{\out{\hypertarget{method-Constant_cfg-new}{}}} 39 | \subsection{Method \code{new()}}{ 40 | Create a new \code{Constant_cfg} object. 41 | \subsection{Usage}{ 42 | \if{html}{\out{
}}\preformatted{Constant_cfg$new()}\if{html}{\out{
}} 43 | } 44 | 45 | \subsection{Returns}{ 46 | A new \code{Constant_cfg} object. 47 | } 48 | \subsection{Examples}{ 49 | \if{html}{\out{
}} 50 | \preformatted{Constant_cfg$new() 51 | } 52 | \if{html}{\out{
}} 53 | 54 | } 55 | 56 | } 57 | \if{html}{\out{
}} 58 | \if{html}{\out{}} 59 | \if{latex}{\out{\hypertarget{method-Constant_cfg-clone}{}}} 60 | \subsection{Method \code{clone()}}{ 61 | The objects of this class are cloneable with this method. 62 | \subsection{Usage}{ 63 | \if{html}{\out{
}}\preformatted{Constant_cfg$clone(deep = FALSE)}\if{html}{\out{
}} 64 | } 65 | 66 | \subsection{Arguments}{ 67 | \if{html}{\out{
}} 68 | \describe{ 69 | \item{\code{deep}}{Whether to make a deep clone.} 70 | } 71 | \if{html}{\out{
}} 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /tests/testthat/test-utils.R: -------------------------------------------------------------------------------- 1 | test_that("muffle_messages works", { 2 | f1 <- function() { 3 | message("test message 1") 4 | } 5 | f2 <- function() { 6 | } 7 | f3 <- function() { 8 | message("test message 2") 9 | } 10 | expect_message(muffle_messages(f1(), "test"), NA) 11 | expect_message(muffle_messages(f2(), "test"), NA) 12 | expect_message(muffle_messages(f2(), ""), NA) 13 | expect_message(muffle_messages(f3(), "test message 1"), "test message 2") 14 | }) 15 | 16 | test_that("muffle_warnings works", { 17 | f1 <- function() { 18 | warning("test warning 1") 19 | } 20 | f2 <- function() { 21 | } 22 | f3 <- function() { 23 | warning("test warning 2") 24 | } 25 | expect_warning(muffle_warnings(f1(), "test"), NA) 26 | expect_warning(muffle_warnings(f2(), "test"), NA) 27 | expect_warning(muffle_warnings(f2(), ""), NA) 28 | expect_warning(muffle_warnings(f3(), "test warning 1"), "test warning 2") 29 | }) 30 | 31 | test_that("zero_range works", { 32 | expect_true(zero_range(1)) 33 | expect_true(zero_range(.Machine$double.eps, 2 * .Machine$double.eps)) 34 | }) 35 | 36 | test_that("soft_require works", { 37 | expect_error( 38 | soft_require("thispackagedoesntexist"), 39 | class = "rlib_error_package_not_found" 40 | ) 41 | 42 | expect_error(soft_require("SuperLearner"), NA) 43 | }) 44 | 45 | test_that("cluster robust SEs are correct", { 46 | skip_if_not_installed("estimatr") 47 | skip_on_cran() 48 | 49 | y <- rnorm(100) 50 | cl <- sample(1:4, size = 100, replace = TRUE) 51 | lmr <- estimatr::lm_robust(y ~ 1, clusters = cl, se_type = "stata") 52 | 53 | expect_error(result <- clustered_se_of_mean(y, cl), NA) 54 | expect_equal(result, lmr$std.error, ignore_attr = TRUE) 55 | 56 | cl <- sample(1:10, size = 125, replace = TRUE) 57 | y <- cl + rnorm(125) 58 | lmr <- estimatr::lm_robust(y ~ 1, clusters = cl, se_type = "stata") 59 | 60 | expect_error(result <- clustered_se_of_mean(y, cl), NA) 61 | expect_equal(result, lmr$std.error, ignore_attr = TRUE) 62 | 63 | cl <- sample(1:10, size = 125, replace = TRUE) 64 | y <- cl + rnorm(125) 65 | w <- 1 / 5 + rexp(125) 66 | lmr <- estimatr::lm_robust(y ~ 1, clusters = cl, se_type = "stata", weights = w) 67 | 68 | expect_error(result <- clustered_se_of_mean(y, cl, weights = w), NA) 69 | expect_equal(result, lmr$std.error, ignore_attr = TRUE) 70 | }) 71 | -------------------------------------------------------------------------------- /man/HTEFold.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data-utils.R 3 | \name{HTEFold} 4 | \alias{HTEFold} 5 | \title{R6 class to represent partitions of the data between training and held-out} 6 | \description{ 7 | R6 class to represent partitions of the data between training and held-out 8 | 9 | R6 class to represent partitions of the data between training and held-out 10 | } 11 | \details{ 12 | This takes a set of folds calculated elsewhere and represents 13 | these folds in a consistent format. 14 | } 15 | \keyword{internal} 16 | \section{Public fields}{ 17 | \if{html}{\out{
}} 18 | \describe{ 19 | \item{\code{train}}{A dataframe containing only the training set} 20 | 21 | \item{\code{holdout}}{A dataframe containing only the held-out data} 22 | 23 | \item{\code{in_holdout}}{A logical vector indicating if the initial data 24 | lies in the holdout set.} 25 | } 26 | \if{html}{\out{
}} 27 | } 28 | \section{Methods}{ 29 | \subsection{Public methods}{ 30 | \itemize{ 31 | \item \href{#method-HTEFold-new}{\code{HTEFold$new()}} 32 | \item \href{#method-HTEFold-clone}{\code{HTEFold$clone()}} 33 | } 34 | } 35 | \if{html}{\out{
}} 36 | \if{html}{\out{}} 37 | \if{latex}{\out{\hypertarget{method-HTEFold-new}{}}} 38 | \subsection{Method \code{new()}}{ 39 | Creates an R6 object of the data split between training and test set. 40 | \subsection{Usage}{ 41 | \if{html}{\out{
}}\preformatted{HTEFold$new(data, split_id)}\if{html}{\out{
}} 42 | } 43 | 44 | \subsection{Arguments}{ 45 | \if{html}{\out{
}} 46 | \describe{ 47 | \item{\code{data}}{The dataset to be split} 48 | 49 | \item{\code{split_id}}{An identifier indicating which data should lie in the holdout set.} 50 | } 51 | \if{html}{\out{
}} 52 | } 53 | \subsection{Returns}{ 54 | Returns an object of class \code{HTEFold} 55 | } 56 | } 57 | \if{html}{\out{
}} 58 | \if{html}{\out{}} 59 | \if{latex}{\out{\hypertarget{method-HTEFold-clone}{}}} 60 | \subsection{Method \code{clone()}}{ 61 | The objects of this class are cloneable with this method. 62 | \subsection{Usage}{ 63 | \if{html}{\out{
}}\preformatted{HTEFold$clone(deep = FALSE)}\if{html}{\out{
}} 64 | } 65 | 66 | \subsection{Arguments}{ 67 | \if{html}{\out{
}} 68 | \describe{ 69 | \item{\code{deep}}{Whether to make a deep clone.} 70 | } 71 | \if{html}{\out{
}} 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /man/basic_config.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/recipe-api.R 3 | \name{basic_config} 4 | \alias{basic_config} 5 | \title{Create a basic config for HTE estimation} 6 | \usage{ 7 | basic_config() 8 | } 9 | \value{ 10 | \code{HTE_cfg} object 11 | } 12 | \description{ 13 | This provides a basic recipe for HTE estimation that can 14 | be extended by providing additional information about models 15 | to be estimated and what quantities of interest should be 16 | returned based on those models. This basic model includes 17 | only linear models for nuisance function estimation, and 18 | basic diagnostics. 19 | } 20 | \details{ 21 | Additional models, diagnostics and quantities of interest should 22 | be added using their respective helper functions provided as part 23 | of the Recipe API. 24 | 25 | To see an example analysis, read \code{vignette("experimental_analysis")} in the context 26 | of an experiment, \code{vignette("experimental_analysis")} for an observational study, or 27 | \code{vignette("methodological_details")} for a deeper dive under the hood. 28 | } 29 | \examples{ 30 | library("dplyr") 31 | basic_config() \%>\% 32 | add_known_propensity_score("ps") \%>\% 33 | add_outcome_model("SL.glm.interaction") \%>\% 34 | add_outcome_model("SL.glmnet", alpha = c(0.05, 0.15, 0.2, 0.25, 0.5, 0.75)) \%>\% 35 | add_outcome_model("SL.glmnet.interaction", alpha = c(0.05, 0.15, 0.2, 0.25, 0.5, 0.75)) \%>\% 36 | add_outcome_diagnostic("RROC") \%>\% 37 | add_effect_model("SL.glm.interaction") \%>\% 38 | add_effect_model("SL.glmnet", alpha = c(0.05, 0.15, 0.2, 0.25, 0.5, 0.75)) \%>\% 39 | add_effect_model("SL.glmnet.interaction", alpha = c(0.05, 0.15, 0.2, 0.25, 0.5, 0.75)) \%>\% 40 | add_effect_diagnostic("RROC") \%>\% 41 | add_moderator("Stratified", x2, x3) \%>\% 42 | add_moderator("KernelSmooth", x1, x4, x5) \%>\% 43 | add_vimp(sample_splitting = FALSE) -> hte_cfg 44 | } 45 | \seealso{ 46 | \code{\link[=add_propensity_score_model]{add_propensity_score_model()}}, \code{\link[=add_known_propensity_score]{add_known_propensity_score()}}, 47 | \code{\link[=add_propensity_diagnostic]{add_propensity_diagnostic()}}, \code{\link[=add_outcome_model]{add_outcome_model()}}, \code{\link[=add_outcome_diagnostic]{add_outcome_diagnostic()}}, 48 | \code{\link[=add_effect_model]{add_effect_model()}}, \code{\link[=add_effect_diagnostic]{add_effect_diagnostic()}}, \code{\link[=add_moderator]{add_moderator()}}, \code{\link[=add_vimp]{add_vimp()}} 49 | } 50 | -------------------------------------------------------------------------------- /man/produce_plugin_estimates.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/public_api.R 3 | \name{produce_plugin_estimates} 4 | \alias{produce_plugin_estimates} 5 | \title{Estimate models of nuisance functions} 6 | \usage{ 7 | produce_plugin_estimates(data, outcome, treatment, ..., .weights = NULL) 8 | } 9 | \arguments{ 10 | \item{data}{dataframe (already prepared with \code{attach_config} and \code{make_splits})} 11 | 12 | \item{outcome}{Unquoted name of the outcome variable.} 13 | 14 | \item{treatment}{Unquoted name of the treatment variable.} 15 | 16 | \item{...}{Unquoted names of covariates to include in the models of the nuisance functions.} 17 | 18 | \item{.weights}{Unquoted name of weights column. If NULL, all analysis will assume weights 19 | are all equal to one and sample-based quantities will be returned.} 20 | } 21 | \description{ 22 | This takes a dataset with an identified outcome and treatment column along 23 | with any number of covariates and appends three columns to the dataset corresponding 24 | to an estimate of the conditional expectation of treatment (\code{.pi_hat}), along with the 25 | conditional expectation of the control and treatment potential outcome surfaces 26 | (\code{.mu0_hat} and \code{.mu1_hat} respectively). 27 | } 28 | \details{ 29 | To see an example analysis, read \code{vignette("experimental_analysis")} in the context 30 | of an experiment, \code{vignette("experimental_analysis")} for an observational study, or 31 | \code{vignette("methodological_details")} for a deeper dive under the hood. 32 | } 33 | \examples{ 34 | library("dplyr") 35 | if(require("palmerpenguins")) { 36 | data(package = 'palmerpenguins') 37 | penguins$unitid = seq_len(nrow(penguins)) 38 | penguins$propensity = rep(0.5, nrow(penguins)) 39 | penguins$treatment = rbinom(nrow(penguins), 1, penguins$propensity) 40 | cfg <- basic_config() \%>\% 41 | add_known_propensity_score("propensity") \%>\% 42 | add_outcome_model("SL.glm.interaction") \%>\% 43 | remove_vimp() 44 | attach_config(penguins, cfg) \%>\% 45 | make_splits(unitid, .num_splits = 4) \%>\% 46 | produce_plugin_estimates(outcome = body_mass_g, treatment = treatment, species, sex) \%>\% 47 | construct_pseudo_outcomes(body_mass_g, treatment) \%>\% 48 | estimate_QoI(species, sex) 49 | } 50 | } 51 | \seealso{ 52 | \code{\link[=attach_config]{attach_config()}}, \code{\link[=make_splits]{make_splits()}}, \code{\link[=construct_pseudo_outcomes]{construct_pseudo_outcomes()}}, \code{\link[=estimate_QoI]{estimate_QoI()}} 53 | } 54 | -------------------------------------------------------------------------------- /man/Stratified_cfg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/model_cfg.R 3 | \name{Stratified_cfg} 4 | \alias{Stratified_cfg} 5 | \title{Configuration for a Stratification Estimator} 6 | \description{ 7 | \code{Stratified_cfg} is a configuration class for stratifying a covariate 8 | and calculating statistics within each cell. 9 | } 10 | \examples{ 11 | 12 | ## ------------------------------------------------ 13 | ## Method `Stratified_cfg$new` 14 | ## ------------------------------------------------ 15 | 16 | Stratified_cfg$new(covariate = "test_covariate") 17 | } 18 | \section{Super class}{ 19 | \code{\link[tidyhte:Model_cfg]{tidyhte::Model_cfg}} -> \code{Stratified_cfg} 20 | } 21 | \section{Public fields}{ 22 | \if{html}{\out{
}} 23 | \describe{ 24 | \item{\code{model_class}}{The class of the model, required for all classes 25 | which inherit from \code{Model_cfg}.} 26 | 27 | \item{\code{covariate}}{The name of the column in the dataset 28 | which corresponds to the covariate on which to stratify.} 29 | } 30 | \if{html}{\out{
}} 31 | } 32 | \section{Methods}{ 33 | \subsection{Public methods}{ 34 | \itemize{ 35 | \item \href{#method-Stratified_cfg-new}{\code{Stratified_cfg$new()}} 36 | \item \href{#method-Stratified_cfg-clone}{\code{Stratified_cfg$clone()}} 37 | } 38 | } 39 | \if{html}{\out{
}} 40 | \if{html}{\out{}} 41 | \if{latex}{\out{\hypertarget{method-Stratified_cfg-new}{}}} 42 | \subsection{Method \code{new()}}{ 43 | Create a new \code{Stratified_cfg} object with specified number of evaluation points. 44 | \subsection{Usage}{ 45 | \if{html}{\out{
}}\preformatted{Stratified_cfg$new(covariate)}\if{html}{\out{
}} 46 | } 47 | 48 | \subsection{Arguments}{ 49 | \if{html}{\out{
}} 50 | \describe{ 51 | \item{\code{covariate}}{The name of the column in the dataset 52 | which corresponds to the covariate on which to stratify.} 53 | } 54 | \if{html}{\out{
}} 55 | } 56 | \subsection{Returns}{ 57 | A new \code{Stratified_cfg} object. 58 | } 59 | \subsection{Examples}{ 60 | \if{html}{\out{
}} 61 | \preformatted{Stratified_cfg$new(covariate = "test_covariate") 62 | } 63 | \if{html}{\out{
}} 64 | 65 | } 66 | 67 | } 68 | \if{html}{\out{
}} 69 | \if{html}{\out{}} 70 | \if{latex}{\out{\hypertarget{method-Stratified_cfg-clone}{}}} 71 | \subsection{Method \code{clone()}}{ 72 | The objects of this class are cloneable with this method. 73 | \subsection{Usage}{ 74 | \if{html}{\out{
}}\preformatted{Stratified_cfg$clone(deep = FALSE)}\if{html}{\out{
}} 75 | } 76 | 77 | \subsection{Arguments}{ 78 | \if{html}{\out{
}} 79 | \describe{ 80 | \item{\code{deep}}{Whether to make a deep clone.} 81 | } 82 | \if{html}{\out{
}} 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /man/Known_cfg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/model_cfg.R 3 | \name{Known_cfg} 4 | \alias{Known_cfg} 5 | \title{Configuration of Known Model} 6 | \description{ 7 | \code{Known_cfg} is a configuration class for when a particular model is known 8 | a-priori. The prototypical usage of this class is when heterogeneous 9 | treatment effects are estimated in the context of a randomized control 10 | trial with known propensity scores. 11 | } 12 | \examples{ 13 | 14 | ## ------------------------------------------------ 15 | ## Method `Known_cfg$new` 16 | ## ------------------------------------------------ 17 | 18 | Known_cfg$new("propensity_score") 19 | } 20 | \section{Super class}{ 21 | \code{\link[tidyhte:Model_cfg]{tidyhte::Model_cfg}} -> \code{Known_cfg} 22 | } 23 | \section{Public fields}{ 24 | \if{html}{\out{
}} 25 | \describe{ 26 | \item{\code{covariate_name}}{The name of the column in the dataset 27 | which corresponds to the known model score.} 28 | 29 | \item{\code{model_class}}{The class of the model, required for all classes 30 | which inherit from \code{Model_cfg}.} 31 | } 32 | \if{html}{\out{
}} 33 | } 34 | \section{Methods}{ 35 | \subsection{Public methods}{ 36 | \itemize{ 37 | \item \href{#method-Known_cfg-new}{\code{Known_cfg$new()}} 38 | \item \href{#method-Known_cfg-clone}{\code{Known_cfg$clone()}} 39 | } 40 | } 41 | \if{html}{\out{
}} 42 | \if{html}{\out{}} 43 | \if{latex}{\out{\hypertarget{method-Known_cfg-new}{}}} 44 | \subsection{Method \code{new()}}{ 45 | Create a new \code{Known_cfg} object with specified covariate column. 46 | \subsection{Usage}{ 47 | \if{html}{\out{
}}\preformatted{Known_cfg$new(covariate_name)}\if{html}{\out{
}} 48 | } 49 | 50 | \subsection{Arguments}{ 51 | \if{html}{\out{
}} 52 | \describe{ 53 | \item{\code{covariate_name}}{The name of the column, a string, in the dataset 54 | corresponding to the known model score (i.e. the true conditional expectation).} 55 | } 56 | \if{html}{\out{
}} 57 | } 58 | \subsection{Returns}{ 59 | A new \code{Known_cfg} object. 60 | } 61 | \subsection{Examples}{ 62 | \if{html}{\out{
}} 63 | \preformatted{Known_cfg$new("propensity_score") 64 | } 65 | \if{html}{\out{
}} 66 | 67 | } 68 | 69 | } 70 | \if{html}{\out{
}} 71 | \if{html}{\out{}} 72 | \if{latex}{\out{\hypertarget{method-Known_cfg-clone}{}}} 73 | \subsection{Method \code{clone()}}{ 74 | The objects of this class are cloneable with this method. 75 | \subsection{Usage}{ 76 | \if{html}{\out{
}}\preformatted{Known_cfg$clone(deep = FALSE)}\if{html}{\out{
}} 77 | } 78 | 79 | \subsection{Arguments}{ 80 | \if{html}{\out{
}} 81 | \describe{ 82 | \item{\code{deep}}{Whether to make a deep clone.} 83 | } 84 | \if{html}{\out{
}} 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /R/errors.R: -------------------------------------------------------------------------------- 1 | #' Error utilities for tidyhte package 2 | #' 3 | #' These functions provide a standardized way to throw classed errors throughout 4 | #' the tidyhte package, replacing basic `stop()` calls with structured error 5 | #' conditions that can be caught and handled programmatically. 6 | #' 7 | #' @details 8 | #' The tidyhte package uses a hierarchical error class system based on rlang's 9 | #' structured condition handling. All tidyhte errors inherit from the base class 10 | #' `tidyhte_error_general` and include more specific subclasses: 11 | #' 12 | #' * `tidyhte_error_config` - Configuration-related errors 13 | #' * `tidyhte_error_model` - Model-related errors 14 | #' * `tidyhte_error_data` - Data validation errors 15 | #' * `tidyhte_error_not_implemented` - Not implemented functionality 16 | #' * `tidyhte_error_package` - Package dependency errors 17 | #' 18 | #' These classed errors can be caught using `tryCatch()` or `rlang::catch_cnd()` 19 | #' with the specific error class for more precise error handling. 20 | #' 21 | #' @examples 22 | #' \dontrun{ 23 | #' # Catching specific error types 24 | #' tryCatch( 25 | #' some_tidyhte_function(), 26 | #' tidyhte_error_config = function(e) cat("Configuration error:", conditionMessage(e)), 27 | #' tidyhte_error_data = function(e) cat("Data error:", conditionMessage(e)) 28 | #' ) 29 | #' } 30 | #' 31 | #' @name tidyhte-errors 32 | #' @keywords internal 33 | NULL 34 | 35 | #' Throw a tidyhte error 36 | #' 37 | #' @param message Error message 38 | #' @param class Error class suffix (will be prefixed with "tidyhte_error_") 39 | #' @param ... Additional arguments passed to `rlang::abort()` 40 | #' @keywords internal 41 | tidyhte_abort <- function(message, class = "general", ...) { 42 | rlang::abort( 43 | message = message, 44 | class = paste0("tidyhte_error_", class), 45 | ... 46 | ) 47 | } 48 | 49 | #' Configuration-related errors 50 | #' 51 | #' @param message Error message 52 | #' @param ... Additional arguments passed to `tidyhte_abort()` 53 | #' @keywords internal 54 | abort_config <- function(message, ...) { 55 | tidyhte_abort(message, class = "config", ...) 56 | } 57 | 58 | #' Model-related errors 59 | #' 60 | #' @param message Error message 61 | #' @param ... Additional arguments passed to `tidyhte_abort()` 62 | #' @keywords internal 63 | abort_model <- function(message, ...) { 64 | tidyhte_abort(message, class = "model", ...) 65 | } 66 | 67 | #' Data validation errors 68 | #' 69 | #' @param message Error message 70 | #' @param ... Additional arguments passed to `tidyhte_abort()` 71 | #' @keywords internal 72 | abort_data <- function(message, ...) { 73 | tidyhte_abort(message, class = "data", ...) 74 | } 75 | 76 | #' Not implemented errors 77 | #' 78 | #' @param message Error message 79 | #' @param ... Additional arguments passed to `tidyhte_abort()` 80 | #' @keywords internal 81 | abort_not_implemented <- function(message = "Not implemented", ...) { 82 | tidyhte_abort(message, class = "not_implemented", ...) 83 | } 84 | 85 | #' Package dependency errors 86 | #' 87 | #' @param message Error message 88 | #' @param ... Additional arguments passed to `tidyhte_abort()` 89 | #' @keywords internal 90 | abort_package <- function(message, ...) { 91 | tidyhte_abort(message, class = "package", ...) 92 | } 93 | -------------------------------------------------------------------------------- /man/SLLearner_cfg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/model_cfg.R 3 | \name{SLLearner_cfg} 4 | \alias{SLLearner_cfg} 5 | \title{Configuration of SuperLearner Submodel} 6 | \description{ 7 | \code{SLLearner_cfg} is a configuration class for a single 8 | sublearner to be included in SuperLearner. By constructing with a named list 9 | of hyperparameters, this configuration allows distinct submodels 10 | for each unique combination of hyperparameters. To understand what models 11 | and hyperparameters are available, examine the methods listed in 12 | \code{SuperLearner::listWrappers("SL")}. 13 | } 14 | \examples{ 15 | 16 | ## ------------------------------------------------ 17 | ## Method `SLLearner_cfg$new` 18 | ## ------------------------------------------------ 19 | 20 | SLLearner_cfg$new("SL.glm") 21 | SLLearner_cfg$new("SL.gam", list(deg.gam = c(2, 3))) 22 | } 23 | \section{Public fields}{ 24 | \if{html}{\out{
}} 25 | \describe{ 26 | \item{\code{model_name}}{The name of the model as passed to \code{SuperLearner} 27 | through the \code{SL.library} parameter.} 28 | 29 | \item{\code{hyperparameters}}{Named list from hyperparameter name to a vector of 30 | values that should be swept over.} 31 | } 32 | \if{html}{\out{
}} 33 | } 34 | \section{Methods}{ 35 | \subsection{Public methods}{ 36 | \itemize{ 37 | \item \href{#method-SLLearner_cfg-new}{\code{SLLearner_cfg$new()}} 38 | \item \href{#method-SLLearner_cfg-clone}{\code{SLLearner_cfg$clone()}} 39 | } 40 | } 41 | \if{html}{\out{
}} 42 | \if{html}{\out{}} 43 | \if{latex}{\out{\hypertarget{method-SLLearner_cfg-new}{}}} 44 | \subsection{Method \code{new()}}{ 45 | Create a new \code{SLLearner_cfg} object with specified model name and hyperparameters. 46 | \subsection{Usage}{ 47 | \if{html}{\out{
}}\preformatted{SLLearner_cfg$new(model_name, hp = NULL)}\if{html}{\out{
}} 48 | } 49 | 50 | \subsection{Arguments}{ 51 | \if{html}{\out{
}} 52 | \describe{ 53 | \item{\code{model_name}}{The name of the model as passed to \code{SuperLearner} 54 | through the \code{SL.library} parameter.} 55 | 56 | \item{\code{hp}}{Named list from hyperparameter name to a vector of values that should be 57 | swept over. Hyperparameters not included in this list are left at their SuperLearner 58 | default values.} 59 | } 60 | \if{html}{\out{
}} 61 | } 62 | \subsection{Returns}{ 63 | A new \code{SLLearner_cfg} object. 64 | } 65 | \subsection{Examples}{ 66 | \if{html}{\out{
}} 67 | \preformatted{SLLearner_cfg$new("SL.glm") 68 | SLLearner_cfg$new("SL.gam", list(deg.gam = c(2, 3))) 69 | } 70 | \if{html}{\out{
}} 71 | 72 | } 73 | 74 | } 75 | \if{html}{\out{
}} 76 | \if{html}{\out{}} 77 | \if{latex}{\out{\hypertarget{method-SLLearner_cfg-clone}{}}} 78 | \subsection{Method \code{clone()}}{ 79 | The objects of this class are cloneable with this method. 80 | \subsection{Usage}{ 81 | \if{html}{\out{
}}\preformatted{SLLearner_cfg$clone(deep = FALSE)}\if{html}{\out{
}} 82 | } 83 | 84 | \subsection{Arguments}{ 85 | \if{html}{\out{
}} 86 | \describe{ 87 | \item{\code{deep}}{Whether to make a deep clone.} 88 | } 89 | \if{html}{\out{
}} 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /.github/workflows/rhub.yaml: -------------------------------------------------------------------------------- 1 | # R-hub's generic GitHub Actions workflow file. It's canonical location is at 2 | # https://github.com/r-hub/actions/blob/v1/workflows/rhub.yaml 3 | # You can update this file to a newer version using the rhub2 package: 4 | # 5 | # rhub::rhub_setup() 6 | # 7 | # It is unlikely that you need to modify this file manually. 8 | 9 | name: R-hub 10 | run-name: "${{ github.event.inputs.id }}: ${{ github.event.inputs.name || format('Manually run by {0}', github.triggering_actor) }}" 11 | 12 | on: 13 | workflow_dispatch: 14 | inputs: 15 | config: 16 | description: 'A comma separated list of R-hub platforms to use.' 17 | type: string 18 | default: 'linux,windows,macos' 19 | name: 20 | description: 'Run name. You can leave this empty now.' 21 | type: string 22 | id: 23 | description: 'Unique ID. You can leave this empty now.' 24 | type: string 25 | 26 | jobs: 27 | 28 | setup: 29 | runs-on: ubuntu-latest 30 | outputs: 31 | containers: ${{ steps.rhub-setup.outputs.containers }} 32 | platforms: ${{ steps.rhub-setup.outputs.platforms }} 33 | 34 | steps: 35 | # NO NEED TO CHECKOUT HERE 36 | - uses: r-hub/actions/setup@v1 37 | with: 38 | config: ${{ github.event.inputs.config }} 39 | id: rhub-setup 40 | 41 | linux-containers: 42 | needs: setup 43 | if: ${{ needs.setup.outputs.containers != '[]' }} 44 | runs-on: ubuntu-latest 45 | name: ${{ matrix.config.label }} 46 | strategy: 47 | fail-fast: false 48 | matrix: 49 | config: ${{ fromJson(needs.setup.outputs.containers) }} 50 | container: 51 | image: ${{ matrix.config.container }} 52 | 53 | steps: 54 | - uses: r-hub/actions/checkout@v1 55 | - uses: r-hub/actions/platform-info@v1 56 | with: 57 | token: ${{ secrets.RHUB_TOKEN }} 58 | job-config: ${{ matrix.config.job-config }} 59 | - uses: r-hub/actions/setup-deps@v1 60 | with: 61 | token: ${{ secrets.RHUB_TOKEN }} 62 | job-config: ${{ matrix.config.job-config }} 63 | - uses: r-hub/actions/run-check@v1 64 | with: 65 | token: ${{ secrets.RHUB_TOKEN }} 66 | job-config: ${{ matrix.config.job-config }} 67 | 68 | other-platforms: 69 | needs: setup 70 | if: ${{ needs.setup.outputs.platforms != '[]' }} 71 | runs-on: ${{ matrix.config.os }} 72 | name: ${{ matrix.config.label }} 73 | strategy: 74 | fail-fast: false 75 | matrix: 76 | config: ${{ fromJson(needs.setup.outputs.platforms) }} 77 | 78 | steps: 79 | - uses: r-hub/actions/checkout@v1 80 | - uses: r-hub/actions/setup-r@v1 81 | with: 82 | job-config: ${{ matrix.config.job-config }} 83 | token: ${{ secrets.RHUB_TOKEN }} 84 | - uses: r-hub/actions/platform-info@v1 85 | with: 86 | token: ${{ secrets.RHUB_TOKEN }} 87 | job-config: ${{ matrix.config.job-config }} 88 | - uses: r-hub/actions/setup-deps@v1 89 | with: 90 | job-config: ${{ matrix.config.job-config }} 91 | token: ${{ secrets.RHUB_TOKEN }} 92 | - uses: r-hub/actions/run-check@v1 93 | with: 94 | job-config: ${{ matrix.config.job-config }} 95 | token: ${{ secrets.RHUB_TOKEN }} 96 | -------------------------------------------------------------------------------- /tests/testthat/test-vimp.R: -------------------------------------------------------------------------------- 1 | n <- 500 2 | d <- dplyr::tibble( 3 | uid = rep(paste0("uid is ", as.character(1:(n / 2))), 2), 4 | cov1 = sample(rep(1:2, c(n / 2, n / 2)), n, replace = TRUE), 5 | cov2 = sample(rep(1:2, c(n / 2, n / 2)), n, replace = TRUE), 6 | cov3 = rnorm(n), 7 | cov4 = runif(n), 8 | w = rep(1, n) 9 | ) %>% dplyr::mutate(y = 2 * cov1 + cov3 + rnorm(n)) 10 | 11 | test_that("splitting ensures even number of splits with VIMP", { 12 | attr(d, "HTE_cfg") <- HTE_cfg$new(qoi = QoI_cfg$new(vimp = VIMP_cfg$new())) 13 | expect_message( 14 | make_splits(d, uid, .num_splits = 5), 15 | "`num_splits` must be even if VIMP is requested as a QoI. Rounding up." 16 | ) 17 | expect_message(make_splits(d, cov1, .num_splits = 4), NA) 18 | }) 19 | 20 | test_that("VIMP ensures even number of splits", { 21 | d <- make_splits(d, uid, .num_splits = 5) 22 | attr(d, "HTE_cfg") <- HTE_cfg$new(qoi = QoI_cfg$new(vimp = VIMP_cfg$new())) 23 | expect_error( 24 | calculate_vimp( 25 | d, w, y, cov1, cov2, .VIMP_cfg = attr(d, "HTE_cfg")$qoi$vimp, 26 | .Model_cfg = SLEnsemble_cfg$new() 27 | ), 28 | "Number of splits must be even to calculate VIMP." 29 | ) 30 | attr(d, "num_splits") <- 4 31 | expect_error( 32 | calculate_vimp( 33 | d, w, y, cov1, cov2, .VIMP_cfg = attr(d, "HTE_cfg")$qoi$vimp, 34 | .Model_cfg = SLEnsemble_cfg$new() 35 | ), 36 | "Number of splits is inconsistent." 37 | ) 38 | }) 39 | 40 | test_that("linear VIMP works with weights", { 41 | d <- make_splits(d, uid, .num_splits = 4) 42 | d$w <- rexp(n, 1 / 0.9) + 0.1 43 | attr(d, "HTE_cfg") <- HTE_cfg$new(qoi = QoI_cfg$new(vimp = VIMP_cfg$new())) 44 | expect_error(result <- calculate_linear_vimp( 45 | d, w, y, cov1, cov2, cov3, 46 | .VIMP_cfg = attr(d, "HTE_cfg")$qoi$vimp, .Model_cfg = SLEnsemble_cfg$new() 47 | ), NA) 48 | 49 | expect_gt(result$estimate[1], result$estimate[2]) 50 | expect_gt(result$estimate[3], result$estimate[2]) 51 | 52 | expect_gt(result$estimate[1] / result$std_error[1], 2) 53 | expect_gt(result$estimate[3] / result$std_error[3], 2) 54 | expect_lt(result$estimate[2] / result$std_error[2], 2) 55 | }) 56 | 57 | test_that("VIMP works with weights", { 58 | d <- make_splits(d, uid, .num_splits = 4) 59 | d$w <- rexp(n, 1 / 0.9) + 0.1 60 | attr(d, "HTE_cfg") <- HTE_cfg$new(qoi = QoI_cfg$new(vimp = VIMP_cfg$new())) 61 | expect_error(result <- calculate_vimp( 62 | d, w, y, cov1, cov2, cov3, 63 | .VIMP_cfg = attr(d, "HTE_cfg")$qoi$vimp, .Model_cfg = SLEnsemble_cfg$new() 64 | ), NA) 65 | 66 | expect_gt(result$estimate[1], result$estimate[2]) 67 | expect_gt(result$estimate[3], result$estimate[2]) 68 | 69 | expect_gt(result$estimate[1] / result$std_error[1], 2) 70 | expect_gt(result$estimate[3] / result$std_error[3], 2) 71 | expect_lt(result$estimate[2] / result$std_error[2], 2) 72 | }) 73 | 74 | test_that("VIMP works with weights without sample splitting", { 75 | d <- make_splits(d, uid, .num_splits = 4) 76 | d$w <- rexp(n, 1 / 0.9) + 0.1 77 | attr(d, "HTE_cfg") <- HTE_cfg$new( 78 | qoi = QoI_cfg$new(vimp = VIMP_cfg$new(sample_splitting = FALSE)) 79 | ) 80 | expect_error(result <- calculate_vimp( 81 | d, w, y, cov1, cov2, cov3, 82 | .VIMP_cfg = attr(d, "HTE_cfg")$qoi$vimp, .Model_cfg = SLEnsemble_cfg$new() 83 | ), NA) 84 | 85 | expect_gt(result$estimate[1], result$estimate[2]) 86 | expect_gt(result$estimate[3], result$estimate[2]) 87 | 88 | expect_gt(result$estimate[1] / result$std_error[1], 2) 89 | expect_gt(result$estimate[3] / result$std_error[3], 2) 90 | expect_lt(result$estimate[2] / result$std_error[2], 2) 91 | }) 92 | -------------------------------------------------------------------------------- /man/VIMP_cfg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/meta_cfg.R 3 | \name{VIMP_cfg} 4 | \alias{VIMP_cfg} 5 | \title{Configuration of Variable Importance} 6 | \description{ 7 | \code{VIMP_cfg} is a configuration class for estimating a variable importance measure 8 | across all moderators. This provides a meaningful measure of which moderators 9 | explain the most of the CATE surface. 10 | } 11 | \examples{ 12 | VIMP_cfg$new() 13 | 14 | ## ------------------------------------------------ 15 | ## Method `VIMP_cfg$new` 16 | ## ------------------------------------------------ 17 | 18 | VIMP_cfg$new() 19 | } 20 | \references{ 21 | \itemize{ 22 | \item Williamson, B. D., Gilbert, P. B., Carone, M., & Simon, N. (2021). 23 | Nonparametric variable importance assessment using machine learning techniques. 24 | Biometrics, 77(1), 9-22. 25 | \item Williamson, B. D., Gilbert, P. B., Simon, N. R., & Carone, M. (2021). 26 | A general framework for inference on algorithm-agnostic variable importance. 27 | Journal of the American Statistical Association, 1-14. 28 | } 29 | } 30 | \section{Public fields}{ 31 | \if{html}{\out{
}} 32 | \describe{ 33 | \item{\code{estimand}}{String indicating the estimand to target.} 34 | 35 | \item{\code{sample_splitting}}{Logical indicating whether to use sample 36 | splitting in the calculation of variable importance.} 37 | 38 | \item{\code{linear}}{Logical indicating whether the variable importance 39 | assuming a linear model should be estimated.} 40 | } 41 | \if{html}{\out{
}} 42 | } 43 | \section{Methods}{ 44 | \subsection{Public methods}{ 45 | \itemize{ 46 | \item \href{#method-VIMP_cfg-new}{\code{VIMP_cfg$new()}} 47 | \item \href{#method-VIMP_cfg-clone}{\code{VIMP_cfg$clone()}} 48 | } 49 | } 50 | \if{html}{\out{
}} 51 | \if{html}{\out{}} 52 | \if{latex}{\out{\hypertarget{method-VIMP_cfg-new}{}}} 53 | \subsection{Method \code{new()}}{ 54 | Create a new \code{VIMP_cfg} object with specified model configuration. 55 | \subsection{Usage}{ 56 | \if{html}{\out{
}}\preformatted{VIMP_cfg$new(sample_splitting = TRUE, linear_only = FALSE)}\if{html}{\out{
}} 57 | } 58 | 59 | \subsection{Arguments}{ 60 | \if{html}{\out{
}} 61 | \describe{ 62 | \item{\code{sample_splitting}}{Logical indicating whether to use sample splitting 63 | in the calculation of variable importance. Choosing not to use sample 64 | splitting means that inference will only be valid for moderators with 65 | non-null importance.} 66 | 67 | \item{\code{linear_only}}{Logical indicating whether the variable importance 68 | should use only a single linear-only model. Variable importance measure 69 | will only be consistent for the population quantity if the true model 70 | of pseudo-outcomes is linear.} 71 | } 72 | \if{html}{\out{
}} 73 | } 74 | \subsection{Returns}{ 75 | A new \code{VIMP_cfg} object. 76 | } 77 | \subsection{Examples}{ 78 | \if{html}{\out{
}} 79 | \preformatted{VIMP_cfg$new() 80 | } 81 | \if{html}{\out{
}} 82 | 83 | } 84 | 85 | } 86 | \if{html}{\out{
}} 87 | \if{html}{\out{}} 88 | \if{latex}{\out{\hypertarget{method-VIMP_cfg-clone}{}}} 89 | \subsection{Method \code{clone()}}{ 90 | The objects of this class are cloneable with this method. 91 | \subsection{Usage}{ 92 | \if{html}{\out{
}}\preformatted{VIMP_cfg$clone(deep = FALSE)}\if{html}{\out{
}} 93 | } 94 | 95 | \subsection{Arguments}{ 96 | \if{html}{\out{
}} 97 | \describe{ 98 | \item{\code{deep}}{Whether to make a deep clone.} 99 | } 100 | \if{html}{\out{
}} 101 | } 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /tests/testthat/test-pseudooutcomes.R: -------------------------------------------------------------------------------- 1 | set.seed(20051920) # 20051920 is derived from 'test' 2 | 3 | n <- 250 4 | data <- dplyr::tibble( 5 | uid = 1:n 6 | ) %>% 7 | dplyr::mutate( 8 | a = rbinom(n, 1, 0.5), 9 | ps = rep(0.5, n), 10 | x1 = rnorm(n), 11 | x2 = factor(sample(1:4, n, prob = c(1 / 5, 1 / 5, 1 / 5, 2 / 5), replace = TRUE)), 12 | x3 = factor(sample(1:3, n, prob = c(1 / 5, 1 / 5, 3 / 5), replace = TRUE)), 13 | x4 = (x1 + rnorm(n)) / 2, 14 | x5 = rnorm(n), 15 | y = a + x1 - 0.5 * a * (x1 - mean(x1)) + as.double(x2) + rnorm(n) 16 | ) 17 | 18 | userid <- rlang::expr(uid) 19 | 20 | propensity_score_variable_name <- "ps" 21 | 22 | continuous_covariates <- c("x1") 23 | 24 | discrete_covariates <- c("x2", "x3") 25 | 26 | continuous_moderators <- rlang::exprs(x1, x4, x5) 27 | discrete_moderators <- rlang::exprs(x2, x3) 28 | moderators <- c(continuous_moderators, discrete_moderators) 29 | 30 | model_covariate_names <- c(continuous_covariates, discrete_covariates) 31 | model_covariates <- rlang::syms(model_covariate_names) 32 | 33 | outcome_variable <- rlang::expr(y) 34 | treatment_variable <- rlang::expr(a) 35 | 36 | trt.cfg <- SLEnsemble_cfg$new( 37 | learner_cfgs = list( 38 | SLLearner_cfg$new( 39 | "SL.glm" 40 | ) 41 | ), 42 | family = stats::binomial() 43 | ) 44 | 45 | regression.cfg <- SLEnsemble_cfg$new( 46 | learner_cfgs = list( 47 | SLLearner_cfg$new( 48 | "SL.glm" 49 | ) 50 | ) 51 | ) 52 | 53 | qoi.list <- list() 54 | for (cov in continuous_moderators) { 55 | qoi.list[[rlang::as_string(cov)]] <- KernelSmooth_cfg$new(neval = 100) 56 | } 57 | for (cov in discrete_moderators) { 58 | qoi.list[[rlang::as_string(cov)]] <- Stratified_cfg$new(cov) 59 | } 60 | 61 | qoi.cfg <- QoI_cfg$new( 62 | mcate = MCATE_cfg$new(cfgs = qoi.list), 63 | diag = Diagnostics_cfg$new( 64 | outcome = c("SL_risk", "SL_coefs", "MSE"), 65 | ps = c("SL_risk", "SL_coefs", "AUC") 66 | ) 67 | ) 68 | 69 | cfg <- HTE_cfg$new( 70 | treatment = trt.cfg, 71 | outcome = regression.cfg, 72 | qoi = qoi.cfg 73 | ) 74 | 75 | E <- new.env(parent = emptyenv()) 76 | 77 | test_that("add config", { 78 | E$data1 <- attach_config(data, cfg) 79 | checkmate::expect_data_frame(E$data1) 80 | expect_true("HTE_cfg" %in% names(attributes(E$data1))) 81 | }) 82 | 83 | test_that("Split data", { 84 | E$data2 <- make_splits(E$data1, {{ userid }}, .num_splits = 3) 85 | checkmate::expect_data_frame(E$data2) 86 | }) 87 | 88 | test_that("Estimate Plugin Models", { 89 | E$data3 <- produce_plugin_estimates( 90 | E$data2, 91 | {{ outcome_variable }}, 92 | {{ treatment_variable }}, 93 | !!!model_covariates 94 | ) 95 | checkmate::expect_data_frame(E$data3) 96 | }) 97 | 98 | test_that("Errors on unknown type", { 99 | expect_error( 100 | construct_pseudo_outcomes( 101 | E$data3, {{ outcome_variable }}, {{ treatment_variable }}, type = "idk" 102 | ), 103 | "Unknown type of pseudo-outcome." 104 | ) 105 | }) 106 | 107 | test_that("Construct DR Pseudo-outcomes", { 108 | data4 <- construct_pseudo_outcomes( 109 | E$data3, {{ outcome_variable }}, {{ treatment_variable }} 110 | ) 111 | checkmate::expect_data_frame(data4) 112 | }) 113 | 114 | test_that("Construct IPW Pseudo-outcomes", { 115 | data4 <- construct_pseudo_outcomes( 116 | E$data3, {{ outcome_variable }}, {{ treatment_variable }}, "ipw" 117 | ) 118 | checkmate::expect_data_frame(data4) 119 | }) 120 | 121 | test_that("Construct Plugin Pseudo-outcomes", { 122 | data4 <- construct_pseudo_outcomes( 123 | E$data3, {{ outcome_variable }}, {{ treatment_variable }}, "plugin" 124 | ) 125 | checkmate::expect_data_frame(data4) 126 | }) 127 | -------------------------------------------------------------------------------- /man/KernelSmooth_cfg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/model_cfg.R 3 | \name{KernelSmooth_cfg} 4 | \alias{KernelSmooth_cfg} 5 | \title{Configuration for a Kernel Smoother} 6 | \description{ 7 | \code{KernelSmooth_cfg} is a configuration class for non-parametric local-linear 8 | regression to construct a smooth representation of the relationship between 9 | two variables. This is typically used for displaying a surface of the conditional 10 | average treatment effect over a continuous covariate. 11 | 12 | Kernel smoothing is handled by the \code{nprobust} package. 13 | } 14 | \examples{ 15 | 16 | ## ------------------------------------------------ 17 | ## Method `KernelSmooth_cfg$new` 18 | ## ------------------------------------------------ 19 | 20 | KernelSmooth_cfg$new(neval = 100) 21 | } 22 | \seealso{ 23 | \link[nprobust:lprobust]{nprobust::lprobust} 24 | } 25 | \section{Super class}{ 26 | \code{\link[tidyhte:Model_cfg]{tidyhte::Model_cfg}} -> \code{KernelSmooth_cfg} 27 | } 28 | \section{Public fields}{ 29 | \if{html}{\out{
}} 30 | \describe{ 31 | \item{\code{model_class}}{The class of the model, required for all classes 32 | which inherit from \code{Model_cfg}.} 33 | 34 | \item{\code{neval}}{The number of points at which to evaluate the local 35 | regression. More points will provide a smoother line at the cost 36 | of somewhat higher computation.} 37 | 38 | \item{\code{eval_min_quantile}}{Minimum quantile at which to evaluate the smoother.} 39 | } 40 | \if{html}{\out{
}} 41 | } 42 | \section{Methods}{ 43 | \subsection{Public methods}{ 44 | \itemize{ 45 | \item \href{#method-KernelSmooth_cfg-new}{\code{KernelSmooth_cfg$new()}} 46 | \item \href{#method-KernelSmooth_cfg-clone}{\code{KernelSmooth_cfg$clone()}} 47 | } 48 | } 49 | \if{html}{\out{
}} 50 | \if{html}{\out{}} 51 | \if{latex}{\out{\hypertarget{method-KernelSmooth_cfg-new}{}}} 52 | \subsection{Method \code{new()}}{ 53 | Create a new \code{KernelSmooth_cfg} object with specified number of evaluation points. 54 | \subsection{Usage}{ 55 | \if{html}{\out{
}}\preformatted{KernelSmooth_cfg$new(neval = 100, eval_min_quantile = 0.05)}\if{html}{\out{
}} 56 | } 57 | 58 | \subsection{Arguments}{ 59 | \if{html}{\out{
}} 60 | \describe{ 61 | \item{\code{neval}}{The number of points at which to evaluate the local 62 | regression. More points will provide a smoother line at the cost 63 | of somewhat higher computation.} 64 | 65 | \item{\code{eval_min_quantile}}{Minimum quantile at which to evaluate the smoother. 66 | A value of zero will do no clipping. Clipping is performed from both the 67 | top and the bottom of the empirical distribution. A value of alpha would 68 | evaluate over [alpha, 1 - alpha].} 69 | } 70 | \if{html}{\out{
}} 71 | } 72 | \subsection{Returns}{ 73 | A new \code{KernelSmooth_cfg} object. 74 | } 75 | \subsection{Examples}{ 76 | \if{html}{\out{
}} 77 | \preformatted{KernelSmooth_cfg$new(neval = 100) 78 | } 79 | \if{html}{\out{
}} 80 | 81 | } 82 | 83 | } 84 | \if{html}{\out{
}} 85 | \if{html}{\out{}} 86 | \if{latex}{\out{\hypertarget{method-KernelSmooth_cfg-clone}{}}} 87 | \subsection{Method \code{clone()}}{ 88 | The objects of this class are cloneable with this method. 89 | \subsection{Usage}{ 90 | \if{html}{\out{
}}\preformatted{KernelSmooth_cfg$clone(deep = FALSE)}\if{html}{\out{
}} 91 | } 92 | 93 | \subsection{Arguments}{ 94 | \if{html}{\out{
}} 95 | \describe{ 96 | \item{\code{deep}}{Whether to make a deep clone.} 97 | } 98 | \if{html}{\out{
}} 99 | } 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /tests/testthat/test-config-construction.R: -------------------------------------------------------------------------------- 1 | propensity_score_variable_name <- "ps" 2 | continuous_covariates <- c("x1") 3 | discrete_covariates <- c("x2") 4 | 5 | outcome_variable <- rlang::expr(y) 6 | treatment_variable <- rlang::expr(a) 7 | 8 | 9 | all_covariates <- rlang::syms(c(continuous_covariates, discrete_covariates)) 10 | model_covariate_names <- c(continuous_covariates, discrete_covariates) 11 | 12 | E <- new.env(parent = emptyenv()) 13 | 14 | test_that("Configs can be constructed successfully.", { 15 | E$trt.cfg <- Known_cfg$new(propensity_score_variable_name) 16 | 17 | E$basic.regression.cfg <- SLEnsemble_cfg$new(cvControl = list(V = 5)) 18 | 19 | E$regression.cfg <- SLEnsemble_cfg$new( 20 | learner_cfgs = list( 21 | SLLearner_cfg$new( 22 | "SL.glm" 23 | ), 24 | SLLearner_cfg$new( 25 | "SL.gam", 26 | list( 27 | deg.gam = c(2, 3, 4, 5, 7, 9) 28 | ) 29 | ), 30 | SLLearner_cfg$new( 31 | "SL.glmnet", 32 | list( 33 | alpha = c(0.05, 0.15, 0.2, 0.25), 34 | loss = c("mse", "deviance") 35 | ) 36 | ), 37 | SLLearner_cfg$new( 38 | "SL.glmnet.interaction", 39 | list( 40 | alpha = c(0.05, 0.15, 0.2, 0.25), 41 | loss = c("mse", "deviance") 42 | ) 43 | ), 44 | SLLearner_cfg$new( 45 | "SL.ranger", 46 | list( 47 | num.trees = c(25, 50, 100, 250, 500), 48 | splitrule = c("gini", "extratrees", "hellinger") 49 | ) 50 | ), 51 | SLLearner_cfg$new( 52 | "SL.xgboost", 53 | list( 54 | ntrees = c(50, 100, 250, 500), 55 | max_depth = c(1, 3, 5, 7, 9), 56 | shrinkage = c(0.01, 0.1) 57 | ) 58 | ) 59 | ) 60 | ) 61 | 62 | E$regression.cfg <- E$regression.cfg$add_sublearner("SL.glm.interaction") 63 | 64 | E$qoi.list <- list() 65 | for (cov in continuous_covariates) { 66 | E$qoi.list[[cov]] <- KernelSmooth_cfg$new(neval = 100) 67 | } 68 | for (cov in discrete_covariates) { 69 | E$qoi.list[[cov]] <- Stratified_cfg$new(cov) 70 | } 71 | 72 | E$diag.cfg <- Diagnostics_cfg$new( 73 | outcome = c("SL_risk", "SL_coefs", "MSE"), 74 | effect = c("SL_risk", "SL_coefs") 75 | ) 76 | 77 | E$diag.cfg$add(effect = "MSE") 78 | E$diag.cfg$add(ps = "MSE") 79 | E$diag.cfg$add(outcome = "MSE") 80 | 81 | Diagnostics_cfg$new( 82 | ps = "MSE" 83 | ) 84 | 85 | E$mcate.cfg <- MCATE_cfg$new(cfgs = E$qoi.list) 86 | 87 | E$mcate.cfg$add_moderator("new_x", Stratified_cfg$new("new_x")) 88 | 89 | E$pcate.cfg <- PCATE_cfg$new( 90 | cfgs = E$qoi.list, 91 | model_covariates = model_covariate_names, 92 | num_mc_samples = 10 93 | ) 94 | 95 | E$pcate.cfg$add_moderator("new_x", Stratified_cfg$new("new_x")) 96 | 97 | expect_error(PCATE_cfg$new( 98 | cfgs = E$qoi.list, 99 | model_covariates = model_covariate_names, 100 | num_mc_samples = "not a number" 101 | ), "Unknown type of num_mc_samples") 102 | 103 | E$pcate.cfg <- PCATE_cfg$new( 104 | cfgs = E$qoi.list, 105 | model_covariates = model_covariate_names, 106 | num_mc_samples = list(x1 = 5, x2 = 10, x3 = 10, x4 = 5, x5 = 5) 107 | ) 108 | 109 | E$vimp.cfg <- VIMP_cfg$new() 110 | 111 | E$qoi.cfg <- QoI_cfg$new( 112 | mcate = E$mcate.cfg, 113 | vimp = E$vimp.cfg, 114 | pcate = E$pcate.cfg, 115 | diag = E$diag.cfg 116 | ) 117 | 118 | expect_error(QoI_cfg$new(), "Must define at least one QoI!") 119 | 120 | E$cfg <- HTE_cfg$new( 121 | treatment = E$trt.cfg, 122 | outcome = E$regression.cfg, 123 | effect = E$regression.cfg, 124 | qoi = E$qoi.cfg 125 | ) 126 | 127 | expect_error(HTE_cfg$new(), "Must define at least one QoI!") 128 | 129 | HTE_cfg$new(qoi = E$qoi.cfg) 130 | }) 131 | 132 | 133 | test_that("Config types are as expected.", { 134 | checkmate::expect_r6(E$cfg, classes = "HTE_cfg") 135 | checkmate::expect_r6(E$qoi.cfg, classes = "QoI_cfg") 136 | for (qoi in E$qoi.list) { 137 | checkmate::expect_r6(qoi, classes = "Model_cfg") 138 | } 139 | checkmate::expect_r6(E$regression.cfg, classes = "SLEnsemble_cfg") 140 | checkmate::expect_environment(E$regression.cfg$SL.env) 141 | checkmate::expect_character(E$regression.cfg$SL.library, len = 79) 142 | checkmate::expect_r6(E$trt.cfg, classes = c("Known_cfg", "Model_cfg")) 143 | checkmate::expect_r6(E$vimp.cfg, classes = "VIMP_cfg") 144 | checkmate::expect_r6(E$pcate.cfg, classes = "PCATE_cfg") 145 | checkmate::expect_r6(E$mcate.cfg, classes = "MCATE_cfg") 146 | checkmate::expect_r6(E$diag.cfg, classes = "Diagnostics_cfg") 147 | }) 148 | -------------------------------------------------------------------------------- /R/plugin-estimates.R: -------------------------------------------------------------------------------- 1 | #' Fits a plugin model using the appropriate settings 2 | #' 3 | #' This function prepares data, fits the appropriate models and returns the 4 | #' resulting estimates in a standardized format. 5 | #' @param full_data The full dataset of interest for the modelling problem. 6 | #' @param weight_col The unquoted weighting variable name to use in model fitting. 7 | #' @param outcome_col The unquoted column name to use as a label for the supervised 8 | #' learning problem. 9 | #' @param ... The unquoted names of covariates to use in the model. 10 | #' @param .Model_cfg A `Model_cfg` object configuring the appropriate model type to use. 11 | #' @return A new `Predictor` object of the appropriate subclass corresponding to the 12 | #' `Model_cfg` fit to the data. 13 | #' @keywords internal 14 | fit_plugin <- function(full_data, weight_col, outcome_col, ..., .Model_cfg) { 15 | dots <- rlang::enexprs(...) 16 | predictor <- predictor_factory(.Model_cfg) 17 | data <- Model_data$new(full_data, {{ outcome_col }}, !!!dots, .weight_col = {{ weight_col }}) 18 | muffle_warnings(predictor$fit(data), "rank-deficient fit", "grouped=FALSE") 19 | } 20 | 21 | #' Fits a propensity score model using the appropriate settings 22 | #' 23 | #' This function prepares data, fits the appropriate model and returns the 24 | #' resulting estimates in a standardized format. 25 | #' @param full_data The full dataset of interest for the modelling problem. 26 | #' @param weight_col The unquoted weighting variable name to use in model fitting. 27 | #' @param a_col The unquoted column name of the treatment. 28 | #' @param ... The unquoted names of covariates to use in the model. 29 | #' @param .Model_cfg A `Model_cfg` object configuring the appropriate model type to use. 30 | #' @return A list with one element, `ps`. This element contains a `Predictor` object of 31 | #' the appropriate subclass corresponding to the `Model_cfg` fit to the data. 32 | #' @keywords internal 33 | fit_plugin_A <- function(full_data, weight_col, a_col, ..., .Model_cfg) { 34 | dots <- rlang::enexprs(...) 35 | if (.Model_cfg$model_class == "known") { 36 | cov <- rlang::sym(.Model_cfg$covariate_name) 37 | dots <- unique(c(dots, cov)) 38 | } 39 | list( 40 | pi = fit_plugin(full_data, {{ weight_col }}, {{ a_col }}, !!!dots, .Model_cfg = .Model_cfg) 41 | ) 42 | } 43 | 44 | #' Fits a T-learner using the appropriate settings 45 | #' 46 | #' This function prepares data, fits the appropriate model and returns the 47 | #' resulting estimates in a standardized format. 48 | #' @param full_data The full dataset of interest for the modelling problem. 49 | #' @param weight_col The unquoted weighting variable name to use in model fitting. 50 | #' @param y_col The unquoted column name of the outcome. 51 | #' @param a_col The unquoted column name of the treatment. 52 | #' @param ... The unquoted names of covariates to use in the model. 53 | #' @param .Model_cfg A `Model_cfg` object configuring the appropriate model type to use. 54 | #' @return A list with two elements, `mu1` and `mu0` corresponding to the models fit to 55 | #' the treatment and control potential outcomes, respectively. Each is a new `Predictor` 56 | #' object of the appropriate subclass corresponding to the the `Model_cfg` fit to the data. 57 | #' @keywords internal 58 | fit_plugin_Y <- function(full_data, weight_col, y_col, a_col, ..., .Model_cfg) { 59 | dots <- rlang::enexprs(...) 60 | df_0 <- dplyr::filter(full_data, {{ a_col }} == 0) 61 | df_1 <- dplyr::filter(full_data, {{ a_col }} == 1) 62 | list( 63 | mu1 = fit_plugin(df_1, {{ weight_col }}, {{ y_col }}, !!!dots, .Model_cfg = .Model_cfg), 64 | mu0 = fit_plugin(df_0, {{ weight_col }}, {{ y_col }}, !!!dots, .Model_cfg = .Model_cfg) 65 | ) 66 | } 67 | 68 | #' Fits a treatment effect model using the appropriate settings 69 | #' 70 | #' This function prepares data, fits the appropriate model and returns the 71 | #' resulting estimates in a standardized format. 72 | #' @param full_data The full dataset of interest for the modelling problem. 73 | #' @param weight_col The unquoted weighting variable name to use in model fitting. 74 | #' @param fx_col The unquoted column name of the pseudo-outcome. 75 | #' @param ... The unquoted names of covariates to use in the model. 76 | #' @param .Model_cfg A `Model_cfg` object configuring the appropriate model type to use. 77 | #' @return A list with one element, `fx`. This element contains a `Predictor` object of 78 | #' the appropriate subclass corresponding to the `Model_cfg` fit to the data. 79 | #' @keywords internal 80 | fit_effect <- function(full_data, weight_col, fx_col, ..., .Model_cfg) { 81 | dots <- rlang::enexprs(...) 82 | list( 83 | fx = fit_plugin(full_data, {{ weight_col }}, {{ fx_col }}, !!!dots, .Model_cfg = .Model_cfg) 84 | ) 85 | } 86 | -------------------------------------------------------------------------------- /tests/testthat/test-api-mcate.R: -------------------------------------------------------------------------------- 1 | set.seed(20051920) # 20051920 is derived from 'test' 2 | 3 | n <- 250 4 | data <- dplyr::tibble( 5 | uid = 1:n 6 | ) %>% 7 | dplyr::mutate( 8 | a = rbinom(n, 1, 0.5), 9 | ps = rep(0.5, n), 10 | x1 = rnorm(n), 11 | x2 = factor(sample(1:4, n, prob = c(1 / 5, 1 / 5, 1 / 5, 2 / 5), replace = TRUE)), 12 | x3 = factor(sample(1:3, n, prob = c(1 / 5, 1 / 5, 3 / 5), replace = TRUE)), 13 | x4 = (x1 + rnorm(n)) / 2, 14 | x5 = rnorm(n), 15 | y = a + x1 - 0.5 * a * (x1 - mean(x1)) + as.double(x2) + rnorm(n) 16 | ) 17 | 18 | userid <- rlang::expr(uid) 19 | 20 | propensity_score_variable_name <- "ps" 21 | 22 | continuous_covariates <- c("x1") 23 | 24 | discrete_covariates <- c("x2", "x3") 25 | 26 | continuous_moderators <- rlang::exprs(x1, x4, x5) 27 | discrete_moderators <- rlang::exprs(x2, x3) 28 | moderators <- c(continuous_moderators, discrete_moderators) 29 | 30 | model_covariate_names <- c(continuous_covariates, discrete_covariates) 31 | model_covariates <- rlang::syms(model_covariate_names) 32 | 33 | outcome_variable <- rlang::expr(y) 34 | treatment_variable <- rlang::expr(a) 35 | 36 | trt.cfg <- SLEnsemble_cfg$new( 37 | learner_cfgs = list( 38 | SLLearner_cfg$new( 39 | "SL.glm" 40 | ) 41 | ), 42 | family = stats::binomial() 43 | ) 44 | 45 | regression.cfg <- SLEnsemble_cfg$new( 46 | learner_cfgs = list( 47 | SLLearner_cfg$new( 48 | "SL.glm" 49 | ), 50 | SLLearner_cfg$new( 51 | "SL.gam", 52 | list( 53 | deg.gam = c(2, 3) 54 | ) 55 | ), 56 | SLLearner_cfg$new( 57 | "SL.glmnet", 58 | list( 59 | alpha = c(0.05, 0.15) 60 | ) 61 | ), 62 | SLLearner_cfg$new( 63 | "SL.glmnet.interaction", 64 | list( 65 | alpha = c(0.05, 0.15) 66 | ) 67 | ) 68 | ) 69 | ) 70 | 71 | qoi.list <- list() 72 | for (cov in continuous_moderators) { 73 | qoi.list[[rlang::as_string(cov)]] <- KernelSmooth_cfg$new(neval = 100) 74 | } 75 | for (cov in discrete_moderators) { 76 | qoi.list[[rlang::as_string(cov)]] <- Stratified_cfg$new(cov) 77 | } 78 | 79 | qoi.cfg <- QoI_cfg$new( 80 | mcate = MCATE_cfg$new(cfgs = qoi.list), 81 | diag = Diagnostics_cfg$new( 82 | outcome = c("SL_risk", "SL_coefs", "MSE"), 83 | ps = c("SL_risk", "SL_coefs", "AUC") 84 | ) 85 | ) 86 | 87 | cfg <- HTE_cfg$new( 88 | treatment = trt.cfg, 89 | outcome = regression.cfg, 90 | qoi = qoi.cfg 91 | ) 92 | 93 | E <- new.env(parent = emptyenv()) 94 | 95 | test_that("add config", { 96 | E$data1 <- attach_config(data, cfg) 97 | checkmate::expect_data_frame(E$data1) 98 | expect_true("HTE_cfg" %in% names(attributes(E$data1))) 99 | }) 100 | 101 | test_that("Split data", { 102 | E$data2 <- make_splits(E$data1, {{ userid }}, .num_splits = 3) 103 | checkmate::expect_data_frame(E$data2) 104 | }) 105 | 106 | test_that("Estimate Plugin Models", { 107 | E$data3 <- produce_plugin_estimates( 108 | E$data2, 109 | {{ outcome_variable }}, 110 | {{ treatment_variable }}, 111 | !!!model_covariates 112 | ) 113 | checkmate::expect_data_frame(E$data3) 114 | }) 115 | 116 | test_that("Construct Pseudo-outcomes", { 117 | E$data4 <- construct_pseudo_outcomes(E$data3, {{ outcome_variable }}, {{ treatment_variable }}) 118 | checkmate::expect_data_frame(E$data4) 119 | }) 120 | 121 | test_that("Estimate QoIs", { 122 | expect_message( 123 | E$results <- estimate_QoI(E$data4), 124 | "No moderators specified, so pulling list from definitions in QoI." 125 | ) 126 | checkmate::expect_data_frame(E$results) 127 | }) 128 | 129 | n_rows <- ( 130 | 1 + # SATE estimate 131 | 2 + # MSE for y(0) & y(1) 132 | 1 + # AUC for pscore 133 | 2 * 7 + 1 + # one row per model in the ensemble for each PO + ps for SL risk 134 | 2 * 7 + 1 + # one row per model in the ensemble for each PO + ps for SL coefficient 135 | 3 * 100 + # 100 rows per continuous moderator for local regression for MCATE and for PCATE 136 | (4 + 3) # 2 rows per discrete moderator level for MCATE and for PCATE 137 | ) 138 | 139 | 140 | test_that("Check results data", { 141 | checkmate::check_character(E$results$estimand, any.missing = FALSE) 142 | checkmate::check_double(E$results$estimate, any.missing = FALSE) 143 | checkmate::check_double(E$results$std_error, any.missing = FALSE) 144 | 145 | checkmate::expect_tibble( 146 | E$results, 147 | all.missing = FALSE, 148 | nrows = n_rows, 149 | ncols = 6, 150 | types = c( 151 | estimand = "character", 152 | term = "character", 153 | value = "double", 154 | level = "character", 155 | estimate = "double", 156 | std_error = "double" 157 | ) 158 | ) 159 | }) 160 | -------------------------------------------------------------------------------- /tests/testthat/test-api-repeats.R: -------------------------------------------------------------------------------- 1 | set.seed(20051920) # 20051920 is derived from 'test' 2 | 3 | n <- 250 4 | n0 <- 20 5 | data <- dplyr::tibble( 6 | uid = 1:(n + n0) 7 | ) %>% 8 | dplyr::mutate( 9 | a = rbinom(n + n0, 1, 0.5), 10 | ps = rep(0.5, n + n0), 11 | x1 = c(rnorm(n), rep(0, n0)), 12 | x2 = factor(sample(1:4, n + n0, prob = c(1 / 5, 1 / 5, 1 / 5, 2 / 5), replace = TRUE)), 13 | y = a + x1 - 2.5 * a * (x1 - mean(x1)) + as.double(x2) + c(rnorm(n), rep(rnorm(1), n0)) 14 | ) 15 | 16 | userid <- rlang::expr(uid) 17 | 18 | propensity_score_variable_name <- "ps" 19 | 20 | continuous_covariates <- c("x1") 21 | 22 | discrete_covariates <- c("x2") 23 | 24 | continuous_moderators <- rlang::exprs(x1) 25 | discrete_moderators <- rlang::exprs(x2) 26 | moderators <- c(continuous_moderators, discrete_moderators) 27 | 28 | model_covariate_names <- c(continuous_covariates, discrete_covariates) 29 | model_covariates <- rlang::syms(model_covariate_names) 30 | 31 | outcome_variable <- rlang::expr(y) 32 | treatment_variable <- rlang::expr(a) 33 | 34 | trt.cfg <- Constant_cfg$new() 35 | 36 | regression.cfg <- SLEnsemble_cfg$new( 37 | learner_cfgs = list( 38 | SLLearner_cfg$new( 39 | "SL.glm" 40 | ), 41 | SLLearner_cfg$new( 42 | "SL.gam", 43 | list( 44 | degree = c(2, 3, 5) 45 | ) 46 | ) 47 | ) 48 | ) 49 | 50 | qoi.list <- list() 51 | for (cov in continuous_moderators) { 52 | qoi.list[[rlang::as_string(cov)]] <- KernelSmooth_cfg$new(neval = 100, eval_min_quantile = 0.05) 53 | } 54 | for (cov in discrete_moderators) { 55 | qoi.list[[rlang::as_string(cov)]] <- Stratified_cfg$new(cov) 56 | } 57 | 58 | qoi.cfg <- QoI_cfg$new( 59 | mcate = MCATE_cfg$new(cfgs = qoi.list), 60 | vimp = VIMP_cfg$new(linear_only = TRUE), 61 | diag = Diagnostics_cfg$new( 62 | outcome = c("SL_risk", "SL_coefs", "MSE", "RROC") 63 | ) 64 | ) 65 | 66 | cfg <- HTE_cfg$new( 67 | treatment = trt.cfg, 68 | outcome = regression.cfg, 69 | effect = regression.cfg, 70 | qoi = qoi.cfg 71 | ) 72 | 73 | E <- new.env(parent = emptyenv()) 74 | 75 | test_that("add config", { 76 | E$data <- attach_config(data, cfg) 77 | checkmate::expect_data_frame(E$data) 78 | expect_true("HTE_cfg" %in% names(attributes(E$data))) 79 | }) 80 | 81 | test_that("Split data", { 82 | E$data2 <- make_splits(E$data, {{ userid }}, .num_splits = 4) 83 | checkmate::expect_data_frame(E$data2) 84 | }) 85 | 86 | test_that("Estimate Plugin Models", { 87 | E$data3 <- produce_plugin_estimates( 88 | E$data2, 89 | {{ outcome_variable }}, 90 | {{ treatment_variable }}, 91 | !!!model_covariates 92 | ) 93 | checkmate::expect_data_frame(E$data3) 94 | }) 95 | 96 | test_that("Construct Pseudo-outcomes", { 97 | E$data4 <- construct_pseudo_outcomes(E$data3, {{ outcome_variable }}, {{ treatment_variable }}) 98 | checkmate::expect_data_frame(E$data4) 99 | }) 100 | 101 | test_that("Estimate QoIs", { 102 | skip_on_cran() 103 | E$results <- estimate_QoI(E$data4, !!!moderators) 104 | checkmate::expect_data_frame(E$results) 105 | }) 106 | 107 | test_that("VIMP is valid", { 108 | skip_on_cran() 109 | vimp <- E$results %>% dplyr::filter(grepl("VIMP", estimand)) 110 | vimp_z <- vimp$estimate / vimp$std_error 111 | # expect small p-value for x1 which has actual HTE 112 | expect_lt(2 * pnorm(vimp_z[1], lower.tail = FALSE), 0.01) 113 | # expect large p-value for x2 which has no HTE 114 | expect_gt(2 * pnorm(vimp_z[2], lower.tail = FALSE), 0.1) 115 | }) 116 | 117 | n_rows <- ( 118 | 1 + # SATE estimate 119 | 2 + # MSE for y(0) & y(1) 120 | 2 * 4 + # one row per model in the ensemble for each PO + ps for SL risk 121 | 2 * 4 + # one row per model in the ensemble for each PO + ps for SL coefficient 122 | 2 + # one row per moderator for variable importance 123 | 1 * 100 + # 100 rows per continuous moderator for local regression for MCATE and for PCATE 124 | (2 + 2) + # 2 rows per discrete moderator level for MCATE and for PCATE 125 | n + n0 # 1 row per observation for RROC 126 | ) 127 | 128 | test_that("Check results data", { 129 | skip_on_cran() 130 | checkmate::check_character(E$results$estimand, any.missing = FALSE) 131 | checkmate::check_double(E$results$estimate, any.missing = FALSE) 132 | checkmate::check_double(E$results$std_error, any.missing = FALSE) 133 | 134 | checkmate::expect_tibble( 135 | E$results, 136 | all.missing = FALSE, 137 | nrows = n_rows, 138 | ncols = 6, 139 | types = c( 140 | estimand = "character", 141 | term = "character", 142 | value = "double", 143 | level = "character", 144 | estimate = "double", 145 | std_error = "double" 146 | ) 147 | ) 148 | }) 149 | -------------------------------------------------------------------------------- /joss/paper.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 'tidyhte: Tidy Estimation of Heterogeneous Treatment Effects in R' 3 | tags: 4 | - R 5 | - causal inference 6 | - heterogeneous treatment effects 7 | - machine learning 8 | - experimental design 9 | authors: 10 | - name: Drew Dimmery 11 | orcid: 0000-0001-9602-6325 12 | affiliation: 1 13 | - name: Edward Kennedy 14 | orcid: 0000-0002-0227-158X 15 | affiliation: 2 16 | affiliations: 17 | - name: Hertie School Data Science Lab 18 | index: 1 19 | ror: 0473a4773 20 | - name: Department of Statistics and Data Science, Carnegie Mellon University, USA 21 | index: 2 22 | ror: 05x2bcf33 23 | date: 4 October 2025 24 | bibliography: paper.bib 25 | --- 26 | 27 | # Summary 28 | 29 | Heterogeneous treatment effects (HTE) describe how intervention impacts vary across individuals or subgroups. Understanding this variation is crucial for targeting interventions, optimizing resource allocation, and identifying mechanisms of action. `tidyhte` provides a principled framework for estimating heterogeneous treatment effects using modern machine learning methods. The package implements the doubly-robust learner of @kennedy2023towards, combining outcome modeling and propensity score estimation with cross-validation to produce valid statistical inference on treatment effect heterogeneity. 30 | 31 | The package uses a "recipe" design that scales naturally from single to multiple outcomes and moderators. Users specify machine learning algorithms for nuisance function estimation (treatment propensity and outcome models) and define moderators of interest; `tidyhte` handles cross-validation, model selection, diagnostics, and construction of quantities of interest. The tidy design integrates seamlessly with the broader R data science ecosystem and supports common empirical practices including weighting for population average effects and clustered treatment assignment. 32 | 33 | # Statement of Need 34 | 35 | Existing tools for heterogeneous treatment effect estimation often require substantial statistical expertise and involve navigating complex decisions about cross-validation, model selection, doubly-robust estimation, and valid inference. Implementations are scattered across packages with inconsistent interfaces, making reliable application difficult. 36 | 37 | `tidyhte` addresses these challenges by providing a unified interface that: (1) implements state-of-the-art doubly-robust methods with automatic cross-validation, (2) uses intuitive "recipe" semantics familiar from packages like `recipes` [@recipes_package] and `parsnip` [@parsnip_package], (3) handles both experimental and observational data, (4) scales from single to multiple outcomes and moderators, (5) provides built-in diagnostics for model quality and effect heterogeneity, and (6) returns tidy data formats for downstream analysis. By automating technical details while maintaining statistical rigor, the package makes modern HTE methods accessible to applied researchers who need to understand treatment effect variation. 38 | 39 | # Research Applications 40 | 41 | `tidyhte` supports research across clinical trials (identifying patient subgroups), policy evaluation (understanding differential population impacts), technology (optimizing interventions for user segments), economics (studying policy effects across demographics), and education (evaluating differential intervention impacts). The package was stress-tested as the HTE estimation software for publications on Facebook and Instagram's effects on the 2020 US Presidential election, involving estimation across approximately ten moderators and dozens of outcomes [@guess2023social; @nyhan2023like; @guess2023reshares]. 42 | 43 | # Related Work 44 | 45 | The package builds on the theoretical framework of @kennedy2023towards for doubly-robust HTE estimation. It leverages the SuperLearner ensemble learning framework [@van2007super] for flexible machine learning. The tidy design philosophy follows @wickham2014tidy and integrates with the broader tidyverse ecosystem [@wickham2019welcome]. 46 | 47 | Other R packages for HTE estimation include `grf` [@grf_package; @grf] for generalized random forests, `bcf` [@bcf_package; @bcf] for Bayesian causal forests, and `FindIt` [@FindIt_package; @imai2013estimating]. `tidyhte` differentiates itself through its doubly-robust methodology [@kennedy2023towards], recipe interface, support for scaling to multiple outcomes and moderators, and native handling of weights and clustering. 48 | 49 | # Acknowledgements 50 | 51 | We gratefully acknowledge the collaboration with the US 2020 Facebook and Instagram Election Study team for being a testbed for the initial versions of `tidyhte`, particularly Pablo Barberá for testing this package during development and providing valuable feedback on the design and functionality. We received no financial support for this software project. 52 | 53 | # References -------------------------------------------------------------------------------- /man/FX.Predictor.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/joint-effect-estimation.R 3 | \name{FX.Predictor} 4 | \alias{FX.Predictor} 5 | \title{Predictor class for the cross-fit predictor of "partial" CATEs} 6 | \description{ 7 | Predictor class for the cross-fit predictor of "partial" CATEs 8 | 9 | Predictor class for the cross-fit predictor of "partial" CATEs 10 | } 11 | \details{ 12 | The class makes it easier to manage the K predictors for retrieving K-fold 13 | cross-validated estimates, as well as to measure how treatment effects change 14 | when only a single covariate is changed from its "natural" levels (in the sense 15 | "natural" used by the direct / indirect effects literature). 16 | } 17 | \note{ 18 | This is experimental functionality that hasn't been tested as thoroughly. 19 | The API may change in future versions. 20 | } 21 | \keyword{internal} 22 | \section{Public fields}{ 23 | \if{html}{\out{
}} 24 | \describe{ 25 | \item{\code{models}}{A list of the K model fits} 26 | 27 | \item{\code{num_splits}}{The number of folds used in cross-fitting.} 28 | 29 | \item{\code{num_mc_samples}}{The number of samples to retrieve across the covariate space. 30 | If num_mc_samples is larger than the sample size, then the entire dataset will be used.} 31 | 32 | \item{\code{covariates}}{The unquoted names of the covariates used in the second-stage model.} 33 | 34 | \item{\code{model_class}}{The model class (in the sense of \code{Model_cfg}). For instance, 35 | a SuperLearner model will have model class "SL".} 36 | } 37 | \if{html}{\out{
}} 38 | } 39 | \section{Methods}{ 40 | \subsection{Public methods}{ 41 | \itemize{ 42 | \item \href{#method-FX.Predictor-new}{\code{FX.Predictor$new()}} 43 | \item \href{#method-FX.Predictor-predict}{\code{FX.Predictor$predict()}} 44 | \item \href{#method-FX.Predictor-clone}{\code{FX.Predictor$clone()}} 45 | } 46 | } 47 | \if{html}{\out{
}} 48 | \if{html}{\out{}} 49 | \if{latex}{\out{\hypertarget{method-FX.Predictor-new}{}}} 50 | \subsection{Method \code{new()}}{ 51 | \code{FX.predictor} is a class which simplifies the management of a set of cross-fit 52 | prediction models of treatment effects and provides the ability to get the "partial" 53 | effects of particular covariates. 54 | \subsection{Usage}{ 55 | \if{html}{\out{
}}\preformatted{FX.Predictor$new(models, num_splits, num_mc_samples, covariates, model_class)}\if{html}{\out{
}} 56 | } 57 | 58 | \subsection{Arguments}{ 59 | \if{html}{\out{
}} 60 | \describe{ 61 | \item{\code{models}}{A list of the K model fits.} 62 | 63 | \item{\code{num_splits}}{Integer number of cross-fitting folds.} 64 | 65 | \item{\code{num_mc_samples}}{Integer number of Monte-Carlo samples across the covariate 66 | space. If this is larger than the sample size, then the whole dataset will be used.} 67 | 68 | \item{\code{covariates}}{The unquoted names of the covariates.} 69 | 70 | \item{\code{model_class}}{The model class (in the sense of \code{Model_cfg}).} 71 | } 72 | \if{html}{\out{
}} 73 | } 74 | } 75 | \if{html}{\out{
}} 76 | \if{html}{\out{}} 77 | \if{latex}{\out{\hypertarget{method-FX.Predictor-predict}{}}} 78 | \subsection{Method \code{predict()}}{ 79 | Predicts the PCATE surface over a particular covariate, returning a tibble with 80 | the predicted HTE for every Monte-Carlo sample. 81 | \subsection{Usage}{ 82 | \if{html}{\out{
}}\preformatted{FX.Predictor$predict(data, covariate)}\if{html}{\out{
}} 83 | } 84 | 85 | \subsection{Arguments}{ 86 | \if{html}{\out{
}} 87 | \describe{ 88 | \item{\code{data}}{The full dataset} 89 | 90 | \item{\code{covariate}}{The unquoted covariate name for which to calculate predicted 91 | treatment effects.} 92 | } 93 | \if{html}{\out{
}} 94 | } 95 | \subsection{Returns}{ 96 | A tibble with columns: 97 | \itemize{ 98 | \item \code{covariate_value} - The value of the covariate of interest 99 | \item \code{.hte} - An estimated HTE 100 | \item \code{.id} - The identifier for the original row (which had 101 | \code{covariate} modified to \code{covariate_value}). 102 | } 103 | } 104 | } 105 | \if{html}{\out{
}} 106 | \if{html}{\out{}} 107 | \if{latex}{\out{\hypertarget{method-FX.Predictor-clone}{}}} 108 | \subsection{Method \code{clone()}}{ 109 | The objects of this class are cloneable with this method. 110 | \subsection{Usage}{ 111 | \if{html}{\out{
}}\preformatted{FX.Predictor$clone(deep = FALSE)}\if{html}{\out{
}} 112 | } 113 | 114 | \subsection{Arguments}{ 115 | \if{html}{\out{
}} 116 | \describe{ 117 | \item{\code{deep}}{Whether to make a deep clone.} 118 | } 119 | \if{html}{\out{
}} 120 | } 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /man/Model_data.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data-utils.R 3 | \name{Model_data} 4 | \alias{Model_data} 5 | \title{R6 class to represent data to be used in estimating a model} 6 | \description{ 7 | R6 class to represent data to be used in estimating a model 8 | 9 | R6 class to represent data to be used in estimating a model 10 | } 11 | \details{ 12 | This class provides consistent names and interfaces to data which will 13 | be used in a supervised regression / classification model. 14 | } 15 | \examples{ 16 | 17 | ## ------------------------------------------------ 18 | ## Method `Model_data$new` 19 | ## ------------------------------------------------ 20 | 21 | library("dplyr") 22 | df <- dplyr::tibble( 23 | uid = 1:100, 24 | x1 = rnorm(100), 25 | x2 = rnorm(100), 26 | x3 = sample(4, 100, replace = TRUE) 27 | ) \%>\% dplyr::mutate( 28 | y = x1 + x2 + x3 + rnorm(100), 29 | x3 = factor(x3) 30 | ) 31 | df <- make_splits(df, uid, .num_splits = 5) 32 | data <- Model_data$new(df, y, x1, x2, x3) 33 | } 34 | \seealso{ 35 | \link[SuperLearner:SuperLearner.CV.control]{SuperLearner::SuperLearner.CV.control} 36 | } 37 | \section{Public fields}{ 38 | \if{html}{\out{
}} 39 | \describe{ 40 | \item{\code{label}}{The labels for the eventual model as a vector.} 41 | 42 | \item{\code{features}}{The matrix representation of the data to be used for model fitting. 43 | Constructed using \code{stats::model.matrix}.} 44 | 45 | \item{\code{model_frame}}{The data-frame representation of the data as constructed by 46 | \code{stats::model.frame}.} 47 | 48 | \item{\code{split_id}}{The split identifiers as a vector.} 49 | 50 | \item{\code{num_splits}}{The integer number of splits in the data.} 51 | 52 | \item{\code{cluster}}{A cluster ID as a vector, constructed using the unit identifiers.} 53 | 54 | \item{\code{weights}}{The case-weights as a vector.} 55 | } 56 | \if{html}{\out{
}} 57 | } 58 | \section{Methods}{ 59 | \subsection{Public methods}{ 60 | \itemize{ 61 | \item \href{#method-Model_data-new}{\code{Model_data$new()}} 62 | \item \href{#method-Model_data-SL_cv_control}{\code{Model_data$SL_cv_control()}} 63 | \item \href{#method-Model_data-clone}{\code{Model_data$clone()}} 64 | } 65 | } 66 | \if{html}{\out{
}} 67 | \if{html}{\out{}} 68 | \if{latex}{\out{\hypertarget{method-Model_data-new}{}}} 69 | \subsection{Method \code{new()}}{ 70 | Creates an R6 object to represent data to be used in a prediction model. 71 | \subsection{Usage}{ 72 | \if{html}{\out{
}}\preformatted{Model_data$new(data, label_col, ..., .weight_col = NULL)}\if{html}{\out{
}} 73 | } 74 | 75 | \subsection{Arguments}{ 76 | \if{html}{\out{
}} 77 | \describe{ 78 | \item{\code{data}}{The full dataset to populate the class with.} 79 | 80 | \item{\code{label_col}}{The unquoted name of the column to use as the label in 81 | supervised learning models.} 82 | 83 | \item{\code{...}}{The unquoted names of features to use in the model.} 84 | 85 | \item{\code{.weight_col}}{The unquoted name of the column to use as case-weights 86 | in subsequent models.} 87 | } 88 | \if{html}{\out{
}} 89 | } 90 | \subsection{Returns}{ 91 | A \code{Model_data} object. 92 | } 93 | \subsection{Examples}{ 94 | \if{html}{\out{
}} 95 | \preformatted{library("dplyr") 96 | df <- dplyr::tibble( 97 | uid = 1:100, 98 | x1 = rnorm(100), 99 | x2 = rnorm(100), 100 | x3 = sample(4, 100, replace = TRUE) 101 | ) \%>\% dplyr::mutate( 102 | y = x1 + x2 + x3 + rnorm(100), 103 | x3 = factor(x3) 104 | ) 105 | df <- make_splits(df, uid, .num_splits = 5) 106 | data <- Model_data$new(df, y, x1, x2, x3) 107 | } 108 | \if{html}{\out{
}} 109 | 110 | } 111 | 112 | } 113 | \if{html}{\out{
}} 114 | \if{html}{\out{}} 115 | \if{latex}{\out{\hypertarget{method-Model_data-SL_cv_control}{}}} 116 | \subsection{Method \code{SL_cv_control()}}{ 117 | A helper function to create the cross-validation options to be used by SuperLearner. 118 | \subsection{Usage}{ 119 | \if{html}{\out{
}}\preformatted{Model_data$SL_cv_control()}\if{html}{\out{
}} 120 | } 121 | 122 | } 123 | \if{html}{\out{
}} 124 | \if{html}{\out{}} 125 | \if{latex}{\out{\hypertarget{method-Model_data-clone}{}}} 126 | \subsection{Method \code{clone()}}{ 127 | The objects of this class are cloneable with this method. 128 | \subsection{Usage}{ 129 | \if{html}{\out{
}}\preformatted{Model_data$clone(deep = FALSE)}\if{html}{\out{
}} 130 | } 131 | 132 | \subsection{Arguments}{ 133 | \if{html}{\out{
}} 134 | \describe{ 135 | \item{\code{deep}}{Whether to make a deep clone.} 136 | } 137 | \if{html}{\out{
}} 138 | } 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /_pkgdown.yml: -------------------------------------------------------------------------------- 1 | url: https://ddimmery.github.io/tidyhte/ 2 | template: 3 | params: 4 | ganalytics: G-R58XC167ZD 5 | bootswatch: yeti 6 | navbar: 7 | structure: 8 | left: [home, reference, articles, news] 9 | right: [github] 10 | components: 11 | home: ~ 12 | articles: 13 | text: Vignettes 14 | menu: 15 | - text: Methodological details 16 | href: articles/methodological_details.html 17 | - text: Example experimental analysis 18 | href: articles/experimental_analysis.html 19 | - text: Example observational analysis 20 | href: articles/observational_analysis.html 21 | reference: 22 | - title: Estimation API 23 | desc: > 24 | Tidy functions for performing an analysis of heterogeneous treatment effects. Once 25 | a configuration has been defined (like by the Recipe API), these functions are the 26 | workhorses that perform all estimation. 27 | - contents: 28 | - attach_config 29 | - make_splits 30 | - produce_plugin_estimates 31 | - construct_pseudo_outcomes 32 | - estimate_QoI 33 | - title: Recipe API 34 | desc: > 35 | Tidy functions for configuring a "recipe" for how to estimate heterogeneous treatment effects. 36 | This is the easiest way to get started with setting up a configuration for an HTE analysis. 37 | - contents: 38 | - basic_config 39 | - add_propensity_score_model 40 | - add_known_propensity_score 41 | - add_propensity_diagnostic 42 | - add_outcome_model 43 | - add_outcome_diagnostic 44 | - add_effect_model 45 | - add_effect_diagnostic 46 | - add_moderator 47 | - add_vimp 48 | - remove_vimp 49 | - title: Model Configuration 50 | desc: > 51 | Classes to define the configuration of models to be used in the eventual HTE analysis. 52 | These are the classes which define the underlying configurations in the Recipe API. They're 53 | most useful for advanced users who want the most granular control over their analysis, but 54 | most users will be best served by the Recipe API. 55 | - subtitle: Base Class 56 | - contents: 57 | - Model_cfg 58 | - subtitle: Models for displaying results 59 | desc: > 60 | These model configurations include valid standard errors and are well suited 61 | for providing usable output to be plotted or returned for formal inference. 62 | - contents: 63 | - Stratified_cfg 64 | - KernelSmooth_cfg 65 | - subtitle: Models for nuisance functions 66 | desc: > 67 | These models are most useful for estimating nuisance functions in the course of HTE 68 | estimation. In particular, `Known_cfg` may be used when the propensity score is 69 | known ex ante, while `SLEnsemble_cfg` may be used when a SuperLearner ensemble 70 | of machine learning models should be used to estimate an unknown nuisance function. 71 | - contents: 72 | - Known_cfg 73 | - Constant_cfg 74 | - SLEnsemble_cfg 75 | - SLLearner_cfg 76 | - title: Analysis Configuration 77 | desc: > 78 | These classes configure the overall shape of the HTE analysis: essentially, how 79 | all the various components and models should fit together. They explain what models 80 | should be estimated and how those models should be combined into relevant quantities 81 | of interest. These, too, underlie the Recipe API, and should rarely need to be used 82 | directly. 83 | - contents: 84 | - Diagnostics_cfg 85 | - HTE_cfg 86 | - QoI_cfg 87 | - VIMP_cfg 88 | - subtitle: Configuration of HTE estimands 89 | desc: > 90 | There are two configurations for HTE estimands. `MCATE_cfg` simply provides a one-dimensional 91 | slice of the effects over a particular dimension. It averages over all other moderators, so does not 92 | attempt to attribute heterogeneity to particular covariates. `PCATE_cfg`, however, does attempt to do 93 | this type of attribution. Estimation of these "partial" effects requires strong assumptions, even 94 | within the context of a randomized control trial. Essentially, one must assume that the moderator, too, 95 | is random conditional on treatment and other covariates. 96 | - contents: 97 | - MCATE_cfg 98 | - PCATE_cfg 99 | - title: Internal Functions 100 | desc: > 101 | The remaining functions are not really useful for end-users. Documentation is provided 102 | in order to provide additional details about the internal workings of the methods. 103 | - contents: 104 | - SL.glmnet.interaction 105 | - "predict.SL.glmnet.interaction" 106 | - calculate_ate 107 | - calculate_diagnostics 108 | - calculate_rroc 109 | - calculate_linear_vimp 110 | - calculate_vimp 111 | - split_data 112 | - Model_data 113 | - check_data_has_hte_cfg 114 | - check_identifier 115 | - check_nuisance_models 116 | - check_splits 117 | - check_weights 118 | - fit_plugin 119 | - fit_plugin_A 120 | - fit_plugin_Y 121 | - fit_effect 122 | - listwise_deletion 123 | -------------------------------------------------------------------------------- /man/Diagnostics_cfg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/meta_cfg.R 3 | \name{Diagnostics_cfg} 4 | \alias{Diagnostics_cfg} 5 | \title{Configuration of Model Diagnostics} 6 | \description{ 7 | \code{Diagnostics_cfg} is a configuration class for estimating a variety of 8 | diagnostics for the models trained in the course of HTE estimation. 9 | } 10 | \examples{ 11 | Diagnostics_cfg$new( 12 | outcome = c("SL_risk", "SL_coefs", "MSE", "RROC"), 13 | ps = c("SL_risk", "SL_coefs", "AUC") 14 | ) 15 | 16 | ## ------------------------------------------------ 17 | ## Method `Diagnostics_cfg$new` 18 | ## ------------------------------------------------ 19 | 20 | Diagnostics_cfg$new( 21 | outcome = c("SL_risk", "SL_coefs", "MSE", "RROC"), 22 | ps = c("SL_risk", "SL_coefs", "AUC") 23 | ) 24 | 25 | ## ------------------------------------------------ 26 | ## Method `Diagnostics_cfg$add` 27 | ## ------------------------------------------------ 28 | 29 | cfg <- Diagnostics_cfg$new( 30 | outcome = c("SL_risk", "SL_coefs", "MSE", "RROC"), 31 | ps = c("SL_risk", "SL_coefs") 32 | ) 33 | cfg <- cfg$add(ps = "AUC") 34 | } 35 | \section{Public fields}{ 36 | \if{html}{\out{
}} 37 | \describe{ 38 | \item{\code{ps}}{Model diagnostics for the propensity score model.} 39 | 40 | \item{\code{outcome}}{Model diagnostics for the outcome models.} 41 | 42 | \item{\code{effect}}{Model diagnostics for the joint effect model.} 43 | 44 | \item{\code{params}}{Parameters for any requested diagnostics.} 45 | } 46 | \if{html}{\out{
}} 47 | } 48 | \section{Methods}{ 49 | \subsection{Public methods}{ 50 | \itemize{ 51 | \item \href{#method-Diagnostics_cfg-new}{\code{Diagnostics_cfg$new()}} 52 | \item \href{#method-Diagnostics_cfg-add}{\code{Diagnostics_cfg$add()}} 53 | \item \href{#method-Diagnostics_cfg-clone}{\code{Diagnostics_cfg$clone()}} 54 | } 55 | } 56 | \if{html}{\out{
}} 57 | \if{html}{\out{}} 58 | \if{latex}{\out{\hypertarget{method-Diagnostics_cfg-new}{}}} 59 | \subsection{Method \code{new()}}{ 60 | Create a new \code{Diagnostics_cfg} object with specified diagnostics to estimate. 61 | \subsection{Usage}{ 62 | \if{html}{\out{
}}\preformatted{Diagnostics_cfg$new(ps = NULL, outcome = NULL, effect = NULL, params = NULL)}\if{html}{\out{
}} 63 | } 64 | 65 | \subsection{Arguments}{ 66 | \if{html}{\out{
}} 67 | \describe{ 68 | \item{\code{ps}}{Model diagnostics for the propensity score model.} 69 | 70 | \item{\code{outcome}}{Model diagnostics for the outcome models.} 71 | 72 | \item{\code{effect}}{Model diagnostics for the joint effect model.} 73 | 74 | \item{\code{params}}{List providing values for parameters to any requested diagnostics.} 75 | } 76 | \if{html}{\out{
}} 77 | } 78 | \subsection{Returns}{ 79 | A new \code{Diagnostics_cfg} object. 80 | } 81 | \subsection{Examples}{ 82 | \if{html}{\out{
}} 83 | \preformatted{Diagnostics_cfg$new( 84 | outcome = c("SL_risk", "SL_coefs", "MSE", "RROC"), 85 | ps = c("SL_risk", "SL_coefs", "AUC") 86 | ) 87 | } 88 | \if{html}{\out{
}} 89 | 90 | } 91 | 92 | } 93 | \if{html}{\out{
}} 94 | \if{html}{\out{}} 95 | \if{latex}{\out{\hypertarget{method-Diagnostics_cfg-add}{}}} 96 | \subsection{Method \code{add()}}{ 97 | Add diagnostics to the \code{Diagnostics_cfg} object. 98 | \subsection{Usage}{ 99 | \if{html}{\out{
}}\preformatted{Diagnostics_cfg$add(ps = NULL, outcome = NULL, effect = NULL)}\if{html}{\out{
}} 100 | } 101 | 102 | \subsection{Arguments}{ 103 | \if{html}{\out{
}} 104 | \describe{ 105 | \item{\code{ps}}{Model diagnostics for the propensity score model.} 106 | 107 | \item{\code{outcome}}{Model diagnostics for the outcome models.} 108 | 109 | \item{\code{effect}}{Model diagnostics for the joint effect model.} 110 | } 111 | \if{html}{\out{
}} 112 | } 113 | \subsection{Returns}{ 114 | An updated \code{Diagnostics_cfg} object. 115 | } 116 | \subsection{Examples}{ 117 | \if{html}{\out{
}} 118 | \preformatted{cfg <- Diagnostics_cfg$new( 119 | outcome = c("SL_risk", "SL_coefs", "MSE", "RROC"), 120 | ps = c("SL_risk", "SL_coefs") 121 | ) 122 | cfg <- cfg$add(ps = "AUC") 123 | } 124 | \if{html}{\out{
}} 125 | 126 | } 127 | 128 | } 129 | \if{html}{\out{
}} 130 | \if{html}{\out{}} 131 | \if{latex}{\out{\hypertarget{method-Diagnostics_cfg-clone}{}}} 132 | \subsection{Method \code{clone()}}{ 133 | The objects of this class are cloneable with this method. 134 | \subsection{Usage}{ 135 | \if{html}{\out{
}}\preformatted{Diagnostics_cfg$clone(deep = FALSE)}\if{html}{\out{
}} 136 | } 137 | 138 | \subsection{Arguments}{ 139 | \if{html}{\out{
}} 140 | \describe{ 141 | \item{\code{deep}}{Whether to make a deep clone.} 142 | } 143 | \if{html}{\out{
}} 144 | } 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /man/MCATE_cfg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/meta_cfg.R 3 | \name{MCATE_cfg} 4 | \alias{MCATE_cfg} 5 | \title{Configuration of Marginal CATEs} 6 | \description{ 7 | \code{MCATE_cfg} is a configuration class for estimating marginal response 8 | surfaces based on heterogeneous treatment effect estimates. "Marginal" 9 | in this context implies that all other covariates are marginalized. 10 | Thus, if two covariates are highly correlated, it is likely that their 11 | MCATE surfaces will be extremely similar. 12 | } 13 | \examples{ 14 | MCATE_cfg$new(cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100))) 15 | 16 | ## ------------------------------------------------ 17 | ## Method `MCATE_cfg$new` 18 | ## ------------------------------------------------ 19 | 20 | MCATE_cfg$new(cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100))) 21 | 22 | ## ------------------------------------------------ 23 | ## Method `MCATE_cfg$add_moderator` 24 | ## ------------------------------------------------ 25 | 26 | cfg <- MCATE_cfg$new(cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100))) 27 | cfg <- cfg$add_moderator("x2", KernelSmooth_cfg$new(neval = 100)) 28 | } 29 | \section{Public fields}{ 30 | \if{html}{\out{
}} 31 | \describe{ 32 | \item{\code{cfgs}}{Named list of covariates names to a \code{Model_cfg} object defining 33 | how to present that covariate's CATE surface (while marginalizing 34 | over all other covariates).} 35 | 36 | \item{\code{std_errors}}{Boolean indicating whether the results should be 37 | returned with standard errors or not.} 38 | 39 | \item{\code{estimand}}{String indicating the estimand to target.} 40 | } 41 | \if{html}{\out{
}} 42 | } 43 | \section{Methods}{ 44 | \subsection{Public methods}{ 45 | \itemize{ 46 | \item \href{#method-MCATE_cfg-new}{\code{MCATE_cfg$new()}} 47 | \item \href{#method-MCATE_cfg-add_moderator}{\code{MCATE_cfg$add_moderator()}} 48 | \item \href{#method-MCATE_cfg-clone}{\code{MCATE_cfg$clone()}} 49 | } 50 | } 51 | \if{html}{\out{
}} 52 | \if{html}{\out{}} 53 | \if{latex}{\out{\hypertarget{method-MCATE_cfg-new}{}}} 54 | \subsection{Method \code{new()}}{ 55 | Create a new \code{MCATE_cfg} object with specified model name and hyperparameters. 56 | \subsection{Usage}{ 57 | \if{html}{\out{
}}\preformatted{MCATE_cfg$new(cfgs, std_errors = TRUE)}\if{html}{\out{
}} 58 | } 59 | 60 | \subsection{Arguments}{ 61 | \if{html}{\out{
}} 62 | \describe{ 63 | \item{\code{cfgs}}{Named list from moderator name to a \code{Model_cfg} object 64 | defining how to present that covariate's CATE surface (while 65 | marginalizing over all other covariates)} 66 | 67 | \item{\code{std_errors}}{Boolean indicating whether the results should be returned with standard 68 | errors or not.} 69 | } 70 | \if{html}{\out{
}} 71 | } 72 | \subsection{Returns}{ 73 | A new \code{MCATE_cfg} object. 74 | } 75 | \subsection{Examples}{ 76 | \if{html}{\out{
}} 77 | \preformatted{MCATE_cfg$new(cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100))) 78 | } 79 | \if{html}{\out{
}} 80 | 81 | } 82 | 83 | } 84 | \if{html}{\out{
}} 85 | \if{html}{\out{}} 86 | \if{latex}{\out{\hypertarget{method-MCATE_cfg-add_moderator}{}}} 87 | \subsection{Method \code{add_moderator()}}{ 88 | Add a moderator to the \code{MCATE_cfg} object. This entails defining a configuration 89 | for displaying the effect surface for that moderator. 90 | \subsection{Usage}{ 91 | \if{html}{\out{
}}\preformatted{MCATE_cfg$add_moderator(var_name, cfg)}\if{html}{\out{
}} 92 | } 93 | 94 | \subsection{Arguments}{ 95 | \if{html}{\out{
}} 96 | \describe{ 97 | \item{\code{var_name}}{The name of the moderator to add (and the name of the column in 98 | the dataset).} 99 | 100 | \item{\code{cfg}}{A \code{Model_cfg} defining how to display the selected moderator's effect 101 | surface.} 102 | } 103 | \if{html}{\out{
}} 104 | } 105 | \subsection{Returns}{ 106 | An updated \code{MCATE_cfg} object. 107 | } 108 | \subsection{Examples}{ 109 | \if{html}{\out{
}} 110 | \preformatted{cfg <- MCATE_cfg$new(cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100))) 111 | cfg <- cfg$add_moderator("x2", KernelSmooth_cfg$new(neval = 100)) 112 | } 113 | \if{html}{\out{
}} 114 | 115 | } 116 | 117 | } 118 | \if{html}{\out{
}} 119 | \if{html}{\out{}} 120 | \if{latex}{\out{\hypertarget{method-MCATE_cfg-clone}{}}} 121 | \subsection{Method \code{clone()}}{ 122 | The objects of this class are cloneable with this method. 123 | \subsection{Usage}{ 124 | \if{html}{\out{
}}\preformatted{MCATE_cfg$clone(deep = FALSE)}\if{html}{\out{
}} 125 | } 126 | 127 | \subsection{Arguments}{ 128 | \if{html}{\out{
}} 129 | \describe{ 130 | \item{\code{deep}}{Whether to make a deep clone.} 131 | } 132 | \if{html}{\out{
}} 133 | } 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /man/SLEnsemble_cfg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/model_cfg.R 3 | \name{SLEnsemble_cfg} 4 | \alias{SLEnsemble_cfg} 5 | \title{Configuration for a SuperLearner Ensemble} 6 | \description{ 7 | \code{SLEnsemble_cfg} is a configuration class for estimation of a model 8 | using an ensemble of models using \code{SuperLearner}. 9 | } 10 | \examples{ 11 | SLEnsemble_cfg$new( 12 | learner_cfgs = list(SLLearner_cfg$new("SL.glm"), SLLearner_cfg$new("SL.gam")) 13 | ) 14 | 15 | ## ------------------------------------------------ 16 | ## Method `SLEnsemble_cfg$new` 17 | ## ------------------------------------------------ 18 | 19 | SLEnsemble_cfg$new( 20 | learner_cfgs = list(SLLearner_cfg$new("SL.glm"), SLLearner_cfg$new("SL.gam")) 21 | ) 22 | } 23 | \section{Super class}{ 24 | \code{\link[tidyhte:Model_cfg]{tidyhte::Model_cfg}} -> \code{SLEnsemble_cfg} 25 | } 26 | \section{Public fields}{ 27 | \if{html}{\out{
}} 28 | \describe{ 29 | \item{\code{cvControl}}{A list of parameters for controlling the 30 | cross-validation used in SuperLearner.} 31 | 32 | \item{\code{SL.library}}{A vector of the names of learners to 33 | include in the SuperLearner ensemble.} 34 | 35 | \item{\code{SL.env}}{An environment containing all of the programmatically 36 | generated learners to be included 37 | in the SuperLearner ensemble.} 38 | 39 | \item{\code{family}}{\code{stats::family} object to determine how SuperLearner 40 | should be fitted.} 41 | 42 | \item{\code{model_class}}{The class of the model, required for all classes 43 | which inherit from \code{Model_cfg}.} 44 | } 45 | \if{html}{\out{
}} 46 | } 47 | \section{Methods}{ 48 | \subsection{Public methods}{ 49 | \itemize{ 50 | \item \href{#method-SLEnsemble_cfg-new}{\code{SLEnsemble_cfg$new()}} 51 | \item \href{#method-SLEnsemble_cfg-add_sublearner}{\code{SLEnsemble_cfg$add_sublearner()}} 52 | \item \href{#method-SLEnsemble_cfg-clone}{\code{SLEnsemble_cfg$clone()}} 53 | } 54 | } 55 | \if{html}{\out{
}} 56 | \if{html}{\out{}} 57 | \if{latex}{\out{\hypertarget{method-SLEnsemble_cfg-new}{}}} 58 | \subsection{Method \code{new()}}{ 59 | Create a new \code{SLEnsemble_cfg} object with specified settings. 60 | \subsection{Usage}{ 61 | \if{html}{\out{
}}\preformatted{SLEnsemble_cfg$new( 62 | cvControl = NULL, 63 | learner_cfgs = NULL, 64 | family = stats::gaussian() 65 | )}\if{html}{\out{
}} 66 | } 67 | 68 | \subsection{Arguments}{ 69 | \if{html}{\out{
}} 70 | \describe{ 71 | \item{\code{cvControl}}{A list of parameters for controlling the 72 | cross-validation used in SuperLearner. 73 | For more details, see \code{SuperLearner::SuperLearner.CV.control}.} 74 | 75 | \item{\code{learner_cfgs}}{A list of \code{SLLearner_cfg} objects.} 76 | 77 | \item{\code{family}}{\code{stats::family} object to determine how SuperLearner should be fitted.} 78 | } 79 | \if{html}{\out{
}} 80 | } 81 | \subsection{Returns}{ 82 | A new \code{SLEnsemble_cfg} object. 83 | } 84 | \subsection{Examples}{ 85 | \if{html}{\out{
}} 86 | \preformatted{SLEnsemble_cfg$new( 87 | learner_cfgs = list(SLLearner_cfg$new("SL.glm"), SLLearner_cfg$new("SL.gam")) 88 | ) 89 | } 90 | \if{html}{\out{
}} 91 | 92 | } 93 | 94 | } 95 | \if{html}{\out{
}} 96 | \if{html}{\out{}} 97 | \if{latex}{\out{\hypertarget{method-SLEnsemble_cfg-add_sublearner}{}}} 98 | \subsection{Method \code{add_sublearner()}}{ 99 | Adds a model (or class of models) to the SuperLearner ensemble. 100 | If hyperparameter values are specified, this method will 101 | add a learner for every element in the cross-product of provided 102 | hyperparameter values. 103 | \subsection{Usage}{ 104 | \if{html}{\out{
}}\preformatted{SLEnsemble_cfg$add_sublearner(learner_name, hps = NULL)}\if{html}{\out{
}} 105 | } 106 | 107 | \subsection{Arguments}{ 108 | \if{html}{\out{
}} 109 | \describe{ 110 | \item{\code{learner_name}}{Possible values 111 | use \code{SuperLearner} naming conventions. A full list is available 112 | with \code{SuperLearner::listWrappers("SL")}} 113 | 114 | \item{\code{hps}}{A named list of hyper-parameters. Every element of the 115 | cross-product of these hyper-parameters will be included in the 116 | ensemble. 117 | cfg <- SLEnsemble_cfg$new( 118 | learner_cfgs = list(SLLearner_cfg$new("SL.glm")) 119 | ) 120 | cfg <- cfg$add_sublearner("SL.gam", list(deg.gam = c(2, 3)))} 121 | } 122 | \if{html}{\out{
}} 123 | } 124 | } 125 | \if{html}{\out{
}} 126 | \if{html}{\out{}} 127 | \if{latex}{\out{\hypertarget{method-SLEnsemble_cfg-clone}{}}} 128 | \subsection{Method \code{clone()}}{ 129 | The objects of this class are cloneable with this method. 130 | \subsection{Usage}{ 131 | \if{html}{\out{
}}\preformatted{SLEnsemble_cfg$clone(deep = FALSE)}\if{html}{\out{
}} 132 | } 133 | 134 | \subsection{Arguments}{ 135 | \if{html}{\out{
}} 136 | \describe{ 137 | \item{\code{deep}}{Whether to make a deep clone.} 138 | } 139 | \if{html}{\out{
}} 140 | } 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /tests/testthat/test-api-weights.R: -------------------------------------------------------------------------------- 1 | set.seed(20051920) # 20051920 is derived from 'test' 2 | 3 | n <- 250 4 | data <- dplyr::tibble( 5 | uid = 1:n 6 | ) %>% 7 | dplyr::mutate( 8 | a = rbinom(n, 1, 0.5), 9 | ps = rep(0.5, n), 10 | x1 = rnorm(n), 11 | x2 = factor(sample(1:4, n, prob = c(1 / 5, 1 / 5, 1 / 5, 2 / 5), replace = TRUE)), 12 | x3 = factor(sample(1:3, n, prob = c(1 / 5, 1 / 5, 3 / 5), replace = TRUE)), 13 | x4 = (x1 + rnorm(n)) / 2, 14 | x5 = rnorm(n), 15 | y = a + x1 - 2.5 * a * (x1 - mean(x1)) + as.double(x2) + rnorm(n), 16 | w = rexp(n, plogis(x1 - mean(x1))) 17 | ) 18 | 19 | userid <- rlang::expr(uid) 20 | weight_variable <- rlang::expr(w) 21 | 22 | propensity_score_variable_name <- "ps" 23 | 24 | continuous_covariates <- c("x1") 25 | 26 | discrete_covariates <- c("x2", "x3") 27 | 28 | continuous_moderators <- rlang::exprs(x1) 29 | discrete_moderators <- rlang::exprs(x2, x3) 30 | 31 | model_covariate_names <- c(continuous_covariates, discrete_covariates) 32 | model_covariates <- rlang::syms(model_covariate_names) 33 | 34 | outcome_variable <- rlang::expr(y) 35 | treatment_variable <- rlang::expr(a) 36 | 37 | trt.cfg <- SLEnsemble_cfg$new( 38 | learner_cfgs = list( 39 | SLLearner_cfg$new( 40 | "SL.glm" 41 | ) 42 | ), 43 | family = stats::quasibinomial() 44 | ) 45 | 46 | regression.cfg <- SLEnsemble_cfg$new( 47 | learner_cfgs = list( 48 | SLLearner_cfg$new( 49 | "SL.glm" 50 | ) 51 | ) 52 | ) 53 | 54 | qoi.list <- list() 55 | for (cov in continuous_moderators) { 56 | qoi.list[[rlang::as_string(cov)]] <- KernelSmooth_cfg$new(neval = 100) 57 | } 58 | for (cov in discrete_moderators) { 59 | qoi.list[[rlang::as_string(cov)]] <- Stratified_cfg$new(cov) 60 | } 61 | 62 | qoi.cfg <- QoI_cfg$new( 63 | mcate = MCATE_cfg$new(cfgs = qoi.list), 64 | vimp = VIMP_cfg$new(), 65 | diag = Diagnostics_cfg$new( 66 | outcome = c("SL_risk", "SL_coefs", "MSE"), 67 | ps = c("SL_risk", "SL_coefs", "AUC") 68 | ) 69 | ) 70 | 71 | cfg <- HTE_cfg$new( 72 | treatment = trt.cfg, 73 | outcome = regression.cfg, 74 | effect = regression.cfg, 75 | qoi = qoi.cfg 76 | ) 77 | 78 | E <- new.env(parent = emptyenv()) 79 | 80 | test_that("add config", { 81 | E$data1 <- attach_config(data, cfg) 82 | checkmate::expect_data_frame(E$data1) 83 | expect_true("HTE_cfg" %in% names(attributes(E$data1))) 84 | }) 85 | 86 | test_that("Split data", { 87 | E$data2 <- make_splits(E$data1, {{ userid }}, .num_splits = 4) 88 | checkmate::expect_data_frame(E$data2) 89 | }) 90 | 91 | test_that("Estimate Plugin Models", { 92 | E$data3 <- produce_plugin_estimates( 93 | E$data2, 94 | {{ outcome_variable }}, 95 | {{ treatment_variable }}, 96 | !!!model_covariates, 97 | .weights = {{ weight_variable }} 98 | ) 99 | checkmate::expect_data_frame(E$data3) 100 | }) 101 | 102 | test_that("Construct Pseudo-outcomes", { 103 | E$data4 <- construct_pseudo_outcomes(E$data3, {{ outcome_variable }}, {{ treatment_variable }}) 104 | checkmate::expect_data_frame(E$data4) 105 | }) 106 | 107 | test_that("Estimate QoIs (continuous)", { 108 | expect_error( 109 | estimate_QoI(E$data4, !!!continuous_moderators), 110 | "`nprobust` does not support the use of weights." 111 | ) 112 | }) 113 | 114 | test_that("Estimate QoIs (discrete)", { 115 | E$results <- estimate_QoI(E$data4, !!!discrete_moderators) 116 | checkmate::expect_data_frame(E$results) 117 | }) 118 | 119 | n_rows <- ( 120 | 1 + # SATE estimate 121 | 1 + # PATE estimate 122 | 2 + # MSE for y(0) & y(1) 123 | 1 + # AUC for pscore 124 | 2 * 1 + 1 + # one row per model in the ensemble for each PO + ps for SL risk 125 | 2 * 1 + 1 + # one row per model in the ensemble for each PO + ps for SL coefficient 126 | 2 + # one row per moderator for variable importance 127 | # 1 * 100 + # 100 rows per continuous moderator for local regression for MCATE 128 | (4 + 3) # 1 rows per discrete moderator level for MCATE 129 | ) 130 | 131 | test_that("PATE > SATE", { 132 | # PATE has larger treatment effects (because weight is correlated with moderator) 133 | pate <- E$results %>% dplyr::filter(grepl("PATE", estimand)) %>% select(estimate) %>% unlist() 134 | sate <- E$results %>% dplyr::filter(grepl("SATE", estimand)) %>% select(estimate) %>% unlist() 135 | expect_gt(pate, sate) 136 | 137 | # Weighting makes the standard error larger 138 | pate <- E$results %>% dplyr::filter(grepl("PATE", estimand)) %>% select(std_error) %>% unlist() 139 | sate <- E$results %>% dplyr::filter(grepl("SATE", estimand)) %>% select(std_error) %>% unlist() 140 | expect_gt(pate, sate) 141 | }) 142 | 143 | test_that("Check results data", { 144 | checkmate::check_character(E$results$estimand, any.missing = FALSE) 145 | checkmate::check_double(E$results$estimate, any.missing = FALSE) 146 | checkmate::check_double(E$results$std_error, any.missing = FALSE) 147 | 148 | checkmate::expect_tibble( 149 | E$results, 150 | all.missing = FALSE, 151 | nrows = n_rows, 152 | ncols = 5, # only 5 columns because no continuous moderators 153 | types = c( 154 | estimand = "character", 155 | term = "character", 156 | value = "numeric", 157 | level = "character", 158 | estimate = "double", 159 | std_error = "double" 160 | ) 161 | ) 162 | }) 163 | -------------------------------------------------------------------------------- /man/QoI_cfg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/meta_cfg.R 3 | \name{QoI_cfg} 4 | \alias{QoI_cfg} 5 | \title{Configuration of Quantities of Interest} 6 | \description{ 7 | \code{QoI_cfg} is a configuration class for the Quantities of Interest to be 8 | generated by the HTE analysis. 9 | } 10 | \examples{ 11 | mcate_cfg <- MCATE_cfg$new(cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100))) 12 | pcate_cfg <- PCATE_cfg$new( 13 | cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100)), 14 | model_covariates = c("x1", "x2", "x3"), 15 | num_mc_samples = list(x1 = 100) 16 | ) 17 | vimp_cfg <- VIMP_cfg$new() 18 | diag_cfg <- Diagnostics_cfg$new( 19 | outcome = c("SL_risk", "SL_coefs", "MSE"), 20 | ps = c("SL_risk", "SL_coefs", "AUC") 21 | ) 22 | QoI_cfg$new( 23 | mcate = mcate_cfg, 24 | pcate = pcate_cfg, 25 | vimp = vimp_cfg, 26 | diag = diag_cfg 27 | ) 28 | 29 | ## ------------------------------------------------ 30 | ## Method `QoI_cfg$new` 31 | ## ------------------------------------------------ 32 | 33 | mcate_cfg <- MCATE_cfg$new(cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100))) 34 | pcate_cfg <- PCATE_cfg$new( 35 | cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100)), 36 | model_covariates = c("x1", "x2", "x3"), 37 | num_mc_samples = list(x1 = 100) 38 | ) 39 | vimp_cfg <- VIMP_cfg$new() 40 | diag_cfg <- Diagnostics_cfg$new( 41 | outcome = c("SL_risk", "SL_coefs", "MSE"), 42 | ps = c("SL_risk", "SL_coefs", "AUC") 43 | ) 44 | QoI_cfg$new( 45 | mcate = mcate_cfg, 46 | pcate = pcate_cfg, 47 | vimp = vimp_cfg, 48 | diag = diag_cfg 49 | ) 50 | } 51 | \section{Public fields}{ 52 | \if{html}{\out{
}} 53 | \describe{ 54 | \item{\code{mcate}}{A configuration object of type \code{MCATE_cfg} of 55 | marginal effects to calculate.} 56 | 57 | \item{\code{pcate}}{A configuration object of type \code{PCATE_cfg} of 58 | partial effects to calculate.} 59 | 60 | \item{\code{vimp}}{A configuration object of type \code{VIMP_cfg} of 61 | variable importance to calculate.} 62 | 63 | \item{\code{diag}}{A configuration object of type \code{Diagnostics_cfg} of 64 | model diagnostics to calculate.} 65 | 66 | \item{\code{ate}}{Logical flag indicating whether an estimate of the 67 | ATE should be returned.} 68 | 69 | \item{\code{predictions}}{Logical flag indicating whether estimates of 70 | the CATE for every unit should be returned.} 71 | } 72 | \if{html}{\out{
}} 73 | } 74 | \section{Methods}{ 75 | \subsection{Public methods}{ 76 | \itemize{ 77 | \item \href{#method-QoI_cfg-new}{\code{QoI_cfg$new()}} 78 | \item \href{#method-QoI_cfg-clone}{\code{QoI_cfg$clone()}} 79 | } 80 | } 81 | \if{html}{\out{
}} 82 | \if{html}{\out{}} 83 | \if{latex}{\out{\hypertarget{method-QoI_cfg-new}{}}} 84 | \subsection{Method \code{new()}}{ 85 | Create a new \code{QoI_cfg} object with specified Quantities of Interest 86 | to estimate. 87 | \subsection{Usage}{ 88 | \if{html}{\out{
}}\preformatted{QoI_cfg$new( 89 | mcate = NULL, 90 | pcate = NULL, 91 | vimp = NULL, 92 | diag = NULL, 93 | ate = TRUE, 94 | predictions = FALSE 95 | )}\if{html}{\out{
}} 96 | } 97 | 98 | \subsection{Arguments}{ 99 | \if{html}{\out{
}} 100 | \describe{ 101 | \item{\code{mcate}}{A configuration object of type \code{MCATE_cfg} of marginal 102 | effects to calculate.} 103 | 104 | \item{\code{pcate}}{A configuration object of type \code{PCATE_cfg} of partial 105 | effects to calculate.} 106 | 107 | \item{\code{vimp}}{A configuration object of type \code{VIMP_cfg} of variable 108 | importance to calculate.} 109 | 110 | \item{\code{diag}}{A configuration object of type \code{Diagnostics_cfg} of 111 | model diagnostics to calculate.} 112 | 113 | \item{\code{ate}}{A logical flag for whether to calculate the Average 114 | Treatment Effect (ATE) or not.} 115 | 116 | \item{\code{predictions}}{A logical flag for whether to return predictions 117 | of the CATE for every unit or not.} 118 | } 119 | \if{html}{\out{
}} 120 | } 121 | \subsection{Returns}{ 122 | A new \code{Diagnostics_cfg} object. 123 | } 124 | \subsection{Examples}{ 125 | \if{html}{\out{
}} 126 | \preformatted{mcate_cfg <- MCATE_cfg$new(cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100))) 127 | pcate_cfg <- PCATE_cfg$new( 128 | cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100)), 129 | model_covariates = c("x1", "x2", "x3"), 130 | num_mc_samples = list(x1 = 100) 131 | ) 132 | vimp_cfg <- VIMP_cfg$new() 133 | diag_cfg <- Diagnostics_cfg$new( 134 | outcome = c("SL_risk", "SL_coefs", "MSE"), 135 | ps = c("SL_risk", "SL_coefs", "AUC") 136 | ) 137 | QoI_cfg$new( 138 | mcate = mcate_cfg, 139 | pcate = pcate_cfg, 140 | vimp = vimp_cfg, 141 | diag = diag_cfg 142 | ) 143 | } 144 | \if{html}{\out{
}} 145 | 146 | } 147 | 148 | } 149 | \if{html}{\out{
}} 150 | \if{html}{\out{}} 151 | \if{latex}{\out{\hypertarget{method-QoI_cfg-clone}{}}} 152 | \subsection{Method \code{clone()}}{ 153 | The objects of this class are cloneable with this method. 154 | \subsection{Usage}{ 155 | \if{html}{\out{
}}\preformatted{QoI_cfg$clone(deep = FALSE)}\if{html}{\out{
}} 156 | } 157 | 158 | \subsection{Arguments}{ 159 | \if{html}{\out{
}} 160 | \describe{ 161 | \item{\code{deep}}{Whether to make a deep clone.} 162 | } 163 | \if{html}{\out{
}} 164 | } 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /tests/testthat/test-api-pcate.R: -------------------------------------------------------------------------------- 1 | set.seed(20051920) # 20051920 is derived from 'test' 2 | 3 | n <- 250 4 | data <- dplyr::tibble( 5 | uid = 1:n 6 | ) %>% 7 | dplyr::mutate( 8 | a = rbinom(n, 1, 0.5), 9 | ps = rep(0.5, n), 10 | x1 = rnorm(n), 11 | x2 = factor(sample(1:4, n, prob = c(1 / 5, 1 / 5, 1 / 5, 2 / 5), replace = TRUE)), 12 | x3 = factor(sample(1:3, n, prob = c(1 / 5, 1 / 5, 3 / 5), replace = TRUE)), 13 | x4 = (x1 + rnorm(n)) / 2, 14 | x5 = rnorm(n), 15 | y = a + x1 - 0.5 * a * (x1 - mean(x1)) + as.double(x2) + rnorm(n) 16 | ) 17 | 18 | userid <- rlang::expr(uid) 19 | 20 | propensity_score_variable_name <- "ps" 21 | 22 | continuous_covariates <- c("x1") 23 | 24 | discrete_covariates <- c("x2", "x3") 25 | 26 | continuous_moderators <- rlang::exprs(x1, x4, x5) 27 | discrete_moderators <- rlang::exprs(x2, x3) 28 | moderators <- c(continuous_moderators, discrete_moderators) 29 | 30 | model_covariate_names <- c(continuous_covariates, discrete_covariates) 31 | model_covariates <- rlang::syms(model_covariate_names) 32 | 33 | outcome_variable <- rlang::expr(y) 34 | treatment_variable <- rlang::expr(a) 35 | 36 | trt.cfg <- Known_cfg$new(propensity_score_variable_name) 37 | 38 | regression.cfg <- SLEnsemble_cfg$new( 39 | learner_cfgs = list( 40 | SLLearner_cfg$new( 41 | "SL.glm" 42 | ), 43 | SLLearner_cfg$new( 44 | "SL.glmnet", 45 | list( 46 | alpha = c(0.05, 0.15) 47 | ) 48 | ) 49 | ) 50 | ) 51 | 52 | qoi.list <- list() 53 | for (cov in continuous_moderators) { 54 | qoi.list[[rlang::as_string(cov)]] <- KernelSmooth_cfg$new(neval = 100, eval_min_quantile = 0.05) 55 | } 56 | for (cov in discrete_moderators) { 57 | qoi.list[[rlang::as_string(cov)]] <- Stratified_cfg$new(cov) 58 | } 59 | 60 | qoi.cfg <- QoI_cfg$new( 61 | mcate = MCATE_cfg$new(cfgs = qoi.list), 62 | pcate = PCATE_cfg$new( 63 | cfgs = qoi.list, 64 | model_covariates = model_covariate_names, 65 | num_mc_samples = list(x1 = 5, x2 = 10, x3 = 10, x4 = 5, x5 = 5) 66 | ), 67 | diag = Diagnostics_cfg$new( 68 | outcome = c("SL_risk", "SL_coefs", "MSE"), 69 | effect = c("SL_risk", "SL_coefs") 70 | ), 71 | predictions = TRUE 72 | ) 73 | 74 | qoi.cfg2 <- QoI_cfg$new( 75 | mcate = MCATE_cfg$new(cfgs = qoi.list, std_errors = FALSE), 76 | pcate = PCATE_cfg$new( 77 | cfgs = qoi.list, 78 | model_covariates = model_covariate_names, 79 | num_mc_samples = list(x1 = 5, x2 = 10, x3 = 10, x4 = 5, x5 = 5) 80 | ), 81 | diag = Diagnostics_cfg$new( 82 | outcome = c("SL_risk", "SL_coefs", "MSE"), 83 | effect = c("SL_risk", "SL_coefs") 84 | ) 85 | ) 86 | 87 | cfg <- HTE_cfg$new( 88 | treatment = trt.cfg, 89 | outcome = regression.cfg, 90 | effect = regression.cfg, 91 | qoi = qoi.cfg 92 | ) 93 | 94 | cfg2 <- HTE_cfg$new( 95 | treatment = trt.cfg, 96 | outcome = regression.cfg, 97 | effect = regression.cfg, 98 | qoi = qoi.cfg2 99 | ) 100 | 101 | E <- new.env(parent = emptyenv()) 102 | 103 | test_that("add config", { 104 | E$data <- attach_config(data, cfg) 105 | checkmate::expect_data_frame(E$data) 106 | expect_true("HTE_cfg" %in% names(attributes(E$data))) 107 | }) 108 | 109 | test_that("Split data", { 110 | E$data2 <- make_splits(E$data, {{ userid }}, .num_splits = 3) 111 | checkmate::expect_data_frame(E$data2) 112 | }) 113 | 114 | test_that("Estimate Plugin Models", { 115 | E$data3 <- produce_plugin_estimates( 116 | E$data2, 117 | {{ outcome_variable }}, 118 | {{ treatment_variable }}, 119 | !!!model_covariates 120 | ) 121 | checkmate::expect_data_frame(E$data3) 122 | }) 123 | 124 | test_that("Construct Pseudo-outcomes", { 125 | E$data4 <- construct_pseudo_outcomes(E$data3, {{ outcome_variable }}, {{ treatment_variable }}) 126 | checkmate::expect_data_frame(E$data4) 127 | }) 128 | 129 | test_that("Estimate QoIs", { 130 | skip_on_cran() 131 | expect_warning( 132 | E$results <- estimate_QoI(E$data4, !!!moderators), 133 | "Only use PCATEs if you know what you're doing!" 134 | ) 135 | checkmate::expect_data_frame(E$results) 136 | E$data4 <- attach_config(E$data4, cfg2) 137 | expect_warning( 138 | E$results2 <- estimate_QoI(E$data4, !!!moderators), 139 | "Only use PCATEs if you know what you're doing!" 140 | ) 141 | checkmate::expect_data_frame(E$results2) 142 | }) 143 | 144 | n_rows <- ( 145 | 1 + # SATE estimate 146 | 2 + # MSE for y(0) & y(1) 147 | 3 * 3 + # one row per model in the ensemble for each PO / fx for SL risk 148 | 3 * 3 + # one row per model in the ensemble for each PO / fx for SL coefficient 149 | 2 * 3 * 100 + # 100 rows per continuous moderator for local regression for MCATE and for PCATE 150 | 2 * (4 + 3) + # 2 rows per discrete moderator level for MCATE and for PCATE 151 | n # One row for each predicted value 152 | ) 153 | 154 | 155 | test_that("Check results data", { 156 | skip_on_cran() 157 | checkmate::check_character(E$results$estimand, any.missing = FALSE) 158 | checkmate::check_double(E$results$estimate, any.missing = FALSE) 159 | checkmate::check_double(E$results$std_error, any.missing = FALSE) 160 | 161 | checkmate::expect_tibble( 162 | E$results, 163 | all.missing = FALSE, 164 | nrows = n_rows, 165 | ncols = 6, 166 | types = c( 167 | estimand = "character", 168 | term = "character", 169 | value = "double", 170 | level = "character", 171 | estimate = "double", 172 | std_error = "double" 173 | ) 174 | ) 175 | }) 176 | -------------------------------------------------------------------------------- /tests/testthat/test-splitting.R: -------------------------------------------------------------------------------- 1 | set.seed(20051920) # 20051920 is derived from 'test' 2 | 3 | n <- 100 4 | d <- dplyr::tibble( 5 | uid = 1:n, 6 | cov1 = sample(rep(1:2, c(n / 2, n / 2)), n, replace = TRUE), 7 | cov2 = sample(rep(1:2, c(n / 2, n / 2)), n, replace = TRUE) 8 | ) 9 | 10 | test_that("make_splits output", { 11 | checkmate::expect_data_frame( 12 | make_splits(d, uid, .num_splits = 2), 13 | any.missing = FALSE, 14 | nrows = 100, 15 | ncols = 4, 16 | col.names = "named" 17 | ) 18 | expect_true(".split_id" %in% names(make_splits(d, uid, .num_splits = 2))) 19 | }) 20 | 21 | s_df <- make_splits(d, uid, .num_splits = 2) 22 | s_id <- s_df$.split_id 23 | test_that("2 splits are ok", { 24 | expect_length(unique(s_id), 2) 25 | checkmate::expect_integerish(s_id, len = n, any.missing = FALSE, lower = 1, upper = 2) 26 | expect_true(all(table(s_id) == (n / 2))) 27 | }) 28 | 29 | test_that("resplitting works", { 30 | expect_error(s_df2 <- make_splits(s_df, uid, .num_splits = 4), NA) 31 | checkmate::expect_data_frame(s_df2, min.rows = n, max.rows = n) 32 | expect_true(".split_id" %in% names(s_df2)) 33 | }) 34 | 35 | test_that("check_splits", { 36 | expect_error(check_splits(d), "You must first construct splits with `tidyhte::make_splits`.") 37 | }) 38 | 39 | test_that("check that splits can be used to create dataframes", { 40 | checkmate::expect_r6(split_data(s_df, 1), classes = "HTEFold") 41 | expect_error(split_data(d, 1), "Must construct split identifiers before splitting.") 42 | }) 43 | 44 | test_that("train split is larger than holdout", { 45 | s_df <- make_splits(d, uid, .num_splits = 4) 46 | fold <- split_data(s_df, 1) 47 | checkmate::expect_r6(fold, classes = "HTEFold") 48 | expect_true(nrow(fold$train) == (3 * nrow(fold$holdout))) 49 | expect_true(nrow(fold$holdout) == sum(fold$in_holdout)) 50 | }) 51 | 52 | s_id <- make_splits(d, uid, .num_splits = 7)$.split_id 53 | test_that("check that 7 splits are ok", { 54 | expect_length(unique(s_id), 7) 55 | checkmate::expect_integerish(s_id, len = n, any.missing = FALSE, lower = 1, upper = 10) 56 | checkmate::expect_integer(table(s_id), lower = floor(n / 7), upper = ceiling(n / 7)) 57 | }) 58 | 59 | n <- 100 60 | d <- dplyr::tibble( 61 | uid = 1:n, 62 | cov1 = sample(rep(1:2, c(n / 2, n / 2)), n, replace = TRUE), 63 | cov2 = sample(rep(1:2, c(n / 20, 19 * n / 20)), n, replace = TRUE) 64 | ) 65 | 66 | s_df <- make_splits(d, uid, cov2, .num_splits = 2) 67 | test_that("check that stratification on one variable works", { 68 | result <- s_df %>% 69 | dplyr::group_by(cov2) %>% 70 | dplyr::summarize(all_splits = all(1:2 %in% .split_id)) %>% 71 | dplyr::summarize(all_splits = all(all_splits)) %>% 72 | unlist() 73 | expect_true(result) 74 | }) 75 | 76 | s_df <- make_splits(d, uid, cov1, cov2, .num_splits = 2) 77 | test_that("check that stratification on multiple variables works", { 78 | result <- s_df %>% 79 | dplyr::group_by(cov2) %>% 80 | dplyr::summarize(all_splits = all(1:2 %in% .split_id)) %>% 81 | dplyr::summarize(all_splits = all(all_splits)) %>% 82 | unlist() 83 | expect_true(result) 84 | 85 | result <- s_df %>% 86 | dplyr::group_by(cov1) %>% 87 | dplyr::summarize(all_splits = all(1:2 %in% .split_id)) %>% 88 | dplyr::summarize(all_splits = all(all_splits)) %>% 89 | unlist() 90 | expect_true(result) 91 | }) 92 | 93 | 94 | d <- dplyr::tibble( 95 | uid = paste0("uid is ", as.character(1:n)), 96 | cov1 = sample(rep(1:2, c(n / 2, n / 2)), n, replace = TRUE), 97 | cov2 = sample(rep(1:2, c(n / 2, n / 2)), n, replace = TRUE) 98 | ) 99 | 100 | test_that("character uids", { 101 | checkmate::expect_data_frame( 102 | make_splits(d, uid, .num_splits = 2), 103 | any.missing = FALSE, 104 | nrows = 100, 105 | ncols = 4, 106 | col.names = "named" 107 | ) 108 | expect_true(".split_id" %in% names(make_splits(d, uid, .num_splits = 2))) 109 | }) 110 | 111 | 112 | d <- dplyr::tibble( 113 | uid = paste0("uid is ", as.character(1:n)), 114 | cov1 = sample(rep(1:2, c(n / 2, n / 2)), n, replace = TRUE), 115 | cov2 = sample(rep(1:2, c(n / 2, n / 2)), n, replace = TRUE), 116 | cov3 = rnorm(n), 117 | cov4 = runif(n) 118 | ) 119 | 120 | test_that("lots of covariates informing the strata", { 121 | newd <- make_splits(d, uid, cov1, cov2, cov3, cov4, .num_splits = 2) 122 | checkmate::expect_data_frame( 123 | newd, 124 | any.missing = FALSE, 125 | nrows = 100, 126 | ncols = 6, 127 | col.names = "named" 128 | ) 129 | expect_true(".split_id" %in% names(newd)) 130 | }) 131 | 132 | d <- dplyr::tibble( 133 | uid = rep(paste0("uid is ", as.character(1:(n / 2))), 2), 134 | cov1 = sample(rep(1:2, c(n / 2, n / 2)), n, replace = TRUE), 135 | cov2 = sample(rep(1:2, c(n / 2, n / 2)), n, replace = TRUE), 136 | cov3 = rnorm(n), 137 | cov4 = runif(n) 138 | ) 139 | 140 | test_that("clustered data", { 141 | newd <- make_splits(d, uid, cov1, cov2, cov3, cov4, .num_splits = 2) 142 | checkmate::expect_data_frame( 143 | newd, 144 | any.missing = FALSE, 145 | nrows = 100, 146 | ncols = 6, 147 | col.names = "named" 148 | ) 149 | expect_true(".split_id" %in% names(newd)) 150 | }) 151 | 152 | test_that("splitting works when quickblock isn't installed", { 153 | expect_message( 154 | mockr::with_mock( 155 | package_present = function(x) FALSE, 156 | { 157 | make_splits(d, uid, cov1, .num_splits = 4) 158 | } 159 | ), 160 | "`quickblock` is not installed, so falling back to un-stratified CV." 161 | ) 162 | }) 163 | -------------------------------------------------------------------------------- /tests/testthat/test-api-missing.R: -------------------------------------------------------------------------------- 1 | set.seed(20051920) # 20051920 is derived from 'test' 2 | 3 | n <- 250 4 | data <- dplyr::tibble( 5 | uid = 1:n 6 | ) %>% 7 | dplyr::mutate( 8 | a = rbinom(n, 1, 0.5), 9 | ps = rep(0.5, n), 10 | x1 = rnorm(n), 11 | x2 = factor(sample(1:4, n, prob = c(1 / 5, 1 / 5, 1 / 5, 2 / 5), replace = TRUE)), 12 | x3 = factor(sample(1:3, n, prob = c(1 / 5, 1 / 5, 3 / 5), replace = TRUE)), 13 | x4 = (x1 + rnorm(n)) / 2, 14 | x5 = rnorm(n), 15 | y = a + x1 - 2.5 * a * (x1 - mean(x1)) + as.double(x2) + rnorm(n), 16 | w = rexp(n, plogis(x1 - mean(x1))) 17 | ) 18 | 19 | n_missing_cells <- 25 20 | 21 | for (i in seq_len(n_missing_cells)) { 22 | row <- sample(n, 1) 23 | col <- sample(names(data)[c(2, 4:9)], 1) 24 | data[row, col] <- NA 25 | } 26 | 27 | userid <- rlang::expr(uid) 28 | weight_variable <- rlang::expr(w) 29 | 30 | propensity_score_variable_name <- "ps" 31 | 32 | continuous_covariates <- c("x1") 33 | 34 | discrete_covariates <- c("x2", "x3") 35 | 36 | continuous_moderators <- rlang::exprs(x1) 37 | discrete_moderators <- rlang::exprs(x2, x3) 38 | 39 | model_covariate_names <- c(continuous_covariates, discrete_covariates) 40 | model_covariates <- rlang::syms(model_covariate_names) 41 | 42 | outcome_variable <- rlang::expr(y) 43 | treatment_variable <- rlang::expr(a) 44 | 45 | trt.cfg <- SLEnsemble_cfg$new( 46 | learner_cfgs = list( 47 | SLLearner_cfg$new( 48 | "SL.glm" 49 | ) 50 | ), 51 | family = stats::quasibinomial() 52 | ) 53 | 54 | regression.cfg <- SLEnsemble_cfg$new( 55 | learner_cfgs = list( 56 | SLLearner_cfg$new( 57 | "SL.glm" 58 | ) 59 | ) 60 | ) 61 | 62 | qoi.list <- list() 63 | for (cov in continuous_moderators) { 64 | qoi.list[[rlang::as_string(cov)]] <- KernelSmooth_cfg$new(neval = 100) 65 | } 66 | for (cov in discrete_moderators) { 67 | qoi.list[[rlang::as_string(cov)]] <- Stratified_cfg$new(cov) 68 | } 69 | 70 | qoi.cfg <- QoI_cfg$new( 71 | mcate = MCATE_cfg$new(cfgs = qoi.list), 72 | vimp = VIMP_cfg$new(), 73 | diag = Diagnostics_cfg$new( 74 | outcome = c("SL_risk", "SL_coefs", "MSE", "RROC"), 75 | ps = c("SL_risk", "SL_coefs", "AUC"), 76 | params = list(num_bins = 100) 77 | ), 78 | ) 79 | 80 | cfg <- HTE_cfg$new( 81 | treatment = trt.cfg, 82 | outcome = regression.cfg, 83 | effect = regression.cfg, 84 | qoi = qoi.cfg 85 | ) 86 | 87 | E <- new.env(parent = emptyenv()) 88 | 89 | test_that("add config", { 90 | E$data1 <- attach_config(data, cfg) 91 | checkmate::expect_data_frame(E$data1) 92 | expect_true("HTE_cfg" %in% names(attributes(E$data1))) 93 | }) 94 | 95 | test_that("Split data", { 96 | suppressMessages(E$data2 <- make_splits(E$data1, {{ userid }}, .num_splits = 4)) 97 | checkmate::expect_data_frame(E$data2) 98 | }) 99 | 100 | test_that("Estimate Plugin Models", { 101 | suppressMessages(E$data3 <- produce_plugin_estimates( 102 | E$data2, 103 | {{ outcome_variable }}, 104 | {{ treatment_variable }}, 105 | !!!model_covariates, 106 | .weights = {{ weight_variable }} 107 | )) 108 | checkmate::expect_data_frame(E$data3) 109 | }) 110 | 111 | test_that("Construct Pseudo-outcomes", { 112 | suppressMessages( 113 | E$data4 <- construct_pseudo_outcomes( 114 | E$data3, {{ outcome_variable }}, {{ treatment_variable }} 115 | ) 116 | ) 117 | checkmate::expect_data_frame(E$data4) 118 | }) 119 | 120 | test_that("Estimate QoIs (continuous)", { 121 | expect_error( 122 | suppressMessages(estimate_QoI(E$data4, !!!continuous_moderators)), 123 | "`nprobust` does not support the use of weights." 124 | ) 125 | }) 126 | 127 | test_that("Estimate QoIs (discrete)", { 128 | suppressMessages(E$results <- estimate_QoI(E$data4, !!!discrete_moderators)) 129 | checkmate::expect_data_frame(E$results) 130 | }) 131 | 132 | n_rows <- ( 133 | 1 + # SATE estimate 134 | 1 + # PATE estimate 135 | 2 + # MSE for y(0) & y(1) 136 | 1 + # AUC for pscore 137 | 2 * 1 + 1 + # one row per model in the ensemble for each PO + ps for SL risk 138 | 2 * 1 + 1 + # one row per model in the ensemble for each PO + ps for SL coefficient 139 | 2 + # one row per moderator for variable importance 140 | # 1 * 100 + # 100 rows per continuous moderator for local regression for MCATE 141 | (4 + 3) + # 1 rows per discrete moderator level for MCATE 142 | 2 * 100 # 100 rows for RROC for each PO 143 | ) 144 | 145 | test_that("PATE > SATE", { 146 | # PATE has larger treatment effects (because weight is correlated with moderator) 147 | pate <- E$results %>% dplyr::filter(grepl("PATE", estimand)) %>% select(estimate) %>% unlist() 148 | sate <- E$results %>% dplyr::filter(grepl("SATE", estimand)) %>% select(estimate) %>% unlist() 149 | expect_gt(pate, sate) 150 | 151 | # Weighting makes the standard error larger 152 | pate <- E$results %>% dplyr::filter(grepl("PATE", estimand)) %>% select(std_error) %>% unlist() 153 | sate <- E$results %>% dplyr::filter(grepl("SATE", estimand)) %>% select(std_error) %>% unlist() 154 | expect_gt(pate, sate) 155 | }) 156 | 157 | test_that("Check results data", { 158 | checkmate::check_character(E$results$estimand, any.missing = FALSE) 159 | checkmate::check_double(E$results$estimate, any.missing = FALSE) 160 | checkmate::check_double(E$results$std_error, any.missing = FALSE) 161 | 162 | checkmate::expect_tibble( 163 | E$results, 164 | all.missing = FALSE, 165 | nrows = n_rows, 166 | ncols = 6, 167 | types = c( 168 | estimand = "character", 169 | term = "character", 170 | value = "numeric", 171 | level = "character", 172 | estimate = "double", 173 | std_error = "double" 174 | ) 175 | ) 176 | }) 177 | -------------------------------------------------------------------------------- /man/HTE_cfg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/meta_cfg.R 3 | \name{HTE_cfg} 4 | \alias{HTE_cfg} 5 | \title{Configuration of Quantities of Interest} 6 | \description{ 7 | \code{HTE_cfg} is a configuration class that pulls everything together, indicating 8 | the full configuration for a given HTE analysis. This includes how to estimate 9 | models and what Quantities of Interest to calculate based off those underlying models. 10 | } 11 | \examples{ 12 | 13 | ## ------------------------------------------------ 14 | ## Method `HTE_cfg$new` 15 | ## ------------------------------------------------ 16 | 17 | mcate_cfg <- MCATE_cfg$new(cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100))) 18 | pcate_cfg <- PCATE_cfg$new( 19 | cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100)), 20 | model_covariates = c("x1", "x2", "x3"), 21 | num_mc_samples = list(x1 = 100) 22 | ) 23 | vimp_cfg <- VIMP_cfg$new() 24 | diag_cfg <- Diagnostics_cfg$new( 25 | outcome = c("SL_risk", "SL_coefs", "MSE"), 26 | ps = c("SL_risk", "SL_coefs", "AUC") 27 | ) 28 | qoi_cfg <- QoI_cfg$new( 29 | mcate = mcate_cfg, 30 | pcate = pcate_cfg, 31 | vimp = vimp_cfg, 32 | diag = diag_cfg 33 | ) 34 | ps_cfg <- SLEnsemble_cfg$new( 35 | learner_cfgs = list(SLLearner_cfg$new("SL.glm"), SLLearner_cfg$new("SL.gam")) 36 | ) 37 | y_cfg <- SLEnsemble_cfg$new( 38 | learner_cfgs = list(SLLearner_cfg$new("SL.glm"), SLLearner_cfg$new("SL.gam")) 39 | ) 40 | fx_cfg <- SLEnsemble_cfg$new( 41 | learner_cfgs = list(SLLearner_cfg$new("SL.glm"), SLLearner_cfg$new("SL.gam")) 42 | ) 43 | HTE_cfg$new(outcome = y_cfg, treatment = ps_cfg, effect = fx_cfg, qoi = qoi_cfg) 44 | } 45 | \section{Public fields}{ 46 | \if{html}{\out{
}} 47 | \describe{ 48 | \item{\code{outcome}}{\code{Model_cfg} object indicating how outcome models should be estimated.} 49 | 50 | \item{\code{treatment}}{\code{Model_cfg} object indicating how the propensity score 51 | model should be estimated.} 52 | 53 | \item{\code{effect}}{\code{Model_cfg} object indicating how the joint effect model 54 | should be estimated.} 55 | 56 | \item{\code{qoi}}{\code{QoI_cfg} object indicating what the Quantities of Interest 57 | are and providing all 58 | necessary detail on how they should be estimated.} 59 | 60 | \item{\code{verbose}}{Logical indicating whether to print debugging information.} 61 | } 62 | \if{html}{\out{
}} 63 | } 64 | \section{Methods}{ 65 | \subsection{Public methods}{ 66 | \itemize{ 67 | \item \href{#method-HTE_cfg-new}{\code{HTE_cfg$new()}} 68 | \item \href{#method-HTE_cfg-clone}{\code{HTE_cfg$clone()}} 69 | } 70 | } 71 | \if{html}{\out{
}} 72 | \if{html}{\out{}} 73 | \if{latex}{\out{\hypertarget{method-HTE_cfg-new}{}}} 74 | \subsection{Method \code{new()}}{ 75 | Create a new \code{HTE_cfg} object with all necessary information about how 76 | to carry out an HTE analysis. 77 | \subsection{Usage}{ 78 | \if{html}{\out{
}}\preformatted{HTE_cfg$new( 79 | outcome = NULL, 80 | treatment = NULL, 81 | effect = NULL, 82 | qoi = NULL, 83 | verbose = FALSE 84 | )}\if{html}{\out{
}} 85 | } 86 | 87 | \subsection{Arguments}{ 88 | \if{html}{\out{
}} 89 | \describe{ 90 | \item{\code{outcome}}{\code{Model_cfg} object indicating how outcome models should 91 | be estimated.} 92 | 93 | \item{\code{treatment}}{\code{Model_cfg} object indicating how the propensity score 94 | model should be estimated.} 95 | 96 | \item{\code{effect}}{\code{Model_cfg} object indicating how the joint effect model 97 | should be estimated.} 98 | 99 | \item{\code{qoi}}{\code{QoI_cfg} object indicating what the Quantities of Interest 100 | are and providing all 101 | necessary detail on how they should be estimated.} 102 | 103 | \item{\code{verbose}}{Logical indicating whether to print debugging information.} 104 | } 105 | \if{html}{\out{
}} 106 | } 107 | \subsection{Examples}{ 108 | \if{html}{\out{
}} 109 | \preformatted{mcate_cfg <- MCATE_cfg$new(cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100))) 110 | pcate_cfg <- PCATE_cfg$new( 111 | cfgs = list(x1 = KernelSmooth_cfg$new(neval = 100)), 112 | model_covariates = c("x1", "x2", "x3"), 113 | num_mc_samples = list(x1 = 100) 114 | ) 115 | vimp_cfg <- VIMP_cfg$new() 116 | diag_cfg <- Diagnostics_cfg$new( 117 | outcome = c("SL_risk", "SL_coefs", "MSE"), 118 | ps = c("SL_risk", "SL_coefs", "AUC") 119 | ) 120 | qoi_cfg <- QoI_cfg$new( 121 | mcate = mcate_cfg, 122 | pcate = pcate_cfg, 123 | vimp = vimp_cfg, 124 | diag = diag_cfg 125 | ) 126 | ps_cfg <- SLEnsemble_cfg$new( 127 | learner_cfgs = list(SLLearner_cfg$new("SL.glm"), SLLearner_cfg$new("SL.gam")) 128 | ) 129 | y_cfg <- SLEnsemble_cfg$new( 130 | learner_cfgs = list(SLLearner_cfg$new("SL.glm"), SLLearner_cfg$new("SL.gam")) 131 | ) 132 | fx_cfg <- SLEnsemble_cfg$new( 133 | learner_cfgs = list(SLLearner_cfg$new("SL.glm"), SLLearner_cfg$new("SL.gam")) 134 | ) 135 | HTE_cfg$new(outcome = y_cfg, treatment = ps_cfg, effect = fx_cfg, qoi = qoi_cfg) 136 | } 137 | \if{html}{\out{
}} 138 | 139 | } 140 | 141 | } 142 | \if{html}{\out{
}} 143 | \if{html}{\out{}} 144 | \if{latex}{\out{\hypertarget{method-HTE_cfg-clone}{}}} 145 | \subsection{Method \code{clone()}}{ 146 | The objects of this class are cloneable with this method. 147 | \subsection{Usage}{ 148 | \if{html}{\out{
}}\preformatted{HTE_cfg$clone(deep = FALSE)}\if{html}{\out{
}} 149 | } 150 | 151 | \subsection{Arguments}{ 152 | \if{html}{\out{
}} 153 | \describe{ 154 | \item{\code{deep}}{Whether to make a deep clone.} 155 | } 156 | \if{html}{\out{
}} 157 | } 158 | } 159 | } 160 | --------------------------------------------------------------------------------