├── .Rbuildignore ├── .github ├── .gitignore ├── CODE_OF_CONDUCT.md └── workflows │ ├── R-CMD-check-hard.yaml │ ├── R-CMD-check.yaml │ ├── lock.yaml │ ├── pkgdown.yaml │ ├── pr-commands.yaml │ └── test-coverage.yaml ├── .gitignore ├── DESCRIPTION ├── LICENSE ├── LICENSE.md ├── NAMESPACE ├── NEWS.md ├── R ├── 0_utils.R ├── aaa.R ├── activation.R ├── autoplot.R ├── brulee-package.R ├── checks.R ├── coef.R ├── convert_data.R ├── import-standalone-obj-type.R ├── import-standalone-types-check.R ├── linear_reg-fit.R ├── linear_reg-predict.R ├── logistic_reg-fit.R ├── logistic_reg-predict.R ├── mlp-fit.R ├── mlp-predict.R ├── multinomial_reg-fit.R ├── multinomial_reg-predict.R └── schedulers.R ├── README.Rmd ├── README.md ├── _pkgdown.yml ├── brulee.Rproj ├── codecov.yml ├── inst └── WORDLIST ├── man ├── brulee-autoplot.Rd ├── brulee-coefs.Rd ├── brulee-package.Rd ├── brulee_activations.Rd ├── brulee_linear_reg.Rd ├── brulee_logistic_reg.Rd ├── brulee_mlp.Rd ├── brulee_multinomial_reg.Rd ├── figures │ └── logo.png ├── matrix_to_dataset.Rd ├── predict.brulee_linear_reg.Rd ├── predict.brulee_logistic_reg.Rd ├── predict.brulee_mlp.Rd ├── predict.brulee_multinomial_reg.Rd ├── reexports.Rd └── schedule_decay_time.Rd └── tests ├── spelling.R ├── testthat.R └── testthat ├── _snaps ├── checks.md ├── class-weight.md ├── mlp-regression.md └── schedulers.md ├── test-checks.R ├── test-class-weight.R ├── test-linear_reg-fit.R ├── test-logistic_reg-fit.R ├── test-mlp-activations.R ├── test-mlp-binary.R ├── test-mlp-multinomial.R ├── test-mlp-regression.R ├── test-multinomial_reg-fit.R └── test-schedulers.R /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^brulee\.Rproj$ 2 | ^\.Rproj\.user$ 3 | ^CODE_OF_CONDUCT\.md$ 4 | ^LICENSE\.md$ 5 | ^README\.Rmd$ 6 | ^codecov\.yml$ 7 | ^_pkgdown\.yml$ 8 | ^docs$ 9 | ^pkgdown$ 10 | ^\.github$ 11 | ^mnist$ 12 | ^revdep$ 13 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at codeofconduct@posit.co. 63 | All complaints will be reviewed and investigated promptly and fairly. 64 | 65 | All community leaders are obligated to respect the privacy and security of the 66 | reporter of any incident. 67 | 68 | ## Enforcement Guidelines 69 | 70 | Community leaders will follow these Community Impact Guidelines in determining 71 | the consequences for any action they deem in violation of this Code of Conduct: 72 | 73 | ### 1. Correction 74 | 75 | **Community Impact**: Use of inappropriate language or other behavior deemed 76 | unprofessional or unwelcome in the community. 77 | 78 | **Consequence**: A private, written warning from community leaders, providing 79 | clarity around the nature of the violation and an explanation of why the 80 | behavior was inappropriate. A public apology may be requested. 81 | 82 | ### 2. Warning 83 | 84 | **Community Impact**: A violation through a single incident or series of 85 | actions. 86 | 87 | **Consequence**: A warning with consequences for continued behavior. No 88 | interaction with the people involved, including unsolicited interaction with 89 | those enforcing the Code of Conduct, for a specified period of time. This 90 | includes avoiding interactions in community spaces as well as external channels 91 | like social media. Violating these terms may lead to a temporary or permanent 92 | ban. 93 | 94 | ### 3. Temporary Ban 95 | 96 | **Community Impact**: A serious violation of community standards, including 97 | sustained inappropriate behavior. 98 | 99 | **Consequence**: A temporary ban from any sort of interaction or public 100 | communication with the community for a specified period of time. No public or 101 | private interaction with the people involved, including unsolicited interaction 102 | with those enforcing the Code of Conduct, is allowed during this period. 103 | Violating these terms may lead to a permanent ban. 104 | 105 | ### 4. Permanent Ban 106 | 107 | **Community Impact**: Demonstrating a pattern of violation of community 108 | standards, including sustained inappropriate behavior, harassment of an 109 | individual, or aggression toward or disparagement of classes of individuals. 110 | 111 | **Consequence**: A permanent ban from any sort of public interaction within the 112 | community. 113 | 114 | ## Attribution 115 | 116 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 117 | version 2.1, available at 118 | . 119 | 120 | Community Impact Guidelines were inspired by 121 | [Mozilla's code of conduct enforcement ladder][https://github.com/mozilla/inclusion]. 122 | 123 | For answers to common questions about this code of conduct, see the FAQ at 124 | . Translations are available at . 125 | 126 | [homepage]: https://www.contributor-covenant.org 127 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check-hard.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | # 4 | # NOTE: This workflow only directly installs "hard" dependencies, i.e. Depends, 5 | # Imports, and LinkingTo dependencies. Notably, Suggests dependencies are never 6 | # installed, with the exception of testthat, knitr, and rmarkdown. The cache is 7 | # never used to avoid accidentally restoring a cache containing a suggested 8 | # dependency. 9 | on: 10 | push: 11 | branches: [main, master] 12 | pull_request: 13 | 14 | name: R-CMD-check-hard.yaml 15 | 16 | permissions: read-all 17 | 18 | jobs: 19 | check-no-suggests: 20 | runs-on: ${{ matrix.config.os }} 21 | 22 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 23 | 24 | strategy: 25 | fail-fast: false 26 | matrix: 27 | config: 28 | - {os: ubuntu-latest, r: 'release'} 29 | 30 | env: 31 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 32 | R_KEEP_PKG_SOURCE: yes 33 | 34 | steps: 35 | - uses: actions/checkout@v4 36 | 37 | - uses: r-lib/actions/setup-pandoc@v2 38 | 39 | - uses: r-lib/actions/setup-r@v2 40 | with: 41 | r-version: ${{ matrix.config.r }} 42 | http-user-agent: ${{ matrix.config.http-user-agent }} 43 | use-public-rspm: true 44 | 45 | - uses: r-lib/actions/setup-r-dependencies@v2 46 | with: 47 | dependencies: '"hard"' 48 | cache: false 49 | extra-packages: | 50 | any::rcmdcheck 51 | any::testthat 52 | any::knitr 53 | any::rmarkdown 54 | needs: check 55 | 56 | - uses: r-lib/actions/check-r-package@v2 57 | with: 58 | upload-snapshots: true 59 | build_args: 'c("--no-manual","--compact-vignettes=gs+qpdf")' 60 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | # 4 | # NOTE: This workflow is overkill for most R packages and 5 | # check-standard.yaml is likely a better choice. 6 | # usethis::use_github_action("check-standard") will install it. 7 | on: 8 | push: 9 | branches: [main, master] 10 | pull_request: 11 | branches: [main, master] 12 | 13 | name: R-CMD-check.yaml 14 | 15 | permissions: read-all 16 | 17 | jobs: 18 | R-CMD-check: 19 | runs-on: ${{ matrix.config.os }} 20 | 21 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 22 | 23 | strategy: 24 | fail-fast: false 25 | matrix: 26 | config: 27 | - {os: macos-latest, r: 'release'} 28 | 29 | - {os: windows-latest, r: 'release'} 30 | # use 4.0 or 4.1 to check with rtools40's older compiler 31 | #- {os: windows-latest, r: 'oldrel-4'} 32 | 33 | - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} 34 | - {os: ubuntu-latest, r: 'release'} 35 | - {os: ubuntu-latest, r: 'oldrel-1'} 36 | - {os: ubuntu-latest, r: 'oldrel-2'} 37 | - {os: ubuntu-latest, r: 'oldrel-3'} 38 | 39 | env: 40 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 41 | R_KEEP_PKG_SOURCE: yes 42 | TORCH_INSTALL: 1 43 | 44 | steps: 45 | - uses: actions/checkout@v4 46 | 47 | - uses: r-lib/actions/setup-pandoc@v2 48 | 49 | - uses: r-lib/actions/setup-r@v2 50 | with: 51 | r-version: ${{ matrix.config.r }} 52 | http-user-agent: ${{ matrix.config.http-user-agent }} 53 | use-public-rspm: true 54 | 55 | - uses: r-lib/actions/setup-r-dependencies@v2 56 | with: 57 | extra-packages: any::rcmdcheck 58 | needs: check 59 | 60 | - uses: r-lib/actions/check-r-package@v2 61 | with: 62 | upload-snapshots: true 63 | build_args: 'c("--no-manual","--compact-vignettes=gs+qpdf")' 64 | -------------------------------------------------------------------------------- /.github/workflows/lock.yaml: -------------------------------------------------------------------------------- 1 | name: 'Lock Threads' 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 * * *' 6 | 7 | jobs: 8 | lock: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: dessant/lock-threads@v2 12 | with: 13 | github-token: ${{ github.token }} 14 | issue-lock-inactive-days: '14' 15 | # issue-exclude-labels: '' 16 | # issue-lock-labels: 'outdated' 17 | issue-lock-comment: > 18 | This issue has been automatically locked. If you believe you have 19 | found a related problem, please file a new issue (with a reprex: 20 | ) and link to this issue. 21 | issue-lock-reason: '' 22 | pr-lock-inactive-days: '14' 23 | # pr-exclude-labels: 'wip' 24 | pr-lock-labels: '' 25 | pr-lock-comment: > 26 | This pull request has been automatically locked. If you believe you 27 | have found a related problem, please file a new issue (with a reprex: 28 | ) and link to this issue. 29 | pr-lock-reason: '' 30 | # process-only: 'issues' 31 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | branches: [main, master] 8 | release: 9 | types: [published] 10 | workflow_dispatch: 11 | 12 | name: pkgdown.yaml 13 | 14 | permissions: read-all 15 | 16 | jobs: 17 | pkgdown: 18 | runs-on: ubuntu-latest 19 | # Only restrict concurrency for non-PR jobs 20 | concurrency: 21 | group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }} 22 | env: 23 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 24 | TORCH_INSTALL: 1 25 | permissions: 26 | contents: write 27 | steps: 28 | - uses: actions/checkout@v4 29 | 30 | - uses: r-lib/actions/setup-pandoc@v2 31 | 32 | - uses: r-lib/actions/setup-r@v2 33 | with: 34 | use-public-rspm: true 35 | 36 | - uses: r-lib/actions/setup-r-dependencies@v2 37 | with: 38 | extra-packages: any::pkgdown, local::. 39 | needs: website 40 | 41 | - name: Build site 42 | run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE) 43 | shell: Rscript {0} 44 | 45 | - name: Deploy to GitHub pages 🚀 46 | if: github.event_name != 'pull_request' 47 | uses: JamesIves/github-pages-deploy-action@v4.5.0 48 | with: 49 | clean: false 50 | branch: gh-pages 51 | folder: docs 52 | -------------------------------------------------------------------------------- /.github/workflows/pr-commands.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | issue_comment: 5 | types: [created] 6 | 7 | name: pr-commands.yaml 8 | 9 | permissions: read-all 10 | 11 | jobs: 12 | document: 13 | if: ${{ github.event.issue.pull_request && (github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') && startsWith(github.event.comment.body, '/document') }} 14 | name: document 15 | runs-on: ubuntu-latest 16 | env: 17 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 18 | permissions: 19 | contents: write 20 | steps: 21 | - uses: actions/checkout@v4 22 | 23 | - uses: r-lib/actions/pr-fetch@v2 24 | with: 25 | repo-token: ${{ secrets.GITHUB_TOKEN }} 26 | 27 | - uses: r-lib/actions/setup-r@v2 28 | with: 29 | use-public-rspm: true 30 | 31 | - uses: r-lib/actions/setup-r-dependencies@v2 32 | with: 33 | extra-packages: any::roxygen2 34 | needs: pr-document 35 | 36 | - name: Document 37 | run: roxygen2::roxygenise() 38 | shell: Rscript {0} 39 | 40 | - name: commit 41 | run: | 42 | git config --local user.name "$GITHUB_ACTOR" 43 | git config --local user.email "$GITHUB_ACTOR@users.noreply.github.com" 44 | git add man/\* NAMESPACE 45 | git commit -m 'Document' 46 | 47 | - uses: r-lib/actions/pr-push@v2 48 | with: 49 | repo-token: ${{ secrets.GITHUB_TOKEN }} 50 | 51 | style: 52 | if: ${{ github.event.issue.pull_request && (github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') && startsWith(github.event.comment.body, '/style') }} 53 | name: style 54 | runs-on: ubuntu-latest 55 | env: 56 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 57 | permissions: 58 | contents: write 59 | steps: 60 | - uses: actions/checkout@v4 61 | 62 | - uses: r-lib/actions/pr-fetch@v2 63 | with: 64 | repo-token: ${{ secrets.GITHUB_TOKEN }} 65 | 66 | - uses: r-lib/actions/setup-r@v2 67 | 68 | - name: Install dependencies 69 | run: install.packages("styler") 70 | shell: Rscript {0} 71 | 72 | - name: Style 73 | run: styler::style_pkg() 74 | shell: Rscript {0} 75 | 76 | - name: commit 77 | run: | 78 | git config --local user.name "$GITHUB_ACTOR" 79 | git config --local user.email "$GITHUB_ACTOR@users.noreply.github.com" 80 | git add \*.R 81 | git commit -m 'Style' 82 | 83 | - uses: r-lib/actions/pr-push@v2 84 | with: 85 | repo-token: ${{ secrets.GITHUB_TOKEN }} 86 | -------------------------------------------------------------------------------- /.github/workflows/test-coverage.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | branches: [main, master] 8 | 9 | name: test-coverage.yaml 10 | 11 | permissions: read-all 12 | 13 | permissions: read-all 14 | 15 | jobs: 16 | test-coverage: 17 | runs-on: ubuntu-latest 18 | env: 19 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 20 | TORCH_INSTALL: 1 21 | 22 | steps: 23 | - uses: actions/checkout@v4 24 | 25 | - uses: r-lib/actions/setup-r@v2 26 | with: 27 | use-public-rspm: true 28 | 29 | - uses: r-lib/actions/setup-r-dependencies@v2 30 | with: 31 | extra-packages: any::covr, any::xml2 32 | needs: coverage 33 | 34 | - name: Test coverage 35 | run: | 36 | cov <- covr::package_coverage( 37 | quiet = FALSE, 38 | clean = FALSE, 39 | install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package") 40 | ) 41 | covr::to_cobertura(cov) 42 | shell: Rscript {0} 43 | 44 | - uses: codecov/codecov-action@v4 45 | with: 46 | fail_ci_if_error: ${{ github.event_name != 'pull_request' && true || false }} 47 | file: ./cobertura.xml 48 | plugin: noop 49 | disable_search: true 50 | token: ${{ secrets.CODECOV_TOKEN }} 51 | 52 | - name: Show testthat output 53 | if: always() 54 | run: | 55 | ## -------------------------------------------------------------------- 56 | find '${{ runner.temp }}/package' -name 'testthat.Rout*' -exec cat '{}' \; || true 57 | shell: bash 58 | 59 | - name: Upload test results 60 | if: failure() 61 | uses: actions/upload-artifact@v4 62 | with: 63 | name: coverage-test-failures 64 | path: ${{ runner.temp }}/package 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .DS_Store 4 | docs 5 | revdep 6 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: brulee 2 | Title: High-Level Modeling Functions with 'torch' 3 | Version: 0.5.0.9000 4 | Authors@R: c( 5 | person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"), 6 | comment = c(ORCID = "0000-0003-2402-136X")), 7 | person("Daniel", "Falbel", , "daniel@posit.co", role = "aut"), 8 | person(given = "Posit Software, PBC", role = c("cph", "fnd")) 9 | ) 10 | Description: Provides high-level modeling functions to define and train 11 | models using the 'torch' R package. Models include linear, logistic, 12 | and multinomial regression as well as multilayer perceptrons. 13 | License: MIT + file LICENSE 14 | URL: https://github.com/tidymodels/brulee, 15 | https://brulee.tidymodels.org/ 16 | BugReports: https://github.com/tidymodels/brulee/issues 17 | Depends: 18 | R (>= 4.1) 19 | Imports: 20 | cli, 21 | coro (>= 1.0.1), 22 | dplyr, 23 | generics, 24 | ggplot2, 25 | glue, 26 | hardhat, 27 | rlang (>= 1.1.1), 28 | stats, 29 | tibble, 30 | torch (>= 0.13.0), 31 | utils 32 | Suggests: 33 | covr, 34 | modeldata, 35 | purrr, 36 | recipes, 37 | spelling, 38 | testthat, 39 | yardstick 40 | Config/Needs/website: tidyverse/tidytemplate 41 | Config/testthat/edition: 3 42 | Encoding: UTF-8 43 | Language: en-US 44 | Roxygen: list(markdown = TRUE) 45 | RoxygenNote: 7.3.2 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2023 2 | COPYRIGHT HOLDER: brulee authors 3 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2023 brulee authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method(autoplot,brulee_linear_reg) 4 | S3method(autoplot,brulee_logistic_reg) 5 | S3method(autoplot,brulee_mlp) 6 | S3method(autoplot,brulee_multinomial_reg) 7 | S3method(brulee_linear_reg,data.frame) 8 | S3method(brulee_linear_reg,default) 9 | S3method(brulee_linear_reg,formula) 10 | S3method(brulee_linear_reg,matrix) 11 | S3method(brulee_linear_reg,recipe) 12 | S3method(brulee_logistic_reg,data.frame) 13 | S3method(brulee_logistic_reg,default) 14 | S3method(brulee_logistic_reg,formula) 15 | S3method(brulee_logistic_reg,matrix) 16 | S3method(brulee_logistic_reg,recipe) 17 | S3method(brulee_mlp,data.frame) 18 | S3method(brulee_mlp,default) 19 | S3method(brulee_mlp,formula) 20 | S3method(brulee_mlp,matrix) 21 | S3method(brulee_mlp,recipe) 22 | S3method(brulee_mlp_two_layer,data.frame) 23 | S3method(brulee_mlp_two_layer,default) 24 | S3method(brulee_mlp_two_layer,formula) 25 | S3method(brulee_mlp_two_layer,matrix) 26 | S3method(brulee_mlp_two_layer,recipe) 27 | S3method(brulee_multinomial_reg,data.frame) 28 | S3method(brulee_multinomial_reg,default) 29 | S3method(brulee_multinomial_reg,formula) 30 | S3method(brulee_multinomial_reg,matrix) 31 | S3method(brulee_multinomial_reg,recipe) 32 | S3method(coef,brulee_linear_reg) 33 | S3method(coef,brulee_logistic_reg) 34 | S3method(coef,brulee_mlp) 35 | S3method(coef,brulee_multinomial_reg) 36 | S3method(predict,brulee_linear_reg) 37 | S3method(predict,brulee_logistic_reg) 38 | S3method(predict,brulee_mlp) 39 | S3method(predict,brulee_multinomial_reg) 40 | S3method(print,brulee_linear_reg) 41 | S3method(print,brulee_logistic_reg) 42 | S3method(print,brulee_mlp) 43 | S3method(print,brulee_multinomial_reg) 44 | export(autoplot) 45 | export(brulee_activations) 46 | export(brulee_linear_reg) 47 | export(brulee_logistic_reg) 48 | export(brulee_mlp) 49 | export(brulee_mlp_two_layer) 50 | export(brulee_multinomial_reg) 51 | export(coef) 52 | export(matrix_to_dataset) 53 | export(schedule_cyclic) 54 | export(schedule_decay_expo) 55 | export(schedule_decay_time) 56 | export(schedule_step) 57 | export(set_learn_rate) 58 | export(tunable) 59 | import(rlang) 60 | import(torch) 61 | importFrom(generics,tunable) 62 | importFrom(ggplot2,autoplot) 63 | importFrom(stats,coef) 64 | importFrom(stats,complete.cases) 65 | importFrom(stats,model.matrix) 66 | importFrom(stats,terms) 67 | importFrom(utils,globalVariables) 68 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # brulee (development version) 2 | 3 | * Transition from the magrittr pipe to the base R pipe. 4 | 5 | # brulee (0.5.0) 6 | 7 | * Removed a unit test for numerical overflow since it occurs less frequently and has become increasingly more challenging to reproduce. 8 | 9 | # brulee 0.4.0 10 | 11 | * Added a convenience function, `brulee_mlp_two_layer()`, to more easily fit two-layer networks with parsnip. 12 | 13 | * Various changes and improvements to error and warning messages. 14 | 15 | * Fixed a bug that occurred when linear activation was used for neural networks (#68). 16 | 17 | # brulee 0.3.0 18 | 19 | * Fixed bug where `coef()` didn't would error if used on a `brulee_logistic_reg()` that was trained with a recipe. (#66) 20 | 21 | * Fixed a bug where SGD always being used as the optimizer (#61). 22 | 23 | * Additional activation functions were added (#74). 24 | 25 | # brulee 0.2.0 26 | 27 | * Several learning rate schedulers were added to the modeling functions (#12). 28 | 29 | * An `optimizer` was added to [brulee_mlp()], with a new default being LBFGS instead of stochastic gradient descent. 30 | 31 | # brulee 0.1.0 32 | 33 | * Modeling functions gained a `mixture` argument for the proportion of L1 penalty that is used. (#50) 34 | 35 | * Penalization was not occurring when quasi-Newton optimization was chosen. (#50) 36 | 37 | # brulee 0.0.1 38 | 39 | First CRAN release. 40 | -------------------------------------------------------------------------------- /R/0_utils.R: -------------------------------------------------------------------------------- 1 | 2 | # ------------------------------------------------------------------------------ 3 | # used in print methods 4 | 5 | brulee_print <- function(x, ...) { 6 | lvl <- get_levels(x) 7 | if (is.null(lvl)) { 8 | chr_y <- "numeric outcome" 9 | } else { 10 | chr_y <- paste(length(lvl), "classes") 11 | } 12 | cat( 13 | format(x$dims$n, big.mark = ","), "samples,", 14 | format(x$dims$p, big.mark = ","), "features,", 15 | chr_y, "\n" 16 | ) 17 | if (!is.null(x$dims$levels) && !is.null(x$parameters$class_weights)) { 18 | cat("class weights", 19 | paste0( 20 | names(x$parameters$class_weights), 21 | "=", 22 | format(x$parameters$class_weights), 23 | collapse = ", " 24 | ), 25 | "\n") 26 | } 27 | if (x$parameters$penalty > 0) { 28 | cat("weight decay:", x$parameters$penalty, "\n") 29 | } 30 | if (any(names(x$parameters) == "dropout")) { 31 | cat("dropout proportion:", x$parameters$dropout, "\n") 32 | } 33 | cat("batch size:", x$parameters$batch_size, "\n") 34 | 35 | if (all(c("sched", "sched_opt") %in% names(x$parameters))) { 36 | cat_schedule(x$parameters) 37 | } 38 | 39 | if (!is.null(x$loss)) { 40 | it <- x$best_epoch 41 | chr_it <- cli::pluralize("{it} epoch{?s}:") 42 | if(x$parameters$validation > 0) { 43 | if (is.na(x$y_stats$mean)) { 44 | cat("validation loss after", chr_it, 45 | signif(x$loss[it], 3), "\n") 46 | } else { 47 | cat("scaled validation loss after", chr_it, 48 | signif(x$loss[it], 3), "\n") 49 | } 50 | } else { 51 | if (is.na(x$y_stats$mean)) { 52 | cat("training set loss after", chr_it, 53 | signif(x$loss[it], 3), "\n") 54 | } else { 55 | cat("scaled training set loss after", chr_it, 56 | signif(x$loss[it], 3), "\n") 57 | } 58 | } 59 | } 60 | invisible(x) 61 | } 62 | 63 | # ------------------------------------------------------------------------------ 64 | 65 | cat_schedule <- function(x) { 66 | if (x$sched == "none") { 67 | cat("learn rate:", x$learn_rate, "\n") 68 | } else { 69 | .fn <- paste0("schedule_", x$sched) 70 | cl <- rlang::call2(.fn, !!!x$sched_opt) 71 | chr_cl <- rlang::expr_deparse(cl, width = 200) 72 | 73 | cat(gsub("^schedule_", "schedule: ", chr_cl), "\n") 74 | } 75 | invisible(NULL) 76 | } 77 | 78 | # ------------------------------------------------------------------------------ 79 | 80 | 81 | model_to_raw <- function(model) { 82 | con <- rawConnection(raw(), open = "w") 83 | on.exit({close(con)}, add = TRUE) 84 | torch::torch_save(model, con) 85 | r <- rawConnectionValue(con) 86 | r 87 | } 88 | 89 | # ------------------------------------------------------------------------------ 90 | 91 | lx_term <- function(norm) { 92 | function(model) { 93 | is_bias <- grepl("bias", names(model$parameters)) 94 | coefs <- model$parameters[!is_bias] 95 | l <- lapply(coefs, function(x) { 96 | torch::torch_sum(norm(x)) 97 | }) 98 | torch::torch_sum(torch::torch_stack(l)) 99 | } 100 | } 101 | 102 | l2_term <- lx_term(function(x) torch::torch_pow(x, 2)) 103 | l1_term <- lx_term(function(x) torch::torch_abs(x)) 104 | 105 | # ------------------------------------------------------------------------- 106 | 107 | make_penalized_loss <- function(loss_fn, model, penalty, mixture) { 108 | force(loss_fn) 109 | function(...) { 110 | loss <- loss_fn(...) 111 | if (penalty > 0) { 112 | l_term <- mixture * l1_term(model) + (1 - mixture) / 2 * l2_term(model) 113 | loss <- loss + penalty * l_term 114 | } 115 | loss 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /R/aaa.R: -------------------------------------------------------------------------------- 1 | #' @import torch 2 | #' @import rlang 3 | #' @importFrom stats complete.cases model.matrix terms 4 | #' @importFrom utils globalVariables 5 | #' 6 | 7 | #' @importFrom ggplot2 autoplot 8 | #' @export 9 | ggplot2::autoplot 10 | 11 | #' @importFrom generics tunable 12 | #' @export 13 | generics::tunable 14 | 15 | #' @importFrom stats coef 16 | #' @export 17 | stats::coef 18 | 19 | # ------------------------------------------------------------------------------ 20 | 21 | utils::globalVariables( 22 | c( 23 | "object", "iteration", "loss" 24 | ) 25 | ) 26 | 27 | # ------------------------------------------------------------------------------ 28 | 29 | # nocov start 30 | .onAttach <- function(libname, pkgname) { 31 | s3_register("ggplot2::autoplot", "brulee_mlp") 32 | invisible() 33 | } 34 | 35 | # Dynamic reg helper ----------------------------------------------------------- 36 | 37 | # vctrs/register-s3.R 38 | # https://github.com/r-lib/vctrs/blob/master/R/register-s3.R 39 | s3_register <- function(generic, class, method = NULL) { 40 | stopifnot(is.character(generic), length(generic) == 1) 41 | stopifnot(is.character(class), length(class) == 1) 42 | 43 | pieces <- strsplit(generic, "::")[[1]] 44 | stopifnot(length(pieces) == 2) 45 | package <- pieces[[1]] 46 | generic <- pieces[[2]] 47 | 48 | if (is.null(method)) { 49 | method <- get(paste0(generic, ".", class), envir = parent.frame()) 50 | } 51 | stopifnot(is.function(method)) 52 | 53 | if (package %in% loadedNamespaces()) { 54 | registerS3method(generic, class, method, envir = asNamespace(package)) 55 | } 56 | 57 | # Always register hook in case package is later unloaded & reloaded 58 | setHook( 59 | packageEvent(package, "onLoad"), 60 | function(...) { 61 | registerS3method(generic, class, method, envir = asNamespace(package)) 62 | } 63 | ) 64 | } 65 | 66 | # nocov end 67 | -------------------------------------------------------------------------------- /R/activation.R: -------------------------------------------------------------------------------- 1 | allowed_activation <- 2 | c("celu", "elu", "gelu", "hardshrink", "hardsigmoid", 3 | "hardtanh", "leaky_relu", "linear", "log_sigmoid", "relu", "relu6", 4 | "rrelu", "selu", "sigmoid", "silu", "softplus", "softshrink", 5 | "softsign", "tanh", "tanhshrink") 6 | 7 | #' Activation functions for neural networks in brulee 8 | #' 9 | #' @return A character vector of values. 10 | #' @export 11 | brulee_activations <- function() { 12 | allowed_activation 13 | } 14 | 15 | get_activation_fn <- function(arg, ...) { 16 | 17 | if (arg == "linear") { 18 | arg <- "identity" 19 | } 20 | 21 | cl <- rlang::call2(paste0("nn_", arg), .ns = "torch") 22 | res <- rlang::eval_bare(cl) 23 | 24 | res 25 | } 26 | -------------------------------------------------------------------------------- /R/autoplot.R: -------------------------------------------------------------------------------- 1 | 2 | # used for autoplots 3 | brulee_plot <- function(object, ...) { 4 | x <- tibble::tibble(iteration = seq(along = object$loss), loss = object$loss) 5 | 6 | if(object$parameters$validation > 0) { 7 | if (is.na(object$y_stats$mean)) { 8 | lab <- "loss (validation set)" 9 | } else { 10 | lab <- "loss (validation set, scaled)" 11 | } 12 | } else { 13 | if (is.na(object$y_stats$mean)) { 14 | lab <- "loss (training set)" 15 | } else { 16 | lab <- "loss (training set, scaled)" 17 | } 18 | } 19 | 20 | ggplot2::ggplot(x, ggplot2::aes(x = iteration, y = loss)) + 21 | ggplot2::geom_line() + 22 | ggplot2::labs(y = lab)+ 23 | ggplot2::geom_vline(xintercept = object$best_epoch, lty = 2, col = "green") 24 | } 25 | 26 | 27 | ## ----------------------------------------------------------------------------- 28 | 29 | #' Plot model loss over epochs 30 | #' 31 | #' @param object A `brulee_mlp`, `brulee_logistic_reg`, 32 | #' `brulee_multinomial_reg`, or `brulee_linear_reg` object. 33 | #' @param ... Not currently used 34 | #' @return A `ggplot` object. 35 | #' @details This function plots the loss function across the available epochs. A 36 | #' vertical line shows the epoch with the best loss value. 37 | #' @examples 38 | #' \donttest{ 39 | #' if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "yardstick", "modeldata"))) { 40 | #' library(ggplot2) 41 | #' library(recipes) 42 | #' theme_set(theme_bw()) 43 | #' 44 | #' data(ames, package = "modeldata") 45 | #' 46 | #' ames$Sale_Price <- log10(ames$Sale_Price) 47 | #' 48 | #' set.seed(1) 49 | #' in_train <- sample(1:nrow(ames), 2000) 50 | #' ames_train <- ames[ in_train,] 51 | #' ames_test <- ames[-in_train,] 52 | #' 53 | #' ames_rec <- 54 | #' recipe(Sale_Price ~ Longitude + Latitude, data = ames_train) |> 55 | #' step_normalize(all_numeric_predictors()) 56 | #' 57 | #' set.seed(2) 58 | #' fit <- brulee_mlp(ames_rec, data = ames_train, epochs = 50, batch_size = 32) 59 | #' 60 | #' autoplot(fit) 61 | #' } 62 | #' } 63 | #' @name brulee-autoplot 64 | #' @export 65 | autoplot.brulee_mlp <- brulee_plot 66 | 67 | #' @rdname brulee-autoplot 68 | #' @export 69 | autoplot.brulee_logistic_reg <- brulee_plot 70 | 71 | #' @rdname brulee-autoplot 72 | #' @export 73 | autoplot.brulee_multinomial_reg <- brulee_plot 74 | 75 | #' @rdname brulee-autoplot 76 | #' @export 77 | autoplot.brulee_linear_reg <- brulee_plot 78 | 79 | -------------------------------------------------------------------------------- /R/brulee-package.R: -------------------------------------------------------------------------------- 1 | #' @keywords internal 2 | "_PACKAGE" 3 | 4 | # The following block is used by usethis to automatically manage 5 | # roxygen namespace tags. Modify with care! 6 | ## usethis namespace: start 7 | ## usethis namespace: end 8 | NULL 9 | -------------------------------------------------------------------------------- /R/checks.R: -------------------------------------------------------------------------------- 1 | # Additional type checkers designed for testing argument values. 2 | 3 | check_number_whole_vec <- function(x, call = rlang::caller_env(), arg = rlang::caller_arg(x), ...) { 4 | arg = rlang::caller_arg(x) 5 | for (i in x) { 6 | check_number_whole(i, arg = arg, call = call, ...) 7 | } 8 | x <- as.integer(x) 9 | invisible(x) 10 | } 11 | 12 | check_number_decimal_vec <- function(x, 13 | arg = rlang::caller_arg(x), 14 | allow_na = FALSE, 15 | call = rlang::caller_env(), 16 | ...) { 17 | if (!is.double(x)) { 18 | cli::cli_abort("{.arg {arg}} should be a double vector.") 19 | } 20 | 21 | if (!allow_na && any(is.na(x))) { 22 | cli::cli_abort("{.arg {arg}} should not contain missing values.") 23 | } 24 | 25 | invisible(x) 26 | } 27 | 28 | # ------------------------------------------------------------------------------ 29 | # soon to be replaced checkers 30 | 31 | 32 | check_missing_data <- function(x, y, fn = "some function", verbose = FALSE) { 33 | compl_data <- complete.cases(x, y) 34 | if (any(!compl_data)) { 35 | x <- x[compl_data, , drop = FALSE] 36 | y <- y[compl_data] 37 | if (verbose) { 38 | cl_chr <- as.character() 39 | msg <- paste0(fn, "() removed ", sum(!compl_data), " rows of ", 40 | "data due to missing values.") 41 | cli::cli_warn(msg) 42 | } 43 | } 44 | list(x = x, y = y) 45 | } 46 | 47 | check_data_att <- function(x, y) { 48 | hardhat::validate_outcomes_are_univariate(y) 49 | 50 | # check matrices/vectors, matrix type, matrix column names 51 | if (!is.matrix(x) || !is.numeric(x)) { 52 | cli::cli_abort("'x' should be a numeric matrix.") 53 | } 54 | nms <- colnames(x) 55 | if (length(nms) != ncol(x)) { 56 | cli::cli_abort("Every column of 'x' should have a name.") 57 | } 58 | if (!is.vector(y) & !is.factor(y)) { 59 | cli::cli_abort("'y' should be a vector.") 60 | } 61 | invisible(NULL) 62 | } 63 | 64 | 65 | format_msg <- function(fn, arg) { 66 | if (is.null(fn)) { 67 | fn <- "The function" 68 | } else { 69 | fn <- paste0(fn, "()") 70 | } 71 | paste0(fn, " expected '", arg, "'") 72 | } 73 | 74 | check_rng <- function(x, x_min, x_max, incl = c(TRUE, TRUE)) { 75 | if (incl[[1]]) { 76 | pass_low <- x >= x_min 77 | } else { 78 | pass_low <- x > x_min 79 | } 80 | if (incl[[2]]) { 81 | pass_high <- x <= x_max 82 | } else { 83 | pass_high <- x < x_max 84 | } 85 | any(!pass_low | !pass_high) 86 | } 87 | 88 | numeric_loss_values <- c("mse", "poisson", "smooth_l1", "l1") 89 | check_regression_loss <- function(loss_function) { 90 | check_character(loss_function, single = TRUE, vals = numeric_loss_values) 91 | 92 | # TODO return a different format 93 | dplyr::case_when( 94 | loss_function == "poisson" ~ "torch::nnf_poisson_nll_loss", 95 | loss_function == "smooth_l1" ~ "torch::nnf_smooth_l1_loss", 96 | loss_function == "l1" ~ "torch::nnf_l1_loss", 97 | TRUE ~ "torch::nnf_mse_loss" 98 | ) 99 | 100 | } 101 | 102 | check_classification_loss <- function(x) { 103 | 104 | } 105 | 106 | check_optimizer <- function(x) { 107 | 108 | } 109 | 110 | 111 | check_integer <- 112 | function(x, 113 | single = TRUE, 114 | x_min = -Inf, x_max = Inf, incl = c(TRUE, TRUE), 115 | fn = NULL) { 116 | cl <- match.call() 117 | arg <- as.character(cl$x) 118 | 119 | if (!is.integer(x)) { 120 | msg <- paste(format_msg(fn, arg), "to be integer.") 121 | cli::cli_abort(msg) 122 | } 123 | 124 | if (single && length(x) > 1) { 125 | msg <- paste(format_msg(fn, arg), "to be a single integer.") 126 | cli::cli_abort(msg) 127 | } 128 | 129 | out_of_range <- check_rng(x, x_min, x_max, incl) 130 | if (any(out_of_range)) { 131 | msg <- paste0(format_msg(fn, arg), 132 | " to be an integer on ", 133 | ifelse(incl[[1]], "[", "("), x_min, ", ", 134 | x_max, ifelse(incl[[2]], "]", ")"), ".") 135 | cli::cli_abort(msg) 136 | } 137 | 138 | invisible(TRUE) 139 | } 140 | 141 | check_double <- function(x, 142 | single = TRUE, 143 | x_min = -Inf, x_max = Inf, incl = c(TRUE, TRUE), 144 | fn = NULL) { 145 | cl <- match.call() 146 | arg <- as.character(cl$x) 147 | 148 | if (!is.double(x)) { 149 | msg <- paste(format_msg(fn, arg), "to be a double.") 150 | cli::cli_abort(msg) 151 | } 152 | 153 | if (single && length(x) > 1) { 154 | msg <- paste(format_msg(fn, arg), "to be a single double.") 155 | cli::cli_abort(msg) 156 | } 157 | 158 | out_of_range <- check_rng(x, x_min, x_max, incl) 159 | if (any(out_of_range)) { 160 | msg <- paste0(format_msg(fn, arg), 161 | " to be a double on ", 162 | ifelse(incl[[1]], "[", "("), x_min, ", ", 163 | x_max, ifelse(incl[[2]], "]", ")"), ".") 164 | cli::cli_abort(msg) 165 | } 166 | 167 | invisible(TRUE) 168 | } 169 | 170 | check_character <- function(x, single = TRUE, vals = NULL, fn = NULL) { 171 | cl <- match.call() 172 | arg <- as.character(cl$x) 173 | 174 | if (!is.character(x)) { 175 | msg <- paste(format_msg(fn, arg), "to be character.") 176 | cli::cli_abort(msg) 177 | } 178 | 179 | if (single && length(x) > 1) { 180 | msg <- paste(format_msg(fn, arg), "to be a single character string.") 181 | cli::cli_abort(msg) 182 | } 183 | 184 | if (!is.null(vals)) { 185 | if (any(!(x %in% vals))) { 186 | msg <- paste0(format_msg(fn, arg), " contains an incorrect value.") 187 | cli::cli_abort(msg) 188 | } 189 | } 190 | 191 | invisible(TRUE) 192 | } 193 | 194 | check_logical <- function(x, single = TRUE, fn = NULL) { 195 | cl <- match.call() 196 | arg <- as.character(cl$x) 197 | 198 | if (!is.logical(x)) { 199 | msg <- paste(format_msg(fn, arg), "to be logical.") 200 | cli::cli_abort(msg) 201 | } 202 | 203 | if (single && length(x) > 1) { 204 | msg <- paste(format_msg(fn, arg), "to be a single logical.") 205 | cli::cli_abort(msg) 206 | } 207 | invisible(TRUE) 208 | } 209 | 210 | 211 | check_class_weights <- function(wts, lvls, xtab, fn) { 212 | if (length(lvls) == 0) { 213 | return(NULL) 214 | } 215 | 216 | if (is.null(wts)) { 217 | wts <- rep(1, length(lvls)) 218 | return(torch::torch_tensor(wts)) 219 | } 220 | if (!is.numeric(wts)) { 221 | msg <- paste(format_msg(fn, "class_weights"), "to a numeric vector") 222 | cli::cli_abort(msg) 223 | } 224 | 225 | if (length(wts) == 1) { 226 | val <- wts 227 | wts <- rep(1, length(lvls)) 228 | minority <- names(xtab)[which.min(xtab)] 229 | wts[lvls == minority] <- val 230 | names(wts) <- lvls 231 | } 232 | 233 | if (length(lvls) != length(wts)) { 234 | msg <- paste0("There were ", length(wts), " class weights given but ", 235 | length(lvls), " were expected.") 236 | cli::cli_abort(msg) 237 | } 238 | 239 | nms <- names(wts) 240 | if (is.null(nms)) { 241 | names(wts) <- lvls 242 | } else { 243 | if (!identical(sort(nms), sort(lvls))) { 244 | msg <- paste("Names for class weights should be:", 245 | paste0("'", lvls, "'", collapse = ", ")) 246 | cli::cli_abort(msg) 247 | } 248 | wts <- wts[lvls] 249 | } 250 | 251 | 252 | torch::torch_tensor(wts) 253 | } 254 | -------------------------------------------------------------------------------- /R/coef.R: -------------------------------------------------------------------------------- 1 | brulee_coefs <- function(object, epoch = NULL, ...) { 2 | if (!is.null(epoch) && length(epoch) != 1) { 3 | cli::cli_abort("'epoch' should be a single integer.") 4 | } 5 | max_epochs <- length(object$estimates) 6 | 7 | if (is.null(epoch)) { 8 | epoch <- object$best_epoch 9 | } else { 10 | if (epoch > max_epochs) { 11 | msg <- glue::glue("There were only {max_epochs} epochs fit. Setting 'epochs' to {max_epochs}.") 12 | cli::cli_warn(msg) 13 | epoch <- max_epochs 14 | } 15 | 16 | } 17 | object$estimates[[epoch]] 18 | } 19 | 20 | 21 | #' Extract Model Coefficients 22 | #' 23 | #' @param object A model fit from \pkg{brulee}. 24 | #' @param epoch A single integer for the training iteration. If left `NULL`, 25 | #' the estimates from the best model fit (via internal performance metrics). 26 | #' @param ... Not currently used. 27 | #' @return For logistic/linear regression, a named vector. For neural networks, 28 | #' a list of arrays. 29 | #' @examples 30 | #' \donttest{ 31 | #' if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "modeldata"))) { 32 | #' 33 | #' data(ames, package = "modeldata") 34 | #' 35 | #' ames$Sale_Price <- log10(ames$Sale_Price) 36 | #' 37 | #' set.seed(1) 38 | #' in_train <- sample(1:nrow(ames), 2000) 39 | #' ames_train <- ames[ in_train,] 40 | #' ames_test <- ames[-in_train,] 41 | #' 42 | #' # Using recipe 43 | #' library(recipes) 44 | #' 45 | #' ames_rec <- 46 | #' recipe(Sale_Price ~ Longitude + Latitude, data = ames_train) |> 47 | #' step_normalize(all_numeric_predictors()) 48 | #' 49 | #' set.seed(2) 50 | #' fit <- brulee_linear_reg(ames_rec, data = ames_train, 51 | #' epochs = 50, batch_size = 32) 52 | #' 53 | #' coef(fit) 54 | #' coef(fit, epoch = 1) 55 | #' } 56 | #' } 57 | #' @name brulee-coefs 58 | #' @export 59 | coef.brulee_logistic_reg <- function(object, epoch = NULL, ...) { 60 | network_params <- brulee_coefs(object, epoch) 61 | slopes <- network_params$fc1.weight[2, ] - network_params$fc1.weight[1, ] 62 | int <- network_params$fc1.bias[2] - network_params$fc1.bias[1] 63 | param <- c(int, slopes) 64 | names(param) <- c("(Intercept)", object$dims$features) 65 | param 66 | } 67 | 68 | #' @rdname brulee-coefs 69 | #' @export 70 | coef.brulee_linear_reg <- function(object, epoch = NULL, ...) { 71 | network_params <- brulee_coefs(object, epoch) 72 | slopes <- network_params$fc1.weight[1,] 73 | int <- network_params$fc1.bias 74 | param <- c(int, slopes) 75 | names(param) <- c("(Intercept)", object$dims$features) 76 | param 77 | } 78 | 79 | #' @rdname brulee-coefs 80 | #' @export 81 | coef.brulee_mlp <- brulee_coefs 82 | 83 | #' @rdname brulee-coefs 84 | #' @export 85 | coef.brulee_multinomial_reg <- function(object, epoch = NULL, ...) { 86 | network_params <- brulee_coefs(object, epoch) 87 | slopes <- t(network_params$fc1.weight) 88 | int <- network_params$fc1.bias 89 | param <- rbind(int, slopes) 90 | rownames(param) <- c("(Intercept)", object$dims$features) 91 | colnames(param) <- object$dims$levels 92 | param 93 | } 94 | 95 | -------------------------------------------------------------------------------- /R/convert_data.R: -------------------------------------------------------------------------------- 1 | #' Convert data to torch format 2 | #' 3 | #' For an x/y interface, `matrix_to_dataset()` converts the data to proper 4 | #' encodings then formats the results for consumption by `torch`. 5 | #' 6 | #' @param x A numeric matrix of predictors. 7 | #' @param y A vector. If regression than `y` is numeric. For classification, it 8 | #' is a factor. 9 | #' @return An R6 index sampler object with classes "training_set", 10 | #' "dataset", and "R6". 11 | #' @details Missing values should be removed before passing data to this function. 12 | #' @examples 13 | #' if (torch::torch_is_installed()) { 14 | #' matrix_to_dataset(as.matrix(mtcars[, -1]), mtcars$mpg) 15 | #' } 16 | #' @export 17 | matrix_to_dataset <- function(x, y) { 18 | x <- torch::torch_tensor(x) 19 | if (is.factor(y)) { 20 | y <- as.numeric(y) 21 | y <- torch::torch_tensor(y, dtype = torch_long()) 22 | } else { 23 | y <- torch::torch_tensor(y) 24 | } 25 | torch::tensor_dataset(x = x, y = y) 26 | } 27 | 28 | # ------------------------------------------------------------------------------ 29 | 30 | scale_stats <- function(x) { 31 | res <- list(mean = mean(x, na.rm = TRUE), sd = stats::sd(x, na.rm = TRUE)) 32 | if (res$sd == 0) { 33 | cli::cli_abort("There is no variation in `y`.") 34 | } 35 | res 36 | } 37 | 38 | scale_y <- function(y, stats) { 39 | (y - stats$mean)/stats$sd 40 | } 41 | -------------------------------------------------------------------------------- /R/import-standalone-obj-type.R: -------------------------------------------------------------------------------- 1 | # Standalone file: do not edit by hand 2 | # Source: 3 | # ---------------------------------------------------------------------- 4 | # 5 | # --- 6 | # repo: r-lib/rlang 7 | # file: standalone-obj-type.R 8 | # last-updated: 2024-02-14 9 | # license: https://unlicense.org 10 | # imports: rlang (>= 1.1.0) 11 | # --- 12 | # 13 | # ## Changelog 14 | # 15 | # 2024-02-14: 16 | # - `obj_type_friendly()` now works for S7 objects. 17 | # 18 | # 2023-05-01: 19 | # - `obj_type_friendly()` now only displays the first class of S3 objects. 20 | # 21 | # 2023-03-30: 22 | # - `stop_input_type()` now handles `I()` input literally in `arg`. 23 | # 24 | # 2022-10-04: 25 | # - `obj_type_friendly(value = TRUE)` now shows numeric scalars 26 | # literally. 27 | # - `stop_friendly_type()` now takes `show_value`, passed to 28 | # `obj_type_friendly()` as the `value` argument. 29 | # 30 | # 2022-10-03: 31 | # - Added `allow_na` and `allow_null` arguments. 32 | # - `NULL` is now backticked. 33 | # - Better friendly type for infinities and `NaN`. 34 | # 35 | # 2022-09-16: 36 | # - Unprefixed usage of rlang functions with `rlang::` to 37 | # avoid onLoad issues when called from rlang (#1482). 38 | # 39 | # 2022-08-11: 40 | # - Prefixed usage of rlang functions with `rlang::`. 41 | # 42 | # 2022-06-22: 43 | # - `friendly_type_of()` is now `obj_type_friendly()`. 44 | # - Added `obj_type_oo()`. 45 | # 46 | # 2021-12-20: 47 | # - Added support for scalar values and empty vectors. 48 | # - Added `stop_input_type()` 49 | # 50 | # 2021-06-30: 51 | # - Added support for missing arguments. 52 | # 53 | # 2021-04-19: 54 | # - Added support for matrices and arrays (#141). 55 | # - Added documentation. 56 | # - Added changelog. 57 | # 58 | # nocov start 59 | 60 | #' Return English-friendly type 61 | #' @param x Any R object. 62 | #' @param value Whether to describe the value of `x`. Special values 63 | #' like `NA` or `""` are always described. 64 | #' @param length Whether to mention the length of vectors and lists. 65 | #' @return A string describing the type. Starts with an indefinite 66 | #' article, e.g. "an integer vector". 67 | #' @noRd 68 | obj_type_friendly <- function(x, value = TRUE) { 69 | if (is_missing(x)) { 70 | return("absent") 71 | } 72 | 73 | if (is.object(x)) { 74 | if (inherits(x, "quosure")) { 75 | type <- "quosure" 76 | } else { 77 | type <- class(x)[[1L]] 78 | } 79 | return(sprintf("a <%s> object", type)) 80 | } 81 | 82 | if (!is_vector(x)) { 83 | return(.rlang_as_friendly_type(typeof(x))) 84 | } 85 | 86 | n_dim <- length(dim(x)) 87 | 88 | if (!n_dim) { 89 | if (!is_list(x) && length(x) == 1) { 90 | if (is_na(x)) { 91 | return(switch( 92 | typeof(x), 93 | logical = "`NA`", 94 | integer = "an integer `NA`", 95 | double = 96 | if (is.nan(x)) { 97 | "`NaN`" 98 | } else { 99 | "a numeric `NA`" 100 | }, 101 | complex = "a complex `NA`", 102 | character = "a character `NA`", 103 | .rlang_stop_unexpected_typeof(x) 104 | )) 105 | } 106 | 107 | show_infinites <- function(x) { 108 | if (x > 0) { 109 | "`Inf`" 110 | } else { 111 | "`-Inf`" 112 | } 113 | } 114 | str_encode <- function(x, width = 30, ...) { 115 | if (nchar(x) > width) { 116 | x <- substr(x, 1, width - 3) 117 | x <- paste0(x, "...") 118 | } 119 | encodeString(x, ...) 120 | } 121 | 122 | if (value) { 123 | if (is.numeric(x) && is.infinite(x)) { 124 | return(show_infinites(x)) 125 | } 126 | 127 | if (is.numeric(x) || is.complex(x)) { 128 | number <- as.character(round(x, 2)) 129 | what <- if (is.complex(x)) "the complex number" else "the number" 130 | return(paste(what, number)) 131 | } 132 | 133 | return(switch( 134 | typeof(x), 135 | logical = if (x) "`TRUE`" else "`FALSE`", 136 | character = { 137 | what <- if (nzchar(x)) "the string" else "the empty string" 138 | paste(what, str_encode(x, quote = "\"")) 139 | }, 140 | raw = paste("the raw value", as.character(x)), 141 | .rlang_stop_unexpected_typeof(x) 142 | )) 143 | } 144 | 145 | return(switch( 146 | typeof(x), 147 | logical = "a logical value", 148 | integer = "an integer", 149 | double = if (is.infinite(x)) show_infinites(x) else "a number", 150 | complex = "a complex number", 151 | character = if (nzchar(x)) "a string" else "\"\"", 152 | raw = "a raw value", 153 | .rlang_stop_unexpected_typeof(x) 154 | )) 155 | } 156 | 157 | if (length(x) == 0) { 158 | return(switch( 159 | typeof(x), 160 | logical = "an empty logical vector", 161 | integer = "an empty integer vector", 162 | double = "an empty numeric vector", 163 | complex = "an empty complex vector", 164 | character = "an empty character vector", 165 | raw = "an empty raw vector", 166 | list = "an empty list", 167 | .rlang_stop_unexpected_typeof(x) 168 | )) 169 | } 170 | } 171 | 172 | vec_type_friendly(x) 173 | } 174 | 175 | vec_type_friendly <- function(x, length = FALSE) { 176 | if (!is_vector(x)) { 177 | abort("`x` must be a vector.") 178 | } 179 | type <- typeof(x) 180 | n_dim <- length(dim(x)) 181 | 182 | add_length <- function(type) { 183 | if (length && !n_dim) { 184 | paste0(type, sprintf(" of length %s", length(x))) 185 | } else { 186 | type 187 | } 188 | } 189 | 190 | if (type == "list") { 191 | if (n_dim < 2) { 192 | return(add_length("a list")) 193 | } else if (is.data.frame(x)) { 194 | return("a data frame") 195 | } else if (n_dim == 2) { 196 | return("a list matrix") 197 | } else { 198 | return("a list array") 199 | } 200 | } 201 | 202 | type <- switch( 203 | type, 204 | logical = "a logical %s", 205 | integer = "an integer %s", 206 | numeric = , 207 | double = "a double %s", 208 | complex = "a complex %s", 209 | character = "a character %s", 210 | raw = "a raw %s", 211 | type = paste0("a ", type, " %s") 212 | ) 213 | 214 | if (n_dim < 2) { 215 | kind <- "vector" 216 | } else if (n_dim == 2) { 217 | kind <- "matrix" 218 | } else { 219 | kind <- "array" 220 | } 221 | out <- sprintf(type, kind) 222 | 223 | if (n_dim >= 2) { 224 | out 225 | } else { 226 | add_length(out) 227 | } 228 | } 229 | 230 | .rlang_as_friendly_type <- function(type) { 231 | switch( 232 | type, 233 | 234 | list = "a list", 235 | 236 | NULL = "`NULL`", 237 | environment = "an environment", 238 | externalptr = "a pointer", 239 | weakref = "a weak reference", 240 | S4 = "an S4 object", 241 | 242 | name = , 243 | symbol = "a symbol", 244 | language = "a call", 245 | pairlist = "a pairlist node", 246 | expression = "an expression vector", 247 | 248 | char = "an internal string", 249 | promise = "an internal promise", 250 | ... = "an internal dots object", 251 | any = "an internal `any` object", 252 | bytecode = "an internal bytecode object", 253 | 254 | primitive = , 255 | builtin = , 256 | special = "a primitive function", 257 | closure = "a function", 258 | 259 | type 260 | ) 261 | } 262 | 263 | .rlang_stop_unexpected_typeof <- function(x, call = caller_env()) { 264 | abort( 265 | sprintf("Unexpected type <%s>.", typeof(x)), 266 | call = call 267 | ) 268 | } 269 | 270 | #' Return OO type 271 | #' @param x Any R object. 272 | #' @return One of `"bare"` (for non-OO objects), `"S3"`, `"S4"`, 273 | #' `"R6"`, or `"S7"`. 274 | #' @noRd 275 | obj_type_oo <- function(x) { 276 | if (!is.object(x)) { 277 | return("bare") 278 | } 279 | 280 | class <- inherits(x, c("R6", "S7_object"), which = TRUE) 281 | 282 | if (class[[1]]) { 283 | "R6" 284 | } else if (class[[2]]) { 285 | "S7" 286 | } else if (isS4(x)) { 287 | "S4" 288 | } else { 289 | "S3" 290 | } 291 | } 292 | 293 | #' @param x The object type which does not conform to `what`. Its 294 | #' `obj_type_friendly()` is taken and mentioned in the error message. 295 | #' @param what The friendly expected type as a string. Can be a 296 | #' character vector of expected types, in which case the error 297 | #' message mentions all of them in an "or" enumeration. 298 | #' @param show_value Passed to `value` argument of `obj_type_friendly()`. 299 | #' @param ... Arguments passed to [abort()]. 300 | #' @inheritParams args_error_context 301 | #' @noRd 302 | stop_input_type <- function(x, 303 | what, 304 | ..., 305 | allow_na = FALSE, 306 | allow_null = FALSE, 307 | show_value = TRUE, 308 | arg = caller_arg(x), 309 | call = caller_env()) { 310 | # From standalone-cli.R 311 | cli <- env_get_list( 312 | nms = c("format_arg", "format_code"), 313 | last = topenv(), 314 | default = function(x) sprintf("`%s`", x), 315 | inherit = TRUE 316 | ) 317 | 318 | if (allow_na) { 319 | what <- c(what, cli$format_code("NA")) 320 | } 321 | if (allow_null) { 322 | what <- c(what, cli$format_code("NULL")) 323 | } 324 | if (length(what)) { 325 | what <- oxford_comma(what) 326 | } 327 | if (inherits(arg, "AsIs")) { 328 | format_arg <- identity 329 | } else { 330 | format_arg <- cli$format_arg 331 | } 332 | 333 | message <- sprintf( 334 | "%s must be %s, not %s.", 335 | format_arg(arg), 336 | what, 337 | obj_type_friendly(x, value = show_value) 338 | ) 339 | 340 | abort(message, ..., call = call, arg = arg) 341 | } 342 | 343 | oxford_comma <- function(chr, sep = ", ", final = "or") { 344 | n <- length(chr) 345 | 346 | if (n < 2) { 347 | return(chr) 348 | } 349 | 350 | head <- chr[seq_len(n - 1)] 351 | last <- chr[n] 352 | 353 | head <- paste(head, collapse = sep) 354 | 355 | # Write a or b. But a, b, or c. 356 | if (n > 2) { 357 | paste0(head, sep, final, " ", last) 358 | } else { 359 | paste0(head, " ", final, " ", last) 360 | } 361 | } 362 | 363 | # nocov end 364 | -------------------------------------------------------------------------------- /R/import-standalone-types-check.R: -------------------------------------------------------------------------------- 1 | # Standalone file: do not edit by hand 2 | # Source: 3 | # ---------------------------------------------------------------------- 4 | # 5 | # --- 6 | # repo: r-lib/rlang 7 | # file: standalone-types-check.R 8 | # last-updated: 2023-03-13 9 | # license: https://unlicense.org 10 | # dependencies: standalone-obj-type.R 11 | # imports: rlang (>= 1.1.0) 12 | # --- 13 | # 14 | # ## Changelog 15 | # 16 | # 2024-08-15: 17 | # - `check_character()` gains an `allow_na` argument (@martaalcalde, #1724) 18 | # 19 | # 2023-03-13: 20 | # - Improved error messages of number checkers (@teunbrand) 21 | # - Added `allow_infinite` argument to `check_number_whole()` (@mgirlich). 22 | # - Added `check_data_frame()` (@mgirlich). 23 | # 24 | # 2023-03-07: 25 | # - Added dependency on rlang (>= 1.1.0). 26 | # 27 | # 2023-02-15: 28 | # - Added `check_logical()`. 29 | # 30 | # - `check_bool()`, `check_number_whole()`, and 31 | # `check_number_decimal()` are now implemented in C. 32 | # 33 | # - For efficiency, `check_number_whole()` and 34 | # `check_number_decimal()` now take a `NULL` default for `min` and 35 | # `max`. This makes it possible to bypass unnecessary type-checking 36 | # and comparisons in the default case of no bounds checks. 37 | # 38 | # 2022-10-07: 39 | # - `check_number_whole()` and `_decimal()` no longer treat 40 | # non-numeric types such as factors or dates as numbers. Numeric 41 | # types are detected with `is.numeric()`. 42 | # 43 | # 2022-10-04: 44 | # - Added `check_name()` that forbids the empty string. 45 | # `check_string()` allows the empty string by default. 46 | # 47 | # 2022-09-28: 48 | # - Removed `what` arguments. 49 | # - Added `allow_na` and `allow_null` arguments. 50 | # - Added `allow_decimal` and `allow_infinite` arguments. 51 | # - Improved errors with absent arguments. 52 | # 53 | # 54 | # 2022-09-16: 55 | # - Unprefixed usage of rlang functions with `rlang::` to 56 | # avoid onLoad issues when called from rlang (#1482). 57 | # 58 | # 2022-08-11: 59 | # - Added changelog. 60 | # 61 | # nocov start 62 | 63 | # Scalars ----------------------------------------------------------------- 64 | 65 | .standalone_types_check_dot_call <- .Call 66 | 67 | check_bool <- function(x, 68 | ..., 69 | allow_na = FALSE, 70 | allow_null = FALSE, 71 | arg = caller_arg(x), 72 | call = caller_env()) { 73 | if (!missing(x) && .standalone_types_check_dot_call(ffi_standalone_is_bool_1.0.7, x, allow_na, allow_null)) { 74 | return(invisible(NULL)) 75 | } 76 | 77 | stop_input_type( 78 | x, 79 | c("`TRUE`", "`FALSE`"), 80 | ..., 81 | allow_na = allow_na, 82 | allow_null = allow_null, 83 | arg = arg, 84 | call = call 85 | ) 86 | } 87 | 88 | check_string <- function(x, 89 | ..., 90 | allow_empty = TRUE, 91 | allow_na = FALSE, 92 | allow_null = FALSE, 93 | arg = caller_arg(x), 94 | call = caller_env()) { 95 | if (!missing(x)) { 96 | is_string <- .rlang_check_is_string( 97 | x, 98 | allow_empty = allow_empty, 99 | allow_na = allow_na, 100 | allow_null = allow_null 101 | ) 102 | if (is_string) { 103 | return(invisible(NULL)) 104 | } 105 | } 106 | 107 | stop_input_type( 108 | x, 109 | "a single string", 110 | ..., 111 | allow_na = allow_na, 112 | allow_null = allow_null, 113 | arg = arg, 114 | call = call 115 | ) 116 | } 117 | 118 | .rlang_check_is_string <- function(x, 119 | allow_empty, 120 | allow_na, 121 | allow_null) { 122 | if (is_string(x)) { 123 | if (allow_empty || !is_string(x, "")) { 124 | return(TRUE) 125 | } 126 | } 127 | 128 | if (allow_null && is_null(x)) { 129 | return(TRUE) 130 | } 131 | 132 | if (allow_na && (identical(x, NA) || identical(x, na_chr))) { 133 | return(TRUE) 134 | } 135 | 136 | FALSE 137 | } 138 | 139 | check_name <- function(x, 140 | ..., 141 | allow_null = FALSE, 142 | arg = caller_arg(x), 143 | call = caller_env()) { 144 | if (!missing(x)) { 145 | is_string <- .rlang_check_is_string( 146 | x, 147 | allow_empty = FALSE, 148 | allow_na = FALSE, 149 | allow_null = allow_null 150 | ) 151 | if (is_string) { 152 | return(invisible(NULL)) 153 | } 154 | } 155 | 156 | stop_input_type( 157 | x, 158 | "a valid name", 159 | ..., 160 | allow_na = FALSE, 161 | allow_null = allow_null, 162 | arg = arg, 163 | call = call 164 | ) 165 | } 166 | 167 | IS_NUMBER_true <- 0 168 | IS_NUMBER_false <- 1 169 | IS_NUMBER_oob <- 2 170 | 171 | check_number_decimal <- function(x, 172 | ..., 173 | min = NULL, 174 | max = NULL, 175 | allow_infinite = TRUE, 176 | allow_na = FALSE, 177 | allow_null = FALSE, 178 | arg = caller_arg(x), 179 | call = caller_env()) { 180 | if (missing(x)) { 181 | exit_code <- IS_NUMBER_false 182 | } else if (0 == (exit_code <- .standalone_types_check_dot_call( 183 | ffi_standalone_check_number_1.0.7, 184 | x, 185 | allow_decimal = TRUE, 186 | min, 187 | max, 188 | allow_infinite, 189 | allow_na, 190 | allow_null 191 | ))) { 192 | return(invisible(NULL)) 193 | } 194 | 195 | .stop_not_number( 196 | x, 197 | ..., 198 | exit_code = exit_code, 199 | allow_decimal = TRUE, 200 | min = min, 201 | max = max, 202 | allow_na = allow_na, 203 | allow_null = allow_null, 204 | arg = arg, 205 | call = call 206 | ) 207 | } 208 | 209 | check_number_whole <- function(x, 210 | ..., 211 | min = NULL, 212 | max = NULL, 213 | allow_infinite = FALSE, 214 | allow_na = FALSE, 215 | allow_null = FALSE, 216 | arg = caller_arg(x), 217 | call = caller_env()) { 218 | if (missing(x)) { 219 | exit_code <- IS_NUMBER_false 220 | } else if (0 == (exit_code <- .standalone_types_check_dot_call( 221 | ffi_standalone_check_number_1.0.7, 222 | x, 223 | allow_decimal = FALSE, 224 | min, 225 | max, 226 | allow_infinite, 227 | allow_na, 228 | allow_null 229 | ))) { 230 | return(invisible(NULL)) 231 | } 232 | 233 | .stop_not_number( 234 | x, 235 | ..., 236 | exit_code = exit_code, 237 | allow_decimal = FALSE, 238 | min = min, 239 | max = max, 240 | allow_na = allow_na, 241 | allow_null = allow_null, 242 | arg = arg, 243 | call = call 244 | ) 245 | } 246 | 247 | .stop_not_number <- function(x, 248 | ..., 249 | exit_code, 250 | allow_decimal, 251 | min, 252 | max, 253 | allow_na, 254 | allow_null, 255 | arg, 256 | call) { 257 | if (allow_decimal) { 258 | what <- "a number" 259 | } else { 260 | what <- "a whole number" 261 | } 262 | 263 | if (exit_code == IS_NUMBER_oob) { 264 | min <- min %||% -Inf 265 | max <- max %||% Inf 266 | 267 | if (min > -Inf && max < Inf) { 268 | what <- sprintf("%s between %s and %s", what, min, max) 269 | } else if (x < min) { 270 | what <- sprintf("%s larger than or equal to %s", what, min) 271 | } else if (x > max) { 272 | what <- sprintf("%s smaller than or equal to %s", what, max) 273 | } else { 274 | abort("Unexpected state in OOB check", .internal = TRUE) 275 | } 276 | } 277 | 278 | stop_input_type( 279 | x, 280 | what, 281 | ..., 282 | allow_na = allow_na, 283 | allow_null = allow_null, 284 | arg = arg, 285 | call = call 286 | ) 287 | } 288 | 289 | check_symbol <- function(x, 290 | ..., 291 | allow_null = FALSE, 292 | arg = caller_arg(x), 293 | call = caller_env()) { 294 | if (!missing(x)) { 295 | if (is_symbol(x)) { 296 | return(invisible(NULL)) 297 | } 298 | if (allow_null && is_null(x)) { 299 | return(invisible(NULL)) 300 | } 301 | } 302 | 303 | stop_input_type( 304 | x, 305 | "a symbol", 306 | ..., 307 | allow_na = FALSE, 308 | allow_null = allow_null, 309 | arg = arg, 310 | call = call 311 | ) 312 | } 313 | 314 | check_arg <- function(x, 315 | ..., 316 | allow_null = FALSE, 317 | arg = caller_arg(x), 318 | call = caller_env()) { 319 | if (!missing(x)) { 320 | if (is_symbol(x)) { 321 | return(invisible(NULL)) 322 | } 323 | if (allow_null && is_null(x)) { 324 | return(invisible(NULL)) 325 | } 326 | } 327 | 328 | stop_input_type( 329 | x, 330 | "an argument name", 331 | ..., 332 | allow_na = FALSE, 333 | allow_null = allow_null, 334 | arg = arg, 335 | call = call 336 | ) 337 | } 338 | 339 | check_call <- function(x, 340 | ..., 341 | allow_null = FALSE, 342 | arg = caller_arg(x), 343 | call = caller_env()) { 344 | if (!missing(x)) { 345 | if (is_call(x)) { 346 | return(invisible(NULL)) 347 | } 348 | if (allow_null && is_null(x)) { 349 | return(invisible(NULL)) 350 | } 351 | } 352 | 353 | stop_input_type( 354 | x, 355 | "a defused call", 356 | ..., 357 | allow_na = FALSE, 358 | allow_null = allow_null, 359 | arg = arg, 360 | call = call 361 | ) 362 | } 363 | 364 | check_environment <- function(x, 365 | ..., 366 | allow_null = FALSE, 367 | arg = caller_arg(x), 368 | call = caller_env()) { 369 | if (!missing(x)) { 370 | if (is_environment(x)) { 371 | return(invisible(NULL)) 372 | } 373 | if (allow_null && is_null(x)) { 374 | return(invisible(NULL)) 375 | } 376 | } 377 | 378 | stop_input_type( 379 | x, 380 | "an environment", 381 | ..., 382 | allow_na = FALSE, 383 | allow_null = allow_null, 384 | arg = arg, 385 | call = call 386 | ) 387 | } 388 | 389 | check_function <- function(x, 390 | ..., 391 | allow_null = FALSE, 392 | arg = caller_arg(x), 393 | call = caller_env()) { 394 | if (!missing(x)) { 395 | if (is_function(x)) { 396 | return(invisible(NULL)) 397 | } 398 | if (allow_null && is_null(x)) { 399 | return(invisible(NULL)) 400 | } 401 | } 402 | 403 | stop_input_type( 404 | x, 405 | "a function", 406 | ..., 407 | allow_na = FALSE, 408 | allow_null = allow_null, 409 | arg = arg, 410 | call = call 411 | ) 412 | } 413 | 414 | check_closure <- function(x, 415 | ..., 416 | allow_null = FALSE, 417 | arg = caller_arg(x), 418 | call = caller_env()) { 419 | if (!missing(x)) { 420 | if (is_closure(x)) { 421 | return(invisible(NULL)) 422 | } 423 | if (allow_null && is_null(x)) { 424 | return(invisible(NULL)) 425 | } 426 | } 427 | 428 | stop_input_type( 429 | x, 430 | "an R function", 431 | ..., 432 | allow_na = FALSE, 433 | allow_null = allow_null, 434 | arg = arg, 435 | call = call 436 | ) 437 | } 438 | 439 | check_formula <- function(x, 440 | ..., 441 | allow_null = FALSE, 442 | arg = caller_arg(x), 443 | call = caller_env()) { 444 | if (!missing(x)) { 445 | if (is_formula(x)) { 446 | return(invisible(NULL)) 447 | } 448 | if (allow_null && is_null(x)) { 449 | return(invisible(NULL)) 450 | } 451 | } 452 | 453 | stop_input_type( 454 | x, 455 | "a formula", 456 | ..., 457 | allow_na = FALSE, 458 | allow_null = allow_null, 459 | arg = arg, 460 | call = call 461 | ) 462 | } 463 | 464 | 465 | # Vectors ----------------------------------------------------------------- 466 | 467 | # TODO: Figure out what to do with logical `NA` and `allow_na = TRUE` 468 | 469 | check_character <- function(x, 470 | ..., 471 | allow_na = TRUE, 472 | allow_null = FALSE, 473 | arg = caller_arg(x), 474 | call = caller_env()) { 475 | 476 | if (!missing(x)) { 477 | if (is_character(x)) { 478 | if (!allow_na && any(is.na(x))) { 479 | abort( 480 | sprintf("`%s` can't contain NA values.", arg), 481 | arg = arg, 482 | call = call 483 | ) 484 | } 485 | 486 | return(invisible(NULL)) 487 | } 488 | 489 | if (allow_null && is_null(x)) { 490 | return(invisible(NULL)) 491 | } 492 | } 493 | 494 | stop_input_type( 495 | x, 496 | "a character vector", 497 | ..., 498 | allow_null = allow_null, 499 | arg = arg, 500 | call = call 501 | ) 502 | } 503 | 504 | # check_logical <- function(x, 505 | # ..., 506 | # allow_null = FALSE, 507 | # arg = caller_arg(x), 508 | # call = caller_env()) { 509 | # if (!missing(x)) { 510 | # if (is_logical(x)) { 511 | # return(invisible(NULL)) 512 | # } 513 | # if (allow_null && is_null(x)) { 514 | # return(invisible(NULL)) 515 | # } 516 | # } 517 | # 518 | # stop_input_type( 519 | # x, 520 | # "a logical vector", 521 | # ..., 522 | # allow_na = FALSE, 523 | # allow_null = allow_null, 524 | # arg = arg, 525 | # call = call 526 | # ) 527 | # } 528 | 529 | check_data_frame <- function(x, 530 | ..., 531 | allow_null = FALSE, 532 | arg = caller_arg(x), 533 | call = caller_env()) { 534 | if (!missing(x)) { 535 | if (is.data.frame(x)) { 536 | return(invisible(NULL)) 537 | } 538 | if (allow_null && is_null(x)) { 539 | return(invisible(NULL)) 540 | } 541 | } 542 | 543 | stop_input_type( 544 | x, 545 | "a data frame", 546 | ..., 547 | allow_null = allow_null, 548 | arg = arg, 549 | call = call 550 | ) 551 | } 552 | 553 | # nocov end 554 | -------------------------------------------------------------------------------- /R/linear_reg-predict.R: -------------------------------------------------------------------------------- 1 | #' Predict from a `brulee_linear_reg` 2 | #' 3 | #' @inheritParams predict.brulee_mlp 4 | #' @param object A `brulee_linear_reg` object. 5 | #' @param type A single character. The type of predictions to generate. 6 | #' Valid options are: 7 | #' 8 | #' - `"numeric"` for numeric predictions. 9 | #' 10 | #' @return 11 | #' 12 | #' A tibble of predictions. The number of rows in the tibble is guaranteed 13 | #' to be the same as the number of rows in `new_data`. 14 | #' 15 | #' @examples 16 | #' \donttest{ 17 | #' if (torch::torch_is_installed() & rlang::is_installed("recipes")) { 18 | #' 19 | #' data(ames, package = "modeldata") 20 | #' 21 | #' ames$Sale_Price <- log10(ames$Sale_Price) 22 | #' 23 | #' set.seed(1) 24 | #' in_train <- sample(1:nrow(ames), 2000) 25 | #' ames_train <- ames[ in_train,] 26 | #' ames_test <- ames[-in_train,] 27 | #' 28 | #' # Using recipe 29 | #' library(recipes) 30 | #' 31 | #' ames_rec <- 32 | #' recipe(Sale_Price ~ Longitude + Latitude, data = ames_train) |> 33 | #' step_normalize(all_numeric_predictors()) 34 | #' 35 | #' set.seed(2) 36 | #' fit <- brulee_linear_reg(ames_rec, data = ames_train, 37 | #' epochs = 50, batch_size = 32) 38 | #' 39 | #' predict(fit, ames_test) 40 | #' } 41 | #' } 42 | #' @export 43 | predict.brulee_linear_reg <- function(object, new_data, type = NULL, epoch = NULL, ...) { 44 | forged <- hardhat::forge(new_data, object$blueprint) 45 | type <- check_type(object, type) 46 | if (is.null(epoch)) { 47 | epoch <- object$best_epoch 48 | } 49 | predict_brulee_linear_reg_bridge(type, object, forged$predictors, epoch = epoch) 50 | } 51 | 52 | # ------------------------------------------------------------------------------ 53 | # Bridge 54 | 55 | predict_brulee_linear_reg_bridge <- function(type, model, predictors, epoch) { 56 | 57 | if (!is.matrix(predictors)) { 58 | predictors <- as.matrix(predictors) 59 | if (is.character(predictors)) { 60 | cli::cli_abort( 61 | paste( 62 | "There were some non-numeric columns in the predictors.", 63 | "Please use a formula or recipe to encode all of the predictors as numeric." 64 | ) 65 | ) 66 | } 67 | } 68 | 69 | predict_function <- get_linear_reg_predict_function(type) 70 | 71 | max_epoch <- length(model$estimates) 72 | if (epoch > max_epoch) { 73 | msg <- paste("The model fit only", max_epoch, "epochs; predictions cannot", 74 | "be made at epoch", epoch, "so last epoch is used.") 75 | cli::cli_warn(msg) 76 | } 77 | 78 | predictions <- predict_function(model, predictors, epoch) 79 | hardhat::validate_prediction_size(predictions, predictors) 80 | predictions 81 | } 82 | 83 | get_linear_reg_predict_function <- function(type) { 84 | predict_brulee_linear_reg_numeric 85 | } 86 | 87 | # ------------------------------------------------------------------------------ 88 | # Implementation 89 | 90 | 91 | predict_brulee_linear_reg_raw <- function(model, predictors, epoch) { 92 | # convert from raw format 93 | module <- revive_model(model$model_obj) 94 | # get current model parameters 95 | estimates <- model$estimates[[epoch]] 96 | # convert to torch representation 97 | estimates <- lapply(estimates, torch::torch_tensor) 98 | # stuff back into the model 99 | module$load_state_dict(estimates) 100 | # put the model in evaluation mode 101 | module$eval() 102 | predictions <- module(torch::torch_tensor(predictors)) 103 | predictions <- as.array(predictions) 104 | # torch doesn't have a NA type so it returns NaN 105 | predictions[is.nan(predictions)] <- NA 106 | predictions 107 | } 108 | 109 | predict_brulee_linear_reg_numeric <- function(model, predictors, epoch) { 110 | predictions <- predict_brulee_linear_reg_raw(model, predictors, epoch) 111 | predictions <- predictions * model$y_stats$sd + model$y_stats$mean 112 | hardhat::spruce_numeric(predictions[,1]) 113 | } 114 | -------------------------------------------------------------------------------- /R/logistic_reg-predict.R: -------------------------------------------------------------------------------- 1 | #' Predict from a `brulee_logistic_reg` 2 | #' 3 | #' @inheritParams predict.brulee_mlp 4 | #' @param object A `brulee_logistic_reg` object. 5 | #' @param type A single character. The type of predictions to generate. 6 | #' Valid options are: 7 | #' 8 | #' - `"class"` for hard class predictions 9 | #' - `"prob"` for soft class predictions (i.e., class probabilities) 10 | #' 11 | #' @return 12 | #' 13 | #' A tibble of predictions. The number of rows in the tibble is guaranteed 14 | #' to be the same as the number of rows in `new_data`. 15 | #' 16 | #' @examples 17 | #' \donttest{ 18 | #' if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "yardstick", "modeldata"))) { 19 | #' 20 | #' library(recipes) 21 | #' library(yardstick) 22 | #' 23 | #' data(penguins, package = "modeldata") 24 | #' 25 | #' penguins <- penguins |> na.omit() 26 | #' 27 | #' set.seed(122) 28 | #' in_train <- sample(1:nrow(penguins), 200) 29 | #' penguins_train <- penguins[ in_train,] 30 | #' penguins_test <- penguins[-in_train,] 31 | #' 32 | #' rec <- recipe(sex ~ ., data = penguins_train) |> 33 | #' step_dummy(all_nominal_predictors()) |> 34 | #' step_normalize(all_numeric_predictors()) 35 | #' 36 | #' set.seed(3) 37 | #' fit <- brulee_logistic_reg(rec, data = penguins_train, epochs = 5) 38 | #' fit 39 | #' 40 | #' predict(fit, penguins_test) 41 | #' 42 | #' predict(fit, penguins_test, type = "prob") |> 43 | #' bind_cols(penguins_test) |> 44 | #' roc_curve(sex, .pred_female) |> 45 | #' autoplot() 46 | #' 47 | #' } 48 | #' } 49 | #' @export 50 | predict.brulee_logistic_reg <- function(object, new_data, type = NULL, epoch = NULL, ...) { 51 | forged <- hardhat::forge(new_data, object$blueprint) 52 | type <- check_type(object, type) 53 | if (is.null(epoch)) { 54 | epoch <- object$best_epoch 55 | } 56 | predict_brulee_logistic_reg_bridge(type, object, forged$predictors, epoch = epoch) 57 | } 58 | 59 | # ------------------------------------------------------------------------------ 60 | # Bridge 61 | 62 | predict_brulee_logistic_reg_bridge <- function(type, model, predictors, epoch) { 63 | 64 | if (!is.matrix(predictors)) { 65 | predictors <- as.matrix(predictors) 66 | if (is.character(predictors)) { 67 | cli::cli_abort( 68 | paste( 69 | "There were some non-numeric columns in the predictors.", 70 | "Please use a formula or recipe to encode all of the predictors as numeric." 71 | ) 72 | ) 73 | } 74 | } 75 | 76 | predict_function <- get_logistic_reg_predict_function(type) 77 | 78 | max_epoch <- length(model$estimates) 79 | if (epoch > max_epoch) { 80 | msg <- paste("The model fit only", max_epoch, "epochs; predictions cannot", 81 | "be made at epoch", epoch, "so last epoch is used.") 82 | cli::cli_warn(msg) 83 | } 84 | 85 | predictions <- predict_function(model, predictors, epoch) 86 | hardhat::validate_prediction_size(predictions, predictors) 87 | predictions 88 | } 89 | 90 | get_logistic_reg_predict_function <- function(type) { 91 | switch( 92 | type, 93 | prob = predict_brulee_logistic_reg_prob, 94 | class = predict_brulee_logistic_reg_class 95 | ) 96 | } 97 | 98 | # ------------------------------------------------------------------------------ 99 | # Implementation 100 | 101 | predict_brulee_logistic_reg_raw <- function(model, predictors, epoch) { 102 | # convert from raw format 103 | module <- revive_model(model$model_obj) 104 | # get current model parameters 105 | estimates <- model$estimates[[epoch]] 106 | # convert to torch representation 107 | estimates <- lapply(estimates, torch::torch_tensor) 108 | # stuff back into the model 109 | module$load_state_dict(estimates) 110 | # put the model in evaluation mode 111 | module$eval() 112 | predictions <- module(torch::torch_tensor(predictors)) 113 | predictions <- as.array(predictions) 114 | # torch doesn't have a NA type so it returns NaN 115 | predictions[is.nan(predictions)] <- NA 116 | predictions 117 | } 118 | 119 | predict_brulee_logistic_reg_prob <- function(model, predictors, epoch) { 120 | predictions <- predict_brulee_logistic_reg_raw(model, predictors, epoch) 121 | lvs <- get_levels(model) 122 | hardhat::spruce_prob(pred_levels = lvs, predictions) 123 | } 124 | 125 | predict_brulee_logistic_reg_class <- function(model, predictors, epoch) { 126 | predictions <- predict_brulee_logistic_reg_raw(model, predictors, epoch) 127 | predictions <- apply(predictions, 1, which.max2) # take the maximum value 128 | lvs <- get_levels(model) 129 | hardhat::spruce_class(factor(lvs[predictions], levels = lvs)) 130 | } 131 | -------------------------------------------------------------------------------- /R/mlp-predict.R: -------------------------------------------------------------------------------- 1 | #' Predict from a `brulee_mlp` 2 | #' 3 | #' @param object A `brulee_mlp` object. 4 | #' 5 | #' @param new_data A data frame or matrix of new predictors. 6 | #' @param epoch An integer for the epoch to make predictions. If this value 7 | #' is larger than the maximum number that was fit, a warning is issued and the 8 | #' parameters from the last epoch are used. If left `NULL`, the epoch 9 | #' associated with the smallest loss is used. 10 | #' @param type A single character. The type of predictions to generate. 11 | #' Valid options are: 12 | #' 13 | #' - `"numeric"` for numeric predictions. 14 | #' - `"class"` for hard class predictions 15 | #' - `"prob"` for soft class predictions (i.e., class probabilities) 16 | #' 17 | #' @param ... Not used, but required for extensibility. 18 | #' 19 | #' @return 20 | #' 21 | #' A tibble of predictions. The number of rows in the tibble is guaranteed 22 | #' to be the same as the number of rows in `new_data`. 23 | #' 24 | #' @examples 25 | #' \donttest{ 26 | #' if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "modeldata"))) { 27 | #' # regression example: 28 | #' 29 | #' data(ames, package = "modeldata") 30 | #' 31 | #' ames$Sale_Price <- log10(ames$Sale_Price) 32 | #' 33 | #' set.seed(1) 34 | #' in_train <- sample(1:nrow(ames), 2000) 35 | #' ames_train <- ames[ in_train,] 36 | #' ames_test <- ames[-in_train,] 37 | #' 38 | #' # Using recipe 39 | #' library(recipes) 40 | #' 41 | #' ames_rec <- 42 | #' recipe(Sale_Price ~ Longitude + Latitude, data = ames_train) |> 43 | #' step_normalize(all_numeric_predictors()) 44 | #' 45 | #' set.seed(2) 46 | #' fit <- brulee_mlp(ames_rec, data = ames_train, epochs = 50, batch_size = 32) 47 | #' 48 | #' predict(fit, ames_test) 49 | #' } 50 | #' } 51 | #' @export 52 | predict.brulee_mlp <- function(object, new_data, type = NULL, epoch = NULL, ...) { 53 | forged <- hardhat::forge(new_data, object$blueprint) 54 | type <- check_type(object, type) 55 | if (is.null(epoch)) { 56 | epoch <- object$best_epoch 57 | } 58 | predict_brulee_mlp_bridge(type, object, forged$predictors, epoch = epoch) 59 | } 60 | 61 | # ------------------------------------------------------------------------------ 62 | # Bridge 63 | 64 | predict_brulee_mlp_bridge <- function(type, model, predictors, epoch) { 65 | 66 | if (!is.matrix(predictors)) { 67 | predictors <- as.matrix(predictors) 68 | if (is.character(predictors)) { 69 | cli::cli_abort( 70 | paste( 71 | "There were some non-numeric columns in the predictors.", 72 | "Please use a formula or recipe to encode all of the predictors as numeric." 73 | ) 74 | ) 75 | } 76 | } 77 | 78 | predict_function <- get_mlp_predict_function(type) 79 | 80 | max_epoch <- length(model$estimates) 81 | if (epoch > max_epoch) { 82 | msg <- paste("The model fit only", max_epoch, "epochs; predictions cannot", 83 | "be made at epoch", epoch, "so last epoch is used.") 84 | cli::cli_warn(msg) 85 | } 86 | 87 | predictions <- predict_function(model, predictors, epoch) 88 | hardhat::validate_prediction_size(predictions, predictors) 89 | predictions 90 | } 91 | 92 | get_mlp_predict_function <- function(type) { 93 | switch( 94 | type, 95 | numeric = predict_brulee_mlp_numeric, 96 | prob = predict_brulee_mlp_prob, 97 | class = predict_brulee_mlp_class 98 | ) 99 | } 100 | 101 | # ------------------------------------------------------------------------------ 102 | # Implementation 103 | 104 | add_intercept <- function(x) { 105 | if (!is.array(x)) { 106 | x <- as.array(x) 107 | } 108 | cbind(rep(1, nrow(x)), x) 109 | } 110 | 111 | revive_model <- function(model) { 112 | con <- rawConnection(model) 113 | on.exit({close(con)}, add = TRUE) 114 | module <- torch::torch_load(con) 115 | module 116 | } 117 | 118 | predict_brulee_mlp_raw <- function(model, predictors, epoch) { 119 | # convert from raw format 120 | module <- revive_model(model$model_obj) 121 | # get current model parameters 122 | estimates <- model$estimates[[epoch]] 123 | # convert to torch representation 124 | estimates <- lapply(estimates, torch::torch_tensor) 125 | 126 | # stuff back into the model 127 | module$load_state_dict(estimates) 128 | module$eval() # put the model in evaluation mode 129 | predictions <- module(torch::torch_tensor(predictors)) 130 | predictions <- as.array(predictions) 131 | # torch doesn't have a NA type so it returns NaN 132 | predictions[is.nan(predictions)] <- NA 133 | predictions 134 | } 135 | 136 | predict_brulee_mlp_numeric <- function(model, predictors, epoch) { 137 | predictions <- predict_brulee_mlp_raw(model, predictors, epoch) 138 | predictions <- predictions * model$y_stats$sd + model$y_stats$mean 139 | hardhat::spruce_numeric(predictions[,1]) 140 | } 141 | 142 | predict_brulee_mlp_prob <- function(model, predictors, epoch) { 143 | predictions <- predict_brulee_mlp_raw(model, predictors, epoch) 144 | lvs <- get_levels(model) 145 | hardhat::spruce_prob(pred_levels = lvs, predictions) 146 | } 147 | 148 | predict_brulee_mlp_class <- function(model, predictors, epoch) { 149 | predictions <- predict_brulee_mlp_raw(model, predictors, epoch) 150 | predictions <- apply(predictions, 1, which.max2) # take the maximum value 151 | lvs <- get_levels(model) 152 | hardhat::spruce_class(factor(lvs[predictions], levels = lvs)) 153 | } 154 | 155 | # a which max alternative that returns NA if any 156 | # value is NA 157 | which.max2 <- function(x) { 158 | if (any(is.na(x))) 159 | NA 160 | else 161 | which.max(x) 162 | } 163 | 164 | # get levels from a model object 165 | get_levels <- function(model) { 166 | # Assumes univariate models 167 | levels(model$blueprint$ptypes$outcomes[[1]]) 168 | } 169 | 170 | 171 | valid_predict_types <- function() { 172 | c("numeric", "prob", "class") 173 | } 174 | 175 | check_type <- function(model, type) { 176 | 177 | outcome_ptype <- model$blueprint$ptypes$outcomes[[1]] 178 | 179 | if (is.null(type)) { 180 | if (is.factor(outcome_ptype)) 181 | type <- "class" 182 | else if (is.numeric(outcome_ptype)) 183 | type <- "numeric" 184 | else 185 | cli::cli_abort(glue::glue("Unknown outcome type '{class(outcome_ptype)}'")) 186 | } 187 | 188 | type <- rlang::arg_match(type, valid_predict_types()) 189 | 190 | if (is.factor(outcome_ptype)) { 191 | if (!type %in% c("prob", "class")) 192 | cli::cli_abort(glue::glue("Outcome is factor and the prediction type is '{type}'.")) 193 | } else if (is.numeric(outcome_ptype)) { 194 | if (type != "numeric") 195 | cli::cli_abort(glue::glue("Outcome is numeric and the prediction type is '{type}'.")) 196 | } 197 | 198 | type 199 | } 200 | -------------------------------------------------------------------------------- /R/multinomial_reg-predict.R: -------------------------------------------------------------------------------- 1 | #' Predict from a `brulee_multinomial_reg` 2 | #' 3 | #' @inheritParams predict.brulee_mlp 4 | #' @param object A `brulee_multinomial_reg` object. 5 | #' @param type A single character. The type of predictions to generate. 6 | #' Valid options are: 7 | #' 8 | #' - `"class"` for hard class predictions 9 | #' - `"prob"` for soft class predictions (i.e., class probabilities) 10 | #' 11 | #' @return 12 | #' 13 | #' A tibble of predictions. The number of rows in the tibble is guaranteed 14 | #' to be the same as the number of rows in `new_data`. 15 | #' 16 | #' @examples 17 | #' \donttest{ 18 | #' if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "yardstick", "modeldata"))) { 19 | #' 20 | #' library(recipes) 21 | #' library(yardstick) 22 | #' 23 | #' data(penguins, package = "modeldata") 24 | #' 25 | #' penguins <- penguins |> na.omit() 26 | #' 27 | #' set.seed(122) 28 | #' in_train <- sample(1:nrow(penguins), 200) 29 | #' penguins_train <- penguins[ in_train,] 30 | #' penguins_test <- penguins[-in_train,] 31 | #' 32 | #' rec <- recipe(island ~ ., data = penguins_train) |> 33 | #' step_dummy(species, sex) |> 34 | #' step_normalize(all_numeric_predictors()) 35 | #' 36 | #' set.seed(3) 37 | #' fit <- brulee_multinomial_reg(rec, data = penguins_train, epochs = 5) 38 | #' fit 39 | #' 40 | #' predict(fit, penguins_test) |> 41 | #' bind_cols(penguins_test) |> 42 | #' conf_mat(island, .pred_class) 43 | #' } 44 | #' } 45 | #' @export 46 | predict.brulee_multinomial_reg <- function(object, new_data, type = NULL, epoch = NULL, ...) { 47 | forged <- hardhat::forge(new_data, object$blueprint) 48 | type <- check_type(object, type) 49 | if (is.null(epoch)) { 50 | epoch <- object$best_epoch 51 | } 52 | predict_brulee_multinomial_reg_bridge(type, object, forged$predictors, epoch = epoch) 53 | } 54 | 55 | # ------------------------------------------------------------------------------ 56 | # Bridge 57 | 58 | predict_brulee_multinomial_reg_bridge <- function(type, model, predictors, epoch) { 59 | 60 | if (!is.matrix(predictors)) { 61 | predictors <- as.matrix(predictors) 62 | if (is.character(predictors)) { 63 | cli::cli_abort( 64 | paste( 65 | "There were some non-numeric columns in the predictors.", 66 | "Please use a formula or recipe to encode all of the predictors as numeric." 67 | ) 68 | ) 69 | } 70 | } 71 | 72 | predict_function <- get_multinomial_reg_predict_function(type) 73 | 74 | max_epoch <- length(model$estimates) 75 | if (epoch > max_epoch) { 76 | msg <- paste("The model fit only", max_epoch, "epochs; predictions cannot", 77 | "be made at epoch", epoch, "so last epoch is used.") 78 | cli::cli_warn(msg) 79 | } 80 | 81 | predictions <- predict_function(model, predictors, epoch) 82 | hardhat::validate_prediction_size(predictions, predictors) 83 | predictions 84 | } 85 | 86 | get_multinomial_reg_predict_function <- function(type) { 87 | switch( 88 | type, 89 | prob = predict_brulee_multinomial_reg_prob, 90 | class = predict_brulee_multinomial_reg_class 91 | ) 92 | } 93 | 94 | # ------------------------------------------------------------------------------ 95 | # Implementation 96 | 97 | predict_brulee_multinomial_reg_raw <- function(model, predictors, epoch) { 98 | # convert from raw format 99 | module <- revive_model(model$model_obj) 100 | # get current model parameters 101 | estimates <- model$estimates[[epoch]] 102 | # convert to torch representation 103 | estimates <- lapply(estimates, torch::torch_tensor) 104 | # stuff back into the model 105 | module$load_state_dict(estimates) 106 | # put the model in evaluation mode 107 | module$eval() 108 | predictions <- module(torch::torch_tensor(predictors)) 109 | predictions <- as.array(predictions) 110 | # torch doesn't have a NA type so it returns NaN 111 | predictions[is.nan(predictions)] <- NA 112 | predictions 113 | } 114 | 115 | predict_brulee_multinomial_reg_prob <- function(model, predictors, epoch) { 116 | predictions <- predict_brulee_multinomial_reg_raw(model, predictors, epoch) 117 | lvs <- get_levels(model) 118 | hardhat::spruce_prob(pred_levels = lvs, predictions) 119 | } 120 | 121 | predict_brulee_multinomial_reg_class <- function(model, predictors, epoch) { 122 | predictions <- predict_brulee_multinomial_reg_raw(model, predictors, epoch) 123 | predictions <- apply(predictions, 1, which.max2) # take the maximum value 124 | lvs <- get_levels(model) 125 | hardhat::spruce_class(factor(lvs[predictions], levels = lvs)) 126 | } 127 | -------------------------------------------------------------------------------- /R/schedulers.R: -------------------------------------------------------------------------------- 1 | #' Change the learning rate over time 2 | #' 3 | #' Learning rate schedulers alter the learning rate to adjust as training 4 | #' proceeds. In most cases, the learning rate decreases as epochs increase. 5 | #' The `schedule_*()` functions are individual schedulers and 6 | #' [set_learn_rate()] is a general interface. 7 | #' @param epoch An integer for the number of training epochs (zero being the 8 | #' initial value), 9 | #' @param initial A positive numeric value for the starting learning rate. 10 | #' @param decay A positive numeric constant for decreasing the rate (see 11 | #' Details below). 12 | #' @param reduction A positive numeric constant stating the proportional decrease 13 | #' in the learning rate occurring at every `steps` epochs. 14 | #' @param steps The number of epochs before the learning rate changes. 15 | #' @param largest The maximum learning rate in the cycle. 16 | #' @param step_size The half-length of a cycle. 17 | #' @param learn_rate A constant learning rate (when no scheduler is used), 18 | #' @param type A single character value for the type of scheduler. Possible 19 | #' values are: "decay_time", "decay_expo", "none", "cyclic", and "step". 20 | #' @param ... Arguments to pass to the individual scheduler functions (e.g. 21 | #' `reduction`). 22 | #' @return A numeric value for the updated learning rate. 23 | #' @details 24 | #' The details for how the schedulers change the rates: 25 | #' 26 | #' * `schedule_decay_time()`: \eqn{rate(epoch) = initial/(1 + decay \times epoch)} 27 | #' * `schedule_decay_expo()`: \eqn{rate(epoch) = initial\exp(-decay \times epoch)} 28 | #' * `schedule_step()`: \eqn{rate(epoch) = initial \times reduction^{floor(epoch / steps)}} 29 | #' * `schedule_cyclic()`: \eqn{cycle = floor( 1 + (epoch / 2 / step size) )}, 30 | #' \eqn{x = abs( ( epoch / step size ) - ( 2 * cycle) + 1 )}, and 31 | #' \eqn{rate(epoch) = initial + ( largest - initial ) * \max( 0, 1 - x)} 32 | #' 33 | #' 34 | #' @seealso [brulee_mlp()] 35 | #' @examples 36 | #' if (rlang::is_installed("purrr")) { 37 | #' library(ggplot2) 38 | #' library(dplyr) 39 | #' library(purrr) 40 | #' 41 | #' iters <- 0:50 42 | #' 43 | #' bind_rows( 44 | #' tibble(epoch = iters, rate = map_dbl(iters, schedule_decay_time), type = "decay_time"), 45 | #' tibble(epoch = iters, rate = map_dbl(iters, schedule_decay_expo), type = "decay_expo"), 46 | #' tibble(epoch = iters, rate = map_dbl(iters, schedule_step), type = "step"), 47 | #' tibble(epoch = iters, rate = map_dbl(iters, schedule_cyclic), type = "cyclic") 48 | #' ) |> 49 | #' ggplot(aes(epoch, rate)) + 50 | #' geom_line() + 51 | #' facet_wrap(~ type) 52 | #' 53 | #' } 54 | #' 55 | #' @export 56 | 57 | schedule_decay_time <- function(epoch, initial = 0.1, decay = 1) { 58 | check_rate_arg_value(initial) 59 | check_rate_arg_value(decay) 60 | initial / (1 + decay * epoch) 61 | } 62 | 63 | #' @export 64 | #' @rdname schedule_decay_time 65 | schedule_decay_expo <- function(epoch, initial = 0.1, decay = 1) { 66 | check_rate_arg_value(initial) 67 | check_rate_arg_value(decay) 68 | initial * exp(-decay * epoch) 69 | } 70 | 71 | #' @export 72 | #' @rdname schedule_decay_time 73 | schedule_step <- function(epoch, initial = 0.1, reduction = 1/2, steps = 5) { 74 | check_rate_arg_value(initial) 75 | check_rate_arg_value(reduction) 76 | check_rate_arg_value(steps) 77 | initial * reduction^floor(epoch / steps) 78 | } 79 | 80 | #' @export 81 | #' @rdname schedule_decay_time 82 | schedule_cyclic <- function(epoch, initial = 0.001, largest = 0.1, step_size = 5) { 83 | check_rate_arg_value(initial) 84 | check_rate_arg_value(largest) 85 | check_rate_arg_value(step_size) 86 | 87 | if (largest < initial) { 88 | tmp <- initial 89 | largest <- initial 90 | initial <- tmp 91 | } else if (largest == initial) { 92 | initial <- initial / 10 93 | } 94 | 95 | cycle <- floor( 1 + (epoch / 2 / step_size) ) 96 | x <- abs( ( epoch / step_size ) - ( 2 * cycle) + 1 ) 97 | initial + ( largest - initial ) * max( 0, 1 - x) 98 | } 99 | 100 | # Learning rate can be either static (via rate_schedule == "none") or dynamic. 101 | # Either way, set_learn_rate() figures this out and sets it accordingly. 102 | 103 | #' @export 104 | #' @rdname schedule_decay_time 105 | set_learn_rate <- function(epoch, learn_rate, type = "none", ...) { 106 | types <- c("decay_time", "decay_expo", "none", "step", "cyclic") 107 | types <- rlang::arg_match0(type, types, arg_nm = "type") 108 | if (type == "none") { 109 | return(learn_rate) 110 | } 111 | 112 | fn <- paste0("schedule_", type) 113 | args <- list(...) 114 | 115 | cl <- rlang::call2(fn, epoch = epoch, !!!args) 116 | rlang::eval_tidy(cl) 117 | } 118 | 119 | # ------------------------------------------------------------------------------ 120 | 121 | check_rate_arg_value <- function(x) { 122 | nm <- as.character(match.call()$x) 123 | if (is.null(x) || !is.numeric(x) || length(x) != 1 || any(x <= 0)) { 124 | msg <- paste0("Argument '", nm, "' should be a single positive value.") 125 | cli::cli_abort(msg) 126 | } 127 | invisible(NULL) 128 | } 129 | -------------------------------------------------------------------------------- /README.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | output: github_document 3 | --- 4 | 5 | 6 | 7 | ```{r, include = FALSE} 8 | knitr::opts_chunk$set( 9 | collapse = TRUE, 10 | comment = "#>", 11 | fig.path = "man/figures/README-", 12 | out.width = "100%" 13 | ) 14 | ``` 15 | 16 | # brulee a dish of creme brulee on a striped background 17 | 18 | 19 | [![R-CMD-check](https://github.com/tidymodels/brulee/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/tidymodels/brulee/actions/workflows/R-CMD-check.yaml) 20 | [![Codecov test coverage](https://codecov.io/gh/tidymodels/brulee/branch/main/graph/badge.svg)](https://app.codecov.io/gh/tidymodels/brulee?branch=main) 21 | [![Lifecycle: experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html) 22 | 23 | 24 | The R `brulee` package contains several basic modeling functions that use the `torch` package infrastructure, such as: 25 | 26 | * [neural networks](https://brulee.tidymodels.org/reference/brulee_mlp.html) 27 | * [linear regression](https://brulee.tidymodels.org/reference/brulee_linear_reg.html) 28 | * [logistic regression](https://brulee.tidymodels.org/reference/brulee_logistic_reg.html) 29 | * [multinomial regression](https://brulee.tidymodels.org/reference/brulee_multinomial_reg.html) 30 | 31 | 32 | ## Installation 33 | 34 | You can install the released version of brulee from [CRAN](https://CRAN.R-project.org) with: 35 | 36 | ``` r 37 | install.packages("brulee") 38 | ``` 39 | 40 | And the development version from [GitHub](https://github.com/tidymodels/brulee) with: 41 | 42 | ``` r 43 | # install.packages("pak") 44 | pak::pak("tidymodels/brulee") 45 | ``` 46 | ## Example 47 | 48 | `brulee` has formula, x/y, and recipe user interfaces for each function. For example: 49 | 50 | ```{r load, include = FALSE} 51 | library(brulee) 52 | library(yardstick) 53 | library(recipes) 54 | ``` 55 | ```{r class-fit-form} 56 | library(brulee) 57 | library(recipes) 58 | library(yardstick) 59 | 60 | data(bivariate, package = "modeldata") 61 | set.seed(20) 62 | nn_log_biv <- brulee_mlp(Class ~ log(A) + log(B), data = bivariate_train, 63 | epochs = 150, hidden_units = 3) 64 | 65 | # We use the tidymodels semantics to always return a tibble when predicting 66 | predict(nn_log_biv, bivariate_test, type = "prob") |> 67 | bind_cols(bivariate_test) |> 68 | roc_auc(Class, .pred_One) 69 | ``` 70 | 71 | A recipe can also be used if the data require some sort of preprocessing (e.g., indicator variables, transformations, or standardization): 72 | 73 | ```{r class-fit-rec} 74 | library(recipes) 75 | 76 | rec <- 77 | recipe(Class ~ ., data = bivariate_train) |> 78 | step_YeoJohnson(all_numeric_predictors()) |> 79 | step_normalize(all_numeric_predictors()) 80 | 81 | set.seed(20) 82 | nn_rec_biv <- brulee_mlp(rec, data = bivariate_train, 83 | epochs = 150, hidden_units = 3) 84 | 85 | # A little better 86 | predict(nn_rec_biv, bivariate_test, type = "prob") |> 87 | bind_cols(bivariate_test) |> 88 | roc_auc(Class, .pred_One) 89 | ``` 90 | 91 | ## Code of Conduct 92 | 93 | Please note that the brulee project is released with a [Contributor Code of Conduct](https://contributor-covenant.org/version/2/0/CODE_OF_CONDUCT.html). By contributing to this project, you agree to abide by its terms. 94 | 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # brulee a dish of creme brulee on a striped background 5 | 6 | 7 | 8 | [![R-CMD-check](https://github.com/tidymodels/brulee/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/tidymodels/brulee/actions/workflows/R-CMD-check.yaml) 9 | [![Codecov test 10 | coverage](https://codecov.io/gh/tidymodels/brulee/branch/main/graph/badge.svg)](https://app.codecov.io/gh/tidymodels/brulee?branch=main) 11 | [![Lifecycle: 12 | experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html) 13 | 14 | 15 | The R `brulee` package contains several basic modeling functions that 16 | use the `torch` package infrastructure, such as: 17 | 18 | - [neural 19 | networks](https://brulee.tidymodels.org/reference/brulee_mlp.html) 20 | - [linear 21 | regression](https://brulee.tidymodels.org/reference/brulee_linear_reg.html) 22 | - [logistic 23 | regression](https://brulee.tidymodels.org/reference/brulee_logistic_reg.html) 24 | - [multinomial 25 | regression](https://brulee.tidymodels.org/reference/brulee_multinomial_reg.html) 26 | 27 | ## Installation 28 | 29 | You can install the released version of brulee from 30 | [CRAN](https://CRAN.R-project.org) with: 31 | 32 | ``` r 33 | install.packages("brulee") 34 | ``` 35 | 36 | And the development version from 37 | [GitHub](https://github.com/tidymodels/brulee) with: 38 | 39 | ``` r 40 | # install.packages("pak") 41 | pak::pak("tidymodels/brulee") 42 | ``` 43 | 44 | ## Example 45 | 46 | `brulee` has formula, x/y, and recipe user interfaces for each function. 47 | For example: 48 | 49 | ``` r 50 | library(brulee) 51 | library(recipes) 52 | library(yardstick) 53 | 54 | data(bivariate, package = "modeldata") 55 | set.seed(20) 56 | nn_log_biv <- brulee_mlp(Class ~ log(A) + log(B), data = bivariate_train, 57 | epochs = 150, hidden_units = 3) 58 | 59 | # We use the tidymodels semantics to always return a tibble when predicting 60 | predict(nn_log_biv, bivariate_test, type = "prob") |> 61 | bind_cols(bivariate_test) |> 62 | roc_auc(Class, .pred_One) 63 | #> # A tibble: 1 × 3 64 | #> .metric .estimator .estimate 65 | #> 66 | #> 1 roc_auc binary 0.837 67 | ``` 68 | 69 | A recipe can also be used if the data require some sort of preprocessing 70 | (e.g., indicator variables, transformations, or standardization): 71 | 72 | ``` r 73 | library(recipes) 74 | 75 | rec <- 76 | recipe(Class ~ ., data = bivariate_train) |> 77 | step_YeoJohnson(all_numeric_predictors()) |> 78 | step_normalize(all_numeric_predictors()) 79 | 80 | set.seed(20) 81 | nn_rec_biv <- brulee_mlp(rec, data = bivariate_train, 82 | epochs = 150, hidden_units = 3) 83 | 84 | # A little better 85 | predict(nn_rec_biv, bivariate_test, type = "prob") |> 86 | bind_cols(bivariate_test) |> 87 | roc_auc(Class, .pred_One) 88 | #> # A tibble: 1 × 3 89 | #> .metric .estimator .estimate 90 | #> 91 | #> 1 roc_auc binary 0.866 92 | ``` 93 | 94 | ## Code of Conduct 95 | 96 | Please note that the brulee project is released with a [Contributor Code 97 | of 98 | Conduct](https://contributor-covenant.org/version/2/0/CODE_OF_CONDUCT.html). 99 | By contributing to this project, you agree to abide by its terms. 100 | -------------------------------------------------------------------------------- /_pkgdown.yml: -------------------------------------------------------------------------------- 1 | url: https://brulee.tidymodels.org/ 2 | 3 | template: 4 | package: tidytemplate 5 | bootstrap: 5 6 | bslib: 7 | primary: "#CA225E" 8 | 9 | includes: 10 | in_header: | 11 | 12 | 13 | development: 14 | mode: auto 15 | 16 | figures: 17 | fig.width: 8 18 | fig.height: 5.75 19 | -------------------------------------------------------------------------------- /brulee.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | ProjectId: 75a9e66f-e023-4afe-99fe-dc7ad5b6a3a8 3 | 4 | RestoreWorkspace: No 5 | SaveWorkspace: No 6 | AlwaysSaveHistory: Default 7 | 8 | EnableCodeIndexing: Yes 9 | UseSpacesForTab: Yes 10 | NumSpacesForTab: 1 11 | Encoding: UTF-8 12 | 13 | RnwWeave: knitr 14 | LaTeX: pdfLaTeX 15 | 16 | AutoAppendNewline: Yes 17 | StripTrailingWhitespace: Yes 18 | LineEndingConversion: Posix 19 | 20 | BuildType: Package 21 | PackageUseDevtools: Yes 22 | PackageInstallArgs: --no-multiarch --with-keep.source 23 | PackageRoxygenize: rd,collate,namespace 24 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: auto 8 | threshold: 1% 9 | informational: true 10 | patch: 11 | default: 12 | target: auto 13 | threshold: 1% 14 | informational: true 15 | -------------------------------------------------------------------------------- /inst/WORDLIST: -------------------------------------------------------------------------------- 1 | CMD 2 | Codecov 3 | LBFGS 4 | Lifecycle 5 | ORCID 6 | PBC 7 | SGD 8 | extensibility 9 | funder 10 | magrittr 11 | mlp 12 | multilayer 13 | perceptrons 14 | relu 15 | sigmoid 16 | tanh 17 | tibble 18 | tidymodels 19 | -------------------------------------------------------------------------------- /man/brulee-autoplot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/autoplot.R 3 | \name{brulee-autoplot} 4 | \alias{brulee-autoplot} 5 | \alias{autoplot.brulee_mlp} 6 | \alias{autoplot.brulee_logistic_reg} 7 | \alias{autoplot.brulee_multinomial_reg} 8 | \alias{autoplot.brulee_linear_reg} 9 | \title{Plot model loss over epochs} 10 | \usage{ 11 | \method{autoplot}{brulee_mlp}(object, ...) 12 | 13 | \method{autoplot}{brulee_logistic_reg}(object, ...) 14 | 15 | \method{autoplot}{brulee_multinomial_reg}(object, ...) 16 | 17 | \method{autoplot}{brulee_linear_reg}(object, ...) 18 | } 19 | \arguments{ 20 | \item{object}{A \code{brulee_mlp}, \code{brulee_logistic_reg}, 21 | \code{brulee_multinomial_reg}, or \code{brulee_linear_reg} object.} 22 | 23 | \item{...}{Not currently used} 24 | } 25 | \value{ 26 | A \code{ggplot} object. 27 | } 28 | \description{ 29 | Plot model loss over epochs 30 | } 31 | \details{ 32 | This function plots the loss function across the available epochs. A 33 | vertical line shows the epoch with the best loss value. 34 | } 35 | \examples{ 36 | \donttest{ 37 | if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "yardstick", "modeldata"))) { 38 | library(ggplot2) 39 | library(recipes) 40 | theme_set(theme_bw()) 41 | 42 | data(ames, package = "modeldata") 43 | 44 | ames$Sale_Price <- log10(ames$Sale_Price) 45 | 46 | set.seed(1) 47 | in_train <- sample(1:nrow(ames), 2000) 48 | ames_train <- ames[ in_train,] 49 | ames_test <- ames[-in_train,] 50 | 51 | ames_rec <- 52 | recipe(Sale_Price ~ Longitude + Latitude, data = ames_train) |> 53 | step_normalize(all_numeric_predictors()) 54 | 55 | set.seed(2) 56 | fit <- brulee_mlp(ames_rec, data = ames_train, epochs = 50, batch_size = 32) 57 | 58 | autoplot(fit) 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /man/brulee-coefs.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/coef.R 3 | \name{brulee-coefs} 4 | \alias{brulee-coefs} 5 | \alias{coef.brulee_logistic_reg} 6 | \alias{coef.brulee_linear_reg} 7 | \alias{coef.brulee_mlp} 8 | \alias{coef.brulee_multinomial_reg} 9 | \title{Extract Model Coefficients} 10 | \usage{ 11 | \method{coef}{brulee_logistic_reg}(object, epoch = NULL, ...) 12 | 13 | \method{coef}{brulee_linear_reg}(object, epoch = NULL, ...) 14 | 15 | \method{coef}{brulee_mlp}(object, epoch = NULL, ...) 16 | 17 | \method{coef}{brulee_multinomial_reg}(object, epoch = NULL, ...) 18 | } 19 | \arguments{ 20 | \item{object}{A model fit from \pkg{brulee}.} 21 | 22 | \item{epoch}{A single integer for the training iteration. If left \code{NULL}, 23 | the estimates from the best model fit (via internal performance metrics).} 24 | 25 | \item{...}{Not currently used.} 26 | } 27 | \value{ 28 | For logistic/linear regression, a named vector. For neural networks, 29 | a list of arrays. 30 | } 31 | \description{ 32 | Extract Model Coefficients 33 | } 34 | \examples{ 35 | \donttest{ 36 | if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "modeldata"))) { 37 | 38 | data(ames, package = "modeldata") 39 | 40 | ames$Sale_Price <- log10(ames$Sale_Price) 41 | 42 | set.seed(1) 43 | in_train <- sample(1:nrow(ames), 2000) 44 | ames_train <- ames[ in_train,] 45 | ames_test <- ames[-in_train,] 46 | 47 | # Using recipe 48 | library(recipes) 49 | 50 | ames_rec <- 51 | recipe(Sale_Price ~ Longitude + Latitude, data = ames_train) |> 52 | step_normalize(all_numeric_predictors()) 53 | 54 | set.seed(2) 55 | fit <- brulee_linear_reg(ames_rec, data = ames_train, 56 | epochs = 50, batch_size = 32) 57 | 58 | coef(fit) 59 | coef(fit, epoch = 1) 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /man/brulee-package.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/brulee-package.R 3 | \docType{package} 4 | \name{brulee-package} 5 | \alias{brulee} 6 | \alias{brulee-package} 7 | \title{brulee: High-Level Modeling Functions with 'torch'} 8 | \description{ 9 | \if{html}{\figure{logo.png}{options: style='float: right' alt='logo' width='120'}} 10 | 11 | Provides high-level modeling functions to define and train models using the 'torch' R package. Models include linear, logistic, and multinomial regression as well as multilayer perceptrons. 12 | } 13 | \seealso{ 14 | Useful links: 15 | \itemize{ 16 | \item \url{https://github.com/tidymodels/brulee} 17 | \item \url{https://brulee.tidymodels.org/} 18 | \item Report bugs at \url{https://github.com/tidymodels/brulee/issues} 19 | } 20 | 21 | } 22 | \author{ 23 | \strong{Maintainer}: Max Kuhn \email{max@posit.co} (\href{https://orcid.org/0000-0003-2402-136X}{ORCID}) 24 | 25 | Authors: 26 | \itemize{ 27 | \item Daniel Falbel \email{daniel@posit.co} 28 | } 29 | 30 | Other contributors: 31 | \itemize{ 32 | \item Posit Software, PBC [copyright holder, funder] 33 | } 34 | 35 | } 36 | \keyword{internal} 37 | -------------------------------------------------------------------------------- /man/brulee_activations.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/activation.R 3 | \name{brulee_activations} 4 | \alias{brulee_activations} 5 | \title{Activation functions for neural networks in brulee} 6 | \usage{ 7 | brulee_activations() 8 | } 9 | \value{ 10 | A character vector of values. 11 | } 12 | \description{ 13 | Activation functions for neural networks in brulee 14 | } 15 | -------------------------------------------------------------------------------- /man/brulee_linear_reg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/linear_reg-fit.R 3 | \name{brulee_linear_reg} 4 | \alias{brulee_linear_reg} 5 | \alias{brulee_linear_reg.default} 6 | \alias{brulee_linear_reg.data.frame} 7 | \alias{brulee_linear_reg.matrix} 8 | \alias{brulee_linear_reg.formula} 9 | \alias{brulee_linear_reg.recipe} 10 | \title{Fit a linear regression model} 11 | \usage{ 12 | brulee_linear_reg(x, ...) 13 | 14 | \method{brulee_linear_reg}{default}(x, ...) 15 | 16 | \method{brulee_linear_reg}{data.frame}( 17 | x, 18 | y, 19 | epochs = 20L, 20 | penalty = 0.001, 21 | mixture = 0, 22 | validation = 0.1, 23 | optimizer = "LBFGS", 24 | learn_rate = 1, 25 | momentum = 0, 26 | batch_size = NULL, 27 | stop_iter = 5, 28 | verbose = FALSE, 29 | ... 30 | ) 31 | 32 | \method{brulee_linear_reg}{matrix}( 33 | x, 34 | y, 35 | epochs = 20L, 36 | penalty = 0.001, 37 | mixture = 0, 38 | validation = 0.1, 39 | optimizer = "LBFGS", 40 | learn_rate = 1, 41 | momentum = 0, 42 | batch_size = NULL, 43 | stop_iter = 5, 44 | verbose = FALSE, 45 | ... 46 | ) 47 | 48 | \method{brulee_linear_reg}{formula}( 49 | formula, 50 | data, 51 | epochs = 20L, 52 | penalty = 0.001, 53 | mixture = 0, 54 | validation = 0.1, 55 | optimizer = "LBFGS", 56 | learn_rate = 1, 57 | momentum = 0, 58 | batch_size = NULL, 59 | stop_iter = 5, 60 | verbose = FALSE, 61 | ... 62 | ) 63 | 64 | \method{brulee_linear_reg}{recipe}( 65 | x, 66 | data, 67 | epochs = 20L, 68 | penalty = 0.001, 69 | mixture = 0, 70 | validation = 0.1, 71 | optimizer = "LBFGS", 72 | learn_rate = 1, 73 | momentum = 0, 74 | batch_size = NULL, 75 | stop_iter = 5, 76 | verbose = FALSE, 77 | ... 78 | ) 79 | } 80 | \arguments{ 81 | \item{x}{Depending on the context: 82 | \itemize{ 83 | \item A \strong{data frame} of predictors. 84 | \item A \strong{matrix} of predictors. 85 | \item A \strong{recipe} specifying a set of preprocessing steps 86 | created from \code{\link[recipes:recipe]{recipes::recipe()}}. 87 | } 88 | 89 | The predictor data should be standardized (e.g. centered or scaled).} 90 | 91 | \item{...}{Options to pass to the learning rate schedulers via 92 | \code{\link[=set_learn_rate]{set_learn_rate()}}. For example, the \code{reduction} or \code{steps} arguments to 93 | \code{\link[=schedule_step]{schedule_step()}} could be passed here.} 94 | 95 | \item{y}{When \code{x} is a \strong{data frame} or \strong{matrix}, \code{y} is the outcome 96 | specified as: 97 | \itemize{ 98 | \item A \strong{data frame} with 1 numeric column. 99 | \item A \strong{matrix} with 1 numeric column. 100 | \item A numeric \strong{vector}. 101 | }} 102 | 103 | \item{epochs}{An integer for the number of epochs of training.} 104 | 105 | \item{penalty}{The amount of weight decay (i.e., L2 regularization).} 106 | 107 | \item{mixture}{Proportion of Lasso Penalty (type: double, default: 0.0). A 108 | value of mixture = 1 corresponds to a pure lasso model, while mixture = 0 109 | indicates ridge regression (a.k.a weight decay).} 110 | 111 | \item{validation}{The proportion of the data randomly assigned to a 112 | validation set.} 113 | 114 | \item{optimizer}{The method used in the optimization procedure. Possible choices 115 | are 'LBFGS' and 'SGD'. Default is 'LBFGS'.} 116 | 117 | \item{learn_rate}{A positive number that controls the initial rapidity that 118 | the model moves along the descent path. Values around 0.1 or less are 119 | typical.} 120 | 121 | \item{momentum}{A positive number usually on \verb{[0.50, 0.99]} for the momentum 122 | parameter in gradient descent. (\code{optimizer = "SGD"} only)} 123 | 124 | \item{batch_size}{An integer for the number of training set points in each 125 | batch. (\code{optimizer = "SGD"} only)} 126 | 127 | \item{stop_iter}{A non-negative integer for how many iterations with no 128 | improvement before stopping.} 129 | 130 | \item{verbose}{A logical that prints out the iteration history.} 131 | 132 | \item{formula}{A formula specifying the outcome term(s) on the left-hand side, 133 | and the predictor term(s) on the right-hand side.} 134 | 135 | \item{data}{When a \strong{recipe} or \strong{formula} is used, \code{data} is specified as: 136 | \itemize{ 137 | \item A \strong{data frame} containing both the predictors and the outcome. 138 | }} 139 | } 140 | \value{ 141 | A \code{brulee_linear_reg} object with elements: 142 | \itemize{ 143 | \item \code{models_obj}: a serialized raw vector for the torch module. 144 | \item \code{estimates}: a list of matrices with the model parameter estimates per 145 | epoch. 146 | \item \code{best_epoch}: an integer for the epoch with the smallest loss. 147 | \item \code{loss}: A vector of loss values (MSE) at each epoch. 148 | \item \code{dim}: A list of data dimensions. 149 | \item \code{y_stats}: A list of summary statistics for numeric outcomes. 150 | \item \code{parameters}: A list of some tuning parameter values. 151 | \item \code{blueprint}: The \code{hardhat} blueprint data. 152 | } 153 | } 154 | \description{ 155 | \code{brulee_linear_reg()} fits a linear regression model. 156 | } 157 | \details{ 158 | This function fits a linear combination of coefficients and predictors to 159 | model the numeric outcome. The training process optimizes the 160 | mean squared error loss function. 161 | 162 | The function internally standardizes the outcome data to have mean zero and 163 | a standard deviation of one. The prediction function creates predictions on 164 | the original scale. 165 | 166 | By default, training halts when the validation loss increases for at least 167 | \code{step_iter} iterations. If \code{validation = 0} the training set loss is used. 168 | 169 | The \emph{predictors} data should all be numeric and encoded in the same units (e.g. 170 | standardized to the same range or distribution). If there are factor 171 | predictors, use a recipe or formula to create indicator variables (or some 172 | other method) to make them numeric. Predictors should be in the same units 173 | before training. 174 | 175 | The model objects are saved for each epoch so that the number of epochs can 176 | be efficiently tuned. Both the \code{\link[=coef]{coef()}} and \code{\link[=predict]{predict()}} methods for this 177 | model have an \code{epoch} argument (which defaults to the epoch with the best 178 | loss value). 179 | 180 | The use of the L1 penalty (a.k.a. the lasso penalty) does \emph{not} force 181 | parameters to be strictly zero (as it does in packages such as \pkg{glmnet}). 182 | The zeroing out of parameters is a specific feature the optimization method 183 | used in those packages. 184 | } 185 | \examples{ 186 | \donttest{ 187 | if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "yardstick", "modeldata"))) { 188 | 189 | ## ----------------------------------------------------------------------------- 190 | 191 | library(recipes) 192 | library(yardstick) 193 | 194 | data(ames, package = "modeldata") 195 | 196 | ames$Sale_Price <- log10(ames$Sale_Price) 197 | 198 | set.seed(122) 199 | in_train <- sample(1:nrow(ames), 2000) 200 | ames_train <- ames[ in_train,] 201 | ames_test <- ames[-in_train,] 202 | 203 | 204 | # Using matrices 205 | set.seed(1) 206 | brulee_linear_reg(x = as.matrix(ames_train[, c("Longitude", "Latitude")]), 207 | y = ames_train$Sale_Price, 208 | penalty = 0.10, epochs = 1, batch_size = 64) 209 | 210 | # Using recipe 211 | library(recipes) 212 | 213 | ames_rec <- 214 | recipe(Sale_Price ~ Bldg_Type + Neighborhood + Year_Built + Gr_Liv_Area + 215 | Full_Bath + Year_Sold + Lot_Area + Central_Air + Longitude + Latitude, 216 | data = ames_train) |> 217 | # Transform some highly skewed predictors 218 | step_BoxCox(Lot_Area, Gr_Liv_Area) |> 219 | # Lump some rarely occurring categories into "other" 220 | step_other(Neighborhood, threshold = 0.05) |> 221 | # Encode categorical predictors as binary. 222 | step_dummy(all_nominal_predictors(), one_hot = TRUE) |> 223 | # Add an interaction effect: 224 | step_interact(~ starts_with("Central_Air"):Year_Built) |> 225 | step_zv(all_predictors()) |> 226 | step_normalize(all_numeric_predictors()) 227 | 228 | set.seed(2) 229 | fit <- brulee_linear_reg(ames_rec, data = ames_train, 230 | epochs = 5, batch_size = 32) 231 | fit 232 | 233 | autoplot(fit) 234 | 235 | library(ggplot2) 236 | 237 | predict(fit, ames_test) |> 238 | bind_cols(ames_test) |> 239 | ggplot(aes(x = .pred, y = Sale_Price)) + 240 | geom_abline(col = "green") + 241 | geom_point(alpha = .3) + 242 | lims(x = c(4, 6), y = c(4, 6)) + 243 | coord_fixed(ratio = 1) 244 | 245 | library(yardstick) 246 | predict(fit, ames_test) |> 247 | bind_cols(ames_test) |> 248 | rmse(Sale_Price, .pred) 249 | 250 | } 251 | 252 | } 253 | } 254 | \seealso{ 255 | \code{\link[=predict.brulee_linear_reg]{predict.brulee_linear_reg()}}, \code{\link[=coef.brulee_linear_reg]{coef.brulee_linear_reg()}}, 256 | \code{\link[=autoplot.brulee_linear_reg]{autoplot.brulee_linear_reg()}} 257 | } 258 | -------------------------------------------------------------------------------- /man/brulee_logistic_reg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/logistic_reg-fit.R 3 | \name{brulee_logistic_reg} 4 | \alias{brulee_logistic_reg} 5 | \alias{brulee_logistic_reg.default} 6 | \alias{brulee_logistic_reg.data.frame} 7 | \alias{brulee_logistic_reg.matrix} 8 | \alias{brulee_logistic_reg.formula} 9 | \alias{brulee_logistic_reg.recipe} 10 | \title{Fit a logistic regression model} 11 | \usage{ 12 | brulee_logistic_reg(x, ...) 13 | 14 | \method{brulee_logistic_reg}{default}(x, ...) 15 | 16 | \method{brulee_logistic_reg}{data.frame}( 17 | x, 18 | y, 19 | epochs = 20L, 20 | penalty = 0.001, 21 | mixture = 0, 22 | validation = 0.1, 23 | optimizer = "LBFGS", 24 | learn_rate = 1, 25 | momentum = 0, 26 | batch_size = NULL, 27 | class_weights = NULL, 28 | stop_iter = 5, 29 | verbose = FALSE, 30 | ... 31 | ) 32 | 33 | \method{brulee_logistic_reg}{matrix}( 34 | x, 35 | y, 36 | epochs = 20L, 37 | penalty = 0.001, 38 | mixture = 0, 39 | validation = 0.1, 40 | optimizer = "LBFGS", 41 | learn_rate = 1, 42 | momentum = 0, 43 | batch_size = NULL, 44 | class_weights = NULL, 45 | stop_iter = 5, 46 | verbose = FALSE, 47 | ... 48 | ) 49 | 50 | \method{brulee_logistic_reg}{formula}( 51 | formula, 52 | data, 53 | epochs = 20L, 54 | penalty = 0.001, 55 | mixture = 0, 56 | validation = 0.1, 57 | optimizer = "LBFGS", 58 | learn_rate = 1, 59 | momentum = 0, 60 | batch_size = NULL, 61 | class_weights = NULL, 62 | stop_iter = 5, 63 | verbose = FALSE, 64 | ... 65 | ) 66 | 67 | \method{brulee_logistic_reg}{recipe}( 68 | x, 69 | data, 70 | epochs = 20L, 71 | penalty = 0.001, 72 | mixture = 0, 73 | validation = 0.1, 74 | optimizer = "LBFGS", 75 | learn_rate = 1, 76 | momentum = 0, 77 | batch_size = NULL, 78 | class_weights = NULL, 79 | stop_iter = 5, 80 | verbose = FALSE, 81 | ... 82 | ) 83 | } 84 | \arguments{ 85 | \item{x}{Depending on the context: 86 | \itemize{ 87 | \item A \strong{data frame} of predictors. 88 | \item A \strong{matrix} of predictors. 89 | \item A \strong{recipe} specifying a set of preprocessing steps 90 | created from \code{\link[recipes:recipe]{recipes::recipe()}}. 91 | } 92 | 93 | The predictor data should be standardized (e.g. centered or scaled).} 94 | 95 | \item{...}{Options to pass to the learning rate schedulers via 96 | \code{\link[=set_learn_rate]{set_learn_rate()}}. For example, the \code{reduction} or \code{steps} arguments to 97 | \code{\link[=schedule_step]{schedule_step()}} could be passed here.} 98 | 99 | \item{y}{When \code{x} is a \strong{data frame} or \strong{matrix}, \code{y} is the outcome 100 | specified as: 101 | \itemize{ 102 | \item A \strong{data frame} with 1 factor column (with two levels). 103 | \item A \strong{matrix} with 1 factor column (with two levels). 104 | \item A factor \strong{vector} (with two levels). 105 | }} 106 | 107 | \item{epochs}{An integer for the number of epochs of training.} 108 | 109 | \item{penalty}{The amount of weight decay (i.e., L2 regularization).} 110 | 111 | \item{mixture}{Proportion of Lasso Penalty (type: double, default: 0.0). A 112 | value of mixture = 1 corresponds to a pure lasso model, while mixture = 0 113 | indicates ridge regression (a.k.a weight decay).} 114 | 115 | \item{validation}{The proportion of the data randomly assigned to a 116 | validation set.} 117 | 118 | \item{optimizer}{The method used in the optimization procedure. Possible choices 119 | are 'LBFGS' and 'SGD'. Default is 'LBFGS'.} 120 | 121 | \item{learn_rate}{A positive number that controls the rapidity that the model 122 | moves along the descent path. Values around 0.1 or less are typical. 123 | (\code{optimizer = "SGD"} only)} 124 | 125 | \item{momentum}{A positive number usually on \verb{[0.50, 0.99]} for the momentum 126 | parameter in gradient descent. (\code{optimizer = "SGD"} only)} 127 | 128 | \item{batch_size}{An integer for the number of training set points in each 129 | batch. (\code{optimizer = "SGD"} only)} 130 | 131 | \item{class_weights}{Numeric class weights (classification only). The value 132 | can be: 133 | \itemize{ 134 | \item A named numeric vector (in any order) where the names are the outcome 135 | factor levels. 136 | \item An unnamed numeric vector assumed to be in the same order as the outcome 137 | factor levels. 138 | \item A single numeric value for the least frequent class in the training data 139 | and all other classes receive a weight of one. 140 | }} 141 | 142 | \item{stop_iter}{A non-negative integer for how many iterations with no 143 | improvement before stopping.} 144 | 145 | \item{verbose}{A logical that prints out the iteration history.} 146 | 147 | \item{formula}{A formula specifying the outcome term(s) on the left-hand side, 148 | and the predictor term(s) on the right-hand side.} 149 | 150 | \item{data}{When a \strong{recipe} or \strong{formula} is used, \code{data} is specified as: 151 | \itemize{ 152 | \item A \strong{data frame} containing both the predictors and the outcome. 153 | }} 154 | } 155 | \value{ 156 | A \code{brulee_logistic_reg} object with elements: 157 | \itemize{ 158 | \item \code{models_obj}: a serialized raw vector for the torch module. 159 | \item \code{estimates}: a list of matrices with the model parameter estimates per 160 | epoch. 161 | \item \code{best_epoch}: an integer for the epoch with the smallest loss. 162 | \item \code{loss}: A vector of loss values (MSE for regression, negative log- 163 | likelihood for classification) at each epoch. 164 | \item \code{dim}: A list of data dimensions. 165 | \item \code{parameters}: A list of some tuning parameter values. 166 | \item \code{blueprint}: The \code{hardhat} blueprint data. 167 | } 168 | } 169 | \description{ 170 | \code{brulee_logistic_reg()} fits a model. 171 | } 172 | \details{ 173 | This function fits a linear combination of coefficients and predictors to 174 | model the log odds of the classes. The training process optimizes the 175 | cross-entropy loss function (a.k.a Bernoulli loss). 176 | 177 | By default, training halts when the validation loss increases for at least 178 | \code{step_iter} iterations. If \code{validation = 0} the training set loss is used. 179 | 180 | The \emph{predictors} data should all be numeric and encoded in the same units (e.g. 181 | standardized to the same range or distribution). If there are factor 182 | predictors, use a recipe or formula to create indicator variables (or some 183 | other method) to make them numeric. Predictors should be in the same units 184 | before training. 185 | 186 | The model objects are saved for each epoch so that the number of epochs can 187 | be efficiently tuned. Both the \code{\link[=coef]{coef()}} and \code{\link[=predict]{predict()}} methods for this 188 | model have an \code{epoch} argument (which defaults to the epoch with the best 189 | loss value). 190 | 191 | The use of the L1 penalty (a.k.a. the lasso penalty) does \emph{not} force 192 | parameters to be strictly zero (as it does in packages such as \pkg{glmnet}). 193 | The zeroing out of parameters is a specific feature the optimization method 194 | used in those packages. 195 | } 196 | \examples{ 197 | \donttest{ 198 | if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "yardstick", "modeldata"))) { 199 | 200 | library(recipes) 201 | library(yardstick) 202 | 203 | ## ----------------------------------------------------------------------------- 204 | # increase # epochs to get better results 205 | 206 | data(cells, package = "modeldata") 207 | 208 | cells$case <- NULL 209 | 210 | set.seed(122) 211 | in_train <- sample(1:nrow(cells), 1000) 212 | cells_train <- cells[ in_train,] 213 | cells_test <- cells[-in_train,] 214 | 215 | # Using matrices 216 | set.seed(1) 217 | brulee_logistic_reg(x = as.matrix(cells_train[, c("fiber_width_ch_1", "width_ch_1")]), 218 | y = cells_train$class, 219 | penalty = 0.10, epochs = 3) 220 | 221 | # Using recipe 222 | library(recipes) 223 | 224 | cells_rec <- 225 | recipe(class ~ ., data = cells_train) |> 226 | # Transform some highly skewed predictors 227 | step_YeoJohnson(all_numeric_predictors()) |> 228 | step_normalize(all_numeric_predictors()) |> 229 | step_pca(all_numeric_predictors(), num_comp = 10) 230 | 231 | set.seed(2) 232 | fit <- brulee_logistic_reg(cells_rec, data = cells_train, 233 | penalty = .01, epochs = 5) 234 | fit 235 | 236 | autoplot(fit) 237 | 238 | library(yardstick) 239 | predict(fit, cells_test, type = "prob") |> 240 | bind_cols(cells_test) |> 241 | roc_auc(class, .pred_PS) 242 | } 243 | } 244 | } 245 | \seealso{ 246 | \code{\link[=predict.brulee_logistic_reg]{predict.brulee_logistic_reg()}}, \code{\link[=coef.brulee_logistic_reg]{coef.brulee_logistic_reg()}}, 247 | \code{\link[=autoplot.brulee_logistic_reg]{autoplot.brulee_logistic_reg()}} 248 | } 249 | -------------------------------------------------------------------------------- /man/brulee_multinomial_reg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/multinomial_reg-fit.R 3 | \name{brulee_multinomial_reg} 4 | \alias{brulee_multinomial_reg} 5 | \alias{brulee_multinomial_reg.default} 6 | \alias{brulee_multinomial_reg.data.frame} 7 | \alias{brulee_multinomial_reg.matrix} 8 | \alias{brulee_multinomial_reg.formula} 9 | \alias{brulee_multinomial_reg.recipe} 10 | \title{Fit a multinomial regression model} 11 | \usage{ 12 | brulee_multinomial_reg(x, ...) 13 | 14 | \method{brulee_multinomial_reg}{default}(x, ...) 15 | 16 | \method{brulee_multinomial_reg}{data.frame}( 17 | x, 18 | y, 19 | epochs = 20L, 20 | penalty = 0.001, 21 | mixture = 0, 22 | validation = 0.1, 23 | optimizer = "LBFGS", 24 | learn_rate = 1, 25 | momentum = 0, 26 | batch_size = NULL, 27 | class_weights = NULL, 28 | stop_iter = 5, 29 | verbose = FALSE, 30 | ... 31 | ) 32 | 33 | \method{brulee_multinomial_reg}{matrix}( 34 | x, 35 | y, 36 | epochs = 20L, 37 | penalty = 0.001, 38 | mixture = 0, 39 | validation = 0.1, 40 | optimizer = "LBFGS", 41 | learn_rate = 1, 42 | momentum = 0, 43 | batch_size = NULL, 44 | class_weights = NULL, 45 | stop_iter = 5, 46 | verbose = FALSE, 47 | ... 48 | ) 49 | 50 | \method{brulee_multinomial_reg}{formula}( 51 | formula, 52 | data, 53 | epochs = 20L, 54 | penalty = 0.001, 55 | mixture = 0, 56 | validation = 0.1, 57 | optimizer = "LBFGS", 58 | learn_rate = 1, 59 | momentum = 0, 60 | batch_size = NULL, 61 | class_weights = NULL, 62 | stop_iter = 5, 63 | verbose = FALSE, 64 | ... 65 | ) 66 | 67 | \method{brulee_multinomial_reg}{recipe}( 68 | x, 69 | data, 70 | epochs = 20L, 71 | penalty = 0.001, 72 | mixture = 0, 73 | validation = 0.1, 74 | optimizer = "LBFGS", 75 | learn_rate = 1, 76 | momentum = 0, 77 | batch_size = NULL, 78 | class_weights = NULL, 79 | stop_iter = 5, 80 | verbose = FALSE, 81 | ... 82 | ) 83 | } 84 | \arguments{ 85 | \item{x}{Depending on the context: 86 | \itemize{ 87 | \item A \strong{data frame} of predictors. 88 | \item A \strong{matrix} of predictors. 89 | \item A \strong{recipe} specifying a set of preprocessing steps 90 | created from \code{\link[recipes:recipe]{recipes::recipe()}}. 91 | } 92 | 93 | The predictor data should be standardized (e.g. centered or scaled).} 94 | 95 | \item{...}{Options to pass to the learning rate schedulers via 96 | \code{\link[=set_learn_rate]{set_learn_rate()}}. For example, the \code{reduction} or \code{steps} arguments to 97 | \code{\link[=schedule_step]{schedule_step()}} could be passed here.} 98 | 99 | \item{y}{When \code{x} is a \strong{data frame} or \strong{matrix}, \code{y} is the outcome 100 | specified as: 101 | \itemize{ 102 | \item A \strong{data frame} with 1 factor column (with three or more levels). 103 | \item A \strong{matrix} with 1 factor column (with three or more levels). 104 | \item A factor \strong{vector} (with three or more levels). 105 | }} 106 | 107 | \item{epochs}{An integer for the number of epochs of training.} 108 | 109 | \item{penalty}{The amount of weight decay (i.e., L2 regularization).} 110 | 111 | \item{mixture}{Proportion of Lasso Penalty (type: double, default: 0.0). A 112 | value of mixture = 1 corresponds to a pure lasso model, while mixture = 0 113 | indicates ridge regression (a.k.a weight decay).} 114 | 115 | \item{validation}{The proportion of the data randomly assigned to a 116 | validation set.} 117 | 118 | \item{optimizer}{The method used in the optimization procedure. Possible choices 119 | are 'LBFGS' and 'SGD'. Default is 'LBFGS'.} 120 | 121 | \item{learn_rate}{A positive number that controls the rapidity that the model 122 | moves along the descent path. Values around 0.1 or less are typical. 123 | (\code{optimizer = "SGD"} only)} 124 | 125 | \item{momentum}{A positive number usually on \verb{[0.50, 0.99]} for the momentum 126 | parameter in gradient descent. (\code{optimizer = "SGD"} only)} 127 | 128 | \item{batch_size}{An integer for the number of training set points in each 129 | batch. (\code{optimizer = "SGD"} only)} 130 | 131 | \item{class_weights}{Numeric class weights (classification only). The value 132 | can be: 133 | \itemize{ 134 | \item A named numeric vector (in any order) where the names are the outcome 135 | factor levels. 136 | \item An unnamed numeric vector assumed to be in the same order as the outcome 137 | factor levels. 138 | \item A single numeric value for the least frequent class in the training data 139 | and all other classes receive a weight of one. 140 | }} 141 | 142 | \item{stop_iter}{A non-negative integer for how many iterations with no 143 | improvement before stopping.} 144 | 145 | \item{verbose}{A logical that prints out the iteration history.} 146 | 147 | \item{formula}{A formula specifying the outcome term(s) on the left-hand side, 148 | and the predictor term(s) on the right-hand side.} 149 | 150 | \item{data}{When a \strong{recipe} or \strong{formula} is used, \code{data} is specified as: 151 | \itemize{ 152 | \item A \strong{data frame} containing both the predictors and the outcome. 153 | }} 154 | } 155 | \value{ 156 | A \code{brulee_multinomial_reg} object with elements: 157 | \itemize{ 158 | \item \code{models_obj}: a serialized raw vector for the torch module. 159 | \item \code{estimates}: a list of matrices with the model parameter estimates per 160 | epoch. 161 | \item \code{best_epoch}: an integer for the epoch with the smallest loss. 162 | \item \code{loss}: A vector of loss values (MSE for regression, negative log- 163 | likelihood for classification) at each epoch. 164 | \item \code{dim}: A list of data dimensions. 165 | \item \code{parameters}: A list of some tuning parameter values. 166 | \item \code{blueprint}: The \code{hardhat} blueprint data. 167 | } 168 | } 169 | \description{ 170 | \code{brulee_multinomial_reg()} fits a model. 171 | } 172 | \details{ 173 | This function fits a linear combination of coefficients and predictors to 174 | model the log of the class probabilities. The training process optimizes the 175 | cross-entropy loss function. 176 | 177 | By default, training halts when the validation loss increases for at least 178 | \code{step_iter} iterations. If \code{validation = 0} the training set loss is used. 179 | 180 | The \emph{predictors} data should all be numeric and encoded in the same units (e.g. 181 | standardized to the same range or distribution). If there are factor 182 | predictors, use a recipe or formula to create indicator variables (or some 183 | other method) to make them numeric. Predictors should be in the same units 184 | before training. 185 | 186 | The model objects are saved for each epoch so that the number of epochs can 187 | be efficiently tuned. Both the \code{\link[=coef]{coef()}} and \code{\link[=predict]{predict()}} methods for this 188 | model have an \code{epoch} argument (which defaults to the epoch with the best 189 | loss value). 190 | 191 | The use of the L1 penalty (a.k.a. the lasso penalty) does \emph{not} force 192 | parameters to be strictly zero (as it does in packages such as \pkg{glmnet}). 193 | The zeroing out of parameters is a specific feature the optimization method 194 | used in those packages. 195 | } 196 | \examples{ 197 | \donttest{ 198 | if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "yardstick", "modeldata"))) { 199 | 200 | library(recipes) 201 | library(yardstick) 202 | 203 | data(penguins, package = "modeldata") 204 | 205 | penguins <- penguins |> na.omit() 206 | 207 | set.seed(122) 208 | in_train <- sample(1:nrow(penguins), 200) 209 | penguins_train <- penguins[ in_train,] 210 | penguins_test <- penguins[-in_train,] 211 | 212 | rec <- recipe(island ~ ., data = penguins_train) |> 213 | step_dummy(species, sex) |> 214 | step_normalize(all_predictors()) 215 | 216 | set.seed(3) 217 | fit <- brulee_multinomial_reg(rec, data = penguins_train, epochs = 5) 218 | fit 219 | 220 | predict(fit, penguins_test) |> 221 | bind_cols(penguins_test) |> 222 | conf_mat(island, .pred_class) 223 | } 224 | } 225 | } 226 | \seealso{ 227 | \code{\link[=predict.brulee_multinomial_reg]{predict.brulee_multinomial_reg()}}, \code{\link[=coef.brulee_multinomial_reg]{coef.brulee_multinomial_reg()}}, 228 | \code{\link[=autoplot.brulee_multinomial_reg]{autoplot.brulee_multinomial_reg()}} 229 | } 230 | -------------------------------------------------------------------------------- /man/figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/brulee/a45f8d48d43a31636a2768b1101ce58528fec7e8/man/figures/logo.png -------------------------------------------------------------------------------- /man/matrix_to_dataset.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/convert_data.R 3 | \name{matrix_to_dataset} 4 | \alias{matrix_to_dataset} 5 | \title{Convert data to torch format} 6 | \usage{ 7 | matrix_to_dataset(x, y) 8 | } 9 | \arguments{ 10 | \item{x}{A numeric matrix of predictors.} 11 | 12 | \item{y}{A vector. If regression than \code{y} is numeric. For classification, it 13 | is a factor.} 14 | } 15 | \value{ 16 | An R6 index sampler object with classes "training_set", 17 | "dataset", and "R6". 18 | } 19 | \description{ 20 | For an x/y interface, \code{matrix_to_dataset()} converts the data to proper 21 | encodings then formats the results for consumption by \code{torch}. 22 | } 23 | \details{ 24 | Missing values should be removed before passing data to this function. 25 | } 26 | \examples{ 27 | if (torch::torch_is_installed()) { 28 | matrix_to_dataset(as.matrix(mtcars[, -1]), mtcars$mpg) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /man/predict.brulee_linear_reg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/linear_reg-predict.R 3 | \name{predict.brulee_linear_reg} 4 | \alias{predict.brulee_linear_reg} 5 | \title{Predict from a \code{brulee_linear_reg}} 6 | \usage{ 7 | \method{predict}{brulee_linear_reg}(object, new_data, type = NULL, epoch = NULL, ...) 8 | } 9 | \arguments{ 10 | \item{object}{A \code{brulee_linear_reg} object.} 11 | 12 | \item{new_data}{A data frame or matrix of new predictors.} 13 | 14 | \item{type}{A single character. The type of predictions to generate. 15 | Valid options are: 16 | \itemize{ 17 | \item \code{"numeric"} for numeric predictions. 18 | }} 19 | 20 | \item{epoch}{An integer for the epoch to make predictions. If this value 21 | is larger than the maximum number that was fit, a warning is issued and the 22 | parameters from the last epoch are used. If left \code{NULL}, the epoch 23 | associated with the smallest loss is used.} 24 | 25 | \item{...}{Not used, but required for extensibility.} 26 | } 27 | \value{ 28 | A tibble of predictions. The number of rows in the tibble is guaranteed 29 | to be the same as the number of rows in \code{new_data}. 30 | } 31 | \description{ 32 | Predict from a \code{brulee_linear_reg} 33 | } 34 | \examples{ 35 | \donttest{ 36 | if (torch::torch_is_installed() & rlang::is_installed("recipes")) { 37 | 38 | data(ames, package = "modeldata") 39 | 40 | ames$Sale_Price <- log10(ames$Sale_Price) 41 | 42 | set.seed(1) 43 | in_train <- sample(1:nrow(ames), 2000) 44 | ames_train <- ames[ in_train,] 45 | ames_test <- ames[-in_train,] 46 | 47 | # Using recipe 48 | library(recipes) 49 | 50 | ames_rec <- 51 | recipe(Sale_Price ~ Longitude + Latitude, data = ames_train) |> 52 | step_normalize(all_numeric_predictors()) 53 | 54 | set.seed(2) 55 | fit <- brulee_linear_reg(ames_rec, data = ames_train, 56 | epochs = 50, batch_size = 32) 57 | 58 | predict(fit, ames_test) 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /man/predict.brulee_logistic_reg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/logistic_reg-predict.R 3 | \name{predict.brulee_logistic_reg} 4 | \alias{predict.brulee_logistic_reg} 5 | \title{Predict from a \code{brulee_logistic_reg}} 6 | \usage{ 7 | \method{predict}{brulee_logistic_reg}(object, new_data, type = NULL, epoch = NULL, ...) 8 | } 9 | \arguments{ 10 | \item{object}{A \code{brulee_logistic_reg} object.} 11 | 12 | \item{new_data}{A data frame or matrix of new predictors.} 13 | 14 | \item{type}{A single character. The type of predictions to generate. 15 | Valid options are: 16 | \itemize{ 17 | \item \code{"class"} for hard class predictions 18 | \item \code{"prob"} for soft class predictions (i.e., class probabilities) 19 | }} 20 | 21 | \item{epoch}{An integer for the epoch to make predictions. If this value 22 | is larger than the maximum number that was fit, a warning is issued and the 23 | parameters from the last epoch are used. If left \code{NULL}, the epoch 24 | associated with the smallest loss is used.} 25 | 26 | \item{...}{Not used, but required for extensibility.} 27 | } 28 | \value{ 29 | A tibble of predictions. The number of rows in the tibble is guaranteed 30 | to be the same as the number of rows in \code{new_data}. 31 | } 32 | \description{ 33 | Predict from a \code{brulee_logistic_reg} 34 | } 35 | \examples{ 36 | \donttest{ 37 | if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "yardstick", "modeldata"))) { 38 | 39 | library(recipes) 40 | library(yardstick) 41 | 42 | data(penguins, package = "modeldata") 43 | 44 | penguins <- penguins |> na.omit() 45 | 46 | set.seed(122) 47 | in_train <- sample(1:nrow(penguins), 200) 48 | penguins_train <- penguins[ in_train,] 49 | penguins_test <- penguins[-in_train,] 50 | 51 | rec <- recipe(sex ~ ., data = penguins_train) |> 52 | step_dummy(all_nominal_predictors()) |> 53 | step_normalize(all_numeric_predictors()) 54 | 55 | set.seed(3) 56 | fit <- brulee_logistic_reg(rec, data = penguins_train, epochs = 5) 57 | fit 58 | 59 | predict(fit, penguins_test) 60 | 61 | predict(fit, penguins_test, type = "prob") |> 62 | bind_cols(penguins_test) |> 63 | roc_curve(sex, .pred_female) |> 64 | autoplot() 65 | 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /man/predict.brulee_mlp.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/mlp-predict.R 3 | \name{predict.brulee_mlp} 4 | \alias{predict.brulee_mlp} 5 | \title{Predict from a \code{brulee_mlp}} 6 | \usage{ 7 | \method{predict}{brulee_mlp}(object, new_data, type = NULL, epoch = NULL, ...) 8 | } 9 | \arguments{ 10 | \item{object}{A \code{brulee_mlp} object.} 11 | 12 | \item{new_data}{A data frame or matrix of new predictors.} 13 | 14 | \item{type}{A single character. The type of predictions to generate. 15 | Valid options are: 16 | \itemize{ 17 | \item \code{"numeric"} for numeric predictions. 18 | \item \code{"class"} for hard class predictions 19 | \item \code{"prob"} for soft class predictions (i.e., class probabilities) 20 | }} 21 | 22 | \item{epoch}{An integer for the epoch to make predictions. If this value 23 | is larger than the maximum number that was fit, a warning is issued and the 24 | parameters from the last epoch are used. If left \code{NULL}, the epoch 25 | associated with the smallest loss is used.} 26 | 27 | \item{...}{Not used, but required for extensibility.} 28 | } 29 | \value{ 30 | A tibble of predictions. The number of rows in the tibble is guaranteed 31 | to be the same as the number of rows in \code{new_data}. 32 | } 33 | \description{ 34 | Predict from a \code{brulee_mlp} 35 | } 36 | \examples{ 37 | \donttest{ 38 | if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "modeldata"))) { 39 | # regression example: 40 | 41 | data(ames, package = "modeldata") 42 | 43 | ames$Sale_Price <- log10(ames$Sale_Price) 44 | 45 | set.seed(1) 46 | in_train <- sample(1:nrow(ames), 2000) 47 | ames_train <- ames[ in_train,] 48 | ames_test <- ames[-in_train,] 49 | 50 | # Using recipe 51 | library(recipes) 52 | 53 | ames_rec <- 54 | recipe(Sale_Price ~ Longitude + Latitude, data = ames_train) |> 55 | step_normalize(all_numeric_predictors()) 56 | 57 | set.seed(2) 58 | fit <- brulee_mlp(ames_rec, data = ames_train, epochs = 50, batch_size = 32) 59 | 60 | predict(fit, ames_test) 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /man/predict.brulee_multinomial_reg.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/multinomial_reg-predict.R 3 | \name{predict.brulee_multinomial_reg} 4 | \alias{predict.brulee_multinomial_reg} 5 | \title{Predict from a \code{brulee_multinomial_reg}} 6 | \usage{ 7 | \method{predict}{brulee_multinomial_reg}(object, new_data, type = NULL, epoch = NULL, ...) 8 | } 9 | \arguments{ 10 | \item{object}{A \code{brulee_multinomial_reg} object.} 11 | 12 | \item{new_data}{A data frame or matrix of new predictors.} 13 | 14 | \item{type}{A single character. The type of predictions to generate. 15 | Valid options are: 16 | \itemize{ 17 | \item \code{"class"} for hard class predictions 18 | \item \code{"prob"} for soft class predictions (i.e., class probabilities) 19 | }} 20 | 21 | \item{epoch}{An integer for the epoch to make predictions. If this value 22 | is larger than the maximum number that was fit, a warning is issued and the 23 | parameters from the last epoch are used. If left \code{NULL}, the epoch 24 | associated with the smallest loss is used.} 25 | 26 | \item{...}{Not used, but required for extensibility.} 27 | } 28 | \value{ 29 | A tibble of predictions. The number of rows in the tibble is guaranteed 30 | to be the same as the number of rows in \code{new_data}. 31 | } 32 | \description{ 33 | Predict from a \code{brulee_multinomial_reg} 34 | } 35 | \examples{ 36 | \donttest{ 37 | if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "yardstick", "modeldata"))) { 38 | 39 | library(recipes) 40 | library(yardstick) 41 | 42 | data(penguins, package = "modeldata") 43 | 44 | penguins <- penguins |> na.omit() 45 | 46 | set.seed(122) 47 | in_train <- sample(1:nrow(penguins), 200) 48 | penguins_train <- penguins[ in_train,] 49 | penguins_test <- penguins[-in_train,] 50 | 51 | rec <- recipe(island ~ ., data = penguins_train) |> 52 | step_dummy(species, sex) |> 53 | step_normalize(all_numeric_predictors()) 54 | 55 | set.seed(3) 56 | fit <- brulee_multinomial_reg(rec, data = penguins_train, epochs = 5) 57 | fit 58 | 59 | predict(fit, penguins_test) |> 60 | bind_cols(penguins_test) |> 61 | conf_mat(island, .pred_class) 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /man/reexports.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/aaa.R 3 | \docType{import} 4 | \name{reexports} 5 | \alias{reexports} 6 | \alias{autoplot} 7 | \alias{tunable} 8 | \alias{coef} 9 | \title{Objects exported from other packages} 10 | \keyword{internal} 11 | \description{ 12 | These objects are imported from other packages. Follow the links 13 | below to see their documentation. 14 | 15 | \describe{ 16 | \item{generics}{\code{\link[generics]{tunable}}} 17 | 18 | \item{ggplot2}{\code{\link[ggplot2]{autoplot}}} 19 | 20 | \item{stats}{\code{\link[stats]{coef}}} 21 | }} 22 | 23 | -------------------------------------------------------------------------------- /man/schedule_decay_time.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/schedulers.R 3 | \name{schedule_decay_time} 4 | \alias{schedule_decay_time} 5 | \alias{schedule_decay_expo} 6 | \alias{schedule_step} 7 | \alias{schedule_cyclic} 8 | \alias{set_learn_rate} 9 | \title{Change the learning rate over time} 10 | \usage{ 11 | schedule_decay_time(epoch, initial = 0.1, decay = 1) 12 | 13 | schedule_decay_expo(epoch, initial = 0.1, decay = 1) 14 | 15 | schedule_step(epoch, initial = 0.1, reduction = 1/2, steps = 5) 16 | 17 | schedule_cyclic(epoch, initial = 0.001, largest = 0.1, step_size = 5) 18 | 19 | set_learn_rate(epoch, learn_rate, type = "none", ...) 20 | } 21 | \arguments{ 22 | \item{epoch}{An integer for the number of training epochs (zero being the 23 | initial value),} 24 | 25 | \item{initial}{A positive numeric value for the starting learning rate.} 26 | 27 | \item{decay}{A positive numeric constant for decreasing the rate (see 28 | Details below).} 29 | 30 | \item{reduction}{A positive numeric constant stating the proportional decrease 31 | in the learning rate occurring at every \code{steps} epochs.} 32 | 33 | \item{steps}{The number of epochs before the learning rate changes.} 34 | 35 | \item{largest}{The maximum learning rate in the cycle.} 36 | 37 | \item{step_size}{The half-length of a cycle.} 38 | 39 | \item{learn_rate}{A constant learning rate (when no scheduler is used),} 40 | 41 | \item{type}{A single character value for the type of scheduler. Possible 42 | values are: "decay_time", "decay_expo", "none", "cyclic", and "step".} 43 | 44 | \item{...}{Arguments to pass to the individual scheduler functions (e.g. 45 | \code{reduction}).} 46 | } 47 | \value{ 48 | A numeric value for the updated learning rate. 49 | } 50 | \description{ 51 | Learning rate schedulers alter the learning rate to adjust as training 52 | proceeds. In most cases, the learning rate decreases as epochs increase. 53 | The \verb{schedule_*()} functions are individual schedulers and 54 | \code{\link[=set_learn_rate]{set_learn_rate()}} is a general interface. 55 | } 56 | \details{ 57 | The details for how the schedulers change the rates: 58 | \itemize{ 59 | \item \code{schedule_decay_time()}: \eqn{rate(epoch) = initial/(1 + decay \times epoch)} 60 | \item \code{schedule_decay_expo()}: \eqn{rate(epoch) = initial\exp(-decay \times epoch)} 61 | \item \code{schedule_step()}: \eqn{rate(epoch) = initial \times reduction^{floor(epoch / steps)}} 62 | \item \code{schedule_cyclic()}: \eqn{cycle = floor( 1 + (epoch / 2 / step size) )}, 63 | \eqn{x = abs( ( epoch / step size ) - ( 2 * cycle) + 1 )}, and 64 | \eqn{rate(epoch) = initial + ( largest - initial ) * \max( 0, 1 - x)} 65 | } 66 | } 67 | \examples{ 68 | if (rlang::is_installed("purrr")) { 69 | library(ggplot2) 70 | library(dplyr) 71 | library(purrr) 72 | 73 | iters <- 0:50 74 | 75 | bind_rows( 76 | tibble(epoch = iters, rate = map_dbl(iters, schedule_decay_time), type = "decay_time"), 77 | tibble(epoch = iters, rate = map_dbl(iters, schedule_decay_expo), type = "decay_expo"), 78 | tibble(epoch = iters, rate = map_dbl(iters, schedule_step), type = "step"), 79 | tibble(epoch = iters, rate = map_dbl(iters, schedule_cyclic), type = "cyclic") 80 | ) |> 81 | ggplot(aes(epoch, rate)) + 82 | geom_line() + 83 | facet_wrap(~ type) 84 | 85 | } 86 | 87 | } 88 | \seealso{ 89 | \code{\link[=brulee_mlp]{brulee_mlp()}} 90 | } 91 | -------------------------------------------------------------------------------- /tests/spelling.R: -------------------------------------------------------------------------------- 1 | if(requireNamespace('spelling', quietly = TRUE)) 2 | spelling::spell_check_test(vignettes = TRUE, error = FALSE, 3 | skip_on_cran = TRUE) 4 | -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(brulee) 3 | library(tibble) 4 | 5 | RNGkind("Mersenne-Twister") 6 | 7 | if (torch::torch_is_installed()) { 8 | test_check("brulee") 9 | } 10 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/checks.md: -------------------------------------------------------------------------------- 1 | # checking double vectors 2 | 3 | Code 4 | check_number_decimal_vec(letters) 5 | Condition 6 | Error in `check_number_decimal_vec()`: 7 | ! `letters` should be a double vector. 8 | 9 | --- 10 | 11 | Code 12 | check_number_decimal_vec(variable) 13 | Condition 14 | Error in `check_number_decimal_vec()`: 15 | ! `variable` should not contain missing values. 16 | 17 | --- 18 | 19 | Code 20 | check_number_decimal_vec(variable) 21 | Condition 22 | Error in `check_number_decimal_vec()`: 23 | ! `variable` should be a double vector. 24 | 25 | # checking whole number vectors 26 | 27 | Code 28 | check_number_whole_vec(variable) 29 | Condition 30 | Error: 31 | ! `variable` must be a whole number, not the number 0.5. 32 | 33 | --- 34 | 35 | Code 36 | check_number_whole_vec(variable) 37 | Condition 38 | Error: 39 | ! `variable` must be a whole number, not an integer `NA`. 40 | 41 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/class-weight.md: -------------------------------------------------------------------------------- 1 | # setting class weights 2 | 3 | Code 4 | brulee:::check_class_weights("a", lvls, cls_xtab, "fabulous") 5 | Condition 6 | Error in `brulee:::check_class_weights()`: 7 | ! fabulous() expected 'class_weights' to a numeric vector 8 | 9 | --- 10 | 11 | Code 12 | brulee:::check_class_weights(c(1, 6.25), lvls, cls_xtab, "fabulous") 13 | Condition 14 | Error in `brulee:::check_class_weights()`: 15 | ! There were 2 class weights given but 3 were expected. 16 | 17 | --- 18 | 19 | Code 20 | brulee:::check_class_weights(bad_wts, lvls, cls_xtab, "fabulous") 21 | Condition 22 | Error in `brulee:::check_class_weights()`: 23 | ! Names for class weights should be: 'one', 'two', 'three' 24 | 25 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/mlp-regression.md: -------------------------------------------------------------------------------- 1 | # bad args 2 | 3 | Code 4 | brulee_mlp(reg_x_mat, reg_y, epochs = NA) 5 | Condition 6 | Error in `check_integer()`: 7 | ! brulee_mlp() expected 'epochs' to be integer. 8 | 9 | --- 10 | 11 | Code 12 | brulee_mlp(reg_x_mat, reg_y, epochs = 1:2) 13 | Condition 14 | Error in `check_integer()`: 15 | ! brulee_mlp() expected 'epochs' to be a single integer. 16 | 17 | --- 18 | 19 | Code 20 | brulee_mlp(reg_x_mat, reg_y, epochs = 0L) 21 | Condition 22 | Error in `check_integer()`: 23 | ! brulee_mlp() expected 'epochs' to be an integer on [1, Inf]. 24 | 25 | --- 26 | 27 | Code 28 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, hidden_units = NA) 29 | Condition 30 | Error in `check_integer()`: 31 | ! brulee_mlp() expected 'hidden_units' to be integer. 32 | 33 | --- 34 | 35 | Code 36 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, hidden_units = -1L) 37 | Condition 38 | Error in `check_integer()`: 39 | ! brulee_mlp() expected 'hidden_units' to be an integer on [1, Inf]. 40 | 41 | --- 42 | 43 | Code 44 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, activation = NA) 45 | Condition 46 | Error in `brulee_mlp_bridge()`: 47 | ! `activation` should be one of: celu, elu, gelu, hardshrink, hardsigmoid, hardtanh, leaky_relu, linear, log_sigmoid, relu, relu6, rrelu, selu, sigmoid, silu, softplus, softshrink, softsign, tanh, and tanhshrink, not NA. 48 | 49 | --- 50 | 51 | Code 52 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, penalty = NA) 53 | Condition 54 | Error in `check_double()`: 55 | ! brulee_mlp() expected 'penalty' to be a double. 56 | 57 | --- 58 | 59 | Code 60 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, penalty = runif(2)) 61 | Condition 62 | Error in `check_double()`: 63 | ! brulee_mlp() expected 'penalty' to be a single double. 64 | 65 | --- 66 | 67 | Code 68 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, penalty = -1.1) 69 | Condition 70 | Error in `check_double()`: 71 | ! brulee_mlp() expected 'penalty' to be a double on [0, Inf]. 72 | 73 | --- 74 | 75 | Code 76 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, dropout = NA) 77 | Condition 78 | Error in `check_double()`: 79 | ! brulee_mlp() expected 'dropout' to be a double. 80 | 81 | --- 82 | 83 | Code 84 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, dropout = runif(2)) 85 | Condition 86 | Error in `check_double()`: 87 | ! brulee_mlp() expected 'dropout' to be a single double. 88 | 89 | --- 90 | 91 | Code 92 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, dropout = -1.1) 93 | Condition 94 | Error in `check_double()`: 95 | ! brulee_mlp() expected 'dropout' to be a double on [0, 1). 96 | 97 | --- 98 | 99 | Code 100 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, dropout = 1) 101 | Condition 102 | Error in `check_double()`: 103 | ! brulee_mlp() expected 'dropout' to be a double on [0, 1). 104 | 105 | --- 106 | 107 | Code 108 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, validation = NA) 109 | Condition 110 | Error in `check_double()`: 111 | ! brulee_mlp() expected 'validation' to be a double. 112 | 113 | --- 114 | 115 | Code 116 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, validation = runif(2)) 117 | Condition 118 | Error in `check_double()`: 119 | ! brulee_mlp() expected 'validation' to be a single double. 120 | 121 | --- 122 | 123 | Code 124 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, validation = -1.1) 125 | Condition 126 | Error in `check_double()`: 127 | ! brulee_mlp() expected 'validation' to be a double on [0, 1). 128 | 129 | --- 130 | 131 | Code 132 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, validation = 1) 133 | Condition 134 | Error in `check_double()`: 135 | ! brulee_mlp() expected 'validation' to be a double on [0, 1). 136 | 137 | --- 138 | 139 | Code 140 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, learn_rate = NA) 141 | Condition 142 | Error in `check_double()`: 143 | ! brulee_mlp() expected 'learn_rate' to be a double. 144 | 145 | --- 146 | 147 | Code 148 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, learn_rate = runif(2)) 149 | Condition 150 | Error in `check_double()`: 151 | ! brulee_mlp() expected 'learn_rate' to be a single double. 152 | 153 | --- 154 | 155 | Code 156 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, learn_rate = -1.1) 157 | Condition 158 | Error in `check_double()`: 159 | ! brulee_mlp() expected 'learn_rate' to be a double on (0, Inf]. 160 | 161 | --- 162 | 163 | Code 164 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, verbose = 2) 165 | Condition 166 | Error in `check_logical()`: 167 | ! brulee_mlp() expected 'verbose' to be logical. 168 | 169 | --- 170 | 171 | Code 172 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, verbose = rep(TRUE, 10)) 173 | Condition 174 | Error in `check_logical()`: 175 | ! brulee_mlp() expected 'verbose' to be a single logical. 176 | 177 | --- 178 | 179 | Code 180 | brulee:::new_brulee_mlp(model_obj = bad_models$model_obj, estimates = bad_models$ 181 | estimates, best_epoch = bad_models$best_epoch, loss = bad_models$loss, dims = bad_models$ 182 | dims, y_stats = bad_models$y_stats, parameters = bad_models$parameters, 183 | blueprint = bad_models$blueprint) 184 | Condition 185 | Error in `brulee:::new_brulee_mlp()`: 186 | ! 'model_obj' should be a raw vector. 187 | 188 | --- 189 | 190 | Code 191 | brulee:::new_brulee_mlp(model_obj = bad_est$model_obj, estimates = bad_est$ 192 | estimates, best_epoch = bad_est$best_epoch, loss = bad_est$loss, dims = bad_est$ 193 | dims, y_stats = bad_est$y_stats, parameters = bad_est$parameters, blueprint = bad_est$ 194 | blueprint) 195 | Condition 196 | Error in `brulee:::new_brulee_mlp()`: 197 | ! 'parameters' should be a list 198 | 199 | --- 200 | 201 | Code 202 | brulee:::new_brulee_mlp(model_obj = bad_loss$model_obj, estimates = bad_loss$ 203 | estimates, best_epoch = bad_loss$best_epoch, loss = bad_loss$loss, dims = bad_loss$ 204 | dims, y_stats = bad_loss$y_stats, parameters = bad_loss$parameters, 205 | blueprint = bad_loss$blueprint) 206 | Condition 207 | Error in `brulee:::new_brulee_mlp()`: 208 | ! 'loss' should be a numeric vector 209 | 210 | --- 211 | 212 | Code 213 | brulee:::new_brulee_mlp(model_obj = bad_dims$model_obj, estimates = bad_dims$ 214 | estimates, best_epoch = bad_dims$best_epoch, loss = bad_dims$loss, dims = bad_dims$ 215 | dims, y_stats = bad_dims$y_stats, parameters = bad_dims$parameters, 216 | blueprint = bad_dims$blueprint) 217 | Condition 218 | Error in `brulee:::new_brulee_mlp()`: 219 | ! 'dims' should be a list 220 | 221 | --- 222 | 223 | Code 224 | brulee:::new_brulee_mlp(model_obj = bad_parameters$model_obj, estimates = bad_parameters$ 225 | estimates, best_epoch = bad_parameters$best_epoch, loss = bad_parameters$loss, 226 | dims = bad_parameters$dims, y_stats = bad_parameters$y_stats, parameters = bad_parameters$ 227 | parameters, blueprint = bad_parameters$blueprint) 228 | Condition 229 | Error in `brulee:::new_brulee_mlp()`: 230 | ! 'dims' should be a list 231 | 232 | --- 233 | 234 | Code 235 | brulee:::new_brulee_mlp(model_obj = bad_blueprint$model_obj, estimates = bad_blueprint$ 236 | estimates, best_epoch = bad_blueprint$best_epoch, loss = bad_blueprint$loss, 237 | dims = bad_blueprint$dims, y_stats = bad_blueprint$y_stats, parameters = bad_blueprint$ 238 | parameters, blueprint = bad_blueprint$blueprint) 239 | Condition 240 | Error in `brulee:::new_brulee_mlp()`: 241 | ! 'blueprint' should be a hardhat blueprint 242 | 243 | # variable hidden_units length 244 | 245 | Code 246 | model <- brulee_mlp(x, y, hidden_units = c(2, 3, 4), epochs = 1, activation = c( 247 | "relu", "tanh")) 248 | Condition 249 | Error in `brulee_mlp_bridge()`: 250 | ! 'activation' must be a single value or a vector with the same length as 'hidden_units' 251 | 252 | --- 253 | 254 | Code 255 | model <- brulee_mlp(x, y, hidden_units = c(1), epochs = 1, activation = c( 256 | "relu", "tanh")) 257 | Condition 258 | Error in `brulee_mlp_bridge()`: 259 | ! 'activation' must be a single value or a vector with the same length as 'hidden_units' 260 | 261 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/schedulers.md: -------------------------------------------------------------------------------- 1 | # scheduling functions 2 | 3 | Argument 'initial' should be a single positive value. 4 | 5 | --- 6 | 7 | Argument 'decay' should be a single positive value. 8 | 9 | --- 10 | 11 | Argument 'initial' should be a single positive value. 12 | 13 | --- 14 | 15 | Argument 'decay' should be a single positive value. 16 | 17 | --- 18 | 19 | Argument 'initial' should be a single positive value. 20 | 21 | --- 22 | 23 | Argument 'reduction' should be a single positive value. 24 | 25 | --- 26 | 27 | Argument 'steps' should be a single positive value. 28 | 29 | --- 30 | 31 | Argument 'step_size' should be a single positive value. 32 | 33 | --- 34 | 35 | Argument 'largest' should be a single positive value. 36 | 37 | --- 38 | 39 | Argument 'initial' should be a single positive value. 40 | 41 | --- 42 | 43 | `type` must be one of "decay_time", "decay_expo", "none", "step", or "cyclic", not "random". 44 | 45 | -------------------------------------------------------------------------------- /tests/testthat/test-checks.R: -------------------------------------------------------------------------------- 1 | test_that("checking double vectors", { 2 | variable <- seq(0, 1, length = 3) 3 | expect_silent(check_number_decimal_vec(variable)) 4 | expect_silent(check_number_decimal_vec(variable[1])) 5 | 6 | expect_snapshot(check_number_decimal_vec(letters), error = TRUE) 7 | 8 | variable <- NA_real_ 9 | expect_snapshot(check_number_decimal_vec(variable), error = TRUE) 10 | expect_silent(check_number_decimal_vec(variable, allow_na = TRUE)) 11 | 12 | variable <- 1L 13 | expect_snapshot(check_number_decimal_vec(variable), error = TRUE) 14 | 15 | }) 16 | 17 | test_that("checking whole number vectors", { 18 | variable <- 1:2 19 | expect_silent(check_number_whole_vec(variable)) 20 | expect_silent(check_number_whole_vec(variable[1])) 21 | 22 | variable <- seq(0, 1, length.out = 3) 23 | expect_snapshot(check_number_whole_vec(variable), error = TRUE) 24 | 25 | variable <- NA_integer_ 26 | expect_snapshot(check_number_whole_vec(variable), error = TRUE) 27 | expect_silent(check_number_whole_vec(variable, allow_na = TRUE)) 28 | 29 | }) 30 | -------------------------------------------------------------------------------- /tests/testthat/test-class-weight.R: -------------------------------------------------------------------------------- 1 | 2 | test_that("setting class weights", { 3 | skip_if_not(torch::torch_is_installed()) 4 | skip_if_not_installed("modeldata") 5 | 6 | suppressPackageStartupMessages(library(dplyr)) 7 | 8 | # ------------------------------------------------------------------------------ 9 | 10 | set.seed(585) 11 | mnl_tr <- 12 | modeldata::sim_multinomial( 13 | 1000, 14 | ~ -0.5 + 0.6 * abs(A), 15 | ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2), 16 | ~ -0.6 * A + 0.50 * B - A * B) 17 | 18 | lvls <- levels(mnl_tr$class) 19 | num_class <- length(lvls) 20 | cls_xtab <- table(mnl_tr$class) 21 | min_class <- names(sort(cls_xtab))[1] 22 | 23 | cls_wts <- rep(1, num_class) 24 | names(cls_wts) <- lvls 25 | cls_wts[names(cls_wts) == min_class] <- 10 26 | 27 | bad_wts <- cls_wts 28 | names(bad_wts) <- letters[1:num_class] 29 | 30 | # ------------------------------------------------------------------------------ 31 | 32 | expect_equal( 33 | brulee:::check_class_weights(1.0, lvls, cls_xtab, "fabulous") |> 34 | as.numeric(), 35 | rep(1, num_class) 36 | ) 37 | 38 | expect_s3_class( 39 | brulee:::check_class_weights(1.0, lvls, cls_xtab, "fabulous"), 40 | "torch_tensor" 41 | ) 42 | 43 | expect_equal( 44 | brulee:::check_class_weights(NULL, lvls, cls_xtab, "fabulous") |> 45 | as.numeric(), 46 | rep(1, num_class) 47 | ) 48 | 49 | expect_equal( 50 | brulee:::check_class_weights(6.25, lvls, cls_xtab, "fabulous") |> 51 | as.numeric(), 52 | c(1, 6.25, 1) 53 | ) 54 | 55 | expect_equal( 56 | brulee:::check_class_weights(c(1, 6.25, 1), lvls, cls_xtab, "fabulous") |> 57 | as.numeric(), 58 | c(1, 6.25, 1) 59 | ) 60 | 61 | expect_null( 62 | brulee:::check_class_weights(1, character(0), cls_xtab, "fabulous") 63 | ) 64 | 65 | expect_snapshot( 66 | brulee:::check_class_weights("a", lvls, cls_xtab, "fabulous"), 67 | error = TRUE 68 | ) 69 | 70 | expect_snapshot( 71 | brulee:::check_class_weights(c(1, 6.25), lvls, cls_xtab, "fabulous"), 72 | error = TRUE 73 | ) 74 | 75 | expect_snapshot( 76 | brulee:::check_class_weights(bad_wts, lvls, cls_xtab, "fabulous"), 77 | error = TRUE 78 | ) 79 | 80 | }) 81 | -------------------------------------------------------------------------------- /tests/testthat/test-linear_reg-fit.R: -------------------------------------------------------------------------------- 1 | 2 | test_that("basic linear regression LBFGS", { 3 | skip_if_not(torch::torch_is_installed()) 4 | 5 | skip_if_not_installed("yardstick") 6 | 7 | suppressPackageStartupMessages(library(dplyr)) 8 | 9 | # ------------------------------------------------------------------------------ 10 | 11 | set.seed(1) 12 | lin_tr <- tibble::tibble( 13 | x1 = runif(1000), 14 | x2 = runif(1000), 15 | outcome = 3 + 2 * x1 + 3 * x2 16 | ) 17 | lin_te <- tibble::tibble( 18 | x1 = runif(1000), 19 | x2 = runif(1000), 20 | outcome = 3 + 2 * x1 + 3 * x2 21 | ) 22 | 23 | # ------------------------------------------------------------------------------ 24 | 25 | lm_fit <- lm(outcome ~ ., data = lin_tr) 26 | 27 | expect_error({ 28 | set.seed(392) 29 | lin_fit_lbfgs <- 30 | brulee_linear_reg(outcome ~ ., lin_tr, penlaty = 0)}, 31 | regex = NA) 32 | 33 | expect_equal( 34 | unname(coef(lm_fit)), 35 | unname(coef(lin_fit_lbfgs)), 36 | tolerance = .1 37 | ) 38 | 39 | expect_error( 40 | lin_pred_lbfgs <- 41 | predict(lin_fit_lbfgs, lin_te) |> 42 | bind_cols(lin_te), 43 | regex = NA) 44 | 45 | exp_str <- 46 | structure( 47 | list( 48 | .pred = numeric(0), 49 | x1 = numeric(0), 50 | x2 = numeric(0), 51 | outcome = numeric(0)), 52 | row.names = integer(0), 53 | class = c("tbl_df", "tbl", "data.frame")) 54 | 55 | expect_equal(lin_pred_lbfgs[0,], exp_str) 56 | expect_equal(nrow(lin_pred_lbfgs), nrow(lin_te)) 57 | 58 | # Did it learn anything? 59 | lin_brier_lbfgs <- 60 | lin_pred_lbfgs |> 61 | yardstick::rmse(outcome, .pred) 62 | 63 | set.seed(382) 64 | shuffled <- 65 | lin_pred_lbfgs |> 66 | mutate(outcome = sample(outcome)) |> 67 | yardstick::rmse(outcome, .pred) 68 | 69 | expect_true(lin_brier_lbfgs$.estimate < shuffled$.estimate ) 70 | }) 71 | 72 | test_that("basic Linear regression sgd", { 73 | skip_if_not(torch::torch_is_installed()) 74 | 75 | skip_if_not_installed("yardstick") 76 | 77 | suppressPackageStartupMessages(library(dplyr)) 78 | 79 | # ------------------------------------------------------------------------------ 80 | 81 | set.seed(1) 82 | lin_tr <- tibble::tibble( 83 | x1 = runif(1000), 84 | x2 = runif(1000), 85 | outcome = 3 + 2 * x1 + 3 * x2 86 | ) 87 | lin_te <- tibble::tibble( 88 | x1 = runif(1000), 89 | x2 = runif(1000), 90 | outcome = 3 + 2 * x1 + 3 * x2 91 | ) 92 | 93 | # ------------------------------------------------------------------------------ 94 | 95 | lm_fit <- lm(outcome ~ ., data = lin_tr) 96 | 97 | expect_error({ 98 | set.seed(392) 99 | lin_fit_sgd <- 100 | brulee_linear_reg( 101 | outcome ~ ., 102 | lin_tr, 103 | penlaty = 0, 104 | epochs = 500, 105 | batch_size = 2^5, 106 | learn_rate = 0.1, 107 | optimizer = "SGD", 108 | stop_iter = 20 109 | )}, 110 | regex = NA) 111 | 112 | expect_equal( 113 | unname(coef(lm_fit)), 114 | unname(coef(lin_fit_sgd)), 115 | tolerance = .1 116 | ) 117 | 118 | expect_error( 119 | lin_pred_sgd <- 120 | predict(lin_fit_sgd, lin_te) |> 121 | bind_cols(lin_te), 122 | regex = NA) 123 | 124 | exp_str <- 125 | structure( 126 | list( 127 | .pred = numeric(0), 128 | x1 = numeric(0), 129 | x2 = numeric(0), 130 | outcome = numeric(0)), 131 | row.names = integer(0), 132 | class = c("tbl_df", "tbl", "data.frame")) 133 | 134 | expect_equal(lin_pred_sgd[0,], exp_str) 135 | expect_equal(nrow(lin_pred_sgd), nrow(lin_te)) 136 | 137 | # Did it learn anything? 138 | lin_brier_sgd <- 139 | lin_pred_sgd |> 140 | yardstick::rmse(outcome, .pred) 141 | 142 | set.seed(382) 143 | shuffled <- 144 | lin_pred_sgd |> 145 | mutate(outcome = sample(outcome)) |> 146 | yardstick::rmse(outcome, .pred) 147 | 148 | expect_true(lin_brier_sgd$.estimate < shuffled$.estimate) 149 | }) 150 | -------------------------------------------------------------------------------- /tests/testthat/test-logistic_reg-fit.R: -------------------------------------------------------------------------------- 1 | 2 | test_that("basic logistic regression LBFGS", { 3 | skip_if_not(torch::torch_is_installed()) 4 | skip_if_not_installed("modeldata") 5 | skip_if_not_installed("yardstick") 6 | 7 | suppressPackageStartupMessages(library(dplyr)) 8 | 9 | # ------------------------------------------------------------------------------ 10 | 11 | set.seed(585) 12 | bin_tr <- modeldata::sim_logistic(5000, ~ -1 - 3 * A + 5 * B) 13 | bin_te <- modeldata::sim_logistic(1000, ~ -1 - 3 * A + 5 * B) 14 | num_class <- length(levels(bin_tr$class)) 15 | 16 | # ------------------------------------------------------------------------------ 17 | 18 | glm_fit <- glm(class ~ ., data = bin_tr, family = "binomial") 19 | 20 | expect_error({ 21 | set.seed(392) 22 | bin_fit_lbfgs <- 23 | brulee_logistic_reg(class ~ ., bin_tr, penlaty = 0, epochs = 1)}, 24 | regex = NA) 25 | 26 | expect_equal( 27 | unname(coef(glm_fit)), 28 | unname(coef(bin_fit_lbfgs)), 29 | tolerance = 1 30 | ) 31 | 32 | expect_error( 33 | bin_pred_lbfgs <- 34 | predict(bin_fit_lbfgs,bin_te) |> 35 | bind_cols(predict(bin_fit_lbfgs,bin_te, type = "prob")) |> 36 | bind_cols(bin_te), 37 | regex = NA) 38 | 39 | fact_str <- structure(integer(0), levels = c("one", "two"), class = "factor") 40 | exp_str <- 41 | structure( 42 | list(.pred_class = 43 | fact_str, 44 | .pred_one = numeric(0), 45 | .pred_two = numeric(0), 46 | A = numeric(0), 47 | B = numeric(0), 48 | class = fact_str), 49 | row.names = integer(0), 50 | class = c("tbl_df", "tbl", "data.frame")) 51 | 52 | expect_equal(bin_pred_lbfgs[0,], exp_str) 53 | expect_equal(nrow(bin_pred_lbfgs), nrow(bin_te)) 54 | 55 | # Did it learn anything? 56 | bin_brier_lbfgs <- 57 | bin_pred_lbfgs |> 58 | yardstick::brier_class(class, .pred_one) 59 | 60 | expect_true(bin_brier_lbfgs$.estimate < (1 - 1/num_class)^2) 61 | }) 62 | 63 | test_that("basic logistic regression SGD", { 64 | skip_if_not(torch::torch_is_installed()) 65 | skip_if_not_installed("modeldata") 66 | skip_if_not_installed("yardstick") 67 | 68 | suppressPackageStartupMessages(library(dplyr)) 69 | 70 | # ------------------------------------------------------------------------------ 71 | 72 | set.seed(585) 73 | bin_tr <- modeldata::sim_logistic(5000, ~ -1 - 3 * A + 5 * B) 74 | bin_te <- modeldata::sim_logistic(1000, ~ -1 - 3 * A + 5 * B) 75 | num_class <- length(levels(bin_tr$class)) 76 | 77 | # ------------------------------------------------------------------------------ 78 | 79 | expect_error({ 80 | set.seed(392) 81 | bin_fit_sgd <- 82 | brulee_logistic_reg(class ~ ., 83 | bin_tr, 84 | epochs = 500, 85 | penalty = 0, 86 | dropout = .1, 87 | optimize = "SGD", 88 | batch_size = 2^5, 89 | learn_rate = 0.1)}, 90 | regex = NA) 91 | 92 | glm_fit <- glm(class ~ ., data = bin_tr, family = "binomial") 93 | 94 | expect_equal( 95 | unname(coef(glm_fit)), 96 | unname(coef(bin_fit_sgd)), 97 | tolerance = .5 98 | ) 99 | 100 | expect_error( 101 | bin_pred_sgd <- 102 | predict(bin_fit_sgd,bin_te) |> 103 | bind_cols(predict(bin_fit_sgd,bin_te, type = "prob")) |> 104 | bind_cols(bin_te), 105 | regex = NA) 106 | 107 | # Did it learn anything? 108 | bin_brier_sgd <- 109 | bin_pred_sgd |> 110 | yardstick::brier_class(class, .pred_one) 111 | 112 | expect_true(bin_brier_sgd$.estimate < (1 - 1/num_class)^2) 113 | }) 114 | 115 | test_that("coef works when recipes are used", { 116 | skip_if_not(torch::torch_is_installed()) 117 | skip_if_not_installed("modeldata") 118 | skip_if_not_installed("recipes") 119 | skip_if(packageVersion("rlang") < "1.0.0") 120 | skip_on_os(c("windows", "linux", "solaris")) 121 | 122 | data("lending_club", package = "modeldata") 123 | lending_club <- head(lending_club, 1000) 124 | 125 | rec <- 126 | recipes::recipe(Class ~ revol_util + open_il_24m + emp_length, 127 | data = lending_club) |> 128 | recipes::step_dummy(emp_length, one_hot = TRUE) |> 129 | recipes::step_normalize(recipes::all_predictors()) 130 | 131 | fit_rec <- brulee_logistic_reg(rec, lending_club, epochs = 10L) 132 | 133 | coefs <- coef(fit_rec) 134 | expect_true(all(is.numeric(coefs))) 135 | expect_identical( 136 | names(coefs), 137 | c( 138 | "(Intercept)", "revol_util", "open_il_24m", 139 | paste0("emp_length_", levels(lending_club$emp_length)) 140 | ) 141 | ) 142 | }) 143 | 144 | 145 | # ------------------------------------------------------------------------------ 146 | 147 | test_that("logistic regression class weights", { 148 | skip_if_not(torch::torch_is_installed()) 149 | skip_if_not_installed("modeldata") 150 | skip_if_not_installed("yardstick") 151 | 152 | suppressPackageStartupMessages(library(dplyr)) 153 | 154 | # ------------------------------------------------------------------------------ 155 | 156 | set.seed(585) 157 | bin_tr <- modeldata::sim_logistic(5000, ~ -5 - 3 * A + 5 * B) 158 | bin_te <- modeldata::sim_logistic(1000, ~ -5 - 3 * A + 5 * B) 159 | num_class <- length(levels(bin_tr$class)) 160 | 161 | num_class <- length(levels(bin_tr$class)) 162 | cls_xtab <- table(bin_tr$class) 163 | min_class <- names(sort(cls_xtab))[1] 164 | cls_wts <- rep(1, num_class) 165 | names(cls_wts) <- levels(bin_tr$class) 166 | cls_wts[names(cls_wts) == min_class] <- 10 167 | 168 | # ------------------------------------------------------------------------------ 169 | 170 | expect_error({ 171 | set.seed(392) 172 | bin_fit_lbfgs_wts <- 173 | brulee_logistic_reg(class ~ ., 174 | bin_tr, 175 | epochs = 30, 176 | mixture = 0.5, 177 | rate_schedule = "decay_time", 178 | class_weights = cls_wts, 179 | learn_rate = 0.1)}, 180 | regex = NA) 181 | 182 | expect_error( 183 | bin_pred_lbfgs_wts <- 184 | predict(bin_fit_lbfgs_wts,bin_te) |> 185 | bind_cols(predict(bin_fit_lbfgs_wts,bin_te, type = "prob")) |> 186 | bind_cols(bin_te), 187 | regex = NA) 188 | 189 | ### matched unweighted model 190 | 191 | expect_error({ 192 | set.seed(392) 193 | bin_fit_lbfgs_unwt <- 194 | brulee_logistic_reg(class ~ ., 195 | bin_tr, 196 | epochs = 30, 197 | mixture = 0.5, 198 | rate_schedule = "decay_time", 199 | learn_rate = 0.1)}, 200 | regex = NA) 201 | 202 | expect_error( 203 | bin_pred_lbfgs_unwt <- 204 | predict(bin_fit_lbfgs_unwt,bin_te) |> 205 | bind_cols(predict(bin_fit_lbfgs_unwt,bin_te, type = "prob")) |> 206 | bind_cols(bin_te), 207 | regex = NA) 208 | 209 | # did weighting predict the majority class more often? 210 | expect_true( 211 | sum(bin_pred_lbfgs_wts$.pred_class == min_class) > 212 | sum(bin_pred_lbfgs_unwt$.pred_class == min_class) 213 | ) 214 | 215 | }) 216 | 217 | 218 | -------------------------------------------------------------------------------- /tests/testthat/test-mlp-activations.R: -------------------------------------------------------------------------------- 1 | 2 | test_that("activation functions", { 3 | skip_if(!torch::torch_is_installed()) 4 | skip_if_not_installed("modeldata") 5 | 6 | # ------------------------------------------------------------------------------ 7 | 8 | set.seed(1) 9 | df <- modeldata::sim_regression(500) 10 | 11 | acts <- brulee_activations() 12 | acts <- acts[acts != "linear"] 13 | 14 | for (i in acts) { 15 | expect_error({ 16 | set.seed(2) 17 | model <- brulee_mlp(outcome ~ ., data = df[1:400,], 18 | activation = i, 19 | learn_rate = 0.05, 20 | hidden_units = 10L) 21 | 22 | }, 23 | regex = NA 24 | ) 25 | 26 | r_sq <- cor(predict(model, df[401:500, -1])$.pred, df$outcome[401:500])^2 27 | 28 | # These do very poorly on this problems 29 | pass <- c("tanhshrink") 30 | 31 | if (!(i %in% pass)) { 32 | expect_true(r_sq > 0.1) 33 | } 34 | } 35 | 36 | }) 37 | 38 | -------------------------------------------------------------------------------- /tests/testthat/test-mlp-binary.R: -------------------------------------------------------------------------------- 1 | 2 | test_that("basic binomial mlp LBFGS", { 3 | skip_if_not(torch::torch_is_installed()) 4 | 5 | skip_if_not_installed("modeldata") 6 | skip_if_not_installed("yardstick") 7 | 8 | suppressPackageStartupMessages(library(dplyr)) 9 | suppressPackageStartupMessages(library(recipes)) 10 | 11 | # ------------------------------------------------------------------------------ 12 | 13 | set.seed(585) 14 | bin_tr <- modeldata::sim_classification(5000) 15 | bin_te <- modeldata::sim_classification(1000) 16 | 17 | rec <- 18 | recipe(class ~ ., data = bin_tr) |> 19 | step_normalize(all_predictors()) 20 | num_class <- length(levels(bin_tr$class)) 21 | 22 | # ------------------------------------------------------------------------------ 23 | 24 | expect_error({ 25 | set.seed(392) 26 | bin_fit_f_lbfgs <- 27 | brulee_mlp(class ~ ., 28 | bin_tr, 29 | epochs = 200, 30 | hidden_units = 5, 31 | rate_schedule = "cyclic", 32 | learn_rate = 0.1)}, 33 | regex = NA) 34 | 35 | 36 | expect_error({ 37 | set.seed(392) 38 | bin_fit_lbfgs <- 39 | brulee_mlp(rec, 40 | bin_tr, 41 | epochs = 200, 42 | hidden_units = 5, 43 | rate_schedule = "cyclic", 44 | learn_rate = 0.1)}, 45 | regex = NA) 46 | 47 | expect_error( 48 | bin_pred_lbfgs <- 49 | predict(bin_fit_lbfgs, bin_te) |> 50 | bind_cols(predict(bin_fit_lbfgs, bin_te, type = "prob")) |> 51 | bind_cols(bin_te) |> 52 | select(starts_with(".pred"), class), 53 | regex = NA) 54 | 55 | fact_str <- structure(integer(0), levels = c("class_1", "class_2"), class = "factor") 56 | exp_str <- 57 | structure( 58 | list(.pred_class = 59 | fact_str, 60 | .pred_class_1 = numeric(0), 61 | .pred_class_2 = numeric(0), 62 | class = fact_str), 63 | row.names = integer(0), 64 | class = c("tbl_df", "tbl", "data.frame")) 65 | 66 | expect_equal(bin_pred_lbfgs[0,], exp_str) 67 | expect_equal(nrow(bin_pred_lbfgs), nrow(bin_te)) 68 | 69 | # Did it learn anything? 70 | bin_brier_lbfgs <- 71 | bin_pred_lbfgs |> 72 | yardstick::brier_class(class, .pred_class_1) 73 | 74 | expect_true(bin_brier_lbfgs$.estimate < (1 - 1/num_class)^2) 75 | }) 76 | 77 | 78 | test_that("basic binomial mlp SGD", { 79 | skip_if_not(torch::torch_is_installed()) 80 | 81 | skip_if_not_installed("modeldata") 82 | skip_if_not_installed("yardstick") 83 | 84 | suppressPackageStartupMessages(library(dplyr)) 85 | suppressPackageStartupMessages(library(recipes)) 86 | 87 | # ------------------------------------------------------------------------------ 88 | 89 | set.seed(585) 90 | bin_tr <- modeldata::sim_classification(5000) 91 | bin_te <- modeldata::sim_classification(1000) 92 | 93 | rec <- 94 | recipe(class ~ ., data = bin_tr) |> 95 | step_normalize(all_predictors()) 96 | num_class <- length(levels(bin_tr$class)) 97 | 98 | # ------------------------------------------------------------------------------ 99 | 100 | expect_error({ 101 | set.seed(392) 102 | bin_fit_f_sgd <- 103 | brulee_mlp(class ~ ., 104 | bin_tr, 105 | epochs = 200, 106 | penalty = 0, 107 | dropout = .1, 108 | hidden_units = 5, 109 | optimize = "SGD", 110 | batch_size = 64, 111 | momentum = 0.5, 112 | learn_rate = 0.1)}, 113 | regex = NA) 114 | 115 | 116 | expect_error({ 117 | set.seed(392) 118 | bin_fit_sgd <- 119 | brulee_mlp(rec, 120 | bin_tr, 121 | epochs = 200, 122 | penalty = 0, 123 | dropout = .1, 124 | hidden_units = 5, 125 | optimize = "SGD", 126 | batch_size = 64, 127 | momentum = 0.5, 128 | learn_rate = 0.1)}, 129 | regex = NA) 130 | 131 | expect_error( 132 | bin_pred_sgd <- 133 | predict(bin_fit_sgd, bin_te) |> 134 | bind_cols(predict(bin_fit_sgd, bin_te, type = "prob")) |> 135 | bind_cols(bin_te) |> 136 | select(starts_with(".pred"), class), 137 | regex = NA) 138 | 139 | fact_str <- structure(integer(0), levels = c("class_1", "class_2"), class = "factor") 140 | exp_str <- 141 | structure( 142 | list(.pred_class = 143 | fact_str, 144 | .pred_class_1 = numeric(0), 145 | .pred_class_2 = numeric(0), 146 | class = fact_str), 147 | row.names = integer(0), 148 | class = c("tbl_df", "tbl", "data.frame")) 149 | 150 | expect_equal(bin_pred_sgd[0,], exp_str) 151 | expect_equal(nrow(bin_pred_sgd), nrow(bin_te)) 152 | 153 | # Did it learn anything? 154 | bin_brier_sgd <- 155 | bin_pred_sgd |> 156 | yardstick::brier_class(class, .pred_class_1) 157 | 158 | expect_true(bin_brier_sgd$.estimate < (1 - 1/num_class)^2) 159 | }) 160 | 161 | 162 | test_that("binomial mlp case weights", { 163 | skip_if_not(torch::torch_is_installed()) 164 | 165 | skip_if_not_installed("modeldata") 166 | skip_if_not_installed("yardstick") 167 | 168 | suppressPackageStartupMessages(library(dplyr)) 169 | suppressPackageStartupMessages(library(recipes)) 170 | 171 | # ------------------------------------------------------------------------------ 172 | 173 | set.seed(585) 174 | bin_tr <- modeldata::sim_classification(5000, intercept = 1) 175 | bin_te <- modeldata::sim_classification(1000, intercept = 1) 176 | 177 | rec <- 178 | recipe(class ~ ., data = bin_tr) |> 179 | step_normalize(all_predictors()) 180 | num_class <- length(levels(bin_tr$class)) 181 | 182 | # ------------------------------------------------------------------------------ 183 | 184 | expect_error({ 185 | set.seed(392) 186 | weighted <- 187 | brulee_mlp(rec, 188 | bin_tr, 189 | epochs = 200, 190 | hidden_units = 5, 191 | rate_schedule = "cyclic", 192 | class_weights = 10, 193 | learn_rate = 0.1)}, 194 | regex = NA) 195 | 196 | expect_error( 197 | weighted_pred <- 198 | predict(weighted, bin_te) |> 199 | bind_cols(predict(weighted, bin_te, type = "prob")) |> 200 | bind_cols(bin_te) |> 201 | select(starts_with(".pred"), class), 202 | regex = NA) 203 | 204 | 205 | expect_error({ 206 | set.seed(392) 207 | unweighted <- 208 | brulee_mlp(rec, 209 | bin_tr, 210 | epochs = 200, 211 | hidden_units = 5, 212 | rate_schedule = "cyclic", 213 | learn_rate = 0.1)}, 214 | regex = NA) 215 | 216 | expect_error( 217 | unweighted_pred <- 218 | predict(unweighted, bin_te) |> 219 | bind_cols(predict(unweighted, bin_te, type = "prob")) |> 220 | bind_cols(bin_te) |> 221 | select(starts_with(".pred"), class), 222 | regex = NA) 223 | 224 | expect_true( 225 | sum(weighted_pred$.pred_class == "class_2") > 226 | sum(unweighted_pred$.pred_class == "class_2") 227 | ) 228 | }) 229 | 230 | test_that('linear activations', { 231 | # See https://github.com/tidymodels/brulee/issues/68 232 | skip_if(!torch::torch_is_installed()) 233 | skip_if_not_installed("modeldata") 234 | 235 | data(bivariate, package = "modeldata") 236 | set.seed(20) 237 | nn_log_biv <- 238 | try( 239 | brulee_mlp(Class ~ log(A) + log(B), data = bivariate_train, 240 | epochs = 150, hidden_units = 3, activation = "linear"), 241 | silent = TRUE) 242 | expect_s3_class(nn_log_biv, "brulee_mlp") 243 | 244 | }) 245 | -------------------------------------------------------------------------------- /tests/testthat/test-mlp-multinomial.R: -------------------------------------------------------------------------------- 1 | 2 | test_that("basic multinomial mlp LBFGS", { 3 | skip_if_not(torch::torch_is_installed()) 4 | skip_if_not_installed("modeldata") 5 | skip_if_not_installed("yardstick") 6 | 7 | suppressPackageStartupMessages(library(dplyr)) 8 | 9 | # ------------------------------------------------------------------------------ 10 | 11 | set.seed(585) 12 | mnl_tr <- 13 | modeldata::sim_multinomial( 14 | 1000, 15 | ~ -0.5 + 0.6 * abs(A), 16 | ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2), 17 | ~ -0.6 * A + 0.50 * B - A * B) 18 | mnl_te <- 19 | modeldata::sim_multinomial( 20 | 200, 21 | ~ -0.5 + 0.6 * abs(A), 22 | ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2), 23 | ~ -0.6 * A + 0.50 * B - A * B) 24 | num_class <- length(levels(mnl_tr$class)) 25 | 26 | # ------------------------------------------------------------------------------ 27 | 28 | expect_error({ 29 | set.seed(392) 30 | mnl_fit_lbfgs <- 31 | brulee_mlp(class ~ ., 32 | mnl_tr, 33 | epochs = 10, 34 | hidden_units = 5, 35 | rate_schedule = "cyclic", 36 | learn_rate = 0.1)}, 37 | regex = NA) 38 | 39 | expect_error( 40 | mnl_pred_lbfgs <- 41 | predict(mnl_fit_lbfgs, mnl_te) |> 42 | bind_cols(predict(mnl_fit_lbfgs, mnl_te, type = "prob")) |> 43 | bind_cols(mnl_te), 44 | regex = NA) 45 | 46 | fact_str <- structure(integer(0), levels = c("one", "two", "three"), class = "factor") 47 | exp_str <- 48 | structure( 49 | list(.pred_class = 50 | fact_str, 51 | .pred_one = numeric(0), 52 | .pred_two = numeric(0), 53 | .pred_three = numeric(0), 54 | A = numeric(0), 55 | B = numeric(0), 56 | class = fact_str), 57 | row.names = integer(0), 58 | class = c("tbl_df", "tbl", "data.frame")) 59 | 60 | expect_equal(mnl_pred_lbfgs[0,], exp_str) 61 | expect_equal(nrow(mnl_pred_lbfgs), nrow(mnl_te)) 62 | 63 | # Did it learn anything? 64 | mnl_brier_lbfgs <- 65 | mnl_pred_lbfgs |> 66 | yardstick::brier_class(class, .pred_one, .pred_two, .pred_three) 67 | 68 | expect_true(mnl_brier_lbfgs$.estimate < (1 - 1/num_class)^2) 69 | }) 70 | 71 | test_that("basic multinomial mlp SGD", { 72 | skip_if_not(torch::torch_is_installed()) 73 | skip_if_not_installed("modeldata") 74 | skip_if_not_installed("yardstick") 75 | 76 | suppressPackageStartupMessages(library(dplyr)) 77 | 78 | # ------------------------------------------------------------------------------ 79 | 80 | set.seed(585) 81 | mnl_tr <- 82 | modeldata::sim_multinomial( 83 | 1000, 84 | ~ -0.5 + 0.6 * abs(A), 85 | ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2), 86 | ~ -0.6 * A + 0.50 * B - A * B) 87 | mnl_te <- 88 | modeldata::sim_multinomial( 89 | 200, 90 | ~ -0.5 + 0.6 * abs(A), 91 | ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2), 92 | ~ -0.6 * A + 0.50 * B - A * B) 93 | num_class <- length(levels(mnl_tr$class)) 94 | 95 | # ------------------------------------------------------------------------------ 96 | 97 | expect_error({ 98 | set.seed(392) 99 | mnl_fit_sgd <- 100 | brulee_mlp(class ~ ., 101 | mnl_tr, 102 | epochs = 200, 103 | penalty = 0, 104 | dropout = .1, 105 | hidden_units = 5, 106 | optimize = "SGD", 107 | batch_size = 64, 108 | momentum = 0.5, 109 | learn_rate = 0.1)}, 110 | regex = NA) 111 | 112 | expect_error( 113 | mnl_pred_sgd <- 114 | predict(mnl_fit_sgd, mnl_te) |> 115 | bind_cols(predict(mnl_fit_sgd, mnl_te, type = "prob")) |> 116 | bind_cols(mnl_te), 117 | regex = NA) 118 | 119 | # Did it learn anything? 120 | mnl_brier_sgd <- 121 | mnl_pred_sgd |> 122 | yardstick::brier_class(class, .pred_one, .pred_two, .pred_three) 123 | 124 | expect_true(mnl_brier_sgd$.estimate < (1 - 1/num_class)^2) 125 | }) 126 | 127 | 128 | # ------------------------------------------------------------------------------ 129 | 130 | test_that("multinomial mlp class weights", { 131 | skip_if_not(torch::torch_is_installed()) 132 | skip_if_not_installed("modeldata") 133 | skip_if_not_installed("yardstick") 134 | 135 | suppressPackageStartupMessages(library(dplyr)) 136 | 137 | # ------------------------------------------------------------------------------ 138 | 139 | set.seed(585) 140 | mnl_tr <- 141 | modeldata::sim_multinomial( 142 | 1000, 143 | ~ -0.5 + 0.6 * abs(A), 144 | ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2), 145 | ~ -0.6 * A + 0.50 * B - A * B) 146 | mnl_te <- 147 | modeldata::sim_multinomial( 148 | 200, 149 | ~ -0.5 + 0.6 * abs(A), 150 | ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2), 151 | ~ -0.6 * A + 0.50 * B - A * B) 152 | 153 | num_class <- length(levels(mnl_tr$class)) 154 | cls_xtab <- table(mnl_tr$class) 155 | min_class <- names(sort(cls_xtab))[1] 156 | cls_wts <- rep(1, num_class) 157 | names(cls_wts) <- levels(mnl_tr$class) 158 | cls_wts[names(cls_wts) == min_class] <- 10 159 | 160 | # ------------------------------------------------------------------------------ 161 | 162 | expect_error({ 163 | set.seed(392) 164 | mnl_fit_lbfgs_wts <- 165 | brulee_mlp(class ~ ., 166 | mnl_tr, 167 | epochs = 30, 168 | hidden_units = 5, 169 | rate_schedule = "decay_time", 170 | class_weights = cls_wts, 171 | stop_iter = 100, 172 | learn_rate = 0.1)}, 173 | regex = NA) 174 | 175 | expect_error( 176 | mnl_pred_lbfgs_wts <- 177 | predict(mnl_fit_lbfgs_wts, mnl_te) |> 178 | bind_cols(predict(mnl_fit_lbfgs_wts, mnl_te, type = "prob")) |> 179 | bind_cols(mnl_te), 180 | regex = NA) 181 | 182 | mnl_brier_lbfgs_wts <- 183 | mnl_pred_lbfgs_wts |> 184 | yardstick::brier_class(class, .pred_one, .pred_two, .pred_three) 185 | 186 | expect_true(mnl_brier_lbfgs_wts$.estimate < (1 - 1/num_class)^2) 187 | 188 | ### matched unweighted model 189 | 190 | expect_error({ 191 | set.seed(392) 192 | mnl_fit_lbfgs_unwt <- 193 | brulee_mlp(class ~ ., 194 | mnl_tr, 195 | epochs = 30, 196 | hidden_units = 5, 197 | rate_schedule = "decay_time", 198 | stop_iter = 100, 199 | learn_rate = 0.1)}, 200 | regex = NA) 201 | 202 | expect_error( 203 | mnl_pred_lbfgs_unwt <- 204 | predict(mnl_fit_lbfgs_unwt, mnl_te) |> 205 | bind_cols(predict(mnl_fit_lbfgs_unwt, mnl_te, type = "prob")) |> 206 | bind_cols(mnl_te), 207 | regex = NA) 208 | 209 | # did weighting predict the majority class more often? 210 | expect_true( 211 | sum(mnl_pred_lbfgs_wts$.pred_class == min_class) > 212 | sum(mnl_pred_lbfgs_unwt$.pred_class == min_class) 213 | ) 214 | 215 | }) 216 | 217 | 218 | -------------------------------------------------------------------------------- /tests/testthat/test-mlp-regression.R: -------------------------------------------------------------------------------- 1 | 2 | test_that('basic regression mlp LBFGS', { 3 | skip_if(!torch::torch_is_installed()) 4 | 5 | skip_if_not_installed("modeldata") 6 | skip_if_not_installed("yardstick") 7 | skip_if_not_installed("recipes") 8 | 9 | suppressPackageStartupMessages(library(dplyr)) 10 | suppressPackageStartupMessages(library(recipes)) 11 | 12 | # ------------------------------------------------------------------------------ 13 | 14 | set.seed(585) 15 | reg_tr <- modeldata::sim_regression(5000) 16 | reg_te <- modeldata::sim_regression(1000) 17 | 18 | reg_tr_x_df <- reg_tr[, -1] 19 | reg_tr_x_mat <- as.matrix(reg_tr_x_df) 20 | reg_tr_y <- reg_tr$outcome 21 | 22 | reg_rec <- 23 | recipe(outcome ~ ., data = reg_tr) |> 24 | step_normalize(all_predictors()) 25 | 26 | # ------------------------------------------------------------------------------ 27 | 28 | # matrix x 29 | expect_error({ 30 | set.seed(1) 31 | mlp_reg_mat_lbfgs_fit <- 32 | brulee_mlp(reg_tr_x_mat, reg_tr_y, mixture = 0, learn_rate = .1)}, 33 | regex = NA 34 | ) 35 | 36 | # data frame x (all numeric) 37 | expect_error( 38 | mlp_reg_df_lbfgs_fit <- brulee_mlp(reg_tr_x_df, reg_tr_y, validation = .2), 39 | regex = NA 40 | ) 41 | 42 | # formula (mixed) 43 | expect_error({ 44 | set.seed(8373) 45 | mlp_reg_f_lbfgs_fit <- brulee_mlp(outcome ~ ., reg_tr)}, 46 | regex = NA 47 | ) 48 | 49 | # recipe 50 | expect_error({ 51 | set.seed(8373) 52 | mlp_reg_rec_lbfgs_fit <- brulee_mlp(reg_rec, reg_tr)}, 53 | regex = NA 54 | ) 55 | 56 | expect_error( 57 | reg_pred_lbfgs <- 58 | predict(mlp_reg_rec_lbfgs_fit, reg_te) |> 59 | bind_cols(reg_te) |> 60 | select(-starts_with("predictor")), 61 | regex = NA) 62 | 63 | exp_str <- 64 | structure(list(.pred = numeric(0), outcome = numeric(0)), 65 | row.names = integer(0), class = c("tbl_df", "tbl", "data.frame")) 66 | 67 | expect_equal(reg_pred_lbfgs[0,], exp_str) 68 | expect_equal(nrow(reg_pred_lbfgs), nrow(reg_te)) 69 | 70 | # Did it learn anything? 71 | reg_rmse_lbfgs <- 72 | reg_pred_lbfgs |> 73 | yardstick::rmse(outcome, .pred) 74 | 75 | set.seed(382) 76 | shuffled <- 77 | reg_pred_lbfgs |> 78 | mutate(outcome = sample(outcome)) |> 79 | yardstick::rmse(outcome, .pred) 80 | 81 | expect_true(reg_rmse_lbfgs$.estimate < shuffled$.estimate ) 82 | }) 83 | 84 | 85 | test_that('bad args', { 86 | skip_if(!torch::torch_is_installed()) 87 | 88 | skip_if_not_installed("recipes") 89 | 90 | suppressPackageStartupMessages(library(dplyr)) 91 | suppressPackageStartupMessages(library(recipes)) 92 | 93 | # ------------------------------------------------------------------------------ 94 | 95 | data(ames, package = "modeldata") 96 | 97 | ames$Sale_Price <- log10(ames$Sale_Price) 98 | 99 | reg_x_df <- ames[, c("Longitude", "Latitude")] 100 | reg_x_df_mixed <- ames[, c("Longitude", "Latitude", "Alley")] 101 | reg_x_mat <- as.matrix(reg_x_df) 102 | reg_y <- ames$Sale_Price 103 | reg_smol <- ames[, c("Longitude", "Latitude", "Alley", "Sale_Price")] 104 | 105 | reg_rec <- 106 | recipe(Sale_Price ~ Longitude + Latitude + Alley, data = ames) |> 107 | step_dummy(Alley) |> 108 | step_normalize(all_predictors()) 109 | 110 | # ------------------------------------------------------------------------------ 111 | 112 | expect_snapshot( 113 | brulee_mlp(reg_x_mat, reg_y, epochs = NA), 114 | error = TRUE 115 | ) 116 | expect_snapshot( 117 | brulee_mlp(reg_x_mat, reg_y, epochs = 1:2), 118 | error = TRUE 119 | ) 120 | expect_snapshot( 121 | brulee_mlp(reg_x_mat, reg_y, epochs = 0L), 122 | error = TRUE 123 | ) 124 | expect_error( 125 | brulee_mlp(reg_x_mat, reg_y, epochs = 2), 126 | regex = NA 127 | ) 128 | 129 | expect_snapshot( 130 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, hidden_units = NA), 131 | error = TRUE 132 | ) 133 | 134 | expect_snapshot( 135 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, hidden_units = -1L), 136 | error = TRUE 137 | ) 138 | expect_error( 139 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, hidden_units = 2), 140 | regex = NA 141 | ) 142 | 143 | expect_snapshot( 144 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, activation = NA), 145 | error = TRUE 146 | ) 147 | 148 | expect_snapshot( 149 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, penalty = NA), 150 | error = TRUE 151 | ) 152 | expect_snapshot( 153 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, penalty = runif(2)), 154 | error = TRUE 155 | ) 156 | expect_snapshot( 157 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, penalty = -1.1), 158 | error = TRUE 159 | ) 160 | 161 | expect_snapshot( 162 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, dropout = NA), 163 | error = TRUE 164 | ) 165 | expect_snapshot( 166 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, dropout = runif(2)), 167 | error = TRUE 168 | ) 169 | expect_snapshot( 170 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, dropout = -1.1), 171 | error = TRUE 172 | ) 173 | expect_snapshot( 174 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, dropout = 1.0), 175 | error = TRUE 176 | ) 177 | expect_error( 178 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, dropout = 0), 179 | regex = NA 180 | ) 181 | 182 | expect_snapshot( 183 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, validation = NA), 184 | error = TRUE 185 | ) 186 | expect_snapshot( 187 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, validation = runif(2)), 188 | error = TRUE 189 | ) 190 | expect_snapshot( 191 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, validation = -1.1), 192 | error = TRUE 193 | ) 194 | expect_snapshot( 195 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, validation = 1.0), 196 | error = TRUE 197 | ) 198 | expect_error( 199 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, validation = 0), 200 | regex = NA 201 | ) 202 | 203 | 204 | expect_snapshot( 205 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, learn_rate = NA), 206 | error = TRUE 207 | ) 208 | expect_snapshot( 209 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, learn_rate = runif(2)), 210 | error = TRUE 211 | ) 212 | expect_snapshot( 213 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, learn_rate = -1.1), 214 | error = TRUE 215 | ) 216 | 217 | expect_snapshot( 218 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, verbose = 2), 219 | error = TRUE 220 | ) 221 | expect_snapshot( 222 | brulee_mlp(reg_x_mat, reg_y, epochs = 2, verbose = rep(TRUE, 10)), 223 | error = TRUE 224 | ) 225 | # ------------------------------------------------------------------------------ 226 | 227 | fit_mat <- brulee_mlp(reg_x_mat, reg_y, epochs = 10L) 228 | 229 | bad_models <- fit_mat 230 | bad_models$model_obj <- "potato!" 231 | expect_snapshot( 232 | brulee:::new_brulee_mlp( 233 | model_obj = bad_models$model_obj, 234 | estimates = bad_models$estimates, 235 | best_epoch = bad_models$best_epoch, 236 | loss = bad_models$loss, 237 | dims = bad_models$dims, 238 | y_stats = bad_models$y_stats, 239 | parameters = bad_models$parameters, 240 | blueprint = bad_models$blueprint 241 | ), 242 | error = TRUE 243 | ) 244 | 245 | bad_est <- fit_mat 246 | bad_est$estimates <- "potato!" 247 | expect_snapshot( 248 | brulee:::new_brulee_mlp( 249 | model_obj = bad_est$model_obj, 250 | estimates = bad_est$estimates, 251 | best_epoch = bad_est$best_epoch, 252 | loss = bad_est$loss, 253 | dims = bad_est$dims, 254 | y_stats = bad_est$y_stats, 255 | parameters = bad_est$parameters, 256 | blueprint = bad_est$blueprint 257 | ), 258 | error = TRUE 259 | ) 260 | 261 | bad_loss <- fit_mat 262 | bad_loss$loss <- "potato!" 263 | expect_snapshot( 264 | brulee:::new_brulee_mlp( 265 | model_obj = bad_loss$model_obj, 266 | estimates = bad_loss$estimates, 267 | best_epoch = bad_loss$best_epoch, 268 | loss = bad_loss$loss, 269 | dims = bad_loss$dims, 270 | y_stats = bad_loss$y_stats, 271 | parameters = bad_loss$parameters, 272 | blueprint = bad_loss$blueprint 273 | ), 274 | error = TRUE 275 | ) 276 | 277 | bad_dims <- fit_mat 278 | bad_dims$dims <- "mountainous" 279 | expect_snapshot( 280 | brulee:::new_brulee_mlp( 281 | model_obj = bad_dims$model_obj, 282 | estimates = bad_dims$estimates, 283 | best_epoch = bad_dims$best_epoch, 284 | loss = bad_dims$loss, 285 | dims = bad_dims$dims, 286 | y_stats = bad_dims$y_stats, 287 | parameters = bad_dims$parameters, 288 | blueprint = bad_dims$blueprint 289 | ), 290 | error = TRUE 291 | ) 292 | 293 | 294 | bad_parameters <- fit_mat 295 | bad_parameters$dims <- "mitten" 296 | expect_snapshot( 297 | brulee:::new_brulee_mlp( 298 | model_obj = bad_parameters$model_obj, 299 | estimates = bad_parameters$estimates, 300 | best_epoch = bad_parameters$best_epoch, 301 | loss = bad_parameters$loss, 302 | dims = bad_parameters$dims, 303 | y_stats = bad_parameters$y_stats, 304 | parameters = bad_parameters$parameters, 305 | blueprint = bad_parameters$blueprint 306 | ), 307 | error = TRUE 308 | ) 309 | 310 | 311 | bad_blueprint <- fit_mat 312 | bad_blueprint$blueprint <- "adorable" 313 | expect_snapshot( 314 | brulee:::new_brulee_mlp( 315 | model_obj = bad_blueprint$model_obj, 316 | estimates = bad_blueprint$estimates, 317 | best_epoch = bad_blueprint$best_epoch, 318 | loss = bad_blueprint$loss, 319 | dims = bad_blueprint$dims, 320 | y_stats = bad_blueprint$y_stats, 321 | parameters = bad_blueprint$parameters, 322 | blueprint = bad_blueprint$blueprint 323 | ), 324 | error = TRUE 325 | ) 326 | }) 327 | 328 | test_that("mlp learns something", { 329 | skip_if(!torch::torch_is_installed()) 330 | 331 | # ------------------------------------------------------------------------------ 332 | 333 | set.seed(1) 334 | x <- data.frame(x = rnorm(1000)) 335 | y <- 2 * x$x 336 | 337 | set.seed(2) 338 | model <- brulee_mlp(x, y, 339 | batch_size = 25, 340 | epochs = 50, 341 | optimizer = "SGD", 342 | activation = "relu", 343 | hidden_units = 5L, 344 | learn_rate = 0.1, 345 | dropout = 0) 346 | 347 | expect_true(tail(model$loss, 1) < 0.03) 348 | 349 | }) 350 | test_that("variable hidden_units length", { 351 | skip_if(!torch::torch_is_installed()) 352 | 353 | x <- data.frame(x = rnorm(1000)) 354 | y <- 2 * x$x 355 | 356 | expect_error( 357 | model <- brulee_mlp(x, y, hidden_units = c(2, 3), epochs = 1), 358 | regexp = NA 359 | ) 360 | 361 | expect_equal(length(unlist(coef(model))), (1*2 + 2) + (2*3 + 3) + (3*1 + 1)) 362 | 363 | 364 | expect_snapshot( 365 | model <- brulee_mlp(x, y, hidden_units = c(2, 3, 4), epochs = 1, 366 | activation = c("relu", "tanh")), 367 | error = TRUE 368 | ) 369 | 370 | expect_snapshot( 371 | model <- brulee_mlp(x, y, hidden_units = c(1), epochs = 1, 372 | activation = c("relu", "tanh")), 373 | error = TRUE 374 | ) 375 | 376 | }) 377 | 378 | 379 | test_that('two-layer networks', { 380 | skip_if(!torch::torch_is_installed()) 381 | 382 | skip_if_not_installed("modeldata") 383 | skip_if_not_installed("yardstick") 384 | skip_if_not_installed("recipes") 385 | 386 | suppressPackageStartupMessages(library(dplyr)) 387 | suppressPackageStartupMessages(library(recipes)) 388 | 389 | # ------------------------------------------------------------------------------ 390 | 391 | set.seed(585) 392 | reg_tr <- modeldata::sim_regression(5000) 393 | reg_te <- modeldata::sim_regression(1000) 394 | 395 | reg_tr_x_df <- reg_tr[, -1] 396 | reg_tr_x_mat <- as.matrix(reg_tr_x_df) 397 | reg_tr_y <- reg_tr$outcome 398 | 399 | reg_rec <- 400 | recipe(outcome ~ ., data = reg_tr) |> 401 | step_normalize(all_predictors()) 402 | 403 | # ------------------------------------------------------------------------------ 404 | 405 | # matrix x 406 | expect_error({ 407 | set.seed(1) 408 | mlp_reg_mat_two_fit <- 409 | brulee_mlp_two_layer( 410 | reg_tr_x_mat, 411 | reg_tr_y, 412 | mixture = 0, 413 | learn_rate = .1, 414 | hidden_units = 5, 415 | hidden_units_2 = 10, 416 | activation = "relu", 417 | activation_2 = "elu" 418 | ) 419 | }, 420 | regex = NA) 421 | 422 | expect_error({ 423 | set.seed(1) 424 | mlp_reg_mat_two_check_fit <- 425 | brulee_mlp( 426 | reg_tr_x_mat, 427 | reg_tr_y, 428 | mixture = 0, 429 | learn_rate = .1, 430 | hidden_units = c(5, 10), 431 | activation = c("relu", "elu") 432 | ) 433 | }, 434 | regex = NA) 435 | 436 | expect_equal(mlp_reg_mat_two_fit$loss, mlp_reg_mat_two_check_fit$loss) 437 | 438 | # data frame x (all numeric) 439 | expect_error( 440 | mlp_reg_df_two_fit <- 441 | brulee_mlp_two_layer( 442 | reg_tr_x_df, 443 | reg_tr_y, 444 | validation = .2, 445 | hidden_units = 5, 446 | hidden_units_2 = 10, 447 | activation = "celu", 448 | activation_2 = "gelu" 449 | ), 450 | regex = NA 451 | ) 452 | 453 | # formula (mixed) 454 | expect_error({ 455 | set.seed(8373) 456 | mlp_reg_f_two_fit <- brulee_mlp_two_layer( 457 | outcome ~ ., 458 | reg_tr, 459 | hidden_units = 5, 460 | hidden_units_2 = 10, 461 | activation = "hardshrink", 462 | activation_2 = "hardsigmoid" 463 | ) 464 | }, 465 | regex = NA) 466 | 467 | # recipe 468 | expect_error({ 469 | set.seed(8373) 470 | mlp_reg_rec_two_fit <- brulee_mlp_two_layer( 471 | reg_rec, 472 | reg_tr, 473 | hidden_units = 5, 474 | hidden_units_2 = 10, 475 | activation = "hardtanh", 476 | activation_2 = "sigmoid" 477 | ) 478 | }, 479 | regex = NA) 480 | 481 | }) 482 | -------------------------------------------------------------------------------- /tests/testthat/test-multinomial_reg-fit.R: -------------------------------------------------------------------------------- 1 | 2 | test_that("basic multinomial regression LBFGS", { 3 | skip_if_not(torch::torch_is_installed()) 4 | skip_if_not_installed("modeldata") 5 | skip_if_not_installed("yardstick") 6 | 7 | suppressPackageStartupMessages(library(dplyr)) 8 | 9 | # ------------------------------------------------------------------------------ 10 | 11 | set.seed(585) 12 | mnl_tr <- 13 | modeldata::sim_multinomial( 14 | 1000, 15 | ~ -0.5 + 0.6 * A, 16 | ~ .1 * B, 17 | ~ -0.6 * A + 0.50 * B) 18 | mnl_te <- 19 | modeldata::sim_multinomial( 20 | 200, 21 | ~ -0.5 + 0.6 * A, 22 | ~ .1 * B, 23 | ~ -0.6 * A + 0.50 * B) 24 | num_class <- length(levels(mnl_tr$class)) 25 | 26 | # ------------------------------------------------------------------------------ 27 | 28 | expect_error({ 29 | set.seed(392) 30 | mnl_fit_lbfgs <- 31 | brulee_multinomial_reg(class ~ ., 32 | mnl_tr, 33 | epochs = 200, 34 | rate_schedule = "cyclic", 35 | learn_rate = 0.1)}, 36 | regex = NA) 37 | 38 | expect_error( 39 | mnl_pred_lbfgs <- 40 | predict(mnl_fit_lbfgs, mnl_te) |> 41 | bind_cols(predict(mnl_fit_lbfgs, mnl_te, type = "prob")) |> 42 | bind_cols(mnl_te), 43 | regex = NA) 44 | 45 | fact_str <- structure(integer(0), levels = c("one", "two", "three"), class = "factor") 46 | exp_str <- 47 | structure( 48 | list(.pred_class = 49 | fact_str, 50 | .pred_one = numeric(0), 51 | .pred_two = numeric(0), 52 | .pred_three = numeric(0), 53 | A = numeric(0), 54 | B = numeric(0), 55 | class = fact_str), 56 | row.names = integer(0), 57 | class = c("tbl_df", "tbl", "data.frame")) 58 | 59 | expect_equal(mnl_pred_lbfgs[0,], exp_str) 60 | expect_equal(nrow(mnl_pred_lbfgs), nrow(mnl_te)) 61 | 62 | # Did it learn anything? 63 | mnl_brier_lbfgs <- 64 | mnl_pred_lbfgs |> 65 | yardstick::brier_class(class, .pred_one, .pred_two, .pred_three) 66 | 67 | expect_true(mnl_brier_lbfgs$.estimate < (1 - 1/num_class)^2) 68 | }) 69 | 70 | test_that("basic multinomial regression SGD", { 71 | skip_if_not(torch::torch_is_installed()) 72 | skip_if_not_installed("modeldata") 73 | skip_if_not_installed("yardstick") 74 | 75 | suppressPackageStartupMessages(library(dplyr)) 76 | 77 | # ------------------------------------------------------------------------------ 78 | 79 | set.seed(585) 80 | mnl_tr <- 81 | modeldata::sim_multinomial( 82 | 1000, 83 | ~ -0.5 + 0.6 * A, 84 | ~ .1 * B, 85 | ~ -0.6 * A + 0.50 * B) 86 | mnl_te <- 87 | modeldata::sim_multinomial( 88 | 200, 89 | ~ -0.5 + 0.6 * A, 90 | ~ .1 * B, 91 | ~ -0.6 * A + 0.50 * B) 92 | num_class <- length(levels(mnl_tr$class)) 93 | 94 | # ------------------------------------------------------------------------------ 95 | 96 | expect_error({ 97 | set.seed(392) 98 | mnl_fit_sgd <- 99 | brulee_multinomial_reg(class ~ ., 100 | mnl_tr, 101 | epochs = 200, 102 | penalty = 0, 103 | dropout = .1, 104 | optimize = "SGD", 105 | batch_size = 64, 106 | momentum = 0.5, 107 | learn_rate = 0.1)}, 108 | regex = NA) 109 | 110 | expect_error( 111 | mnl_pred_sgd <- 112 | predict(mnl_fit_sgd, mnl_te) |> 113 | bind_cols(predict(mnl_fit_sgd, mnl_te, type = "prob")) |> 114 | bind_cols(mnl_te), 115 | regex = NA) 116 | 117 | # Did it learn anything? 118 | mnl_brier_sgd <- 119 | mnl_pred_sgd |> 120 | yardstick::brier_class(class, .pred_one, .pred_two, .pred_three) 121 | 122 | expect_true(mnl_brier_sgd$.estimate < (1 - 1/num_class)^2) 123 | }) 124 | 125 | 126 | # ------------------------------------------------------------------------------ 127 | 128 | test_that("multinomial regression class weights", { 129 | skip_if_not(torch::torch_is_installed()) 130 | skip_if_not_installed("modeldata") 131 | skip_if_not_installed("yardstick") 132 | 133 | suppressPackageStartupMessages(library(dplyr)) 134 | 135 | # ------------------------------------------------------------------------------ 136 | 137 | set.seed(585) 138 | mnl_tr <- 139 | modeldata::sim_multinomial( 140 | 1000, 141 | ~ -0.5 + 0.6 * A, 142 | ~ .1 * B, 143 | ~ -0.6 * A + 0.50 * B) 144 | mnl_te <- 145 | modeldata::sim_multinomial( 146 | 200, 147 | ~ -0.5 + 0.6 * A, 148 | ~ .1 * B, 149 | ~ -0.6 * A + 0.50 * B) 150 | 151 | num_class <- length(levels(mnl_tr$class)) 152 | cls_xtab <- table(mnl_tr$class) 153 | min_class <- names(sort(cls_xtab))[1] 154 | cls_wts <- rep(1, num_class) 155 | names(cls_wts) <- levels(mnl_tr$class) 156 | cls_wts[names(cls_wts) == min_class] <- 10 157 | 158 | # ------------------------------------------------------------------------------ 159 | 160 | expect_error({ 161 | set.seed(392) 162 | mnl_fit_lbfgs_wts <- 163 | brulee_multinomial_reg(class ~ ., 164 | mnl_tr, 165 | epochs = 30, 166 | mixture = 0.5, 167 | rate_schedule = "decay_time", 168 | class_weights = cls_wts, 169 | learn_rate = 0.1)}, 170 | regex = NA) 171 | 172 | expect_error( 173 | mnl_pred_lbfgs_wts <- 174 | predict(mnl_fit_lbfgs_wts, mnl_te) |> 175 | bind_cols(predict(mnl_fit_lbfgs_wts, mnl_te, type = "prob")) |> 176 | bind_cols(mnl_te), 177 | regex = NA) 178 | 179 | ### matched unweighted model 180 | 181 | expect_error({ 182 | set.seed(392) 183 | mnl_fit_lbfgs_unwt <- 184 | brulee_multinomial_reg(class ~ ., 185 | mnl_tr, 186 | epochs = 30, 187 | mixture = 0.5, 188 | rate_schedule = "decay_time", 189 | learn_rate = 0.1)}, 190 | regex = NA) 191 | 192 | expect_error( 193 | mnl_pred_lbfgs_unwt <- 194 | predict(mnl_fit_lbfgs_unwt, mnl_te) |> 195 | bind_cols(predict(mnl_fit_lbfgs_unwt, mnl_te, type = "prob")) |> 196 | bind_cols(mnl_te), 197 | regex = NA) 198 | 199 | # did weighting predict the majority class more often? 200 | expect_true( 201 | sum(mnl_pred_lbfgs_wts$.pred_class == min_class) > 202 | sum(mnl_pred_lbfgs_unwt$.pred_class == min_class) 203 | ) 204 | 205 | }) 206 | 207 | 208 | -------------------------------------------------------------------------------- /tests/testthat/test-schedulers.R: -------------------------------------------------------------------------------- 1 | 2 | 3 | test_that("scheduling functions", { 4 | skip_if_not_installed("purrr") 5 | library(purrr) 6 | 7 | x <- 0:100 8 | 9 | # ------------------------------------------------------------------------------ 10 | 11 | expect_equal( 12 | map_dbl(x, schedule_decay_expo), 13 | 0.1 * exp(-x) 14 | ) 15 | 16 | expect_equal( 17 | map_dbl(x, schedule_decay_expo, initial = 1/3, decay = 7/8), 18 | 1 / 3 * exp(-7 / 8 * x) 19 | ) 20 | 21 | expect_snapshot_error(schedule_decay_expo(1, initial = -1)) 22 | expect_snapshot_error(schedule_decay_expo(1, decay = -1)) 23 | 24 | # ------------------------------------------------------------------------------ 25 | 26 | expect_equal( 27 | map_dbl(x, schedule_decay_time), 28 | 0.1 / (1 + x) 29 | ) 30 | 31 | expect_equal( 32 | map_dbl(x, schedule_decay_time, initial = 1/3, decay = 7/8), 33 | 1 / 3 / (1 + 7 / 8 * x) 34 | ) 35 | 36 | expect_snapshot_error(schedule_decay_time(1, initial = -1)) 37 | expect_snapshot_error(schedule_decay_time(1, decay = -1)) 38 | 39 | # ------------------------------------------------------------------------------ 40 | 41 | expect_equal( 42 | map_dbl(x, schedule_step), 43 | 0.1 * (1 / 2) ^ floor(x / 5) 44 | ) 45 | 46 | expect_equal( 47 | map_dbl(x, schedule_step, initial = 1/3, reduction = 7/8, steps = 3), 48 | 1 / 3 * (7 / 8) ^ floor(x / 3) 49 | ) 50 | 51 | expect_snapshot_error(schedule_step(1, initial = -1)) 52 | expect_snapshot_error(schedule_step(1, reduction = -1)) 53 | expect_snapshot_error(schedule_step(1, steps = -1)) 54 | 55 | # ------------------------------------------------------------------------------ 56 | 57 | expect_true( all(map_dbl(x[x %% 10 == 0], schedule_cyclic) == 0.001) ) 58 | 59 | inc <- 0.0198 60 | expect_equal( 61 | abs(diff(map_dbl(x, schedule_cyclic))), 62 | rep(inc, 100), 63 | tolerance = 0.001 64 | ) 65 | 66 | expect_equal( 67 | sign(diff(map_dbl(x, schedule_cyclic))), 68 | rep(rep(c(1, -1), each = 5), times = 10), 69 | tolerance = 0.001 70 | ) 71 | 72 | expect_true( all(map_dbl(x[x %% 20 == 0], schedule_cyclic, step_size = 10) == 0.001) ) 73 | 74 | 75 | expect_snapshot_error(schedule_cyclic(1, step_size = -1)) 76 | expect_snapshot_error(schedule_cyclic(1, largest = -1)) 77 | 78 | # ------------------------------------------------------------------------------ 79 | 80 | expect_equal(set_learn_rate(.x, 1, type = "none"), 1) 81 | expect_equal(set_learn_rate(.x, 0.01, type = "none", potato = 1), .01) 82 | 83 | expect_equal( 84 | map_dbl(x, schedule_decay_time, initial = 1/3, decay = 7/8), 85 | map_dbl(x, ~ set_learn_rate(.x, 0.1, "decay_time", initial = 1/3, decay = 7/8)) 86 | ) 87 | 88 | expect_equal( 89 | map_dbl(x, schedule_decay_expo, initial = 1/3, decay = 7/8), 90 | map_dbl(x, ~ set_learn_rate(.x, 0.1, "decay_expo", initial = 1/3, decay = 7/8)) 91 | ) 92 | 93 | expect_equal( 94 | map_dbl(x, schedule_step, initial = 1/3, reduction = 7/8, steps = 3), 95 | map_dbl(x, ~ set_learn_rate(.x, 0.1, "step", initial = 1/3, reduction = 7/8, steps = 3)) 96 | ) 97 | 98 | expect_snapshot_error(set_learn_rate(1, 1, type = "decay_time", initial = -1)) 99 | expect_snapshot_error(set_learn_rate(1, 1, type = "random")) 100 | 101 | }) 102 | --------------------------------------------------------------------------------