├── .Rbuildignore ├── .github ├── .gitignore ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md └── workflows │ ├── R-CMD-check-no-suggests.yaml │ ├── R-CMD-check.yaml │ ├── lock.yaml │ ├── pkgdown.yaml │ ├── pr-commands.yaml │ └── test-coverage.yaml ├── .gitignore ├── .vscode ├── extensions.json └── settings.json ├── DESCRIPTION ├── LICENSE ├── LICENSE.md ├── NAMESPACE ├── NEWS.md ├── R ├── 0_imports.R ├── as_workflow_set.R ├── autoplot.R ├── checks.R ├── collect.R ├── comments.R ├── compat-dplyr.R ├── compat-vctrs-helpers.R ├── compat-vctrs.R ├── data.R ├── extract.R ├── fit.R ├── fit_best.R ├── import-standalone-obj-type.R ├── import-standalone-types-check.R ├── leave_var_out_formulas.R ├── misc.R ├── options.R ├── predict.R ├── pull.R ├── rank_results.R ├── update.R ├── workflow_map.R ├── workflow_set.R └── zzz.R ├── README.Rmd ├── README.md ├── _pkgdown.yml ├── air.toml ├── codecov.yml ├── cran-comments.md ├── data ├── chi_features_set.rda └── two_class_set.rda ├── inst └── WORDLIST ├── man-roxygen ├── chi_features_set.Rmd ├── example_data.Rmd └── two_class_set.Rmd ├── man ├── as_workflow_set.Rd ├── autoplot.workflow_set.Rd ├── chi_features_set.Rd ├── collect_metrics.workflow_set.Rd ├── comment_add.Rd ├── extract_workflow_set_result.Rd ├── figures │ ├── README-plot-1.png │ ├── README-plot-1.svg │ ├── README-plot-best-1.png │ ├── README-plot-best-1.svg │ └── lifecycle-soft-deprecated.svg ├── fit_best.workflow_set.Rd ├── leave_var_out_formulas.Rd ├── option_add.Rd ├── option_list.Rd ├── pull_workflow_set_result.Rd ├── rank_results.Rd ├── reexports.Rd ├── two_class_set.Rd ├── update_workflow_model.Rd ├── workflow_map.Rd ├── workflow_set.Rd └── workflowsets-package.Rd ├── tests ├── spelling.R ├── testthat.R └── testthat │ ├── _snaps │ ├── autoplot.md │ ├── checks.md │ ├── collect-extracts.md │ ├── collect-notes.md │ ├── comments.md │ ├── extract.md │ ├── fit.md │ ├── fit_best.md │ ├── leave-var-out-formulas.md │ ├── options.md │ ├── predict.md │ ├── pull.md │ ├── workflow-map.md │ └── workflow_set.md │ ├── helper-compat.R │ ├── helper-extract_parameter_set.R │ ├── test-autoplot.R │ ├── test-checks.R │ ├── test-collect-extracts.R │ ├── test-collect-metrics.R │ ├── test-collect-notes.R │ ├── test-collect-predictions.R │ ├── test-comments.R │ ├── test-compat-dplyr.R │ ├── test-compat-vctrs.R │ ├── test-extract.R │ ├── test-fit.R │ ├── test-fit_best.R │ ├── test-leave-var-out-formulas.R │ ├── test-options.R │ ├── test-predict.R │ ├── test-pull.R │ ├── test-updates.R │ ├── test-workflow-map.R │ └── test-workflow_set.R ├── vignettes ├── .gitignore ├── articles │ └── tuning-and-comparing-models.Rmd └── evaluating-different-predictor-sets.Rmd └── workflowsets.Rproj /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^workflowsets\.Rproj$ 2 | ^\.Rproj\.user$ 3 | ^\.github$ 4 | ^LICENSE\.md$ 5 | ^README\.Rmd$ 6 | ^codecov\.yml$ 7 | ^_pkgdown\.yml$ 8 | ^docs$ 9 | ^pkgdown$ 10 | ^CODE_OF_CONDUCT\.md$ 11 | ^revdep$ 12 | ^cran-comments\.md$ 13 | ^man-roxygen$ 14 | ^[\.]?air\.toml$ 15 | ^\.vscode$ 16 | ^CRAN-SUBMISSION$ 17 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # CODEOWNERS for workflowsets 2 | # https://www.tidyverse.org/development/understudies 3 | .github/CODEOWNERS @topepo @juliasilge 4 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at codeofconduct@posit.co. 63 | All complaints will be reviewed and investigated promptly and fairly. 64 | 65 | All community leaders are obligated to respect the privacy and security of the 66 | reporter of any incident. 67 | 68 | ## Enforcement Guidelines 69 | 70 | Community leaders will follow these Community Impact Guidelines in determining 71 | the consequences for any action they deem in violation of this Code of Conduct: 72 | 73 | ### 1. Correction 74 | 75 | **Community Impact**: Use of inappropriate language or other behavior deemed 76 | unprofessional or unwelcome in the community. 77 | 78 | **Consequence**: A private, written warning from community leaders, providing 79 | clarity around the nature of the violation and an explanation of why the 80 | behavior was inappropriate. A public apology may be requested. 81 | 82 | ### 2. Warning 83 | 84 | **Community Impact**: A violation through a single incident or series of 85 | actions. 86 | 87 | **Consequence**: A warning with consequences for continued behavior. No 88 | interaction with the people involved, including unsolicited interaction with 89 | those enforcing the Code of Conduct, for a specified period of time. This 90 | includes avoiding interactions in community spaces as well as external channels 91 | like social media. Violating these terms may lead to a temporary or permanent 92 | ban. 93 | 94 | ### 3. Temporary Ban 95 | 96 | **Community Impact**: A serious violation of community standards, including 97 | sustained inappropriate behavior. 98 | 99 | **Consequence**: A temporary ban from any sort of interaction or public 100 | communication with the community for a specified period of time. No public or 101 | private interaction with the people involved, including unsolicited interaction 102 | with those enforcing the Code of Conduct, is allowed during this period. 103 | Violating these terms may lead to a permanent ban. 104 | 105 | ### 4. Permanent Ban 106 | 107 | **Community Impact**: Demonstrating a pattern of violation of community 108 | standards, including sustained inappropriate behavior, harassment of an 109 | individual, or aggression toward or disparagement of classes of individuals. 110 | 111 | **Consequence**: A permanent ban from any sort of public interaction within the 112 | community. 113 | 114 | ## Attribution 115 | 116 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 117 | version 2.1, available at 118 | . 119 | 120 | Community Impact Guidelines were inspired by 121 | [Mozilla's code of conduct enforcement ladder][https://github.com/mozilla/inclusion]. 122 | 123 | For answers to common questions about this code of conduct, see the FAQ at 124 | . Translations are available at . 125 | 126 | [homepage]: https://www.contributor-covenant.org 127 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to tidymodels 2 | 3 | For more detailed information about contributing to tidymodels packages, see our [**development contributing guide**](https://www.tidymodels.org/contribute/). 4 | 5 | ## Documentation 6 | 7 | Typos or grammatical errors in documentation may be edited directly using the GitHub web interface, as long as the changes are made in the _source_ file. 8 | 9 | * YES ✅: you edit a roxygen comment in an `.R` file in the `R/` directory. 10 | * NO 🚫: you edit an `.Rd` file in the `man/` directory. 11 | 12 | We use [roxygen2](https://cran.r-project.org/package=roxygen2), with [Markdown syntax](https://cran.r-project.org/web/packages/roxygen2/vignettes/rd-formatting.html), for documentation. 13 | 14 | ## Code 15 | 16 | Before you submit 🎯 a pull request on a tidymodels package, always file an issue and confirm the tidymodels team agrees with your idea and is happy with your basic proposal. 17 | 18 | The [tidymodels packages](https://www.tidymodels.org/packages/) work together. Each package contains its own unit tests, while integration tests and other tests using all the packages are contained in [extratests](https://github.com/tidymodels/extratests). 19 | 20 | * We recommend that you create a Git branch for each pull request (PR). 21 | * Look at the build status before and after making changes. The `README` contains badges for any continuous integration services used by the package. 22 | * New code should follow the tidyverse [style guide](http://style.tidyverse.org). You can use the [styler](https://CRAN.R-project.org/package=styler) package to apply these styles, but please don't restyle code that has nothing to do with your PR. 23 | * For user-facing changes, add a bullet to the top of `NEWS.md` below the current development version header describing the changes made followed by your GitHub username, and links to relevant issue(s)/PR(s). 24 | * We use [testthat](https://cran.r-project.org/package=testthat). Contributions with test cases included are easier to accept. 25 | * If your contribution spans the use of more than one package, consider building [extratests](https://github.com/tidymodels/extratests) with your changes to check for breakages and/or adding new tests there. Let us know in your PR if you ran these extra tests. 26 | 27 | ### Code of Conduct 28 | 29 | This project is released with a [Contributor Code of Conduct](https://contributor-covenant.org/version/2/0/CODE_OF_CONDUCT.html). By contributing to this project, you agree to abide by its terms. 30 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check-no-suggests.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | # 4 | # NOTE: This workflow only directly installs "hard" dependencies, i.e. Depends, 5 | # Imports, and LinkingTo dependencies. Notably, Suggests dependencies are never 6 | # installed, with the exception of testthat, knitr, and rmarkdown. The cache is 7 | # never used to avoid accidentally restoring a cache containing a suggested 8 | # dependency. 9 | on: 10 | push: 11 | branches: [main, master] 12 | pull_request: 13 | branches: [main, master] 14 | 15 | name: R-CMD-check-no-suggests.yaml 16 | 17 | permissions: read-all 18 | 19 | jobs: 20 | check-no-suggests: 21 | runs-on: ${{ matrix.config.os }} 22 | 23 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 24 | 25 | strategy: 26 | fail-fast: false 27 | matrix: 28 | config: 29 | - {os: ubuntu-latest, r: 'release'} 30 | 31 | env: 32 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 33 | R_KEEP_PKG_SOURCE: yes 34 | 35 | steps: 36 | - uses: actions/checkout@v4 37 | 38 | - uses: r-lib/actions/setup-pandoc@v2 39 | 40 | - uses: r-lib/actions/setup-r@v2 41 | with: 42 | r-version: ${{ matrix.config.r }} 43 | http-user-agent: ${{ matrix.config.http-user-agent }} 44 | use-public-rspm: true 45 | 46 | - uses: r-lib/actions/setup-r-dependencies@v2 47 | with: 48 | dependencies: '"hard"' 49 | cache: false 50 | extra-packages: | 51 | any::rcmdcheck 52 | any::testthat 53 | any::knitr 54 | any::rmarkdown 55 | needs: check 56 | 57 | - uses: r-lib/actions/check-r-package@v2 58 | with: 59 | upload-snapshots: true 60 | build_args: 'c("--no-manual","--compact-vignettes=gs+qpdf")' 61 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | # 4 | # NOTE: This workflow is overkill for most R packages and 5 | # check-standard.yaml is likely a better choice. 6 | # usethis::use_github_action("check-standard") will install it. 7 | on: 8 | push: 9 | branches: [main, master] 10 | pull_request: 11 | 12 | name: R-CMD-check.yaml 13 | 14 | permissions: read-all 15 | 16 | jobs: 17 | R-CMD-check: 18 | runs-on: ${{ matrix.config.os }} 19 | 20 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | config: 26 | - {os: macos-latest, r: 'release'} 27 | 28 | - {os: windows-latest, r: 'release'} 29 | # use 4.0 or 4.1 to check with rtools40's older compiler 30 | - {os: windows-latest, r: 'oldrel-4'} 31 | 32 | - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} 33 | - {os: ubuntu-latest, r: 'release'} 34 | - {os: ubuntu-latest, r: 'oldrel-1'} 35 | - {os: ubuntu-latest, r: 'oldrel-2'} 36 | - {os: ubuntu-latest, r: 'oldrel-3'} 37 | - {os: ubuntu-latest, r: 'oldrel-4'} 38 | 39 | env: 40 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 41 | R_KEEP_PKG_SOURCE: yes 42 | 43 | steps: 44 | - uses: actions/checkout@v4 45 | 46 | - uses: r-lib/actions/setup-pandoc@v2 47 | 48 | - uses: r-lib/actions/setup-r@v2 49 | with: 50 | r-version: ${{ matrix.config.r }} 51 | http-user-agent: ${{ matrix.config.http-user-agent }} 52 | use-public-rspm: true 53 | 54 | - uses: r-lib/actions/setup-r-dependencies@v2 55 | with: 56 | extra-packages: any::rcmdcheck 57 | needs: check 58 | 59 | - uses: r-lib/actions/check-r-package@v2 60 | with: 61 | upload-snapshots: true 62 | build_args: 'c("--no-manual","--compact-vignettes=gs+qpdf")' 63 | -------------------------------------------------------------------------------- /.github/workflows/lock.yaml: -------------------------------------------------------------------------------- 1 | name: 'Lock Threads' 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 * * *' 6 | 7 | jobs: 8 | lock: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: dessant/lock-threads@v2 12 | with: 13 | github-token: ${{ github.token }} 14 | issue-lock-inactive-days: '14' 15 | # issue-exclude-labels: '' 16 | # issue-lock-labels: 'outdated' 17 | issue-lock-comment: > 18 | This issue has been automatically locked. If you believe you have 19 | found a related problem, please file a new issue (with a reprex: 20 | ) and link to this issue. 21 | issue-lock-reason: '' 22 | pr-lock-inactive-days: '14' 23 | # pr-exclude-labels: 'wip' 24 | pr-lock-labels: '' 25 | pr-lock-comment: > 26 | This pull request has been automatically locked. If you believe you 27 | have found a related problem, please file a new issue (with a reprex: 28 | ) and link to this issue. 29 | pr-lock-reason: '' 30 | # process-only: 'issues' 31 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | release: 8 | types: [published] 9 | workflow_dispatch: 10 | 11 | name: pkgdown.yaml 12 | 13 | permissions: read-all 14 | 15 | jobs: 16 | pkgdown: 17 | runs-on: ubuntu-latest 18 | # Only restrict concurrency for non-PR jobs 19 | concurrency: 20 | group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }} 21 | env: 22 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 23 | permissions: 24 | contents: write 25 | steps: 26 | - uses: actions/checkout@v4 27 | 28 | - uses: r-lib/actions/setup-pandoc@v2 29 | 30 | - uses: r-lib/actions/setup-r@v2 31 | with: 32 | use-public-rspm: true 33 | 34 | - uses: r-lib/actions/setup-r-dependencies@v2 35 | with: 36 | extra-packages: any::pkgdown, local::. 37 | needs: website 38 | 39 | - name: Build site 40 | run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE) 41 | shell: Rscript {0} 42 | 43 | - name: Deploy to GitHub pages 🚀 44 | if: github.event_name != 'pull_request' 45 | uses: JamesIves/github-pages-deploy-action@v4.5.0 46 | with: 47 | clean: false 48 | branch: gh-pages 49 | folder: docs 50 | -------------------------------------------------------------------------------- /.github/workflows/pr-commands.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | issue_comment: 5 | types: [created] 6 | 7 | name: pr-commands.yaml 8 | 9 | permissions: read-all 10 | 11 | jobs: 12 | document: 13 | if: ${{ github.event.issue.pull_request && (github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') && startsWith(github.event.comment.body, '/document') }} 14 | name: document 15 | runs-on: ubuntu-latest 16 | env: 17 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 18 | permissions: 19 | contents: write 20 | steps: 21 | - uses: actions/checkout@v4 22 | 23 | - uses: r-lib/actions/pr-fetch@v2 24 | with: 25 | repo-token: ${{ secrets.GITHUB_TOKEN }} 26 | 27 | - uses: r-lib/actions/setup-r@v2 28 | with: 29 | use-public-rspm: true 30 | 31 | - uses: r-lib/actions/setup-r-dependencies@v2 32 | with: 33 | extra-packages: any::roxygen2 34 | needs: pr-document 35 | 36 | - name: Document 37 | run: roxygen2::roxygenise() 38 | shell: Rscript {0} 39 | 40 | - name: commit 41 | run: | 42 | git config --local user.name "$GITHUB_ACTOR" 43 | git config --local user.email "$GITHUB_ACTOR@users.noreply.github.com" 44 | git add man/\* NAMESPACE 45 | git commit -m 'Document' 46 | 47 | - uses: r-lib/actions/pr-push@v2 48 | with: 49 | repo-token: ${{ secrets.GITHUB_TOKEN }} 50 | 51 | style: 52 | if: ${{ github.event.issue.pull_request && (github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') && startsWith(github.event.comment.body, '/style') }} 53 | name: style 54 | runs-on: ubuntu-latest 55 | env: 56 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 57 | permissions: 58 | contents: write 59 | steps: 60 | - uses: actions/checkout@v4 61 | 62 | - uses: r-lib/actions/pr-fetch@v2 63 | with: 64 | repo-token: ${{ secrets.GITHUB_TOKEN }} 65 | 66 | - uses: r-lib/actions/setup-r@v2 67 | 68 | - name: Install dependencies 69 | run: install.packages("styler") 70 | shell: Rscript {0} 71 | 72 | - name: Style 73 | run: styler::style_pkg() 74 | shell: Rscript {0} 75 | 76 | - name: commit 77 | run: | 78 | git config --local user.name "$GITHUB_ACTOR" 79 | git config --local user.email "$GITHUB_ACTOR@users.noreply.github.com" 80 | git add \*.R 81 | git commit -m 'Style' 82 | 83 | - uses: r-lib/actions/pr-push@v2 84 | with: 85 | repo-token: ${{ secrets.GITHUB_TOKEN }} 86 | -------------------------------------------------------------------------------- /.github/workflows/test-coverage.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | 8 | name: test-coverage.yaml 9 | 10 | permissions: read-all 11 | 12 | jobs: 13 | test-coverage: 14 | runs-on: ubuntu-latest 15 | env: 16 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | 21 | - uses: r-lib/actions/setup-r@v2 22 | with: 23 | use-public-rspm: true 24 | 25 | - uses: r-lib/actions/setup-r-dependencies@v2 26 | with: 27 | extra-packages: any::covr, any::xml2 28 | needs: coverage 29 | 30 | - name: Test coverage 31 | run: | 32 | cov <- covr::package_coverage( 33 | quiet = FALSE, 34 | clean = FALSE, 35 | install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package") 36 | ) 37 | print(cov) 38 | covr::to_cobertura(cov) 39 | shell: Rscript {0} 40 | 41 | - uses: codecov/codecov-action@v5 42 | with: 43 | # Fail if error if not on PR, or if on PR and token is given 44 | fail_ci_if_error: ${{ github.event_name != 'pull_request' || secrets.CODECOV_TOKEN }} 45 | files: ./cobertura.xml 46 | plugins: noop 47 | disable_search: true 48 | token: ${{ secrets.CODECOV_TOKEN }} 49 | 50 | - name: Show testthat output 51 | if: always() 52 | run: | 53 | ## -------------------------------------------------------------------- 54 | find '${{ runner.temp }}/package' -name 'testthat.Rout*' -exec cat '{}' \; || true 55 | shell: bash 56 | 57 | - name: Upload test results 58 | if: failure() 59 | uses: actions/upload-artifact@v4 60 | with: 61 | name: coverage-test-failures 62 | path: ${{ runner.temp }}/package 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .Rdata 4 | .httr-oauth 5 | .DS_Store 6 | docs 7 | inst/doc 8 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "Posit.air-vscode" 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[r]": { 3 | "editor.formatOnSave": true, 4 | "editor.defaultFormatter": "Posit.air-vscode" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: workflowsets 2 | Title: Create a Collection of 'tidymodels' Workflows 3 | Version: 1.1.1.9000 4 | Authors@R: c( 5 | person("Hannah", "Frick", , "hannah@posit.co", role =c("aut", "cre"), 6 | comment = c(ORCID = "0000-0002-6049-5258")), 7 | person("Max", "Kuhn", , "max@posit.co", role = "aut", 8 | comment = c(ORCID = "0000-0003-2402-136X")), 9 | person("Simon", "Couch", , "simon.couch@posit.co", role = "aut", 10 | comment = c(ORCID = "0000-0001-5676-5107")), 11 | person("Posit Software, PBC", role = c("cph", "fnd"), 12 | comment = c(ROR = "03wc8by49")) 13 | ) 14 | Description: A workflow is a combination of a model and preprocessors 15 | (e.g, a formula, recipe, etc.) (Kuhn and Silge (2021) 16 | ). In order to try different combinations of 17 | these, an object can be created that contains many workflows. There 18 | are functions to create workflows en masse as well as training them 19 | and visualizing the results. 20 | License: MIT + file LICENSE 21 | URL: https://github.com/tidymodels/workflowsets, 22 | https://workflowsets.tidymodels.org 23 | BugReports: https://github.com/tidymodels/workflowsets/issues 24 | Depends: 25 | R (>= 4.1) 26 | Imports: 27 | cli, 28 | dplyr (>= 1.0.0), 29 | generics (>= 0.1.2), 30 | ggplot2, 31 | hardhat (>= 1.2.0), 32 | lifecycle (>= 1.0.0), 33 | parsnip (>= 1.2.1), 34 | pillar (>= 1.7.0), 35 | prettyunits, 36 | purrr, 37 | rlang (>= 1.1.0), 38 | rsample (>= 0.0.9), 39 | stats, 40 | tibble (>= 3.1.0), 41 | tidyr, 42 | tune (>= 1.2.0), 43 | vctrs, 44 | withr, 45 | workflows (>= 1.1.4) 46 | Suggests: 47 | covr, 48 | dials (>= 0.1.0), 49 | finetune, 50 | kknn, 51 | knitr, 52 | modeldata, 53 | recipes (>= 1.1.0), 54 | rmarkdown, 55 | spelling, 56 | testthat (>= 3.0.0), 57 | tidyclust, 58 | yardstick (>= 1.3.0) 59 | VignetteBuilder: 60 | knitr 61 | Config/Needs/website: discrim, rpart, mda, klaR, earth, tidymodels, 62 | tidyverse/tidytemplate 63 | Config/testthat/edition: 3 64 | Config/usethis/last-upkeep: 2025-04-25 65 | Encoding: UTF-8 66 | Language: en-US 67 | LazyData: true 68 | Roxygen: list(markdown = TRUE) 69 | RoxygenNote: 7.3.2 70 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2025 2 | COPYRIGHT HOLDER: workflowsets authors 3 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2025 workflowsets authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method("[",workflow_set) 4 | S3method("names<-",workflow_set) 5 | S3method(autoplot,workflow_set) 6 | S3method(collect_extracts,workflow_set) 7 | S3method(collect_metrics,workflow_set) 8 | S3method(collect_notes,workflow_set) 9 | S3method(collect_predictions,workflow_set) 10 | S3method(dplyr_reconstruct,workflow_set) 11 | S3method(extract_fit_engine,workflow_set) 12 | S3method(extract_fit_parsnip,workflow_set) 13 | S3method(extract_mold,workflow_set) 14 | S3method(extract_parameter_dials,workflow_set) 15 | S3method(extract_parameter_set_dials,workflow_set) 16 | S3method(extract_preprocessor,workflow_set) 17 | S3method(extract_recipe,workflow_set) 18 | S3method(extract_spec_parsnip,workflow_set) 19 | S3method(extract_workflow,workflow_set) 20 | S3method(fit,workflow_set) 21 | S3method(fit_best,workflow_set) 22 | S3method(obj_sum,workflow_set_options) 23 | S3method(predict,workflow_set) 24 | S3method(print,workflow_set_options) 25 | S3method(size_sum,workflow_set_options) 26 | S3method(tbl_sum,workflow_set) 27 | S3method(type_sum,workflow_set_options) 28 | S3method(vec_cast,data.frame.workflow_set) 29 | S3method(vec_cast,tbl_df.workflow_set) 30 | S3method(vec_cast,workflow_set.data.frame) 31 | S3method(vec_cast,workflow_set.tbl_df) 32 | S3method(vec_cast,workflow_set.workflow_set) 33 | S3method(vec_ptype2,data.frame.workflow_set) 34 | S3method(vec_ptype2,tbl_df.workflow_set) 35 | S3method(vec_ptype2,workflow_set.data.frame) 36 | S3method(vec_ptype2,workflow_set.tbl_df) 37 | S3method(vec_ptype2,workflow_set.workflow_set) 38 | S3method(vec_restore,workflow_set) 39 | export("%>%") 40 | export(as_workflow_set) 41 | export(autoplot) 42 | export(collect_extracts) 43 | export(collect_metrics) 44 | export(collect_notes) 45 | export(collect_predictions) 46 | export(comment_add) 47 | export(comment_get) 48 | export(comment_print) 49 | export(comment_reset) 50 | export(extract_fit_engine) 51 | export(extract_fit_parsnip) 52 | export(extract_mold) 53 | export(extract_parameter_dials) 54 | export(extract_parameter_set_dials) 55 | export(extract_preprocessor) 56 | export(extract_recipe) 57 | export(extract_spec_parsnip) 58 | export(extract_workflow) 59 | export(extract_workflow_set_result) 60 | export(fit_best) 61 | export(leave_var_out_formulas) 62 | export(option_add) 63 | export(option_add_parameters) 64 | export(option_list) 65 | export(option_remove) 66 | export(pull_workflow) 67 | export(pull_workflow_set_result) 68 | export(rank_results) 69 | export(update_workflow_model) 70 | export(update_workflow_recipe) 71 | export(workflow_map) 72 | export(workflow_set) 73 | import(ggplot2) 74 | import(rlang) 75 | import(vctrs) 76 | importFrom(dplyr,"%>%") 77 | importFrom(dplyr,dplyr_reconstruct) 78 | importFrom(generics,fit) 79 | importFrom(ggplot2,autoplot) 80 | importFrom(hardhat,extract_fit_engine) 81 | importFrom(hardhat,extract_fit_parsnip) 82 | importFrom(hardhat,extract_mold) 83 | importFrom(hardhat,extract_parameter_dials) 84 | importFrom(hardhat,extract_parameter_set_dials) 85 | importFrom(hardhat,extract_preprocessor) 86 | importFrom(hardhat,extract_recipe) 87 | importFrom(hardhat,extract_spec_parsnip) 88 | importFrom(hardhat,extract_workflow) 89 | importFrom(lifecycle,deprecated) 90 | importFrom(pillar,obj_sum) 91 | importFrom(pillar,size_sum) 92 | importFrom(pillar,tbl_sum) 93 | importFrom(pillar,type_sum) 94 | importFrom(stats,as.formula) 95 | importFrom(stats,model.frame) 96 | importFrom(stats,predict) 97 | importFrom(stats,qnorm) 98 | importFrom(tune,collect_extracts) 99 | importFrom(tune,collect_metrics) 100 | importFrom(tune,collect_notes) 101 | importFrom(tune,collect_predictions) 102 | importFrom(tune,fit_best) 103 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # workflowsets (development version) 2 | 3 | # workflowsets 1.1.1 4 | 5 | * Added a `collect_extracts()` method for workflow sets (@jrosell, #156). 6 | 7 | * The deprecation of the `pull_*()` functions has been moved forward. These functions now error. Please use the `extract_*()` functions instead (#178). 8 | 9 | * Increased the minimum required R version to R 4.1. 10 | 11 | # workflowsets 1.1.0 12 | 13 | * Ellipses (...) are now used consistently in the package to require optional arguments to be named; `collect_metrics()` and `collect_predictions()` are the only functions that received changes (#151, tidymodels/tune#863). 14 | * Enabled evaluating censored regression models (#139, #144). The placement of 15 | the new `eval_time` argument to `rank_results()` breaks passing-by-position 16 | for the `select_best` argument. 17 | * Added a `collect_notes()` method for workflow sets (#135). 18 | * Added methods to improve error messages when workflow sets are mistakenly 19 | passed to unsupported functions like `fit()` and `predict()` (#137). 20 | * Added a new argument, `type`, to the `workflow_set` `autoplot()` method. The 21 | default, `"class"`, retains the existing behavior of mapping model type to 22 | color and preprocessor type to shape, while the new `"wflow_id"` 23 | type maps the workflow IDs to color (#134). 24 | * Added type checking for inputted arguments (#136, #131). 25 | 26 | # workflowsets 1.0.1 27 | 28 | * The `extract_parameter_dials()` and `extract_parameter_set_dials()` extractors 29 | will now return the parameter or parameter set 30 | _that will be used by the tuning function utilized in `workflow_map()`_. 31 | The extractors previously always returned the parameter or parameter set 32 | associated with the workflow contained in the `info` column, which can be 33 | overridden by passing a `param_info` argument to `option_add()`. The 34 | extractors will now first look to the added options before extracting from 35 | workflows (#106). 36 | * Introduces support for clustering model specifications via the tidyclust 37 | package. Supplying clustering models to `workflow_set()` and set 38 | `fn = "tune_cluster"` in `workflow_map()` to use this feature (#125)! 39 | * Introduces a `fit_best()` method for workflowsets that takes in a workflow set 40 | evaluated with `workflow_map()` and returns a workflow fitted with the model 41 | configuration associated with the best performance (#126). 42 | * Transitions deprecations of `pull_*()` functions to now warn on every usage 43 | (#123). 44 | * Various bug fixes and improvements to documentation. 45 | 46 | # workflowsets 1.0.0 47 | 48 | * New `extract_parameter_set_dials()` and `extract_parameter_dials()` methods to 49 | extract parameter sets and single parameters from `workflow_set` objects. 50 | 51 | * Added support for case weights via a new `case_weights` argument 52 | to `workflow_set()` (#82). 53 | 54 | # workflowsets 0.2.1 55 | 56 | * `update_workflow_model()` and `update_workflow_recipe()` were added. These are analogous to `workflows::add_model()` or `workflows::add_recipe()` (#64). 57 | 58 | * Updated tests related to changes in workflows 0.2.5 (#75). 59 | 60 | * `as_workflow_set()` can now take a mixture of workflows or `tune_results` objects. 61 | 62 | * `option_add()` now checks the names of the options to see if they are valid names for the functions that receive them (#66) 63 | 64 | # workflowsets 0.1.0 65 | 66 | * Fixed an `autoplot()` bug where, if one metric is selected but a ranking metric is not specified, the wrong metric is used to order the workflows (#52) 67 | 68 | * Updated pillar formatting for options objects. 69 | 70 | * New `extract_*()` functions have been added that supersede the existing `pull_*()` functions. This is part of a larger move across the tidymodels packages towards a family of generic `extract_*()` functions. The `pull_*()` functions have been soft-deprecated, and will eventually be removed 71 | 72 | # workflowsets 0.0.2 73 | 74 | * Ensured that `workflow_map()` does not fail if there are missing packages or if the function being mapped fails. 75 | 76 | # workflowsets 0.0.1 77 | 78 | * First CRAN version 79 | -------------------------------------------------------------------------------- /R/0_imports.R: -------------------------------------------------------------------------------- 1 | #' @keywords internal 2 | "_PACKAGE" 3 | 4 | ## usethis namespace: start 5 | ## usethis namespace: end 6 | NULL 7 | 8 | #' @import ggplot2 9 | #' @import vctrs 10 | #' @import rlang 11 | #' @importFrom stats qnorm as.formula model.frame 12 | #' @importFrom pillar obj_sum type_sum tbl_sum size_sum 13 | #' @importFrom dplyr dplyr_reconstruct 14 | #' @importFrom lifecycle deprecated 15 | 16 | utils::globalVariables( 17 | c( 18 | ".config", 19 | ".estimator", 20 | ".metric", 21 | "info", 22 | "metric", 23 | "mod_nm", 24 | "model", 25 | "n", 26 | "pp_nm", 27 | "preprocessor", 28 | "preproc", 29 | "object", 30 | "engine", 31 | "result", 32 | "std_err", 33 | "wflow_id", 34 | "func", 35 | "is_race", 36 | "num_rs", 37 | "option", 38 | "metrics", 39 | "predictions", 40 | "hash", 41 | "id", 42 | "workflow", 43 | "comment", 44 | "get_from_env", 45 | ".get_tune_metric_names", 46 | "select_best", 47 | "notes", 48 | "extracts" 49 | ) 50 | ) 51 | 52 | # ------------------------------------------------------------------------------ 53 | 54 | #' @importFrom tune collect_metrics 55 | #' @export 56 | tune::collect_metrics 57 | 58 | #' @importFrom tune collect_predictions 59 | #' @export 60 | tune::collect_predictions 61 | 62 | #' @importFrom tune collect_notes 63 | #' @export 64 | tune::collect_notes 65 | 66 | #' @importFrom tune collect_extracts 67 | #' @export 68 | tune::collect_extracts 69 | 70 | #' @importFrom dplyr %>% 71 | #' @export 72 | dplyr::`%>%` 73 | 74 | #' @importFrom ggplot2 autoplot 75 | #' @export 76 | ggplot2::autoplot 77 | 78 | #' @importFrom hardhat extract_spec_parsnip 79 | #' @export 80 | hardhat::extract_spec_parsnip 81 | #' 82 | #' @importFrom hardhat extract_recipe 83 | #' @export 84 | hardhat::extract_recipe 85 | #' 86 | #' @importFrom hardhat extract_fit_parsnip 87 | #' @export 88 | hardhat::extract_fit_parsnip 89 | #' 90 | #' @importFrom hardhat extract_fit_engine 91 | #' @export 92 | hardhat::extract_fit_engine 93 | #' 94 | #' @importFrom hardhat extract_mold 95 | #' @export 96 | hardhat::extract_mold 97 | #' 98 | #' @importFrom hardhat extract_preprocessor 99 | #' @export 100 | hardhat::extract_preprocessor 101 | #' 102 | #' @importFrom hardhat extract_workflow 103 | #' @export 104 | hardhat::extract_workflow 105 | #' 106 | #' @importFrom hardhat extract_parameter_set_dials 107 | #' @export 108 | hardhat::extract_parameter_set_dials 109 | #' 110 | #' @importFrom hardhat extract_parameter_dials 111 | #' @export 112 | hardhat::extract_parameter_dials 113 | -------------------------------------------------------------------------------- /R/as_workflow_set.R: -------------------------------------------------------------------------------- 1 | #' Convert existing objects to a workflow set 2 | #' 3 | #' Use existing objects to create a workflow set. A list of objects that are 4 | #' either simple workflows or objects that have class `"tune_results"` can be 5 | #' converted into a workflow set. 6 | #' @param ... One or more named objects. Names should be unique and the 7 | #' objects should have at least one of the following classes: `workflow`, 8 | #' `iteration_results`, `tune_results`, `resample_results`, or `tune_race`. Each 9 | #' `tune_results` element should also contain the original workflow 10 | #' (accomplished using the `save_workflow` option in the control function). 11 | #' @return A workflow set. Note that the `option` column will not reflect the 12 | #' options that were used to create each object. 13 | #' 14 | #' @includeRmd man-roxygen/example_data.Rmd note 15 | #' 16 | #' @examples 17 | #' 18 | #' # ------------------------------------------------------------------------------ 19 | #' # Existing results 20 | #' 21 | #' # Use the already worked example to show how to add tuned 22 | #' # objects to a workflow set 23 | #' two_class_res 24 | #' 25 | #' results <- two_class_res |> purrr::pluck("result") 26 | #' names(results) <- two_class_res$wflow_id 27 | #' 28 | #' # These are all objects that have been resampled or tuned: 29 | #' purrr::map_chr(results, \(x) class(x)[1]) 30 | #' 31 | #' # Use rlang's !!! operator to splice in the elements of the list 32 | #' new_set <- as_workflow_set(!!!results) 33 | #' 34 | #' # ------------------------------------------------------------------------------ 35 | #' # Make a set from unfit workflows 36 | #' 37 | #' library(parsnip) 38 | #' library(workflows) 39 | #' 40 | #' lr_spec <- logistic_reg() 41 | #' 42 | #' main_effects <- 43 | #' workflow() |> 44 | #' add_model(lr_spec) |> 45 | #' add_formula(Class ~ .) 46 | #' 47 | #' interactions <- 48 | #' workflow() |> 49 | #' add_model(lr_spec) |> 50 | #' add_formula(Class ~ (.)^2) 51 | #' 52 | #' as_workflow_set(main = main_effects, int = interactions) 53 | #' @export 54 | as_workflow_set <- function(...) { 55 | object <- rlang::list2(...) 56 | 57 | # These could be workflows or objects of class `tune_result` 58 | is_workflow <- purrr::map_lgl(object, \(x) inherits(x, "workflow")) 59 | wflows <- vector("list", length(is_workflow)) 60 | wflows[is_workflow] <- object[is_workflow] 61 | wflows[!is_workflow] <- purrr::map( 62 | object[!is_workflow], 63 | tune::.get_tune_workflow 64 | ) 65 | names(wflows) <- names(object) 66 | 67 | check_names(wflows) 68 | check_for_workflow(wflows) 69 | 70 | res <- tibble::tibble(wflow_id = names(wflows)) 71 | res <- 72 | res |> 73 | dplyr::mutate( 74 | workflow = unname(wflows), 75 | info = purrr::map(workflow, get_info), 76 | option = purrr::map(1:nrow(res), \(i) new_workflow_set_options()) 77 | ) 78 | res$result <- vector(mode = "list", length = nrow(res)) 79 | res$result[!is_workflow] <- object[!is_workflow] 80 | 81 | res |> 82 | dplyr::select(wflow_id, info, option, result) |> 83 | new_workflow_set() 84 | } 85 | -------------------------------------------------------------------------------- /R/autoplot.R: -------------------------------------------------------------------------------- 1 | #' Plot the results of a workflow set 2 | #' 3 | #' This `autoplot()` method plots performance metrics that have been ranked using 4 | #' a metric. It can also run `autoplot()` on the individual results (per 5 | #' `wflow_id`). 6 | #' 7 | #' @param object A `workflow_set` whose elements have results. 8 | #' @param rank_metric A character string for which metric should be used to rank 9 | #' the results. If none is given, the first metric in the metric set is used 10 | #' (after filtering by the `metric` option). 11 | #' @param id A character string for what to plot. If a value of 12 | #' `"workflow_set"` is used, the results of each model (and sub-model) are ordered 13 | #' and plotted. Alternatively, a value of the workflow set's `wflow_id` can be 14 | #' given and the `autoplot()` method is executed on that workflow's results. 15 | #' @param select_best A logical; should the results only contain the numerically 16 | #' best submodel per workflow? 17 | #' @param metric A character vector for which metrics (apart from `rank_metric`) 18 | #' to be included in the visualization. 19 | #' @param std_errs The number of standard errors to plot (if the standard error 20 | #' exists). 21 | #' @param type The aesthetics with which to differentiate workflows. The 22 | #' default `"class"` maps color to the model type and shape to the preprocessor 23 | #' type. The `"workflow"` option maps a color to each `"wflow_id"`. This 24 | #' argument is ignored for values of `id` other than `"workflow_set"`. 25 | #' @param ... Other options to pass to `autoplot()`. 26 | #' @details 27 | #' This function is intended to produce a default plot to visualize helpful 28 | #' information across all possible applications of a workflow set. A more 29 | #' appropriate plot for your specific analysis can be created by 30 | #' calling [rank_results()] and using standard `ggplot2` code for plotting. 31 | #' 32 | #' The x-axis is the workflow rank in the set (a value of one being the best) 33 | #' versus the performance metric(s) on the y-axis. With multiple metrics, there 34 | #' will be facets for each metric. 35 | #' 36 | #' If multiple resamples are used, confidence bounds are shown for each result 37 | #' (90% confidence, by default). 38 | #' @return A ggplot object. 39 | #' 40 | #' @includeRmd man-roxygen/example_data.Rmd note 41 | #' 42 | #' @examples 43 | #' autoplot(two_class_res) 44 | #' autoplot(two_class_res, select_best = TRUE) 45 | #' autoplot(two_class_res, id = "yj_trans_cart", metric = "roc_auc") 46 | #' @export 47 | autoplot.workflow_set <- function( 48 | object, 49 | rank_metric = NULL, 50 | metric = NULL, 51 | id = "workflow_set", 52 | select_best = FALSE, 53 | std_errs = qnorm(0.95), 54 | type = "class", 55 | ... 56 | ) { 57 | rlang::arg_match(type, c("class", "wflow_id")) 58 | check_string(rank_metric, allow_null = TRUE) 59 | check_character(metric, allow_null = TRUE) 60 | check_number_decimal(std_errs) 61 | check_bool(select_best) 62 | 63 | if (id == "workflow_set") { 64 | p <- rank_plot( 65 | object, 66 | rank_metric = rank_metric, 67 | metric = metric, 68 | select_best = select_best, 69 | std_errs = std_errs, 70 | type = type 71 | ) 72 | } else { 73 | p <- autoplot( 74 | object$result[[which(object$wflow_id == id)]], 75 | metric = metric, 76 | ... 77 | ) 78 | } 79 | p 80 | } 81 | 82 | rank_plot <- function( 83 | object, 84 | rank_metric = NULL, 85 | metric = NULL, 86 | select_best = FALSE, 87 | std_errs = 1, 88 | type = "class" 89 | ) { 90 | metric_info <- pick_metric(object, rank_metric, metric) 91 | metrics <- collate_metrics(object) 92 | res <- rank_results( 93 | object, 94 | rank_metric = metric_info$metric, 95 | select_best = select_best 96 | ) 97 | 98 | if (!is.null(metric)) { 99 | keep_metrics <- unique(c(rank_metric, metric)) 100 | res <- dplyr::filter(res, .metric %in% keep_metrics) 101 | } 102 | 103 | num_metrics <- length(unique(res$.metric)) 104 | has_std_error <- !all(is.na(res$std_err)) 105 | 106 | p <- 107 | switch( 108 | type, 109 | class = ggplot(res, aes(x = rank, y = mean, col = model)) + 110 | geom_point(aes(shape = preprocessor)), 111 | wflow_id = ggplot(res, aes(x = rank, y = mean, col = wflow_id)) + 112 | geom_point() 113 | ) 114 | 115 | if (num_metrics > 1) { 116 | res$.metric <- factor(as.character(res$.metric), levels = metrics$metric) 117 | p <- 118 | p + 119 | facet_wrap(~.metric, scales = "free_y", as.table = FALSE) + 120 | labs(x = "Workflow Rank", y = "Metric") 121 | } else { 122 | p <- p + labs(x = "Workflow Rank", y = metric_info$metric) 123 | } 124 | 125 | if (has_std_error) { 126 | p <- 127 | p + 128 | geom_errorbar( 129 | aes( 130 | ymin = mean - std_errs * std_err, 131 | ymax = mean + std_errs * std_err 132 | ), 133 | width = diff(range(res$rank)) / 75 134 | ) 135 | } 136 | 137 | p 138 | } 139 | -------------------------------------------------------------------------------- /R/collect.R: -------------------------------------------------------------------------------- 1 | #' Obtain and format results produced by tuning functions for workflow sets 2 | #' 3 | #' Return a tibble of performance metrics for all models or submodels. 4 | #' 5 | #' @param x A [`workflow_set`][workflow_set()] object that has been evaluated 6 | #' with [workflow_map()]. 7 | #' @param ... Not currently used. 8 | #' @param summarize A logical for whether the performance estimates should be 9 | #' summarized via the mean (over resamples) or the raw performance values (per 10 | #' resample) should be returned along with the resampling identifiers. When 11 | #' collecting predictions, these are averaged if multiple assessment sets 12 | #' contain the same row. 13 | #' @param parameters An optional tibble of tuning parameter values that can be 14 | #' used to filter the predicted values before processing. This tibble should 15 | #' only have columns for each tuning parameter identifier (e.g. `"my_param"` 16 | #' if `tune("my_param")` was used). 17 | #' @param select_best A single logical for whether the numerically best results 18 | #' are retained. If `TRUE`, the `parameters` argument is ignored. 19 | #' @param metric A character string for the metric that is used for 20 | #' `select_best`. 21 | #' @return A tibble. 22 | #' @details 23 | #' 24 | #' When applied to a workflow set, the metrics and predictions that are returned 25 | #' do not contain the actual tuning parameter columns and values (unlike when 26 | #' these collect functions are run on other objects). The reason is that workflow 27 | #' sets can contain different types of models or models with different tuning 28 | #' parameters. 29 | #' 30 | #' If the columns are needed, there are two options. First, the `.config` column 31 | #' can be used to merge the tuning parameter columns into an appropriate object. 32 | #' Alternatively, the `map()` function can be used to get the metrics from the 33 | #' original objects (see the example below). 34 | #' 35 | #' @seealso [tune::collect_metrics()], [rank_results()] 36 | #' 37 | #' @includeRmd man-roxygen/example_data.Rmd note 38 | #' 39 | #' @examples 40 | #' library(dplyr) 41 | #' library(purrr) 42 | #' library(tidyr) 43 | #' 44 | #' two_class_res 45 | #' 46 | #' # ------------------------------------------------------------------------------ 47 | #' \donttest{ 48 | #' collect_metrics(two_class_res) 49 | #' 50 | #' # Alternatively, if the tuning parameter values are needed: 51 | #' two_class_res |> 52 | #' dplyr::filter(grepl("cart", wflow_id)) |> 53 | #' mutate(metrics = map(result, collect_metrics)) |> 54 | #' dplyr::select(wflow_id, metrics) |> 55 | #' tidyr::unnest(cols = metrics) 56 | #' } 57 | #' 58 | #' collect_metrics(two_class_res, summarize = FALSE) 59 | #' @export 60 | collect_metrics.workflow_set <- function(x, ..., summarize = TRUE) { 61 | rlang::check_dots_empty() 62 | check_incompete(x, fail = TRUE) 63 | check_bool(summarize) 64 | x <- 65 | dplyr::mutate( 66 | x, 67 | metrics = purrr::map( 68 | result, 69 | collect_metrics, 70 | summarize = summarize 71 | ), 72 | metrics = purrr::map2(metrics, result, remove_parameters) 73 | ) 74 | info <- dplyr::bind_rows(x$info) |> dplyr::select(-workflow, -comment) 75 | x <- 76 | dplyr::select(x, wflow_id, metrics) |> 77 | dplyr::bind_cols(info) |> 78 | tidyr::unnest(cols = c(metrics)) |> 79 | reorder_cols() 80 | check_consistent_metrics(x, fail = FALSE) 81 | x 82 | } 83 | 84 | remove_parameters <- function(x, object) { 85 | prm <- tune::.get_tune_parameter_names(object) 86 | x <- dplyr::select(x, -dplyr::one_of(prm)) 87 | x 88 | } 89 | 90 | reorder_cols <- function(x) { 91 | if (any(names(x) == ".iter")) { 92 | cols <- c("wflow_id", ".config", ".iter", "preproc", "model") 93 | } else { 94 | cols <- c("wflow_id", ".config", "preproc", "model") 95 | } 96 | dplyr::relocate(x, !!!cols) 97 | } 98 | 99 | #' @export 100 | #' @rdname collect_metrics.workflow_set 101 | collect_predictions.workflow_set <- 102 | function( 103 | x, 104 | ..., 105 | summarize = TRUE, 106 | parameters = NULL, 107 | select_best = FALSE, 108 | metric = NULL 109 | ) { 110 | rlang::check_dots_empty() 111 | check_incompete(x, fail = TRUE) 112 | check_bool(summarize) 113 | check_bool(select_best) 114 | check_string(metric, allow_null = TRUE) 115 | if (select_best) { 116 | x <- 117 | dplyr::mutate( 118 | x, 119 | predictions = purrr::map( 120 | result, 121 | \(.x) 122 | select_bare_predictions( 123 | .x, 124 | summarize = summarize, 125 | metric = metric 126 | ) 127 | ) 128 | ) 129 | } else { 130 | x <- 131 | dplyr::mutate( 132 | x, 133 | predictions = purrr::map( 134 | result, 135 | get_bare_predictions, 136 | summarize = summarize, 137 | parameters = parameters 138 | ) 139 | ) 140 | } 141 | info <- dplyr::bind_rows(x$info) |> dplyr::select(-workflow, -comment) 142 | x <- 143 | dplyr::select(x, wflow_id, predictions) |> 144 | dplyr::bind_cols(info) |> 145 | tidyr::unnest(cols = c(predictions)) |> 146 | reorder_cols() 147 | x 148 | } 149 | 150 | select_bare_predictions <- function(x, metric, summarize) { 151 | res <- 152 | tune::collect_predictions( 153 | x, 154 | summarize = summarize, 155 | parameters = tune::select_best(x, metric = metric) 156 | ) 157 | remove_parameters(res, x) 158 | } 159 | 160 | get_bare_predictions <- function(x, ...) { 161 | res <- tune::collect_predictions(x, ...) 162 | remove_parameters(res, x) 163 | } 164 | 165 | #' @export 166 | #' @rdname collect_metrics.workflow_set 167 | collect_notes.workflow_set <- function(x, ...) { 168 | check_incompete(x) 169 | 170 | res <- dplyr::rowwise(x) 171 | res <- dplyr::mutate(res, notes = list(collect_notes(result))) 172 | res <- dplyr::ungroup(res) 173 | res <- dplyr::select(res, wflow_id, notes) 174 | res <- tidyr::unnest(res, cols = notes) 175 | 176 | res 177 | } 178 | 179 | #' 180 | #' @export 181 | #' @rdname collect_metrics.workflow_set 182 | collect_extracts.workflow_set <- function(x, ...) { 183 | check_incompete(x) 184 | 185 | res <- dplyr::rowwise(x) 186 | res <- dplyr::mutate(res, extracts = list(collect_extracts(result))) 187 | res <- dplyr::ungroup(res) 188 | res <- dplyr::select(res, wflow_id, extracts) 189 | res <- tidyr::unnest(res, cols = extracts) 190 | 191 | res 192 | } 193 | -------------------------------------------------------------------------------- /R/comments.R: -------------------------------------------------------------------------------- 1 | #' Add annotations and comments for workflows 2 | #' 3 | #' `comment_add()` can be used to log important information about the workflow or 4 | #' its results as you work. Comments can be appended or removed. 5 | #' @param x A workflow set outputted by [workflow_set()] or [workflow_map()]. 6 | #' @param id A single character string for a value in the `wflow_id` column. For 7 | #' `comment_print()`, `id` can be a vector or `NULL` (and this indicates that 8 | #' all comments are printed). 9 | #' @param ... One or more character strings. 10 | #' @param append A logical value to determine if the new comment should be added 11 | #' to the existing values. 12 | #' @param collapse A character string that separates the comments. 13 | #' @return `comment_add()` and `comment_reset()` return an updated workflow set. 14 | #' `comment_get()` returns a character string. `comment_print()` returns `NULL` 15 | #' invisibly. 16 | #' @export 17 | #' @examples 18 | #' two_class_set 19 | #' 20 | #' two_class_set |> comment_get("none_cart") 21 | #' 22 | #' new_set <- 23 | #' two_class_set |> 24 | #' comment_add("none_cart", "What does 'cart' stand for\u2753") |> 25 | #' comment_add("none_cart", "Classification And Regression Trees.") 26 | #' 27 | #' comment_print(new_set) 28 | #' 29 | #' new_set |> comment_get("none_cart") 30 | #' 31 | #' new_set |> 32 | #' comment_reset("none_cart") |> 33 | #' comment_get("none_cart") 34 | comment_add <- function(x, id, ..., append = TRUE, collapse = "\n") { 35 | check_wf_set(x) 36 | check_bool(append) 37 | check_string(collapse) 38 | dots <- list(...) 39 | if (length(dots) == 0) { 40 | return(x) 41 | } else { 42 | is_chr <- purrr::map_lgl(dots, is.character) 43 | if (any(!is_chr)) { 44 | cli::cli_abort("The comments should be character strings.") 45 | } 46 | } 47 | 48 | check_string(id) 49 | has_id <- id == x$wflow_id 50 | if (!any(has_id)) { 51 | cli::cli_abort("The {.arg id} value is not in {.arg wflow_id}.") 52 | } 53 | id_index <- which(has_id) 54 | current_val <- x$info[[id_index]]$comment 55 | if (!is.na(current_val) && !append) { 56 | cli::cli_abort( 57 | "There is already a comment for this id and {.code append = FALSE}." 58 | ) 59 | } 60 | new_value <- c(x$info[[id_index]]$comment, unlist(dots)) 61 | new_value <- new_value[!is.na(new_value) & nchar(new_value) > 0] 62 | new_value <- paste0(new_value, collapse = "\n") 63 | x$info[[id_index]]$comment <- new_value 64 | x 65 | } 66 | 67 | #' @export 68 | #' @rdname comment_add 69 | comment_get <- function(x, id) { 70 | check_wf_set(x) 71 | if (length(id) > 1) { 72 | cli::cli_abort("{.arg id} should be a single character value.") 73 | } 74 | has_id <- id == x$wflow_id 75 | if (!any(has_id)) { 76 | cli::cli_abort("The {.arg id} value is not in {.arg wflow_id}.") 77 | } 78 | id_index <- which(has_id) 79 | x$info[[id_index]]$comment 80 | } 81 | 82 | 83 | #' @export 84 | #' @rdname comment_add 85 | comment_reset <- function(x, id) { 86 | check_wf_set(x) 87 | if (length(id) > 1) { 88 | cli::cli_abort("{.arg id} should be a single character value.") 89 | } 90 | has_id <- id == x$wflow_id 91 | if (!any(has_id)) { 92 | cli::cli_abort("The {.arg id} value is not in {.arg wflow_id}.") 93 | } 94 | id_index <- which(has_id) 95 | x$info[[id_index]]$comment <- character(1) 96 | x 97 | } 98 | 99 | #' @export 100 | #' @rdname comment_add 101 | comment_print <- function(x, id = NULL, ...) { 102 | check_wf_set(x) 103 | if (is.null(id)) { 104 | id <- x$wflow_id 105 | } 106 | 107 | x <- dplyr::filter(x, wflow_id %in% id) 108 | chr_x <- purrr::map(x$wflow_id, \(.x) comment_get(x, id = .x)) 109 | has_comment <- purrr::map_lgl(chr_x, \(.x) nchar(.x) > 0) 110 | chr_x <- chr_x[which(has_comment)] 111 | id <- x$wflow_id[which(has_comment)] 112 | 113 | for (i in seq_along(chr_x)) { 114 | cat(cli::rule(id[i]), "\n\n") 115 | tmp_chr <- comment_format(chr_x[[i]]) 116 | n_comments <- length(tmp_chr) 117 | 118 | for (j in 1:n_comments) { 119 | cat(tmp_chr[j], "\n\n") 120 | } 121 | } 122 | invisible(NULL) 123 | } 124 | 125 | comment_format <- function(x, id, ...) { 126 | x <- strsplit(x, "\n")[[1]] 127 | x <- purrr::map(x, \(.x) strwrap(.x)) 128 | x <- purrr::map(x, \(.x) add_returns(.x)) 129 | paste0(x, collapse = "\n\n") 130 | } 131 | 132 | add_returns <- function(x) { 133 | paste0(x, collapse = "\n") 134 | } 135 | -------------------------------------------------------------------------------- /R/compat-dplyr.R: -------------------------------------------------------------------------------- 1 | #' @export 2 | dplyr_reconstruct.workflow_set <- function(data, template) { 3 | workflow_set_maybe_reconstruct(data) 4 | } 5 | -------------------------------------------------------------------------------- /R/compat-vctrs-helpers.R: -------------------------------------------------------------------------------- 1 | workflow_set_maybe_reconstruct <- function(x) { 2 | if (workflow_set_is_reconstructable(x)) { 3 | new_workflow_set0(x) 4 | } else { 5 | new_tibble0(x) 6 | } 7 | } 8 | 9 | workflow_set_is_reconstructable <- function(x) { 10 | has_required_container_type(x) && 11 | has_required_container_columns(x) && 12 | has_valid_column_info_structure(x) && 13 | has_valid_column_info_inner_types(x) && 14 | has_valid_column_info_inner_names(x) && 15 | has_valid_column_result_structure(x) && 16 | has_valid_column_result_inner_types(x) && 17 | has_valid_column_result_fingerprints(x) && 18 | has_valid_column_option_structure(x) && 19 | has_valid_column_option_inner_types(x) && 20 | has_valid_column_wflow_id_structure(x) && 21 | has_valid_column_wflow_id_strings(x) 22 | } 23 | -------------------------------------------------------------------------------- /R/compat-vctrs.R: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | 3 | # `vec_restore()` 4 | # 5 | # Called at the end of `vec_slice()` and `vec_ptype()` after all slicing has 6 | # been done on the proxy object. 7 | # 8 | # If all invariants still hold after modifying the proxy, then we can restore 9 | # to a workflow_set object. Otherwise, it will fall back to a bare tibble. 10 | # 11 | # Unlike rsample classes, `vec_ptype()` returns a workflow_set object here. 12 | # This allows `vec_ptype.workflow_set.workflow_set()` to be called. 13 | 14 | #' @export 15 | vec_restore.workflow_set <- function(x, to, ...) { 16 | workflow_set_maybe_reconstruct(x) 17 | } 18 | 19 | # ------------------------------------------------------------------------------ 20 | 21 | # `vec_ptype2()` 22 | # 23 | # When combining two workflow_sets together, `x` and `y` will be zero-row slices 24 | # which should always result in a new workflow_set object, as long as 25 | # `df_ptype2()` can compute a common type. 26 | # 27 | # Combining a workflow_set with a tibble/data.frame will only ever happen if 28 | # the user calls `vec_c()` or `vec_rbind()` with one of each of those inputs. 29 | # I think that it would be very difficult to expect that this returns a new 30 | # workflow_set, so instead we always return a tibble. 31 | 32 | #' @export 33 | vec_ptype2.workflow_set.workflow_set <- function( 34 | x, 35 | y, 36 | ..., 37 | x_arg = "", 38 | y_arg = "" 39 | ) { 40 | out <- vctrs::df_ptype2(x, y, ..., x_arg = x_arg, y_arg = y_arg) 41 | workflow_set_maybe_reconstruct(out) 42 | } 43 | #' @export 44 | vec_ptype2.workflow_set.tbl_df <- function(x, y, ..., x_arg = "", y_arg = "") { 45 | vctrs::tib_ptype2(x, y, ..., x_arg = x_arg, y_arg = y_arg) 46 | } 47 | #' @export 48 | vec_ptype2.tbl_df.workflow_set <- function(x, y, ..., x_arg = "", y_arg = "") { 49 | vctrs::tib_ptype2(x, y, ..., x_arg = x_arg, y_arg = y_arg) 50 | } 51 | #' @export 52 | vec_ptype2.workflow_set.data.frame <- function( 53 | x, 54 | y, 55 | ..., 56 | x_arg = "", 57 | y_arg = "" 58 | ) { 59 | vctrs::tib_ptype2(x, y, ..., x_arg = x_arg, y_arg = y_arg) 60 | } 61 | #' @export 62 | vec_ptype2.data.frame.workflow_set <- function( 63 | x, 64 | y, 65 | ..., 66 | x_arg = "", 67 | y_arg = "" 68 | ) { 69 | vctrs::tib_ptype2(x, y, ..., x_arg = x_arg, y_arg = y_arg) 70 | } 71 | 72 | # ------------------------------------------------------------------------------ 73 | 74 | # `vec_cast()` 75 | # 76 | # These methods are designed with `vec_ptype2()` in mind. 77 | # 78 | # Casting from one workflow_set to another will happen "automatically" when 79 | # two workflow_sets are combined with `vec_c()`. The common type will be 80 | # computed with `vec_ptype2()`, then each input will be `vec_cast()` to that 81 | # common type. It should always be possible to reconstruct the workflow_set 82 | # if `df_cast()` is able to cast the underlying data frames successfully. 83 | # 84 | # Casting a tibble or data.frame to a workflow_set should never happen 85 | # automatically, because the ptype2 methods always push towards 86 | # tibble / data.frame. Since it is so unlikely that this will be done 87 | # correctly, we don't ever allow it. 88 | # 89 | # Casting a workflow_set to a tibble or data.frame is easy, the underlying 90 | # vctrs function does the work for us. This is used when doing 91 | # `vec_c(, )`, as the `vec_ptype2()` method will compute 92 | # a common type of tibble, and then each input will be cast to tibble. 93 | 94 | #' @export 95 | vec_cast.workflow_set.workflow_set <- function( 96 | x, 97 | to, 98 | ..., 99 | x_arg = "", 100 | to_arg = "" 101 | ) { 102 | out <- vctrs::df_cast(x, to, ..., x_arg = x_arg, to_arg = to_arg) 103 | workflow_set_maybe_reconstruct(out) 104 | } 105 | #' @export 106 | vec_cast.workflow_set.tbl_df <- function(x, to, ..., x_arg = "", to_arg = "") { 107 | stop_incompatible_cast_workflow_set(x, to, x_arg = x_arg, to_arg = to_arg) 108 | } 109 | #' @export 110 | vec_cast.tbl_df.workflow_set <- function(x, to, ..., x_arg = "", to_arg = "") { 111 | vctrs::tib_cast(x, to, ..., x_arg = x_arg, to_arg = to_arg) 112 | } 113 | #' @export 114 | vec_cast.workflow_set.data.frame <- function( 115 | x, 116 | to, 117 | ..., 118 | x_arg = "", 119 | to_arg = "" 120 | ) { 121 | stop_incompatible_cast_workflow_set(x, to, x_arg = x_arg, to_arg = to_arg) 122 | } 123 | #' @export 124 | vec_cast.data.frame.workflow_set <- function( 125 | x, 126 | to, 127 | ..., 128 | x_arg = "", 129 | to_arg = "" 130 | ) { 131 | vctrs::df_cast(x, to, ..., x_arg = x_arg, to_arg = to_arg) 132 | } 133 | 134 | # ------------------------------------------------------------------------------ 135 | 136 | stop_incompatible_cast_workflow_set <- function(x, to, ..., x_arg, to_arg) { 137 | details <- "Can't cast to a because the resulting structure is likely invalid." 138 | vctrs::stop_incompatible_cast( 139 | x, 140 | to, 141 | x_arg = x_arg, 142 | to_arg = to_arg, 143 | details = details 144 | ) 145 | } 146 | -------------------------------------------------------------------------------- /R/data.R: -------------------------------------------------------------------------------- 1 | #' Two Class Example Data 2 | #' 3 | #' @includeRmd man-roxygen/example_data.Rmd description 4 | #' @includeRmd man-roxygen/two_class_set.Rmd 5 | #' 6 | #' @name two_class_set 7 | #' @aliases two_class_set two_class_res 8 | #' @docType data 9 | #' @keywords datasets 10 | #' @examples 11 | #' data(two_class_set) 12 | #' 13 | #' two_class_set 14 | NULL 15 | 16 | #' Chicago Features Example Data 17 | #' 18 | #' @includeRmd man-roxygen/example_data.Rmd description 19 | #' @includeRmd man-roxygen/chi_features_set.Rmd 20 | #' 21 | #' @name chi_features_set 22 | #' @aliases chi_features_set chi_features_res 23 | #' @docType data 24 | #' @keywords datasets 25 | #' @references Max Kuhn and Kjell Johnson (2019) _Feature Engineering and 26 | #' Selection_, \url{https://bookdown.org/max/FES/a-more-complex-example.html} 27 | #' @examples 28 | #' data(chi_features_set) 29 | #' 30 | #' chi_features_set 31 | NULL 32 | -------------------------------------------------------------------------------- /R/extract.R: -------------------------------------------------------------------------------- 1 | #' Extract elements of workflow sets 2 | #' 3 | #' @description 4 | #' These functions extract various elements from a workflow set object. If they 5 | #' do not exist yet, an error is thrown. 6 | #' 7 | #' - `extract_preprocessor()` returns the formula, recipe, or variable 8 | #' expressions used for preprocessing. 9 | #' 10 | #' - `extract_spec_parsnip()` returns the parsnip model specification. 11 | #' 12 | #' - `extract_fit_parsnip()` returns the parsnip model fit object. 13 | #' 14 | #' - `extract_fit_engine()` returns the engine specific fit embedded within 15 | #' a parsnip model fit. For example, when using [parsnip::linear_reg()] 16 | #' with the `"lm"` engine, this returns the underlying `lm` object. 17 | #' 18 | #' - `extract_mold()` returns the preprocessed "mold" object returned 19 | #' from [hardhat::mold()]. It contains information about the preprocessing, 20 | #' including either the prepped recipe, the formula terms object, or 21 | #' variable selectors. 22 | #' 23 | #' - `extract_recipe()` returns the recipe. The `estimated` argument specifies 24 | #' whether the fitted or original recipe is returned. 25 | #' 26 | #' - `extract_workflow_set_result()` returns the results of [workflow_map()] 27 | #' for a particular workflow. 28 | #' 29 | #' - `extract_workflow()` returns the workflow object. The workflow will not 30 | #' have been estimated. 31 | #' 32 | #' - `extract_parameter_set_dials()` returns the parameter set 33 | #' _that will be used to fit_ the supplied row `id` of the workflow set. 34 | #' Note that workflow sets reference a parameter set associated with the 35 | #' `workflow` contained in the `info` column by default, but can be 36 | #' fitted with a modified parameter set via the [option_add()] interface. 37 | #' This extractor returns the latter, if it exists, and returns the former 38 | #' if not, mirroring the process that [workflow_map()] follows to provide 39 | #' tuning functions a parameter set. 40 | #' 41 | #' - `extract_parameter_dials()` returns the `parameters` object 42 | #' _that will be used to fit_ the supplied tuning `parameter` in the supplied 43 | #' row `id` of the workflow set. See the above notes in 44 | #' `extract_parameter_set_dials()` on precedence. 45 | #' 46 | #' @inheritParams comment_add 47 | #' @param id A single character string for a workflow ID. 48 | #' @param parameter A single string for the parameter ID. 49 | #' @param estimated A logical for whether the original (unfit) recipe or the 50 | #' fitted recipe should be returned. 51 | #' @param ... Other options (not currently used). 52 | #' @details 53 | #' 54 | #' These functions supersede the `pull_*()` functions (e.g., 55 | #' [extract_workflow_set_result()]). 56 | #' @return 57 | #' The extracted value from the object, `x`, as described in the 58 | #' description section. 59 | #' 60 | #' @includeRmd man-roxygen/example_data.Rmd note 61 | #' 62 | #' @examples 63 | #' library(tune) 64 | #' 65 | #' two_class_res 66 | #' 67 | #' extract_workflow_set_result(two_class_res, "none_cart") 68 | #' 69 | #' extract_workflow(two_class_res, "none_cart") 70 | #' @export 71 | extract_workflow_set_result <- function(x, id, ...) { 72 | check_wf_set(x) 73 | y <- filter_id(x, id) 74 | y$result[[1]] 75 | } 76 | 77 | #' @export 78 | #' @rdname extract_workflow_set_result 79 | extract_workflow.workflow_set <- function(x, id, ...) { 80 | y <- filter_id(x, id) 81 | y$info[[1]]$workflow[[1]] 82 | } 83 | 84 | #' @export 85 | #' @rdname extract_workflow_set_result 86 | extract_spec_parsnip.workflow_set <- function(x, id, ...) { 87 | y <- filter_id(x, id) 88 | extract_spec_parsnip(y$info[[1]]$workflow[[1]]) 89 | } 90 | 91 | 92 | #' @export 93 | #' @rdname extract_workflow_set_result 94 | extract_recipe.workflow_set <- function(x, id, ..., estimated = TRUE) { 95 | check_empty_dots(...) 96 | if (!rlang::is_bool(estimated)) { 97 | cli::cli_abort( 98 | "{.arg estimated} must be a single {.code TRUE} or {.code FALSE}." 99 | ) 100 | } 101 | y <- filter_id(x, id) 102 | extract_recipe(y$info[[1]]$workflow[[1]], estimated = estimated) 103 | } 104 | check_empty_dots <- function(..., call = caller_env()) { 105 | opts <- list(...) 106 | if (any(names(opts) == "estimated")) { 107 | cli::cli_abort("{.arg estimated} should be a named argument.", call = call) 108 | } 109 | if (length(opts) > 0) { 110 | cli::cli_abort("{.arg ...} are not used in this function.", call = call) 111 | } 112 | invisible(NULL) 113 | } 114 | 115 | 116 | #' @export 117 | #' @rdname extract_workflow_set_result 118 | extract_fit_parsnip.workflow_set <- function(x, id, ...) { 119 | y <- filter_id(x, id) 120 | extract_fit_parsnip(y$info[[1]]$workflow[[1]]) 121 | } 122 | 123 | #' @export 124 | #' @rdname extract_workflow_set_result 125 | extract_fit_engine.workflow_set <- function(x, id, ...) { 126 | y <- filter_id(x, id) 127 | extract_fit_engine(y$info[[1]]$workflow[[1]]) 128 | } 129 | 130 | #' @export 131 | #' @rdname extract_workflow_set_result 132 | extract_mold.workflow_set <- function(x, id, ...) { 133 | y <- filter_id(x, id) 134 | extract_mold(y$info[[1]]$workflow[[1]]) 135 | } 136 | 137 | #' @export 138 | #' @rdname extract_workflow_set_result 139 | extract_preprocessor.workflow_set <- function(x, id, ...) { 140 | y <- filter_id(x, id) 141 | extract_preprocessor(y$info[[1]]$workflow[[1]]) 142 | } 143 | 144 | #' @export 145 | #' @rdname extract_workflow_set_result 146 | extract_parameter_set_dials.workflow_set <- function(x, id, ...) { 147 | y <- filter_id(x, id) 148 | 149 | if ("param_info" %in% names(y$option[[1]])) { 150 | return(y$option[[1]][["param_info"]]) 151 | } 152 | 153 | extract_parameter_set_dials(y$info[[1]]$workflow[[1]]) 154 | } 155 | 156 | #' @export 157 | #' @rdname extract_workflow_set_result 158 | extract_parameter_dials.workflow_set <- function(x, id, parameter, ...) { 159 | res <- extract_parameter_set_dials(x, id) 160 | res <- extract_parameter_dials(res, parameter) 161 | 162 | res 163 | } 164 | 165 | # ------------------------------------------------------------------------------ 166 | 167 | filter_id <- function(x, id, call = caller_env()) { 168 | check_string(id) 169 | out <- dplyr::filter(x, wflow_id == id) 170 | if (nrow(out) != 1L) { 171 | cli::cli_abort( 172 | "{.arg id} must correspond to a single row in {.arg x}.", 173 | call = call 174 | ) 175 | } 176 | out 177 | } 178 | -------------------------------------------------------------------------------- /R/fit.R: -------------------------------------------------------------------------------- 1 | #' @importFrom generics fit 2 | 3 | #' @noRd 4 | #' @method fit workflow_set 5 | #' @export 6 | fit.workflow_set <- function(object, ...) { 7 | msg <- "`fit()` is not well-defined for workflow sets." 8 | 9 | # supply a different message depending on whether the 10 | # workflow set has been (attempted to have been) fitted or not 11 | if (!all(purrr::map_lgl(object$result, ~ identical(.x, list())))) { 12 | # if fitted: 13 | msg <- 14 | c( 15 | msg, 16 | "i" = "Please see {.help [{.fun fit_best}](workflowsets::fit_best.workflow_set)}." 17 | ) 18 | } else { 19 | # if not fitted: 20 | msg <- 21 | c( 22 | msg, 23 | "i" = "Please see {.help [{.fun workflow_map}](workflowsets::workflow_map)}." 24 | ) 25 | } 26 | 27 | cli::cli_abort(msg) 28 | } 29 | -------------------------------------------------------------------------------- /R/fit_best.R: -------------------------------------------------------------------------------- 1 | #' @importFrom tune fit_best 2 | #' @export 3 | tune::fit_best 4 | 5 | #' Fit a model to the numerically optimal configuration 6 | #' 7 | #' `fit_best()` takes results from tuning many models and fits the workflow 8 | #' configuration associated with the best performance to the training set. 9 | #' 10 | #' @param x A [`workflow_set`][workflow_set()] object that has been evaluated 11 | #' with [workflow_map()]. Note that the workflow set must have been fitted with 12 | #' the [control option][option_add] `save_workflow = TRUE`. 13 | #' @param metric A character string giving the metric to rank results by. 14 | #' @inheritParams tune::fit_best.tune_results 15 | #' @param ... Additional options to pass to 16 | #' [tune::fit_best][tune::fit_best.tune_results]. 17 | #' 18 | #' @details 19 | #' This function is a shortcut for the steps needed to fit the 20 | #' numerically optimal configuration in a fitted workflow set. 21 | #' The function ranks results, extracts the tuning result pertaining 22 | #' to the best result, and then again calls `fit_best()` (itself a 23 | #' wrapper) on the tuning result containing the best result. 24 | #' 25 | #' In pseudocode: 26 | #' 27 | #' ``` 28 | #' rankings <- rank_results(wf_set, metric, select_best = TRUE) 29 | #' tune_res <- extract_workflow_set_result(wf_set, rankings$wflow_id[1]) 30 | #' fit_best(tune_res, metric) 31 | #' ``` 32 | #' 33 | #' @includeRmd man-roxygen/example_data.Rmd note 34 | #' 35 | #' @examplesIf rlang::is_installed(c("kknn", "modeldata", "recipes", "yardstick", "dials")) && identical(Sys.getenv("NOT_CRAN"), "true") 36 | #' library(tune) 37 | #' library(modeldata) 38 | #' library(rsample) 39 | #' 40 | #' data(Chicago) 41 | #' Chicago <- Chicago[1:1195, ] 42 | #' 43 | #' time_val_split <- 44 | #' sliding_period( 45 | #' Chicago, 46 | #' date, 47 | #' "month", 48 | #' lookback = 38, 49 | #' assess_stop = 1 50 | #' ) 51 | #' 52 | #' chi_features_set 53 | #' 54 | #' chi_features_res_new <- 55 | #' chi_features_set |> 56 | #' # note: must set `save_workflow = TRUE` to use `fit_best()` 57 | #' option_add(control = control_grid(save_workflow = TRUE)) |> 58 | #' # evaluate with resamples 59 | #' workflow_map(resamples = time_val_split, grid = 21, seed = 1, verbose = TRUE) 60 | #' 61 | #' chi_features_res_new 62 | #' 63 | #' # sort models by performance metrics 64 | #' rank_results(chi_features_res_new) 65 | #' 66 | #' # fit the numerically optimal configuration to the training set 67 | #' chi_features_wf <- fit_best(chi_features_res_new) 68 | #' 69 | #' chi_features_wf 70 | #' 71 | #' # to select optimal value based on a specific metric: 72 | #' fit_best(chi_features_res_new, metric = "rmse") 73 | #' @name fit_best.workflow_set 74 | #' @export 75 | fit_best.workflow_set <- function(x, metric = NULL, eval_time = NULL, ...) { 76 | check_string(metric, allow_null = TRUE) 77 | result_1 <- extract_workflow_set_result(x, id = x$wflow_id[[1]]) 78 | met_set <- tune::.get_tune_metrics(result_1) 79 | 80 | if (is.null(metric)) { 81 | metric <- .get_tune_metric_names(result_1)[1] 82 | } else { 83 | tune::check_metric_in_tune_results(tibble::as_tibble(met_set), metric) 84 | } 85 | 86 | if (is.null(eval_time) & is_dyn(met_set, metric)) { 87 | eval_time <- tune::.get_tune_eval_times(result_1)[1] 88 | } 89 | 90 | rankings <- 91 | rank_results( 92 | x, 93 | rank_metric = metric, 94 | select_best = TRUE, 95 | eval_time = eval_time 96 | ) 97 | 98 | tune_res <- extract_workflow_set_result(x, id = rankings$wflow_id[1]) 99 | 100 | best_params <- select_best(tune_res, metric = metric, eval_time = eval_time) 101 | 102 | fit_best(tune_res, parameters = best_params, ...) 103 | } 104 | 105 | # from unexported 106 | # https://github.com/tidymodels/tune/blob/5b0e10fac559f18c075eb4bd7020e217c6174e66/R/metric-selection.R#L137-L141 107 | is_dyn <- function(mtr_set, metric) { 108 | mtr_info <- tibble::as_tibble(mtr_set) 109 | mtr_cls <- mtr_info$class[mtr_info$metric == metric] 110 | mtr_cls == "dynamic_survival_metric" 111 | } 112 | -------------------------------------------------------------------------------- /R/leave_var_out_formulas.R: -------------------------------------------------------------------------------- 1 | #' Create formulas without each predictor 2 | #' 3 | #' From an initial model formula, create a list of formulas that exclude 4 | #' each predictor. 5 | #' @param formula A model formula that contains at least two predictors. 6 | #' @param data A data frame. 7 | #' @param full_model A logical; should the list include the original formula? 8 | #' @param ... Options to pass to [stats::model.frame()] 9 | #' @seealso [workflow_set()] 10 | #' @return A named list of formulas 11 | #' @details The new formulas obey the hierarchy rule so that interactions 12 | #' without main effects are not included (unless the original formula contains 13 | #' such terms). 14 | #' 15 | #' Factor predictors are left as-is (i.e., no indicator variables are created). 16 | #' 17 | #' @examplesIf rlang::is_installed("modeldata") 18 | #' data(penguins, package = "modeldata") 19 | #' 20 | #' leave_var_out_formulas( 21 | #' bill_length_mm ~ ., 22 | #' data = penguins 23 | #' ) 24 | #' 25 | #' leave_var_out_formulas( 26 | #' bill_length_mm ~ (island + sex)^2 + flipper_length_mm, 27 | #' data = penguins 28 | #' ) 29 | #' 30 | #' leave_var_out_formulas( 31 | #' bill_length_mm ~ (island + sex)^2 + flipper_length_mm + 32 | #' I(flipper_length_mm^2), 33 | #' data = penguins 34 | #' ) 35 | #' @export 36 | leave_var_out_formulas <- function(formula, data, full_model = TRUE, ...) { 37 | check_formula(formula) 38 | check_bool(full_model) 39 | 40 | trms <- attr(model.frame(formula, data, ...), "terms") 41 | x_vars <- attr(trms, "term.labels") 42 | if (length(x_vars) < 2) { 43 | cli::cli_abort("There should be at least 2 predictors in the formula.") 44 | } 45 | y_vars <- as.character(formula[[2]]) 46 | 47 | form_terms <- purrr::map(x_vars, rm_vars, lst = x_vars) 48 | form <- purrr::map_chr( 49 | form_terms, 50 | \(.x) paste(y_vars, "~", paste(.x, collapse = " + ")) 51 | ) 52 | form <- purrr::map(form, as.formula) 53 | form <- purrr::map(form, rm_formula_env) 54 | names(form) <- x_vars 55 | if (full_model) { 56 | form$everything <- formula 57 | } 58 | form 59 | } 60 | 61 | rm_vars <- function(x, lst) { 62 | remaining_terms(x, lst) 63 | } 64 | 65 | remaining_terms <- function(x, lst) { 66 | has_x <- purrr::map_lgl(lst, \(.x) x %in% all_terms(.x)) 67 | is_x <- lst == x 68 | lst[!has_x & !is_x] 69 | } 70 | 71 | rm_formula_env <- function(x) { 72 | attr(x, ".Environment") <- rlang::base_env() 73 | x 74 | } 75 | 76 | all_terms <- function(x) { 77 | y <- paste("~", x) 78 | y <- as.formula(y) 79 | all.vars(y) 80 | } 81 | -------------------------------------------------------------------------------- /R/misc.R: -------------------------------------------------------------------------------- 1 | make_workflow <- function(x, y, call = caller_env()) { 2 | exp_classes <- c("formula", "recipe", "workflow_variables") 3 | w <- 4 | workflows::workflow() |> 5 | workflows::add_model(y) 6 | if (inherits(x, "formula")) { 7 | w <- workflows::add_formula(w, x) 8 | } else if (inherits(x, "recipe")) { 9 | w <- workflows::add_recipe(w, x) 10 | } else if (inherits(x, "workflow_variables")) { 11 | w <- workflows::add_variables(w, variables = x) 12 | } else { 13 | cli::cli_abort( 14 | "The preprocessor must be an object with one of the 15 | following classes: {.or {.cls {exp_classes}}}.", 16 | call = call 17 | ) 18 | } 19 | w 20 | } 21 | 22 | # ------------------------------------------------------------------------------ 23 | 24 | metric_to_df <- function(x, ...) { 25 | metrics <- attributes(x)$metrics 26 | names <- names(metrics) 27 | metrics <- unname(metrics) 28 | classes <- purrr::map_chr(metrics, \(.x) class(.x)[[1]]) 29 | directions <- purrr::map_chr(metrics, \(.x) attr(.x, "direction")) 30 | info <- data.frame(metric = names, class = classes, direction = directions) 31 | info 32 | } 33 | 34 | 35 | collate_metrics <- function(x) { 36 | metrics <- 37 | x$result |> 38 | purrr::map(tune::.get_tune_metrics) |> 39 | purrr::map(metric_to_df) |> 40 | purrr::map_dfr(\(.x) dplyr::mutate(.x, order = 1:nrow(.x))) 41 | 42 | mean_order <- 43 | metrics |> 44 | dplyr::group_by(metric) |> 45 | dplyr::summarize( 46 | order = mean(order, na.rm = TRUE), 47 | n = dplyr::n(), 48 | .groups = "drop" 49 | ) 50 | 51 | dplyr::full_join( 52 | dplyr::distinct(metrics) |> dplyr::select(-order), 53 | mean_order, 54 | by = "metric" 55 | ) |> 56 | dplyr::arrange(order) 57 | } 58 | 59 | pick_metric <- function( 60 | x, 61 | rank_metric, 62 | select_metrics = NULL, 63 | call = caller_env() 64 | ) { 65 | # mostly to check for completeness and consistency: 66 | tmp <- collect_metrics(x) 67 | metrics <- collate_metrics(x) 68 | 69 | if (!is.null(select_metrics)) { 70 | tmp <- dplyr::filter(tmp, .metric %in% select_metrics) 71 | metrics <- dplyr::filter(metrics, metric %in% select_metrics) 72 | } 73 | 74 | if (is.null(rank_metric)) { 75 | rank_metric <- metrics$metric[1] 76 | direction <- metrics$direction[1] 77 | } else { 78 | if (!any(metrics$metric == rank_metric)) { 79 | cli::cli_abort( 80 | "Metric {.val {rank_metric}} was not in the results.", 81 | call = call 82 | ) 83 | } 84 | direction <- metrics$direction[metrics$metric == rank_metric] 85 | } 86 | list(metric = as.character(rank_metric), direction = as.character(direction)) 87 | } 88 | -------------------------------------------------------------------------------- /R/options.R: -------------------------------------------------------------------------------- 1 | #' Add and edit options saved in a workflow set 2 | #' 3 | #' @description 4 | #' The `option` column controls options for the functions that are used to 5 | #' _evaluate_ the workflow set, such as [tune::fit_resamples()] or 6 | #' [tune::tune_grid()]. Examples of common options to set for these functions 7 | #' include `param_info` and `grid`. 8 | #' 9 | #' These functions are helpful for manipulating the information in the `option` 10 | #' column. 11 | #' 12 | #' @export 13 | #' @inheritParams comment_add 14 | #' @param ... Arguments to pass to the `tune_*()` functions (e.g. 15 | #' [tune::tune_grid()]) or [tune::fit_resamples()]. For `option_remove()` this 16 | #' can be a series of unquoted option names. 17 | #' @param id A character string of one or more values from the `wflow_id` 18 | #' column that indicates which options to update. By default, all workflows 19 | #' are updated. 20 | #' @param strict A logical; should execution stop if existing options are being 21 | #' replaced? 22 | #' @return An updated workflow set. 23 | #' @details 24 | #' `option_add()` is used to update all of the options in a workflow set. 25 | #' 26 | #' `option_remove()` will eliminate specific options across rows. 27 | #' 28 | #' `option_add_parameters()` adds a parameter object to the `option` column 29 | #' (if parameters are being tuned). 30 | #' 31 | #' Note that executing a function on the workflow set, such as `tune_grid()`, 32 | #' will add any options given to that function to the `option` column. 33 | #' 34 | #' These functions do _not_ control options for the individual workflows, such as 35 | #' the recipe blueprint. When creating a workflow manually, use 36 | #' [workflows::add_model()] or [workflows::add_recipe()] to specify 37 | #' extra options. To alter these in a workflow set, use 38 | #' [update_workflow_model()] or [update_workflow_recipe()]. 39 | #' 40 | #' @examples 41 | #' library(tune) 42 | #' 43 | #' two_class_set 44 | #' 45 | #' two_class_set |> 46 | #' option_add(grid = 10) 47 | #' 48 | #' two_class_set |> 49 | #' option_add(grid = 10) |> 50 | #' option_add(grid = 50, id = "none_cart") 51 | #' 52 | #' two_class_set |> 53 | #' option_add_parameters() 54 | option_add <- function(x, ..., id = NULL, strict = FALSE) { 55 | check_wf_set(x) 56 | dots <- list(...) 57 | if (length(dots) == 0) { 58 | return(x) 59 | } 60 | 61 | if (strict) { 62 | act <- "fail" 63 | } else { 64 | act <- "warn" 65 | } 66 | 67 | check_tune_args(names(dots)) 68 | 69 | check_string(id, allow_null = TRUE) 70 | check_bool(strict) 71 | 72 | if (!is.null(id)) { 73 | for (i in id) { 74 | ind <- which(x$wflow_id == i) 75 | if (length(ind) == 0) { 76 | cli::cli_warn("Don't have an {.arg id} value {i}") 77 | } else { 78 | check_options(x$option[[ind]], x$wflow_id[[ind]], dots, action = act) 79 | x$option[[ind]] <- append_options(x$option[[ind]], dots) 80 | } 81 | } 82 | } else { 83 | check_options(x$option, x$wflow_id, dots, action = act) 84 | x <- dplyr::mutate(x, option = purrr::map(option, append_options, dots)) 85 | } 86 | x 87 | } 88 | 89 | 90 | #' @export 91 | #' @rdname option_add 92 | option_remove <- function(x, ...) { 93 | dots <- rlang::enexprs(...) 94 | if (length(dots) == 0) { 95 | return(x) 96 | } 97 | dots <- purrr::map_chr(dots, rlang::expr_text) 98 | 99 | x <- dplyr::mutate(x, option = purrr::map(option, rm_elem, dots)) 100 | x 101 | } 102 | 103 | 104 | maybe_param <- function(x) { 105 | prm <- hardhat::extract_parameter_set_dials(x) 106 | if (nrow(prm) == 0) { 107 | x <- list() 108 | } else { 109 | x <- list(param_info = prm) 110 | } 111 | x 112 | } 113 | #' @export 114 | #' @rdname option_add 115 | option_add_parameters <- function(x, id = NULL, strict = FALSE) { 116 | prm <- purrr::map(x$info, \(.x) maybe_param(.x$workflow[[1]])) 117 | num <- purrr::map_int(prm, length) 118 | if (all(num == 0)) { 119 | return(x) 120 | } 121 | 122 | if (strict) { 123 | act <- "fail" 124 | } else { 125 | act <- "warn" 126 | } 127 | 128 | if (!is.null(id)) { 129 | for (i in id) { 130 | ind <- which(x$wflow_id == i) 131 | if (length(ind) == 0) { 132 | cli::cli_warn("Don't have an {.arg id} value {i}") 133 | } else { 134 | check_options( 135 | x$option[[ind]], 136 | x$wflow_id[[ind]], 137 | prm[[ind]], 138 | action = act 139 | ) 140 | x$option[[ind]] <- append_options(x$option[[ind]], prm[[ind]]) 141 | } 142 | } 143 | } else { 144 | check_options(x$option, x$wflow_id, prm[1], action = act) 145 | x <- dplyr::mutate(x, option = purrr::map2(option, prm, append_options)) 146 | } 147 | x 148 | } 149 | 150 | rm_elem <- function(x, nms) { 151 | x <- x[!(names(x) %in% nms)] 152 | new_workflow_set_options(!!!x) 153 | } 154 | 155 | append_options <- function(model, global) { 156 | old_names <- names(model) 157 | new_names <- names(global) 158 | common_names <- intersect(old_names, new_names) 159 | 160 | if (length(common_names) > 0) { 161 | model <- rm_elem(model, common_names) 162 | } 163 | 164 | all_opt <- c(model, global) 165 | new_workflow_set_options(!!!all_opt) 166 | } 167 | 168 | #' @export 169 | print.workflow_set_options <- function(x, ...) { 170 | if (length(x) > 0) { 171 | cat( 172 | "a list of options with names: ", 173 | paste0("'", names(x), "'", collapse = ", ") 174 | ) 175 | } else { 176 | cat("an empty container for options") 177 | } 178 | cat("\n") 179 | } 180 | 181 | 182 | #' Make a classed list of options 183 | #' 184 | #' This function returns a named list with an extra class of 185 | #' `"workflow_set_options"` that has corresponding formatting methods for 186 | #' printing inside of tibbles. 187 | #' @param ... A set of named options (or nothing) 188 | #' @return A classed list. 189 | #' @examples 190 | #' option_list(a = 1, b = 2) 191 | #' option_list() 192 | #' @export 193 | option_list <- function(...) new_workflow_set_options(...) 194 | 195 | new_workflow_set_options <- function(..., call = caller_env()) { 196 | res <- rlang::list2(...) 197 | if (any(names(res) == "")) { 198 | cli::cli_abort("All options should be named.", call = call) 199 | } 200 | structure(res, class = c("workflow_set_options", "list")) 201 | } 202 | 203 | #' @export 204 | type_sum.workflow_set_options <- function(x) { 205 | paste0("opts[", length(x), "]") 206 | } 207 | #' @export 208 | size_sum.workflow_set_options <- function(x) { 209 | "" 210 | } 211 | #' @export 212 | obj_sum.workflow_set_options <- function(x) { 213 | paste0("opts[", length(x), "]") 214 | } 215 | -------------------------------------------------------------------------------- /R/predict.R: -------------------------------------------------------------------------------- 1 | #' @importFrom stats predict 2 | 3 | #' @noRd 4 | #' @method predict workflow_set 5 | #' @export 6 | predict.workflow_set <- function(object, ...) { 7 | cli::cli_abort(c( 8 | "`predict()` is not well-defined for workflow sets.", 9 | "i" = "To predict with the optimal model configuration from a workflow \\ 10 | set, ensure that the workflow set was fitted with the \\ 11 | {.help [control option](workflowsets::option_add)} \\ 12 | {.help [{.code save_workflow = TRUE}](tune::control_grid)}, run \\ 13 | {.help [{.fun fit_best}](tune::fit_best)}, and then predict using \\ 14 | {.help [{.fun predict}](workflows::predict.workflow)} on its output.", 15 | "i" = "To collect predictions from a workflow set, ensure that \\ 16 | the workflow set was fitted with the \\ 17 | {.help [control option](workflowsets::option_add)} \\ 18 | {.help [{.code save_pred = TRUE}](tune::control_grid)} and run \\ 19 | {.help [{.fun collect_predictions}](tune::collect_predictions)}." 20 | )) 21 | } 22 | -------------------------------------------------------------------------------- /R/pull.R: -------------------------------------------------------------------------------- 1 | #' Extract elements from a workflow set 2 | #' 3 | #' `r lifecycle::badge("deprecated")` 4 | #' 5 | #' `pull_workflow_set_result()` retrieves the results of [workflow_map()] for a 6 | #' particular workflow while `pull_workflow()` extracts the unfitted workflow 7 | #' from the `info` column. 8 | #' 9 | #' 10 | #' @inheritParams comment_add 11 | #' @param id A single character string for a workflow ID. 12 | #' @details 13 | #' The [extract_workflow_set_result()] and [extract_workflow()] functions should 14 | #' be used instead of these functions. 15 | #' @return `pull_workflow_set_result()` produces a `tune_result` or 16 | #' `resample_results` object. `pull_workflow()` returns an unfit workflow 17 | #' object. 18 | #' @export 19 | pull_workflow_set_result <- function(x, id) { 20 | lifecycle::deprecate_stop( 21 | "0.1.0", 22 | "pull_workflow_set_result()", 23 | "extract_workflow_set_result()" 24 | ) 25 | } 26 | 27 | #' @export 28 | #' @rdname pull_workflow_set_result 29 | pull_workflow <- function(x, id) { 30 | lifecycle::deprecate_stop("0.1.0", "pull_workflow()", "extract_workflow()") 31 | } 32 | -------------------------------------------------------------------------------- /R/rank_results.R: -------------------------------------------------------------------------------- 1 | #' Rank the results by a metric 2 | #' 3 | #' This function sorts the results by a specific performance metric. 4 | #' 5 | #' @inheritParams collect_metrics.workflow_set 6 | #' @param rank_metric A character string for a metric. 7 | #' @inheritParams tune::fit_best.tune_results 8 | #' @param select_best A logical giving whether the results should only contain 9 | #' the numerically best submodel per workflow. 10 | #' @details 11 | #' If some models have the exact same performance, 12 | #' `rank(value, ties.method = "random")` is used (with a reproducible seed) so 13 | #' that all ranks are integers. 14 | #' 15 | #' No columns are returned for the tuning parameters since they are likely to 16 | #' be different (or not exist) for some models. The `wflow_id` and `.config` 17 | #' columns can be used to determine the corresponding parameter values. 18 | #' @return A tibble with columns: `wflow_id`, `.config`, `.metric`, `mean`, 19 | #' `std_err`, `n`, `preprocessor`, `model`, and `rank`. 20 | #' 21 | #' @includeRmd man-roxygen/example_data.Rmd note 22 | #' 23 | #' @examples 24 | #' chi_features_res 25 | #' 26 | #' rank_results(chi_features_res) 27 | #' rank_results(chi_features_res, select_best = TRUE) 28 | #' rank_results(chi_features_res, rank_metric = "rsq") 29 | #' @export 30 | rank_results <- function( 31 | x, 32 | rank_metric = NULL, 33 | eval_time = NULL, 34 | select_best = FALSE 35 | ) { 36 | check_wf_set(x) 37 | check_string(rank_metric, allow_null = TRUE) 38 | check_bool(select_best) 39 | result_1 <- extract_workflow_set_result(x, id = x$wflow_id[[1]]) 40 | met_set <- tune::.get_tune_metrics(result_1) 41 | if (!is.null(rank_metric)) { 42 | tune::check_metric_in_tune_results(tibble::as_tibble(met_set), rank_metric) 43 | } 44 | 45 | metric_info <- pick_metric(x, rank_metric) 46 | metric <- metric_info$metric 47 | direction <- metric_info$direction 48 | wflow_info <- dplyr::bind_cols( 49 | purrr::map_dfr(x$info, I), 50 | dplyr::select(x, wflow_id) 51 | ) 52 | 53 | eval_time <- tune::choose_eval_time(result_1, metric, eval_time = eval_time) 54 | 55 | results <- collect_metrics(x) |> 56 | dplyr::select( 57 | wflow_id, 58 | .config, 59 | .metric, 60 | mean, 61 | std_err, 62 | n, 63 | dplyr::any_of(".eval_time") 64 | ) |> 65 | dplyr::full_join(wflow_info, by = "wflow_id") |> 66 | dplyr::select(-comment, -workflow) 67 | 68 | if (!is.null(eval_time) && ".eval_time" %in% names(results)) { 69 | results <- results[results$.eval_time == eval_time, ] 70 | } 71 | 72 | types <- x |> 73 | dplyr::full_join(wflow_info, by = "wflow_id") |> 74 | dplyr::mutate( 75 | is_race = purrr::map_lgl(result, \(.x) inherits(.x, "tune_race")), 76 | num_rs = purrr::map_int(result, get_num_resamples) 77 | ) |> 78 | dplyr::select(wflow_id, is_race, num_rs) 79 | 80 | ranked <- 81 | dplyr::full_join(results, types, by = "wflow_id") |> 82 | dplyr::filter(.metric == metric) 83 | 84 | if (any(ranked$is_race)) { 85 | # remove any racing results with less resamples than the total number 86 | rm_rows <- 87 | ranked |> 88 | dplyr::filter(is_race & (num_rs > n)) |> 89 | dplyr::select(wflow_id, .config) |> 90 | dplyr::distinct() 91 | if (nrow(rm_rows) > 0) { 92 | ranked <- dplyr::anti_join(ranked, rm_rows, by = c("wflow_id", ".config")) 93 | results <- dplyr::anti_join( 94 | results, 95 | rm_rows, 96 | by = c("wflow_id", ".config") 97 | ) 98 | } 99 | } 100 | 101 | if (direction == "maximize") { 102 | ranked$mean <- -ranked$mean 103 | } 104 | 105 | if (select_best) { 106 | best_by_wflow <- 107 | dplyr::group_by(ranked, wflow_id) |> 108 | dplyr::slice_min(mean, with_ties = FALSE) |> 109 | dplyr::ungroup() |> 110 | dplyr::select(wflow_id, .config) 111 | ranked <- dplyr::inner_join( 112 | ranked, 113 | best_by_wflow, 114 | by = c("wflow_id", ".config") 115 | ) 116 | } 117 | 118 | # ensure reproducible rankings when there are ties 119 | withr::with_seed( 120 | 1, 121 | { 122 | ranked <- 123 | ranked |> 124 | dplyr::mutate(rank = rank(mean, ties.method = "random")) |> 125 | dplyr::select(wflow_id, .config, rank) 126 | } 127 | ) 128 | 129 | dplyr::inner_join(results, ranked, by = c("wflow_id", ".config")) |> 130 | dplyr::arrange(rank) |> 131 | dplyr::rename(preprocessor = preproc) 132 | } 133 | 134 | get_num_resamples <- function(x) { 135 | purrr::map_dfr(x$splits, \(.x) .x$id) |> 136 | dplyr::distinct() |> 137 | nrow() 138 | } 139 | -------------------------------------------------------------------------------- /R/update.R: -------------------------------------------------------------------------------- 1 | #' Update components of a workflow within a workflow set 2 | #' 3 | #' @description 4 | #' Workflows can take special arguments for the recipe (e.g. a blueprint) or a 5 | #' model (e.g. a special formula). However, when creating a workflow set, there 6 | #' is no way to specify these extra components. 7 | #' 8 | #' `update_workflow_model()` and `update_workflow_recipe()` allow users to set 9 | #' these values _after_ the workflow set is initially created. They are 10 | #' analogous to [workflows::add_model()] or [workflows::add_recipe()]. 11 | #' 12 | #' @inheritParams comment_add 13 | #' @param id A single character string from the `wflow_id` column indicating 14 | #' which workflow to update. 15 | #' @inheritParams workflows::add_recipe 16 | #' @inheritParams workflows::add_model 17 | #' 18 | #' @includeRmd man-roxygen/example_data.Rmd note 19 | #' 20 | #' @examples 21 | #' library(parsnip) 22 | #' 23 | #' new_mod <- 24 | #' decision_tree() |> 25 | #' set_engine("rpart", method = "anova") |> 26 | #' set_mode("classification") 27 | #' 28 | #' new_set <- update_workflow_model(two_class_res, "none_cart", spec = new_mod) 29 | #' 30 | #' new_set 31 | #' 32 | #' extract_workflow(new_set, id = "none_cart") 33 | #' @export 34 | update_workflow_model <- function(x, id, spec, formula = NULL) { 35 | check_wf_set(x) 36 | check_string(id) 37 | check_formula(formula, allow_null = TRUE) 38 | 39 | wflow <- extract_workflow(x, id = id) 40 | wflow <- workflows::update_model(wflow, spec = spec, formula = formula) 41 | id_ind <- which(x$wflow_id == id) 42 | x$info[[id_ind]]$workflow[[1]] <- wflow 43 | # Remove any existing results since they are now inconsistent 44 | if (!identical(x$result[[id_ind]], list())) { 45 | x$result[[id_ind]] <- list() 46 | } 47 | x 48 | } 49 | 50 | 51 | #' @rdname update_workflow_model 52 | #' @export 53 | update_workflow_recipe <- function(x, id, recipe, blueprint = NULL) { 54 | check_wf_set(x) 55 | check_string(id) 56 | 57 | wflow <- extract_workflow(x, id = id) 58 | wflow <- workflows::update_recipe( 59 | wflow, 60 | recipe = recipe, 61 | blueprint = blueprint 62 | ) 63 | id_ind <- which(x$wflow_id == id) 64 | x$info[[id_ind]]$workflow[[1]] <- wflow 65 | # Remove any existing results since they are now inconsistent 66 | if (!identical(x$result[[id_ind]], list())) { 67 | x$result[[id_ind]] <- list() 68 | } 69 | x 70 | } 71 | -------------------------------------------------------------------------------- /R/zzz.R: -------------------------------------------------------------------------------- 1 | .onLoad <- function(libname, pkgname) { 2 | vctrs::s3_register("pillar::obj_sum", "workflow_set_options") 3 | vctrs::s3_register("pillar::size_sum", "workflow_set_options") 4 | vctrs::s3_register("pillar::type_sum", "workflow_set_options") 5 | vctrs::s3_register("pillar::tbl_sum", "workflow_set") 6 | vctrs::s3_register("tune::collect_metrics", "workflow_set") 7 | vctrs::s3_register("tune::collect_predictions", "workflow_set") 8 | vctrs::s3_register("ggplot2::autoplot", "workflow_set") 9 | invisible() 10 | } 11 | -------------------------------------------------------------------------------- /_pkgdown.yml: -------------------------------------------------------------------------------- 1 | url: https://workflowsets.tidymodels.org 2 | 3 | template: 4 | package: tidytemplate 5 | bootstrap: 5 6 | bslib: 7 | danger: "#CA225E" 8 | primary: "#CA225E" 9 | includes: 10 | in_header: | 11 | 12 | development: 13 | mode: auto 14 | 15 | reference: 16 | - title: Core functions 17 | contents: 18 | - workflow_set 19 | - workflow_map 20 | - title: Interface with workflow sets 21 | contents: 22 | - starts_with("option") 23 | - starts_with("extract") 24 | - starts_with("comment") 25 | - starts_with("update") 26 | - title: Process workflow set results 27 | contents: 28 | - rank_results 29 | - autoplot.workflow_set 30 | - fit_best.workflow_set 31 | - starts_with("collect") 32 | - title: Miscellanous 33 | contents: 34 | - matches(".") 35 | -------------------------------------------------------------------------------- /air.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/workflowsets/ba44a8f28ebb519fa796a6199bc0cee505bd2f6e/air.toml -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: auto 8 | threshold: 1% 9 | informational: true 10 | patch: 11 | default: 12 | target: auto 13 | threshold: 1% 14 | informational: true 15 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/workflowsets/ba44a8f28ebb519fa796a6199bc0cee505bd2f6e/cran-comments.md -------------------------------------------------------------------------------- /data/chi_features_set.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/workflowsets/ba44a8f28ebb519fa796a6199bc0cee505bd2f6e/data/chi_features_set.rda -------------------------------------------------------------------------------- /data/two_class_set.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/workflowsets/ba44a8f28ebb519fa796a6199bc0cee505bd2f6e/data/two_class_set.rda -------------------------------------------------------------------------------- /inst/WORDLIST: -------------------------------------------------------------------------------- 1 | CMD 2 | Codecov 3 | Kjell 4 | Lifecycle 5 | ORCID 6 | PBC 7 | Silge 8 | decorrelate 9 | deprecations 10 | finetune 11 | funder 12 | ggplot 13 | masse 14 | modeldata 15 | pre 16 | preprocess 17 | preprocessor 18 | preprocessors 19 | pseudocode 20 | reprex 21 | tibble 22 | tibbles 23 | tidyclust 24 | tidymodels 25 | un 26 | wc 27 | -------------------------------------------------------------------------------- /man-roxygen/chi_features_set.Rmd: -------------------------------------------------------------------------------- 1 | See below for the source code to generate the Chicago Features example workflow sets: 2 | 3 | ```{r, eval = FALSE} 4 | library(workflowsets) 5 | library(workflows) 6 | library(modeldata) 7 | library(recipes) 8 | library(parsnip) 9 | library(dplyr) 10 | library(rsample) 11 | library(tune) 12 | library(yardstick) 13 | library(dials) 14 | 15 | # ------------------------------------------------------------------------------ 16 | # Slightly smaller data size 17 | data(Chicago) 18 | Chicago <- Chicago[1:1195,] 19 | 20 | time_val_split <- 21 | sliding_period( 22 | Chicago, 23 | date, 24 | "month", 25 | lookback = 38, 26 | assess_stop = 1 27 | ) 28 | 29 | # ------------------------------------------------------------------------------ 30 | 31 | base_recipe <- 32 | recipe(ridership ~ ., data = Chicago) |> 33 | # create date features 34 | step_date(date) |> 35 | step_holiday(date) |> 36 | # remove date from the list of predictors 37 | update_role(date, new_role = "id") |> 38 | # create dummy variables from factor columns 39 | step_dummy(all_nominal()) |> 40 | # remove any columns with a single unique value 41 | step_zv(all_predictors()) |> 42 | step_normalize(all_predictors()) 43 | 44 | date_only <- 45 | recipe(ridership ~ ., data = Chicago) |> 46 | # create date features 47 | step_date(date) |> 48 | update_role(date, new_role = "id") |> 49 | # create dummy variables from factor columns 50 | step_dummy(all_nominal()) |> 51 | # remove any columns with a single unique value 52 | step_zv(all_predictors()) 53 | 54 | date_and_holidays <- 55 | recipe(ridership ~ ., data = Chicago) |> 56 | # create date features 57 | step_date(date) |> 58 | step_holiday(date) |> 59 | # remove date from the list of predictors 60 | update_role(date, new_role = "id") |> 61 | # create dummy variables from factor columns 62 | step_dummy(all_nominal()) |> 63 | # remove any columns with a single unique value 64 | step_zv(all_predictors()) 65 | 66 | date_and_holidays_and_pca <- 67 | recipe(ridership ~ ., data = Chicago) |> 68 | # create date features 69 | step_date(date) |> 70 | step_holiday(date) |> 71 | # remove date from the list of predictors 72 | update_role(date, new_role = "id") |> 73 | # create dummy variables from factor columns 74 | step_dummy(all_nominal()) |> 75 | # remove any columns with a single unique value 76 | step_zv(all_predictors()) |> 77 | step_pca(!!stations, num_comp = tune()) 78 | 79 | # ------------------------------------------------------------------------------ 80 | 81 | lm_spec <- linear_reg() |> set_engine("lm") 82 | 83 | # ------------------------------------------------------------------------------ 84 | 85 | pca_param <- 86 | parameters(num_comp()) |> 87 | update(num_comp = num_comp(c(0, 20))) 88 | 89 | # ------------------------------------------------------------------------------ 90 | 91 | chi_features_set <- 92 | workflow_set( 93 | preproc = list(date = date_only, 94 | plus_holidays = date_and_holidays, 95 | plus_pca = date_and_holidays_and_pca), 96 | models = list(lm = lm_spec), 97 | cross = TRUE 98 | ) 99 | 100 | # ------------------------------------------------------------------------------ 101 | 102 | chi_features_res <- 103 | chi_features_set |> 104 | option_add(param_info = pca_param, id = "plus_pca_lm") |> 105 | workflow_map(resamples = time_val_split, grid = 21, seed = 1, verbose = TRUE) 106 | ``` 107 | 108 | ```{r, eval = FALSE, include = FALSE} 109 | save(chi_features_set, chi_features_res, 110 | file = "data/chi_features_set.rda", 111 | version = 2, compress = "xz") 112 | ``` 113 | -------------------------------------------------------------------------------- /man-roxygen/example_data.Rmd: -------------------------------------------------------------------------------- 1 | The package supplies two pre-generated workflow sets, `two_class_set` and `chi_features_set`, and associated sets of model fits `two_class_res` and `chi_features_res`. 2 | 3 | The `two_class_*` objects are based on a binary classification problem using the `two_class_dat` data from the modeldata package. The six models utilize either a bare formula or a basic recipe utilizing `recipes::step_YeoJohnson()` as a preprocessor, and a decision tree, logistic regression, or MARS model specification. See `?two_class_set` for source code. 4 | 5 | The `chi_features_*` objects are based on a regression problem using the `Chicago` data from the modeldata package. Each of the three models utilize a linear regression model specification, with three different recipes of varying complexity. The objects are meant to approximate the sequence of models built in Section 1.3 of Kuhn and Johnson (2019). See `?chi_features_set` for source code. 6 | -------------------------------------------------------------------------------- /man-roxygen/two_class_set.Rmd: -------------------------------------------------------------------------------- 1 | See below for the source code to generate the Two Class example workflow sets: 2 | 3 | ```{r, eval = FALSE} 4 | library(workflowsets) 5 | library(workflows) 6 | library(modeldata) 7 | library(recipes) 8 | library(parsnip) 9 | library(dplyr) 10 | library(rsample) 11 | library(tune) 12 | library(yardstick) 13 | 14 | # ------------------------------------------------------------------------------ 15 | 16 | data(two_class_dat, package = "modeldata") 17 | 18 | set.seed(1) 19 | folds <- vfold_cv(two_class_dat, v = 5) 20 | 21 | # ------------------------------------------------------------------------------ 22 | 23 | decision_tree_rpart_spec <- 24 | decision_tree(min_n = tune(), cost_complexity = tune()) |> 25 | set_engine('rpart') |> 26 | set_mode('classification') 27 | 28 | logistic_reg_glm_spec <- 29 | logistic_reg() |> 30 | set_engine('glm') 31 | 32 | mars_earth_spec <- 33 | mars(prod_degree = tune()) |> 34 | set_engine('earth') |> 35 | set_mode('classification') 36 | 37 | # ------------------------------------------------------------------------------ 38 | 39 | yj_recipe <- 40 | recipe(Class ~ ., data = two_class_dat) |> 41 | step_YeoJohnson(A, B) 42 | 43 | # ------------------------------------------------------------------------------ 44 | 45 | two_class_set <- 46 | workflow_set( 47 | preproc = list(none = Class ~ A + B, yj_trans = yj_recipe), 48 | models = list(cart = decision_tree_rpart_spec, glm = logistic_reg_glm_spec, 49 | mars = mars_earth_spec) 50 | ) 51 | 52 | # ------------------------------------------------------------------------------ 53 | 54 | two_class_res <- 55 | two_class_set |> 56 | workflow_map( 57 | resamples = folds, 58 | grid = 10, 59 | seed = 2, 60 | verbose = TRUE, 61 | control = control_grid(save_workflow = TRUE) 62 | ) 63 | ``` 64 | 65 | ```{r, eval = FALSE, include = FALSE} 66 | save(two_class_set, two_class_res, file = "data/two_class_set.rda", 67 | compress = "xz", version = 2) 68 | ``` 69 | -------------------------------------------------------------------------------- /man/as_workflow_set.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/as_workflow_set.R 3 | \name{as_workflow_set} 4 | \alias{as_workflow_set} 5 | \title{Convert existing objects to a workflow set} 6 | \usage{ 7 | as_workflow_set(...) 8 | } 9 | \arguments{ 10 | \item{...}{One or more named objects. Names should be unique and the 11 | objects should have at least one of the following classes: \code{workflow}, 12 | \code{iteration_results}, \code{tune_results}, \code{resample_results}, or \code{tune_race}. Each 13 | \code{tune_results} element should also contain the original workflow 14 | (accomplished using the \code{save_workflow} option in the control function).} 15 | } 16 | \value{ 17 | A workflow set. Note that the \code{option} column will not reflect the 18 | options that were used to create each object. 19 | } 20 | \description{ 21 | Use existing objects to create a workflow set. A list of objects that are 22 | either simple workflows or objects that have class \code{"tune_results"} can be 23 | converted into a workflow set. 24 | } 25 | \note{ 26 | The package supplies two pre-generated workflow sets, \code{two_class_set} 27 | and \code{chi_features_set}, and associated sets of model fits 28 | \code{two_class_res} and \code{chi_features_res}. 29 | 30 | The \verb{two_class_*} objects are based on a binary classification problem 31 | using the \code{two_class_dat} data from the modeldata package. The six 32 | models utilize either a bare formula or a basic recipe utilizing 33 | \code{recipes::step_YeoJohnson()} as a preprocessor, and a decision tree, 34 | logistic regression, or MARS model specification. See \code{?two_class_set} 35 | for source code. 36 | 37 | The \verb{chi_features_*} objects are based on a regression problem using the 38 | \code{Chicago} data from the modeldata package. Each of the three models 39 | utilize a linear regression model specification, with three different 40 | recipes of varying complexity. The objects are meant to approximate the 41 | sequence of models built in Section 1.3 of Kuhn and Johnson (2019). See 42 | \code{?chi_features_set} for source code. 43 | } 44 | \examples{ 45 | 46 | # ------------------------------------------------------------------------------ 47 | # Existing results 48 | 49 | # Use the already worked example to show how to add tuned 50 | # objects to a workflow set 51 | two_class_res 52 | 53 | results <- two_class_res |> purrr::pluck("result") 54 | names(results) <- two_class_res$wflow_id 55 | 56 | # These are all objects that have been resampled or tuned: 57 | purrr::map_chr(results, \(x) class(x)[1]) 58 | 59 | # Use rlang's !!! operator to splice in the elements of the list 60 | new_set <- as_workflow_set(!!!results) 61 | 62 | # ------------------------------------------------------------------------------ 63 | # Make a set from unfit workflows 64 | 65 | library(parsnip) 66 | library(workflows) 67 | 68 | lr_spec <- logistic_reg() 69 | 70 | main_effects <- 71 | workflow() |> 72 | add_model(lr_spec) |> 73 | add_formula(Class ~ .) 74 | 75 | interactions <- 76 | workflow() |> 77 | add_model(lr_spec) |> 78 | add_formula(Class ~ (.)^2) 79 | 80 | as_workflow_set(main = main_effects, int = interactions) 81 | } 82 | -------------------------------------------------------------------------------- /man/autoplot.workflow_set.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/autoplot.R 3 | \name{autoplot.workflow_set} 4 | \alias{autoplot.workflow_set} 5 | \title{Plot the results of a workflow set} 6 | \usage{ 7 | \method{autoplot}{workflow_set}( 8 | object, 9 | rank_metric = NULL, 10 | metric = NULL, 11 | id = "workflow_set", 12 | select_best = FALSE, 13 | std_errs = qnorm(0.95), 14 | type = "class", 15 | ... 16 | ) 17 | } 18 | \arguments{ 19 | \item{object}{A \code{workflow_set} whose elements have results.} 20 | 21 | \item{rank_metric}{A character string for which metric should be used to rank 22 | the results. If none is given, the first metric in the metric set is used 23 | (after filtering by the \code{metric} option).} 24 | 25 | \item{metric}{A character vector for which metrics (apart from \code{rank_metric}) 26 | to be included in the visualization.} 27 | 28 | \item{id}{A character string for what to plot. If a value of 29 | \code{"workflow_set"} is used, the results of each model (and sub-model) are ordered 30 | and plotted. Alternatively, a value of the workflow set's \code{wflow_id} can be 31 | given and the \code{autoplot()} method is executed on that workflow's results.} 32 | 33 | \item{select_best}{A logical; should the results only contain the numerically 34 | best submodel per workflow?} 35 | 36 | \item{std_errs}{The number of standard errors to plot (if the standard error 37 | exists).} 38 | 39 | \item{type}{The aesthetics with which to differentiate workflows. The 40 | default \code{"class"} maps color to the model type and shape to the preprocessor 41 | type. The \code{"workflow"} option maps a color to each \code{"wflow_id"}. This 42 | argument is ignored for values of \code{id} other than \code{"workflow_set"}.} 43 | 44 | \item{...}{Other options to pass to \code{autoplot()}.} 45 | } 46 | \value{ 47 | A ggplot object. 48 | } 49 | \description{ 50 | This \code{autoplot()} method plots performance metrics that have been ranked using 51 | a metric. It can also run \code{autoplot()} on the individual results (per 52 | \code{wflow_id}). 53 | } 54 | \details{ 55 | This function is intended to produce a default plot to visualize helpful 56 | information across all possible applications of a workflow set. A more 57 | appropriate plot for your specific analysis can be created by 58 | calling \code{\link[=rank_results]{rank_results()}} and using standard \code{ggplot2} code for plotting. 59 | 60 | The x-axis is the workflow rank in the set (a value of one being the best) 61 | versus the performance metric(s) on the y-axis. With multiple metrics, there 62 | will be facets for each metric. 63 | 64 | If multiple resamples are used, confidence bounds are shown for each result 65 | (90\% confidence, by default). 66 | } 67 | \note{ 68 | The package supplies two pre-generated workflow sets, \code{two_class_set} 69 | and \code{chi_features_set}, and associated sets of model fits 70 | \code{two_class_res} and \code{chi_features_res}. 71 | 72 | The \verb{two_class_*} objects are based on a binary classification problem 73 | using the \code{two_class_dat} data from the modeldata package. The six 74 | models utilize either a bare formula or a basic recipe utilizing 75 | \code{recipes::step_YeoJohnson()} as a preprocessor, and a decision tree, 76 | logistic regression, or MARS model specification. See \code{?two_class_set} 77 | for source code. 78 | 79 | The \verb{chi_features_*} objects are based on a regression problem using the 80 | \code{Chicago} data from the modeldata package. Each of the three models 81 | utilize a linear regression model specification, with three different 82 | recipes of varying complexity. The objects are meant to approximate the 83 | sequence of models built in Section 1.3 of Kuhn and Johnson (2019). See 84 | \code{?chi_features_set} for source code. 85 | } 86 | \examples{ 87 | autoplot(two_class_res) 88 | autoplot(two_class_res, select_best = TRUE) 89 | autoplot(two_class_res, id = "yj_trans_cart", metric = "roc_auc") 90 | } 91 | -------------------------------------------------------------------------------- /man/chi_features_set.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{chi_features_set} 5 | \alias{chi_features_set} 6 | \alias{chi_features_res} 7 | \title{Chicago Features Example Data} 8 | \description{ 9 | The package supplies two pre-generated workflow sets, \code{two_class_set} 10 | and \code{chi_features_set}, and associated sets of model fits 11 | \code{two_class_res} and \code{chi_features_res}. 12 | 13 | The \verb{two_class_*} objects are based on a binary classification problem 14 | using the \code{two_class_dat} data from the modeldata package. The six 15 | models utilize either a bare formula or a basic recipe utilizing 16 | \code{recipes::step_YeoJohnson()} as a preprocessor, and a decision tree, 17 | logistic regression, or MARS model specification. See \code{?two_class_set} 18 | for source code. 19 | 20 | The \verb{chi_features_*} objects are based on a regression problem using the 21 | \code{Chicago} data from the modeldata package. Each of the three models 22 | utilize a linear regression model specification, with three different 23 | recipes of varying complexity. The objects are meant to approximate the 24 | sequence of models built in Section 1.3 of Kuhn and Johnson (2019). See 25 | \code{?chi_features_set} for source code. 26 | } 27 | \details{ 28 | See below for the source code to generate the Chicago Features example 29 | workflow sets: 30 | 31 | \if{html}{\out{
}}\preformatted{library(workflowsets) 32 | library(workflows) 33 | library(modeldata) 34 | library(recipes) 35 | library(parsnip) 36 | library(dplyr) 37 | library(rsample) 38 | library(tune) 39 | library(yardstick) 40 | library(dials) 41 | 42 | # ------------------------------------------------------------------------------ 43 | # Slightly smaller data size 44 | data(Chicago) 45 | Chicago <- Chicago[1:1195,] 46 | 47 | time_val_split <- 48 | sliding_period( 49 | Chicago, 50 | date, 51 | "month", 52 | lookback = 38, 53 | assess_stop = 1 54 | ) 55 | 56 | # ------------------------------------------------------------------------------ 57 | 58 | base_recipe <- 59 | recipe(ridership ~ ., data = Chicago) |> 60 | # create date features 61 | step_date(date) |> 62 | step_holiday(date) |> 63 | # remove date from the list of predictors 64 | update_role(date, new_role = "id") |> 65 | # create dummy variables from factor columns 66 | step_dummy(all_nominal()) |> 67 | # remove any columns with a single unique value 68 | step_zv(all_predictors()) |> 69 | step_normalize(all_predictors()) 70 | 71 | date_only <- 72 | recipe(ridership ~ ., data = Chicago) |> 73 | # create date features 74 | step_date(date) |> 75 | update_role(date, new_role = "id") |> 76 | # create dummy variables from factor columns 77 | step_dummy(all_nominal()) |> 78 | # remove any columns with a single unique value 79 | step_zv(all_predictors()) 80 | 81 | date_and_holidays <- 82 | recipe(ridership ~ ., data = Chicago) |> 83 | # create date features 84 | step_date(date) |> 85 | step_holiday(date) |> 86 | # remove date from the list of predictors 87 | update_role(date, new_role = "id") |> 88 | # create dummy variables from factor columns 89 | step_dummy(all_nominal()) |> 90 | # remove any columns with a single unique value 91 | step_zv(all_predictors()) 92 | 93 | date_and_holidays_and_pca <- 94 | recipe(ridership ~ ., data = Chicago) |> 95 | # create date features 96 | step_date(date) |> 97 | step_holiday(date) |> 98 | # remove date from the list of predictors 99 | update_role(date, new_role = "id") |> 100 | # create dummy variables from factor columns 101 | step_dummy(all_nominal()) |> 102 | # remove any columns with a single unique value 103 | step_zv(all_predictors()) |> 104 | step_pca(!!stations, num_comp = tune()) 105 | 106 | # ------------------------------------------------------------------------------ 107 | 108 | lm_spec <- linear_reg() |> set_engine("lm") 109 | 110 | # ------------------------------------------------------------------------------ 111 | 112 | pca_param <- 113 | parameters(num_comp()) |> 114 | update(num_comp = num_comp(c(0, 20))) 115 | 116 | # ------------------------------------------------------------------------------ 117 | 118 | chi_features_set <- 119 | workflow_set( 120 | preproc = list(date = date_only, 121 | plus_holidays = date_and_holidays, 122 | plus_pca = date_and_holidays_and_pca), 123 | models = list(lm = lm_spec), 124 | cross = TRUE 125 | ) 126 | 127 | # ------------------------------------------------------------------------------ 128 | 129 | chi_features_res <- 130 | chi_features_set |> 131 | option_add(param_info = pca_param, id = "plus_pca_lm") |> 132 | workflow_map(resamples = time_val_split, grid = 21, seed = 1, verbose = TRUE) 133 | }\if{html}{\out{
}} 134 | } 135 | \examples{ 136 | data(chi_features_set) 137 | 138 | chi_features_set 139 | } 140 | \references{ 141 | Max Kuhn and Kjell Johnson (2019) \emph{Feature Engineering and 142 | Selection}, \url{https://bookdown.org/max/FES/a-more-complex-example.html} 143 | } 144 | \keyword{datasets} 145 | -------------------------------------------------------------------------------- /man/collect_metrics.workflow_set.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/collect.R 3 | \name{collect_metrics.workflow_set} 4 | \alias{collect_metrics.workflow_set} 5 | \alias{collect_predictions.workflow_set} 6 | \alias{collect_notes.workflow_set} 7 | \alias{collect_extracts.workflow_set} 8 | \title{Obtain and format results produced by tuning functions for workflow sets} 9 | \usage{ 10 | \method{collect_metrics}{workflow_set}(x, ..., summarize = TRUE) 11 | 12 | \method{collect_predictions}{workflow_set}( 13 | x, 14 | ..., 15 | summarize = TRUE, 16 | parameters = NULL, 17 | select_best = FALSE, 18 | metric = NULL 19 | ) 20 | 21 | \method{collect_notes}{workflow_set}(x, ...) 22 | 23 | \method{collect_extracts}{workflow_set}(x, ...) 24 | } 25 | \arguments{ 26 | \item{x}{A \code{\link[=workflow_set]{workflow_set}} object that has been evaluated 27 | with \code{\link[=workflow_map]{workflow_map()}}.} 28 | 29 | \item{...}{Not currently used.} 30 | 31 | \item{summarize}{A logical for whether the performance estimates should be 32 | summarized via the mean (over resamples) or the raw performance values (per 33 | resample) should be returned along with the resampling identifiers. When 34 | collecting predictions, these are averaged if multiple assessment sets 35 | contain the same row.} 36 | 37 | \item{parameters}{An optional tibble of tuning parameter values that can be 38 | used to filter the predicted values before processing. This tibble should 39 | only have columns for each tuning parameter identifier (e.g. \code{"my_param"} 40 | if \code{tune("my_param")} was used).} 41 | 42 | \item{select_best}{A single logical for whether the numerically best results 43 | are retained. If \code{TRUE}, the \code{parameters} argument is ignored.} 44 | 45 | \item{metric}{A character string for the metric that is used for 46 | \code{select_best}.} 47 | } 48 | \value{ 49 | A tibble. 50 | } 51 | \description{ 52 | Return a tibble of performance metrics for all models or submodels. 53 | } 54 | \details{ 55 | When applied to a workflow set, the metrics and predictions that are returned 56 | do not contain the actual tuning parameter columns and values (unlike when 57 | these collect functions are run on other objects). The reason is that workflow 58 | sets can contain different types of models or models with different tuning 59 | parameters. 60 | 61 | If the columns are needed, there are two options. First, the \code{.config} column 62 | can be used to merge the tuning parameter columns into an appropriate object. 63 | Alternatively, the \code{map()} function can be used to get the metrics from the 64 | original objects (see the example below). 65 | } 66 | \note{ 67 | The package supplies two pre-generated workflow sets, \code{two_class_set} 68 | and \code{chi_features_set}, and associated sets of model fits 69 | \code{two_class_res} and \code{chi_features_res}. 70 | 71 | The \verb{two_class_*} objects are based on a binary classification problem 72 | using the \code{two_class_dat} data from the modeldata package. The six 73 | models utilize either a bare formula or a basic recipe utilizing 74 | \code{recipes::step_YeoJohnson()} as a preprocessor, and a decision tree, 75 | logistic regression, or MARS model specification. See \code{?two_class_set} 76 | for source code. 77 | 78 | The \verb{chi_features_*} objects are based on a regression problem using the 79 | \code{Chicago} data from the modeldata package. Each of the three models 80 | utilize a linear regression model specification, with three different 81 | recipes of varying complexity. The objects are meant to approximate the 82 | sequence of models built in Section 1.3 of Kuhn and Johnson (2019). See 83 | \code{?chi_features_set} for source code. 84 | } 85 | \examples{ 86 | library(dplyr) 87 | library(purrr) 88 | library(tidyr) 89 | 90 | two_class_res 91 | 92 | # ------------------------------------------------------------------------------ 93 | \donttest{ 94 | collect_metrics(two_class_res) 95 | 96 | # Alternatively, if the tuning parameter values are needed: 97 | two_class_res |> 98 | dplyr::filter(grepl("cart", wflow_id)) |> 99 | mutate(metrics = map(result, collect_metrics)) |> 100 | dplyr::select(wflow_id, metrics) |> 101 | tidyr::unnest(cols = metrics) 102 | } 103 | 104 | collect_metrics(two_class_res, summarize = FALSE) 105 | } 106 | \seealso{ 107 | \code{\link[tune:collect_predictions]{tune::collect_metrics()}}, \code{\link[=rank_results]{rank_results()}} 108 | } 109 | -------------------------------------------------------------------------------- /man/comment_add.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/comments.R 3 | \name{comment_add} 4 | \alias{comment_add} 5 | \alias{comment_get} 6 | \alias{comment_reset} 7 | \alias{comment_print} 8 | \title{Add annotations and comments for workflows} 9 | \usage{ 10 | comment_add(x, id, ..., append = TRUE, collapse = "\\n") 11 | 12 | comment_get(x, id) 13 | 14 | comment_reset(x, id) 15 | 16 | comment_print(x, id = NULL, ...) 17 | } 18 | \arguments{ 19 | \item{x}{A workflow set outputted by \code{\link[=workflow_set]{workflow_set()}} or \code{\link[=workflow_map]{workflow_map()}}.} 20 | 21 | \item{id}{A single character string for a value in the \code{wflow_id} column. For 22 | \code{comment_print()}, \code{id} can be a vector or \code{NULL} (and this indicates that 23 | all comments are printed).} 24 | 25 | \item{...}{One or more character strings.} 26 | 27 | \item{append}{A logical value to determine if the new comment should be added 28 | to the existing values.} 29 | 30 | \item{collapse}{A character string that separates the comments.} 31 | } 32 | \value{ 33 | \code{comment_add()} and \code{comment_reset()} return an updated workflow set. 34 | \code{comment_get()} returns a character string. \code{comment_print()} returns \code{NULL} 35 | invisibly. 36 | } 37 | \description{ 38 | \code{comment_add()} can be used to log important information about the workflow or 39 | its results as you work. Comments can be appended or removed. 40 | } 41 | \examples{ 42 | two_class_set 43 | 44 | two_class_set |> comment_get("none_cart") 45 | 46 | new_set <- 47 | two_class_set |> 48 | comment_add("none_cart", "What does 'cart' stand for\u2753") |> 49 | comment_add("none_cart", "Classification And Regression Trees.") 50 | 51 | comment_print(new_set) 52 | 53 | new_set |> comment_get("none_cart") 54 | 55 | new_set |> 56 | comment_reset("none_cart") |> 57 | comment_get("none_cart") 58 | } 59 | -------------------------------------------------------------------------------- /man/extract_workflow_set_result.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/extract.R 3 | \name{extract_workflow_set_result} 4 | \alias{extract_workflow_set_result} 5 | \alias{extract_workflow.workflow_set} 6 | \alias{extract_spec_parsnip.workflow_set} 7 | \alias{extract_recipe.workflow_set} 8 | \alias{extract_fit_parsnip.workflow_set} 9 | \alias{extract_fit_engine.workflow_set} 10 | \alias{extract_mold.workflow_set} 11 | \alias{extract_preprocessor.workflow_set} 12 | \alias{extract_parameter_set_dials.workflow_set} 13 | \alias{extract_parameter_dials.workflow_set} 14 | \title{Extract elements of workflow sets} 15 | \usage{ 16 | extract_workflow_set_result(x, id, ...) 17 | 18 | \method{extract_workflow}{workflow_set}(x, id, ...) 19 | 20 | \method{extract_spec_parsnip}{workflow_set}(x, id, ...) 21 | 22 | \method{extract_recipe}{workflow_set}(x, id, ..., estimated = TRUE) 23 | 24 | \method{extract_fit_parsnip}{workflow_set}(x, id, ...) 25 | 26 | \method{extract_fit_engine}{workflow_set}(x, id, ...) 27 | 28 | \method{extract_mold}{workflow_set}(x, id, ...) 29 | 30 | \method{extract_preprocessor}{workflow_set}(x, id, ...) 31 | 32 | \method{extract_parameter_set_dials}{workflow_set}(x, id, ...) 33 | 34 | \method{extract_parameter_dials}{workflow_set}(x, id, parameter, ...) 35 | } 36 | \arguments{ 37 | \item{x}{A workflow set outputted by \code{\link[=workflow_set]{workflow_set()}} or \code{\link[=workflow_map]{workflow_map()}}.} 38 | 39 | \item{id}{A single character string for a workflow ID.} 40 | 41 | \item{...}{Other options (not currently used).} 42 | 43 | \item{estimated}{A logical for whether the original (unfit) recipe or the 44 | fitted recipe should be returned.} 45 | 46 | \item{parameter}{A single string for the parameter ID.} 47 | } 48 | \value{ 49 | The extracted value from the object, \code{x}, as described in the 50 | description section. 51 | } 52 | \description{ 53 | These functions extract various elements from a workflow set object. If they 54 | do not exist yet, an error is thrown. 55 | \itemize{ 56 | \item \code{extract_preprocessor()} returns the formula, recipe, or variable 57 | expressions used for preprocessing. 58 | \item \code{extract_spec_parsnip()} returns the parsnip model specification. 59 | \item \code{extract_fit_parsnip()} returns the parsnip model fit object. 60 | \item \code{extract_fit_engine()} returns the engine specific fit embedded within 61 | a parsnip model fit. For example, when using \code{\link[parsnip:linear_reg]{parsnip::linear_reg()}} 62 | with the \code{"lm"} engine, this returns the underlying \code{lm} object. 63 | \item \code{extract_mold()} returns the preprocessed "mold" object returned 64 | from \code{\link[hardhat:mold]{hardhat::mold()}}. It contains information about the preprocessing, 65 | including either the prepped recipe, the formula terms object, or 66 | variable selectors. 67 | \item \code{extract_recipe()} returns the recipe. The \code{estimated} argument specifies 68 | whether the fitted or original recipe is returned. 69 | \item \code{extract_workflow_set_result()} returns the results of \code{\link[=workflow_map]{workflow_map()}} 70 | for a particular workflow. 71 | \item \code{extract_workflow()} returns the workflow object. The workflow will not 72 | have been estimated. 73 | \item \code{extract_parameter_set_dials()} returns the parameter set 74 | \emph{that will be used to fit} the supplied row \code{id} of the workflow set. 75 | Note that workflow sets reference a parameter set associated with the 76 | \code{workflow} contained in the \code{info} column by default, but can be 77 | fitted with a modified parameter set via the \code{\link[=option_add]{option_add()}} interface. 78 | This extractor returns the latter, if it exists, and returns the former 79 | if not, mirroring the process that \code{\link[=workflow_map]{workflow_map()}} follows to provide 80 | tuning functions a parameter set. 81 | \item \code{extract_parameter_dials()} returns the \code{parameters} object 82 | \emph{that will be used to fit} the supplied tuning \code{parameter} in the supplied 83 | row \code{id} of the workflow set. See the above notes in 84 | \code{extract_parameter_set_dials()} on precedence. 85 | } 86 | } 87 | \details{ 88 | These functions supersede the \verb{pull_*()} functions (e.g., 89 | \code{\link[=extract_workflow_set_result]{extract_workflow_set_result()}}). 90 | } 91 | \note{ 92 | The package supplies two pre-generated workflow sets, \code{two_class_set} 93 | and \code{chi_features_set}, and associated sets of model fits 94 | \code{two_class_res} and \code{chi_features_res}. 95 | 96 | The \verb{two_class_*} objects are based on a binary classification problem 97 | using the \code{two_class_dat} data from the modeldata package. The six 98 | models utilize either a bare formula or a basic recipe utilizing 99 | \code{recipes::step_YeoJohnson()} as a preprocessor, and a decision tree, 100 | logistic regression, or MARS model specification. See \code{?two_class_set} 101 | for source code. 102 | 103 | The \verb{chi_features_*} objects are based on a regression problem using the 104 | \code{Chicago} data from the modeldata package. Each of the three models 105 | utilize a linear regression model specification, with three different 106 | recipes of varying complexity. The objects are meant to approximate the 107 | sequence of models built in Section 1.3 of Kuhn and Johnson (2019). See 108 | \code{?chi_features_set} for source code. 109 | } 110 | \examples{ 111 | library(tune) 112 | 113 | two_class_res 114 | 115 | extract_workflow_set_result(two_class_res, "none_cart") 116 | 117 | extract_workflow(two_class_res, "none_cart") 118 | } 119 | -------------------------------------------------------------------------------- /man/figures/README-plot-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/workflowsets/ba44a8f28ebb519fa796a6199bc0cee505bd2f6e/man/figures/README-plot-1.png -------------------------------------------------------------------------------- /man/figures/README-plot-best-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/workflowsets/ba44a8f28ebb519fa796a6199bc0cee505bd2f6e/man/figures/README-plot-best-1.png -------------------------------------------------------------------------------- /man/figures/lifecycle-soft-deprecated.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclesoft-deprecatedsoft-deprecated -------------------------------------------------------------------------------- /man/fit_best.workflow_set.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/fit_best.R 3 | \name{fit_best.workflow_set} 4 | \alias{fit_best.workflow_set} 5 | \title{Fit a model to the numerically optimal configuration} 6 | \usage{ 7 | \method{fit_best}{workflow_set}(x, metric = NULL, eval_time = NULL, ...) 8 | } 9 | \arguments{ 10 | \item{x}{A \code{\link[=workflow_set]{workflow_set}} object that has been evaluated 11 | with \code{\link[=workflow_map]{workflow_map()}}. Note that the workflow set must have been fitted with 12 | the \link[=option_add]{control option} \code{save_workflow = TRUE}.} 13 | 14 | \item{metric}{A character string giving the metric to rank results by.} 15 | 16 | \item{eval_time}{A single numeric time point where dynamic event time 17 | metrics should be chosen (e.g., the time-dependent ROC curve, etc). The 18 | values should be consistent with the values used to create \code{x}. The \code{NULL} 19 | default will automatically use the first evaluation time used by \code{x}.} 20 | 21 | \item{...}{Additional options to pass to 22 | \link[tune:fit_best]{tune::fit_best}.} 23 | } 24 | \description{ 25 | \code{fit_best()} takes results from tuning many models and fits the workflow 26 | configuration associated with the best performance to the training set. 27 | } 28 | \details{ 29 | This function is a shortcut for the steps needed to fit the 30 | numerically optimal configuration in a fitted workflow set. 31 | The function ranks results, extracts the tuning result pertaining 32 | to the best result, and then again calls \code{fit_best()} (itself a 33 | wrapper) on the tuning result containing the best result. 34 | 35 | In pseudocode: 36 | 37 | \if{html}{\out{
}}\preformatted{rankings <- rank_results(wf_set, metric, select_best = TRUE) 38 | tune_res <- extract_workflow_set_result(wf_set, rankings$wflow_id[1]) 39 | fit_best(tune_res, metric) 40 | }\if{html}{\out{
}} 41 | } 42 | \note{ 43 | The package supplies two pre-generated workflow sets, \code{two_class_set} 44 | and \code{chi_features_set}, and associated sets of model fits 45 | \code{two_class_res} and \code{chi_features_res}. 46 | 47 | The \verb{two_class_*} objects are based on a binary classification problem 48 | using the \code{two_class_dat} data from the modeldata package. The six 49 | models utilize either a bare formula or a basic recipe utilizing 50 | \code{recipes::step_YeoJohnson()} as a preprocessor, and a decision tree, 51 | logistic regression, or MARS model specification. See \code{?two_class_set} 52 | for source code. 53 | 54 | The \verb{chi_features_*} objects are based on a regression problem using the 55 | \code{Chicago} data from the modeldata package. Each of the three models 56 | utilize a linear regression model specification, with three different 57 | recipes of varying complexity. The objects are meant to approximate the 58 | sequence of models built in Section 1.3 of Kuhn and Johnson (2019). See 59 | \code{?chi_features_set} for source code. 60 | } 61 | \examples{ 62 | \dontshow{if (rlang::is_installed(c("kknn", "modeldata", "recipes", "yardstick", "dials")) && identical(Sys.getenv("NOT_CRAN"), "true")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 63 | library(tune) 64 | library(modeldata) 65 | library(rsample) 66 | 67 | data(Chicago) 68 | Chicago <- Chicago[1:1195, ] 69 | 70 | time_val_split <- 71 | sliding_period( 72 | Chicago, 73 | date, 74 | "month", 75 | lookback = 38, 76 | assess_stop = 1 77 | ) 78 | 79 | chi_features_set 80 | 81 | chi_features_res_new <- 82 | chi_features_set |> 83 | # note: must set `save_workflow = TRUE` to use `fit_best()` 84 | option_add(control = control_grid(save_workflow = TRUE)) |> 85 | # evaluate with resamples 86 | workflow_map(resamples = time_val_split, grid = 21, seed = 1, verbose = TRUE) 87 | 88 | chi_features_res_new 89 | 90 | # sort models by performance metrics 91 | rank_results(chi_features_res_new) 92 | 93 | # fit the numerically optimal configuration to the training set 94 | chi_features_wf <- fit_best(chi_features_res_new) 95 | 96 | chi_features_wf 97 | 98 | # to select optimal value based on a specific metric: 99 | fit_best(chi_features_res_new, metric = "rmse") 100 | \dontshow{\}) # examplesIf} 101 | } 102 | -------------------------------------------------------------------------------- /man/leave_var_out_formulas.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/leave_var_out_formulas.R 3 | \name{leave_var_out_formulas} 4 | \alias{leave_var_out_formulas} 5 | \title{Create formulas without each predictor} 6 | \usage{ 7 | leave_var_out_formulas(formula, data, full_model = TRUE, ...) 8 | } 9 | \arguments{ 10 | \item{formula}{A model formula that contains at least two predictors.} 11 | 12 | \item{data}{A data frame.} 13 | 14 | \item{full_model}{A logical; should the list include the original formula?} 15 | 16 | \item{...}{Options to pass to \code{\link[stats:model.frame]{stats::model.frame()}}} 17 | } 18 | \value{ 19 | A named list of formulas 20 | } 21 | \description{ 22 | From an initial model formula, create a list of formulas that exclude 23 | each predictor. 24 | } 25 | \details{ 26 | The new formulas obey the hierarchy rule so that interactions 27 | without main effects are not included (unless the original formula contains 28 | such terms). 29 | 30 | Factor predictors are left as-is (i.e., no indicator variables are created). 31 | } 32 | \examples{ 33 | \dontshow{if (rlang::is_installed("modeldata")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} 34 | data(penguins, package = "modeldata") 35 | 36 | leave_var_out_formulas( 37 | bill_length_mm ~ ., 38 | data = penguins 39 | ) 40 | 41 | leave_var_out_formulas( 42 | bill_length_mm ~ (island + sex)^2 + flipper_length_mm, 43 | data = penguins 44 | ) 45 | 46 | leave_var_out_formulas( 47 | bill_length_mm ~ (island + sex)^2 + flipper_length_mm + 48 | I(flipper_length_mm^2), 49 | data = penguins 50 | ) 51 | \dontshow{\}) # examplesIf} 52 | } 53 | \seealso{ 54 | \code{\link[=workflow_set]{workflow_set()}} 55 | } 56 | -------------------------------------------------------------------------------- /man/option_add.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/options.R 3 | \name{option_add} 4 | \alias{option_add} 5 | \alias{option_remove} 6 | \alias{option_add_parameters} 7 | \title{Add and edit options saved in a workflow set} 8 | \usage{ 9 | option_add(x, ..., id = NULL, strict = FALSE) 10 | 11 | option_remove(x, ...) 12 | 13 | option_add_parameters(x, id = NULL, strict = FALSE) 14 | } 15 | \arguments{ 16 | \item{x}{A workflow set outputted by \code{\link[=workflow_set]{workflow_set()}} or \code{\link[=workflow_map]{workflow_map()}}.} 17 | 18 | \item{...}{Arguments to pass to the \verb{tune_*()} functions (e.g. 19 | \code{\link[tune:tune_grid]{tune::tune_grid()}}) or \code{\link[tune:fit_resamples]{tune::fit_resamples()}}. For \code{option_remove()} this 20 | can be a series of unquoted option names.} 21 | 22 | \item{id}{A character string of one or more values from the \code{wflow_id} 23 | column that indicates which options to update. By default, all workflows 24 | are updated.} 25 | 26 | \item{strict}{A logical; should execution stop if existing options are being 27 | replaced?} 28 | } 29 | \value{ 30 | An updated workflow set. 31 | } 32 | \description{ 33 | The \code{option} column controls options for the functions that are used to 34 | \emph{evaluate} the workflow set, such as \code{\link[tune:fit_resamples]{tune::fit_resamples()}} or 35 | \code{\link[tune:tune_grid]{tune::tune_grid()}}. Examples of common options to set for these functions 36 | include \code{param_info} and \code{grid}. 37 | 38 | These functions are helpful for manipulating the information in the \code{option} 39 | column. 40 | } 41 | \details{ 42 | \code{option_add()} is used to update all of the options in a workflow set. 43 | 44 | \code{option_remove()} will eliminate specific options across rows. 45 | 46 | \code{option_add_parameters()} adds a parameter object to the \code{option} column 47 | (if parameters are being tuned). 48 | 49 | Note that executing a function on the workflow set, such as \code{tune_grid()}, 50 | will add any options given to that function to the \code{option} column. 51 | 52 | These functions do \emph{not} control options for the individual workflows, such as 53 | the recipe blueprint. When creating a workflow manually, use 54 | \code{\link[workflows:add_model]{workflows::add_model()}} or \code{\link[workflows:add_recipe]{workflows::add_recipe()}} to specify 55 | extra options. To alter these in a workflow set, use 56 | \code{\link[=update_workflow_model]{update_workflow_model()}} or \code{\link[=update_workflow_recipe]{update_workflow_recipe()}}. 57 | } 58 | \examples{ 59 | library(tune) 60 | 61 | two_class_set 62 | 63 | two_class_set |> 64 | option_add(grid = 10) 65 | 66 | two_class_set |> 67 | option_add(grid = 10) |> 68 | option_add(grid = 50, id = "none_cart") 69 | 70 | two_class_set |> 71 | option_add_parameters() 72 | } 73 | -------------------------------------------------------------------------------- /man/option_list.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/options.R 3 | \name{option_list} 4 | \alias{option_list} 5 | \title{Make a classed list of options} 6 | \usage{ 7 | option_list(...) 8 | } 9 | \arguments{ 10 | \item{...}{A set of named options (or nothing)} 11 | } 12 | \value{ 13 | A classed list. 14 | } 15 | \description{ 16 | This function returns a named list with an extra class of 17 | \code{"workflow_set_options"} that has corresponding formatting methods for 18 | printing inside of tibbles. 19 | } 20 | \examples{ 21 | option_list(a = 1, b = 2) 22 | option_list() 23 | } 24 | -------------------------------------------------------------------------------- /man/pull_workflow_set_result.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/pull.R 3 | \name{pull_workflow_set_result} 4 | \alias{pull_workflow_set_result} 5 | \alias{pull_workflow} 6 | \title{Extract elements from a workflow set} 7 | \usage{ 8 | pull_workflow_set_result(x, id) 9 | 10 | pull_workflow(x, id) 11 | } 12 | \arguments{ 13 | \item{x}{A workflow set outputted by \code{\link[=workflow_set]{workflow_set()}} or \code{\link[=workflow_map]{workflow_map()}}.} 14 | 15 | \item{id}{A single character string for a workflow ID.} 16 | } 17 | \value{ 18 | \code{pull_workflow_set_result()} produces a \code{tune_result} or 19 | \code{resample_results} object. \code{pull_workflow()} returns an unfit workflow 20 | object. 21 | } 22 | \description{ 23 | \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} 24 | } 25 | \details{ 26 | \code{pull_workflow_set_result()} retrieves the results of \code{\link[=workflow_map]{workflow_map()}} for a 27 | particular workflow while \code{pull_workflow()} extracts the unfitted workflow 28 | from the \code{info} column. 29 | 30 | The \code{\link[=extract_workflow_set_result]{extract_workflow_set_result()}} and \code{\link[=extract_workflow]{extract_workflow()}} functions should 31 | be used instead of these functions. 32 | } 33 | -------------------------------------------------------------------------------- /man/rank_results.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/rank_results.R 3 | \name{rank_results} 4 | \alias{rank_results} 5 | \title{Rank the results by a metric} 6 | \usage{ 7 | rank_results(x, rank_metric = NULL, eval_time = NULL, select_best = FALSE) 8 | } 9 | \arguments{ 10 | \item{x}{A \code{\link[=workflow_set]{workflow_set}} object that has been evaluated 11 | with \code{\link[=workflow_map]{workflow_map()}}.} 12 | 13 | \item{rank_metric}{A character string for a metric.} 14 | 15 | \item{eval_time}{A single numeric time point where dynamic event time 16 | metrics should be chosen (e.g., the time-dependent ROC curve, etc). The 17 | values should be consistent with the values used to create \code{x}. The \code{NULL} 18 | default will automatically use the first evaluation time used by \code{x}.} 19 | 20 | \item{select_best}{A logical giving whether the results should only contain 21 | the numerically best submodel per workflow.} 22 | } 23 | \value{ 24 | A tibble with columns: \code{wflow_id}, \code{.config}, \code{.metric}, \code{mean}, 25 | \code{std_err}, \code{n}, \code{preprocessor}, \code{model}, and \code{rank}. 26 | } 27 | \description{ 28 | This function sorts the results by a specific performance metric. 29 | } 30 | \details{ 31 | If some models have the exact same performance, 32 | \code{rank(value, ties.method = "random")} is used (with a reproducible seed) so 33 | that all ranks are integers. 34 | 35 | No columns are returned for the tuning parameters since they are likely to 36 | be different (or not exist) for some models. The \code{wflow_id} and \code{.config} 37 | columns can be used to determine the corresponding parameter values. 38 | } 39 | \note{ 40 | The package supplies two pre-generated workflow sets, \code{two_class_set} 41 | and \code{chi_features_set}, and associated sets of model fits 42 | \code{two_class_res} and \code{chi_features_res}. 43 | 44 | The \verb{two_class_*} objects are based on a binary classification problem 45 | using the \code{two_class_dat} data from the modeldata package. The six 46 | models utilize either a bare formula or a basic recipe utilizing 47 | \code{recipes::step_YeoJohnson()} as a preprocessor, and a decision tree, 48 | logistic regression, or MARS model specification. See \code{?two_class_set} 49 | for source code. 50 | 51 | The \verb{chi_features_*} objects are based on a regression problem using the 52 | \code{Chicago} data from the modeldata package. Each of the three models 53 | utilize a linear regression model specification, with three different 54 | recipes of varying complexity. The objects are meant to approximate the 55 | sequence of models built in Section 1.3 of Kuhn and Johnson (2019). See 56 | \code{?chi_features_set} for source code. 57 | } 58 | \examples{ 59 | chi_features_res 60 | 61 | rank_results(chi_features_res) 62 | rank_results(chi_features_res, select_best = TRUE) 63 | rank_results(chi_features_res, rank_metric = "rsq") 64 | } 65 | -------------------------------------------------------------------------------- /man/reexports.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/0_imports.R, R/fit_best.R 3 | \docType{import} 4 | \name{reexports} 5 | \alias{reexports} 6 | \alias{collect_metrics} 7 | \alias{collect_predictions} 8 | \alias{collect_notes} 9 | \alias{collect_extracts} 10 | \alias{\%>\%} 11 | \alias{autoplot} 12 | \alias{extract_spec_parsnip} 13 | \alias{extract_recipe} 14 | \alias{extract_fit_parsnip} 15 | \alias{extract_fit_engine} 16 | \alias{extract_mold} 17 | \alias{extract_preprocessor} 18 | \alias{extract_workflow} 19 | \alias{extract_parameter_set_dials} 20 | \alias{extract_parameter_dials} 21 | \alias{fit_best} 22 | \title{Objects exported from other packages} 23 | \keyword{internal} 24 | \description{ 25 | These objects are imported from other packages. Follow the links 26 | below to see their documentation. 27 | 28 | \describe{ 29 | \item{dplyr}{\code{\link[dplyr:reexports]{\%>\%}}} 30 | 31 | \item{ggplot2}{\code{\link[ggplot2]{autoplot}}} 32 | 33 | \item{hardhat}{\code{\link[hardhat:hardhat-extract]{extract_fit_engine}}, \code{\link[hardhat:hardhat-extract]{extract_fit_parsnip}}, \code{\link[hardhat:hardhat-extract]{extract_mold}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_dials}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_set_dials}}, \code{\link[hardhat:hardhat-extract]{extract_preprocessor}}, \code{\link[hardhat:hardhat-extract]{extract_recipe}}, \code{\link[hardhat:hardhat-extract]{extract_spec_parsnip}}, \code{\link[hardhat:hardhat-extract]{extract_workflow}}} 34 | 35 | \item{tune}{\code{\link[tune:collect_predictions]{collect_extracts}}, \code{\link[tune:collect_predictions]{collect_metrics}}, \code{\link[tune:collect_predictions]{collect_notes}}, \code{\link[tune]{collect_predictions}}, \code{\link[tune]{fit_best}}} 36 | }} 37 | 38 | -------------------------------------------------------------------------------- /man/two_class_set.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{two_class_set} 5 | \alias{two_class_set} 6 | \alias{two_class_res} 7 | \title{Two Class Example Data} 8 | \description{ 9 | The package supplies two pre-generated workflow sets, \code{two_class_set} 10 | and \code{chi_features_set}, and associated sets of model fits 11 | \code{two_class_res} and \code{chi_features_res}. 12 | 13 | The \verb{two_class_*} objects are based on a binary classification problem 14 | using the \code{two_class_dat} data from the modeldata package. The six 15 | models utilize either a bare formula or a basic recipe utilizing 16 | \code{recipes::step_YeoJohnson()} as a preprocessor, and a decision tree, 17 | logistic regression, or MARS model specification. See \code{?two_class_set} 18 | for source code. 19 | 20 | The \verb{chi_features_*} objects are based on a regression problem using the 21 | \code{Chicago} data from the modeldata package. Each of the three models 22 | utilize a linear regression model specification, with three different 23 | recipes of varying complexity. The objects are meant to approximate the 24 | sequence of models built in Section 1.3 of Kuhn and Johnson (2019). See 25 | \code{?chi_features_set} for source code. 26 | } 27 | \details{ 28 | See below for the source code to generate the Two Class example workflow 29 | sets: 30 | 31 | \if{html}{\out{
}}\preformatted{library(workflowsets) 32 | library(workflows) 33 | library(modeldata) 34 | library(recipes) 35 | library(parsnip) 36 | library(dplyr) 37 | library(rsample) 38 | library(tune) 39 | library(yardstick) 40 | 41 | # ------------------------------------------------------------------------------ 42 | 43 | data(two_class_dat, package = "modeldata") 44 | 45 | set.seed(1) 46 | folds <- vfold_cv(two_class_dat, v = 5) 47 | 48 | # ------------------------------------------------------------------------------ 49 | 50 | decision_tree_rpart_spec <- 51 | decision_tree(min_n = tune(), cost_complexity = tune()) |> 52 | set_engine('rpart') |> 53 | set_mode('classification') 54 | 55 | logistic_reg_glm_spec <- 56 | logistic_reg() |> 57 | set_engine('glm') 58 | 59 | mars_earth_spec <- 60 | mars(prod_degree = tune()) |> 61 | set_engine('earth') |> 62 | set_mode('classification') 63 | 64 | # ------------------------------------------------------------------------------ 65 | 66 | yj_recipe <- 67 | recipe(Class ~ ., data = two_class_dat) |> 68 | step_YeoJohnson(A, B) 69 | 70 | # ------------------------------------------------------------------------------ 71 | 72 | two_class_set <- 73 | workflow_set( 74 | preproc = list(none = Class ~ A + B, yj_trans = yj_recipe), 75 | models = list(cart = decision_tree_rpart_spec, glm = logistic_reg_glm_spec, 76 | mars = mars_earth_spec) 77 | ) 78 | 79 | # ------------------------------------------------------------------------------ 80 | 81 | two_class_res <- 82 | two_class_set |> 83 | workflow_map( 84 | resamples = folds, 85 | grid = 10, 86 | seed = 2, 87 | verbose = TRUE, 88 | control = control_grid(save_workflow = TRUE) 89 | ) 90 | }\if{html}{\out{
}} 91 | } 92 | \examples{ 93 | data(two_class_set) 94 | 95 | two_class_set 96 | } 97 | \keyword{datasets} 98 | -------------------------------------------------------------------------------- /man/update_workflow_model.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/update.R 3 | \name{update_workflow_model} 4 | \alias{update_workflow_model} 5 | \alias{update_workflow_recipe} 6 | \title{Update components of a workflow within a workflow set} 7 | \usage{ 8 | update_workflow_model(x, id, spec, formula = NULL) 9 | 10 | update_workflow_recipe(x, id, recipe, blueprint = NULL) 11 | } 12 | \arguments{ 13 | \item{x}{A workflow set outputted by \code{\link[=workflow_set]{workflow_set()}} or \code{\link[=workflow_map]{workflow_map()}}.} 14 | 15 | \item{id}{A single character string from the \code{wflow_id} column indicating 16 | which workflow to update.} 17 | 18 | \item{spec}{A parsnip model specification.} 19 | 20 | \item{formula}{An optional formula override to specify the terms of the 21 | model. Typically, the terms are extracted from the formula or recipe 22 | preprocessing methods. However, some models (like survival and bayesian 23 | models) use the formula not to preprocess, but to specify the structure 24 | of the model. In those cases, a formula specifying the model structure 25 | must be passed unchanged into the model call itself. This argument is 26 | used for those purposes.} 27 | 28 | \item{recipe}{A recipe created using \code{\link[recipes:recipe]{recipes::recipe()}}. The recipe 29 | should not have been trained already with \code{\link[recipes:prep]{recipes::prep()}}; workflows 30 | will handle training internally.} 31 | 32 | \item{blueprint}{A hardhat blueprint used for fine tuning the preprocessing. 33 | 34 | If \code{NULL}, \code{\link[hardhat:default_recipe_blueprint]{hardhat::default_recipe_blueprint()}} is used. 35 | 36 | Note that preprocessing done here is separate from preprocessing that 37 | might be done automatically by the underlying model.} 38 | } 39 | \description{ 40 | Workflows can take special arguments for the recipe (e.g. a blueprint) or a 41 | model (e.g. a special formula). However, when creating a workflow set, there 42 | is no way to specify these extra components. 43 | 44 | \code{update_workflow_model()} and \code{update_workflow_recipe()} allow users to set 45 | these values \emph{after} the workflow set is initially created. They are 46 | analogous to \code{\link[workflows:add_model]{workflows::add_model()}} or \code{\link[workflows:add_recipe]{workflows::add_recipe()}}. 47 | } 48 | \note{ 49 | The package supplies two pre-generated workflow sets, \code{two_class_set} 50 | and \code{chi_features_set}, and associated sets of model fits 51 | \code{two_class_res} and \code{chi_features_res}. 52 | 53 | The \verb{two_class_*} objects are based on a binary classification problem 54 | using the \code{two_class_dat} data from the modeldata package. The six 55 | models utilize either a bare formula or a basic recipe utilizing 56 | \code{recipes::step_YeoJohnson()} as a preprocessor, and a decision tree, 57 | logistic regression, or MARS model specification. See \code{?two_class_set} 58 | for source code. 59 | 60 | The \verb{chi_features_*} objects are based on a regression problem using the 61 | \code{Chicago} data from the modeldata package. Each of the three models 62 | utilize a linear regression model specification, with three different 63 | recipes of varying complexity. The objects are meant to approximate the 64 | sequence of models built in Section 1.3 of Kuhn and Johnson (2019). See 65 | \code{?chi_features_set} for source code. 66 | } 67 | \examples{ 68 | library(parsnip) 69 | 70 | new_mod <- 71 | decision_tree() |> 72 | set_engine("rpart", method = "anova") |> 73 | set_mode("classification") 74 | 75 | new_set <- update_workflow_model(two_class_res, "none_cart", spec = new_mod) 76 | 77 | new_set 78 | 79 | extract_workflow(new_set, id = "none_cart") 80 | } 81 | -------------------------------------------------------------------------------- /man/workflowsets-package.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/0_imports.R 3 | \docType{package} 4 | \name{workflowsets-package} 5 | \alias{workflowsets} 6 | \alias{workflowsets-package} 7 | \title{workflowsets: Create a Collection of 'tidymodels' Workflows} 8 | \description{ 9 | A workflow is a combination of a model and preprocessors (e.g, a formula, recipe, etc.) (Kuhn and Silge (2021) \url{https://www.tmwr.org/}). In order to try different combinations of these, an object can be created that contains many workflows. There are functions to create workflows en masse as well as training them and visualizing the results. 10 | } 11 | \seealso{ 12 | Useful links: 13 | \itemize{ 14 | \item \url{https://github.com/tidymodels/workflowsets} 15 | \item \url{https://workflowsets.tidymodels.org} 16 | \item Report bugs at \url{https://github.com/tidymodels/workflowsets/issues} 17 | } 18 | 19 | } 20 | \author{ 21 | \strong{Maintainer}: Simon Couch \email{simon.couch@posit.co} (\href{https://orcid.org/0000-0001-5676-5107}{ORCID}) 22 | 23 | Authors: 24 | \itemize{ 25 | \item Max Kuhn \email{max@posit.co} (\href{https://orcid.org/0000-0003-2402-136X}{ORCID}) 26 | } 27 | 28 | Other contributors: 29 | \itemize{ 30 | \item Posit Software, PBC (03wc8by49) [copyright holder, funder] 31 | } 32 | 33 | } 34 | \keyword{internal} 35 | -------------------------------------------------------------------------------- /tests/spelling.R: -------------------------------------------------------------------------------- 1 | if (requireNamespace("spelling", quietly = TRUE)) { 2 | spelling::spell_check_test( 3 | vignettes = TRUE, 4 | error = FALSE, 5 | skip_on_cran = TRUE 6 | ) 7 | } 8 | -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(workflowsets) 3 | 4 | test_check("workflowsets") 5 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/autoplot.md: -------------------------------------------------------------------------------- 1 | # autoplot with bad type input 2 | 3 | Code 4 | autoplot(two_class_res, metric = "roc_auc", type = "banana") 5 | Condition 6 | Error in `autoplot()`: 7 | ! `type` must be one of "class" or "wflow_id", not "banana". 8 | 9 | # automatic selection of rank metric 10 | 11 | Code 12 | pick_metric(two_class_res, "roc_auc", "accuracy") 13 | Condition 14 | Error: 15 | ! Metric "roc_auc" was not in the results. 16 | 17 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/checks.md: -------------------------------------------------------------------------------- 1 | # wf set check works 2 | 3 | Code 4 | check_wf_set("no") 5 | Condition 6 | Error: 7 | ! "no" must be a workflow set, not the string "no". 8 | 9 | --- 10 | 11 | Code 12 | check_wf_set(data.frame()) 13 | Condition 14 | Error: 15 | ! data.frame() must be a workflow set, not a object. 16 | 17 | --- 18 | 19 | Code 20 | rank_results("beeEeEEp!") 21 | Condition 22 | Error in `rank_results()`: 23 | ! x must be a workflow set, not the string "beeEeEEp!". 24 | 25 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/collect-extracts.md: -------------------------------------------------------------------------------- 1 | # collect_extracts fails gracefully without .extracts column 2 | 3 | Code 4 | collect_extracts(wflow_set_trained) 5 | Condition 6 | Error in `dplyr::mutate()`: 7 | i In argument: `extracts = list(collect_extracts(result))`. 8 | i In row 1. 9 | Caused by error in `collect_extracts()`: 10 | ! The `.extracts` column does not exist. 11 | i Please supply a control object (`?tune::control_grid()`) with a non-`NULL` `extract` argument during resample fitting. 12 | 13 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/collect-notes.md: -------------------------------------------------------------------------------- 1 | # collect_notes works 2 | 3 | Code 4 | collect_notes(wflow_set) 5 | Condition 6 | Error in `collect_notes()`: 7 | ! There were 2 workflows that had no results. 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/comments.md: -------------------------------------------------------------------------------- 1 | # test comments 2 | 3 | Code 4 | comment_add(two_class_set, "toe", "foot") 5 | Condition 6 | Error in `comment_add()`: 7 | ! The `id` value is not in `wflow_id`. 8 | 9 | --- 10 | 11 | Code 12 | comment_add(two_class_set, letters, "foot") 13 | Condition 14 | Error in `comment_add()`: 15 | ! `id` must be a single string, not a character vector. 16 | 17 | --- 18 | 19 | Code 20 | comment_add(two_class_set, 1:2, "foot") 21 | Condition 22 | Error in `comment_add()`: 23 | ! `id` must be a single string, not an integer vector. 24 | 25 | --- 26 | 27 | Code 28 | comment_add(two_class_set, "none_cart", 1:2) 29 | Condition 30 | Error in `comment_add()`: 31 | ! The comments should be character strings. 32 | 33 | --- 34 | 35 | Code 36 | comment_add(comments_1, "none_cart", "Stuff.", append = FALSE) 37 | Condition 38 | Error in `comment_add()`: 39 | ! There is already a comment for this id and `append = FALSE`. 40 | 41 | --- 42 | 43 | Code 44 | comment_get(comments_1, id = letters) 45 | Condition 46 | Error in `comment_get()`: 47 | ! `id` should be a single character value. 48 | 49 | --- 50 | 51 | Code 52 | comment_get(comments_1, id = letters[1]) 53 | Condition 54 | Error in `comment_get()`: 55 | ! The `id` value is not in `wflow_id`. 56 | 57 | --- 58 | 59 | Code 60 | comment_reset(comments_1, letters) 61 | Condition 62 | Error in `comment_reset()`: 63 | ! `id` should be a single character value. 64 | 65 | --- 66 | 67 | Code 68 | comment_reset(comments_1, "none_carts") 69 | Condition 70 | Error in `comment_reset()`: 71 | ! The `id` value is not in `wflow_id`. 72 | 73 | # print comments 74 | 75 | Code 76 | comment_print(test) 77 | Output 78 | -- none_cart ------------------------------------------------------------------- 79 | 80 | "Whenever you feel like criticizing any one," he told me, "just 81 | remember that all the people in this world haven't had the advantages 82 | that you've had." 83 | 84 | My family have been prominent, well-to-do people in this middle-western 85 | city for three generations. The Carraways are something of a clan and 86 | we have a tradition that we're descended from the Dukes of Buccleuch, 87 | but the actual founder of my line was my grandfather's brother who came 88 | here in fifty-one, sent a substitute to the Civil War and started the 89 | wholesale hardware business that my father carries on today. 90 | 91 | -- none_glm -------------------------------------------------------------------- 92 | 93 | Across the courtesy bay the white palaces of fashionable East Egg 94 | glittered along the water, and the history of the summer really begins 95 | on the evening I drove over there to have dinner with the Tom 96 | Buchanans. Daisy was my second cousin once removed and I'd known Tom in 97 | college. And just after the war I spent two days with them in Chicago. 98 | 99 | 100 | --- 101 | 102 | Code 103 | comment_print(test, id = "none_glm") 104 | Output 105 | -- none_glm -------------------------------------------------------------------- 106 | 107 | Across the courtesy bay the white palaces of fashionable East Egg 108 | glittered along the water, and the history of the summer really begins 109 | on the evening I drove over there to have dinner with the Tom 110 | Buchanans. Daisy was my second cousin once removed and I'd known Tom in 111 | college. And just after the war I spent two days with them in Chicago. 112 | 113 | 114 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/extract.md: -------------------------------------------------------------------------------- 1 | # extracts 2 | 3 | Code 4 | extract_fit_engine(car_set_1, id = "reg_lm") 5 | Condition 6 | Error in `extract_fit_parsnip()`: 7 | ! Can't extract a model fit from an untrained workflow. 8 | i Do you need to call `fit()`? 9 | 10 | --- 11 | 12 | Code 13 | extract_fit_parsnip(car_set_1, id = "reg_lm") 14 | Condition 15 | Error in `extract_fit_parsnip()`: 16 | ! Can't extract a model fit from an untrained workflow. 17 | i Do you need to call `fit()`? 18 | 19 | --- 20 | 21 | Code 22 | extract_mold(car_set_1, id = "reg_lm") 23 | Condition 24 | Error in `extract_mold()`: 25 | ! Can't extract a mold from an untrained workflow. 26 | i Do you need to call `fit()`? 27 | 28 | --- 29 | 30 | Code 31 | extract_recipe(car_set_1, id = "reg_lm") 32 | Condition 33 | Error in `extract_mold()`: 34 | ! Can't extract a mold from an untrained workflow. 35 | i Do you need to call `fit()`? 36 | 37 | --- 38 | 39 | Code 40 | extract_workflow_set_result(car_set_1, "Gideon Nav") 41 | Condition 42 | Error in `extract_workflow_set_result()`: 43 | ! `id` must correspond to a single row in `x`. 44 | 45 | --- 46 | 47 | Code 48 | extract_workflow(car_set_1, "Coronabeth Tridentarius") 49 | Condition 50 | Error in `extract_workflow()`: 51 | ! `id` must correspond to a single row in `x`. 52 | 53 | # extract single parameter from workflow set with untunable workflow 54 | 55 | Code 56 | hardhat::extract_parameter_dials(wf_set, id = "reg_lm", parameter = "non there") 57 | Condition 58 | Error in `extract_parameter_dials()`: 59 | ! No parameter exists with id "non there". 60 | 61 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/fit.md: -------------------------------------------------------------------------------- 1 | # fit() errors informatively with workflow sets 2 | 3 | Code 4 | fit(car_set_1) 5 | Condition 6 | Error in `fit()`: 7 | ! `fit()` is not well-defined for workflow sets. 8 | i Please see `workflow_map()` (`?workflowsets::workflow_map()`). 9 | 10 | --- 11 | 12 | Code 13 | fit(car_set_2) 14 | Condition 15 | Error in `fit()`: 16 | ! `fit()` is not well-defined for workflow sets. 17 | i Please see `fit_best()` (`?workflowsets::fit_best.workflow_set()`). 18 | 19 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/fit_best.md: -------------------------------------------------------------------------------- 1 | # fit_best errors informatively with bad inputs 2 | 3 | Code 4 | fit_best(chi_features_res) 5 | Condition 6 | Error in `fit_best()`: 7 | x The control option `save_workflow = TRUE` should be used when tuning. 8 | 9 | --- 10 | 11 | Code 12 | fit_best(chi_features_map, metric = "boop") 13 | Condition 14 | Error in `fit_best()`: 15 | ! "boop" was not in the metric set. Please choose from: "rmse" and "rsq". 16 | 17 | --- 18 | 19 | Code 20 | fit_best(chi_features_map, boop = "bop") 21 | Condition 22 | Error in `fit_best()`: 23 | ! `...` must be empty. 24 | x Problematic argument: 25 | * boop = "bop" 26 | 27 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/leave-var-out-formulas.md: -------------------------------------------------------------------------------- 1 | # LOO var formulas 2 | 3 | Code 4 | leave_var_out_formulas(y ~ 1, data = form_data) 5 | Condition 6 | Error in `leave_var_out_formulas()`: 7 | ! There should be at least 2 predictors in the formula. 8 | 9 | --- 10 | 11 | Code 12 | leave_var_out_formulas(y ~ a, data = form_data) 13 | Condition 14 | Error in `leave_var_out_formulas()`: 15 | ! There should be at least 2 predictors in the formula. 16 | 17 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/options.md: -------------------------------------------------------------------------------- 1 | # check for bad options 2 | 3 | The option `grid2` cannot be used as an argument for `fit_resamples()` or the `tune_*()` functions. 4 | 5 | --- 6 | 7 | The option `blueprint` cannot be used as an argument for `fit_resamples()` or the `tune_*()` functions. 8 | 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/predict.md: -------------------------------------------------------------------------------- 1 | # predict() errors informatively with workflow sets 2 | 3 | Code 4 | predict(car_set_1) 5 | Condition 6 | Error in `predict()`: 7 | ! `predict()` is not well-defined for workflow sets. 8 | i To predict with the optimal model configuration from a workflow set, ensure that the workflow set was fitted with the control option (`?workflowsets::option_add()`) `save_workflow = TRUE` (`?tune::control_grid()`), run `fit_best()` (`?tune::fit_best()`), and then predict using `predict()` (`?workflows::predict.workflow()`) on its output. 9 | i To collect predictions from a workflow set, ensure that the workflow set was fitted with the control option (`?workflowsets::option_add()`) `save_pred = TRUE` (`?tune::control_grid()`) and run `collect_predictions()` (`?tune::collect_predictions()`). 10 | 11 | --- 12 | 13 | Code 14 | predict(car_set_2) 15 | Condition 16 | Error in `predict()`: 17 | ! `predict()` is not well-defined for workflow sets. 18 | i To predict with the optimal model configuration from a workflow set, ensure that the workflow set was fitted with the control option (`?workflowsets::option_add()`) `save_workflow = TRUE` (`?tune::control_grid()`), run `fit_best()` (`?tune::fit_best()`), and then predict using `predict()` (`?workflows::predict.workflow()`) on its output. 19 | i To collect predictions from a workflow set, ensure that the workflow set was fitted with the control option (`?workflowsets::option_add()`) `save_pred = TRUE` (`?tune::control_grid()`) and run `collect_predictions()` (`?tune::collect_predictions()`). 20 | 21 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/pull.md: -------------------------------------------------------------------------------- 1 | # pulling objects is deprecated 2 | 3 | Code 4 | pull_workflow_set_result(car_set_1, "reg_lm") 5 | Condition 6 | Error: 7 | ! `pull_workflow_set_result()` was deprecated in workflowsets 0.1.0 and is now defunct. 8 | i Please use `extract_workflow_set_result()` instead. 9 | 10 | --- 11 | 12 | Code 13 | pull_workflow(car_set_1, "reg_lm") 14 | Condition 15 | Error: 16 | ! `pull_workflow()` was deprecated in workflowsets 0.1.0 and is now defunct. 17 | i Please use `extract_workflow()` instead. 18 | 19 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/workflow-map.md: -------------------------------------------------------------------------------- 1 | # basic mapping 2 | 3 | Code 4 | workflow_map(two_class_set, "foo", seed = 1, resamples = folds, grid = 2) 5 | Condition 6 | Error in `workflow_map()`: 7 | ! `fn` must be one of "tune_grid", "tune_bayes", "fit_resamples", "tune_race_anova", "tune_race_win_loss", "tune_sim_anneal", or "tune_cluster", not "foo". 8 | 9 | --- 10 | 11 | Code 12 | workflow_map(two_class_set, fn = 1L, seed = 1, resamples = folds, grid = 2) 13 | Condition 14 | Error in `workflow_map()`: 15 | ! `fn` must be a character vector, not the number 1. 16 | 17 | --- 18 | 19 | Code 20 | workflow_map(two_class_set, fn = tune::tune_grid, seed = 1, resamples = folds, 21 | grid = 2) 22 | Condition 23 | Error in `workflow_map()`: 24 | ! `fn` must be a character vector, not a function. 25 | 26 | # map logging 27 | 28 | Code 29 | cat(logging_res, sep = "\n") 30 | Output 31 | i No tuning parameters. `fit_resamples()` will be attempted 32 | i 1 of 3 resampling: reg_lm 33 | i 2 of 3 tuning: reg_knn 34 | i No tuning parameters. `fit_resamples()` will be attempted 35 | i 3 of 3 resampling: nonlin_lm 36 | 37 | # failers 38 | 39 | Code 40 | res_loud <- workflow_map(car_set_3, resamples = folds, seed = 2, verbose = TRUE, 41 | grid = "a") 42 | Message 43 | i 1 of 2 tuning: reg_knn 44 | x 1 of 2 tuning: reg_knn failed with: Error in check_grid(grid = grid, workflow = workflow, pset = pset) : `grid` should be a positive integer or a data frame. 45 | i No tuning parameters. `fit_resamples()` will be attempted 46 | i 2 of 2 resampling: reg_lm 47 | v 2 of 2 resampling: reg_lm (ms) 48 | 49 | # fail informatively on mismatched spec/tuning function 50 | 51 | Code 52 | workflow_map(wf_set_1, resamples = folds) 53 | Condition 54 | Error in `workflow_map()`: 55 | ! To tune with `tune_grid()`, each workflow's model specification must inherit from , but `reg_km` does not. 56 | i The workflow `reg_km` is a cluster specification. Did you intend to set `fn = 'tune_cluster'`? 57 | 58 | --- 59 | 60 | Code 61 | workflow_map(wf_set_2, resamples = folds) 62 | Condition 63 | Error in `workflow_map()`: 64 | ! To tune with `tune_grid()`, each workflow's model specification must inherit from , but `reg_km` and `reg_hc` do not. 65 | i The workflows `reg_km` and `reg_hc` are cluster specifications. Did you intend to set `fn = 'tune_cluster'`? 66 | 67 | --- 68 | 69 | Code 70 | workflow_map(wf_set_1, resamples = folds, fn = "tune_cluster") 71 | Condition 72 | Error in `workflow_map()`: 73 | ! To tune with `tune_cluster()`, each workflow's model specification must inherit from , but `reg_dt` does not. 74 | 75 | --- 76 | 77 | Code 78 | workflow_map(wf_set_3, resamples = folds, fn = "tune_cluster") 79 | Condition 80 | Error in `workflow_map()`: 81 | ! To tune with `tune_cluster()`, each workflow's model specification must inherit from , but `reg_dt` and `reg_nn` do not. 82 | 83 | -------------------------------------------------------------------------------- /tests/testthat/helper-compat.R: -------------------------------------------------------------------------------- 1 | workflow_set_objects <- list( 2 | unfit = chi_features_set, 3 | fit = chi_features_res 4 | ) 5 | 6 | expect_s3_class_workflow_set <- function(x) { 7 | expect_s3_class(x, "workflow_set") 8 | } 9 | 10 | expect_s3_class_bare_tibble <- function(x) { 11 | expect_s3_class(x, c("tbl_df", "tbl", "data.frame"), exact = TRUE) 12 | } 13 | -------------------------------------------------------------------------------- /tests/testthat/helper-extract_parameter_set.R: -------------------------------------------------------------------------------- 1 | check_parameter_set_tibble <- function(x) { 2 | expect_equal( 3 | names(x), 4 | c("name", "id", "source", "component", "component_id", "object") 5 | ) 6 | expect_equal(class(x$name), "character") 7 | expect_equal(class(x$id), "character") 8 | expect_equal(class(x$source), "character") 9 | expect_equal(class(x$component), "character") 10 | expect_equal(class(x$component_id), "character") 11 | expect_true(!any(duplicated(x$id))) 12 | 13 | expect_equal(class(x$object), "list") 14 | obj_check <- purrr::map_lgl( 15 | x$object, 16 | \(.x) inherits(.x, "param") | all(is.na(.x)) 17 | ) 18 | expect_true(all(obj_check)) 19 | 20 | invisible(TRUE) 21 | } 22 | -------------------------------------------------------------------------------- /tests/testthat/test-autoplot.R: -------------------------------------------------------------------------------- 1 | test_that("autoplot with error bars (class)", { 2 | p_1 <- autoplot(two_class_res, metric = "roc_auc") 3 | expect_s3_class(p_1, "ggplot") 4 | expect_equal( 5 | names(p_1$data), 6 | c( 7 | "wflow_id", 8 | ".config", 9 | ".metric", 10 | "mean", 11 | "std_err", 12 | "n", 13 | "preprocessor", 14 | "model", 15 | "rank" 16 | ) 17 | ) 18 | expect_equal(rlang::get_expr(p_1$mapping$x), expr(rank)) 19 | expect_equal(rlang::get_expr(p_1$mapping$y), expr(mean)) 20 | expect_equal(rlang::get_expr(p_1$mapping$colour), expr(model)) 21 | expect_equal(as.list(p_1$facet)$params, list()) 22 | expect_equal( 23 | rlang::get_expr(as.list(p_1$layers[[2]])$mapping$ymin), 24 | expr(mean - std_errs * std_err) 25 | ) 26 | expect_equal( 27 | rlang::get_expr(as.list(p_1$layers[[2]])$mapping$ymax), 28 | expr(mean + std_errs * std_err) 29 | ) 30 | expect_equal(as.character(p_1$labels$y), "roc_auc") 31 | expect_equal(as.character(p_1$labels$x), "Workflow Rank") 32 | }) 33 | 34 | test_that("autoplot with error bars (wflow_id)", { 35 | p_1 <- autoplot(two_class_res, metric = "roc_auc", type = "wflow_id") 36 | expect_s3_class(p_1, "ggplot") 37 | expect_equal( 38 | names(p_1$data), 39 | c( 40 | "wflow_id", 41 | ".config", 42 | ".metric", 43 | "mean", 44 | "std_err", 45 | "n", 46 | "preprocessor", 47 | "model", 48 | "rank" 49 | ) 50 | ) 51 | expect_equal(rlang::get_expr(p_1$mapping$x), expr(rank)) 52 | expect_equal(rlang::get_expr(p_1$mapping$y), expr(mean)) 53 | expect_equal(rlang::get_expr(p_1$mapping$colour), expr(wflow_id)) 54 | expect_equal( 55 | rlang::get_expr(as.list(p_1$layers[[2]])$mapping$ymin), 56 | expr(mean - std_errs * std_err) 57 | ) 58 | expect_equal( 59 | rlang::get_expr(as.list(p_1$layers[[2]])$mapping$ymax), 60 | expr(mean + std_errs * std_err) 61 | ) 62 | expect_equal(as.character(p_1$labels$y), "roc_auc") 63 | expect_equal(as.character(p_1$labels$x), "Workflow Rank") 64 | }) 65 | 66 | test_that("autoplot with bad type input", { 67 | expect_snapshot( 68 | error = TRUE, 69 | autoplot(two_class_res, metric = "roc_auc", type = "banana") 70 | ) 71 | }) 72 | 73 | 74 | test_that("autoplot with without error bars", { 75 | p_2 <- autoplot(chi_features_res) 76 | expect_s3_class(p_2, "ggplot") 77 | expect_equal( 78 | names(p_2$data), 79 | c( 80 | "wflow_id", 81 | ".config", 82 | ".metric", 83 | "mean", 84 | "std_err", 85 | "n", 86 | "preprocessor", 87 | "model", 88 | "rank" 89 | ) 90 | ) 91 | expect_equal(rlang::get_expr(p_2$mapping$x), expr(rank)) 92 | expect_equal(rlang::get_expr(p_2$mapping$y), expr(mean)) 93 | expect_equal(rlang::get_expr(p_2$mapping$colour), expr(model)) 94 | expect_equal(length(p_2$layers), 1) 95 | expect_equal(names(as.list(p_2$facet)$params$facet), ".metric") 96 | expect_equal(p_2$labels$y, "Metric") 97 | expect_equal(p_2$labels$x, "Workflow Rank") 98 | }) 99 | 100 | test_that("autoplot for specific workflow result", { 101 | p_3 <- autoplot(chi_features_res, id = "plus_pca_lm") 102 | p_4 <- autoplot( 103 | extract_workflow_set_result( 104 | chi_features_res, 105 | id = "plus_pca_lm" 106 | ) 107 | ) 108 | expect_equal(p_3$data, p_4$data) 109 | expect_equal(p_3$labels, p_4$labels) 110 | expect_equal( 111 | purrr::map(as.list(p_3$mapping), rlang::get_expr), 112 | purrr::map(as.list(p_4$mapping), rlang::get_expr) 113 | ) 114 | }) 115 | 116 | test_that("automatic selection of rank metric", { 117 | expect_equal( 118 | pick_metric(two_class_res, NULL, NULL), 119 | list(metric = "roc_auc", direction = "maximize") 120 | ) 121 | expect_equal( 122 | pick_metric(two_class_res, NULL, "accuracy"), 123 | list(metric = "accuracy", direction = "maximize") 124 | ) 125 | expect_equal( 126 | pick_metric(two_class_res, "accuracy"), 127 | list(metric = "accuracy", direction = "maximize") 128 | ) 129 | expect_equal( 130 | pick_metric(two_class_res, "roc_auc"), 131 | list(metric = "roc_auc", direction = "maximize") 132 | ) 133 | expect_snapshot( 134 | error = TRUE, 135 | pick_metric(two_class_res, "roc_auc", "accuracy") 136 | ) 137 | }) 138 | -------------------------------------------------------------------------------- /tests/testthat/test-checks.R: -------------------------------------------------------------------------------- 1 | test_that("wf set check works", { 2 | expect_true(check_wf_set(chi_features_set)) 3 | expect_true(check_wf_set(chi_features_res)) 4 | 5 | expect_snapshot(error = TRUE, check_wf_set("no")) 6 | expect_snapshot(error = TRUE, check_wf_set(data.frame())) 7 | expect_snapshot(error = TRUE, rank_results("beeEeEEp!")) 8 | }) 9 | -------------------------------------------------------------------------------- /tests/testthat/test-collect-extracts.R: -------------------------------------------------------------------------------- 1 | skip_on_cran() 2 | 3 | test_that("collect_extracts works", { 4 | set.seed(1) 5 | folds <- rsample::vfold_cv(mtcars, v = 3) 6 | 7 | wflow_set <- 8 | workflow_set( 9 | list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), 10 | list(lm = parsnip::linear_reg()) 11 | ) 12 | 13 | wflow_set_trained <- 14 | wflow_set |> 15 | workflow_map( 16 | "fit_resamples", 17 | resamples = folds, 18 | control = tune::control_resamples(extract = function(x) { 19 | x 20 | }) 21 | ) 22 | 23 | extracts <- collect_extracts(wflow_set_trained) 24 | 25 | expect_equal(nrow(extracts), 6) 26 | expect_contains( 27 | class(extracts$.extracts[[1]]), 28 | "workflow" 29 | ) 30 | expect_named(extracts, c("wflow_id", "id", ".extracts", ".config")) 31 | }) 32 | 33 | 34 | test_that("collect_extracts fails gracefully without .extracts column", { 35 | set.seed(1) 36 | folds <- rsample::vfold_cv(mtcars, v = 3) 37 | 38 | wflow_set <- 39 | workflow_set( 40 | list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), 41 | list(lm = parsnip::linear_reg()) 42 | ) 43 | 44 | wflow_set_trained <- 45 | wflow_set |> 46 | workflow_map("fit_resamples", resamples = folds) 47 | 48 | expect_snapshot( 49 | error = TRUE, 50 | collect_extracts(wflow_set_trained) 51 | ) 52 | }) 53 | -------------------------------------------------------------------------------- /tests/testthat/test-collect-metrics.R: -------------------------------------------------------------------------------- 1 | check_metric_results <- function(ind, x, ...) { 2 | id_val <- x$wflow_id[ind] 3 | 4 | if (any(names(list(...)) == "summarize")) { 5 | cols <- c(".metric", ".estimator", ".estimate", ".config", "id") 6 | } else { 7 | cols <- c(".metric", ".estimator", "mean", "n", "std_err", ".config") 8 | } 9 | 10 | orig <- 11 | collect_metrics(x$result[[ind]], ...) |> 12 | dplyr::select(dplyr::all_of(cols)) 13 | 14 | everythng <- 15 | collect_metrics(x, ...) |> 16 | dplyr::filter(wflow_id == id_val) |> 17 | dplyr::select(dplyr::all_of(cols)) 18 | all.equal(orig, everythng) 19 | } 20 | 21 | # ------------------------------------------------------------------------------ 22 | 23 | test_that("collect summarized metrics", { 24 | for (i in 1:nrow(two_class_res)) { 25 | expect_true(check_metric_results(i, two_class_res)) 26 | } 27 | for (i in 1:nrow(chi_features_res)) { 28 | expect_true(check_metric_results(i, chi_features_res)) 29 | } 30 | 31 | for (i in 1:nrow(two_class_res)) { 32 | expect_true(check_metric_results(i, two_class_res, summarize = FALSE)) 33 | } 34 | for (i in 1:nrow(chi_features_res)) { 35 | expect_true(check_metric_results(i, chi_features_res, summarize = FALSE)) 36 | } 37 | }) 38 | 39 | test_that("ranking models", { 40 | # expected number of rows per metric per model 41 | param_lines <- 42 | c( 43 | none_cart = 10, 44 | none_glm = 1, 45 | none_mars = 2, 46 | yj_trans_cart = 10, 47 | yj_trans_glm = 1, 48 | yj_trans_mars = 2 49 | ) 50 | 51 | expect_no_error(ranking_1 <- rank_results(two_class_res)) 52 | expect_equal(nrow(ranking_1), sum(param_lines * 2)) 53 | 54 | expect_no_error(ranking_2 <- rank_results(two_class_res, select_best = TRUE)) 55 | expect_equal(nrow(ranking_2), nrow(two_class_res) * 2) 56 | }) 57 | -------------------------------------------------------------------------------- /tests/testthat/test-collect-notes.R: -------------------------------------------------------------------------------- 1 | skip_on_cran() 2 | 3 | test_that("collect_notes works", { 4 | set.seed(1) 5 | folds <- rsample::vfold_cv(mtcars, v = 3) 6 | 7 | wflow_set <- 8 | workflow_set( 9 | list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), 10 | list(lm = parsnip::linear_reg()) 11 | ) 12 | 13 | wflow_set_trained <- 14 | wflow_set |> 15 | workflow_map( 16 | "fit_resamples", 17 | resamples = folds, 18 | control = tune::control_resamples(extract = function(x) { 19 | warn("hey!") 20 | }) 21 | ) 22 | 23 | expect_snapshot(error = TRUE, collect_notes(wflow_set)) 24 | notes <- collect_notes(wflow_set_trained) 25 | 26 | expect_equal(nrow(notes), 6) 27 | expect_contains(notes$note, "hey!") 28 | expect_named(notes, c("wflow_id", "id", "location", "type", "note")) 29 | }) 30 | -------------------------------------------------------------------------------- /tests/testthat/test-collect-predictions.R: -------------------------------------------------------------------------------- 1 | skip_on_cran() 2 | skip_if_not_installed("kknn") 3 | skip_if_not_installed("modeldata") 4 | 5 | # ------------------------------------------------------------------------------ 6 | 7 | library(parsnip) 8 | suppressPackageStartupMessages(library(rsample)) 9 | suppressPackageStartupMessages(library(tune)) 10 | 11 | # ------------------------------------------------------------------------------ 12 | 13 | lr_spec <- linear_reg() |> set_engine("lm") 14 | knn_spec <- 15 | nearest_neighbor(neighbors = tune()) |> 16 | set_engine("kknn") |> 17 | set_mode("regression") 18 | 19 | set.seed(1) 20 | car_set_1 <- 21 | workflow_set( 22 | list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), 23 | list(lm = lr_spec) 24 | ) |> 25 | workflow_map( 26 | "fit_resamples", 27 | resamples = vfold_cv(mtcars, v = 3), 28 | control = tune::control_resamples(save_pred = TRUE) 29 | ) 30 | 31 | set.seed(1) 32 | resamples <- vfold_cv(mtcars, v = 3, repeats = 2) 33 | 34 | set.seed(1) 35 | car_set_2 <- 36 | workflow_set( 37 | list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), 38 | list(lm = lr_spec) 39 | ) |> 40 | workflow_map( 41 | "fit_resamples", 42 | resamples = resamples, 43 | control = tune::control_resamples(save_pred = TRUE) 44 | ) 45 | 46 | set.seed(1) 47 | car_set_3 <- 48 | workflow_set( 49 | list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), 50 | list(knn = knn_spec) 51 | ) |> 52 | workflow_map( 53 | "tune_bayes", 54 | resamples = resamples, 55 | control = tune::control_bayes(save_pred = TRUE), 56 | seed = 1, 57 | iter = 2, 58 | initial = 3 59 | ) 60 | 61 | car_set_23 <- dplyr::bind_rows(car_set_2, car_set_3) 62 | 63 | # ------------------------------------------------------------------------------ 64 | 65 | check_prediction_results <- function(ind, x, summarize = FALSE, ...) { 66 | id_val <- x$wflow_id[ind] 67 | 68 | cols <- c(".row", "mpg", ".config", ".pred") 69 | 70 | orig <- 71 | collect_predictions(x$result[[ind]], summarize = summarize, ...) |> 72 | dplyr::select(dplyr::all_of(cols)) 73 | 74 | if (any(names(list(...)) == "summarize")) { 75 | cols <- c(grep("^id", names(orig), value = TRUE), cols) 76 | } 77 | 78 | everythng <- 79 | collect_predictions(x, summarize = summarize, ...) |> 80 | dplyr::filter(wflow_id == id_val) |> 81 | dplyr::select(dplyr::all_of(cols)) 82 | all.equal(orig, everythng) 83 | } 84 | 85 | # ------------------------------------------------------------------------------ 86 | 87 | test_that("collect predictions", { 88 | expect_no_error( 89 | res_car_set_1 <- collect_predictions(car_set_1) 90 | ) 91 | expect_true(nrow(mtcars) * nrow(car_set_1) == nrow(res_car_set_1)) 92 | 93 | expect_no_error( 94 | res_car_set_2 <- collect_predictions(car_set_2) 95 | ) 96 | expect_true(nrow(mtcars) * nrow(car_set_2) == nrow(res_car_set_2)) 97 | 98 | expect_no_error( 99 | res_car_set_2_reps <- collect_predictions(car_set_2, summarize = FALSE) 100 | ) 101 | expect_true(nrow(mtcars) * nrow(car_set_2) * 2 == nrow(res_car_set_2_reps)) 102 | 103 | expect_no_error( 104 | res_car_set_3 <- collect_predictions(car_set_3) 105 | ) 106 | expect_true(nrow(mtcars) * nrow(car_set_2) * 5 == nrow(res_car_set_3)) 107 | 108 | expect_no_error( 109 | res_car_set_3_reps <- collect_predictions(car_set_3, summarize = FALSE) 110 | ) 111 | expect_true( 112 | nrow(mtcars) * nrow(car_set_2) * 5 * 2 == nrow(res_car_set_3_reps) 113 | ) 114 | 115 | # --------------------------------------------------------------------------- 116 | # These don't seem to get captured by covr 117 | for (i in 1:nrow(car_set_1)) { 118 | expect_true(check_prediction_results(i, car_set_1)) 119 | } 120 | for (i in 1:nrow(car_set_2)) { 121 | expect_true(check_prediction_results(i, car_set_2)) 122 | } 123 | 124 | for (i in 1:nrow(car_set_1)) { 125 | expect_true(check_prediction_results(i, car_set_1, summarize = FALSE)) 126 | } 127 | for (i in 1:nrow(car_set_2)) { 128 | expect_true(check_prediction_results(i, car_set_2, summarize = FALSE)) 129 | } 130 | }) 131 | 132 | skip_if(packageVersion("tune") <= "1.1.0") 133 | 134 | test_that("dropping tuning parameter columns", { 135 | expect_named( 136 | collect_predictions(car_set_1), 137 | c("wflow_id", ".config", "preproc", "model", ".row", "mpg", ".pred"), 138 | ignore.order = TRUE 139 | ) 140 | expect_named( 141 | collect_predictions(car_set_2), 142 | c("wflow_id", ".config", "preproc", "model", ".row", "mpg", ".pred"), 143 | ignore.order = TRUE 144 | ) 145 | 146 | expect_named( 147 | collect_predictions(car_set_1, summarize = FALSE), 148 | c("wflow_id", ".config", "preproc", "model", "id", ".pred", ".row", "mpg"), 149 | ignore.order = TRUE 150 | ) 151 | expect_named( 152 | collect_predictions(car_set_2, summarize = FALSE), 153 | c( 154 | "wflow_id", 155 | ".config", 156 | "preproc", 157 | "model", 158 | "id", 159 | "id2", 160 | ".pred", 161 | ".row", 162 | "mpg" 163 | ), 164 | ignore.order = TRUE 165 | ) 166 | 167 | expect_no_error( 168 | best_iter <- collect_predictions( 169 | car_set_3, 170 | select_best = TRUE, 171 | metric = "rmse" 172 | ) 173 | ) 174 | expect_true( 175 | nrow(dplyr::distinct(best_iter[, c(".config", "wflow_id")])) == 2 176 | ) 177 | expect_no_error( 178 | no_param <- 179 | select_bare_predictions(car_set_3$result[[1]], metric = "rmse", TRUE) 180 | ) 181 | expect_named( 182 | no_param, 183 | c(".row", "mpg", ".config", ".iter", ".pred"), 184 | ignore.order = TRUE 185 | ) 186 | }) 187 | 188 | 189 | test_that("mixed object types", { 190 | expect_true(".iter" %in% names(collect_predictions(car_set_23))) 191 | }) 192 | -------------------------------------------------------------------------------- /tests/testthat/test-comments.R: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | 3 | test_that("test comments", { 4 | comments_1 <- 5 | two_class_set |> 6 | comment_add("none_cart", "What does 'cart' stand for\u2753") 7 | 8 | expect_equal( 9 | comment_get(comments_1, id = "none_cart"), 10 | "What does 'cart' stand for\u2753" 11 | ) 12 | 13 | for (i in 2:nrow(comments_1)) { 14 | expect_equal(comments_1$info[[i]]$comment, character(1)) 15 | } 16 | comments_2 <- 17 | comments_1 |> 18 | comment_add("none_cart", "Stuff.") 19 | expect_equal( 20 | comment_get(comments_2, id = "none_cart") |> paste0(collapse = "\n"), 21 | "What does 'cart' stand for\u2753\nStuff." 22 | ) 23 | comments_3 <- 24 | comments_2 |> 25 | comment_reset("none_cart") 26 | expect_equal( 27 | comments_3$info[[1]]$comment, 28 | character(1) 29 | ) 30 | expect_equal( 31 | two_class_set |> comment_add(), 32 | two_class_set 33 | ) 34 | expect_equal( 35 | two_class_set |> comment_add("none_cart"), 36 | two_class_set 37 | ) 38 | expect_snapshot( 39 | error = TRUE, 40 | two_class_set |> comment_add("toe", "foot") 41 | ) 42 | expect_snapshot( 43 | error = TRUE, 44 | two_class_set |> comment_add(letters, "foot") 45 | ) 46 | expect_snapshot( 47 | error = TRUE, 48 | two_class_set |> comment_add(1:2, "foot") 49 | ) 50 | expect_snapshot( 51 | error = TRUE, 52 | two_class_set |> comment_add("none_cart", 1:2) 53 | ) 54 | expect_snapshot( 55 | error = TRUE, 56 | comments_1 |> comment_add("none_cart", "Stuff.", append = FALSE) 57 | ) 58 | expect_snapshot( 59 | error = TRUE, 60 | comment_get(comments_1, id = letters) 61 | ) 62 | expect_snapshot( 63 | error = TRUE, 64 | comment_get(comments_1, id = letters[1]) 65 | ) 66 | expect_snapshot( 67 | error = TRUE, 68 | comments_1 |> comment_reset(letters) 69 | ) 70 | expect_snapshot( 71 | error = TRUE, 72 | comments_1 |> comment_reset("none_carts") 73 | ) 74 | }) 75 | 76 | test_that("print comments", { 77 | gatsby_1 <- "\"Whenever you feel like criticizing any one,\" he told me, \"just remember that all the people in this world haven't had the advantages that you've had.\"" 78 | gatsby_2 <- "My family have been prominent, well-to-do people in this middle-western city for three generations. The Carraways are something of a clan and we have a tradition that we're descended from the Dukes of Buccleuch, but the actual founder of my line was my grandfather's brother who came here in fifty-one, sent a substitute to the Civil War and started the wholesale hardware business that my father carries on today." 79 | gatsby_3 <- "Across the courtesy bay the white palaces of fashionable East Egg glittered along the water, and the history of the summer really begins on the evening I drove over there to have dinner with the Tom Buchanans. Daisy was my second cousin once removed and I'd known Tom in college. And just after the war I spent two days with them in Chicago." 80 | 81 | test <- 82 | two_class_res |> 83 | comment_add("none_cart", gatsby_1) |> 84 | comment_add("none_cart", gatsby_2) |> 85 | comment_add("none_glm", gatsby_3) 86 | 87 | expect_snapshot(comment_print(test)) 88 | expect_snapshot(comment_print(test, id = "none_glm")) 89 | }) 90 | -------------------------------------------------------------------------------- /tests/testthat/test-compat-vctrs.R: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # vec_restore() 3 | 4 | test_that("vec_restore() returns a workflow_set subclass if `x` retains structure", { 5 | for (x in workflow_set_objects) { 6 | expect_identical(vec_restore(x, x), x) 7 | expect_s3_class_workflow_set(vec_restore(x, x)) 8 | } 9 | }) 10 | 11 | test_that("vec_restore() returns workflow_set when row slicing", { 12 | for (x in workflow_set_objects) { 13 | row1 <- x[1, ] 14 | row0 <- x[0, ] 15 | 16 | expect_s3_class_workflow_set(vec_restore(row1, x)) 17 | expect_s3_class_workflow_set(vec_restore(row0, x)) 18 | } 19 | }) 20 | 21 | test_that("vec_restore() returns bare tibble if `x` loses column structure", { 22 | for (x in workflow_set_objects) { 23 | col <- x[1] 24 | expect_s3_class_bare_tibble(vec_restore(col, x)) 25 | } 26 | }) 27 | 28 | # ------------------------------------------------------------------------------ 29 | # vec_ptype2() 30 | 31 | test_that("vec_ptype2() is working", { 32 | for (x in workflow_set_objects) { 33 | x2 <- x 34 | x2$y <- 1 35 | x3 <- x 36 | x3$z <- 2 37 | 38 | tbl <- tibble::tibble(x = 1) 39 | df <- data.frame(x = 1) 40 | 41 | # workflow_set-workflow_set 42 | expect_identical(vec_ptype2(x, x), vec_slice(x, NULL)) 43 | expect_identical(vec_ptype2(x2, x3), new_workflow_set0(df_ptype2(x2, x3))) 44 | 45 | # workflow_set-tbl_df 46 | expect_identical(vec_ptype2(x, tbl), vec_ptype2(new_tibble0(x), tbl)) 47 | expect_identical(vec_ptype2(tbl, x), vec_ptype2(tbl, new_tibble0(x))) 48 | 49 | # workflow_set-df 50 | expect_identical(vec_ptype2(x, df), vec_ptype2(new_tibble0(x), df)) 51 | expect_identical(vec_ptype2(df, x), vec_ptype2(df, new_tibble0(x))) 52 | } 53 | }) 54 | 55 | # ------------------------------------------------------------------------------ 56 | # vec_cast() 57 | 58 | test_that("vec_cast() is working", { 59 | for (x in workflow_set_objects) { 60 | x2 <- x 61 | x2$y <- 1 62 | x3 <- x 63 | x3$z <- 2 64 | 65 | tbl <- new_tibble0(x) 66 | df <- as.data.frame(tbl) 67 | 68 | # workflow_set-workflow_set 69 | expect_identical(vec_cast(x, x), x) 70 | 71 | x2_expect <- x 72 | x2_expect$y <- NA_real_ 73 | expect_identical(vec_cast(x, x2), x2_expect) 74 | 75 | expect_error(vec_cast(x2, x3), class = "vctrs_error_cast_lossy_dropped") 76 | 77 | # workflow_set-tbl_df 78 | expect_identical(vec_cast(x, tbl), tbl) 79 | expect_error(vec_cast(tbl, x), class = "vctrs_error_incompatible_type") 80 | 81 | # workflow_set-df 82 | expect_identical(vec_cast(x, df), df) 83 | expect_error(vec_cast(df, x), class = "vctrs_error_incompatible_type") 84 | } 85 | }) 86 | 87 | # ------------------------------------------------------------------------------ 88 | # vctrs methods 89 | 90 | test_that("vec_ptype() returns a workflow_set", { 91 | for (x in workflow_set_objects) { 92 | expect_s3_class_workflow_set(vec_ptype(x)) 93 | } 94 | }) 95 | 96 | test_that("vec_slice() generally returns a workflow_set", { 97 | for (x in workflow_set_objects) { 98 | expect_s3_class_workflow_set(vec_slice(x, 0)) 99 | expect_s3_class_workflow_set(vec_slice(x, 1:2)) 100 | } 101 | }) 102 | 103 | test_that("vec_slice() can return a tibble if wflow_ids are duplicated", { 104 | for (x in workflow_set_objects) { 105 | expect_identical(vec_slice(x, c(1, 1)), vec_slice(new_tibble0(x), c(1, 1))) 106 | } 107 | }) 108 | 109 | test_that("vec_c() works", { 110 | for (x in workflow_set_objects) { 111 | tbl <- new_tibble0(x) 112 | 113 | expect_identical(vec_c(x), x) 114 | expect_identical(vec_c(x, x), vec_c(tbl, tbl)) 115 | expect_identical(vec_c(x[1:2, ], x[3, ]), x) 116 | } 117 | }) 118 | 119 | test_that("vec_rbind() works", { 120 | for (x in workflow_set_objects) { 121 | tbl <- new_tibble0(x) 122 | 123 | expect_identical(vec_rbind(x), x) 124 | expect_identical(vec_rbind(x, x), vec_rbind(tbl, tbl)) 125 | expect_identical(vec_rbind(x[1:2, ], x[3, ]), x) 126 | } 127 | }) 128 | 129 | test_that("vec_cbind() returns a bare tibble", { 130 | for (x in workflow_set_objects) { 131 | tbl <- new_tibble0(x) 132 | 133 | # Unlike vec_c() and vec_rbind(), the prototype of the output comes 134 | # from doing `x[0]`, which will drop the workflow_set class 135 | expect_identical(vec_cbind(x), vec_cbind(tbl)) 136 | expect_identical( 137 | vec_cbind(x, x, .name_repair = "minimal"), 138 | vec_cbind(tbl, tbl, .name_repair = "minimal") 139 | ) 140 | expect_identical( 141 | vec_cbind(x, tbl, .name_repair = "minimal"), 142 | vec_cbind(tbl, tbl, .name_repair = "minimal") 143 | ) 144 | } 145 | }) 146 | -------------------------------------------------------------------------------- /tests/testthat/test-extract.R: -------------------------------------------------------------------------------- 1 | skip_if_not_installed(c("kknn", "modeldata")) 2 | 3 | library(parsnip) 4 | library(rsample) 5 | library(recipes) 6 | data(Chicago, package = "modeldata") 7 | 8 | lr_spec <- linear_reg() |> set_engine("lm") 9 | 10 | set.seed(1) 11 | car_set_1 <- 12 | workflow_set( 13 | list( 14 | reg = recipe(mpg ~ ., data = mtcars) |> step_log(disp), 15 | nonlin = mpg ~ wt + 1 / sqrt(disp) 16 | ), 17 | list(lm = lr_spec) 18 | ) |> 19 | workflow_map( 20 | "fit_resamples", 21 | resamples = vfold_cv(mtcars, v = 3), 22 | control = tune::control_resamples(save_pred = TRUE) 23 | ) 24 | 25 | # ------------------------------------------------------------------------------ 26 | 27 | test_that("extracts", { 28 | # workflows specific errors, so we don't capture their messages 29 | expect_snapshot( 30 | error = TRUE, 31 | extract_fit_engine(car_set_1, id = "reg_lm") 32 | ) 33 | expect_snapshot( 34 | error = TRUE, 35 | extract_fit_parsnip(car_set_1, id = "reg_lm") 36 | ) 37 | expect_snapshot( 38 | error = TRUE, 39 | extract_mold(car_set_1, id = "reg_lm") 40 | ) 41 | expect_snapshot( 42 | error = TRUE, 43 | extract_recipe(car_set_1, id = "reg_lm") 44 | ) 45 | 46 | expect_s3_class( 47 | extract_preprocessor(car_set_1, id = "reg_lm"), 48 | "recipe" 49 | ) 50 | expect_s3_class( 51 | extract_spec_parsnip(car_set_1, id = "reg_lm"), 52 | "model_spec" 53 | ) 54 | expect_s3_class( 55 | extract_workflow(car_set_1, id = "reg_lm"), 56 | "workflow" 57 | ) 58 | expect_s3_class( 59 | extract_recipe(car_set_1, id = "reg_lm", estimated = FALSE), 60 | "recipe" 61 | ) 62 | 63 | expect_equal( 64 | car_set_1 |> extract_workflow("reg_lm"), 65 | car_set_1$info[[1]]$workflow[[1]] 66 | ) 67 | 68 | expect_equal( 69 | car_set_1 |> extract_workflow_set_result("reg_lm"), 70 | car_set_1$result[[1]] 71 | ) 72 | 73 | expect_snapshot(error = TRUE, { 74 | car_set_1 |> extract_workflow_set_result("Gideon Nav") 75 | }) 76 | 77 | expect_snapshot(error = TRUE, { 78 | car_set_1 |> extract_workflow("Coronabeth Tridentarius") 79 | }) 80 | }) 81 | 82 | 83 | test_that("extract parameter set from workflow set with untunable workflow", { 84 | rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) |> 85 | recipes::step_rm(date, ends_with("away")) 86 | lm_model <- parsnip::linear_reg() |> 87 | parsnip::set_engine("lm") 88 | bst_model <- 89 | parsnip::boost_tree( 90 | mode = "classification", 91 | trees = hardhat::tune("funky name \n") 92 | ) |> 93 | parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE) 94 | wf_set <- workflow_set( 95 | list(reg = rm_rec), 96 | list(lm = lm_model, bst = bst_model) 97 | ) 98 | 99 | lm_info <- hardhat::extract_parameter_set_dials(wf_set, id = "reg_lm") 100 | check_parameter_set_tibble(lm_info) 101 | expect_equal(nrow(lm_info), 0) 102 | }) 103 | 104 | test_that("extract parameter set from workflow set with tunable workflow", { 105 | rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) |> 106 | recipes::step_rm(date, ends_with("away")) 107 | lm_model <- parsnip::linear_reg() |> 108 | parsnip::set_engine("lm") 109 | bst_model <- 110 | parsnip::boost_tree( 111 | mode = "classification", 112 | trees = hardhat::tune("funky name \n") 113 | ) |> 114 | parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE) 115 | wf_set <- workflow_set( 116 | list(reg = rm_rec), 117 | list(lm = lm_model, bst = bst_model) 118 | ) 119 | 120 | c5_info <- extract_parameter_set_dials(wf_set, id = "reg_bst") 121 | expect_equal( 122 | c5_info, 123 | extract_parameter_set_dials(bst_model) 124 | ) 125 | check_parameter_set_tibble(c5_info) 126 | expect_equal(nrow(c5_info), 2) 127 | expect_true(all(c5_info$source == "model_spec")) 128 | expect_true(all(c5_info$component == "boost_tree")) 129 | expect_equal(c5_info$component_id, c("main", "engine")) 130 | nms <- c("trees", "rules") 131 | expect_equal(c5_info$name, nms) 132 | ids <- c("funky name \n", "rules") 133 | expect_equal(c5_info$id, ids) 134 | 135 | expect_equal(c5_info$object[[1]], dials::trees(c(1, 100))) 136 | expect_equal(c5_info$object[[2]], NA) 137 | 138 | c5_new_info <- 139 | c5_info |> 140 | update( 141 | rules = dials::new_qual_param( 142 | "logical", 143 | values = c(TRUE, FALSE), 144 | label = c(rules = "Rules") 145 | ) 146 | ) 147 | 148 | wf_set_2 <- 149 | wf_set |> 150 | option_add(id = "reg_bst", param_info = c5_new_info) 151 | 152 | check_parameter_set_tibble(c5_new_info) 153 | expect_s3_class(c5_new_info$object[[2]], "qual_param") 154 | expect_equal( 155 | c5_new_info, 156 | extract_parameter_set_dials(wf_set_2, "reg_bst") 157 | ) 158 | }) 159 | 160 | 161 | test_that("extract single parameter from workflow set with untunable workflow", { 162 | rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) |> 163 | recipes::step_rm(date, ends_with("away")) 164 | lm_model <- parsnip::linear_reg() |> 165 | parsnip::set_engine("lm") 166 | bst_model <- 167 | parsnip::boost_tree( 168 | mode = "classification", 169 | trees = hardhat::tune("funky name \n") 170 | ) |> 171 | parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE) 172 | wf_set <- workflow_set( 173 | list(reg = rm_rec), 174 | list(lm = lm_model, bst = bst_model) 175 | ) 176 | 177 | expect_snapshot( 178 | error = TRUE, 179 | hardhat::extract_parameter_dials( 180 | wf_set, 181 | id = "reg_lm", 182 | parameter = "non there" 183 | ) 184 | ) 185 | }) 186 | 187 | test_that("extract single parameter from workflow set with tunable workflow", { 188 | rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) |> 189 | recipes::step_rm(date, ends_with("away")) 190 | lm_model <- parsnip::linear_reg() |> 191 | parsnip::set_engine("lm") 192 | bst_model <- 193 | parsnip::boost_tree( 194 | mode = "classification", 195 | trees = hardhat::tune("funky name \n") 196 | ) |> 197 | parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE) 198 | wf_set <- workflow_set( 199 | list(reg = rm_rec), 200 | list(lm = lm_model, bst = bst_model) 201 | ) 202 | 203 | expect_equal( 204 | hardhat::extract_parameter_dials( 205 | wf_set, 206 | id = "reg_bst", 207 | parameter = "funky name \n" 208 | ), 209 | dials::trees(c(1, 100)) 210 | ) 211 | expect_equal( 212 | extract_parameter_dials(wf_set, id = "reg_bst", parameter = "rules"), 213 | NA 214 | ) 215 | }) 216 | -------------------------------------------------------------------------------- /tests/testthat/test-fit.R: -------------------------------------------------------------------------------- 1 | skip_on_cran() 2 | 3 | # ------------------------------------------------------------------------------ 4 | 5 | library(parsnip) 6 | suppressPackageStartupMessages(library(rsample)) 7 | suppressPackageStartupMessages(library(tune)) 8 | 9 | # ------------------------------------------------------------------------------ 10 | 11 | lr_spec <- linear_reg() |> set_engine("lm") 12 | knn_spec <- 13 | nearest_neighbor(neighbors = tune()) |> 14 | set_engine("kknn") |> 15 | set_mode("regression") 16 | 17 | set.seed(1) 18 | car_set_1 <- 19 | workflow_set( 20 | list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), 21 | list(lm = lr_spec) 22 | ) 23 | 24 | car_set_2 <- 25 | car_set_1 |> 26 | workflow_map( 27 | "fit_resamples", 28 | resamples = vfold_cv(mtcars, v = 3), 29 | control = tune::control_resamples(save_pred = TRUE) 30 | ) 31 | 32 | test_that("fit() errors informatively with workflow sets", { 33 | expect_snapshot(fit(car_set_1), error = TRUE) 34 | 35 | expect_snapshot(fit(car_set_2), error = TRUE) 36 | }) 37 | -------------------------------------------------------------------------------- /tests/testthat/test-fit_best.R: -------------------------------------------------------------------------------- 1 | skip_if_not_installed("kknn") 2 | skip_if_not_installed("modeldata") 3 | 4 | test_that("fit_best fits with correct hyperparameters", { 5 | skip_on_cran() 6 | 7 | library(tune) 8 | library(modeldata) 9 | library(rsample) 10 | library(yardstick) 11 | 12 | data(Chicago) 13 | Chicago <- Chicago[1:1195, ] 14 | 15 | time_val_split <- 16 | sliding_period( 17 | Chicago, 18 | date, 19 | "month", 20 | lookback = 38, 21 | assess_stop = 1 22 | ) 23 | 24 | chi_features_map <- 25 | chi_features_set |> 26 | option_add( 27 | control = control_grid(save_workflow = TRUE), 28 | # choose metrics resulting in different rankings 29 | metrics = metric_set(rmse, iic) 30 | ) |> 31 | workflow_map(resamples = time_val_split, grid = 21, seed = 1) 32 | 33 | chi_features_map 34 | 35 | # metric: rmse 36 | fit_best_wf <- fit_best(chi_features_map) 37 | expect_s3_class(fit_best_wf, "workflow") 38 | 39 | rankings <- rank_results(chi_features_map, "rmse") 40 | tune_res <- extract_workflow_set_result( 41 | chi_features_map, 42 | rankings$wflow_id[1] 43 | ) 44 | tune_params <- select_best(tune_res, metric = "rmse") 45 | manual_wf <- fit_best(tune_res, parameters = tune_params) 46 | 47 | manual_wf$pre$mold$blueprint$recipe$fit_times <- 48 | fit_best_wf$pre$mold$blueprint$recipe$fit_times 49 | manual_wf$fit$fit$elapsed$elapsed <- 50 | fit_best_wf$fit$fit$elapsed$elapsed 51 | expect_equal(manual_wf, fit_best_wf) 52 | 53 | # metric: iic 54 | fit_best_wf_2 <- fit_best(chi_features_map, "iic") 55 | expect_s3_class(fit_best_wf_2, "workflow") 56 | 57 | rankings_2 <- rank_results(chi_features_map, "iic") 58 | tune_res_2 <- extract_workflow_set_result( 59 | chi_features_map, 60 | rankings_2$wflow_id[1] 61 | ) 62 | tune_params_2 <- select_best(tune_res_2, metric = "iic") 63 | manual_wf_2 <- fit_best(tune_res_2, parameters = tune_params_2) 64 | 65 | manual_wf_2$pre$mold$blueprint$recipe$fit_times <- 66 | fit_best_wf_2$pre$mold$blueprint$recipe$fit_times 67 | manual_wf_2$fit$fit$elapsed$elapsed <- 68 | fit_best_wf_2$fit$fit$elapsed$elapsed 69 | expect_equal(manual_wf_2, fit_best_wf_2) 70 | }) 71 | 72 | test_that("fit_best errors informatively with bad inputs", { 73 | skip_on_cran() 74 | 75 | library(tune) 76 | library(modeldata) 77 | library(rsample) 78 | library(yardstick) 79 | 80 | data(Chicago) 81 | Chicago <- Chicago[1:1195, ] 82 | 83 | time_val_split <- 84 | sliding_period( 85 | Chicago, 86 | date, 87 | "month", 88 | lookback = 38, 89 | assess_stop = 1 90 | ) 91 | 92 | chi_features_map <- 93 | chi_features_set |> 94 | option_add( 95 | # set needed `save_workflow` option 96 | control = control_grid(save_workflow = TRUE) 97 | ) |> 98 | workflow_map(resamples = time_val_split, grid = 21, seed = 1) 99 | 100 | expect_snapshot( 101 | fit_best(chi_features_res), 102 | error = TRUE 103 | ) 104 | 105 | expect_snapshot( 106 | fit_best(chi_features_map, metric = "boop"), 107 | error = TRUE 108 | ) 109 | 110 | expect_snapshot( 111 | fit_best(chi_features_map, boop = "bop"), 112 | error = TRUE 113 | ) 114 | }) 115 | -------------------------------------------------------------------------------- /tests/testthat/test-leave-var-out-formulas.R: -------------------------------------------------------------------------------- 1 | form_data <- data.frame( 2 | a = 1:10, 3 | b = seq(1, 7, length.out = 10), 4 | c = factor(rep(letters[1:2], 5)), 5 | y = (1:10) * 2 6 | ) 7 | 8 | 9 | num_pred <- function(f) { 10 | length(all.vars(f[-2])) 11 | } 12 | 13 | num_terms <- function(f) { 14 | length(strsplit(deparse(f[-2]), "+", fixed = TRUE)[[1]]) 15 | } 16 | 17 | # ------------------------------------------------------------------------------ 18 | 19 | test_that("LOO var formulas", { 20 | expect_snapshot( 21 | error = TRUE, 22 | leave_var_out_formulas(y ~ 1, data = form_data) 23 | ) 24 | expect_snapshot( 25 | error = TRUE, 26 | leave_var_out_formulas(y ~ a, data = form_data) 27 | ) 28 | 29 | f_1 <- leave_var_out_formulas(y ~ ., data = form_data) 30 | expect_true(length(f_1) == 4) 31 | expect_equal(names(f_1), c(letters[1:3], "everything")) 32 | expect_equal( 33 | purrr::map_int(f_1, num_pred), 34 | c(a = 2L, b = 2L, c = 2L, everything = 1L) 35 | ) 36 | 37 | f_2 <- leave_var_out_formulas(y ~ (.)^2, data = form_data, FALSE) 38 | expect_true(length(f_2) == 6) 39 | expect_equal(names(f_2), c("a", "b", "c", "a:b", "a:c", "b:c")) 40 | expect_equal(unname(purrr::map_int(f_2, num_pred)), rep(2:3, each = 3)) 41 | 42 | f_3 <- leave_var_out_formulas(y ~ . + I(a^3), data = form_data, FALSE) 43 | expect_true(length(f_3) == 4) 44 | expect_equal(names(f_3), c("a", "b", "c", "I(a^3)")) 45 | expect_equal(unname(purrr::map_int(f_3, num_pred)), c(2, 2, 2, 3)) 46 | expect_equal(unname(purrr::map_int(f_3, num_terms)), c(2, 3, 3, 3)) 47 | }) 48 | -------------------------------------------------------------------------------- /tests/testthat/test-options.R: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | 3 | test_that("option management", { 4 | expect_no_error( 5 | set_1 <- two_class_set |> option_add(grid = 1) 6 | ) 7 | for (i in 1:nrow(set_1)) { 8 | expect_equal(unclass(set_1$option[[i]]), list(grid = 1)) 9 | } 10 | expect_no_error( 11 | set_2 <- two_class_set |> option_remove(grid) 12 | ) 13 | for (i in 1:nrow(set_2)) { 14 | expect_equal(unclass(set_2$option[[i]]), list()) 15 | } 16 | expect_no_error( 17 | set_3 <- two_class_set |> option_add(grid = 1, id = "none_cart") 18 | ) 19 | expect_equal(unclass(set_3$option[[1]]), list(grid = 1)) 20 | for (i in 2:nrow(set_3)) { 21 | expect_equal(unclass(set_3$option[[i]]), list()) 22 | } 23 | expect_no_error( 24 | set_4 <- two_class_set |> option_add_parameters() 25 | ) 26 | for (i in which(!grepl("glm", set_4$wflow_id))) { 27 | expect_true(all(names(set_4$option[[i]]) == "param_info")) 28 | expect_true(inherits(set_4$option[[i]]$param_info, "parameters")) 29 | } 30 | for (i in which(grepl("glm", set_4$wflow_id))) { 31 | expect_equal(unclass(set_4$option[[i]]), list()) 32 | } 33 | expect_no_error( 34 | set_5 <- two_class_set |> option_add_parameters(id = "none_cart") 35 | ) 36 | expect_true(all(names(set_5$option[[1]]) == "param_info")) 37 | expect_true(inherits(set_5$option[[1]]$param_info, "parameters")) 38 | for (i in 2:nrow(set_5)) { 39 | expect_equal(unclass(set_5$option[[i]]), list()) 40 | } 41 | }) 42 | 43 | 44 | test_that("option printing", { 45 | expect_output( 46 | print(two_class_res$option[[1]]), 47 | "a list of options with names: 'resamples', 'grid'" 48 | ) 49 | expect_equal( 50 | pillar::type_sum(two_class_res$option[[1]]), 51 | "opts[3]" 52 | ) 53 | }) 54 | 55 | 56 | test_that("check for bad options", { 57 | expect_snapshot_error( 58 | two_class_set |> option_add(grid2 = 1) 59 | ) 60 | expect_snapshot_error( 61 | two_class_set |> option_add(grid = 1, blueprint = 2) 62 | ) 63 | }) 64 | -------------------------------------------------------------------------------- /tests/testthat/test-predict.R: -------------------------------------------------------------------------------- 1 | skip_on_cran() 2 | 3 | # ------------------------------------------------------------------------------ 4 | 5 | library(parsnip) 6 | suppressPackageStartupMessages(library(rsample)) 7 | suppressPackageStartupMessages(library(tune)) 8 | 9 | # ------------------------------------------------------------------------------ 10 | 11 | lr_spec <- linear_reg() |> set_engine("lm") 12 | knn_spec <- 13 | nearest_neighbor(neighbors = tune()) |> 14 | set_engine("kknn") |> 15 | set_mode("regression") 16 | 17 | set.seed(1) 18 | car_set_1 <- 19 | workflow_set( 20 | list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), 21 | list(lm = lr_spec) 22 | ) 23 | 24 | car_set_2 <- 25 | car_set_1 |> 26 | workflow_map( 27 | "fit_resamples", 28 | resamples = vfold_cv(mtcars, v = 3), 29 | control = tune::control_resamples(save_pred = TRUE) 30 | ) 31 | 32 | test_that("predict() errors informatively with workflow sets", { 33 | expect_snapshot(predict(car_set_1), error = TRUE) 34 | 35 | expect_snapshot(predict(car_set_2), error = TRUE) 36 | }) 37 | -------------------------------------------------------------------------------- /tests/testthat/test-pull.R: -------------------------------------------------------------------------------- 1 | library(parsnip) 2 | library(rsample) 3 | 4 | lr_spec <- linear_reg() |> set_engine("lm") 5 | 6 | set.seed(1) 7 | car_set_1 <- 8 | workflow_set( 9 | list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), 10 | list(lm = lr_spec) 11 | ) |> 12 | workflow_map( 13 | "fit_resamples", 14 | resamples = vfold_cv(mtcars, v = 3), 15 | control = tune::control_resamples(save_pred = TRUE) 16 | ) 17 | 18 | # ------------------------------------------------------------------------------ 19 | 20 | test_that("pulling objects is deprecated", { 21 | expect_snapshot( 22 | error = TRUE, 23 | car_set_1 |> pull_workflow_set_result("reg_lm") 24 | ) 25 | expect_snapshot( 26 | error = TRUE, 27 | car_set_1 |> pull_workflow("reg_lm") 28 | ) 29 | }) 30 | -------------------------------------------------------------------------------- /tests/testthat/test-updates.R: -------------------------------------------------------------------------------- 1 | skip_if_not_installed("kknn") 2 | skip_if_not_installed("modeldata") 3 | 4 | library(parsnip) 5 | library(recipes) 6 | library(hardhat) 7 | 8 | data(two_class_dat, package = "modeldata") 9 | 10 | xgb <- boost_tree(trees = 3) |> set_mode("classification") 11 | rec <- 12 | recipe(Class ~ A + B, two_class_dat) |> 13 | step_normalize(A) |> 14 | step_normalize(B) 15 | 16 | sparse_bp <- default_recipe_blueprint(composition = "dgCMatrix") 17 | 18 | 19 | test_that("update model", { 20 | expect_no_error( 21 | new_set <- update_workflow_model(two_class_res, "none_cart", spec = xgb) 22 | ) 23 | expect_true( 24 | inherits( 25 | extract_spec_parsnip(new_set, id = "none_cart"), 26 | "boost_tree" 27 | ) 28 | ) 29 | expect_equal(new_set$result[[1]], list()) 30 | 31 | expect_no_error( 32 | new_new_set <- 33 | update_workflow_model( 34 | new_set, 35 | "none_glm", 36 | spec = xgb, 37 | formula = Class ~ log(A) + B 38 | ) 39 | ) 40 | new_wflow <- extract_workflow(new_new_set, "none_glm") 41 | expect_equal( 42 | new_wflow$fit$actions$model$formula, 43 | Class ~ log(A) + B 44 | ) 45 | }) 46 | 47 | test_that("update recipe", { 48 | expect_no_error( 49 | new_set <- update_workflow_recipe( 50 | two_class_res, 51 | "yj_trans_cart", 52 | recipe = rec 53 | ) 54 | ) 55 | new_rec <- extract_recipe(new_set, id = "yj_trans_cart", estimated = FALSE) 56 | 57 | expect_true(all(tidy(new_rec)$type == "normalize")) 58 | expect_equal(new_set$result[[4]], list()) 59 | 60 | expect_no_error( 61 | new_new_set <- 62 | update_workflow_recipe( 63 | new_set, 64 | "yj_trans_cart", 65 | recipe = rec, 66 | blueprint = sparse_bp 67 | ) 68 | ) 69 | new_wflow <- extract_workflow(new_new_set, "yj_trans_cart") 70 | expect_equal(new_wflow$pre$actions$recipe$blueprint, sparse_bp) 71 | }) 72 | -------------------------------------------------------------------------------- /tests/testthat/test-workflow-map.R: -------------------------------------------------------------------------------- 1 | skip_if_not_installed("kknn") 2 | skip_if_not_installed("modeldata") 3 | 4 | library(parsnip) 5 | suppressPackageStartupMessages(library(rsample)) 6 | suppressPackageStartupMessages(library(tune)) 7 | library(kknn) 8 | 9 | # ------------------------------------------------------------------------------ 10 | 11 | lr_spec <- linear_reg() |> set_engine("lm") 12 | knn_spec <- 13 | nearest_neighbor(neighbors = tune()) |> 14 | set_engine("kknn") |> 15 | set_mode("regression") 16 | glmn_spec <- 17 | linear_reg(penalty = tune()) |> 18 | set_engine("glmnet") 19 | 20 | set.seed(1) 21 | folds <- vfold_cv(mtcars, v = 3) 22 | 23 | car_set_1 <- 24 | workflow_set( 25 | list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), 26 | list(lm = lr_spec, knn = knn_spec) 27 | ) |> 28 | dplyr::slice(-4) 29 | 30 | # ------------------------------------------------------------------------------ 31 | 32 | test_that("basic mapping", { 33 | expect_no_error({ 34 | res_1 <- 35 | car_set_1 |> 36 | workflow_map(resamples = folds, seed = 2, grid = 2) 37 | }) 38 | 39 | # check reproducibility 40 | expect_no_error({ 41 | res_2 <- 42 | car_set_1 |> 43 | workflow_map(resamples = folds, seed = 2, grid = 2) 44 | }) 45 | expect_equal(collect_metrics(res_1), collect_metrics(res_2)) 46 | 47 | # --------------------------------------------------------------------------- 48 | 49 | expect_snapshot( 50 | error = TRUE, 51 | two_class_set |> 52 | workflow_map("foo", seed = 1, resamples = folds, grid = 2) 53 | ) 54 | 55 | expect_snapshot( 56 | error = TRUE, 57 | two_class_set |> 58 | workflow_map(fn = 1L, seed = 1, resamples = folds, grid = 2) 59 | ) 60 | 61 | expect_snapshot( 62 | error = TRUE, 63 | two_class_set |> 64 | workflow_map(fn = tune::tune_grid, seed = 1, resamples = folds, grid = 2) 65 | ) 66 | }) 67 | 68 | 69 | test_that("map logging", { 70 | # since the logging prints execution times, we capture output then make a 71 | # snapshot without those lines 72 | expect_no_error({ 73 | logging_res <- 74 | capture.output( 75 | res <- 76 | car_set_1 |> 77 | workflow_map(resamples = folds, seed = 2, verbose = TRUE), 78 | type = "message" 79 | ) 80 | }) 81 | logging_res <- logging_res[!grepl("s\\)$", logging_res)] 82 | expect_snapshot( 83 | cat(logging_res, sep = "\n") 84 | ) 85 | }) 86 | 87 | test_that("missing packages", { 88 | skip_if(rlang::is_installed("glmnet")) 89 | car_set_2 <- 90 | workflow_set( 91 | list(reg = mpg ~ .), 92 | list(glmn = glmn_spec) 93 | ) 94 | 95 | expect_snapshot( 96 | { 97 | res <- 98 | car_set_2 |> 99 | workflow_map(resamples = folds, seed = 2, verbose = FALSE) 100 | }, 101 | transform = function(lines) { 102 | gsub("\\([0-9]+ms\\)", "(ms)", lines) 103 | } 104 | ) 105 | expect_true(inherits(res, "workflow_set")) 106 | expect_equal(res$result[[1]], list()) 107 | }) 108 | 109 | 110 | test_that("failers", { 111 | skip_on_cran() 112 | car_set_3 <- 113 | workflow_set( 114 | list(reg = mpg ~ .), 115 | list(knn = knn_spec, lm = lr_spec) 116 | ) 117 | 118 | expect_no_error({ 119 | res_quiet <- 120 | car_set_3 |> 121 | workflow_map(resamples = folds, seed = 2, verbose = FALSE, grid = "a") 122 | }) 123 | expect_true(inherits(res_quiet, "workflow_set")) 124 | expect_true(inherits(res_quiet$result[[1]], "try-error")) 125 | 126 | expect_snapshot( 127 | { 128 | res_loud <- 129 | car_set_3 |> 130 | workflow_map(resamples = folds, seed = 2, verbose = TRUE, grid = "a") 131 | }, 132 | transform = function(lines) { 133 | gsub("\\([0-9]+ms\\)", "(ms)", lines) 134 | } 135 | ) 136 | expect_true(inherits(res_loud, "workflow_set")) 137 | expect_true(inherits(res_loud$result[[1]], "try-error")) 138 | }) 139 | 140 | test_that("workflow_map can handle cluster specifications", { 141 | skip_on_cran() 142 | skip_if_not_installed("tidyclust") 143 | library(tidyclust) 144 | library(recipes) 145 | 146 | set.seed(1) 147 | mtcars_tbl <- mtcars |> dplyr::select(where(is.numeric)) 148 | folds <- vfold_cv(mtcars_tbl, v = 3) 149 | 150 | wf_set_spec <- 151 | workflow_set( 152 | list(all = recipe(mtcars_tbl, ~.), some = ~ mpg + hp), 153 | list(km = k_means(num_clusters = tune())) 154 | ) 155 | 156 | wf_set_fit <- 157 | workflow_map(wf_set_spec, fn = "tune_cluster", resamples = folds) 158 | 159 | wf_set_fit 160 | }) 161 | 162 | test_that("fail informatively on mismatched spec/tuning function", { 163 | skip_on_cran() 164 | skip_if_not_installed("tidyclust") 165 | library(tidyclust) 166 | 167 | set.seed(1) 168 | mtcars_tbl <- mtcars |> dplyr::select(where(is.numeric)) 169 | folds <- vfold_cv(mtcars_tbl, v = 3) 170 | 171 | wf_set_1 <- 172 | workflow_set( 173 | list(reg = mpg ~ .), 174 | list( 175 | dt = decision_tree("regression", min_n = tune()), 176 | km = k_means(num_clusters = tune()) 177 | ) 178 | ) 179 | 180 | wf_set_2 <- 181 | workflow_set( 182 | list(reg = mpg ~ .), 183 | list( 184 | dt = decision_tree("regression", min_n = tune()), 185 | km = k_means(num_clusters = tune()), 186 | hc = hier_clust() 187 | ) 188 | ) 189 | 190 | wf_set_3 <- 191 | workflow_set( 192 | list(reg = mpg ~ .), 193 | list( 194 | dt = decision_tree("regression", min_n = tune()), 195 | nn = nearest_neighbor("regression", neighbors = tune()), 196 | km = k_means(num_clusters = tune()) 197 | ) 198 | ) 199 | 200 | # pass a cluster spec to `tune_grid()` 201 | expect_snapshot( 202 | error = TRUE, 203 | workflow_map(wf_set_1, resamples = folds) 204 | ) 205 | 206 | expect_snapshot( 207 | error = TRUE, 208 | workflow_map(wf_set_2, resamples = folds) 209 | ) 210 | 211 | # pass a model spec to `tune_cluster()` 212 | expect_snapshot( 213 | error = TRUE, 214 | workflow_map(wf_set_1, resamples = folds, fn = "tune_cluster") 215 | ) 216 | 217 | expect_snapshot( 218 | error = TRUE, 219 | workflow_map(wf_set_3, resamples = folds, fn = "tune_cluster") 220 | ) 221 | }) 222 | -------------------------------------------------------------------------------- /vignettes/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | *.R 3 | -------------------------------------------------------------------------------- /vignettes/evaluating-different-predictor-sets.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Evaluating different predictor sets" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Evaluating different predictor sets} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | ```{r} 11 | #| include: false 12 | knitr::opts_chunk$set( 13 | collapse = TRUE, 14 | eval = rlang::is_installed(c("modeldata", "recipes")), 15 | comment = "#>" 16 | ) 17 | library(parsnip) 18 | library(recipes) 19 | library(dplyr) 20 | library(workflowsets) 21 | library(ggplot2) 22 | theme_set(theme_bw() + theme(legend.position = "top")) 23 | ``` 24 | 25 | Workflow sets are collections of tidymodels workflow objects that are created as a set. A workflow object is a combination of a preprocessor (e.g. a formula or recipe) and a `parsnip` model specification. 26 | 27 | For some problems, users might want to try different combinations of preprocessing options, models, and/or predictor sets. In stead of creating a large number of individual objects, a cohort of workflows can be created simultaneously. 28 | 29 | In this example, we'll fit the same model but specify different predictor sets in the preprocessor list. 30 | 31 | Let's take a look at the customer churn data from the `modeldata` package: 32 | 33 | ```{r} 34 | #| label: tidymodels 35 | data(mlc_churn, package = "modeldata") 36 | ncol(mlc_churn) 37 | ``` 38 | 39 | There are 19 predictors, mostly numeric. This include aspects of their account, such as `number_customer_service_calls`. The outcome is a factor with two levels: "yes" and "no". 40 | 41 | We'll use a logistic regression to model the data. Since the data set is not small, we'll use basic 10-fold cross-validation to get resampled performance estimates. 42 | 43 | ```{r} 44 | #| label: churn-objects 45 | library(workflowsets) 46 | library(parsnip) 47 | library(rsample) 48 | library(dplyr) 49 | library(ggplot2) 50 | 51 | lr_model <- logistic_reg() |> set_engine("glm") 52 | 53 | set.seed(1) 54 | trn_tst_split <- initial_split(mlc_churn, strata = churn) 55 | 56 | # Resample the training set 57 | set.seed(1) 58 | folds <- vfold_cv(training(trn_tst_split), strata = churn) 59 | ``` 60 | 61 | We would make a basic workflow that uses this model specification and a basic formula. However, in this application, we'd like to know which predictors are associated with the best area under the ROC curve. 62 | 63 | ```{r} 64 | #| label: churn-formulas 65 | formulas <- leave_var_out_formulas(churn ~ ., data = mlc_churn) 66 | length(formulas) 67 | 68 | formulas[["area_code"]] 69 | ``` 70 | 71 | We create our workflow set: 72 | 73 | ```{r} 74 | #| label: churn-wflow-sets 75 | churn_workflows <- 76 | workflow_set( 77 | preproc = formulas, 78 | models = list(logistic = lr_model) 79 | ) 80 | churn_workflows 81 | ``` 82 | 83 | Since we are using basic logistic regression, there is nothing to tune for these models. Instead of `tune_grid()`, we'll use `tune::fit_resamples()` instead by giving that function name as the first argument: 84 | 85 | ```{r} 86 | #| label: churn-wflow-set-fits 87 | churn_workflows <- 88 | churn_workflows |> 89 | workflow_map("fit_resamples", resamples = folds) 90 | churn_workflows 91 | ``` 92 | 93 | To assess how to measure the effect of each predictor, let's subtract the area under the ROC curve for each predictor from the same metric from the full model. We'll match first by resampling ID, the compute the mean difference. 94 | 95 | ```{r} 96 | #| label: churn-metrics 97 | #| fig-width: 6 98 | #| fig-height: 5 99 | roc_values <- 100 | churn_workflows |> 101 | collect_metrics(summarize = FALSE) |> 102 | filter(.metric == "roc_auc") |> 103 | mutate(wflow_id = gsub("_logistic", "", wflow_id)) 104 | 105 | full_model <- 106 | roc_values |> 107 | filter(wflow_id == "everything") |> 108 | select(full_model = .estimate, id) 109 | 110 | differences <- 111 | roc_values |> 112 | filter(wflow_id != "everything") |> 113 | full_join(full_model, by = "id") |> 114 | mutate(performance_drop = full_model - .estimate) 115 | 116 | summary_stats <- 117 | differences |> 118 | group_by(wflow_id) |> 119 | summarize( 120 | std_err = sd(performance_drop) / sum(!is.na(performance_drop)), 121 | performance_drop = mean(performance_drop), 122 | lower = performance_drop - qnorm(0.975) * std_err, 123 | upper = performance_drop + qnorm(0.975) * std_err, 124 | .groups = "drop" 125 | ) |> 126 | mutate( 127 | wflow_id = factor(wflow_id), 128 | wflow_id = reorder(wflow_id, performance_drop) 129 | ) 130 | 131 | summary_stats |> filter(lower > 0) 132 | 133 | ggplot(summary_stats, aes(x = performance_drop, y = wflow_id)) + 134 | geom_point() + 135 | geom_errorbar(aes(xmin = lower, xmax = upper), width = .25) + 136 | ylab("") 137 | ``` 138 | 139 | From this, there are a predictors that, when not included in the model, have a significant effect on the performance metric. 140 | -------------------------------------------------------------------------------- /workflowsets.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: No 4 | SaveWorkspace: No 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: knitr 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | 18 | BuildType: Package 19 | PackageUseDevtools: Yes 20 | PackageInstallArgs: --no-multiarch --with-keep.source 21 | PackageRoxygenize: rd,collate,namespace 22 | --------------------------------------------------------------------------------