├── .Rbuildignore ├── .github ├── .gitignore ├── CODE_OF_CONDUCT.md └── workflows │ ├── R-CMD-check-hard.yaml │ ├── R-CMD-check.yaml │ ├── lock.yaml │ ├── pkgdown.yaml │ ├── pr-commands.yaml │ └── test-coverage.yaml ├── .gitignore ├── .vscode ├── extensions.json └── settings.json ├── DESCRIPTION ├── LICENSE ├── LICENSE.md ├── NAMESPACE ├── NEWS.md ├── R ├── control_race.R ├── control_sim_anneal.R ├── finetune-package.R ├── plot_race.R ├── racing_helpers.R ├── s3_register.R ├── sim_anneal_helpers.R ├── tune_race_anova.R ├── tune_race_win_loss.R ├── tune_sim_anneal.R └── zzz.R ├── README.Rmd ├── README.md ├── _pkgdown.yml ├── air.toml ├── codecov.yml ├── cran-comments.md ├── docs └── CNAME ├── finetune.Rproj ├── inst ├── WORDLIST └── data-raw │ └── sa_cart_test_objects.R ├── man ├── collect_predictions.Rd ├── control_race.Rd ├── control_sim_anneal.Rd ├── figures │ └── logo.png ├── finetune-package.Rd ├── plot_race.Rd ├── rmd │ └── anova-benchmark.md ├── show_best.Rd ├── tune_race_anova.Rd ├── tune_race_win_loss.Rd └── tune_sim_anneal.Rd └── tests ├── spelling.R ├── testthat.R └── testthat ├── _snaps ├── anova-filter.md ├── anova-overall.md ├── race-control.md ├── sa-control.md ├── sa-misc.md ├── sa-overall.md └── win-loss-overall.md ├── helper.R ├── sa_cart_test_objects.RData ├── test-anova-filter.R ├── test-anova-overall.R ├── test-condense_control.R ├── test-race-control.R ├── test-race-s3.R ├── test-random-integer-neighbors.R ├── test-sa-control.R ├── test-sa-decision.R ├── test-sa-misc.R ├── test-sa-overall.R ├── test-sa-perturb.R ├── test-win-loss-filter.R └── test-win-loss-overall.R /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^.*\.Rproj$ 2 | ^\.Rproj\.user$ 3 | ^CODE_OF_CONDUCT\.md$ 4 | ^LICENSE\.md$ 5 | ^README\.Rmd$ 6 | ^_pkgdown\.yml$ 7 | ^cran-comments\.md$ 8 | ^codecov\.yml$ 9 | ^\.github$ 10 | ^docs$ 11 | ^pkgdown$ 12 | ^revdep 13 | ^[\.]?air\.toml$ 14 | ^\.vscode$ 15 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at codeofconduct@posit.co. 63 | All complaints will be reviewed and investigated promptly and fairly. 64 | 65 | All community leaders are obligated to respect the privacy and security of the 66 | reporter of any incident. 67 | 68 | ## Enforcement Guidelines 69 | 70 | Community leaders will follow these Community Impact Guidelines in determining 71 | the consequences for any action they deem in violation of this Code of Conduct: 72 | 73 | ### 1. Correction 74 | 75 | **Community Impact**: Use of inappropriate language or other behavior deemed 76 | unprofessional or unwelcome in the community. 77 | 78 | **Consequence**: A private, written warning from community leaders, providing 79 | clarity around the nature of the violation and an explanation of why the 80 | behavior was inappropriate. A public apology may be requested. 81 | 82 | ### 2. Warning 83 | 84 | **Community Impact**: A violation through a single incident or series of 85 | actions. 86 | 87 | **Consequence**: A warning with consequences for continued behavior. No 88 | interaction with the people involved, including unsolicited interaction with 89 | those enforcing the Code of Conduct, for a specified period of time. This 90 | includes avoiding interactions in community spaces as well as external channels 91 | like social media. Violating these terms may lead to a temporary or permanent 92 | ban. 93 | 94 | ### 3. Temporary Ban 95 | 96 | **Community Impact**: A serious violation of community standards, including 97 | sustained inappropriate behavior. 98 | 99 | **Consequence**: A temporary ban from any sort of interaction or public 100 | communication with the community for a specified period of time. No public or 101 | private interaction with the people involved, including unsolicited interaction 102 | with those enforcing the Code of Conduct, is allowed during this period. 103 | Violating these terms may lead to a permanent ban. 104 | 105 | ### 4. Permanent Ban 106 | 107 | **Community Impact**: Demonstrating a pattern of violation of community 108 | standards, including sustained inappropriate behavior, harassment of an 109 | individual, or aggression toward or disparagement of classes of individuals. 110 | 111 | **Consequence**: A permanent ban from any sort of public interaction within the 112 | community. 113 | 114 | ## Attribution 115 | 116 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 117 | version 2.1, available at 118 | . 119 | 120 | Community Impact Guidelines were inspired by 121 | [Mozilla's code of conduct enforcement ladder][https://github.com/mozilla/inclusion]. 122 | 123 | For answers to common questions about this code of conduct, see the FAQ at 124 | . Translations are available at . 125 | 126 | [homepage]: https://www.contributor-covenant.org 127 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check-hard.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | # 4 | # NOTE: This workflow only directly installs "hard" dependencies, i.e. Depends, 5 | # Imports, and LinkingTo dependencies. Notably, Suggests dependencies are never 6 | # installed, with the exception of testthat, knitr, and rmarkdown. The cache is 7 | # never used to avoid accidentally restoring a cache containing a suggested 8 | # dependency. 9 | on: 10 | push: 11 | branches: [main, master] 12 | pull_request: 13 | 14 | name: R-CMD-check-hard.yaml 15 | 16 | permissions: read-all 17 | 18 | jobs: 19 | check-no-suggests: 20 | runs-on: ${{ matrix.config.os }} 21 | 22 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 23 | 24 | strategy: 25 | fail-fast: false 26 | matrix: 27 | config: 28 | - {os: ubuntu-latest, r: 'release'} 29 | 30 | env: 31 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 32 | R_KEEP_PKG_SOURCE: yes 33 | 34 | steps: 35 | - uses: actions/checkout@v4 36 | 37 | - uses: r-lib/actions/setup-pandoc@v2 38 | 39 | - uses: r-lib/actions/setup-r@v2 40 | with: 41 | r-version: ${{ matrix.config.r }} 42 | http-user-agent: ${{ matrix.config.http-user-agent }} 43 | use-public-rspm: true 44 | 45 | - uses: r-lib/actions/setup-r-dependencies@v2 46 | with: 47 | dependencies: '"hard"' 48 | cache: false 49 | extra-packages: | 50 | any::rcmdcheck 51 | any::testthat 52 | any::knitr 53 | any::rmarkdown 54 | needs: check 55 | 56 | - uses: r-lib/actions/check-r-package@v2 57 | with: 58 | upload-snapshots: true 59 | build_args: 'c("--no-manual","--compact-vignettes=gs+qpdf")' 60 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | # 4 | # NOTE: This workflow is overkill for most R packages and 5 | # check-standard.yaml is likely a better choice. 6 | # usethis::use_github_action("check-standard") will install it. 7 | on: 8 | push: 9 | branches: [main, master] 10 | pull_request: 11 | 12 | name: R-CMD-check.yaml 13 | 14 | permissions: read-all 15 | 16 | jobs: 17 | R-CMD-check: 18 | runs-on: ${{ matrix.config.os }} 19 | 20 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | config: 26 | - {os: macos-latest, r: 'release'} 27 | 28 | - {os: windows-latest, r: 'release'} 29 | # use 4.0 or 4.1 to check with rtools40's older compiler 30 | - {os: windows-latest, r: 'oldrel-4'} 31 | 32 | - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} 33 | - {os: ubuntu-latest, r: 'release'} 34 | - {os: ubuntu-latest, r: 'oldrel-1'} 35 | - {os: ubuntu-latest, r: 'oldrel-2'} 36 | - {os: ubuntu-latest, r: 'oldrel-3'} 37 | - {os: ubuntu-latest, r: 'oldrel-4'} 38 | 39 | env: 40 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 41 | R_KEEP_PKG_SOURCE: yes 42 | 43 | steps: 44 | - uses: actions/checkout@v4 45 | 46 | - uses: r-lib/actions/setup-pandoc@v2 47 | 48 | - uses: r-lib/actions/setup-r@v2 49 | with: 50 | r-version: ${{ matrix.config.r }} 51 | http-user-agent: ${{ matrix.config.http-user-agent }} 52 | use-public-rspm: true 53 | 54 | - uses: r-lib/actions/setup-r-dependencies@v2 55 | with: 56 | extra-packages: any::rcmdcheck 57 | needs: check 58 | 59 | - uses: r-lib/actions/check-r-package@v2 60 | with: 61 | upload-snapshots: true 62 | build_args: 'c("--no-manual","--compact-vignettes=gs+qpdf")' 63 | -------------------------------------------------------------------------------- /.github/workflows/lock.yaml: -------------------------------------------------------------------------------- 1 | name: 'Lock Threads' 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 * * *' 6 | 7 | jobs: 8 | lock: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: dessant/lock-threads@v2 12 | with: 13 | github-token: ${{ github.token }} 14 | issue-lock-inactive-days: '14' 15 | # issue-exclude-labels: '' 16 | # issue-lock-labels: 'outdated' 17 | issue-lock-comment: > 18 | This issue has been automatically locked. If you believe you have 19 | found a related problem, please file a new issue (with a reprex: 20 | ) and link to this issue. 21 | issue-lock-reason: '' 22 | pr-lock-inactive-days: '14' 23 | # pr-exclude-labels: 'wip' 24 | pr-lock-labels: '' 25 | pr-lock-comment: > 26 | This pull request has been automatically locked. If you believe you 27 | have found a related problem, please file a new issue (with a reprex: 28 | ) and link to this issue. 29 | pr-lock-reason: '' 30 | # process-only: 'issues' 31 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | release: 8 | types: [published] 9 | workflow_dispatch: 10 | 11 | name: pkgdown.yaml 12 | 13 | permissions: read-all 14 | 15 | jobs: 16 | pkgdown: 17 | runs-on: ubuntu-latest 18 | # Only restrict concurrency for non-PR jobs 19 | concurrency: 20 | group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }} 21 | env: 22 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 23 | permissions: 24 | contents: write 25 | steps: 26 | - uses: actions/checkout@v4 27 | 28 | - uses: r-lib/actions/setup-pandoc@v2 29 | 30 | - uses: r-lib/actions/setup-r@v2 31 | with: 32 | use-public-rspm: true 33 | 34 | - uses: r-lib/actions/setup-r-dependencies@v2 35 | with: 36 | extra-packages: any::pkgdown, local::. 37 | needs: website 38 | 39 | - name: Build site 40 | run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE) 41 | shell: Rscript {0} 42 | 43 | - name: Deploy to GitHub pages 🚀 44 | if: github.event_name != 'pull_request' 45 | uses: JamesIves/github-pages-deploy-action@v4.5.0 46 | with: 47 | clean: false 48 | branch: gh-pages 49 | folder: docs 50 | -------------------------------------------------------------------------------- /.github/workflows/pr-commands.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | issue_comment: 5 | types: [created] 6 | 7 | name: pr-commands.yaml 8 | 9 | permissions: read-all 10 | 11 | jobs: 12 | document: 13 | if: ${{ github.event.issue.pull_request && (github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') && startsWith(github.event.comment.body, '/document') }} 14 | name: document 15 | runs-on: ubuntu-latest 16 | env: 17 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 18 | permissions: 19 | contents: write 20 | steps: 21 | - uses: actions/checkout@v4 22 | 23 | - uses: r-lib/actions/pr-fetch@v2 24 | with: 25 | repo-token: ${{ secrets.GITHUB_TOKEN }} 26 | 27 | - uses: r-lib/actions/setup-r@v2 28 | with: 29 | use-public-rspm: true 30 | 31 | - uses: r-lib/actions/setup-r-dependencies@v2 32 | with: 33 | extra-packages: any::roxygen2 34 | needs: pr-document 35 | 36 | - name: Document 37 | run: roxygen2::roxygenise() 38 | shell: Rscript {0} 39 | 40 | - name: commit 41 | run: | 42 | git config --local user.name "$GITHUB_ACTOR" 43 | git config --local user.email "$GITHUB_ACTOR@users.noreply.github.com" 44 | git add man/\* NAMESPACE 45 | git commit -m 'Document' 46 | 47 | - uses: r-lib/actions/pr-push@v2 48 | with: 49 | repo-token: ${{ secrets.GITHUB_TOKEN }} 50 | 51 | style: 52 | if: ${{ github.event.issue.pull_request && (github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') && startsWith(github.event.comment.body, '/style') }} 53 | name: style 54 | runs-on: ubuntu-latest 55 | env: 56 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 57 | permissions: 58 | contents: write 59 | steps: 60 | - uses: actions/checkout@v4 61 | 62 | - uses: r-lib/actions/pr-fetch@v2 63 | with: 64 | repo-token: ${{ secrets.GITHUB_TOKEN }} 65 | 66 | - uses: r-lib/actions/setup-r@v2 67 | 68 | - name: Install dependencies 69 | run: install.packages("styler") 70 | shell: Rscript {0} 71 | 72 | - name: Style 73 | run: styler::style_pkg() 74 | shell: Rscript {0} 75 | 76 | - name: commit 77 | run: | 78 | git config --local user.name "$GITHUB_ACTOR" 79 | git config --local user.email "$GITHUB_ACTOR@users.noreply.github.com" 80 | git add \*.R 81 | git commit -m 'Style' 82 | 83 | - uses: r-lib/actions/pr-push@v2 84 | with: 85 | repo-token: ${{ secrets.GITHUB_TOKEN }} 86 | -------------------------------------------------------------------------------- /.github/workflows/test-coverage.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | 8 | name: test-coverage.yaml 9 | 10 | permissions: read-all 11 | 12 | jobs: 13 | test-coverage: 14 | runs-on: ubuntu-latest 15 | env: 16 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | 21 | - uses: r-lib/actions/setup-r@v2 22 | with: 23 | use-public-rspm: true 24 | 25 | - uses: r-lib/actions/setup-r-dependencies@v2 26 | with: 27 | extra-packages: any::covr, any::xml2 28 | needs: coverage 29 | 30 | - name: Test coverage 31 | run: | 32 | cov <- covr::package_coverage( 33 | quiet = FALSE, 34 | clean = FALSE, 35 | install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package") 36 | ) 37 | print(cov) 38 | covr::to_cobertura(cov) 39 | shell: Rscript {0} 40 | 41 | - uses: codecov/codecov-action@v5 42 | with: 43 | # Fail if error if not on PR, or if on PR and token is given 44 | fail_ci_if_error: ${{ github.event_name != 'pull_request' || secrets.CODECOV_TOKEN }} 45 | files: ./cobertura.xml 46 | plugins: noop 47 | disable_search: true 48 | token: ${{ secrets.CODECOV_TOKEN }} 49 | 50 | - name: Show testthat output 51 | if: always() 52 | run: | 53 | ## -------------------------------------------------------------------- 54 | find '${{ runner.temp }}/package' -name 'testthat.Rout*' -exec cat '{}' \; || true 55 | shell: bash 56 | 57 | - name: Upload test results 58 | if: failure() 59 | uses: actions/upload-artifact@v4 60 | with: 61 | name: coverage-test-failures 62 | path: ${{ runner.temp }}/package 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .RData 4 | .Ruserdata 5 | .DS_Store 6 | docs 7 | revdep/* 8 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "Posit.air-vscode" 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[r]": { 3 | "editor.formatOnSave": true, 4 | "editor.defaultFormatter": "Posit.air-vscode" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: finetune 2 | Title: Additional Functions for Model Tuning 3 | Version: 1.2.1.9000 4 | Authors@R: c( 5 | person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"), 6 | comment = c(ORCID = "0000-0003-2402-136X")), 7 | person("Posit Software, PBC", role = c("cph", "fnd"), 8 | comment = c(ROR = "03wc8by49")) 9 | ) 10 | Description: The ability to tune models is important. 'finetune' enhances 11 | the 'tune' package by providing more specialized methods for finding 12 | reasonable values of model tuning parameters. Two racing methods 13 | described by Kuhn (2014) are included. An 14 | iterative search method using generalized simulated annealing (Bohachevsky, 15 | Johnson and Stein, 1986) is also 16 | included. 17 | License: MIT + file LICENSE 18 | URL: https://github.com/tidymodels/finetune, 19 | https://finetune.tidymodels.org 20 | BugReports: https://github.com/tidymodels/finetune/issues 21 | Depends: 22 | R (>= 4.1), 23 | tune (>= 1.2.0) 24 | Imports: 25 | cli, 26 | dials (>= 0.3.0), 27 | dplyr (>= 1.1.1), 28 | ggplot2, 29 | parsnip (>= 1.1.0), 30 | purrr (>= 1.0.0), 31 | rlang, 32 | tibble, 33 | tidyr, 34 | tidyselect, 35 | utils, 36 | vctrs, 37 | workflows (>= 0.2.6) 38 | Suggests: 39 | BradleyTerry2, 40 | covr, 41 | discrim, 42 | kknn, 43 | klaR, 44 | lme4, 45 | modeldata, 46 | ranger, 47 | recipes (>= 0.2.0), 48 | rpart, 49 | rsample, 50 | spelling, 51 | testthat, 52 | yardstick 53 | Config/Needs/website: tidyverse/tidytemplate 54 | Config/testthat/edition: 3 55 | Config/usethis/last-upkeep: 2025-05-20 56 | Encoding: UTF-8 57 | Language: en-US 58 | Roxygen: list(markdown = TRUE) 59 | RoxygenNote: 7.3.2 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2025 2 | COPYRIGHT HOLDER: finetune authors 3 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2025 finetune authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method(collect_metrics,tune_race) 4 | S3method(collect_predictions,tune_race) 5 | S3method(print,control_race) 6 | S3method(print,control_sim_anneal) 7 | S3method(show_best,tune_race) 8 | S3method(tune_race_anova,default) 9 | S3method(tune_race_anova,formula) 10 | S3method(tune_race_anova,model_spec) 11 | S3method(tune_race_anova,recipe) 12 | S3method(tune_race_anova,workflow) 13 | S3method(tune_race_win_loss,default) 14 | S3method(tune_race_win_loss,formula) 15 | S3method(tune_race_win_loss,model_spec) 16 | S3method(tune_race_win_loss,recipe) 17 | S3method(tune_race_win_loss,workflow) 18 | S3method(tune_sim_anneal,default) 19 | S3method(tune_sim_anneal,formula) 20 | S3method(tune_sim_anneal,model_spec) 21 | S3method(tune_sim_anneal,recipe) 22 | S3method(tune_sim_anneal,workflow) 23 | export(control_race) 24 | export(control_sim_anneal) 25 | export(plot_race) 26 | export(tune_race_anova) 27 | export(tune_race_win_loss) 28 | export(tune_sim_anneal) 29 | import(tune) 30 | importFrom(dplyr,distinct) 31 | importFrom(rlang,caller_env) 32 | importFrom(rlang,syms) 33 | importFrom(stats,coef) 34 | importFrom(stats,confint) 35 | importFrom(stats,dist) 36 | importFrom(stats,lm) 37 | importFrom(stats,qt) 38 | importFrom(stats,reorder) 39 | importFrom(stats,rnorm) 40 | importFrom(stats,runif) 41 | importFrom(stats,setNames) 42 | importFrom(utils,globalVariables) 43 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # finetune (development version) 2 | 3 | # finetune 1.2.1 4 | 5 | * Maintenance release required by CRAN. 6 | 7 | * Transition from the magrittr pipe to the base R pipe. 8 | 9 | # finetune 1.2.0 10 | 11 | ## New Features 12 | 13 | * finetune now fully supports models in the "censored regression" mode. These models can be fit, tuned, and evaluated like the regression and classification modes. [tidymodels.org](https://www.tidymodels.org/learn/#category=survival%20analysis) has more information and tutorials on how to work with survival analysis models. 14 | 15 | * Improved error message from `tune_sim_anneal()` when values in the supplied `param_info` do not encompass all values evaluated in the `initial` grid. This most often happens when a user mistakenly supplies different parameter sets to the function that generated the initial results and `tune_sim_anneal()`. 16 | 17 | * `autoplot()` methods for racing objects will now use integers in x-axis breaks (#75). 18 | 19 | * Enabling the `verbose_elim` control option for `tune_race_anova()` will now additionally introduce a message confirming that the function is evaluating against the burn-in resamples. 20 | 21 | * Updates based on the new version of tune, primarily for survival analysis models (#104). 22 | 23 | ## Bug Fixes 24 | 25 | * Fixed bug where `tune_sim_anneal()` would fail when supplied parameters needing finalization. The function will now finalize needed parameter ranges internally (#39). 26 | 27 | * Fixed bug where packages specified in `control_race(pkgs)` were not actually loaded in `tune_race_anova()` (#74). 28 | 29 | ## Breaking Change 30 | 31 | * Ellipses (...) are now used consistently in the package to require optional arguments to be named. `collect_predictions()`, `collect_metrics()` and `show_best()` methods previously had ellipses at the end of the function signature that have been moved to follow the last argument without a default value. Optional arguments previously passed by position will now error informatively prompting them to be named (#105). 32 | 33 | # finetune 1.1.0 34 | 35 | * Various minor changes to keep up with developments in the tune and dplyr packages (#60) (#62) (#67) (#68). 36 | 37 | * Corrects `.config` output with `save_pred = TRUE` in `tune_sim_anneal()`. The function previously outputted a constant `Model1_Preprocessor1` in the `.predictions` slot, and now provides `.config` values that align with those in `.metrics` (#57). 38 | 39 | * An `eval_time` attribute was added to tune objects produced by finetune. 40 | 41 | # finetune 1.0.1 42 | 43 | * For racing: 44 | - `collect_metrics()` and `collect_predictions()` have a `'complete'` argument that only returns results for model configurations that were fully resampled. 45 | - `select_best()` and `show_best()` now only show results for model configurations that were fully resampled. 46 | 47 | * `tune_race_anova()`, `tune_race_win_loss()`, and `tune_sim_anneal()` no longer error if `control` argument isn't a the corresponding `control_*()` object. Will work as long as the object passed to `control` includes the same elements as the required `control_*()` object. 48 | 49 | * The `control_sim_anneal()` got a new argument `verbose_iter` that is used to control the verbosity of the iterative calculations. This change means that the `verbose` argument is being passed to `tune_grid()` to control its verbosity. 50 | 51 | 52 | # finetune 1.0.0 53 | 54 | * An informative error is given when there are not enough resamples for racing (#33). 55 | 56 | * `tune_sim_anneal()` was not passing all arguments to `tune_grid()` (#40). 57 | 58 | # finetune 0.2.0 59 | 60 | * Maintenance release for CRAN requirements. 61 | 62 | * Use `extract_parameter_set_dials()` instead of `parameters()` to get parameter sets. 63 | 64 | * Removed some pillar-related S3 methods that currently live in tune. 65 | 66 | # finetune 0.1.1 67 | 68 | * `tune_sim_anneal()` only overwrites tuning parameter information when they originally contain unknowns. 69 | 70 | # finetune 0.1.0 71 | 72 | * A check was added to make sure that `lme4` or `BradleyTerry2` are installed (#8) 73 | 74 | * Added `pillar` methods for formatting `tune` objects in list columns. 75 | 76 | * Fixed bug in `random_integer_neighbor_calc()` to keep values inside range (#10) 77 | 78 | * `tune_sim_anneal()` now retains a finalized parameter set and replaces any existing parameter set that was not finalized (#14) 79 | 80 | * A bug in win/loss racing was fixed for cases when one tuning parameter had results that were so bad that it broke the Bradley-Terry model (#7) 81 | 82 | # finetune 0.0.1 83 | 84 | * First CRAN release 85 | -------------------------------------------------------------------------------- /R/control_race.R: -------------------------------------------------------------------------------- 1 | #' Control aspects of the grid search racing process 2 | #' 3 | #' @inheritParams tune::control_grid 4 | #' @param verbose_elim A logical for whether logging of the elimination of 5 | #' tuning parameter combinations should occur. 6 | #' @param burn_in An integer for how many resamples should be completed for all 7 | #' grid combinations before parameter filtering begins. 8 | #' @param num_ties An integer for when tie-breaking should occur. If there are 9 | #' two final parameter combinations being evaluated, `num_ties` specified how 10 | #' many more resampling iterations should be evaluated. After `num_ties` more 11 | #' iterations, the parameter combination with the current best results is 12 | #' retained. 13 | #' @param alpha The alpha level for a one-sided confidence interval for each 14 | #' parameter combination. 15 | #' @param randomize Should the resamples be evaluated in a random order? By 16 | #' default, the resamples are evaluated in a random order so the random number 17 | #' seed should be control prior to calling this method (to be reproducible). 18 | #' For repeated cross-validation the randomization occurs within each repeat. 19 | #' @return An object of class `control_race` that echos the argument values. 20 | #' @examples 21 | #' control_race() 22 | #' @export 23 | control_race <- 24 | function( 25 | verbose = FALSE, 26 | verbose_elim = FALSE, 27 | allow_par = TRUE, 28 | extract = NULL, 29 | save_pred = FALSE, 30 | burn_in = 3, 31 | num_ties = 10, 32 | alpha = 0.05, 33 | randomize = TRUE, 34 | pkgs = NULL, 35 | save_workflow = FALSE, 36 | event_level = "first", 37 | parallel_over = "everything", 38 | backend_options = NULL 39 | ) { 40 | # Any added arguments should also be added in superset control functions 41 | # in other package. In other words, if tune_grid adds an option, the same 42 | # object should be added here (regardless) 43 | 44 | tune::val_class_and_single(verbose, "logical", "control_race()") 45 | tune::val_class_and_single(verbose_elim, "logical", "control_race()") 46 | tune::val_class_and_single(allow_par, "logical", "control_race()") 47 | tune::val_class_and_single(alpha, "numeric", "control_race()") 48 | tune::val_class_and_single(burn_in, "numeric", "control_race()") 49 | tune::val_class_and_single(randomize, "logical", "control_race()") 50 | tune::val_class_and_single(num_ties, "numeric", "control_race()") 51 | tune::val_class_and_single(save_pred, "logical", "control_race()") 52 | tune::val_class_or_null(pkgs, "character", "control_race()") 53 | tune::val_class_and_single(event_level, "character", "control_race()") 54 | tune::val_class_or_null(extract, "function", "control_race()") 55 | tune::val_class_and_single(save_workflow, "logical", "control_race()") 56 | if (!is.null(parallel_over)) { 57 | val_parallel_over(parallel_over, "control_sim_anneal()") 58 | } 59 | 60 | if (alpha <= 0 | alpha >= 1) { 61 | cli::cli_abort("{.arg alpha} should be on (0, 1).") 62 | } 63 | 64 | if (burn_in < 2) { 65 | cli::cli_abort("{.arg burn_in} should be at least two.") 66 | } 67 | 68 | res <- list( 69 | verbose = verbose, 70 | verbose_elim = verbose_elim, 71 | allow_par = allow_par, 72 | extract = extract, 73 | save_pred = save_pred, 74 | alpha = alpha, 75 | burn_in = burn_in, 76 | num_ties = num_ties, 77 | randomize = randomize, 78 | pkgs = pkgs, 79 | save_workflow = save_workflow, 80 | parallel_over = parallel_over, 81 | event_level = event_level, 82 | backend_options = backend_options 83 | ) 84 | 85 | class(res) <- c("control_race") 86 | res 87 | } 88 | 89 | #' @export 90 | print.control_race <- function(x, ...) { 91 | cat("Racing method control object\n") 92 | invisible(x) 93 | } 94 | -------------------------------------------------------------------------------- /R/control_sim_anneal.R: -------------------------------------------------------------------------------- 1 | #' Control aspects of the simulated annealing search process 2 | #' @inheritParams tune::control_grid 3 | #' @param verbose_iter A logical for logging results of the search 4 | #' process. Defaults to FALSE. If using a dark IDE theme, some logging 5 | #' messages might be hard to see; try setting the `tidymodels.dark` option 6 | #' with `options(tidymodels.dark = TRUE)` to print lighter colors. 7 | #' @param no_improve The integer cutoff for the number of iterations without 8 | #' better results. 9 | #' @param restart The number of iterations with no improvement before new tuning 10 | #' parameter candidates are generated from the last, overall best conditions. 11 | #' @param radius Two real numbers on `(0, 1)` describing what a value "in the 12 | #' neighborhood" of the current result should be. If all numeric parameters were 13 | #' scaled to be on the `[0, 1]` scale, these values set the min. and max. 14 | #' of a radius of a circle used to generate new numeric parameter values. 15 | #' @param flip A real number between `[0, 1]` for the probability of changing 16 | #' any non-numeric parameter values at each iteration. 17 | #' @param cooling_coef A real, positive number to influence the cooling 18 | #' schedule. Larger values decrease the probability of accepting a sub-optimal 19 | #' parameter setting. 20 | #' @param time_limit A number for the minimum number of _minutes_ (elapsed) that 21 | #' the function should execute. The elapsed time is evaluated at internal 22 | #' checkpoints and, if over time, the results at that time are returned (with 23 | #' a warning). This means that the `time_limit` is not an exact limit, but a 24 | #' minimum time limit. 25 | #' @param save_history A logical to save the iteration details of the search. 26 | #' These are saved to `tempdir()` named `sa_history.RData`. These results are 27 | #' deleted when the R session ends. This option is only useful for teaching 28 | #' purposes. 29 | #' @return An object of class `control_sim_anneal` that echos the argument values. 30 | #' @examples 31 | #' control_sim_anneal() 32 | #' @export 33 | control_sim_anneal <- 34 | function( 35 | verbose = FALSE, 36 | verbose_iter = TRUE, 37 | no_improve = Inf, 38 | restart = 8L, 39 | radius = c(0.05, 0.15), 40 | flip = 3 / 4, 41 | cooling_coef = 0.02, 42 | extract = NULL, 43 | save_pred = FALSE, 44 | time_limit = NA, 45 | pkgs = NULL, 46 | save_workflow = FALSE, 47 | save_history = FALSE, 48 | event_level = "first", 49 | parallel_over = NULL, 50 | allow_par = TRUE, 51 | backend_options = NULL 52 | ) { 53 | # Any added arguments should also be added in superset control functions 54 | # in other package. In other words, if tune_grid adds an option, the same 55 | # object should be added here (regardless) 56 | 57 | tune::val_class_and_single(verbose, "logical", "control_sim_anneal()") 58 | tune::val_class_and_single(verbose_iter, "logical", "control_sim_anneal()") 59 | tune::val_class_and_single(save_pred, "logical", "control_sim_anneal()") 60 | tune::val_class_and_single( 61 | no_improve, 62 | c("numeric", "integer"), 63 | "control_sim_anneal()" 64 | ) 65 | tune::val_class_and_single( 66 | restart, 67 | c("numeric", "integer"), 68 | "control_sim_anneal()" 69 | ) 70 | tune::val_class_and_single(flip, "numeric", "control_sim_anneal()") 71 | tune::val_class_and_single(cooling_coef, "numeric", "control_sim_anneal()") 72 | tune::val_class_or_null(extract, "function", "control_sim_anneal()") 73 | tune::val_class_and_single( 74 | time_limit, 75 | c("logical", "numeric"), 76 | "control_sim_anneal()" 77 | ) 78 | tune::val_class_or_null(pkgs, "character", "control_sim_anneal()") 79 | tune::val_class_and_single(save_workflow, "logical", "control_sim_anneal()") 80 | tune::val_class_and_single(save_history, "logical", "control_sim_anneal()") 81 | tune::val_class_and_single(allow_par, "logical", "control_sim_anneal()") 82 | 83 | if (!is.null(parallel_over)) { 84 | val_parallel_over(parallel_over, "control_sim_anneal()") 85 | } 86 | 87 | if (!is.numeric(radius) | !length(radius) == 2) { 88 | cli::cli_abort("Argument {.arg radius} should be two numeric values.") 89 | } 90 | radius <- sort(radius) 91 | radius[radius <= 0] <- 0.001 92 | radius[radius >= 1] <- 0.999 93 | 94 | flip[flip < 0] <- 0 95 | flip[flip > 1] <- 1 96 | cooling_coef[cooling_coef <= 0] <- 0.0001 97 | 98 | if (no_improve < 2) { 99 | cli::cli_abort("{.arg no_improve} should be > 1.") 100 | } 101 | if (restart < 2) { 102 | cli::cli_abort("{.arg restart} should be > 1.") 103 | } 104 | if (!is.infinite(restart) && restart > no_improve) { 105 | cli::cli_alert_warning( 106 | "Parameter restart is scheduled after {restart} poor iterations but the search will stop after {no_improve}." 107 | ) 108 | } 109 | 110 | res <- 111 | list( 112 | verbose = verbose, 113 | verbose_iter = verbose_iter, 114 | no_improve = no_improve, 115 | restart = restart, 116 | radius = radius, 117 | flip = flip, 118 | cooling_coef = cooling_coef, 119 | extract = extract, 120 | save_pred = save_pred, 121 | time_limit = time_limit, 122 | pkgs = pkgs, 123 | save_workflow = save_workflow, 124 | save_history = save_history, 125 | event_level = event_level, 126 | parallel_over = parallel_over, 127 | allow_par = allow_par, 128 | backend_options = backend_options 129 | ) 130 | 131 | class(res) <- "control_sim_anneal" 132 | res 133 | } 134 | 135 | #' @export 136 | print.control_sim_anneal <- function(x, ...) { 137 | cat("Simulated annealing control object\n") 138 | invisible(x) 139 | } 140 | 141 | 142 | val_parallel_over <- function(parallel_over, where) { 143 | val_class_and_single(parallel_over, "character", where) 144 | rlang::arg_match0( 145 | parallel_over, 146 | c("resamples", "everything"), 147 | "parallel_over" 148 | ) 149 | invisible(NULL) 150 | } 151 | -------------------------------------------------------------------------------- /R/finetune-package.R: -------------------------------------------------------------------------------- 1 | #' @keywords internal 2 | "_PACKAGE" 3 | 4 | ## usethis namespace: start 5 | ## usethis namespace: end 6 | NULL 7 | 8 | # ------------------------------------------------------------------------------ 9 | 10 | #' @importFrom stats qt runif coef confint lm reorder rnorm setNames dist 11 | #' @importFrom utils globalVariables 12 | #' @importFrom rlang syms caller_env 13 | #' @importFrom dplyr distinct 14 | #' @importFrom utils globalVariables 15 | #' @import tune 16 | NULL 17 | 18 | # ------------------------------------------------------------------------------ 19 | 20 | # fmt: skip 21 | utils::globalVariables( 22 | c( 23 | ".config", ".estimate", ".iter", ".metric", ".parent", "B", "Estimate", 24 | "Std. Error", "lower", "metric_1", "metric_2", "n", "no_improve", "p1", 25 | "p2", "pair", "pass", "player", "player_1", "player_2", "std_err", "upper", 26 | "value", "wins", "wins_1", "wins_2", ".metrics", ".order", "id", "new", 27 | "orig", "stage", "symb", "id2", ".rand", ".eval_time" 28 | ) 29 | ) 30 | -------------------------------------------------------------------------------- /R/plot_race.R: -------------------------------------------------------------------------------- 1 | #' Plot racing results 2 | #' 3 | #' Plot the model results over stages of the racing results. A line is given 4 | #' for each submodel that was tested. 5 | #' @param x A object with class `tune_results` 6 | #' @return A ggplot object. 7 | #' @export 8 | plot_race <- function(x) { 9 | metric <- tune::.get_tune_metric_names(x)[1] 10 | ex_mtrc <- collect_metrics(x) 11 | 12 | if (any(names(ex_mtrc) == ".eval_time")) { 13 | eval_time <- min(ex_mtrc$.eval_time, na.rm = TRUE) 14 | } else { 15 | eval_time <- NULL 16 | } 17 | 18 | rs <- 19 | x |> 20 | dplyr::select(id, .order, .metrics) |> 21 | tidyr::unnest(cols = .metrics) |> 22 | dplyr::filter(.metric == metric) 23 | 24 | if (!is.null(eval_time) && any(names(rs) == ".eval_time")) { 25 | rs <- dplyr::filter(rs, .eval_time == eval_time) 26 | } 27 | 28 | .order <- sort(unique(rs$.order)) 29 | purrr::map(.order, \(x) stage_results(x, rs)) |> 30 | purrr::list_rbind() |> 31 | ggplot2::ggplot(ggplot2::aes( 32 | x = stage, 33 | y = mean, 34 | group = .config, 35 | col = .config 36 | )) + 37 | ggplot2::geom_line(alpha = .5, show.legend = FALSE) + 38 | ggplot2::xlab("Analysis Stage") + 39 | ggplot2::ylab(metric) + 40 | ggplot2::scale_x_continuous(breaks = integer_breaks) 41 | } 42 | 43 | integer_breaks <- function(lims) { 44 | breaks <- pretty(lims) 45 | 46 | unique(round(breaks)) 47 | } 48 | 49 | stage_results <- function(ind, x) { 50 | res <- 51 | x |> 52 | dplyr::filter(.order <= ind) |> 53 | dplyr::group_by(.config) |> 54 | dplyr::summarize( 55 | mean = mean(.estimate, na.rm = TRUE), 56 | n = sum(!is.na(.estimate)), 57 | .groups = "drop" 58 | ) |> 59 | dplyr::mutate(stage = ind) |> 60 | dplyr::ungroup() |> 61 | dplyr::filter(n == ind) 62 | } 63 | -------------------------------------------------------------------------------- /R/s3_register.R: -------------------------------------------------------------------------------- 1 | # nocov start 2 | s3_register <- function(generic, class, method = NULL) { 3 | stopifnot(is.character(generic), length(generic) == 1) 4 | stopifnot(is.character(class), length(class) == 1) 5 | 6 | pieces <- strsplit(generic, "::")[[1]] 7 | stopifnot(length(pieces) == 2) 8 | package <- pieces[[1]] 9 | generic <- pieces[[2]] 10 | 11 | caller <- parent.frame() 12 | 13 | get_method_env <- function() { 14 | top <- topenv(caller) 15 | if (isNamespace(top)) { 16 | asNamespace(environmentName(top)) 17 | } else { 18 | caller 19 | } 20 | } 21 | get_method <- function(method, env) { 22 | if (is.null(method)) { 23 | get(paste0(generic, ".", class), envir = get_method_env()) 24 | } else { 25 | method 26 | } 27 | } 28 | 29 | method_fn <- get_method(method) 30 | stopifnot(is.function(method_fn)) 31 | 32 | # Always register hook in case package is later unloaded & reloaded 33 | setHook( 34 | packageEvent(package, "onLoad"), 35 | function(...) { 36 | ns <- asNamespace(package) 37 | 38 | # Refresh the method, it might have been updated by `devtools::load_all()` 39 | method_fn <- get_method(method) 40 | 41 | registerS3method(generic, class, method_fn, envir = ns) 42 | } 43 | ) 44 | 45 | # Avoid registration failures during loading (pkgload or regular) 46 | if (!isNamespaceLoaded(package)) { 47 | return(invisible()) 48 | } 49 | 50 | envir <- asNamespace(package) 51 | 52 | # Only register if generic can be accessed 53 | if (exists(generic, envir)) { 54 | registerS3method(generic, class, method_fn, envir = envir) 55 | } 56 | 57 | invisible() 58 | } 59 | 60 | # nocov end 61 | -------------------------------------------------------------------------------- /R/sim_anneal_helpers.R: -------------------------------------------------------------------------------- 1 | maximize_metric <- function(x, metric) { 2 | metrics <- tune::.get_tune_metrics(x) 3 | metrics_data <- tune::metrics_info(metrics) 4 | x <- metrics_data$.metric[1] 5 | metrics_data$direction[metrics_data$.metric == metric] == "maximize" 6 | } 7 | 8 | # Might not use this function 9 | treat_as_integer <- function(x, num_unique = 10) { 10 | param_type <- purrr::map_chr(x$object, \(x) x$type) 11 | is_int <- param_type == "integer" 12 | x_vals <- purrr::map(x$object, \(x) dials::value_seq(x, n = 200)) 13 | x_vals <- purrr::map_int(x_vals, \(x) length(unique(x))) 14 | x_vals < num_unique & is_int 15 | } 16 | 17 | new_in_neighborhood <- function( 18 | current, 19 | hist_values, 20 | pset, 21 | radius = c(0.05, 0.15), 22 | flip = 0.1 23 | ) { 24 | current <- dplyr::select(current, !!!pset$id) 25 | param_type <- purrr::map_chr(pset$object, \(x) x$type) 26 | if (any(param_type == "double")) { 27 | dbl_nms <- pset$id[param_type == "double"] 28 | new_dbl <- 29 | random_real_neighbor( 30 | current |> dplyr::select(dplyr::all_of(dbl_nms)), 31 | hist_values = hist_values |> dplyr::select(dplyr::all_of(dbl_nms)), 32 | pset |> dplyr::filter(id %in% dbl_nms), 33 | r = radius 34 | ) 35 | current[, dbl_nms] <- new_dbl 36 | } 37 | 38 | if (any(param_type == "integer")) { 39 | int_nms <- pset$id[param_type == "integer"] 40 | flip_one <- all(param_type == "integer") 41 | new_int <- 42 | random_integer_neighbor( 43 | current |> dplyr::select(dplyr::all_of(int_nms)), 44 | hist_values = hist_values |> dplyr::select(dplyr::all_of(int_nms)), 45 | pset |> dplyr::filter(id %in% int_nms), 46 | prob = flip, 47 | change = flip_one 48 | ) 49 | current[, int_nms] <- new_int 50 | } 51 | 52 | if (any(param_type == "character")) { 53 | chr_nms <- pset$id[param_type == "character"] 54 | flip_one <- all(param_type == "character") 55 | new_chr <- 56 | random_discrete_neighbor( 57 | current |> dplyr::select(!!!chr_nms), 58 | pset |> dplyr::filter(id %in% chr_nms), 59 | prob = flip, 60 | change = flip_one 61 | ) 62 | current[, chr_nms] <- new_chr 63 | } 64 | current 65 | } 66 | 67 | random_discrete_neighbor <- function(current, pset, prob, change) { 68 | pnames <- pset$id 69 | change_val <- runif(length(pnames)) <= prob 70 | if (change & !any(change_val)) { 71 | change_val[sample(seq_along(change_val), 1)] <- TRUE 72 | } 73 | if (any(change_val)) { 74 | new_vals <- pnames[change_val] 75 | for (i in new_vals) { 76 | current_val <- current[[i]] 77 | parm_obj <- pset$object[[which(pset$id == i)]] 78 | parm_obj$values <- setdiff(parm_obj$values, current_val) 79 | current[[i]] <- dials::value_sample(parm_obj, 1) 80 | } 81 | } 82 | current 83 | } 84 | 85 | 86 | random_integer_neighbor <- function( 87 | current, 88 | hist_values, 89 | pset, 90 | prob, 91 | change, 92 | retain = 1, 93 | tries = 500 94 | ) { 95 | candidates <- 96 | purrr::map( 97 | 1:tries, 98 | \(x) random_integer_neighbor_calc(current, pset, prob, change) 99 | ) |> 100 | purrr::list_rbind() 101 | 102 | rnd <- tune::encode_set(candidates, pset, as_matrix = TRUE) 103 | sample_by_distance(rnd, hist_values, retain = retain, pset = pset) 104 | } 105 | 106 | random_integer_neighbor_calc <- function(current, pset, prob, change) { 107 | change_val <- runif(nrow(pset)) <= prob 108 | if (change & !any(change_val)) { 109 | change_val[sample(seq_along(change_val), 1)] <- TRUE 110 | } 111 | if (any(change_val)) { 112 | param_change <- pset$id[change_val] 113 | for (i in param_change) { 114 | prm <- pset$object[[which(pset$id == i)]] 115 | prm_rng <- prm$range$upper - prm$range$lower 116 | tries <- min(prm_rng + 1, 500) 117 | pool <- dials::value_seq(prm, n = tries) 118 | smol_range <- floor(prm_rng / 10) + 1 119 | val_diff <- abs(current[[i]] - pool) 120 | pool <- pool[val_diff <= smol_range & val_diff > 0] 121 | if (length(pool) > 1) { 122 | current[[i]] <- sample(pool, 1) 123 | } else if (length(pool) == 1) { 124 | current[[i]] <- pool 125 | } 126 | } 127 | } 128 | current 129 | } 130 | 131 | random_real_neighbor <- function( 132 | current, 133 | hist_values, 134 | pset, 135 | retain = 1, 136 | tries = 500, 137 | r = c(0.05, 0.15) 138 | ) { 139 | is_quant <- purrr::map_lgl(pset$object, inherits, "quant_param") 140 | current <- current[, is_quant] 141 | pset <- pset[is_quant, ] 142 | encoded <- tune::encode_set(current, pset, as_matrix = TRUE) 143 | 144 | num_param <- ncol(encoded) 145 | if (num_param > 1) { 146 | rnd <- rnorm(num_param * tries) 147 | rnd <- matrix(rnd, ncol = num_param) 148 | rnd <- t(apply(rnd, 1, function(x) x / sqrt(sum(x^2)))) 149 | rnd <- rnd * runif(tries, min = min(r), max = max(r)) 150 | rnd <- sweep(rnd, 2, as.vector(encoded), "+") 151 | outside <- apply(rnd, 1, function(x) any(x > 1 | x < 0)) 152 | rnd <- rnd[!outside, , drop = FALSE] 153 | } else { 154 | rnd <- runif(tries, min = -max(r), max = max(r)) + encoded[[1]] 155 | rnd <- ifelse(rnd > 1, 1, rnd) 156 | rnd <- ifelse(rnd < 0, 0, rnd) 157 | rnd <- matrix(rnd, ncol = 1) 158 | rnd <- rnd[!duplicated(rnd), , drop = FALSE] 159 | } 160 | colnames(rnd) <- names(current) 161 | retain <- min(retain, nrow(rnd)) 162 | 163 | sample_by_distance(rnd, hist_values, retain = retain, pset = pset) 164 | } 165 | 166 | encode_set_backwards <- function(x, pset, ...) { 167 | pset <- pset[pset$id %in% names(x), ] 168 | mapply( 169 | check_backwards_encode, 170 | pset$object, 171 | x, 172 | pset$id, 173 | SIMPLIFY = FALSE, 174 | USE.NAMES = FALSE 175 | ) 176 | new_vals <- purrr::map2( 177 | pset$object, 178 | x, 179 | dials::encode_unit, 180 | direction = "backward" 181 | ) 182 | names(new_vals) <- names(x) 183 | tibble::as_tibble(new_vals) 184 | } 185 | 186 | check_backwards_encode <- function(x, value, id) { 187 | if (!dials::has_unknowns(x)) { 188 | compl <- value[!is.na(value)] 189 | if (any(compl < 0) | any(compl > 1)) { 190 | cli::cli_abort( 191 | c( 192 | "!" = "The range for parameter {.val {noquote(id)}} used when \\ 193 | generating initial results isn't compatible with the range \\ 194 | supplied in {.arg param_info}.", 195 | "i" = "Possible values of parameters in {.arg param_info} should \\ 196 | encompass all values evaluated in the initial grid." 197 | ), 198 | call = rlang::call2("tune_sim_anneal()") 199 | ) 200 | } 201 | } 202 | } 203 | 204 | sample_by_distance <- function(candidates, existing, retain, pset) { 205 | if (nrow(existing) > 0) { 206 | existing <- tune::encode_set(existing, pset, as_matrix = TRUE) 207 | hist_index <- 1:nrow(existing) 208 | all_values <- rbind(existing, candidates) 209 | all_values <- stats::dist(all_values) 210 | all_values <- as.matrix(all_values) 211 | all_values <- all_values[hist_index, -hist_index, drop = FALSE] 212 | min_dist <- apply(all_values, 2, min) 213 | min_dist <- min_dist / max(min_dist) 214 | prob_wt <- min_dist^2 215 | prob_wt[is.na(prob_wt)] <- 0.0001 216 | 217 | if (diff(range(prob_wt)) < 0.0001) { 218 | prob_wt <- rep(1 / nrow(candidates), nrow(candidates)) 219 | } 220 | } else { 221 | prob_wt <- rep(1 / nrow(candidates), nrow(candidates)) 222 | } 223 | retain <- min(retain, nrow(candidates)) 224 | 225 | candidates <- tibble::as_tibble(candidates) 226 | candidates <- encode_set_backwards(candidates, pset) 227 | 228 | selected <- sample(seq_along(prob_wt), size = retain, prob = prob_wt) 229 | candidates[selected, ] 230 | } 231 | 232 | ## ----------------------------------------------------------------------------- 233 | 234 | update_history <- function(history, x, iter, eval_time) { 235 | analysis_metric <- tune::.get_tune_metric_names(x)[1] 236 | res <- 237 | tune::show_best(x, metric = analysis_metric, eval_time = eval_time) |> 238 | dplyr::mutate( 239 | .config = paste0("iter", iter), 240 | .iter = iter, 241 | random = runif(1), 242 | accept = NA_real_, 243 | results = NA_character_ 244 | ) 245 | if (is.null(history)) { 246 | history <- res 247 | } else { 248 | history <- dplyr::bind_rows(history, res) 249 | } 250 | 251 | if (maximize_metric(x, analysis_metric)) { 252 | best_res <- which.max(history$mean) 253 | } else { 254 | best_res <- which.min(history$mean) 255 | } 256 | 257 | history$global_best <- FALSE 258 | history$global_best[best_res] <- TRUE 259 | history 260 | } 261 | 262 | sa_decide <- function(x, parent, metric, maximize, coef) { 263 | res <- dplyr::filter(x, .metric == metric) 264 | latest_ind <- which.max(res$.iter) 265 | prev_ind <- which(res$.config == parent) 266 | prev_metric <- res$mean[prev_ind] 267 | latest_metric <- res$mean[latest_ind] 268 | all_prev <- res$mean[1:prev_ind] 269 | 270 | if (maximize) { 271 | is_best <- latest_metric > max(all_prev, na.rm = TRUE) 272 | is_better <- isTRUE(latest_metric > prev_metric) 273 | } else { 274 | is_best <- latest_metric < min(all_prev, na.rm = TRUE) 275 | is_better <- isTRUE(latest_metric < prev_metric) 276 | } 277 | 278 | m <- nrow(x) 279 | 280 | x$accept[m] <- 281 | acceptance_prob( 282 | current = prev_metric, 283 | new = latest_metric, 284 | iter = max(x$.iter), 285 | maximize = maximize, 286 | coef = coef 287 | ) 288 | 289 | if (is_best) { 290 | x$results[m] <- "new best" 291 | x$random[m] <- x$accept[m] <- NA_real_ 292 | } else if (is_better) { 293 | x$results[m] <- "better suboptimal" 294 | x$random[m] <- x$accept[m] <- NA_real_ 295 | } else { 296 | if (x$random[m] <= x$accept[m]) { 297 | x$results[m] <- "accept suboptimal" 298 | } else { 299 | x$results[m] <- "discard suboptimal" 300 | } 301 | } 302 | x 303 | } 304 | 305 | initialize_history <- function(x, eval_time = NULL, ...) { 306 | # check to see if there is existing history 307 | res <- 308 | tune::collect_metrics(x) |> 309 | dplyr::filter(.metric == tune::.get_tune_metric_names(x)[1]) 310 | if (!is.na(eval_time) && any(names(res) == ".eval_time")) { 311 | res <- res |> dplyr::filter(.eval_time == eval_time) 312 | } 313 | 314 | if (!any(names(res) == ".iter")) { 315 | res$.iter <- 0 316 | } 317 | 318 | res <- 319 | res |> 320 | dplyr::mutate( 321 | random = NA_real_, 322 | accept = NA_real_, 323 | results = "initial" 324 | ) 325 | res 326 | } 327 | 328 | 329 | percent_diff <- function(current, new, maximize = TRUE) { 330 | if (isTRUE(all.equal(current, new))) { 331 | return(0.0) 332 | } 333 | if (maximize) { 334 | pct_diff <- (new - current) / current 335 | } else { 336 | pct_diff <- (current - new) / current 337 | } 338 | pct_diff * 100 339 | } 340 | 341 | acceptance_prob <- function( 342 | current, 343 | new, 344 | iter, 345 | maximize = TRUE, 346 | coef = 2 / 100 347 | ) { 348 | pct_diff <- percent_diff(current, new, maximize) 349 | if (pct_diff > 0) { 350 | return(1.0) 351 | } 352 | exp(pct_diff * coef * iter) 353 | } 354 | 355 | log_sa_progress <- function( 356 | control = list(verbose_iter = TRUE), 357 | x, 358 | metric, 359 | max_iter, 360 | maximize = TRUE, 361 | digits = 5 362 | ) { 363 | if (!control$verbose_iter) { 364 | return(invisible(NULL)) 365 | } 366 | is_initial <- all(x$results == "initial") 367 | if (is_initial) { 368 | m <- max(which(x$global_best)) 369 | new_res <- x$mean[m] 370 | new_std <- x$std_err[m] 371 | new_event <- x$results[m] 372 | } else { 373 | m <- nrow(x) 374 | new_res <- x$mean[m] 375 | new_std <- x$std_err[m] 376 | new_event <- x$results[m] 377 | } 378 | iter <- max(x$.iter) 379 | if (iter > 0 & !is_initial) { 380 | is_best <- isTRUE(x$global_best[m]) 381 | prev_res <- x$mean[m - 1] 382 | pct_diff <- percent_diff(prev_res, new_res, maximize) * 100 383 | pct_diff <- sprintf("%6.2f", pct_diff) 384 | } else { 385 | is_best <- FALSE 386 | pct_diff <- NA_real_ 387 | } 388 | 389 | chr_iter <- format(1:max_iter)[iter] 390 | dig <- paste0("%.", digits, "f") 391 | 392 | cols <- tune::get_tune_colors() 393 | if (iter > 0) { 394 | msg <- paste0(metric, "=", signif(new_res, digits = digits)) 395 | if (!is.na(new_std) && new_std > 0) { 396 | msg <- paste0(msg, "\t(+/-", signif(new_std, digits = digits - 1), ")") 397 | } 398 | msg <- paste(chr_iter, format_event(new_event), msg) 399 | } else { 400 | if (maximize) { 401 | initial_res <- max(x$mean[x$.iter == 0], na.rm = TRUE) 402 | } else { 403 | initial_res <- min(x$mean[x$.iter == 0], na.rm = TRUE) 404 | } 405 | msg <- paste0( 406 | "Initial best: ", 407 | sprintf(dig, signif(initial_res, digits = digits)) 408 | ) 409 | } 410 | 411 | cli::cli_bullets(cols$message$info(msg)) 412 | } 413 | 414 | # fmt: skip 415 | format_event <- function(x) { 416 | result_key <- tibble::tribble( 417 | ~orig, ~symb, 418 | "initial", cli::symbol$tick, 419 | "new best", cli::symbol$heart, 420 | "better suboptimal", "+", 421 | "discard suboptimal", cli::symbol$line, 422 | "accept suboptimal", cli::symbol$circle, 423 | "restart from best", cli::symbol$cross 424 | ) |> 425 | dplyr::mutate( 426 | new = format(orig, justify = "left"), 427 | new = gsub(" ", "\u00a0", new, fixed = TRUE), 428 | result = paste(symb, new) 429 | ) 430 | color_event(result_key$result[result_key$orig == x]) 431 | } 432 | 433 | color_event <- function(x) { 434 | cols <- tune::get_tune_colors() 435 | dplyr::case_when( 436 | grepl("initial", x) ~ cols$symbol$info(x), 437 | grepl("new", x) ~ cols$symbol$success(x), 438 | grepl("better", x) ~ cols$symbol$success(x), 439 | grepl("discard", x) ~ cols$message$danger(x), 440 | grepl("accept", x) ~ cols$message$warning(x), 441 | grepl("restart", x) ~ cols$message$danger(x), 442 | TRUE ~ cols$message$info(x) 443 | ) 444 | } 445 | 446 | get_outcome_names <- function(x, rs) { 447 | preproc <- extract_preprocessor(x) 448 | if (inherits(preproc, "workflow_variables")) { 449 | if (any(names(preproc) == "outcomes")) { 450 | dat <- rs$splits[[1]]$data 451 | res <- tidyselect::eval_select(preproc$outcomes, data = dat) 452 | res <- names(res) 453 | } else { 454 | cli::cli_abort("Cannot obtain the outcome name(s)") 455 | } 456 | } else { 457 | res <- outcome_names(x) 458 | } 459 | res 460 | } 461 | 462 | update_config <- function(x, prefix = NULL, config = "new", save_pred) { 463 | if (!is.null(prefix)) { 464 | x$.metrics <- 465 | purrr::map( 466 | x$.metrics, 467 | \(x) dplyr::mutate(x, .config = paste0(prefix, "_", .config)) 468 | ) 469 | if (save_pred) { 470 | x$.predictions <- 471 | purrr::map( 472 | x$.predictions, 473 | \(x) dplyr::mutate(x, .config = paste0(prefix, "_", .config)) 474 | ) 475 | } 476 | } else { 477 | x$.metrics <- 478 | purrr::map( 479 | x$.metrics, 480 | \(x) dplyr::mutate(x, .config = config) 481 | ) 482 | if (save_pred) { 483 | x$.predictions <- 484 | purrr::map( 485 | x$.predictions, 486 | \(x) dplyr::mutate(x, .config = config) 487 | ) 488 | } 489 | } 490 | x 491 | } 492 | -------------------------------------------------------------------------------- /R/tune_race_anova.R: -------------------------------------------------------------------------------- 1 | #' Efficient grid search via racing with ANOVA models 2 | #' 3 | #' [tune_race_anova()] computes a set of performance metrics (e.g. accuracy or RMSE) 4 | #' for a pre-defined set of tuning parameters that correspond to a model or 5 | #' recipe across one or more resamples of the data. After an initial number of 6 | #' resamples have been evaluated, the process eliminates tuning parameter 7 | #' combinations that are unlikely to be the best results using a repeated 8 | #' measure ANOVA model. 9 | #' 10 | #' @param object A `parsnip` model specification or a [workflows::workflow()]. 11 | #' @param preprocessor A traditional model formula or a recipe created using 12 | #' [recipes::recipe()]. This is only required when `object` is not a workflow. 13 | #' @param resamples An `rset()` object that has multiple resamples (i.e., is not 14 | #' a validation set). 15 | #' @param param_info A [dials::parameters()] object or `NULL`. If none is given, 16 | #' a parameters set is derived from other arguments. Passing this argument can 17 | #' be useful when parameter ranges need to be customized. 18 | #' @param grid A data frame of tuning combinations or a positive integer. The 19 | #' data frame should have columns for each parameter being tuned and rows for 20 | #' tuning parameter candidates. An integer denotes the number of candidate 21 | #' parameter sets to be created automatically. 22 | #' @param metrics A [yardstick::metric_set()] or `NULL`. 23 | #' @param eval_time A numeric vector of time points where dynamic event time 24 | #' metrics should be computed (e.g. the time-dependent ROC curve, etc). The 25 | #' values must be non-negative and should probably be no greater than the 26 | #' largest event time in the training set (See Details below). 27 | #' @param control An object used to modify the tuning process. See 28 | #' [control_race()] for more details. 29 | #' @param ... Not currently used. 30 | #' @references 31 | #' Kuhn, M 2014. "Futility Analysis in the Cross-Validation of Machine Learning 32 | #' Models." \url{https://arxiv.org/abs/1405.6974}. 33 | #' @details 34 | #' The technical details of this method are described in Kuhn (2014). 35 | #' 36 | #' Racing methods are efficient approaches to grid search. Initially, the 37 | #' function evaluates all tuning parameters on a small initial set of 38 | #' resamples. The `burn_in` argument of [control_race()] sets the number of 39 | #' initial resamples. 40 | #' 41 | #' The performance statistics from these resamples are analyzed to determine 42 | #' which tuning parameters are _not_ statistically different from the current 43 | #' best setting. If a parameter is statistically different, it is excluded from 44 | #' further resampling. 45 | #' 46 | #' The next resample is used with the remaining parameter combinations and the 47 | #' statistical analysis is updated. More candidate parameters may be excluded 48 | #' with each new resample that is processed. 49 | #' 50 | #' This function determines statistical significance using a repeated measures ANOVA 51 | #' model where the performance statistic (e.g., RMSE, accuracy, etc.) is the 52 | #' outcome data and the random effect is due to resamples. The 53 | #' [control_race()] function contains are parameter for the significance cutoff 54 | #' applied to the ANOVA results as well as other relevant arguments. 55 | #' 56 | #' There is benefit to using racing methods in conjunction with parallel 57 | #' processing. The following section shows a benchmark of results for one 58 | #' dataset and model. 59 | #' 60 | #' ## Censored regression models 61 | #' 62 | #' With dynamic performance metrics (e.g. Brier or ROC curves), performance is 63 | #' calculated for every value of `eval_time` but the _first_ evaluation time 64 | #' given by the user (e.g., `eval_time[1]`) is analyzed during racing. 65 | #' 66 | #' Also, values of `eval_time` should be less than the largest observed event 67 | #' time in the training data. For many non-parametric models, the results beyond 68 | #' the largest time corresponding to an event are constant (or `NA`). 69 | #' 70 | #' @return An object with primary class `tune_race` in the same standard format 71 | #' as objects produced by [tune::tune_grid()]. 72 | #' @includeRmd man/rmd/anova-benchmark.md details 73 | #' @examples 74 | #' \donttest{ 75 | #' library(parsnip) 76 | #' library(rsample) 77 | #' library(dials) 78 | #' 79 | #' ## ----------------------------------------------------------------------------- 80 | #' 81 | #' if (rlang::is_installed(c("discrim", "lme4", "modeldata"))) { 82 | #' library(discrim) 83 | #' data(two_class_dat, package = "modeldata") 84 | #' 85 | #' set.seed(6376) 86 | #' rs <- bootstraps(two_class_dat, times = 10) 87 | #' 88 | #' ## ----------------------------------------------------------------------------- 89 | #' 90 | #' # optimize an regularized discriminant analysis model 91 | #' rda_spec <- 92 | #' discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) |> 93 | #' set_engine("klaR") 94 | #' 95 | #' ## ----------------------------------------------------------------------------- 96 | #' 97 | #' ctrl <- control_race(verbose_elim = TRUE) 98 | #' set.seed(11) 99 | #' grid_anova <- 100 | #' rda_spec |> 101 | #' tune_race_anova(Class ~ ., resamples = rs, grid = 10, control = ctrl) 102 | #' 103 | #' # Shows only the fully resampled parameters 104 | #' show_best(grid_anova, metric = "roc_auc", n = 2) 105 | #' 106 | #' plot_race(grid_anova) 107 | #' } 108 | #' } 109 | #' @seealso [tune::tune_grid()], [control_race()], [tune_race_win_loss()] 110 | #' @export 111 | tune_race_anova <- function(object, ...) { 112 | UseMethod("tune_race_anova") 113 | } 114 | 115 | #' @export 116 | tune_race_anova.default <- function(object, ...) { 117 | msg <- paste0( 118 | "The first argument to [tune_race_anova()] should be either ", 119 | "a model or workflow." 120 | ) 121 | cli::cli_abort(msg) 122 | } 123 | 124 | #' @export 125 | tune_race_anova.recipe <- 126 | function( 127 | object, 128 | model, 129 | resamples, 130 | ..., 131 | param_info = NULL, 132 | grid = 10, 133 | metrics = NULL, 134 | eval_time = NULL, 135 | control = control_race() 136 | ) { 137 | tune::empty_ellipses(...) 138 | 139 | control <- parsnip::condense_control(control, control_race()) 140 | 141 | tune_race_anova( 142 | model, 143 | preprocessor = object, 144 | resamples = resamples, 145 | param_info = param_info, 146 | grid = grid, 147 | metrics = metrics, 148 | eval_time = eval_time, 149 | control = control 150 | ) 151 | } 152 | 153 | #' @export 154 | tune_race_anova.formula <- 155 | function( 156 | formula, 157 | model, 158 | resamples, 159 | ..., 160 | param_info = NULL, 161 | grid = 10, 162 | metrics = NULL, 163 | eval_time = NULL, 164 | control = control_race() 165 | ) { 166 | tune::empty_ellipses(...) 167 | 168 | control <- parsnip::condense_control(control, control_race()) 169 | 170 | tune_race_anova( 171 | model, 172 | preprocessor = formula, 173 | resamples = resamples, 174 | param_info = param_info, 175 | grid = grid, 176 | metrics = metrics, 177 | eval_time = eval_time, 178 | control = control 179 | ) 180 | } 181 | 182 | #' @export 183 | #' @rdname tune_race_anova 184 | tune_race_anova.model_spec <- 185 | function( 186 | object, 187 | preprocessor, 188 | resamples, 189 | ..., 190 | param_info = NULL, 191 | grid = 10, 192 | metrics = NULL, 193 | eval_time = NULL, 194 | control = control_race() 195 | ) { 196 | if ( 197 | rlang::is_missing(preprocessor) || !tune::is_preprocessor(preprocessor) 198 | ) { 199 | cli::cli_abort( 200 | "To tune a model spec, you must preprocess with a formula, recipe, \\ 201 | or variable specification." 202 | ) 203 | } 204 | 205 | tune::empty_ellipses(...) 206 | 207 | control <- parsnip::condense_control(control, control_race()) 208 | 209 | wflow <- workflows::add_model(workflows::workflow(), object) 210 | 211 | if (tune::is_recipe(preprocessor)) { 212 | wflow <- workflows::add_recipe(wflow, preprocessor) 213 | } else if (rlang::is_formula(preprocessor)) { 214 | wflow <- workflows::add_formula(wflow, preprocessor) 215 | } 216 | 217 | tune_race_anova_workflow( 218 | wflow, 219 | resamples = resamples, 220 | grid = grid, 221 | metrics = metrics, 222 | eval_time = eval_time, 223 | param_info = param_info, 224 | control = control 225 | ) 226 | } 227 | 228 | #' @export 229 | #' @rdname tune_race_anova 230 | tune_race_anova.workflow <- 231 | function( 232 | object, 233 | resamples, 234 | ..., 235 | param_info = NULL, 236 | grid = 10, 237 | metrics = NULL, 238 | eval_time = NULL, 239 | control = control_race() 240 | ) { 241 | tune::empty_ellipses(...) 242 | 243 | control <- parsnip::condense_control(control, control_race()) 244 | 245 | tune_race_anova_workflow( 246 | object, 247 | resamples = resamples, 248 | grid = grid, 249 | metrics = metrics, 250 | eval_time = eval_time, 251 | param_info = param_info, 252 | control = control 253 | ) 254 | } 255 | 256 | ## ----------------------------------------------------------------------------- 257 | 258 | tune_race_anova_workflow <- 259 | function( 260 | object, 261 | resamples, 262 | param_info = NULL, 263 | grid = 10, 264 | metrics = NULL, 265 | eval_time = NULL, 266 | control = control_race(), 267 | call = caller_env() 268 | ) { 269 | rlang::check_installed("lme4") 270 | 271 | tune::initialize_catalog(control = control) 272 | 273 | B <- nrow(resamples) 274 | if (control$randomize) { 275 | resamples <- randomize_resamples(resamples) 276 | } 277 | resamples <- dplyr::mutate(resamples, .order = dplyr::row_number()) 278 | 279 | min_rs <- control$burn_in 280 | check_num_resamples(B, min_rs) 281 | tmp_resamples <- restore_rset(resamples, 1:min_rs) 282 | 283 | metrics <- tune::check_metrics_arg(metrics, object, call = call) 284 | eval_time <- tune::check_eval_time_arg(eval_time, metrics, call = call) 285 | 286 | control$pkgs <- c( 287 | control$pkgs, 288 | tune::required_pkgs(object), 289 | "workflows", 290 | "tidyr", 291 | "rlang" 292 | ) 293 | 294 | if (control$verbose_elim) { 295 | tune_cols <- tune::get_tune_colors() 296 | msg <- tune_cols$message$info( 297 | paste0( 298 | cli::symbol$info, 299 | " Evaluating against the initial {min_rs} burn-in resamples." 300 | ) 301 | ) 302 | 303 | cli::cli_inform(msg) 304 | } 305 | 306 | grid_control <- parsnip::condense_control(control, tune::control_grid()) 307 | res <- 308 | object |> 309 | tune::tune_grid( 310 | tmp_resamples, 311 | param_info = param_info, 312 | grid = grid, 313 | metrics = metrics, 314 | eval_time = eval_time, 315 | control = grid_control 316 | ) 317 | 318 | param_names <- tune::.get_tune_parameter_names(res) 319 | 320 | opt_metric <- tune::first_metric(metrics) 321 | opt_metric_name <- opt_metric$metric 322 | maximize <- opt_metric$direction == "maximize" 323 | 324 | opt_metric_time <- tune::first_eval_time( 325 | metrics, 326 | metric = opt_metric_name, 327 | eval_time = eval_time, 328 | call = call 329 | ) 330 | 331 | racing_obj_log( 332 | opt_metric_name, 333 | opt_metric$direction, 334 | control, 335 | opt_metric_time 336 | ) 337 | 338 | filters_results <- test_parameters_gls(res, control$alpha, opt_metric_time) 339 | n_grid <- nrow(filters_results) 340 | 341 | log_final <- TRUE 342 | num_ties <- 0 343 | for (rs in (min_rs + 1):B) { 344 | if (sum(filters_results$pass) == 2) { 345 | num_ties <- num_ties + 1 346 | } 347 | new_grid <- 348 | filters_results |> 349 | dplyr::filter(pass) |> 350 | dplyr::select(!!!param_names) 351 | 352 | if (nrow(new_grid) > 1) { 353 | tmp_resamples <- restore_rset(resamples, rs) 354 | log_racing( 355 | control, 356 | filters_results, 357 | res$splits, 358 | n_grid, 359 | opt_metric_name 360 | ) 361 | } else { 362 | tmp_resamples <- restore_rset(resamples, rs:B) 363 | if (log_final) { 364 | log_racing( 365 | control, 366 | filters_results, 367 | res$splits, 368 | n_grid, 369 | opt_metric_name 370 | ) 371 | } 372 | log_final <- FALSE 373 | } 374 | 375 | grid_control <- parsnip::condense_control(control, tune::control_grid()) 376 | tmp_res <- 377 | object |> 378 | tune::tune_grid( 379 | tmp_resamples, 380 | param_info = param_info, 381 | grid = new_grid, 382 | metrics = metrics, 383 | eval_time = eval_time, 384 | control = grid_control 385 | ) 386 | 387 | res <- restore_tune(res, tmp_res, opt_metric_time) 388 | 389 | if (nrow(new_grid) > 1) { 390 | filters_results <- test_parameters_gls( 391 | res, 392 | control$alpha, 393 | opt_metric_time 394 | ) 395 | if (sum(filters_results$pass) == 2 & num_ties >= control$num_ties) { 396 | filters_results <- tie_breaker( 397 | res, 398 | control, 399 | eval_time = opt_metric_time 400 | ) 401 | } 402 | } else { 403 | # Depending on the value of control$parallel_over we don't need to do 404 | # the remaining loop to get the rs counter to B 405 | max_B <- max(tune::collect_metrics(res)$n) 406 | if (max_B == B) { 407 | break() 408 | } 409 | } 410 | } 411 | 412 | .stash_last_result(res) 413 | 414 | res 415 | } 416 | 417 | # fmt: skip 418 | check_num_resamples <- function(B, min_rs) { 419 | if (B <= min_rs) { 420 | cli::cli_abort( 421 | paste0("The number of resamples (", B, ") needs to be more than the ", 422 | "number of burn-in resamples (", min_rs, ") set by the control ", 423 | "function `control_race()`."), 424 | call = NULL 425 | ) 426 | } 427 | invisible(NULL) 428 | } 429 | -------------------------------------------------------------------------------- /R/tune_race_win_loss.R: -------------------------------------------------------------------------------- 1 | #' Efficient grid search via racing with win/loss statistics 2 | #' 3 | #' [tune_race_win_loss()] computes a set of performance metrics (e.g. accuracy or RMSE) 4 | #' for a pre-defined set of tuning parameters that correspond to a model or 5 | #' recipe across one or more resamples of the data. After an initial number of 6 | #' resamples have been evaluated, the process eliminates tuning parameter 7 | #' combinations that are unlikely to be the best results using a statistical 8 | #' model. For each pairwise combinations of tuning parameters, win/loss 9 | #' statistics are calculated and a logistic regression model is used to measure 10 | #' how likely each combination is to win overall. 11 | #' 12 | #' @inheritParams tune_race_anova 13 | #' @references 14 | #' Kuhn, M 2014. "Futility Analysis in the Cross-Validation of Machine Learning 15 | #' Models." \url{https://arxiv.org/abs/1405.6974}. 16 | #' @param ... Not currently used. 17 | #' @details 18 | #' The technical details of this method are described in Kuhn (2014). 19 | #' 20 | #' Racing methods are efficient approaches to grid search. Initially, the 21 | #' function evaluates all tuning parameters on a small initial set of 22 | #' resamples. The `burn_in` argument of [control_race()] sets the number of 23 | #' initial resamples. 24 | #' 25 | #' The performance statistics from the current set of resamples are converted 26 | #' to win/loss/tie results. For example, for two parameters (`j` and `k`) in a 27 | #' classification model that have each been resampled three times: 28 | #' 29 | #' \preformatted{ 30 | #' | area under the ROC curve | 31 | #' ----------------------------- 32 | #' resample | parameter j | parameter k | winner 33 | #' --------------------------------------------- 34 | #' 1 | 0.81 | 0.92 | k 35 | #' 2 | 0.95 | 0.94 | j 36 | #' 3 | 0.79 | 0.81 | k 37 | #' --------------------------------------------- 38 | #' } 39 | #' 40 | #' After the third resample, parameter `k` has a 2:1 win/loss ratio versus `j`. 41 | #' Parameters with equal results are treated as a half-win for each setting. 42 | #' These statistics are determined for all pairwise combinations of the 43 | #' parameters and a Bradley-Terry model is used to model these win/loss/tie 44 | #' statistics. This model can compute the ability of a parameter combination to 45 | #' win overall. A confidence interval for the winning ability is computed and 46 | #' any settings whose interval includes zero are retained for future resamples 47 | #' (since it is not statistically different form the best results). 48 | #' 49 | #' 50 | #' The next resample is used with the remaining parameter combinations and the 51 | #' statistical analysis is updated. More candidate parameters may be excluded 52 | #' with each new resample that is processed. 53 | #' 54 | #' The [control_race()] function contains are parameter for the significance cutoff 55 | #' applied to the Bradley-Terry model results as well as other relevant arguments. 56 | #' 57 | #' ## Censored regression models 58 | #' 59 | #' With dynamic performance metrics (e.g. Brier or ROC curves), performance is 60 | #' calculated for every value of `eval_time` but the _first_ evaluation time 61 | #' given by the user (e.g., `eval_time[1]`) is analyzed during racing. 62 | #' 63 | #' Also, values of `eval_time` should be less than the largest observed event 64 | #' time in the training data. For many non-parametric models, the results beyond 65 | #' the largest time corresponding to an event are constant (or `NA`). 66 | #' 67 | #' @examples 68 | #' \donttest{ 69 | #' library(parsnip) 70 | #' library(rsample) 71 | #' library(dials) 72 | #' 73 | #' ## ----------------------------------------------------------------------------- 74 | #' 75 | #' if (rlang::is_installed(c("discrim", "modeldata"))) { 76 | #' library(discrim) 77 | #' data(two_class_dat, package = "modeldata") 78 | #' 79 | #' set.seed(6376) 80 | #' rs <- bootstraps(two_class_dat, times = 10) 81 | #' 82 | #' ## ----------------------------------------------------------------------------- 83 | #' 84 | #' # optimize an regularized discriminant analysis model 85 | #' rda_spec <- 86 | #' discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) |> 87 | #' set_engine("klaR") 88 | #' 89 | #' ## ----------------------------------------------------------------------------- 90 | #' 91 | #' ctrl <- control_race(verbose_elim = TRUE) 92 | #' 93 | #' set.seed(11) 94 | #' grid_wl <- 95 | #' rda_spec |> 96 | #' tune_race_win_loss(Class ~ ., resamples = rs, grid = 10, control = ctrl) 97 | #' 98 | #' # Shows only the fully resampled parameters 99 | #' show_best(grid_wl, metric = "roc_auc") 100 | #' 101 | #' plot_race(grid_wl) 102 | #' } 103 | #' } 104 | #' @return An object with primary class `tune_race` in the same standard format 105 | #' as objects produced by [tune::tune_grid()]. 106 | #' @seealso [tune::tune_grid()], [control_race()], [tune_race_anova()] 107 | #' @export 108 | tune_race_win_loss <- function(object, ...) { 109 | UseMethod("tune_race_win_loss") 110 | } 111 | 112 | #' @export 113 | tune_race_win_loss.default <- function(object, ...) { 114 | msg <- paste0( 115 | "The first argument to {.fn tune_race_win_loss} should be either ", 116 | "a model or workflow." 117 | ) 118 | cli::cli_abort(msg) 119 | } 120 | 121 | #' @export 122 | tune_race_win_loss.recipe <- 123 | function( 124 | object, 125 | model, 126 | resamples, 127 | ..., 128 | param_info = NULL, 129 | grid = 10, 130 | metrics = NULL, 131 | eval_time = NULL, 132 | control = control_race() 133 | ) { 134 | tune::empty_ellipses(...) 135 | 136 | control <- parsnip::condense_control(control, control_race()) 137 | 138 | tune_race_win_loss( 139 | model, 140 | preprocessor = object, 141 | resamples = resamples, 142 | param_info = param_info, 143 | grid = grid, 144 | metrics = metrics, 145 | control = control, 146 | eval_time = eval_time 147 | ) 148 | } 149 | 150 | #' @export 151 | tune_race_win_loss.formula <- 152 | function( 153 | formula, 154 | model, 155 | resamples, 156 | ..., 157 | param_info = NULL, 158 | grid = 10, 159 | metrics = NULL, 160 | eval_time = NULL, 161 | control = control_race() 162 | ) { 163 | tune::empty_ellipses(...) 164 | 165 | control <- parsnip::condense_control(control, control_race()) 166 | 167 | tune_race_win_loss( 168 | model, 169 | preprocessor = formula, 170 | resamples = resamples, 171 | param_info = param_info, 172 | grid = grid, 173 | metrics = metrics, 174 | eval_time = eval_time, 175 | control = control 176 | ) 177 | } 178 | 179 | #' @export 180 | #' @rdname tune_race_win_loss 181 | tune_race_win_loss.model_spec <- 182 | function( 183 | object, 184 | preprocessor, 185 | resamples, 186 | ..., 187 | param_info = NULL, 188 | grid = 10, 189 | metrics = NULL, 190 | eval_time = NULL, 191 | control = control_race() 192 | ) { 193 | if ( 194 | rlang::is_missing(preprocessor) || !tune::is_preprocessor(preprocessor) 195 | ) { 196 | cli::cli_abort( 197 | "To tune a model spec, you must preprocess with a formula, recipe, \\ 198 | or variable specification." 199 | ) 200 | } 201 | 202 | tune::empty_ellipses(...) 203 | 204 | control <- parsnip::condense_control(control, control_race()) 205 | 206 | wflow <- workflows::add_model(workflows::workflow(), object) 207 | 208 | if (tune::is_recipe(preprocessor)) { 209 | wflow <- workflows::add_recipe(wflow, preprocessor) 210 | } else if (rlang::is_formula(preprocessor)) { 211 | wflow <- workflows::add_formula(wflow, preprocessor) 212 | } 213 | 214 | tune_race_win_loss_workflow( 215 | wflow, 216 | resamples = resamples, 217 | grid = grid, 218 | metrics = metrics, 219 | eval_time = eval_time, 220 | param_info = param_info, 221 | control = control 222 | ) 223 | } 224 | 225 | #' @export 226 | #' @rdname tune_race_win_loss 227 | tune_race_win_loss.workflow <- function( 228 | object, 229 | resamples, 230 | ..., 231 | param_info = NULL, 232 | grid = 10, 233 | metrics = NULL, 234 | eval_time = NULL, 235 | control = control_race() 236 | ) { 237 | tune::empty_ellipses(...) 238 | 239 | control <- parsnip::condense_control(control, control_race()) 240 | 241 | tune_race_win_loss_workflow( 242 | object, 243 | resamples = resamples, 244 | grid = grid, 245 | metrics = metrics, 246 | eval_time = eval_time, 247 | param_info = param_info, 248 | control = control 249 | ) 250 | } 251 | 252 | ## ----------------------------------------------------------------------------- 253 | 254 | tune_race_win_loss_workflow <- 255 | function( 256 | object, 257 | resamples, 258 | param_info = NULL, 259 | grid = 10, 260 | metrics = NULL, 261 | eval_time = NULL, 262 | control = control_race(), 263 | call = caller_env() 264 | ) { 265 | rlang::check_installed("BradleyTerry2") 266 | 267 | B <- nrow(resamples) 268 | if (control$randomize) { 269 | resamples <- randomize_resamples(resamples) 270 | } 271 | resamples <- dplyr::mutate(resamples, .order = dplyr::row_number()) 272 | 273 | min_rs <- control$burn_in 274 | check_num_resamples(B, min_rs) 275 | tmp_resamples <- restore_rset(resamples, 1:min_rs) 276 | 277 | metrics <- tune::check_metrics_arg(metrics, object, call = call) 278 | eval_time <- tune::check_eval_time_arg(eval_time, metrics, call = call) 279 | 280 | grid_control <- parsnip::condense_control(control, tune::control_grid()) 281 | res <- 282 | object |> 283 | tune::tune_grid( 284 | resamples = tmp_resamples, 285 | param_info = param_info, 286 | grid = grid, 287 | metrics = metrics, 288 | eval_time = eval_time, 289 | control = grid_control 290 | ) 291 | 292 | param_names <- tune::.get_tune_parameter_names(res) 293 | 294 | opt_metric <- tune::first_metric(metrics) 295 | opt_metric_name <- opt_metric$metric 296 | maximize <- opt_metric$direction == "maximize" 297 | 298 | opt_metric_time <- tune::first_eval_time( 299 | metrics, 300 | metric = opt_metric_name, 301 | eval_time = eval_time, 302 | call = call 303 | ) 304 | 305 | racing_obj_log( 306 | opt_metric_name, 307 | opt_metric$direction, 308 | control, 309 | opt_metric_time 310 | ) 311 | 312 | filters_results <- test_parameters_bt(res, control$alpha, opt_metric_time) 313 | n_grid <- nrow(filters_results) 314 | 315 | log_final <- TRUE 316 | num_ties <- 0 317 | for (rs in (min_rs + 1):B) { 318 | if (sum(filters_results$pass) == 2) { 319 | num_ties <- num_ties + 1 320 | } 321 | new_grid <- 322 | filters_results |> 323 | dplyr::filter(pass) |> 324 | dplyr::select(!!!param_names) 325 | 326 | if (nrow(new_grid) > 1) { 327 | tmp_resamples <- restore_rset(resamples, rs) 328 | log_racing( 329 | control, 330 | filters_results, 331 | res$splits, 332 | n_grid, 333 | opt_metric_name 334 | ) 335 | } else { 336 | tmp_resamples <- restore_rset(resamples, rs:B) 337 | if (log_final) { 338 | log_racing( 339 | control, 340 | filters_results, 341 | res$splits, 342 | n_grid, 343 | opt_metric_name 344 | ) 345 | } 346 | log_final <- FALSE 347 | } 348 | 349 | grid_control <- parsnip::condense_control(control, tune::control_grid()) 350 | tmp_res <- 351 | object |> 352 | tune::tune_grid( 353 | resamples = tmp_resamples, 354 | param_info = param_info, 355 | grid = new_grid, 356 | metrics = metrics, 357 | eval_time = eval_time, 358 | control = grid_control 359 | ) 360 | res <- restore_tune(res, tmp_res, opt_metric_time) 361 | 362 | if (nrow(new_grid) > 1) { 363 | filters_results <- test_parameters_bt( 364 | res, 365 | control$alpha, 366 | opt_metric_time 367 | ) 368 | if (sum(filters_results$pass) == 2 & num_ties >= control$num_ties) { 369 | filters_results <- tie_breaker( 370 | res, 371 | control, 372 | eval_time = opt_metric_time 373 | ) 374 | } 375 | } else { 376 | # Depending on the value of control$parallel_over we don't need to do 377 | # the remaining loop to get the rs counter to B 378 | max_B <- max(tune::collect_metrics(res)$n) 379 | if (max_B == B) { 380 | break() 381 | } 382 | } 383 | } 384 | 385 | .stash_last_result(res) 386 | res 387 | } 388 | -------------------------------------------------------------------------------- /R/zzz.R: -------------------------------------------------------------------------------- 1 | .onLoad <- function(libname, pkgname) { 2 | vctrs::s3_register("tune::show_best", "tune_race") 3 | vctrs::s3_register("tune::collect_metrics", "tune_race") 4 | vctrs::s3_register("tune::collect_predictions", "tune_race") 5 | } 6 | -------------------------------------------------------------------------------- /README.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | output: github_document 3 | --- 4 | 5 | 6 | 7 | ```{r, include = FALSE} 8 | knitr::opts_chunk$set( 9 | collapse = TRUE, 10 | comment = "#>", 11 | fig.path = "man/figures/README-", 12 | out.width = "100%" 13 | ) 14 | ``` 15 | 16 | # finetune 17 | 18 | 19 | [![R-CMD-check](https://github.com/tidymodels/finetune/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/tidymodels/finetune/actions/workflows/R-CMD-check.yaml) 20 | [![Codecov test coverage](https://codecov.io/gh/tidymodels/finetune/branch/main/graph/badge.svg)](https://app.codecov.io/gh/tidymodels/finetune?branch=main) 21 | [![Lifecycle](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html) 22 | [![Codecov test coverage](https://codecov.io/gh/tidymodels/finetune/graph/badge.svg)](https://app.codecov.io/gh/tidymodels/finetune) 23 | 24 | 25 | `finetune` contains some extra functions for model tuning that extend what is currently in the `tune` package. You can install the CRAN version of the package with the following code: 26 | 27 | ```{r, eval = FALSE} 28 | install.packages("finetune") 29 | ``` 30 | 31 | To install the development version of the package, run: 32 | 33 | ```{r, eval = FALSE} 34 | # install.packages("pak") 35 | pak::pak("tidymodels/finetune") 36 | ``` 37 | 38 | There are two main sets of tools in the package: _simulated annealing_ and _racing_. 39 | 40 | Tuning via _simulated annealing_ optimization is an iterative search tool for finding good values: 41 | 42 | ```{r load, include=FALSE} 43 | library(tidymodels) 44 | library(finetune) 45 | library(discrim) 46 | library(rlang) 47 | library(MASS) 48 | ``` 49 | ```{r sa} 50 | library(tidymodels) 51 | library(finetune) 52 | 53 | # Syntax very similar to `tune_grid()` or `tune_bayes()`: 54 | 55 | ## ----------------------------------------------------------------------------- 56 | 57 | data(two_class_dat, package = "modeldata") 58 | 59 | set.seed(1) 60 | rs <- bootstraps(two_class_dat, times = 10) # more resamples usually needed 61 | 62 | # Optimize a regularized discriminant analysis model 63 | library(discrim) 64 | rda_spec <- 65 | discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) |> 66 | set_engine("klaR") 67 | 68 | ## ----------------------------------------------------------------------------- 69 | 70 | set.seed(2) 71 | sa_res <- 72 | rda_spec |> 73 | tune_sim_anneal(Class ~ ., resamples = rs, iter = 20, initial = 4) 74 | show_best(sa_res, metric = "roc_auc", n = 2) 75 | ``` 76 | 77 | The second set of methods are for _racing_. We start off by doing a small set of resamples for all of the grid points, then statistically testing to see which ones should be dropped or investigated more. The two methods here are based on those should in [Kuhn (2014)](https://arxiv.org/abs/1405.6974). 78 | 79 | For example, using an ANOVA-type analysis to filter out parameter combinations: 80 | 81 | ```{r race} 82 | set.seed(3) 83 | grid <- 84 | rda_spec |> 85 | extract_parameter_set_dials() |> 86 | grid_max_entropy(size = 20) 87 | 88 | ctrl <- control_race(verbose_elim = TRUE) 89 | 90 | set.seed(4) 91 | grid_anova <- 92 | rda_spec |> 93 | tune_race_anova(Class ~ ., resamples = rs, grid = grid, control = ctrl) 94 | 95 | show_best(grid_anova, metric = "roc_auc", n = 2) 96 | ``` 97 | 98 | `tune_race_win_loss()` can also be used. It treats the tuning parameters as sports teams in a tournament and computed win/loss statistics. 99 | 100 | 101 | ```{r race-wl} 102 | set.seed(4) 103 | grid_win_loss<- 104 | rda_spec |> 105 | tune_race_win_loss(Class ~ ., resamples = rs, grid = grid, control = ctrl) 106 | 107 | show_best(grid_win_loss, metric = "roc_auc", n = 2) 108 | ``` 109 | 110 | 111 | ## Contributing 112 | 113 | This project is released with a [Contributor Code of Conduct](https://contributor-covenant.org/version/2/0/CODE_OF_CONDUCT.html). By contributing to this project, you agree to abide by its terms. 114 | 115 | - For questions and discussions about tidymodels packages, modeling, and machine learning, please [post on Posit Community](https://forum.posit.co/new-topic?category_id=15&tags=tidymodels,question). 116 | 117 | - If you think you have encountered a bug, please [submit an issue](https://github.com/tidymodels/usemodels/issues). 118 | 119 | - Either way, learn how to create and share a [reprex](https://reprex.tidyverse.org/articles/articles/learn-reprex.html) (a minimal, reproducible example), to clearly communicate about your code. 120 | 121 | - Check out further details on [contributing guidelines for tidymodels packages](https://www.tidymodels.org/contribute/) and [how to get help](https://www.tidymodels.org/help/). 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # finetune 5 | 6 | 7 | 8 | [![R-CMD-check](https://github.com/tidymodels/finetune/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/tidymodels/finetune/actions/workflows/R-CMD-check.yaml) 9 | [![Codecov test 10 | coverage](https://codecov.io/gh/tidymodels/finetune/branch/main/graph/badge.svg)](https://app.codecov.io/gh/tidymodels/finetune?branch=main) 11 | [![Lifecycle](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html) 12 | [![Codecov test 13 | coverage](https://codecov.io/gh/tidymodels/finetune/graph/badge.svg)](https://app.codecov.io/gh/tidymodels/finetune) 14 | 15 | 16 | `finetune` contains some extra functions for model tuning that extend 17 | what is currently in the `tune` package. You can install the CRAN 18 | version of the package with the following code: 19 | 20 | ``` r 21 | install.packages("finetune") 22 | ``` 23 | 24 | To install the development version of the package, run: 25 | 26 | ``` r 27 | # install.packages("pak") 28 | pak::pak("tidymodels/finetune") 29 | ``` 30 | 31 | There are two main sets of tools in the package: *simulated annealing* 32 | and *racing*. 33 | 34 | Tuning via *simulated annealing* optimization is an iterative search 35 | tool for finding good values: 36 | 37 | ``` r 38 | library(tidymodels) 39 | library(finetune) 40 | 41 | # Syntax very similar to `tune_grid()` or `tune_bayes()`: 42 | 43 | ## ----------------------------------------------------------------------------- 44 | 45 | data(two_class_dat, package = "modeldata") 46 | 47 | set.seed(1) 48 | rs <- bootstraps(two_class_dat, times = 10) # more resamples usually needed 49 | 50 | # Optimize a regularized discriminant analysis model 51 | library(discrim) 52 | rda_spec <- 53 | discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) |> 54 | set_engine("klaR") 55 | 56 | ## ----------------------------------------------------------------------------- 57 | 58 | set.seed(2) 59 | sa_res <- 60 | rda_spec |> 61 | tune_sim_anneal(Class ~ ., resamples = rs, iter = 20, initial = 4) 62 | #> Optimizing roc_auc 63 | #> Initial best: 0.88281 64 | #> 1 ◯ accept suboptimal roc_auc=0.87797 (+/-0.004187) 65 | #> 2 + better suboptimal roc_auc=0.87811 (+/-0.004082) 66 | #> 3 ◯ accept suboptimal roc_auc=0.86938 (+/-0.005172) 67 | #> 4 ◯ accept suboptimal roc_auc=0.85949 (+/-0.006067) 68 | #> 5 ◯ accept suboptimal roc_auc=0.84727 (+/-0.006757) 69 | #> 6 ◯ accept suboptimal roc_auc=0.84441 (+/-0.006885) 70 | #> 7 + better suboptimal roc_auc=0.84715 (+/-0.006718) 71 | #> 8 ✖ restart from best roc_auc=0.85368 (+/-0.006366) 72 | #> 9 ◯ accept suboptimal roc_auc=0.88032 (+/-0.00397) 73 | #> 10 ◯ accept suboptimal roc_auc=0.87373 (+/-0.004807) 74 | #> 11 + better suboptimal roc_auc=0.87691 (+/-0.004533) 75 | #> 12 ◯ accept suboptimal roc_auc=0.86149 (+/-0.005802) 76 | #> 13 + better suboptimal roc_auc=0.86304 (+/-0.005684) 77 | #> 14 + better suboptimal roc_auc=0.87479 (+/-0.004721) 78 | #> 15 ◯ accept suboptimal roc_auc=0.86637 (+/-0.005425) 79 | #> 16 ✖ restart from best roc_auc=0.85841 (+/-0.006111) 80 | #> 17 ◯ accept suboptimal roc_auc=0.87862 (+/-0.004139) 81 | #> 18 + better suboptimal roc_auc=0.88011 (+/-0.004023) 82 | #> 19 ◯ accept suboptimal roc_auc=0.87175 (+/-0.004952) 83 | #> 20 ─ discard suboptimal roc_auc=0.86236 (+/-0.005762) 84 | show_best(sa_res, metric = "roc_auc", n = 2) 85 | #> # A tibble: 2 × 9 86 | #> frac_common_cov frac_identity .metric .estimator mean n std_err .config 87 | #> 88 | #> 1 0.667 0 roc_auc binary 0.883 10 0.00360 initial_… 89 | #> 2 0.793 0.0344 roc_auc binary 0.880 10 0.00397 Iter9 90 | #> # ℹ 1 more variable: .iter 91 | ``` 92 | 93 | The second set of methods are for *racing*. We start off by doing a 94 | small set of resamples for all of the grid points, then statistically 95 | testing to see which ones should be dropped or investigated more. The 96 | two methods here are based on those should in [Kuhn 97 | (2014)](https://arxiv.org/abs/1405.6974). 98 | 99 | For example, using an ANOVA-type analysis to filter out parameter 100 | combinations: 101 | 102 | ``` r 103 | set.seed(3) 104 | grid <- 105 | rda_spec |> 106 | extract_parameter_set_dials() |> 107 | grid_max_entropy(size = 20) 108 | #> Warning: `grid_max_entropy()` was deprecated in dials 1.3.0. 109 | #> ℹ Please use `grid_space_filling()` instead. 110 | #> This warning is displayed once every 8 hours. 111 | #> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was 112 | #> generated. 113 | 114 | ctrl <- control_race(verbose_elim = TRUE) 115 | 116 | set.seed(4) 117 | grid_anova <- 118 | rda_spec |> 119 | tune_race_anova(Class ~ ., resamples = rs, grid = grid, control = ctrl) 120 | #> ℹ Evaluating against the initial 3 burn-in resamples. 121 | #> ℹ Racing will maximize the roc_auc metric. 122 | #> ℹ Resamples are analyzed in a random order. 123 | #> ℹ Bootstrap10: 14 eliminated; 6 candidates remain. 124 | #> 125 | #> ℹ Bootstrap04: 2 eliminated; 4 candidates remain. 126 | #> 127 | #> ℹ Bootstrap03: All but one parameter combination were eliminated. 128 | 129 | show_best(grid_anova, metric = "roc_auc", n = 2) 130 | #> # A tibble: 1 × 8 131 | #> frac_common_cov frac_identity .metric .estimator mean n std_err .config 132 | #> 133 | #> 1 0.831 0.0207 roc_auc binary 0.881 10 0.00386 Preproce… 134 | ``` 135 | 136 | `tune_race_win_loss()` can also be used. It treats the tuning parameters 137 | as sports teams in a tournament and computed win/loss statistics. 138 | 139 | ``` r 140 | set.seed(4) 141 | grid_win_loss<- 142 | rda_spec |> 143 | tune_race_win_loss(Class ~ ., resamples = rs, grid = grid, control = ctrl) 144 | #> ℹ Racing will maximize the roc_auc metric. 145 | #> ℹ Resamples are analyzed in a random order. 146 | #> ℹ Bootstrap10: 3 eliminated; 17 candidates remain. 147 | #> 148 | #> ℹ Bootstrap04: 2 eliminated; 15 candidates remain. 149 | #> 150 | #> ℹ Bootstrap03: 2 eliminated; 13 candidates remain. 151 | #> 152 | #> ℹ Bootstrap01: 1 eliminated; 12 candidates remain. 153 | #> 154 | #> ℹ Bootstrap07: 1 eliminated; 11 candidates remain. 155 | #> 156 | #> ℹ Bootstrap05: 1 eliminated; 10 candidates remain. 157 | #> 158 | #> ℹ Bootstrap08: 1 eliminated; 9 candidates remain. 159 | 160 | show_best(grid_win_loss, metric = "roc_auc", n = 2) 161 | #> # A tibble: 2 × 8 162 | #> frac_common_cov frac_identity .metric .estimator mean n std_err .config 163 | #> 164 | #> 1 0.831 0.0207 roc_auc binary 0.881 10 0.00386 Preproce… 165 | #> 2 0.119 0.0470 roc_auc binary 0.879 10 0.00387 Preproce… 166 | ``` 167 | 168 | ## Contributing 169 | 170 | This project is released with a [Contributor Code of 171 | Conduct](https://contributor-covenant.org/version/2/0/CODE_OF_CONDUCT.html). 172 | By contributing to this project, you agree to abide by its terms. 173 | 174 | - For questions and discussions about tidymodels packages, modeling, and 175 | machine learning, please [post on Posit 176 | Community](https://forum.posit.co/new-topic?category_id=15&tags=tidymodels,question). 177 | 178 | - If you think you have encountered a bug, please [submit an 179 | issue](https://github.com/tidymodels/usemodels/issues). 180 | 181 | - Either way, learn how to create and share a 182 | [reprex](https://reprex.tidyverse.org/articles/articles/learn-reprex.html) 183 | (a minimal, reproducible example), to clearly communicate about your 184 | code. 185 | 186 | - Check out further details on [contributing guidelines for tidymodels 187 | packages](https://www.tidymodels.org/contribute/) and [how to get 188 | help](https://www.tidymodels.org/help/). 189 | -------------------------------------------------------------------------------- /_pkgdown.yml: -------------------------------------------------------------------------------- 1 | url: https://finetune.tidymodels.org 2 | 3 | template: 4 | package: tidytemplate 5 | bootstrap: 5 6 | bslib: 7 | danger: "#CA225E" 8 | primary: "#CA225E" 9 | includes: 10 | in_header: | 11 | 12 | 13 | development: 14 | mode: auto 15 | 16 | figures: 17 | fig.width: 8 18 | fig.height: 5.75 19 | 20 | -------------------------------------------------------------------------------- /air.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/finetune/13b1ba2d4b97e63b3a4e08d1e607e8c6917df22a/air.toml -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: auto 8 | threshold: 1% 9 | informational: true 10 | patch: 11 | default: 12 | target: auto 13 | threshold: 1% 14 | informational: true 15 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- 1 | ## R CMD check results 2 | 3 | 0 errors | 0 warnings | 0 notes 4 | 5 | -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | finetune.tidymodels.org -------------------------------------------------------------------------------- /finetune.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: No 4 | SaveWorkspace: No 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: knitr 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | LineEndingConversion: Posix 18 | 19 | BuildType: Package 20 | PackageUseDevtools: Yes 21 | PackageInstallArgs: --no-multiarch --with-keep.source 22 | PackageRoxygenize: rd,collate,namespace 23 | -------------------------------------------------------------------------------- /inst/WORDLIST: -------------------------------------------------------------------------------- 1 | Benchmarking 2 | Bohachevsky 3 | CMD 4 | Codecov 5 | Lifecycle 6 | ORCID 7 | PBC 8 | Technometrics 9 | doi 10 | dplyr 11 | funder 12 | ggplot 13 | magrittr 14 | pre 15 | preprocessor 16 | reprex 17 | suboptimal 18 | tibble 19 | tidymodels 20 | unsummarized 21 | wc 22 | -------------------------------------------------------------------------------- /inst/data-raw/sa_cart_test_objects.R: -------------------------------------------------------------------------------- 1 | library(finetune) 2 | library(rpart) 3 | library(dplyr) 4 | library(tune) 5 | library(rsample) 6 | library(parsnip) 7 | library(workflows) 8 | library(ggplot2) 9 | 10 | ## ----------------------------------------------------------------------------- 11 | 12 | data(two_class_dat, package = "modeldata") 13 | 14 | set.seed(5046) 15 | bt <- bootstraps(two_class_dat, times = 5) 16 | 17 | ## ----------------------------------------------------------------------------- 18 | 19 | cart_mod <- 20 | decision_tree(cost_complexity = tune(), min_n = tune()) |> 21 | set_engine("rpart") |> 22 | set_mode("classification") 23 | 24 | ## ----------------------------------------------------------------------------- 25 | 26 | ctrl <- control_sim_anneal(save_history = TRUE) 27 | 28 | set.seed(2981) 29 | # For reproducibility, set the seed before running. 30 | cart_search <- 31 | cart_mod |> 32 | tune_sim_anneal(Class ~ ., resamples = bt, iter = 12, control = ctrl) 33 | 34 | load(file.path(tempdir(), "sa_history.RData")) 35 | cart_history <- result_history 36 | 37 | save( 38 | cart_history, 39 | cart_search, 40 | file = file.path(testthat::test_path(), "sa_cart_test_objects.RData"), 41 | version = 2, 42 | compress = "xz" 43 | ) 44 | 45 | if (!interactive()) { 46 | q("no") 47 | } 48 | -------------------------------------------------------------------------------- /man/collect_predictions.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/racing_helpers.R 3 | \name{collect_predictions} 4 | \alias{collect_predictions} 5 | \alias{collect_predictions.tune_race} 6 | \alias{collect_metrics.tune_race} 7 | \title{Obtain and format results produced by racing functions} 8 | \usage{ 9 | \method{collect_predictions}{tune_race}( 10 | x, 11 | ..., 12 | summarize = FALSE, 13 | parameters = NULL, 14 | all_configs = FALSE 15 | ) 16 | 17 | \method{collect_metrics}{tune_race}( 18 | x, 19 | ..., 20 | summarize = TRUE, 21 | type = c("long", "wide"), 22 | all_configs = FALSE 23 | ) 24 | } 25 | \arguments{ 26 | \item{x}{The results of \code{\link[tune:tune_grid]{tune_grid()}}, \code{\link[tune:tune_bayes]{tune_bayes()}}, \code{\link[tune:fit_resamples]{fit_resamples()}}, 27 | or \code{\link[tune:last_fit]{last_fit()}}. For \code{\link[tune:collect_predictions]{collect_predictions()}}, the control option \code{save_pred = TRUE} should have been used.} 28 | 29 | \item{...}{Not currently used.} 30 | 31 | \item{summarize}{A logical; should metrics be summarized over resamples 32 | (\code{TRUE}) or return the values for each individual resample. Note that, if \code{x} 33 | is created by \code{\link[tune:last_fit]{last_fit()}}, \code{summarize} has no effect. For the other object 34 | types, the method of summarizing predictions is detailed below.} 35 | 36 | \item{parameters}{An optional tibble of tuning parameter values that can be 37 | used to filter the predicted values before processing. This tibble should 38 | only have columns for each tuning parameter identifier (e.g. \code{"my_param"} 39 | if \code{tune("my_param")} was used).} 40 | 41 | \item{all_configs}{A logical: should we return the complete set of model 42 | configurations or just those that made it to the end of the race (the 43 | default).} 44 | 45 | \item{type}{One of \code{"long"} (the default) or \code{"wide"}. When \code{type = "long"}, 46 | output has columns \code{.metric} and one of \code{.estimate} or \code{mean}. 47 | \code{.estimate}/\code{mean} gives the values for the \code{.metric}. When \code{type = "wide"}, 48 | each metric has its own column and the \code{n} and \code{std_err} columns are removed, 49 | if they exist.} 50 | } 51 | \value{ 52 | A tibble. The column names depend on the results and the mode of the 53 | model. 54 | } 55 | \description{ 56 | Obtain and format results produced by racing functions 57 | } 58 | \details{ 59 | For \code{\link[tune:collect_predictions]{tune::collect_metrics()}} and \code{\link[tune:collect_predictions]{tune::collect_predictions()}}, when unsummarized, 60 | there are columns for each tuning parameter (using the \code{id} from \code{\link[hardhat:tune]{hardhat::tune()}}, 61 | if any). 62 | \code{\link[tune:collect_predictions]{tune::collect_metrics()}} also has columns \code{.metric}, and \code{.estimator}. When the 63 | results are summarized, there are columns for \code{mean}, \code{n}, and \code{std_err}. 64 | When not summarized, the additional columns for the resampling identifier(s) 65 | and \code{.estimate}. 66 | 67 | For \code{\link[tune:collect_predictions]{tune::collect_predictions()}}, there are additional columns for the resampling 68 | identifier(s), columns for the predicted values (e.g., \code{.pred}, 69 | \code{.pred_class}, etc.), and a column for the outcome(s) using the original 70 | column name(s) in the data. 71 | 72 | \code{\link[tune:collect_predictions]{tune::collect_predictions()}} can summarize the various results over 73 | replicate out-of-sample predictions. For example, when using the bootstrap, 74 | each row in the original training set has multiple holdout predictions 75 | (across assessment sets). To convert these results to a format where every 76 | training set same has a single predicted value, the results are averaged 77 | over replicate predictions. 78 | 79 | For regression cases, the numeric predictions are simply averaged. For 80 | classification models, the problem is more complex. When class probabilities 81 | are used, these are averaged and then re-normalized to make sure that they 82 | add to one. If hard class predictions also exist in the data, then these are 83 | determined from the summarized probability estimates (so that they match). 84 | If only hard class predictions are in the results, then the mode is used to 85 | summarize. 86 | 87 | For racing results, it is best to only 88 | collect model configurations that finished the race (i.e., were completely 89 | resampled). Comparing performance metrics for configurations averaged with 90 | different resamples is likely to lead to inappropriate results. 91 | } 92 | -------------------------------------------------------------------------------- /man/control_race.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/control_race.R 3 | \name{control_race} 4 | \alias{control_race} 5 | \title{Control aspects of the grid search racing process} 6 | \usage{ 7 | control_race( 8 | verbose = FALSE, 9 | verbose_elim = FALSE, 10 | allow_par = TRUE, 11 | extract = NULL, 12 | save_pred = FALSE, 13 | burn_in = 3, 14 | num_ties = 10, 15 | alpha = 0.05, 16 | randomize = TRUE, 17 | pkgs = NULL, 18 | save_workflow = FALSE, 19 | event_level = "first", 20 | parallel_over = "everything", 21 | backend_options = NULL 22 | ) 23 | } 24 | \arguments{ 25 | \item{verbose}{A logical for logging results (other than warnings and errors, 26 | which are always shown) as they are generated during training in a single 27 | R process. When using most parallel backends, this argument typically will 28 | not result in any logging. If using a dark IDE theme, some logging messages 29 | might be hard to see; try setting the \code{tidymodels.dark} option with 30 | \code{options(tidymodels.dark = TRUE)} to print lighter colors.} 31 | 32 | \item{verbose_elim}{A logical for whether logging of the elimination of 33 | tuning parameter combinations should occur.} 34 | 35 | \item{allow_par}{A logical to allow parallel processing (if a parallel 36 | backend is registered).} 37 | 38 | \item{extract}{An optional function with at least one argument (or \code{NULL}) 39 | that can be used to retain arbitrary objects from the model fit object, 40 | recipe, or other elements of the workflow.} 41 | 42 | \item{save_pred}{A logical for whether the out-of-sample predictions should 43 | be saved for each model \emph{evaluated}.} 44 | 45 | \item{burn_in}{An integer for how many resamples should be completed for all 46 | grid combinations before parameter filtering begins.} 47 | 48 | \item{num_ties}{An integer for when tie-breaking should occur. If there are 49 | two final parameter combinations being evaluated, \code{num_ties} specified how 50 | many more resampling iterations should be evaluated. After \code{num_ties} more 51 | iterations, the parameter combination with the current best results is 52 | retained.} 53 | 54 | \item{alpha}{The alpha level for a one-sided confidence interval for each 55 | parameter combination.} 56 | 57 | \item{randomize}{Should the resamples be evaluated in a random order? By 58 | default, the resamples are evaluated in a random order so the random number 59 | seed should be control prior to calling this method (to be reproducible). 60 | For repeated cross-validation the randomization occurs within each repeat.} 61 | 62 | \item{pkgs}{An optional character string of R package names that should be 63 | loaded (by namespace) during parallel processing.} 64 | 65 | \item{save_workflow}{A logical for whether the workflow should be appended 66 | to the output as an attribute.} 67 | 68 | \item{event_level}{A single string containing either \code{"first"} or \code{"second"}. 69 | This argument is passed on to yardstick metric functions when any type 70 | of class prediction is made, and specifies which level of the outcome 71 | is considered the "event".} 72 | 73 | \item{parallel_over}{A single string containing either \code{"resamples"} or 74 | \code{"everything"} describing how to use parallel processing. Alternatively, 75 | \code{NULL} is allowed, which chooses between \code{"resamples"} and \code{"everything"} 76 | automatically. 77 | 78 | If \code{"resamples"}, then tuning will be performed in parallel over resamples 79 | alone. Within each resample, the preprocessor (i.e. recipe or formula) is 80 | processed once, and is then reused across all models that need to be fit. 81 | 82 | If \code{"everything"}, then tuning will be performed in parallel at two levels. 83 | An outer parallel loop will iterate over resamples. Additionally, an 84 | inner parallel loop will iterate over all unique combinations of 85 | preprocessor and model tuning parameters for that specific resample. This 86 | will result in the preprocessor being re-processed multiple times, but 87 | can be faster if that processing is extremely fast. 88 | 89 | If \code{NULL}, chooses \code{"resamples"} if there are more than one resample, 90 | otherwise chooses \code{"everything"} to attempt to maximize core utilization. 91 | 92 | Note that switching between \code{parallel_over} strategies is not guaranteed 93 | to use the same random number generation schemes. However, re-tuning a 94 | model using the same \code{parallel_over} strategy is guaranteed to be 95 | reproducible between runs.} 96 | 97 | \item{backend_options}{An object of class \code{"tune_backend_options"} as created 98 | by \code{tune::new_backend_options()}, used to pass arguments to specific tuning 99 | backend. Defaults to \code{NULL} for default backend options.} 100 | } 101 | \value{ 102 | An object of class \code{control_race} that echos the argument values. 103 | } 104 | \description{ 105 | Control aspects of the grid search racing process 106 | } 107 | \examples{ 108 | control_race() 109 | } 110 | -------------------------------------------------------------------------------- /man/control_sim_anneal.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/control_sim_anneal.R 3 | \name{control_sim_anneal} 4 | \alias{control_sim_anneal} 5 | \title{Control aspects of the simulated annealing search process} 6 | \usage{ 7 | control_sim_anneal( 8 | verbose = FALSE, 9 | verbose_iter = TRUE, 10 | no_improve = Inf, 11 | restart = 8L, 12 | radius = c(0.05, 0.15), 13 | flip = 3/4, 14 | cooling_coef = 0.02, 15 | extract = NULL, 16 | save_pred = FALSE, 17 | time_limit = NA, 18 | pkgs = NULL, 19 | save_workflow = FALSE, 20 | save_history = FALSE, 21 | event_level = "first", 22 | parallel_over = NULL, 23 | allow_par = TRUE, 24 | backend_options = NULL 25 | ) 26 | } 27 | \arguments{ 28 | \item{verbose}{A logical for logging results (other than warnings and errors, 29 | which are always shown) as they are generated during training in a single 30 | R process. When using most parallel backends, this argument typically will 31 | not result in any logging. If using a dark IDE theme, some logging messages 32 | might be hard to see; try setting the \code{tidymodels.dark} option with 33 | \code{options(tidymodels.dark = TRUE)} to print lighter colors.} 34 | 35 | \item{verbose_iter}{A logical for logging results of the search 36 | process. Defaults to FALSE. If using a dark IDE theme, some logging 37 | messages might be hard to see; try setting the \code{tidymodels.dark} option 38 | with \code{options(tidymodels.dark = TRUE)} to print lighter colors.} 39 | 40 | \item{no_improve}{The integer cutoff for the number of iterations without 41 | better results.} 42 | 43 | \item{restart}{The number of iterations with no improvement before new tuning 44 | parameter candidates are generated from the last, overall best conditions.} 45 | 46 | \item{radius}{Two real numbers on \verb{(0, 1)} describing what a value "in the 47 | neighborhood" of the current result should be. If all numeric parameters were 48 | scaled to be on the \verb{[0, 1]} scale, these values set the min. and max. 49 | of a radius of a circle used to generate new numeric parameter values.} 50 | 51 | \item{flip}{A real number between \verb{[0, 1]} for the probability of changing 52 | any non-numeric parameter values at each iteration.} 53 | 54 | \item{cooling_coef}{A real, positive number to influence the cooling 55 | schedule. Larger values decrease the probability of accepting a sub-optimal 56 | parameter setting.} 57 | 58 | \item{extract}{An optional function with at least one argument (or \code{NULL}) 59 | that can be used to retain arbitrary objects from the model fit object, 60 | recipe, or other elements of the workflow.} 61 | 62 | \item{save_pred}{A logical for whether the out-of-sample predictions should 63 | be saved for each model \emph{evaluated}.} 64 | 65 | \item{time_limit}{A number for the minimum number of \emph{minutes} (elapsed) that 66 | the function should execute. The elapsed time is evaluated at internal 67 | checkpoints and, if over time, the results at that time are returned (with 68 | a warning). This means that the \code{time_limit} is not an exact limit, but a 69 | minimum time limit.} 70 | 71 | \item{pkgs}{An optional character string of R package names that should be 72 | loaded (by namespace) during parallel processing.} 73 | 74 | \item{save_workflow}{A logical for whether the workflow should be appended 75 | to the output as an attribute.} 76 | 77 | \item{save_history}{A logical to save the iteration details of the search. 78 | These are saved to \code{tempdir()} named \code{sa_history.RData}. These results are 79 | deleted when the R session ends. This option is only useful for teaching 80 | purposes.} 81 | 82 | \item{event_level}{A single string containing either \code{"first"} or \code{"second"}. 83 | This argument is passed on to yardstick metric functions when any type 84 | of class prediction is made, and specifies which level of the outcome 85 | is considered the "event".} 86 | 87 | \item{parallel_over}{A single string containing either \code{"resamples"} or 88 | \code{"everything"} describing how to use parallel processing. Alternatively, 89 | \code{NULL} is allowed, which chooses between \code{"resamples"} and \code{"everything"} 90 | automatically. 91 | 92 | If \code{"resamples"}, then tuning will be performed in parallel over resamples 93 | alone. Within each resample, the preprocessor (i.e. recipe or formula) is 94 | processed once, and is then reused across all models that need to be fit. 95 | 96 | If \code{"everything"}, then tuning will be performed in parallel at two levels. 97 | An outer parallel loop will iterate over resamples. Additionally, an 98 | inner parallel loop will iterate over all unique combinations of 99 | preprocessor and model tuning parameters for that specific resample. This 100 | will result in the preprocessor being re-processed multiple times, but 101 | can be faster if that processing is extremely fast. 102 | 103 | If \code{NULL}, chooses \code{"resamples"} if there are more than one resample, 104 | otherwise chooses \code{"everything"} to attempt to maximize core utilization. 105 | 106 | Note that switching between \code{parallel_over} strategies is not guaranteed 107 | to use the same random number generation schemes. However, re-tuning a 108 | model using the same \code{parallel_over} strategy is guaranteed to be 109 | reproducible between runs.} 110 | 111 | \item{allow_par}{A logical to allow parallel processing (if a parallel 112 | backend is registered).} 113 | 114 | \item{backend_options}{An object of class \code{"tune_backend_options"} as created 115 | by \code{tune::new_backend_options()}, used to pass arguments to specific tuning 116 | backend. Defaults to \code{NULL} for default backend options.} 117 | } 118 | \value{ 119 | An object of class \code{control_sim_anneal} that echos the argument values. 120 | } 121 | \description{ 122 | Control aspects of the simulated annealing search process 123 | } 124 | \examples{ 125 | control_sim_anneal() 126 | } 127 | -------------------------------------------------------------------------------- /man/figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/finetune/13b1ba2d4b97e63b3a4e08d1e607e8c6917df22a/man/figures/logo.png -------------------------------------------------------------------------------- /man/finetune-package.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/finetune-package.R 3 | \docType{package} 4 | \name{finetune-package} 5 | \alias{finetune} 6 | \alias{finetune-package} 7 | \title{finetune: Additional Functions for Model Tuning} 8 | \description{ 9 | \if{html}{\figure{logo.png}{options: style='float: right' alt='logo' width='120'}} 10 | 11 | The ability to tune models is important. 'finetune' enhances the 'tune' package by providing more specialized methods for finding reasonable values of model tuning parameters. Two racing methods described by Kuhn (2014) \doi{10.48550/arXiv.1405.6974} are included. An iterative search method using generalized simulated annealing (Bohachevsky, Johnson and Stein, 1986) \doi{10.1080/00401706.1986.10488128} is also included. 12 | } 13 | \seealso{ 14 | Useful links: 15 | \itemize{ 16 | \item \url{https://github.com/tidymodels/finetune} 17 | \item \url{https://finetune.tidymodels.org} 18 | \item Report bugs at \url{https://github.com/tidymodels/finetune/issues} 19 | } 20 | 21 | } 22 | \author{ 23 | \strong{Maintainer}: Max Kuhn \email{max@posit.co} (\href{https://orcid.org/0000-0003-2402-136X}{ORCID}) 24 | 25 | Other contributors: 26 | \itemize{ 27 | \item Posit Software, PBC (03wc8by49) [copyright holder, funder] 28 | } 29 | 30 | } 31 | \keyword{internal} 32 | -------------------------------------------------------------------------------- /man/plot_race.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/plot_race.R 3 | \name{plot_race} 4 | \alias{plot_race} 5 | \title{Plot racing results} 6 | \usage{ 7 | plot_race(x) 8 | } 9 | \arguments{ 10 | \item{x}{A object with class \code{tune_results}} 11 | } 12 | \value{ 13 | A ggplot object. 14 | } 15 | \description{ 16 | Plot the model results over stages of the racing results. A line is given 17 | for each submodel that was tested. 18 | } 19 | -------------------------------------------------------------------------------- /man/rmd/anova-benchmark.md: -------------------------------------------------------------------------------- 1 | ## Benchmarking results 2 | 3 | To demonstrate, we use a SVM model with the `kernlab` package. 4 | 5 | ```r 6 | library(kernlab) 7 | library(tidymodels) 8 | library(finetune) 9 | library(doParallel) 10 | 11 | ## ----------------------------------------------------------------------------- 12 | 13 | data(cells, package = "modeldata") 14 | cells <- cells |> select(-case) 15 | 16 | ## ----------------------------------------------------------------------------- 17 | 18 | set.seed(6376) 19 | rs <- bootstraps(cells, times = 25) 20 | ``` 21 | 22 | We'll only tune the model parameters (i.e., not recipe tuning): 23 | 24 | ```r 25 | ## ----------------------------------------------------------------------------- 26 | 27 | svm_spec <- 28 | svm_rbf(cost = tune(), rbf_sigma = tune()) |> 29 | set_engine("kernlab") |> 30 | set_mode("classification") 31 | 32 | svm_rec <- 33 | recipe(class ~ ., data = cells) |> 34 | step_YeoJohnson(all_predictors()) |> 35 | step_normalize(all_predictors()) 36 | 37 | svm_wflow <- 38 | workflow() |> 39 | add_model(svm_spec) |> 40 | add_recipe(svm_rec) 41 | 42 | set.seed(1) 43 | svm_grid <- 44 | svm_spec |> 45 | parameters() |> 46 | grid_latin_hypercube(size = 25) 47 | ``` 48 | 49 | We'll get the times for grid search and ANOVA racing with and without parallel processing: 50 | 51 | ```r 52 | ## ----------------------------------------------------------------------------- 53 | ## Regular grid search 54 | 55 | system.time({ 56 | set.seed(2) 57 | svm_wflow |> tune_grid(resamples = rs, grid = svm_grid) 58 | }) 59 | ``` 60 | 61 | ``` 62 | ## user system elapsed 63 | ## 741.660 19.654 761.357 64 | ``` 65 | 66 | 67 | ```r 68 | ## ----------------------------------------------------------------------------- 69 | ## With racing 70 | 71 | system.time({ 72 | set.seed(2) 73 | svm_wflow |> tune_race_anova(resamples = rs, grid = svm_grid) 74 | }) 75 | ``` 76 | 77 | ``` 78 | ## user system elapsed 79 | ## 133.143 3.675 136.822 80 | ``` 81 | 82 | Speed-up of 5.56-fold for racing. 83 | 84 | 85 | ```r 86 | ## ----------------------------------------------------------------------------- 87 | ## Parallel processing setup 88 | 89 | cores <- parallel::detectCores(logical = FALSE) 90 | cores 91 | ``` 92 | 93 | ``` 94 | ## [1] 10 95 | ``` 96 | 97 | ```r 98 | cl <- makePSOCKcluster(cores) 99 | registerDoParallel(cl) 100 | ``` 101 | 102 | 103 | ```r 104 | ## ----------------------------------------------------------------------------- 105 | ## Parallel grid search 106 | 107 | system.time({ 108 | set.seed(2) 109 | svm_wflow |> tune_grid(resamples = rs, grid = svm_grid) 110 | }) 111 | ``` 112 | 113 | ``` 114 | ## user system elapsed 115 | ## 1.112 0.190 126.650 116 | ``` 117 | 118 | Parallel processing with grid search was 6.01-fold faster than sequential grid search. 119 | 120 | 121 | ```r 122 | ## ----------------------------------------------------------------------------- 123 | ## Parallel racing 124 | 125 | system.time({ 126 | set.seed(2) 127 | svm_wflow |> tune_race_anova(resamples = rs, grid = svm_grid) 128 | }) 129 | ``` 130 | 131 | ``` 132 | ## user system elapsed 133 | ## 1.908 0.261 21.442 134 | ``` 135 | 136 | Parallel processing with racing was 35.51-fold faster than sequential grid search. 137 | 138 | There is a compounding effect of racing and parallel processing but its magnitude depends on the type of model, number of resamples, number of tuning parameters, and so on. 139 | 140 | 141 | -------------------------------------------------------------------------------- /man/show_best.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/racing_helpers.R 3 | \name{show_best.tune_race} 4 | \alias{show_best.tune_race} 5 | \title{Investigate best tuning parameters} 6 | \usage{ 7 | \method{show_best}{tune_race}( 8 | x, 9 | ..., 10 | metric = NULL, 11 | eval_time = NULL, 12 | n = 5, 13 | call = rlang::current_env() 14 | ) 15 | } 16 | \arguments{ 17 | \item{x}{The results of \code{\link[tune:tune_grid]{tune_grid()}} or \code{\link[tune:tune_bayes]{tune_bayes()}}.} 18 | 19 | \item{...}{For \code{\link[tune:select_by_one_std_err]{select_by_one_std_err()}} and \code{\link[tune:select_by_pct_loss]{select_by_pct_loss()}}, this 20 | argument is passed directly to \code{\link[dplyr:arrange]{dplyr::arrange()}} so that the user can sort 21 | the models from \emph{most simple to most complex}. That is, for a parameter \code{p}, 22 | pass the unquoted expression \code{p} if smaller values of \code{p} indicate a simpler 23 | model, or \code{desc(p)} if larger values indicate a simpler model. At 24 | least one term is required for these two functions. See the examples below.} 25 | 26 | \item{metric}{A character value for the metric that will be used to sort 27 | the models. (See 28 | \url{https://yardstick.tidymodels.org/articles/metric-types.html} for 29 | more details). Not required if a single metric exists in \code{x}. If there are 30 | multiple metric and none are given, the first in the metric set is used (and 31 | a warning is issued).} 32 | 33 | \item{eval_time}{A single numeric time point where dynamic event time 34 | metrics should be chosen (e.g., the time-dependent ROC curve, etc). The 35 | values should be consistent with the values used to create \code{x}. The \code{NULL} 36 | default will automatically use the first evaluation time used by \code{x}.} 37 | 38 | \item{n}{An integer for the maximum number of top results/rows to return.} 39 | 40 | \item{call}{The call to be shown in errors and warnings.} 41 | } 42 | \description{ 43 | \code{\link[tune:show_best]{tune::show_best()}} displays the top sub-models and their performance estimates. 44 | } 45 | \details{ 46 | For racing results (from the \pkg{finetune} package), it is best to only 47 | report configurations that finished the race (i.e., were completely 48 | resampled). Comparing performance metrics for configurations averaged with 49 | different resamples is likely to lead to inappropriate results. 50 | } 51 | -------------------------------------------------------------------------------- /man/tune_race_anova.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tune_race_anova.R 3 | \name{tune_race_anova} 4 | \alias{tune_race_anova} 5 | \alias{tune_race_anova.model_spec} 6 | \alias{tune_race_anova.workflow} 7 | \title{Efficient grid search via racing with ANOVA models} 8 | \usage{ 9 | tune_race_anova(object, ...) 10 | 11 | \method{tune_race_anova}{model_spec}( 12 | object, 13 | preprocessor, 14 | resamples, 15 | ..., 16 | param_info = NULL, 17 | grid = 10, 18 | metrics = NULL, 19 | eval_time = NULL, 20 | control = control_race() 21 | ) 22 | 23 | \method{tune_race_anova}{workflow}( 24 | object, 25 | resamples, 26 | ..., 27 | param_info = NULL, 28 | grid = 10, 29 | metrics = NULL, 30 | eval_time = NULL, 31 | control = control_race() 32 | ) 33 | } 34 | \arguments{ 35 | \item{object}{A \code{parsnip} model specification or a \code{\link[workflows:workflow]{workflows::workflow()}}.} 36 | 37 | \item{...}{Not currently used.} 38 | 39 | \item{preprocessor}{A traditional model formula or a recipe created using 40 | \code{\link[recipes:recipe]{recipes::recipe()}}. This is only required when \code{object} is not a workflow.} 41 | 42 | \item{resamples}{An \code{rset()} object that has multiple resamples (i.e., is not 43 | a validation set).} 44 | 45 | \item{param_info}{A \code{\link[dials:parameters]{dials::parameters()}} object or \code{NULL}. If none is given, 46 | a parameters set is derived from other arguments. Passing this argument can 47 | be useful when parameter ranges need to be customized.} 48 | 49 | \item{grid}{A data frame of tuning combinations or a positive integer. The 50 | data frame should have columns for each parameter being tuned and rows for 51 | tuning parameter candidates. An integer denotes the number of candidate 52 | parameter sets to be created automatically.} 53 | 54 | \item{metrics}{A \code{\link[yardstick:metric_set]{yardstick::metric_set()}} or \code{NULL}.} 55 | 56 | \item{eval_time}{A numeric vector of time points where dynamic event time 57 | metrics should be computed (e.g. the time-dependent ROC curve, etc). The 58 | values must be non-negative and should probably be no greater than the 59 | largest event time in the training set (See Details below).} 60 | 61 | \item{control}{An object used to modify the tuning process. See 62 | \code{\link[=control_race]{control_race()}} for more details.} 63 | } 64 | \value{ 65 | An object with primary class \code{tune_race} in the same standard format 66 | as objects produced by \code{\link[tune:tune_grid]{tune::tune_grid()}}. 67 | } 68 | \description{ 69 | \code{\link[=tune_race_anova]{tune_race_anova()}} computes a set of performance metrics (e.g. accuracy or RMSE) 70 | for a pre-defined set of tuning parameters that correspond to a model or 71 | recipe across one or more resamples of the data. After an initial number of 72 | resamples have been evaluated, the process eliminates tuning parameter 73 | combinations that are unlikely to be the best results using a repeated 74 | measure ANOVA model. 75 | } 76 | \details{ 77 | The technical details of this method are described in Kuhn (2014). 78 | 79 | Racing methods are efficient approaches to grid search. Initially, the 80 | function evaluates all tuning parameters on a small initial set of 81 | resamples. The \code{burn_in} argument of \code{\link[=control_race]{control_race()}} sets the number of 82 | initial resamples. 83 | 84 | The performance statistics from these resamples are analyzed to determine 85 | which tuning parameters are \emph{not} statistically different from the current 86 | best setting. If a parameter is statistically different, it is excluded from 87 | further resampling. 88 | 89 | The next resample is used with the remaining parameter combinations and the 90 | statistical analysis is updated. More candidate parameters may be excluded 91 | with each new resample that is processed. 92 | 93 | This function determines statistical significance using a repeated measures ANOVA 94 | model where the performance statistic (e.g., RMSE, accuracy, etc.) is the 95 | outcome data and the random effect is due to resamples. The 96 | \code{\link[=control_race]{control_race()}} function contains are parameter for the significance cutoff 97 | applied to the ANOVA results as well as other relevant arguments. 98 | 99 | There is benefit to using racing methods in conjunction with parallel 100 | processing. The following section shows a benchmark of results for one 101 | dataset and model. 102 | \subsection{Censored regression models}{ 103 | 104 | With dynamic performance metrics (e.g. Brier or ROC curves), performance is 105 | calculated for every value of \code{eval_time} but the \emph{first} evaluation time 106 | given by the user (e.g., \code{eval_time[1]}) is analyzed during racing. 107 | 108 | Also, values of \code{eval_time} should be less than the largest observed event 109 | time in the training data. For many non-parametric models, the results beyond 110 | the largest time corresponding to an event are constant (or \code{NA}). 111 | } 112 | 113 | \subsection{Benchmarking results}{ 114 | 115 | To demonstrate, we use a SVM model with the \code{kernlab} package. 116 | 117 | \if{html}{\out{
}}\preformatted{library(kernlab) 118 | library(tidymodels) 119 | library(finetune) 120 | library(doParallel) 121 | 122 | ## ----------------------------------------------------------------------------- 123 | 124 | data(cells, package = "modeldata") 125 | cells <- cells |> select(-case) 126 | 127 | ## ----------------------------------------------------------------------------- 128 | 129 | set.seed(6376) 130 | rs <- bootstraps(cells, times = 25) 131 | }\if{html}{\out{
}} 132 | 133 | We’ll only tune the model parameters (i.e., not recipe tuning): 134 | 135 | \if{html}{\out{
}}\preformatted{## ----------------------------------------------------------------------------- 136 | 137 | svm_spec <- 138 | svm_rbf(cost = tune(), rbf_sigma = tune()) |> 139 | set_engine("kernlab") |> 140 | set_mode("classification") 141 | 142 | svm_rec <- 143 | recipe(class ~ ., data = cells) |> 144 | step_YeoJohnson(all_predictors()) |> 145 | step_normalize(all_predictors()) 146 | 147 | svm_wflow <- 148 | workflow() |> 149 | add_model(svm_spec) |> 150 | add_recipe(svm_rec) 151 | 152 | set.seed(1) 153 | svm_grid <- 154 | svm_spec |> 155 | parameters() |> 156 | grid_latin_hypercube(size = 25) 157 | }\if{html}{\out{
}} 158 | 159 | We’ll get the times for grid search and ANOVA racing with and without 160 | parallel processing: 161 | 162 | \if{html}{\out{
}}\preformatted{## ----------------------------------------------------------------------------- 163 | ## Regular grid search 164 | 165 | system.time(\{ 166 | set.seed(2) 167 | svm_wflow |> tune_grid(resamples = rs, grid = svm_grid) 168 | \}) 169 | }\if{html}{\out{
}} 170 | 171 | \if{html}{\out{
}}\preformatted{## user system elapsed 172 | ## 741.660 19.654 761.357 173 | }\if{html}{\out{
}} 174 | 175 | \if{html}{\out{
}}\preformatted{## ----------------------------------------------------------------------------- 176 | ## With racing 177 | 178 | system.time(\{ 179 | set.seed(2) 180 | svm_wflow |> tune_race_anova(resamples = rs, grid = svm_grid) 181 | \}) 182 | }\if{html}{\out{
}} 183 | 184 | \if{html}{\out{
}}\preformatted{## user system elapsed 185 | ## 133.143 3.675 136.822 186 | }\if{html}{\out{
}} 187 | 188 | Speed-up of 5.56-fold for racing. 189 | 190 | \if{html}{\out{
}}\preformatted{## ----------------------------------------------------------------------------- 191 | ## Parallel processing setup 192 | 193 | cores <- parallel::detectCores(logical = FALSE) 194 | cores 195 | }\if{html}{\out{
}} 196 | 197 | \if{html}{\out{
}}\preformatted{## [1] 10 198 | }\if{html}{\out{
}} 199 | 200 | \if{html}{\out{
}}\preformatted{cl <- makePSOCKcluster(cores) 201 | registerDoParallel(cl) 202 | }\if{html}{\out{
}} 203 | 204 | \if{html}{\out{
}}\preformatted{## ----------------------------------------------------------------------------- 205 | ## Parallel grid search 206 | 207 | system.time(\{ 208 | set.seed(2) 209 | svm_wflow |> tune_grid(resamples = rs, grid = svm_grid) 210 | \}) 211 | }\if{html}{\out{
}} 212 | 213 | \if{html}{\out{
}}\preformatted{## user system elapsed 214 | ## 1.112 0.190 126.650 215 | }\if{html}{\out{
}} 216 | 217 | Parallel processing with grid search was 6.01-fold faster than 218 | sequential grid search. 219 | 220 | \if{html}{\out{
}}\preformatted{## ----------------------------------------------------------------------------- 221 | ## Parallel racing 222 | 223 | system.time(\{ 224 | set.seed(2) 225 | svm_wflow |> tune_race_anova(resamples = rs, grid = svm_grid) 226 | \}) 227 | }\if{html}{\out{
}} 228 | 229 | \if{html}{\out{
}}\preformatted{## user system elapsed 230 | ## 1.908 0.261 21.442 231 | }\if{html}{\out{
}} 232 | 233 | Parallel processing with racing was 35.51-fold faster than sequential 234 | grid search. 235 | 236 | There is a compounding effect of racing and parallel processing but its 237 | magnitude depends on the type of model, number of resamples, number of 238 | tuning parameters, and so on. 239 | } 240 | } 241 | \examples{ 242 | \donttest{ 243 | library(parsnip) 244 | library(rsample) 245 | library(dials) 246 | 247 | ## ----------------------------------------------------------------------------- 248 | 249 | if (rlang::is_installed(c("discrim", "lme4", "modeldata"))) { 250 | library(discrim) 251 | data(two_class_dat, package = "modeldata") 252 | 253 | set.seed(6376) 254 | rs <- bootstraps(two_class_dat, times = 10) 255 | 256 | ## ----------------------------------------------------------------------------- 257 | 258 | # optimize an regularized discriminant analysis model 259 | rda_spec <- 260 | discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) |> 261 | set_engine("klaR") 262 | 263 | ## ----------------------------------------------------------------------------- 264 | 265 | ctrl <- control_race(verbose_elim = TRUE) 266 | set.seed(11) 267 | grid_anova <- 268 | rda_spec |> 269 | tune_race_anova(Class ~ ., resamples = rs, grid = 10, control = ctrl) 270 | 271 | # Shows only the fully resampled parameters 272 | show_best(grid_anova, metric = "roc_auc", n = 2) 273 | 274 | plot_race(grid_anova) 275 | } 276 | } 277 | } 278 | \references{ 279 | Kuhn, M 2014. "Futility Analysis in the Cross-Validation of Machine Learning 280 | Models." \url{https://arxiv.org/abs/1405.6974}. 281 | } 282 | \seealso{ 283 | \code{\link[tune:tune_grid]{tune::tune_grid()}}, \code{\link[=control_race]{control_race()}}, \code{\link[=tune_race_win_loss]{tune_race_win_loss()}} 284 | } 285 | -------------------------------------------------------------------------------- /man/tune_race_win_loss.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tune_race_win_loss.R 3 | \name{tune_race_win_loss} 4 | \alias{tune_race_win_loss} 5 | \alias{tune_race_win_loss.model_spec} 6 | \alias{tune_race_win_loss.workflow} 7 | \title{Efficient grid search via racing with win/loss statistics} 8 | \usage{ 9 | tune_race_win_loss(object, ...) 10 | 11 | \method{tune_race_win_loss}{model_spec}( 12 | object, 13 | preprocessor, 14 | resamples, 15 | ..., 16 | param_info = NULL, 17 | grid = 10, 18 | metrics = NULL, 19 | eval_time = NULL, 20 | control = control_race() 21 | ) 22 | 23 | \method{tune_race_win_loss}{workflow}( 24 | object, 25 | resamples, 26 | ..., 27 | param_info = NULL, 28 | grid = 10, 29 | metrics = NULL, 30 | eval_time = NULL, 31 | control = control_race() 32 | ) 33 | } 34 | \arguments{ 35 | \item{object}{A \code{parsnip} model specification or a \code{\link[workflows:workflow]{workflows::workflow()}}.} 36 | 37 | \item{...}{Not currently used.} 38 | 39 | \item{preprocessor}{A traditional model formula or a recipe created using 40 | \code{\link[recipes:recipe]{recipes::recipe()}}. This is only required when \code{object} is not a workflow.} 41 | 42 | \item{resamples}{An \code{rset()} object that has multiple resamples (i.e., is not 43 | a validation set).} 44 | 45 | \item{param_info}{A \code{\link[dials:parameters]{dials::parameters()}} object or \code{NULL}. If none is given, 46 | a parameters set is derived from other arguments. Passing this argument can 47 | be useful when parameter ranges need to be customized.} 48 | 49 | \item{grid}{A data frame of tuning combinations or a positive integer. The 50 | data frame should have columns for each parameter being tuned and rows for 51 | tuning parameter candidates. An integer denotes the number of candidate 52 | parameter sets to be created automatically.} 53 | 54 | \item{metrics}{A \code{\link[yardstick:metric_set]{yardstick::metric_set()}} or \code{NULL}.} 55 | 56 | \item{eval_time}{A numeric vector of time points where dynamic event time 57 | metrics should be computed (e.g. the time-dependent ROC curve, etc). The 58 | values must be non-negative and should probably be no greater than the 59 | largest event time in the training set (See Details below).} 60 | 61 | \item{control}{An object used to modify the tuning process. See 62 | \code{\link[=control_race]{control_race()}} for more details.} 63 | } 64 | \value{ 65 | An object with primary class \code{tune_race} in the same standard format 66 | as objects produced by \code{\link[tune:tune_grid]{tune::tune_grid()}}. 67 | } 68 | \description{ 69 | \code{\link[=tune_race_win_loss]{tune_race_win_loss()}} computes a set of performance metrics (e.g. accuracy or RMSE) 70 | for a pre-defined set of tuning parameters that correspond to a model or 71 | recipe across one or more resamples of the data. After an initial number of 72 | resamples have been evaluated, the process eliminates tuning parameter 73 | combinations that are unlikely to be the best results using a statistical 74 | model. For each pairwise combinations of tuning parameters, win/loss 75 | statistics are calculated and a logistic regression model is used to measure 76 | how likely each combination is to win overall. 77 | } 78 | \details{ 79 | The technical details of this method are described in Kuhn (2014). 80 | 81 | Racing methods are efficient approaches to grid search. Initially, the 82 | function evaluates all tuning parameters on a small initial set of 83 | resamples. The \code{burn_in} argument of \code{\link[=control_race]{control_race()}} sets the number of 84 | initial resamples. 85 | 86 | The performance statistics from the current set of resamples are converted 87 | to win/loss/tie results. For example, for two parameters (\code{j} and \code{k}) in a 88 | classification model that have each been resampled three times: 89 | 90 | \preformatted{ 91 | | area under the ROC curve | 92 | ----------------------------- 93 | resample | parameter j | parameter k | winner 94 | --------------------------------------------- 95 | 1 | 0.81 | 0.92 | k 96 | 2 | 0.95 | 0.94 | j 97 | 3 | 0.79 | 0.81 | k 98 | --------------------------------------------- 99 | } 100 | 101 | After the third resample, parameter \code{k} has a 2:1 win/loss ratio versus \code{j}. 102 | Parameters with equal results are treated as a half-win for each setting. 103 | These statistics are determined for all pairwise combinations of the 104 | parameters and a Bradley-Terry model is used to model these win/loss/tie 105 | statistics. This model can compute the ability of a parameter combination to 106 | win overall. A confidence interval for the winning ability is computed and 107 | any settings whose interval includes zero are retained for future resamples 108 | (since it is not statistically different form the best results). 109 | 110 | The next resample is used with the remaining parameter combinations and the 111 | statistical analysis is updated. More candidate parameters may be excluded 112 | with each new resample that is processed. 113 | 114 | The \code{\link[=control_race]{control_race()}} function contains are parameter for the significance cutoff 115 | applied to the Bradley-Terry model results as well as other relevant arguments. 116 | \subsection{Censored regression models}{ 117 | 118 | With dynamic performance metrics (e.g. Brier or ROC curves), performance is 119 | calculated for every value of \code{eval_time} but the \emph{first} evaluation time 120 | given by the user (e.g., \code{eval_time[1]}) is analyzed during racing. 121 | 122 | Also, values of \code{eval_time} should be less than the largest observed event 123 | time in the training data. For many non-parametric models, the results beyond 124 | the largest time corresponding to an event are constant (or \code{NA}). 125 | } 126 | } 127 | \examples{ 128 | \donttest{ 129 | library(parsnip) 130 | library(rsample) 131 | library(dials) 132 | 133 | ## ----------------------------------------------------------------------------- 134 | 135 | if (rlang::is_installed(c("discrim", "modeldata"))) { 136 | library(discrim) 137 | data(two_class_dat, package = "modeldata") 138 | 139 | set.seed(6376) 140 | rs <- bootstraps(two_class_dat, times = 10) 141 | 142 | ## ----------------------------------------------------------------------------- 143 | 144 | # optimize an regularized discriminant analysis model 145 | rda_spec <- 146 | discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) |> 147 | set_engine("klaR") 148 | 149 | ## ----------------------------------------------------------------------------- 150 | 151 | ctrl <- control_race(verbose_elim = TRUE) 152 | 153 | set.seed(11) 154 | grid_wl <- 155 | rda_spec |> 156 | tune_race_win_loss(Class ~ ., resamples = rs, grid = 10, control = ctrl) 157 | 158 | # Shows only the fully resampled parameters 159 | show_best(grid_wl, metric = "roc_auc") 160 | 161 | plot_race(grid_wl) 162 | } 163 | } 164 | } 165 | \references{ 166 | Kuhn, M 2014. "Futility Analysis in the Cross-Validation of Machine Learning 167 | Models." \url{https://arxiv.org/abs/1405.6974}. 168 | } 169 | \seealso{ 170 | \code{\link[tune:tune_grid]{tune::tune_grid()}}, \code{\link[=control_race]{control_race()}}, \code{\link[=tune_race_anova]{tune_race_anova()}} 171 | } 172 | -------------------------------------------------------------------------------- /man/tune_sim_anneal.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tune_sim_anneal.R 3 | \name{tune_sim_anneal} 4 | \alias{tune_sim_anneal} 5 | \alias{tune_sim_anneal.model_spec} 6 | \alias{tune_sim_anneal.workflow} 7 | \title{Optimization of model parameters via simulated annealing} 8 | \usage{ 9 | tune_sim_anneal(object, ...) 10 | 11 | \method{tune_sim_anneal}{model_spec}( 12 | object, 13 | preprocessor, 14 | resamples, 15 | ..., 16 | iter = 10, 17 | param_info = NULL, 18 | metrics = NULL, 19 | eval_time = NULL, 20 | initial = 1, 21 | control = control_sim_anneal() 22 | ) 23 | 24 | \method{tune_sim_anneal}{workflow}( 25 | object, 26 | resamples, 27 | ..., 28 | iter = 10, 29 | param_info = NULL, 30 | metrics = NULL, 31 | eval_time = NULL, 32 | initial = 1, 33 | control = control_sim_anneal() 34 | ) 35 | } 36 | \arguments{ 37 | \item{object}{A \code{parsnip} model specification or a \code{\link[workflows:workflow]{workflows::workflow()}}.} 38 | 39 | \item{...}{Not currently used.} 40 | 41 | \item{preprocessor}{A traditional model formula or a recipe created using 42 | \code{\link[recipes:recipe]{recipes::recipe()}}. This is only required when \code{object} is not a workflow.} 43 | 44 | \item{resamples}{An \code{rset()} object.} 45 | 46 | \item{iter}{The maximum number of search iterations.} 47 | 48 | \item{param_info}{A \code{\link[dials:parameters]{dials::parameters()}} object or \code{NULL}. If none is given, 49 | a parameter set is derived from other arguments. Passing this argument can 50 | be useful when parameter ranges need to be customized.} 51 | 52 | \item{metrics}{A \code{\link[yardstick:metric_set]{yardstick::metric_set()}} object containing information on how 53 | models will be evaluated for performance. The first metric in \code{metrics} is the 54 | one that will be optimized.} 55 | 56 | \item{eval_time}{A numeric vector of time points where dynamic event time 57 | metrics should be computed (e.g. the time-dependent ROC curve, etc). The 58 | values must be non-negative and should probably be no greater than the 59 | largest event time in the training set (See Details below).} 60 | 61 | \item{initial}{An initial set of results in a tidy format (as would the result 62 | of \code{\link[tune:tune_grid]{tune::tune_grid()}}, \code{\link[tune:tune_bayes]{tune::tune_bayes()}}, \code{\link[=tune_race_win_loss]{tune_race_win_loss()}}, or 63 | \code{\link[=tune_race_anova]{tune_race_anova()}}) or a positive integer. If the initial object was a 64 | sequential search method, the simulated annealing iterations start after the 65 | last iteration of the initial results.} 66 | 67 | \item{control}{The results of \code{\link[=control_sim_anneal]{control_sim_anneal()}}.} 68 | } 69 | \value{ 70 | A tibble of results that mirror those generated by \code{\link[tune:tune_grid]{tune::tune_grid()}}. 71 | However, these results contain an \code{.iter} column and replicate the \code{rset} 72 | object multiple times over iterations (at limited additional memory costs). 73 | } 74 | \description{ 75 | \code{\link[=tune_sim_anneal]{tune_sim_anneal()}} uses an iterative search procedure to generate new 76 | candidate tuning parameter combinations based on previous results. It uses 77 | the generalized simulated annealing method of Bohachevsky, Johnson, and 78 | Stein (1986). 79 | } 80 | \details{ 81 | Simulated annealing is a global optimization method. For model tuning, it 82 | can be used to iteratively search the parameter space for optimal tuning 83 | parameter combinations. At each iteration, a new parameter combination is 84 | created by perturbing the current parameters in some small way so that they 85 | are within a small neighborhood. This new parameter combination is used to 86 | fit a model and that model's performance is measured using resampling (or a 87 | simple validation set). 88 | 89 | If the new settings have better results than the current settings, they are 90 | accepted and the process continues. 91 | 92 | If the new settings has worse performance, a probability threshold is 93 | computed for accepting these sub-optimal values. The probability is a 94 | function of \emph{how} sub-optimal the results are as well as how many iterations 95 | have elapsed. This is referred to as the "cooling schedule" for the 96 | algorithm. If the sub-optimal results are accepted, the next iterations 97 | settings are based on these inferior results. Otherwise, new parameter 98 | values are generated from the previous iteration's settings. 99 | 100 | This process continues for a pre-defined number of iterations and the 101 | overall best settings are recommended for use. The \code{\link[=control_sim_anneal]{control_sim_anneal()}} 102 | function can specify the number of iterations without improvement for early 103 | stopping. Also, that function can be used to specify a \emph{restart} threshold; 104 | if no globally best results have not be discovered within a certain number 105 | if iterations, the process can restart using the last known settings that 106 | globally best. 107 | \subsection{Creating new settings}{ 108 | 109 | For each numeric parameter, the range of possible values is known as well 110 | as any transformations. The current values are transformed and scaled to 111 | have values between zero and one (based on the possible range of values). A 112 | candidate set of values that are on a sphere with random radii between 113 | \code{rmin} and \code{rmax} are generated. Infeasible values are removed and one value 114 | is chosen at random. This value is back transformed to the original units 115 | and scale and are used as the new settings. The argument \code{radius} of 116 | \code{\link[=control_sim_anneal]{control_sim_anneal()}} controls the range neighborhood sizes. 117 | 118 | For categorical and integer parameters, each is changes with a pre-defined 119 | probability. The \code{flip} argument of \code{\link[=control_sim_anneal]{control_sim_anneal()}} can be used to 120 | specify this probability. For integer parameters, a nearby integer value is 121 | used. 122 | 123 | Simulated annealing search may not be the preferred method when many of the 124 | parameters are non-numeric or integers with few unique values. In these 125 | cases, it is likely that the same candidate set may be tested more than 126 | once. 127 | } 128 | 129 | \subsection{Cooling schedule}{ 130 | 131 | To determine the probability of accepting a new value, the percent 132 | difference in performance is calculated. If the performance metric is to be 133 | maximized, this would be \code{d = (new-old)/old*100}. The probability is 134 | calculated as \code{p = exp(d * coef * iter)} were \code{coef} is a user-defined 135 | constant that can be used to increase or decrease the probabilities. 136 | 137 | The \code{cooling_coef} of \code{\link[=control_sim_anneal]{control_sim_anneal()}} can be used for this purpose. 138 | } 139 | 140 | \subsection{Termination criterion}{ 141 | 142 | The restart counter is reset when a new global best results is found. 143 | 144 | The termination counter resets when a new global best is located or when a 145 | suboptimal result is improved. 146 | } 147 | 148 | \subsection{Parallelism}{ 149 | 150 | The \code{tune} and \code{finetune} packages currently parallelize over resamples. 151 | Specifying a parallel back-end will improve the generation of the initial 152 | set of sub-models (if any). Each iteration of the search are also run in 153 | parallel if a parallel backend is registered. 154 | } 155 | 156 | \subsection{Censored regression models}{ 157 | 158 | With dynamic performance metrics (e.g. Brier or ROC curves), performance is 159 | calculated for every value of \code{eval_time} but the \emph{first} evaluation time 160 | given by the user (e.g., \code{eval_time[1]}) is used to guide the optimization. 161 | 162 | Also, values of \code{eval_time} should be less than the largest observed event 163 | time in the training data. For many non-parametric models, the results beyond 164 | the largest time corresponding to an event are constant (or \code{NA}). 165 | } 166 | } 167 | \examples{ 168 | \donttest{ 169 | library(finetune) 170 | library(rpart) 171 | library(dplyr) 172 | library(tune) 173 | library(rsample) 174 | library(parsnip) 175 | library(workflows) 176 | library(ggplot2) 177 | 178 | ## ----------------------------------------------------------------------------- 179 | if (rlang::is_installed("modeldata")) { 180 | data(two_class_dat, package = "modeldata") 181 | 182 | set.seed(5046) 183 | bt <- bootstraps(two_class_dat, times = 5) 184 | 185 | ## ----------------------------------------------------------------------------- 186 | 187 | cart_mod <- 188 | decision_tree(cost_complexity = tune(), min_n = tune()) |> 189 | set_engine("rpart") |> 190 | set_mode("classification") 191 | 192 | ## ----------------------------------------------------------------------------- 193 | 194 | # For reproducibility, set the seed before running. 195 | set.seed(10) 196 | sa_search <- 197 | cart_mod |> 198 | tune_sim_anneal(Class ~ ., resamples = bt, iter = 10) 199 | 200 | autoplot(sa_search, metric = "roc_auc", type = "parameters") + 201 | theme_bw() 202 | 203 | ## ----------------------------------------------------------------------------- 204 | # More iterations. `initial` can be any other tune_* object or an integer 205 | # (for new values). 206 | 207 | set.seed(11) 208 | more_search <- 209 | cart_mod |> 210 | tune_sim_anneal(Class ~ ., resamples = bt, iter = 10, initial = sa_search) 211 | 212 | autoplot(more_search, metric = "roc_auc", type = "performance") + 213 | theme_bw() 214 | } 215 | } 216 | } 217 | \references{ 218 | Bohachevsky, Johnson, and Stein (1986) "Generalized Simulated Annealing for 219 | Function Optimization", \emph{Technometrics}, 28:3, 209-217 220 | } 221 | \seealso{ 222 | \code{\link[tune:tune_grid]{tune::tune_grid()}}, \code{\link[=control_sim_anneal]{control_sim_anneal()}}, \code{\link[yardstick:metric_set]{yardstick::metric_set()}} 223 | } 224 | -------------------------------------------------------------------------------- /tests/spelling.R: -------------------------------------------------------------------------------- 1 | if (requireNamespace("spelling", quietly = TRUE)) { 2 | spelling::spell_check_test( 3 | vignettes = TRUE, error = FALSE, 4 | skip_on_cran = TRUE 5 | ) 6 | } 7 | -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | 2 | suppressPackageStartupMessages(library(finetune)) 3 | 4 | # CRAN wants packages to be able to be check without the Suggests dependencies 5 | if (rlang::is_installed(c("modeldata", "lme4", "testthat"))) { 6 | suppressPackageStartupMessages(library(testthat)) 7 | test_check("finetune") 8 | } 9 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/anova-filter.md: -------------------------------------------------------------------------------- 1 | # anova filtering and logging 2 | 3 | Code 4 | set.seed(129) 5 | anova_mod <- tune_race_anova(spec, mpg ~ ., folds, grid = grid) 6 | 7 | --- 8 | 9 | Code 10 | finetune:::log_racing(control_race(verbose_elim = TRUE), anova_res, 11 | ames_grid_search$splits, 10, "rmse") 12 | Message 13 | i Fold10: 7 eliminated; 3 candidates remain. 14 | 15 | --- 16 | 17 | Code 18 | finetune:::log_racing(control_race(verbose_elim = TRUE), anova_res, 19 | ames_grid_search$splits, 10, "rmse") 20 | Message 21 | i Fold10: 7 eliminated; 3 candidates remain. 22 | 23 | --- 24 | 25 | Code 26 | finetune:::log_racing(control_race(verbose_elim = TRUE), anova_res, 27 | ames_grid_search$splits, 10, "rmse") 28 | Message 29 | i Fold10: 7 eliminated; 3 candidates remain. 30 | 31 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/anova-overall.md: -------------------------------------------------------------------------------- 1 | # formula interface 2 | 3 | Code 4 | set.seed(1) 5 | res <- tune_race_anova(f_wflow, cell_folds, grid = grid_mod, control = control_race( 6 | verbose_elim = TRUE)) 7 | Message 8 | i Evaluating against the initial 3 burn-in resamples. 9 | i Racing will maximize the roc_auc metric. 10 | i Resamples are analyzed in a random order. 11 | i Fold3, Repeat1: 2 eliminated; 2 candidates remain. 12 | i Fold2, Repeat2: All but one parameter combination were eliminated. 13 | 14 | # too few resamples 15 | 16 | The number of resamples (2) needs to be more than the number of burn-in resamples (3) set by the control function `control_race()`. 17 | 18 | --- 19 | 20 | The number of resamples (2) needs to be more than the number of burn-in resamples (3) set by the control function `control_race()`. 21 | 22 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/race-control.md: -------------------------------------------------------------------------------- 1 | # control_race bad arg passing 2 | 3 | Argument 'verbose' should be a single logical value in `control_race()` 4 | 5 | --- 6 | 7 | Argument 'verbose' should be a single logical value in `control_race()` 8 | 9 | --- 10 | 11 | Argument 'verbose_elim' should be a single logical value in `control_race()` 12 | 13 | --- 14 | 15 | Argument 'verbose_elim' should be a single logical value in `control_race()` 16 | 17 | --- 18 | 19 | Argument 'save_pred' should be a single logical value in `control_race()` 20 | 21 | --- 22 | 23 | Argument 'save_pred' should be a single logical value in `control_race()` 24 | 25 | --- 26 | 27 | Argument 'save_workflow' should be a single logical value in `control_race()` 28 | 29 | --- 30 | 31 | Argument 'save_workflow' should be a single logical value in `control_race()` 32 | 33 | --- 34 | 35 | Argument 'burn_in' should be a single numeric value in `control_race()` 36 | 37 | --- 38 | 39 | Argument 'burn_in' should be a single numeric value in `control_race()` 40 | 41 | --- 42 | 43 | `burn_in` should be at least two. 44 | 45 | --- 46 | 47 | Argument 'num_ties' should be a single numeric value in `control_race()` 48 | 49 | --- 50 | 51 | Argument 'num_ties' should be a single numeric value in `control_race()` 52 | 53 | --- 54 | 55 | Argument 'alpha' should be a single numeric value in `control_race()` 56 | 57 | --- 58 | 59 | Argument 'alpha' should be a single numeric value in `control_race()` 60 | 61 | --- 62 | 63 | `alpha` should be on (0, 1). 64 | 65 | --- 66 | 67 | Argument 'pkgs' should be a character or NULL in `control_race()` 68 | 69 | --- 70 | 71 | Argument 'extract' should be a function or NULL in `control_race()` 72 | 73 | # casting control_race to control_grid 74 | 75 | Code 76 | parsnip::condense_control(control_race(), control_grid()) 77 | Output 78 | grid/resamples control object 79 | 80 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/sa-control.md: -------------------------------------------------------------------------------- 1 | # control_sim_anneal bad arg passing 2 | 3 | Argument 'verbose' should be a single logical value in `control_sim_anneal()` 4 | 5 | --- 6 | 7 | Argument 'verbose' should be a single logical value in `control_sim_anneal()` 8 | 9 | --- 10 | 11 | Argument 'save_pred' should be a single logical value in `control_sim_anneal()` 12 | 13 | --- 14 | 15 | Argument 'save_pred' should be a single logical value in `control_sim_anneal()` 16 | 17 | --- 18 | 19 | Argument 'save_workflow' should be a single logical value in `control_sim_anneal()` 20 | 21 | --- 22 | 23 | Argument 'save_workflow' should be a single logical value in `control_sim_anneal()` 24 | 25 | --- 26 | 27 | Argument 'no_improve' should be a single numeric or integer value in `control_sim_anneal()` 28 | 29 | --- 30 | 31 | Argument 'no_improve' should be a single numeric or integer value in `control_sim_anneal()` 32 | 33 | --- 34 | 35 | `no_improve` should be > 1. 36 | 37 | --- 38 | 39 | Argument 'restart' should be a single numeric or integer value in `control_sim_anneal()` 40 | 41 | --- 42 | 43 | Argument 'restart' should be a single numeric or integer value in `control_sim_anneal()` 44 | 45 | --- 46 | 47 | `restart` should be > 1. 48 | 49 | --- 50 | 51 | Code 52 | control_sim_anneal(no_improve = 2, restart = 6) 53 | Message 54 | ! Parameter restart is scheduled after 6 poor iterations but the search will stop after 2. 55 | Output 56 | Simulated annealing control object 57 | 58 | --- 59 | 60 | Argument `radius` should be two numeric values. 61 | 62 | --- 63 | 64 | Argument 'flip' should be a single numeric value in `control_sim_anneal()` 65 | 66 | --- 67 | 68 | Argument 'flip' should be a single numeric value in `control_sim_anneal()` 69 | 70 | --- 71 | 72 | Argument 'cooling_coef' should be a single numeric value in `control_sim_anneal()` 73 | 74 | --- 75 | 76 | Argument 'cooling_coef' should be a single numeric value in `control_sim_anneal()` 77 | 78 | --- 79 | 80 | Argument 'pkgs' should be a character or NULL in `control_sim_anneal()` 81 | 82 | --- 83 | 84 | Argument 'extract' should be a function or NULL in `control_sim_anneal()` 85 | 86 | # casting control_sim_anneal to control_grid 87 | 88 | Code 89 | parsnip::condense_control(control_sim_anneal(), control_grid()) 90 | Output 91 | grid/resamples control object 92 | 93 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/sa-misc.md: -------------------------------------------------------------------------------- 1 | # tune_sim_anneal interfaces 2 | 3 | Code 4 | set.seed(1) 5 | f_res_1 <- tune_sim_anneal(rda_spec, Class ~ ., rs, iter = 3) 6 | Message 7 | Optimizing roc_auc 8 | Initial best: 0.85731 9 | 1 ( ) accept suboptimal roc_auc=0.85682 (+/-0.01022) 10 | 2 ( ) accept suboptimal roc_auc=0.85238 (+/-0.01078) 11 | 3 ( ) accept suboptimal roc_auc=0.85138 (+/-0.0109) 12 | 13 | --- 14 | 15 | Code 16 | set.seed(1) 17 | f_res_2 <- tune_sim_anneal(rda_spec, Class ~ ., rs, iter = 3, param_info = rda_param) 18 | Message 19 | Optimizing roc_auc 20 | Initial best: 0.85325 21 | 1 ( ) accept suboptimal roc_auc=0.85313 (+/-0.0106) 22 | 2 ( ) accept suboptimal roc_auc=0.85181 (+/-0.01065) 23 | 3 ( ) accept suboptimal roc_auc=0.85165 (+/-0.01055) 24 | 25 | --- 26 | 27 | Code 28 | set.seed(1) 29 | f_rec_1 <- tune_sim_anneal(rda_spec, rec, rs, iter = 3) 30 | Message 31 | Optimizing roc_auc 32 | Initial best: 0.86616 33 | 1 ( ) accept suboptimal roc_auc=0.86399 (+/-0.01081) 34 | 2 <3 new best roc_auc=0.86768 (+/-0.009563) 35 | 3 <3 new best roc_auc=0.87329 (+/-0.008273) 36 | 37 | --- 38 | 39 | Code 40 | set.seed(1) 41 | f_wflow_1 <- tune_sim_anneal(wflow, rs, iter = 3) 42 | Message 43 | Optimizing roc_auc 44 | Initial best: 0.86616 45 | 1 ( ) accept suboptimal roc_auc=0.86399 (+/-0.01081) 46 | 2 <3 new best roc_auc=0.86768 (+/-0.009563) 47 | 3 <3 new best roc_auc=0.87329 (+/-0.008273) 48 | 49 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/sa-overall.md: -------------------------------------------------------------------------------- 1 | # formula interface 2 | 3 | Code 4 | set.seed(1) 5 | res <- tune_sim_anneal(f_wflow, cell_folds, iter = 2, control = control_sim_anneal( 6 | verbose = TRUE)) 7 | Message 8 | 9 | > Generating a set of 1 initial parameter results 10 | v Initialization complete 11 | 12 | Optimizing roc_auc 13 | Initial best: 0.73008 14 | i Fold1, Repeat1: preprocessor 1/1 15 | v Fold1, Repeat1: preprocessor 1/1 16 | i Fold1, Repeat1: preprocessor 1/1, model 1/1 17 | v Fold1, Repeat1: preprocessor 1/1, model 1/1 18 | i Fold1, Repeat1: preprocessor 1/1, model 1/1 (extracts) 19 | i Fold1, Repeat1: preprocessor 1/1, model 1/1 (predictions) 20 | i Fold2, Repeat1: preprocessor 1/1 21 | v Fold2, Repeat1: preprocessor 1/1 22 | i Fold2, Repeat1: preprocessor 1/1, model 1/1 23 | v Fold2, Repeat1: preprocessor 1/1, model 1/1 24 | i Fold2, Repeat1: preprocessor 1/1, model 1/1 (extracts) 25 | i Fold2, Repeat1: preprocessor 1/1, model 1/1 (predictions) 26 | i Fold3, Repeat1: preprocessor 1/1 27 | v Fold3, Repeat1: preprocessor 1/1 28 | i Fold3, Repeat1: preprocessor 1/1, model 1/1 29 | v Fold3, Repeat1: preprocessor 1/1, model 1/1 30 | i Fold3, Repeat1: preprocessor 1/1, model 1/1 (extracts) 31 | i Fold3, Repeat1: preprocessor 1/1, model 1/1 (predictions) 32 | i Fold1, Repeat2: preprocessor 1/1 33 | v Fold1, Repeat2: preprocessor 1/1 34 | i Fold1, Repeat2: preprocessor 1/1, model 1/1 35 | v Fold1, Repeat2: preprocessor 1/1, model 1/1 36 | i Fold1, Repeat2: preprocessor 1/1, model 1/1 (extracts) 37 | i Fold1, Repeat2: preprocessor 1/1, model 1/1 (predictions) 38 | i Fold2, Repeat2: preprocessor 1/1 39 | v Fold2, Repeat2: preprocessor 1/1 40 | i Fold2, Repeat2: preprocessor 1/1, model 1/1 41 | v Fold2, Repeat2: preprocessor 1/1, model 1/1 42 | i Fold2, Repeat2: preprocessor 1/1, model 1/1 (extracts) 43 | i Fold2, Repeat2: preprocessor 1/1, model 1/1 (predictions) 44 | i Fold3, Repeat2: preprocessor 1/1 45 | v Fold3, Repeat2: preprocessor 1/1 46 | i Fold3, Repeat2: preprocessor 1/1, model 1/1 47 | v Fold3, Repeat2: preprocessor 1/1, model 1/1 48 | i Fold3, Repeat2: preprocessor 1/1, model 1/1 (extracts) 49 | i Fold3, Repeat2: preprocessor 1/1, model 1/1 (predictions) 50 | 1 ( ) accept suboptimal roc_auc=0.72145 (+/-0.003605) 51 | i Fold1, Repeat1: preprocessor 1/1 52 | v Fold1, Repeat1: preprocessor 1/1 53 | i Fold1, Repeat1: preprocessor 1/1, model 1/1 54 | v Fold1, Repeat1: preprocessor 1/1, model 1/1 55 | i Fold1, Repeat1: preprocessor 1/1, model 1/1 (extracts) 56 | i Fold1, Repeat1: preprocessor 1/1, model 1/1 (predictions) 57 | i Fold2, Repeat1: preprocessor 1/1 58 | v Fold2, Repeat1: preprocessor 1/1 59 | i Fold2, Repeat1: preprocessor 1/1, model 1/1 60 | v Fold2, Repeat1: preprocessor 1/1, model 1/1 61 | i Fold2, Repeat1: preprocessor 1/1, model 1/1 (extracts) 62 | i Fold2, Repeat1: preprocessor 1/1, model 1/1 (predictions) 63 | i Fold3, Repeat1: preprocessor 1/1 64 | v Fold3, Repeat1: preprocessor 1/1 65 | i Fold3, Repeat1: preprocessor 1/1, model 1/1 66 | v Fold3, Repeat1: preprocessor 1/1, model 1/1 67 | i Fold3, Repeat1: preprocessor 1/1, model 1/1 (extracts) 68 | i Fold3, Repeat1: preprocessor 1/1, model 1/1 (predictions) 69 | i Fold1, Repeat2: preprocessor 1/1 70 | v Fold1, Repeat2: preprocessor 1/1 71 | i Fold1, Repeat2: preprocessor 1/1, model 1/1 72 | v Fold1, Repeat2: preprocessor 1/1, model 1/1 73 | i Fold1, Repeat2: preprocessor 1/1, model 1/1 (extracts) 74 | i Fold1, Repeat2: preprocessor 1/1, model 1/1 (predictions) 75 | i Fold2, Repeat2: preprocessor 1/1 76 | v Fold2, Repeat2: preprocessor 1/1 77 | i Fold2, Repeat2: preprocessor 1/1, model 1/1 78 | v Fold2, Repeat2: preprocessor 1/1, model 1/1 79 | i Fold2, Repeat2: preprocessor 1/1, model 1/1 (extracts) 80 | i Fold2, Repeat2: preprocessor 1/1, model 1/1 (predictions) 81 | i Fold3, Repeat2: preprocessor 1/1 82 | v Fold3, Repeat2: preprocessor 1/1 83 | i Fold3, Repeat2: preprocessor 1/1, model 1/1 84 | v Fold3, Repeat2: preprocessor 1/1, model 1/1 85 | i Fold3, Repeat2: preprocessor 1/1, model 1/1 (extracts) 86 | i Fold3, Repeat2: preprocessor 1/1, model 1/1 (predictions) 87 | 2 <3 new best roc_auc=0.73173 (+/-0.003018) 88 | 89 | # variable interface 90 | 91 | Code 92 | set.seed(1) 93 | res <- tune_sim_anneal(var_wflow, cell_folds, iter = 2, control = control_sim_anneal( 94 | verbose = TRUE, verbose_iter = TRUE)) 95 | Message 96 | 97 | > Generating a set of 1 initial parameter results 98 | v Initialization complete 99 | 100 | Optimizing roc_auc 101 | Initial best: 0.73008 102 | i Fold1, Repeat1: preprocessor 1/1 103 | v Fold1, Repeat1: preprocessor 1/1 104 | i Fold1, Repeat1: preprocessor 1/1, model 1/1 105 | v Fold1, Repeat1: preprocessor 1/1, model 1/1 106 | i Fold1, Repeat1: preprocessor 1/1, model 1/1 (extracts) 107 | i Fold1, Repeat1: preprocessor 1/1, model 1/1 (predictions) 108 | i Fold2, Repeat1: preprocessor 1/1 109 | v Fold2, Repeat1: preprocessor 1/1 110 | i Fold2, Repeat1: preprocessor 1/1, model 1/1 111 | v Fold2, Repeat1: preprocessor 1/1, model 1/1 112 | i Fold2, Repeat1: preprocessor 1/1, model 1/1 (extracts) 113 | i Fold2, Repeat1: preprocessor 1/1, model 1/1 (predictions) 114 | i Fold3, Repeat1: preprocessor 1/1 115 | v Fold3, Repeat1: preprocessor 1/1 116 | i Fold3, Repeat1: preprocessor 1/1, model 1/1 117 | v Fold3, Repeat1: preprocessor 1/1, model 1/1 118 | i Fold3, Repeat1: preprocessor 1/1, model 1/1 (extracts) 119 | i Fold3, Repeat1: preprocessor 1/1, model 1/1 (predictions) 120 | i Fold1, Repeat2: preprocessor 1/1 121 | v Fold1, Repeat2: preprocessor 1/1 122 | i Fold1, Repeat2: preprocessor 1/1, model 1/1 123 | v Fold1, Repeat2: preprocessor 1/1, model 1/1 124 | i Fold1, Repeat2: preprocessor 1/1, model 1/1 (extracts) 125 | i Fold1, Repeat2: preprocessor 1/1, model 1/1 (predictions) 126 | i Fold2, Repeat2: preprocessor 1/1 127 | v Fold2, Repeat2: preprocessor 1/1 128 | i Fold2, Repeat2: preprocessor 1/1, model 1/1 129 | v Fold2, Repeat2: preprocessor 1/1, model 1/1 130 | i Fold2, Repeat2: preprocessor 1/1, model 1/1 (extracts) 131 | i Fold2, Repeat2: preprocessor 1/1, model 1/1 (predictions) 132 | i Fold3, Repeat2: preprocessor 1/1 133 | v Fold3, Repeat2: preprocessor 1/1 134 | i Fold3, Repeat2: preprocessor 1/1, model 1/1 135 | v Fold3, Repeat2: preprocessor 1/1, model 1/1 136 | i Fold3, Repeat2: preprocessor 1/1, model 1/1 (extracts) 137 | i Fold3, Repeat2: preprocessor 1/1, model 1/1 (predictions) 138 | 1 ( ) accept suboptimal roc_auc=0.72145 (+/-0.003605) 139 | i Fold1, Repeat1: preprocessor 1/1 140 | v Fold1, Repeat1: preprocessor 1/1 141 | i Fold1, Repeat1: preprocessor 1/1, model 1/1 142 | v Fold1, Repeat1: preprocessor 1/1, model 1/1 143 | i Fold1, Repeat1: preprocessor 1/1, model 1/1 (extracts) 144 | i Fold1, Repeat1: preprocessor 1/1, model 1/1 (predictions) 145 | i Fold2, Repeat1: preprocessor 1/1 146 | v Fold2, Repeat1: preprocessor 1/1 147 | i Fold2, Repeat1: preprocessor 1/1, model 1/1 148 | v Fold2, Repeat1: preprocessor 1/1, model 1/1 149 | i Fold2, Repeat1: preprocessor 1/1, model 1/1 (extracts) 150 | i Fold2, Repeat1: preprocessor 1/1, model 1/1 (predictions) 151 | i Fold3, Repeat1: preprocessor 1/1 152 | v Fold3, Repeat1: preprocessor 1/1 153 | i Fold3, Repeat1: preprocessor 1/1, model 1/1 154 | v Fold3, Repeat1: preprocessor 1/1, model 1/1 155 | i Fold3, Repeat1: preprocessor 1/1, model 1/1 (extracts) 156 | i Fold3, Repeat1: preprocessor 1/1, model 1/1 (predictions) 157 | i Fold1, Repeat2: preprocessor 1/1 158 | v Fold1, Repeat2: preprocessor 1/1 159 | i Fold1, Repeat2: preprocessor 1/1, model 1/1 160 | v Fold1, Repeat2: preprocessor 1/1, model 1/1 161 | i Fold1, Repeat2: preprocessor 1/1, model 1/1 (extracts) 162 | i Fold1, Repeat2: preprocessor 1/1, model 1/1 (predictions) 163 | i Fold2, Repeat2: preprocessor 1/1 164 | v Fold2, Repeat2: preprocessor 1/1 165 | i Fold2, Repeat2: preprocessor 1/1, model 1/1 166 | v Fold2, Repeat2: preprocessor 1/1, model 1/1 167 | i Fold2, Repeat2: preprocessor 1/1, model 1/1 (extracts) 168 | i Fold2, Repeat2: preprocessor 1/1, model 1/1 (predictions) 169 | i Fold3, Repeat2: preprocessor 1/1 170 | v Fold3, Repeat2: preprocessor 1/1 171 | i Fold3, Repeat2: preprocessor 1/1, model 1/1 172 | v Fold3, Repeat2: preprocessor 1/1, model 1/1 173 | i Fold3, Repeat2: preprocessor 1/1, model 1/1 (extracts) 174 | i Fold3, Repeat2: preprocessor 1/1, model 1/1 (predictions) 175 | 2 <3 new best roc_auc=0.73173 (+/-0.003018) 176 | 177 | --- 178 | 179 | Code 180 | set.seed(1) 181 | new_res <- tune_sim_anneal(var_wflow, cell_folds, iter = 2, initial = res, 182 | control = control_sim_anneal(verbose = FALSE)) 183 | Message 184 | There were 2 previous iterations 185 | Optimizing roc_auc 186 | 2 v initial roc_auc=0.73173 (+/-0.003018) 187 | 3 <3 new best roc_auc=0.74172 (+/-0.008775) 188 | 4 <3 new best roc_auc=0.7909 (+/-0.009887) 189 | 190 | --- 191 | 192 | Code 193 | set.seed(1) 194 | new_new_res <- tune_sim_anneal(var_wflow, cell_folds, iter = 2, initial = grid_res, 195 | control = control_sim_anneal(verbose = FALSE)) 196 | Message 197 | Optimizing roc_auc 198 | Initial best: 0.84497 199 | 1 <3 new best roc_auc=0.84531 (+/-0.005563) 200 | 2 ( ) accept suboptimal roc_auc=0.83776 (+/-0.007509) 201 | 202 | # unfinalized parameters 203 | 204 | Code 205 | set.seed(40) 206 | rf_res_finetune <- tune_sim_anneal(wf_rf, resamples = bt, initial = rf_res) 207 | Message 208 | i Creating pre-processing data to finalize unknown parameter: mtry 209 | Optimizing roc_auc 210 | Initial best: 0.84994 211 | 1 ( ) accept suboptimal roc_auc=0.84375 (+/-0.007727) 212 | 2 + better suboptimal roc_auc=0.84943 (+/-0.007036) 213 | 3 ( ) accept suboptimal roc_auc=0.84371 (+/-0.007903) 214 | 4 + better suboptimal roc_auc=0.84825 (+/-0.008036) 215 | 5 ( ) accept suboptimal roc_auc=0.84479 (+/-0.00814) 216 | 6 + better suboptimal roc_auc=0.84816 (+/-0.007283) 217 | 7 ( ) accept suboptimal roc_auc=0.84381 (+/-0.007999) 218 | 8 <3 new best roc_auc=0.85014 (+/-0.007172) 219 | 9 ( ) accept suboptimal roc_auc=0.84344 (+/-0.007818) 220 | 10 + better suboptimal roc_auc=0.84802 (+/-0.007281) 221 | 222 | --- 223 | 224 | Code 225 | set.seed(40) 226 | rf_res_finetune <- tune_sim_anneal(wf_rf, resamples = bt) 227 | Message 228 | i Creating pre-processing data to finalize unknown parameter: mtry 229 | Optimizing roc_auc 230 | Initial best: 0.84418 231 | 1 <3 new best roc_auc=0.84839 (+/-0.007753) 232 | 2 ( ) accept suboptimal roc_auc=0.84384 (+/-0.008085) 233 | 3 <3 new best roc_auc=0.84857 (+/-0.007615) 234 | 4 ( ) accept suboptimal roc_auc=0.8435 (+/-0.007746) 235 | 5 + better suboptimal roc_auc=0.84804 (+/-0.00774) 236 | 6 ( ) accept suboptimal roc_auc=0.84338 (+/-0.007515) 237 | 7 <3 new best roc_auc=0.84923 (+/-0.007371) 238 | 8 ( ) accept suboptimal roc_auc=0.84389 (+/-0.007938) 239 | 9 <3 new best roc_auc=0.84926 (+/-0.007163) 240 | 10 ( ) accept suboptimal roc_auc=0.84397 (+/-0.00741) 241 | 242 | # incompatible parameter objects 243 | 244 | Code 245 | res <- tune_sim_anneal(car_wflow, param_info = parameter_set_with_smaller_range, 246 | resamples = car_folds, initial = tune_res_with_bigger_range, iter = 2) 247 | Message 248 | Optimizing rmse 249 | 250 | Condition 251 | Error in `tune_sim_anneal()`: 252 | ! The range for parameter mtry used when generating initial results isn't compatible with the range supplied in `param_info`. 253 | i Possible values of parameters in `param_info` should encompass all values evaluated in the initial grid. 254 | Message 255 | x Optimization stopped prematurely; returning current results. 256 | 257 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/win-loss-overall.md: -------------------------------------------------------------------------------- 1 | # formula interface 2 | 3 | Code 4 | set.seed(1) 5 | res <- tune_race_win_loss(f_wflow, cell_folds, grid = 5, control = control_race( 6 | verbose_elim = TRUE)) 7 | Message 8 | i Racing will maximize the roc_auc metric. 9 | i Resamples are analyzed in a random order. 10 | i Fold3, Repeat1: 1 eliminated; 4 candidates remain. 11 | i Fold2, Repeat2: 0 eliminated; 4 candidates remain. 12 | i Fold3, Repeat2: 0 eliminated; 4 candidates remain. 13 | 14 | # one player is really bad 15 | 16 | Code 17 | best_res <- show_best(tuning_results) 18 | Condition 19 | Warning in `show_best()`: 20 | No value of `metric` was given; "roc_auc" will be used. 21 | 22 | -------------------------------------------------------------------------------- /tests/testthat/helper.R: -------------------------------------------------------------------------------- 1 | suppressPackageStartupMessages(library(finetune)) 2 | suppressPackageStartupMessages(library(rsample)) 3 | suppressPackageStartupMessages(library(workflows)) 4 | suppressPackageStartupMessages(library(parsnip)) 5 | suppressPackageStartupMessages(library(dplyr)) 6 | suppressPackageStartupMessages(library(recipes)) 7 | suppressPackageStartupMessages(library(tibble)) 8 | suppressPackageStartupMessages(library(yardstick)) 9 | suppressPackageStartupMessages(library(purrr)) 10 | suppressPackageStartupMessages(library(lme4)) 11 | suppressPackageStartupMessages(library(ranger)) 12 | suppressPackageStartupMessages(library(recipes)) 13 | suppressPackageStartupMessages(library(modeldata)) 14 | 15 | # ------------------------------------------------------------------------------ 16 | 17 | data(cells, package = "modeldata") 18 | cells <- modeldata::cells |> dplyr::select(class, contains("ch_1")) 19 | set.seed(33) 20 | cell_folds <- rsample::vfold_cv(cells, v = 3, repeats = 2) 21 | 22 | ## ----------------------------------------------------------------------------- 23 | 24 | cart_spec <- 25 | parsnip::decision_tree( 26 | cost_complexity = parsnip::tune(), 27 | min_n = parsnip::tune() 28 | ) |> 29 | parsnip::set_mode("classification") |> 30 | parsnip::set_engine("rpart") 31 | 32 | cart_rec <- 33 | recipes::recipe(class ~ ., data = cells) |> 34 | recipes::step_normalize(recipes::all_predictors()) |> 35 | recipes::step_pca(recipes::all_predictors(), num_comp = parsnip::tune()) 36 | 37 | ## ----------------------------------------------------------------------------- 38 | 39 | rec_wflow <- 40 | cell_knn <- 41 | workflows::workflow() |> 42 | workflows::add_model(cart_spec) |> 43 | workflows::add_recipe(cart_rec) 44 | 45 | f_wflow <- 46 | cell_knn <- 47 | workflows::workflow() |> 48 | workflows::add_model(cart_spec) |> 49 | workflows::add_formula(class ~ .) 50 | 51 | var_wflow <- 52 | cell_knn <- 53 | workflows::workflow() |> 54 | workflows::add_model(cart_spec) |> 55 | workflows::add_variables(class, dplyr::everything()) 56 | 57 | 58 | # ------------------------------------------------------------------------------ 59 | 60 | grid_mod <- 61 | expand.grid(cost_complexity = c(0.001, 0.0001), min_n = c(3, 4)) 62 | 63 | grid_mod_rec <- 64 | expand.grid(cost_complexity = c(0.001, 0.0001), min_n = 3:4, num_comp = 19:20) 65 | 66 | # ------------------------------------------------------------------------------ 67 | -------------------------------------------------------------------------------- /tests/testthat/sa_cart_test_objects.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/finetune/13b1ba2d4b97e63b3a4e08d1e607e8c6917df22a/tests/testthat/sa_cart_test_objects.RData -------------------------------------------------------------------------------- /tests/testthat/test-anova-filter.R: -------------------------------------------------------------------------------- 1 | ## ----------------------------------------------------------------------------- 2 | 3 | test_that("anova filtering and logging", { 4 | # Skip for < 4.0 due to random number differences 5 | skip_if(getRversion() < "4.0.0") 6 | skip_if_not_installed("Matrix", "1.6-2") 7 | skip_if_not_installed("lme4", "1.1-35.1") 8 | 9 | set.seed(2332) 10 | folds <- vfold_cv(mtcars, v = 5, repeats = 2) 11 | fold_att <- attributes(folds) 12 | spec <- 13 | decision_tree(cost_complexity = tune(), min_n = tune()) |> 14 | set_engine("rpart") |> 15 | set_mode("regression") 16 | wflow <- workflow() |> 17 | add_model(spec) |> 18 | add_formula(mpg ~ .) 19 | grid <- expand.grid(cost_complexity = c(0.001, 0.01), min_n = c(2:5)) 20 | 21 | ## ----------------------------------------------------------------------------- 22 | 23 | grid_res <- 24 | spec |> tune_grid(mpg ~ ., folds, grid = grid, metrics = metric_set(rmse)) 25 | # Pull out rmse values, format them to emulate the racing tests then 26 | # use lme4 package to create the model results for removing configurations. 27 | 28 | alpha <- 0.0381 29 | 30 | rmse_means <- collect_metrics(grid_res) 31 | configs <- rmse_means$.config[order(rmse_means$mean)] 32 | rmse_vals <- collect_metrics(grid_res, summarize = FALSE) 33 | rmse_configs <- rmse_vals 34 | rmse_configs$.config <- factor(rmse_configs$.config, levels = configs) 35 | rmse_configs <- rmse_configs[, c("id", "id2", ".estimate", ".config")] 36 | rmse_mod <- lmer(.estimate ~ .config + (1 | id2 / id), data = rmse_configs) 37 | rmse_summary <- summary(rmse_mod)$coef 38 | rmse_res <- tibble::as_tibble(rmse_summary) 39 | rmse_res$.config <- gsub("\\.config", "", rownames(rmse_summary)) 40 | rmse_res$.config <- gsub( 41 | "(Intercept)", 42 | configs[1], 43 | rmse_res$.config, 44 | fixed = TRUE 45 | ) 46 | rmse_ci <- confint(rmse_mod, level = 1 - alpha, method = "Wald", quiet = TRUE) 47 | rmse_ci <- rmse_ci[grepl("config", rownames(rmse_ci)), ] 48 | 49 | # ------------------------------------------------------------------------------ 50 | # anova results 51 | 52 | anova_res <- finetune:::fit_anova(grid_res, rmse_configs, alpha = alpha) 53 | expect_equal(anova_res$estimate, rmse_res$Estimate[-1]) 54 | expect_equal(anova_res$lower, unname(rmse_ci[, 1])) 55 | expect_equal(anova_res$upper, unname(rmse_ci[, 2])) 56 | expect_equal(anova_res$.config, configs[-1]) 57 | 58 | # ------------------------------------------------------------------------------ 59 | # top-level anova filter interfaces 60 | 61 | expect_snapshot({ 62 | set.seed(129) 63 | anova_mod <- spec |> tune_race_anova(mpg ~ ., folds, grid = grid) 64 | }) 65 | expect_true(inherits(anova_mod, "tune_race")) 66 | expect_true(inherits(anova_mod, "tune_results")) 67 | expect_true(tibble::is_tibble((anova_mod))) 68 | 69 | expect_silent({ 70 | set.seed(129) 71 | anova_wlfow <- 72 | wflow |> 73 | tune_race_anova( 74 | folds, 75 | grid = grid, 76 | control = control_race(verbose_elim = FALSE, save_pred = TRUE) 77 | ) 78 | }) 79 | expect_true(inherits(anova_wlfow, "tune_race")) 80 | expect_true(inherits(anova_wlfow, "tune_results")) 81 | expect_true(tibble::is_tibble((anova_wlfow))) 82 | expect_true(sum(names(anova_wlfow) == ".predictions") == 1) 83 | 84 | ## ----------------------------------------------------------------------------- 85 | ## anova formula 86 | 87 | for (i in 2:nrow(folds)) { 88 | f <- finetune:::lmer_formula(folds |> slice(1:i), fold_att) 89 | if (i < 7) { 90 | expect_equal(f, .estimate ~ .config + (1 | .all_id), ignore_attr = TRUE) 91 | } else { 92 | expect_equal(f, .estimate ~ .config + (1 | id2 / id), ignore_attr = TRUE) 93 | } 94 | } 95 | # This one takes a while to run: 96 | expect_equal(environment(f), rlang::base_env()) 97 | 98 | car_bt <- bootstraps(mtcars, times = 5) 99 | car_att <- attributes(car_bt) 100 | 101 | for (i in 2:nrow(car_bt)) { 102 | f <- finetune:::lmer_formula(car_bt |> slice(1:i), car_att) 103 | expect_equal(f, .estimate ~ .config + (1 | id), ignore_attr = TRUE) 104 | } 105 | expect_equal(environment(f), rlang::base_env()) 106 | 107 | res <- finetune:::refactor_by_mean(rmse_vals, maximize = FALSE) 108 | expect_equal(res, rmse_configs) 109 | 110 | # ------------------------------------------------------------------------------ 111 | 112 | # Ue the built-in `ames_grid_search` object to test the object structure andU 113 | # printing 114 | 115 | param <- .get_tune_parameter_names(ames_grid_search) 116 | ames_grid_res <- collect_metrics(ames_grid_search) 117 | ames_grid_res <- ames_grid_res[ames_grid_res$.metric == "rmse", ] 118 | 119 | anova_res <- finetune:::test_parameters_gls(ames_grid_search) 120 | expect_equal( 121 | names(anova_res), 122 | # fmt: skip 123 | c( 124 | ".config", "lower", "upper", "estimate", "pass", "K", "weight_func", 125 | "dist_power", "lon", "lat" 126 | ) 127 | ) 128 | expect_equal(nrow(anova_res), nrow(ames_grid_res)) 129 | expect_equal(anova_res$lower <= 0, anova_res$pass) 130 | expect_equal( 131 | anova_res |> dplyr::select(!!!param, .config) |> arrange(.config), 132 | ames_grid_res |> dplyr::select(!!!param, .config) |> arrange(.config) 133 | ) 134 | 135 | expect_snapshot( 136 | finetune:::log_racing( 137 | control_race(verbose_elim = TRUE), 138 | anova_res, 139 | ames_grid_search$splits, 140 | 10, 141 | "rmse" 142 | ) 143 | ) 144 | expect_snapshot( 145 | finetune:::log_racing( 146 | control_race(verbose_elim = TRUE), 147 | anova_res, 148 | ames_grid_search$splits, 149 | 10, 150 | "rmse" 151 | ) 152 | ) 153 | expect_snapshot( 154 | finetune:::log_racing( 155 | control_race(verbose_elim = TRUE), 156 | anova_res, 157 | ames_grid_search$splits, 158 | 10, 159 | "rmse" 160 | ) 161 | ) 162 | }) 163 | -------------------------------------------------------------------------------- /tests/testthat/test-anova-overall.R: -------------------------------------------------------------------------------- 1 | test_that("formula interface", { 2 | skip_on_cran() 3 | skip_if_not_installed("Matrix", "1.6-2") 4 | skip_if_not_installed("lme4", "1.1-35.1") 5 | 6 | expect_snapshot({ 7 | set.seed(1) 8 | res <- f_wflow |> 9 | tune_race_anova( 10 | cell_folds, 11 | grid = grid_mod, 12 | control = control_race(verbose_elim = TRUE) 13 | ) 14 | }) 15 | expect_equal( 16 | class(res), 17 | c("tune_race", "tune_results", "tbl_df", "tbl", "data.frame") 18 | ) 19 | expect_true(nrow(collect_metrics(res)) < nrow(grid_mod) * 3) 20 | expect_equal(res, .Last.tune.result) 21 | expect_null(.get_tune_eval_times(res)) 22 | expect_null(.get_tune_eval_time_target(res)) 23 | }) 24 | 25 | # ------------------------------------------------------------------------------ 26 | 27 | test_that("recipe interface", { 28 | skip_on_cran() 29 | skip_if_not_installed("Matrix", "1.6-2") 30 | skip_if_not_installed("lme4", "1.1-35.1") 31 | expect_silent({ 32 | set.seed(1) 33 | res <- rec_wflow |> 34 | tune_race_anova( 35 | cell_folds, 36 | grid = grid_mod_rec, 37 | control = control_race(verbose_elim = FALSE) 38 | ) 39 | }) 40 | expect_equal( 41 | class(res), 42 | c("tune_race", "tune_results", "tbl_df", "tbl", "data.frame") 43 | ) 44 | expect_true(nrow(collect_metrics(res)) < nrow(grid_mod) * 3) 45 | expect_equal(res, .Last.tune.result) 46 | }) 47 | 48 | # ------------------------------------------------------------------------------ 49 | 50 | test_that("variable interface", { 51 | skip_on_cran() 52 | skip_if_not_installed("Matrix", "1.6-2") 53 | skip_if_not_installed("lme4", "1.1-35.1") 54 | 55 | expect_silent({ 56 | set.seed(1) 57 | res <- var_wflow |> 58 | tune_race_anova( 59 | cell_folds, 60 | grid = grid_mod, 61 | control = control_race(verbose_elim = FALSE) 62 | ) 63 | }) 64 | expect_equal( 65 | class(res), 66 | c("tune_race", "tune_results", "tbl_df", "tbl", "data.frame") 67 | ) 68 | expect_true(nrow(collect_metrics(res)) < nrow(grid_mod) * 3) 69 | expect_equal(res, .Last.tune.result) 70 | }) 71 | 72 | # ------------------------------------------------------------------------------ 73 | 74 | test_that("too few resamples", { 75 | skip_if_not_installed("Matrix", "1.6-2") 76 | skip_if_not_installed("lme4", "1.1-35.1") 77 | 78 | rs <- rsample::vfold_cv(modeldata::cells, v = 2) 79 | expect_snapshot_error( 80 | f_wflow |> 81 | tune_race_anova( 82 | rs, 83 | grid = grid_mod, 84 | control = control_race(verbose_elim = TRUE) 85 | ) 86 | ) 87 | expect_snapshot_error( 88 | f_wflow |> 89 | tune_race_win_loss( 90 | rs, 91 | grid = grid_mod, 92 | control = control_race(verbose_elim = TRUE) 93 | ) 94 | ) 95 | }) 96 | -------------------------------------------------------------------------------- /tests/testthat/test-condense_control.R: -------------------------------------------------------------------------------- 1 | test_that("control_race works with condense_control", { 2 | expect_equal( 3 | parsnip::condense_control(control_race(), control_grid()), 4 | control_grid(parallel_over = "everything") 5 | ) 6 | 7 | expect_equal( 8 | parsnip::condense_control(control_race(), control_resamples()), 9 | control_resamples(parallel_over = "everything") 10 | ) 11 | }) 12 | 13 | test_that("control_sim_anneal works with condense_control", { 14 | expect_equal( 15 | parsnip::condense_control(control_sim_anneal(), control_grid()), 16 | control_grid(verbose = FALSE) 17 | ) 18 | 19 | expect_equal( 20 | parsnip::condense_control(control_sim_anneal(), control_resamples()), 21 | control_resamples(verbose = FALSE) 22 | ) 23 | }) 24 | -------------------------------------------------------------------------------- /tests/testthat/test-race-control.R: -------------------------------------------------------------------------------- 1 | ## ----------------------------------------------------------------------------- 2 | 3 | test_that("control_race arg passing", { 4 | expect_equal(control_race(verbose = TRUE)$verbose, TRUE) 5 | expect_equal(control_race(verbose_elim = TRUE)$verbose_elim, TRUE) 6 | expect_equal(control_race(burn_in = 13)$burn_in, 13) 7 | expect_equal(control_race(num_ties = 2)$num_ties, 2) 8 | expect_equal(control_race(alpha = .12)$alpha, .12) 9 | expect_equal(control_race(extract = function(x) x)$extract, function(x) x) 10 | expect_equal(control_race(save_pred = TRUE)$save_pred, TRUE) 11 | expect_equal(control_race(pkgs = "carrot")$pkgs, "carrot") 12 | expect_equal(control_race(save_workflow = TRUE)$save_workflow, TRUE) 13 | }) 14 | 15 | test_that("control_race bad arg passing", { 16 | expect_snapshot_error(control_race(verbose = "TRUE")) 17 | expect_snapshot_error(control_race(verbose = rep(TRUE, 2))) 18 | expect_snapshot_error(control_race(verbose_elim = "TRUE")) 19 | expect_snapshot_error(control_race(verbose_elim = rep(TRUE, 2))) 20 | expect_snapshot_error(control_race(save_pred = "TRUE")) 21 | expect_snapshot_error(control_race(save_pred = rep(TRUE, 2))) 22 | expect_snapshot_error(control_race(save_workflow = "TRUE")) 23 | expect_snapshot_error(control_race(save_workflow = rep(TRUE, 2))) 24 | expect_snapshot_error(control_race(burn_in = "yes")) 25 | expect_snapshot_error(control_race(burn_in = 0:1)) 26 | expect_snapshot_error(control_race(burn_in = 1)) 27 | expect_snapshot_error(control_race(num_ties = "yes")) 28 | expect_snapshot_error(control_race(num_ties = 0:1)) 29 | expect_snapshot_error(control_race(alpha = 0:1)) 30 | expect_snapshot_error(control_race(alpha = "huge")) 31 | expect_snapshot_error(control_race(alpha = 1)) 32 | expect_snapshot_error(control_race(pkg = 0:1)) 33 | expect_snapshot_error(control_race(extract = 0:1)) 34 | }) 35 | 36 | test_that("casting control_race to control_grid", { 37 | expect_snapshot(parsnip::condense_control(control_race(), control_grid())) 38 | }) 39 | -------------------------------------------------------------------------------- /tests/testthat/test-race-s3.R: -------------------------------------------------------------------------------- 1 | test_that("racing S3 methods", { 2 | skip_if_not_installed("Matrix", "1.6-2") 3 | skip_if_not_installed("lme4", "1.1-35.1") 4 | skip_if_not_installed("kknn") 5 | 6 | library(purrr) 7 | library(dplyr) 8 | library(parsnip) 9 | library(rsample) 10 | library(recipes) 11 | 12 | knn_mod_power <- 13 | nearest_neighbor(mode = "regression", dist_power = tune()) |> 14 | set_engine("kknn") 15 | 16 | simple_rec <- recipe(mpg ~ ., data = mtcars) 17 | 18 | set.seed(7898) 19 | race_folds <- vfold_cv(mtcars, repeats = 2) 20 | 21 | ctrl_rc <- control_race(save_pred = TRUE) 22 | set.seed(9323) 23 | anova_race <- 24 | tune_race_anova( 25 | knn_mod_power, 26 | simple_rec, 27 | resamples = race_folds, 28 | grid = tibble::tibble(dist_power = c(1 / 10, 1, 2)), 29 | control = ctrl_rc 30 | ) 31 | 32 | # ------------------------------------------------------------------------------ 33 | # collect metrics 34 | 35 | expect_equal(nrow(collect_metrics(anova_race)), 2) 36 | expect_equal(nrow(collect_metrics(anova_race, all_configs = TRUE)), 6) 37 | expect_equal(nrow(collect_metrics(anova_race, summarize = FALSE)), 2 * 20) 38 | expect_equal( 39 | nrow(collect_metrics(anova_race, summarize = FALSE, all_configs = TRUE)), 40 | nrow(map(anova_race$.metrics, \(x) x) |> list_rbind()) 41 | ) 42 | 43 | # ------------------------------------------------------------------------------ 44 | # collect predictions 45 | 46 | expect_equal( 47 | nrow(collect_predictions( 48 | anova_race, 49 | all_configs = FALSE, 50 | summarize = TRUE 51 | )), 52 | nrow(mtcars) * 1 # 1 config x nrow(mtcars) 53 | ) 54 | expect_equal( 55 | nrow(collect_predictions(anova_race, all_configs = TRUE, summarize = TRUE)), 56 | map(anova_race$.predictions, \(x) x) |> 57 | list_rbind() |> 58 | distinct(.config, .row) |> 59 | nrow() 60 | ) 61 | expect_equal( 62 | nrow(collect_predictions( 63 | anova_race, 64 | all_configs = FALSE, 65 | summarize = FALSE 66 | )), 67 | nrow(mtcars) * 1 * 2 # 1 config x 2 repeats x nrow(mtcars) 68 | ) 69 | expect_equal( 70 | nrow(collect_predictions( 71 | anova_race, 72 | all_configs = TRUE, 73 | summarize = FALSE 74 | )), 75 | nrow(map(anova_race$.predictions, \(x) x) |> list_rbind()) 76 | ) 77 | 78 | # ------------------------------------------------------------------------------ 79 | # show_best and select_best 80 | 81 | expect_equal(nrow(show_best(anova_race, metric = "rmse")), 1) 82 | expect_true(all(show_best(anova_race, metric = "rmse")$n == 20)) 83 | expect_equal(nrow(select_best(anova_race, metric = "rmse")), 1) 84 | expect_equal( 85 | nrow(select_by_pct_loss( 86 | anova_race, 87 | metric = "rmse", 88 | dist_power, 89 | limit = 10 90 | )), 91 | 1 92 | ) 93 | expect_equal( 94 | nrow(select_by_one_std_err(anova_race, metric = "rmse", dist_power)), 95 | 1 96 | ) 97 | }) 98 | -------------------------------------------------------------------------------- /tests/testthat/test-random-integer-neighbors.R: -------------------------------------------------------------------------------- 1 | test_that("random integers in range", { 2 | set.seed(123) 3 | parameters <- dials::parameters(list(dials::tree_depth(range = c(2, 3)))) 4 | random_integer_neigbors <- 5 | purrr::map( 6 | 1:500, 7 | \(x) 8 | finetune:::random_integer_neighbor_calc( 9 | tibble::tibble(tree_depth = 3), 10 | parameters, 11 | 0.75, 12 | FALSE 13 | ) 14 | ) |> 15 | purrr::list_rbind() 16 | 17 | expect_true(all(random_integer_neigbors$tree_depth >= 2)) 18 | expect_true(all(random_integer_neigbors$tree_depth <= 3)) 19 | }) 20 | -------------------------------------------------------------------------------- /tests/testthat/test-sa-control.R: -------------------------------------------------------------------------------- 1 | ## ----------------------------------------------------------------------------- 2 | 3 | test_that("control_sim_anneal arg passing", { 4 | expect_equal(control_sim_anneal(verbose = TRUE)$verbose, TRUE) 5 | expect_equal(control_sim_anneal(no_improve = 13)$no_improve, 13L) 6 | expect_equal(control_sim_anneal(restart = 2)$restart, 2) 7 | expect_equal(control_sim_anneal(radius = rep(.12, 2))$radius, rep(.12, 2)) 8 | expect_equal(control_sim_anneal(flip = .122)$flip, .122) 9 | expect_equal(control_sim_anneal(cooling_coef = 1 / 10)$cooling_coef, 1 / 10) 10 | expect_equal( 11 | control_sim_anneal(extract = function(x) x)$extract, 12 | function(x) x 13 | ) 14 | expect_equal(control_sim_anneal(save_pred = TRUE)$save_pred, TRUE) 15 | expect_equal(control_sim_anneal(time_limit = 2)$time_limit, 2) 16 | expect_equal(control_sim_anneal(pkgs = "carrot")$pkgs, "carrot") 17 | expect_equal(control_sim_anneal(save_workflow = TRUE)$save_workflow, TRUE) 18 | }) 19 | 20 | test_that("control_sim_anneal bad arg passing", { 21 | expect_snapshot_error(control_sim_anneal(verbose = "TRUE")) 22 | expect_snapshot_error(control_sim_anneal(verbose = rep(TRUE, 2))) 23 | expect_snapshot_error(control_sim_anneal(save_pred = "TRUE")) 24 | expect_snapshot_error(control_sim_anneal(save_pred = rep(TRUE, 2))) 25 | expect_snapshot_error(control_sim_anneal(save_workflow = "TRUE")) 26 | expect_snapshot_error(control_sim_anneal(save_workflow = rep(TRUE, 2))) 27 | expect_snapshot_error(control_sim_anneal(no_improve = "yes")) 28 | expect_snapshot_error(control_sim_anneal(no_improve = 0:1)) 29 | expect_snapshot_error(control_sim_anneal(no_improve = 1)) 30 | expect_snapshot_error(control_sim_anneal(restart = "yes")) 31 | expect_snapshot_error(control_sim_anneal(restart = 0:1)) 32 | expect_snapshot_error(control_sim_anneal(restart = 1)) 33 | expect_snapshot(control_sim_anneal(no_improve = 2, restart = 6)) 34 | expect_snapshot_error(control_sim_anneal(radius = "huge")) 35 | expect_equal(control_sim_anneal(radius = c(-1, .2))$radius, c(0.001, .2)) 36 | expect_equal(control_sim_anneal(radius = c(15, .1))$radius, c(.1, 0.999)) 37 | expect_snapshot_error(control_sim_anneal(flip = 0:1)) 38 | expect_snapshot_error(control_sim_anneal(flip = "huge")) 39 | expect_equal(control_sim_anneal(flip = -1)$flip, 0) 40 | expect_equal(control_sim_anneal(flip = 2)$flip, 1) 41 | expect_snapshot_error(control_sim_anneal(cooling_coef = 0:1)) 42 | expect_snapshot_error(control_sim_anneal(cooling_coef = "huge")) 43 | expect_equal(control_sim_anneal(cooling_coef = -1)$cooling_coef, 0.0001) 44 | expect_equal(control_sim_anneal(cooling_coef = 2)$cooling_coef, 2) 45 | expect_snapshot_error(control_sim_anneal(pkg = 0:1)) 46 | expect_snapshot_error(control_sim_anneal(extract = 0:1)) 47 | }) 48 | 49 | test_that("casting control_sim_anneal to control_grid", { 50 | expect_snapshot(parsnip::condense_control( 51 | control_sim_anneal(), 52 | control_grid() 53 | )) 54 | }) 55 | -------------------------------------------------------------------------------- /tests/testthat/test-sa-decision.R: -------------------------------------------------------------------------------- 1 | load(file.path(test_path(), "sa_cart_test_objects.RData")) 2 | 3 | ## ----------------------------------------------------------------------------- 4 | 5 | cart_param <- tune::.get_tune_parameters(cart_search) 6 | cart_metrics <- tune::.get_tune_metrics(cart_search) 7 | cart_outcomes <- tune::.get_tune_outcome_names(cart_search) 8 | cart_rset_info <- attributes(cart_search)$rset_info 9 | 10 | ## ----------------------------------------------------------------------------- 11 | 12 | test_that("simulated annealing decisions", { 13 | for (iter_val in 1:max(cart_history$.iter)) { 14 | iter_hist <- cart_history |> filter(.iter < iter_val) 15 | iter_res <- 16 | cart_search |> 17 | filter(.iter == iter_val) |> 18 | tune:::new_tune_results( 19 | parameters = cart_param, 20 | outcomes = cart_outcomes, 21 | metrics = cart_metrics, 22 | eval_time = NULL, 23 | eval_time_target = NULL, 24 | rset_info = cart_rset_info 25 | ) 26 | iter_new_hist <- finetune:::update_history( 27 | iter_hist, 28 | iter_res, 29 | iter_val, 30 | NULL 31 | ) 32 | iter_new_hist$random[1:nrow(iter_new_hist)] <- cart_history$random[ 33 | 1:nrow(iter_new_hist) 34 | ] 35 | 36 | expect_equal( 37 | iter_new_hist$mean[iter_new_hist$.iter == iter_val], 38 | cart_history$mean[cart_history$.iter == iter_val] 39 | ) 40 | 41 | expect_equal( 42 | iter_new_hist$std_err[iter_new_hist$.iter == iter_val], 43 | cart_history$std_err[cart_history$.iter == iter_val] 44 | ) 45 | 46 | new_sa_res <- 47 | finetune:::sa_decide( 48 | iter_new_hist, 49 | parent = cart_history$.parent[cart_history$.iter == iter_val], 50 | metric = "roc_auc", 51 | maximize = TRUE, 52 | coef = control_sim_anneal()$cooling_coef 53 | ) 54 | 55 | expect_equal( 56 | new_sa_res$results[new_sa_res$.iter == iter_val], 57 | cart_history$results[cart_history$.iter == iter_val] 58 | ) 59 | 60 | expect_equal( 61 | new_sa_res$accept[new_sa_res$.iter == iter_val], 62 | cart_history$accept[cart_history$.iter == iter_val] 63 | ) 64 | } 65 | }) 66 | 67 | ## ----------------------------------------------------------------------------- 68 | 69 | test_that("percent difference", { 70 | expect_equal(finetune:::percent_diff(1, 2), 100) 71 | expect_equal(finetune:::percent_diff(1, 1), 0) 72 | expect_equal(finetune:::percent_diff(1, 2, FALSE), -100) 73 | expect_equal(finetune:::percent_diff(1, 1, FALSE), 0) 74 | }) 75 | 76 | 77 | ## ----------------------------------------------------------------------------- 78 | 79 | test_that("acceptance probabilities", { 80 | expect_equal(finetune:::acceptance_prob(1, 2, iter = 1, maximize = TRUE), 1) 81 | expect_equal(finetune:::acceptance_prob(1, 1, iter = 1, maximize = TRUE), 1) 82 | 83 | expect_equal( 84 | finetune:::acceptance_prob(2, 1, iter = 1, maximize = TRUE), 85 | exp(finetune:::percent_diff(2, 1) * 1 * control_sim_anneal()$cooling_coef) 86 | ) 87 | expect_equal( 88 | finetune:::acceptance_prob(2, 1, iter = 10, maximize = TRUE), 89 | exp(finetune:::percent_diff(2, 1) * 10 * control_sim_anneal()$cooling_coef) 90 | ) 91 | 92 | expect_equal(finetune:::acceptance_prob(3, 1, iter = 1, maximize = FALSE), 1) 93 | expect_equal(finetune:::acceptance_prob(3, 1, iter = 1, maximize = FALSE), 1) 94 | 95 | expect_equal( 96 | finetune:::acceptance_prob(1, 3, iter = 1, maximize = FALSE), 97 | exp( 98 | finetune:::percent_diff(1, 3, maximize = FALSE) * 99 | 1 * 100 | control_sim_anneal()$cooling_coef 101 | ) 102 | ) 103 | expect_equal( 104 | finetune:::acceptance_prob(1, 3, iter = 10, maximize = FALSE), 105 | exp( 106 | finetune:::percent_diff(1, 3, maximize = FALSE) * 107 | 10 * 108 | control_sim_anneal()$cooling_coef 109 | ) 110 | ) 111 | }) 112 | 113 | ## ----------------------------------------------------------------------------- 114 | 115 | test_that("logging results", { 116 | iters <- max(cart_history$.iter) 117 | 118 | for (i in 1:iters) { 119 | expect_message( 120 | finetune:::log_sa_progress( 121 | x = cart_history |> filter(.iter <= i), 122 | metric = "roc_auc", 123 | max_iter = i 124 | ), 125 | regexp = cart_history$results[cart_history$.iter == i] 126 | ) 127 | } 128 | }) 129 | -------------------------------------------------------------------------------- /tests/testthat/test-sa-misc.R: -------------------------------------------------------------------------------- 1 | ## ----------------------------------------------------------------------------- 2 | 3 | test_that("tune_sim_anneal interfaces", { 4 | skip_on_cran() 5 | skip_if_not_installed(c("discrim", "klaR")) 6 | 7 | library(discrim) 8 | data("two_class_dat", package = "modeldata") 9 | 10 | ## ----------------------------------------------------------------------------- 11 | 12 | rda_spec <- 13 | discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) |> 14 | set_engine("klaR") 15 | 16 | rda_param <- rda_spec |> 17 | extract_parameter_set_dials() |> 18 | update( 19 | frac_common_cov = frac_common_cov(c(.3, .6)), 20 | frac_identity = frac_identity(c(.3, .6)) 21 | ) 22 | 23 | set.seed(813) 24 | rs <- bootstraps(two_class_dat, times = 3) 25 | 26 | rec <- recipe(Class ~ ., data = two_class_dat) |> 27 | step_ns(A, deg_free = tune()) 28 | 29 | # ------------------------------------------------------------------------------ 30 | # formula interface 31 | 32 | expect_snapshot({ 33 | set.seed(1) 34 | f_res_1 <- rda_spec |> tune_sim_anneal(Class ~ ., rs, iter = 3) 35 | }) 36 | 37 | expect_snapshot({ 38 | set.seed(1) 39 | f_res_2 <- rda_spec |> 40 | tune_sim_anneal(Class ~ ., rs, iter = 3, param_info = rda_param) 41 | }) 42 | 43 | expect_true(all(collect_metrics(f_res_2)$frac_common_cov >= 0.3)) 44 | expect_true(all(collect_metrics(f_res_2)$frac_common_cov <= 0.6)) 45 | expect_true(all(collect_metrics(f_res_2)$frac_identity >= 0.3)) 46 | expect_true(all(collect_metrics(f_res_2)$frac_identity <= 0.6)) 47 | 48 | # ------------------------------------------------------------------------------ 49 | # recipe interface 50 | 51 | expect_snapshot({ 52 | set.seed(1) 53 | f_rec_1 <- rda_spec |> tune_sim_anneal(rec, rs, iter = 3) 54 | }) 55 | expect_equal(sum(names(collect_metrics(f_rec_1)) == "deg_free"), 1) 56 | expect_equal(sum(names(collect_metrics(f_rec_1)) == "frac_common_cov"), 1) 57 | expect_equal(sum(names(collect_metrics(f_rec_1)) == "frac_identity"), 1) 58 | 59 | # ------------------------------------------------------------------------------ 60 | # workflow interface 61 | 62 | wflow <- 63 | workflow() |> 64 | add_model(rda_spec) |> 65 | add_recipe(rec) 66 | 67 | expect_snapshot({ 68 | set.seed(1) 69 | f_wflow_1 <- wflow |> tune_sim_anneal(rs, iter = 3) 70 | }) 71 | expect_equal(sum(names(collect_metrics(f_wflow_1)) == "deg_free"), 1) 72 | expect_equal(sum(names(collect_metrics(f_wflow_1)) == "frac_common_cov"), 1) 73 | expect_equal(sum(names(collect_metrics(f_wflow_1)) == "frac_identity"), 1) 74 | }) 75 | 76 | ## ----------------------------------------------------------------------------- 77 | 78 | test_that("tune_sim_anneal with wrong type", { 79 | expect_snapshot( 80 | tune_sim_anneal(1), 81 | error = TRUE 82 | ) 83 | }) 84 | -------------------------------------------------------------------------------- /tests/testthat/test-sa-overall.R: -------------------------------------------------------------------------------- 1 | test_that("formula interface", { 2 | skip_on_cran() 3 | expect_snapshot({ 4 | set.seed(1) 5 | res <- f_wflow |> 6 | tune_sim_anneal( 7 | cell_folds, 8 | iter = 2, 9 | control = control_sim_anneal(verbose = TRUE) 10 | ) 11 | }) 12 | expect_equal( 13 | class(res), 14 | c("iteration_results", "tune_results", "tbl_df", "tbl", "data.frame") 15 | ) 16 | expect_true(nrow(collect_metrics(res)) == 9) 17 | expect_equal(res, .Last.tune.result) 18 | expect_null(.get_tune_eval_times(res)) 19 | expect_null(.get_tune_eval_time_target(res)) 20 | }) 21 | 22 | # ------------------------------------------------------------------------------ 23 | 24 | test_that("recipe interface", { 25 | skip_on_cran() 26 | skip_on_os("windows") 27 | skip_on_os("linux") 28 | 29 | expect_silent({ 30 | set.seed(1) 31 | res <- rec_wflow |> 32 | tune_sim_anneal( 33 | cell_folds, 34 | iter = 2, 35 | control = control_sim_anneal(verbose = FALSE, verbose_iter = FALSE) 36 | ) 37 | }) 38 | 39 | expect_equal( 40 | class(res), 41 | c("iteration_results", "tune_results", "tbl_df", "tbl", "data.frame") 42 | ) 43 | expect_true(nrow(collect_metrics(res)) == 9) 44 | expect_equal(res, .Last.tune.result) 45 | }) 46 | 47 | # ------------------------------------------------------------------------------ 48 | 49 | test_that("variable interface", { 50 | skip_on_cran() 51 | expect_snapshot({ 52 | set.seed(1) 53 | res <- var_wflow |> 54 | tune_sim_anneal( 55 | cell_folds, 56 | iter = 2, 57 | control = control_sim_anneal(verbose = TRUE, verbose_iter = TRUE) 58 | ) 59 | }) 60 | expect_equal( 61 | class(res), 62 | c("iteration_results", "tune_results", "tbl_df", "tbl", "data.frame") 63 | ) 64 | expect_true(nrow(collect_metrics(res)) == 9) 65 | expect_equal(res, .Last.tune.result) 66 | 67 | # Check to see if iterations are picked up when an iterative object is used 68 | # as the initial object 69 | 70 | expect_snapshot({ 71 | set.seed(1) 72 | new_res <- var_wflow |> 73 | tune_sim_anneal( 74 | cell_folds, 75 | iter = 2, 76 | initial = res, 77 | control = control_sim_anneal(verbose = FALSE) 78 | ) 79 | }) 80 | expect_true(nrow(collect_metrics(new_res)) == 15) 81 | expect_true(max(new_res$.iter) == 4) 82 | expect_true(sum(grepl("^initial", collect_metrics(new_res)$.config)) == 9) 83 | expect_equal(new_res, .Last.tune.result) 84 | 85 | # but not for non-iterative objects 86 | set.seed(1) 87 | grid_res <- var_wflow |> 88 | tune_grid(cell_folds, grid = 2) 89 | 90 | expect_snapshot({ 91 | set.seed(1) 92 | new_new_res <- var_wflow |> 93 | tune_sim_anneal( 94 | cell_folds, 95 | iter = 2, 96 | initial = grid_res, 97 | control = control_sim_anneal(verbose = FALSE) 98 | ) 99 | }) 100 | expect_true(nrow(collect_metrics(new_new_res)) == 12) 101 | expect_true(max(new_new_res$.iter) == 2) 102 | expect_true(sum(grepl("^initial", collect_metrics(new_new_res)$.config)) == 6) 103 | expect_equal(new_new_res, .Last.tune.result) 104 | }) 105 | 106 | 107 | test_that("unfinalized parameters", { 108 | skip_on_cran() 109 | skip_on_os("windows") 110 | skip_on_os("linux") 111 | 112 | data(two_class_dat, package = "modeldata") 113 | 114 | set.seed(5046) 115 | bt <- bootstraps(two_class_dat, times = 5) 116 | 117 | rec_example <- recipe(Class ~ ., data = two_class_dat) 118 | 119 | # RF 120 | model_rf <- rand_forest(mtry = tune()) |> 121 | set_mode("classification") |> 122 | set_engine("ranger") 123 | 124 | wf_rf <- workflow() |> 125 | add_model(model_rf) |> 126 | add_recipe(rec_example) 127 | 128 | set.seed(30) 129 | rf_res <- wf_rf |> 130 | tune_grid(resamples = bt, grid = 4) 131 | 132 | expect_snapshot({ 133 | set.seed(40) 134 | rf_res_finetune <- wf_rf |> 135 | tune_sim_anneal(resamples = bt, initial = rf_res) 136 | }) 137 | 138 | # don't supply an initial grid (#39) 139 | expect_snapshot({ 140 | set.seed(40) 141 | rf_res_finetune <- wf_rf |> 142 | tune_sim_anneal(resamples = bt) 143 | }) 144 | }) 145 | 146 | test_that("incompatible parameter objects", { 147 | skip_on_cran() 148 | 149 | skip_if_not_installed("ranger") 150 | skip_if_not_installed("modeldata") 151 | skip_if_not_installed("rsample") 152 | 153 | rf_spec <- parsnip::rand_forest(mode = "regression", mtry = tune::tune()) 154 | 155 | set.seed(1) 156 | grid_with_bigger_range <- 157 | dials::grid_space_filling(dials::mtry(range = c(1, 16))) 158 | 159 | set.seed(1) 160 | car_folds <- rsample::vfold_cv(car_prices, v = 2) 161 | 162 | car_wflow <- workflows::workflow() |> 163 | workflows::add_formula(Price ~ .) |> 164 | workflows::add_model(rf_spec) 165 | 166 | set.seed(1) 167 | tune_res_with_bigger_range <- tune::tune_grid( 168 | car_wflow, 169 | resamples = car_folds, 170 | grid = grid_with_bigger_range 171 | ) 172 | 173 | set.seed(1) 174 | parameter_set_with_smaller_range <- 175 | dials::parameters(dials::mtry(range = c(1, 5))) 176 | 177 | scrub_best <- function(lines) { 178 | has_best <- grepl("Initial best", lines) 179 | lines[has_best] <- "" 180 | lines 181 | } 182 | 183 | set.seed(1) 184 | expect_snapshot(error = TRUE, transform = scrub_best, { 185 | res <- 186 | tune_sim_anneal( 187 | car_wflow, 188 | param_info = parameter_set_with_smaller_range, 189 | resamples = car_folds, 190 | initial = tune_res_with_bigger_range, 191 | iter = 2 192 | ) 193 | }) 194 | }) 195 | 196 | test_that("set event-level", { 197 | # See issue 40 198 | skip_if_not_installed("rpart") 199 | skip_if_not_installed("modeldata") 200 | skip_if_not_installed("yardstick") 201 | skip_if_not_installed("rsample") 202 | skip_on_cran() 203 | 204 | # ------------------------------------------------------------------------------ 205 | 206 | set.seed(1) 207 | dat <- modeldata::sim_classification(500, intercept = 8) 208 | 209 | # We should get high sensitivity and low specificity when event_level = "first" 210 | # count(dat, class) 211 | # levels(dat$class) 212 | 213 | set.seed(2) 214 | rs <- vfold_cv(dat, strata = class) 215 | 216 | cart_spec <- decision_tree(min_n = tune()) |> set_mode("classification") 217 | 218 | stats <- metric_set(accuracy, sensitivity, specificity) 219 | 220 | # ------------------------------------------------------------------------------ 221 | # high sensitivity and low specificity 222 | 223 | set.seed(3) 224 | cart_res_first <- 225 | cart_spec |> 226 | tune_sim_anneal( 227 | class ~ ., 228 | rs, 229 | control = control_sim_anneal(event_level = "first", verbose_iter = FALSE), 230 | metrics = stats 231 | ) 232 | 233 | results_first <- 234 | cart_res_first |> 235 | collect_metrics() |> 236 | dplyr::filter(.metric != "accuracy") |> 237 | dplyr::select(.config, .metric, mean) |> 238 | tidyr::pivot_wider( 239 | id_cols = .config, 240 | names_from = .metric, 241 | values_from = mean 242 | ) 243 | 244 | dir_check <- all(results_first$sensitivity > results_first$specificity) 245 | expect_true(dir_check) 246 | 247 | # ------------------------------------------------------------------------------ 248 | # Now reversed 249 | 250 | set.seed(3) 251 | cart_res_second <- 252 | cart_spec |> 253 | tune_sim_anneal( 254 | class ~ ., 255 | rs, 256 | control = control_sim_anneal( 257 | event_level = "second", 258 | verbose_iter = FALSE 259 | ), 260 | metrics = stats 261 | ) 262 | 263 | results_second <- 264 | cart_res_second |> 265 | collect_metrics() |> 266 | dplyr::filter(.metric != "accuracy") |> 267 | dplyr::select(.config, .metric, mean) |> 268 | tidyr::pivot_wider( 269 | id_cols = .config, 270 | names_from = .metric, 271 | values_from = mean 272 | ) 273 | 274 | rev_dir_check <- all(results_second$sensitivity < results_second$specificity) 275 | expect_true(rev_dir_check) 276 | }) 277 | -------------------------------------------------------------------------------- /tests/testthat/test-sa-perturb.R: -------------------------------------------------------------------------------- 1 | test_that("numerical neighborhood", { 2 | suppressPackageStartupMessages(library(dials)) 3 | 4 | num_prm <- dials::parameters(dials::mixture(), dials::threshold()) 5 | 6 | vals <- tibble::tibble(mixture = 0.5, threshold = 0.5) 7 | set.seed(1) 8 | new_vals <- 9 | finetune:::random_real_neighbor(vals, vals[0, ], num_prm, retain = 100) 10 | 11 | rad <- control_sim_anneal()$radius 12 | 13 | correct_r <- 14 | purrr::map2_dbl( 15 | new_vals$mixture, 16 | new_vals$threshold, 17 | ~ sqrt((.x - .5)^2 + (.y - .5)^2) 18 | ) |> 19 | purrr::map_lgl(\(x) x >= rad[1] & x <= rad[2]) 20 | expect_true(all(correct_r)) 21 | 22 | set.seed(1) 23 | prev <- tibble::tibble(mixture = runif(5), threshold = runif(5)) 24 | 25 | set.seed(2) 26 | more_vals <- finetune:::new_in_neighborhood( 27 | vals, 28 | prev, 29 | num_prm, 30 | radius = rep(0.12, 2) 31 | ) 32 | rad_vals <- sqrt((more_vals$mixture - .5)^2 + (more_vals$threshold - .5)^2) 33 | expect_equal(rad_vals, 0.12, tolerance = 0.001) 34 | }) 35 | 36 | test_that("numerical neighborhood boundary filters", { 37 | suppressPackageStartupMessages(library(dials)) 38 | num_prm <- dials::parameters(dials::mixture(), dials::threshold()) 39 | 40 | vals <- tibble::tibble(mixture = 0.05, threshold = 0.05) 41 | set.seed(1) 42 | new_vals <- 43 | finetune:::random_real_neighbor( 44 | vals, 45 | vals[0, ], 46 | num_prm, 47 | retain = 100, 48 | tries = 100, 49 | r = 0.12 50 | ) 51 | expect_true(nrow(new_vals) < 100) 52 | }) 53 | 54 | ## ----------------------------------------------------------------------------- 55 | 56 | test_that("categorical value switching", { 57 | suppressPackageStartupMessages(library(dials)) 58 | cat_prm <- parameters(activation(), weight_func()) 59 | 60 | vals <- tibble::tibble(activation = "relu", weight_func = "biweight") 61 | set.seed(1) 62 | new_vals <- 63 | purrr::map( 64 | 1:1000, 65 | \(x) 66 | finetune:::random_discrete_neighbor( 67 | vals, 68 | cat_prm, 69 | prob = 1 / 4, 70 | change = FALSE 71 | ) 72 | ) |> 73 | purrr::list_rbind() 74 | relu_same <- mean(new_vals$activation == "relu") 75 | biweight_same <- mean(new_vals$weight_func == "biweight") 76 | 77 | expect_true(relu_same > .7 & relu_same < .8) 78 | expect_true(biweight_same > .7 & biweight_same < .8) 79 | 80 | set.seed(1) 81 | prev <- tibble::tibble( 82 | activation = dials::values_activation[1:4], 83 | weight_func = dials::values_weight_func[1:4] 84 | ) 85 | set.seed(2) 86 | must_change <- finetune:::new_in_neighborhood(vals, prev, cat_prm, flip = 1) 87 | expect_true(must_change$activation != "relu") 88 | expect_true(must_change$weight_func != "biweight") 89 | }) 90 | 91 | ## ----------------------------------------------------------------------------- 92 | 93 | test_that("reverse-unit encoding", { 94 | suppressPackageStartupMessages(library(dials)) 95 | prm <- 96 | parameters(batch_size(), Laplace(), activation()) |> 97 | update(Laplace = Laplace(c(2, 4)), batch_size = batch_size(c(10, 20))) 98 | unit_vals <- tibble::tibble(batch_size = .1, Laplace = .4, activation = .7) 99 | vals <- finetune:::encode_set_backwards(unit_vals, prm) 100 | expect_true(vals$batch_size > 1) 101 | expect_true(vals$Laplace > 1) 102 | expect_true(is.character(vals$activation)) 103 | }) 104 | -------------------------------------------------------------------------------- /tests/testthat/test-win-loss-filter.R: -------------------------------------------------------------------------------- 1 | test_that("top-level win/loss filter interfaces", { 2 | skip_on_cran() 3 | # Skip for < 4.0 due to random number differences 4 | skip_if(getRversion() < "4.0.0") 5 | 6 | library(dials) 7 | 8 | # ------------------------------------------------------------------------------ 9 | 10 | set.seed(2332) 11 | folds <- vfold_cv(mtcars, v = 5, repeats = 2) 12 | fold_att <- attributes(folds) 13 | spec <- decision_tree(cost_complexity = tune(), min_n = tune()) |> 14 | set_engine("rpart") |> 15 | set_mode("regression") 16 | wflow <- workflow() |> 17 | add_model(spec) |> 18 | add_formula(mpg ~ .) 19 | grid <- expand.grid(cost_complexity = c(0.001, 0.01), min_n = c(2:5)) 20 | rec <- recipe(mpg ~ ., data = mtcars) |> 21 | step_normalize(all_predictors()) 22 | prm <- extract_parameter_set_dials(wflow) |> update(min_n = min_n(c(2, 20))) 23 | 24 | # ------------------------------------------------------------------------------ 25 | 26 | set.seed(129) 27 | suppressWarnings( 28 | wl_mod <- spec |> tune_race_win_loss(mpg ~ ., folds, grid = grid) 29 | ) 30 | 31 | expect_true(inherits(wl_mod, "tune_race")) 32 | expect_true(inherits(wl_mod, "tune_results")) 33 | expect_true(tibble::is_tibble((wl_mod))) 34 | expect_null(.get_tune_eval_times(wl_mod)) 35 | expect_null(.get_tune_eval_time_target(wl_mod)) 36 | 37 | expect_silent({ 38 | set.seed(129) 39 | suppressWarnings( 40 | wl_wlfow <- 41 | wflow |> 42 | tune_race_win_loss( 43 | folds, 44 | grid = grid, 45 | param_info = prm, 46 | control = control_race(verbose_elim = FALSE, save_pred = TRUE) 47 | ) 48 | ) 49 | }) 50 | 51 | expect_true(inherits(wl_wlfow, "tune_race")) 52 | expect_true(inherits(wl_wlfow, "tune_results")) 53 | expect_true(tibble::is_tibble((wl_wlfow))) 54 | expect_true(sum(names(wl_wlfow) == ".predictions") == 1) 55 | 56 | get_mod <- function(x) workflows::extract_fit_parsnip(x) 57 | 58 | expect_silent({ 59 | set.seed(129) 60 | suppressMessages( 61 | wl_rec <- 62 | spec |> 63 | tune_race_win_loss( 64 | rec, 65 | folds, 66 | grid = expand.grid(cost_complexity = c(.0001, .001), min_n = c(3, 5)), 67 | param_info = prm, 68 | control = control_race( 69 | verbose_elim = FALSE, 70 | extract = get_mod 71 | ) 72 | ) 73 | ) 74 | }) 75 | 76 | expect_true(inherits(wl_rec, "tune_race")) 77 | expect_true(inherits(wl_rec, "tune_results")) 78 | expect_true(tibble::is_tibble((wl_rec))) 79 | expect_true(sum(names(wl_rec) == ".extracts") == 1) 80 | }) 81 | -------------------------------------------------------------------------------- /tests/testthat/test-win-loss-overall.R: -------------------------------------------------------------------------------- 1 | test_that("formula interface", { 2 | skip_on_cran() 3 | 4 | expect_snapshot({ 5 | set.seed(1) 6 | res <- f_wflow |> 7 | tune_race_win_loss( 8 | cell_folds, 9 | grid = 5, 10 | control = control_race(verbose_elim = TRUE) 11 | ) 12 | }) 13 | 14 | expect_equal( 15 | class(res), 16 | c("tune_race", "tune_results", "tbl_df", "tbl", "data.frame") 17 | ) 18 | expect_true(nrow(collect_metrics(res)) == 12) # this run has one elimination 19 | expect_equal(res, .Last.tune.result) 20 | }) 21 | 22 | # ------------------------------------------------------------------------------ 23 | 24 | test_that("recipe interface", { 25 | skip_on_cran() 26 | expect_silent({ 27 | set.seed(1) 28 | res <- rec_wflow |> 29 | tune_race_win_loss( 30 | cell_folds, 31 | grid = 5, 32 | control = control_race(verbose_elim = FALSE) 33 | ) 34 | }) 35 | expect_equal( 36 | class(res), 37 | c("tune_race", "tune_results", "tbl_df", "tbl", "data.frame") 38 | ) 39 | expect_true(nrow(collect_metrics(res)) < 10) 40 | expect_equal(res, .Last.tune.result) 41 | }) 42 | 43 | # ------------------------------------------------------------------------------ 44 | 45 | test_that("variable interface", { 46 | skip_on_cran() 47 | expect_silent({ 48 | set.seed(1) 49 | res <- var_wflow |> 50 | tune_race_win_loss( 51 | cell_folds, 52 | grid = 5, 53 | control = control_race(verbose_elim = FALSE) 54 | ) 55 | }) 56 | expect_equal( 57 | class(res), 58 | c("tune_race", "tune_results", "tbl_df", "tbl", "data.frame") 59 | ) 60 | expect_true(nrow(collect_metrics(res)) == 12) # one elimination 61 | expect_equal(res, .Last.tune.result) 62 | }) 63 | 64 | # ------------------------------------------------------------------------------ 65 | 66 | test_that("one player is really bad", { 67 | skip_on_cran() 68 | skip_if_not_installed("tune", "0.1.5.9001") 69 | 70 | set.seed(1341) 71 | df <- tibble( 72 | x1 = rnorm(500, 1:500), 73 | x2 = sample(c(1:4), size = 500, replace = T) 74 | ) |> 75 | mutate( 76 | y = rbinom(500, 1, prob = (x1 / max(x1))) |> as.factor() 77 | ) 78 | 79 | set.seed(121) 80 | df_folds <- vfold_cv(df, strata = y) 81 | 82 | rf_spec <- 83 | rand_forest(min_n = tune(), trees = 10) |> 84 | set_engine("ranger") |> 85 | set_mode("classification") 86 | 87 | wf <- workflow() |> 88 | add_formula(y ~ .) |> 89 | add_model(rf_spec) 90 | 91 | grid <- tibble(min_n = c(1, 40)) 92 | ctrl <- control_race(burn_in = 2, alpha = .05, randomize = TRUE) 93 | set.seed(3355) 94 | tuning_results <- tune_race_win_loss( 95 | wf, 96 | resamples = df_folds, 97 | metrics = metric_set(roc_auc), 98 | grid = grid, 99 | control = ctrl 100 | ) 101 | 102 | expect_snapshot(best_res <- show_best(tuning_results)) 103 | expect_true(nrow(best_res) == 1) 104 | }) 105 | --------------------------------------------------------------------------------