├── .Rbuildignore ├── .github ├── .gitignore ├── CODE_OF_CONDUCT.md └── workflows │ ├── R-CMD-check-hard.yaml │ ├── R-CMD-check.yaml │ ├── lock.yaml │ ├── pkgdown.yaml │ ├── pr-commands.yaml │ └── test-coverage.yaml ├── .gitignore ├── DESCRIPTION ├── LICENSE ├── LICENSE.md ├── NAMESPACE ├── NEWS.md ├── R ├── bound_prediction.R ├── cal-apply-binary.R ├── cal-apply-impl.R ├── cal-apply-multi.R ├── cal-apply-regression.R ├── cal-apply.R ├── cal-estimate-beta.R ├── cal-estimate-isotonic.R ├── cal-estimate-linear.R ├── cal-estimate-logistic.R ├── cal-estimate-multinom.R ├── cal-estimate-none.R ├── cal-estimate-utils.R ├── cal-pkg-check.R ├── cal-plot-breaks.R ├── cal-plot-logistic.R ├── cal-plot-regression.R ├── cal-plot-utils.R ├── cal-plot-windowed.R ├── cal-utils.R ├── cal-validate.R ├── class-pred.R ├── conformal_infer.R ├── conformal_infer_cv.R ├── conformal_infer_quantile.R ├── conformal_infer_split.R ├── data.R ├── import-standalone-obj-type.R ├── import-standalone-types-check.R ├── make_class_pred.R ├── printing.R ├── probably-package.R ├── reexports.R ├── threshold_perf.R ├── utils.R ├── vctrs-compat.R └── zzz.R ├── README.Rmd ├── README.md ├── _pkgdown.yml ├── codecov.yml ├── data ├── boosting_predictions.RData ├── datalist ├── segment_logistic.RData ├── segment_naive_bayes.RData └── species_probs.RData ├── man-roxygen ├── metrics_both.R ├── metrics_cls.R ├── metrics_reg.R └── multiclass.R ├── man ├── append_class_pred.Rd ├── as_class_pred.Rd ├── boosting_predictions.Rd ├── bound_prediction.Rd ├── cal_apply.Rd ├── cal_binary_tables.Rd ├── cal_estimate_beta.Rd ├── cal_estimate_isotonic.Rd ├── cal_estimate_isotonic_boot.Rd ├── cal_estimate_linear.Rd ├── cal_estimate_logistic.Rd ├── cal_estimate_multinomial.Rd ├── cal_estimate_none.Rd ├── cal_plot_breaks.Rd ├── cal_plot_logistic.Rd ├── cal_plot_regression.Rd ├── cal_plot_windowed.Rd ├── cal_validate_beta.Rd ├── cal_validate_isotonic.Rd ├── cal_validate_isotonic_boot.Rd ├── cal_validate_linear.Rd ├── cal_validate_logistic.Rd ├── cal_validate_multinomial.Rd ├── cal_validate_none.Rd ├── class_pred.Rd ├── collect_metrics.cal_rset.Rd ├── collect_predictions.cal_rset.Rd ├── control_conformal_full.Rd ├── figures │ └── logo.png ├── int_conformal_cv.Rd ├── int_conformal_full.Rd ├── int_conformal_quantile.Rd ├── int_conformal_split.Rd ├── is_class_pred.Rd ├── levels.class_pred.Rd ├── locate-equivocal.Rd ├── make_class_pred.Rd ├── predict.int_conformal_full.Rd ├── probably-package.Rd ├── reexports.Rd ├── reportable_rate.Rd ├── required_pkgs.cal_object.Rd ├── rmd │ └── parallel_intervals.Rmd ├── segment_naive_bayes.Rd ├── species_probs.Rd └── threshold_perf.Rd ├── probably.Rproj ├── revdep ├── .gitignore ├── README.md ├── cran.md ├── email.yml ├── failures.md └── problems.md ├── tests ├── testthat.R └── testthat │ ├── _snaps │ ├── bound-prediction.md │ ├── cal-estimate-beta.md │ ├── cal-estimate-isotonic.md │ ├── cal-estimate-linear.md │ ├── cal-estimate-logistic.md │ ├── cal-estimate-multinomial.md │ ├── cal-estimate-none.md │ ├── cal-estimate.md │ ├── cal-plot.md │ ├── cal-validate.md │ ├── class-pred.md │ ├── conformal-intervals-quantile.md │ ├── conformal-intervals-split.md │ ├── conformal-intervals.md │ ├── make-class-pred.md │ └── threshold-perf.md │ ├── cal_files │ ├── binary_sim.rds │ ├── fit_rs.rds │ ├── multiclass_ames.rds │ ├── reg_sim.rds │ └── sim_multi.rds │ ├── helper-cal.R │ ├── test-bound-prediction.R │ ├── test-cal-apply.R │ ├── test-cal-estimate-beta.R │ ├── test-cal-estimate-isotonic.R │ ├── test-cal-estimate-linear.R │ ├── test-cal-estimate-logistic.R │ ├── test-cal-estimate-multinomial.R │ ├── test-cal-estimate-none.R │ ├── test-cal-estimate.R │ ├── test-cal-pkg-check.R │ ├── test-cal-plot.R │ ├── test-cal-validate-multiclass.R │ ├── test-cal-validate.R │ ├── test-class-pred.R │ ├── test-conformal-intervals-quantile.R │ ├── test-conformal-intervals-split.R │ ├── test-conformal-intervals.R │ ├── test-make-class-pred.R │ ├── test-threshold-perf.R │ └── test-vctrs-compat.R └── vignettes ├── .gitignore ├── equivocal-zones.Rmd └── where-to-use.Rmd /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^CRAN-RELEASE$ 2 | ^docs$ 3 | ^_pkgdown\.yml$ 4 | ^codecov\.yml$ 5 | ^README\.Rmd$ 6 | ^\.travis\.yml$ 7 | ^probably\.Rproj$ 8 | ^\.Rproj\.user$ 9 | ^revdep$ 10 | ^\.github$ 11 | ^LICENSE\.md$ 12 | ^pkgdown$ 13 | ^CRAN-SUBMISSION$ 14 | ^man-roxygen$ 15 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check-hard.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | # 4 | # NOTE: This workflow only directly installs "hard" dependencies, i.e. Depends, 5 | # Imports, and LinkingTo dependencies. Notably, Suggests dependencies are never 6 | # installed, with the exception of testthat, knitr, and rmarkdown. The cache is 7 | # never used to avoid accidentally restoring a cache containing a suggested 8 | # dependency. 9 | on: 10 | push: 11 | branches: [main, master] 12 | pull_request: 13 | 14 | name: R-CMD-check-hard.yaml 15 | 16 | permissions: read-all 17 | 18 | jobs: 19 | check-no-suggests: 20 | runs-on: ${{ matrix.config.os }} 21 | 22 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 23 | 24 | strategy: 25 | fail-fast: false 26 | matrix: 27 | config: 28 | - {os: ubuntu-latest, r: 'release'} 29 | 30 | env: 31 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 32 | R_KEEP_PKG_SOURCE: yes 33 | 34 | steps: 35 | - uses: actions/checkout@v4 36 | 37 | - uses: r-lib/actions/setup-pandoc@v2 38 | 39 | - uses: r-lib/actions/setup-r@v2 40 | with: 41 | r-version: ${{ matrix.config.r }} 42 | http-user-agent: ${{ matrix.config.http-user-agent }} 43 | use-public-rspm: true 44 | 45 | - uses: r-lib/actions/setup-r-dependencies@v2 46 | with: 47 | dependencies: '"hard"' 48 | cache: false 49 | extra-packages: | 50 | any::rcmdcheck 51 | any::testthat 52 | any::knitr 53 | any::rmarkdown 54 | needs: check 55 | 56 | - uses: r-lib/actions/check-r-package@v2 57 | with: 58 | upload-snapshots: true 59 | build_args: 'c("--no-manual","--compact-vignettes=gs+qpdf")' 60 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | # 4 | # NOTE: This workflow is overkill for most R packages and 5 | # check-standard.yaml is likely a better choice. 6 | # usethis::use_github_action("check-standard") will install it. 7 | on: 8 | push: 9 | branches: [main, master] 10 | pull_request: 11 | branches: [main, master] 12 | 13 | name: R-CMD-check.yaml 14 | 15 | permissions: read-all 16 | 17 | jobs: 18 | R-CMD-check: 19 | runs-on: ${{ matrix.config.os }} 20 | 21 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 22 | 23 | strategy: 24 | fail-fast: false 25 | matrix: 26 | config: 27 | - {os: macos-latest, r: 'release'} 28 | 29 | - {os: windows-latest, r: 'release'} 30 | # use 4.0 or 4.1 to check with rtools40's older compiler 31 | - {os: windows-latest, r: 'oldrel-4'} 32 | 33 | - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} 34 | - {os: ubuntu-latest, r: 'release'} 35 | - {os: ubuntu-latest, r: 'oldrel-1'} 36 | - {os: ubuntu-latest, r: 'oldrel-2'} 37 | - {os: ubuntu-latest, r: 'oldrel-3'} 38 | #- {os: ubuntu-latest, r: 'oldrel-4'} 39 | 40 | env: 41 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 42 | R_KEEP_PKG_SOURCE: yes 43 | 44 | steps: 45 | - uses: actions/checkout@v4 46 | 47 | - uses: r-lib/actions/setup-pandoc@v2 48 | 49 | - uses: r-lib/actions/setup-r@v2 50 | with: 51 | r-version: ${{ matrix.config.r }} 52 | http-user-agent: ${{ matrix.config.http-user-agent }} 53 | use-public-rspm: true 54 | 55 | - uses: r-lib/actions/setup-r-dependencies@v2 56 | with: 57 | extra-packages: any::rcmdcheck 58 | needs: check 59 | 60 | - uses: r-lib/actions/check-r-package@v2 61 | with: 62 | upload-snapshots: true 63 | build_args: 'c("--no-manual","--compact-vignettes=gs+qpdf")' 64 | -------------------------------------------------------------------------------- /.github/workflows/lock.yaml: -------------------------------------------------------------------------------- 1 | name: 'Lock Threads' 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 * * *' 6 | 7 | jobs: 8 | lock: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: dessant/lock-threads@v2 12 | with: 13 | github-token: ${{ github.token }} 14 | issue-lock-inactive-days: '14' 15 | # issue-exclude-labels: '' 16 | # issue-lock-labels: 'outdated' 17 | issue-lock-comment: > 18 | This issue has been automatically locked. If you believe you have 19 | found a related problem, please file a new issue (with a reprex: 20 | ) and link to this issue. 21 | issue-lock-reason: '' 22 | pr-lock-inactive-days: '14' 23 | # pr-exclude-labels: 'wip' 24 | pr-lock-labels: '' 25 | pr-lock-comment: > 26 | This pull request has been automatically locked. If you believe you 27 | have found a related problem, please file a new issue (with a reprex: 28 | ) and link to this issue. 29 | pr-lock-reason: '' 30 | # process-only: 'issues' 31 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | branches: [main, master] 8 | release: 9 | types: [published] 10 | workflow_dispatch: 11 | 12 | name: pkgdown.yaml 13 | 14 | permissions: read-all 15 | 16 | jobs: 17 | pkgdown: 18 | runs-on: ubuntu-latest 19 | # Only restrict concurrency for non-PR jobs 20 | concurrency: 21 | group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }} 22 | env: 23 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 24 | permissions: 25 | contents: write 26 | steps: 27 | - uses: actions/checkout@v4 28 | 29 | - uses: r-lib/actions/setup-pandoc@v2 30 | 31 | - uses: r-lib/actions/setup-r@v2 32 | with: 33 | use-public-rspm: true 34 | 35 | - uses: r-lib/actions/setup-r-dependencies@v2 36 | with: 37 | extra-packages: any::pkgdown, local::. 38 | needs: website 39 | 40 | - name: Build site 41 | run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE) 42 | shell: Rscript {0} 43 | 44 | - name: Deploy to GitHub pages 🚀 45 | if: github.event_name != 'pull_request' 46 | uses: JamesIves/github-pages-deploy-action@v4.5.0 47 | with: 48 | clean: false 49 | branch: gh-pages 50 | folder: docs 51 | -------------------------------------------------------------------------------- /.github/workflows/pr-commands.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | issue_comment: 5 | types: [created] 6 | 7 | name: pr-commands.yaml 8 | 9 | permissions: read-all 10 | 11 | jobs: 12 | document: 13 | if: ${{ github.event.issue.pull_request && (github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') && startsWith(github.event.comment.body, '/document') }} 14 | name: document 15 | runs-on: ubuntu-latest 16 | env: 17 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 18 | permissions: 19 | contents: write 20 | steps: 21 | - uses: actions/checkout@v4 22 | 23 | - uses: r-lib/actions/pr-fetch@v2 24 | with: 25 | repo-token: ${{ secrets.GITHUB_TOKEN }} 26 | 27 | - uses: r-lib/actions/setup-r@v2 28 | with: 29 | use-public-rspm: true 30 | 31 | - uses: r-lib/actions/setup-r-dependencies@v2 32 | with: 33 | extra-packages: any::roxygen2 34 | needs: pr-document 35 | 36 | - name: Document 37 | run: roxygen2::roxygenise() 38 | shell: Rscript {0} 39 | 40 | - name: commit 41 | run: | 42 | git config --local user.name "$GITHUB_ACTOR" 43 | git config --local user.email "$GITHUB_ACTOR@users.noreply.github.com" 44 | git add man/\* NAMESPACE 45 | git commit -m 'Document' 46 | 47 | - uses: r-lib/actions/pr-push@v2 48 | with: 49 | repo-token: ${{ secrets.GITHUB_TOKEN }} 50 | 51 | style: 52 | if: ${{ github.event.issue.pull_request && (github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') && startsWith(github.event.comment.body, '/style') }} 53 | name: style 54 | runs-on: ubuntu-latest 55 | env: 56 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 57 | permissions: 58 | contents: write 59 | steps: 60 | - uses: actions/checkout@v4 61 | 62 | - uses: r-lib/actions/pr-fetch@v2 63 | with: 64 | repo-token: ${{ secrets.GITHUB_TOKEN }} 65 | 66 | - uses: r-lib/actions/setup-r@v2 67 | 68 | - name: Install dependencies 69 | run: install.packages("styler") 70 | shell: Rscript {0} 71 | 72 | - name: Style 73 | run: styler::style_pkg() 74 | shell: Rscript {0} 75 | 76 | - name: commit 77 | run: | 78 | git config --local user.name "$GITHUB_ACTOR" 79 | git config --local user.email "$GITHUB_ACTOR@users.noreply.github.com" 80 | git add \*.R 81 | git commit -m 'Style' 82 | 83 | - uses: r-lib/actions/pr-push@v2 84 | with: 85 | repo-token: ${{ secrets.GITHUB_TOKEN }} 86 | -------------------------------------------------------------------------------- /.github/workflows/test-coverage.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | branches: [main, master] 8 | 9 | name: test-coverage.yaml 10 | 11 | permissions: read-all 12 | 13 | jobs: 14 | test-coverage: 15 | runs-on: ubuntu-latest 16 | env: 17 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | 22 | - uses: r-lib/actions/setup-r@v2 23 | with: 24 | use-public-rspm: true 25 | 26 | - uses: r-lib/actions/setup-r-dependencies@v2 27 | with: 28 | extra-packages: any::covr, any::xml2 29 | needs: coverage 30 | 31 | - name: Test coverage 32 | run: | 33 | cov <- covr::package_coverage( 34 | quiet = FALSE, 35 | clean = FALSE, 36 | install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package") 37 | ) 38 | covr::to_cobertura(cov) 39 | shell: Rscript {0} 40 | 41 | - uses: codecov/codecov-action@v4 42 | with: 43 | fail_ci_if_error: ${{ github.event_name != 'pull_request' && true || false }} 44 | file: ./cobertura.xml 45 | plugin: noop 46 | disable_search: true 47 | token: ${{ secrets.CODECOV_TOKEN }} 48 | 49 | - name: Show testthat output 50 | if: always() 51 | run: | 52 | ## -------------------------------------------------------------------- 53 | find '${{ runner.temp }}/package' -name 'testthat.Rout*' -exec cat '{}' \; || true 54 | shell: bash 55 | 56 | - name: Upload test results 57 | if: failure() 58 | uses: actions/upload-artifact@v4 59 | with: 60 | name: coverage-test-failures 61 | path: ${{ runner.temp }}/package 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | inst/doc 2 | .Rproj.user 3 | .Rhistory 4 | .RData 5 | .DS_Store 6 | docs 7 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: probably 2 | Title: Tools for Post-Processing Predicted Values 3 | Version: 1.1.0.9000 4 | Authors@R: c( 5 | person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), 6 | person("Davis", "Vaughan", , "davis@posit.co", role = "aut"), 7 | person("Edgar", "Ruiz", , "edgar@posit.co", role = "aut"), 8 | person("Posit Software, PBC", role = c("cph", "fnd"), 9 | comment = c(ROR = "03wc8by49")) 10 | ) 11 | Description: Models can be improved by post-processing class 12 | probabilities, by: recalibration, conversion to hard probabilities, 13 | assessment of equivocal zones, and other activities. 'probably' 14 | contains tools for conducting these operations as well as calibration 15 | tools and conformal inference techniques for regression models. 16 | License: MIT + file LICENSE 17 | URL: https://github.com/tidymodels/probably, 18 | https://probably.tidymodels.org 19 | BugReports: https://github.com/tidymodels/probably/issues 20 | Depends: 21 | R (>= 4.1) 22 | Imports: 23 | butcher, 24 | cli, 25 | dplyr (>= 1.1.0), 26 | furrr, 27 | generics (>= 0.1.3), 28 | ggplot2, 29 | hardhat, 30 | pillar, 31 | purrr, 32 | rlang (>= 1.1.0), 33 | tidyr (>= 1.3.0), 34 | tidyselect (>= 1.1.2), 35 | tune (>= 1.1.2), 36 | vctrs (>= 0.4.1), 37 | withr, 38 | workflows (>= 1.1.4), 39 | yardstick (>= 1.3.0) 40 | Suggests: 41 | betacal, 42 | covr, 43 | knitr, 44 | MASS, 45 | mgcv, 46 | modeldata (>= 1.1.0), 47 | nnet, 48 | parsnip (>= 1.2.0), 49 | quantregForest, 50 | randomForest, 51 | recipes, 52 | rmarkdown, 53 | rsample, 54 | testthat (>= 3.0.0) 55 | VignetteBuilder: 56 | knitr 57 | ByteCompile: true 58 | Config/Needs/website: tidyverse/tidytemplate 59 | Config/testthat/edition: 3 60 | Encoding: UTF-8 61 | LazyData: true 62 | Roxygen: list(markdown = TRUE) 63 | RoxygenNote: 7.3.2 64 | Collate: 65 | 'bound_prediction.R' 66 | 'cal-apply-binary.R' 67 | 'cal-apply-impl.R' 68 | 'cal-apply-multi.R' 69 | 'cal-apply-regression.R' 70 | 'cal-apply.R' 71 | 'cal-estimate-beta.R' 72 | 'cal-estimate-isotonic.R' 73 | 'cal-estimate-linear.R' 74 | 'cal-estimate-logistic.R' 75 | 'cal-estimate-multinom.R' 76 | 'cal-estimate-utils.R' 77 | 'cal-estimate-none.R' 78 | 'cal-pkg-check.R' 79 | 'cal-plot-breaks.R' 80 | 'cal-plot-logistic.R' 81 | 'cal-plot-regression.R' 82 | 'cal-plot-utils.R' 83 | 'cal-plot-windowed.R' 84 | 'cal-utils.R' 85 | 'cal-validate.R' 86 | 'class-pred.R' 87 | 'conformal_infer.R' 88 | 'conformal_infer_cv.R' 89 | 'conformal_infer_quantile.R' 90 | 'conformal_infer_split.R' 91 | 'data.R' 92 | 'import-standalone-obj-type.R' 93 | 'import-standalone-types-check.R' 94 | 'make_class_pred.R' 95 | 'printing.R' 96 | 'probably-package.R' 97 | 'reexports.R' 98 | 'threshold_perf.R' 99 | 'utils.R' 100 | 'vctrs-compat.R' 101 | 'zzz.R' 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2023 2 | COPYRIGHT HOLDER: probably authors 3 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2023 probably authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /R/bound_prediction.R: -------------------------------------------------------------------------------- 1 | #' Truncate a numeric prediction column 2 | #' 3 | #' For user-defined `lower_limit` and/or `upper_limit` bound, ensure that the values in the 4 | #' `.pred` column are coerced to these bounds. 5 | #' 6 | #' @param x A data frame that contains a numeric column named `.pred`. 7 | #' @param lower_limit,upper_limit Single numerics (or `NA`) that define 8 | #' constraints on `.pred`. 9 | #' @param call The call to be displayed in warnings or errors. 10 | #' @return `x` with potentially adjusted values. 11 | #' @examples 12 | #' data(solubility_test, package = "yardstick") 13 | #' 14 | #' names(solubility_test) <- c("solubility", ".pred") 15 | #' 16 | #' bound_prediction(solubility_test, lower_limit = -1) 17 | #' @export 18 | bound_prediction <- function(x, lower_limit = -Inf, upper_limit = Inf, 19 | call = rlang::current_env()) { 20 | check_data_frame(x, call = call) 21 | 22 | if (!any(names(x) == ".pred")) { 23 | cli::cli_abort("The argument {.arg x} should have a column named {.code .pred}.", 24 | call = call) 25 | } 26 | if (!is.numeric(x$.pred)) { 27 | cli::cli_abort("Column {.code .pred} should be numeric.", call = call) 28 | } 29 | 30 | check_number_decimal(lower_limit, allow_na = TRUE, call = call) 31 | check_number_decimal(upper_limit, allow_na = TRUE, call = call) 32 | 33 | if (!is.na(lower_limit)) { 34 | x$.pred <- ifelse(x$.pred < lower_limit, lower_limit, x$.pred) 35 | } 36 | 37 | if (!is.na(upper_limit)) { 38 | x$.pred <- ifelse(x$.pred > upper_limit, upper_limit, x$.pred) 39 | } 40 | x 41 | } 42 | 43 | -------------------------------------------------------------------------------- /R/cal-apply-binary.R: -------------------------------------------------------------------------------- 1 | # ------------------------------- Methods -------------------------------------- 2 | 3 | cal_apply_binary <- function(object, .data, pred_class) { 4 | UseMethod("cal_apply_binary") 5 | } 6 | 7 | #' @export 8 | cal_apply_binary.cal_estimate_logistic <- function(object, 9 | .data, 10 | pred_class = NULL, 11 | ...) { 12 | apply_model_predict( 13 | object = object, 14 | .data = .data 15 | ) 16 | } 17 | 18 | #' @export 19 | cal_apply_binary.cal_estimate_logistic_spline <- function(object, 20 | .data, 21 | pred_class = NULL, 22 | ...) { 23 | apply_model_predict( 24 | object = object, 25 | .data = .data 26 | ) 27 | } 28 | 29 | #---------------------------- Adjust implementations --------------------------- 30 | 31 | binary_repredict <- function(x, predict_data, object) { 32 | if (is.null(x$filter)) { 33 | new_data <- predict_data 34 | } else { 35 | new_data <- dplyr::filter(predict_data, !!x$filter) 36 | } 37 | preds <- predict(x$estimate, newdata = new_data, type = "response") 38 | preds <- 1 - preds 39 | lvls <- nm_levels(object$levels) 40 | new_data[lvls[1]] <- preds 41 | new_data[lvls[2]] <- 1 - preds 42 | new_data 43 | } 44 | 45 | apply_model_predict <- function(object, .data) { 46 | if (object$type == "binary") { 47 | .data <- 48 | purrr::map( 49 | object$estimates, 50 | binary_repredict, 51 | predict_data = .data, 52 | object = object 53 | ) |> 54 | purrr::reduce(dplyr::bind_rows) 55 | } 56 | .data 57 | } 58 | -------------------------------------------------------------------------------- /R/cal-apply-impl.R: -------------------------------------------------------------------------------- 1 | #---------------------------------- >> Interval -------------------------------- 2 | apply_interval_impl <- function(object, .data, multi = FALSE, method = "auto") { 3 | # Iterates through each group 4 | new_data <- object$estimates |> 5 | purrr::map( 6 | ~ apply_interval_column( 7 | .data = .data, 8 | est_filter = .x$filter, 9 | estimates = .x$estimates 10 | ) 11 | ) |> 12 | purrr::reduce(dplyr::bind_rows) 13 | 14 | apply_adjustment(new_data, object) 15 | } 16 | 17 | # Iterates through each prediction column 18 | apply_interval_column <- function(.data, est_filter, estimates) { 19 | if (is.null(est_filter)) { 20 | df <- .data 21 | } else { 22 | df <- dplyr::filter(.data, !!est_filter) 23 | } 24 | 25 | ret <- estimates |> 26 | purrr::list_transpose(simplify = FALSE) |> 27 | purrr::imap( 28 | ~ apply_interval_estimate(estimate = .x, df = df, est_name = .y) 29 | ) 30 | 31 | names_ret <- names(ret) 32 | for (i in seq_along(names_ret)) { 33 | df[, names_ret[i]] <- ret[[names_ret[i]]] 34 | } 35 | df 36 | } 37 | 38 | # Iterates through each model run 39 | apply_interval_estimate <- function(estimate, df, est_name) { 40 | # Handles single quoted variable names, which are typically created 41 | # when there are spaces in the original variable name 42 | df_names <- names(df) 43 | if (!(est_name %in% df_names)) { 44 | test_name <- sub("`", "", est_name) 45 | test_name <- sub("`", "", test_name) 46 | if (test_name %in% df_names) { 47 | est_name <- test_name 48 | } else { 49 | cli::cli_abort("Variable {.var {est_name}} was not found in data.") 50 | } 51 | } 52 | 53 | ret <- estimate |> 54 | purrr::map( 55 | apply_interval_single, 56 | df = df, 57 | est_name = est_name 58 | ) 59 | 60 | if (length(estimate) > 1) { 61 | ret <- ret |> 62 | data.frame() |> 63 | rowMeans() 64 | } else { 65 | ret <- ret[[1]] 66 | } 67 | 68 | ret 69 | } 70 | 71 | apply_interval_single <- function(estimates_table, df, est_name) { 72 | y <- estimates_table$.adj_estimate 73 | find_interval <- findInterval( 74 | x = df[[est_name]], 75 | vec = estimates_table$.estimate 76 | ) 77 | find_interval[find_interval == 0] <- 1 78 | ret <- y[find_interval] 79 | ret 80 | } 81 | 82 | 83 | #---------------------------- >> Beta Predict ---------------------------------- 84 | 85 | apply_beta_impl <- function(object, .data) { 86 | # Iterates through each group 87 | new_data <- 88 | purrr::map( 89 | object$estimates, 90 | ~ apply_beta_column( 91 | .data = .data, 92 | est_filter = .x$filter, 93 | estimates = .x$estimate 94 | ) 95 | ) |> 96 | purrr::reduce(dplyr::bind_rows) 97 | 98 | apply_adjustment(new_data, object) 99 | } 100 | 101 | # Iterates through each prediction column 102 | apply_beta_column <- function(.data, est_filter, estimates) { 103 | if (is.null(est_filter)) { 104 | df <- .data 105 | } else { 106 | df <- dplyr::filter(.data, !!est_filter) 107 | } 108 | 109 | ret <- 110 | purrr::imap(estimates, ~ apply_beta_single(model = .x, df = df, est_name = .y)) 111 | 112 | names_ret <- names(ret) 113 | for (i in seq_along(names_ret)) { 114 | df[, names_ret[i]] <- ret[[names_ret[i]]] 115 | } 116 | df 117 | } 118 | 119 | apply_beta_single <- function(model, df, est_name) { 120 | p <- df[[est_name]] 121 | betacal::beta_predict( 122 | p = p, 123 | calib = model 124 | ) 125 | } 126 | 127 | # ------------------------------ Adjustment ----------------------------------- 128 | 129 | apply_adjustment <- function(new_data, object) { 130 | if (object$type == "binary") { 131 | lvls <- nm_levels(object$levels) 132 | new_data[, lvls[[2]]] <- 1 - new_data[, lvls[[1]]] 133 | } 134 | 135 | if (object$type == "one_vs_all") { 136 | ols <- purrr::map_chr(object$levels, rlang::as_name) 137 | rs <- rowSums(new_data[, ols]) 138 | for (i in seq_along(ols)) { 139 | new_data[, ols[i]] <- new_data[, ols[i]] / rs 140 | } 141 | } 142 | 143 | new_data 144 | } 145 | -------------------------------------------------------------------------------- /R/cal-apply-multi.R: -------------------------------------------------------------------------------- 1 | # ------------------------------- Methods -------------------------------------- 2 | 3 | cal_apply_multi <- function(object, .data, pred_class) { 4 | UseMethod("cal_apply_multi") 5 | } 6 | 7 | #' @export 8 | cal_apply_multi.cal_estimate_multinomial <- 9 | function(object, .data, pred_class = NULL, ...) { 10 | apply_multi_predict( 11 | object = object, 12 | .data = .data 13 | ) 14 | } 15 | 16 | #---------------------------- Adjust implementations --------------------------- 17 | 18 | #---------------------------- >> Single Predict -------------------------------- 19 | 20 | apply_multi_predict <- function(object, .data) { 21 | if (inherits(object$estimates[[1]]$estimate, "gam")) { 22 | prob_type <- "response" 23 | } else { 24 | prob_type <- "probs" 25 | } 26 | preds <- object$estimates[[1]]$estimate |> 27 | predict(newdata = .data, type = prob_type) 28 | 29 | lvls <- nm_levels(object$levels) 30 | colnames(preds) <- lvls 31 | preds <- dplyr::as_tibble(preds) 32 | 33 | for (i in seq_along(lvls)) { 34 | .data[, lvls[i]] <- preds[, lvls[i]] 35 | } 36 | .data 37 | } 38 | -------------------------------------------------------------------------------- /R/cal-apply-regression.R: -------------------------------------------------------------------------------- 1 | # ------------------------------- Methods -------------------------------------- 2 | 3 | cal_apply_regression <- function(object, .data, pred_class) { 4 | UseMethod("cal_apply_regression") 5 | } 6 | 7 | #' @export 8 | cal_apply_regression.cal_estimate_linear_spline <- 9 | function(object, .data, pred_class = NULL, ...) { 10 | apply_reg_predict( 11 | object = object, 12 | .data = .data 13 | ) 14 | } 15 | 16 | #' @export 17 | cal_apply_regression.cal_estimate_linear <- 18 | cal_apply_regression.cal_estimate_linear_spline 19 | 20 | #---------------------------- Adjust implementations --------------------------- 21 | 22 | numeric_repredict <- function(x, predict_data, prd_nm) { 23 | if (is.null(x$filter)) { 24 | new_data <- predict_data 25 | } else { 26 | new_data <- dplyr::filter(predict_data, !!x$filter) 27 | } 28 | preds <- predict(x$estimate, newdata = new_data, type = "response") 29 | new_data[prd_nm] <- preds 30 | new_data 31 | } 32 | 33 | apply_reg_predict <- function(object, .data) { 34 | .data <- 35 | purrr::map( 36 | object$estimates, 37 | numeric_repredict, 38 | predict_data = .data, 39 | prd_nm = rlang::expr_deparse(object$levels$predictions) 40 | ) |> 41 | purrr::reduce(dplyr::bind_rows) 42 | .data 43 | } 44 | -------------------------------------------------------------------------------- /R/cal-pkg-check.R: -------------------------------------------------------------------------------- 1 | cal_pkg_check <- function(pkgs = NULL) { 2 | installed <- purrr::map_lgl(pkgs, rlang::is_installed) 3 | 4 | not_installed <- pkgs[!installed] 5 | 6 | if (length(not_installed)) { 7 | n_pkgs <- length(not_installed) 8 | 9 | pkg_str <- paste0(not_installed, collapse = ", ") 10 | install_cmd <- paste0("install.packages(", pkg_str, ")") 11 | 12 | cli::cli_abort( 13 | c( 14 | "{n_pkgs} package{?s} ({.pkg {not_installed}}) {?is/are} needed for 15 | this calibration but {?is/are} not installed.", 16 | "i" = "To install run: {.run {install_cmd}}" 17 | ) 18 | ) 19 | } 20 | invisible() 21 | } 22 | 23 | #' S3 methods to track which additional packages are needed for specific 24 | #' calibrations 25 | #' @param x A calibration object 26 | #' @inheritParams generics::required_pkgs 27 | #' @export 28 | required_pkgs.cal_object <- function(x, ...) { 29 | c("probably") 30 | } 31 | -------------------------------------------------------------------------------- /R/cal-utils.R: -------------------------------------------------------------------------------- 1 | # Centralizes the figuring out of which probability-variable maps to which 2 | # factor level of the "truth" variable. This is where the logic of finding 3 | # and mapping tidymodels explicit column names happen. If there are no .pred_ 4 | # named variables, it will map the variables based on the position. 5 | # It returns a named list, wit the variable names as syms, and the assigned 6 | # levels as the name. 7 | truth_estimate_map <- function(.data, truth, estimate, validate = FALSE) { 8 | truth_str <- tidyselect_cols(.data, {{ truth }}) 9 | 10 | if (is.integer(truth_str)) { 11 | truth_str <- names(truth_str) 12 | } 13 | 14 | # Get the name(s) of the column(s) that have the predicted values. For binary 15 | # data, this is a single column name. 16 | estimate_str <- .data |> 17 | tidyselect_cols({{ estimate }}) |> 18 | names() 19 | 20 | if (length(estimate_str) == 0) { 21 | cli::cli_abort("{.arg estimate} must select at least one column.") 22 | } 23 | 24 | truth_levels <- levels(.data[[truth_str]]) 25 | 26 | # `est_map` maps the levels of the outcome to the corresponding column(s) in 27 | # the data 28 | if (length(truth_levels) > 0) { 29 | if (all(substr(estimate_str, 1, 6) == ".pred_")) { 30 | est_map <- purrr::map( 31 | truth_levels, 32 | ~ { 33 | match <- paste0(".pred_", .x) == estimate_str 34 | if (any(match)) { 35 | sym(estimate_str[match]) 36 | } 37 | } 38 | ) 39 | } else { 40 | if (length(estimate_str) == 1) { 41 | est_map <- list(sym(estimate_str), NULL) 42 | } else { 43 | est_map <- purrr::map(seq_along(truth_levels), ~ sym(estimate_str[[.x]])) 44 | } 45 | } 46 | if (validate) { 47 | check_level_consistency(truth_levels, est_map) 48 | } 49 | res <- set_names(est_map, truth_levels) 50 | } else { 51 | # regression case 52 | res <- list(sym(estimate_str)) 53 | names(res) <- "predictions" 54 | } 55 | purrr::discard(res, is.null) 56 | } 57 | 58 | check_level_consistency <- function(lvls, mapping) { 59 | null_map <- purrr::map_lgl(mapping, is_null) 60 | if (any(null_map) | length(lvls) != length(mapping)) { 61 | missings <- lvls[null_map] 62 | missings <- paste0(missings, collapse = ", ") 63 | cols <- mapping[!null_map] 64 | cols <- purrr::map_chr(cols, as.character) 65 | cols <- paste0(cols, collapse = ", ") 66 | cli::cli_abort( 67 | c( 68 | "We can't connect the specified prediction columns to {.val {missings}}.", 69 | "i" = "The selected columns were {.val {cols}}.", 70 | "i" = "Are there more columns to add in the function call?" 71 | ) 72 | ) 73 | } 74 | invisible(NULL) 75 | } 76 | 77 | nm_levels <- function(x) { 78 | purrr::map_chr(x, rlang::as_name) 79 | } 80 | -------------------------------------------------------------------------------- /R/conformal_infer_split.R: -------------------------------------------------------------------------------- 1 | #' Prediction intervals via split conformal inference 2 | #' 3 | #' Nonparametric prediction intervals can be computed for fitted regression 4 | #' workflow objects using the split conformal inference method described by 5 | #' Lei _et al_ (2018). 6 | #' @param object A fitted [workflows::workflow()] object. 7 | #' @param ... Not currently used. 8 | #' @param cal_data A data frame with the _original predictor and outcome data_ 9 | #' used to produce predictions (and residuals). If the workflow used a recipe, 10 | #' this should be the data that were inputs to the recipe (and not the product 11 | #' of a recipe). 12 | #' @param ... Not currently used. 13 | #' @return An object of class `"int_conformal_split"` containing the 14 | #' information to create intervals (which includes `object`). 15 | #' The `predict()` method is used to produce the intervals. 16 | #' @details 17 | #' This function implements what is usually called "split conformal inference" 18 | #' (see Algorithm 1 in Lei _et al_ (2018)). 19 | #' 20 | #' This function prepares the statistics for the interval computations. The 21 | #' [predict()] method computes the intervals for new data and the signficance 22 | #' level is specified there. 23 | #' 24 | #' `cal_data` should be large enough to get a good estimates of a extreme 25 | #' quantile (e.g., the 95th for 95% interval) and should not include rows that 26 | #' were in the original training set. 27 | #' @seealso [predict.int_conformal_split()] 28 | #' @references 29 | #' Lei, Jing, et al. "Distribution-free predictive inference for regression." 30 | #' _Journal of the American Statistical Association_ 113.523 (2018): 1094-1111. 31 | #' @examplesIf !probably:::is_cran_check() & rlang::is_installed(c("modeldata", "parsnip", "nnet")) 32 | #' library(workflows) 33 | #' library(dplyr) 34 | #' library(parsnip) 35 | #' library(rsample) 36 | #' library(tune) 37 | #' library(modeldata) 38 | #' 39 | #' set.seed(2) 40 | #' sim_train <- sim_regression(500) 41 | #' sim_cal <- sim_regression(200) 42 | #' sim_new <- sim_regression(5) |> select(-outcome) 43 | #' 44 | #' # We'll use a neural network model 45 | #' mlp_spec <- 46 | #' mlp(hidden_units = 5, penalty = 0.01) |> 47 | #' set_mode("regression") 48 | #' 49 | #' mlp_wflow <- 50 | #' workflow() |> 51 | #' add_model(mlp_spec) |> 52 | #' add_formula(outcome ~ .) 53 | #' 54 | #' mlp_fit <- fit(mlp_wflow, data = sim_train) 55 | #' 56 | #' mlp_int <- int_conformal_split(mlp_fit, sim_cal) 57 | #' mlp_int 58 | #' 59 | #' predict(mlp_int, sim_new, level = 0.90) 60 | #' @export 61 | int_conformal_split <- function(object, ...) { 62 | UseMethod("int_conformal_split") 63 | } 64 | 65 | #' @export 66 | #' @rdname int_conformal_split 67 | int_conformal_split.default <- function(object, ...) { 68 | cli::cli_abort("No known {.fn int_conformal_split} methods for this type of object.") 69 | } 70 | 71 | #' @export 72 | #' @rdname int_conformal_split 73 | int_conformal_split.workflow <- function(object, cal_data, ...) { 74 | rlang::check_dots_empty() 75 | check_data_all(cal_data, object) 76 | 77 | y_name <- names(hardhat::extract_mold(object)$outcomes) 78 | cal_pred <- generics::augment(object, cal_data) 79 | cal_pred$.resid <- cal_pred[[y_name]] - cal_pred$.pred 80 | res <- list(resid = sort(abs(cal_pred$.resid)), wflow = object, n = nrow(cal_pred)) 81 | class(res) <- c("conformal_reg_split", "int_conformal_split") 82 | res 83 | } 84 | 85 | #' @export 86 | print.int_conformal_split <- function(x, ...) { 87 | cat("Split Conformal inference\n") 88 | 89 | cat("preprocessor:", .get_pre_type(x$wflow), "\n") 90 | cat("model:", .get_fit_type(x$wflow), "\n") 91 | cat("calibration set size:", format(x$n, big.mark = ","), "\n\n") 92 | 93 | cat("Use `predict(object, new_data, level)` to compute prediction intervals\n") 94 | invisible(x) 95 | } 96 | 97 | #' @export 98 | #' @rdname predict.int_conformal_full 99 | predict.int_conformal_split <- function(object, new_data, level = 0.95, ...) { 100 | check_data(new_data, object$wflow) 101 | rlang::check_dots_empty() 102 | 103 | new_pred <- predict(object$wflow, new_data) 104 | 105 | alpha <- 1 - level 106 | q_ind <- ceiling(level * (object$n + 1)) 107 | q_val <- object$resid[q_ind] 108 | 109 | new_pred$.pred_lower <- new_pred$.pred - q_val 110 | new_pred$.pred_upper <- new_pred$.pred + q_val 111 | new_pred 112 | } 113 | 114 | check_data_all <- function(.data, wflow) { 115 | mold <- hardhat::extract_mold(wflow) 116 | ptypes <- mold$blueprint$ptypes 117 | ptypes <- dplyr::bind_cols(ptypes$predictors, ptypes$outcomes) 118 | hardhat::shrink(.data, ptypes) 119 | invisible(NULL) 120 | } 121 | -------------------------------------------------------------------------------- /R/data.R: -------------------------------------------------------------------------------- 1 | #' Predictions on animal species 2 | #' 3 | #' @details These data are holdout predictions from resampling for the animal 4 | #' scat data of Reid (2015) based on a C5.0 classification model. 5 | #' 6 | #' @name species_probs 7 | #' @aliases species_probs 8 | #' @docType data 9 | #' @return \item{species_probs}{a tibble} 10 | #' 11 | #' @source Reid, R. E. B. (2015). A morphometric modeling approach to 12 | #' distinguishing among bobcat, coyote and gray fox scats. \emph{Wildlife 13 | #' Biology}, 21(5), 254-262 14 | #' 15 | #' @keywords datasets 16 | #' @examples 17 | #' data(species_probs) 18 | #' str(species_probs) 19 | NULL 20 | 21 | 22 | #' Image segmentation predictions 23 | #' 24 | #' @details These objects contain test set predictions for the cell segmentation 25 | #' data from Hill, LaPan, Li and Haney (2007). Each data frame are the results 26 | #' from different models (naive Bayes and logistic regression). 27 | #' 28 | #' @name segment_naive_bayes 29 | #' @aliases segment_naive_bayes segment_logistic 30 | #' @docType data 31 | #' @return \item{segment_naive_bayes,segment_logistic}{a tibble} 32 | #' 33 | #' @source Hill, LaPan, Li and Haney (2007). Impact of image segmentation on 34 | #' high-content screening data quality for SK-BR-3 cells, \emph{BMC 35 | #' Bioinformatics}, Vol. 8, pg. 340, 36 | #' \url{https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-8-340}. 37 | #' 38 | #' @keywords datasets 39 | #' @examples 40 | #' data(segment_naive_bayes) 41 | #' data(segment_logistic) 42 | NULL 43 | 44 | 45 | #' Boosted regression trees predictions 46 | #' 47 | #' @details These data have a set of holdout predictions from 10-fold 48 | #' cross-validation and a separate collection of test set predictions from the 49 | #' same boosted tree model. The data were generated using the `sim_regression` 50 | #' function in the \pkg{modeldata} package. 51 | #' 52 | #' @name boosting_predictions 53 | #' @aliases boosting_predictions_oob boosting_predictions_test 54 | #' @docType data 55 | #' @return \item{boosting_predictions_oob,boosting_predictions_test}{tibbles} 56 | #' 57 | #' @keywords datasets 58 | #' @examples 59 | #' data(boosting_predictions_oob) 60 | #' str(boosting_predictions_oob) 61 | #' str(boosting_predictions_test) 62 | NULL 63 | -------------------------------------------------------------------------------- /R/printing.R: -------------------------------------------------------------------------------- 1 | cat_class_pred <- function(x) { 2 | if (length(x) == 0) { 3 | cat("class_pred(0)", "\n") 4 | } else { 5 | print(format(x), quote = FALSE) 6 | } 7 | } 8 | 9 | # Adapted from print.factor 10 | # Smart enough to truncate the levels if they get too long 11 | cat_levels <- function(x, width = getOption("width")) { 12 | ord <- is_ordered_class_pred(x) 13 | 14 | if (ord) { 15 | colsep <- " < " 16 | } else { 17 | colsep <- " " 18 | } 19 | 20 | lev <- levels(x) 21 | n_lev <- length(lev) 22 | 23 | header <- "Levels: " 24 | 25 | maxl <- { 26 | width <- width - (nchar(header, "w") + 3L + 1L + 3L) 27 | lenl <- cumsum(nchar(lev, "w") + nchar(colsep, "w")) 28 | 29 | if (n_lev <= 1L || lenl[n_lev] <= width) { 30 | n_lev 31 | } else { 32 | max(1L, which.max(lenl > width) - 1L) 33 | } 34 | } 35 | 36 | # do we need to drop levels? 37 | drop <- n_lev > maxl 38 | 39 | cat( 40 | 41 | # Print number of levels if we had to drop some 42 | if (drop) { 43 | paste(format(n_lev), "") 44 | }, 45 | 46 | # Print `Levels: ` 47 | header, 48 | paste( 49 | 50 | # `first levels ... last levels` 51 | if (drop) { 52 | c(lev[1L:max(1, maxl - 1)], "...", if (maxl > 1) lev[n_lev]) 53 | } 54 | 55 | # print all levels 56 | else { 57 | lev 58 | }, 59 | collapse = colsep 60 | ), 61 | 62 | # Newline 63 | "\n", 64 | sep = "" 65 | ) 66 | } 67 | 68 | cat_reportable <- function(x) { 69 | reportable <- 100 * reportable_rate(x) 70 | 71 | if (rlang::is_scalar_integerish(reportable)) { 72 | reportable <- vec_cast(reportable, integer()) 73 | } 74 | 75 | digits <- function(x) { 76 | if (is.integer(x)) { 77 | 0 78 | } else { 79 | 1 80 | } 81 | } 82 | 83 | reportable <- paste0( 84 | formatC(reportable, format = "f", digits = digits(reportable)), 85 | "%" 86 | ) 87 | 88 | cat_report <- "Reportable: " 89 | cat_report <- paste0(cat_report, reportable) 90 | cat(cat_report) 91 | cat("\n") 92 | } 93 | -------------------------------------------------------------------------------- /R/probably-package.R: -------------------------------------------------------------------------------- 1 | #' @keywords internal 2 | "_PACKAGE" 3 | 4 | ## usethis namespace: start 5 | #' @import rlang 6 | #' @import vctrs 7 | #' @import ggplot2 8 | #' @importFrom purrr map 9 | #' @importFrom utils head 10 | #' @importFrom yardstick sens spec j_index 11 | #' @importFrom stats binomial median predict qnorm as.stepfun glm isoreg prop.test 12 | ## usethis namespace: end 13 | NULL 14 | 15 | utils::globalVariables(c( 16 | ".bin", ".is_val", "event_rate", "events", "lower", 17 | "predicted_midpoint", "total", "upper", ".config", 18 | ".adj_estimate", ".rounded", ".pred", ".bound", "pred_val", ".extracts", 19 | ".x", ".type", ".metrics", "cal_data" 20 | )) 21 | -------------------------------------------------------------------------------- /R/reexports.R: -------------------------------------------------------------------------------- 1 | #' @importFrom generics fit 2 | #' @export 3 | generics::fit 4 | 5 | #' @importFrom generics augment 6 | #' @export 7 | generics::augment 8 | 9 | #' @importFrom generics required_pkgs 10 | #' @export 11 | generics::required_pkgs 12 | 13 | #' @importFrom tune collect_metrics 14 | #' @export 15 | tune::collect_metrics 16 | 17 | #' @importFrom tune collect_predictions 18 | #' @export 19 | tune::collect_predictions 20 | 21 | # from tune 22 | # nocov start 23 | 24 | is_cran_check <- function() { 25 | if (identical(Sys.getenv("NOT_CRAN"), "true")) { 26 | FALSE 27 | } else { 28 | Sys.getenv("_R_CHECK_PACKAGE_NAME_", "") != "" 29 | } 30 | } 31 | 32 | # nocov end 33 | -------------------------------------------------------------------------------- /R/utils.R: -------------------------------------------------------------------------------- 1 | # is there a forcats for this? 2 | recode_data <- function(obs, prob, threshold, event_level) { 3 | lvl <- levels(obs) 4 | if (identical(event_level, "first")) { 5 | pred <- ifelse(prob >= threshold, lvl[1], lvl[2]) 6 | } else { 7 | pred <- ifelse(prob >= threshold, lvl[2], lvl[1]) 8 | } 9 | factor(pred, levels = lvl) 10 | } 11 | 12 | quote_collapse <- function(x, quote = "`", collapse = ", ") { 13 | paste(encodeString(x, quote = quote), collapse = collapse) 14 | } 15 | 16 | abort_default <- function(x, fn) { 17 | cls <- quote_collapse(class(x)) 18 | cli::cli_abort("No implementation of {.fn {fn}} for {.obj_type_friendly {cls}}.") 19 | } 20 | 21 | # Check if a class_pred object came from an ordered factor 22 | is_ordered_class_pred <- function(x) { 23 | attr(x, "ordered") 24 | } 25 | 26 | get_equivocal_label <- function(x) { 27 | attr(x, "equivocal") 28 | } 29 | 30 | is_ordered <- function(x) { 31 | UseMethod("is_ordered") 32 | } 33 | 34 | # Must export internal methods for testing 35 | #' @export 36 | is_ordered.class_pred <- function(x) { 37 | is_ordered_class_pred(x) 38 | } 39 | 40 | # Must export internal methods for testing 41 | #' @export 42 | is_ordered.default <- function(x) { 43 | is.ordered(x) 44 | } 45 | 46 | get_group_argument <- function(group, .data, call = rlang::env_parent()) { 47 | group <- rlang::enquo(group) 48 | 49 | group_names <- tidyselect::eval_select( 50 | expr = group, 51 | data = .data, 52 | allow_rename = FALSE, 53 | allow_empty = TRUE, 54 | allow_predicates = TRUE, 55 | error_call = call 56 | ) 57 | 58 | n_group_names <- length(group_names) 59 | 60 | useable_config <- n_group_names == 0 && 61 | ".config" %in% names(.data) && 62 | dplyr::n_distinct(.data[[".config"]]) > 1 63 | 64 | if (useable_config) { 65 | return(quo(.config)) 66 | } 67 | 68 | if (n_group_names > 1) { 69 | cli::cli_abort( 70 | c( 71 | x = "{.arg .by} cannot select more than one column.", 72 | i = "The following {n_group_names} columns were selected:", 73 | i = "{names(group_names)}" 74 | ) 75 | ) 76 | } 77 | 78 | return(group) 79 | } 80 | 81 | abort_if_tune_result <- function(call = rlang::caller_env()) { 82 | cli::cli_abort( 83 | c( 84 | "This function can only be used with an {.cls rset} object or the \\ 85 | results of {.fn tune::fit_resamples} with a {.field .predictions} \\ 86 | column.", 87 | i = "Not an {.cls tune_results} object." 88 | ), 89 | call = call 90 | ) 91 | } 92 | 93 | abort_if_grouped_df <- function(call = rlang::caller_env()) { 94 | cli::cli_abort( 95 | c( 96 | "x" = "This function does not work with grouped data frames.", 97 | "i" = "Apply {.fn dplyr::ungroup} and use the {.arg .by} argument." 98 | ), 99 | call = call 100 | ) 101 | } 102 | -------------------------------------------------------------------------------- /R/zzz.R: -------------------------------------------------------------------------------- 1 | .onLoad <- function(libname, pkgname) { 2 | vctrs::s3_register("tune::collect_metrics", "cal_rset") 3 | vctrs::s3_register("tune::collect_predictions", "cal_rset") 4 | } 5 | -------------------------------------------------------------------------------- /README.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | output: github_document 3 | editor_options: 4 | chunk_output_type: console 5 | --- 6 | 7 | 8 | 9 | ```{r setup, include = FALSE} 10 | knitr::opts_chunk$set( 11 | collapse = TRUE, 12 | comment = "#>", 13 | fig.path = "man/figures/README-", 14 | out.width = "100%" 15 | ) 16 | ``` 17 | 18 | # probably probably website 19 | 20 | 21 | [![Codecov test coverage](https://codecov.io/gh/tidymodels/probably/branch/main/graph/badge.svg)](https://app.codecov.io/gh/tidymodels/probably?branch=main) 22 | [![Lifecycle: experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html) 23 | [![R-CMD-check](https://github.com/tidymodels/probably/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/tidymodels/probably/actions/workflows/R-CMD-check.yaml) 24 | 25 | 26 | ## Introduction 27 | 28 | probably contains tools to facilitate activities such as: 29 | 30 | * Conversion of probabilities to discrete class predictions. 31 | 32 | * Investigating and estimating optimal probability thresholds. 33 | 34 | * Calibration assessments and remediation for classification and regression models. 35 | 36 | * Inclusion of _equivocal zones_ where the probabilities are too uncertain to report a prediction. 37 | 38 | ## Installation 39 | 40 | You can install probably from CRAN with: 41 | 42 | ```{r, eval = FALSE} 43 | install.packages("probably") 44 | ``` 45 | 46 | You can install the development version of probably from GitHub with: 47 | 48 | ```{r, eval = FALSE} 49 | # install.packages("pak") 50 | pak::pak("tidymodels/probably") 51 | ``` 52 | 53 | ## Examples 54 | 55 | Good places to look for examples of using probably are the vignettes. 56 | 57 | * `vignette("equivocal-zones", "probably")` discusses the new `class_pred` class that probably provides for working with equivocal zones. 58 | 59 | * `vignette("where-to-use", "probably")` discusses how probably fits in with the rest of the tidymodels ecosystem, and provides an example of optimizing class probability thresholds. 60 | 61 | ## Contributing 62 | 63 | This project is released with a [Contributor Code of Conduct](https://contributor-covenant.org/version/2/0/CODE_OF_CONDUCT.html). By contributing to this project, you agree to abide by its terms. 64 | 65 | - For questions and discussions about tidymodels packages, modeling, and machine learning, please [post on RStudio Community](https://forum.posit.co/new-topic?category_id=15&tags=tidymodels,question). 66 | 67 | - If you think you have encountered a bug, please [submit an issue](https://github.com/tidymodels/probably/issues). 68 | 69 | - Either way, learn how to create and share a [reprex](https://reprex.tidyverse.org/articles/articles/learn-reprex.html) (a minimal, reproducible example), to clearly communicate about your code. 70 | 71 | - Check out further details on [contributing guidelines for tidymodels packages](https://www.tidymodels.org/contribute/) and [how to get help](https://www.tidymodels.org/help/). 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # probably probably website 5 | 6 | 7 | 8 | [![Codecov test 9 | coverage](https://codecov.io/gh/tidymodels/probably/branch/main/graph/badge.svg)](https://app.codecov.io/gh/tidymodels/probably?branch=main) 10 | [![Lifecycle: 11 | experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html) 12 | [![R-CMD-check](https://github.com/tidymodels/probably/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/tidymodels/probably/actions/workflows/R-CMD-check.yaml) 13 | 14 | 15 | ## Introduction 16 | 17 | probably contains tools to facilitate activities such as: 18 | 19 | - Conversion of probabilities to discrete class predictions. 20 | 21 | - Investigating and estimating optimal probability thresholds. 22 | 23 | - Calibration assessments and remediation for classification and 24 | regression models. 25 | 26 | - Inclusion of *equivocal zones* where the probabilities are too 27 | uncertain to report a prediction. 28 | 29 | ## Installation 30 | 31 | You can install probably from CRAN with: 32 | 33 | ``` r 34 | install.packages("probably") 35 | ``` 36 | 37 | You can install the development version of probably from GitHub with: 38 | 39 | ``` r 40 | # install.packages("pak") 41 | pak::pak("tidymodels/probably") 42 | ``` 43 | 44 | ## Examples 45 | 46 | Good places to look for examples of using probably are the vignettes. 47 | 48 | - `vignette("equivocal-zones", "probably")` discusses the new 49 | `class_pred` class that probably provides for working with equivocal 50 | zones. 51 | 52 | - `vignette("where-to-use", "probably")` discusses how probably fits in 53 | with the rest of the tidymodels ecosystem, and provides an example of 54 | optimizing class probability thresholds. 55 | 56 | ## Contributing 57 | 58 | This project is released with a [Contributor Code of 59 | Conduct](https://contributor-covenant.org/version/2/0/CODE_OF_CONDUCT.html). 60 | By contributing to this project, you agree to abide by its terms. 61 | 62 | - For questions and discussions about tidymodels packages, modeling, and 63 | machine learning, please [post on RStudio 64 | Community](https://forum.posit.co/new-topic?category_id=15&tags=tidymodels,question). 65 | 66 | - If you think you have encountered a bug, please [submit an 67 | issue](https://github.com/tidymodels/probably/issues). 68 | 69 | - Either way, learn how to create and share a 70 | [reprex](https://reprex.tidyverse.org/articles/articles/learn-reprex.html) 71 | (a minimal, reproducible example), to clearly communicate about your 72 | code. 73 | 74 | - Check out further details on [contributing guidelines for tidymodels 75 | packages](https://www.tidymodels.org/contribute/) and [how to get 76 | help](https://www.tidymodels.org/help/). 77 | -------------------------------------------------------------------------------- /_pkgdown.yml: -------------------------------------------------------------------------------- 1 | url: https://probably.tidymodels.org 2 | 3 | template: 4 | bootstrap: 5 5 | package: tidytemplate 6 | bslib: 7 | primary: '#CA225E' 8 | includes: 9 | in_header: | 10 | 11 | 12 | development: 13 | mode: auto 14 | 15 | reference: 16 | - title: Thresholds 17 | contents: 18 | - threshold_perf 19 | 20 | - title: Create class predictions 21 | contents: 22 | - append_class_pred 23 | - make_class_pred 24 | - make_two_class_pred 25 | 26 | - title: Class predictions 27 | contents: 28 | - class_pred 29 | - as_class_pred 30 | - is_class_pred 31 | - reportable_rate 32 | - locate-equivocal 33 | - levels.class_pred 34 | 35 | - title: Regression predictions 36 | contents: 37 | - starts_with("int_") 38 | - starts_with("control_conformal_") 39 | - starts_with("predict.int") 40 | - bound_prediction 41 | 42 | - title: Data 43 | contents: 44 | - segment_naive_bayes 45 | - segment_logistic 46 | - species_probs 47 | - boosting_predictions 48 | 49 | - title: Calibration 50 | contents: 51 | - starts_with("cal_estimate") 52 | - cal_apply 53 | 54 | - title: Calibration Validation 55 | contents: 56 | - starts_with("cal_validate") 57 | - starts_with("collect_") 58 | 59 | - title: Calibration Plots 60 | contents: 61 | - starts_with("cal_plot") 62 | 63 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: auto 8 | threshold: 1% 9 | informational: true 10 | patch: 11 | default: 12 | target: auto 13 | threshold: 1% 14 | informational: true 15 | -------------------------------------------------------------------------------- /data/boosting_predictions.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/probably/4dab76fce47ac7d75334f57c8cb4ecba21cadf15/data/boosting_predictions.RData -------------------------------------------------------------------------------- /data/datalist: -------------------------------------------------------------------------------- 1 | segment_logistic: segment_logistic 2 | segment_naive_bayes: segment_naive_bayes 3 | species_probs: species_probs 4 | boosting_predictions: boosting_predictions_oob boosting_predictions_test -------------------------------------------------------------------------------- /data/segment_logistic.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/probably/4dab76fce47ac7d75334f57c8cb4ecba21cadf15/data/segment_logistic.RData -------------------------------------------------------------------------------- /data/segment_naive_bayes.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/probably/4dab76fce47ac7d75334f57c8cb4ecba21cadf15/data/segment_naive_bayes.RData -------------------------------------------------------------------------------- /data/species_probs.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/probably/4dab76fce47ac7d75334f57c8cb4ecba21cadf15/data/species_probs.RData -------------------------------------------------------------------------------- /man-roxygen/metrics_both.R: -------------------------------------------------------------------------------- 1 | #' @section Performance Metrics: 2 | #' 3 | #' By default, the average of the Brier scores (classification calibration) or the 4 | #' root mean squared error (regression) is returned. Any appropriate 5 | #' [yardstick::metric_set()] can be used. The validation function compares the 6 | #' average of the metrics before, and after the calibration. 7 | -------------------------------------------------------------------------------- /man-roxygen/metrics_cls.R: -------------------------------------------------------------------------------- 1 | #' @section Performance Metrics: 2 | #' 3 | #' By default, the average of the Brier scores is returned. Any appropriate 4 | #' [yardstick::metric_set()] can be used. The validation function compares the 5 | #' average of the metrics before, and after the calibration. 6 | -------------------------------------------------------------------------------- /man-roxygen/metrics_reg.R: -------------------------------------------------------------------------------- 1 | #' @section Performance Metrics: 2 | #' 3 | #' By default, the average of the root mean square error (RMSE) is returned. 4 | #' Any appropriate [yardstick::metric_set()] can be used. The validation 5 | #' function compares the average of the metrics before, and after the calibration. 6 | -------------------------------------------------------------------------------- /man-roxygen/multiclass.R: -------------------------------------------------------------------------------- 1 | #' @section Multiclass Extension: 2 | #' 3 | #' This method is designed to work with two classes. For multiclass, it creates 4 | #' a set of "one versus all" calibrations for each class. After they are 5 | #' applied to the data, the probability estimates are re-normalized to add to 6 | #' one. This final step might compromise the calibration. 7 | #' 8 | -------------------------------------------------------------------------------- /man/append_class_pred.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/make_class_pred.R 3 | \name{append_class_pred} 4 | \alias{append_class_pred} 5 | \title{Add a \code{class_pred} column} 6 | \usage{ 7 | append_class_pred( 8 | .data, 9 | ..., 10 | levels, 11 | ordered = FALSE, 12 | min_prob = 1/length(levels), 13 | name = ".class_pred" 14 | ) 15 | } 16 | \arguments{ 17 | \item{.data}{A data frame or tibble.} 18 | 19 | \item{...}{One or more unquoted expressions separated by commas 20 | to capture the columns of \code{.data} containing the class 21 | probabilities. You can treat variable names like they are 22 | positions, so you can use expressions like \code{x:y} to select ranges 23 | of variables or use selector functions to choose which columns. 24 | For \code{make_class_pred}, the columns for all class probabilities 25 | should be selected (in the same order as the \code{levels} object). 26 | For \code{two_class_pred}, a vector of class probabilities should be 27 | selected.} 28 | 29 | \item{levels}{A character vector of class levels. The length should be the 30 | same as the number of selections made through \code{...}, or length \code{2} 31 | for \code{make_two_class_pred()}.} 32 | 33 | \item{ordered}{A single logical to determine if the levels should be regarded 34 | as ordered (in the order given). This results in a \code{class_pred} object 35 | that is flagged as ordered.} 36 | 37 | \item{min_prob}{A single numeric value. If any probabilities are less than 38 | this value (by row), the row is marked as \emph{equivocal}.} 39 | 40 | \item{name}{A single character value for the name of the appended 41 | \code{class_pred} column.} 42 | } 43 | \value{ 44 | \code{.data} with an extra \code{class_pred} column appended onto it. 45 | } 46 | \description{ 47 | This function is similar to \code{\link[=make_class_pred]{make_class_pred()}}, but is useful when you have 48 | a large number of class probability columns and want to use \code{tidyselect} 49 | helpers. It appends the new \code{class_pred} vector as a column on the original 50 | data frame. 51 | } 52 | \examples{ 53 | 54 | # The following two examples are equivalent and demonstrate 55 | # the helper, append_class_pred() 56 | 57 | library(dplyr) 58 | 59 | species_probs |> 60 | mutate( 61 | .class_pred = make_class_pred( 62 | .pred_bobcat, .pred_coyote, .pred_gray_fox, 63 | levels = levels(Species), 64 | min_prob = .5 65 | ) 66 | ) 67 | 68 | lvls <- levels(species_probs$Species) 69 | 70 | append_class_pred( 71 | .data = species_probs, 72 | contains(".pred_"), 73 | levels = lvls, 74 | min_prob = .5 75 | ) 76 | 77 | } 78 | -------------------------------------------------------------------------------- /man/as_class_pred.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/class-pred.R 3 | \name{as_class_pred} 4 | \alias{as_class_pred} 5 | \title{Coerce to a \code{class_pred} object} 6 | \usage{ 7 | as_class_pred(x, which = integer(), equivocal = "[EQ]") 8 | } 9 | \arguments{ 10 | \item{x}{A factor or ordered factor.} 11 | 12 | \item{which}{An integer vector specifying the locations of \code{x} to declare 13 | as equivocal.} 14 | 15 | \item{equivocal}{A single character specifying the equivocal label used 16 | when printing.} 17 | } 18 | \description{ 19 | \code{as_class_pred()} provides coercion to \code{class_pred} from other 20 | existing objects. 21 | } 22 | \examples{ 23 | 24 | x <- factor(c("Yes", "No", "Yes", "Yes")) 25 | as_class_pred(x) 26 | 27 | } 28 | -------------------------------------------------------------------------------- /man/boosting_predictions.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{boosting_predictions} 5 | \alias{boosting_predictions} 6 | \alias{boosting_predictions_oob} 7 | \alias{boosting_predictions_test} 8 | \title{Boosted regression trees predictions} 9 | \value{ 10 | \item{boosting_predictions_oob,boosting_predictions_test}{tibbles} 11 | } 12 | \description{ 13 | Boosted regression trees predictions 14 | } 15 | \details{ 16 | These data have a set of holdout predictions from 10-fold 17 | cross-validation and a separate collection of test set predictions from the 18 | same boosted tree model. The data were generated using the \code{sim_regression} 19 | function in the \pkg{modeldata} package. 20 | } 21 | \examples{ 22 | data(boosting_predictions_oob) 23 | str(boosting_predictions_oob) 24 | str(boosting_predictions_test) 25 | } 26 | \keyword{datasets} 27 | -------------------------------------------------------------------------------- /man/bound_prediction.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/bound_prediction.R 3 | \name{bound_prediction} 4 | \alias{bound_prediction} 5 | \title{Truncate a numeric prediction column} 6 | \usage{ 7 | bound_prediction( 8 | x, 9 | lower_limit = -Inf, 10 | upper_limit = Inf, 11 | call = rlang::current_env() 12 | ) 13 | } 14 | \arguments{ 15 | \item{x}{A data frame that contains a numeric column named \code{.pred}.} 16 | 17 | \item{lower_limit, upper_limit}{Single numerics (or \code{NA}) that define 18 | constraints on \code{.pred}.} 19 | 20 | \item{call}{The call to be displayed in warnings or errors.} 21 | } 22 | \value{ 23 | \code{x} with potentially adjusted values. 24 | } 25 | \description{ 26 | For user-defined \code{lower_limit} and/or \code{upper_limit} bound, ensure that the values in the 27 | \code{.pred} column are coerced to these bounds. 28 | } 29 | \examples{ 30 | data(solubility_test, package = "yardstick") 31 | 32 | names(solubility_test) <- c("solubility", ".pred") 33 | 34 | bound_prediction(solubility_test, lower_limit = -1) 35 | } 36 | -------------------------------------------------------------------------------- /man/cal_apply.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-apply.R 3 | \name{cal_apply} 4 | \alias{cal_apply} 5 | \alias{cal_apply.data.frame} 6 | \alias{cal_apply.tune_results} 7 | \alias{cal_apply.cal_object} 8 | \title{Applies a calibration to a set of existing predictions} 9 | \usage{ 10 | cal_apply(.data, object, pred_class = NULL, parameters = NULL, ...) 11 | 12 | \method{cal_apply}{data.frame}(.data, object, pred_class = NULL, parameters = NULL, ...) 13 | 14 | \method{cal_apply}{tune_results}(.data, object, pred_class = NULL, parameters = NULL, ...) 15 | 16 | \method{cal_apply}{cal_object}(.data, object, pred_class = NULL, parameters = NULL, ...) 17 | } 18 | \arguments{ 19 | \item{.data}{An object that can process a calibration object.} 20 | 21 | \item{object}{The calibration object (\code{cal_object}).} 22 | 23 | \item{pred_class}{(Optional, classification only) Column identifier for the 24 | hard class predictions (a factor vector). This column will be adjusted based 25 | on changes to the calibrated probability columns.} 26 | 27 | \item{parameters}{(Optional) An optional tibble of tuning parameter values 28 | that can be used to filter the predicted values before processing. Applies 29 | only to \code{tune_results} objects.} 30 | 31 | \item{...}{Optional arguments; currently unused.} 32 | } 33 | \description{ 34 | Applies a calibration to a set of existing predictions 35 | } 36 | \details{ 37 | \code{cal_apply()} currently supports data.frames only. It extracts the \code{truth} and 38 | the estimate columns names from the calibration object. 39 | } 40 | \examples{ 41 | 42 | # ------------------------------------------------------------------------------ 43 | # classification example 44 | 45 | w_calibration <- cal_estimate_logistic(segment_logistic, Class) 46 | 47 | cal_apply(segment_logistic, w_calibration) 48 | } 49 | \seealso{ 50 | \url{https://www.tidymodels.org/learn/models/calibration/}, 51 | \code{\link[=cal_estimate_beta]{cal_estimate_beta()}}, \code{\link[=cal_estimate_isotonic]{cal_estimate_isotonic()}}, 52 | \code{\link[=cal_estimate_isotonic_boot]{cal_estimate_isotonic_boot()}}, \code{\link[=cal_estimate_linear]{cal_estimate_linear()}}, 53 | \code{\link[=cal_estimate_logistic]{cal_estimate_logistic()}}, \code{\link[=cal_estimate_multinomial]{cal_estimate_multinomial()}} 54 | } 55 | -------------------------------------------------------------------------------- /man/cal_binary_tables.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-plot-breaks.R, R/cal-plot-logistic.R, 3 | % R/cal-plot-windowed.R 4 | \name{.cal_table_breaks} 5 | \alias{.cal_table_breaks} 6 | \alias{.cal_table_logistic} 7 | \alias{.cal_table_windowed} 8 | \title{Probability Calibration table} 9 | \usage{ 10 | .cal_table_breaks( 11 | .data, 12 | truth = NULL, 13 | estimate = NULL, 14 | .by = NULL, 15 | num_breaks = 10, 16 | conf_level = 0.9, 17 | event_level = c("auto", "first", "second"), 18 | ... 19 | ) 20 | 21 | .cal_table_logistic( 22 | .data, 23 | truth = NULL, 24 | estimate = NULL, 25 | .by = NULL, 26 | conf_level = 0.9, 27 | smooth = TRUE, 28 | event_level = c("auto", "first", "second"), 29 | ... 30 | ) 31 | 32 | .cal_table_windowed( 33 | .data, 34 | truth = NULL, 35 | estimate = NULL, 36 | .by = NULL, 37 | window_size = 0.1, 38 | step_size = window_size/2, 39 | conf_level = 0.9, 40 | event_level = c("auto", "first", "second"), 41 | ... 42 | ) 43 | } 44 | \arguments{ 45 | \item{.data}{An ungrouped data frame object containing predictions and 46 | probability columns.} 47 | 48 | \item{truth}{The column identifier for the true class results 49 | (that is a factor). This should be an unquoted column name.} 50 | 51 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 52 | functions to choose which variables contains the class probabilities. It 53 | defaults to the prefix used by tidymodels (\code{.pred_}). The order of the 54 | identifiers will be considered the same as the order of the levels of the 55 | \code{truth} variable.} 56 | 57 | \item{.by}{The column identifier for the grouping variable. This should be 58 | a single unquoted column name that selects a qualitative variable for 59 | grouping. Default to \code{NULL}. When \code{.by = NULL} no grouping will take place.} 60 | 61 | \item{num_breaks}{The number of segments to group the probabilities. It 62 | defaults to 10.} 63 | 64 | \item{conf_level}{Confidence level to use in the visualization. It defaults 65 | to 0.9.} 66 | 67 | \item{event_level}{single string. Either "first" or "second" to specify which 68 | level of truth to consider as the "event". Defaults to "auto", which allows 69 | the function decide which one to use based on the type of model (binary, 70 | multi-class or linear)} 71 | 72 | \item{...}{Additional arguments passed to the \code{tune_results} object.} 73 | } 74 | \description{ 75 | Calibration table functions. They require a data.frame that 76 | contains the predictions and probability columns. The output is another 77 | \code{tibble} with segmented data that compares the accuracy of the probability 78 | to the actual outcome. 79 | } 80 | \details{ 81 | \itemize{ 82 | \item \code{.cal_table_breaks()} - Splits the data into bins, based on the 83 | number of breaks provided (\code{num_breaks}). The bins are even ranges, starting 84 | at 0, and ending at 1. 85 | \item \code{.cal_table_logistic()} - Fits a logistic spline regression (GAM) 86 | against the data. It then creates a table with the predictions based on 100 87 | probabilities starting at 0, and ending at 1. 88 | \item \code{.cal_table_windowed()} - Creates a running percentage of the 89 | probability that moves across the proportion of events. 90 | } 91 | } 92 | \examples{ 93 | .cal_table_breaks( 94 | segment_logistic, 95 | Class, 96 | .pred_good 97 | ) 98 | 99 | .cal_table_logistic( 100 | segment_logistic, 101 | Class, 102 | .pred_good 103 | ) 104 | 105 | .cal_table_windowed( 106 | segment_logistic, 107 | Class, 108 | .pred_good 109 | ) 110 | } 111 | \keyword{internal} 112 | -------------------------------------------------------------------------------- /man/cal_estimate_beta.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-estimate-beta.R 3 | \name{cal_estimate_beta} 4 | \alias{cal_estimate_beta} 5 | \alias{cal_estimate_beta.data.frame} 6 | \alias{cal_estimate_beta.tune_results} 7 | \alias{cal_estimate_beta.grouped_df} 8 | \title{Uses a Beta calibration model to calculate new probabilities} 9 | \usage{ 10 | cal_estimate_beta( 11 | .data, 12 | truth = NULL, 13 | shape_params = 2, 14 | location_params = 1, 15 | estimate = dplyr::starts_with(".pred_"), 16 | parameters = NULL, 17 | ... 18 | ) 19 | 20 | \method{cal_estimate_beta}{data.frame}( 21 | .data, 22 | truth = NULL, 23 | shape_params = 2, 24 | location_params = 1, 25 | estimate = dplyr::starts_with(".pred_"), 26 | parameters = NULL, 27 | ..., 28 | .by = NULL 29 | ) 30 | 31 | \method{cal_estimate_beta}{tune_results}( 32 | .data, 33 | truth = NULL, 34 | shape_params = 2, 35 | location_params = 1, 36 | estimate = dplyr::starts_with(".pred_"), 37 | parameters = NULL, 38 | ... 39 | ) 40 | 41 | \method{cal_estimate_beta}{grouped_df}( 42 | .data, 43 | truth = NULL, 44 | shape_params = 2, 45 | location_params = 1, 46 | estimate = NULL, 47 | parameters = NULL, 48 | ... 49 | ) 50 | } 51 | \arguments{ 52 | \item{.data}{An ungrouped \code{data.frame} object, or \code{tune_results} object, 53 | that contains predictions and probability columns.} 54 | 55 | \item{truth}{The column identifier for the true class results 56 | (that is a factor). This should be an unquoted column name.} 57 | 58 | \item{shape_params}{Number of shape parameters to use. Accepted values are 59 | 1 and 2. Defaults to 2.} 60 | 61 | \item{location_params}{Number of location parameters to use. Accepted values 62 | 1 and 0. Defaults to 1.} 63 | 64 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 65 | functions to choose which variables contains the class probabilities. It 66 | defaults to the prefix used by tidymodels (\code{.pred_}). The order of the 67 | identifiers will be considered the same as the order of the levels of the 68 | \code{truth} variable.} 69 | 70 | \item{parameters}{(Optional) An optional tibble of tuning parameter values 71 | that can be used to filter the predicted values before processing. Applies 72 | only to \code{tune_results} objects.} 73 | 74 | \item{...}{Additional arguments passed to the models or routines used to 75 | calculate the new probabilities.} 76 | 77 | \item{.by}{The column identifier for the grouping variable. This should be 78 | a single unquoted column name that selects a qualitative variable for 79 | grouping. Default to \code{NULL}. When \code{.by = NULL} no grouping will take place.} 80 | } 81 | \description{ 82 | Uses a Beta calibration model to calculate new probabilities 83 | } 84 | \details{ 85 | This function uses the \code{\link[betacal:beta_calibration]{betacal::beta_calibration()}} function, and 86 | retains the resulting model. 87 | } 88 | \section{Multiclass Extension}{ 89 | 90 | 91 | This method is designed to work with two classes. For multiclass, it creates 92 | a set of "one versus all" calibrations for each class. After they are 93 | applied to the data, the probability estimates are re-normalized to add to 94 | one. This final step might compromise the calibration. 95 | } 96 | 97 | \examples{ 98 | if (rlang::is_installed("betacal")) { 99 | # It will automatically identify the probability columns 100 | # if passed a model fitted with tidymodels 101 | cal_estimate_beta(segment_logistic, Class) 102 | } 103 | } 104 | \references{ 105 | Meelis Kull, Telmo M. Silva Filho, Peter Flach "Beyond sigmoids: 106 | How to obtain well-calibrated probabilities from binary classifiers with beta 107 | calibration," \emph{Electronic Journal of Statistics} 11(2), 5052-5080, (2017) 108 | } 109 | \seealso{ 110 | \url{https://www.tidymodels.org/learn/models/calibration/}, 111 | \code{\link[=cal_validate_beta]{cal_validate_beta()}} 112 | } 113 | -------------------------------------------------------------------------------- /man/cal_estimate_isotonic.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-estimate-isotonic.R 3 | \name{cal_estimate_isotonic} 4 | \alias{cal_estimate_isotonic} 5 | \alias{cal_estimate_isotonic.data.frame} 6 | \alias{cal_estimate_isotonic.tune_results} 7 | \alias{cal_estimate_isotonic.grouped_df} 8 | \title{Uses an Isotonic regression model to calibrate model predictions.} 9 | \usage{ 10 | cal_estimate_isotonic( 11 | .data, 12 | truth = NULL, 13 | estimate = dplyr::starts_with(".pred"), 14 | parameters = NULL, 15 | ... 16 | ) 17 | 18 | \method{cal_estimate_isotonic}{data.frame}( 19 | .data, 20 | truth = NULL, 21 | estimate = dplyr::starts_with(".pred"), 22 | parameters = NULL, 23 | ..., 24 | .by = NULL 25 | ) 26 | 27 | \method{cal_estimate_isotonic}{tune_results}( 28 | .data, 29 | truth = NULL, 30 | estimate = dplyr::starts_with(".pred"), 31 | parameters = NULL, 32 | ... 33 | ) 34 | 35 | \method{cal_estimate_isotonic}{grouped_df}( 36 | .data, 37 | truth = NULL, 38 | estimate = NULL, 39 | parameters = NULL, 40 | ... 41 | ) 42 | } 43 | \arguments{ 44 | \item{.data}{An ungrouped \code{data.frame} object, or \code{tune_results} object, 45 | that contains predictions and probability columns.} 46 | 47 | \item{truth}{The column identifier for the true class results 48 | (that is a factor). This should be an unquoted column name.} 49 | 50 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 51 | functions to choose which variables contains the class probabilities. It 52 | defaults to the prefix used by tidymodels (\code{.pred_}). The order of the 53 | identifiers will be considered the same as the order of the levels of the 54 | \code{truth} variable.} 55 | 56 | \item{parameters}{(Optional) An optional tibble of tuning parameter values 57 | that can be used to filter the predicted values before processing. Applies 58 | only to \code{tune_results} objects.} 59 | 60 | \item{...}{Additional arguments passed to the models or routines used to 61 | calculate the new probabilities.} 62 | 63 | \item{.by}{The column identifier for the grouping variable. This should be 64 | a single unquoted column name that selects a qualitative variable for 65 | grouping. Default to \code{NULL}. When \code{.by = NULL} no grouping will take place.} 66 | } 67 | \description{ 68 | Uses an Isotonic regression model to calibrate model predictions. 69 | } 70 | \details{ 71 | This function uses \code{\link[stats:isoreg]{stats::isoreg()}} to create obtain the calibration 72 | values for binary classification or numeric regression. 73 | } 74 | \section{Multiclass Extension}{ 75 | 76 | 77 | This method is designed to work with two classes. For multiclass, it creates 78 | a set of "one versus all" calibrations for each class. After they are 79 | applied to the data, the probability estimates are re-normalized to add to 80 | one. This final step might compromise the calibration. 81 | } 82 | 83 | \examples{ 84 | # ------------------------------------------------------------------------------ 85 | # Binary Classification 86 | 87 | # It will automatically identify the probability columns 88 | # if passed a model fitted with tidymodels 89 | cal_estimate_isotonic(segment_logistic, Class) 90 | 91 | # Specify the variable names in a vector of unquoted names 92 | cal_estimate_isotonic(segment_logistic, Class, c(.pred_poor, .pred_good)) 93 | 94 | # dplyr selector functions are also supported 95 | cal_estimate_isotonic(segment_logistic, Class, dplyr::starts_with(".pred_")) 96 | 97 | # ------------------------------------------------------------------------------ 98 | # Regression (numeric outcomes) 99 | 100 | cal_estimate_isotonic(boosting_predictions_oob, outcome, .pred) 101 | } 102 | \references{ 103 | Zadrozny, Bianca and Elkan, Charles. (2002). Transforming Classifier Scores 104 | into Accurate Multiclass Probability Estimates. \emph{Proceedings of the ACM SIGKDD 105 | International Conference on Knowledge Discovery and Data Mining.} 106 | } 107 | \seealso{ 108 | \url{https://www.tidymodels.org/learn/models/calibration/}, 109 | \code{\link[=cal_validate_isotonic]{cal_validate_isotonic()}} 110 | } 111 | -------------------------------------------------------------------------------- /man/cal_estimate_isotonic_boot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-estimate-isotonic.R 3 | \name{cal_estimate_isotonic_boot} 4 | \alias{cal_estimate_isotonic_boot} 5 | \alias{cal_estimate_isotonic_boot.data.frame} 6 | \alias{cal_estimate_isotonic_boot.tune_results} 7 | \alias{cal_estimate_isotonic_boot.grouped_df} 8 | \title{Uses a bootstrapped Isotonic regression model to calibrate probabilities} 9 | \usage{ 10 | cal_estimate_isotonic_boot( 11 | .data, 12 | truth = NULL, 13 | estimate = dplyr::starts_with(".pred"), 14 | times = 10, 15 | parameters = NULL, 16 | ... 17 | ) 18 | 19 | \method{cal_estimate_isotonic_boot}{data.frame}( 20 | .data, 21 | truth = NULL, 22 | estimate = dplyr::starts_with(".pred"), 23 | times = 10, 24 | parameters = NULL, 25 | ..., 26 | .by = NULL 27 | ) 28 | 29 | \method{cal_estimate_isotonic_boot}{tune_results}( 30 | .data, 31 | truth = NULL, 32 | estimate = dplyr::starts_with(".pred"), 33 | times = 10, 34 | parameters = NULL, 35 | ... 36 | ) 37 | 38 | \method{cal_estimate_isotonic_boot}{grouped_df}( 39 | .data, 40 | truth = NULL, 41 | estimate = NULL, 42 | times = 10, 43 | parameters = NULL, 44 | ... 45 | ) 46 | } 47 | \arguments{ 48 | \item{.data}{An ungrouped \code{data.frame} object, or \code{tune_results} object, 49 | that contains predictions and probability columns.} 50 | 51 | \item{truth}{The column identifier for the true class results 52 | (that is a factor). This should be an unquoted column name.} 53 | 54 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 55 | functions to choose which variables contains the class probabilities. It 56 | defaults to the prefix used by tidymodels (\code{.pred_}). The order of the 57 | identifiers will be considered the same as the order of the levels of the 58 | \code{truth} variable.} 59 | 60 | \item{times}{Number of bootstraps.} 61 | 62 | \item{parameters}{(Optional) An optional tibble of tuning parameter values 63 | that can be used to filter the predicted values before processing. Applies 64 | only to \code{tune_results} objects.} 65 | 66 | \item{...}{Additional arguments passed to the models or routines used to 67 | calculate the new probabilities.} 68 | 69 | \item{.by}{The column identifier for the grouping variable. This should be 70 | a single unquoted column name that selects a qualitative variable for 71 | grouping. Default to \code{NULL}. When \code{.by = NULL} no grouping will take place.} 72 | } 73 | \description{ 74 | Uses a bootstrapped Isotonic regression model to calibrate probabilities 75 | } 76 | \details{ 77 | This function uses \code{\link[stats:isoreg]{stats::isoreg()}} to create obtain the calibration 78 | values. It runs \code{\link[stats:isoreg]{stats::isoreg()}} multiple times, and each time with a different 79 | seed. The results are saved inside the returned \code{cal_object}. 80 | } 81 | \section{Multiclass Extension}{ 82 | 83 | 84 | This method is designed to work with two classes. For multiclass, it creates 85 | a set of "one versus all" calibrations for each class. After they are 86 | applied to the data, the probability estimates are re-normalized to add to 87 | one. This final step might compromise the calibration. 88 | } 89 | 90 | \examples{ 91 | # It will automatically identify the probability columns 92 | # if passed a model fitted with tidymodels 93 | cal_estimate_isotonic_boot(segment_logistic, Class) 94 | # Specify the variable names in a vector of unquoted names 95 | cal_estimate_isotonic_boot(segment_logistic, Class, c(.pred_poor, .pred_good)) 96 | # dplyr selector functions are also supported 97 | cal_estimate_isotonic_boot(segment_logistic, Class, dplyr::starts_with(".pred")) 98 | } 99 | \seealso{ 100 | \url{https://www.tidymodels.org/learn/models/calibration/}, 101 | \code{\link[=cal_validate_isotonic_boot]{cal_validate_isotonic_boot()}} 102 | } 103 | -------------------------------------------------------------------------------- /man/cal_estimate_linear.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-estimate-linear.R 3 | \name{cal_estimate_linear} 4 | \alias{cal_estimate_linear} 5 | \alias{cal_estimate_linear.data.frame} 6 | \alias{cal_estimate_linear.tune_results} 7 | \alias{cal_estimate_linear.grouped_df} 8 | \title{Uses a linear regression model to calibrate numeric predictions} 9 | \usage{ 10 | cal_estimate_linear( 11 | .data, 12 | truth = NULL, 13 | estimate = dplyr::matches("^.pred$"), 14 | smooth = TRUE, 15 | parameters = NULL, 16 | ..., 17 | .by = NULL 18 | ) 19 | 20 | \method{cal_estimate_linear}{data.frame}( 21 | .data, 22 | truth = NULL, 23 | estimate = dplyr::matches("^.pred$"), 24 | smooth = TRUE, 25 | parameters = NULL, 26 | ..., 27 | .by = NULL 28 | ) 29 | 30 | \method{cal_estimate_linear}{tune_results}( 31 | .data, 32 | truth = NULL, 33 | estimate = dplyr::matches("^.pred$"), 34 | smooth = TRUE, 35 | parameters = NULL, 36 | ... 37 | ) 38 | 39 | \method{cal_estimate_linear}{grouped_df}( 40 | .data, 41 | truth = NULL, 42 | estimate = NULL, 43 | smooth = TRUE, 44 | parameters = NULL, 45 | ... 46 | ) 47 | } 48 | \arguments{ 49 | \item{.data}{Am ungrouped \code{data.frame} object, or \code{tune_results} object, 50 | that contains a prediction column.} 51 | 52 | \item{truth}{The column identifier for the observed outcome data (that is 53 | numeric). This should be an unquoted column name.} 54 | 55 | \item{estimate}{Column identifier for the predicted values} 56 | 57 | \item{smooth}{Applies to the linear models. It switches between a generalized 58 | additive model using spline terms when \code{TRUE}, and simple linear regression 59 | when \code{FALSE}.} 60 | 61 | \item{parameters}{(Optional) An optional tibble of tuning parameter values 62 | that can be used to filter the predicted values before processing. Applies 63 | only to \code{tune_results} objects.} 64 | 65 | \item{...}{Additional arguments passed to the models or routines used to 66 | calculate the new predictions.} 67 | 68 | \item{.by}{The column identifier for the grouping variable. This should be 69 | a single unquoted column name that selects a qualitative variable for 70 | grouping. Default to \code{NULL}. When \code{.by = NULL} no grouping will take place.} 71 | } 72 | \description{ 73 | Uses a linear regression model to calibrate numeric predictions 74 | } 75 | \details{ 76 | This function uses existing modeling functions from other packages to create 77 | the calibration: 78 | \itemize{ 79 | \item \code{\link[stats:glm]{stats::glm()}} is used when \code{smooth} is set to \code{FALSE} 80 | \item \code{\link[mgcv:gam]{mgcv::gam()}} is used when \code{smooth} is set to \code{TRUE} 81 | } 82 | 83 | These methods estimate the relationship in the unmodified predicted values 84 | and then remove that trend when \code{\link[=cal_apply]{cal_apply()}} is invoked. 85 | } 86 | \examples{ 87 | library(dplyr) 88 | library(ggplot2) 89 | 90 | head(boosting_predictions_test) 91 | 92 | # ------------------------------------------------------------------------------ 93 | # Before calibration 94 | 95 | y_rng <- extendrange(boosting_predictions_test$outcome) 96 | 97 | boosting_predictions_test |> 98 | ggplot(aes(outcome, .pred)) + 99 | geom_abline(lty = 2) + 100 | geom_point(alpha = 1 / 2) + 101 | geom_smooth(se = FALSE, col = "blue", linewidth = 1.2, alpha = 3 / 4) + 102 | coord_equal(xlim = y_rng, ylim = y_rng) + 103 | ggtitle("Before calibration") 104 | 105 | # ------------------------------------------------------------------------------ 106 | # Smoothed trend removal 107 | 108 | smoothed_cal <- 109 | boosting_predictions_oob |> 110 | # It will automatically identify the predicted value columns when the 111 | # standard tidymodels naming conventions are used. 112 | cal_estimate_linear(outcome) 113 | smoothed_cal 114 | 115 | boosting_predictions_test |> 116 | cal_apply(smoothed_cal) |> 117 | ggplot(aes(outcome, .pred)) + 118 | geom_abline(lty = 2) + 119 | geom_point(alpha = 1 / 2) + 120 | geom_smooth(se = FALSE, col = "blue", linewidth = 1.2, alpha = 3 / 4) + 121 | coord_equal(xlim = y_rng, ylim = y_rng) + 122 | ggtitle("After calibration") 123 | 124 | } 125 | \seealso{ 126 | \url{https://www.tidymodels.org/learn/models/calibration/}, 127 | \code{\link[=cal_validate_linear]{cal_validate_linear()}} 128 | } 129 | -------------------------------------------------------------------------------- /man/cal_estimate_logistic.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-estimate-logistic.R 3 | \name{cal_estimate_logistic} 4 | \alias{cal_estimate_logistic} 5 | \alias{cal_estimate_logistic.data.frame} 6 | \alias{cal_estimate_logistic.tune_results} 7 | \alias{cal_estimate_logistic.grouped_df} 8 | \title{Uses a logistic regression model to calibrate probabilities} 9 | \usage{ 10 | cal_estimate_logistic( 11 | .data, 12 | truth = NULL, 13 | estimate = dplyr::starts_with(".pred_"), 14 | smooth = TRUE, 15 | parameters = NULL, 16 | ... 17 | ) 18 | 19 | \method{cal_estimate_logistic}{data.frame}( 20 | .data, 21 | truth = NULL, 22 | estimate = dplyr::starts_with(".pred_"), 23 | smooth = TRUE, 24 | parameters = NULL, 25 | ..., 26 | .by = NULL 27 | ) 28 | 29 | \method{cal_estimate_logistic}{tune_results}( 30 | .data, 31 | truth = NULL, 32 | estimate = dplyr::starts_with(".pred_"), 33 | smooth = TRUE, 34 | parameters = NULL, 35 | ... 36 | ) 37 | 38 | \method{cal_estimate_logistic}{grouped_df}( 39 | .data, 40 | truth = NULL, 41 | estimate = NULL, 42 | smooth = TRUE, 43 | parameters = NULL, 44 | ... 45 | ) 46 | } 47 | \arguments{ 48 | \item{.data}{An ungrouped \code{data.frame} object, or \code{tune_results} object, 49 | that contains predictions and probability columns.} 50 | 51 | \item{truth}{The column identifier for the true class results 52 | (that is a factor). This should be an unquoted column name.} 53 | 54 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 55 | functions to choose which variables contains the class probabilities. It 56 | defaults to the prefix used by tidymodels (\code{.pred_}). The order of the 57 | identifiers will be considered the same as the order of the levels of the 58 | \code{truth} variable.} 59 | 60 | \item{smooth}{Applies to the logistic models. It switches between logistic 61 | spline when \code{TRUE}, and simple logistic regression when \code{FALSE}.} 62 | 63 | \item{parameters}{(Optional) An optional tibble of tuning parameter values 64 | that can be used to filter the predicted values before processing. Applies 65 | only to \code{tune_results} objects.} 66 | 67 | \item{...}{Additional arguments passed to the models or routines used to 68 | calculate the new probabilities.} 69 | 70 | \item{.by}{The column identifier for the grouping variable. This should be 71 | a single unquoted column name that selects a qualitative variable for 72 | grouping. Default to \code{NULL}. When \code{.by = NULL} no grouping will take place.} 73 | } 74 | \description{ 75 | Uses a logistic regression model to calibrate probabilities 76 | } 77 | \details{ 78 | This function uses existing modeling functions from other packages to create 79 | the calibration: 80 | \itemize{ 81 | \item \code{\link[stats:glm]{stats::glm()}} is used when \code{smooth} is set to \code{FALSE} 82 | \item \code{\link[mgcv:gam]{mgcv::gam()}} is used when \code{smooth} is set to \code{TRUE} 83 | } 84 | \subsection{Multiclass Extension}{ 85 | 86 | This method has \emph{not} been extended to multiclass outcomes. However, the 87 | natural multiclass extension is \code{\link[=cal_estimate_multinomial]{cal_estimate_multinomial()}}. 88 | } 89 | } 90 | \examples{ 91 | # It will automatically identify the probability columns 92 | # if passed a model fitted with tidymodels 93 | cal_estimate_logistic(segment_logistic, Class) 94 | 95 | # Specify the variable names in a vector of unquoted names 96 | cal_estimate_logistic(segment_logistic, Class, c(.pred_poor, .pred_good)) 97 | 98 | # dplyr selector functions are also supported 99 | cal_estimate_logistic(segment_logistic, Class, dplyr::starts_with(".pred_")) 100 | } 101 | \seealso{ 102 | \url{https://www.tidymodels.org/learn/models/calibration/}, 103 | \code{\link[=cal_validate_logistic]{cal_validate_logistic()}} 104 | } 105 | -------------------------------------------------------------------------------- /man/cal_estimate_multinomial.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-estimate-multinom.R 3 | \name{cal_estimate_multinomial} 4 | \alias{cal_estimate_multinomial} 5 | \alias{cal_estimate_multinomial.data.frame} 6 | \alias{cal_estimate_multinomial.tune_results} 7 | \alias{cal_estimate_multinomial.grouped_df} 8 | \title{Uses a Multinomial calibration model to calculate new probabilities} 9 | \usage{ 10 | cal_estimate_multinomial( 11 | .data, 12 | truth = NULL, 13 | estimate = dplyr::starts_with(".pred_"), 14 | smooth = TRUE, 15 | parameters = NULL, 16 | ... 17 | ) 18 | 19 | \method{cal_estimate_multinomial}{data.frame}( 20 | .data, 21 | truth = NULL, 22 | estimate = dplyr::starts_with(".pred_"), 23 | smooth = TRUE, 24 | parameters = NULL, 25 | ..., 26 | .by = NULL 27 | ) 28 | 29 | \method{cal_estimate_multinomial}{tune_results}( 30 | .data, 31 | truth = NULL, 32 | estimate = dplyr::starts_with(".pred_"), 33 | smooth = TRUE, 34 | parameters = NULL, 35 | ... 36 | ) 37 | 38 | \method{cal_estimate_multinomial}{grouped_df}( 39 | .data, 40 | truth = NULL, 41 | estimate = NULL, 42 | smooth = TRUE, 43 | parameters = NULL, 44 | ... 45 | ) 46 | } 47 | \arguments{ 48 | \item{.data}{An ungrouped \code{data.frame} object, or \code{tune_results} object, 49 | that contains predictions and probability columns.} 50 | 51 | \item{truth}{The column identifier for the true class results 52 | (that is a factor). This should be an unquoted column name.} 53 | 54 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 55 | functions to choose which variables contains the class probabilities. It 56 | defaults to the prefix used by tidymodels (\code{.pred_}). The order of the 57 | identifiers will be considered the same as the order of the levels of the 58 | \code{truth} variable.} 59 | 60 | \item{smooth}{Applies to the logistic models. It switches between logistic 61 | spline when \code{TRUE}, and simple logistic regression when \code{FALSE}.} 62 | 63 | \item{parameters}{(Optional) An optional tibble of tuning parameter values 64 | that can be used to filter the predicted values before processing. Applies 65 | only to \code{tune_results} objects.} 66 | 67 | \item{...}{Additional arguments passed to the models or routines used to 68 | calculate the new probabilities.} 69 | 70 | \item{.by}{The column identifier for the grouping variable. This should be 71 | a single unquoted column name that selects a qualitative variable for 72 | grouping. Default to \code{NULL}. When \code{.by = NULL} no grouping will take place.} 73 | } 74 | \description{ 75 | Uses a Multinomial calibration model to calculate new probabilities 76 | } 77 | \details{ 78 | When \code{smooth = FALSE}, \code{\link[nnet:multinom]{nnet::multinom()}} function is used to estimate the 79 | model, otherwise \code{\link[mgcv:gam]{mgcv::gam()}} is used. 80 | } 81 | \examples{ 82 | \dontshow{if (!probably:::is_cran_check() & rlang::is_installed(c("modeldata", "parsnip", "randomForest"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 83 | library(modeldata) 84 | library(parsnip) 85 | library(dplyr) 86 | 87 | f <- 88 | list( 89 | ~ -0.5 + 0.6 * abs(A), 90 | ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, -2), 91 | ~ -0.6 * A + 0.50 * B - A * B 92 | ) 93 | 94 | set.seed(1) 95 | tr_dat <- sim_multinomial(500, eqn_1 = f[[1]], eqn_2 = f[[2]], eqn_3 = f[[3]]) 96 | cal_dat <- sim_multinomial(500, eqn_1 = f[[1]], eqn_2 = f[[2]], eqn_3 = f[[3]]) 97 | te_dat <- sim_multinomial(500, eqn_1 = f[[1]], eqn_2 = f[[2]], eqn_3 = f[[3]]) 98 | 99 | set.seed(2) 100 | rf_fit <- 101 | rand_forest() |> 102 | set_mode("classification") |> 103 | set_engine("randomForest") |> 104 | fit(class ~ ., data = tr_dat) 105 | 106 | cal_pred <- 107 | predict(rf_fit, cal_dat, type = "prob") |> 108 | bind_cols(cal_dat) 109 | te_pred <- 110 | predict(rf_fit, te_dat, type = "prob") |> 111 | bind_cols(te_dat) 112 | 113 | cal_plot_windowed(cal_pred, truth = class, window_size = 0.1, step_size = 0.03) 114 | 115 | smoothed_mn <- cal_estimate_multinomial(cal_pred, truth = class) 116 | 117 | new_test_pred <- cal_apply(te_pred, smoothed_mn) 118 | 119 | cal_plot_windowed(new_test_pred, truth = class, window_size = 0.1, step_size = 0.03) 120 | \dontshow{\}) # examplesIf} 121 | } 122 | \seealso{ 123 | \url{https://www.tidymodels.org/learn/models/calibration/}, 124 | \code{\link[=cal_validate_multinomial]{cal_validate_multinomial()}} 125 | } 126 | -------------------------------------------------------------------------------- /man/cal_estimate_none.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-estimate-none.R 3 | \name{cal_estimate_none} 4 | \alias{cal_estimate_none} 5 | \alias{cal_estimate_none.data.frame} 6 | \alias{cal_estimate_none.tune_results} 7 | \alias{cal_estimate_none.grouped_df} 8 | \title{Do not calibrate model predictions.} 9 | \usage{ 10 | cal_estimate_none( 11 | .data, 12 | truth = NULL, 13 | estimate = dplyr::starts_with(".pred"), 14 | parameters = NULL, 15 | ... 16 | ) 17 | 18 | \method{cal_estimate_none}{data.frame}( 19 | .data, 20 | truth = NULL, 21 | estimate = dplyr::starts_with(".pred"), 22 | parameters = NULL, 23 | ..., 24 | .by = NULL 25 | ) 26 | 27 | \method{cal_estimate_none}{tune_results}( 28 | .data, 29 | truth = NULL, 30 | estimate = dplyr::starts_with(".pred"), 31 | parameters = NULL, 32 | ... 33 | ) 34 | 35 | \method{cal_estimate_none}{grouped_df}(.data, truth = NULL, estimate = NULL, parameters = NULL, ...) 36 | } 37 | \arguments{ 38 | \item{.data}{An ungrouped \code{data.frame} object, or \code{tune_results} object, 39 | that contains predictions and probability columns.} 40 | 41 | \item{truth}{The column identifier for the true outcome results 42 | (that is factor or numeric). This should be an unquoted column name.} 43 | 44 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 45 | functions to choose which variables contains the class probabilities or 46 | numeric predictions. It defaults to the prefix used by tidymodels (\code{.pred_}). 47 | For classification problems, the order of the identifiers will be considered 48 | the same as the order of the levels of the \code{truth} variable.} 49 | 50 | \item{parameters}{(Optional) An optional tibble of tuning parameter values 51 | that can be used to filter the predicted values before processing. Applies 52 | only to \code{tune_results} objects.} 53 | 54 | \item{...}{Additional arguments passed to the models or routines used to 55 | calculate the new probabilities.} 56 | 57 | \item{.by}{The column identifier for the grouping variable. This should be 58 | a single unquoted column name that selects a qualitative variable for 59 | grouping. Default to \code{NULL}. When \code{.by = NULL} no grouping will take place.} 60 | } 61 | \description{ 62 | Do not calibrate model predictions. 63 | } 64 | \details{ 65 | This function does nothing to the predictions. It is used as a 66 | reference when tuning over different calibration methods. 67 | } 68 | \examples{ 69 | 70 | nada <- cal_estimate_none(boosting_predictions_oob, outcome, .pred) 71 | nada 72 | 73 | identical( 74 | cal_apply(boosting_predictions_oob, nada), 75 | boosting_predictions_oob 76 | ) 77 | 78 | # ------------------------------------------------------------------------------ 79 | 80 | nichts <- cal_estimate_none(segment_logistic, Class) 81 | 82 | identical( 83 | cal_apply(segment_logistic, nichts), 84 | segment_logistic 85 | ) 86 | } 87 | -------------------------------------------------------------------------------- /man/cal_plot_logistic.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-plot-logistic.R 3 | \name{cal_plot_logistic} 4 | \alias{cal_plot_logistic} 5 | \alias{cal_plot_logistic.data.frame} 6 | \alias{cal_plot_logistic.tune_results} 7 | \alias{cal_plot_logistic.grouped_df} 8 | \title{Probability calibration plots via logistic regression} 9 | \usage{ 10 | cal_plot_logistic( 11 | .data, 12 | truth = NULL, 13 | estimate = dplyr::starts_with(".pred"), 14 | conf_level = 0.9, 15 | smooth = TRUE, 16 | include_rug = TRUE, 17 | include_ribbon = TRUE, 18 | event_level = c("auto", "first", "second"), 19 | ... 20 | ) 21 | 22 | \method{cal_plot_logistic}{data.frame}( 23 | .data, 24 | truth = NULL, 25 | estimate = dplyr::starts_with(".pred"), 26 | conf_level = 0.9, 27 | smooth = TRUE, 28 | include_rug = TRUE, 29 | include_ribbon = TRUE, 30 | event_level = c("auto", "first", "second"), 31 | ..., 32 | .by = NULL 33 | ) 34 | 35 | \method{cal_plot_logistic}{tune_results}( 36 | .data, 37 | truth = NULL, 38 | estimate = dplyr::starts_with(".pred"), 39 | conf_level = 0.9, 40 | smooth = TRUE, 41 | include_rug = TRUE, 42 | include_ribbon = TRUE, 43 | event_level = c("auto", "first", "second"), 44 | ... 45 | ) 46 | 47 | \method{cal_plot_logistic}{grouped_df}( 48 | .data, 49 | truth = NULL, 50 | estimate = NULL, 51 | conf_level = 0.9, 52 | smooth = TRUE, 53 | include_rug = TRUE, 54 | include_ribbon = TRUE, 55 | event_level = c("auto", "first", "second"), 56 | ... 57 | ) 58 | } 59 | \arguments{ 60 | \item{.data}{An ungrouped data frame object containing predictions and 61 | probability columns.} 62 | 63 | \item{truth}{The column identifier for the true class results 64 | (that is a factor). This should be an unquoted column name.} 65 | 66 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 67 | functions to choose which variables contains the class probabilities. It 68 | defaults to the prefix used by tidymodels (\code{.pred_}). The order of the 69 | identifiers will be considered the same as the order of the levels of the 70 | \code{truth} variable.} 71 | 72 | \item{conf_level}{Confidence level to use in the visualization. It defaults 73 | to 0.9.} 74 | 75 | \item{smooth}{A logical for using a generalized additive model with smooth 76 | terms for the predictor via \code{\link[mgcv:gam]{mgcv::gam()}} and \code{\link[mgcv:s]{mgcv::s()}}.} 77 | 78 | \item{include_rug}{Flag that indicates if the Rug layer is to be included. 79 | It defaults to \code{TRUE}. In the plot, the top side shows the frequency the 80 | event occurring, and the bottom the frequency of the event not occurring.} 81 | 82 | \item{include_ribbon}{Flag that indicates if the ribbon layer is to be 83 | included. It defaults to \code{TRUE}.} 84 | 85 | \item{event_level}{single string. Either "first" or "second" to specify which 86 | level of truth to consider as the "event". Defaults to "auto", which allows 87 | the function decide which one to use based on the type of model (binary, 88 | multi-class or linear)} 89 | 90 | \item{...}{Additional arguments passed to the \code{tune_results} object.} 91 | 92 | \item{.by}{The column identifier for the grouping variable. This should be 93 | a single unquoted column name that selects a qualitative variable for 94 | grouping. Default to \code{NULL}. When \code{.by = NULL} no grouping will take place.} 95 | } 96 | \value{ 97 | A ggplot object. 98 | } 99 | \description{ 100 | A logistic regression model is fit where the original outcome data are used 101 | as the outcome and the estimated class probabilities for one class are used 102 | as the predictor. If \code{smooth = TRUE}, a generalized additive model is fit 103 | using \code{\link[mgcv:gam]{mgcv::gam()}} and the default smoothing method. Otherwise, a simple 104 | logistic regression is used. 105 | 106 | If the predictions are well calibrated, the fitted curve should align with 107 | the diagonal line. Confidence intervals for the fitted line are also 108 | shown. 109 | } 110 | \examples{ 111 | 112 | library(ggplot2) 113 | library(dplyr) 114 | 115 | cal_plot_logistic( 116 | segment_logistic, 117 | Class, 118 | .pred_good 119 | ) 120 | 121 | cal_plot_logistic( 122 | segment_logistic, 123 | Class, 124 | .pred_good, 125 | smooth = FALSE 126 | ) 127 | } 128 | \seealso{ 129 | \url{https://www.tidymodels.org/learn/models/calibration/}, 130 | \code{\link[=cal_plot_windowed]{cal_plot_windowed()}}, \code{\link[=cal_plot_breaks]{cal_plot_breaks()}} 131 | 132 | \code{\link[=cal_plot_breaks]{cal_plot_breaks()}}, \code{\link[=cal_plot_windowed]{cal_plot_windowed()}} 133 | } 134 | -------------------------------------------------------------------------------- /man/cal_plot_regression.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-plot-regression.R 3 | \name{cal_plot_regression} 4 | \alias{cal_plot_regression} 5 | \alias{cal_plot_regression.data.frame} 6 | \alias{cal_plot_regression.tune_results} 7 | \alias{cal_plot_regression.grouped_df} 8 | \title{Regression calibration plots} 9 | \usage{ 10 | cal_plot_regression(.data, truth = NULL, estimate = NULL, smooth = TRUE, ...) 11 | 12 | \method{cal_plot_regression}{data.frame}( 13 | .data, 14 | truth = NULL, 15 | estimate = NULL, 16 | smooth = TRUE, 17 | ..., 18 | .by = NULL 19 | ) 20 | 21 | \method{cal_plot_regression}{tune_results}(.data, truth = NULL, estimate = NULL, smooth = TRUE, ...) 22 | 23 | \method{cal_plot_regression}{grouped_df}(.data, truth = NULL, estimate = NULL, smooth = TRUE, ...) 24 | } 25 | \arguments{ 26 | \item{.data}{An ungrouped data frame object containing a prediction 27 | column.} 28 | 29 | \item{truth}{The column identifier for the true results 30 | (numeric). This should be an unquoted column name.} 31 | 32 | \item{estimate}{The column identifier for the predictions. 33 | This should be an unquoted column name} 34 | 35 | \item{smooth}{A logical: should a smoother curve be added.} 36 | 37 | \item{...}{Additional arguments passed to \code{\link[ggplot2:geom_point]{ggplot2::geom_point()}}.} 38 | 39 | \item{.by}{The column identifier for the grouping variable. This should be 40 | a single unquoted column name that selects a qualitative variable for 41 | grouping. Default to \code{NULL}. When \code{.by = NULL} no grouping will take place.} 42 | } 43 | \value{ 44 | A ggplot object. 45 | } 46 | \description{ 47 | A scatter plot of the observed and predicted values is computed where the 48 | axes are the same. When \code{smooth = TRUE}, a generalized additive model fit 49 | is shown. If the predictions are well calibrated, the fitted curve should align with 50 | the diagonal line. 51 | } 52 | \examples{ 53 | cal_plot_regression(boosting_predictions_oob, outcome, .pred) 54 | 55 | cal_plot_regression(boosting_predictions_oob, outcome, .pred, 56 | alpha = 1 / 6, cex = 3, smooth = FALSE 57 | ) 58 | 59 | cal_plot_regression(boosting_predictions_oob, outcome, .pred, 60 | .by = id, 61 | alpha = 1 / 6, cex = 3, smooth = FALSE 62 | ) 63 | } 64 | -------------------------------------------------------------------------------- /man/cal_validate_beta.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-validate.R 3 | \name{cal_validate_beta} 4 | \alias{cal_validate_beta} 5 | \alias{cal_validate_beta.resample_results} 6 | \alias{cal_validate_beta.rset} 7 | \alias{cal_validate_beta.tune_results} 8 | \title{Measure performance with and without using Beta calibration} 9 | \usage{ 10 | cal_validate_beta( 11 | .data, 12 | truth = NULL, 13 | estimate = dplyr::starts_with(".pred_"), 14 | metrics = NULL, 15 | save_pred = FALSE, 16 | ... 17 | ) 18 | 19 | \method{cal_validate_beta}{resample_results}( 20 | .data, 21 | truth = NULL, 22 | estimate = dplyr::starts_with(".pred_"), 23 | metrics = NULL, 24 | save_pred = FALSE, 25 | ... 26 | ) 27 | 28 | \method{cal_validate_beta}{rset}( 29 | .data, 30 | truth = NULL, 31 | estimate = dplyr::starts_with(".pred_"), 32 | metrics = NULL, 33 | save_pred = FALSE, 34 | ... 35 | ) 36 | 37 | \method{cal_validate_beta}{tune_results}( 38 | .data, 39 | truth = NULL, 40 | estimate = NULL, 41 | metrics = NULL, 42 | save_pred = FALSE, 43 | ... 44 | ) 45 | } 46 | \arguments{ 47 | \item{.data}{An \code{rset} object or the results of \code{\link[tune:fit_resamples]{tune::fit_resamples()}} with 48 | a \code{.predictions} column.} 49 | 50 | \item{truth}{The column identifier for the true class results 51 | (that is a factor). This should be an unquoted column name.} 52 | 53 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 54 | functions to choose which variables contains the class probabilities. It 55 | defaults to the prefix used by tidymodels (\code{.pred_}). The order of the 56 | identifiers will be considered the same as the order of the levels of the 57 | \code{truth} variable.} 58 | 59 | \item{metrics}{A set of metrics passed created via \code{\link[yardstick:metric_set]{yardstick::metric_set()}}} 60 | 61 | \item{save_pred}{Indicates whether to a column of post-calibration predictions.} 62 | 63 | \item{...}{Options to pass to \code{\link[=cal_estimate_beta]{cal_estimate_beta()}}, such as the 64 | \code{shape_params} and \code{location_params} arguments.} 65 | } 66 | \value{ 67 | The original object with a \code{.metrics_cal} column and, optionally, 68 | an additional \code{.predictions_cal} column. The class \code{cal_rset} is also added. 69 | } 70 | \description{ 71 | This function uses resampling to measure the effect of calibrating predicted 72 | values. 73 | } 74 | \details{ 75 | These functions are designed to calculate performance with and without 76 | calibration. They use resampling to measure out-of-sample effectiveness. 77 | There are two ways to pass the data in: 78 | \itemize{ 79 | \item If you have a data frame of predictions, an \code{rset} object can be created 80 | via \pkg{rsample} functions. See the example below. 81 | \item If you have already made a resampling object from the original data and 82 | used it with \code{\link[tune:fit_resamples]{tune::fit_resamples()}}, you can pass that object to the 83 | calibration function and it will use the same resampling scheme. If a 84 | different resampling scheme should be used, run 85 | \code{\link[tune:collect_predictions]{tune::collect_predictions()}} on the object and use the process in the 86 | previous bullet point. 87 | } 88 | 89 | Please note that these functions do not apply to \code{tune_result} objects. The 90 | notion of "validation" implies that the tuning parameter selection has been 91 | resolved. 92 | 93 | \code{collect_predictions()} can be used to aggregate the metrics for analysis. 94 | } 95 | \section{Performance Metrics}{ 96 | 97 | 98 | By default, the average of the Brier scores is returned. Any appropriate 99 | \code{\link[yardstick:metric_set]{yardstick::metric_set()}} can be used. The validation function compares the 100 | average of the metrics before, and after the calibration. 101 | } 102 | 103 | \examples{ 104 | 105 | library(dplyr) 106 | 107 | if (rlang::is_installed("betacal")) { 108 | segment_logistic |> 109 | rsample::vfold_cv() |> 110 | cal_validate_beta(Class) 111 | } 112 | } 113 | \seealso{ 114 | \url{https://www.tidymodels.org/learn/models/calibration/}, 115 | \code{\link[=cal_estimate_beta]{cal_estimate_beta()}} 116 | } 117 | -------------------------------------------------------------------------------- /man/cal_validate_isotonic.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-validate.R 3 | \name{cal_validate_isotonic} 4 | \alias{cal_validate_isotonic} 5 | \alias{cal_validate_isotonic.resample_results} 6 | \alias{cal_validate_isotonic.rset} 7 | \alias{cal_validate_isotonic.tune_results} 8 | \title{Measure performance with and without using isotonic regression calibration} 9 | \usage{ 10 | cal_validate_isotonic( 11 | .data, 12 | truth = NULL, 13 | estimate = dplyr::starts_with(".pred"), 14 | metrics = NULL, 15 | save_pred = FALSE, 16 | ... 17 | ) 18 | 19 | \method{cal_validate_isotonic}{resample_results}( 20 | .data, 21 | truth = NULL, 22 | estimate = dplyr::starts_with(".pred"), 23 | metrics = NULL, 24 | save_pred = FALSE, 25 | ... 26 | ) 27 | 28 | \method{cal_validate_isotonic}{rset}( 29 | .data, 30 | truth = NULL, 31 | estimate = dplyr::starts_with(".pred"), 32 | metrics = NULL, 33 | save_pred = FALSE, 34 | ... 35 | ) 36 | 37 | \method{cal_validate_isotonic}{tune_results}( 38 | .data, 39 | truth = NULL, 40 | estimate = NULL, 41 | metrics = NULL, 42 | save_pred = FALSE, 43 | ... 44 | ) 45 | } 46 | \arguments{ 47 | \item{.data}{An \code{rset} object or the results of \code{\link[tune:fit_resamples]{tune::fit_resamples()}} with 48 | a \code{.predictions} column.} 49 | 50 | \item{truth}{The column identifier for the true class results 51 | (that is a factor). This should be an unquoted column name.} 52 | 53 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 54 | functions to choose which variables contains the class probabilities. It 55 | defaults to the prefix used by tidymodels (\code{.pred_}). The order of the 56 | identifiers will be considered the same as the order of the levels of the 57 | \code{truth} variable.} 58 | 59 | \item{metrics}{A set of metrics passed created via \code{\link[yardstick:metric_set]{yardstick::metric_set()}}} 60 | 61 | \item{save_pred}{Indicates whether to a column of post-calibration predictions.} 62 | 63 | \item{...}{Options to pass to \code{\link[=cal_estimate_logistic]{cal_estimate_logistic()}}, such as the \code{smooth} 64 | argument.} 65 | } 66 | \value{ 67 | The original object with a \code{.metrics_cal} column and, optionally, 68 | an additional \code{.predictions_cal} column. The class \code{cal_rset} is also added. 69 | } 70 | \description{ 71 | This function uses resampling to measure the effect of calibrating predicted 72 | values. 73 | } 74 | \details{ 75 | These functions are designed to calculate performance with and without 76 | calibration. They use resampling to measure out-of-sample effectiveness. 77 | There are two ways to pass the data in: 78 | \itemize{ 79 | \item If you have a data frame of predictions, an \code{rset} object can be created 80 | via \pkg{rsample} functions. See the example below. 81 | \item If you have already made a resampling object from the original data and 82 | used it with \code{\link[tune:fit_resamples]{tune::fit_resamples()}}, you can pass that object to the 83 | calibration function and it will use the same resampling scheme. If a 84 | different resampling scheme should be used, run 85 | \code{\link[tune:collect_predictions]{tune::collect_predictions()}} on the object and use the process in the 86 | previous bullet point. 87 | } 88 | 89 | Please note that these functions do not apply to \code{tune_result} objects. The 90 | notion of "validation" implies that the tuning parameter selection has been 91 | resolved. 92 | 93 | \code{collect_predictions()} can be used to aggregate the metrics for analysis. 94 | } 95 | \section{Performance Metrics}{ 96 | 97 | 98 | By default, the average of the Brier scores (classification calibration) or the 99 | root mean squared error (regression) is returned. Any appropriate 100 | \code{\link[yardstick:metric_set]{yardstick::metric_set()}} can be used. The validation function compares the 101 | average of the metrics before, and after the calibration. 102 | } 103 | 104 | \examples{ 105 | 106 | library(dplyr) 107 | 108 | segment_logistic |> 109 | rsample::vfold_cv() |> 110 | cal_validate_isotonic(Class) 111 | 112 | } 113 | \seealso{ 114 | \url{https://www.tidymodels.org/learn/models/calibration/}, 115 | \code{\link[=cal_estimate_isotonic]{cal_estimate_isotonic()}} 116 | } 117 | -------------------------------------------------------------------------------- /man/cal_validate_isotonic_boot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-validate.R 3 | \name{cal_validate_isotonic_boot} 4 | \alias{cal_validate_isotonic_boot} 5 | \alias{cal_validate_isotonic_boot.resample_results} 6 | \alias{cal_validate_isotonic_boot.rset} 7 | \alias{cal_validate_isotonic_boot.tune_results} 8 | \title{Measure performance with and without using bagged isotonic regression calibration} 9 | \usage{ 10 | cal_validate_isotonic_boot( 11 | .data, 12 | truth = NULL, 13 | estimate = dplyr::starts_with(".pred"), 14 | metrics = NULL, 15 | save_pred = FALSE, 16 | ... 17 | ) 18 | 19 | \method{cal_validate_isotonic_boot}{resample_results}( 20 | .data, 21 | truth = NULL, 22 | estimate = dplyr::starts_with(".pred"), 23 | metrics = NULL, 24 | save_pred = FALSE, 25 | ... 26 | ) 27 | 28 | \method{cal_validate_isotonic_boot}{rset}( 29 | .data, 30 | truth = NULL, 31 | estimate = dplyr::starts_with(".pred"), 32 | metrics = NULL, 33 | save_pred = FALSE, 34 | ... 35 | ) 36 | 37 | \method{cal_validate_isotonic_boot}{tune_results}( 38 | .data, 39 | truth = NULL, 40 | estimate = NULL, 41 | metrics = NULL, 42 | save_pred = FALSE, 43 | ... 44 | ) 45 | } 46 | \arguments{ 47 | \item{.data}{An \code{rset} object or the results of \code{\link[tune:fit_resamples]{tune::fit_resamples()}} with 48 | a \code{.predictions} column.} 49 | 50 | \item{truth}{The column identifier for the true class results 51 | (that is a factor). This should be an unquoted column name.} 52 | 53 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 54 | functions to choose which variables contains the class probabilities. It 55 | defaults to the prefix used by tidymodels (\code{.pred_}). The order of the 56 | identifiers will be considered the same as the order of the levels of the 57 | \code{truth} variable.} 58 | 59 | \item{metrics}{A set of metrics passed created via \code{\link[yardstick:metric_set]{yardstick::metric_set()}}} 60 | 61 | \item{save_pred}{Indicates whether to a column of post-calibration predictions.} 62 | 63 | \item{...}{Options to pass to \code{\link[=cal_estimate_isotonic_boot]{cal_estimate_isotonic_boot()}}, such as the 64 | \code{times} argument.} 65 | } 66 | \value{ 67 | The original object with a \code{.metrics_cal} column and, optionally, 68 | an additional \code{.predictions_cal} column. The class \code{cal_rset} is also added. 69 | } 70 | \description{ 71 | This function uses resampling to measure the effect of calibrating predicted 72 | values. 73 | } 74 | \details{ 75 | These functions are designed to calculate performance with and without 76 | calibration. They use resampling to measure out-of-sample effectiveness. 77 | There are two ways to pass the data in: 78 | \itemize{ 79 | \item If you have a data frame of predictions, an \code{rset} object can be created 80 | via \pkg{rsample} functions. See the example below. 81 | \item If you have already made a resampling object from the original data and 82 | used it with \code{\link[tune:fit_resamples]{tune::fit_resamples()}}, you can pass that object to the 83 | calibration function and it will use the same resampling scheme. If a 84 | different resampling scheme should be used, run 85 | \code{\link[tune:collect_predictions]{tune::collect_predictions()}} on the object and use the process in the 86 | previous bullet point. 87 | } 88 | 89 | Please note that these functions do not apply to \code{tune_result} objects. The 90 | notion of "validation" implies that the tuning parameter selection has been 91 | resolved. 92 | 93 | \code{collect_predictions()} can be used to aggregate the metrics for analysis. 94 | } 95 | \section{Performance Metrics}{ 96 | 97 | 98 | By default, the average of the Brier scores (classification calibration) or the 99 | root mean squared error (regression) is returned. Any appropriate 100 | \code{\link[yardstick:metric_set]{yardstick::metric_set()}} can be used. The validation function compares the 101 | average of the metrics before, and after the calibration. 102 | } 103 | 104 | \examples{ 105 | 106 | library(dplyr) 107 | 108 | segment_logistic |> 109 | rsample::vfold_cv() |> 110 | cal_validate_isotonic_boot(Class) 111 | 112 | } 113 | \seealso{ 114 | \url{https://www.tidymodels.org/learn/models/calibration/}, 115 | \code{\link[=cal_estimate_isotonic_boot]{cal_estimate_isotonic_boot()}} 116 | } 117 | -------------------------------------------------------------------------------- /man/cal_validate_linear.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-validate.R 3 | \name{cal_validate_linear} 4 | \alias{cal_validate_linear} 5 | \alias{cal_validate_linear.resample_results} 6 | \alias{cal_validate_linear.rset} 7 | \title{Measure performance with and without using linear regression calibration} 8 | \usage{ 9 | cal_validate_linear( 10 | .data, 11 | truth = NULL, 12 | estimate = dplyr::starts_with(".pred"), 13 | metrics = NULL, 14 | save_pred = FALSE, 15 | ... 16 | ) 17 | 18 | \method{cal_validate_linear}{resample_results}( 19 | .data, 20 | truth = NULL, 21 | estimate = dplyr::starts_with(".pred"), 22 | metrics = NULL, 23 | save_pred = FALSE, 24 | ... 25 | ) 26 | 27 | \method{cal_validate_linear}{rset}( 28 | .data, 29 | truth = NULL, 30 | estimate = dplyr::starts_with(".pred"), 31 | metrics = NULL, 32 | save_pred = FALSE, 33 | ... 34 | ) 35 | } 36 | \arguments{ 37 | \item{.data}{An \code{rset} object or the results of \code{\link[tune:fit_resamples]{tune::fit_resamples()}} with 38 | a \code{.predictions} column.} 39 | 40 | \item{truth}{The column identifier for the true class results 41 | (that is a factor). This should be an unquoted column name.} 42 | 43 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 44 | functions to choose which variables contains the class probabilities. It 45 | defaults to the prefix used by tidymodels (\code{.pred_}). The order of the 46 | identifiers will be considered the same as the order of the levels of the 47 | \code{truth} variable.} 48 | 49 | \item{metrics}{A set of metrics passed created via \code{\link[yardstick:metric_set]{yardstick::metric_set()}}} 50 | 51 | \item{save_pred}{Indicates whether to a column of post-calibration predictions.} 52 | 53 | \item{...}{Options to pass to \code{\link[=cal_estimate_logistic]{cal_estimate_logistic()}}, such as the \code{smooth} 54 | argument.} 55 | } 56 | \description{ 57 | Measure performance with and without using linear regression calibration 58 | } 59 | \section{Performance Metrics}{ 60 | 61 | 62 | By default, the average of the root mean square error (RMSE) is returned. 63 | Any appropriate \code{\link[yardstick:metric_set]{yardstick::metric_set()}} can be used. The validation 64 | function compares the average of the metrics before, and after the calibration. 65 | } 66 | 67 | \examples{ 68 | library(dplyr) 69 | library(yardstick) 70 | library(rsample) 71 | 72 | head(boosting_predictions_test) 73 | 74 | reg_stats <- metric_set(rmse, ccc) 75 | 76 | set.seed(828) 77 | boosting_predictions_oob |> 78 | # Resample with 10-fold cross-validation 79 | vfold_cv() |> 80 | cal_validate_linear(truth = outcome, smooth = FALSE, metrics = reg_stats) 81 | } 82 | \seealso{ 83 | \url{https://www.tidymodels.org/learn/models/calibration/}, 84 | \code{\link[=cal_estimate_linear]{cal_estimate_linear()}} 85 | } 86 | -------------------------------------------------------------------------------- /man/cal_validate_logistic.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-validate.R 3 | \name{cal_validate_logistic} 4 | \alias{cal_validate_logistic} 5 | \alias{cal_validate_logistic.resample_results} 6 | \alias{cal_validate_logistic.rset} 7 | \alias{cal_validate_logistic.tune_results} 8 | \title{Measure performance with and without using logistic calibration} 9 | \usage{ 10 | cal_validate_logistic( 11 | .data, 12 | truth = NULL, 13 | estimate = dplyr::starts_with(".pred_"), 14 | metrics = NULL, 15 | save_pred = FALSE, 16 | ... 17 | ) 18 | 19 | \method{cal_validate_logistic}{resample_results}( 20 | .data, 21 | truth = NULL, 22 | estimate = dplyr::starts_with(".pred_"), 23 | metrics = NULL, 24 | save_pred = FALSE, 25 | ... 26 | ) 27 | 28 | \method{cal_validate_logistic}{rset}( 29 | .data, 30 | truth = NULL, 31 | estimate = dplyr::starts_with(".pred_"), 32 | metrics = NULL, 33 | save_pred = FALSE, 34 | ... 35 | ) 36 | 37 | \method{cal_validate_logistic}{tune_results}( 38 | .data, 39 | truth = NULL, 40 | estimate = NULL, 41 | metrics = NULL, 42 | save_pred = FALSE, 43 | ... 44 | ) 45 | } 46 | \arguments{ 47 | \item{.data}{An \code{rset} object or the results of \code{\link[tune:fit_resamples]{tune::fit_resamples()}} with 48 | a \code{.predictions} column.} 49 | 50 | \item{truth}{The column identifier for the true class results 51 | (that is a factor). This should be an unquoted column name.} 52 | 53 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 54 | functions to choose which variables contains the class probabilities. It 55 | defaults to the prefix used by tidymodels (\code{.pred_}). The order of the 56 | identifiers will be considered the same as the order of the levels of the 57 | \code{truth} variable.} 58 | 59 | \item{metrics}{A set of metrics passed created via \code{\link[yardstick:metric_set]{yardstick::metric_set()}}} 60 | 61 | \item{save_pred}{Indicates whether to a column of post-calibration predictions.} 62 | 63 | \item{...}{Options to pass to \code{\link[=cal_estimate_logistic]{cal_estimate_logistic()}}, such as the \code{smooth} 64 | argument.} 65 | } 66 | \value{ 67 | The original object with a \code{.metrics_cal} column and, optionally, 68 | an additional \code{.predictions_cal} column. The class \code{cal_rset} is also added. 69 | } 70 | \description{ 71 | This function uses resampling to measure the effect of calibrating predicted 72 | values. 73 | } 74 | \details{ 75 | These functions are designed to calculate performance with and without 76 | calibration. They use resampling to measure out-of-sample effectiveness. 77 | There are two ways to pass the data in: 78 | \itemize{ 79 | \item If you have a data frame of predictions, an \code{rset} object can be created 80 | via \pkg{rsample} functions. See the example below. 81 | \item If you have already made a resampling object from the original data and 82 | used it with \code{\link[tune:fit_resamples]{tune::fit_resamples()}}, you can pass that object to the 83 | calibration function and it will use the same resampling scheme. If a 84 | different resampling scheme should be used, run 85 | \code{\link[tune:collect_predictions]{tune::collect_predictions()}} on the object and use the process in the 86 | previous bullet point. 87 | } 88 | 89 | Please note that these functions do not apply to \code{tune_result} objects. The 90 | notion of "validation" implies that the tuning parameter selection has been 91 | resolved. 92 | 93 | \code{collect_predictions()} can be used to aggregate the metrics for analysis. 94 | } 95 | \section{Performance Metrics}{ 96 | 97 | 98 | By default, the average of the Brier scores is returned. Any appropriate 99 | \code{\link[yardstick:metric_set]{yardstick::metric_set()}} can be used. The validation function compares the 100 | average of the metrics before, and after the calibration. 101 | } 102 | 103 | \examples{ 104 | 105 | library(dplyr) 106 | 107 | # --------------------------------------------------------------------------- 108 | # classification example 109 | 110 | segment_logistic |> 111 | rsample::vfold_cv() |> 112 | cal_validate_logistic(Class) 113 | 114 | } 115 | \seealso{ 116 | \url{https://www.tidymodels.org/learn/models/calibration/}, 117 | \code{\link[=cal_estimate_logistic]{cal_estimate_logistic()}} 118 | } 119 | -------------------------------------------------------------------------------- /man/cal_validate_multinomial.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-validate.R 3 | \name{cal_validate_multinomial} 4 | \alias{cal_validate_multinomial} 5 | \alias{cal_validate_multinomial.resample_results} 6 | \alias{cal_validate_multinomial.rset} 7 | \alias{cal_validate_multinomial.tune_results} 8 | \title{Measure performance with and without using multinomial calibration} 9 | \usage{ 10 | cal_validate_multinomial( 11 | .data, 12 | truth = NULL, 13 | estimate = dplyr::starts_with(".pred_"), 14 | metrics = NULL, 15 | save_pred = FALSE, 16 | ... 17 | ) 18 | 19 | \method{cal_validate_multinomial}{resample_results}( 20 | .data, 21 | truth = NULL, 22 | estimate = dplyr::starts_with(".pred_"), 23 | metrics = NULL, 24 | save_pred = FALSE, 25 | ... 26 | ) 27 | 28 | \method{cal_validate_multinomial}{rset}( 29 | .data, 30 | truth = NULL, 31 | estimate = dplyr::starts_with(".pred_"), 32 | metrics = NULL, 33 | save_pred = FALSE, 34 | ... 35 | ) 36 | 37 | \method{cal_validate_multinomial}{tune_results}( 38 | .data, 39 | truth = NULL, 40 | estimate = NULL, 41 | metrics = NULL, 42 | save_pred = FALSE, 43 | ... 44 | ) 45 | } 46 | \arguments{ 47 | \item{.data}{An \code{rset} object or the results of \code{\link[tune:fit_resamples]{tune::fit_resamples()}} with 48 | a \code{.predictions} column.} 49 | 50 | \item{truth}{The column identifier for the true class results 51 | (that is a factor). This should be an unquoted column name.} 52 | 53 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 54 | functions to choose which variables contains the class probabilities. It 55 | defaults to the prefix used by tidymodels (\code{.pred_}). The order of the 56 | identifiers will be considered the same as the order of the levels of the 57 | \code{truth} variable.} 58 | 59 | \item{metrics}{A set of metrics passed created via \code{\link[yardstick:metric_set]{yardstick::metric_set()}}} 60 | 61 | \item{save_pred}{Indicates whether to a column of post-calibration predictions.} 62 | 63 | \item{...}{Options to pass to \code{\link[=cal_estimate_logistic]{cal_estimate_logistic()}}, such as the \code{smooth} 64 | argument.} 65 | } 66 | \value{ 67 | The original object with a \code{.metrics_cal} column and, optionally, 68 | an additional \code{.predictions_cal} column. The class \code{cal_rset} is also added. 69 | } 70 | \description{ 71 | This function uses resampling to measure the effect of calibrating predicted 72 | values. 73 | } 74 | \details{ 75 | These functions are designed to calculate performance with and without 76 | calibration. They use resampling to measure out-of-sample effectiveness. 77 | There are two ways to pass the data in: 78 | \itemize{ 79 | \item If you have a data frame of predictions, an \code{rset} object can be created 80 | via \pkg{rsample} functions. See the example below. 81 | \item If you have already made a resampling object from the original data and 82 | used it with \code{\link[tune:fit_resamples]{tune::fit_resamples()}}, you can pass that object to the 83 | calibration function and it will use the same resampling scheme. If a 84 | different resampling scheme should be used, run 85 | \code{\link[tune:collect_predictions]{tune::collect_predictions()}} on the object and use the process in the 86 | previous bullet point. 87 | } 88 | 89 | Please note that these functions do not apply to \code{tune_result} objects. The 90 | notion of "validation" implies that the tuning parameter selection has been 91 | resolved. 92 | 93 | \code{collect_predictions()} can be used to aggregate the metrics for analysis. 94 | } 95 | \section{Performance Metrics}{ 96 | 97 | 98 | By default, the average of the Brier scores is returned. Any appropriate 99 | \code{\link[yardstick:metric_set]{yardstick::metric_set()}} can be used. The validation function compares the 100 | average of the metrics before, and after the calibration. 101 | } 102 | 103 | \examples{ 104 | 105 | library(dplyr) 106 | 107 | species_probs |> 108 | rsample::vfold_cv() |> 109 | cal_validate_multinomial(Species) 110 | 111 | } 112 | \seealso{ 113 | \code{\link[=cal_apply]{cal_apply()}}, \code{\link[=cal_estimate_multinomial]{cal_estimate_multinomial()}} 114 | } 115 | -------------------------------------------------------------------------------- /man/cal_validate_none.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-validate.R 3 | \name{cal_validate_none} 4 | \alias{cal_validate_none} 5 | \alias{cal_validate_none.resample_results} 6 | \alias{cal_validate_none.rset} 7 | \alias{cal_validate_none.tune_results} 8 | \title{Measure performance without using calibration} 9 | \usage{ 10 | cal_validate_none( 11 | .data, 12 | truth = NULL, 13 | estimate = dplyr::starts_with(".pred_"), 14 | metrics = NULL, 15 | save_pred = FALSE, 16 | ... 17 | ) 18 | 19 | \method{cal_validate_none}{resample_results}( 20 | .data, 21 | truth = NULL, 22 | estimate = dplyr::starts_with(".pred_"), 23 | metrics = NULL, 24 | save_pred = FALSE, 25 | ... 26 | ) 27 | 28 | \method{cal_validate_none}{rset}( 29 | .data, 30 | truth = NULL, 31 | estimate = dplyr::starts_with(".pred_"), 32 | metrics = NULL, 33 | save_pred = FALSE, 34 | ... 35 | ) 36 | 37 | \method{cal_validate_none}{tune_results}( 38 | .data, 39 | truth = NULL, 40 | estimate = NULL, 41 | metrics = NULL, 42 | save_pred = FALSE, 43 | ... 44 | ) 45 | } 46 | \arguments{ 47 | \item{.data}{An \code{rset} object or the results of \code{\link[tune:fit_resamples]{tune::fit_resamples()}} with 48 | a \code{.predictions} column.} 49 | 50 | \item{truth}{The column identifier for the true class results 51 | (that is a factor). This should be an unquoted column name.} 52 | 53 | \item{estimate}{A vector of column identifiers, or one of \code{dplyr} selector 54 | functions to choose which variables contains the class probabilities. It 55 | defaults to the prefix used by tidymodels (\code{.pred_}). The order of the 56 | identifiers will be considered the same as the order of the levels of the 57 | \code{truth} variable.} 58 | 59 | \item{metrics}{A set of metrics passed created via \code{\link[yardstick:metric_set]{yardstick::metric_set()}}} 60 | 61 | \item{save_pred}{Indicates whether to a column of post-calibration predictions.} 62 | 63 | \item{...}{Options to pass to \code{\link[=cal_estimate_logistic]{cal_estimate_logistic()}}, such as the \code{smooth} 64 | argument.} 65 | } 66 | \value{ 67 | The original object with a \code{.metrics_cal} column and, optionally, 68 | an additional \code{.predictions_cal} column. The class \code{cal_rset} is also added. 69 | } 70 | \description{ 71 | This function uses resampling to measure the effect of calibrating predicted 72 | values. 73 | } 74 | \details{ 75 | This function exists to have a complete API for all calibration methods. It 76 | returns the same results "with and without calibration" which, in this case, 77 | is always without calibration. 78 | 79 | There are two ways to pass the data in: 80 | \itemize{ 81 | \item If you have a data frame of predictions, an \code{rset} object can be created 82 | via \pkg{rsample} functions. See the example below. 83 | \item If you have already made a resampling object from the original data and 84 | used it with \code{\link[tune:fit_resamples]{tune::fit_resamples()}}, you can pass that object to the 85 | calibration function and it will use the same resampling scheme. If a 86 | different resampling scheme should be used, run 87 | \code{\link[tune:collect_predictions]{tune::collect_predictions()}} on the object and use the process in the 88 | previous bullet point. 89 | } 90 | 91 | Please note that these functions do not apply to \code{tune_result} objects. The 92 | notion of "validation" implies that the tuning parameter selection has been 93 | resolved. 94 | 95 | \code{collect_predictions()} can be used to aggregate the metrics for analysis. 96 | } 97 | \section{Performance Metrics}{ 98 | 99 | 100 | By default, the average of the Brier scores is returned. Any appropriate 101 | \code{\link[yardstick:metric_set]{yardstick::metric_set()}} can be used. The validation function compares the 102 | average of the metrics before, and after the calibration. 103 | } 104 | 105 | \examples{ 106 | 107 | library(dplyr) 108 | 109 | species_probs |> 110 | rsample::vfold_cv() |> 111 | cal_validate_none(Species) |> 112 | collect_metrics() 113 | 114 | } 115 | \seealso{ 116 | \code{\link[=cal_apply]{cal_apply()}}, \code{\link[=cal_estimate_none]{cal_estimate_none()}} 117 | } 118 | -------------------------------------------------------------------------------- /man/class_pred.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/class-pred.R 3 | \name{class_pred} 4 | \alias{class_pred} 5 | \title{Create a class prediction object} 6 | \usage{ 7 | class_pred(x = factor(), which = integer(), equivocal = "[EQ]") 8 | } 9 | \arguments{ 10 | \item{x}{A factor or ordered factor.} 11 | 12 | \item{which}{An integer vector specifying the locations of \code{x} to declare 13 | as equivocal.} 14 | 15 | \item{equivocal}{A single character specifying the equivocal label used 16 | when printing.} 17 | } 18 | \description{ 19 | \code{class_pred()} creates a \code{class_pred} object from a factor or ordered 20 | factor. You can optionally specify values of the factor to be set 21 | as \emph{equivocal}. 22 | } 23 | \details{ 24 | Equivocal values are those that you feel unsure about, and would like to 25 | exclude from performance calculations or other metrics. 26 | } 27 | \examples{ 28 | 29 | x <- factor(c("Yes", "No", "Yes", "Yes")) 30 | 31 | # Create a class_pred object from a factor 32 | class_pred(x) 33 | 34 | # Say you aren't sure about that 2nd "Yes" value. You could mark it as 35 | # equivocal. 36 | class_pred(x, which = 3) 37 | 38 | # Maybe you want a different equivocal label 39 | class_pred(x, which = 3, equivocal = "eq_value") 40 | 41 | } 42 | -------------------------------------------------------------------------------- /man/collect_metrics.cal_rset.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-validate.R 3 | \name{collect_metrics.cal_rset} 4 | \alias{collect_metrics.cal_rset} 5 | \title{Obtain and format metrics produced by calibration validation} 6 | \usage{ 7 | \method{collect_metrics}{cal_rset}(x, summarize = TRUE, ...) 8 | } 9 | \arguments{ 10 | \item{x}{An object produced by one of the validation function (or class 11 | \code{cal_rset}).} 12 | 13 | \item{summarize}{A logical; should metrics be summarized over resamples 14 | (\code{TRUE}) or return the values for each individual resample. See 15 | \code{\link[tune:collect_predictions]{tune::collect_metrics()}} for more details.} 16 | 17 | \item{...}{Not currently used.} 18 | } 19 | \value{ 20 | A tibble 21 | } 22 | \description{ 23 | Obtain and format metrics produced by calibration validation 24 | } 25 | -------------------------------------------------------------------------------- /man/collect_predictions.cal_rset.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-validate.R 3 | \name{collect_predictions.cal_rset} 4 | \alias{collect_predictions.cal_rset} 5 | \title{Obtain and format predictions produced by calibration validation} 6 | \usage{ 7 | \method{collect_predictions}{cal_rset}(x, summarize = TRUE, ...) 8 | } 9 | \arguments{ 10 | \item{x}{An object produced by one of the validation function (or class 11 | \code{cal_rset}).} 12 | 13 | \item{summarize}{A logical; should predictions be summarized over resamples 14 | (\code{TRUE}) or return the values for each individual resample. See 15 | \code{\link[tune:collect_predictions]{tune::collect_predictions()}} for more details.} 16 | 17 | \item{...}{Not currently used.} 18 | } 19 | \value{ 20 | A tibble 21 | } 22 | \description{ 23 | Obtain and format predictions produced by calibration validation 24 | } 25 | -------------------------------------------------------------------------------- /man/control_conformal_full.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/conformal_infer.R 3 | \name{control_conformal_full} 4 | \alias{control_conformal_full} 5 | \title{Controlling the numeric details for conformal inference} 6 | \usage{ 7 | control_conformal_full( 8 | method = "iterative", 9 | trial_points = 100, 10 | var_multiplier = 10, 11 | max_iter = 100, 12 | tolerance = .Machine$double.eps^0.25, 13 | progress = FALSE, 14 | required_pkgs = character(0), 15 | seed = sample.int(10^5, 1) 16 | ) 17 | } 18 | \arguments{ 19 | \item{method}{The method for computing the intervals. The options are 20 | \code{'search'} (using) \code{\link[stats:uniroot]{stats::uniroot()}}, and \code{'grid'}.} 21 | 22 | \item{trial_points}{When \code{method = "grid"}, how many points should be 23 | evaluated?} 24 | 25 | \item{var_multiplier}{A multiplier for the variance model that determines the 26 | possible range of the bounds.} 27 | 28 | \item{max_iter}{When \code{method = "iterative"}, the maximum number of iterations.} 29 | 30 | \item{tolerance}{Tolerance value passed to \code{\link[=all.equal]{all.equal()}} to determine 31 | convergence during the search computations.} 32 | 33 | \item{progress}{Should a progress bar be used to track execution?} 34 | 35 | \item{required_pkgs}{An optional character string for which packages are 36 | required.} 37 | 38 | \item{seed}{A single integer used to control randomness when models are 39 | (re)fit.} 40 | } 41 | \value{ 42 | A list object with the options given by the user. 43 | } 44 | \description{ 45 | Controlling the numeric details for conformal inference 46 | } 47 | -------------------------------------------------------------------------------- /man/figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/probably/4dab76fce47ac7d75334f57c8cb4ecba21cadf15/man/figures/logo.png -------------------------------------------------------------------------------- /man/int_conformal_cv.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/conformal_infer_cv.R 3 | \name{int_conformal_cv} 4 | \alias{int_conformal_cv} 5 | \alias{int_conformal_cv.default} 6 | \alias{int_conformal_cv.resample_results} 7 | \alias{int_conformal_cv.tune_results} 8 | \title{Prediction intervals via conformal inference CV+} 9 | \usage{ 10 | int_conformal_cv(object, ...) 11 | 12 | \method{int_conformal_cv}{default}(object, ...) 13 | 14 | \method{int_conformal_cv}{resample_results}(object, ...) 15 | 16 | \method{int_conformal_cv}{tune_results}(object, parameters, ...) 17 | } 18 | \arguments{ 19 | \item{object}{An object from a tidymodels resampling or tuning function such 20 | as \code{\link[tune:fit_resamples]{tune::fit_resamples()}}, \code{\link[tune:tune_grid]{tune::tune_grid()}}, or similar. The object 21 | should have been produced in a way that the \code{.extracts} column contains the 22 | fitted workflow for each resample (see the Details below).} 23 | 24 | \item{...}{Not currently used.} 25 | 26 | \item{parameters}{An tibble of tuning parameter values that can be 27 | used to filter the predicted values before processing. This tibble should 28 | select a single set of hyper-parameter values from the tuning results. This is 29 | only required when a tuning object is passed to \code{object}.} 30 | } 31 | \value{ 32 | An object of class \code{"int_conformal_cv"} containing the information 33 | to create intervals. The \code{predict()} method is used to produce the intervals. 34 | } 35 | \description{ 36 | Nonparametric prediction intervals can be computed for fitted regression 37 | workflow objects using the CV+ conformal inference method described by 38 | Barber \emph{at al} (2018). 39 | } 40 | \details{ 41 | This function implements the CV+ method found in Section 3 of Barber \emph{at al} 42 | (2018). It uses the resampled model fits and their associated holdout 43 | residuals to make prediction intervals for regression models. 44 | 45 | This function prepares the objects for the computations. The \code{\link[=predict]{predict()}} 46 | method computes the intervals for new data. 47 | 48 | This method was developed for V-fold cross-validation (no repeats). Interval 49 | coverage is unknown for any other resampling methods. The function will not 50 | stop the computations for other types of resamples, but we have no way of 51 | knowing whether the results are appropriate. 52 | } 53 | \examples{ 54 | \dontshow{if (!probably:::is_cran_check() & rlang::is_installed(c("modeldata", "parsnip"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 55 | library(workflows) 56 | library(dplyr) 57 | library(parsnip) 58 | library(rsample) 59 | library(tune) 60 | library(modeldata) 61 | 62 | set.seed(2) 63 | sim_train <- sim_regression(200) 64 | sim_new <- sim_regression(5) |> select(-outcome) 65 | 66 | sim_rs <- vfold_cv(sim_train) 67 | 68 | # We'll use a neural network model 69 | mlp_spec <- 70 | mlp(hidden_units = 5, penalty = 0.01) |> 71 | set_mode("regression") 72 | 73 | # Use a control function that saves the predictions as well as the models. 74 | # Consider using the butcher package in the extracts function to have smaller 75 | # object sizes 76 | 77 | ctrl <- control_resamples(save_pred = TRUE, extract = I) 78 | 79 | set.seed(3) 80 | nnet_res <- 81 | mlp_spec |> 82 | fit_resamples(outcome ~ ., resamples = sim_rs, control = ctrl) 83 | 84 | nnet_int_obj <- int_conformal_cv(nnet_res) 85 | nnet_int_obj 86 | 87 | predict(nnet_int_obj, sim_new) 88 | \dontshow{\}) # examplesIf} 89 | } 90 | \references{ 91 | Rina Foygel Barber, Emmanuel J. Candès, Aaditya Ramdas, Ryan J. Tibshirani 92 | "Predictive inference with the jackknife+," \emph{The Annals of Statistics}, 93 | 49(1), 486-507, 2021 94 | } 95 | \seealso{ 96 | \code{\link[=predict.int_conformal_cv]{predict.int_conformal_cv()}} 97 | } 98 | -------------------------------------------------------------------------------- /man/int_conformal_quantile.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/conformal_infer_quantile.R 3 | \name{int_conformal_quantile} 4 | \alias{int_conformal_quantile} 5 | \alias{int_conformal_quantile.workflow} 6 | \title{Prediction intervals via conformal inference and quantile regression} 7 | \usage{ 8 | int_conformal_quantile(object, ...) 9 | 10 | \method{int_conformal_quantile}{workflow}(object, train_data, cal_data, level = 0.95, ...) 11 | } 12 | \arguments{ 13 | \item{object}{A fitted \code{\link[workflows:workflow]{workflows::workflow()}} object.} 14 | 15 | \item{...}{Options to pass to \code{\link[quantregForest:quantregForest]{quantregForest::quantregForest()}} (such as the 16 | number of trees).} 17 | 18 | \item{train_data, cal_data}{Data frames with the \emph{predictor and outcome data}. 19 | \code{train_data} should be the same data used to produce \code{object} and \code{cal_data} is 20 | used to produce predictions (and residuals). If the workflow used a recipe, 21 | these should be the data that were inputs to the recipe (and not the product 22 | of a recipe).} 23 | 24 | \item{level}{The confidence level for the intervals.} 25 | } 26 | \value{ 27 | An object of class \code{"int_conformal_quantile"} containing the 28 | information to create intervals (which includes \code{object}). 29 | The \code{predict()} method is used to produce the intervals. 30 | } 31 | \description{ 32 | Nonparametric prediction intervals can be computed for fitted regression 33 | workflow objects using the split conformal inference method described by 34 | Romano \emph{et al} (2019). To compute quantiles, this function uses Quantile 35 | Random Forests instead of classic quantile regression. 36 | } 37 | \details{ 38 | Note that the significance level should be specified in this function 39 | (instead of the \code{predict()} method). 40 | 41 | \code{cal_data} should be large enough to get a good estimates of a extreme 42 | quantile (e.g., the 95th for 95\% interval) and should not include rows that 43 | were in the original training set. 44 | 45 | Note that the because of the method used to construct the interval, it is 46 | possible that the prediction intervals will not include the predicted value. 47 | } 48 | \examples{ 49 | \dontshow{if (!probably:::is_cran_check() & rlang::is_installed(c("modeldata", "parsnip", "quantregForest"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 50 | library(workflows) 51 | library(dplyr) 52 | library(parsnip) 53 | library(rsample) 54 | library(tune) 55 | library(modeldata) 56 | 57 | set.seed(2) 58 | sim_train <- sim_regression(500) 59 | sim_cal <- sim_regression(200) 60 | sim_new <- sim_regression(5) |> select(-outcome) 61 | 62 | # We'll use a neural network model 63 | mlp_spec <- 64 | mlp(hidden_units = 5, penalty = 0.01) |> 65 | set_mode("regression") 66 | 67 | mlp_wflow <- 68 | workflow() |> 69 | add_model(mlp_spec) |> 70 | add_formula(outcome ~ .) 71 | 72 | mlp_fit <- fit(mlp_wflow, data = sim_train) 73 | 74 | mlp_int <- int_conformal_quantile(mlp_fit, sim_train, sim_cal, 75 | level = 0.90 76 | ) 77 | mlp_int 78 | 79 | predict(mlp_int, sim_new) 80 | \dontshow{\}) # examplesIf} 81 | } 82 | \references{ 83 | Romano, Yaniv, Evan Patterson, and Emmanuel Candes. "Conformalized quantile 84 | regression." \emph{Advances in neural information processing systems} 32 (2019). 85 | } 86 | \seealso{ 87 | \code{\link[=predict.int_conformal_quantile]{predict.int_conformal_quantile()}} 88 | } 89 | -------------------------------------------------------------------------------- /man/int_conformal_split.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/conformal_infer_split.R 3 | \name{int_conformal_split} 4 | \alias{int_conformal_split} 5 | \alias{int_conformal_split.default} 6 | \alias{int_conformal_split.workflow} 7 | \title{Prediction intervals via split conformal inference} 8 | \usage{ 9 | int_conformal_split(object, ...) 10 | 11 | \method{int_conformal_split}{default}(object, ...) 12 | 13 | \method{int_conformal_split}{workflow}(object, cal_data, ...) 14 | } 15 | \arguments{ 16 | \item{object}{A fitted \code{\link[workflows:workflow]{workflows::workflow()}} object.} 17 | 18 | \item{...}{Not currently used.} 19 | 20 | \item{cal_data}{A data frame with the \emph{original predictor and outcome data} 21 | used to produce predictions (and residuals). If the workflow used a recipe, 22 | this should be the data that were inputs to the recipe (and not the product 23 | of a recipe).} 24 | } 25 | \value{ 26 | An object of class \code{"int_conformal_split"} containing the 27 | information to create intervals (which includes \code{object}). 28 | The \code{predict()} method is used to produce the intervals. 29 | } 30 | \description{ 31 | Nonparametric prediction intervals can be computed for fitted regression 32 | workflow objects using the split conformal inference method described by 33 | Lei \emph{et al} (2018). 34 | } 35 | \details{ 36 | This function implements what is usually called "split conformal inference" 37 | (see Algorithm 1 in Lei \emph{et al} (2018)). 38 | 39 | This function prepares the statistics for the interval computations. The 40 | \code{\link[=predict]{predict()}} method computes the intervals for new data and the signficance 41 | level is specified there. 42 | 43 | \code{cal_data} should be large enough to get a good estimates of a extreme 44 | quantile (e.g., the 95th for 95\% interval) and should not include rows that 45 | were in the original training set. 46 | } 47 | \examples{ 48 | \dontshow{if (!probably:::is_cran_check() & rlang::is_installed(c("modeldata", "parsnip", "nnet"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 49 | library(workflows) 50 | library(dplyr) 51 | library(parsnip) 52 | library(rsample) 53 | library(tune) 54 | library(modeldata) 55 | 56 | set.seed(2) 57 | sim_train <- sim_regression(500) 58 | sim_cal <- sim_regression(200) 59 | sim_new <- sim_regression(5) |> select(-outcome) 60 | 61 | # We'll use a neural network model 62 | mlp_spec <- 63 | mlp(hidden_units = 5, penalty = 0.01) |> 64 | set_mode("regression") 65 | 66 | mlp_wflow <- 67 | workflow() |> 68 | add_model(mlp_spec) |> 69 | add_formula(outcome ~ .) 70 | 71 | mlp_fit <- fit(mlp_wflow, data = sim_train) 72 | 73 | mlp_int <- int_conformal_split(mlp_fit, sim_cal) 74 | mlp_int 75 | 76 | predict(mlp_int, sim_new, level = 0.90) 77 | \dontshow{\}) # examplesIf} 78 | } 79 | \references{ 80 | Lei, Jing, et al. "Distribution-free predictive inference for regression." 81 | \emph{Journal of the American Statistical Association} 113.523 (2018): 1094-1111. 82 | } 83 | \seealso{ 84 | \code{\link[=predict.int_conformal_split]{predict.int_conformal_split()}} 85 | } 86 | -------------------------------------------------------------------------------- /man/is_class_pred.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/class-pred.R 3 | \name{is_class_pred} 4 | \alias{is_class_pred} 5 | \title{Test if an object inherits from \code{class_pred}} 6 | \usage{ 7 | is_class_pred(x) 8 | } 9 | \arguments{ 10 | \item{x}{An object.} 11 | } 12 | \description{ 13 | \code{is_class_pred()} checks if an object is a \code{class_pred} object. 14 | } 15 | \examples{ 16 | 17 | x <- class_pred(factor(1:5)) 18 | 19 | is_class_pred(x) 20 | 21 | } 22 | -------------------------------------------------------------------------------- /man/levels.class_pred.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/class-pred.R 3 | \name{levels.class_pred} 4 | \alias{levels.class_pred} 5 | \title{Extract \code{class_pred} levels} 6 | \usage{ 7 | \method{levels}{class_pred}(x) 8 | } 9 | \arguments{ 10 | \item{x}{A \code{class_pred} object.} 11 | } 12 | \description{ 13 | The levels of a \code{class_pred} object do \emph{not} include the equivocal value. 14 | } 15 | \examples{ 16 | 17 | x <- class_pred(factor(1:5), which = 1) 18 | 19 | # notice that even though `1` is not in the `class_pred` vector, the 20 | # level remains from the original factor 21 | levels(x) 22 | 23 | } 24 | -------------------------------------------------------------------------------- /man/locate-equivocal.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/class-pred.R 3 | \name{locate-equivocal} 4 | \alias{locate-equivocal} 5 | \alias{is_equivocal} 6 | \alias{which_equivocal} 7 | \alias{any_equivocal} 8 | \title{Locate equivocal values} 9 | \usage{ 10 | is_equivocal(x) 11 | 12 | which_equivocal(x) 13 | 14 | any_equivocal(x) 15 | } 16 | \arguments{ 17 | \item{x}{A \code{class_pred} object.} 18 | } 19 | \value{ 20 | \code{is_equivocal()} returns a logical vector the same length as \code{x} 21 | where \code{TRUE} means the value is equivocal. 22 | 23 | \code{which_equivocal()} returns an integer vector specifying the locations 24 | of the equivocal values. 25 | 26 | \code{any_equivocal()} returns \code{TRUE} if there are any equivocal values. 27 | } 28 | \description{ 29 | These functions provide multiple methods of checking for equivocal values, 30 | and finding their locations. 31 | } 32 | \examples{ 33 | 34 | x <- class_pred(factor(1:10), which = c(2, 5)) 35 | 36 | is_equivocal(x) 37 | 38 | which_equivocal(x) 39 | 40 | any_equivocal(x) 41 | 42 | } 43 | -------------------------------------------------------------------------------- /man/make_class_pred.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/make_class_pred.R 3 | \name{make_class_pred} 4 | \alias{make_class_pred} 5 | \alias{make_two_class_pred} 6 | \title{Create a \code{class_pred} vector from class probabilities} 7 | \usage{ 8 | make_class_pred(..., levels, ordered = FALSE, min_prob = 1/length(levels)) 9 | 10 | make_two_class_pred( 11 | estimate, 12 | levels, 13 | threshold = 0.5, 14 | ordered = FALSE, 15 | buffer = NULL 16 | ) 17 | } 18 | \arguments{ 19 | \item{...}{Numeric vectors corresponding to class probabilities. There should 20 | be one for each level in \code{levels}, and \emph{it is assumed that the vectors 21 | are in the same order as \code{levels}}.} 22 | 23 | \item{levels}{A character vector of class levels. The length should be the 24 | same as the number of selections made through \code{...}, or length \code{2} 25 | for \code{make_two_class_pred()}.} 26 | 27 | \item{ordered}{A single logical to determine if the levels should be regarded 28 | as ordered (in the order given). This results in a \code{class_pred} object 29 | that is flagged as ordered.} 30 | 31 | \item{min_prob}{A single numeric value. If any probabilities are less than 32 | this value (by row), the row is marked as \emph{equivocal}.} 33 | 34 | \item{estimate}{A single numeric vector corresponding to the class 35 | probabilities of the first level in \code{levels}.} 36 | 37 | \item{threshold}{A single numeric value for the threshold to call a row to 38 | be labeled as the first value of \code{levels}.} 39 | 40 | \item{buffer}{A numeric vector of length 1 or 2 for the buffer around 41 | \code{threshold} that defines the equivocal zone (i.e., \code{threshold - buffer[1]} to 42 | \code{threshold + buffer[2]}). A length 1 vector is recycled to length 2. The 43 | default, \code{NULL}, is interpreted as no equivocal zone.} 44 | } 45 | \value{ 46 | A vector of class \code{\link{class_pred}}. 47 | } 48 | \description{ 49 | These functions can be used to convert class probability estimates to 50 | \code{class_pred} objects with an optional equivocal zone. 51 | } 52 | \examples{ 53 | 54 | library(dplyr) 55 | 56 | good <- segment_logistic$.pred_good 57 | lvls <- levels(segment_logistic$Class) 58 | 59 | # Equivocal zone of .5 +/- .15 60 | make_two_class_pred(good, lvls, buffer = 0.15) 61 | 62 | # Equivocal zone of c(.5 - .05, .5 + .15) 63 | make_two_class_pred(good, lvls, buffer = c(0.05, 0.15)) 64 | 65 | # These functions are useful alongside dplyr::mutate() 66 | segment_logistic |> 67 | mutate( 68 | .class_pred = make_two_class_pred( 69 | estimate = .pred_good, 70 | levels = levels(Class), 71 | buffer = 0.15 72 | ) 73 | ) 74 | 75 | # Multi-class example 76 | # Note that we provide class probability columns in the same 77 | # order as the levels 78 | species_probs |> 79 | mutate( 80 | .class_pred = make_class_pred( 81 | .pred_bobcat, .pred_coyote, .pred_gray_fox, 82 | levels = levels(Species), 83 | min_prob = .5 84 | ) 85 | ) 86 | 87 | } 88 | -------------------------------------------------------------------------------- /man/predict.int_conformal_full.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/conformal_infer.R, R/conformal_infer_cv.R, 3 | % R/conformal_infer_quantile.R, R/conformal_infer_split.R 4 | \name{predict.int_conformal_full} 5 | \alias{predict.int_conformal_full} 6 | \alias{predict.int_conformal_cv} 7 | \alias{predict.int_conformal_quantile} 8 | \alias{predict.int_conformal_split} 9 | \title{Prediction intervals from conformal methods} 10 | \usage{ 11 | \method{predict}{int_conformal_full}(object, new_data, level = 0.95, ...) 12 | 13 | \method{predict}{int_conformal_cv}(object, new_data, level = 0.95, ...) 14 | 15 | \method{predict}{int_conformal_quantile}(object, new_data, ...) 16 | 17 | \method{predict}{int_conformal_split}(object, new_data, level = 0.95, ...) 18 | } 19 | \arguments{ 20 | \item{object}{An object produced by \code{\link[=predict.int_conformal_full]{predict.int_conformal_full()}}.} 21 | 22 | \item{new_data}{A data frame of predictors.} 23 | 24 | \item{level}{The confidence level for the intervals.} 25 | 26 | \item{...}{Not currently used.} 27 | } 28 | \value{ 29 | A tibble with columns \code{.pred_lower} and \code{.pred_upper}. If 30 | the computations for the prediction bound fail, a missing value is used. For 31 | objects produced by \code{\link[=int_conformal_cv]{int_conformal_cv()}}, an additional \code{.pred} column 32 | is also returned (see Details below). 33 | } 34 | \description{ 35 | Prediction intervals from conformal methods 36 | } 37 | \details{ 38 | For the CV+. estimator produced by \code{\link[=int_conformal_cv]{int_conformal_cv()}}, the intervals 39 | are centered around the mean of the predictions produced by the 40 | resample-specific model. For example, with 10-fold cross-validation, \code{.pred} 41 | is the average of the predictions from the 10 models produced by each fold. 42 | This may differ from the prediction generated from a model fit that was 43 | trained on the entire training set, especially if the training sets are 44 | small. 45 | } 46 | \seealso{ 47 | \code{\link[=int_conformal_full]{int_conformal_full()}}, \code{\link[=int_conformal_cv]{int_conformal_cv()}} 48 | } 49 | -------------------------------------------------------------------------------- /man/probably-package.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/probably-package.R 3 | \docType{package} 4 | \name{probably-package} 5 | \alias{probably} 6 | \alias{probably-package} 7 | \title{probably: Tools for Post-Processing Predicted Values} 8 | \description{ 9 | \if{html}{\figure{logo.png}{options: style='float: right' alt='logo' width='120'}} 10 | 11 | Models can be improved by post-processing class probabilities, by: recalibration, conversion to hard probabilities, assessment of equivocal zones, and other activities. 'probably' contains tools for conducting these operations as well as calibration tools and conformal inference techniques for regression models. 12 | } 13 | \seealso{ 14 | Useful links: 15 | \itemize{ 16 | \item \url{https://github.com/tidymodels/probably} 17 | \item \url{https://probably.tidymodels.org} 18 | \item Report bugs at \url{https://github.com/tidymodels/probably/issues} 19 | } 20 | 21 | } 22 | \author{ 23 | \strong{Maintainer}: Max Kuhn \email{max@posit.co} 24 | 25 | Authors: 26 | \itemize{ 27 | \item Davis Vaughan \email{davis@posit.co} 28 | \item Edgar Ruiz \email{edgar@posit.co} 29 | } 30 | 31 | Other contributors: 32 | \itemize{ 33 | \item Posit Software, PBC (03wc8by49) [copyright holder, funder] 34 | } 35 | 36 | } 37 | \keyword{internal} 38 | -------------------------------------------------------------------------------- /man/reexports.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/reexports.R, R/vctrs-compat.R 3 | \docType{import} 4 | \name{reexports} 5 | \alias{reexports} 6 | \alias{fit} 7 | \alias{augment} 8 | \alias{required_pkgs} 9 | \alias{collect_metrics} 10 | \alias{collect_predictions} 11 | \alias{as.factor} 12 | \alias{as.ordered} 13 | \title{Objects exported from other packages} 14 | \keyword{internal} 15 | \description{ 16 | These objects are imported from other packages. Follow the links 17 | below to see their documentation. 18 | 19 | \describe{ 20 | \item{generics}{\code{\link[generics:coercion-factor]{as.factor}}, \code{\link[generics:coercion-factor]{as.ordered}}, \code{\link[generics]{augment}}, \code{\link[generics]{fit}}, \code{\link[generics]{required_pkgs}}} 21 | 22 | \item{tune}{\code{\link[tune:collect_predictions]{collect_metrics}}, \code{\link[tune]{collect_predictions}}} 23 | }} 24 | 25 | -------------------------------------------------------------------------------- /man/reportable_rate.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/class-pred.R 3 | \name{reportable_rate} 4 | \alias{reportable_rate} 5 | \title{Calculate the reportable rate} 6 | \usage{ 7 | reportable_rate(x) 8 | } 9 | \arguments{ 10 | \item{x}{A \code{class_pred} object.} 11 | } 12 | \description{ 13 | The \emph{reportable rate} is defined as the percentage of class predictions 14 | that are \emph{not} equivocal. 15 | } 16 | \details{ 17 | The reportable rate is calculated as \code{(n_not_equivocal / n)}. 18 | } 19 | \examples{ 20 | 21 | x <- class_pred(factor(1:5), which = c(1, 2)) 22 | 23 | # 3 / 5 24 | reportable_rate(x) 25 | 26 | } 27 | -------------------------------------------------------------------------------- /man/required_pkgs.cal_object.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cal-estimate-beta.R, R/cal-estimate-linear.R, 3 | % R/cal-estimate-logistic.R, R/cal-estimate-multinom.R, R/cal-estimate-none.R, 4 | % R/cal-pkg-check.R 5 | \name{required_pkgs.cal_estimate_beta} 6 | \alias{required_pkgs.cal_estimate_beta} 7 | \alias{required_pkgs.cal_estimate_linear_spline} 8 | \alias{required_pkgs.cal_estimate_linear} 9 | \alias{required_pkgs.cal_estimate_logistic_spline} 10 | \alias{required_pkgs.cal_estimate_logistic} 11 | \alias{required_pkgs.cal_estimate_multinomial_spline} 12 | \alias{required_pkgs.cal_estimate_multinomial} 13 | \alias{required_pkgs.cal_estimate_none} 14 | \alias{required_pkgs.cal_object} 15 | \title{S3 methods to track which additional packages are needed for specific 16 | calibrations} 17 | \usage{ 18 | \method{required_pkgs}{cal_estimate_beta}(x, ...) 19 | 20 | \method{required_pkgs}{cal_estimate_linear_spline}(x, ...) 21 | 22 | \method{required_pkgs}{cal_estimate_linear}(x, ...) 23 | 24 | \method{required_pkgs}{cal_estimate_logistic_spline}(x, ...) 25 | 26 | \method{required_pkgs}{cal_estimate_logistic}(x, ...) 27 | 28 | \method{required_pkgs}{cal_estimate_multinomial_spline}(x, ...) 29 | 30 | \method{required_pkgs}{cal_estimate_multinomial}(x, ...) 31 | 32 | \method{required_pkgs}{cal_estimate_multinomial}(x, ...) 33 | 34 | \method{required_pkgs}{cal_estimate_none}(x, ...) 35 | 36 | \method{required_pkgs}{cal_object}(x, ...) 37 | } 38 | \arguments{ 39 | \item{x}{A calibration object} 40 | 41 | \item{...}{Other arguments passed to methods} 42 | } 43 | \description{ 44 | S3 methods to track which additional packages are needed for specific 45 | calibrations 46 | } 47 | \keyword{internal} 48 | -------------------------------------------------------------------------------- /man/rmd/parallel_intervals.Rmd: -------------------------------------------------------------------------------- 1 | ## Speed 2 | 3 | The time it takes to compute the intervals depends on the training set size, search parameters (i.e., convergence criterion, number of iterations), the grid size, and the number of worker processes that are used. For the last item, the computations can be parallelized using the future and furrr packages. 4 | 5 | To use parallelism, the [future::plan()] function can be invoked to create a parallel backend. For example, let's make an initial workflow: 6 | 7 | 8 | ```{r, include = FALSE} 9 | library(tidymodels) 10 | library(probably) 11 | tidymodels_prefer() 12 | ``` 13 | ```{r} 14 | library(tidymodels) 15 | library(probably) 16 | library(future) 17 | 18 | tidymodels_prefer() 19 | 20 | ## Make a fitted workflow from some simulated data: 21 | set.seed(121) 22 | train_dat <- sim_regression(200) 23 | new_dat <- sim_regression( 5) |> select(-outcome) 24 | 25 | lm_fit <- 26 | workflow() |> 27 | add_model(linear_reg()) |> 28 | add_formula(outcome ~ .) |> 29 | fit(data = train_dat) 30 | 31 | # Create the object to be used to make prediction intervals 32 | lm_conform <- int_conformal_full(lm_fit, train_dat) 33 | ``` 34 | 35 | We'll use a `"multisession"` parallel processing plan to compute the intervals for the five new samples in parallel: 36 | 37 | ```{r} 38 | plan("multisession") 39 | 40 | # This is run in parallel: 41 | predict(lm_conform, new_dat) 42 | ``` 43 | 44 | Using simulations, there are slightly sub-linear speed-ups when using parallel processing to compute the row-wise intervals. 45 | 46 | In comparison with parametric intervals: 47 | 48 | ```{r} 49 | predict(lm_fit, new_dat, type = "pred_int") 50 | ``` 51 | -------------------------------------------------------------------------------- /man/segment_naive_bayes.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{segment_naive_bayes} 5 | \alias{segment_naive_bayes} 6 | \alias{segment_logistic} 7 | \title{Image segmentation predictions} 8 | \source{ 9 | Hill, LaPan, Li and Haney (2007). Impact of image segmentation on 10 | high-content screening data quality for SK-BR-3 cells, \emph{BMC 11 | Bioinformatics}, Vol. 8, pg. 340, 12 | \url{https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-8-340}. 13 | } 14 | \value{ 15 | \item{segment_naive_bayes,segment_logistic}{a tibble} 16 | } 17 | \description{ 18 | Image segmentation predictions 19 | } 20 | \details{ 21 | These objects contain test set predictions for the cell segmentation 22 | data from Hill, LaPan, Li and Haney (2007). Each data frame are the results 23 | from different models (naive Bayes and logistic regression). 24 | } 25 | \examples{ 26 | data(segment_naive_bayes) 27 | data(segment_logistic) 28 | } 29 | \keyword{datasets} 30 | -------------------------------------------------------------------------------- /man/species_probs.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{species_probs} 5 | \alias{species_probs} 6 | \title{Predictions on animal species} 7 | \source{ 8 | Reid, R. E. B. (2015). A morphometric modeling approach to 9 | distinguishing among bobcat, coyote and gray fox scats. \emph{Wildlife 10 | Biology}, 21(5), 254-262 11 | } 12 | \value{ 13 | \item{species_probs}{a tibble} 14 | } 15 | \description{ 16 | Predictions on animal species 17 | } 18 | \details{ 19 | These data are holdout predictions from resampling for the animal 20 | scat data of Reid (2015) based on a C5.0 classification model. 21 | } 22 | \examples{ 23 | data(species_probs) 24 | str(species_probs) 25 | } 26 | \keyword{datasets} 27 | -------------------------------------------------------------------------------- /man/threshold_perf.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/threshold_perf.R 3 | \name{threshold_perf} 4 | \alias{threshold_perf} 5 | \alias{threshold_perf.data.frame} 6 | \title{Generate performance metrics across probability thresholds} 7 | \usage{ 8 | threshold_perf(.data, ...) 9 | 10 | \method{threshold_perf}{data.frame}( 11 | .data, 12 | truth, 13 | estimate, 14 | thresholds = NULL, 15 | metrics = NULL, 16 | na_rm = TRUE, 17 | event_level = "first", 18 | ... 19 | ) 20 | } 21 | \arguments{ 22 | \item{.data}{A tibble, potentially grouped.} 23 | 24 | \item{...}{Currently unused.} 25 | 26 | \item{truth}{The column identifier for the true two-class results 27 | (that is a factor). This should be an unquoted column name.} 28 | 29 | \item{estimate}{The column identifier for the predicted class probabilities 30 | (that is a numeric). This should be an unquoted column name.} 31 | 32 | \item{thresholds}{A numeric vector of values for the probability 33 | threshold. If unspecified, a series 34 | of values between 0.5 and 1.0 are used. \strong{Note}: if this 35 | argument is used, it must be named.} 36 | 37 | \item{metrics}{Either \code{NULL} or a \code{\link[yardstick:metric_set]{yardstick::metric_set()}} with a list of 38 | performance metrics to calculate. The metrics should all be oriented towards 39 | hard class predictions (e.g. \code{\link[yardstick:sens]{yardstick::sensitivity()}}, 40 | \code{\link[yardstick:accuracy]{yardstick::accuracy()}}, \code{\link[yardstick:recall]{yardstick::recall()}}, etc.) and not 41 | class probabilities. A set of default metrics is used when \code{NULL} (see 42 | Details below).} 43 | 44 | \item{na_rm}{A single logical: should missing data be removed?} 45 | 46 | \item{event_level}{A single string. Either \code{"first"} or \code{"second"} to specify 47 | which level of \code{truth} to consider as the "event".} 48 | } 49 | \value{ 50 | A tibble with columns: \code{.threshold}, \code{.estimator}, \code{.metric}, 51 | \code{.estimate} and any existing groups. 52 | } 53 | \description{ 54 | \code{threshold_perf()} can take a set of class probability predictions 55 | and determine performance characteristics across different values 56 | of the probability threshold and any existing groups. 57 | } 58 | \details{ 59 | Note that that the global option \code{yardstick.event_first} will be 60 | used to determine which level is the event of interest. For more details, 61 | see the Relevant level section of \code{\link[yardstick:sens]{yardstick::sens()}}. 62 | 63 | The default calculated metrics are: 64 | \itemize{ 65 | \item \code{\link[yardstick:j_index]{yardstick::j_index()}} 66 | \item \code{\link[yardstick:sens]{yardstick::sens()}} 67 | \item \code{\link[yardstick:spec]{yardstick::spec()}} 68 | \item \code{distance = (1 - sens) ^ 2 + (1 - spec) ^ 2} 69 | } 70 | 71 | If a custom metric is passed that does not compute sensitivity and 72 | specificity, the distance metric is not computed. 73 | } 74 | \examples{ 75 | library(dplyr) 76 | data("segment_logistic") 77 | 78 | # Set the threshold to 0.6 79 | # > 0.6 = good 80 | # < 0.6 = poor 81 | threshold_perf(segment_logistic, Class, .pred_good, thresholds = 0.6) 82 | 83 | # Set the threshold to multiple values 84 | thresholds <- seq(0.5, 0.9, by = 0.1) 85 | 86 | segment_logistic |> 87 | threshold_perf(Class, .pred_good, thresholds) 88 | 89 | # --------------------------------------------------------------------------- 90 | 91 | # It works with grouped data frames as well 92 | # Let's mock some resampled data 93 | resamples <- 5 94 | 95 | mock_resamples <- resamples |> 96 | replicate( 97 | expr = sample_n(segment_logistic, 100, replace = TRUE), 98 | simplify = FALSE 99 | ) |> 100 | bind_rows(.id = "resample") 101 | 102 | resampled_threshold_perf <- mock_resamples |> 103 | group_by(resample) |> 104 | threshold_perf(Class, .pred_good, thresholds) 105 | 106 | resampled_threshold_perf 107 | 108 | # Average over the resamples 109 | resampled_threshold_perf |> 110 | group_by(.metric, .threshold) |> 111 | summarise(.estimate = mean(.estimate)) 112 | 113 | } 114 | -------------------------------------------------------------------------------- /probably.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | ProjectId: cdf2db78-fccc-43cc-9265-e5f5194ef54c 3 | 4 | RestoreWorkspace: No 5 | SaveWorkspace: No 6 | AlwaysSaveHistory: Default 7 | 8 | EnableCodeIndexing: Yes 9 | UseSpacesForTab: Yes 10 | NumSpacesForTab: 2 11 | Encoding: UTF-8 12 | 13 | RnwWeave: Sweave 14 | LaTeX: pdfLaTeX 15 | 16 | AutoAppendNewline: Yes 17 | StripTrailingWhitespace: Yes 18 | 19 | BuildType: Package 20 | PackageUseDevtools: Yes 21 | PackageInstallArgs: --no-multiarch --with-keep.source 22 | PackageRoxygenize: rd,collate,namespace 23 | -------------------------------------------------------------------------------- /revdep/.gitignore: -------------------------------------------------------------------------------- 1 | checks 2 | library 3 | checks.noindex 4 | library.noindex 5 | data.sqlite 6 | *.html 7 | cloud.noindex 8 | -------------------------------------------------------------------------------- /revdep/README.md: -------------------------------------------------------------------------------- 1 | # Platform 2 | 3 | |field |value | 4 | |:--------|:---------------------------------------------------------------------------------| 5 | |version |R version 4.1.3 (2022-03-10) | 6 | |os |macOS Monterey 12.5.1 | 7 | |system |x86_64, darwin17.0 | 8 | |ui |RStudio | 9 | |language |(EN) | 10 | |collate |en_US.UTF-8 | 11 | |ctype |en_US.UTF-8 | 12 | |tz |America/New_York | 13 | |date |2022-08-29 | 14 | |rstudio |2022.07.1+554 Spotted Wakerobin (desktop) | 15 | |pandoc |2.18 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown) | 16 | 17 | # Dependencies 18 | 19 | |package |old |new |Δ | 20 | |:----------|:-----|:----------|:--| 21 | |probably |0.0.6 |0.0.6.9000 |* | 22 | |cli |3.3.0 |3.3.0 | | 23 | |dplyr |1.0.9 |1.0.9 | | 24 | |ellipsis |0.3.2 |0.3.2 | | 25 | |fansi |1.0.3 |1.0.3 | | 26 | |generics |0.1.3 |0.1.3 | | 27 | |glue |1.6.2 |1.6.2 | | 28 | |hardhat |1.2.0 |1.2.0 | | 29 | |lifecycle |1.0.1 |1.0.1 | | 30 | |magrittr |2.0.3 |2.0.3 | | 31 | |pillar |1.8.1 |1.8.1 | | 32 | |pkgconfig |2.0.3 |2.0.3 | | 33 | |purrr |0.3.4 |0.3.4 | | 34 | |R6 |2.5.1 |2.5.1 | | 35 | |rlang |1.0.4 |1.0.4 | | 36 | |tibble |3.1.8 |3.1.8 | | 37 | |tidyselect |1.1.2 |1.1.2 | | 38 | |utf8 |1.2.2 |1.2.2 | | 39 | |vctrs |0.4.1 |0.4.1 | | 40 | |yardstick |1.0.0 |1.0.0 | | 41 | 42 | # Revdeps 43 | 44 | -------------------------------------------------------------------------------- /revdep/cran.md: -------------------------------------------------------------------------------- 1 | ## revdepcheck results 2 | 3 | We checked 1 reverse dependencies, comparing R CMD check results across CRAN and dev versions of this package. 4 | 5 | * We saw 0 new problems 6 | * We failed to check 0 packages 7 | 8 | -------------------------------------------------------------------------------- /revdep/email.yml: -------------------------------------------------------------------------------- 1 | release_date: ??? 2 | rel_release_date: ??? 3 | my_news_url: ??? 4 | release_version: ??? 5 | release_details: ??? 6 | -------------------------------------------------------------------------------- /revdep/failures.md: -------------------------------------------------------------------------------- 1 | *Wow, no problems at all. :)* -------------------------------------------------------------------------------- /revdep/problems.md: -------------------------------------------------------------------------------- 1 | *Wow, no problems at all. :)* -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | # This file is part of the standard setup for testthat. 2 | # It is recommended that you do not modify it. 3 | # 4 | # Where should you do additional test configuration? 5 | # Learn more about the roles of various files in: 6 | # * https://r-pkgs.org/tests.html 7 | # * https://testthat.r-lib.org/reference/test_package.html#special-files 8 | 9 | library(testthat) 10 | library(probably) 11 | 12 | test_check("probably") 13 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/bound-prediction.md: -------------------------------------------------------------------------------- 1 | # lower_limit bounds for numeric predictions 2 | 3 | Code 4 | bound_prediction(modeldata::solubility_test, lower_limit = 2) 5 | Condition 6 | Error in `bound_prediction()`: 7 | ! The argument `x` should have a column named `.pred`. 8 | 9 | --- 10 | 11 | Code 12 | bound_prediction(mutate(modeldata::solubility_test, .pred = format(prediction)), 13 | lower_limit = 2) 14 | Condition 15 | Error in `bound_prediction()`: 16 | ! Column `.pred` should be numeric. 17 | 18 | --- 19 | 20 | Code 21 | bound_prediction(sol, lower_limit = tune2()) 22 | Condition 23 | Error in `bound_prediction()`: 24 | ! `lower_limit` must be a number or `NA`, not a call. 25 | 26 | --- 27 | 28 | Code 29 | bound_prediction(as.matrix(sol), lower_limit = 1) 30 | Condition 31 | Error in `bound_prediction()`: 32 | ! `x` must be a data frame, not a double matrix. 33 | 34 | # upper_limit bounds for numeric predictions 35 | 36 | Code 37 | bound_prediction(modeldata::solubility_test, lower_limit = 2) 38 | Condition 39 | Error in `bound_prediction()`: 40 | ! The argument `x` should have a column named `.pred`. 41 | 42 | --- 43 | 44 | Code 45 | bound_prediction(mutate(modeldata::solubility_test, .pred = format(prediction)), 46 | lower_limit = 2) 47 | Condition 48 | Error in `bound_prediction()`: 49 | ! Column `.pred` should be numeric. 50 | 51 | --- 52 | 53 | Code 54 | bound_prediction(sol, upper_limit = tune2()) 55 | Condition 56 | Error in `bound_prediction()`: 57 | ! `upper_limit` must be a number or `NA`, not a call. 58 | 59 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/cal-estimate-beta.md: -------------------------------------------------------------------------------- 1 | # Beta estimates work - data.frame 2 | 3 | Code 4 | print(sl_beta) 5 | Message 6 | 7 | -- Probability Calibration 8 | Method: Beta calibration 9 | Type: Binary 10 | Source class: Data Frame 11 | Data points: 1,010 12 | Truth variable: `Class` 13 | Estimate variables: 14 | `.pred_good` ==> good 15 | `.pred_poor` ==> poor 16 | 17 | --- 18 | 19 | Code 20 | print(sl_beta_group) 21 | Message 22 | 23 | -- Probability Calibration 24 | Method: Beta calibration 25 | Type: Binary 26 | Source class: Data Frame 27 | Data points: 1,010, split in 2 groups 28 | Truth variable: `Class` 29 | Estimate variables: 30 | `.pred_good` ==> good 31 | `.pred_poor` ==> poor 32 | 33 | --- 34 | 35 | x `.by` cannot select more than one column. 36 | i The following columns were selected: 37 | i group1 and group2 38 | 39 | # Beta estimates work - tune_results 40 | 41 | Code 42 | print(tl_beta) 43 | Message 44 | 45 | -- Probability Calibration 46 | Method: Beta calibration 47 | Type: Binary 48 | Source class: Tune Results 49 | Data points: 4,000, split in 8 groups 50 | Truth variable: `class` 51 | Estimate variables: 52 | `.pred_class_1` ==> class_1 53 | `.pred_class_2` ==> class_2 54 | 55 | --- 56 | 57 | Code 58 | print(mtnl_beta) 59 | Message 60 | 61 | -- Probability Calibration 62 | Method: Beta calibration 63 | Type: Multiclass (1 v All) 64 | Source class: Tune Results 65 | Data points: 5,000, split in 10 groups 66 | Truth variable: `class` 67 | Estimate variables: 68 | `.pred_one` ==> one 69 | `.pred_two` ==> two 70 | `.pred_three` ==> three 71 | 72 | # Beta estimates errors - grouped_df 73 | 74 | x This function does not work with grouped data frames. 75 | i Apply `dplyr::ungroup()` and use the `.by` argument. 76 | 77 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/cal-estimate-logistic.md: -------------------------------------------------------------------------------- 1 | # Logistic estimates work - data.frame 2 | 3 | Code 4 | print(sl_logistic) 5 | Message 6 | 7 | -- Probability Calibration 8 | Method: Logistic regression calibration 9 | Type: Binary 10 | Source class: Data Frame 11 | Data points: 1,010 12 | Truth variable: `Class` 13 | Estimate variables: 14 | `.pred_good` ==> good 15 | `.pred_poor` ==> poor 16 | 17 | --- 18 | 19 | The selectors in `estimate` resolves to 1 values (".pred_poor") but there are 2 class levels ("good" and "poor"). 20 | 21 | --- 22 | 23 | The `truth` column has 4 levels ("VF", "F", "M", and "L"), but only two-class factors are allowed for this calibration method. 24 | 25 | --- 26 | 27 | Code 28 | print(sl_logistic_group) 29 | Message 30 | 31 | -- Probability Calibration 32 | Method: Logistic regression calibration 33 | Type: Binary 34 | Source class: Data Frame 35 | Data points: 1,010, split in 2 groups 36 | Truth variable: `Class` 37 | Estimate variables: 38 | `.pred_good` ==> good 39 | `.pred_poor` ==> poor 40 | 41 | --- 42 | 43 | x `.by` cannot select more than one column. 44 | i The following columns were selected: 45 | i group1 and group2 46 | 47 | # Logistic estimates work - tune_results 48 | 49 | Code 50 | print(tl_logistic) 51 | Message 52 | 53 | -- Probability Calibration 54 | Method: Logistic regression calibration 55 | Type: Binary 56 | Source class: Tune Results 57 | Data points: 4,000, split in 8 groups 58 | Truth variable: `class` 59 | Estimate variables: 60 | `.pred_class_1` ==> class_1 61 | `.pred_class_2` ==> class_2 62 | 63 | --- 64 | 65 | The `truth` column has 3 levels ("one", "two", and "three"), but only two-class factors are allowed for this calibration method. 66 | 67 | # Logistic estimates errors - grouped_df 68 | 69 | x This function does not work with grouped data frames. 70 | i Apply `dplyr::ungroup()` and use the `.by` argument. 71 | 72 | # Logistic spline estimates work - data.frame 73 | 74 | Code 75 | print(sl_gam) 76 | Message 77 | 78 | -- Probability Calibration 79 | Method: Generalized additive model calibration 80 | Type: Binary 81 | Source class: Data Frame 82 | Data points: 1,010 83 | Truth variable: `Class` 84 | Estimate variables: 85 | `.pred_good` ==> good 86 | `.pred_poor` ==> poor 87 | 88 | --- 89 | 90 | Code 91 | print(sl_gam_group) 92 | Message 93 | 94 | -- Probability Calibration 95 | Method: Generalized additive model calibration 96 | Type: Binary 97 | Source class: Data Frame 98 | Data points: 1,010, split in 2 groups 99 | Truth variable: `Class` 100 | Estimate variables: 101 | `.pred_good` ==> good 102 | `.pred_poor` ==> poor 103 | 104 | --- 105 | 106 | x `.by` cannot select more than one column. 107 | i The following columns were selected: 108 | i group1 and group2 109 | 110 | # Logistic spline estimates work - tune_results 111 | 112 | Code 113 | print(tl_gam) 114 | Message 115 | 116 | -- Probability Calibration 117 | Method: Generalized additive model calibration 118 | Type: Binary 119 | Source class: Tune Results 120 | Data points: 4,000, split in 8 groups 121 | Truth variable: `class` 122 | Estimate variables: 123 | `.pred_class_1` ==> class_1 124 | `.pred_class_2` ==> class_2 125 | 126 | # Logistic spline switches to linear if too few unique 127 | 128 | Code 129 | sl_gam <- cal_estimate_logistic(segment_logistic, Class, smooth = TRUE) 130 | Condition 131 | Warning: 132 | Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. 133 | 134 | --- 135 | 136 | Code 137 | sl_gam <- cal_estimate_logistic(segment_logistic, Class, .by = id, smooth = TRUE) 138 | Condition 139 | Warning: 140 | Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. 141 | Warning: 142 | Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. 143 | 144 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/cal-estimate-multinomial.md: -------------------------------------------------------------------------------- 1 | # Multinomial estimates work - data.frame 2 | 3 | Code 4 | print(sp_multi) 5 | Message 6 | 7 | -- Probability Calibration 8 | Method: Multinomial regression calibration 9 | Type: Multiclass 10 | Source class: Data Frame 11 | Data points: 110 12 | Truth variable: `Species` 13 | Estimate variables: 14 | `.pred_bobcat` ==> bobcat 15 | `.pred_coyote` ==> coyote 16 | `.pred_gray_fox` ==> gray_fox 17 | 18 | --- 19 | 20 | Code 21 | print(sp_smth_multi) 22 | Message 23 | 24 | -- Probability Calibration 25 | Method: Generalized additive model calibration 26 | Type: Multiclass 27 | Source class: Data Frame 28 | Data points: 110 29 | Truth variable: `Species` 30 | Estimate variables: 31 | `.pred_bobcat` ==> bobcat 32 | `.pred_coyote` ==> coyote 33 | `.pred_gray_fox` ==> gray_fox 34 | 35 | --- 36 | 37 | Code 38 | print(sl_multi_group) 39 | Message 40 | 41 | -- Probability Calibration 42 | Method: Multinomial regression calibration 43 | Type: Multiclass 44 | Source class: Data Frame 45 | Data points: 110, split in 2 groups 46 | Truth variable: `Species` 47 | Estimate variables: 48 | `.pred_bobcat` ==> bobcat 49 | `.pred_coyote` ==> coyote 50 | `.pred_gray_fox` ==> gray_fox 51 | 52 | --- 53 | 54 | x `.by` cannot select more than one column. 55 | i The following columns were selected: 56 | i group1 and group2 57 | 58 | # Multinomial estimates work - tune_results 59 | 60 | Code 61 | print(tl_multi) 62 | Message 63 | 64 | -- Probability Calibration 65 | Method: Multinomial regression calibration 66 | Type: Multiclass 67 | Source class: Tune Results 68 | Data points: 5,000, split in 10 groups 69 | Truth variable: `class` 70 | Estimate variables: 71 | `.pred_one` ==> one 72 | `.pred_two` ==> two 73 | `.pred_three` ==> three 74 | 75 | --- 76 | 77 | Code 78 | print(tl_smth_multi) 79 | Message 80 | 81 | -- Probability Calibration 82 | Method: Generalized additive model calibration 83 | Type: Multiclass 84 | Source class: Tune Results 85 | Data points: 5,000, split in 10 groups 86 | Truth variable: `class` 87 | Estimate variables: 88 | `.pred_one` ==> one 89 | `.pred_two` ==> two 90 | `.pred_three` ==> three 91 | 92 | # Multinomial estimates errors - grouped_df 93 | 94 | x This function does not work with grouped data frames. 95 | i Apply `dplyr::ungroup()` and use the `.by` argument. 96 | 97 | # Multinomial spline switches to linear if too few unique 98 | 99 | Code 100 | sl_gam <- cal_estimate_multinomial(smol_species_probs, Species, smooth = TRUE) 101 | Condition 102 | Warning: 103 | Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. 104 | 105 | --- 106 | 107 | Code 108 | sl_gam <- cal_estimate_multinomial(smol_by_species_probs, Species, .by = id, 109 | smooth = TRUE) 110 | Condition 111 | Warning: 112 | Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. 113 | Warning: 114 | Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. 115 | 116 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/cal-estimate-none.md: -------------------------------------------------------------------------------- 1 | # no calibration works - data.frame 2 | 3 | Code 4 | print(nope_reg) 5 | Message 6 | 7 | -- Regression Calibration 8 | Method: No calibration 9 | Source class: Data Frame 10 | Data points: 2,000 11 | Truth variable: `outcome` 12 | Estimate variable: `.pred` 13 | 14 | --- 15 | 16 | Code 17 | print(nope_reg_group) 18 | Message 19 | 20 | -- Regression Calibration 21 | Method: No calibration 22 | Source class: Data Frame 23 | Data points: 2,000, split in 2 groups 24 | Truth variable: `outcome` 25 | Estimate variable: `.pred` 26 | 27 | --- 28 | 29 | x `.by` cannot select more than one column. 30 | i The following columns were selected: 31 | i group1 and group2 32 | 33 | --- 34 | 35 | `...` must be empty. 36 | x Problematic argument: 37 | * smooth = TRUE 38 | 39 | --- 40 | 41 | Code 42 | print(nope_binary) 43 | Message 44 | 45 | -- Probability Calibration 46 | Method: No calibration 47 | Type: Binary 48 | Source class: Data Frame 49 | Data points: 1,010 50 | Truth variable: `Class` 51 | Estimate variables: 52 | `.pred_good` ==> good 53 | `.pred_poor` ==> poor 54 | 55 | --- 56 | 57 | The selectors in `estimate` resolves to 1 values (".pred_poor") but there are 2 class levels ("good" and "poor"). 58 | 59 | --- 60 | 61 | x `.by` cannot select more than one column. 62 | i The following columns were selected: 63 | i group1 and group2 64 | 65 | --- 66 | 67 | Code 68 | print(nope_multi) 69 | Message 70 | 71 | -- Probability Calibration 72 | Method: No calibration 73 | Type: Multiclass 74 | Source class: Data Frame 75 | Data points: 110 76 | Truth variable: `Species` 77 | Estimate variables: 78 | `.pred_bobcat` ==> bobcat 79 | `.pred_coyote` ==> coyote 80 | `.pred_gray_fox` ==> gray_fox 81 | 82 | --- 83 | 84 | x `.by` cannot select more than one column. 85 | i The following columns were selected: 86 | i group1 and group2 87 | 88 | # no calibration works - tune_results 89 | 90 | Code 91 | print(nope_reg) 92 | Message 93 | 94 | -- Regression Calibration 95 | Method: No calibration 96 | Source class: Tune Results 97 | Data points: 750, split in 10 groups 98 | Truth variable: `outcome` 99 | Estimate variable: `.pred` 100 | 101 | --- 102 | 103 | `...` must be empty. 104 | x Problematic argument: 105 | * do_something = FALSE 106 | 107 | --- 108 | 109 | Code 110 | print(nope_binary) 111 | Message 112 | 113 | -- Probability Calibration 114 | Method: No calibration 115 | Type: Binary 116 | Source class: Tune Results 117 | Data points: 4,000, split in 8 groups 118 | Truth variable: `class` 119 | Estimate variables: 120 | `.pred_class_1` ==> class_1 121 | `.pred_class_2` ==> class_2 122 | 123 | --- 124 | 125 | Code 126 | print(nope_multi) 127 | Message 128 | 129 | -- Probability Calibration 130 | Method: No calibration 131 | Type: Multiclass 132 | Source class: Tune Results 133 | Data points: 5,000, split in 10 groups 134 | Truth variable: `class` 135 | Estimate variables: 136 | `.pred_one` ==> one 137 | `.pred_two` ==> two 138 | `.pred_three` ==> three 139 | 140 | # no calibration fails - grouped_df 141 | 142 | x This function does not work with grouped data frames. 143 | i Apply `dplyr::ungroup()` and use the `.by` argument. 144 | 145 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/cal-plot.md: -------------------------------------------------------------------------------- 1 | # Binary breaks functions work with group argument 2 | 3 | x `.by` cannot select more than one column. 4 | i The following 2 columns were selected: 5 | i group1 and group2 6 | 7 | # breaks plot function errors - grouped_df 8 | 9 | x This function does not work with grouped data frames. 10 | i Apply `dplyr::ungroup()` and use the `.by` argument. 11 | 12 | # Binary logistic functions work with group argument 13 | 14 | x `.by` cannot select more than one column. 15 | i The following 2 columns were selected: 16 | i group1 and group2 17 | 18 | # logistic plot function errors - grouped_df 19 | 20 | x This function does not work with grouped data frames. 21 | i Apply `dplyr::ungroup()` and use the `.by` argument. 22 | 23 | # windowed plot function errors - grouped_df 24 | 25 | x This function does not work with grouped data frames. 26 | i Apply `dplyr::ungroup()` and use the `.by` argument. 27 | 28 | # Event level handling works 29 | 30 | i In argument: `res = map(...)`. 31 | Caused by error in `map()`: 32 | i In index: 1. 33 | Caused by error: 34 | ! Invalid `event_level` entry: invalid. Valid entries are "first", "second", or "auto". 35 | 36 | # regression plot function errors - grouped_df 37 | 38 | x This function does not work with grouped data frames. 39 | i Apply `dplyr::ungroup()` and use the `.by` argument. 40 | 41 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/cal-validate.md: -------------------------------------------------------------------------------- 1 | # Logistic validation with data frame input 2 | 3 | There are no saved prediction columns to collect. 4 | 5 | # Validation without calibration with data frame input 6 | 7 | There are no saved prediction columns to collect. 8 | 9 | # Isotonic classification validation with `fit_resamples` 10 | 11 | `truth` is automatically set when this type of object is used. 12 | 13 | # validation functions error with tune_results input 14 | 15 | This function can only be used with an object or the results of `tune::fit_resamples()` with a .predictions column. 16 | i Not an object. 17 | 18 | --- 19 | 20 | This function can only be used with an object or the results of `tune::fit_resamples()` with a .predictions column. 21 | i Not an object. 22 | 23 | --- 24 | 25 | This function can only be used with an object or the results of `tune::fit_resamples()` with a .predictions column. 26 | i Not an object. 27 | 28 | --- 29 | 30 | no applicable method for 'cal_validate_linear' applied to an object of class "c('tune_results', 'tbl_df', 'tbl', 'data.frame')" 31 | 32 | --- 33 | 34 | This function can only be used with an object or the results of `tune::fit_resamples()` with a .predictions column. 35 | i Not an object. 36 | 37 | --- 38 | 39 | This function can only be used with an object or the results of `tune::fit_resamples()` with a .predictions column. 40 | i Not an object. 41 | 42 | --- 43 | 44 | This function can only be used with an object or the results of `tune::fit_resamples()` with a .predictions column. 45 | i Not an object. 46 | 47 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/class-pred.md: -------------------------------------------------------------------------------- 1 | # slicing 2 | 3 | Code 4 | manual_creation_eq[1:6] 5 | Condition 6 | Error in `vec_slice()`: 7 | ! Can't subset elements past the end. 8 | i Location 6 doesn't exist. 9 | i There are only 5 elements. 10 | 11 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/conformal-intervals-quantile.md: -------------------------------------------------------------------------------- 1 | # split conformal quantile intervals 2 | 3 | no applicable method for 'int_conformal_quantile' applied to an object of class "lm" 4 | 5 | --- 6 | 7 | The required column "predictor_01" is missing. 8 | 9 | --- 10 | 11 | The required column "predictor_01" is missing. 12 | 13 | --- 14 | 15 | `...` must be empty. 16 | x Problematic argument: 17 | * level = 0.9 18 | 19 | --- 20 | 21 | Code 22 | lm_int 23 | Output 24 | Split Conformal inference via Quantile Regression 25 | preprocessor: formula 26 | model: linear_reg (engine = lm) 27 | calibration set size: 100 28 | confidence level: 0.9 29 | 30 | Use `predict(object, new_data)` to compute prediction intervals 31 | 32 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/conformal-intervals-split.md: -------------------------------------------------------------------------------- 1 | # split conformal intervals 2 | 3 | No known `int_conformal_split()` methods for this type of object. 4 | 5 | --- 6 | 7 | The required column "outcome" is missing. 8 | 9 | --- 10 | 11 | The required column "predictor_01" is missing. 12 | 13 | --- 14 | 15 | `...` must be empty. 16 | x Problematic argument: 17 | * level = 0.1 18 | 19 | --- 20 | 21 | `...` must be empty. 22 | x Problematic argument: 23 | * potato = 3 24 | 25 | --- 26 | 27 | Code 28 | lm_int 29 | Output 30 | Split Conformal inference 31 | preprocessor: formula 32 | model: linear_reg (engine = lm) 33 | calibration set size: 100 34 | 35 | Use `predict(object, new_data, level)` to compute prediction intervals 36 | 37 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/make-class-pred.md: -------------------------------------------------------------------------------- 1 | # fails with different length `...` 2 | 3 | All vectors passed to `...` must be of the same length. 4 | 5 | # fails with different type `...` 6 | 7 | x The index supplied to `...` are not numeric: 8 | i 2 9 | 10 | # fails with different length `...` VS levels 11 | 12 | `levels` must be a character vector with the same length as the number of vectors passed to `...`. 13 | 14 | # validates type of `levels` (#42) 15 | 16 | `levels` must be a character vector of length 2. 17 | 18 | --- 19 | 20 | `levels` must be a character vector with the same length as the number of vectors passed to `...`. 21 | 22 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/threshold-perf.md: -------------------------------------------------------------------------------- 1 | # custom metrics 2 | 3 | All metrics must be of type 'class_metric' (e.g. `sensitivity()`, ect) 4 | 5 | --- 6 | 7 | Code 8 | dplyr::count(threshold_perf(segment_logistic, Class, .pred_good, metrics = cls_met_good), 9 | .metric) 10 | Output 11 | # A tibble: 5 x 2 12 | .metric n 13 | 14 | 1 accuracy 21 15 | 2 distance 21 16 | 3 mcc 21 17 | 4 sensitivity 21 18 | 5 specificity 21 19 | 20 | --- 21 | 22 | Code 23 | dplyr::count(threshold_perf(segment_logistic, Class, .pred_good, metrics = cls_met_other), 24 | .metric) 25 | Output 26 | # A tibble: 2 x 2 27 | .metric n 28 | 29 | 1 accuracy 21 30 | 2 mcc 21 31 | 32 | -------------------------------------------------------------------------------- /tests/testthat/cal_files/binary_sim.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/probably/4dab76fce47ac7d75334f57c8cb4ecba21cadf15/tests/testthat/cal_files/binary_sim.rds -------------------------------------------------------------------------------- /tests/testthat/cal_files/fit_rs.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/probably/4dab76fce47ac7d75334f57c8cb4ecba21cadf15/tests/testthat/cal_files/fit_rs.rds -------------------------------------------------------------------------------- /tests/testthat/cal_files/multiclass_ames.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/probably/4dab76fce47ac7d75334f57c8cb4ecba21cadf15/tests/testthat/cal_files/multiclass_ames.rds -------------------------------------------------------------------------------- /tests/testthat/cal_files/reg_sim.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/probably/4dab76fce47ac7d75334f57c8cb4ecba21cadf15/tests/testthat/cal_files/reg_sim.rds -------------------------------------------------------------------------------- /tests/testthat/cal_files/sim_multi.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/probably/4dab76fce47ac7d75334f57c8cb4ecba21cadf15/tests/testthat/cal_files/sim_multi.rds -------------------------------------------------------------------------------- /tests/testthat/test-bound-prediction.R: -------------------------------------------------------------------------------- 1 | test_that("lower_limit bounds for numeric predictions", { 2 | skip_if_not_installed("modeldata") 3 | library(dplyr) 4 | library(rlang) 5 | data("solubility_test", package = "modeldata") 6 | tune2 <- function() call("tune", "test") 7 | 8 | # ------------------------------------------------------------------------------ 9 | 10 | expect_snapshot(bound_prediction(modeldata::solubility_test, lower_limit = 2), error = TRUE) 11 | expect_snapshot( 12 | modeldata::solubility_test |> 13 | mutate(.pred = format(prediction)) |> 14 | bound_prediction(lower_limit = 2), 15 | error = TRUE) 16 | 17 | sol <- modeldata::solubility_test |> set_names(c("solubility", ".pred")) 18 | 19 | expect_equal(bound_prediction(sol), sol) 20 | expect_equal(bound_prediction(sol, lower_limit = NA), sol) 21 | 22 | res_1 <- bound_prediction(sol, lower_limit = -1) 23 | expect_true(all(res_1$.pred[res_1$.pred < -1] == -1)) 24 | expect_equal(res_1$.pred[sol$.pred >= -1], sol$.pred[sol$.pred >= -1]) 25 | 26 | expect_snapshot(bound_prediction(sol, lower_limit = tune2()), error = TRUE) 27 | expect_snapshot(bound_prediction(as.matrix(sol), lower_limit = 1), error = TRUE) 28 | }) 29 | 30 | test_that("upper_limit bounds for numeric predictions", { 31 | skip_if_not_installed("modeldata") 32 | library(dplyr) 33 | library(rlang) 34 | data("solubility_test", package = "modeldata") 35 | tune2 <- function() call("tune", "test") 36 | 37 | # ------------------------------------------------------------------------------ 38 | 39 | expect_snapshot(bound_prediction(modeldata::solubility_test, lower_limit = 2), error = TRUE) 40 | expect_snapshot( 41 | modeldata::solubility_test |> 42 | mutate(.pred = format(prediction)) |> 43 | bound_prediction(lower_limit = 2), 44 | error = TRUE) 45 | 46 | sol <- modeldata::solubility_test |> set_names(c("solubility", ".pred")) 47 | 48 | expect_equal(bound_prediction(sol), sol) 49 | expect_equal(bound_prediction(sol, upper_limit = NA), sol) 50 | 51 | res_1 <- bound_prediction(sol, upper_limit = -1) 52 | expect_true(all(res_1$.pred[res_1$.pred > -1] == -1)) 53 | expect_equal(res_1$.pred[sol$.pred <= -1], sol$.pred[sol$.pred <= -1]) 54 | 55 | expect_snapshot(bound_prediction(sol, upper_limit = tune2()), error = TRUE) 56 | }) 57 | -------------------------------------------------------------------------------- /tests/testthat/test-cal-estimate-beta.R: -------------------------------------------------------------------------------- 1 | test_that("Beta estimates work - data.frame", { 2 | skip_if_not_installed("betacal") 3 | sl_beta <- cal_estimate_beta(segment_logistic, Class, smooth = FALSE) 4 | expect_cal_type(sl_beta, "binary") 5 | expect_cal_method(sl_beta, "Beta calibration") 6 | expect_cal_rows(sl_beta) 7 | expect_snapshot(print(sl_beta)) 8 | 9 | sl_beta_group <- segment_logistic |> 10 | dplyr::mutate(group = .pred_poor > 0.5) |> 11 | cal_estimate_beta(Class, smooth = FALSE, .by = group) 12 | 13 | expect_cal_type(sl_beta_group, "binary") 14 | expect_cal_method(sl_beta_group, "Beta calibration") 15 | expect_cal_rows(sl_beta_group) 16 | expect_snapshot(print(sl_beta_group)) 17 | 18 | expect_snapshot_error( 19 | segment_logistic |> 20 | dplyr::mutate(group1 = 1, group2 = 2) |> 21 | cal_estimate_beta(Class, smooth = FALSE, .by = c(group1, group2)) 22 | ) 23 | 24 | }) 25 | 26 | test_that("Beta estimates work - tune_results", { 27 | skip_if_not_installed("betacal") 28 | skip_if_not_installed("modeldata") 29 | 30 | tl_beta <- cal_estimate_beta(testthat_cal_binary()) 31 | expect_cal_type(tl_beta, "binary") 32 | expect_cal_method(tl_beta, "Beta calibration") 33 | expect_snapshot(print(tl_beta)) 34 | 35 | expect_equal( 36 | testthat_cal_binary_count(), 37 | nrow(cal_apply(testthat_cal_binary(), tl_beta)) 38 | ) 39 | 40 | # ------------------------------------------------------------------------------ 41 | # multinomial outcomes 42 | 43 | set.seed(100) 44 | suppressWarnings( 45 | mtnl_beta <- cal_estimate_beta(testthat_cal_multiclass()) 46 | ) 47 | expect_cal_type(mtnl_beta, "one_vs_all") 48 | expect_cal_method(mtnl_beta, "Beta calibration") 49 | expect_snapshot(print(mtnl_beta)) 50 | 51 | expect_equal( 52 | testthat_cal_multiclass_count(), 53 | nrow(cal_apply(testthat_cal_multiclass(), mtnl_beta)) 54 | ) 55 | }) 56 | 57 | test_that("Beta estimates errors - grouped_df", { 58 | skip_if_not_installed("betacal") 59 | expect_snapshot_error( 60 | cal_estimate_beta(dplyr::group_by(mtcars, vs)) 61 | ) 62 | }) 63 | -------------------------------------------------------------------------------- /tests/testthat/test-cal-estimate-multinomial.R: -------------------------------------------------------------------------------- 1 | test_that("Multinomial estimates work - data.frame", { 2 | skip_if_not_installed("modeldata") 3 | skip_if_not_installed("nnet") 4 | 5 | sp_multi <- cal_estimate_multinomial(species_probs, Species, smooth = FALSE) 6 | expect_cal_type(sp_multi, "multiclass") 7 | expect_cal_method(sp_multi, "Multinomial regression calibration") 8 | expect_cal_rows(sp_multi, n = 110) 9 | expect_snapshot(print(sp_multi)) 10 | 11 | sp_smth_multi <- cal_estimate_multinomial(species_probs, Species, smooth = TRUE) 12 | expect_cal_type(sp_smth_multi, "multiclass") 13 | expect_cal_method(sp_smth_multi, "Generalized additive model calibration") 14 | expect_cal_rows(sp_smth_multi, n = 110) 15 | expect_snapshot(print(sp_smth_multi)) 16 | 17 | sl_multi_group <- species_probs |> 18 | dplyr::mutate(group = .pred_bobcat > 0.5) |> 19 | cal_estimate_multinomial(Species, smooth = FALSE, .by = group) 20 | 21 | expect_cal_type(sl_multi_group, "multiclass") 22 | expect_cal_method(sl_multi_group, "Multinomial regression calibration") 23 | expect_cal_rows(sl_multi_group, n = 110) 24 | expect_snapshot(print(sl_multi_group)) 25 | 26 | expect_snapshot_error( 27 | species_probs |> 28 | dplyr::mutate(group1 = 1, group2 = 2) |> 29 | cal_estimate_multinomial(Species, smooth = FALSE, .by = c(group1, group2)) 30 | ) 31 | 32 | mltm_configs <- 33 | mnl_with_configs() |> 34 | cal_estimate_multinomial(truth = obs, estimate = c(VF:L), smooth = FALSE) 35 | }) 36 | 37 | test_that("Multinomial estimates work - tune_results", { 38 | skip_if_not_installed("modeldata") 39 | skip_if_not_installed("nnet") 40 | 41 | tl_multi <- cal_estimate_multinomial(testthat_cal_multiclass(), smooth = FALSE) 42 | expect_cal_type(tl_multi, "multiclass") 43 | expect_cal_method(tl_multi, "Multinomial regression calibration") 44 | expect_snapshot(print(tl_multi)) 45 | 46 | expect_equal( 47 | testthat_cal_multiclass() |> 48 | tune::collect_predictions(summarize = TRUE) |> 49 | nrow(), 50 | testthat_cal_multiclass() |> 51 | cal_apply(tl_multi) |> 52 | nrow() 53 | ) 54 | 55 | tl_smth_multi <- cal_estimate_multinomial(testthat_cal_multiclass(), smooth = TRUE) 56 | expect_cal_type(tl_smth_multi, "multiclass") 57 | expect_cal_method(tl_smth_multi, "Generalized additive model calibration") 58 | expect_snapshot(print(tl_smth_multi)) 59 | 60 | expect_equal( 61 | testthat_cal_multiclass() |> 62 | tune::collect_predictions(summarize = TRUE) |> 63 | nrow(), 64 | testthat_cal_multiclass() |> 65 | cal_apply(tl_smth_multi) |> 66 | nrow() 67 | ) 68 | }) 69 | 70 | test_that("Multinomial estimates errors - grouped_df", { 71 | skip_if_not_installed("modeldata") 72 | skip_if_not_installed("nnet") 73 | 74 | expect_snapshot_error( 75 | cal_estimate_multinomial(dplyr::group_by(mtcars, vs)) 76 | ) 77 | }) 78 | 79 | test_that("Passing a binary outcome causes error", { 80 | expect_error( 81 | cal_estimate_multinomial(segment_logistic, Class) 82 | ) 83 | }) 84 | 85 | test_that("Multinomial spline switches to linear if too few unique", { 86 | skip_if_not_installed("modeldata") 87 | 88 | smol_species_probs <- 89 | species_probs |> 90 | dplyr::slice_head(n = 2, by = Species) 91 | 92 | expect_snapshot( 93 | sl_gam <- cal_estimate_multinomial(smol_species_probs, Species, smooth = TRUE) 94 | ) 95 | sl_glm <- cal_estimate_multinomial(smol_species_probs, Species, smooth = FALSE) 96 | 97 | expect_identical( 98 | sl_gam$estimates, 99 | sl_glm$estimates 100 | ) 101 | 102 | smol_by_species_probs <- 103 | species_probs |> 104 | dplyr::slice_head(n = 4, by = Species) |> 105 | dplyr::mutate(id = rep(1:2, 6)) 106 | 107 | expect_snapshot( 108 | sl_gam <- cal_estimate_multinomial(smol_by_species_probs, Species, .by = id, smooth = TRUE) 109 | ) 110 | sl_glm <- cal_estimate_multinomial(smol_by_species_probs, Species, .by = id, smooth = FALSE) 111 | 112 | expect_identical( 113 | sl_gam$estimates, 114 | sl_glm$estimates 115 | ) 116 | }) 117 | -------------------------------------------------------------------------------- /tests/testthat/test-cal-estimate-none.R: -------------------------------------------------------------------------------- 1 | test_that("no calibration works - data.frame", { 2 | skip_if_not_installed("modeldata") 3 | 4 | ## Regression 5 | 6 | nope_reg <- cal_estimate_none(boosting_predictions_oob, outcome) 7 | expect_cal_type(nope_reg, "regression") 8 | expect_cal_method(nope_reg, "No calibration") 9 | expect_cal_rows(nope_reg, 2000) 10 | expect_snapshot(print(nope_reg)) 11 | expect_equal( 12 | cal_apply(boosting_predictions_oob, nope_reg), 13 | boosting_predictions_oob 14 | ) 15 | 16 | reg_group_data <- boosting_predictions_oob |> 17 | dplyr::mutate(group = .pred > 0.5) 18 | 19 | nope_reg_group <- cal_estimate_none(reg_group_data, outcome, .by = group) 20 | expect_cal_type(nope_reg_group, "regression") 21 | expect_cal_method(nope_reg_group, "No calibration") 22 | expect_cal_rows(nope_reg_group, 2000) 23 | expect_snapshot(print(nope_reg_group)) 24 | expect_equal( 25 | cal_apply(reg_group_data, nope_reg_group), 26 | reg_group_data 27 | ) 28 | 29 | expect_snapshot_error( 30 | boosting_predictions_oob |> 31 | dplyr::mutate(group1 = 1, group2 = 2) |> 32 | cal_estimate_none(outcome, .by = c(group1, group2)) 33 | ) 34 | 35 | expect_snapshot_error( 36 | cal_estimate_none(boosting_predictions_oob, outcome, smooth = TRUE) 37 | ) 38 | 39 | ## Binary classification 40 | 41 | nope_binary <- cal_estimate_none(segment_logistic, Class) 42 | expect_cal_type(nope_binary, "binary") 43 | expect_cal_method(nope_binary, "No calibration") 44 | expect_cal_rows(nope_binary) 45 | expect_snapshot(print(nope_binary)) 46 | expect_equal( 47 | cal_apply(segment_logistic, nope_binary), 48 | segment_logistic 49 | ) 50 | 51 | expect_snapshot_error( 52 | segment_logistic |> cal_estimate_none(truth = Class, estimate = .pred_poor) 53 | ) 54 | 55 | expect_snapshot_error( 56 | segment_logistic |> 57 | dplyr::mutate(group1 = 1, group2 = 2) |> 58 | cal_estimate_none(Class, .by = c(group1, group2)) 59 | ) 60 | 61 | ## Multinomial classification 62 | 63 | nope_multi <- cal_estimate_none(species_probs, Species) 64 | expect_cal_type(nope_multi, "multiclass") 65 | expect_cal_method(nope_multi, "No calibration") 66 | expect_cal_rows(nope_multi, n = 110) 67 | expect_snapshot(print(nope_multi)) 68 | expect_equal( 69 | cal_apply(species_probs, nope_multi), 70 | species_probs 71 | ) 72 | 73 | expect_snapshot_error( 74 | species_probs |> 75 | dplyr::mutate(group1 = 1, group2 = 2) |> 76 | cal_estimate_none(Species, .by = c(group1, group2)) 77 | ) 78 | 79 | }) 80 | 81 | test_that("no calibration works - tune_results", { 82 | skip_if_not_installed("modeldata") 83 | 84 | ## Regression 85 | reg_pred <- collect_predictions(testthat_cal_reg()) 86 | nope_reg <- cal_estimate_none(testthat_cal_reg(), outcome) 87 | expect_cal_type(nope_reg, "regression") 88 | expect_cal_method(nope_reg, "No calibration") 89 | expect_snapshot(print(nope_reg)) 90 | expect_equal( 91 | cal_apply(reg_pred, nope_reg), 92 | reg_pred 93 | ) 94 | 95 | expect_snapshot_error( 96 | cal_estimate_none(testthat_cal_reg(), outcome, do_something = FALSE) 97 | ) 98 | 99 | ## Binary classification 100 | 101 | binary_pred <- collect_predictions(testthat_cal_binary()) 102 | nope_binary <- cal_estimate_none(testthat_cal_binary()) 103 | expect_cal_type(nope_binary, "binary") 104 | expect_cal_method(nope_binary, "No calibration") 105 | expect_snapshot(print(nope_binary)) 106 | expect_equal( 107 | cal_apply(binary_pred, nope_binary), 108 | binary_pred 109 | ) 110 | 111 | ## Multinomial classification 112 | 113 | multi_pred <- collect_predictions(testthat_cal_multiclass()) 114 | nope_multi <- cal_estimate_none(testthat_cal_multiclass()) 115 | expect_cal_type(nope_multi, "multiclass") 116 | expect_cal_method(nope_multi, "No calibration") 117 | expect_snapshot(print(nope_multi)) 118 | expect_equal( 119 | cal_apply(multi_pred, nope_multi), 120 | multi_pred 121 | ) 122 | 123 | }) 124 | 125 | test_that("no calibration fails - grouped_df", { 126 | 127 | expect_snapshot_error( 128 | cal_estimate_none(dplyr::group_by(mtcars, vs)) 129 | ) 130 | 131 | }) 132 | 133 | -------------------------------------------------------------------------------- /tests/testthat/test-cal-pkg-check.R: -------------------------------------------------------------------------------- 1 | test_that("Missing package returns error", { 2 | expect_error( 3 | cal_pkg_check(c("NotReal")) 4 | ) 5 | }) 6 | -------------------------------------------------------------------------------- /tests/testthat/test-cal-validate-multiclass.R: -------------------------------------------------------------------------------- 1 | test_that("Isotonic validation with data frame input - Multiclass", { 2 | df <- rsample::vfold_cv(testthat_cal_sim_multi()) 3 | val_obj <- cal_validate_isotonic(df, class) 4 | val_with_pred <- cal_validate_isotonic(df, class, save_pred = TRUE) 5 | 6 | expect_s3_class(val_obj, "data.frame") 7 | expect_s3_class(val_obj, "cal_rset") 8 | expect_equal(nrow(val_obj), nrow(df)) 9 | expect_equal( 10 | names(val_obj), 11 | c("splits", "id", ".metrics", ".metrics_cal") 12 | ) 13 | 14 | expect_s3_class(val_with_pred, "data.frame") 15 | expect_s3_class(val_with_pred, "cal_rset") 16 | expect_equal(nrow(val_with_pred), nrow(df)) 17 | expect_equal( 18 | names(val_with_pred), 19 | c("splits", "id", ".metrics", ".metrics_cal", ".predictions_cal") 20 | ) 21 | expect_equal( 22 | names(val_with_pred$.predictions_cal[[1]]), 23 | c(".pred_one", ".pred_two", ".pred_three", "class", ".row", ".pred_class") 24 | ) 25 | expect_equal( 26 | purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), 27 | purrr::map_int(val_with_pred$.predictions_cal, nrow) 28 | ) 29 | }) 30 | 31 | test_that("Isotonic validation with `fit_resamples` - Multiclass", { 32 | res <- testthat_cal_fit_rs() 33 | val_obj <- cal_validate_isotonic(res$multin) 34 | val_with_pred <- cal_validate_isotonic(res$multin, save_pred = TRUE) 35 | 36 | expect_s3_class(val_obj, "data.frame") 37 | expect_s3_class(val_obj, "cal_rset") 38 | expect_equal(nrow(val_obj), nrow(res$multin)) 39 | expect_equal( 40 | names(val_obj), 41 | c("splits", "id", ".notes", ".predictions", ".metrics", ".metrics_cal") 42 | ) 43 | 44 | expect_s3_class(val_with_pred, "data.frame") 45 | expect_s3_class(val_with_pred, "cal_rset") 46 | expect_equal(nrow(val_with_pred), nrow(res$multin)) 47 | expect_equal( 48 | names(val_with_pred), 49 | c("splits", "id", ".notes", ".predictions", ".metrics", ".metrics_cal", ".predictions_cal") 50 | ) 51 | skip_if_not_installed("tune", "1.2.0") 52 | expect_equal( 53 | names(val_with_pred$.predictions_cal[[1]]), 54 | c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class") 55 | ) 56 | expect_equal( 57 | purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), 58 | purrr::map_int(val_with_pred$.predictions_cal, nrow) 59 | ) 60 | }) 61 | -------------------------------------------------------------------------------- /tests/testthat/test-conformal-intervals-quantile.R: -------------------------------------------------------------------------------- 1 | test_that("split conformal quantile intervals", { 2 | skip_if_not_installed("modeldata") 3 | skip_if_not_installed("nnet") 4 | skip_if_not_installed("quantregForest") 5 | 6 | # ---------------------------------------------------------------------------- 7 | 8 | suppressPackageStartupMessages(library(workflows)) 9 | suppressPackageStartupMessages(library(modeldata)) 10 | suppressPackageStartupMessages(library(dplyr)) 11 | 12 | # ---------------------------------------------------------------------------- 13 | 14 | set.seed(111) 15 | sim_data <- sim_regression(500) 16 | sim_cal <- sim_regression(100) 17 | sim_new <- sim_regression(2) 18 | 19 | wflow <- 20 | workflow() |> 21 | add_model(parsnip::linear_reg()) |> 22 | add_formula(outcome ~ .) |> 23 | fit(sim_data) 24 | 25 | # ------------------------------------------------------------------------------ 26 | 27 | expect_snapshot_error( 28 | int_conformal_quantile(lm(outcome ~ ., sim_data), sim_cal) 29 | ) 30 | 31 | expect_snapshot_error( 32 | int_conformal_quantile(wflow, sim_data[, -2], sim_cal[, -1]) 33 | ) 34 | expect_snapshot_error( 35 | int_conformal_quantile(wflow, sim_data, sim_cal[, -2]) 36 | ) 37 | 38 | # ------------------------------------------------------------------------------ 39 | 40 | lm_int <- 41 | int_conformal_quantile(wflow, sim_data, sim_cal, level = 0.90, trees = 20) 42 | expect_snapshot_error( 43 | predict(lm_int, sim_new, level = 0.90) 44 | ) 45 | expect_snapshot(lm_int) 46 | expect_true(inherits(lm_int, "int_conformal_quantile")) 47 | 48 | new_int <- predict(lm_int, sim_new) 49 | exp_ptype <- 50 | dplyr::tibble( 51 | .pred = numeric(0), 52 | .pred_lower = numeric(0), 53 | .pred_upper = numeric(0) 54 | ) 55 | 56 | expect_true(inherits(new_int, "tbl_df")) 57 | expect_equal(new_int[0, ], exp_ptype) 58 | expect_equal( 59 | colnames(new_int), 60 | c(".pred", ".pred_lower", ".pred_upper") 61 | ) 62 | expect_equal( 63 | nrow(new_int), 64 | nrow(sim_new) 65 | ) 66 | }) 67 | -------------------------------------------------------------------------------- /tests/testthat/test-conformal-intervals-split.R: -------------------------------------------------------------------------------- 1 | test_that("split conformal intervals", { 2 | skip_if_not_installed("modeldata") 3 | skip_if_not_installed("nnet") 4 | 5 | # ---------------------------------------------------------------------------- 6 | 7 | suppressPackageStartupMessages(library(workflows)) 8 | suppressPackageStartupMessages(library(modeldata)) 9 | suppressPackageStartupMessages(library(dplyr)) 10 | 11 | # ---------------------------------------------------------------------------- 12 | 13 | set.seed(111) 14 | sim_data <- sim_regression(500) 15 | sim_cal <- sim_regression(100) 16 | sim_new <- sim_regression(2) 17 | 18 | wflow <- 19 | workflow() |> 20 | add_model(parsnip::linear_reg()) |> 21 | add_formula(outcome ~ .) |> 22 | fit(sim_data) 23 | 24 | # ------------------------------------------------------------------------------ 25 | 26 | expect_snapshot_error( 27 | int_conformal_split(lm(outcome ~ ., sim_data), sim_cal) 28 | ) 29 | 30 | expect_snapshot_error( 31 | int_conformal_split(wflow, sim_cal[, -1]) 32 | ) 33 | expect_snapshot_error( 34 | int_conformal_split(wflow, sim_cal[, -2]) 35 | ) 36 | expect_snapshot_error( 37 | int_conformal_split(wflow, sim_cal, level = .1) 38 | ) 39 | 40 | # ------------------------------------------------------------------------------ 41 | 42 | lm_int <- int_conformal_split(wflow, sim_cal) 43 | expect_snapshot_error( 44 | predict(lm_int, sim_new, potato = 3) 45 | ) 46 | expect_snapshot(lm_int) 47 | expect_true(inherits(lm_int, "int_conformal_split")) 48 | 49 | new_int <- predict(lm_int, sim_new, level = 0.90) 50 | exp_ptype <- 51 | dplyr::tibble( 52 | .pred = numeric(0), 53 | .pred_lower = numeric(0), 54 | .pred_upper = numeric(0) 55 | ) 56 | 57 | expect_true(inherits(new_int, "tbl_df")) 58 | expect_equal(new_int[0, ], exp_ptype) 59 | expect_equal( 60 | colnames(new_int), 61 | c(".pred", ".pred_lower", ".pred_upper") 62 | ) 63 | expect_equal( 64 | nrow(new_int), 65 | nrow(sim_new) 66 | ) 67 | }) 68 | -------------------------------------------------------------------------------- /tests/testthat/test-make-class-pred.R: -------------------------------------------------------------------------------- 1 | test_data <- segment_logistic[1:5, ] 2 | good <- test_data$.pred_good 3 | poor <- test_data$.pred_poor 4 | lvls <- levels(test_data$Class) 5 | 6 | test_data2 <- species_probs[1:5, ] 7 | bobcat <- test_data2$.pred_bobcat 8 | coyote <- test_data2$.pred_coyote 9 | gray_fox <- test_data2$.pred_gray_fox 10 | lvls2 <- levels(test_data2$Species) 11 | 12 | test_that("two class succeeds with vector interface", { 13 | res <- make_two_class_pred(good, levels = lvls, threshold = .5, buffer = .4) 14 | 15 | fct <- factor(c("poor", "poor", "good", "good", "poor")) 16 | known <- class_pred(fct, which = c(2, 3, 4)) 17 | 18 | expect_s3_class(res, "class_pred") 19 | expect_equal(res, known) 20 | }) 21 | 22 | test_that("multi class succeeds with vector interface", { 23 | res <- make_class_pred(bobcat, coyote, gray_fox, levels = lvls2, min_prob = 0.5) 24 | 25 | fct <- factor(c("gray_fox", "gray_fox", "bobcat", "gray_fox", "coyote")) 26 | known <- class_pred(fct, which = 5) 27 | 28 | expect_s3_class(res, "class_pred") 29 | expect_equal(res, known) 30 | }) 31 | 32 | test_that("multi class succeeds with data frame helper", { 33 | res <- append_class_pred( 34 | test_data2, 35 | contains(".pred_"), 36 | levels = lvls2, 37 | min_prob = 0.5, 38 | name = "cp_test" 39 | ) 40 | 41 | known <- make_class_pred(bobcat, coyote, gray_fox, levels = lvls2, min_prob = 0.5) 42 | 43 | expect_s3_class(res, "data.frame") 44 | expect_equal(res[["cp_test"]], known) 45 | }) 46 | 47 | 48 | 49 | test_that("ordered passes through to class_pred", { 50 | res <- make_class_pred(bobcat, coyote, gray_fox, levels = lvls2, ordered = TRUE) 51 | res2 <- make_class_pred(bobcat, coyote, gray_fox, levels = lvls2, ordered = TRUE) 52 | 53 | expect_true(is_ordered_class_pred(res)) 54 | expect_true(is_ordered_class_pred(res2)) 55 | }) 56 | 57 | test_that("fails with different length `...`", { 58 | v1 <- c(1, 2, 3) 59 | v2 <- c(1, 2) 60 | 61 | expect_snapshot_error( 62 | make_class_pred(v1, v2) 63 | ) 64 | }) 65 | 66 | test_that("fails with different type `...`", { 67 | v1 <- c(1) 68 | v2 <- c("a") 69 | 70 | expect_snapshot_error( 71 | make_class_pred(v1, v2) 72 | ) 73 | }) 74 | 75 | test_that("fails with different length `...` VS levels", { 76 | v1 <- c(1, 2, 3) 77 | v2 <- c(1, 2, 3) 78 | 79 | expect_snapshot_error( 80 | make_class_pred(v1, v2, levels = c("one", "two", "three")) 81 | ) 82 | }) 83 | 84 | test_that("validates type of `levels` (#42)", { 85 | expect_snapshot_error( 86 | make_two_class_pred(1, levels = NULL) 87 | ) 88 | 89 | expect_snapshot_error( 90 | make_class_pred(1, 2, levels = c(0L, 4L)) 91 | ) 92 | }) 93 | -------------------------------------------------------------------------------- /tests/testthat/test-threshold-perf.R: -------------------------------------------------------------------------------- 1 | sim_n <- 120 2 | 3 | set.seed(1094) 4 | ex_data <- 5 | data.frame( 6 | group_1 = sample(month.abb, replace = TRUE, size = sim_n), 7 | group_2 = sample(LETTERS[1:2], replace = TRUE, size = sim_n), 8 | outcome = factor(sample(paste0("Cl", 1:2), replace = TRUE, size = sim_n)), 9 | prob_est = runif(sim_n), 10 | x = rnorm(sim_n) 11 | ) 12 | 13 | ex_data_miss <- ex_data 14 | ex_data_miss$group_1[c(1, 10, 29)] <- NA 15 | ex_data_miss$prob_est[c(56, 117)] <- NA 16 | ex_data_miss$outcome[c(49, 85, 57, 110)] <- NA 17 | 18 | thr <- c(0, .5, .78, 1) 19 | 20 | get_res <- function(prob, obs, cut) { 21 | suppressPackageStartupMessages(require(yardstick)) 22 | cls <- recode_data(obs, prob, cut, event_level = "first") 23 | dat <- data.frame( 24 | obs = obs, 25 | cls = cls 26 | ) 27 | 28 | mets <- yardstick::metric_set(sensitivity, specificity, j_index) 29 | 30 | .data_metrics <- dat |> 31 | mets(obs, estimate = cls) 32 | 33 | # Create the `distance` metric data frame 34 | sens_vec <- .data_metrics |> 35 | dplyr::filter(.metric == "sensitivity") |> 36 | dplyr::pull(.estimate) 37 | 38 | dist <- .data_metrics |> 39 | dplyr::filter(.metric == "specificity") |> 40 | dplyr::mutate( 41 | .metric = "distance", 42 | # .estimate is spec currently 43 | .estimate = (1 - sens_vec)^2 + (1 - .estimate)^2 44 | ) 45 | 46 | dplyr::bind_rows(.data_metrics, dist) 47 | } 48 | 49 | # ---------------------------------------------------------------- 50 | 51 | test_that("factor from numeric", { 52 | new_fac_1 <- 53 | recode_data( 54 | obs = ex_data$outcome, 55 | prob = ex_data$prob_est, 56 | threshold = ex_data$prob_est[1], 57 | event_level = "first" 58 | ) 59 | tab_1 <- table(new_fac_1) 60 | expect_s3_class(new_fac_1, "factor") 61 | expect_true(isTRUE(all.equal(levels(new_fac_1), levels(ex_data$outcome)))) 62 | expect_equal(unname(tab_1["Cl1"]), sum(ex_data$prob_est >= ex_data$prob_est[1])) 63 | expect_equal(unname(tab_1["Cl2"]), sum(ex_data$prob_est < ex_data$prob_est[1])) 64 | 65 | # missing data 66 | new_fac_2 <- 67 | recode_data( 68 | obs = ex_data_miss$outcome, 69 | prob = ex_data_miss$prob_est, 70 | threshold = ex_data_miss$prob_est[1], 71 | event_level = "first" 72 | ) 73 | tab_2 <- table(new_fac_2) 74 | expect_s3_class(new_fac_2, "factor") 75 | cmpl_probs <- ex_data_miss$prob_est[!is.na(ex_data_miss$prob_est)] 76 | expect_true(isTRUE(all.equal(is.na(new_fac_2), is.na(ex_data_miss$prob_est)))) 77 | expect_true(isTRUE(all.equal(levels(new_fac_2), levels(ex_data_miss$outcome)))) 78 | expect_equal(unname(tab_2["Cl1"]), sum(cmpl_probs >= ex_data_miss$prob_est[1])) 79 | expect_equal(unname(tab_2["Cl2"]), sum(cmpl_probs < ex_data_miss$prob_est[1])) 80 | 81 | new_fac_3 <- 82 | recode_data( 83 | obs = ex_data$outcome, 84 | prob = ex_data$prob_est, 85 | threshold = ex_data$prob_est[1], 86 | event_level = "second" 87 | ) 88 | tab_3 <- table(new_fac_3) 89 | expect_s3_class(new_fac_3, "factor") 90 | expect_true(isTRUE(all.equal(levels(new_fac_3), levels(ex_data$outcome)))) 91 | expect_equal(unname(tab_3["Cl1"]), sum(ex_data$prob_est < ex_data$prob_est[1])) 92 | expect_equal(unname(tab_3["Cl2"]), sum(ex_data$prob_est >= ex_data$prob_est[1])) 93 | }) 94 | 95 | test_that("single group", { 96 | one_group_data <- 97 | ex_data |> 98 | dplyr::group_by(group_2) |> 99 | threshold_perf( 100 | outcome, 101 | prob_est, 102 | thresholds = thr 103 | ) 104 | 105 | for (i in thr) { 106 | one_group_data_obs <- one_group_data |> 107 | dplyr::filter(group_2 == "A" & .threshold == i) |> 108 | dplyr::select(-group_2, -.threshold) |> 109 | as.data.frame() 110 | 111 | one_group_data_exp <- 112 | get_res( 113 | ex_data$prob_est[ex_data$group_2 == "A"], 114 | ex_data$outcome[ex_data$group_2 == "A"], 115 | i 116 | ) |> 117 | as.data.frame() 118 | expect_equal(one_group_data_obs, one_group_data_exp) 119 | } 120 | }) 121 | 122 | test_that("custom metrics", { 123 | suppressPackageStartupMessages(require(yardstick)) 124 | suppressPackageStartupMessages(require(dplyr)) 125 | 126 | cls_met_bad <- metric_set(sensitivity, specificity, accuracy, roc_auc) 127 | cls_met_good <- metric_set(sensitivity, specificity, accuracy, mcc) 128 | cls_met_other <- metric_set(accuracy, mcc) 129 | 130 | expect_snapshot_error( 131 | segment_logistic |> 132 | threshold_perf(Class, .pred_good, metrics = cls_met_bad) 133 | ) 134 | 135 | expect_snapshot( 136 | segment_logistic |> 137 | threshold_perf(Class, .pred_good, metrics = cls_met_good) |> 138 | dplyr::count(.metric) 139 | ) 140 | 141 | expect_snapshot( 142 | segment_logistic |> 143 | threshold_perf(Class, .pred_good, metrics = cls_met_other) |> 144 | dplyr::count(.metric) 145 | ) 146 | }) 147 | -------------------------------------------------------------------------------- /tests/testthat/test-vctrs-compat.R: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # vec_proxy_compare() 3 | 4 | test_that("can take the comparison proxy", { 5 | x <- class_pred(factor(c("a", "b", NA)), which = 2) 6 | expect_identical(vec_proxy_compare(x), unclass(x)) 7 | }) 8 | 9 | # ------------------------------------------------------------------------------ 10 | # vctrs miscellaneous 11 | 12 | test_that("can order by level with equivocal as smallest value using vec_order()", { 13 | x <- factor(c("a", "b", NA, "b"), levels = c("b", "a")) 14 | x <- class_pred(x, which = 2) 15 | 16 | expect <- factor(c("b", "b", "a", NA), levels = c("b", "a")) 17 | expect <- class_pred(expect, which = 1) 18 | 19 | expect_identical(x[vec_order(x)], expect) 20 | }) 21 | -------------------------------------------------------------------------------- /vignettes/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | *.R 3 | calibrate.qmd 4 | calibrate_files 5 | rsconnect 6 | --------------------------------------------------------------------------------