├── .Rbuildignore ├── .github ├── .gitignore └── workflows │ ├── R-CMD-check.yaml │ └── pkgdown.yaml ├── .gitignore ├── DESCRIPTION ├── LICENSE ├── LICENSE.md ├── NAMESPACE ├── NEWS.md ├── R ├── adjust_weights.R ├── adjustr-package.R ├── logprob.R ├── make_spec.R ├── parsing.R └── use_weights.R ├── README.md ├── _pkgdown.yml ├── adjustr.Rproj ├── codecov.yml ├── man ├── adjust_weights.Rd ├── adjustr-package.Rd ├── as.data.frame.adjustr_spec.Rd ├── dplyr.adjustr_spec.Rd ├── extract_samp_stmts.Rd ├── figures │ └── logo.png ├── get_resampling_idxs.Rd ├── make_spec.Rd ├── pull.adjustr_weighted.Rd ├── spec_plot.Rd └── summarize.adjustr_weighted.Rd ├── pkgdown └── favicon │ ├── apple-touch-icon-120x120.png │ ├── apple-touch-icon-152x152.png │ ├── apple-touch-icon-180x180.png │ ├── apple-touch-icon-60x60.png │ ├── apple-touch-icon-76x76.png │ ├── apple-touch-icon.png │ ├── favicon-16x16.png │ ├── favicon-32x32.png │ └── favicon.ico ├── tests ├── ex.stan ├── test_model.rda ├── testthat.R └── testthat │ ├── setup.R │ ├── test_logprob.R │ ├── test_parsing.R │ ├── test_spec.R │ ├── test_use.R │ └── test_weights.R └── vignettes ├── .gitignore ├── adjustr.Rmd ├── adjustr.bib └── eightschools_model.rda /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^.*\.Rproj$ 2 | ^\.Rproj\.user$ 3 | ^R/mockup.R$ 4 | ^dp_.*\.so$ 5 | ^_pkgdown\.yml$ 6 | ^docs$ 7 | ^pkgdown$ 8 | ^codecov\.yml$ 9 | ^\.travis\.yml$ 10 | ^doc$ 11 | ^Meta$ 12 | ^\.github$ 13 | ^LICENSE\.md$ 14 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # For help debugging build failures open an issue on the RStudio community with the 'github-actions' tag. 2 | # https://community.rstudio.com/new-topic?category=Package%20development&tags=github-actions 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | pull_request: 9 | branches: 10 | - main 11 | - master 12 | 13 | name: R-CMD-check 14 | 15 | jobs: 16 | R-CMD-check: 17 | runs-on: ${{ matrix.config.os }} 18 | 19 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 20 | 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | config: 25 | - {os: windows-latest, r: 'release'} 26 | - {os: macOS-latest, r: 'release'} 27 | - {os: ubuntu-20.04, r: 'release', rspm: "https://packagemanager.rstudio.com/cran/__linux__/focal/latest"} 28 | - {os: ubuntu-20.04, r: 'devel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/focal/latest"} 29 | 30 | env: 31 | R_REMOTES_NO_ERRORS_FROM_WARNINGS: true 32 | RSPM: ${{ matrix.config.rspm }} 33 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 34 | 35 | steps: 36 | - uses: actions/checkout@v2 37 | 38 | - uses: r-lib/actions/setup-r@v1 39 | with: 40 | r-version: ${{ matrix.config.r }} 41 | 42 | - uses: r-lib/actions/setup-pandoc@v1 43 | 44 | - name: Query dependencies 45 | run: | 46 | install.packages('remotes') 47 | saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2) 48 | writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version") 49 | shell: Rscript {0} 50 | 51 | - name: Restore R package cache 52 | if: runner.os != 'Windows' 53 | uses: actions/cache@v2 54 | with: 55 | path: ${{ env.R_LIBS_USER }} 56 | key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }} 57 | restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1- 58 | 59 | - name: Install system dependencies 60 | if: runner.os == 'Linux' 61 | run: | 62 | while read -r cmd 63 | do 64 | eval sudo $cmd 65 | done < <(Rscript -e 'writeLines(remotes::system_requirements("ubuntu", "20.04"))') 66 | 67 | - name: Install dependencies 68 | run: | 69 | remotes::install_deps(dependencies = TRUE) 70 | remotes::install_cran("rcmdcheck") 71 | shell: Rscript {0} 72 | 73 | - name: Check 74 | env: 75 | _R_CHECK_CRAN_INCOMING_REMOTE_: false 76 | run: | 77 | options(crayon.enabled = TRUE) 78 | rcmdcheck::rcmdcheck(args = c("--no-manual", "--as-cran"), error_on = "warning", check_dir = "check") 79 | shell: Rscript {0} 80 | 81 | - name: Upload check results 82 | if: failure() 83 | uses: actions/upload-artifact@main 84 | with: 85 | name: ${{ runner.os }}-r${{ matrix.config.r }}-results 86 | path: check 87 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | - master 6 | 7 | name: pkgdown 8 | 9 | jobs: 10 | pkgdown: 11 | runs-on: macOS-latest 12 | env: 13 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 14 | steps: 15 | - uses: actions/checkout@v2 16 | 17 | - uses: r-lib/actions/setup-r@v1 18 | 19 | - uses: r-lib/actions/setup-pandoc@v1 20 | 21 | - name: Query dependencies 22 | run: | 23 | install.packages('remotes') 24 | saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2) 25 | writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version") 26 | shell: Rscript {0} 27 | 28 | - name: Restore R package cache 29 | uses: actions/cache@v2 30 | with: 31 | path: ${{ env.R_LIBS_USER }} 32 | key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }} 33 | restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1- 34 | 35 | - name: Install dependencies 36 | run: | 37 | remotes::install_deps(dependencies = TRUE) 38 | install.packages("pkgdown", type = "binary") 39 | shell: Rscript {0} 40 | 41 | - name: Install package 42 | run: R CMD INSTALL . 43 | 44 | - name: Deploy package 45 | run: | 46 | git config --local user.email "actions@github.com" 47 | git config --local user.name "GitHub Actions" 48 | Rscript -e 'pkgdown::deploy_to_branch(new_process = FALSE)' 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rhistory 2 | .RData 3 | .Rproj.user/* 4 | .DS_Store 5 | *.so 6 | *.tmp 7 | *.bak 8 | 9 | .Rproj.user 10 | inst/doc 11 | doc 12 | Meta 13 | docs 14 | /doc/ 15 | /Meta/ 16 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: adjustr 2 | Encoding: UTF-8 3 | Type: Package 4 | Title: Stan Model Adjustments and Sensitivity Analyses using Importance Sampling 5 | Version: 0.1.2 6 | Authors@R: person("Cory", "McCartan", email = "cmccartan@g.harvard.edu", 7 | role = c("aut", "cre")) 8 | Description: Functions to help assess the sensitivity of a Bayesian model 9 | (fitted using the rstan package) to the specification of its likelihood and 10 | priors. Users provide a series of alternate sampling specifications, and the 11 | package uses Pareto-smoothed importance sampling to estimate posterior 12 | quantities of interest under each specification. 13 | License: MIT + file LICENSE 14 | Depends: R (>= 3.6.0), 15 | dplyr (>= 1.0.0) 16 | Imports: 17 | rlang, 18 | tidyselect, 19 | purrr, 20 | stringr, 21 | rstan, 22 | loo 23 | Suggests: 24 | ggplot2, 25 | extraDistr, 26 | tidyr, 27 | testthat, 28 | covr, 29 | knitr, 30 | rmarkdown 31 | URL: https://corymccartan.github.io/adjustr/ 32 | BugReports: https://github.com/CoryMcCartan/adjustr/issues 33 | LazyData: true 34 | RoxygenNote: 7.2.0 35 | VignetteBuilder: knitr 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2021 2 | COPYRIGHT HOLDER: Cory McCartan 3 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2021 Cory McCartan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method(arrange,adjustr_spec) 4 | S3method(as.data.frame,adjustr_spec) 5 | S3method(length,adjustr_spec) 6 | S3method(print,adjustr_spec) 7 | S3method(pull,adjustr_weighted) 8 | S3method(rename,adjustr_spec) 9 | S3method(select,adjustr_spec) 10 | S3method(slice,adjustr_spec) 11 | S3method(summarise,adjustr_weighted) 12 | S3method(summarize,adjustr_weighted) 13 | export(adjust_weights) 14 | export(extract_samp_stmts) 15 | export(get_resampling_idxs) 16 | export(make_spec) 17 | export(spec_plot) 18 | import(dplyr) 19 | import(rlang) 20 | importFrom(purrr,map) 21 | importFrom(purrr,map2) 22 | importFrom(purrr,map_chr) 23 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # adjustr 0.1.3 2 | 3 | # adjustr 0.1.2 4 | 5 | * Add support for `cmdstanr` objects by passing a list containing the fit and the model object. 6 | 7 | * Fix bug in parsing code that caused an error with some `target +=` model statements. 8 | 9 | 10 | # adjustr 0.1.1 11 | 12 | * Improved documentation and additional references. 13 | 14 | * Fix bug in `extract_samp_stmts()` which prevented `brmsfit` objects from being used directly. 15 | 16 | 17 | # adjustr 0.1.0 18 | 19 | * Initial release. 20 | 21 | * Basic workflow implemented: `make_spec()`, `adjust_weights()`, and `summarize()`/`spec_plot()`. -------------------------------------------------------------------------------- /R/adjust_weights.R: -------------------------------------------------------------------------------- 1 | #' Compute Pareto-smoothed Importance Weights for Alternative Model 2 | #' Specifications 3 | #' 4 | #' Given a set of new sampling statements, which can be parametrized by a 5 | #' data frame or list, compute Pareto-smoothed importance weights and attach 6 | #' them to the specification object, for further calculation and plotting. 7 | #' 8 | #' This function does the bulk of the sensitivity analysis work. It operates 9 | #' by parsing the model code from the provided Stan object, extracting the 10 | #' parameters and their sampling statements. It then uses R 11 | #' metaprogramming/tidy evaluation tools to flexibly evaluate the log density 12 | #' for each draw and each sampling statement, under the original and alternative 13 | #' specifications. From these, the function computes the overall importance 14 | #' weight for each draw and performs Pareto-smoothed importance sampling. All 15 | #' of the work is performed in R, without recompiling or refitting the Stan 16 | #' model. 17 | #' 18 | #' @param spec An object of class \code{adjustr_spec}, probably produced by 19 | #' \code{\link{make_spec}}, containing the new sampling sampling statements 20 | #' to replace their counterparts in the original Stan model, and the data, 21 | #' if any, by which these sampling statements are parametrized. 22 | #' @param object A model object, either of type \code{\link[rstan]{stanfit}}, 23 | #' \code{\link[rstanarm:stanreg-objects]{stanreg}}, \code{\link[brms]{brmsfit}}, 24 | #' or a list with two elements: \code{model} containing a 25 | #' \code{\link[cmdstanr]{CmdStanModel}}, and \code{fit} containing a 26 | #' \code{\link[cmdstanr]{CmdStanMCMC}} object. 27 | #' @param data The data that was used to fit the model in \code{object}. 28 | #' Required only if one of the new sampling specifications involves Stan data 29 | #' variables. 30 | #' @param keep_bad When \code{FALSE} (the strongly recommended default), 31 | #' alternate specifications which deviate too much from the original 32 | #' posterior, and which as a result cannot be reliably estimated using 33 | #' importance sampling (i.e., if the Pareto shape parameter is larger than 34 | #' 0.7), have their weights discarded—weights are set to \code{NA_real_}. 35 | #' @param incl_orig When \code{TRUE}, include a row for the original 36 | #' model specification, with all weights equal. Can facilitate comparison 37 | #' and plotting later. 38 | #' 39 | #' @return A tibble, produced by converting the provided \code{specs} to a 40 | #' tibble (see \code{\link{as.data.frame.adjustr_spec}}), and adding columns 41 | #' \code{.weights}, containing vectors of weights for each draw, and 42 | #' \code{.pareto_k}, containing the diagnostic Pareto shape parameters. Values 43 | #' greater than 0.7 indicate that importance sampling is not reliable. 44 | #' If \code{incl_orig} is \code{TRUE}, a row is added for the original model 45 | #' specification. Weights can be extracted with the 46 | #' \code{\link{pull.adjustr_weighted}} method. The returned object also 47 | #' includes the model sample draws, in the \code{draws} attribute. 48 | #' 49 | #' @references 50 | #' Vehtari, A., Simpson, D., Gelman, A., Yao, Y., & Gabry, J. (2015). 51 | #' Pareto smoothed importance sampling. \href{https://arxiv.org/abs/1507.02646}{arXiv preprint arXiv:1507.02646}. 52 | #' 53 | #' @seealso \code{\link{make_spec}}, \code{\link{summarize.adjustr_weighted}}, \code{\link{spec_plot}} 54 | #' 55 | #' @examples \dontrun{ 56 | #' model_data = list( 57 | #' J = 8, 58 | #' y = c(28, 8, -3, 7, -1, 1, 18, 12), 59 | #' sigma = c(15, 10, 16, 11, 9, 11, 10, 18) 60 | #' ) 61 | #' 62 | #' spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10) 63 | #' adjust_weights(spec, eightschools_m) 64 | #' adjust_weights(spec, eightschools_m, keep_bad=TRUE) 65 | #' 66 | #' spec = make_spec(y ~ student_t(df, theta, sigma), df=1:10) 67 | #' adjust_weights(spec, eightschools_m, data=model_data) 68 | #' # will throw an error because `y` and `sigma` aren't provided 69 | #' adjust_weights(spec, eightschools_m) 70 | #' } 71 | #' 72 | #' @export 73 | adjust_weights = function(spec, object, data=NULL, keep_bad=FALSE, incl_orig=TRUE) { 74 | # CHECK ARGUMENTS 75 | if (is.null(data) & inherits(object, "brmsfit")) 76 | data = object$data 77 | object = get_fit_obj(object) 78 | stopifnot(is.adjustr_spec(spec)) 79 | 80 | parsed = parse_model(object@stanmodel@model_code) 81 | 82 | # if no model data provided, we can only resample distributions of parameters 83 | if (is.null(data)) { 84 | samp_vars = map(parsed$samp, ~ deparse(f_lhs(.))) %>% 85 | purrr::as_vector() 86 | prior_vars = parsed$vars[samp_vars] != "data" 87 | parsed$samp = parsed$samp[prior_vars] 88 | data = list() 89 | } 90 | 91 | matched_samp = match_sampling_stmts(spec$samp, parsed$samp) 92 | original_lp = calc_original_lp(object, matched_samp, parsed$vars, data) 93 | specs_lp = calc_specs_lp(object, spec$samp, parsed$vars, data, spec$params) 94 | 95 | # compute weights 96 | wgts = map(specs_lp, function(spec_lp) { 97 | lratio = spec_lp - original_lp 98 | dim(lratio) = c(dim(lratio), 1) 99 | r_eff = loo::relative_eff(as.array(exp(-lratio))) 100 | psis_wgt = suppressWarnings(loo::psis(lratio, r_eff=r_eff)) 101 | pareto_k = loo::pareto_k_values(psis_wgt) 102 | if (all(psis_wgt$log_weights == psis_wgt$log_weights[1])) { 103 | warning("New specification equal to old specification.", call.=FALSE) 104 | pareto_k = -Inf 105 | } 106 | 107 | list( 108 | weights = loo::weights.importance_sampling(psis_wgt, log=FALSE), 109 | pareto_k = pareto_k 110 | ) 111 | }) 112 | 113 | adjust_obj = as_tibble(spec) 114 | class(adjust_obj) = c("adjustr_weighted", class(adjust_obj)) 115 | adjust_obj$.weights = map(wgts, ~ as.numeric(.$weights)) 116 | adjust_obj$.pareto_k = purrr::map_dbl(wgts, ~ .$pareto_k) 117 | if (!keep_bad) 118 | adjust_obj$.weights[adjust_obj$.pareto_k > 0.7] = list(NA_real_) 119 | attr(adjust_obj, "draws") = rstan::extract(object) 120 | attr(adjust_obj, "data") = data 121 | attr(adjust_obj, "iter") = object@sim$chains * (object@sim$iter - object@sim$warmup) 122 | if (incl_orig) { 123 | adjust_obj = bind_rows(adjust_obj, tibble( 124 | .weights=list(rep(1, attr(adjust_obj, "iter"))), 125 | .pareto_k = -Inf)) 126 | samp_cols = stringr::str_detect(names(adjust_obj), "\\.samp") 127 | adjust_obj[nrow(adjust_obj), samp_cols] = "" 128 | } 129 | 130 | adjust_obj 131 | } 132 | 133 | 134 | # Generic methods 135 | is.adjustr_weighted = function(x) inherits(x, "adjustr_weighted") 136 | #' Extract Weights From an \code{adjustr_weighted} Object 137 | #' 138 | #' This function modifies the default behavior of \code{dplyr::pull} to extract 139 | #' the \code{.weights} column. 140 | #' 141 | #' @param .data A table of data 142 | #' @param var A variable, as in \code{\link[dplyr]{pull}}. The default returns 143 | #' the \code{.weights} column, and if there is only one row, it returns the 144 | #' first element of that column 145 | #' @param name Ignored 146 | #' @param ... Ignored 147 | #' 148 | #' @export 149 | pull.adjustr_weighted = function(.data, var=".weights", name=NULL, ...) { 150 | var = tidyselect::vars_pull(names(.data), !!enquo(var)) 151 | if (nrow(.data) == 1 && var == ".weights") { 152 | .data$.weights[[1]] 153 | } else { 154 | .data[[var]] 155 | } 156 | } 157 | 158 | #' Extract Model Sampling Statements From a Stan Model. 159 | #' 160 | #' Prints a list of sampling statements extracted from the \code{model} block of 161 | #' a Stan program, with each labelled "parameter" or "data" depending on the 162 | #' type of variable being sampled. 163 | #' 164 | #' @param object A \code{\link[rstan]{stanfit}} model object. 165 | #' 166 | #' @return Invisibly returns a list of sampling formulas. 167 | #' 168 | #' @examples \dontrun{ 169 | #' extract_samp_stmts(eightschools_m) 170 | #' #> Sampling statements for model 2c8d1d8a30137533422c438f23b83428: 171 | #' #> parameter eta ~ std_normal() 172 | #' #> data y ~ normal(theta, sigma) 173 | #' } 174 | #' @export 175 | extract_samp_stmts = function(object) { 176 | object = get_fit_obj(object) 177 | 178 | parsed = parse_model(object@stanmodel@model_code) 179 | 180 | samp_vars = map_chr(parsed$samp, ~ rlang::expr_text(f_lhs(.))) 181 | samp_var_names = stringr::str_replace(samp_vars, "\\[.+\\]", "") 182 | type = map_chr(samp_var_names, function(var) { 183 | if (stringr::str_ends(parsed$vars[var], "data")) "data" else "parameter" 184 | }) 185 | print_order = order(type, samp_vars, decreasing=c(TRUE, FALSE), method="radix") 186 | 187 | cat(paste0("Sampling statements for model ", object@model_name, ":\n")) 188 | purrr::walk(print_order, ~ cat(sprintf(" %-9s %s\n", type[.], as.character(parsed$samp[.])))) 189 | invisible(parsed$samp) 190 | } 191 | 192 | # Check that the model object is correct, and put it into a convenient format 193 | get_fit_obj = function(object) { 194 | if (inherits(object, "stanfit")) { 195 | object 196 | } else if (inherits(object, "stanreg")) { 197 | object$stanfit 198 | } else if (inherits(object, "brmsfit")) { 199 | object$fit 200 | } else if (inherits(object, "list") && all(c("fit", "model") %in% names(object)) 201 | && inherits(object$fit, "CmdStanMCMC") && inherits(object$model, "CmdStanModel")) { 202 | out = rstan::read_stan_csv(object$fit$output_files()) 203 | out@stanmodel@model_code = paste0(object$model$code(), collapse="\n") 204 | out 205 | } else { 206 | stop("`object` must be of class `stanfit`, `stanreg`, `brmsfit`, or ", 207 | "a list with `CmdStanModel` and `CmdStanMCMC` objects.") 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /R/adjustr-package.R: -------------------------------------------------------------------------------- 1 | #' adjustr: Stan Model Adjustments and Sensitivity Analyses using Importance 2 | #' Sampling 3 | #' 4 | #' 5 | #' Functions to help assess the sensitivity of a Bayesian model to the 6 | #' specification of its likelihood and priors, estimated using the rstan 7 | #' package. Users provide a series of alternate sampling specifications, and the 8 | #' package uses Pareto-smoothed importance sampling to estimate posterior 9 | #' quantities of interest under each specification. 10 | #' 11 | #' See the list of key functions and the example below. 12 | #' Full package documentation available at \url{https://corymccartan.github.io/adjustr/}. 13 | #' 14 | #' @section Key Functions: 15 | #' \itemize{ 16 | #' \item \code{\link{make_spec}} 17 | #' \item \code{\link{adjust_weights}} 18 | #' \item \code{\link{summarize}} 19 | #' \item \code{\link{plot}} 20 | #' } 21 | #' 22 | #' @import rlang 23 | #' @importFrom purrr map_chr map map2 24 | #' @import dplyr 25 | #' 26 | #' @docType package 27 | #' @name adjustr-package 28 | NULL 29 | 30 | # internal; to store shared package objects 31 | pkg_env = new_environment() 32 | 33 | .onLoad = function(libname, pkgname) { # nocov start 34 | # create the Stan parser 35 | #tryCatch(get_parser(), error = function(e) {}) 36 | 37 | utils::globalVariables(c("name", "pos", "value", ".y", ".y_ol", ".y_ou", 38 | ".y_il", ".y_iu", ".y_med", "quantile", "median")) 39 | 40 | # Grab even more distributions from `extraDistr` if available 41 | distrs_onload() 42 | } # nocov end 43 | #> NULL -------------------------------------------------------------------------------- /R/logprob.R: -------------------------------------------------------------------------------- 1 | # Pass this function a density function. 2 | # It will return a curried version which first takes the points at 3 | # which to evaluate the density, and then takes the distribution parameters 4 | make_dens = function(f) { 5 | function(x) { 6 | function(...) { 7 | f(x, ..., log=TRUE) 8 | } 9 | } 10 | } 11 | 12 | # Given a sampling formula and a `vars_data` object containing (1) model draws, 13 | # (2) model input data (if applicable), and (3) prior specification data (if 14 | # applicable), compute the log probability of the prior distribution for each 15 | # MCMC draw of the parameter of interest 16 | calc_lp = function(samp, vars_data) { 17 | # plug in RHS sampling distribution 18 | distr = tryCatch(call_fn(f_rhs(samp), distr_env), 19 | error=function(e) 20 | stop("Distribution ", as.character(samp)[3], " not supported.")) 21 | # plug in LHS, then RHS values. eval_tidy will throw error if no matches found 22 | distr = distr(eval_tidy(f_lhs(samp), vars_data)) 23 | params = map(call_args(f_rhs(samp)), eval_tidy, vars_data) 24 | 25 | apply(exec(distr, !!!params), 1:2, sum) 26 | } 27 | 28 | # Given a list of sampling formulas, compile the necessary data from 29 | # `object` MCMC draws and `data` model data 30 | # NOTE: fills data first with MCMC data, so these values will override any 31 | # passed in data 32 | get_base_data = function(object, samps, parsed_vars, data, extra_names=NULL) { 33 | iter = object@sim$iter - object@sim$warmup 34 | chains = object@sim$chains 35 | reshape_data = function(x) { 36 | x = as.array(x) 37 | new_dim = c(iter, chains) 38 | new_x = array(rep(0, prod(new_dim)), dim=new_dim) 39 | apply(new_x, 1:2, function(y) x) %>% 40 | aperm(c(length(dim(x)) + 1:2, 1:length(dim(x)))) 41 | } 42 | 43 | map(samps, function(samp) { 44 | vars = get_stmt_vars(samp) 45 | # vars stored in MCMC draws 46 | vars_inmodel = intersect(vars, names(parsed_vars)) 47 | vars_indraws = vars_inmodel[!stringr::str_ends(parsed_vars[vars_inmodel], "data")] 48 | vars_indata = vars_inmodel[stringr::str_ends(parsed_vars[vars_inmodel], "data")] 49 | # check data vars provided 50 | found = vars_indata %in% names(data) 51 | if (!all(found)) stop(paste(vars_indata[!found], collapse=", "), " not found") 52 | # combine draws and data 53 | base_data = append( 54 | map(vars_indraws, ~ rstan::extract(object, ., permuted=FALSE)) %>% 55 | set_names(vars_indraws), 56 | map(vars_indata, ~ reshape_data(data[[.]])) %>% 57 | set_names(vars_indata), 58 | ) 59 | # check all data found 60 | found = vars %in% c(names(base_data), extra_names) 61 | if (!all(found)) stop(paste(vars[!found], collapse=", "), " not found") 62 | base_data 63 | }) 64 | } 65 | 66 | # Given a `samps` list of samp formula, compute the total log probability 67 | # for each draw of the parameters of interest 68 | calc_original_lp = function(object, samps, parsed_vars, data) { 69 | # figure out what data we need and calculate and sum lp 70 | base_data = get_base_data(object, samps, parsed_vars, data) 71 | purrr::reduce(map2(samps, base_data, calc_lp), `+`) 72 | } 73 | 74 | # Given a `samps` list of samp formula, compute the total log probability for 75 | # each draw of the parameters of interest, for each samp specification given 76 | # in `specs` 77 | calc_specs_lp = function(object, samps, parsed_vars, data, specs) { 78 | base_data = get_base_data(object, samps, parsed_vars, data, names(specs[[1]])) 79 | map(specs, function(spec) { 80 | purrr::reduce(map2(samps, base_data, ~ calc_lp(.x, append(.y, spec))), `+`) 81 | }) 82 | } 83 | 84 | 85 | # Mapping of Stan distribution names to R functions 86 | distrs = list( 87 | bernoulli = function(x, p, ...) dbinom(x, 1, p, ...), 88 | bernoulli_logit = function(x, p, ...) dbinom(x, 1, plogis(p), ...), 89 | binomial = dbinom, 90 | binomial_logit = function(x, n, p, ...) dbinom(x, n, plogis(p), ...), 91 | hypergeometric = dhyper, 92 | poisson = dpois, 93 | poisson_log = function(x, alpha, ...) dpois(x, exp(alpha), ...), 94 | normal = dnorm, 95 | std_normal = dnorm, 96 | student_t = function(x, df, loc, scale, ...) dt((x - loc)/scale, df, ...), 97 | cauchy = dcauchy, 98 | laplace = function(x, mu, sigma, ...) dexp(abs(x - mu), 1/sigma, ...), 99 | logistic = dlogis, 100 | lognormal = dlnorm, 101 | chi_square = dchisq, 102 | exponential = dexp, 103 | gamma = function(x, alpha, beta, ...) dgamma(x, shape=alpha, rate=beta, ...), 104 | weibull = dweibull, 105 | beta = dbeta, 106 | beta_proportion = function(x, mu, k, ...) dbeta(x, mu*k, (1-mu)*k, ...), 107 | uniform = dunif 108 | ) 109 | # Grab even more distributions from `extraDistr` if available (called from .onLoad()) 110 | distrs_onload = function() { 111 | if (requireNamespace("extraDistr", quietly=TRUE)) { 112 | distrs <<- append(distrs, list( 113 | beta_binomial = extraDistr::dbbinom, 114 | categorical = extraDistr::dcat, 115 | gumbel = extraDistr::dgumbel, 116 | inv_chi_square = extraDistr::dinvchisq, 117 | scaled_inv_chi_square = extraDistr::dinvchisq, 118 | inv_gamma = extraDistr::dinvgamma, 119 | frechet = function(x, lambda, sigma, ...) 120 | extraDistr::dfrechet(x, lambda, sigma=sigma, ...), 121 | rayleigh = extraDistr::drayleigh, 122 | pareto = extraDistr::dpareto 123 | )) 124 | } else { 125 | message("`extraDistr` package not found. Install to access more ", 126 | "distributions, like inverse chi-square and beta-binomial.") 127 | } 128 | # Turn mapping into an environment suitable for metaprogramming, 129 | # and turn each density into its curried form (see `make_dens` above) 130 | distr_env <<- new_environment(purrr::map(distrs, make_dens)) 131 | } -------------------------------------------------------------------------------- /R/make_spec.R: -------------------------------------------------------------------------------- 1 | #' Set Up Model Adjustment Specifications 2 | #' 3 | #' Takes a set of new sampling statements, which can be parametrized by other 4 | #' arguments, data frames, or lists, and creates an \code{adjustr_spec} object 5 | #' suitable for use in \code{\link{adjust_weights}}. 6 | #' 7 | #' @param ... Model specification. Each argument can either be a formula, 8 | #' a named vector, data frames, or lists. 9 | #' 10 | #' Formula arguments provide new sampling statements to replace their 11 | #' counterparts in the original Stan model. All such formulas must be of the 12 | #' form \code{variable ~ distribution(parameters)}, where \code{variable} and 13 | #' \code{parameters} are Stan data variables or parameters, or are provided by 14 | #' other arguments to this function (see below), and where \code{distribution} 15 | #' matches one of the univariate 16 | #' \href{https://mc-stan.org/docs/2_22/functions-reference/conventions-for-probability-functions.html}{Stan distributions}. 17 | #' Arithmetic expressions of parameters are also allowed, but care must be 18 | #' taken with multivariate parameter arguments. Since specifications are 19 | #' passed as formulas, R's arithmetic operators are used, not Stan's. As a 20 | #' result, matrix and elementwise multipilcation in Stan sampling statements may 21 | #' not be interpreted correctly. Moving these computations out of sampling 22 | #' statements and into local variables will ensure correct results. 23 | #' 24 | #' For named vector arguments, each entry of the vector will be substituted 25 | #' into the corresponding parameter in the sampling statements. For data 26 | #' frames, each entry in each column will be substituted into the corresponding 27 | #' parameter in the sampling statements. 28 | #' 29 | #' List arguments are coerced to data frames. They can either be lists of named 30 | #' vectors, or lists of lists of single-element named vectors. 31 | #' 32 | #' The lengths of all parameter arguments must be consistent. Named vectors 33 | #' can have length 1 or must have length equal to the number of rows in all 34 | #' data frame arguments and the length of list arguments. 35 | #' 36 | #' @return An object of class \code{adjustr_spec}, which is essentially a list 37 | #' with two elements: \code{samp}, which is a list of sampling formulas, and 38 | #' \code{params}, which is a list of lists of parameters. Core 39 | #' \link[=filter.adjustr_spec]{dplyr verbs} which don't involve grouping 40 | #' (\code{\link[dplyr]{filter}}, \code{\link[dplyr]{arrange}}, 41 | #' \code{\link[dplyr]{mutate}}, \code{\link[dplyr]{select}}, 42 | #' \code{\link[dplyr]{rename}}, and \code{\link[dplyr]{slice}}) are 43 | #' supported and operate on the underlying table of specification parameters. 44 | #' 45 | #' @seealso \code{\link{adjust_weights}}, \code{\link{summarize.adjustr_weighted}}, \code{\link{spec_plot}} 46 | #' 47 | #' @examples 48 | #' make_spec(eta ~ cauchy(0, 1)) 49 | #' 50 | #' make_spec(eta ~ student_t(df, 0, 1), df=1:10) 51 | #' 52 | #' params = tidyr::crossing(df=1:10, infl=c(1, 1.5, 2)) 53 | #' make_spec(eta ~ student_t(df, 0, 1), 54 | #' y ~ normal(theta, infl*sigma), 55 | #' params) 56 | #' 57 | #' @export 58 | make_spec = function(...) { 59 | args = dots_list(..., .check_assign=TRUE) 60 | 61 | spec_samp = purrr::keep(args, is_formula) 62 | if (length(spec_samp) == 0) warning("No sampling statements provided.") 63 | spec_params = purrr::imap(args, function(value, name) { 64 | if (!is.numeric(name) && name != "") { # named arguments are preserved as is 65 | list2(!!name := value) 66 | } else if (is.data.frame(value)) { 67 | as.list(value) 68 | } else if (is_list(value)) { 69 | if (is_list(value[[1]])) { 70 | tryCatch(as.list(do.call(bind_rows, value)), error = function(e) { 71 | stop("List-of-list arguments must be coercible to data frames.") 72 | }) 73 | } else if (is_vector(value[[1]])) { 74 | tryCatch(do.call(bind_cols, value), error = function(e) { 75 | stop("List-of-vector arguments must be coercible to data frames.") 76 | }) 77 | value 78 | } else { 79 | stop("List arguments must be lists of lists or lists of ", 80 | "vectors, and coercible to data frames.") 81 | } 82 | } else if (is_formula(value)) { 83 | NULL 84 | } else { 85 | stop("Arguments must be formulas, named vectors, data frames, ", 86 | "or lists. Use ?make_spec to see documentations.") 87 | } 88 | }) %>% 89 | purrr::compact() %>% # remove NULLS 90 | purrr::flatten() %>% 91 | as_tibble 92 | 93 | if (any(is.na(spec_params))) 94 | stop("NAs found. Check input parameters and format.") 95 | 96 | spec_obj = structure(list( 97 | samp = spec_samp, 98 | params = if (nrow(spec_params) > 0) 99 | purrr::transpose(as.list(spec_params)) 100 | else 101 | list(list()) 102 | ), class="adjustr_spec") 103 | spec_obj 104 | } 105 | 106 | # GENERIC FUNCTIONS for `adjustr_spec` 107 | is.adjustr_spec = function(x) inherits(x, "adjustr_spec") 108 | #' @export 109 | print.adjustr_spec = function(x, ...) { 110 | cat("Sampling specifications:\n") 111 | purrr::walk(x$samp, print) 112 | if (length(x$params[[1]]) > 0) { 113 | cat("\nSpecification parameters:\n") 114 | df = as.data.frame(do.call(rbind, x$params)) 115 | print(df, row.names=FALSE, max=15*ncol(df)) 116 | } 117 | } 118 | #' @export 119 | length.adjustr_spec = function(x) length(x$params) 120 | 121 | #' Convert an \code{adjustr_spec} Object Into a Data Frame 122 | #' 123 | #' Returns the data frame of specification parameters, with added columns of 124 | #' the form \code{.samp_1}, \code{.samp_2}, ... for each sampling statement 125 | #' (or just \code{.samp} if there is only one sampling statement). 126 | #' 127 | #' @param x the \code{adjustr_spec} object 128 | #' @param ... additional arguments to underlying method 129 | #' 130 | #' @export 131 | as.data.frame.adjustr_spec = function(x, ...) { 132 | if (length(x$params[[1]]) == 0) { 133 | params_df = as.data.frame(matrix(nrow=1, ncol=0)) 134 | } else { 135 | params_df = do.call(bind_rows, x$params) %>% 136 | as_tibble %>% 137 | as.data.frame 138 | } 139 | 140 | n_samp = length(x$samp) 141 | if (n_samp == 1) { 142 | params_df$.samp = format(x$samp[[1]]) 143 | } else { 144 | for (i in 1:n_samp) { 145 | colname = paste0(".samp_", i) 146 | params_df[colname] = format(x$samp[[i]]) 147 | } 148 | } 149 | 150 | params_df 151 | } 152 | 153 | 154 | #' \code{dplyr} Methods for \code{adjustr_spec} Objects 155 | #' 156 | #' Core \code{\link[dplyr]{dplyr}} verbs which don't involve grouping 157 | #' (\code{\link[dplyr]{filter}}, \code{\link[dplyr]{arrange}}, 158 | #' \code{\link[dplyr]{mutate}}, \code{\link[dplyr]{select}}, 159 | #' \code{\link[dplyr]{rename}}, and \code{\link[dplyr]{slice}}) are 160 | #' implemented and operate on the underlying table of specification parameters. 161 | #' 162 | #' @param .data the \code{adjustr_spec} object 163 | #' @param ... additional arguments to underlying method 164 | #' @param .preserve as in \code{filter} and \code{slice} 165 | #' @name dplyr.adjustr_spec 166 | #' 167 | #' @examples \dontrun{ 168 | #' spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10) 169 | #' 170 | #' arrange(spec, desc(df)) 171 | #' slice(spec, 4:7) 172 | #' filter(spec, df == 2) 173 | #' } 174 | NULL 175 | # dplyr generics 176 | dplyr_handler = function(dplyr_func, x, ...) { 177 | if (length(x$params[[1]]) == 0) return(x) 178 | x$params = do.call(bind_rows, x$params) %>% 179 | as_tibble %>% 180 | dplyr_func(...) %>% 181 | as.list %>% 182 | purrr::transpose() 183 | x 184 | } 185 | 186 | # no @export because R CMD CHECK didn't like it 187 | #' @rdname dplyr.adjustr_spec 188 | filter.adjustr_spec = function(.data, ..., .preserve=FALSE) { 189 | dplyr_handler(dplyr::filter, .data, ..., .preserve=.preserve) 190 | } 191 | #' @rdname dplyr.adjustr_spec 192 | #' @export 193 | arrange.adjustr_spec = function(.data, ...) { 194 | dplyr_handler(dplyr::arrange, .data, ...) 195 | } 196 | #' @rdname dplyr.adjustr_spec 197 | #' @export 198 | rename.adjustr_spec = function(.data, ...) { 199 | dplyr_handler(dplyr::rename, .data, ...) 200 | } 201 | #' @rdname dplyr.adjustr_spec 202 | #' @export 203 | select.adjustr_spec = function(.data, ...) { 204 | dplyr_handler(dplyr::select, .data, ...) 205 | } 206 | #' @rdname dplyr.adjustr_spec 207 | #' @export 208 | slice.adjustr_spec = function(.data, ..., .preserve=FALSE) { 209 | dplyr_handler(dplyr::slice, .data, ..., .preserve=.preserve) 210 | } 211 | 212 | -------------------------------------------------------------------------------- /R/parsing.R: -------------------------------------------------------------------------------- 1 | # regexes 2 | identifier = "[a-zA-Z][a-zA-Z0-9_]*" 3 | re_stmt = paste0("(int|real|(?:unit_|row_)?vector|(?:positive_)?ordered|simplex", 4 | "|(?:cov_|corr_)?matrix|cholesky_factor(?:_corr|_cov)?)(?:<.+>)?(?:\\[.+\\])?", 5 | " (", identifier, ")(?:\\[.+\\])? ?=?") 6 | #re_block = paste0(block_names, " ?\\{ ?(.+) ?\\} ?", block_names, "") 7 | re_block = "((?:transformed )?data|(?:transformed )?parameters|model|generated quantities)" 8 | re_samp = paste0("(", identifier, " ?~[^~{}]+)") 9 | re_samp2 = paste0("target ?\\+= ?(", identifier, ")_lp[md]f\\((", 10 | identifier, "(?:\\[[a-zA-Z0-9_]*\\])?)", "(?:| ?[|] ?(.+))\\)") 11 | 12 | # Extract variable name from variable declaration, or return NA if no declaration 13 | get_variables = function(statement) { 14 | matches = stringr::str_match(statement, re_stmt)[,3] 15 | matches[!is.na(matches)] 16 | } 17 | 18 | get_sampling = function(statement) { 19 | samps = stringr::str_match(statement, re_samp)[,2] 20 | samps2 = stringr::str_match(statement, re_samp2)#[,,3] 21 | samps2_rearr = paste0(samps2[,3], " ~ ", samps2[,2], "(", coalesce(samps2[,4], ""), ")") 22 | stmts = c(samps[!is.na(samps)], samps2_rearr[!is.na(samps2[,1])]) 23 | make_form = purrr::possibly(function(stmt) { 24 | stats::as.formula(stmt, env=empty_env()) 25 | }, NULL) 26 | purrr::compact(map(stmts, make_form)) 27 | } 28 | 29 | # Parse Stan `model_code` into a list with two elements: `vars` named 30 | # vector, with the names matching the model's variable names and the values 31 | # representing the program blocks they are defined in; `samp` is a list of 32 | # sampling statements (as formulas) 33 | parse_model = function(model_code) { 34 | clean_code = stringr::str_replace_all(model_code, "//.*", "") %>% 35 | stringr::str_replace_all("/\\*[^*]*\\*+(?:[^/*][^*]*\\*+)*/", "") %>% 36 | stringr::str_replace_all("\\n", " ") %>% 37 | stringr::str_replace_all("\\s\\s+", " ") 38 | 39 | block_names = stringr::str_extract_all(clean_code, re_block)[[1]] 40 | if (length(block_names)==0) return(list(vars=character(0), samps=list())) 41 | 42 | block_locs = rbind(stringr::str_locate_all(clean_code, re_block)[[1]], 43 | c(nchar(clean_code), NA)) 44 | blocks = map(1:length(block_names), function(i) { 45 | block = stringr::str_sub(clean_code, block_locs[i,2]+1, block_locs[i+1,1]) 46 | start = stringr::str_locate_all(block, stringr::fixed("{"))[[1]][1,1] + 1 47 | end = utils::tail(stringr::str_locate_all(block, stringr::fixed("}"))[[1]][,1], 1) - 1 48 | stringr::str_trim(stringr::str_sub(block, start+1, end-1)) 49 | }) 50 | names(blocks) = block_names 51 | 52 | statements = map(blocks, ~ stringr::str_split(., "; ?", simplify=TRUE)[1,]) 53 | 54 | vars = map(statements, get_variables) 55 | vars = purrr::flatten_chr(purrr::imap(vars, function(name, block) { 56 | block = rep(block, length(name)) 57 | names(block) = name 58 | block 59 | })) 60 | 61 | 62 | samps = map(statements, get_sampling) 63 | names(samps) = NULL 64 | samps = flatten(samps) 65 | 66 | parameters = names(vars)[vars == "parameters"] 67 | sampled_pars = map(samps, ~ deparse(f_lhs(.))) %>% 68 | purrr::as_vector() 69 | uniform_pars = setdiff(parameters, sampled_pars) 70 | if (length(uniform_pars) > 0) { 71 | uniform_samp = paste0(uniform_pars, " ~ uniform(-1e100, 1e100)") 72 | uniform_samp = map(uniform_samp, ~ stats::as.formula(., env=empty_env())) 73 | } else { 74 | uniform_samp = NULL 75 | } 76 | 77 | list(vars=vars, samp=c(samps, uniform_samp)) 78 | } 79 | 80 | 81 | # Take a list of provided sampling formulas and return a matching list of 82 | # sampling statements from a reference list 83 | match_sampling_stmts = function(new_samp, ref_samp) { 84 | ref_vars = map(ref_samp, ~ deparse(f_lhs(.))) %>% 85 | purrr::as_vector() 86 | new_vars = map(new_samp, ~ deparse(f_lhs(.))) %>% 87 | purrr::as_vector() 88 | indices = match(new_vars, ref_vars) 89 | # check that every prior was matched 90 | if (any(is.na(indices))) { 91 | stop("No matching sampling statement found for ", 92 | new_samp[which.max(is.na(indices))], 93 | "\n Check sampling statements and ensure that model data ", 94 | "has been provided.") 95 | } 96 | ref_samp[indices] 97 | } 98 | 99 | # Extract a list of variables from a sampling statement 100 | # R versions of mathematical operators must be used 101 | get_stmt_vars = function(stmt) { 102 | get_ast = function(x) purrr::map_if(as.list(x), is_call, get_ast) 103 | if (!is_call(f_rhs(stmt))) 104 | stop("Sampling statment ", format(stmt), 105 | " does not contain a distribution on the right-hand side.") 106 | # pull out variables from RHS 107 | rhs_vars = call_args(f_rhs(stmt)) %>% 108 | get_ast %>% 109 | unlist %>% 110 | purrr::discard(is.numeric) %>% 111 | as.character %>% 112 | purrr::discard(~ . %in% c("`+`", "`-`", "`*`", "`/`", "`^`", "`%*%`", "`%%`")) 113 | c(deparse(f_lhs(stmt)), rhs_vars) 114 | } 115 | 116 | -------------------------------------------------------------------------------- /R/use_weights.R: -------------------------------------------------------------------------------- 1 | #' Get Importance Resampling Indices From Weights 2 | #' 3 | #' Takes a vector of weights, or data frame or list containing sets of weights, 4 | #' and resamples indices for use in later computation. 5 | #' 6 | #' @param x A vector of weights, a list of weight vectors, or a data frame of 7 | #' type \code{adjustr_weighted} containing a \code{.weights} list-column 8 | #' of weights. 9 | #' @param frac A real number giving the fraction of draws to resample; the 10 | #' default, 1, resamples all draws. Smaller values should be used when 11 | #' \code{replace=FALSE}. 12 | #' @param replace Whether sampling should be with replacement. When weights 13 | #' are extreme it may make sense to use \code{replace=FALSE}, but accuracy 14 | #' is not guaranteed in these cases. 15 | #' 16 | #' @return A vector, list, or data frame, depending of the type of \code{x}, 17 | #' containing the sampled indices. If any weights are \code{NA}, the indices 18 | #' will also be \code{NA}. 19 | #' 20 | #' @examples \dontrun{ 21 | #' spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10) 22 | #' adjusted = adjust_weights(spec, eightschools_m) 23 | #' 24 | #' get_resampling_idxs(adjusted) 25 | #' get_resampling_idxs(adjusted, frac=0.1, replace=FALSE) 26 | #' } 27 | #' 28 | #' @export 29 | get_resampling_idxs = function(x, frac=1, replace=TRUE) { 30 | if (frac < 0) stop("`frac` parameter must be nonnegative") 31 | get_idxs = function(w) { 32 | if (all(is.na(w))) return(NA_integer_) 33 | sample.int(length(w), size=round(frac*length(w)), replace=replace, prob=w) 34 | } 35 | 36 | if (inherits(x, "list")) { 37 | map(x, get_idxs) 38 | } else if (inherits(x, "adjustr_weighted")) { 39 | x$.idxs = map(x$.weights, get_idxs) 40 | x 41 | } else { 42 | get_idxs(x) 43 | } 44 | } 45 | 46 | #' Summarize Posterior Distributions Under Alternative Model Specifications 47 | #' 48 | #' Uses weights computed in \code{\link{adjust_weights}} to compute posterior 49 | #' summary statistics. These statistics can be compared against their reference 50 | #' values to quantify the sensitivity of the model to aspects of its 51 | #' specification. 52 | #' 53 | #' @param .data An \code{adjustr_weighted} object. 54 | #' @param ... Name-value pairs of expressions. The name of each argument will be 55 | #' the name of a new variable, and the value will be computed for the 56 | #' posterior distribution of eight alternative specification. For example, 57 | #' a value of \code{mean(theta)} will compute the posterior mean of 58 | #' \code{theta} for each alternative specification. 59 | #' 60 | #' Also supported is the custom function \code{wasserstein}, which computes 61 | #' the Wasserstein-p distance between the posterior distribution of the 62 | #' provided expression under the new model and under the original model, with 63 | #' \code{p=1} the default. Lower the \code{spacing} parameter from the 64 | #' default of 0.005 to compute a finer (but slower) approximation. 65 | #' 66 | #' The arguments in \code{...} are automatically quoted and evaluated in the 67 | #' context of \code{.data}. They support unquoting and splicing. 68 | #' @param .resampling Whether to compute summary statistics by first resampling 69 | #' the data according to the weights. Defaults to \code{FALSE}, but will be 70 | #' used for any summary statistic that is not \code{mean}, \code{var} or 71 | #' \code{sd}. 72 | #' @param .model_data Stan model data, if not provided in the earlier call to 73 | #' \code{\link{adjust_weights}}. 74 | #' 75 | #' @return An \code{adjustr_weighted} object, with the new columns specified in 76 | #' \code{...} added. 77 | #' 78 | #' @seealso \code{\link{adjust_weights}}, \code{\link{spec_plot}} 79 | #' 80 | #' @examples \dontrun{ 81 | #' model_data = list( 82 | #' J = 8, 83 | #' y = c(28, 8, -3, 7, -1, 1, 18, 12), 84 | #' sigma = c(15, 10, 16, 11, 9, 11, 10, 18) 85 | #' ) 86 | #' 87 | #' spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10) 88 | #' adjusted = adjust_weights(spec, eightschools_m) 89 | #' 90 | #' summarize(adjusted, mean(mu), var(mu)) 91 | #' summarize(adjusted, wasserstein(mu, p=2)) 92 | #' summarize(adjusted, diff_1 = mean(y[1] - theta[1]), .model_data=model_data) 93 | #' summarize(adjusted, quantile(tau, probs=c(0.05, 0.5, 0.95))) 94 | #' } 95 | #' 96 | #' @rdname summarize.adjustr_weighted 97 | #' @export 98 | summarise.adjustr_weighted = function(.data, ..., .resampling=FALSE, .model_data=NULL) { 99 | stopifnot(is.adjustr_weighted(.data)) # just in case called manually 100 | args = enexprs(...) 101 | 102 | broadcast = function(x) { 103 | dims = c(dim(as.array(x)), iter) 104 | x = array(rep(x, iter), dim=dims) 105 | aperm(x, c(length(dims), 2:length(dims) - 1)) 106 | } 107 | iter = attr(.data, "iter") 108 | if (!is_null(.model_data)) attr(.data, "data") = .model_data 109 | data = append(attr(.data, "draws"), map(attr(.data, "data"), broadcast)) 110 | 111 | n_args = length(args) 112 | for (i in seq_along(args)) { 113 | name = names(args)[i] 114 | if (name == "") name = expr_name(args[[i]]) 115 | 116 | call = args[[i]] 117 | if (!is_call(call)) { 118 | stop("Expressions must summarize posterior draws; `", 119 | expr_text(call), "` has a different value for each draw.\n", 120 | " Try `mean(", expr_text(call), ")` or `sd(", expr_text(call), ")`.") 121 | } 122 | if (!.resampling && exists(call_name(call), funs_env)) { 123 | fun = funs_env[[call_name(call)]] 124 | } else { 125 | fun = function(x, ...) apply(x, 2, call_fn(call), ...) 126 | .resampling = T 127 | } 128 | 129 | expr = expr_deparse(call_args(call)[[1]]) 130 | expr = stringr::str_replace_all(expr, "\\[(\\d)", "[,\\1") 131 | expr = stringr::str_replace_all(expr, "(? q)[1] 176 | if (idx == 1) return(x[1]) 177 | stats::approx(y[idx-0:1], x[idx-0:1], q)$y 178 | }) 179 | } 180 | 181 | weighted.wasserstein = function(samp, wgt, p=1, spacing=0.005) { 182 | f = weighted.ecdf(samp, wgt) 183 | q = seq(0, 1, spacing) 184 | W = mean(abs(stats::quantile(samp, q, names=FALSE, type=4) - quantile.weighted.ecdf(f, q))^p) 185 | if (W < .Machine$double.eps) 0 else W^(1/p) 186 | } 187 | 188 | # Weighted summary functions that work on arrays 189 | wtd_array_mean = function(arr, wgt) colSums(as.array(arr)*wgt) / sum(wgt) 190 | wtd_array_var = function(arr, wgt) wtd_array_mean((arr - wtd_array_mean(arr, wgt))^2, wgt) 191 | wtd_array_sd = function(arr, wgt) sqrt(wtd_array_var(arr, wgt)) 192 | wtd_array_quantile = function(arr, wgt, probs=c(0.05, 0.25, 0.5, 0.75, 0.95)) { 193 | apply(arr, 2, function(x) quantile.weighted.ecdf(weighted.ecdf(x, wgt), probs)) 194 | } 195 | wtd_array_median = function(arr, wgt) wtd_array_quantile(arr, wgt, 0.5) 196 | wtd_array_wasserstein = function(arr, wgt, ...) { 197 | apply(arr, 2, function(x) weighted.wasserstein(x, wgt, ...)) 198 | } 199 | 200 | funs_env = new_environment(list( 201 | mean = wtd_array_mean, 202 | var = wtd_array_var, 203 | sd = wtd_array_sd, 204 | quantile = wtd_array_quantile, 205 | median = wtd_array_median, 206 | wasserstein = wtd_array_wasserstein 207 | )) 208 | 209 | 210 | #' Plot Posterior Quantities of Interest Under Alternative Model Specifications 211 | #' 212 | #' Uses weights computed in \code{\link{adjust_weights}} to plot posterior 213 | #' quantities of interest versus specification parameters 214 | #' 215 | #' @param x An \code{adjustr_weighted} object. 216 | #' @param by The x-axis variable, which is usually one of the specification 217 | #' parameters. Can be set to \code{1} if there is only one specification. 218 | #' Automatically quoted and evaluated in the context of \code{x}. 219 | #' @param post The posterior quantity of interest, to be computed for each 220 | #' resampled draw of each specification. Should evaluate to a single number 221 | #' for each draw. Automatically quoted and evaluated in the context of \code{x}. 222 | #' @param only_mean Whether to only plot the posterior mean. May be more stable. 223 | #' @param ci_level The inner credible interval to plot. Central 224 | #' 100*ci_level% intervals are computed from the quantiles of the resampled 225 | #' posterior draws. 226 | #' @param outer_level The outer credible interval to plot. 227 | #' @param ... Ignored. 228 | #' 229 | #' @return A \code{\link[ggplot2]{ggplot}} object which can be further 230 | #' customized with the \strong{ggplot2} package. 231 | #' 232 | #' @seealso \code{\link{adjust_weights}}, \code{\link{summarize.adjustr_weighted}} 233 | #' 234 | #' @examples \dontrun{ 235 | #' spec = make_spec(eta ~ student_t(df, 0, scale), 236 | #' df=1:10, scale=seq(2, 1, -1/9)) 237 | #' adjusted = adjust_weights(spec, eightschools_m) 238 | #' 239 | #' spec_plot(adjusted, df, theta[1]) 240 | #' spec_plot(adjusted, df, mu, only_mean=TRUE) 241 | #' spec_plot(adjusted, scale, tau) 242 | #' } 243 | #' 244 | #' @export 245 | spec_plot = function(x, by, post, only_mean=FALSE, ci_level=0.8, 246 | outer_level=0.95, ...) { 247 | if (!requireNamespace("ggplot2", quietly=TRUE)) { # nocov start 248 | stop("Package `ggplot2` must be installed to plot posterior quantities of interest.") 249 | } # nocov end 250 | if (ci_level > outer_level) stop("`ci_level` should be less than `outer_level`") 251 | 252 | post = enexpr(post) 253 | orig_row = filter(x, if_any(starts_with(".samp"), ~ . == "")) 254 | if (!only_mean) { 255 | outer = (1 - outer_level) / 2 256 | inner = (1 - ci_level) / 2 257 | q_probs = c(outer, inner, 0.5, 1-inner, 1-outer) 258 | sum_arg = quo(quantile(!!post, probs = !!q_probs)) 259 | 260 | filter(x, if_any(starts_with(".samp"), ~ . != "")) %>% 261 | summarise.adjustr_weighted(.y = !!sum_arg) %>% 262 | rowwise() %>% 263 | mutate(.y_ol = .y[1], 264 | .y_il = .y[2], 265 | .y_med = .y[3], 266 | .y_iu = .y[4], 267 | .y_ou = .y[5]) %>% 268 | ggplot2::ggplot(ggplot2::aes({{ by }}, .y_med)) + 269 | ggplot2::geom_ribbon(ggplot2::aes(ymin=.y_ol, ymax=.y_ou), alpha=0.4) + 270 | ggplot2::geom_ribbon(ggplot2::aes(ymin=.y_il, ymax=.y_iu), alpha=0.5) + 271 | { if (nrow(orig_row) == 1) 272 | ggplot2::geom_hline(yintercept=summarise.adjustr_weighted(orig_row, .y = median(!!post))$`.y`, 273 | lty="dashed") 274 | } + 275 | ggplot2::geom_line() + 276 | ggplot2::geom_point(size=3) + 277 | ggplot2::theme_minimal() + 278 | ggplot2::labs(y = expr_name(post)) 279 | } else { 280 | filter(x, if_any(starts_with(".samp"), ~ . != "")) %>% 281 | summarise.adjustr_weighted(.y = mean(!!post)) %>% 282 | ggplot2::ggplot(ggplot2::aes({{ by }}, .y)) + 283 | { if (nrow(orig_row) == 1) 284 | ggplot2::geom_hline(yintercept=summarise.adjustr_weighted(orig_row, .y = mean(!!post))$`.y`, 285 | lty="dashed") 286 | } + 287 | ggplot2::geom_line() + 288 | ggplot2::geom_point(size=3) + 289 | ggplot2::theme_minimal() + 290 | ggplot2::labs(y = expr_name(post)) 291 | } 292 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # adjustr 2 | 3 | 4 | [![Lifecycle: experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://www.tidyverse.org/lifecycle/#experimental) 5 | [![R-CMD-check](https://github.com/CoryMcCartan/adjustr/workflows/R-CMD-check/badge.svg)](https://github.com/CoryMcCartan/adjustr/actions) 6 | [![Codecov test coverage](https://codecov.io/gh/CoryMcCartan/adjustr/branch/master/graph/badge.svg)](https://codecov.io/gh/CoryMcCartan/adjustr?branch=master) 7 | 8 | 9 | Sensitivity analysis is a critical component of a good modeling workflow. Yet 10 | as the number and power of Bayesian computational tools has increased, the 11 | options for sensitivity analysis have remained largely the same: compute 12 | importance sampling weights manually, or fit a large number of similar models, 13 | dramatically increasing computation time. Neither option is satisfactory for 14 | most applied modeling. 15 | 16 | **adjustr** is an R package which aims to make sensitivity analysis faster 17 | and easier, and works with Bayesian models fitted with [Stan](https://mc-stan.org). 18 | Users provide a series of alternate sampling specifications, and the package 19 | uses Pareto-smoothed importance sampling to estimate the posterior under each 20 | specification. The package also provides functions to summarize and plot how 21 | posterior quantities change across specifications. 22 | 23 | The package provides simple interface that makes it as easy as possible 24 | for researchers to try out various adjustments to their Stan models, without 25 | needing to write any specific Stan code or even recompile or rerun their model. 26 | 27 | The package works by parsing Stan model code, so everything works best if the 28 | model was written by the user. Models made using **brms** may in principle be 29 | used as well. Models made using **rstanarm** are constructed using more 30 | complex model templates, and cannot be used. 31 | 32 | ## Getting Started 33 | 34 | The basic __adjustr__ workflow is as follows: 35 | 36 | 1. Use [`make_spec`](https://corymccartan.github.io/adjustr/reference/make_spec.html) 37 | to specify the set of alternative model specifications you'd like to fit. 38 | 39 | 2. Use [`adjust_weights`](https://corymccartan.github.io/adjustr/reference/adjust_weights.html) 40 | to calculate importance sampling weights which approximate the posterior of each 41 | alternative specification. 42 | 43 | 3. Use [`summarize`](https://corymccartan.github.io/adjustr/reference/summarize.adjustr_weighted.html) 44 | and [`spec_plot`](https://corymccartan.github.io/adjustr/reference/spec_plot.html) 45 | to examine posterior quantities of interest for each alternative specification, 46 | in order to assess the sensitivity of the underlying model. 47 | 48 | To illustrate, the package lets us do the following: 49 | ```r 50 | extract_samp_stmts(eightschools_m) 51 | #> Sampling statements for model 2c8d1d8a30137533422c438f23b83428: 52 | #> parameter eta ~ std_normal() 53 | #> data y ~ normal(theta, sigma) 54 | 55 | make_spec(eta ~ student_t(0, 1, df), df=1:10) %>% 56 | adjust_weights(eightschools_m) %>% 57 | summarize(wasserstein(mu)) 58 | #> # A tibble: 11 x 5 59 | #> df .samp .weights .pareto_k `wasserstein(mu)` 60 | #> 61 | #> 1 1 eta ~ student_t(df, 0, 1) 1.02 0.928 62 | #> 2 2 eta ~ student_t(df, 0, 1) 1.03 0.736 63 | #> 3 3 eta ~ student_t(df, 0, 1) 0.915 0.534 64 | #> 4 4 eta ~ student_t(df, 0, 1) 0.856 0.411 65 | #> 5 5 eta ~ student_t(df, 0, 1) 0.826 0.341 66 | #> 6 6 eta ~ student_t(df, 0, 1) 0.803 0.275 67 | #> 7 7 eta ~ student_t(df, 0, 1) 0.782 0.234 68 | #> 8 8 eta ~ student_t(df, 0, 1) 0.753 0.195 69 | #> 9 9 eta ~ student_t(df, 0, 1) 0.736 0.166 70 | #> 10 10 eta ~ student_t(df, 0, 1) 0.721 0.151 71 | #> 11 NA -Inf 0 72 | ``` 73 | 74 | The tutorial [vignette](https://corymccartan.github.io/adjustr/articles/eight-schools.html) 75 | walks through a full sensitivity analysis for this 8-schools example. 76 | Smaller examples are also included in the package 77 | [documentation](https://corymccartan.github.io/adjustr/reference/index.html). 78 | 79 | ## Installation 80 | 81 | Install the latest version from **GitHub**: 82 | 83 | ```r 84 | if (!require("devtools")) { 85 | install.packages("devtools") 86 | } 87 | devtools::install_github("corymccartan/adjustr@*release") 88 | ``` 89 | 90 | ## References 91 | 92 | Vehtari, A., Simpson, D., Gelman, A., Yao, Y., & Gabry, J. (2015). 93 | Pareto smoothed importance sampling. 94 | _[arXiv preprint arXiv:1507.02646](https://arxiv.org/abs/1507.02646)_. 95 | -------------------------------------------------------------------------------- /_pkgdown.yml: -------------------------------------------------------------------------------- 1 | url: https://corymccartan.github.io/adjustr/ 2 | 3 | destination: docs 4 | 5 | development: 6 | mode: auto 7 | 8 | template: 9 | params: 10 | bootswatch: cosmo 11 | ganalytics: UA-79274202-5 12 | 13 | navbar: 14 | title: "adjustr" 15 | left: 16 | - text: "Get Started" 17 | href: articles/adjustr.html 18 | - text: "Functions" 19 | href: reference/index.html 20 | - text: "News" 21 | href: news/index.html 22 | - text: "Other Packages" 23 | menu: 24 | - text: "rstan" 25 | href: https://mc-stan.org/rstan 26 | - text: "cmdstanr" 27 | href: https://mc-stan.org/cmdstanr 28 | - text: "rstanarm" 29 | href: https://mc-stan.org/rstanarm 30 | - text: "bayesplot" 31 | href: https://mc-stan.org/bayesplot 32 | - text: "shinystan" 33 | href: https://mc-stan.org/shinystan 34 | - text: "loo" 35 | href: https://mc-stan.org/loo 36 | - text: "projpred" 37 | href: https://mc-stan.org/projpred 38 | - text: "rstantools" 39 | href: https://mc-stan.org/rstantools 40 | - text: "Stan" 41 | href: https://mc-stan.org 42 | right: 43 | - icon: fa-twitter 44 | href: https://twitter.com/mcmc_stan 45 | - icon: fa-github 46 | href: https://github.com/CoryMcCartan/adjustr/ 47 | - icon: fa-users 48 | href: https://discourse.mc-stan.org/ 49 | 50 | home: 51 | title: "adjustr: Stan Model Adjustments and Sensitivity Analyses using Importance Sampling" 52 | description: > 53 | An R package which provides functions to help assess the sensitivity of a 54 | Bayesian model to the specification of its likelihood and priors. Users 55 | provide a series of alternate sampling specifications, and the package uses 56 | Pareto-smoothed importance sampling to estimate posterior quantities of 57 | interest under each specification. The package also provides functions to 58 | summarize and plot how these quantities change across specifications. 59 | links: 60 | - text: Ask a question 61 | href: https://discourse.mc-stan.org/ 62 | 63 | authors: 64 | Cory McCartan: 65 | href: "https://corymccartan.github.io/" 66 | 67 | reference: 68 | - title: "Model Adjustments" 69 | desc: Core functions for sensitivity anlysis workflow. 70 | contents: 71 | - make_spec 72 | - adjust_weights 73 | - summarize.adjustr_weighted 74 | - spec_plot 75 | - title: "Helper Functions" 76 | desc: > 77 | Various helper functions for examining a model or building sampling 78 | specifications. 79 | contents: 80 | - extract_samp_stmts 81 | - as.data.frame.adjustr_spec 82 | - dplyr.adjustr_spec 83 | - get_resampling_idxs 84 | - pull.adjustr_weighted 85 | - adjustr-package 86 | -------------------------------------------------------------------------------- /adjustr.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: Default 4 | SaveWorkspace: Default 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 4 10 | Encoding: UTF-8 11 | 12 | RnwWeave: Sweave 13 | LaTeX: pdfLaTeX 14 | 15 | StripTrailingWhitespace: Yes 16 | 17 | BuildType: Package 18 | PackageUseDevtools: Yes 19 | PackageCleanBeforeInstall: Yes 20 | PackageInstallArgs: --no-multiarch --with-keep.source 21 | PackageRoxygenize: rd,collate,namespace 22 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | c comment: false 2 | 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: auto 8 | threshold: 1% 9 | informational: true 10 | patch: 11 | default: 12 | target: auto 13 | threshold: 1% 14 | informational: true 15 | -------------------------------------------------------------------------------- /man/adjust_weights.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/adjust_weights.R 3 | \name{adjust_weights} 4 | \alias{adjust_weights} 5 | \title{Compute Pareto-smoothed Importance Weights for Alternative Model 6 | Specifications} 7 | \usage{ 8 | adjust_weights(spec, object, data = NULL, keep_bad = FALSE, incl_orig = TRUE) 9 | } 10 | \arguments{ 11 | \item{spec}{An object of class \code{adjustr_spec}, probably produced by 12 | \code{\link{make_spec}}, containing the new sampling sampling statements 13 | to replace their counterparts in the original Stan model, and the data, 14 | if any, by which these sampling statements are parametrized.} 15 | 16 | \item{object}{A model object, either of type \code{\link[rstan]{stanfit}}, 17 | \code{\link[rstanarm:stanreg-objects]{stanreg}}, \code{\link[brms]{brmsfit}}, 18 | or a list with two elements: \code{model} containing a 19 | \code{\link[cmdstanr]{CmdStanModel}}, and \code{fit} containing a 20 | \code{\link[cmdstanr]{CmdStanMCMC}} object.} 21 | 22 | \item{data}{The data that was used to fit the model in \code{object}. 23 | Required only if one of the new sampling specifications involves Stan data 24 | variables.} 25 | 26 | \item{keep_bad}{When \code{FALSE} (the strongly recommended default), 27 | alternate specifications which deviate too much from the original 28 | posterior, and which as a result cannot be reliably estimated using 29 | importance sampling (i.e., if the Pareto shape parameter is larger than 30 | 0.7), have their weights discarded—weights are set to \code{NA_real_}.} 31 | 32 | \item{incl_orig}{When \code{TRUE}, include a row for the original 33 | model specification, with all weights equal. Can facilitate comparison 34 | and plotting later.} 35 | } 36 | \value{ 37 | A tibble, produced by converting the provided \code{specs} to a 38 | tibble (see \code{\link{as.data.frame.adjustr_spec}}), and adding columns 39 | \code{.weights}, containing vectors of weights for each draw, and 40 | \code{.pareto_k}, containing the diagnostic Pareto shape parameters. Values 41 | greater than 0.7 indicate that importance sampling is not reliable. 42 | If \code{incl_orig} is \code{TRUE}, a row is added for the original model 43 | specification. Weights can be extracted with the 44 | \code{\link{pull.adjustr_weighted}} method. The returned object also 45 | includes the model sample draws, in the \code{draws} attribute. 46 | } 47 | \description{ 48 | Given a set of new sampling statements, which can be parametrized by a 49 | data frame or list, compute Pareto-smoothed importance weights and attach 50 | them to the specification object, for further calculation and plotting. 51 | } 52 | \details{ 53 | This function does the bulk of the sensitivity analysis work. It operates 54 | by parsing the model code from the provided Stan object, extracting the 55 | parameters and their sampling statements. It then uses R 56 | metaprogramming/tidy evaluation tools to flexibly evaluate the log density 57 | for each draw and each sampling statement, under the original and alternative 58 | specifications. From these, the function computes the overall importance 59 | weight for each draw and performs Pareto-smoothed importance sampling. All 60 | of the work is performed in R, without recompiling or refitting the Stan 61 | model. 62 | } 63 | \examples{ 64 | \dontrun{ 65 | model_data = list( 66 | J = 8, 67 | y = c(28, 8, -3, 7, -1, 1, 18, 12), 68 | sigma = c(15, 10, 16, 11, 9, 11, 10, 18) 69 | ) 70 | 71 | spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10) 72 | adjust_weights(spec, eightschools_m) 73 | adjust_weights(spec, eightschools_m, keep_bad=TRUE) 74 | 75 | spec = make_spec(y ~ student_t(df, theta, sigma), df=1:10) 76 | adjust_weights(spec, eightschools_m, data=model_data) 77 | # will throw an error because `y` and `sigma` aren't provided 78 | adjust_weights(spec, eightschools_m) 79 | } 80 | 81 | } 82 | \references{ 83 | Vehtari, A., Simpson, D., Gelman, A., Yao, Y., & Gabry, J. (2015). 84 | Pareto smoothed importance sampling. \href{https://arxiv.org/abs/1507.02646}{arXiv preprint arXiv:1507.02646}. 85 | } 86 | \seealso{ 87 | \code{\link{make_spec}}, \code{\link{summarize.adjustr_weighted}}, \code{\link{spec_plot}} 88 | } 89 | -------------------------------------------------------------------------------- /man/adjustr-package.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/adjustr-package.R 3 | \docType{package} 4 | \name{adjustr-package} 5 | \alias{adjustr-package} 6 | \title{adjustr: Stan Model Adjustments and Sensitivity Analyses using Importance 7 | Sampling} 8 | \description{ 9 | Functions to help assess the sensitivity of a Bayesian model to the 10 | specification of its likelihood and priors, estimated using the rstan 11 | package. Users provide a series of alternate sampling specifications, and the 12 | package uses Pareto-smoothed importance sampling to estimate posterior 13 | quantities of interest under each specification. 14 | } 15 | \details{ 16 | See the list of key functions and the example below. 17 | Full package documentation available at \url{https://corymccartan.github.io/adjustr/}. 18 | } 19 | \section{Key Functions}{ 20 | 21 | \itemize{ 22 | \item \code{\link{make_spec}} 23 | \item \code{\link{adjust_weights}} 24 | \item \code{\link{summarize}} 25 | \item \code{\link{plot}} 26 | } 27 | } 28 | 29 | -------------------------------------------------------------------------------- /man/as.data.frame.adjustr_spec.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/make_spec.R 3 | \name{as.data.frame.adjustr_spec} 4 | \alias{as.data.frame.adjustr_spec} 5 | \title{Convert an \code{adjustr_spec} Object Into a Data Frame} 6 | \usage{ 7 | \method{as.data.frame}{adjustr_spec}(x, ...) 8 | } 9 | \arguments{ 10 | \item{x}{the \code{adjustr_spec} object} 11 | 12 | \item{...}{additional arguments to underlying method} 13 | } 14 | \description{ 15 | Returns the data frame of specification parameters, with added columns of 16 | the form \code{.samp_1}, \code{.samp_2}, ... for each sampling statement 17 | (or just \code{.samp} if there is only one sampling statement). 18 | } 19 | -------------------------------------------------------------------------------- /man/dplyr.adjustr_spec.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/make_spec.R 3 | \name{dplyr.adjustr_spec} 4 | \alias{dplyr.adjustr_spec} 5 | \alias{filter.adjustr_spec} 6 | \alias{arrange.adjustr_spec} 7 | \alias{rename.adjustr_spec} 8 | \alias{select.adjustr_spec} 9 | \alias{slice.adjustr_spec} 10 | \title{\code{dplyr} Methods for \code{adjustr_spec} Objects} 11 | \usage{ 12 | \method{filter}{adjustr_spec}(.data, ..., .preserve = FALSE) 13 | 14 | \method{arrange}{adjustr_spec}(.data, ...) 15 | 16 | \method{rename}{adjustr_spec}(.data, ...) 17 | 18 | \method{select}{adjustr_spec}(.data, ...) 19 | 20 | \method{slice}{adjustr_spec}(.data, ..., .preserve = FALSE) 21 | } 22 | \arguments{ 23 | \item{.data}{the \code{adjustr_spec} object} 24 | 25 | \item{...}{additional arguments to underlying method} 26 | 27 | \item{.preserve}{as in \code{filter} and \code{slice}} 28 | } 29 | \description{ 30 | Core \code{\link[dplyr]{dplyr}} verbs which don't involve grouping 31 | (\code{\link[dplyr]{filter}}, \code{\link[dplyr]{arrange}}, 32 | \code{\link[dplyr]{mutate}}, \code{\link[dplyr]{select}}, 33 | \code{\link[dplyr]{rename}}, and \code{\link[dplyr]{slice}}) are 34 | implemented and operate on the underlying table of specification parameters. 35 | } 36 | \examples{ 37 | \dontrun{ 38 | spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10) 39 | 40 | arrange(spec, desc(df)) 41 | slice(spec, 4:7) 42 | filter(spec, df == 2) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /man/extract_samp_stmts.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/adjust_weights.R 3 | \name{extract_samp_stmts} 4 | \alias{extract_samp_stmts} 5 | \title{Extract Model Sampling Statements From a Stan Model.} 6 | \usage{ 7 | extract_samp_stmts(object) 8 | } 9 | \arguments{ 10 | \item{object}{A \code{\link[rstan]{stanfit}} model object.} 11 | } 12 | \value{ 13 | Invisibly returns a list of sampling formulas. 14 | } 15 | \description{ 16 | Prints a list of sampling statements extracted from the \code{model} block of 17 | a Stan program, with each labelled "parameter" or "data" depending on the 18 | type of variable being sampled. 19 | } 20 | \examples{ 21 | \dontrun{ 22 | extract_samp_stmts(eightschools_m) 23 | #> Sampling statements for model 2c8d1d8a30137533422c438f23b83428: 24 | #> parameter eta ~ std_normal() 25 | #> data y ~ normal(theta, sigma) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /man/figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoryMcCartan/adjustr/c401950a404ce30917622d2dfbd9408897e4f05b/man/figures/logo.png -------------------------------------------------------------------------------- /man/get_resampling_idxs.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/use_weights.R 3 | \name{get_resampling_idxs} 4 | \alias{get_resampling_idxs} 5 | \title{Get Importance Resampling Indices From Weights} 6 | \usage{ 7 | get_resampling_idxs(x, frac = 1, replace = TRUE) 8 | } 9 | \arguments{ 10 | \item{x}{A vector of weights, a list of weight vectors, or a data frame of 11 | type \code{adjustr_weighted} containing a \code{.weights} list-column 12 | of weights.} 13 | 14 | \item{frac}{A real number giving the fraction of draws to resample; the 15 | default, 1, resamples all draws. Smaller values should be used when 16 | \code{replace=FALSE}.} 17 | 18 | \item{replace}{Whether sampling should be with replacement. When weights 19 | are extreme it may make sense to use \code{replace=FALSE}, but accuracy 20 | is not guaranteed in these cases.} 21 | } 22 | \value{ 23 | A vector, list, or data frame, depending of the type of \code{x}, 24 | containing the sampled indices. If any weights are \code{NA}, the indices 25 | will also be \code{NA}. 26 | } 27 | \description{ 28 | Takes a vector of weights, or data frame or list containing sets of weights, 29 | and resamples indices for use in later computation. 30 | } 31 | \examples{ 32 | \dontrun{ 33 | spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10) 34 | adjusted = adjust_weights(spec, eightschools_m) 35 | 36 | get_resampling_idxs(adjusted) 37 | get_resampling_idxs(adjusted, frac=0.1, replace=FALSE) 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /man/make_spec.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/make_spec.R 3 | \name{make_spec} 4 | \alias{make_spec} 5 | \title{Set Up Model Adjustment Specifications} 6 | \usage{ 7 | make_spec(...) 8 | } 9 | \arguments{ 10 | \item{...}{Model specification. Each argument can either be a formula, 11 | a named vector, data frames, or lists. 12 | 13 | Formula arguments provide new sampling statements to replace their 14 | counterparts in the original Stan model. All such formulas must be of the 15 | form \code{variable ~ distribution(parameters)}, where \code{variable} and 16 | \code{parameters} are Stan data variables or parameters, or are provided by 17 | other arguments to this function (see below), and where \code{distribution} 18 | matches one of the univariate 19 | \href{https://mc-stan.org/docs/2_22/functions-reference/conventions-for-probability-functions.html}{Stan distributions}. 20 | Arithmetic expressions of parameters are also allowed, but care must be 21 | taken with multivariate parameter arguments. Since specifications are 22 | passed as formulas, R's arithmetic operators are used, not Stan's. As a 23 | result, matrix and elementwise multipilcation in Stan sampling statements may 24 | not be interpreted correctly. Moving these computations out of sampling 25 | statements and into local variables will ensure correct results. 26 | 27 | For named vector arguments, each entry of the vector will be substituted 28 | into the corresponding parameter in the sampling statements. For data 29 | frames, each entry in each column will be substituted into the corresponding 30 | parameter in the sampling statements. 31 | 32 | List arguments are coerced to data frames. They can either be lists of named 33 | vectors, or lists of lists of single-element named vectors. 34 | 35 | The lengths of all parameter arguments must be consistent. Named vectors 36 | can have length 1 or must have length equal to the number of rows in all 37 | data frame arguments and the length of list arguments.} 38 | } 39 | \value{ 40 | An object of class \code{adjustr_spec}, which is essentially a list 41 | with two elements: \code{samp}, which is a list of sampling formulas, and 42 | \code{params}, which is a list of lists of parameters. Core 43 | \link[=filter.adjustr_spec]{dplyr verbs} which don't involve grouping 44 | (\code{\link[dplyr]{filter}}, \code{\link[dplyr]{arrange}}, 45 | \code{\link[dplyr]{mutate}}, \code{\link[dplyr]{select}}, 46 | \code{\link[dplyr]{rename}}, and \code{\link[dplyr]{slice}}) are 47 | supported and operate on the underlying table of specification parameters. 48 | } 49 | \description{ 50 | Takes a set of new sampling statements, which can be parametrized by other 51 | arguments, data frames, or lists, and creates an \code{adjustr_spec} object 52 | suitable for use in \code{\link{adjust_weights}}. 53 | } 54 | \examples{ 55 | make_spec(eta ~ cauchy(0, 1)) 56 | 57 | make_spec(eta ~ student_t(df, 0, 1), df=1:10) 58 | 59 | params = tidyr::crossing(df=1:10, infl=c(1, 1.5, 2)) 60 | make_spec(eta ~ student_t(df, 0, 1), 61 | y ~ normal(theta, infl*sigma), 62 | params) 63 | 64 | } 65 | \seealso{ 66 | \code{\link{adjust_weights}}, \code{\link{summarize.adjustr_weighted}}, \code{\link{spec_plot}} 67 | } 68 | -------------------------------------------------------------------------------- /man/pull.adjustr_weighted.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/adjust_weights.R 3 | \name{pull.adjustr_weighted} 4 | \alias{pull.adjustr_weighted} 5 | \title{Extract Weights From an \code{adjustr_weighted} Object} 6 | \usage{ 7 | \method{pull}{adjustr_weighted}(.data, var = ".weights", name = NULL, ...) 8 | } 9 | \arguments{ 10 | \item{.data}{A table of data} 11 | 12 | \item{var}{A variable, as in \code{\link[dplyr]{pull}}. The default returns 13 | the \code{.weights} column, and if there is only one row, it returns the 14 | first element of that column} 15 | 16 | \item{name}{Ignored} 17 | 18 | \item{...}{Ignored} 19 | } 20 | \description{ 21 | This function modifies the default behavior of \code{dplyr::pull} to extract 22 | the \code{.weights} column. 23 | } 24 | -------------------------------------------------------------------------------- /man/spec_plot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/use_weights.R 3 | \name{spec_plot} 4 | \alias{spec_plot} 5 | \title{Plot Posterior Quantities of Interest Under Alternative Model Specifications} 6 | \usage{ 7 | spec_plot( 8 | x, 9 | by, 10 | post, 11 | only_mean = FALSE, 12 | ci_level = 0.8, 13 | outer_level = 0.95, 14 | ... 15 | ) 16 | } 17 | \arguments{ 18 | \item{x}{An \code{adjustr_weighted} object.} 19 | 20 | \item{by}{The x-axis variable, which is usually one of the specification 21 | parameters. Can be set to \code{1} if there is only one specification. 22 | Automatically quoted and evaluated in the context of \code{x}.} 23 | 24 | \item{post}{The posterior quantity of interest, to be computed for each 25 | resampled draw of each specification. Should evaluate to a single number 26 | for each draw. Automatically quoted and evaluated in the context of \code{x}.} 27 | 28 | \item{only_mean}{Whether to only plot the posterior mean. May be more stable.} 29 | 30 | \item{ci_level}{The inner credible interval to plot. Central 31 | 100*ci_level% intervals are computed from the quantiles of the resampled 32 | posterior draws.} 33 | 34 | \item{outer_level}{The outer credible interval to plot.} 35 | 36 | \item{...}{Ignored.} 37 | } 38 | \value{ 39 | A \code{\link[ggplot2]{ggplot}} object which can be further 40 | customized with the \strong{ggplot2} package. 41 | } 42 | \description{ 43 | Uses weights computed in \code{\link{adjust_weights}} to plot posterior 44 | quantities of interest versus specification parameters 45 | } 46 | \examples{ 47 | \dontrun{ 48 | spec = make_spec(eta ~ student_t(df, 0, scale), 49 | df=1:10, scale=seq(2, 1, -1/9)) 50 | adjusted = adjust_weights(spec, eightschools_m) 51 | 52 | spec_plot(adjusted, df, theta[1]) 53 | spec_plot(adjusted, df, mu, only_mean=TRUE) 54 | spec_plot(adjusted, scale, tau) 55 | } 56 | 57 | } 58 | \seealso{ 59 | \code{\link{adjust_weights}}, \code{\link{summarize.adjustr_weighted}} 60 | } 61 | -------------------------------------------------------------------------------- /man/summarize.adjustr_weighted.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/use_weights.R 3 | \name{summarise.adjustr_weighted} 4 | \alias{summarise.adjustr_weighted} 5 | \alias{summarize.adjustr_weighted} 6 | \title{Summarize Posterior Distributions Under Alternative Model Specifications} 7 | \usage{ 8 | \method{summarise}{adjustr_weighted}(.data, ..., .resampling = FALSE, .model_data = NULL) 9 | 10 | \method{summarize}{adjustr_weighted}(.data, ..., .resampling = FALSE, .model_data = NULL) 11 | } 12 | \arguments{ 13 | \item{.data}{An \code{adjustr_weighted} object.} 14 | 15 | \item{...}{Name-value pairs of expressions. The name of each argument will be 16 | the name of a new variable, and the value will be computed for the 17 | posterior distribution of eight alternative specification. For example, 18 | a value of \code{mean(theta)} will compute the posterior mean of 19 | \code{theta} for each alternative specification. 20 | 21 | Also supported is the custom function \code{wasserstein}, which computes 22 | the Wasserstein-p distance between the posterior distribution of the 23 | provided expression under the new model and under the original model, with 24 | \code{p=1} the default. Lower the \code{spacing} parameter from the 25 | default of 0.005 to compute a finer (but slower) approximation. 26 | 27 | The arguments in \code{...} are automatically quoted and evaluated in the 28 | context of \code{.data}. They support unquoting and splicing.} 29 | 30 | \item{.resampling}{Whether to compute summary statistics by first resampling 31 | the data according to the weights. Defaults to \code{FALSE}, but will be 32 | used for any summary statistic that is not \code{mean}, \code{var} or 33 | \code{sd}.} 34 | 35 | \item{.model_data}{Stan model data, if not provided in the earlier call to 36 | \code{\link{adjust_weights}}.} 37 | } 38 | \value{ 39 | An \code{adjustr_weighted} object, with the new columns specified in 40 | \code{...} added. 41 | } 42 | \description{ 43 | Uses weights computed in \code{\link{adjust_weights}} to compute posterior 44 | summary statistics. These statistics can be compared against their reference 45 | values to quantify the sensitivity of the model to aspects of its 46 | specification. 47 | } 48 | \examples{ 49 | \dontrun{ 50 | model_data = list( 51 | J = 8, 52 | y = c(28, 8, -3, 7, -1, 1, 18, 12), 53 | sigma = c(15, 10, 16, 11, 9, 11, 10, 18) 54 | ) 55 | 56 | spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10) 57 | adjusted = adjust_weights(spec, eightschools_m) 58 | 59 | summarize(adjusted, mean(mu), var(mu)) 60 | summarize(adjusted, wasserstein(mu, p=2)) 61 | summarize(adjusted, diff_1 = mean(y[1] - theta[1]), .model_data=model_data) 62 | summarize(adjusted, quantile(tau, probs=c(0.05, 0.5, 0.95))) 63 | } 64 | 65 | } 66 | \seealso{ 67 | \code{\link{adjust_weights}}, \code{\link{spec_plot}} 68 | } 69 | -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-120x120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoryMcCartan/adjustr/c401950a404ce30917622d2dfbd9408897e4f05b/pkgdown/favicon/apple-touch-icon-120x120.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-152x152.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoryMcCartan/adjustr/c401950a404ce30917622d2dfbd9408897e4f05b/pkgdown/favicon/apple-touch-icon-152x152.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-180x180.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoryMcCartan/adjustr/c401950a404ce30917622d2dfbd9408897e4f05b/pkgdown/favicon/apple-touch-icon-180x180.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-60x60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoryMcCartan/adjustr/c401950a404ce30917622d2dfbd9408897e4f05b/pkgdown/favicon/apple-touch-icon-60x60.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-76x76.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoryMcCartan/adjustr/c401950a404ce30917622d2dfbd9408897e4f05b/pkgdown/favicon/apple-touch-icon-76x76.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoryMcCartan/adjustr/c401950a404ce30917622d2dfbd9408897e4f05b/pkgdown/favicon/apple-touch-icon.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoryMcCartan/adjustr/c401950a404ce30917622d2dfbd9408897e4f05b/pkgdown/favicon/favicon-16x16.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoryMcCartan/adjustr/c401950a404ce30917622d2dfbd9408897e4f05b/pkgdown/favicon/favicon-32x32.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoryMcCartan/adjustr/c401950a404ce30917622d2dfbd9408897e4f05b/pkgdown/favicon/favicon.ico -------------------------------------------------------------------------------- /tests/ex.stan: -------------------------------------------------------------------------------- 1 | // generated with brms 2.16.1 2 | functions { 3 | /* cumulative-logit log-PDF for a single response 4 | * Args: 5 | * y: response category 6 | * mu: latent mean parameter 7 | * disc: discrimination parameter 8 | * thres: ordinal thresholds 9 | * Returns: 10 | * a scalar to be added to the log posterior 11 | */ 12 | real cumulative_logit_lpmf(int y, real mu, real disc, vector thres) { 13 | int nthres = num_elements(thres); 14 | if (y == 1) { 15 | return log_inv_logit(disc * (thres[1] - mu)); 16 | } else if (y == nthres + 1) { 17 | return log1m_inv_logit(disc * (thres[nthres] - mu)); 18 | } else { 19 | return log_diff_exp( 20 | log_inv_logit(disc * (thres[y] - mu)), 21 | log_inv_logit(disc * (thres[y - 1] - mu)) 22 | ); 23 | } 24 | } 25 | /* cumulative-logit log-PDF for a single response and merged thresholds 26 | * Args: 27 | * y: response category 28 | * mu: latent mean parameter 29 | * disc: discrimination parameter 30 | * thres: vector of merged ordinal thresholds 31 | * j: start and end index for the applid threshold within 'thres' 32 | * Returns: 33 | * a scalar to be added to the log posterior 34 | */ 35 | real cumulative_logit_merged_lpmf(int y, real mu, real disc, vector thres, int[] j) { 36 | return cumulative_logit_lpmf(y | mu, disc, thres[j[1]:j[2]]); 37 | } 38 | /* ordered-logistic log-PDF for a single response and merged thresholds 39 | * Args: 40 | * y: response category 41 | * mu: latent mean parameter 42 | * thres: vector of merged ordinal thresholds 43 | * j: start and end index for the applid threshold within 'thres' 44 | * Returns: 45 | * a scalar to be added to the log posterior 46 | */ 47 | real ordered_logistic_merged_lpmf(int y, real mu, vector thres, int[] j) { 48 | return ordered_logistic_lpmf(y | mu, thres[j[1]:j[2]]); 49 | } 50 | } 51 | data { 52 | int N; // total number of observations 53 | int Y[N]; // response variable 54 | int nthres; // number of thresholds 55 | int prior_only; // should the likelihood be ignored? 56 | } 57 | transformed data { 58 | } 59 | parameters { 60 | ordered[nthres] Intercept; // temporary thresholds for centered predictors 61 | } 62 | transformed parameters { 63 | real disc = 1; // discrimination parameters 64 | } 65 | model { 66 | // likelihood including constants 67 | if (!prior_only) { 68 | // initialize linear predictor term 69 | vector[N] mu = rep_vector(0.0, N); 70 | for (n in 1:N) { 71 | target += ordered_logistic_lpmf(Y[n] | mu[n], Intercept); 72 | } 73 | } 74 | // priors including constants 75 | target += normal_lpdf(Intercept | 0, 3); 76 | } 77 | generated quantities { 78 | // compute actual thresholds 79 | vector[nthres] b_Intercept = Intercept; 80 | } 81 | -------------------------------------------------------------------------------- /tests/test_model.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoryMcCartan/adjustr/c401950a404ce30917622d2dfbd9408897e4f05b/tests/test_model.rda -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(adjustr) 3 | 4 | test_check("adjustr") 5 | -------------------------------------------------------------------------------- /tests/testthat/setup.R: -------------------------------------------------------------------------------- 1 | load("../test_model.rda") -------------------------------------------------------------------------------- /tests/testthat/test_logprob.R: -------------------------------------------------------------------------------- 1 | context("Log probability calculations") 2 | 3 | test_that("`make_dens` curries correctly", { 4 | f = make_dens(dexp) 5 | expect_error(f(1, 2), "unused argument") # right number of arguments 6 | test_val = runif(1, 0, 10) 7 | test_x = runif(1, 0.5, 2) 8 | expect_equal(f(test_x)(test_val), dexp(test_x, test_val, log=TRUE)) 9 | }) 10 | 11 | test_that("Constant parameter log probabilities are calculated correctly", { 12 | draws = matrix(-3:3, ncol=1) 13 | lp = calc_lp(y ~ std_normal(), list(y=draws)) 14 | expect_equal(lp, matrix(dnorm(-3:3, log=TRUE), ncol=1)) 15 | }) 16 | 17 | test_that("Model parameter log probabilities are calculated correctly", { 18 | draws = matrix(-3:3, ncol=1) 19 | mus = -4:2 20 | sigmas = c(2, 1, 1, 1, 2, 1, 2) 21 | lp = calc_lp(y ~ normal(mu, sigma), list(y=draws, mu=mus, sigma=sigmas)) 22 | expect_equal(lp, matrix(dnorm(-3:3, mus, sigmas, log=TRUE), ncol=1)) 23 | }) 24 | 25 | test_that("Data is assembled correctly", { 26 | code = eightschools_m@stanmodel@model_code 27 | parsed = parse_model(code) 28 | bd = get_base_data(eightschools_m, list(eta ~ student_t(df, 0, tau)), 29 | parsed$vars, list(df=1:2), "df") 30 | 31 | expect_length(bd, 1) 32 | expect_named(bd[[1]], c("eta", "tau"), ignore.order=TRUE) 33 | expect_equal(dim(bd[[1]]$eta), c(10, 2, 8)) 34 | expect_equal(dim(bd[[1]]$tau), c(10, 2, 1)) 35 | }) 36 | 37 | test_that("MCMC draws are preferred over provided data", { 38 | code = eightschools_m@stanmodel@model_code 39 | parsed = parse_model(code) 40 | bd = get_base_data(eightschools_m, list(eta ~ student_t(2, 0, tau)), 41 | parsed$vars, list(tau=3)) 42 | 43 | expect_length(bd, 1) 44 | expect_named(bd[[1]], c("eta", "tau"), ignore.order=TRUE) 45 | expect_equal(dim(bd[[1]]$eta), c(10, 2, 8)) 46 | expect_equal(dim(bd[[1]]$tau), c(10, 2, 1)) 47 | }) 48 | 49 | test_that("Parameter-less specification data is correctly assembled", { 50 | code = eightschools_m@stanmodel@model_code 51 | parsed = parse_model(code) 52 | bd = get_base_data(eightschools_m, list(y ~ std_normal()), parsed$vars, 53 | list(y=c(28, 8, -3, 7, -1, 1, 18, 12), J=8)) 54 | 55 | expect_length(bd, 1) 56 | expect_named(bd[[1]], "y") 57 | expect_equal(dim(bd[[1]]$y), c(10, 2, 8)) 58 | }) 59 | 60 | 61 | test_that("Error thrown for missing data", { 62 | code = eightschools_m@stanmodel@model_code 63 | parsed = parse_model(code) 64 | 65 | expect_error(get_base_data(eightschools_m, list(eta ~ normal(gamma, sigma)), 66 | parsed$vars, list()), "sigma not found") 67 | expect_error(get_base_data(eightschools_m, list(eta ~ normal(gamma, 2)), 68 | parsed$vars, list()), "gamma not found") 69 | }) 70 | 71 | test_that("Model log probability is correctly calculated", { 72 | code = eightschools_m@stanmodel@model_code 73 | parsed = parse_model(code) 74 | form = eta ~ normal(0, 1) 75 | draws = rstan::extract(eightschools_m, "eta", permuted=FALSE) 76 | exp_lp = 2*apply(dnorm(draws, 0, 1, log=TRUE), 1:2, sum) 77 | lp = calc_original_lp(eightschools_m, list(form, form), parsed$vars, list()) 78 | expect_equal(exp_lp, lp) 79 | }) 80 | 81 | test_that("Alternate specifications log probabilities are correctly calculated", { 82 | code = eightschools_m@stanmodel@model_code 83 | parsed = parse_model(code) 84 | form = eta ~ normal(0, s) 85 | draws = rstan::extract(eightschools_m, "eta", permuted=FALSE) 86 | exp_lp = 2*apply(dnorm(draws, 0, 1, log=TRUE), 1:2, sum) 87 | lp = calc_specs_lp(eightschools_m, list(form, form), parsed$vars, list(), list(list(s=1))) 88 | expect_equal(exp_lp, lp[[1]]) 89 | }) 90 | 91 | 92 | -------------------------------------------------------------------------------- /tests/testthat/test_parsing.R: -------------------------------------------------------------------------------- 1 | context("Stan model parsing") 2 | 3 | test_that("Empty model handled correctly", { 4 | parsed = parse_model("") 5 | expect_equal(length(parsed$vars), 0) 6 | expect_equal(length(parsed$samps), 0) 7 | }) 8 | 9 | test_that("Correct parsed variables", { 10 | correct_vars = c(J= "data", y= "data", sigma="data", mu="parameters", 11 | tau="parameters", eta="parameters", 12 | theta="transformed parameters") 13 | code = eightschools_m@stanmodel@model_code 14 | parsed = parse_model(code) 15 | expect_equal(parsed$vars, correct_vars) 16 | }) 17 | 18 | test_that("Correct parsed sampling statements", { 19 | correct_samp = list(eta ~ std_normal(), y ~ normal(theta, sigma), 20 | mu ~ uniform(-1e+100, 1e+100), tau ~ uniform(-1e+100, 1e+100)) 21 | code = eightschools_m@stanmodel@model_code 22 | parsed = parse_model(code) 23 | expect_equal(parsed$samp, correct_samp) 24 | }) 25 | 26 | test_that("Provided sampling statements can be matched to model", { 27 | model_samp = list(eta ~ std_normal(), y ~ normal(theta, sigma)) 28 | prov_samp = list(eta ~ exponential(5)) 29 | matched = match_sampling_stmts(prov_samp, model_samp) 30 | expect_false(identical(matched, prov_samp)) 31 | expect_length(matched, 1) 32 | expect_equal(matched[[1]], model_samp[[1]]) 33 | }) 34 | 35 | test_that("Extra sampling statements not in model throw an error", { 36 | model_samp = list(eta ~ std_normal(), y ~ normal(theta, sigma)) 37 | prov_samp = list(eta ~ exponential(5), x ~ normal(theta, sigma)) 38 | expect_error(match_sampling_stmts(prov_samp, model_samp), 39 | "No matching sampling statement found for x ~ normal\\(theta, sigma\\)") 40 | }) 41 | 42 | test_that("Variables are correctly extracted from sampling statements", { 43 | expect_equal(get_stmt_vars(y ~ normal(theta %*% eta/2, 2%%4 + sigma)), 44 | c("y", "theta", "eta", "sigma")) 45 | expect_error(get_stmt_vars(y ~ .), "y ~ \\. does not") 46 | }) 47 | -------------------------------------------------------------------------------- /tests/testthat/test_spec.R: -------------------------------------------------------------------------------- 1 | context("Specification creation") 2 | 3 | test_that("Specifications can be created out of any number of formulas with no data", { 4 | spec1 = make_spec(y ~ normal(mu, sigma)) 5 | expect_equal(spec1$params, list(list())) 6 | expect_length(spec1$samp, 1) 7 | expect_is(spec1$samp[[1]], "formula") 8 | 9 | spec2 = make_spec(y ~ normal(mu, sigma), sigma ~ gamma(alpha, beta)) 10 | expect_equal(spec2$params, list(list())) 11 | expect_length(spec2$samp, 2) 12 | expect_is(spec2$samp[[1]], "formula") 13 | expect_is(spec2$samp[[2]], "formula") 14 | }) 15 | 16 | test_that("Empty specifications generate warnigns", { 17 | expect_warning(make_spec(), "No sampling statements provided.") 18 | expect_warning(make_spec(df=1:5), "No sampling statements provided.") 19 | }) 20 | 21 | test_that("Specifications can be created with named vectors", { 22 | spec = make_spec(y ~ std_normal(), df=1:5) 23 | expect_equal(spec$params, purrr::transpose(list(df=1:5))) 24 | expect_error(make_spec(y ~ std_normal(), 1:5), "named vector") 25 | }) 26 | 27 | test_that("Specifications can be created with data frames", { 28 | dat = tibble(df=1:5) 29 | spec = make_spec(y ~ std_normal(), dat) 30 | expect_equal(spec$params, purrr::transpose(as.list(dat))) 31 | }) 32 | 33 | test_that("Specifications can be created with lists", { 34 | dat = tibble(df=1:5) 35 | ldat = purrr::transpose(as.list(dat)) 36 | expect_equal(make_spec(y ~ std_normal(), ldat)$params, ldat) 37 | expect_equal(make_spec(y ~ std_normal(), as.list(dat))$params, ldat) 38 | 39 | expect_error(make_spec(y ~ std_normal(), list(list(a=1:3), 7)), 40 | "List-of-list arguments must be coercible to data frames") 41 | expect_error(make_spec(y ~ std_normal(), list(list(a=1:3), list(c=2:3))), 42 | "NAs found. Check input parameters and format.") 43 | expect_error(make_spec(y ~ std_normal(), list(a=1:3, b=2:3)), 44 | "List-of-vector arguments must be coercible to data frames") 45 | expect_error(make_spec(y ~ std_normal(), list(y~x)), "must be lists of lists") 46 | }) 47 | 48 | test_that("Specifications generics work correctly", { 49 | spec = make_spec(y ~ std_normal(), df=1:5) 50 | 51 | expect_output(print(spec), "Specification parameters:\n df\n 1") 52 | expect_true(is.adjustr_spec(spec)) 53 | expect_equal(length(spec), 5) 54 | 55 | spec2 = make_spec(y ~ std_normal(), df=1:5, mu=0:4) 56 | expect_equal(length(filter(spec2, df < 4)), 3) 57 | expect_equal(slice(arrange(spec2, desc(df)), 1)$params[[1]]$df, 5) 58 | expect_equal(names(rename(spec2, b=df)$params[[1]]), c("b", "mu")) 59 | expect_equal(length(select(spec2, df)$params[[1]]), 1) 60 | expect_equal(filter(make_spec(y ~ normal())), make_spec(y ~ normal())) 61 | expect_is(as.data.frame(spec2), "data.frame") 62 | expect_is(as.data.frame(make_spec(y ~ normal(), x ~ normal())), "data.frame") 63 | }) -------------------------------------------------------------------------------- /tests/testthat/test_use.R: -------------------------------------------------------------------------------- 1 | context("Summarizing and using computed weights") 2 | 3 | test_that("Resampling indices can be made from a vector", { 4 | expect_equal(get_resampling_idxs(c(0, 0, 1)), c(3, 3, 3)) 5 | expect_equal(get_resampling_idxs(c(0, 1, 0), frac=1/3), 2) 6 | expect_true(is.na(get_resampling_idxs(NA))) 7 | }) 8 | 9 | test_that("Resampling indices can be made from a list", { 10 | expect_equal(get_resampling_idxs(list(c(0, 0, 1), c(0, 1, 0))), 11 | list(c(3, 3, 3), c(2, 2, 2))) 12 | expect_equal(get_resampling_idxs(list(NA, c(0, 1, 0))), 13 | list(NA_integer_, c(2, 2, 2))) 14 | }) 15 | 16 | test_that("Resampling indices can be made from an adjustr_weighted object", { 17 | spec = make_spec(eta ~ student_t(7, 0, 1)) 18 | obj = adjust_weights(spec, eightschools_m, keep_bad=TRUE) 19 | obj = get_resampling_idxs(obj) 20 | expect_is(obj$.idxs, "list") 21 | expect_length(obj$.idxs[[1]], 20) 22 | }) 23 | 24 | test_that("Resampling indices fail with negative `frac`", { 25 | expect_error(get_resampling_idxs(c(0, 0, 1), frac=-0.5), "must be nonnegative") 26 | }) 27 | 28 | test_that("Weighted array functions compute correctly", { 29 | y = as.array(1:5) 30 | dim(y) = c(dim(y), 1) 31 | wgt = c(1, 1, 2, 5, 1) 32 | wtd_mean = weighted.mean(y, wgt) 33 | 34 | expect_equal(wtd_array_mean(y, wgt), wtd_mean) 35 | expect_equal(wtd_array_var(y, wgt), weighted.mean((y - wtd_mean)^2, wgt)) 36 | expect_equal(wtd_array_sd(y, wgt), sqrt(weighted.mean((y - wtd_mean)^2, wgt))) 37 | expect_equal(wtd_array_quantile(y, rep(1, 5), 0.2), 1) 38 | expect_equal(wtd_array_median(y, rep(1, 5)), 2.5) 39 | }) 40 | 41 | test_that("Empty call to `summarize` should change nothing", { 42 | obj = tibble(.weights=list(c(1,1,1), c(1,1,4))) 43 | class(obj) = c("adjustr_weighted", class(obj)) 44 | expect_identical(summarize(obj), obj) 45 | }) 46 | 47 | test_that("Non-summary call to `summarize` should throw error", { 48 | obj = tibble(.weights=list(c(1,1,1), c(1,1,4))) 49 | attr(obj, "draws") = list(theta=matrix(c(3,5,7,1,1,1), ncol=2)) 50 | attr(obj, "iter") = 3 51 | class(obj) = c("adjustr_weighted", class(obj)) 52 | 53 | expect_error(summarize(obj, theta), "must summarize posterior draws") 54 | }) 55 | 56 | test_that("Basic summaries are computed correctly", { 57 | obj = tibble(.weights=list(c(1,1,1), c(1,1,4))) 58 | attr(obj, "draws") = list(theta=matrix(c(3,5,7,1,1,1), ncol=2)) 59 | attr(obj, "iter") = 3 60 | class(obj) = c("adjustr_weighted", class(obj)) 61 | 62 | sum1 = summarize(obj, mean(theta[1])) 63 | expect_is(sum1, "adjustr_weighted") 64 | expect_equal(sum1$`mean(theta[1])`, 5:6) 65 | 66 | sum2 = summarize(obj, th = mean(theta)) 67 | expect_is(sum2, "adjustr_weighted") 68 | expect_equal(sum2$th, list(c(5, 1), c(6, 1))) 69 | 70 | sum3 = summarize(obj, W = wasserstein(theta[1])) 71 | expect_is(sum3, "adjustr_weighted") 72 | expect_equal(sum3$W[1], 0) 73 | 74 | expect_error(summarise.adjustr_weighted(as_tibble(obj)), "is not TRUE") 75 | }) 76 | 77 | test_that("`summarize` uses data correctly", { 78 | obj = tibble(.weights=list(c(1,1,1), c(1,1,4))) 79 | attr(obj, "draws") = list(eta=matrix(c(3,5,7), nrow=3, ncol=1)) 80 | attr(obj, "iter") = 3 81 | class(obj) = c("adjustr_weighted", class(obj)) 82 | model_d = list(theta=4) 83 | 84 | expect_equal(summarize(obj, th=mean(theta+eta), .model_data=model_d)$th, c(9,10)) 85 | 86 | attr(obj, "data") = model_d 87 | expect_equal(summarize(obj, th=mean(theta+eta))$th, c(9,10)) 88 | }) 89 | 90 | 91 | test_that("Resampling-based summaries are computed correctly", { 92 | obj = tibble(.weights=list(c(1,0,0), c(0,0,1))) 93 | attr(obj, "draws") = list(theta=matrix(c(3,5,7), nrow=3, ncol=1)) 94 | attr(obj, "iter") = 3 95 | class(obj) = c("adjustr_weighted", class(obj)) 96 | 97 | sum1 = summarize(obj, th=mean(theta), .resampling=TRUE) 98 | expect_equal(sum1$th, c(3,7)) 99 | 100 | sum2 = summarize(obj, th=quantile(theta, 0.05), .resampling=TRUE) 101 | expect_equal(sum2$th, c(3,7)) 102 | }) 103 | 104 | 105 | test_that("Plotting function handles arguments correctly", { 106 | obj = tibble(.weights=list(c(1,0,0), c(0,0,1), c(1,1,1)), 107 | .samp=c("y ~ normal(0, 1)", "y ~ normal(0, 2)", "")) 108 | attr(obj, "draws") = list(theta=matrix(c(3,5,7), nrow=3, ncol=1)) 109 | attr(obj, "iter") = 3 110 | class(obj) = c("adjustr_weighted", class(obj)) 111 | 112 | expect_is(spec_plot(obj, 1, theta), "ggplot") 113 | expect_is(spec_plot(obj, 1, theta, only_mean=TRUE), "ggplot") 114 | 115 | expect_error(spec_plot(obj, 1, theta, outer_level=0.4), "should be less than") 116 | }) 117 | -------------------------------------------------------------------------------- /tests/testthat/test_weights.R: -------------------------------------------------------------------------------- 1 | context("Weight generation and model helpers") 2 | 3 | setup({ 4 | pkg_env$model_d = list(J = 8, 5 | y = c(28, 8, -3, 7, -1, 1, 18, 12), 6 | sigma = c(15, 10, 16, 11, 9, 11, 10, 18)) 7 | }) 8 | 9 | test_that("Identical specification gives warning", { 10 | spec = make_spec(y ~ normal(theta, sigma)) 11 | expect_warning(adjust_weights(spec, eightschools_m, pkg_env$model_d), "equal to old") 12 | }) 13 | 14 | test_that("High Pareto k values lead to discarded weights", { 15 | spec = make_spec(y ~ normal(theta, 1.1*sigma)) 16 | obj = adjust_weights(spec, eightschools_m, pkg_env$model_d) 17 | expect_true(is.na(obj$.weights[[1]])) 18 | }) 19 | 20 | test_that("Weights calculated correctly (normal/inflated)", { 21 | theta_draws = rstan::extract(eightschools_m, "theta", permuted=FALSE) 22 | y = pkg_env$model_d$y 23 | sigma = pkg_env$model_d$sigma 24 | 25 | ref_lp = apply(theta_draws, 1:2, function(theta) sum(dnorm(y, theta, sigma, log=TRUE))) 26 | new_lp = apply(theta_draws, 1:2, function(theta) sum(dnorm(y, theta, 1.1*sigma, log=TRUE))) 27 | lratio = new_lp - ref_lp 28 | dim(lratio) = c(dim(lratio), 1) 29 | r_eff = loo::relative_eff(as.array(exp(-lratio))) 30 | psis_wgt = suppressWarnings(loo::psis(lratio, r_eff=r_eff)) 31 | pareto_k = loo::pareto_k_values(psis_wgt) 32 | weights = as.numeric(loo::weights.importance_sampling(psis_wgt, log=FALSE)) 33 | 34 | spec = make_spec(y ~ normal(theta, 1.1*sigma)) 35 | obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=TRUE, incl_orig=FALSE) 36 | 37 | expect_s3_class(obj, "adjustr_weighted") 38 | expect_s3_class(obj, "tbl_df") 39 | expect_true(is.adjustr_weighted(obj)) 40 | expect_true("draws" %in% names(attributes(obj))) 41 | expect_true("data" %in% names(attributes(obj))) 42 | expect_equal(weights, obj$.weights[[1]]) 43 | expect_equal(pareto_k, obj$.pareto_k) 44 | }) 45 | 46 | test_that("Weights calculated correctly (normal/student_t)", { 47 | theta_draws = rstan::extract(eightschools_m, "theta", permuted=FALSE) 48 | y = pkg_env$model_d$y 49 | sigma = pkg_env$model_d$sigma 50 | 51 | ref_lp = apply(theta_draws, 1:2, function(theta) sum(dnorm(y, theta, sigma, log=TRUE))) 52 | new_lp = apply(theta_draws, 1:2, function(theta) sum(dt((y-theta)/sigma, 6, log=TRUE))) 53 | lratio = new_lp - ref_lp 54 | dim(lratio) = c(dim(lratio), 1) 55 | r_eff = loo::relative_eff(exp(-lratio)) 56 | psis_wgt = suppressWarnings(loo::psis(lratio, r_eff=r_eff)) 57 | pareto_k = loo::pareto_k_values(psis_wgt) 58 | weights = as.numeric(loo::weights.importance_sampling(psis_wgt, log=FALSE)) 59 | 60 | spec = make_spec(y ~ student_t(df, theta, sigma), df=5:6) 61 | obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=TRUE, incl_orig=FALSE) 62 | 63 | expect_equal(weights, obj$.weights[[2]]) 64 | expect_equal(pareto_k, obj$.pareto_k[2]) 65 | }) 66 | 67 | test_that("Weights calculated correctly (no data normal/student_t)", { 68 | eta_draws = rstan::extract(eightschools_m, "eta", permuted=FALSE) 69 | 70 | ref_lp = apply(eta_draws, 1:2, function(eta) sum(dnorm(eta, log=TRUE))) 71 | new_lp = apply(eta_draws, 1:2, function(eta) sum(dt(eta, 4, log=TRUE))) 72 | lratio = new_lp - ref_lp 73 | dim(lratio) = c(dim(lratio), 1) 74 | r_eff = loo::relative_eff(exp(-lratio)) 75 | psis_wgt = suppressWarnings(loo::psis(lratio, r_eff=r_eff)) 76 | pareto_k = loo::pareto_k_values(psis_wgt) 77 | weights = as.numeric(loo::weights.importance_sampling(psis_wgt, log=FALSE)) 78 | 79 | spec = make_spec(eta ~ student_t(4, 0, 1)) 80 | obj = adjust_weights(spec, eightschools_m, keep_bad=TRUE, incl_orig=FALSE) 81 | 82 | expect_equal(weights, obj$.weights[[1]]) 83 | expect_equal(pareto_k, obj$.pareto_k) 84 | }) 85 | 86 | 87 | test_that("Weights extracted correctly", { 88 | spec = make_spec(y ~ student_t(df, theta, sigma), df=5) 89 | obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=TRUE, incl_orig=FALSE) 90 | pulled = pull(obj) 91 | 92 | expect_is(pulled, "numeric") 93 | expect_length(pulled, 20) 94 | 95 | spec2 = make_spec(y ~ student_t(df, theta, sigma), df=5:6) 96 | obj = adjust_weights(spec2, eightschools_m, pkg_env$model_d, keep_bad=TRUE, incl_orig=FALSE) 97 | pulled = pull(obj) 98 | 99 | expect_is(pulled, "list") 100 | expect_length(pulled, 2) 101 | expect_equal(purrr::map_int(pulled, length), c(20, 20)) 102 | }) 103 | 104 | test_that("Sampling statements printed correctly", { 105 | expect_output(extract_samp_stmts(eightschools_m), 106 | "Sampling statements for model 2c8d1d8a30137533422c438f23b83428: 107 | parameter eta ~ std_normal() 108 | parameter mu ~ uniform(-1e+100, 1e+100) 109 | parameter tau ~ uniform(-1e+100, 1e+100) 110 | data y ~ normal(theta, sigma)", fixed=TRUE) 111 | }) 112 | 113 | test_that("Fit objects extracted correctly", { 114 | obj = list(stanfit="stanreg", fit="brmsfit") 115 | 116 | class(obj) = "stanreg" 117 | expect_equal(get_fit_obj(obj), "stanreg") 118 | 119 | class(obj) = "brmsfit" 120 | expect_equal(get_fit_obj(obj), "brmsfit") 121 | 122 | class(obj) = "list" 123 | expect_error(get_fit_obj(obj), "must be of class") 124 | }) 125 | -------------------------------------------------------------------------------- /vignettes/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | *.R 3 | -------------------------------------------------------------------------------- /vignettes/adjustr.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Sensitivity Analysis of a Simple Hierarchical Model" 3 | output: rmarkdown::html_vignette 4 | bibliography: adjustr.bib 5 | vignette: > 6 | %\VignetteIndexEntry{Sensitivity Analysis of a Simple Hierarchical Model} 7 | %\VignetteEngine{knitr::rmarkdown} 8 | %\VignetteEncoding{UTF-8} 9 | --- 10 | 11 | ```{r, include = FALSE} 12 | knitr::opts_chunk$set( 13 | collapse = TRUE, 14 | comment = "#>" 15 | ) 16 | 17 | library(dplyr) 18 | library(rstan) 19 | library(adjustr) 20 | load("eightschools_model.rda") 21 | ``` 22 | 23 | ## Introduction 24 | 25 | This vignette walks through the process of performing sensitivity 26 | analysis using the `adjustr` package for the classic introductory 27 | hierarchical model: the "eight schools" meta-analysis from Chapter 5 28 | of @bda3. 29 | 30 | We begin by specifying and fitting the model, which should be familiar 31 | to most users of Stan. 32 | ```{r eval=F} 33 | library(dplyr) 34 | library(rstan) 35 | library(adjustr) 36 | 37 | model_code = " 38 | data { 39 | int J; // number of schools 40 | real y[J]; // estimated treatment effects 41 | real sigma[J]; // standard error of effect estimates 42 | } 43 | parameters { 44 | real mu; // population treatment effect 45 | real tau; // standard deviation in treatment effects 46 | vector[J] eta; // unscaled deviation from mu by school 47 | } 48 | transformed parameters { 49 | vector[J] theta = mu + tau * eta; // school treatment effects 50 | } 51 | model { 52 | eta ~ std_normal(); 53 | y ~ normal(theta, sigma); 54 | }" 55 | 56 | model_d = list(J = 8, 57 | y = c(28, 8, -3, 7, -1, 1, 18, 12), 58 | sigma = c(15, 10, 16, 11, 9, 11, 10, 18)) 59 | eightschools_m = stan(model_code=model_code, chains=2, data=model_d, 60 | warmup=500, iter=1000) 61 | ``` 62 | 63 | We plot the original estimates for each of the eight schools. 64 | ```{r} 65 | plot(eightschools_m, pars="theta") 66 | ``` 67 | 68 | The model partially pools information, pulling the school-level treatment effects 69 | towards the overall mean. 70 | 71 | It is natural to wonder how much these estimates depend on certain aspects of 72 | our model. The individual and school treatment effects are assumed to follow a 73 | normal distribution, and we have used a uniform prior on the population 74 | parameters `mu` and `tau`. 75 | 76 | The basic __adjustr__ workflow is as follows: 77 | 78 | 1. Use `make_spec` to specify the set of alternative model specifications you'd 79 | like to fit. 80 | 81 | 2. Use `adjust_weights` to calculate importance sampling weights which 82 | approximate the posterior of each alternative specification. 83 | 84 | 3. Use `summarize` and `spec_plot` to examine posterior quantities of interest 85 | for each alternative specification, in order to assess the sensitivity of the 86 | underlying model. 87 | 88 | ## Basic Workflow Example 89 | 90 | First suppose we want to examine the effect of our choice of uniform prior 91 | on `mu` and `tau`. We begin by specifying an alternative model in which 92 | these parameters have more informative priors. This just requires 93 | passing the `make_spec` function the new sampling statements we'd like to 94 | use. These replace any in the original model (`mu` and `tau` have implicit 95 | improper uniform priors, since the original model does not have any sampling 96 | statements for them). 97 | ```{r} 98 | spec = make_spec(mu ~ normal(0, 20), tau ~ exponential(5)) 99 | print(spec) 100 | ``` 101 | 102 | Then we compute importance sampling weights to approximate the posterior under 103 | this alternative model. 104 | ```{r include=F} 105 | adjusted = adjust_weights(spec, eightschools_m, keep_bad=TRUE) 106 | ``` 107 | ```{r eval=F} 108 | adjusted = adjust_weights(spec, eightschools_m) 109 | ``` 110 | 111 | The `adjust_weights` function returns a data frame 112 | containing a summary of the alternative model and a list-column named `.weights` 113 | containing the importance weights. The last row of the table by default 114 | corresponds to the original model specification. The table also includes the 115 | diagnostic Pareto *k*-value. When this value exceeds 0.7, importance sampling is 116 | unreliable, and by default `adjust_weights` discards weights with a Pareto *k* 117 | above 0.7 (the respective rows in `adjusted` are kept, but the `weights` column 118 | is set to `NA_real_`). 119 | ```{r} 120 | print(adjusted) 121 | ``` 122 | 123 | Finally, we can examine how these alternative priors have changed our posterior 124 | inference. We use `summarize` to calculate these under the alternative model. 125 | ```{r} 126 | summarize(adjusted, mean(mu), var(mu)) 127 | ``` 128 | We see that the more informative priors have pulled the posterior distribution 129 | of `mu` towards zero and made it less variable. 130 | 131 | ## Multiple Alternative Specifications 132 | What if instead we are concerned about our distributional assumption on the 133 | school treatment effects? We could probe this assumption by fitting a series of 134 | models where `eta` had a Student's *t* distribution, with varying degrees of 135 | freedom. 136 | 137 | The `make_spec` function handles this easily. 138 | ```{r} 139 | spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10) 140 | print(spec) 141 | ``` 142 | Notice how we have parameterized the alternative sampling statement with 143 | a variable `df`, and then provided the values `df` takes in another argument 144 | to `make_spec`. 145 | 146 | As before, we compute importance sampling weights to approximate the posterior 147 | under these alternative models. Here, for the purposes of illustration, 148 | we are using `keep_bad=TRUE` to compute weights even when the Pareto _k_ diagnostic 149 | value is above 0.7. In practice, the alternative models should be completely 150 | re-fit in Stan. 151 | ```{r} 152 | adjusted = adjust_weights(spec, eightschools_m, keep_bad=TRUE) 153 | ``` 154 | Now, `adjusted` has ten rows, one for each alternative model. 155 | ```{r} 156 | print(adjusted) 157 | ``` 158 | 159 | To examine the impact of these model changes, we can plot the posterior for 160 | a quantity of interest versus the degrees of freedom for the *t* distribution. 161 | The package provides the `spec_plot` function which takes an x-axis specification 162 | parameter and a y-axis posterior quantity (which must evaluate to a single 163 | number per posterior draw). The dashed line shows the posterior median 164 | under the original model. 165 | ```{r} 166 | spec_plot(adjusted, df, mu) 167 | spec_plot(adjusted, df, theta[3]) 168 | ``` 169 | 170 | It appears that changing the distribution of `eta`/`theta` from normal to 171 | *t* has a small effect on posterior inferences (although, as noted above, 172 | these inferences are unreliable as _k_ > 0.7). 173 | 174 | By default, the function plots an inner 80\% credible interval and an outer 175 | 95\% credible interval, but these can be changed by the user. 176 | 177 | We can also measure the distance between the new and original posterior 178 | marginals by using the special `wasserstein()` function available in 179 | `summarize()`: 180 | ```{r} 181 | summarize(adjusted, wasserstein(mu)) 182 | ``` 183 | As we would expect, the 1-Wasserstein distance decreases as the degrees of 184 | freedom increase. In general, we can compute the _p_-Wasserstein distance 185 | by passing an extra `p` parameter to `wasserstein()`. 186 | 187 | 188 | ### 189 | -------------------------------------------------------------------------------- /vignettes/adjustr.bib: -------------------------------------------------------------------------------- 1 | @book{bda3, 2 | address = {London}, 3 | author = {Andrew Gelman and J.~B.~Carlin and Hal S.~Stern and David B.~Dunson and Aki Vehtari and Donald B.~Rubin}, 4 | edition = {3rd}, 5 | publisher = {CRC Press}, 6 | title = {Bayesian Data Analysis}, 7 | year = {2013}} 8 | -------------------------------------------------------------------------------- /vignettes/eightschools_model.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoryMcCartan/adjustr/c401950a404ce30917622d2dfbd9408897e4f05b/vignettes/eightschools_model.rda --------------------------------------------------------------------------------