├── .Rbuildignore ├── .github ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md └── workflows │ ├── R-CMD-check-no-suggests.yaml │ ├── R-CMD-check.yaml │ ├── lock.yaml │ ├── pkgdown.yaml │ ├── pr-commands.yaml │ └── test-coverage.yaml ├── .gitignore ├── .vscode ├── extensions.json └── settings.json ├── DESCRIPTION ├── LICENSE ├── LICENSE.md ├── MAINTENANCE.md ├── NAMESPACE ├── NEWS.md ├── R ├── aaa-metrics.R ├── aaa-new.R ├── aaa.R ├── check-metric.R ├── class-accuracy.R ├── class-bal_accuracy.R ├── class-detection_prevalence.R ├── class-f_meas.R ├── class-j_index.R ├── class-kap.R ├── class-mcc.R ├── class-npv.R ├── class-ppv.R ├── class-precision.R ├── class-recall.R ├── class-sens.R ├── class-spec.R ├── conf_mat.R ├── data.R ├── deprecated-prob_helpers.R ├── deprecated-template.R ├── estimator-helpers.R ├── event-level.R ├── fair-aaa.R ├── fair-demographic_parity.R ├── fair-equal_opportunity.R ├── fair-equalized_odds.R ├── import-standalone-obj-type.R ├── import-standalone-survival.R ├── import-standalone-types-check.R ├── metric-tweak.R ├── misc.R ├── missings.R ├── num-ccc.R ├── num-huber_loss.R ├── num-iic.R ├── num-mae.R ├── num-mape.R ├── num-mase.R ├── num-mpe.R ├── num-msd.R ├── num-poisson_log_loss.R ├── num-pseudo_huber_loss.R ├── num-rmse.R ├── num-rpd.R ├── num-rpiq.R ├── num-rsq.R ├── num-rsq_trad.R ├── num-smape.R ├── orderedprob-ranked_prob_score.R ├── prob-average_precision.R ├── prob-binary-thresholds.R ├── prob-brier_class.R ├── prob-classification_cost.R ├── prob-gain_capture.R ├── prob-gain_curve.R ├── prob-helpers.R ├── prob-lift_curve.R ├── prob-mn_log_loss.R ├── prob-pr_auc.R ├── prob-pr_curve.R ├── prob-roc_auc.R ├── prob-roc_aunp.R ├── prob-roc_aunu.R ├── prob-roc_curve.R ├── reexports.R ├── surv-brier_survival.R ├── surv-brier_survival_integrated.R ├── surv-concordance_survival.R ├── surv-roc_auc_survival.R ├── surv-roc_curve_survival.R ├── template.R ├── validation.R └── yardstick-package.R ├── README.Rmd ├── README.md ├── _pkgdown.yml ├── air.toml ├── codecov.yml ├── cran-comments.md ├── data-raw ├── dyn-surv-metrics │ ├── auc_churn_res.RData │ ├── brier_churn_res.RData │ ├── collect_metrics.R │ ├── generate_metrics.R │ ├── rr_churn_data.RData │ └── survival-curve-reference.R └── lung_surv.R ├── data ├── datalist ├── hpc_cv.rda ├── lung_surv.rda ├── pathology.rda ├── solubility_test.rda └── two_class_example.rda ├── man-roxygen ├── event-fair.R ├── event_first.R ├── examples-binary-prob.R ├── examples-class.R ├── examples-counts.R ├── examples-fair.R ├── examples-multiclass-prob.R ├── examples-numeric.R ├── multiclass-curve.R ├── multiclass-prob.R ├── multiclass.R ├── return-dynamic-survival.R ├── return-fair.R ├── return-prob.R ├── return.R ├── table-positive.R └── table-relevance.R ├── man ├── accuracy.Rd ├── average_precision.Rd ├── bal_accuracy.Rd ├── brier_class.Rd ├── brier_survival.Rd ├── brier_survival_integrated.Rd ├── ccc.Rd ├── check_metric.Rd ├── classification_cost.Rd ├── concordance_survival.Rd ├── conf_mat.Rd ├── demographic_parity.Rd ├── detection_prevalence.Rd ├── developer-helpers.Rd ├── equal_opportunity.Rd ├── equalized_odds.Rd ├── f_meas.Rd ├── figures │ ├── README-roc-curves-1.png │ ├── lifecycle-archived.svg │ ├── lifecycle-defunct.svg │ ├── lifecycle-deprecated.svg │ ├── lifecycle-experimental.svg │ ├── lifecycle-maturing.svg │ ├── lifecycle-questioning.svg │ ├── lifecycle-stable.svg │ ├── lifecycle-superseded.svg │ └── logo.png ├── gain_capture.Rd ├── gain_curve.Rd ├── hpc_cv.Rd ├── huber_loss.Rd ├── huber_loss_pseudo.Rd ├── iic.Rd ├── j_index.Rd ├── kap.Rd ├── lift_curve.Rd ├── lung_surv.Rd ├── mae.Rd ├── mape.Rd ├── mase.Rd ├── mcc.Rd ├── metric-summarizers.Rd ├── metric_set.Rd ├── metric_summarizer.Rd ├── metric_tweak.Rd ├── metric_vec_template.Rd ├── metrics.Rd ├── mn_log_loss.Rd ├── mpe.Rd ├── msd.Rd ├── new-metric.Rd ├── new_groupwise_metric.Rd ├── npv.Rd ├── pathology.Rd ├── poisson_log_loss.Rd ├── ppv.Rd ├── pr_auc.Rd ├── pr_curve.Rd ├── precision.Rd ├── ranked_prob_score.Rd ├── recall.Rd ├── reexports.Rd ├── rmse.Rd ├── roc_auc.Rd ├── roc_auc_survival.Rd ├── roc_aunp.Rd ├── roc_aunu.Rd ├── roc_curve.Rd ├── roc_curve_survival.Rd ├── rpd.Rd ├── rpiq.Rd ├── rsq.Rd ├── rsq_trad.Rd ├── sens.Rd ├── smape.Rd ├── solubility_test.Rd ├── spec.Rd ├── summary.conf_mat.Rd ├── two_class_example.Rd ├── yardstick-package.Rd └── yardstick_remove_missing.Rd ├── pkgdown └── favicon │ ├── apple-touch-icon-120x120.png │ ├── apple-touch-icon-60x60.png │ ├── apple-touch-icon-76x76.png │ ├── apple-touch-icon.png │ ├── favicon-16x16.png │ ├── favicon-32x32.png │ └── favicon.ico ├── revdep ├── .gitignore ├── README.md ├── cran.md ├── email.yml ├── failures.md └── problems.md ├── src ├── .gitignore ├── init.c ├── mcc-multiclass.c └── yardstick.h ├── tests ├── pycompare │ └── generate-pycompare.R ├── testthat.R └── testthat │ ├── _snaps │ ├── aaa-metrics.md │ ├── aaa-new.md │ ├── autoplot.md │ ├── check_metric.md │ ├── class-accuracy.md │ ├── class-bal_accuracy.md │ ├── class-detection_prevalence.md │ ├── class-f_meas.md │ ├── class-j_index.md │ ├── class-kap.md │ ├── class-mcc.md │ ├── class-npv.md │ ├── class-ppv.md │ ├── class-precision.md │ ├── class-recall.md │ ├── class-sens.md │ ├── class-spec.md │ ├── conf_mat.md │ ├── deprecated-template.md │ ├── error-handling.md │ ├── estimator-helpers.md │ ├── event-level.md │ ├── fair-aaa.md │ ├── flatten.md │ ├── metric-tweak.md │ ├── misc.md │ ├── num-huber_loss.md │ ├── num-mase.md │ ├── num-pseudo_huber_loss.md │ ├── num-rsq.md │ ├── orderedprob-ranked_prob_score.md │ ├── prob-average_precision.md │ ├── prob-brier_class.md │ ├── prob-classification_cost.md │ ├── prob-gain_capture.md │ ├── prob-gain_curve.md │ ├── prob-lift_curve.md │ ├── prob-mn_log_loss.md │ ├── prob-pr_auc.md │ ├── prob-pr_curve.md │ ├── prob-roc_auc.md │ ├── prob-roc_aunp.md │ ├── prob-roc_aunu.md │ ├── prob-roc_curve.md │ ├── probably.md │ ├── surv-brier_survival_integrated.md │ ├── template.md │ └── validation.md │ ├── data │ ├── auc_churn_res.rds │ ├── brier_churn_res.rds │ ├── helper-pROC-two-class-example-curve.rds │ ├── helper-soybean.rds │ ├── helper-three-class-helpers.rds │ ├── ref_roc_auc_survival.rds │ ├── ref_roc_curve_survival.rds │ ├── rr_churn_data.rds │ ├── test_autoplot.rds │ ├── tidy_churn.rds │ ├── weights-hpc-cv.rds │ ├── weights-solubility-test.rds │ └── weights-two-class-example.rds │ ├── helper-data.R │ ├── helper-macro-micro.R │ ├── helper-macro-prob.R │ ├── helper-numeric.R │ ├── helper-pROC.R │ ├── helper-read_pydata.R │ ├── helper-weights.R │ ├── py-data │ ├── py-accuracy.rds │ ├── py-average-precision.rds │ ├── py-bal-accuracy.rds │ ├── py-brier-survival.rds │ ├── py-demographic_parity.rds │ ├── py-equal_opportunity.rds │ ├── py-equalized_odds.rds │ ├── py-f_meas.rds │ ├── py-f_meas_beta_.5.rds │ ├── py-kap.rds │ ├── py-mae.rds │ ├── py-mape.rds │ ├── py-mcc.rds │ ├── py-mn_log_loss.rds │ ├── py-npv.rds │ ├── py-ppv.rds │ ├── py-pr-curve.rds │ ├── py-precision.rds │ ├── py-recall.rds │ ├── py-rmse.rds │ ├── py-roc-auc.rds │ ├── py-roc-curve.rds │ └── py-rsq-trad.rds │ ├── test-aaa-metrics.R │ ├── test-aaa-new.R │ ├── test-auc.R │ ├── test-autoplot.R │ ├── test-check_metric.R │ ├── test-class-accuracy.R │ ├── test-class-bal_accuracy.R │ ├── test-class-detection_prevalence.R │ ├── test-class-f_meas.R │ ├── test-class-j_index.R │ ├── test-class-kap.R │ ├── test-class-mcc.R │ ├── test-class-npv.R │ ├── test-class-ppv.R │ ├── test-class-precision.R │ ├── test-class-recall.R │ ├── test-class-sens.R │ ├── test-class-spec.R │ ├── test-conf_mat.R │ ├── test-deprecated-template.R │ ├── test-error-handling.R │ ├── test-estimator-helpers.R │ ├── test-event-level.R │ ├── test-fair-aaa.R │ ├── test-fair-demographic_parity.R │ ├── test-fair-equal_opportunity.R │ ├── test-fair-equalized_odds.R │ ├── test-flatten.R │ ├── test-global-option.R │ ├── test-handle_missings.R │ ├── test-metric-tweak.R │ ├── test-misc.R │ ├── test-num-ccc.R │ ├── test-num-huber_loss.R │ ├── test-num-iic.R │ ├── test-num-mae.R │ ├── test-num-mape.R │ ├── test-num-mase.R │ ├── test-num-mpe.R │ ├── test-num-msd.R │ ├── test-num-poisson_log_loss.R │ ├── test-num-pseudo_huber_loss.R │ ├── test-num-rmse.R │ ├── test-num-rpd.R │ ├── test-num-rpiq.R │ ├── test-num-rsq.R │ ├── test-num-rsq_trad.R │ ├── test-num-smape.R │ ├── test-orderedprob-ranked_prob_score.R │ ├── test-prob-average_precision.R │ ├── test-prob-brier_class.R │ ├── test-prob-classification_cost.R │ ├── test-prob-gain_capture.R │ ├── test-prob-gain_curve.R │ ├── test-prob-lift_curve.R │ ├── test-prob-mn_log_loss.R │ ├── test-prob-pr_auc.R │ ├── test-prob-pr_curve.R │ ├── test-prob-roc_auc.R │ ├── test-prob-roc_aunp.R │ ├── test-prob-roc_aunu.R │ ├── test-prob-roc_curve.R │ ├── test-probably.R │ ├── test-surv-brier_survival.R │ ├── test-surv-brier_survival_integrated.R │ ├── test-surv-concordance_survival.R │ ├── test-surv-roc_auc_survival.R │ ├── test-surv-roc_curve_survival.R │ ├── test-template.R │ └── test-validation.R ├── vignettes ├── .gitignore ├── grouping.Rmd ├── metric-types.Rmd └── multiclass.Rmd └── yardstick.Rproj /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^revdep$ 2 | ^CRAN-RELEASE$ 3 | ^cran-comments\.md$ 4 | ^.*\.Rproj$ 5 | ^\.Rproj\.user$ 6 | ^\.travis\.yml$ 7 | ^codecov\.yml$ 8 | ^docs$ 9 | README.Rmd 10 | README.html 11 | ^_pkgdown\.yml$ 12 | contributors.md 13 | ^pkgdown$ 14 | ^man-roxygen$ 15 | ^\.github$ 16 | ^CODE_OF_CONDUCT\.md$ 17 | ^LICENSE\.md$ 18 | ^CRAN-SUBMISSION$ 19 | ^MAINTENANCE\.md$ 20 | ^data-raw$ 21 | ^[\.]?air\.toml$ 22 | ^\.vscode$ 23 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to tidymodels 2 | 3 | For more detailed information about contributing to tidymodels packages, see our [**development contributing guide**](https://www.tidymodels.org/contribute/). 4 | 5 | ## Documentation 6 | 7 | Typos or grammatical errors in documentation may be edited directly using the GitHub web interface, as long as the changes are made in the _source_ file. 8 | 9 | * YES ✅: you edit a roxygen comment in an `.R` file in the `R/` directory. 10 | * NO 🚫: you edit an `.Rd` file in the `man/` directory. 11 | 12 | We use [roxygen2](https://cran.r-project.org/package=roxygen2), with [Markdown syntax](https://cran.r-project.org/web/packages/roxygen2/vignettes/rd-formatting.html), for documentation. 13 | 14 | ## Code 15 | 16 | Before you submit 🎯 a pull request on a tidymodels package, always file an issue and confirm the tidymodels team agrees with your idea and is happy with your basic proposal. 17 | 18 | The [tidymodels packages](https://www.tidymodels.org/packages/) work together. Each package contains its own unit tests, while integration tests and other tests using all the packages are contained in [extratests](https://github.com/tidymodels/extratests). 19 | 20 | * For pull requests, we recommend that you [create a fork of this repo](https://usethis.r-lib.org/articles/articles/pr-functions.html) with `usethis::create_from_github()`, and then initiate a new branch with `usethis::pr_init()`. 21 | * Look at the build status before and after making changes. The `README` contains badges for any continuous integration services used by the package. 22 | * New code should follow the tidyverse [style guide](http://style.tidyverse.org). You can use the [styler](https://CRAN.R-project.org/package=styler) package to apply these styles, but please don't restyle code that has nothing to do with your PR. 23 | * For user-facing changes, add a bullet to the top of `NEWS.md` below the current development version header describing the changes made followed by your GitHub username, and links to relevant issue(s)/PR(s). 24 | * We use [testthat](https://cran.r-project.org/package=testthat). Contributions with test cases included are easier to accept. 25 | * If your contribution spans the use of more than one package, consider building [extratests](https://github.com/tidymodels/extratests) with your changes to check for breakages and/or adding new tests there. Let us know in your PR if you ran these extra tests. 26 | * Here in the yardstick package, some test objects are created via helper functions in the `tests/testthat/` folder. 27 | 28 | ### Code of Conduct 29 | 30 | This project is released with a [Contributor Code of Conduct](https://contributor-covenant.org/version/2/0/CODE_OF_CONDUCT.html). By contributing to this project, you agree to abide by its terms. 31 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check-no-suggests.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | # 4 | # NOTE: This workflow only directly installs "hard" dependencies, i.e. Depends, 5 | # Imports, and LinkingTo dependencies. Notably, Suggests dependencies are never 6 | # installed, with the exception of testthat, knitr, and rmarkdown. The cache is 7 | # never used to avoid accidentally restoring a cache containing a suggested 8 | # dependency. 9 | on: 10 | push: 11 | branches: [main, master] 12 | pull_request: 13 | 14 | name: R-CMD-check-no-suggests.yaml 15 | 16 | permissions: read-all 17 | 18 | jobs: 19 | check-no-suggests: 20 | runs-on: ${{ matrix.config.os }} 21 | 22 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 23 | 24 | strategy: 25 | fail-fast: false 26 | matrix: 27 | config: 28 | - {os: ubuntu-latest, r: 'release'} 29 | 30 | env: 31 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 32 | R_KEEP_PKG_SOURCE: yes 33 | 34 | steps: 35 | - uses: actions/checkout@v4 36 | 37 | - uses: r-lib/actions/setup-pandoc@v2 38 | 39 | - uses: r-lib/actions/setup-r@v2 40 | with: 41 | r-version: ${{ matrix.config.r }} 42 | http-user-agent: ${{ matrix.config.http-user-agent }} 43 | use-public-rspm: true 44 | 45 | - uses: r-lib/actions/setup-r-dependencies@v2 46 | with: 47 | dependencies: '"hard"' 48 | cache: false 49 | extra-packages: | 50 | any::rcmdcheck 51 | any::testthat 52 | any::knitr 53 | any::rmarkdown 54 | needs: check 55 | 56 | - uses: r-lib/actions/check-r-package@v2 57 | with: 58 | upload-snapshots: true 59 | build_args: 'c("--no-manual","--compact-vignettes=gs+qpdf")' 60 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | # 4 | # NOTE: This workflow is overkill for most R packages and 5 | # check-standard.yaml is likely a better choice. 6 | # usethis::use_github_action("check-standard") will install it. 7 | on: 8 | push: 9 | branches: [main, master] 10 | pull_request: 11 | 12 | name: R-CMD-check.yaml 13 | 14 | permissions: read-all 15 | 16 | jobs: 17 | R-CMD-check: 18 | runs-on: ${{ matrix.config.os }} 19 | 20 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | config: 26 | - {os: macos-latest, r: 'release'} 27 | 28 | - {os: windows-latest, r: 'release'} 29 | # use 4.0 or 4.1 to check with rtools40's older compiler 30 | - {os: windows-latest, r: 'oldrel-4'} 31 | 32 | - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} 33 | - {os: ubuntu-latest, r: 'release'} 34 | - {os: ubuntu-latest, r: 'oldrel-1'} 35 | - {os: ubuntu-latest, r: 'oldrel-2'} 36 | - {os: ubuntu-latest, r: 'oldrel-3'} 37 | - {os: ubuntu-latest, r: 'oldrel-4'} 38 | 39 | env: 40 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 41 | R_KEEP_PKG_SOURCE: yes 42 | 43 | steps: 44 | - uses: actions/checkout@v4 45 | 46 | - uses: r-lib/actions/setup-pandoc@v2 47 | 48 | - uses: r-lib/actions/setup-r@v2 49 | with: 50 | r-version: ${{ matrix.config.r }} 51 | http-user-agent: ${{ matrix.config.http-user-agent }} 52 | use-public-rspm: true 53 | 54 | - uses: r-lib/actions/setup-r-dependencies@v2 55 | with: 56 | extra-packages: any::rcmdcheck 57 | needs: check 58 | 59 | - uses: r-lib/actions/check-r-package@v2 60 | with: 61 | upload-snapshots: true 62 | build_args: 'c("--no-manual","--compact-vignettes=gs+qpdf")' 63 | -------------------------------------------------------------------------------- /.github/workflows/lock.yaml: -------------------------------------------------------------------------------- 1 | name: 'Lock Threads' 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 * * *' 6 | 7 | jobs: 8 | lock: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: dessant/lock-threads@v2 12 | with: 13 | github-token: ${{ github.token }} 14 | issue-lock-inactive-days: '14' 15 | # issue-exclude-labels: '' 16 | # issue-lock-labels: 'outdated' 17 | issue-lock-comment: > 18 | This issue has been automatically locked. If you believe you have 19 | found a related problem, please file a new issue (with a reprex: 20 | ) and link to this issue. 21 | issue-lock-reason: '' 22 | pr-lock-inactive-days: '14' 23 | # pr-exclude-labels: 'wip' 24 | pr-lock-labels: '' 25 | pr-lock-comment: > 26 | This pull request has been automatically locked. If you believe you 27 | have found a related problem, please file a new issue (with a reprex: 28 | ) and link to this issue. 29 | pr-lock-reason: '' 30 | # process-only: 'issues' 31 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | release: 8 | types: [published] 9 | workflow_dispatch: 10 | 11 | name: pkgdown.yaml 12 | 13 | permissions: read-all 14 | 15 | jobs: 16 | pkgdown: 17 | runs-on: ubuntu-latest 18 | # Only restrict concurrency for non-PR jobs 19 | concurrency: 20 | group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }} 21 | env: 22 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 23 | permissions: 24 | contents: write 25 | steps: 26 | - uses: actions/checkout@v4 27 | 28 | - uses: r-lib/actions/setup-pandoc@v2 29 | 30 | - uses: r-lib/actions/setup-r@v2 31 | with: 32 | use-public-rspm: true 33 | 34 | - uses: r-lib/actions/setup-r-dependencies@v2 35 | with: 36 | extra-packages: any::pkgdown, local::. 37 | needs: website 38 | 39 | - name: Build site 40 | run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE) 41 | shell: Rscript {0} 42 | 43 | - name: Deploy to GitHub pages 🚀 44 | if: github.event_name != 'pull_request' 45 | uses: JamesIves/github-pages-deploy-action@v4.5.0 46 | with: 47 | clean: false 48 | branch: gh-pages 49 | folder: docs 50 | -------------------------------------------------------------------------------- /.github/workflows/pr-commands.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | issue_comment: 5 | types: [created] 6 | 7 | name: pr-commands.yaml 8 | 9 | permissions: read-all 10 | 11 | jobs: 12 | document: 13 | if: ${{ github.event.issue.pull_request && (github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') && startsWith(github.event.comment.body, '/document') }} 14 | name: document 15 | runs-on: ubuntu-latest 16 | env: 17 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 18 | permissions: 19 | contents: write 20 | steps: 21 | - uses: actions/checkout@v4 22 | 23 | - uses: r-lib/actions/pr-fetch@v2 24 | with: 25 | repo-token: ${{ secrets.GITHUB_TOKEN }} 26 | 27 | - uses: r-lib/actions/setup-r@v2 28 | with: 29 | use-public-rspm: true 30 | 31 | - uses: r-lib/actions/setup-r-dependencies@v2 32 | with: 33 | extra-packages: any::roxygen2 34 | needs: pr-document 35 | 36 | - name: Document 37 | run: roxygen2::roxygenise() 38 | shell: Rscript {0} 39 | 40 | - name: commit 41 | run: | 42 | git config --local user.name "$GITHUB_ACTOR" 43 | git config --local user.email "$GITHUB_ACTOR@users.noreply.github.com" 44 | git add man/\* NAMESPACE 45 | git commit -m 'Document' 46 | 47 | - uses: r-lib/actions/pr-push@v2 48 | with: 49 | repo-token: ${{ secrets.GITHUB_TOKEN }} 50 | 51 | style: 52 | if: ${{ github.event.issue.pull_request && (github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') && startsWith(github.event.comment.body, '/style') }} 53 | name: style 54 | runs-on: ubuntu-latest 55 | env: 56 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 57 | permissions: 58 | contents: write 59 | steps: 60 | - uses: actions/checkout@v4 61 | 62 | - uses: r-lib/actions/pr-fetch@v2 63 | with: 64 | repo-token: ${{ secrets.GITHUB_TOKEN }} 65 | 66 | - uses: r-lib/actions/setup-r@v2 67 | 68 | - name: Install dependencies 69 | run: install.packages("styler") 70 | shell: Rscript {0} 71 | 72 | - name: Style 73 | run: styler::style_pkg() 74 | shell: Rscript {0} 75 | 76 | - name: commit 77 | run: | 78 | git config --local user.name "$GITHUB_ACTOR" 79 | git config --local user.email "$GITHUB_ACTOR@users.noreply.github.com" 80 | git add \*.R 81 | git commit -m 'Style' 82 | 83 | - uses: r-lib/actions/pr-push@v2 84 | with: 85 | repo-token: ${{ secrets.GITHUB_TOKEN }} 86 | -------------------------------------------------------------------------------- /.github/workflows/test-coverage.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | 8 | name: test-coverage.yaml 9 | 10 | permissions: read-all 11 | 12 | jobs: 13 | test-coverage: 14 | runs-on: ubuntu-latest 15 | env: 16 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | 21 | - uses: r-lib/actions/setup-r@v2 22 | with: 23 | use-public-rspm: true 24 | 25 | - uses: r-lib/actions/setup-r-dependencies@v2 26 | with: 27 | extra-packages: any::covr, any::xml2 28 | needs: coverage 29 | 30 | - name: Test coverage 31 | run: | 32 | cov <- covr::package_coverage( 33 | quiet = FALSE, 34 | clean = FALSE, 35 | install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package") 36 | ) 37 | print(cov) 38 | covr::to_cobertura(cov) 39 | shell: Rscript {0} 40 | 41 | - uses: codecov/codecov-action@v5 42 | with: 43 | # Fail if error if not on PR, or if on PR and token is given 44 | fail_ci_if_error: ${{ github.event_name != 'pull_request' || secrets.CODECOV_TOKEN }} 45 | files: ./cobertura.xml 46 | plugins: noop 47 | disable_search: true 48 | token: ${{ secrets.CODECOV_TOKEN }} 49 | 50 | - name: Show testthat output 51 | if: always() 52 | run: | 53 | ## -------------------------------------------------------------------- 54 | find '${{ runner.temp }}/package' -name 'testthat.Rout*' -exec cat '{}' \; || true 55 | shell: bash 56 | 57 | - name: Upload test results 58 | if: failure() 59 | uses: actions/upload-artifact@v4 60 | with: 61 | name: coverage-test-failures 62 | path: ${{ runner.temp }}/package 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | inst/doc 2 | .Rproj.user 3 | .Rhistory 4 | .RData 5 | .Ruserdata 6 | .DS_Store 7 | docs 8 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "Posit.air-vscode" 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[r]": { 3 | "editor.formatOnSave": true, 4 | "editor.defaultFormatter": "Posit.air-vscode" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2025 2 | COPYRIGHT HOLDER: yardstick authors 3 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2025 yardstick authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /R/aaa.R: -------------------------------------------------------------------------------- 1 | # nocov start 2 | 3 | # Global vars ------------------------------------------------------------------ 4 | 5 | utils::globalVariables( 6 | c( 7 | # for class prob metrics 8 | "estimate", 9 | ".estimator", 10 | "threshold", 11 | "specificity", 12 | ".level", 13 | ".", 14 | 15 | # for survival metrics 16 | ".estimate", 17 | ".eval_time", 18 | ".pred_survival", 19 | ".weight_censored", 20 | 21 | # for autoplot methods 22 | ".n_events", 23 | ".n", 24 | "slope", 25 | "perfect", 26 | "sensitivity", 27 | ".percent_found", 28 | ".percent_tested", 29 | "Prediction", 30 | "Truth", 31 | "Freq", 32 | "xmin", 33 | "xmax", 34 | "ymin", 35 | "ymax" 36 | ) 37 | ) 38 | 39 | # Onload ----------------------------------------------------------------------- 40 | 41 | ## Taken from https://github.com/tidyverse/dplyr/blob/d310ad1cef1c14d770c94e1a9a4c79c888f46af6/R/zzz.r#L2-L9 42 | 43 | .onLoad <- function(libname, pkgname) { 44 | # dynamically register autoplot methods 45 | s3_register("ggplot2::autoplot", "gain_df") 46 | s3_register("ggplot2::autoplot", "lift_df") 47 | s3_register("ggplot2::autoplot", "roc_df") 48 | s3_register("ggplot2::autoplot", "roc_survival_df") 49 | s3_register("ggplot2::autoplot", "pr_df") 50 | s3_register("ggplot2::autoplot", "conf_mat") 51 | 52 | invisible() 53 | } 54 | 55 | # Dynamic reg helper ----------------------------------------------------------- 56 | 57 | # vctrs/register-s3.R 58 | # https://github.com/r-lib/vctrs/blob/master/R/register-s3.R 59 | s3_register <- function(generic, class, method = NULL) { 60 | stopifnot(is.character(generic), length(generic) == 1) 61 | stopifnot(is.character(class), length(class) == 1) 62 | 63 | pieces <- strsplit(generic, "::")[[1]] 64 | stopifnot(length(pieces) == 2) 65 | package <- pieces[[1]] 66 | generic <- pieces[[2]] 67 | 68 | if (is.null(method)) { 69 | method <- get(paste0(generic, ".", class), envir = parent.frame()) 70 | } 71 | stopifnot(is.function(method)) 72 | 73 | if (package %in% loadedNamespaces()) { 74 | registerS3method(generic, class, method, envir = asNamespace(package)) 75 | } 76 | 77 | # Always register hook in case package is later unloaded & reloaded 78 | setHook( 79 | packageEvent(package, "onLoad"), 80 | function(...) { 81 | registerS3method(generic, class, method, envir = asNamespace(package)) 82 | } 83 | ) 84 | } 85 | 86 | # nocov end 87 | -------------------------------------------------------------------------------- /R/class-accuracy.R: -------------------------------------------------------------------------------- 1 | #' Accuracy 2 | #' 3 | #' Accuracy is the proportion of the data that are predicted correctly. 4 | #' 5 | #' @family class metrics 6 | #' @templateVar fn accuracy 7 | #' @template return 8 | #' 9 | #' @section Multiclass: 10 | #' 11 | #' Accuracy extends naturally to multiclass scenarios. Because 12 | #' of this, macro and micro averaging are not implemented. 13 | #' 14 | #' @inheritParams sens 15 | #' 16 | #' @author Max Kuhn 17 | #' 18 | #' @export 19 | #' @examples 20 | #' library(dplyr) 21 | #' data("two_class_example") 22 | #' data("hpc_cv") 23 | #' 24 | #' # Two class 25 | #' accuracy(two_class_example, truth, predicted) 26 | #' 27 | #' # Multiclass 28 | #' # accuracy() has a natural multiclass extension 29 | #' hpc_cv |> 30 | #' filter(Resample == "Fold01") |> 31 | #' accuracy(obs, pred) 32 | #' 33 | #' # Groups are respected 34 | #' hpc_cv |> 35 | #' group_by(Resample) |> 36 | #' accuracy(obs, pred) 37 | accuracy <- function(data, ...) { 38 | UseMethod("accuracy") 39 | } 40 | accuracy <- new_class_metric( 41 | accuracy, 42 | direction = "maximize" 43 | ) 44 | 45 | #' @export 46 | #' @rdname accuracy 47 | accuracy.data.frame <- function( 48 | data, 49 | truth, 50 | estimate, 51 | na_rm = TRUE, 52 | case_weights = NULL, 53 | ... 54 | ) { 55 | class_metric_summarizer( 56 | name = "accuracy", 57 | fn = accuracy_vec, 58 | data = data, 59 | truth = !!enquo(truth), 60 | estimate = !!enquo(estimate), 61 | na_rm = na_rm, 62 | case_weights = !!enquo(case_weights) 63 | ) 64 | } 65 | 66 | #' @export 67 | accuracy.table <- function(data, ...) { 68 | check_table(data) 69 | estimator <- finalize_estimator(data, metric_class = "accuracy") 70 | 71 | metric_tibbler( 72 | .metric = "accuracy", 73 | .estimator = estimator, 74 | .estimate = accuracy_table_impl(data) 75 | ) 76 | } 77 | 78 | #' @export 79 | accuracy.matrix <- function(data, ...) { 80 | data <- as.table(data) 81 | accuracy.table(data) 82 | } 83 | 84 | #' @export 85 | #' @rdname accuracy 86 | accuracy_vec <- function( 87 | truth, 88 | estimate, 89 | na_rm = TRUE, 90 | case_weights = NULL, 91 | ... 92 | ) { 93 | abort_if_class_pred(truth) 94 | estimate <- as_factor_from_class_pred(estimate) 95 | 96 | estimator <- finalize_estimator(truth, metric_class = "accuracy") 97 | check_class_metric(truth, estimate, case_weights, estimator) 98 | 99 | if (na_rm) { 100 | result <- yardstick_remove_missing(truth, estimate, case_weights) 101 | 102 | truth <- result$truth 103 | estimate <- result$estimate 104 | case_weights <- result$case_weights 105 | } else if (yardstick_any_missing(truth, estimate, case_weights)) { 106 | return(NA_real_) 107 | } 108 | 109 | data <- yardstick_table(truth, estimate, case_weights = case_weights) 110 | accuracy_table_impl(data) 111 | } 112 | 113 | accuracy_table_impl <- function(x) { 114 | sum(diag(x)) / sum(x) 115 | } 116 | -------------------------------------------------------------------------------- /R/deprecated-prob_helpers.R: -------------------------------------------------------------------------------- 1 | # nocov start 2 | 3 | # `...` -> estimate matrix / vector helper ------------------------------------- 4 | 5 | #' Developer helpers 6 | #' 7 | #' Helpers to be used alongside [check_metric], [yardstick_remove_missing] and 8 | #' [metric summarizers][class_metric_summarizer()] when creating new metrics. 9 | #' See [Custom performance 10 | #' metrics](https://www.tidymodels.org/learn/develop/metrics/) for more 11 | #' information. 12 | #' 13 | #' @section Dots -> Estimate: 14 | #' `r lifecycle::badge("deprecated")` 15 | #' 16 | #' `dots_to_estimate()` is useful with class probability metrics that take 17 | #' `...` rather than `estimate` as an argument. It constructs either a single 18 | #' name if 1 input is provided to `...` or it constructs a quosure where the 19 | #' expression constructs a matrix of as many columns as are provided to `...`. 20 | #' These are eventually evaluated in the `summarise()` call in 21 | #' [metric-summarizers] and evaluate to either a vector or a matrix for 22 | #' further use in the underlying vector functions. 23 | #' 24 | #' 25 | #' @name developer-helpers 26 | #' 27 | #' @aliases dots_to_estimate 28 | #' 29 | #' @export 30 | #' 31 | #' @inheritParams roc_auc 32 | dots_to_estimate <- function(data, ...) { 33 | lifecycle::deprecate_soft( 34 | when = "1.2.0", 35 | what = "dots_to_estimate()", 36 | details = I( 37 | paste( 38 | "No longer needed with", 39 | "`prob_metric_summarizer()`, or `curve_metric_summarizer()`." 40 | ) 41 | ) 42 | ) 43 | 44 | # Capture dots 45 | dot_vars <- with_handlers( 46 | tidyselect::vars_select(names(data), !!!enquos(...)), 47 | tidyselect_empty_dots = function(cnd) { 48 | abort("No valid variables provided to `...`.") 49 | } 50 | ) 51 | 52 | # estimate is a matrix of the selected columns if >1 selected 53 | dot_nms <- lapply(dot_vars, as.name) 54 | 55 | if (length(dot_nms) > 1) { 56 | estimate <- quo( 57 | matrix( 58 | data = c(!!!dot_nms), 59 | ncol = !!length(dot_nms), 60 | dimnames = list(NULL, !!dot_vars) 61 | ) 62 | ) 63 | } else { 64 | estimate <- dot_nms[[1]] 65 | } 66 | 67 | estimate 68 | } 69 | 70 | # nocov end 71 | -------------------------------------------------------------------------------- /R/event-level.R: -------------------------------------------------------------------------------- 1 | # Internal helper to query a default `event_level` 2 | # 3 | # 1) Respect `yardstick.event_first` if set, but warn about deprecation 4 | # 2) Return `"first"` otherwise as the default event level 5 | # 6 | # Metric functions that use this helper can completely ignore the global option 7 | # by setting the `event_first` argument to `"first"` or `"second"` directly. 8 | yardstick_event_level <- function() { 9 | opt <- getOption("yardstick.event_first") 10 | 11 | if (!is.null(opt)) { 12 | lifecycle::deprecate_warn( 13 | when = "0.0.7", 14 | what = I("The global option `yardstick.event_first`"), 15 | with = I("the metric function argument `event_level`"), 16 | details = "The global option is being ignored entirely." 17 | ) 18 | } 19 | 20 | "first" 21 | } 22 | 23 | is_event_first <- function(event_level) { 24 | validate_event_level(event_level) 25 | identical(event_level, "first") 26 | } 27 | 28 | validate_event_level <- function(event_level) { 29 | if (identical(event_level, "first")) { 30 | return(invisible()) 31 | } 32 | if (identical(event_level, "second")) { 33 | return(invisible()) 34 | } 35 | 36 | cli::cli_abort("{.arg event_level} must be {.val first} or {.val second}.") 37 | } 38 | -------------------------------------------------------------------------------- /R/fair-demographic_parity.R: -------------------------------------------------------------------------------- 1 | #' Demographic parity 2 | #' 3 | #' @description 4 | #' Demographic parity is satisfied when a model's predictions have the 5 | #' same predicted positive rate across groups. A value of 0 indicates parity 6 | #' across groups. Note that this definition does not depend on the true 7 | #' outcome; the `truth` argument is included in outputted metrics 8 | #' for consistency. 9 | #' 10 | #' `demographic_parity()` is calculated as the difference between the largest 11 | #' and smallest value of [detection_prevalence()] across groups. 12 | #' 13 | #' Demographic parity is sometimes referred to as group fairness, 14 | #' disparate impact, or statistical parity. 15 | #' 16 | #' See the "Measuring Disparity" section for details on implementation. 17 | #' 18 | #' @param by The column identifier for the sensitive feature. This should be an 19 | #' unquoted column name referring to a column in the un-preprocessed data. 20 | #' 21 | #' @templateVar fn demographic_parity 22 | #' @templateVar internal_fn detection_prevalence 23 | #' @template return-fair 24 | #' @template event-fair 25 | #' @template examples-fair 26 | #' 27 | #' @family fairness metrics 28 | #' 29 | #' @references 30 | #' 31 | #' Agarwal, A., Beygelzimer, A., Dudik, M., Langford, J., & Wallach, H. (2018). 32 | #' "A Reductions Approach to Fair Classification." Proceedings of the 35th 33 | #' International Conference on Machine Learning, in Proceedings of Machine 34 | #' Learning Research. 80:60-69. 35 | #' 36 | #' Verma, S., & Rubin, J. (2018). "Fairness definitions explained". In 37 | #' Proceedings of the international workshop on software fairness (pp. 1-7). 38 | #' 39 | #' Bird, S., Dudík, M., Edgar, R., Horn, B., Lutz, R., Milan, V., ... & Walker, 40 | #' K. (2020). "Fairlearn: A toolkit for assessing and improving fairness in AI". 41 | #' Microsoft, Tech. Rep. MSR-TR-2020-32. 42 | #' 43 | #' @export 44 | demographic_parity <- 45 | new_groupwise_metric( 46 | fn = detection_prevalence, 47 | name = "demographic_parity", 48 | aggregate = diff_range 49 | ) 50 | -------------------------------------------------------------------------------- /R/fair-equal_opportunity.R: -------------------------------------------------------------------------------- 1 | #' Equal opportunity 2 | #' 3 | #' @description 4 | #' 5 | #' Equal opportunity is satisfied when a model's predictions have the same 6 | #' true positive and false negative rates across protected groups. A value of 7 | #' 0 indicates parity across groups. 8 | #' 9 | #' `equal_opportunity()` is calculated as the difference between the largest 10 | #' and smallest value of [sens()] across groups. 11 | #' 12 | #' Equal opportunity is sometimes referred to as equality of opportunity. 13 | #' 14 | #' See the "Measuring Disparity" section for details on implementation. 15 | #' 16 | #' @inheritParams demographic_parity 17 | #' 18 | #' @templateVar fn equal_opportunity 19 | #' @templateVar internal_fn sens 20 | #' @template return-fair 21 | #' @template event-fair 22 | #' @template examples-fair 23 | #' 24 | #' @family fairness metrics 25 | #' 26 | #' @references 27 | #' 28 | #' Hardt, M., Price, E., & Srebro, N. (2016). "Equality of opportunity in 29 | #' supervised learning". Advances in neural information processing systems, 29. 30 | #' 31 | #' Verma, S., & Rubin, J. (2018). "Fairness definitions explained". In 32 | #' Proceedings of the international workshop on software fairness (pp. 1-7). 33 | #' 34 | #' Bird, S., Dudík, M., Edgar, R., Horn, B., Lutz, R., Milan, V., ... & Walker, 35 | #' K. (2020). "Fairlearn: A toolkit for assessing and improving fairness in AI". 36 | #' Microsoft, Tech. Rep. MSR-TR-2020-32. 37 | #' 38 | #' @export 39 | equal_opportunity <- 40 | new_groupwise_metric( 41 | fn = sens, 42 | name = "equal_opportunity", 43 | aggregate = diff_range 44 | ) 45 | -------------------------------------------------------------------------------- /R/fair-equalized_odds.R: -------------------------------------------------------------------------------- 1 | max_positive_rate_diff <- function(x) { 2 | metric_values <- vec_split(x, x$.metric) 3 | 4 | positive_rate_diff <- vapply(metric_values$val, diff_range, numeric(1)) 5 | 6 | max(positive_rate_diff) 7 | } 8 | 9 | #' Equalized odds 10 | #' 11 | #' @description 12 | #' 13 | #' Equalized odds is satisfied when a model's predictions have the same false 14 | #' positive, true positive, false negative, and true negative rates across 15 | #' protected groups. A value of 0 indicates parity across groups. 16 | #' 17 | #' By default, this function takes the maximum difference in range of [sens()] 18 | #' and [spec()] `.estimate`s across groups. That is, the maximum pair-wise 19 | #' disparity in [sens()] or [spec()] between groups is the return value of 20 | #' `equalized_odds()`'s `.estimate`. 21 | #' 22 | #' Equalized odds is sometimes referred to as conditional procedure accuracy 23 | #' equality or disparate mistreatment. 24 | #' 25 | #' See the "Measuring disparity" section for details on implementation. 26 | #' 27 | #' @inheritParams demographic_parity 28 | #' 29 | #' @templateVar fn equalized_odds 30 | #' @templateVar internal_fn [sens()] and [spec()] 31 | #' @template return-fair 32 | #' @template examples-fair 33 | #' 34 | #' @section Measuring Disparity: 35 | #' For finer control of group treatment, construct a context-aware fairness 36 | #' metric with the [new_groupwise_metric()] function by passing a custom `aggregate` 37 | #' function: 38 | #' 39 | #' ``` 40 | #' # see yardstick:::max_positive_rate_diff for the actual `aggregate()` 41 | #' diff_range <- function(x, ...) {diff(range(x$.estimate))} 42 | #' 43 | #' equalized_odds_2 <- 44 | #' new_groupwise_metric( 45 | #' fn = metric_set(sens, spec), 46 | #' name = "equalized_odds_2", 47 | #' aggregate = diff_range 48 | #' ) 49 | #' ``` 50 | #' 51 | #' In `aggregate()`, `x` is the [metric_set()] output with [sens()] and [spec()] 52 | #' values for each group, and `...` gives additional arguments (such as a grouping 53 | #' level to refer to as the "baseline") to pass to the function outputted 54 | #' by `equalized_odds_2()` for context. 55 | #' 56 | #' @family fairness metrics 57 | #' 58 | #' @references 59 | #' 60 | #' Agarwal, A., Beygelzimer, A., Dudik, M., Langford, J., & Wallach, H. (2018). 61 | #' "A Reductions Approach to Fair Classification." Proceedings of the 35th 62 | #' International Conference on Machine Learning, in Proceedings of Machine 63 | #' Learning Research. 80:60-69. 64 | #' 65 | #' Verma, S., & Rubin, J. (2018). "Fairness definitions explained". In 66 | #' Proceedings of the international workshop on software fairness (pp. 1-7). 67 | #' 68 | #' Bird, S., Dudík, M., Edgar, R., Horn, B., Lutz, R., Milan, V., ... & Walker, 69 | #' K. (2020). "Fairlearn: A toolkit for assessing and improving fairness in AI". 70 | #' Microsoft, Tech. Rep. MSR-TR-2020-32. 71 | #' 72 | #' @export 73 | equalized_odds <- 74 | new_groupwise_metric( 75 | fn = metric_set(sens, spec), 76 | name = "equalized_odds", 77 | aggregate = max_positive_rate_diff 78 | ) 79 | -------------------------------------------------------------------------------- /R/missings.R: -------------------------------------------------------------------------------- 1 | #' Developer function for handling missing values in new metrics 2 | #' 3 | #' `yardstick_remove_missing()`, and `yardstick_any_missing()` are useful 4 | #' alongside the [metric-summarizers] functions for implementing new custom 5 | #' metrics. `yardstick_remove_missing()` removes any observations that contains 6 | #' missing values across, `truth`, `estimate` and `case_weights`. 7 | #' `yardstick_any_missing()` returns `FALSE` if there is any missing values in 8 | #' the inputs. 9 | #' 10 | #' @param truth,estimate Vectors of the same length. 11 | #' 12 | #' @param case_weights A vector of the same length as `truth` and `estimate`, or 13 | #' `NULL` if case weights are not being used. 14 | #' 15 | #' @seealso [metric-summarizers] 16 | #' 17 | #' @name yardstick_remove_missing 18 | NULL 19 | 20 | #' @rdname yardstick_remove_missing 21 | #' @export 22 | yardstick_remove_missing <- function(truth, estimate, case_weights) { 23 | complete_cases <- stats::complete.cases( 24 | truth, 25 | estimate, 26 | case_weights 27 | ) 28 | 29 | if (.is_surv(truth, fail = FALSE)) { 30 | Surv_type <- .extract_surv_type(truth) 31 | 32 | truth <- truth[complete_cases, ] 33 | 34 | attr(truth, "type") <- Surv_type 35 | attr(truth, "class") <- "Surv" 36 | } else { 37 | truth <- truth[complete_cases] 38 | } 39 | 40 | if (is.matrix(estimate)) { 41 | estimate <- estimate[complete_cases, , drop = FALSE] 42 | } else { 43 | estimate <- estimate[complete_cases] 44 | } 45 | 46 | case_weights <- case_weights[complete_cases] 47 | 48 | list( 49 | truth = truth, 50 | estimate = estimate, 51 | case_weights = case_weights 52 | ) 53 | } 54 | 55 | #' @rdname yardstick_remove_missing 56 | #' @export 57 | yardstick_any_missing <- function(truth, estimate, case_weights) { 58 | anyNA(truth) || 59 | anyNA(estimate) || 60 | (!is.null(case_weights) && anyNA(case_weights)) 61 | } 62 | -------------------------------------------------------------------------------- /R/num-huber_loss.R: -------------------------------------------------------------------------------- 1 | #' Huber loss 2 | #' 3 | #' Calculate the Huber loss, a loss function used in robust regression. This 4 | #' loss function is less sensitive to outliers than [rmse()]. This function is 5 | #' quadratic for small residual values and linear for large residual values. 6 | #' 7 | #' @family numeric metrics 8 | #' @family accuracy metrics 9 | #' @templateVar fn huber_loss 10 | #' @template return 11 | #' 12 | #' @inheritParams rmse 13 | #' 14 | #' @param delta A single `numeric` value. Defines the boundary where the loss function 15 | #' transitions from quadratic to linear. Defaults to 1. 16 | #' 17 | #' @author James Blair 18 | #' 19 | #' @references 20 | #' 21 | #' Huber, P. (1964). Robust Estimation of a Location Parameter. 22 | #' _Annals of Statistics_, 53 (1), 73-101. 23 | #' 24 | #' @template examples-numeric 25 | #' 26 | #' @export 27 | huber_loss <- function(data, ...) { 28 | UseMethod("huber_loss") 29 | } 30 | huber_loss <- new_numeric_metric( 31 | huber_loss, 32 | direction = "minimize" 33 | ) 34 | 35 | #' @rdname huber_loss 36 | #' @export 37 | huber_loss.data.frame <- function( 38 | data, 39 | truth, 40 | estimate, 41 | delta = 1, 42 | na_rm = TRUE, 43 | case_weights = NULL, 44 | ... 45 | ) { 46 | numeric_metric_summarizer( 47 | name = "huber_loss", 48 | fn = huber_loss_vec, 49 | data = data, 50 | truth = !!enquo(truth), 51 | estimate = !!enquo(estimate), 52 | na_rm = na_rm, 53 | case_weights = !!enquo(case_weights), 54 | # Extra argument for huber_loss_impl() 55 | fn_options = list(delta = delta) 56 | ) 57 | } 58 | 59 | #' @export 60 | #' @rdname huber_loss 61 | huber_loss_vec <- function( 62 | truth, 63 | estimate, 64 | delta = 1, 65 | na_rm = TRUE, 66 | case_weights = NULL, 67 | ... 68 | ) { 69 | check_numeric_metric(truth, estimate, case_weights) 70 | 71 | if (na_rm) { 72 | result <- yardstick_remove_missing(truth, estimate, case_weights) 73 | 74 | truth <- result$truth 75 | estimate <- result$estimate 76 | case_weights <- result$case_weights 77 | } else if (yardstick_any_missing(truth, estimate, case_weights)) { 78 | return(NA_real_) 79 | } 80 | 81 | huber_loss_impl(truth, estimate, delta, case_weights) 82 | } 83 | 84 | huber_loss_impl <- function( 85 | truth, 86 | estimate, 87 | delta, 88 | case_weights, 89 | call = caller_env() 90 | ) { 91 | # Weighted Huber Loss implementation confirmed against matlab: 92 | # https://www.mathworks.com/help/deeplearning/ref/dlarray.huber.html 93 | 94 | check_number_decimal(delta, min = 0, call = call) 95 | 96 | a <- truth - estimate 97 | abs_a <- abs(a) 98 | 99 | loss <- ifelse( 100 | abs_a <= delta, 101 | 0.5 * a^2, 102 | delta * (abs_a - 0.5 * delta) 103 | ) 104 | 105 | yardstick_mean(loss, case_weights = case_weights) 106 | } 107 | -------------------------------------------------------------------------------- /R/num-mae.R: -------------------------------------------------------------------------------- 1 | #' Mean absolute error 2 | #' 3 | #' Calculate the mean absolute error. This metric is in the same units as the 4 | #' original data. 5 | #' 6 | #' @family numeric metrics 7 | #' @family accuracy metrics 8 | #' @templateVar fn mae 9 | #' @template return 10 | #' 11 | #' @inheritParams rmse 12 | #' 13 | #' @author Max Kuhn 14 | #' 15 | #' @template examples-numeric 16 | #' 17 | #' @export 18 | mae <- function(data, ...) { 19 | UseMethod("mae") 20 | } 21 | mae <- new_numeric_metric( 22 | mae, 23 | direction = "minimize" 24 | ) 25 | 26 | #' @rdname mae 27 | #' @export 28 | mae.data.frame <- function( 29 | data, 30 | truth, 31 | estimate, 32 | na_rm = TRUE, 33 | case_weights = NULL, 34 | ... 35 | ) { 36 | numeric_metric_summarizer( 37 | name = "mae", 38 | fn = mae_vec, 39 | data = data, 40 | truth = !!enquo(truth), 41 | estimate = !!enquo(estimate), 42 | na_rm = na_rm, 43 | case_weights = !!enquo(case_weights) 44 | ) 45 | } 46 | 47 | #' @export 48 | #' @rdname mae 49 | mae_vec <- function(truth, estimate, na_rm = TRUE, case_weights = NULL, ...) { 50 | check_numeric_metric(truth, estimate, case_weights) 51 | 52 | if (na_rm) { 53 | result <- yardstick_remove_missing(truth, estimate, case_weights) 54 | 55 | truth <- result$truth 56 | estimate <- result$estimate 57 | case_weights <- result$case_weights 58 | } else if (yardstick_any_missing(truth, estimate, case_weights)) { 59 | return(NA_real_) 60 | } 61 | 62 | mae_impl(truth, estimate, case_weights) 63 | } 64 | 65 | mae_impl <- function(truth, estimate, case_weights) { 66 | errors <- abs(truth - estimate) 67 | yardstick_mean(errors, case_weights = case_weights) 68 | } 69 | -------------------------------------------------------------------------------- /R/num-mape.R: -------------------------------------------------------------------------------- 1 | #' Mean absolute percent error 2 | #' 3 | #' Calculate the mean absolute percentage error. This metric is in _relative 4 | #' units_. 5 | #' 6 | #' Note that a value of `Inf` is returned for `mape()` when the 7 | #' observed value is negative. 8 | #' 9 | #' @family numeric metrics 10 | #' @family accuracy metrics 11 | #' @templateVar fn mape 12 | #' @template return 13 | #' 14 | #' @inheritParams rmse 15 | #' 16 | #' @author Max Kuhn 17 | #' 18 | #' @template examples-numeric 19 | #' 20 | #' @export 21 | #' 22 | mape <- function(data, ...) { 23 | UseMethod("mape") 24 | } 25 | mape <- new_numeric_metric( 26 | mape, 27 | direction = "minimize" 28 | ) 29 | 30 | #' @rdname mape 31 | #' @export 32 | mape.data.frame <- function( 33 | data, 34 | truth, 35 | estimate, 36 | na_rm = TRUE, 37 | case_weights = NULL, 38 | ... 39 | ) { 40 | numeric_metric_summarizer( 41 | name = "mape", 42 | fn = mape_vec, 43 | data = data, 44 | truth = !!enquo(truth), 45 | estimate = !!enquo(estimate), 46 | na_rm = na_rm, 47 | case_weights = !!enquo(case_weights) 48 | ) 49 | } 50 | 51 | #' @export 52 | #' @rdname mape 53 | mape_vec <- function(truth, estimate, na_rm = TRUE, case_weights = NULL, ...) { 54 | check_numeric_metric(truth, estimate, case_weights) 55 | 56 | if (na_rm) { 57 | result <- yardstick_remove_missing(truth, estimate, case_weights) 58 | 59 | truth <- result$truth 60 | estimate <- result$estimate 61 | case_weights <- result$case_weights 62 | } else if (yardstick_any_missing(truth, estimate, case_weights)) { 63 | return(NA_real_) 64 | } 65 | 66 | mape_impl(truth, estimate, case_weights) 67 | } 68 | 69 | mape_impl <- function(truth, estimate, case_weights) { 70 | errors <- abs((truth - estimate) / truth) 71 | out <- yardstick_mean(errors, case_weights = case_weights) 72 | out <- out * 100 73 | out 74 | } 75 | -------------------------------------------------------------------------------- /R/num-msd.R: -------------------------------------------------------------------------------- 1 | #' Mean signed deviation 2 | #' 3 | #' @description 4 | #' Mean signed deviation (also known as mean signed difference, or mean signed 5 | #' error) computes the average differences between `truth` and `estimate`. A 6 | #' related metric is the mean absolute error ([mae()]). 7 | #' 8 | #' @details 9 | #' Mean signed deviation is rarely used, since positive and negative errors 10 | #' cancel each other out. For example, `msd_vec(c(100, -100), c(0, 0))` would 11 | #' return a seemingly "perfect" value of `0`, even though `estimate` is wildly 12 | #' different from `truth`. [mae()] attempts to remedy this by taking the 13 | #' absolute value of the differences before computing the mean. 14 | #' 15 | #' This metric is computed as `mean(truth - estimate)`, following the convention 16 | #' that an "error" is computed as `observed - predicted`. If you expected this 17 | #' metric to be computed as `mean(estimate - truth)`, reverse the sign of the 18 | #' result. 19 | #' 20 | #' @family numeric metrics 21 | #' @family accuracy metrics 22 | #' @templateVar fn msd 23 | #' @template return 24 | #' 25 | #' @inheritParams rmse 26 | #' 27 | #' @author Thomas Bierhance 28 | #' 29 | #' @template examples-numeric 30 | #' 31 | #' @export 32 | msd <- function(data, ...) { 33 | UseMethod("msd") 34 | } 35 | msd <- new_numeric_metric( 36 | msd, 37 | direction = "zero" 38 | ) 39 | 40 | #' @rdname msd 41 | #' @export 42 | msd.data.frame <- function( 43 | data, 44 | truth, 45 | estimate, 46 | na_rm = TRUE, 47 | case_weights = NULL, 48 | ... 49 | ) { 50 | numeric_metric_summarizer( 51 | name = "msd", 52 | fn = msd_vec, 53 | data = data, 54 | truth = !!enquo(truth), 55 | estimate = !!enquo(estimate), 56 | na_rm = na_rm, 57 | case_weights = !!enquo(case_weights) 58 | ) 59 | } 60 | 61 | #' @export 62 | #' @rdname msd 63 | msd_vec <- function(truth, estimate, na_rm = TRUE, case_weights = NULL, ...) { 64 | check_numeric_metric(truth, estimate, case_weights) 65 | 66 | if (na_rm) { 67 | result <- yardstick_remove_missing(truth, estimate, case_weights) 68 | 69 | truth <- result$truth 70 | estimate <- result$estimate 71 | case_weights <- result$case_weights 72 | } else if (yardstick_any_missing(truth, estimate, case_weights)) { 73 | return(NA_real_) 74 | } 75 | 76 | msd_impl(truth, estimate, case_weights) 77 | } 78 | 79 | msd_impl <- function(truth, estimate, case_weights) { 80 | yardstick_mean(truth - estimate, case_weights = case_weights) 81 | } 82 | -------------------------------------------------------------------------------- /R/num-poisson_log_loss.R: -------------------------------------------------------------------------------- 1 | #' Mean log loss for Poisson data 2 | #' 3 | #' Calculate the loss function for the Poisson distribution. 4 | #' 5 | #' @family numeric metrics 6 | #' @family accuracy metrics 7 | #' @templateVar fn poisson_log_loss 8 | #' @template return 9 | #' 10 | #' @inheritParams rmse 11 | #' 12 | #' @param truth The column identifier for the true counts (that is `integer`). 13 | #' This should be an unquoted column name although this argument is passed by 14 | #' expression and supports [quasiquotation][rlang::quasiquotation] (you can 15 | #' unquote column names). For `_vec()` functions, an `integer` vector. 16 | #' 17 | #' @author Max Kuhn 18 | #' 19 | #' @template examples-counts 20 | #' 21 | #' @export 22 | #' 23 | poisson_log_loss <- function(data, ...) { 24 | UseMethod("poisson_log_loss") 25 | } 26 | poisson_log_loss <- new_numeric_metric( 27 | poisson_log_loss, 28 | direction = "minimize" 29 | ) 30 | 31 | #' @rdname poisson_log_loss 32 | #' @export 33 | poisson_log_loss.data.frame <- function( 34 | data, 35 | truth, 36 | estimate, 37 | na_rm = TRUE, 38 | case_weights = NULL, 39 | ... 40 | ) { 41 | numeric_metric_summarizer( 42 | name = "poisson_log_loss", 43 | fn = poisson_log_loss_vec, 44 | data = data, 45 | truth = !!enquo(truth), 46 | estimate = !!enquo(estimate), 47 | na_rm = na_rm, 48 | case_weights = !!enquo(case_weights) 49 | ) 50 | } 51 | 52 | #' @export 53 | #' @rdname poisson_log_loss 54 | poisson_log_loss_vec <- function( 55 | truth, 56 | estimate, 57 | na_rm = TRUE, 58 | case_weights = NULL, 59 | ... 60 | ) { 61 | check_numeric_metric(truth, estimate, case_weights) 62 | 63 | if (na_rm) { 64 | result <- yardstick_remove_missing(truth, estimate, case_weights) 65 | 66 | truth <- result$truth 67 | estimate <- result$estimate 68 | case_weights <- result$case_weights 69 | } else if (yardstick_any_missing(truth, estimate, case_weights)) { 70 | return(NA_real_) 71 | } 72 | 73 | poisson_log_loss_impl(truth, estimate, case_weights) 74 | } 75 | 76 | poisson_log_loss_impl <- function(truth, estimate, case_weights) { 77 | if (!is.integer(truth)) { 78 | truth <- as.integer(truth) 79 | } 80 | eps <- 1e-15 81 | estimate <- pmax(estimate, eps) 82 | loss <- log(gamma(truth + 1)) + estimate - log(estimate) * truth 83 | 84 | yardstick_mean(loss, case_weights = case_weights) 85 | } 86 | -------------------------------------------------------------------------------- /R/num-pseudo_huber_loss.R: -------------------------------------------------------------------------------- 1 | #' Psuedo-Huber Loss 2 | #' 3 | #' Calculate the Pseudo-Huber Loss, a smooth approximation of [huber_loss()]. 4 | #' Like [huber_loss()], this is less sensitive to outliers than [rmse()]. 5 | #' 6 | #' @family numeric metrics 7 | #' @family accuracy metrics 8 | #' @templateVar fn huber_loss_pseudo 9 | #' @template return 10 | #' 11 | #' @inheritParams huber_loss 12 | #' 13 | #' @author James Blair 14 | #' 15 | #' @references 16 | #' 17 | #' Huber, P. (1964). Robust Estimation of a Location Parameter. 18 | #' _Annals of Statistics_, 53 (1), 73-101. 19 | #' 20 | #' Hartley, Richard (2004). Multiple View Geometry in Computer Vision. 21 | #' (Second Edition). Page 619. 22 | #' 23 | #' @template examples-numeric 24 | #' 25 | #' @export 26 | huber_loss_pseudo <- function(data, ...) { 27 | UseMethod("huber_loss_pseudo") 28 | } 29 | huber_loss_pseudo <- new_numeric_metric( 30 | huber_loss_pseudo, 31 | direction = "minimize" 32 | ) 33 | 34 | #' @rdname huber_loss_pseudo 35 | #' @export 36 | huber_loss_pseudo.data.frame <- function( 37 | data, 38 | truth, 39 | estimate, 40 | delta = 1, 41 | na_rm = TRUE, 42 | case_weights = NULL, 43 | ... 44 | ) { 45 | numeric_metric_summarizer( 46 | name = "huber_loss_pseudo", 47 | fn = huber_loss_pseudo_vec, 48 | data = data, 49 | truth = !!enquo(truth), 50 | estimate = !!enquo(estimate), 51 | na_rm = na_rm, 52 | case_weights = !!enquo(case_weights), 53 | # Extra argument for huber_loss_pseudo_impl() 54 | fn_options = list(delta = delta) 55 | ) 56 | } 57 | 58 | #' @export 59 | #' @rdname huber_loss_pseudo 60 | huber_loss_pseudo_vec <- function( 61 | truth, 62 | estimate, 63 | delta = 1, 64 | na_rm = TRUE, 65 | case_weights = NULL, 66 | ... 67 | ) { 68 | check_numeric_metric(truth, estimate, case_weights) 69 | 70 | if (na_rm) { 71 | result <- yardstick_remove_missing(truth, estimate, case_weights) 72 | 73 | truth <- result$truth 74 | estimate <- result$estimate 75 | case_weights <- result$case_weights 76 | } else if (yardstick_any_missing(truth, estimate, case_weights)) { 77 | return(NA_real_) 78 | } 79 | 80 | huber_loss_pseudo_impl( 81 | truth = truth, 82 | estimate = estimate, 83 | delta = delta, 84 | case_weights = case_weights 85 | ) 86 | } 87 | 88 | huber_loss_pseudo_impl <- function( 89 | truth, 90 | estimate, 91 | delta, 92 | case_weights, 93 | call = caller_env() 94 | ) { 95 | check_number_decimal(delta, min = 0, call = call) 96 | 97 | a <- truth - estimate 98 | loss <- delta^2 * (sqrt(1 + (a / delta)^2) - 1) 99 | 100 | yardstick_mean(loss, case_weights = case_weights) 101 | } 102 | -------------------------------------------------------------------------------- /R/num-rmse.R: -------------------------------------------------------------------------------- 1 | #' Root mean squared error 2 | #' 3 | #' Calculate the root mean squared error. `rmse()` is a metric that is in 4 | #' the same units as the original data. 5 | #' 6 | #' @family numeric metrics 7 | #' @family accuracy metrics 8 | #' @templateVar fn rmse 9 | #' @template return 10 | #' 11 | #' @param data A `data.frame` containing the columns specified by the `truth` 12 | #' and `estimate` arguments. 13 | #' 14 | #' @param truth The column identifier for the true results 15 | #' (that is `numeric`). This should be an unquoted column name although 16 | #' this argument is passed by expression and supports 17 | #' [quasiquotation][rlang::quasiquotation] (you can unquote column 18 | #' names). For `_vec()` functions, a `numeric` vector. 19 | #' 20 | #' @param estimate The column identifier for the predicted 21 | #' results (that is also `numeric`). As with `truth` this can be 22 | #' specified different ways but the primary method is to use an 23 | #' unquoted variable name. For `_vec()` functions, a `numeric` vector. 24 | #' 25 | #' @param na_rm A `logical` value indicating whether `NA` 26 | #' values should be stripped before the computation proceeds. 27 | #' 28 | #' @param case_weights The optional column identifier for case weights. This 29 | #' should be an unquoted column name that evaluates to a numeric column in 30 | #' `data`. For `_vec()` functions, a numeric vector, 31 | #' [hardhat::importance_weights()], or [hardhat::frequency_weights()]. 32 | #' 33 | #' @param ... Not currently used. 34 | #' 35 | #' @author Max Kuhn 36 | #' 37 | #' @template examples-numeric 38 | #' 39 | #' @export 40 | #' 41 | rmse <- function(data, ...) { 42 | UseMethod("rmse") 43 | } 44 | rmse <- new_numeric_metric( 45 | rmse, 46 | direction = "minimize" 47 | ) 48 | 49 | #' @rdname rmse 50 | #' @export 51 | rmse.data.frame <- function( 52 | data, 53 | truth, 54 | estimate, 55 | na_rm = TRUE, 56 | case_weights = NULL, 57 | ... 58 | ) { 59 | numeric_metric_summarizer( 60 | name = "rmse", 61 | fn = rmse_vec, 62 | data = data, 63 | truth = !!enquo(truth), 64 | estimate = !!enquo(estimate), 65 | na_rm = na_rm, 66 | case_weights = !!enquo(case_weights) 67 | ) 68 | } 69 | 70 | #' @export 71 | #' @rdname rmse 72 | rmse_vec <- function(truth, estimate, na_rm = TRUE, case_weights = NULL, ...) { 73 | check_numeric_metric(truth, estimate, case_weights) 74 | 75 | if (na_rm) { 76 | result <- yardstick_remove_missing(truth, estimate, case_weights) 77 | 78 | truth <- result$truth 79 | estimate <- result$estimate 80 | case_weights <- result$case_weights 81 | } else if (yardstick_any_missing(truth, estimate, case_weights)) { 82 | return(NA_real_) 83 | } 84 | 85 | rmse_impl(truth, estimate, case_weights = case_weights) 86 | } 87 | 88 | rmse_impl <- function(truth, estimate, case_weights) { 89 | errors <- (truth - estimate)^2 90 | sqrt(yardstick_mean(errors, case_weights = case_weights)) 91 | } 92 | -------------------------------------------------------------------------------- /R/num-rpiq.R: -------------------------------------------------------------------------------- 1 | #' Ratio of performance to inter-quartile 2 | #' 3 | #' These functions are appropriate for cases where the model outcome is a 4 | #' numeric. The ratio of performance to deviation 5 | #' ([rpd()]) and the ratio of performance to inter-quartile ([rpiq()]) 6 | #' are both measures of consistency/correlation between observed 7 | #' and predicted values (and not of accuracy). 8 | #' 9 | #' @inherit rpd details 10 | #' @inherit rpd references 11 | #' 12 | #' @family numeric metrics 13 | #' @family consistency metrics 14 | #' @templateVar fn rpd 15 | #' @template return 16 | #' 17 | #' @inheritParams rmse 18 | #' 19 | #' @author Pierre Roudier 20 | #' 21 | #' @seealso 22 | #' 23 | #' The closely related deviation metric: [rpd()] 24 | #' 25 | #' @template examples-numeric 26 | #' 27 | #' @export 28 | rpiq <- function(data, ...) { 29 | UseMethod("rpiq") 30 | } 31 | rpiq <- new_numeric_metric( 32 | rpiq, 33 | direction = "maximize" 34 | ) 35 | 36 | #' @rdname rpiq 37 | #' @export 38 | rpiq.data.frame <- function( 39 | data, 40 | truth, 41 | estimate, 42 | na_rm = TRUE, 43 | case_weights = NULL, 44 | ... 45 | ) { 46 | numeric_metric_summarizer( 47 | name = "rpiq", 48 | fn = rpiq_vec, 49 | data = data, 50 | truth = !!enquo(truth), 51 | estimate = !!enquo(estimate), 52 | na_rm = na_rm, 53 | case_weights = !!enquo(case_weights) 54 | ) 55 | } 56 | 57 | #' @export 58 | #' @rdname rpiq 59 | rpiq_vec <- function(truth, estimate, na_rm = TRUE, case_weights = NULL, ...) { 60 | check_numeric_metric(truth, estimate, case_weights) 61 | 62 | if (na_rm) { 63 | result <- yardstick_remove_missing(truth, estimate, case_weights) 64 | 65 | truth <- result$truth 66 | estimate <- result$estimate 67 | case_weights <- result$case_weights 68 | } else if (yardstick_any_missing(truth, estimate, case_weights)) { 69 | return(NA_real_) 70 | } 71 | 72 | rpiq_impl(truth, estimate, case_weights) 73 | } 74 | 75 | rpiq_impl <- function(truth, estimate, case_weights) { 76 | quantiles <- yardstick_quantile( 77 | x = truth, 78 | probabilities = c(0.25, 0.75), 79 | case_weights = case_weights 80 | ) 81 | 82 | iqr <- quantiles[[2L]] - quantiles[[1L]] 83 | rmse <- rmse_vec(truth, estimate, case_weights = case_weights) 84 | 85 | iqr / rmse 86 | } 87 | -------------------------------------------------------------------------------- /R/num-smape.R: -------------------------------------------------------------------------------- 1 | #' Symmetric mean absolute percentage error 2 | #' 3 | #' Calculate the symmetric mean absolute percentage error. This metric is in 4 | #' _relative units_. 5 | #' 6 | #' This implementation of `smape()` is the "usual definition" where the 7 | #' denominator is divided by two. 8 | #' 9 | #' @family numeric metrics 10 | #' @family accuracy metrics 11 | #' @templateVar fn smape 12 | #' @template return 13 | #' 14 | #' @inheritParams rmse 15 | #' 16 | #' @author Max Kuhn, Riaz Hedayati 17 | #' 18 | #' @template examples-numeric 19 | #' 20 | #' @export 21 | #' 22 | smape <- function(data, ...) { 23 | UseMethod("smape") 24 | } 25 | smape <- new_numeric_metric( 26 | smape, 27 | direction = "minimize" 28 | ) 29 | 30 | #' @rdname smape 31 | #' @export 32 | smape.data.frame <- function( 33 | data, 34 | truth, 35 | estimate, 36 | na_rm = TRUE, 37 | case_weights = NULL, 38 | ... 39 | ) { 40 | numeric_metric_summarizer( 41 | name = "smape", 42 | fn = smape_vec, 43 | data = data, 44 | truth = !!enquo(truth), 45 | estimate = !!enquo(estimate), 46 | na_rm = na_rm, 47 | case_weights = !!enquo(case_weights) 48 | ) 49 | } 50 | 51 | #' @export 52 | #' @rdname smape 53 | smape_vec <- function(truth, estimate, na_rm = TRUE, case_weights = NULL, ...) { 54 | check_numeric_metric(truth, estimate, case_weights) 55 | 56 | if (na_rm) { 57 | result <- yardstick_remove_missing(truth, estimate, case_weights) 58 | 59 | truth <- result$truth 60 | estimate <- result$estimate 61 | case_weights <- result$case_weights 62 | } else if (yardstick_any_missing(truth, estimate, case_weights)) { 63 | return(NA_real_) 64 | } 65 | 66 | smape_impl(truth, estimate, case_weights) 67 | } 68 | 69 | smape_impl <- function(truth, estimate, case_weights) { 70 | numer <- abs(estimate - truth) 71 | denom <- (abs(truth) + abs(estimate)) / 2 72 | error <- numer / denom 73 | 74 | out <- yardstick_mean(error, case_weights = case_weights) 75 | out <- out * 100 76 | 77 | out 78 | } 79 | -------------------------------------------------------------------------------- /R/prob-helpers.R: -------------------------------------------------------------------------------- 1 | # AUC helper ------------------------------------------------------------------- 2 | 3 | # AUC by trapezoidal rule: 4 | # https://en.wikipedia.org/wiki/Trapezoidal_rule 5 | # assumes x is a partition and that x & y are the same length 6 | auc <- function(x, y, na_rm = TRUE) { 7 | if (na_rm) { 8 | comp <- stats::complete.cases(x, y) 9 | x <- x[comp] 10 | y <- y[comp] 11 | } 12 | 13 | if (is.unsorted(x, na.rm = TRUE, strictly = FALSE)) { 14 | # should not be reachable 15 | cli::cli_abort( 16 | "{.arg x} must already be in weakly increasing order.", 17 | .internal = TRUE 18 | ) 19 | } 20 | 21 | # length x = length y 22 | n <- length(x) 23 | 24 | # dx 25 | dx <- x[-1] - x[-n] 26 | 27 | # mid height of y 28 | height <- (y[-n] + y[-1]) / 2 29 | 30 | auc <- sum(height * dx) 31 | 32 | auc 33 | } 34 | 35 | # One vs all helper ------------------------------------------------------------ 36 | 37 | one_vs_all_impl <- function(fn, truth, estimate, case_weights, call, ...) { 38 | lvls <- levels(truth) 39 | other <- "..other" 40 | 41 | metric_lst <- new_list(n = length(lvls)) 42 | 43 | # one vs all 44 | for (i in seq_along(lvls)) { 45 | # Recode truth into 2 levels, relevant and other 46 | # Pull out estimate prob column corresponding to relevant 47 | # Pulls by order, so they have to be in the same order as the levels! 48 | # (cannot pull by name because they arent always the same name i.e. .pred_{level}) 49 | lvl <- lvls[i] 50 | 51 | truth_temp <- factor( 52 | x = ifelse(truth == lvl, lvl, other), 53 | levels = c(lvl, other) 54 | ) 55 | 56 | estimate_temp <- as.numeric(estimate[, i]) 57 | 58 | # `one_vs_all_impl()` always ignores the event level ordering when 59 | # computing each individual binary metric 60 | metric_lst[[i]] <- fn( 61 | truth_temp, 62 | estimate_temp, 63 | case_weights = case_weights, 64 | event_level = "first", 65 | ... 66 | ) 67 | } 68 | 69 | metric_lst 70 | } 71 | 72 | one_vs_all_with_level <- function( 73 | fn, 74 | truth, 75 | estimate, 76 | case_weights, 77 | call, 78 | ... 79 | ) { 80 | res <- one_vs_all_impl( 81 | fn = fn, 82 | truth = truth, 83 | estimate = estimate, 84 | case_weights = case_weights, 85 | call = call, 86 | ... 87 | ) 88 | 89 | lvls <- levels(truth) 90 | 91 | with_level <- function(df, lvl) { 92 | df$.level <- lvl 93 | dplyr::select(df, .level, tidyselect::everything()) 94 | } 95 | 96 | res <- mapply( 97 | with_level, 98 | df = res, 99 | lvl = lvls, 100 | SIMPLIFY = FALSE, 101 | USE.NAMES = FALSE 102 | ) 103 | 104 | dplyr::bind_rows(res) 105 | } 106 | -------------------------------------------------------------------------------- /R/reexports.R: -------------------------------------------------------------------------------- 1 | #' @importFrom generics tidy 2 | #' @export 3 | generics::tidy 4 | -------------------------------------------------------------------------------- /R/yardstick-package.R: -------------------------------------------------------------------------------- 1 | #' @keywords internal 2 | "_PACKAGE" 3 | 4 | ## usethis namespace: start 5 | #' @import rlang 6 | #' @import vctrs 7 | #' @importFrom dplyr as_tibble 8 | #' @importFrom lifecycle deprecated 9 | #' @useDynLib yardstick, .registration = TRUE 10 | ## usethis namespace: end 11 | NULL 12 | 13 | # Importing something from utils so we don't get dinged about having an 14 | # Import we don't use. We use `utils::globalVariables()` at a global scope, and 15 | # R CMD check doesn't detect that. Usually shows up as a NOTE on rhub's Linux 16 | # check machines. 17 | #' @importFrom utils globalVariables 18 | NULL 19 | -------------------------------------------------------------------------------- /air.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/air.toml -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: auto 8 | threshold: 1% 9 | informational: true 10 | patch: 11 | default: 12 | target: auto 13 | threshold: 1% 14 | informational: true 15 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- 1 | ## Submission 1.3.2 2 | 3 | We checked 41 reverse dependencies, comparing R CMD check results across CRAN and dev versions of this package. 4 | 5 | * We saw 0 new problems 6 | * We failed to check 0 packages 7 | 8 | ## Submission 1.3.1 9 | 10 | We checked 37 reverse dependencies, comparing R CMD check results across CRAN and dev versions of this package. 11 | 12 | * We saw 0 new problems 13 | * We failed to check 0 packages 14 | 15 | ## Submission 1.3.0 16 | 17 | Final additions to make survival metrics possible. 18 | 19 | We checked 36 reverse dependencies, comparing R CMD check results across CRAN and dev versions of this package. 20 | 21 | * We saw 0 new problems 22 | * We failed to check 0 packages 23 | 24 | ## Submission 1.2.0 25 | 26 | Major refactor done to allow new types of metrics to be used. 27 | 28 | ### revdepcheck results 29 | 30 | We checked 32 reverse dependencies, comparing R CMD check results across CRAN and dev versions of this package. 31 | 32 | * We saw 1 new problems 33 | * We failed to check 0 packages 34 | 35 | #### Problems 36 | 37 | * sknifedatar 38 | Were contacted on 6th of February with PR to fix https://github.com/rafzamb/sknifedatar/pull/23. 39 | 40 | ## Submission 1.1.0 41 | 42 | This minor release of yardstick changes the maintainer from Davis to Emil. 43 | 44 | ## Submission 1.0.0 45 | 46 | This major release of yardstick adds case weight support throughout the package. 47 | 48 | ## Submission 0.0.9 49 | 50 | This release is mainly to avoid any issues from the upcoming dplyr 1.0.8 51 | release, but it also includes small bug fixes and a new metric. 52 | 53 | ## Submission 0.0.8 54 | 55 | This release contains a number of new metrics, features, and bug fixes. We 56 | have also re-licensed yardstick to MIT. 57 | 58 | ## Submission 0.0.7 59 | 60 | This release contains a number of small bug fixes, along with a new argument 61 | for class and class probability metrics, `event_level`, that replaces the 62 | now soft deprecated global option, `yardstick.event_first`. 63 | -------------------------------------------------------------------------------- /data-raw/dyn-surv-metrics/auc_churn_res.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/data-raw/dyn-surv-metrics/auc_churn_res.RData -------------------------------------------------------------------------------- /data-raw/dyn-surv-metrics/brier_churn_res.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/data-raw/dyn-surv-metrics/brier_churn_res.RData -------------------------------------------------------------------------------- /data-raw/dyn-surv-metrics/collect_metrics.R: -------------------------------------------------------------------------------- 1 | library(tidymodels) 2 | library(survival) 3 | 4 | # ------------------------------------------------------------------------------ 5 | 6 | tidymodels_prefer() 7 | theme_set(theme_bw()) 8 | options(pillar.advice = FALSE, pillar.min_title_chars = Inf) 9 | 10 | # ------------------------------------------------------------------------------ 11 | 12 | # Objects saved from riskRegression::Score() via browser() 13 | load("data-raw/dyn-surv-metrics/rr_churn_data.RData") 14 | load("data-raw/dyn-surv-metrics/brier_churn_res.RData") 15 | load("data-raw/dyn-surv-metrics/auc_churn_res.RData") 16 | 17 | brier_churn_res |> 18 | filter(grepl("churn", model)) |> 19 | readr::write_rds("tests/testthat/data/brier_churn_res.rds") 20 | 21 | auc_churn_res |> 22 | readr::write_rds("tests/testthat/data/auc_churn_res.rds") 23 | 24 | # ------------------------------------------------------------------------------ 25 | 26 | rr_churn_data <- rr_churn_data[rr_churn_data$model == 1, ] 27 | rr_churn_data$surv <- Surv(rr_churn_data$time, rr_churn_data$status) 28 | rr_churn_data$surv_prob <- 1 - rr_churn_data$risk 29 | 30 | # From Brier.survival() 31 | # DT[time<=times & status==1,residuals:=(1-risk)^2/WTi] 32 | # DT[time<=times & status==0,residuals:=0] 33 | # DT[time>times,residuals:=(risk)^2/Wt] 34 | 35 | g_1 <- rr_churn_data$time <= rr_churn_data$times & rr_churn_data$status == 1 36 | g_2 <- rr_churn_data$time <= rr_churn_data$times & rr_churn_data$status == 0 37 | g_3 <- rr_churn_data$time > rr_churn_data$times 38 | 39 | rr_churn_data$ipcw <- NA_real_ 40 | rr_churn_data$ipcw[g_1] <- 1 / rr_churn_data$WTi[g_1] 41 | rr_churn_data$ipcw[g_2] <- 0 42 | rr_churn_data$ipcw[g_3] <- 1 / rr_churn_data$Wt[g_3] 43 | 44 | rr_churn_data |> 45 | readr::write_rds("tests/testthat/data/rr_churn_data.rds") 46 | 47 | tidy_churn <- readRDS(test_path("data/rr_churn_data.rds")) |> 48 | dplyr::rename( 49 | .eval_time = times, 50 | .pred_survival = surv_prob, 51 | .weight_censored = ipcw 52 | ) |> 53 | dplyr::mutate( 54 | .weight_censored = dplyr::if_else( 55 | status == 0 & time < .eval_time, 56 | NA, 57 | .weight_censored 58 | ) 59 | ) |> 60 | tidyr::nest(.pred = -c(ID, time, status, model)) |> 61 | dplyr::mutate(surv_obj = survival::Surv(time, status)) 62 | 63 | tidy_churn |> 64 | readr::write_rds("tests/testthat/data/tidy_churn.rds") 65 | -------------------------------------------------------------------------------- /data-raw/dyn-surv-metrics/generate_metrics.R: -------------------------------------------------------------------------------- 1 | library(survival) # survival_3.5-3 2 | library(riskRegression) # riskRegression_2023.03.10 3 | library(prodlim) # prodlim_2022.10.13 4 | library(modeldata) 5 | 6 | # ------------------------------------------------------------------------------ 7 | 8 | data(wa_churn) 9 | 10 | wa_churn <- 11 | wa_churn |> 12 | filter(!is.na(total_charges)) |> 13 | mutate( 14 | status = ifelse(churn == "No", 1, 0) 15 | ) |> 16 | select(tenure, status, female, total_charges) 17 | 18 | # ------------------------------------------------------------------------------ 19 | 20 | cox_fit <- coxph( 21 | Surv(tenure, status) ~ female + total_charges, 22 | data = wa_churn, 23 | y = TRUE, 24 | x = TRUE 25 | ) 26 | 27 | # ------------------------------------------------------------------------------ 28 | 29 | xs_auc <- Score( 30 | list("churn" = cox_fit), 31 | formula = Surv(tenure, status) ~ 1, 32 | data = wa_churn, 33 | conf.int = FALSE, 34 | times = c(1, 23, 70), 35 | metrics = "AUC", 36 | cens.method = "ipcw", 37 | cens.model = "km", 38 | seed = 1 39 | ) 40 | 41 | xs_brier <- Score( 42 | list("churn" = cox_fit), 43 | formula = Surv(tenure, status) ~ 1, 44 | data = wa_churn, 45 | conf.int = FALSE, 46 | times = c(1, 23, 70), 47 | metrics = "Brier", 48 | cens.method = "ipcw", 49 | cens.model = "km", 50 | seed = 1 51 | ) 52 | 53 | # ------------------------------------------------------------------------------ 54 | 55 | # after getPerformanceData() 56 | if (FALSE) { 57 | rr_churn_data <- as.data.frame(DT) 58 | save(rr_churn_data, file = "rr_churn_data.RData") 59 | } 60 | 61 | # after computePerformance() when metrics = "AUC" 62 | if (FALSE) { 63 | auc_churn_res <- as.data.frame(noSplit$AUC$score) 64 | save(auc_churn_res, file = "auc_churn_res.RData") 65 | } 66 | 67 | # after computePerformance() when metrics = "Brier" 68 | if (FALSE) { 69 | brier_churn_res <- as.data.frame(noSplit$Brier$score) 70 | save(brier_churn_res, file = "brier_churn_res.RData") 71 | } 72 | -------------------------------------------------------------------------------- /data-raw/dyn-surv-metrics/rr_churn_data.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/data-raw/dyn-surv-metrics/rr_churn_data.RData -------------------------------------------------------------------------------- /data-raw/dyn-surv-metrics/survival-curve-reference.R: -------------------------------------------------------------------------------- 1 | tidy_churn <- readRDS(test_path("data/tidy_churn.rds")) 2 | 3 | ref_roc_auc_survival <- tidy_churn |> 4 | roc_auc_survival( 5 | truth = surv_obj, 6 | .pred 7 | ) 8 | 9 | yardstick_res <- ref_roc_auc_survival |> 10 | readr::write_rds("tests/testthat/data/ref_roc_auc_survival.rds") 11 | 12 | ref_roc_curve_survival <- tidy_churn |> 13 | roc_curve_survival( 14 | truth = surv_obj, 15 | .pred 16 | ) 17 | 18 | yardstick_res <- ref_roc_curve_survival |> 19 | readr::write_rds("tests/testthat/data/ref_roc_curve_survival.rds") 20 | -------------------------------------------------------------------------------- /data-raw/lung_surv.R: -------------------------------------------------------------------------------- 1 | library(tidymodels) 2 | library(censored) 3 | 4 | # ------------------------------------------------------------------------------ 5 | 6 | tidymodels_prefer() 7 | theme_set(theme_bw()) 8 | options(pillar.advice = FALSE, pillar.min_title_chars = Inf) 9 | 10 | # ------------------------------------------------------------------------------ 11 | 12 | lung_data <- 13 | survival::lung |> 14 | select(time, status, age, sex, ph.ecog) 15 | 16 | model_fit <- 17 | survival_reg() |> 18 | fit(Surv(time, status) ~ age + sex + ph.ecog, data = lung_data) 19 | 20 | # ------------------------------------------------------------------------------ 21 | 22 | pred_times <- (1:5) * 100 23 | 24 | # Data to compute metrics: 25 | lung_surv <- 26 | # Now dynamic predictions at 5 time points 27 | predict(model_fit, lung_data, type = "survival", eval_time = pred_times) |> 28 | bind_cols( 29 | # Static predictions 30 | predict(model_fit, lung_data, type = "time"), 31 | # We'll need the surv object 32 | lung_data |> transmute(surv_obj = Surv(time, status)) 33 | ) |> 34 | .censoring_weights_graf(model_fit, .) 35 | 36 | usethis::use_data(lung_surv, overwrite = TRUE) 37 | -------------------------------------------------------------------------------- /data/datalist: -------------------------------------------------------------------------------- 1 | hpc_cv: hpc_cv 2 | pathology: pathology 3 | solubility_test: solubility_test 4 | two_class_example: two_class_example -------------------------------------------------------------------------------- /data/hpc_cv.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/data/hpc_cv.rda -------------------------------------------------------------------------------- /data/lung_surv.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/data/lung_surv.rda -------------------------------------------------------------------------------- /data/pathology.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/data/pathology.rda -------------------------------------------------------------------------------- /data/solubility_test.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/data/solubility_test.rda -------------------------------------------------------------------------------- /data/two_class_example.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/data/two_class_example.rda -------------------------------------------------------------------------------- /man-roxygen/event-fair.R: -------------------------------------------------------------------------------- 1 | #' @section Measuring Disparity: 2 | #' By default, this function takes the difference in range of <%=internal_fn %> 3 | #' `.estimate`s across groups. That is, the maximum pair-wise disparity between 4 | #' groups is the return value of `<%=fn %>()`'s `.estimate`. 5 | #' 6 | #' For finer control of group treatment, construct a context-aware fairness 7 | #' metric with the [new_groupwise_metric()] function by passing a custom `aggregate` 8 | #' function: 9 | #' 10 | #' ``` 11 | #' # the actual default `aggregate` is: 12 | #' diff_range <- function(x, ...) {diff(range(x$.estimate))} 13 | #' 14 | #' <%=fn %>_2 <- 15 | #' new_groupwise_metric( 16 | #' fn = <%=internal_fn %>, 17 | #' name = "<%=fn %>_2", 18 | #' aggregate = diff_range 19 | #' ) 20 | #' ``` 21 | #' 22 | #' In `aggregate()`, `x` is the `metric_set()` output with <%=internal_fn %> values 23 | #' for each group, and `...` gives additional arguments (such as a grouping 24 | #' level to refer to as the "baseline") to pass to the function outputted 25 | #' by `<%=fn %>_2()` for context. 26 | -------------------------------------------------------------------------------- /man-roxygen/event_first.R: -------------------------------------------------------------------------------- 1 | #' @section Relevant Level: 2 | #' 3 | #' There is no common convention on which factor level should 4 | #' automatically be considered the "event" or "positive" result 5 | #' when computing binary classification metrics. In `yardstick`, the default 6 | #' is to use the _first_ level. To alter this, change the argument 7 | #' `event_level` to `"second"` to consider the _last_ level of the factor the 8 | #' level of interest. For multiclass extensions involving one-vs-all 9 | #' comparisons (such as macro averaging), this option is ignored and 10 | #' the "one" level is always the relevant result. 11 | -------------------------------------------------------------------------------- /man-roxygen/examples-binary-prob.R: -------------------------------------------------------------------------------- 1 | #' @examples 2 | #' # --------------------------------------------------------------------------- 3 | #' # Two class example 4 | #' 5 | #' # `truth` is a 2 level factor. The first level is `"Class1"`, which is the 6 | #' # "event of interest" by default in yardstick. See the Relevant Level 7 | #' # section above. 8 | #' data(two_class_example) 9 | #' 10 | #' # Binary metrics using class probabilities take a factor `truth` column, 11 | #' # and a single class probability column containing the probabilities of 12 | #' # the event of interest. Here, since `"Class1"` is the first level of 13 | #' # `"truth"`, it is the event of interest and we pass in probabilities for it. 14 | #' <%=fn %>(two_class_example, truth, Class1) 15 | #' 16 | -------------------------------------------------------------------------------- /man-roxygen/examples-class.R: -------------------------------------------------------------------------------- 1 | #' @examples 2 | #' # Two class 3 | #' data("two_class_example") 4 | #' <%=fn %>(two_class_example, truth, predicted) 5 | #' 6 | #' # Multiclass 7 | #' library(dplyr) 8 | #' data(hpc_cv) 9 | #' 10 | #' hpc_cv |> 11 | #' filter(Resample == "Fold01") |> 12 | #' <%=fn %>(obs, pred) 13 | #' 14 | #' # Groups are respected 15 | #' hpc_cv |> 16 | #' group_by(Resample) |> 17 | #' <%=fn %>(obs, pred) 18 | #' 19 | #' # Weighted macro averaging 20 | #' hpc_cv |> 21 | #' group_by(Resample) |> 22 | #' <%=fn %>(obs, pred, estimator = "macro_weighted") 23 | #' 24 | #' # Vector version 25 | #' <%=fn %>_vec( 26 | #' two_class_example$truth, 27 | #' two_class_example$predicted 28 | #' ) 29 | #' 30 | #' # Making Class2 the "relevant" level 31 | #' <%=fn %>_vec( 32 | #' two_class_example$truth, 33 | #' two_class_example$predicted, 34 | #' event_level = "second" 35 | #' ) 36 | -------------------------------------------------------------------------------- /man-roxygen/examples-counts.R: -------------------------------------------------------------------------------- 1 | #' @examples 2 | #' count_truth <- c(2L, 7L, 1L, 1L, 0L, 3L) 3 | #' count_pred <- c(2.14, 5.35, 1.65, 1.56, 1.3, 2.71) 4 | #' count_results <- dplyr::tibble(count = count_truth, pred = count_pred) 5 | #' 6 | #' # Supply truth and predictions as bare column names 7 | #' <%=fn %>(count_results, count, pred) 8 | #' 9 | 10 | -------------------------------------------------------------------------------- /man-roxygen/examples-fair.R: -------------------------------------------------------------------------------- 1 | #' @examples 2 | #' library(dplyr) 3 | #' 4 | #' data(hpc_cv) 5 | #' 6 | #' head(hpc_cv) 7 | #' 8 | #' # evaluate `<%=fn %>()` by Resample 9 | #' m_set <- metric_set(<%=fn %>(Resample)) 10 | #' 11 | #' # use output like any other metric set 12 | #' hpc_cv |> 13 | #' m_set(truth = obs, estimate = pred) 14 | #' 15 | #' # can mix fairness metrics and regular metrics 16 | #' m_set_2 <- metric_set(sens, <%=fn %>(Resample)) 17 | #' 18 | #' hpc_cv |> 19 | #' m_set_2(truth = obs, estimate = pred) 20 | -------------------------------------------------------------------------------- /man-roxygen/examples-multiclass-prob.R: -------------------------------------------------------------------------------- 1 | #' @examples 2 | #' # --------------------------------------------------------------------------- 3 | #' # Multiclass example 4 | #' 5 | #' # `obs` is a 4 level factor. The first level is `"VF"`, which is the 6 | #' # "event of interest" by default in yardstick. See the Relevant Level 7 | #' # section above. 8 | #' data(hpc_cv) 9 | #' 10 | #' # You can use the col1:colN tidyselect syntax 11 | #' library(dplyr) 12 | #' hpc_cv |> 13 | #' filter(Resample == "Fold01") |> 14 | #' <%=fn %>(obs, VF:L) 15 | #' 16 | #' # Change the first level of `obs` from `"VF"` to `"M"` to alter the 17 | #' # event of interest. The class probability columns should be supplied 18 | #' # in the same order as the levels. 19 | #' hpc_cv |> 20 | #' filter(Resample == "Fold01") |> 21 | #' mutate(obs = relevel(obs, "M")) |> 22 | #' <%=fn %>(obs, M, VF:L) 23 | #' 24 | #' # Groups are respected 25 | #' hpc_cv |> 26 | #' group_by(Resample) |> 27 | #' <%=fn %>(obs, VF:L) 28 | #' 29 | #' # Weighted macro averaging 30 | #' hpc_cv |> 31 | #' group_by(Resample) |> 32 | #' <%=fn %>(obs, VF:L, estimator = "macro_weighted") 33 | #' 34 | #' # Vector version 35 | #' # Supply a matrix of class probabilities 36 | #' fold1 <- hpc_cv |> 37 | #' filter(Resample == "Fold01") 38 | #' 39 | #' <%=fn %>_vec( 40 | #' truth = fold1$obs, 41 | #' matrix( 42 | #' c(fold1$VF, fold1$F, fold1$M, fold1$L), 43 | #' ncol = 4 44 | #' ) 45 | #' ) 46 | #' 47 | -------------------------------------------------------------------------------- /man-roxygen/examples-numeric.R: -------------------------------------------------------------------------------- 1 | #' @examples 2 | #' # Supply truth and predictions as bare column names 3 | #' <%=fn %>(solubility_test, solubility, prediction) 4 | #' 5 | #' library(dplyr) 6 | #' 7 | #' set.seed(1234) 8 | #' size <- 100 9 | #' times <- 10 10 | #' 11 | #' # create 10 resamples 12 | #' solubility_resampled <- bind_rows( 13 | #' replicate( 14 | #' n = times, 15 | #' expr = sample_n(solubility_test, size, replace = TRUE), 16 | #' simplify = FALSE 17 | #' ), 18 | #' .id = "resample" 19 | #' ) 20 | #' 21 | #' # Compute the metric by group 22 | #' metric_results <- solubility_resampled |> 23 | #' group_by(resample) |> 24 | #' <%=fn %>(solubility, prediction) 25 | #' 26 | #' metric_results 27 | #' 28 | #' # Resampled mean estimate 29 | #' metric_results |> 30 | #' summarise(avg_estimate = mean(.estimate)) 31 | -------------------------------------------------------------------------------- /man-roxygen/multiclass-curve.R: -------------------------------------------------------------------------------- 1 | #' @section Multiclass: 2 | #' 3 | #' If a multiclass `truth` column is provided, a one-vs-all 4 | #' approach will be taken to calculate multiple curves, one per level. 5 | #' In this case, there will be an additional column, `.level`, 6 | #' identifying the "one" column in the one-vs-all calculation. 7 | -------------------------------------------------------------------------------- /man-roxygen/multiclass-prob.R: -------------------------------------------------------------------------------- 1 | #' @section Multiclass: 2 | #' 3 | #' Macro and macro-weighted averaging is available for this metric. 4 | #' The default is to select macro averaging if a `truth` factor with more 5 | #' than 2 levels is provided. Otherwise, a standard binary calculation is done. 6 | #' See `vignette("multiclass", "yardstick")` for more information. 7 | -------------------------------------------------------------------------------- /man-roxygen/multiclass.R: -------------------------------------------------------------------------------- 1 | #' @section Multiclass: 2 | #' 3 | #' Macro, micro, and macro-weighted averaging is available for this metric. 4 | #' The default is to select macro averaging if a `truth` factor with more 5 | #' than 2 levels is provided. Otherwise, a standard binary calculation is done. 6 | #' See `vignette("multiclass", "yardstick")` for more information. 7 | -------------------------------------------------------------------------------- /man-roxygen/return-dynamic-survival.R: -------------------------------------------------------------------------------- 1 | #' @return 2 | #' 3 | #' A `tibble` with columns `.metric`, `.estimator`, and `.estimate`. 4 | #' 5 | #' For an ungrouped data frame, the result has one row of values. For a grouped data frame, 6 | #' the number of rows returned is the same as the number of groups. 7 | #' 8 | #' For `<%=fn %>_vec()`, a `numeric` vector same length as the input argument 9 | #' `eval_time`. (or `NA`). 10 | -------------------------------------------------------------------------------- /man-roxygen/return-fair.R: -------------------------------------------------------------------------------- 1 | #' @return 2 | #' 3 | #' This function outputs a yardstick _fairness metric_ function. Given a 4 | #' grouping variable `by`, `<%=fn %>()` will return a yardstick metric 5 | #' function that is associated with the data-variable grouping `by` and a 6 | #' post-processor. The outputted function will first generate a set 7 | #' of <%=internal_fn %> metric values by group before summarizing across 8 | #' groups using the post-processing function. 9 | #' 10 | #' The outputted function only has a data frame method and is intended to 11 | #' be used as part of a metric set. 12 | -------------------------------------------------------------------------------- /man-roxygen/return-prob.R: -------------------------------------------------------------------------------- 1 | #' @return 2 | #' 3 | #' A `tibble` with columns `.metric`, `.estimator`, 4 | #' and `.estimate` and 1 row of values. 5 | #' 6 | #' For grouped data frames, the number of rows returned will be the same as 7 | #' the number of groups. 8 | 9 | # these don't have _vec() methods 10 | -------------------------------------------------------------------------------- /man-roxygen/return.R: -------------------------------------------------------------------------------- 1 | #' @return 2 | #' 3 | #' A `tibble` with columns `.metric`, `.estimator`, 4 | #' and `.estimate` and 1 row of values. 5 | #' 6 | #' For grouped data frames, the number of rows returned will be the same as 7 | #' the number of groups. 8 | #' 9 | #' For `<%=fn %>_vec()`, a single `numeric` value (or `NA`). 10 | -------------------------------------------------------------------------------- /man-roxygen/table-positive.R: -------------------------------------------------------------------------------- 1 | #' @section Implementation: 2 | #' 3 | #' Suppose a 2x2 table with notation: 4 | #' 5 | #' \tabular{rcc}{ \tab Reference \tab \cr Predicted \tab Positive \tab Negative 6 | #' \cr Positive \tab A \tab B \cr Negative \tab C \tab D \cr } 7 | #' 8 | #' The formulas used here are: 9 | #' 10 | #' \deqn{Sensitivity = A/(A+C)} 11 | #' \deqn{Specificity = D/(B+D)} 12 | #' \deqn{Prevalence = (A+C)/(A+B+C+D)} 13 | #' \deqn{PPV = (Sensitivity * Prevalence) / ((Sensitivity * Prevalence) + ((1-Specificity) * (1-Prevalence)))} 14 | #' \deqn{NPV = (Specificity * (1-Prevalence)) / (((1-Sensitivity) * Prevalence) + ((Specificity) * (1-Prevalence)))} 15 | #' 16 | #' See the references for discussions of the statistics. 17 | -------------------------------------------------------------------------------- /man-roxygen/table-relevance.R: -------------------------------------------------------------------------------- 1 | #' @section Implementation: 2 | #' 3 | #' Suppose a 2x2 table with notation: 4 | #' 5 | #' \tabular{rcc}{ \tab Reference \tab \cr Predicted \tab Relevant \tab 6 | #' Irrelevant \cr Relevant \tab A \tab B \cr Irrelevant \tab C \tab D \cr } 7 | #' 8 | #' The formulas used here are: 9 | #' 10 | #' \deqn{recall = A/(A+C)} 11 | #' \deqn{precision = A/(A+B)} 12 | #' \deqn{F_{meas} = (1+\beta^2) * precision * recall/((\beta^2 * precision)+recall)} 13 | #' 14 | #' See the references for discussions of the statistics. 15 | -------------------------------------------------------------------------------- /man/figures/README-roc-curves-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/man/figures/README-roc-curves-1.png -------------------------------------------------------------------------------- /man/figures/lifecycle-archived.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclearchivedarchived -------------------------------------------------------------------------------- /man/figures/lifecycle-defunct.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecycledefunctdefunct -------------------------------------------------------------------------------- /man/figures/lifecycle-deprecated.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecycledeprecateddeprecated -------------------------------------------------------------------------------- /man/figures/lifecycle-experimental.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecycleexperimentalexperimental -------------------------------------------------------------------------------- /man/figures/lifecycle-maturing.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclematuringmaturing -------------------------------------------------------------------------------- /man/figures/lifecycle-questioning.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclequestioningquestioning -------------------------------------------------------------------------------- /man/figures/lifecycle-stable.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclestablestable -------------------------------------------------------------------------------- /man/figures/lifecycle-superseded.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclesupersededsuperseded -------------------------------------------------------------------------------- /man/figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/man/figures/logo.png -------------------------------------------------------------------------------- /man/hpc_cv.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{hpc_cv} 5 | \alias{hpc_cv} 6 | \title{Multiclass Probability Predictions} 7 | \source{ 8 | Kuhn, M., Johnson, K. (2013) \emph{Applied Predictive 9 | Modeling}, Springer 10 | } 11 | \value{ 12 | \item{hpc_cv}{a data frame} 13 | } 14 | \description{ 15 | Multiclass Probability Predictions 16 | } 17 | \details{ 18 | This data frame contains the predicted classes and 19 | class probabilities for a linear discriminant analysis model fit 20 | to the HPC data set from Kuhn and Johnson (2013). These data are 21 | the assessment sets from a 10-fold cross-validation scheme. The 22 | data column columns for the true class (\code{obs}), the class 23 | prediction (\code{pred}) and columns for each class probability 24 | (columns \code{VF}, \code{F}, \code{M}, and \code{L}). Additionally, a column for 25 | the resample indicator is included. 26 | } 27 | \examples{ 28 | data(hpc_cv) 29 | str(hpc_cv) 30 | 31 | # `obs` is a 4 level factor. The first level is `"VF"`, which is the 32 | # "event of interest" by default in yardstick. See the Relevant Level 33 | # section in any classification function (such as `?pr_auc`) to see how 34 | # to change this. 35 | levels(hpc_cv$obs) 36 | } 37 | \keyword{datasets} 38 | -------------------------------------------------------------------------------- /man/lung_surv.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{lung_surv} 5 | \alias{lung_surv} 6 | \title{Survival Analysis Results} 7 | \value{ 8 | \item{lung_surv}{a data frame} 9 | } 10 | \description{ 11 | Survival Analysis Results 12 | } 13 | \details{ 14 | These data contain plausible results from applying predictive 15 | survival models to the \link[survival]{lung} data set using the censored 16 | package. 17 | } 18 | \examples{ 19 | data(lung_surv) 20 | str(lung_surv) 21 | 22 | # `surv_obj` is a `Surv()` object 23 | } 24 | \keyword{datasets} 25 | -------------------------------------------------------------------------------- /man/metric_tweak.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/metric-tweak.R 3 | \name{metric_tweak} 4 | \alias{metric_tweak} 5 | \title{Tweak a metric function} 6 | \usage{ 7 | metric_tweak(.name, .fn, ...) 8 | } 9 | \arguments{ 10 | \item{.name}{A single string giving the name of the new metric. This will be 11 | used in the \code{".metric"} column of the output.} 12 | 13 | \item{.fn}{An existing yardstick metric function to tweak.} 14 | 15 | \item{...}{Name-value pairs specifying which optional arguments to override 16 | and the values to replace them with. 17 | 18 | Arguments \code{data}, \code{truth}, and \code{estimate} are considered \emph{protected}, 19 | and cannot be overridden, but all other optional arguments can be 20 | altered.} 21 | } 22 | \value{ 23 | A tweaked version of \code{.fn}, updated to use new defaults supplied in \code{...}. 24 | } 25 | \description{ 26 | \code{metric_tweak()} allows you to tweak an existing metric \code{.fn}, giving it a 27 | new \code{.name} and setting new optional argument defaults through \code{...}. It 28 | is similar to \code{purrr::partial()}, but is designed specifically for yardstick 29 | metrics. 30 | 31 | \code{metric_tweak()} is especially useful when constructing a \code{\link[=metric_set]{metric_set()}} for 32 | tuning with the tune package. After the metric set has been constructed, 33 | there is no way to adjust the value of any optional arguments (such as 34 | \code{beta} in \code{\link[=f_meas]{f_meas()}}). Using \code{metric_tweak()}, you can set optional arguments 35 | to custom values ahead of time, before they go into the metric set. 36 | } 37 | \details{ 38 | The function returned from \code{metric_tweak()} only takes \code{...} as arguments, 39 | which are passed through to the original \code{.fn}. Passing \code{data}, \code{truth}, 40 | and \code{estimate} through by position should generally be safe, but it is 41 | recommended to pass any other optional arguments through by name to ensure 42 | that they are evaluated correctly. 43 | } 44 | \examples{ 45 | mase12 <- metric_tweak("mase12", mase, m = 12) 46 | 47 | # Defaults to `m = 1` 48 | mase(solubility_test, solubility, prediction) 49 | 50 | # Updated to use `m = 12`. `mase12()` has this set already. 51 | mase(solubility_test, solubility, prediction, m = 12) 52 | mase12(solubility_test, solubility, prediction) 53 | 54 | # This is most useful to set optional argument values ahead of time when 55 | # using a metric set 56 | mase10 <- metric_tweak("mase10", mase, m = 10) 57 | metrics <- metric_set(mase, mase10, mase12) 58 | metrics(solubility_test, solubility, prediction) 59 | } 60 | -------------------------------------------------------------------------------- /man/new-metric.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/aaa-new.R 3 | \name{new-metric} 4 | \alias{new-metric} 5 | \alias{new_class_metric} 6 | \alias{new_prob_metric} 7 | \alias{new_ordered_prob_metric} 8 | \alias{new_numeric_metric} 9 | \alias{new_dynamic_survival_metric} 10 | \alias{new_integrated_survival_metric} 11 | \alias{new_static_survival_metric} 12 | \title{Construct a new metric function} 13 | \usage{ 14 | new_class_metric(fn, direction) 15 | 16 | new_prob_metric(fn, direction) 17 | 18 | new_ordered_prob_metric(fn, direction) 19 | 20 | new_numeric_metric(fn, direction) 21 | 22 | new_dynamic_survival_metric(fn, direction) 23 | 24 | new_integrated_survival_metric(fn, direction) 25 | 26 | new_static_survival_metric(fn, direction) 27 | } 28 | \arguments{ 29 | \item{fn}{A function. The metric function to attach a metric-specific class 30 | and \code{direction} attribute to.} 31 | 32 | \item{direction}{A string. One of: 33 | \itemize{ 34 | \item \code{"maximize"} 35 | \item \code{"minimize"} 36 | \item \code{"zero"} 37 | }} 38 | } 39 | \description{ 40 | These functions provide convenient wrappers to create the three types of 41 | metric functions in yardstick: numeric metrics, class metrics, and 42 | class probability metrics. They add a metric-specific class to \code{fn} and 43 | attach a \code{direction} attribute. These features are used by \code{\link[=metric_set]{metric_set()}} 44 | and by \href{https://tune.tidymodels.org/}{tune} when model tuning. 45 | 46 | See \href{https://www.tidymodels.org/learn/develop/metrics/}{Custom performance metrics} for more 47 | information about creating custom metrics. 48 | } 49 | -------------------------------------------------------------------------------- /man/pathology.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{pathology} 5 | \alias{pathology} 6 | \title{Liver Pathology Data} 7 | \source{ 8 | Altman, D.G., Bland, J.M. (1994) ``Diagnostic tests 1: 9 | sensitivity and specificity,'' \emph{British Medical Journal}, 10 | vol 308, 1552. 11 | } 12 | \value{ 13 | \item{pathology}{a data frame} 14 | } 15 | \description{ 16 | Liver Pathology Data 17 | } 18 | \details{ 19 | These data have the results of a \emph{x}-ray examination 20 | to determine whether liver is abnormal or not (in the \code{scan} 21 | column) versus the more extensive pathology results that 22 | approximate the truth (in \code{pathology}). 23 | } 24 | \examples{ 25 | data(pathology) 26 | str(pathology) 27 | } 28 | \keyword{datasets} 29 | -------------------------------------------------------------------------------- /man/reexports.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/reexports.R 3 | \docType{import} 4 | \name{reexports} 5 | \alias{reexports} 6 | \alias{tidy} 7 | \title{Objects exported from other packages} 8 | \keyword{internal} 9 | \description{ 10 | These objects are imported from other packages. Follow the links 11 | below to see their documentation. 12 | 13 | \describe{ 14 | \item{generics}{\code{\link[generics]{tidy}}} 15 | }} 16 | 17 | -------------------------------------------------------------------------------- /man/solubility_test.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{solubility_test} 5 | \alias{solubility_test} 6 | \title{Solubility Predictions from MARS Model} 7 | \source{ 8 | Kuhn, M., Johnson, K. (2013) \emph{Applied Predictive 9 | Modeling}, Springer 10 | } 11 | \value{ 12 | \item{solubility_test}{a data frame} 13 | } 14 | \description{ 15 | Solubility Predictions from MARS Model 16 | } 17 | \details{ 18 | For the solubility data in Kuhn and Johnson (2013), 19 | these data are the test set results for the MARS model. The 20 | observed solubility (in column \code{solubility}) and the model 21 | results (\code{prediction}) are contained in the data. 22 | } 23 | \examples{ 24 | data(solubility_test) 25 | str(solubility_test) 26 | } 27 | \keyword{datasets} 28 | -------------------------------------------------------------------------------- /man/two_class_example.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{two_class_example} 5 | \alias{two_class_example} 6 | \title{Two Class Predictions} 7 | \value{ 8 | \item{two_class_example}{a data frame} 9 | } 10 | \description{ 11 | Two Class Predictions 12 | } 13 | \details{ 14 | These data are a test set form a model built for two 15 | classes ("Class1" and "Class2"). There are columns for the true 16 | and predicted classes and column for the probabilities for each 17 | class. 18 | } 19 | \examples{ 20 | data(two_class_example) 21 | str(two_class_example) 22 | 23 | # `truth` is a 2 level factor. The first level is `"Class1"`, which is the 24 | # "event of interest" by default in yardstick. See the Relevant Level 25 | # section in any classification function (such as `?pr_auc`) to see how 26 | # to change this. 27 | levels(hpc_cv$obs) 28 | } 29 | \keyword{datasets} 30 | -------------------------------------------------------------------------------- /man/yardstick-package.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/yardstick-package.R 3 | \docType{package} 4 | \name{yardstick-package} 5 | \alias{yardstick} 6 | \alias{yardstick-package} 7 | \title{yardstick: Tidy Characterizations of Model Performance} 8 | \description{ 9 | \if{html}{\figure{logo.png}{options: style='float: right' alt='logo' width='120'}} 10 | 11 | Tidy tools for quantifying how well model fits to a data set such as confusion matrices, class probability curve summaries, and regression metrics (e.g., RMSE). 12 | } 13 | \seealso{ 14 | Useful links: 15 | \itemize{ 16 | \item \url{https://github.com/tidymodels/yardstick} 17 | \item \url{https://yardstick.tidymodels.org} 18 | \item Report bugs at \url{https://github.com/tidymodels/yardstick/issues} 19 | } 20 | 21 | } 22 | \author{ 23 | \strong{Maintainer}: Emil Hvitfeldt \email{emil.hvitfeldt@posit.co} (\href{https://orcid.org/0000-0002-0679-1945}{ORCID}) 24 | 25 | Authors: 26 | \itemize{ 27 | \item Max Kuhn \email{max@posit.co} 28 | \item Davis Vaughan \email{davis@posit.co} 29 | } 30 | 31 | Other contributors: 32 | \itemize{ 33 | \item Posit Software, PBC (03wc8by49) [copyright holder, funder] 34 | } 35 | 36 | } 37 | \keyword{internal} 38 | -------------------------------------------------------------------------------- /man/yardstick_remove_missing.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/missings.R 3 | \name{yardstick_remove_missing} 4 | \alias{yardstick_remove_missing} 5 | \alias{yardstick_any_missing} 6 | \title{Developer function for handling missing values in new metrics} 7 | \usage{ 8 | yardstick_remove_missing(truth, estimate, case_weights) 9 | 10 | yardstick_any_missing(truth, estimate, case_weights) 11 | } 12 | \arguments{ 13 | \item{truth, estimate}{Vectors of the same length.} 14 | 15 | \item{case_weights}{A vector of the same length as \code{truth} and \code{estimate}, or 16 | \code{NULL} if case weights are not being used.} 17 | } 18 | \description{ 19 | \code{yardstick_remove_missing()}, and \code{yardstick_any_missing()} are useful 20 | alongside the \link{metric-summarizers} functions for implementing new custom 21 | metrics. \code{yardstick_remove_missing()} removes any observations that contains 22 | missing values across, \code{truth}, \code{estimate} and \code{case_weights}. 23 | \code{yardstick_any_missing()} returns \code{FALSE} if there is any missing values in 24 | the inputs. 25 | } 26 | \seealso{ 27 | \link{metric-summarizers} 28 | } 29 | -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-120x120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/pkgdown/favicon/apple-touch-icon-120x120.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-60x60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/pkgdown/favicon/apple-touch-icon-60x60.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-76x76.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/pkgdown/favicon/apple-touch-icon-76x76.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/pkgdown/favicon/apple-touch-icon.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/pkgdown/favicon/favicon-16x16.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/pkgdown/favicon/favicon-32x32.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/pkgdown/favicon/favicon.ico -------------------------------------------------------------------------------- /revdep/.gitignore: -------------------------------------------------------------------------------- 1 | checks 2 | library 3 | checks.noindex 4 | library.noindex 5 | data.sqlite 6 | *.html 7 | cloud.noindex 8 | -------------------------------------------------------------------------------- /revdep/README.md: -------------------------------------------------------------------------------- 1 | # Platform 2 | 3 | |field |value | 4 | |:--------|:---------------------------------------| 5 | |version |R version 4.4.2 (2024-10-31) | 6 | |os |macOS Sequoia 15.2 | 7 | |system |aarch64, darwin20 | 8 | |ui |X11 | 9 | |language |(EN) | 10 | |collate |en_US.UTF-8 | 11 | |ctype |en_US.UTF-8 | 12 | |tz |America/Los_Angeles | 13 | |date |2025-01-22 | 14 | |pandoc |3.6.1 @ /usr/local/bin/ (via rmarkdown) | 15 | 16 | # Dependencies 17 | 18 | |package |old |new |Δ | 19 | |:----------|:------|:----------|:--| 20 | |yardstick |1.3.1 |1.3.1.9000 |* | 21 | |cli |3.6.3 |3.6.3 | | 22 | |dplyr |1.1.4 |1.1.4 | | 23 | |fansi |1.0.6 |1.0.6 | | 24 | |generics |0.1.3 |0.1.3 | | 25 | |glue |1.8.0 |1.8.0 | | 26 | |hardhat |1.4.0 |1.4.0 | | 27 | |lifecycle |1.0.4 |1.0.4 | | 28 | |magrittr |2.0.3 |2.0.3 | | 29 | |pillar |1.10.1 |1.10.1 | | 30 | |pkgconfig |2.0.3 |2.0.3 | | 31 | |R6 |2.5.1 |2.5.1 | | 32 | |rlang |1.1.4 |1.1.4 | | 33 | |tibble |3.2.1 |3.2.1 | | 34 | |tidyselect |1.2.1 |1.2.1 | | 35 | |utf8 |1.2.4 |1.2.4 | | 36 | |vctrs |0.6.5 |0.6.5 | | 37 | |withr |3.0.2 |3.0.2 | | 38 | 39 | # Revdeps 40 | 41 | ## Failed to check (3) 42 | 43 | |package |version |error |warning |note | 44 | |:-------|:-------|:-----|:-------|:----| 45 | |ldmppr |1.0.3 |1 | | | 46 | |rTwig |1.3.0 |1 | | | 47 | |shapr |1.0.1 |1 | | | 48 | 49 | -------------------------------------------------------------------------------- /revdep/cran.md: -------------------------------------------------------------------------------- 1 | ## revdepcheck results 2 | 3 | We checked 41 reverse dependencies (40 from CRAN + 1 from Bioconductor), comparing R CMD check results across CRAN and dev versions of this package. 4 | 5 | * We saw 0 new problems 6 | * We failed to check 3 packages 7 | 8 | Issues with CRAN packages are summarised below. 9 | 10 | ### Failed to check 11 | 12 | * ldmppr (NA) 13 | * rTwig (NA) 14 | * shapr (NA) 15 | -------------------------------------------------------------------------------- /revdep/email.yml: -------------------------------------------------------------------------------- 1 | release_date: ??? 2 | rel_release_date: ??? 3 | my_news_url: ??? 4 | release_version: ??? 5 | release_details: ??? 6 | -------------------------------------------------------------------------------- /revdep/problems.md: -------------------------------------------------------------------------------- 1 | *Wow, no problems at all. :)* -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | *.so 3 | *.dll 4 | -------------------------------------------------------------------------------- /src/init.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // for NULL 4 | #include 5 | #include 6 | 7 | extern SEXP yardstick_mcc_multiclass_impl(SEXP); 8 | 9 | static const R_CallMethodDef CallEntries[] = { 10 | {"yardstick_mcc_multiclass_impl", (DL_FUNC) &yardstick_mcc_multiclass_impl, 1}, 11 | {NULL, NULL, 0} 12 | }; 13 | 14 | void R_init_yardstick(DllInfo *dll) { 15 | R_registerRoutines(dll, NULL, CallEntries, NULL, NULL); 16 | R_useDynamicSymbols(dll, FALSE); 17 | } 18 | -------------------------------------------------------------------------------- /src/yardstick.h: -------------------------------------------------------------------------------- 1 | #ifndef YARDSTICK_H 2 | #define YARDSTICK_H 3 | 4 | #define R_NO_REMAP 5 | #include 6 | #include 7 | 8 | #endif 9 | -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(yardstick) 3 | 4 | test_check("yardstick") 5 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/aaa-new.md: -------------------------------------------------------------------------------- 1 | # `fn` is validated 2 | 3 | Code 4 | new_class_metric(1, "maximize") 5 | Condition 6 | Error in `new_class_metric()`: 7 | ! `fn` must be a function, not the number 1. 8 | 9 | # `direction` is validated 10 | 11 | Code 12 | new_class_metric(function() 1, "min") 13 | Condition 14 | Error in `new_class_metric()`: 15 | ! `direction` must be one of "maximize", "minimize", or "zero", not "min". 16 | i Did you mean "minimize"? 17 | 18 | # metric print method works 19 | 20 | Code 21 | rmse 22 | Output 23 | A numeric metric | direction: minimize 24 | 25 | --- 26 | 27 | Code 28 | roc_auc 29 | Output 30 | A probability metric | direction: maximize 31 | 32 | --- 33 | 34 | Code 35 | demographic_parity(boop) 36 | Output 37 | A class metric | direction: minimize, group-wise on: boop 38 | 39 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/autoplot.md: -------------------------------------------------------------------------------- 1 | # Confusion Matrix - type argument 2 | 3 | Code 4 | ggplot2::autoplot(res, type = "wrong") 5 | Condition 6 | Error in `ggplot2::autoplot()`: 7 | ! `type` must be one of "mosaic" or "heatmap", not "wrong". 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/class-accuracy.md: -------------------------------------------------------------------------------- 1 | # work with class_pred input 2 | 3 | Code 4 | accuracy_vec(cp_truth, cp_estimate) 5 | Condition 6 | Error in `accuracy_vec()`: 7 | ! `truth` should not a object. 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/class-bal_accuracy.md: -------------------------------------------------------------------------------- 1 | # work with class_pred input 2 | 3 | Code 4 | bal_accuracy_vec(cp_truth, cp_estimate) 5 | Condition 6 | Error in `bal_accuracy_vec()`: 7 | ! `truth` should not a object. 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/class-detection_prevalence.md: -------------------------------------------------------------------------------- 1 | # work with class_pred input 2 | 3 | Code 4 | detection_prevalence_vec(cp_truth, cp_estimate) 5 | Condition 6 | Error in `detection_prevalence_vec()`: 7 | ! `truth` should not a object. 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/class-j_index.md: -------------------------------------------------------------------------------- 1 | # Binary `j_index()` returns `NA` with a warning when sensitivity is undefined (tp + fn = 0) (#265) 2 | 3 | Code 4 | out <- j_index_vec(truth, estimate) 5 | Condition 6 | Warning: 7 | While computing binary `sens()`, no true events were detected (i.e. `true_positive + false_negative = 0`). 8 | Sensitivity is undefined in this case, and `NA` will be returned. 9 | Note that 1 predicted event(s) actually occurred for the problematic event level, a 10 | 11 | # Binary `j_index()` returns `NA` with a warning when specificity is undefined (tn + fp = 0) (#265) 12 | 13 | Code 14 | out <- j_index_vec(truth, estimate) 15 | Condition 16 | Warning: 17 | While computing binary `spec()`, no true negatives were detected (i.e. `true_negative + false_positive = 0`). 18 | Specificity is undefined in this case, and `NA` will be returned. 19 | Note that 1 predicted negatives(s) actually occurred for the problematic event level, a 20 | 21 | # Multiclass `j_index()` returns averaged value with `NA`s removed + a warning when sensitivity is undefined (tp + fn = 0) (#265) 22 | 23 | Code 24 | out <- j_index_vec(truth, estimate) 25 | Condition 26 | Warning: 27 | While computing multiclass `sens()`, some levels had no true events (i.e. `true_positive + false_negative = 0`). 28 | Sensitivity is undefined in this case, and those levels will be removed from the averaged result. 29 | Note that the following number of predicted events actually occurred for each problematic event level: 30 | 'c': 1 31 | 32 | # Multiclass `j_index()` returns averaged value with `NA`s removed + a warning when specificity is undefined (tn + fp = 0) (#265) 33 | 34 | Code 35 | out <- j_index_vec(truth, estimate) 36 | Condition 37 | Warning: 38 | While computing multiclass `sens()`, some levels had no true events (i.e. `true_positive + false_negative = 0`). 39 | Sensitivity is undefined in this case, and those levels will be removed from the averaged result. 40 | Note that the following number of predicted events actually occurred for each problematic event level: 41 | 'b': 1, 'c': 1 42 | Warning: 43 | While computing multiclass `spec()`, some levels had no true negatives (i.e. `true_negative + false_positive = 0`). 44 | Specificity is undefined in this case, and those levels will be removed from the averaged result. 45 | Note that the following number of predicted negatives actually occurred for each problematic event level: 46 | 'a': 2 47 | 48 | # work with class_pred input 49 | 50 | Code 51 | j_index_vec(cp_truth, cp_estimate) 52 | Condition 53 | Error in `j_index_vec()`: 54 | ! `truth` should not a object. 55 | 56 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/class-kap.md: -------------------------------------------------------------------------------- 1 | # kap errors with wrong `weighting` 2 | 3 | Code 4 | kap(three_class, truth = "obs", estimate = "pred", weighting = 1) 5 | Condition 6 | Error in `kap()`: 7 | ! `weighting` must be a single string, not the number 1. 8 | 9 | --- 10 | 11 | Code 12 | kap(three_class, truth = "obs", estimate = "pred", weighting = "not right") 13 | Condition 14 | Error in `kap()`: 15 | ! `weighting` must be "none", "linear", or "quadratic", not "not right". 16 | 17 | # work with class_pred input 18 | 19 | Code 20 | kap_vec(cp_truth, cp_estimate) 21 | Condition 22 | Error in `kap_vec()`: 23 | ! `truth` should not a object. 24 | 25 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/class-mcc.md: -------------------------------------------------------------------------------- 1 | # work with class_pred input 2 | 3 | Code 4 | mcc_vec(cp_truth, cp_estimate) 5 | Condition 6 | Error in `mcc_vec()`: 7 | ! `truth` should not a object. 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/class-npv.md: -------------------------------------------------------------------------------- 1 | # work with class_pred input 2 | 3 | Code 4 | accuracy_vec(cp_truth, cp_estimate) 5 | Condition 6 | Error in `accuracy_vec()`: 7 | ! `truth` should not a object. 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/class-ppv.md: -------------------------------------------------------------------------------- 1 | # Binary `ppv()` returns `NA` with a warning when `sens()` is undefined (tp + fn = 0) (#101) 2 | 3 | Code 4 | out <- ppv_vec(truth, estimate) 5 | Condition 6 | Warning: 7 | While computing binary `sens()`, no true events were detected (i.e. `true_positive + false_negative = 0`). 8 | Sensitivity is undefined in this case, and `NA` will be returned. 9 | Note that 1 predicted event(s) actually occurred for the problematic event level, a 10 | 11 | # work with class_pred input 12 | 13 | Code 14 | ppv_vec(cp_truth, cp_estimate) 15 | Condition 16 | Error in `ppv_vec()`: 17 | ! `truth` should not a object. 18 | 19 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/class-precision.md: -------------------------------------------------------------------------------- 1 | # Binary `precision()` returns `NA` with a warning when undefined (tp + fp = 0) (#98) 2 | 3 | Code 4 | out <- precision_vec(truth, estimate) 5 | Condition 6 | Warning: 7 | While computing binary `precision()`, no predicted events were detected (i.e. `true_positive + false_positive = 0`). 8 | Precision is undefined in this case, and `NA` will be returned. 9 | Note that 1 true event(s) actually occurred for the problematic event level, a 10 | 11 | # Multiclass `precision()` returns averaged value with `NA`s removed + a warning when undefined (tp + fp = 0) (#98) 12 | 13 | Code 14 | out <- precision_vec(truth, estimate) 15 | Condition 16 | Warning: 17 | While computing multiclass `precision()`, some levels had no predicted events (i.e. `true_positive + false_positive = 0`). 18 | Precision is undefined in this case, and those levels will be removed from the averaged result. 19 | Note that the following number of true events actually occurred for each problematic event level: 20 | 'a': 1, 'b': 1, 'c': 1 21 | 22 | --- 23 | 24 | Code 25 | out <- precision_vec(truth, estimate) 26 | Condition 27 | Warning: 28 | While computing multiclass `precision()`, some levels had no predicted events (i.e. `true_positive + false_positive = 0`). 29 | Precision is undefined in this case, and those levels will be removed from the averaged result. 30 | Note that the following number of true events actually occurred for each problematic event level: 31 | 'c': 1 32 | 33 | # work with class_pred input 34 | 35 | Code 36 | precision_vec(cp_truth, cp_estimate) 37 | Condition 38 | Error in `precision_vec()`: 39 | ! `truth` should not a object. 40 | 41 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/class-recall.md: -------------------------------------------------------------------------------- 1 | # Binary `recall()` returns `NA` with a warning when undefined (tp + fn = 0) (#98) 2 | 3 | Code 4 | out <- recall_vec(truth, estimate) 5 | Condition 6 | Warning: 7 | While computing binary `recall()`, no true events were detected (i.e. `true_positive + false_negative = 0`). 8 | Recall is undefined in this case, and `NA` will be returned. 9 | Note that 1 predicted event(s) actually occurred for the problematic event level a 10 | 11 | # Multiclass `recall()` returns averaged value with `NA`s removed + a warning when undefined (tp + fn = 0) (#98) 12 | 13 | Code 14 | out <- recall_vec(truth, estimate) 15 | Condition 16 | Warning: 17 | While computing multiclass `recall()`, some levels had no true events (i.e. `true_positive + false_negative = 0`). 18 | Recall is undefined in this case, and those levels will be removed from the averaged result. 19 | Note that the following number of predicted events actually occurred for each problematic event level: 20 | 'b': 0, 'c': 1 21 | 22 | # work with class_pred input 23 | 24 | Code 25 | recall_vec(cp_truth, cp_estimate) 26 | Condition 27 | Error in `recall_vec()`: 28 | ! `truth` should not a object. 29 | 30 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/class-sens.md: -------------------------------------------------------------------------------- 1 | # Binary `sens()` returns `NA` with a warning when undefined (tp + fn = 0) (#98) 2 | 3 | Code 4 | out <- sens_vec(truth, estimate) 5 | Condition 6 | Warning: 7 | While computing binary `sens()`, no true events were detected (i.e. `true_positive + false_negative = 0`). 8 | Sensitivity is undefined in this case, and `NA` will be returned. 9 | Note that 1 predicted event(s) actually occurred for the problematic event level, a 10 | 11 | # Multiclass `sens()` returns averaged value with `NA`s removed + a warning when undefined (tp + fn = 0) (#98) 12 | 13 | Code 14 | out <- sens_vec(truth, estimate) 15 | Condition 16 | Warning: 17 | While computing multiclass `sens()`, some levels had no true events (i.e. `true_positive + false_negative = 0`). 18 | Sensitivity is undefined in this case, and those levels will be removed from the averaged result. 19 | Note that the following number of predicted events actually occurred for each problematic event level: 20 | 'b': 0, 'c': 1 21 | 22 | # work with class_pred input 23 | 24 | Code 25 | sensitivity_vec(cp_truth, cp_estimate) 26 | Condition 27 | Error in `sensitivity_vec()`: 28 | ! `truth` should not a object. 29 | 30 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/class-spec.md: -------------------------------------------------------------------------------- 1 | # Binary `spec()` returns `NA` with a warning when undefined (tn + fp = 0) (#98) 2 | 3 | Code 4 | out <- spec_vec(truth, estimate) 5 | Condition 6 | Warning: 7 | While computing binary `spec()`, no true negatives were detected (i.e. `true_negative + false_positive = 0`). 8 | Specificity is undefined in this case, and `NA` will be returned. 9 | Note that 1 predicted negatives(s) actually occurred for the problematic event level, a 10 | 11 | # Multiclass `spec()` returns averaged value with `NA`s removed + a warning when undefined (tn + fp = 0) (#98) 12 | 13 | Code 14 | out <- spec_vec(truth, estimate) 15 | Condition 16 | Warning: 17 | While computing multiclass `spec()`, some levels had no true negatives (i.e. `true_negative + false_positive = 0`). 18 | Specificity is undefined in this case, and those levels will be removed from the averaged result. 19 | Note that the following number of predicted negatives actually occurred for each problematic event level: 20 | 'a': 2 21 | 22 | # work with class_pred input 23 | 24 | Code 25 | specificity_vec(cp_truth, cp_estimate) 26 | Condition 27 | Error in `specificity_vec()`: 28 | ! `truth` should not a object. 29 | 30 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/deprecated-template.md: -------------------------------------------------------------------------------- 1 | # metric_summarizer() is soft-deprecated 2 | 3 | Code 4 | tmp <- metric_summarizer(metric_nm = "rmse", metric_fn = rmse_vec, data = mtcars, 5 | truth = mpg, estimate = disp, na_rm = TRUE, case_weights = NULL) 6 | Condition 7 | Warning: 8 | `metric_summarizer()` was deprecated in yardstick 1.2.0. 9 | i Please use `numeric_metric_summarizer()`, `class_metric_summarizer()`, `prob_metric_summarizer()`, or `curve_metric_summarizer()` instead. 10 | 11 | # metric_summarizer()'s errors when wrong things are passes 12 | 13 | Code 14 | metric_summarizer(metric_nm = "rmse", metric_fn = rmse_vec, data = mtcars, 15 | truth = not_a_real_column_name, estimate = disp) 16 | Condition 17 | Error in `dplyr::summarise()`: 18 | i In argument: `.estimator = eval_tidy(finalize_estimator_expr)`. 19 | Caused by error: 20 | ! object 'not_a_real_column_name' not found 21 | 22 | --- 23 | 24 | Code 25 | metric_summarizer(metric_nm = "rmse", metric_fn = rmse_vec, data = mtcars, 26 | truth = mpg, estimate = not_a_real_column_name) 27 | Condition 28 | Error in `dplyr::summarise()`: 29 | i In argument: `.estimate = metric_fn(truth = mpg, estimate = not_a_real_column_name, na_rm = na_rm)`. 30 | Caused by error: 31 | ! object 'not_a_real_column_name' not found 32 | 33 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/estimator-helpers.md: -------------------------------------------------------------------------------- 1 | # get_weights() errors with wrong estimator 2 | 3 | Code 4 | get_weights(mtcars, "wrong") 5 | Condition 6 | Error in `get_weights()`: 7 | ! `estimator` type "wrong" is unknown. 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/event-level.md: -------------------------------------------------------------------------------- 1 | # `yardstick_event_level()` ignores option - TRUE, with a warning 2 | 3 | Code 4 | out <- yardstick_event_level() 5 | Condition 6 | Warning: 7 | The global option `yardstick.event_first` was deprecated in yardstick 0.0.7. 8 | i Please use the metric function argument `event_level` instead. 9 | i The global option is being ignored entirely. 10 | 11 | # `yardstick_event_level()` ignores option - FALSE, with a warning 12 | 13 | Code 14 | out <- yardstick_event_level() 15 | Condition 16 | Warning: 17 | The global option `yardstick.event_first` was deprecated in yardstick 0.0.7. 18 | i Please use the metric function argument `event_level` instead. 19 | i The global option is being ignored entirely. 20 | 21 | # validate_event_level() works 22 | 23 | Code 24 | recall(two_class_example, truth, predicted, event_level = "wrong") 25 | Condition 26 | Error in `recall()`: 27 | ! `event_level` must be "first" or "second". 28 | 29 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/flatten.md: -------------------------------------------------------------------------------- 1 | # flat tables 2 | 3 | Code 4 | yardstick:::flatten(three_class_tb[, 1:2]) 5 | Condition 6 | Error: 7 | ! `x` must have equal dimensions. `x` has 2 columns and 3 rows. 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/metric-tweak.md: -------------------------------------------------------------------------------- 1 | # cannot use protected names 2 | 3 | Code 4 | metric_tweak("f_meas2", f_meas, data = 2) 5 | Condition 6 | Error in `metric_tweak()`: 7 | ! Arguments passed through `...` cannot be named any of: `data`, `truth`, and `estimate`. 8 | 9 | --- 10 | 11 | Code 12 | metric_tweak("f_meas2", f_meas, truth = 2) 13 | Condition 14 | Error in `metric_tweak()`: 15 | ! Arguments passed through `...` cannot be named any of: `data`, `truth`, and `estimate`. 16 | 17 | --- 18 | 19 | Code 20 | metric_tweak("f_meas2", f_meas, estimate = 2) 21 | Condition 22 | Error in `metric_tweak()`: 23 | ! Arguments passed through `...` cannot be named any of: `data`, `truth`, and `estimate`. 24 | 25 | # `name` must be a string 26 | 27 | Code 28 | metric_tweak(1, f_meas, beta = 2) 29 | Condition 30 | Error in `metric_tweak()`: 31 | ! `.name` must be a single string, not the number 1. 32 | 33 | # `fn` must be a metric function 34 | 35 | Code 36 | metric_tweak("foo", function() { }, beta = 2) 37 | Condition 38 | Error in `metric_tweak()`: 39 | ! `.fn` must be a metric function, not a function. 40 | 41 | # All `...` must be named 42 | 43 | Code 44 | metric_tweak("foo", accuracy, 1) 45 | Condition 46 | Error in `metric_tweak()`: 47 | ! All arguments passed through `...` must be named. 48 | 49 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/num-huber_loss.md: -------------------------------------------------------------------------------- 1 | # Huber Loss 2 | 3 | Code 4 | huber_loss(ex_dat, truth = "obs", estimate = "pred_na", delta = -1) 5 | Condition 6 | Error in `huber_loss()`: 7 | ! `delta` must be a number larger than or equal to 0, not the number -1. 8 | 9 | --- 10 | 11 | Code 12 | huber_loss(ex_dat, truth = "obs", estimate = "pred_na", delta = c(1, 2)) 13 | Condition 14 | Error in `huber_loss()`: 15 | ! `delta` must be a number, not a double vector. 16 | 17 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/num-mase.md: -------------------------------------------------------------------------------- 1 | # Mean Absolute Scaled Error 2 | 3 | Code 4 | mase_vec(truth, pred, m = "x") 5 | Condition 6 | Error in `mase_vec()`: 7 | ! `m` must be a whole number, not the string "x". 8 | 9 | --- 10 | 11 | Code 12 | mase_vec(truth, pred, m = -1) 13 | Condition 14 | Error in `mase_vec()`: 15 | ! `m` must be a whole number larger than or equal to 0, not the number -1. 16 | 17 | --- 18 | 19 | Code 20 | mase_vec(truth, pred, m = 1.5) 21 | Condition 22 | Error in `mase_vec()`: 23 | ! `m` must be a whole number, not the number 1.5. 24 | 25 | --- 26 | 27 | Code 28 | mase_vec(truth, pred, mae_train = -1) 29 | Condition 30 | Error in `mase_vec()`: 31 | ! `mae_train` must be a number larger than or equal to 0 or `NULL`, not the number -1. 32 | 33 | --- 34 | 35 | Code 36 | mase_vec(truth, pred, mae_train = "x") 37 | Condition 38 | Error in `mase_vec()`: 39 | ! `mae_train` must be a number or `NULL`, not the string "x". 40 | 41 | # mase() errors if m is larger than number of observations 42 | 43 | Code 44 | mase(mtcars, mpg, disp, m = 100) 45 | Condition 46 | Error in `mase()`: 47 | ! `truth` (32) must have a length greater than `m` (100) to compute the out-of-sample naive mean absolute error. 48 | 49 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/num-pseudo_huber_loss.md: -------------------------------------------------------------------------------- 1 | # Pseudo-Huber Loss 2 | 3 | Code 4 | huber_loss_pseudo(ex_dat, truth = "obs", estimate = "pred_na", delta = -1) 5 | Condition 6 | Error in `huber_loss_pseudo()`: 7 | ! `delta` must be a number larger than or equal to 0, not the number -1. 8 | 9 | --- 10 | 11 | Code 12 | huber_loss_pseudo(ex_dat, truth = "obs", estimate = "pred_na", delta = c(1, 2)) 13 | Condition 14 | Error in `huber_loss_pseudo()`: 15 | ! `delta` must be a number, not a double vector. 16 | 17 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/num-rsq.md: -------------------------------------------------------------------------------- 1 | # yardstick correlation warnings are thrown 2 | 3 | Code 4 | (expect_warning(object = out <- rsq_vec(1, 1), class = "yardstick_warning_correlation_undefined_size_zero_or_one") 5 | ) 6 | Output 7 | 8 | Warning: 9 | A correlation computation is required, but the inputs are size zero or one and the standard deviation cannot be computed. `NA` will be returned. 10 | 11 | --- 12 | 13 | Code 14 | (expect_warning(object = out <- rsq_vec(double(), double()), class = "yardstick_warning_correlation_undefined_size_zero_or_one") 15 | ) 16 | Output 17 | 18 | Warning: 19 | A correlation computation is required, but the inputs are size zero or one and the standard deviation cannot be computed. `NA` will be returned. 20 | 21 | --- 22 | 23 | Code 24 | (expect_warning(object = out <- rsq_vec(c(1, 2), c(1, 1)), class = "yardstick_warning_correlation_undefined_constant_estimate") 25 | ) 26 | Output 27 | 28 | Warning: 29 | A correlation computation is required, but `estimate` is constant and has 0 standard deviation, resulting in a divide by 0 error. `NA` will be returned. 30 | 31 | --- 32 | 33 | Code 34 | (expect_warning(object = out <- rsq_vec(c(1, 1), c(1, 2)), class = "yardstick_warning_correlation_undefined_constant_truth") 35 | ) 36 | Output 37 | 38 | Warning: 39 | A correlation computation is required, but `truth` is constant and has 0 standard deviation, resulting in a divide by 0 error. `NA` will be returned. 40 | 41 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/orderedprob-ranked_prob_score.md: -------------------------------------------------------------------------------- 1 | # errors with bad input 2 | 3 | Code 4 | ranked_prob_score_vec(cp_truth, estimate) 5 | Condition 6 | Error in `ranked_prob_score_vec()`: 7 | ! `truth` should not a object. 8 | 9 | --- 10 | 11 | Code 12 | ranked_prob_score_vec(two_class_example$truth, estimate) 13 | Condition 14 | Error in `ranked_prob_score_vec()`: 15 | ! `truth` should be a ordered factor, not a a object. 16 | 17 | --- 18 | 19 | Code 20 | ranked_prob_score_vec(ord_truth, estimate_1D) 21 | Condition 22 | Error in `ranked_prob_score_vec()`: 23 | ! The number of levels in `truth` (2) must match the number of columns supplied in `...` (1). 24 | 25 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/prob-average_precision.md: -------------------------------------------------------------------------------- 1 | # known corner cases are correct 2 | 3 | Code 4 | out <- average_precision(df, truth, estimate)$.estimate 5 | Condition 6 | Warning: 7 | There are 0 event cases in `truth`, results will be meaningless. 8 | 9 | --- 10 | 11 | Code 12 | out <- average_precision(df, truth, estimate)$.estimate 13 | Condition 14 | Warning: 15 | There are 0 event cases in `truth`, results will be meaningless. 16 | 17 | --- 18 | 19 | Code 20 | expect <- pr_auc(df, truth, estimate)$.estimate 21 | Condition 22 | Warning: 23 | There are 0 event cases in `truth`, results will be meaningless. 24 | 25 | # errors with class_pred input 26 | 27 | Code 28 | average_precision_vec(cp_truth, estimate) 29 | Condition 30 | Error in `average_precision_vec()`: 31 | ! `truth` should not a object. 32 | 33 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/prob-brier_class.md: -------------------------------------------------------------------------------- 1 | # errors with class_pred input 2 | 3 | Code 4 | brier_class_vec(cp_truth, estimate) 5 | Condition 6 | Error in `brier_class_vec()`: 7 | ! `truth` should not a object. 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/prob-classification_cost.md: -------------------------------------------------------------------------------- 1 | # binary - requires 1 column of probabilities 2 | 3 | Code 4 | classification_cost(two_class_example, truth, Class1:Class2) 5 | Condition 6 | Error in `classification_cost()`: 7 | ! You are using a binary metric but have passed multiple columns to `...`. 8 | 9 | # costs must be a data frame with the right column names 10 | 11 | Code 12 | classification_cost(df, obs, A, costs = 1) 13 | Condition 14 | Error in `classification_cost()`: 15 | ! `costs` must be a data frame or `NULL`, not the number 1. 16 | 17 | --- 18 | 19 | Code 20 | classification_cost(df, obs, A, costs = data.frame()) 21 | Condition 22 | Error in `classification_cost()`: 23 | ! `costs` must be a data.frame with 3 columns, not 0. 24 | 25 | --- 26 | 27 | Code 28 | classification_cost(df, obs, A, costs = data.frame(x = 1, y = 2, z = 3)) 29 | Condition 30 | Error in `classification_cost()`: 31 | ! `costs` must have columns: "truth", "estimate", and "cost". Not x, y, and z. 32 | 33 | # costs$estimate must contain the right levels 34 | 35 | Code 36 | classification_cost(df, obs, A, costs = costs) 37 | Condition 38 | Error in `classification_cost()`: 39 | ! `costs$estimate` can only contain 'A', 'B'. 40 | 41 | # costs$truth must contain the right levels 42 | 43 | Code 44 | classification_cost(df, obs, A, costs = costs) 45 | Condition 46 | Error in `classification_cost()`: 47 | ! `costs$truth` can only contain 'A', 'B'. 48 | 49 | # costs$truth, costs$estimate, and costs$cost must have the right type 50 | 51 | Code 52 | classification_cost(df, obs, A, costs = costs) 53 | Condition 54 | Error in `classification_cost()`: 55 | ! `costs$truth` must be a character or factor column, not a double vector. 56 | 57 | --- 58 | 59 | Code 60 | classification_cost(df, obs, A, costs = costs) 61 | Condition 62 | Error in `classification_cost()`: 63 | ! `costs$estimate` must be a character or factor column, not a double vector. 64 | 65 | --- 66 | 67 | Code 68 | classification_cost(df, obs, A, costs = costs) 69 | Condition 70 | Error in `classification_cost()`: 71 | ! `costs$cost` must be a numeric column, not a character vector. 72 | 73 | # costs$truth and costs$estimate cannot contain duplicate pairs 74 | 75 | Code 76 | classification_cost(df, obs, A, costs = costs) 77 | Condition 78 | Error in `classification_cost()`: 79 | ! costs cannot have duplicate truth / estimate combinations. 80 | 81 | # errors with class_pred input 82 | 83 | Code 84 | classification_cost_vec(cp_truth, estimate) 85 | Condition 86 | Error in `classification_cost_vec()`: 87 | ! `truth` should not a object. 88 | 89 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/prob-gain_capture.md: -------------------------------------------------------------------------------- 1 | # errors with class_pred input 2 | 3 | Code 4 | gain_capture_vec(cp_truth, estimate) 5 | Condition 6 | Error in `gain_capture_vec()`: 7 | ! `truth` should not a object. 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/prob-gain_curve.md: -------------------------------------------------------------------------------- 1 | # error handling 2 | 3 | Code 4 | gain_curve(df, truth, estimate) 5 | Condition 6 | Error in `gain_curve()`: 7 | ! `truth` should be a factor, not a a number. 8 | 9 | # na_rm = FALSE errors if missing values are present 10 | 11 | Code 12 | gain_curve_vec(df$truth, df$Class1, na_rm = FALSE) 13 | Condition 14 | Error in `gain_curve_vec()`: 15 | x Missing values were detected and `na_ra = FALSE`. 16 | i Not able to perform calculations. 17 | 18 | # errors with class_pred input 19 | 20 | Code 21 | gain_curve_vec_vec(cp_truth, estimate) 22 | Condition 23 | Error in `gain_curve_vec_vec()`: 24 | ! could not find function "gain_curve_vec_vec" 25 | 26 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/prob-lift_curve.md: -------------------------------------------------------------------------------- 1 | # error handling 2 | 3 | Code 4 | lift_curve(df, truth, estimate) 5 | Condition 6 | Error in `lift_curve()`: 7 | ! `truth` should be a factor, not a a number. 8 | 9 | # errors with class_pred input 10 | 11 | Code 12 | lift_curve_vec(cp_truth, estimate) 13 | Condition 14 | Error in `gain_curve_vec()`: 15 | ! `truth` should not a object. 16 | 17 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/prob-mn_log_loss.md: -------------------------------------------------------------------------------- 1 | # errors with class_pred input 2 | 3 | Code 4 | mn_log_loss_vec(cp_truth, estimate) 5 | Condition 6 | Error in `mn_log_loss_vec()`: 7 | ! `truth` should not a object. 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/prob-pr_auc.md: -------------------------------------------------------------------------------- 1 | # errors with class_pred input 2 | 3 | Code 4 | pr_auc_vec(cp_truth, estimate) 5 | Condition 6 | Error in `pr_auc_vec()`: 7 | ! `truth` should not a object. 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/prob-pr_curve.md: -------------------------------------------------------------------------------- 1 | # PR - zero row data frame works 2 | 3 | Code 4 | out <- pr_curve(df, y, x) 5 | Condition 6 | Warning: 7 | There are 0 event cases in `truth`, results will be meaningless. 8 | 9 | # errors with class_pred input 10 | 11 | Code 12 | pr_curve_vec(cp_truth, estimate) 13 | Condition 14 | Error in `pr_curve_vec()`: 15 | ! `truth` should not a object. 16 | 17 | # na_rm = FALSE errors if missing values are present 18 | 19 | Code 20 | pr_curve_vec(df$truth, df$Class1, na_rm = FALSE) 21 | Condition 22 | Error in `pr_curve_vec()`: 23 | x Missing values were detected and `na_ra = FALSE`. 24 | i Not able to perform calculations. 25 | 26 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/prob-roc_aunp.md: -------------------------------------------------------------------------------- 1 | # AUNP errors on binary case 2 | 3 | Code 4 | roc_aunp(two_class_example, truth, Class1) 5 | Condition 6 | Error in `roc_aunp()`: 7 | ! The number of levels in `truth` (2) must match the number of columns supplied in `...` (1). 8 | 9 | # roc_aunp() - `options` is deprecated 10 | 11 | Code 12 | out <- roc_aunp(two_class_example, truth, Class1, Class2, options = 1) 13 | Condition 14 | Warning: 15 | The `options` argument of `roc_aunp()` was deprecated in yardstick 1.0.0. 16 | i This argument no longer has any effect, and is being ignored. Use the pROC package directly if you need these features. 17 | 18 | --- 19 | 20 | Code 21 | out <- roc_aunp_vec(truth = two_class_example$truth, estimate = as.matrix( 22 | two_class_example[c("Class1", "Class2")]), options = 1) 23 | Condition 24 | Warning: 25 | The `options` argument of `roc_aunp_vec()` was deprecated in yardstick 1.0.0. 26 | i This argument no longer has any effect, and is being ignored. Use the pROC package directly if you need these features. 27 | 28 | # work with class_pred input 29 | 30 | Code 31 | roc_aunp_vec(cp_truth, estimate) 32 | Condition 33 | Error in `roc_aunp_vec()`: 34 | ! `truth` should not a object. 35 | 36 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/prob-roc_aunu.md: -------------------------------------------------------------------------------- 1 | # AUNU errors on binary case 2 | 3 | Code 4 | roc_aunu(two_class_example, truth, Class1) 5 | Condition 6 | Error in `roc_aunu()`: 7 | ! The number of levels in `truth` (2) must match the number of columns supplied in `...` (1). 8 | 9 | # roc_aunu() - `options` is deprecated 10 | 11 | Code 12 | out <- roc_aunu(two_class_example, truth, Class1, Class2, options = 1) 13 | Condition 14 | Warning: 15 | The `options` argument of `roc_aunu()` was deprecated in yardstick 1.0.0. 16 | i This argument no longer has any effect, and is being ignored. Use the pROC package directly if you need these features. 17 | 18 | --- 19 | 20 | Code 21 | out <- roc_aunu_vec(truth = two_class_example$truth, estimate = as.matrix( 22 | two_class_example[c("Class1", "Class2")]), options = 1) 23 | Condition 24 | Warning: 25 | The `options` argument of `roc_aunu_vec()` was deprecated in yardstick 1.0.0. 26 | i This argument no longer has any effect, and is being ignored. Use the pROC package directly if you need these features. 27 | 28 | # errors with class_pred input 29 | 30 | Code 31 | roc_aunu_vec(cp_truth, estimate) 32 | Condition 33 | Error in `roc_aunu_vec()`: 34 | ! `truth` should not a object. 35 | 36 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/prob-roc_curve.md: -------------------------------------------------------------------------------- 1 | # roc_curve() - error is thrown when missing events 2 | 3 | Code 4 | roc_curve_vec(no_event$truth, no_event$Class1)[[".estimate"]] 5 | Condition 6 | Error in `roc_curve_vec()`: 7 | ! No event observations were detected in `truth` with event level 'Class1'. 8 | 9 | # roc_curve() - error is thrown when missing controls 10 | 11 | Code 12 | roc_curve_vec(no_control$truth, no_control$Class1)[[".estimate"]] 13 | Condition 14 | Error in `roc_curve()`: 15 | ! No control observations were detected in `truth` with control level 'Class2'. 16 | 17 | # roc_curve() - multiclass one-vs-all approach results in error 18 | 19 | Code 20 | roc_curve_vec(no_event$obs, as.matrix(dplyr::select(no_event, VF:L)))[[ 21 | ".estimate"]] 22 | Condition 23 | Error in `roc_curve()`: 24 | ! No control observations were detected in `truth` with control level '..other'. 25 | 26 | # roc_curve() - `options` is deprecated 27 | 28 | Code 29 | out <- roc_curve(two_class_example, truth, Class1, options = 1) 30 | Condition 31 | Warning: 32 | The `options` argument of `roc_curve()` was deprecated in yardstick 1.0.0. 33 | i This argument no longer has any effect, and is being ignored. Use the pROC package directly if you need these features. 34 | 35 | # errors with class_pred input 36 | 37 | Code 38 | roc_curve_vec(cp_truth, estimate) 39 | Condition 40 | Error in `roc_curve_vec()`: 41 | ! `truth` should not a object. 42 | 43 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/probably.md: -------------------------------------------------------------------------------- 1 | # `class_pred` can be converted to `factor` when computing metrics 2 | 3 | Code 4 | accuracy_vec(cp_truth, cp_estimate) 5 | Condition 6 | Error in `accuracy_vec()`: 7 | ! `truth` should not a object. 8 | 9 | # `class_pred` errors when passed to `conf_mat()` 10 | 11 | Code 12 | conf_mat(cp_hpc_cv, obs, pred) 13 | Condition 14 | Error in `conf_mat()`: 15 | ! `truth` should not a object. 16 | 17 | --- 18 | 19 | Code 20 | conf_mat(dplyr::group_by(cp_hpc_cv, Resample), obs, pred) 21 | Condition 22 | Error in `conf_mat()`: 23 | ! `truth` should not a object. 24 | 25 | # `class_pred` errors when passed to `metrics()` 26 | 27 | Code 28 | metrics(cp_df, truth, estimate, class1) 29 | Condition 30 | Error in `metric_set()`: 31 | ! Failed to compute `accuracy()`. 32 | Caused by error: 33 | ! `truth` should not a object. 34 | 35 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/surv-brier_survival_integrated.md: -------------------------------------------------------------------------------- 1 | # brier_survival_integrated calculations 2 | 3 | Code 4 | brier_survival_integrated(data = lung_surv, truth = surv_obj, .pred) 5 | Condition 6 | Error in `brier_survival_integrated()`: 7 | ! At least 2 evaluation times are required. Only 1 unique time was given. 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/data/auc_churn_res.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/data/auc_churn_res.rds -------------------------------------------------------------------------------- /tests/testthat/data/brier_churn_res.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/data/brier_churn_res.rds -------------------------------------------------------------------------------- /tests/testthat/data/helper-pROC-two-class-example-curve.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/data/helper-pROC-two-class-example-curve.rds -------------------------------------------------------------------------------- /tests/testthat/data/helper-soybean.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/data/helper-soybean.rds -------------------------------------------------------------------------------- /tests/testthat/data/helper-three-class-helpers.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/data/helper-three-class-helpers.rds -------------------------------------------------------------------------------- /tests/testthat/data/ref_roc_auc_survival.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/data/ref_roc_auc_survival.rds -------------------------------------------------------------------------------- /tests/testthat/data/ref_roc_curve_survival.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/data/ref_roc_curve_survival.rds -------------------------------------------------------------------------------- /tests/testthat/data/rr_churn_data.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/data/rr_churn_data.rds -------------------------------------------------------------------------------- /tests/testthat/data/test_autoplot.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/data/test_autoplot.rds -------------------------------------------------------------------------------- /tests/testthat/data/tidy_churn.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/data/tidy_churn.rds -------------------------------------------------------------------------------- /tests/testthat/data/weights-hpc-cv.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/data/weights-hpc-cv.rds -------------------------------------------------------------------------------- /tests/testthat/data/weights-solubility-test.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/data/weights-solubility-test.rds -------------------------------------------------------------------------------- /tests/testthat/data/weights-two-class-example.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/data/weights-two-class-example.rds -------------------------------------------------------------------------------- /tests/testthat/helper-macro-micro.R: -------------------------------------------------------------------------------- 1 | # These helpers are used to test macro and macro weighted methods 2 | 3 | data_three_by_three <- function() { 4 | as.table( 5 | matrix( 6 | c( 7 | 3, 8 | 1, 9 | 1, 10 | 0, 11 | 4, 12 | 2, 13 | 1, 14 | 3, 15 | 5 16 | ), 17 | ncol = 3, 18 | byrow = TRUE, 19 | dimnames = list(c("c1", "c2", "c3"), c("c1", "c2", "c3")) 20 | ) 21 | ) 22 | } 23 | 24 | multi_ex <- data_three_by_three() 25 | weighted_macro_weights <- colSums(multi_ex) / sum(colSums(multi_ex)) 26 | 27 | # turn a 3x3 conf mat into a 2x2 submatrix in a one vs all approach 28 | make_submat <- function(data, col) { 29 | top_left <- data[col, col] 30 | top_righ <- sum(data[col, -col]) 31 | bot_left <- sum(data[-col, col]) 32 | bot_righ <- sum(data[-col, -col]) 33 | as.table( 34 | matrix( 35 | c(top_left, top_righ, bot_left, bot_righ), 36 | ncol = 2, 37 | byrow = TRUE 38 | ) 39 | ) 40 | } 41 | 42 | # These are the "one vs all" sub matrices 43 | # for macro / weighted macro, calculate the binary version of each metric 44 | # and then average them together 45 | multi_submats <- list( 46 | c1 = make_submat(multi_ex, 1), 47 | c2 = make_submat(multi_ex, 2), 48 | c3 = make_submat(multi_ex, 3) 49 | ) 50 | 51 | # Just pass in a binary metric function 52 | macro_metric <- function(binary_metric, event_level = "first", ...) { 53 | mean( 54 | vapply( 55 | multi_submats, 56 | binary_metric, 57 | numeric(1), 58 | event_level = event_level, 59 | ... 60 | ) 61 | ) 62 | } 63 | 64 | macro_weighted_metric <- function(binary_metric, event_level = "first", ...) { 65 | stats::weighted.mean( 66 | vapply( 67 | multi_submats, 68 | binary_metric, 69 | numeric(1), 70 | event_level = event_level, 71 | ... 72 | ), 73 | weighted_macro_weights 74 | ) 75 | } 76 | 77 | # For micro examples, we calculate the pieces by hand and use them individually 78 | data_three_by_three_micro <- function() { 79 | res <- list( 80 | tp = vapply( 81 | multi_submats, 82 | function(x) { 83 | x[1, 1] 84 | }, 85 | double(1) 86 | ), 87 | p = vapply( 88 | multi_submats, 89 | function(x) { 90 | colSums(x)[1] 91 | }, 92 | double(1) 93 | ), 94 | tn = vapply( 95 | multi_submats, 96 | function(x) { 97 | x[2, 2] 98 | }, 99 | double(1) 100 | ), 101 | n = vapply( 102 | multi_submats, 103 | function(x) { 104 | colSums(x)[2] 105 | }, 106 | double(1) 107 | ) 108 | ) 109 | 110 | res <- c( 111 | res, 112 | list( 113 | fp = res$p - res$tp, 114 | fn = res$n - res$tn 115 | ) 116 | ) 117 | 118 | res 119 | } 120 | -------------------------------------------------------------------------------- /tests/testthat/helper-macro-prob.R: -------------------------------------------------------------------------------- 1 | # These are helpers for class prob macro / macro_weighted tests 2 | 3 | data_hpc_fold1 <- function() { 4 | data("hpc_cv") 5 | dplyr::filter(hpc_cv, Resample == "Fold01") 6 | } 7 | 8 | hpc_fold1_macro_metric <- function(binary_metric, ...) { 9 | hpc_f1 <- data_hpc_fold1() 10 | truth <- hpc_f1$obs 11 | prob_mat <- as.matrix(dplyr::select(hpc_f1, VF:L)) 12 | case_weights <- NULL 13 | 14 | res <- one_vs_all_impl( 15 | fn = binary_metric, 16 | truth = truth, 17 | estimate = prob_mat, 18 | case_weights = case_weights, 19 | ... 20 | ) 21 | res <- vapply(res, FUN.VALUE = numeric(1), function(x) x) 22 | 23 | mean(res) 24 | } 25 | 26 | hpc_fold1_macro_weighted_metric <- function(binary_metric, ...) { 27 | hpc_f1 <- data_hpc_fold1() 28 | wt <- as.vector(table(hpc_f1$obs)) 29 | macro_wt <- wt / sum(wt) 30 | truth <- hpc_f1$obs 31 | prob_mat <- as.matrix(dplyr::select(hpc_f1, VF:L)) 32 | case_weights <- NULL 33 | 34 | res <- one_vs_all_impl( 35 | fn = binary_metric, 36 | truth = truth, 37 | estimate = prob_mat, 38 | case_weights = case_weights, 39 | ... 40 | ) 41 | 42 | res <- vapply(res, FUN.VALUE = numeric(1), function(x) x) 43 | 44 | stats::weighted.mean(res, macro_wt) 45 | } 46 | -------------------------------------------------------------------------------- /tests/testthat/helper-numeric.R: -------------------------------------------------------------------------------- 1 | generate_numeric_test_data <- function() { 2 | set.seed(1812) 3 | out <- data.frame(obs = rnorm(50)) 4 | out$pred <- .2 + 1.1 * out$obs + rnorm(50, sd = 0.5) 5 | out$pred_na <- out$pred 6 | ind <- (1:5) * 10 7 | out$pred_na[ind] <- NA 8 | out$rand <- sample(out$pred) 9 | out$rand_na <- out$rand 10 | out$rand_na[ind] <- NA 11 | out 12 | } 13 | -------------------------------------------------------------------------------- /tests/testthat/helper-pROC.R: -------------------------------------------------------------------------------- 1 | # # For comparison against pROC in the `roc_curve()` tests 2 | # 3 | # curve <- pROC::roc( 4 | # two_class_example$truth, 5 | # two_class_example$Class1, 6 | # levels = rev(levels(two_class_example$truth)), 7 | # direction = "<" 8 | # ) 9 | # 10 | # points <- pROC::coords( 11 | # curve, 12 | # x = unique(c(-Inf, two_class_example$Class1, Inf)), 13 | # input = "threshold", 14 | # transpose = FALSE 15 | # ) 16 | # 17 | # points <- dplyr::as_tibble(points) 18 | # points <- dplyr::arrange(points, threshold) 19 | # points <- dplyr::rename(points, .threshold = threshold) 20 | # class(points) <- c("roc_df", class(points)) 21 | # 22 | # saveRDS( 23 | # object = points, 24 | # file = test_path("data", "helper-pROC-two-class-example-curve.rds"), 25 | # version = 2L 26 | # ) 27 | data_pROC_two_class_example_curve <- function() { 28 | readRDS(test_path("data", "helper-pROC-two-class-example-curve.rds")) 29 | } 30 | -------------------------------------------------------------------------------- /tests/testthat/helper-read_pydata.R: -------------------------------------------------------------------------------- 1 | read_pydata <- function(py_path) { 2 | py_path <- paste0(py_path, ".rds") 3 | py_path <- test_path("py-data", py_path) 4 | readRDS(py_path) 5 | } 6 | -------------------------------------------------------------------------------- /tests/testthat/helper-weights.R: -------------------------------------------------------------------------------- 1 | # # Weights used in `two_class_example` tests 2 | # save_weights_two_class_example <- function(seed) { 3 | # withr::with_seed(seed, { 4 | # weights_two_class_example <- runif( 5 | # n = nrow(two_class_example), 6 | # min = 0, 7 | # max = 100 8 | # ) 9 | # }) 10 | # 11 | # saveRDS( 12 | # object = weights_two_class_example, 13 | # file = test_path("data/weights-two-class-example.rds"), 14 | # version = 2 15 | # ) 16 | # } 17 | # 18 | # save_weights_two_class_example(12345) 19 | 20 | read_weights_two_class_example <- function() { 21 | readRDS(test_path("data/weights-two-class-example.rds")) 22 | } 23 | 24 | # # Weights used in `hpc_cv` tests 25 | # save_weights_hpc_cv <- function(seed) { 26 | # withr::with_seed(seed, { 27 | # weights_hpc_cv <- runif( 28 | # n = nrow(hpc_cv), 29 | # min = 0, 30 | # max = 100 31 | # ) 32 | # }) 33 | # 34 | # saveRDS( 35 | # object = weights_hpc_cv, 36 | # file = test_path("data/weights-hpc-cv.rds"), 37 | # version = 2 38 | # ) 39 | # } 40 | # 41 | # save_weights_hpc_cv(4321) 42 | 43 | read_weights_hpc_cv <- function() { 44 | readRDS(test_path("data/weights-hpc-cv.rds")) 45 | } 46 | 47 | # # Weights used in `solubility_test` tests 48 | # save_weights_solubility_test <- function(seed) { 49 | # withr::with_seed(seed, { 50 | # weights_solubility_test <- runif( 51 | # n = nrow(solubility_test), 52 | # min = 0, 53 | # max = 100 54 | # ) 55 | # }) 56 | # 57 | # saveRDS( 58 | # object = weights_solubility_test, 59 | # file = test_path("data/weights-solubility-test.rds"), 60 | # version = 2 61 | # ) 62 | # } 63 | # 64 | # save_weights_solubility_test(55555) 65 | 66 | read_weights_solubility_test <- function() { 67 | readRDS(test_path("data/weights-solubility-test.rds")) 68 | } 69 | -------------------------------------------------------------------------------- /tests/testthat/py-data/py-accuracy.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-accuracy.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-average-precision.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-average-precision.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-bal-accuracy.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-bal-accuracy.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-brier-survival.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-brier-survival.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-demographic_parity.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-demographic_parity.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-equal_opportunity.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-equal_opportunity.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-equalized_odds.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-equalized_odds.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-f_meas.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-f_meas.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-f_meas_beta_.5.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-f_meas_beta_.5.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-kap.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-kap.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-mae.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-mae.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-mape.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-mape.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-mcc.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-mcc.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-mn_log_loss.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-mn_log_loss.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-npv.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-npv.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-ppv.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-ppv.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-pr-curve.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-pr-curve.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-precision.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-precision.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-recall.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-recall.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-rmse.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-rmse.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-roc-auc.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-roc-auc.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-roc-curve.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-roc-curve.rds -------------------------------------------------------------------------------- /tests/testthat/py-data/py-rsq-trad.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/yardstick/14b6aea050f80ce154a758a8c4b20f234e6f7b4e/tests/testthat/py-data/py-rsq-trad.rds -------------------------------------------------------------------------------- /tests/testthat/test-aaa-new.R: -------------------------------------------------------------------------------- 1 | test_that("can create metric functions", { 2 | fn1 <- new_class_metric(function() 1, "maximize") 3 | fn2 <- new_prob_metric(function() 1, "maximize") 4 | fn3 <- new_numeric_metric(function() 1, "minimize") 5 | fn3zero <- new_numeric_metric(function() 1, "zero") 6 | fn4 <- new_dynamic_survival_metric(function() 1, "minimize") 7 | fn5 <- new_static_survival_metric(function() 1, "minimize") 8 | fn6 <- new_integrated_survival_metric(function() 1, "minimize") 9 | 10 | expect_identical(class(fn1), c("class_metric", "metric", "function")) 11 | expect_identical(class(fn2), c("prob_metric", "metric", "function")) 12 | expect_identical(class(fn3), c("numeric_metric", "metric", "function")) 13 | expect_identical(class(fn3zero), c("numeric_metric", "metric", "function")) 14 | expect_identical( 15 | class(fn4), 16 | c("dynamic_survival_metric", "metric", "function") 17 | ) 18 | expect_identical( 19 | class(fn5), 20 | c("static_survival_metric", "metric", "function") 21 | ) 22 | expect_identical( 23 | class(fn6), 24 | c("integrated_survival_metric", "metric", "function") 25 | ) 26 | 27 | expect_identical(attr(fn1, "direction"), "maximize") 28 | expect_identical(attr(fn2, "direction"), "maximize") 29 | expect_identical(attr(fn3, "direction"), "minimize") 30 | expect_identical(attr(fn3zero, "direction"), "zero") 31 | expect_identical(attr(fn4, "direction"), "minimize") 32 | expect_identical(attr(fn5, "direction"), "minimize") 33 | expect_identical(attr(fn6, "direction"), "minimize") 34 | }) 35 | 36 | test_that("`fn` is validated", { 37 | expect_snapshot( 38 | error = TRUE, 39 | new_class_metric(1, "maximize") 40 | ) 41 | }) 42 | 43 | test_that("`direction` is validated", { 44 | expect_snapshot( 45 | error = TRUE, 46 | new_class_metric(function() 1, "min") 47 | ) 48 | }) 49 | 50 | test_that("metric print method works", { 51 | expect_snapshot(rmse) 52 | expect_snapshot(roc_auc) 53 | expect_snapshot(demographic_parity(boop)) 54 | }) 55 | -------------------------------------------------------------------------------- /tests/testthat/test-auc.R: -------------------------------------------------------------------------------- 1 | test_that("Matches MLmetrics", { 2 | x <- c(1, 1.2, 1.6, 2) 3 | y <- c(4, 3.8, 4.2, 5) 4 | # MLmetrics::Area_Under_Curve(x, y, "trapezoid") 5 | auc_known <- 4.22 6 | 7 | expect_equal(auc(x, y), auc_known) 8 | }) 9 | -------------------------------------------------------------------------------- /tests/testthat/test-check_metric.R: -------------------------------------------------------------------------------- 1 | test_that("check_numeric_metric() validates case_weights", { 2 | expect_snapshot( 3 | error = TRUE, 4 | check_numeric_metric(1:10, 1:10, 1:11) 5 | ) 6 | }) 7 | 8 | test_that("check_numeric_metric() validates inputs", { 9 | expect_snapshot( 10 | error = TRUE, 11 | check_numeric_metric(1, "1", 1) 12 | ) 13 | }) 14 | 15 | test_that("check_class_metric() validates case_weights", { 16 | expect_snapshot( 17 | error = TRUE, 18 | check_class_metric(letters, letters, 1:5) 19 | ) 20 | }) 21 | 22 | test_that("check_class_metric() validates inputs", { 23 | expect_snapshot( 24 | error = TRUE, 25 | check_class_metric(1, "1", 1) 26 | ) 27 | }) 28 | 29 | test_that("check_class_metric() validates estimator", { 30 | expect_snapshot( 31 | error = TRUE, 32 | check_class_metric( 33 | factor(c("a", "b", "a"), levels = c("a", "b", "c")), 34 | factor(c("a", "b", "a"), levels = c("a", "b", "c")), 35 | case_weights = 1:3, 36 | estimator = "binary" 37 | ) 38 | ) 39 | }) 40 | 41 | test_that("check_prob_metric() validates case_weights", { 42 | expect_snapshot( 43 | error = TRUE, 44 | check_prob_metric( 45 | factor(c("a", "b", "a")), 46 | matrix(1:6, nrow = 2), 47 | 1:4, 48 | estimator = "binary" 49 | ) 50 | ) 51 | }) 52 | 53 | test_that("check_prob_metric() validates inputs", { 54 | expect_snapshot( 55 | error = TRUE, 56 | check_prob_metric( 57 | factor(c("a", "b", "a")), 58 | matrix(1:6, nrow = 2), 59 | 1:3, 60 | estimator = "binary" 61 | ) 62 | ) 63 | }) 64 | 65 | test_that("check_ordered_prob_metric() validates case_weights", { 66 | expect_snapshot( 67 | error = TRUE, 68 | check_ordered_prob_metric( 69 | ordered(c("a", "b", "a")), 70 | matrix(1:6, nrow = 2), 71 | 1:4, 72 | estimator = "binary" 73 | ) 74 | ) 75 | }) 76 | 77 | test_that("check_ordered_prob_metric() validates inputs", { 78 | expect_snapshot( 79 | error = TRUE, 80 | check_ordered_prob_metric( 81 | ordered(c("a", "b", "a")), 82 | matrix(1:6, nrow = 2), 83 | 1:3, 84 | estimator = "binary" 85 | ) 86 | ) 87 | }) 88 | 89 | test_that("check_static_survival_metric() validates case_weights", { 90 | lung_surv <- data_lung_surv() 91 | 92 | expect_snapshot( 93 | error = TRUE, 94 | check_static_survival_metric( 95 | truth = lung_surv$surv_obj, 96 | estimate = lung_surv$.pred_survival, 97 | case_weights = 1:151 98 | ) 99 | ) 100 | }) 101 | 102 | test_that("check_static_survival_metric() validates inputs", { 103 | lung_surv <- data_lung_surv() 104 | 105 | expect_snapshot( 106 | error = TRUE, 107 | check_static_survival_metric( 108 | truth = lung_surv$surv_obj, 109 | estimate = as.character(lung_surv$inst), 110 | case_weights = 1:150 111 | ) 112 | ) 113 | }) 114 | -------------------------------------------------------------------------------- /tests/testthat/test-estimator-helpers.R: -------------------------------------------------------------------------------- 1 | test_that("get_weights() errors with wrong estimator", { 2 | expect_snapshot( 3 | error = TRUE, 4 | get_weights(mtcars, "wrong") 5 | ) 6 | }) 7 | -------------------------------------------------------------------------------- /tests/testthat/test-event-level.R: -------------------------------------------------------------------------------- 1 | test_that("`yardstick_event_level()` defaults to 'first'", { 2 | expect_identical(yardstick_event_level(), "first") 3 | }) 4 | 5 | test_that("`yardstick_event_level()` ignores option - TRUE, with a warning", { 6 | skip_if( 7 | getRversion() <= "3.5.3", 8 | "Base R used a different deprecated warning class." 9 | ) 10 | rlang::local_options(lifecycle_verbosity = "warning") 11 | rlang::local_options(yardstick.event_first = TRUE) 12 | expect_snapshot(out <- yardstick_event_level()) 13 | expect_identical(out, "first") 14 | }) 15 | 16 | test_that("`yardstick_event_level()` ignores option - FALSE, with a warning", { 17 | skip_if( 18 | getRversion() <= "3.5.3", 19 | "Base R used a different deprecated warning class." 20 | ) 21 | rlang::local_options(lifecycle_verbosity = "warning") 22 | rlang::local_options(yardstick.event_first = FALSE) 23 | expect_snapshot(out <- yardstick_event_level()) 24 | expect_identical(out, "first") 25 | }) 26 | 27 | test_that("validate_event_level() works", { 28 | expect_snapshot( 29 | error = TRUE, 30 | recall(two_class_example, truth, predicted, event_level = "wrong") 31 | ) 32 | }) 33 | -------------------------------------------------------------------------------- /tests/testthat/test-fair-demographic_parity.R: -------------------------------------------------------------------------------- 1 | test_that("result matches reference implementation (fairlearn)", { 2 | data("hpc_cv") 3 | py_res <- read_pydata("py-demographic_parity") 4 | 5 | hpc_cv$obs_vf = as.factor(hpc_cv$obs == "VF") 6 | hpc_cv$pred_vf = as.factor(hpc_cv$pred == "VF") 7 | hpc_cv$case_weights <- read_weights_hpc_cv() 8 | 9 | dp <- demographic_parity(Resample) 10 | 11 | expect_equal( 12 | dp( 13 | hpc_cv, 14 | truth = obs_vf, 15 | estimate = pred_vf, 16 | event_level = "second" 17 | )$.estimate, 18 | py_res$binary 19 | ) 20 | 21 | expect_equal( 22 | dp( 23 | hpc_cv, 24 | truth = obs_vf, 25 | estimate = pred_vf, 26 | event_level = "second", 27 | case_weights = case_weights 28 | )$.estimate, 29 | py_res$weighted 30 | ) 31 | }) 32 | -------------------------------------------------------------------------------- /tests/testthat/test-fair-equal_opportunity.R: -------------------------------------------------------------------------------- 1 | test_that("result matches reference implementation (fairlearn)", { 2 | data("hpc_cv") 3 | py_res <- read_pydata("py-equal_opportunity") 4 | 5 | hpc_cv$obs_vf = as.factor(hpc_cv$obs == "VF") 6 | hpc_cv$pred_vf = as.factor(hpc_cv$pred == "VF") 7 | hpc_cv$case_weights <- read_weights_hpc_cv() 8 | 9 | eo <- equal_opportunity(Resample) 10 | 11 | expect_equal( 12 | eo( 13 | hpc_cv, 14 | truth = obs_vf, 15 | estimate = pred_vf, 16 | event_level = "second" 17 | )$.estimate, 18 | py_res$binary 19 | ) 20 | 21 | expect_equal( 22 | eo( 23 | hpc_cv, 24 | truth = obs_vf, 25 | estimate = pred_vf, 26 | event_level = "second", 27 | case_weights = case_weights 28 | )$.estimate, 29 | py_res$weighted 30 | ) 31 | }) 32 | -------------------------------------------------------------------------------- /tests/testthat/test-fair-equalized_odds.R: -------------------------------------------------------------------------------- 1 | test_that("result matches reference implementation (fairlearn)", { 2 | data("hpc_cv") 3 | py_res <- read_pydata("py-equalized_odds") 4 | 5 | hpc_cv$obs_vf = as.factor(hpc_cv$obs == "VF") 6 | hpc_cv$pred_vf = as.factor(hpc_cv$pred == "VF") 7 | hpc_cv$case_weights <- read_weights_hpc_cv() 8 | 9 | eo <- equalized_odds(Resample) 10 | 11 | expect_equal( 12 | eo( 13 | hpc_cv, 14 | truth = obs_vf, 15 | estimate = pred_vf, 16 | event_level = "second" 17 | )$.estimate, 18 | py_res$binary 19 | ) 20 | 21 | expect_equal( 22 | eo( 23 | hpc_cv, 24 | truth = obs_vf, 25 | estimate = pred_vf, 26 | event_level = "second", 27 | case_weights = case_weights 28 | )$.estimate, 29 | py_res$weighted 30 | ) 31 | }) 32 | -------------------------------------------------------------------------------- /tests/testthat/test-flatten.R: -------------------------------------------------------------------------------- 1 | test_that("flat tables", { 2 | lst <- data_three_class() 3 | three_class <- lst$three_class 4 | three_class_tb <- lst$three_class_tb 5 | 6 | expect_identical( 7 | unname(yardstick:::flatten(three_class_tb)), 8 | as.vector(three_class_tb) 9 | ) 10 | expect_equal( 11 | names(yardstick:::flatten(three_class_tb[1:2, 1:2])), 12 | c("cell_1_1", "cell_2_1", "cell_1_2", "cell_2_2") 13 | ) 14 | 15 | expect_snapshot( 16 | error = TRUE, 17 | yardstick:::flatten(three_class_tb[, 1:2]) 18 | ) 19 | }) 20 | -------------------------------------------------------------------------------- /tests/testthat/test-handle_missings.R: -------------------------------------------------------------------------------- 1 | test_that("yardstick_remove_missing works", { 2 | expect_identical( 3 | yardstick_remove_missing(1:10, 1:10, 1:10), 4 | list( 5 | truth = 1:10, 6 | estimate = 1:10, 7 | case_weights = 1:10 8 | ) 9 | ) 10 | 11 | expect_identical( 12 | yardstick_remove_missing(c(1:4, NA, NA, 7:10), 1:10, 1:10), 13 | list( 14 | truth = c(1:4, 7:10), 15 | estimate = c(1:4, 7:10), 16 | case_weights = c(1:4, 7:10) 17 | ) 18 | ) 19 | 20 | expect_identical( 21 | yardstick_remove_missing(1:10, c(1:4, NA, NA, 7:10), 1:10), 22 | list( 23 | truth = c(1:4, 7:10), 24 | estimate = c(1:4, 7:10), 25 | case_weights = c(1:4, 7:10) 26 | ) 27 | ) 28 | 29 | expect_identical( 30 | yardstick_remove_missing(1:10, 1:10, c(1:4, NA, NA, 7:10)), 31 | list( 32 | truth = c(1:4, 7:10), 33 | estimate = c(1:4, 7:10), 34 | case_weights = c(1:4, 7:10) 35 | ) 36 | ) 37 | 38 | expect_identical( 39 | yardstick_remove_missing(1:10, c(1:4, NA, NA, 7:10), 1:10), 40 | list( 41 | truth = c(1:4, 7:10), 42 | estimate = c(1:4, 7:10), 43 | case_weights = c(1:4, 7:10) 44 | ) 45 | ) 46 | 47 | expect_identical( 48 | yardstick_remove_missing(c(NA, 2:10), c(1:9, NA), c(1:4, NA, NA, 7:10)), 49 | list( 50 | truth = c(2:4, 7:9), 51 | estimate = c(2:4, 7:9), 52 | case_weights = c(2:4, 7:9) 53 | ) 54 | ) 55 | }) 56 | 57 | test_that("yardstick_any_missing works", { 58 | expect_false( 59 | yardstick_any_missing(1:10, 1:10, 1:10) 60 | ) 61 | 62 | expect_true( 63 | yardstick_any_missing(c(1:4, NA, NA, 7:10), 1:10, 1:10) 64 | ) 65 | 66 | expect_true( 67 | yardstick_any_missing(1:10, c(1:4, NA, NA, 7:10), 1:10) 68 | ) 69 | 70 | expect_true( 71 | yardstick_any_missing(1:10, 1:10, c(1:4, NA, NA, 7:10)) 72 | ) 73 | 74 | expect_true( 75 | yardstick_any_missing(1:10, c(1:4, NA, NA, 7:10), 1:10) 76 | ) 77 | 78 | expect_true( 79 | yardstick_any_missing(c(NA, 2:10), c(1:9, NA), c(1:4, NA, NA, 7:10)) 80 | ) 81 | }) 82 | -------------------------------------------------------------------------------- /tests/testthat/test-num-huber_loss.R: -------------------------------------------------------------------------------- 1 | test_that("Huber Loss", { 2 | ex_dat <- generate_numeric_test_data() 3 | not_na <- !is.na(ex_dat$pred_na) 4 | 5 | delta <- 2 6 | 7 | expect_equal( 8 | huber_loss(ex_dat, truth = "obs", estimate = "pred", delta = delta)[[ 9 | ".estimate" 10 | ]], 11 | { 12 | a <- ex_dat$obs - ex_dat$pred 13 | mean( 14 | ifelse(abs(a) <= delta, 0.5 * a^2, delta * (abs(a) - 0.5 * delta)) 15 | ) 16 | } 17 | ) 18 | 19 | expect_equal( 20 | huber_loss(ex_dat, truth = "obs", estimate = "pred_na", delta = delta)[[ 21 | ".estimate" 22 | ]], 23 | { 24 | a <- ex_dat$obs[not_na] - ex_dat$pred[not_na] 25 | mean( 26 | ifelse(abs(a) <= delta, 0.5 * a^2, delta * (abs(a) - 0.5 * delta)) 27 | ) 28 | } 29 | ) 30 | 31 | expect_snapshot( 32 | error = TRUE, 33 | huber_loss(ex_dat, truth = "obs", estimate = "pred_na", delta = -1) 34 | ) 35 | 36 | expect_snapshot( 37 | error = TRUE, 38 | huber_loss(ex_dat, truth = "obs", estimate = "pred_na", delta = c(1, 2)) 39 | ) 40 | }) 41 | 42 | test_that("Weighted results are working", { 43 | truth <- c(1, 2, 3) 44 | estimate <- c(2, 4, 3) 45 | weights <- c(1, 2, 1) 46 | 47 | expect_identical( 48 | huber_loss_vec(truth, estimate, case_weights = weights), 49 | 3.5 / 4 50 | ) 51 | }) 52 | 53 | test_that("works with hardhat case weights", { 54 | solubility_test$weights <- floor(read_weights_solubility_test()) 55 | df <- solubility_test 56 | 57 | imp_wgt <- hardhat::importance_weights(df$weights) 58 | freq_wgt <- hardhat::frequency_weights(df$weights) 59 | 60 | expect_no_error( 61 | huber_loss_vec(df$solubility, df$prediction, case_weights = imp_wgt) 62 | ) 63 | 64 | expect_no_error( 65 | huber_loss_vec(df$solubility, df$prediction, case_weights = freq_wgt) 66 | ) 67 | }) 68 | -------------------------------------------------------------------------------- /tests/testthat/test-num-iic.R: -------------------------------------------------------------------------------- 1 | # All tests (excepted weighted ones) confirmed against the software: 2 | # http://www.insilico.eu/coral/SOFTWARECORAL.html 3 | 4 | test_that("iic() returns known correct results", { 5 | ex_dat <- generate_numeric_test_data() 6 | 7 | expect_equal(iic(ex_dat, obs, pred)[[".estimate"]], 0.43306222006167) 8 | }) 9 | 10 | test_that("iic() can be negative", { 11 | expect_equal(iic_vec(c(1, 2, 3), c(2, 1, 1)), -0.577350269189626) 12 | }) 13 | 14 | test_that("iic() is NaN if truth/estimate are equivalent", { 15 | expect_equal(iic_vec(c(1, 2), c(1, 2)), NaN) 16 | }) 17 | 18 | test_that("case weights are applied", { 19 | df <- dplyr::tibble( 20 | truth = c(1, 2, 3, 4, 5), 21 | estimate = c(1, 3, 1, 3, 2), 22 | weight = c(1, 2, 1, 2, 0) 23 | ) 24 | 25 | expect_equal( 26 | iic(df, truth, estimate, case_weights = weight)[[".estimate"]], 27 | 0.4264014327112208846415 28 | ) 29 | }) 30 | 31 | test_that("yardstick correlation warnings are thrown", { 32 | cnd <- rlang::catch_cnd(iic_vec(c(1, 2), c(1, 1))) 33 | expect_s3_class( 34 | cnd, 35 | "yardstick_warning_correlation_undefined_constant_estimate" 36 | ) 37 | 38 | cnd <- rlang::catch_cnd(iic_vec(c(1, 1), c(1, 2))) 39 | expect_s3_class(cnd, "yardstick_warning_correlation_undefined_constant_truth") 40 | }) 41 | 42 | test_that("works with hardhat case weights", { 43 | solubility_test$weights <- floor(read_weights_solubility_test()) 44 | df <- solubility_test 45 | 46 | imp_wgt <- hardhat::importance_weights(df$weights) 47 | freq_wgt <- hardhat::frequency_weights(df$weights) 48 | 49 | expect_no_error( 50 | iic_vec(df$solubility, df$prediction, case_weights = imp_wgt) 51 | ) 52 | 53 | expect_no_error( 54 | iic_vec(df$solubility, df$prediction, case_weights = freq_wgt) 55 | ) 56 | }) 57 | -------------------------------------------------------------------------------- /tests/testthat/test-num-mae.R: -------------------------------------------------------------------------------- 1 | test_that("mean absolute error", { 2 | ex_dat <- generate_numeric_test_data() 3 | not_na <- !is.na(ex_dat$pred_na) 4 | 5 | expect_equal( 6 | mae(ex_dat, truth = "obs", estimate = "pred")[[".estimate"]], 7 | mean(abs(ex_dat$obs - ex_dat$pred)) 8 | ) 9 | expect_equal( 10 | mae(ex_dat, obs, pred_na)[[".estimate"]], 11 | mean(abs(ex_dat$obs[not_na] - ex_dat$pred[not_na])) 12 | ) 13 | }) 14 | 15 | test_that("Weighted results are the same as scikit-learn", { 16 | solubility_test$weights <- read_weights_solubility_test() 17 | 18 | expect_equal( 19 | mae(solubility_test, solubility, prediction, case_weights = weights)[[ 20 | ".estimate" 21 | ]], 22 | read_pydata("py-mae")$case_weight 23 | ) 24 | }) 25 | 26 | test_that("works with hardhat case weights", { 27 | solubility_test$weights <- floor(read_weights_solubility_test()) 28 | df <- solubility_test 29 | 30 | imp_wgt <- hardhat::importance_weights(df$weights) 31 | freq_wgt <- hardhat::frequency_weights(df$weights) 32 | 33 | expect_no_error( 34 | mae_vec(df$solubility, df$prediction, case_weights = imp_wgt) 35 | ) 36 | 37 | expect_no_error( 38 | mae_vec(df$solubility, df$prediction, case_weights = freq_wgt) 39 | ) 40 | }) 41 | -------------------------------------------------------------------------------- /tests/testthat/test-num-mape.R: -------------------------------------------------------------------------------- 1 | test_that("Mean Absolute Percentage Error", { 2 | ex_dat <- generate_numeric_test_data() 3 | not_na <- !is.na(ex_dat$pred_na) 4 | 5 | expect_equal( 6 | mape(ex_dat, truth = "obs", estimate = "pred")[[".estimate"]], 7 | 100 * mean(abs((ex_dat$obs - ex_dat$pred) / ex_dat$obs)) 8 | ) 9 | expect_equal( 10 | mape(ex_dat, obs, pred_na)[[".estimate"]], 11 | 100 * 12 | mean(abs((ex_dat$obs[not_na] - ex_dat$pred[not_na]) / ex_dat$obs[not_na])) 13 | ) 14 | }) 15 | 16 | test_that("`mape()` computes expected values when singular `truth` is `0` (#271)", { 17 | expect_identical( 18 | mape_vec(truth = 0, estimate = 1), 19 | Inf 20 | ) 21 | 22 | expect_identical( 23 | mape_vec(truth = 0, estimate = -1), 24 | Inf 25 | ) 26 | 27 | expect_identical( 28 | mape_vec(truth = 0, estimate = 0), 29 | NaN 30 | ) 31 | }) 32 | 33 | test_that("Weighted results are the same as scikit-learn", { 34 | solubility_test$weights <- read_weights_solubility_test() 35 | zero_solubility <- solubility_test$solubility == 0 36 | solubility_test_not_zero <- solubility_test[!zero_solubility, ] 37 | 38 | expect_equal( 39 | mape( 40 | solubility_test_not_zero, 41 | solubility, 42 | prediction, 43 | case_weights = weights 44 | )[[".estimate"]], 45 | read_pydata("py-mape")$case_weight * 100 46 | ) 47 | }) 48 | 49 | test_that("works with hardhat case weights", { 50 | solubility_test$weights <- floor(read_weights_solubility_test()) 51 | df <- solubility_test 52 | 53 | imp_wgt <- hardhat::importance_weights(df$weights) 54 | freq_wgt <- hardhat::frequency_weights(df$weights) 55 | 56 | expect_no_error( 57 | mape_vec(df$solubility, df$prediction, case_weights = imp_wgt) 58 | ) 59 | 60 | expect_no_error( 61 | mape_vec(df$solubility, df$prediction, case_weights = freq_wgt) 62 | ) 63 | }) 64 | -------------------------------------------------------------------------------- /tests/testthat/test-num-mase.R: -------------------------------------------------------------------------------- 1 | test_that("Mean Absolute Scaled Error", { 2 | ex_dat <- generate_numeric_test_data() 3 | 4 | truth <- ex_dat$obs 5 | pred <- ex_dat$pred 6 | 7 | truth_lag <- dplyr::lag(truth, 1L) 8 | naive_error <- truth - truth_lag 9 | mae_denom <- mean(abs(naive_error)[-1]) 10 | scaled_error <- (truth - pred) / mae_denom 11 | known_mase <- mean(abs(scaled_error)) 12 | 13 | m <- 2 14 | 15 | truth_lag <- dplyr::lag(truth, m) 16 | naive_error <- truth - truth_lag 17 | mae_denom <- mean(abs(naive_error)[-c(1, 2)]) 18 | scaled_error <- (truth - pred) / mae_denom 19 | known_mase_with_m <- mean(abs(scaled_error)) 20 | 21 | mae_train <- .5 22 | 23 | mae_denom <- mae_train 24 | scaled_error <- (truth - pred) / mae_denom 25 | known_mase_with_mae_train <- mean(abs(scaled_error)) 26 | 27 | expect_equal( 28 | mase(ex_dat, obs, pred)[[".estimate"]], 29 | known_mase 30 | ) 31 | 32 | expect_equal( 33 | mase(ex_dat, obs, pred, m = 2)[[".estimate"]], 34 | known_mase_with_m 35 | ) 36 | 37 | expect_equal( 38 | mase(ex_dat, obs, pred, mae_train = mae_train)[[".estimate"]], 39 | known_mase_with_mae_train 40 | ) 41 | 42 | expect_snapshot( 43 | error = TRUE, 44 | mase_vec(truth, pred, m = "x") 45 | ) 46 | 47 | expect_snapshot( 48 | error = TRUE, 49 | mase_vec(truth, pred, m = -1) 50 | ) 51 | 52 | expect_snapshot( 53 | error = TRUE, 54 | mase_vec(truth, pred, m = 1.5) 55 | ) 56 | 57 | expect_snapshot( 58 | error = TRUE, 59 | mase_vec(truth, pred, mae_train = -1) 60 | ) 61 | 62 | expect_snapshot( 63 | error = TRUE, 64 | mase_vec(truth, pred, mae_train = "x") 65 | ) 66 | }) 67 | 68 | test_that("Weighted results are working", { 69 | truth <- c(1, 2, 3) 70 | estimate <- c(2, 4, 3) 71 | weights <- c(1, 2, 1) 72 | 73 | expect_identical( 74 | mase_vec(truth, estimate, case_weights = weights), 75 | 5 / 4 76 | ) 77 | }) 78 | 79 | test_that("works with hardhat case weights", { 80 | solubility_test$weights <- floor(read_weights_solubility_test()) 81 | df <- solubility_test 82 | 83 | imp_wgt <- hardhat::importance_weights(df$weights) 84 | freq_wgt <- hardhat::frequency_weights(df$weights) 85 | 86 | expect_no_error( 87 | mase_vec(df$solubility, df$prediction, case_weights = imp_wgt) 88 | ) 89 | 90 | expect_no_error( 91 | mase_vec(df$solubility, df$prediction, case_weights = freq_wgt) 92 | ) 93 | }) 94 | 95 | test_that("mase() errors if m is larger than number of observations", { 96 | expect_snapshot( 97 | error = TRUE, 98 | mase(mtcars, mpg, disp, m = 100) 99 | ) 100 | }) 101 | -------------------------------------------------------------------------------- /tests/testthat/test-num-mpe.R: -------------------------------------------------------------------------------- 1 | test_that("`mpe()` works", { 2 | set.seed(1812) 3 | df <- data.frame(obs = rnorm(50)) 4 | df$pred <- .2 + 1.1 * df$obs + rnorm(50, sd = 0.5) 5 | 6 | expect_identical( 7 | mpe(df, truth = "obs", estimate = "pred")[[".estimate"]], 8 | mean((df$obs - df$pred) / df$obs) * 100 9 | ) 10 | 11 | ind <- c(10, 20, 30, 40, 50) 12 | df$pred[ind] <- NA 13 | 14 | expect_identical( 15 | mpe(df, obs, pred)[[".estimate"]], 16 | mean((df$obs[-ind] - df$pred[-ind]) / df$obs[-ind]) * 100 17 | ) 18 | }) 19 | 20 | test_that("`mpe()` computes expected values when singular `truth` is `0`", { 21 | expect_identical( 22 | mpe_vec(truth = 0, estimate = 1), 23 | -Inf 24 | ) 25 | 26 | expect_identical( 27 | mpe_vec(truth = 0, estimate = -1), 28 | Inf 29 | ) 30 | 31 | expect_identical( 32 | mpe_vec(truth = 0, estimate = 0), 33 | NaN 34 | ) 35 | }) 36 | 37 | test_that("Weighted results are working", { 38 | truth <- c(1, 2, 3) 39 | estimate <- c(2, 4, 3) 40 | weights <- c(1, 2, 1) 41 | 42 | expect_identical( 43 | mpe_vec(truth, estimate, case_weights = weights), 44 | -3 / 4 * 100 45 | ) 46 | }) 47 | 48 | test_that("works with hardhat case weights", { 49 | solubility_test$weights <- floor(read_weights_solubility_test()) 50 | df <- solubility_test 51 | 52 | imp_wgt <- hardhat::importance_weights(df$weights) 53 | freq_wgt <- hardhat::frequency_weights(df$weights) 54 | 55 | expect_no_error( 56 | mpe_vec(df$solubility, df$prediction, case_weights = imp_wgt) 57 | ) 58 | 59 | expect_no_error( 60 | mpe_vec(df$solubility, df$prediction, case_weights = freq_wgt) 61 | ) 62 | }) 63 | -------------------------------------------------------------------------------- /tests/testthat/test-num-msd.R: -------------------------------------------------------------------------------- 1 | test_that("`msd()` works", { 2 | set.seed(1812) 3 | df <- data.frame(obs = rnorm(50)) 4 | df$pred <- .2 + 1.1 * df$obs + rnorm(50, sd = 0.5) 5 | 6 | expect_identical( 7 | msd(df, truth = "obs", estimate = "pred")[[".estimate"]], 8 | mean(df$obs - df$pred) 9 | ) 10 | 11 | # adding some NA values and check that they are ignored 12 | ind <- c(10, 20, 30, 40, 50) 13 | df$pred[ind] <- NA 14 | 15 | expect_identical( 16 | msd(df, obs, pred)[[".estimate"]], 17 | mean(df$obs[-ind] - df$pred[-ind]) 18 | ) 19 | }) 20 | 21 | test_that("positive and negative errors cancel each other out", { 22 | expect_identical(msd_vec(c(100, -100), c(0, 0)), 0) 23 | }) 24 | 25 | test_that("differences are computed as `truth - estimate`", { 26 | expect_identical(msd_vec(0, 1), -1) 27 | }) 28 | 29 | test_that("weighted results are correct", { 30 | truth <- c(1, 2, 3) 31 | estimate <- c(1, 4, 4) 32 | weights <- c(0, 1, 2) 33 | 34 | expect_identical( 35 | msd_vec(truth, estimate, case_weights = weights), 36 | -4 / 3 37 | ) 38 | }) 39 | 40 | test_that("works with hardhat case weights", { 41 | solubility_test$weights <- floor(read_weights_solubility_test()) 42 | df <- solubility_test 43 | 44 | imp_wgt <- hardhat::importance_weights(df$weights) 45 | freq_wgt <- hardhat::frequency_weights(df$weights) 46 | 47 | expect_no_error( 48 | msd_vec(df$solubility, df$prediction, case_weights = imp_wgt) 49 | ) 50 | 51 | expect_no_error( 52 | msd_vec(df$solubility, df$prediction, case_weights = freq_wgt) 53 | ) 54 | }) 55 | -------------------------------------------------------------------------------- /tests/testthat/test-num-poisson_log_loss.R: -------------------------------------------------------------------------------- 1 | test_that("poisson log-loss", { 2 | count_results <- data_counts()$basic 3 | count_missing <- data_counts()$missing 4 | count_poor <- data_counts()$poor 5 | 6 | expect_equal( 7 | poisson_log_loss(count_results, count, pred)[[".estimate"]], 8 | mean(-stats::dpois(count_results$count, count_results$pred, log = TRUE)) 9 | ) 10 | 11 | expect_equal( 12 | poisson_log_loss(count_missing, count, pred)[[".estimate"]], 13 | mean( 14 | -stats::dpois(count_results$count[-1], count_results$pred[-1], log = TRUE) 15 | ) 16 | ) 17 | 18 | expect_true( 19 | poisson_log_loss(count_results, count, pred)[[".estimate"]] < 20 | poisson_log_loss(count_poor, count, pred)[[".estimate"]] 21 | ) 22 | }) 23 | 24 | test_that("poisson log-loss handles 0 valued estimates (#513)", { 25 | count_results <- data_counts()$basic 26 | 27 | count_results$pred <- 0 28 | 29 | expect_false( 30 | is.nan(poisson_log_loss(count_results, count, pred)[[".estimate"]]), 31 | ) 32 | expect_false( 33 | is.infinite(poisson_log_loss(count_results, count, pred)[[".estimate"]]), 34 | ) 35 | }) 36 | 37 | test_that("weighted results are working", { 38 | count_results <- data_counts()$basic 39 | count_results$weights <- c(1, 2, 1, 1, 2, 1) 40 | 41 | expect_identical( 42 | poisson_log_loss(count_results, count, pred, case_weights = weights)[[ 43 | ".estimate" 44 | ]], 45 | yardstick_mean( 46 | log(gamma(count_results$count + 1)) + 47 | count_results$pred - 48 | log(count_results$pred) * count_results$count, 49 | case_weights = count_results$weights 50 | ) 51 | ) 52 | }) 53 | 54 | test_that("works with hardhat case weights", { 55 | count_results <- data_counts()$basic 56 | count_results$weights <- c(1, 2, 1, 1, 2, 1) 57 | 58 | df <- count_results 59 | 60 | imp_wgt <- hardhat::importance_weights(df$weights) 61 | freq_wgt <- hardhat::frequency_weights(df$weights) 62 | 63 | expect_no_error( 64 | poisson_log_loss_vec(df$count, df$pred, case_weights = imp_wgt) 65 | ) 66 | 67 | expect_no_error( 68 | poisson_log_loss_vec(df$count, df$pred, case_weights = freq_wgt) 69 | ) 70 | }) 71 | -------------------------------------------------------------------------------- /tests/testthat/test-num-pseudo_huber_loss.R: -------------------------------------------------------------------------------- 1 | test_that("Pseudo-Huber Loss", { 2 | ex_dat <- generate_numeric_test_data() 3 | not_na <- !is.na(ex_dat$pred_na) 4 | 5 | delta <- 2 6 | expect_equal( 7 | huber_loss_pseudo(ex_dat, truth = "obs", estimate = "pred", delta = delta)[[ 8 | ".estimate" 9 | ]], 10 | { 11 | a <- ex_dat$obs - ex_dat$pred 12 | mean(delta^2 * (sqrt(1 + (a / delta)^2) - 1)) 13 | } 14 | ) 15 | expect_equal( 16 | huber_loss_pseudo( 17 | ex_dat, 18 | truth = "obs", 19 | estimate = "pred_na", 20 | delta = delta 21 | )[[".estimate"]], 22 | { 23 | a <- ex_dat$obs[not_na] - ex_dat$pred[not_na] 24 | mean(delta^2 * (sqrt(1 + (a / delta)^2) - 1)) 25 | } 26 | ) 27 | 28 | expect_snapshot( 29 | error = TRUE, 30 | huber_loss_pseudo(ex_dat, truth = "obs", estimate = "pred_na", delta = -1) 31 | ) 32 | 33 | expect_snapshot( 34 | error = TRUE, 35 | huber_loss_pseudo( 36 | ex_dat, 37 | truth = "obs", 38 | estimate = "pred_na", 39 | delta = c(1, 2) 40 | ) 41 | ) 42 | }) 43 | 44 | test_that("Weighted results are working", { 45 | truth <- c(1, 2, 3) 46 | estimate <- c(2, 4, 3) 47 | weights <- c(1, 2, 1) 48 | 49 | expect_identical( 50 | huber_loss_pseudo_vec(truth, estimate, case_weights = weights), 51 | yardstick_mean(sqrt(1 + (truth - estimate)^2) - 1, case_weights = weights) 52 | ) 53 | }) 54 | 55 | test_that("works with hardhat case weights", { 56 | solubility_test$weights <- floor(read_weights_solubility_test()) 57 | df <- solubility_test 58 | 59 | imp_wgt <- hardhat::importance_weights(df$weights) 60 | freq_wgt <- hardhat::frequency_weights(df$weights) 61 | 62 | expect_no_error( 63 | huber_loss_pseudo_vec(df$solubility, df$prediction, case_weights = imp_wgt) 64 | ) 65 | 66 | expect_no_error( 67 | huber_loss_pseudo_vec(df$solubility, df$prediction, case_weights = freq_wgt) 68 | ) 69 | }) 70 | -------------------------------------------------------------------------------- /tests/testthat/test-num-rmse.R: -------------------------------------------------------------------------------- 1 | test_that("rmse", { 2 | ex_dat <- generate_numeric_test_data() 3 | not_na <- !is.na(ex_dat$pred_na) 4 | 5 | expect_equal( 6 | rmse(ex_dat, truth = "obs", estimate = "pred")[[".estimate"]], 7 | sqrt(mean((ex_dat$obs - ex_dat$pred)^2)) 8 | ) 9 | expect_equal( 10 | rmse(ex_dat, truth = obs, estimate = "pred_na")[[".estimate"]], 11 | sqrt(mean((ex_dat$obs[not_na] - ex_dat$pred[not_na])^2)) 12 | ) 13 | }) 14 | 15 | test_that("Weighted results are the same as scikit-learn", { 16 | solubility_test$weights <- read_weights_solubility_test() 17 | 18 | expect_identical( 19 | rmse(solubility_test, solubility, prediction, case_weights = weights)[[ 20 | ".estimate" 21 | ]], 22 | read_pydata("py-rmse")$case_weight 23 | ) 24 | }) 25 | 26 | test_that("Integer columns are allowed (#44)", { 27 | ex_dat <- generate_numeric_test_data() 28 | ex_dat$obs <- as.integer(ex_dat$obs) 29 | 30 | expect_equal( 31 | rmse(ex_dat, truth = "obs", estimate = "pred")[[".estimate"]], 32 | sqrt(mean((ex_dat$obs - ex_dat$pred)^2)) 33 | ) 34 | }) 35 | 36 | test_that("works with hardhat case weights", { 37 | solubility_test$weights <- floor(read_weights_solubility_test()) 38 | df <- solubility_test 39 | 40 | imp_wgt <- hardhat::importance_weights(df$weights) 41 | freq_wgt <- hardhat::frequency_weights(df$weights) 42 | 43 | expect_no_error( 44 | rmse_vec(df$solubility, df$prediction, case_weights = imp_wgt) 45 | ) 46 | 47 | expect_no_error( 48 | rmse_vec(df$solubility, df$prediction, case_weights = freq_wgt) 49 | ) 50 | }) 51 | -------------------------------------------------------------------------------- /tests/testthat/test-num-rpd.R: -------------------------------------------------------------------------------- 1 | test_that("rpd", { 2 | ex_dat <- generate_numeric_test_data() 3 | not_na <- !is.na(ex_dat$pred_na) 4 | 5 | expect_equal( 6 | rpd(ex_dat, truth = "obs", estimate = "pred")[[".estimate"]], 7 | stats::sd(ex_dat$obs) / (sqrt(mean((ex_dat$obs - ex_dat$pred)^2))) 8 | ) 9 | expect_equal( 10 | rpd(ex_dat, truth = "obs", estimate = "pred_na")[[".estimate"]], 11 | stats::sd(ex_dat$obs[not_na]) / 12 | (sqrt(mean((ex_dat$obs[not_na] - ex_dat$pred[not_na])^2))) 13 | ) 14 | }) 15 | 16 | test_that("case weights are applied", { 17 | solubility_test$weights <- read_weights_solubility_test() 18 | 19 | expect_identical( 20 | rpd(solubility_test, solubility, prediction, case_weights = weights)[[ 21 | ".estimate" 22 | ]], 23 | { 24 | sd <- yardstick_sd( 25 | solubility_test$solubility, 26 | case_weights = solubility_test$weights 27 | ) 28 | rmse <- rmse_vec( 29 | solubility_test$solubility, 30 | solubility_test$prediction, 31 | case_weights = solubility_test$weights 32 | ) 33 | sd / rmse 34 | } 35 | ) 36 | }) 37 | 38 | test_that("works with hardhat case weights", { 39 | solubility_test$weights <- floor(read_weights_solubility_test()) 40 | df <- solubility_test 41 | 42 | imp_wgt <- hardhat::importance_weights(df$weights) 43 | freq_wgt <- hardhat::frequency_weights(df$weights) 44 | 45 | expect_no_error( 46 | rpd_vec(df$solubility, df$prediction, case_weights = imp_wgt) 47 | ) 48 | 49 | expect_no_error( 50 | rpd_vec(df$solubility, df$prediction, case_weights = freq_wgt) 51 | ) 52 | }) 53 | -------------------------------------------------------------------------------- /tests/testthat/test-num-rpiq.R: -------------------------------------------------------------------------------- 1 | test_that("rpiq", { 2 | ex_dat <- generate_numeric_test_data() 3 | not_na <- !is.na(ex_dat$pred_na) 4 | 5 | # Note: Uses `quantile(type = 7)` when case weights aren't provided 6 | expect_equal( 7 | rpiq(ex_dat, truth = "obs", estimate = "pred")[[".estimate"]], 8 | stats::IQR(ex_dat$obs) / sqrt(mean((ex_dat$obs - ex_dat$pred)^2)) 9 | ) 10 | expect_equal( 11 | rpiq(ex_dat, truth = "obs", estimate = "pred_na")[[".estimate"]], 12 | stats::IQR(ex_dat$obs[not_na]) / 13 | sqrt(mean((ex_dat$obs[not_na] - ex_dat$pred[not_na])^2)) 14 | ) 15 | }) 16 | 17 | test_that("case weights are applied", { 18 | solubility_test$weights <- read_weights_solubility_test() 19 | 20 | expect_equal( 21 | rpiq(solubility_test, solubility, prediction, case_weights = weights)[[ 22 | ".estimate" 23 | ]], 24 | 3.401406885440771965534 25 | ) 26 | }) 27 | 28 | test_that("works with hardhat case weights", { 29 | count_results <- data_counts()$basic 30 | count_results$weights <- c(1, 2, 1, 1, 2, 1) 31 | 32 | df <- count_results 33 | 34 | imp_wgt <- hardhat::importance_weights(df$weights) 35 | freq_wgt <- hardhat::frequency_weights(df$weights) 36 | 37 | expect_no_error( 38 | rpiq_vec(df$count, df$pred, case_weights = imp_wgt) 39 | ) 40 | 41 | expect_no_error( 42 | rpiq_vec(df$count, df$pred, case_weights = freq_wgt) 43 | ) 44 | }) 45 | -------------------------------------------------------------------------------- /tests/testthat/test-num-rsq.R: -------------------------------------------------------------------------------- 1 | test_that("R^2", { 2 | ex_dat <- generate_numeric_test_data() 3 | 4 | expect_equal( 5 | rsq(ex_dat, truth = "obs", estimate = "pred")[[".estimate"]], 6 | stats::cor(ex_dat[, 1:2])[1, 2]^2 7 | ) 8 | expect_equal( 9 | rsq(ex_dat, truth = "obs", estimate = "pred_na")[[".estimate"]], 10 | stats::cor(ex_dat[, c(1, 3)], use = "complete.obs")[1, 2]^2 11 | ) 12 | expect_equal( 13 | rsq(ex_dat, truth = "obs", estimate = "rand")[[".estimate"]], 14 | stats::cor(ex_dat[, c(1, 4)])[1, 2]^2 15 | ) 16 | expect_equal( 17 | rsq(ex_dat, estimate = rand_na, truth = obs)[[".estimate"]], 18 | stats::cor(ex_dat[, c(1, 5)], use = "complete.obs")[1, 2]^2 19 | ) 20 | }) 21 | 22 | test_that("case weights are applied", { 23 | df <- dplyr::tibble( 24 | truth = c(1, 2, 3, 4, 5), 25 | estimate = c(1, 3, 1, 3, 2), 26 | weight = c(1, 0, 1, 0, 1) 27 | ) 28 | 29 | expect_identical( 30 | rsq(df, truth, estimate, case_weights = weight), 31 | rsq(df[as.logical(df$weight), ], truth, estimate) 32 | ) 33 | }) 34 | 35 | test_that("yardstick correlation warnings are thrown", { 36 | expect_snapshot({ 37 | (expect_warning( 38 | object = out <- rsq_vec(1, 1), 39 | class = "yardstick_warning_correlation_undefined_size_zero_or_one" 40 | )) 41 | }) 42 | expect_identical(out, NA_real_) 43 | 44 | expect_snapshot({ 45 | (expect_warning( 46 | object = out <- rsq_vec(double(), double()), 47 | class = "yardstick_warning_correlation_undefined_size_zero_or_one" 48 | )) 49 | }) 50 | expect_identical(out, NA_real_) 51 | 52 | expect_snapshot({ 53 | (expect_warning( 54 | object = out <- rsq_vec(c(1, 2), c(1, 1)), 55 | class = "yardstick_warning_correlation_undefined_constant_estimate" 56 | )) 57 | }) 58 | expect_identical(out, NA_real_) 59 | 60 | expect_snapshot({ 61 | (expect_warning( 62 | object = out <- rsq_vec(c(1, 1), c(1, 2)), 63 | class = "yardstick_warning_correlation_undefined_constant_truth" 64 | )) 65 | }) 66 | expect_identical(out, NA_real_) 67 | }) 68 | 69 | test_that("works with hardhat case weights", { 70 | solubility_test$weights <- floor(read_weights_solubility_test()) 71 | df <- solubility_test 72 | 73 | imp_wgt <- hardhat::importance_weights(df$weights) 74 | freq_wgt <- hardhat::frequency_weights(df$weights) 75 | 76 | expect_no_error( 77 | rsq_vec(df$solubility, df$prediction, case_weights = imp_wgt) 78 | ) 79 | 80 | expect_no_error( 81 | rsq_vec(df$solubility, df$prediction, case_weights = freq_wgt) 82 | ) 83 | }) 84 | -------------------------------------------------------------------------------- /tests/testthat/test-num-rsq_trad.R: -------------------------------------------------------------------------------- 1 | test_that("Traditional R^2", { 2 | ex_dat <- generate_numeric_test_data() 3 | not_na <- !is.na(ex_dat$pred_na) 4 | 5 | expect_equal( 6 | rsq_trad(ex_dat, truth = "obs", estimate = "pred")[[".estimate"]], 7 | 1 - 8 | (sum((ex_dat$obs - ex_dat$pred)^2) / 9 | sum((ex_dat$obs - mean(ex_dat$obs))^2)) 10 | ) 11 | expect_equal( 12 | rsq_trad(ex_dat, truth = "obs", estimate = "pred_na")[[".estimate"]], 13 | 1 - 14 | (sum((ex_dat$obs[not_na] - ex_dat$pred[not_na])^2) / 15 | sum((ex_dat$obs[not_na] - mean(ex_dat$obs[not_na]))^2)) 16 | ) 17 | expect_equal( 18 | rsq_trad(ex_dat, truth = "obs", estimate = rand)[[".estimate"]], 19 | 1 - 20 | (sum((ex_dat$obs - ex_dat$rand)^2) / 21 | sum((ex_dat$obs - mean(ex_dat$obs))^2)) 22 | ) 23 | expect_equal( 24 | rsq_trad(ex_dat, obs, rand_na)[[".estimate"]], 25 | 1 - 26 | (sum((ex_dat$obs[not_na] - ex_dat$rand[not_na])^2) / 27 | sum((ex_dat$obs[not_na] - mean(ex_dat$obs[not_na]))^2)) 28 | ) 29 | }) 30 | 31 | test_that("Weighted results are the same as scikit-learn", { 32 | solubility_test$weights <- read_weights_solubility_test() 33 | 34 | expect_equal( 35 | rsq_trad(solubility_test, solubility, prediction, case_weights = weights)[[ 36 | ".estimate" 37 | ]], 38 | read_pydata("py-rsq-trad")$case_weight 39 | ) 40 | }) 41 | 42 | test_that("works with hardhat case weights", { 43 | solubility_test$weights <- floor(read_weights_solubility_test()) 44 | df <- solubility_test 45 | 46 | imp_wgt <- hardhat::importance_weights(df$weights) 47 | freq_wgt <- hardhat::frequency_weights(df$weights) 48 | 49 | expect_no_error( 50 | rsq_trad_vec(df$solubility, df$prediction, case_weights = imp_wgt) 51 | ) 52 | 53 | expect_no_error( 54 | rsq_trad_vec(df$solubility, df$prediction, case_weights = freq_wgt) 55 | ) 56 | }) 57 | -------------------------------------------------------------------------------- /tests/testthat/test-num-smape.R: -------------------------------------------------------------------------------- 1 | test_that("Symmetric Mean Absolute Percentage Error", { 2 | ex_dat <- generate_numeric_test_data() 3 | not_na <- !is.na(ex_dat$pred_na) 4 | 5 | expect_equal( 6 | smape(ex_dat, truth = "obs", estimate = "pred")[[".estimate"]], 7 | 100 * 8 | mean( 9 | abs( 10 | (ex_dat$obs - ex_dat$pred) / 11 | ((abs(ex_dat$obs) + abs(ex_dat$pred)) / 2) 12 | ) 13 | ) 14 | ) 15 | expect_equal( 16 | smape(ex_dat, obs, pred_na)[[".estimate"]], 17 | 100 * 18 | mean( 19 | abs( 20 | (ex_dat$obs[not_na] - ex_dat$pred[not_na]) / 21 | ((abs(ex_dat$obs[not_na]) + abs(ex_dat$pred[not_na])) / 2) 22 | ) 23 | ) 24 | ) 25 | }) 26 | 27 | test_that("Weighted results are working", { 28 | truth <- c(1, 2, 3) 29 | estimate <- c(2, 4, 3) 30 | weights <- c(1, 2, 1) 31 | 32 | expect_identical( 33 | smape_vec(truth, estimate, case_weights = weights), 34 | 50 35 | ) 36 | }) 37 | 38 | test_that("works with hardhat case weights", { 39 | solubility_test$weights <- floor(read_weights_solubility_test()) 40 | df <- solubility_test 41 | 42 | imp_wgt <- hardhat::importance_weights(df$weights) 43 | freq_wgt <- hardhat::frequency_weights(df$weights) 44 | 45 | expect_no_error( 46 | smape_vec(df$solubility, df$prediction, case_weights = imp_wgt) 47 | ) 48 | 49 | expect_no_error( 50 | smape_vec(df$solubility, df$prediction, case_weights = freq_wgt) 51 | ) 52 | }) 53 | -------------------------------------------------------------------------------- /tests/testthat/test-orderedprob-ranked_prob_score.R: -------------------------------------------------------------------------------- 1 | test_that("basic results", { 2 | hpc_cv$obs <- as.ordered(hpc_cv$obs) 3 | 4 | # With orf:::rps(as.matrix(hpc_cv[, 3:6]), hpc_cv$obs) 5 | hpc_exp <- 0.08566779 6 | 7 | expect_equal( 8 | yardstick:::ranked_prob_score_vec( 9 | hpc_cv$obs, 10 | as.matrix(hpc_cv |> dplyr::select(VF:L)) 11 | ), 12 | hpc_exp, 13 | tolerance = 0.01 14 | ) 15 | 16 | expect_equal( 17 | yardstick:::ranked_prob_score(hpc_cv, obs, VF:L), 18 | dplyr::tibble( 19 | .metric = "ranked_prob_score", 20 | .estimator = "multiclass", 21 | .estimate = hpc_exp 22 | ), 23 | tolerance = 0.01 24 | ) 25 | 26 | # ---------------------------------------------------------------------------- 27 | # with missing data 28 | hpc_miss <- hpc_cv 29 | hpc_miss$obs <- as.ordered(hpc_miss$obs) 30 | hpc_miss$obs[1] <- NA 31 | hpc_miss$L[2] <- NA 32 | 33 | cmlpt_ind <- complete.cases(hpc_miss) 34 | 35 | # With orf:::rps(as.matrix(hpc_cv[cmlpt_ind, 3:6]), hpc_cv$obs[cmlpt_ind]) 36 | hpc_miss_exp <- 0.08571614 37 | expect_equal( 38 | ranked_prob_score(hpc_miss, obs, VF:L)$.estimate, 39 | hpc_miss_exp, 40 | tolerance = 0.01 41 | ) 42 | 43 | expect_equal( 44 | ranked_prob_score(hpc_miss, obs, VF:L, na_rm = FALSE)$.estimate, 45 | NA_real_ 46 | ) 47 | }) 48 | 49 | test_that("works with hardhat case weights", { 50 | df <- two_class_example 51 | df$truth <- as.ordered(df$truth) 52 | 53 | imp_wgt <- hardhat::importance_weights(seq_len(nrow(df))) 54 | freq_wgt <- hardhat::frequency_weights(seq_len(nrow(df))) 55 | 56 | expect_no_error( 57 | ranked_prob_score_vec( 58 | df$truth, 59 | as.matrix(df[c("Class1", "Class2")]), 60 | case_weights = imp_wgt 61 | ) 62 | ) 63 | 64 | expect_no_error( 65 | ranked_prob_score_vec( 66 | df$truth, 67 | as.matrix(df[c("Class1", "Class2")]), 68 | case_weights = freq_wgt 69 | ) 70 | ) 71 | }) 72 | 73 | test_that("errors with bad input", { 74 | skip_if_not_installed("probably") 75 | 76 | cp_truth <- probably::as_class_pred(two_class_example$truth, which = 1) 77 | fct_truth <- two_class_example$truth 78 | fct_truth[1] <- NA 79 | ord_truth <- as.ordered(two_class_example$truth) 80 | 81 | estimate_1D <- two_class_example$Class1 82 | estimate <- two_class_example[, 2:3] 83 | 84 | expect_snapshot( 85 | error = TRUE, 86 | ranked_prob_score_vec(cp_truth, estimate) 87 | ) 88 | expect_snapshot( 89 | error = TRUE, 90 | ranked_prob_score_vec(two_class_example$truth, estimate) 91 | ) 92 | expect_snapshot( 93 | error = TRUE, 94 | ranked_prob_score_vec(ord_truth, estimate_1D) 95 | ) 96 | }) 97 | -------------------------------------------------------------------------------- /tests/testthat/test-probably.R: -------------------------------------------------------------------------------- 1 | test_that("`class_pred` can be converted to `factor` when computing metrics", { 2 | skip_if_not_installed("probably") 3 | 4 | cp_truth <- probably::as_class_pred(two_class_example$truth, which = 1) 5 | cp_estimate <- probably::as_class_pred(two_class_example$predicted, which = 2) 6 | 7 | fct_truth <- two_class_example$truth 8 | fct_truth[1] <- NA 9 | 10 | fct_estimate <- two_class_example$predicted 11 | fct_estimate[2] <- NA 12 | 13 | expect_identical( 14 | accuracy_vec(fct_truth, cp_estimate), 15 | accuracy_vec(fct_truth, fct_estimate) 16 | ) 17 | 18 | expect_identical( 19 | accuracy_vec(fct_truth, cp_estimate, na_rm = FALSE), 20 | NA_real_ 21 | ) 22 | 23 | expect_snapshot( 24 | error = TRUE, 25 | accuracy_vec(cp_truth, cp_estimate) 26 | ) 27 | }) 28 | 29 | test_that("`class_pred` errors when passed to `conf_mat()`", { 30 | skip_if_not_installed("probably") 31 | 32 | cp_hpc_cv <- hpc_cv 33 | cp_hpc_cv$obs <- probably::as_class_pred(cp_hpc_cv$obs, which = 1) 34 | cp_hpc_cv$pred <- probably::as_class_pred(cp_hpc_cv$pred, which = 2) 35 | 36 | expect_snapshot( 37 | error = TRUE, 38 | conf_mat(cp_hpc_cv, obs, pred) 39 | ) 40 | 41 | expect_snapshot( 42 | error = TRUE, 43 | conf_mat(dplyr::group_by(cp_hpc_cv, Resample), obs, pred) 44 | ) 45 | }) 46 | 47 | test_that("`class_pred` errors when passed to `metrics()`", { 48 | skip_if_not_installed("probably") 49 | 50 | cp_truth <- probably::as_class_pred(two_class_example$truth, which = 1) 51 | cp_estimate <- probably::as_class_pred(two_class_example$predicted, which = 2) 52 | cp_df <- data.frame( 53 | truth = cp_truth, 54 | estimate = cp_estimate, 55 | class1 = two_class_example$Class1 56 | ) 57 | 58 | expect_snapshot( 59 | error = TRUE, 60 | metrics(cp_df, truth, estimate, class1) 61 | ) 62 | }) 63 | -------------------------------------------------------------------------------- /tests/testthat/test-surv-brier_survival.R: -------------------------------------------------------------------------------- 1 | test_that("case weights", { 2 | skip_if_not_installed("tidyr") 3 | 4 | lung_surv <- data_lung_surv() 5 | 6 | brier_res <- brier_survival( 7 | data = lung_surv, 8 | truth = surv_obj, 9 | .pred 10 | ) 11 | 12 | expect_equal( 13 | names(brier_res), 14 | c(".metric", ".estimator", ".eval_time", ".estimate") 15 | ) 16 | }) 17 | 18 | test_that("case weights", { 19 | skip_if_not_installed("tidyr") 20 | 21 | lung_surv <- data_lung_surv() 22 | lung_surv$case_wts <- rep(2, nrow(lung_surv)) 23 | 24 | brier_res <- brier_survival( 25 | data = lung_surv, 26 | truth = surv_obj, 27 | .pred 28 | ) 29 | 30 | brier_res_case_wts <- brier_survival( 31 | data = lung_surv, 32 | truth = surv_obj, 33 | .pred, 34 | case_weights = case_wts 35 | ) 36 | 37 | expect_equal( 38 | brier_res$.estimate, 39 | brier_res_case_wts$.estimate 40 | ) 41 | }) 42 | 43 | test_that("works with hardhat case weights", { 44 | skip_if_not_installed("tidyr") 45 | 46 | lung_surv <- data_lung_surv() 47 | lung_surv$case_wts <- rep(2, nrow(lung_surv)) 48 | 49 | df <- lung_surv 50 | 51 | df$imp_wgt <- hardhat::importance_weights(lung_surv$case_wts) 52 | df$freq_wgt <- hardhat::frequency_weights(lung_surv$case_wts) 53 | 54 | expect_no_error( 55 | brier_survival(df, truth = surv_obj, .pred, case_weights = imp_wgt) 56 | ) 57 | 58 | expect_no_error( 59 | brier_survival(df, truth = surv_obj, .pred, case_weights = freq_wgt) 60 | ) 61 | }) 62 | 63 | # riskRegression compare ------------------------------------------------------- 64 | 65 | test_that("riskRegression equivalent", { 66 | skip_if_not_installed("tidyr") 67 | 68 | riskRegression_res <- readRDS(test_path("data/brier_churn_res.rds")) 69 | 70 | yardstick_res <- readRDS(test_path("data/tidy_churn.rds")) |> 71 | brier_survival( 72 | truth = surv_obj, 73 | .pred 74 | ) 75 | 76 | expect_identical( 77 | riskRegression_res$times, 78 | yardstick_res$.eval_time 79 | ) 80 | 81 | expect_equal( 82 | riskRegression_res$Brier, 83 | yardstick_res$.estimate 84 | ) 85 | }) 86 | -------------------------------------------------------------------------------- /tests/testthat/test-surv-brier_survival_integrated.R: -------------------------------------------------------------------------------- 1 | test_that("brier_survival_integrated calculations", { 2 | skip_if_not_installed("tidyr") 3 | 4 | lung_surv <- data_lung_surv() 5 | 6 | brier_res <- brier_survival( 7 | data = lung_surv, 8 | truth = surv_obj, 9 | .pred 10 | ) |> 11 | dplyr::summarise( 12 | .estimate = yardstick:::auc(.eval_time, .estimate) / max(.eval_time) 13 | ) 14 | 15 | brier_integrated_res <- brier_survival_integrated( 16 | data = lung_surv, 17 | truth = surv_obj, 18 | .pred 19 | ) 20 | 21 | expect_equal( 22 | brier_res$.estimate, 23 | brier_integrated_res$.estimate 24 | ) 25 | }) 26 | 27 | test_that("brier_survival_integrated calculations", { 28 | skip_if_not_installed("tidyr") 29 | 30 | lung_surv <- data_lung_surv() 31 | 32 | lung_surv$.pred <- lapply(lung_surv$.pred, function(x) x[1, ]) 33 | 34 | expect_snapshot( 35 | error = TRUE, 36 | brier_survival_integrated( 37 | data = lung_surv, 38 | truth = surv_obj, 39 | .pred 40 | ) 41 | ) 42 | }) 43 | 44 | test_that("case weights", { 45 | skip_if_not_installed("tidyr") 46 | 47 | lung_surv <- data_lung_surv() 48 | lung_surv$case_wts <- seq_len(nrow(lung_surv)) 49 | 50 | lung_surv <- data_lung_surv() 51 | 52 | brier_res <- brier_survival( 53 | data = lung_surv, 54 | truth = surv_obj, 55 | .pred 56 | ) |> 57 | dplyr::summarise( 58 | .estimate = yardstick:::auc(.eval_time, .estimate) / max(.eval_time) 59 | ) 60 | 61 | brier_integrated_res <- brier_survival_integrated( 62 | data = lung_surv, 63 | truth = surv_obj, 64 | .pred 65 | ) 66 | 67 | expect_equal( 68 | brier_res$.estimate, 69 | brier_integrated_res$.estimate 70 | ) 71 | }) 72 | 73 | test_that("works with hardhat case weights", { 74 | skip_if_not_installed("tidyr") 75 | 76 | lung_surv <- data_lung_surv() 77 | lung_surv$case_wts <- rep(2, nrow(lung_surv)) 78 | 79 | df <- lung_surv 80 | 81 | df$imp_wgt <- hardhat::importance_weights(lung_surv$case_wts) 82 | df$freq_wgt <- hardhat::frequency_weights(lung_surv$case_wts) 83 | 84 | expect_no_error( 85 | brier_survival_integrated( 86 | df, 87 | truth = surv_obj, 88 | .pred, 89 | case_weights = imp_wgt 90 | ) 91 | ) 92 | 93 | expect_no_error( 94 | brier_survival_integrated( 95 | df, 96 | truth = surv_obj, 97 | .pred, 98 | case_weights = freq_wgt 99 | ) 100 | ) 101 | }) 102 | -------------------------------------------------------------------------------- /tests/testthat/test-surv-concordance_survival.R: -------------------------------------------------------------------------------- 1 | test_that("comparison test with survival", { 2 | res <- concordance_survival( 3 | data = lung_surv, 4 | truth = surv_obj, 5 | estimate = .pred_time 6 | ) 7 | 8 | expect_equal( 9 | res[[".estimate"]], 10 | survival::concordance(surv_obj ~ .pred_time, data = lung_surv)$concordance 11 | ) 12 | }) 13 | 14 | test_that("case weights works", { 15 | lung_surv$wts <- seq_len(nrow(lung_surv)) 16 | 17 | res <- concordance_survival( 18 | data = lung_surv, 19 | truth = surv_obj, 20 | estimate = .pred_time, 21 | case_weights = wts 22 | ) 23 | 24 | expect_equal( 25 | res[[".estimate"]], 26 | survival::concordance( 27 | surv_obj ~ .pred_time, 28 | weights = wts, 29 | data = lung_surv 30 | )$concordance 31 | ) 32 | }) 33 | 34 | test_that("works with infinite time predictions", { 35 | exp_res <- concordance_survival( 36 | data = lung_surv, 37 | truth = surv_obj, 38 | estimate = .pred_time 39 | ) 40 | 41 | lung_surv$.pred_time[which.max(lung_surv$.pred_time)] <- Inf 42 | 43 | expect_no_error( 44 | res <- concordance_survival( 45 | data = lung_surv, 46 | truth = surv_obj, 47 | estimate = .pred_time 48 | ) 49 | ) 50 | 51 | expect_identical(res, exp_res) 52 | 53 | exp_res <- concordance_survival( 54 | data = lung_surv, 55 | truth = surv_obj, 56 | estimate = .pred_time 57 | ) 58 | 59 | lung_surv$.pred_time[which.min(lung_surv$.pred_time)] <- Inf 60 | 61 | expect_no_error( 62 | res <- concordance_survival( 63 | data = lung_surv, 64 | truth = surv_obj, 65 | estimate = .pred_time 66 | ) 67 | ) 68 | 69 | expect_true(!identical(res, exp_res)) 70 | }) 71 | 72 | test_that("works with hardhat case weights", { 73 | lung_surv <- data_lung_surv() 74 | lung_surv$case_wts <- rep(2, nrow(lung_surv)) 75 | 76 | df <- lung_surv 77 | 78 | df$imp_wgt <- hardhat::importance_weights(lung_surv$case_wts) 79 | df$freq_wgt <- hardhat::frequency_weights(lung_surv$case_wts) 80 | 81 | expect_no_error( 82 | concordance_survival( 83 | df, 84 | truth = surv_obj, 85 | .pred_time, 86 | case_weights = imp_wgt 87 | ) 88 | ) 89 | 90 | expect_no_error( 91 | concordance_survival( 92 | df, 93 | truth = surv_obj, 94 | .pred_time, 95 | case_weights = freq_wgt 96 | ) 97 | ) 98 | }) 99 | -------------------------------------------------------------------------------- /vignettes/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | *.R 3 | -------------------------------------------------------------------------------- /yardstick.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: Default 4 | SaveWorkspace: Default 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: knitr 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | 18 | BuildType: Package 19 | PackageUseDevtools: Yes 20 | PackageInstallArgs: --no-multiarch --with-keep.source 21 | PackageRoxygenize: rd,collate,namespace 22 | --------------------------------------------------------------------------------